diff --git a/common/types/types.go b/common/types/types.go new file mode 100644 index 0000000..ef70f89 --- /dev/null +++ b/common/types/types.go @@ -0,0 +1,36 @@ +package types + +import ( + "encoding/json" +) + +// ResponseData 返回数据 +type ResponseData struct { + AccessToken string `json:"access_token"` + UID *string `json:"uid"` + Username string `json:"username,omitempty"` + Nickname string `json:"nickname"` + Avatar string `json:"avatar"` + Status int64 `json:"status"` +} + +func (res ResponseData) MarshalBinary() ([]byte, error) { + return json.Marshal(res) +} + +func (res ResponseData) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, &res) +} + +type RedisToken struct { + AccessToken string `json:"access_token"` + UID string `json:"uid"` +} + +func (res RedisToken) MarshalBinary() ([]byte, error) { + return json.Marshal(res) +} + +func (res RedisToken) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, &res) +} diff --git a/controller/comment_controller/comment_controller.go b/controller/comment_controller/comment_controller.go index ccf7f6c..122f61c 100644 --- a/controller/comment_controller/comment_controller.go +++ b/controller/comment_controller/comment_controller.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "github.com/mssola/useragent" + "schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/enum" "schisandra-cloud-album/common/result" "schisandra-cloud-album/global" @@ -62,7 +63,8 @@ func (CommentController) CommentSubmit(c *gin.Context) { browser, _ := ua.Browser() operatingSystem := ua.OS() isAuthor := 0 - if commentRequest.UserID == commentRequest.Author { + uid := utils.GetSession(c, constant.SessionKey).UID + if uid == commentRequest.Author { isAuthor = 1 } xssFilterContent := utils.XssFilter(commentRequest.Content) @@ -71,9 +73,10 @@ func (CommentController) CommentSubmit(c *gin.Context) { return } commentContent := global.SensitiveManager.Replace(xssFilterContent, '*') + commentReply := model.ScaCommentReply{ Content: commentContent, - UserId: commentRequest.UserID, + UserId: uid, TopicId: commentRequest.TopicId, TopicType: enum.CommentTopicType, CommentType: enum.COMMENT, @@ -84,7 +87,7 @@ func (CommentController) CommentSubmit(c *gin.Context) { OperatingSystem: operatingSystem, Agent: userAgent, } - commentId, response := commentReplyService.SubmitCommentService(&commentReply, commentRequest.TopicId, commentRequest.UserID, commentRequest.Images) + commentId, response := commentReplyService.SubmitCommentService(&commentReply, commentRequest.TopicId, uid, commentRequest.Images) if !response { result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c) return @@ -92,7 +95,7 @@ func (CommentController) CommentSubmit(c *gin.Context) { responseData := model.ScaCommentReply{ Id: commentId, Content: commentContent, - UserId: commentRequest.UserID, + UserId: uid, TopicId: commentRequest.TopicId, Author: isAuthor, Location: location, @@ -147,7 +150,8 @@ func (CommentController) ReplySubmit(c *gin.Context) { browser, _ := ua.Browser() operatingSystem := ua.OS() isAuthor := 0 - if replyCommentRequest.UserID == replyCommentRequest.Author { + uid := utils.GetSession(c, constant.SessionKey).UID + if uid == replyCommentRequest.Author { isAuthor = 1 } xssFilterContent := utils.XssFilter(replyCommentRequest.Content) @@ -158,7 +162,7 @@ func (CommentController) ReplySubmit(c *gin.Context) { commentContent := global.SensitiveManager.Replace(xssFilterContent, '*') commentReply := model.ScaCommentReply{ Content: commentContent, - UserId: replyCommentRequest.UserID, + UserId: uid, TopicId: replyCommentRequest.TopicId, TopicType: enum.CommentTopicType, CommentType: enum.REPLY, @@ -171,7 +175,7 @@ func (CommentController) ReplySubmit(c *gin.Context) { OperatingSystem: operatingSystem, Agent: userAgent, } - commentReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyCommentRequest.TopicId, replyCommentRequest.UserID, replyCommentRequest.Images) + commentReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyCommentRequest.TopicId, uid, replyCommentRequest.Images) if !response { result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c) return @@ -179,7 +183,7 @@ func (CommentController) ReplySubmit(c *gin.Context) { responseData := model.ScaCommentReply{ Id: commentReplyId, Content: commentContent, - UserId: replyCommentRequest.UserID, + UserId: uid, TopicId: replyCommentRequest.TopicId, ReplyId: replyCommentRequest.ReplyId, ReplyUser: replyCommentRequest.ReplyUser, @@ -236,7 +240,8 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) { browser, _ := ua.Browser() operatingSystem := ua.OS() isAuthor := 0 - if replyReplyRequest.UserID == replyReplyRequest.Author { + uid := utils.GetSession(c, constant.SessionKey).UID + if uid == replyReplyRequest.Author { isAuthor = 1 } xssFilterContent := utils.XssFilter(replyReplyRequest.Content) @@ -247,7 +252,7 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) { commentContent := global.SensitiveManager.Replace(xssFilterContent, '*') commentReply := model.ScaCommentReply{ Content: commentContent, - UserId: replyReplyRequest.UserID, + UserId: uid, TopicId: replyReplyRequest.TopicId, TopicType: enum.CommentTopicType, CommentType: enum.REPLY, @@ -261,7 +266,7 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) { OperatingSystem: operatingSystem, Agent: userAgent, } - commentReplyReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyReplyRequest.TopicId, replyReplyRequest.UserID, replyReplyRequest.Images) + commentReplyReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyReplyRequest.TopicId, uid, replyReplyRequest.Images) if !response { result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c) return @@ -269,7 +274,7 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) { responseData := model.ScaCommentReply{ Id: commentReplyReplyId, Content: commentContent, - UserId: replyReplyRequest.UserID, + UserId: uid, TopicId: replyReplyRequest.TopicId, ReplyTo: replyReplyRequest.ReplyTo, ReplyId: replyReplyRequest.ReplyId, @@ -299,7 +304,8 @@ func (CommentController) CommentList(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - response := commentReplyService.GetCommentListService(commentListRequest.UserID, commentListRequest.TopicId, commentListRequest.Page, commentListRequest.Size, commentListRequest.IsHot) + uid := utils.GetSession(c, constant.SessionKey).UID + response := commentReplyService.GetCommentListService(uid, commentListRequest.TopicId, commentListRequest.Page, commentListRequest.Size, commentListRequest.IsHot) result.OkWithData(response, c) return } @@ -319,7 +325,8 @@ func (CommentController) ReplyList(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - response := commentReplyService.GetCommentReplyListService(replyListRequest.UserID, replyListRequest.TopicId, replyListRequest.CommentId, replyListRequest.Page, replyListRequest.Size) + uid := utils.GetSession(c, constant.SessionKey).UID + response := commentReplyService.GetCommentReplyListService(uid, replyListRequest.TopicId, replyListRequest.CommentId, replyListRequest.Page, replyListRequest.Size) result.OkWithData(response, c) return } @@ -339,7 +346,8 @@ func (CommentController) CommentLikes(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - res := commentReplyService.CommentLikeService(likeRequest.UserID, likeRequest.CommentId, likeRequest.TopicId) + uid := utils.GetSession(c, constant.SessionKey).UID + res := commentReplyService.CommentLikeService(uid, likeRequest.CommentId, likeRequest.TopicId) if !res { result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeFailed"), c) return @@ -362,7 +370,8 @@ func (CommentController) CancelCommentLikes(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - res := commentReplyService.CommentDislikeService(cancelLikeRequest.UserID, cancelLikeRequest.CommentId, cancelLikeRequest.TopicId) + uid := utils.GetSession(c, constant.SessionKey).UID + res := commentReplyService.CommentDislikeService(uid, cancelLikeRequest.CommentId, cancelLikeRequest.TopicId) if !res { result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentDislikeFailed"), c) return diff --git a/controller/comment_controller/request_param.go b/controller/comment_controller/request_param.go index 516aec1..2fa3c41 100644 --- a/controller/comment_controller/request_param.go +++ b/controller/comment_controller/request_param.go @@ -4,43 +4,43 @@ package comment_controller type CommentRequest struct { Content string `json:"content" binding:"required"` Images []string `json:"images"` - UserID string `json:"user_id" binding:"required"` - TopicId string `json:"topic_id" binding:"required"` - Author string `json:"author" binding:"required"` - Key string `json:"key" binding:"required"` - Point []int64 `json:"point" binding:"required"` + // UserID string `json:"user_id" binding:"required"` + TopicId string `json:"topic_id" binding:"required"` + Author string `json:"author" binding:"required"` + Key string `json:"key" binding:"required"` + Point []int64 `json:"point" binding:"required"` } // ReplyCommentRequest 回复评论请求参数 type ReplyCommentRequest struct { - Content string `json:"content" binding:"required"` - Images []string `json:"images"` - UserID string `json:"user_id" binding:"required"` - TopicId string `json:"topic_id" binding:"required"` - ReplyId int64 `json:"reply_id" binding:"required"` - ReplyUser string `json:"reply_user" binding:"required"` - Author string `json:"author" binding:"required"` - Key string `json:"key" binding:"required"` - Point []int64 `json:"point" binding:"required"` + Content string `json:"content" binding:"required"` + Images []string `json:"images"` + // UserID string `json:"user_id" binding:"required"` + TopicId string `json:"topic_id" binding:"required"` + ReplyId int64 `json:"reply_id" binding:"required"` + ReplyUser string `json:"reply_user" binding:"required"` + Author string `json:"author" binding:"required"` + Key string `json:"key" binding:"required"` + Point []int64 `json:"point" binding:"required"` } // ReplyReplyRequest 回复回复请求参数 type ReplyReplyRequest struct { - Content string `json:"content" binding:"required"` - Images []string `json:"images"` - UserID string `json:"user_id" binding:"required"` - TopicId string `json:"topic_id" binding:"required"` - ReplyTo int64 `json:"reply_to" binding:"required"` - ReplyId int64 `json:"reply_id" binding:"required"` - ReplyUser string `json:"reply_user" binding:"required"` - Author string `json:"author" binding:"required"` - Key string `json:"key" binding:"required"` - Point []int64 `json:"point" binding:"required"` + Content string `json:"content" binding:"required"` + Images []string `json:"images"` + // UserID string `json:"user_id" binding:"required"` + TopicId string `json:"topic_id" binding:"required"` + ReplyTo int64 `json:"reply_to" binding:"required"` + ReplyId int64 `json:"reply_id" binding:"required"` + ReplyUser string `json:"reply_user" binding:"required"` + Author string `json:"author" binding:"required"` + Key string `json:"key" binding:"required"` + Point []int64 `json:"point" binding:"required"` } // CommentListRequest 评论列表请求参数 type CommentListRequest struct { - UserID string `json:"user_id" binding:"required"` + // UserID string `json:"user_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"` Page int `json:"page" default:"1"` Size int `json:"size" default:"5"` @@ -49,7 +49,7 @@ type CommentListRequest struct { // ReplyListRequest 回复列表请求参数 type ReplyListRequest struct { - UserID string `json:"user_id" binding:"required"` + // UserID string `json:"user_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"` CommentId int64 `json:"comment_id" binding:"required"` Page int `json:"page" default:"1"` @@ -60,5 +60,5 @@ type ReplyListRequest struct { type CommentLikeRequest struct { TopicId string `json:"topic_id" binding:"required"` CommentId int64 `json:"comment_id" binding:"required"` - UserID string `json:"user_id" binding:"required"` + // UserID string `json:"user_id" binding:"required"` } diff --git a/controller/oauth_controller/oauth.go b/controller/oauth_controller/oauth.go index 1b75cca..5c5686e 100644 --- a/controller/oauth_controller/oauth.go +++ b/controller/oauth_controller/oauth.go @@ -3,16 +3,19 @@ package oauth_controller import ( "encoding/json" "fmt" - "github.com/gin-gonic/gin" "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/result" + "schisandra-cloud-album/common/types" "schisandra-cloud-album/global" "schisandra-cloud-album/service/impl" "schisandra-cloud-album/utils" - "sync" - "time" ) type OAuthController struct{} @@ -36,7 +39,6 @@ func HandleLoginResponse(c *gin.Context, uid string) { user := userService.QueryUserByUuidService(&uid) var accessToken, refreshToken string - var expiresAt int64 var err error var wg sync.WaitGroup var accessTokenErr error @@ -52,7 +54,7 @@ func HandleLoginResponse(c *gin.Context, uid string) { // 使用goroutine生成refreshToken go func() { defer wg.Done() // 完成时减少计数器 - refreshToken, expiresAt = utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &uid}, time.Hour*24*7) + refreshToken = utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &uid}, time.Hour*24*7) }() // 等待两个协程完成 @@ -64,29 +66,28 @@ func HandleLoginResponse(c *gin.Context, uid string) { return } - data := ResponseData{ - AccessToken: accessToken, - RefreshToken: refreshToken, - ExpiresAt: expiresAt, - UID: &uid, - UserInfo: UserInfo{ - Username: user.Username, - Nickname: user.Nickname, - Avatar: user.Avatar, - Email: user.Email, - Phone: user.Phone, - Gender: user.Gender, - Status: user.Status, - CreateAt: *user.CreatedTime, - }, + data := types.ResponseData{ + AccessToken: accessToken, + UID: &uid, + Username: user.Username, + Nickname: user.Nickname, + Avatar: user.Avatar, + Status: user.Status, } - - if err = utils.SetSession(c, constant.SessionKey, data); err != nil { + // 设置session + sessionData := utils.SessionData{ + RefreshToken: refreshToken, + UID: uid, + } + if err = utils.SetSession(c, constant.SessionKey, sessionData); err != nil { return } - + redisTokenData := types.RedisToken{ + AccessToken: accessToken, + UID: uid, + } // 将数据存入redis - if err = redis.Set(constant.UserLoginTokenRedisKey+uid, data, time.Hour*24*7).Err(); err != nil { + if err = redis.Set(constant.UserLoginTokenRedisKey+uid, redisTokenData, time.Hour*24*7).Err(); err != nil { global.LOG.Error(err) return } diff --git a/controller/oauth_controller/request_param.go b/controller/oauth_controller/request_param.go deleted file mode 100644 index 8c86cd4..0000000 --- a/controller/oauth_controller/request_param.go +++ /dev/null @@ -1,33 +0,0 @@ -package oauth_controller - -import ( - "encoding/json" - "time" -) - -// ResponseData 返回数据 -type ResponseData struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresAt int64 `json:"expires_at"` - UID *string `json:"uid"` - UserInfo UserInfo `json:"user_info"` -} -type UserInfo struct { - Username string `json:"username,omitempty"` - Nickname string `json:"nickname"` - Avatar string `json:"avatar"` - Phone string `json:"phone,omitempty"` - Email string `json:"email,omitempty"` - Gender string `json:"gender"` - Status int64 `json:"status"` - CreateAt time.Time `json:"create_at"` -} - -func (res ResponseData) MarshalBinary() ([]byte, error) { - return json.Marshal(res) -} - -func (res ResponseData) UnmarshalBinary(data []byte) error { - return json.Unmarshal(data, &res) -} diff --git a/controller/oauth_controller/wechat_controller.go b/controller/oauth_controller/wechat_controller.go index 22f0553..e9ce2cb 100644 --- a/controller/oauth_controller/wechat_controller.go +++ b/controller/oauth_controller/wechat_controller.go @@ -1,9 +1,12 @@ package oauth_controller import ( - "encoding/gob" "encoding/json" "errors" + "strconv" + "strings" + "time" + "github.com/ArtisanCloud/PowerLibs/v3/http/helper" "github.com/ArtisanCloud/PowerWeChat/v3/src/basicService/qrCode/response" "github.com/ArtisanCloud/PowerWeChat/v3/src/kernel/contract" @@ -14,18 +17,17 @@ import ( "github.com/gin-gonic/gin" "github.com/yitter/idgenerator-go/idgen" "gorm.io/gorm" + "schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/enum" "schisandra-cloud-album/common/randomname" "schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/result" + "schisandra-cloud-album/common/types" "schisandra-cloud-album/controller/websocket_controller/qr_ws_controller" "schisandra-cloud-album/global" "schisandra-cloud-album/model" "schisandra-cloud-album/utils" - "strconv" - "strings" - "time" ) // CallbackNotify 微信回调 @@ -246,24 +248,20 @@ func handelUserLogin(userId string, clientId string, c *gin.Context) bool { resultChan <- false return } - refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7) - data := ResponseData{ - AccessToken: accessToken, - RefreshToken: refreshToken, - ExpiresAt: expiresAt, - UID: &userId, - UserInfo: UserInfo{ - Username: user.Username, - Nickname: user.Nickname, - Avatar: user.Avatar, - Gender: user.Gender, - Phone: user.Phone, - Email: user.Email, - CreateAt: *user.CreatedTime, - Status: user.Status, - }, + refreshToken := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7) + data := types.ResponseData{ + AccessToken: accessToken, + UID: &userId, + Username: user.Username, + Nickname: user.Nickname, + Avatar: user.Avatar, + Status: user.Status, } - fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err() + redisTokenData := types.RedisToken{ + AccessToken: accessToken, + UID: userId, + } + fail := redis.Set(constant.UserLoginTokenRedisKey+userId, redisTokenData, time.Hour*24*7).Err() if fail != nil { resultChan <- false return @@ -279,8 +277,11 @@ func handelUserLogin(userId string, clientId string, c *gin.Context) bool { resultChan <- false return } - gob.Register(ResponseData{}) - wrong := utils.SetSession(c, constant.SessionKey, data) + sessionData := utils.SessionData{ + RefreshToken: refreshToken, + UID: userId, + } + wrong := utils.SetSession(c, constant.SessionKey, sessionData) if wrong != nil { resultChan <- false return diff --git a/controller/user_controller/request_param.go b/controller/user_controller/request_param.go index 49a74c8..0a07ae4 100644 --- a/controller/user_controller/request_param.go +++ b/controller/user_controller/request_param.go @@ -1,10 +1,5 @@ package user_controller -// RefreshTokenRequest 刷新token请求 -type RefreshTokenRequest struct { - RefreshToken string `json:"refresh_token" binding:"required"` -} - // PhoneLoginRequest 手机号登录请求 type PhoneLoginRequest struct { Phone string `json:"phone" binding:"required"` diff --git a/controller/user_controller/user_controller.go b/controller/user_controller/user_controller.go index 0c760ad..8a3396e 100644 --- a/controller/user_controller/user_controller.go +++ b/controller/user_controller/user_controller.go @@ -4,7 +4,6 @@ import ( "errors" "reflect" "strconv" - "sync" ginI18n "github.com/gin-contrib/i18n" "github.com/gin-gonic/gin" @@ -25,7 +24,6 @@ import ( type UserController struct{} -var mu sync.Mutex var userService = impl.UserServiceImpl{} var userDeviceService = impl.UserDeviceServiceImpl{} @@ -244,12 +242,8 @@ func (UserController) PhoneLogin(c *gin.Context) { // @Success 200 {string} json // @Router /controller/token/refresh [post] func (UserController) RefreshHandler(c *gin.Context) { - request := RefreshTokenRequest{} - if err := c.ShouldBindJSON(&request); err != nil { - global.LOG.Error(err) - return - } - data, res := userService.RefreshTokenService(request.RefreshToken) + session := utils.GetSession(c, constant.SessionKey) + data, res := userService.RefreshTokenService(c, session.RefreshToken) if !res { result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "LoginExpired"), c) return @@ -343,7 +337,7 @@ func (UserController) ResetPassword(c *gin.Context) { // @Success 200 {string} json // @Router /controller/auth/user/logout [post] func (UserController) Logout(c *gin.Context) { - userId := c.Query("user_id") + userId := utils.GetSession(c, constant.SessionKey).UID if userId == "" { global.LOG.Errorln("userId is empty") result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) @@ -371,7 +365,7 @@ func (UserController) Logout(c *gin.Context) { // GetUserLoginDevice 获取用户登录设备 func (UserController) GetUserLoginDevice(c *gin.Context) { - userId := c.Query("user_id") + userId := utils.GetSession(c, constant.SessionKey).UID if userId == "" { return } diff --git a/core/session.go b/core/session.go index 1b80147..2444f91 100644 --- a/core/session.go +++ b/core/session.go @@ -2,10 +2,12 @@ package core import ( "context" + "net/http" + "github.com/gorilla/sessions" "github.com/rbcervilla/redisstore/v9" "github.com/redis/go-redis/v9" - "net/http" + "schisandra-cloud-album/common/constant" "schisandra-cloud-album/global" ) @@ -20,7 +22,7 @@ func InitSession(client *redis.Client) { store.KeyPrefix(constant.UserSessionRedisKey) store.Options(sessions.Options{ Path: "/", - //Domain: global.CONFIG.System.Web, + // Domain: global.CONFIG.System.Web, MaxAge: 86400 * 7, HttpOnly: true, Secure: true, diff --git a/middleware/casbin.go b/middleware/casbin.go index b5ac65f..c6476a0 100644 --- a/middleware/casbin.go +++ b/middleware/casbin.go @@ -10,7 +10,7 @@ import ( func CasbinMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - userIdAny, exists := c.Get("userId") + userIdAny, exists := c.Get("user_id") if !exists { global.LOG.Error("casbin middleware: userId not found") result.FailWithMessage(ginI18n.MustGetMessage(c, "PermissionDenied"), c) diff --git a/middleware/jwt.go b/middleware/jwt.go index 3091331..2667a91 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -10,17 +10,11 @@ import ( "schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/result" + "schisandra-cloud-album/common/types" "schisandra-cloud-album/global" "schisandra-cloud-album/utils" ) -type TokenData struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresAt int64 `json:"expires_at"` - UID *string `json:"uid"` -} - func JWTAuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // 默认Token放在请求头Authorization的Bearer中,并以空格隔开 @@ -51,7 +45,7 @@ func JWTAuthMiddleware() gin.HandlerFunc { c.Abort() return } - tokenResult := TokenData{} + tokenResult := types.RedisToken{} err = json.Unmarshal([]byte(token), &tokenResult) if err != nil { result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c) @@ -63,7 +57,13 @@ func JWTAuthMiddleware() gin.HandlerFunc { c.Abort() return } - c.Set("userId", parseToken.UserID) + uid := utils.GetSession(c, constant.SessionKey).UID + if uid != *parseToken.UserID { + result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c) + c.Abort() + return + } + c.Set("user_id", parseToken.UserID) global.DB.Set("user_id", parseToken.UserID) // 全局变量中设置用户ID c.Next() } diff --git a/middleware/session_check.go b/middleware/session_check.go deleted file mode 100644 index 857811e..0000000 --- a/middleware/session_check.go +++ /dev/null @@ -1,41 +0,0 @@ -package middleware - -import ( - ginI18n "github.com/gin-contrib/i18n" - "github.com/gin-gonic/gin" - - "schisandra-cloud-album/common/constant" - "schisandra-cloud-album/common/result" - "schisandra-cloud-album/utils" -) - -// SessionCheckMiddleware session检查中间件 -func SessionCheckMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - session := utils.GetSession(c, constant.SessionKey) - if session == nil { - result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c) - c.Abort() - return - } - - userIdAny, exists := c.Get("userId") - if !exists { - result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c) - c.Abort() - return - } - userId, ok := userIdAny.(*string) - if !ok { - result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c) - c.Abort() - return - } - if *userId != *session.UID { - result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c) - c.Abort() - return - } - c.Next() - } -} diff --git a/router/router.go b/router/router.go index b3718c0..b5bbe40 100644 --- a/router/router.go +++ b/router/router.go @@ -52,7 +52,6 @@ func InitRouter() *gin.Engine { middleware.SecurityHeaders(), middleware.JWTAuthMiddleware(), middleware.CasbinMiddleware(), - middleware.SessionCheckMiddleware(), middleware.VerifySignature(), ) { diff --git a/service/impl/user_service_impl.go b/service/impl/user_service_impl.go index bf62085..481a0be 100644 --- a/service/impl/user_service_impl.go +++ b/service/impl/user_service_impl.go @@ -1,20 +1,21 @@ package impl import ( - "encoding/gob" - "encoding/json" "errors" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/mssola/useragent" "gorm.io/gorm" + "schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/redis" + "schisandra-cloud-album/common/types" "schisandra-cloud-album/dao/impl" "schisandra-cloud-album/global" "schisandra-cloud-album/model" "schisandra-cloud-album/utils" - "sync" - "time" ) var userDao = impl.UserDaoImpl{} @@ -23,33 +24,6 @@ type UserServiceImpl struct{} var mu = &sync.Mutex{} -// ResponseData 返回数据 -type ResponseData struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresAt int64 `json:"expires_at"` - UID *string `json:"uid"` - UserInfo UserInfo `json:"user_info"` -} -type UserInfo struct { - Username string `json:"username,omitempty"` - Nickname string `json:"nickname"` - Avatar string `json:"avatar"` - Phone string `json:"phone,omitempty"` - Email string `json:"email,omitempty"` - Gender string `json:"gender"` - Status int64 `json:"status"` - CreateAt time.Time `json:"create_at"` -} - -func (res ResponseData) MarshalBinary() ([]byte, error) { - return json.Marshal(res) -} - -func (res ResponseData) UnmarshalBinary(data []byte) error { - return json.Unmarshal(data, &res) -} - // GetUserListService 返回用户列表 func (UserServiceImpl) GetUserListService() []*model.ScaAuthUser { return userDao.GetUserList() @@ -95,36 +69,34 @@ func (UserServiceImpl) UpdateUserService(phone, encrypt string) error { } // RefreshTokenService 刷新用户token -func (UserServiceImpl) RefreshTokenService(refreshToken string) (*ResponseData, bool) { +func (UserServiceImpl) RefreshTokenService(c *gin.Context, refreshToken string) (string, bool) { parseRefreshToken, isUpd, err := utils.ParseRefreshToken(refreshToken) if err != nil || !isUpd { global.LOG.Errorln(err) - return nil, false + return "", false } accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID}) if err != nil { - return nil, false + return "", false } tokenKey := constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID - token, err := redis.Get(tokenKey).Result() - if err != nil || token == "" { + session := utils.GetSession(c, constant.SessionKey) + if session.RefreshToken == "" { + return "", false + } + redisTokenData := types.RedisToken{ + AccessToken: accessTokenString, + UID: *parseRefreshToken.UserID, + } + if err = redis.Set(tokenKey, redisTokenData, time.Hour*24*7).Err(); err != nil { global.LOG.Errorln(err) - return nil, false + return "", false } - data := ResponseData{ - AccessToken: accessTokenString, - RefreshToken: refreshToken, - UID: parseRefreshToken.UserID, - } - if err = redis.Set(tokenKey, data, time.Hour*24*7).Err(); err != nil { - global.LOG.Errorln(err) - return nil, false - } - return &data, true + return accessTokenString, true } // HandelUserLogin 处理用户登录 -func (UserServiceImpl) HandelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) (*ResponseData, bool) { +func (UserServiceImpl) HandelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) (*types.ResponseData, bool) { // 检查 user.UID 是否为 nil if user.UID == "" { return nil, false @@ -143,30 +115,28 @@ func (UserServiceImpl) HandelUserLogin(user model.ScaAuthUser, autoLogin bool, c days = time.Minute * 30 } - refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &user.UID}, days) - data := ResponseData{ - AccessToken: accessToken, - RefreshToken: refreshToken, - ExpiresAt: expiresAt, - UID: &user.UID, - UserInfo: UserInfo{ - Username: user.Username, - Nickname: user.Nickname, - Avatar: user.Avatar, - Phone: user.Phone, - Email: user.Email, - Gender: user.Gender, - Status: user.Status, - CreateAt: *user.CreatedTime, - }, + refreshToken := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &user.UID}, days) + data := types.ResponseData{ + AccessToken: accessToken, + UID: &user.UID, + Username: user.Username, + Nickname: user.Nickname, + Avatar: user.Avatar, + Status: user.Status, } - - err = redis.Set(constant.UserLoginTokenRedisKey+user.UID, data, days).Err() + redisTokenData := types.RedisToken{ + AccessToken: accessToken, + UID: user.UID, + } + err = redis.Set(constant.UserLoginTokenRedisKey+user.UID, redisTokenData, days).Err() if err != nil { return nil, false } - gob.Register(ResponseData{}) - err = utils.SetSession(c, constant.SessionKey, data) + sessionData := utils.SessionData{ + RefreshToken: refreshToken, + UID: user.UID, + } + err = utils.SetSession(c, constant.SessionKey, sessionData) if err != nil { return nil, false } diff --git a/utils/jwt.go b/utils/jwt.go index 17e5866..05904b0 100644 --- a/utils/jwt.go +++ b/utils/jwt.go @@ -33,7 +33,7 @@ func GenerateAccessToken(payload AccessJWTPayload) (string, error) { claims := AccessJWTClaims{ AccessJWTPayload: payload, RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 30)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 15)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), }, @@ -52,7 +52,7 @@ func GenerateAccessToken(payload AccessJWTPayload) (string, error) { } // GenerateRefreshToken generates a JWT token with the given payload, and returns the accessToken and refreshToken -func GenerateRefreshToken(payload RefreshJWTPayload, days time.Duration) (string, int64) { +func GenerateRefreshToken(payload RefreshJWTPayload, days time.Duration) string { MySecret = []byte(global.CONFIG.JWT.Secret) refreshClaims := RefreshJWTClaims{ RefreshJWTPayload: payload, @@ -67,14 +67,14 @@ func GenerateRefreshToken(payload RefreshJWTPayload, days time.Duration) (string refreshTokenString, err := refreshToken.SignedString(MySecret) if err != nil { global.LOG.Error(err) - return "", 0 + return "" } // refreshTokenEncrypted, err := aes.AesCtrEncryptHex([]byte(refreshTokenString), []byte(global.CONFIG.Encrypt.Key), []byte(global.CONFIG.Encrypt.IV)) // if err != nil { // fmt.Println(err) // return "", 0 // } - return refreshTokenString, refreshClaims.ExpiresAt.Time.Unix() + return refreshTokenString } // ParseAccessToken parses a JWT token and returns the payload diff --git a/utils/session.go b/utils/session.go index d1790e2..61de7c0 100644 --- a/utils/session.go +++ b/utils/session.go @@ -1,35 +1,23 @@ package utils import ( + "encoding/gob" "encoding/json" - "time" "github.com/gin-gonic/gin" "schisandra-cloud-album/global" ) -// ResponseData 返回数据 -type ResponseData struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresAt int64 `json:"expires_at"` - UID *string `json:"uid"` - UserInfo UserInfo `json:"user_info"` -} -type UserInfo struct { - Username string `json:"username,omitempty"` - Nickname string `json:"nickname"` - Avatar string `json:"avatar"` - Phone string `json:"phone,omitempty"` - Email string `json:"email,omitempty"` - Gender string `json:"gender"` - Status int64 `json:"status"` - CreateAt time.Time `json:"create_at"` +// SessionData 返回数据 +type SessionData struct { + RefreshToken string `json:"refresh_token"` + UID string `json:"uid"` } // SetSession sets session data with key and data -func SetSession(c *gin.Context, key string, data interface{}) error { +func SetSession(c *gin.Context, key string, data SessionData) error { + gob.Register(SessionData{}) session, err := global.Session.Get(c.Request, key) if err != nil { global.LOG.Error("SetSession failed: ", err) @@ -50,24 +38,24 @@ func SetSession(c *gin.Context, key string, data interface{}) error { } // GetSession gets session data with key -func GetSession(c *gin.Context, key string) *ResponseData { +func GetSession(c *gin.Context, key string) SessionData { session, err := global.Session.Get(c.Request, key) if err != nil { global.LOG.Error("GetSession failed: ", err) - return nil + return SessionData{} } jsonData, ok := session.Values[key] if !ok { global.LOG.Error("GetSession failed: ", "key not found") - return nil + return SessionData{} } - data := ResponseData{} + data := SessionData{} err = json.Unmarshal(jsonData.([]byte), &data) if err != nil { global.LOG.Error("GetSession failed: ", err) - return nil + return SessionData{} } - return &data + return data } // DelSession deletes session data with key