diff --git a/api/oauth_api/gitee_api.go b/api/oauth_api/gitee_api.go index f7fbd47..542ab84 100644 --- a/api/oauth_api/gitee_api.go +++ b/api/oauth_api/gitee_api.go @@ -171,7 +171,7 @@ func (OAuthAPI) GiteeCallback(c *gin.Context) { Id := strconv.Itoa(giteeUser.ID) userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGitee) - if errors.Is(err, gorm.ErrRecordNotFound) { + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { // 第一次登录,创建用户 uid := idgen.NextId() uidStr := strconv.FormatInt(uid, 10) diff --git a/api/oauth_api/github_api.go b/api/oauth_api/github_api.go index e3e304a..a2a4099 100644 --- a/api/oauth_api/github_api.go +++ b/api/oauth_api/github_api.go @@ -176,7 +176,7 @@ func (OAuthAPI) Callback(c *gin.Context) { } Id := strconv.Itoa(gitHubUser.ID) userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGithub) - if errors.Is(err, gorm.ErrRecordNotFound) { + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { // 第一次登录,创建用户 uid := idgen.NextId() uidStr := strconv.FormatInt(uid, 10) diff --git a/api/oauth_api/qq_api.go b/api/oauth_api/qq_api.go index c93b453..bc20967 100644 --- a/api/oauth_api/qq_api.go +++ b/api/oauth_api/qq_api.go @@ -2,18 +2,52 @@ package oauth_api import ( "encoding/json" + "errors" "fmt" ginI18n "github.com/gin-contrib/i18n" "github.com/gin-gonic/gin" + "github.com/yitter/idgenerator-go/idgen" + "gorm.io/gorm" "net/http" + "schisandra-cloud-album/common/enum" "schisandra-cloud-album/common/result" "schisandra-cloud-album/global" + "schisandra-cloud-album/model" + "strconv" ) type AuthQQme struct { ClientID string `json:"client_id"` OpenID string `json:"openid"` } +type QQToken struct { + AccessToken string `json:"access_token"` + ExpireIn string `json:"expire_in"` + RefreshToken string `json:"refresh_token"` +} + +type QQUserInfo struct { + City string `json:"city"` + Figureurl string `json:"figureurl"` + Figureurl1 string `json:"figureurl_1"` + Figureurl2 string `json:"figureurl_2"` + FigureurlQq string `json:"figureurl_qq"` + FigureurlQq1 string `json:"figureurl_qq_1"` + FigureurlQq2 string `json:"figureurl_qq_2"` + Gender string `json:"gender"` + GenderType int `json:"gender_type"` + IsLost int `json:"is_lost"` + IsYellowVip string `json:"is_yellow_vip"` + IsYellowYearVip string `json:"is_yellow_year_vip"` + Level string `json:"level"` + Msg string `json:"msg"` + Nickname string `json:"nickname"` + Province string `json:"province"` + Ret int `json:"ret"` + Vip string `json:"vip"` + Year string `json:"year"` + YellowVipLevel string `json:"yellow_vip_level"` +} // GetQQRedirectUrl 获取登录地址 // @Summary 获取QQ登录地址 @@ -37,87 +71,89 @@ func GetQQTokenAuthUrl(code string) string { clientSecret := global.CONFIG.OAuth.QQ.ClientSecret redirectURI := global.CONFIG.OAuth.QQ.RedirectURI return fmt.Sprintf( - "https://graph.qq.com/oauth2.0/token?grant_type=authorization_code&client_id=%s&client_secret=%s&code=%s&redirect_uri=%s", + "https://graph.qq.com/oauth2.0/token?grant_type=authorization_code&client_id=%s&client_secret=%s&code=%s&redirect_uri=%s&fmt=json", clientId, clientSecret, code, redirectURI, ) } // GetQQToken 获取 token -func GetQQToken(url string) (*Token, error) { +func GetQQToken(url string) (*QQToken, error) { // 形成请求 var req *http.Request var err error if req, err = http.NewRequest(http.MethodGet, url, nil); err != nil { + global.LOG.Error(err) return nil, err } - req.Header.Set("accept", "application/json") // 发送请求并获得响应 var httpClient = http.Client{} var res *http.Response if res, err = httpClient.Do(req); err != nil { + global.LOG.Error(err) return nil, err } - - // 将响应体解析为 token,并返回 - var token Token + //将响应体解析为 token,并返回 + var token QQToken if err = json.NewDecoder(res.Body).Decode(&token); err != nil { + global.LOG.Error(err) return nil, err } return &token, nil } // GetQQUserOpenID 获取用户 openid -func GetQQUserOpenID(token *Token) (*AuthQQme, error) { +func GetQQUserOpenID(token *QQToken) (*AuthQQme, error) { // 形成请求 - var userInfoUrl = "https://graph.qq.com/oauth2.0/me" // github用户信息获取接口 + var userInfoUrl = "https://graph.qq.com/oauth2.0/me?access_token=" + token.AccessToken + "&fmt=json" var req *http.Request var err error if req, err = http.NewRequest(http.MethodGet, userInfoUrl, nil); err != nil { + global.LOG.Error(err) return nil, err } - req.Header.Set("accept", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken)) // 发送请求并获取响应 var client = http.Client{} var res *http.Response if res, err = client.Do(req); err != nil { + global.LOG.Error(err) return nil, err } // 将响应体解析为 AuthQQme,并返回 var authQQme AuthQQme if err = json.NewDecoder(res.Body).Decode(&authQQme); err != nil { + global.LOG.Error(err) return nil, err } return &authQQme, nil } // GetQQUserUserInfo 获取用户信息 -func GetQQUserUserInfo(token *Token, openId string) (map[string]interface{}, error) { +func GetQQUserUserInfo(token *QQToken, openId string) (map[string]interface{}, error) { clientId := global.CONFIG.OAuth.QQ.ClientID // 形成请求 - var userInfoUrl = "https://graph.qq.com/user/get_user_info??access_token=" + token.AccessToken + "&oauth_consumer_key=" + clientId + "&openid=" + openId + var userInfoUrl = "https://graph.qq.com/user/get_user_info?access_token=" + token.AccessToken + "&oauth_consumer_key=" + clientId + "&openid=" + openId var req *http.Request var err error if req, err = http.NewRequest(http.MethodGet, userInfoUrl, nil); err != nil { return nil, err } - req.Header.Set("accept", "application/json") - //req.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken)) // 发送请求并获取响应 var client = http.Client{} var res *http.Response if res, err = client.Do(req); err != nil { + global.LOG.Error(err) return nil, err } // 将响应的数据写入 userInfo 中,并返回 var userInfo = make(map[string]interface{}) if err = json.NewDecoder(res.Body).Decode(&userInfo); err != nil { + global.LOG.Error(err) return nil, err } return userInfo, nil @@ -139,13 +175,15 @@ func (OAuthAPI) QQCallback(c *gin.Context) { } // 通过 code, 获取 token var tokenAuthUrl = GetQQTokenAuthUrl(code) - var token *Token + var token *QQToken if token, err = GetQQToken(tokenAuthUrl); err != nil { global.LOG.Error(err) return } + // 通过 token,获取 openid authQQme, err := GetQQUserOpenID(token) if err != nil { + global.LOG.Error(err) return } @@ -155,6 +193,87 @@ func (OAuthAPI) QQCallback(c *gin.Context) { global.LOG.Error(err) return } - result.OkWithData(userInfo, c) - return + + userInfoBytes, err := json.Marshal(userInfo) + if err != nil { + global.LOG.Error(err) + return + } + var qqUserInfo QQUserInfo + err = json.Unmarshal(userInfoBytes, &qqUserInfo) + if err != nil { + global.LOG.Error(err) + return + } + userSocial, err := userSocialService.QueryUserSocialByOpenID(authQQme.OpenID, enum.OAuthSourceQQ) + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { + // 第一次登录,创建用户 + uid := idgen.NextId() + uidStr := strconv.FormatInt(uid, 10) + location := qqUserInfo.Province + "|" + qqUserInfo.City + user := model.ScaAuthUser{ + UID: &uidStr, + Username: &authQQme.OpenID, + Nickname: &qqUserInfo.Nickname, + Avatar: &qqUserInfo.FigureurlQq1, + Gender: &qqUserInfo.Gender, + Location: &location, + } + addUser, err := userService.AddUser(user) + if err != nil { + global.LOG.Error(err) + return + } + qq := enum.OAuthSourceQQ + userSocial = model.ScaAuthUserSocial{ + UserID: &addUser.ID, + OpenID: &authQQme.OpenID, + Source: &qq, + } + err = userSocialService.AddUserSocial(userSocial) + if err != nil { + global.LOG.Error(err) + return + } + userRole := model.ScaAuthUserRole{ + UserID: addUser.ID, + RoleID: enum.User, + } + err = userRoleService.AddUserRole(userRole) + if err != nil { + global.LOG.Error(err) + return + } + res, data := HandelUserLogin(addUser) + if !res { + return + } + tokenData, err := json.Marshal(data) + if err != nil { + global.LOG.Error(err) + return + } + formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web) + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript)) + return + } else { + user, err := userService.QueryUserById(userSocial.UserID) + if err != nil { + global.LOG.Error(err) + return + } + res, data := HandelUserLogin(user) + if !res { + return + } + tokenData, err := json.Marshal(data) + if err != nil { + global.LOG.Error(err) + return + } + formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web) + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript)) + return + } + } diff --git a/api/oauth_api/wechat_api.go b/api/oauth_api/wechat_api.go index b8bd415..a527ba5 100644 --- a/api/oauth_api/wechat_api.go +++ b/api/oauth_api/wechat_api.go @@ -202,7 +202,7 @@ func wechatLoginHandler(openId string, clientId string) bool { return false } authUserSocial, err := userSocialService.QueryUserSocialByOpenID(openId, enum.OAuthSourceWechat) - if errors.Is(err, gorm.ErrRecordNotFound) { + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { uid := idgen.NextId() uidStr := strconv.FormatInt(uid, 10) createUser := model.ScaAuthUser{ diff --git a/api/websocket_api/gws_api.go b/api/websocket_api/gws_api.go index 8715eee..eb63658 100644 --- a/api/websocket_api/gws_api.go +++ b/api/websocket_api/gws_api.go @@ -23,7 +23,7 @@ func (WebsocketAPI) NewGWSServer(c *gin.Context) { ReadBufferSize: 1024, // 读缓冲区大小 ParallelEnabled: true, // 开启并行消息处理 Recovery: gws.Recovery, // 开启异常恢复 - CheckUtf8Enabled: true, // 开启UTF8校验 + CheckUtf8Enabled: false, // 关闭UTF8校验 PermessageDeflate: gws.PermessageDeflate{ Enabled: true, // 开启压缩 }, diff --git a/router/modules/oauth_router.go b/router/modules/oauth_router.go index 7ffc65f..a5534b4 100644 --- a/router/modules/oauth_router.go +++ b/router/modules/oauth_router.go @@ -14,8 +14,8 @@ func OauthRouter(router *gin.RouterGroup) { { wechatRouter.GET("/generate_client_id", oauth.GenerateClientId) wechatRouter.GET("/get_temp_qrcode", oauth.GetTempQrCode) - //wechatRouter.GET("/callback", oauth.CallbackVerify) - wechatRouter.POST("/callback", oauth.CallbackNotify) + wechatRouter.GET("/callback", oauth.CallbackVerify) + //wechatRouter.POST("/callback", oauth.CallbackNotify) } githubRouter := group.Group("/github") {