diff --git a/api/oauth_api/gitee_api.go b/api/oauth_api/gitee_api.go index bb0ac0b..f7fbd47 100644 --- a/api/oauth_api/gitee_api.go +++ b/api/oauth_api/gitee_api.go @@ -170,7 +170,7 @@ func (OAuthAPI) GiteeCallback(c *gin.Context) { } Id := strconv.Itoa(giteeUser.ID) - userSocial, err := userSocialService.QueryUserSocialByUUID(Id) + userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGitee) if errors.Is(err, gorm.ErrRecordNotFound) { // 第一次登录,创建用户 uid := idgen.NextId() @@ -220,7 +220,7 @@ func (OAuthAPI) GiteeCallback(c *gin.Context) { formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web) c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript)) } else { - user, err := userService.QueryUserByUsername(giteeUser.Login) + user, err := userService.QueryUserById(userSocial.UserID) if err != nil { global.LOG.Error(err) return diff --git a/api/oauth_api/github_api.go b/api/oauth_api/github_api.go index b222f33..e3e304a 100644 --- a/api/oauth_api/github_api.go +++ b/api/oauth_api/github_api.go @@ -175,7 +175,7 @@ func (OAuthAPI) Callback(c *gin.Context) { return } Id := strconv.Itoa(gitHubUser.ID) - userSocial, err := userSocialService.QueryUserSocialByUUID(Id) + userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGithub) if errors.Is(err, gorm.ErrRecordNotFound) { // 第一次登录,创建用户 uid := idgen.NextId() @@ -225,7 +225,7 @@ func (OAuthAPI) Callback(c *gin.Context) { formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web) c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript)) } else { - user, err := userService.QueryUserByUsername(gitHubUser.Login) + user, err := userService.QueryUserById(userSocial.UserID) if err != nil { global.LOG.Error(err) return diff --git a/api/oauth_api/wechat_api.go b/api/oauth_api/wechat_api.go index cd257f0..b8bd415 100644 --- a/api/oauth_api/wechat_api.go +++ b/api/oauth_api/wechat_api.go @@ -201,7 +201,7 @@ func wechatLoginHandler(openId string, clientId string) bool { if openId == "" { return false } - authUserSocial, err := userSocialService.QueryUserSocialByOpenID(openId) + authUserSocial, err := userSocialService.QueryUserSocialByOpenID(openId, enum.OAuthSourceWechat) if errors.Is(err, gorm.ErrRecordNotFound) { uid := idgen.NextId() uidStr := strconv.FormatInt(uid, 10) diff --git a/api/user_api/user_api.go b/api/user_api/user_api.go index 84b52e0..26dd9cd 100644 --- a/api/user_api/user_api.go +++ b/api/user_api/user_api.go @@ -339,8 +339,10 @@ func (UserAPI) RefreshHandler(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) return } - token := redis.Get(constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID).Val() - if token == "" { + + token, err := redis.Get(constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID).Result() + if token == "" || err != nil { + global.LOG.Errorln(err) result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) return } diff --git a/service/user_social_service/user_social_service.go b/service/user_social_service/user_social_service.go index 4cc1139..d50f7e7 100644 --- a/service/user_social_service/user_social_service.go +++ b/service/user_social_service/user_social_service.go @@ -17,10 +17,10 @@ func (UserSocialService) AddUserSocial(user model.ScaAuthUserSocial) error { return nil } -// QueryUserSocialByOpenID 根据openID查询用户信息 -func (UserSocialService) QueryUserSocialByOpenID(openID string) (model.ScaAuthUserSocial, error) { +// QueryUserSocialByOpenID 根据openID和source查询用户信息 +func (UserSocialService) QueryUserSocialByOpenID(openID string, source string) (model.ScaAuthUserSocial, error) { var user model.ScaAuthUserSocial - result := global.DB.Where("open_id = ? and deleted = 0", openID).First(&user) + result := global.DB.Where("open_id = ? and source = ? and deleted = 0", openID, source).First(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return model.ScaAuthUserSocial{}, result.Error @@ -30,10 +30,10 @@ func (UserSocialService) QueryUserSocialByOpenID(openID string) (model.ScaAuthUs return user, nil } -// QueryUserSocialByUUID 根据uuid查询用户信息 -func (UserSocialService) QueryUserSocialByUUID(openID string) (model.ScaAuthUserSocial, error) { +// QueryUserSocialByUUID 根据uuid和source查询用户信息 +func (UserSocialService) QueryUserSocialByUUID(openID string, source string) (model.ScaAuthUserSocial, error) { var user model.ScaAuthUserSocial - result := global.DB.Where("uuid = ? and deleted = 0", openID).First(&user) + result := global.DB.Where("uuid = ? and source = ? and deleted = 0", openID, source).First(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return model.ScaAuthUserSocial{}, result.Error diff --git a/tmp/schisandra_cloud_album_linux_linux b/tmp/schisandra_cloud_album_linux_linux new file mode 100644 index 0000000..e0c3dd3 Binary files /dev/null and b/tmp/schisandra_cloud_album_linux_linux differ diff --git a/utils/jwt.go b/utils/jwt.go index 32cd909..18c8150 100644 --- a/utils/jwt.go +++ b/utils/jwt.go @@ -18,8 +18,11 @@ type AccessJWTPayload struct { RoleID []*int64 `json:"role_id"` Type *string `json:"type" default:"access"` } -type JWTClaims struct { +type AccessJWTClaims struct { AccessJWTPayload + jwt.RegisteredClaims +} +type RefreshJWTClaims struct { RefreshJWTPayload jwt.RegisteredClaims } @@ -29,7 +32,7 @@ var MySecret []byte // GenerateAccessToken generates a JWT token with the given payload func GenerateAccessToken(payload AccessJWTPayload) (string, error) { MySecret = []byte(global.CONFIG.JWT.Secret) - claims := JWTClaims{ + claims := AccessJWTClaims{ AccessJWTPayload: payload, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 2)), @@ -53,10 +56,10 @@ func GenerateAccessToken(payload AccessJWTPayload) (string, error) { // GenerateRefreshToken generates a JWT token with the given payload, and returns the accessToken and refreshToken func GenerateRefreshToken(payload RefreshJWTPayload, days time.Duration) (string, int64) { MySecret = []byte(global.CONFIG.JWT.Secret) - refreshClaims := JWTClaims{ + refreshClaims := RefreshJWTClaims{ RefreshJWTPayload: payload, RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(days)), // 7天 + ExpiresAt: jwt.NewNumericDate(time.Now().Add(days)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), Issuer: global.CONFIG.JWT.Issuer, @@ -84,14 +87,14 @@ func ParseAccessToken(tokenString string) (*AccessJWTPayload, bool, error) { global.LOG.Error(err) return nil, false, err } - token, err := jwt.ParseWithClaims(string(plaintext), &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(string(plaintext), &AccessJWTClaims{}, func(token *jwt.Token) (interface{}, error) { return MySecret, nil }) if err != nil { global.LOG.Error(err) return nil, false, err } - if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { + if claims, ok := token.Claims.(*AccessJWTClaims); ok && token.Valid { return &claims.AccessJWTPayload, true, nil } return nil, false, err @@ -105,14 +108,14 @@ func ParseRefreshToken(tokenString string) (*RefreshJWTPayload, bool, error) { global.LOG.Error(err) return nil, false, err } - token, err := jwt.ParseWithClaims(string(plaintext), &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(string(plaintext), &RefreshJWTClaims{}, func(token *jwt.Token) (interface{}, error) { return MySecret, nil }) if err != nil { global.LOG.Error(err) return nil, false, err } - if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { + if claims, ok := token.Claims.(*RefreshJWTClaims); ok && token.Valid { return &claims.RefreshJWTPayload, true, nil } return nil, false, err