Files
schisandra-album-cloud-micr…/app/auth/api/internal/logic/oauth/github_callback_logic.go
2024-12-24 18:01:54 +08:00

245 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package oauth
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
model2 "schisandra-album-cloud-microservices/app/auth/model/mysql/model"
"schisandra-album-cloud-microservices/common/constant"
"strconv"
"github.com/yitter/idgenerator-go/idgen"
"gorm.io/gorm"
"github.com/zeromicro/go-zero/core/logx"
"schisandra-album-cloud-microservices/app/auth/api/internal/svc"
"schisandra-album-cloud-microservices/app/auth/api/internal/types"
)
type GithubCallbackLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
type GitHubUser struct {
AvatarURL string `json:"avatar_url"`
Bio interface{} `json:"bio"`
Blog string `json:"blog"`
Company interface{} `json:"company"`
CreatedAt string `json:"created_at"`
Email string `json:"email"`
EventsURL string `json:"events_url"`
Followers int `json:"followers"`
FollowersURL string `json:"followers_url"`
Following int `json:"following"`
FollowingURL string `json:"following_url"`
GistsURL string `json:"gists_url"`
GravatarID string `json:"gravatar_id"`
Hireable interface{} `json:"hireable"`
HTMLURL string `json:"html_url"`
ID int `json:"id"`
Location interface{} `json:"location"`
Login string `json:"login"`
Name string `json:"name"`
NodeID string `json:"node_id"`
NotificationEmail interface{} `json:"notification_email"`
OrganizationsURL string `json:"organizations_url"`
PublicGists int `json:"public_gists"`
PublicRepos int `json:"public_repos"`
ReceivedEventsURL string `json:"received_events_url"`
ReposURL string `json:"repos_url"`
SiteAdmin bool `json:"site_admin"`
StarredURL string `json:"starred_url"`
SubscriptionsURL string `json:"subscriptions_url"`
TwitterUsername interface{} `json:"twitter_username"`
Type string `json:"type"`
UpdatedAt string `json:"updated_at"`
URL string `json:"url"`
}
func NewGithubCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GithubCallbackLogic {
return &GithubCallbackLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *GithubCallbackLogic) GithubCallback(r *http.Request, req *types.OAuthCallbackRequest) (string, error) {
// 获取 token
tokenAuthUrl := l.GetTokenAuthUrl(req.Code)
token, err := l.GetToken(tokenAuthUrl)
if err != nil {
return "", err
}
if token == nil {
return "", errors.New("get github token failed")
}
// 获取用户信息
userInfo, err := l.GetUserInfo(token)
if err != nil {
return "", err
}
if userInfo == nil {
return "", errors.New("get github user info failed")
}
// 处理用户信息
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
return "", err
}
var gitHubUser GitHubUser
err = json.Unmarshal(userInfoBytes, &gitHubUser)
if err != nil {
return "", err
}
Id := strconv.Itoa(gitHubUser.ID)
tx := l.svcCtx.DB.Begin()
userSocial := l.svcCtx.DB.ScaAuthUserSocial
socialUser, err := tx.ScaAuthUserSocial.Where(userSocial.OpenID.Eq(Id), userSocial.Source.Eq(constant.OAuthSourceGithub)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return "", err
}
if socialUser == nil {
// 创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
male := constant.Male
addUser := &model2.ScaAuthUser{
UID: uidStr,
Avatar: gitHubUser.AvatarURL,
Username: gitHubUser.Login,
Nickname: gitHubUser.Name,
Blog: gitHubUser.Blog,
Email: gitHubUser.Email,
Gender: male,
}
err = tx.ScaAuthUser.Create(addUser)
if err != nil {
_ = tx.Rollback()
return "", err
}
githubUser := constant.OAuthSourceGithub
newSocialUser := &model2.ScaAuthUserSocial{
UserID: uidStr,
OpenID: Id,
Source: githubUser,
}
err = tx.ScaAuthUserSocial.Create(newSocialUser)
if err != nil {
_ = tx.Rollback()
return "", err
}
if res, err := l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User); !res || err != nil {
_ = tx.Rollback()
return "", err
}
data, err := HandleOauthLoginResponse(addUser, l.svcCtx, r, l.ctx)
if err != nil {
_ = tx.Rollback()
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
} else {
authUser := l.svcCtx.DB.ScaAuthUser
authUserInfo, err := tx.ScaAuthUser.Where(authUser.UID.Eq(socialUser.UserID)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
_ = tx.Rollback()
return "", err
}
data, err := HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, l.ctx)
if err != nil {
_ = tx.Rollback()
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
}
}
// GetTokenAuthUrl 通过code获取token认证url
func (l *GithubCallbackLogic) GetTokenAuthUrl(code string) string {
clientId := l.svcCtx.Config.OAuth.Github.ClientID
clientSecret := l.svcCtx.Config.OAuth.Github.ClientSecret
return fmt.Sprintf(
"https://github.com/login/oauth/access_token?client_id=%s&client_secret=%s&code=%s",
clientId, clientSecret, code,
)
}
// GetToken 获取 token
func (l *GithubCallbackLogic) GetToken(url string) (*Token, error) {
// 形成请求
var req *http.Request
var err error
if req, err = http.NewRequest(http.MethodGet, url, nil); err != nil {
return nil, err
}
req.Header.Set("accept", "application/json")
// 发送请求并获得响应
var httpClient = http.Client{}
var res *http.Response
if res, err = httpClient.Do(req); err != nil {
return nil, err
}
// 将响应体解析为 token并返回
var token Token
if err = json.NewDecoder(res.Body).Decode(&token); err != nil {
return nil, err
}
return &token, nil
}
// GetUserInfo 获取用户信息
func (l *GithubCallbackLogic) GetUserInfo(token *Token) (map[string]interface{}, error) {
// 形成请求
var userInfoUrl = "https://api.github.com/user" // github用户信息获取接口
var req *http.Request
var err error
if req, err = http.NewRequest(http.MethodGet, userInfoUrl, nil); err != nil {
return nil, err
}
req.Header.Set("accept", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken))
// 发送请求并获取响应
var client = http.Client{}
var res *http.Response
if res, err = client.Do(req); err != nil {
return nil, err
}
// 将响应的数据写入 userInfo 中,并返回
var userInfo = make(map[string]interface{})
if err = json.NewDecoder(res.Body).Decode(&userInfo); err != nil {
return nil, err
}
return userInfo, nil
}