From 372835296d430d7889b9e1b007cbc11a36701389 Mon Sep 17 00:00:00 2001 From: landaiqing <3517283258@qq.com> Date: Thu, 26 Sep 2024 01:13:06 +0800 Subject: [PATCH] :sparkles: add comment verification --- api/captcha_api/captcha_api.go | 120 ++++++++++++++--------------- api/comment_api/comment_api.go | 33 ++++++-- api/comment_api/dto/request_dto.go | 7 ++ api/oauth_api/oauth.go | 25 +++--- api/oauth_api/wechat_api.go | 19 +++-- api/user_api/user_api.go | 2 + common/constant/redis_key.go | 6 ++ core/captcha.go | 1 + i18n/language/en.toml | 4 +- i18n/language/zh.toml | 4 +- router/modules/captcha_router.go | 7 ++ router/router.go | 9 ++- utils/check_captcha.go | 66 ++++++++++++++++ utils/session.go | 23 ++++-- 14 files changed, 228 insertions(+), 98 deletions(-) create mode 100644 utils/check_captcha.go diff --git a/api/captcha_api/captcha_api.go b/api/captcha_api/captcha_api.go index 7c02e52..7d2fdac 100644 --- a/api/captcha_api/captcha_api.go +++ b/api/captcha_api/captcha_api.go @@ -239,12 +239,12 @@ func (CaptchaAPI) GenerateClickShapeCaptcha(c *gin.Context) { result.OkWithData(bt, c) } -// GenerateSlideBasicCaptData 生成点击形状基础验证码 -// @Summary 生成点击形状基础验证码 -// @Description 生成点击形状基础验证码 -// @Tags 点击形状验证码 +// GenerateSlideBasicCaptData 滑块基础验证码 +// @Summary 滑块基础验证码 +// @Description 滑块基础验证码 +// @Tags 滑块基础验证码 // @Success 200 {string} json -// @Router /api/captcha/shape/check [get] +// @Router /api/captcha/slide/generate [get] func (CaptchaAPI) GenerateSlideBasicCaptData(c *gin.Context) { captData, err := global.SlideCaptcha.Generate() if err != nil { @@ -266,70 +266,27 @@ func (CaptchaAPI) GenerateSlideBasicCaptData(c *gin.Context) { return } key := helper.StringToMD5(string(dotsByte)) - err = redis.Set(key, dotsByte, time.Minute).Err() + err = redis.Set(constant.CommentSubmitCaptchaRedisKey+key, dotsByte, time.Minute).Err() if err != nil { result.FailWithNull(c) return } bt := map[string]interface{}{ - "key": key, - "image": masterImageBase64, - "tile": tileImageBase64, - "tile_width": blockData.Width, - "tile_height": blockData.Height, - "tile_x": blockData.TileX, - "tile_y": blockData.TileY, + "key": key, + "image": masterImageBase64, + "thumb": tileImageBase64, + "thumb_width": blockData.Width, + "thumb_height": blockData.Height, + "thumb_x": blockData.TileX, + "thumb_y": blockData.TileY, } result.OkWithData(bt, c) } -// CheckSlideData 验证点击形状验证码 -// @Summary 验证点击形状验证码 -// @Description 验证点击形状验证码 -// @Tags 点击形状验证码 -// @Param point query string true "点击坐标" -// @Param key query string true "验证码key" -// @Success 200 {string} json -// @Router /api/captcha/shape/slide/check [get] -func (CaptchaAPI) CheckSlideData(c *gin.Context) { - point := c.Query("point") - key := c.Query("key") - if point == "" || key == "" { - result.FailWithNull(c) - return - } - - cacheDataByte, err := redis.Get(key).Bytes() - if len(cacheDataByte) == 0 || err != nil { - result.FailWithNull(c) - return - } - src := strings.Split(point, ",") - - var dct *slide.Block - if err := json.Unmarshal(cacheDataByte, &dct); err != nil { - result.FailWithNull(c) - return - } - - chkRet := false - if 2 == len(src) { - sx, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[0]), 64) - sy, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[1]), 64) - chkRet = slide.CheckPoint(int64(sx), int64(sy), int64(dct.X), int64(dct.Y), 4) - } - - if chkRet { - result.OkWithMessage("success", c) - return - } - result.FailWithMessage("fail", c) -} - -// GenerateSlideRegionCaptData 生成点击形状验证码 -// @Summary 生成点击形状验证码 -// @Description 生成点击形状验证码 -// @Tags 点击形状验证码 +// GenerateSlideRegionCaptData 生成滑动区域形状验证码 +// @Summary 生成滑动区域形状验证码 +// @Description 生成滑动区域形状验证码 +// @Tags 生成滑动区域形状验证码 // @Success 200 {string} json // @Router /api/captcha/shape/slide/region/get [get] func (CaptchaAPI) GenerateSlideRegionCaptData(c *gin.Context) { @@ -371,3 +328,46 @@ func (CaptchaAPI) GenerateSlideRegionCaptData(c *gin.Context) { } result.OkWithData(bt, c) } + +// CheckSlideData 验证滑动验证码 +// @Summary 验证滑动验证码 +// @Description 验证滑动验证码 +// @Tags 验证滑动验证码 +// @Param point query string true "点击坐标" +// @Param key query string true "验证码key" +// @Success 200 {string} json +// @Router /api/captcha/shape/slide/check [get] +func (CaptchaAPI) CheckSlideData(c *gin.Context) { + point := c.Query("point") + key := c.Query("key") + if point == "" || key == "" { + result.FailWithNull(c) + return + } + + cacheDataByte, err := redis.Get(key).Bytes() + if len(cacheDataByte) == 0 || err != nil { + result.FailWithNull(c) + return + } + src := strings.Split(point, ",") + + var dct *slide.Block + if err := json.Unmarshal(cacheDataByte, &dct); err != nil { + result.FailWithNull(c) + return + } + + chkRet := false + if 2 == len(src) { + sx, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[0]), 64) + sy, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[1]), 64) + chkRet = slide.CheckPoint(int64(sx), int64(sy), int64(dct.X), int64(dct.Y), 4) + } + + if chkRet { + result.OkWithMessage("success", c) + return + } + result.FailWithMessage("fail", c) +} diff --git a/api/comment_api/comment_api.go b/api/comment_api/comment_api.go index 17cc3e5..78ac5eb 100644 --- a/api/comment_api/comment_api.go +++ b/api/comment_api/comment_api.go @@ -30,9 +30,15 @@ import ( func (CommentAPI) CommentSubmit(c *gin.Context) { commentRequest := dto.CommentRequest{} if err := c.ShouldBindJSON(&commentRequest); err != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } + // 验证校验 + res := utils.CheckSlideData(commentRequest.Point, commentRequest.Key) + if !res { + result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c) + return + } + if len(commentRequest.Images) > 3 { result.FailWithMessage(ginI18n.MustGetMessage(c, "TooManyImages"), c) return @@ -148,7 +154,12 @@ func (CommentAPI) ReplySubmit(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - + // 验证校验 + res := utils.CheckSlideData(replyCommentRequest.Point, replyCommentRequest.Key) + if !res { + result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c) + return + } if len(replyCommentRequest.Images) > 3 { result.FailWithMessage(ginI18n.MustGetMessage(c, "TooManyImages"), c) return @@ -270,7 +281,12 @@ func (CommentAPI) ReplyReplySubmit(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - + // 验证校验 + res := utils.CheckSlideData(replyReplyRequest.Point, replyReplyRequest.Key) + if !res { + result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c) + return + } if len(replyReplyRequest.Images) > 3 { result.FailWithMessage(ginI18n.MustGetMessage(c, "TooManyImages"), c) return @@ -394,7 +410,12 @@ func (CommentAPI) CommentList(c *gin.Context) { // 查询评论列表 query, u := gplus.NewQuery[model.ScaCommentReply]() page := gplus.NewPage[model.ScaCommentReply](commentListRequest.Page, commentListRequest.Size) - query.Eq(&u.TopicId, commentListRequest.TopicId).Eq(&u.CommentType, enum.COMMENT).OrderByDesc(&u.CommentOrder).OrderByDesc(&u.Likes).OrderByDesc(&u.ReplyCount).OrderByDesc(&u.CreatedTime) + if commentListRequest.IsHot { + query.OrderByDesc(&u.CommentOrder).OrderByDesc(&u.Likes).OrderByDesc(&u.ReplyCount) + } else { + query.OrderByDesc(&u.CommentOrder).OrderByDesc(&u.CreatedTime) + } + query.Eq(&u.TopicId, commentListRequest.TopicId).Eq(&u.CommentType, enum.COMMENT) page, pageDB := gplus.SelectPage(page, query) if pageDB.Error != nil { global.LOG.Errorln(pageDB.Error) @@ -455,7 +476,7 @@ func (CommentAPI) CommentList(c *gin.Context) { // 查询评论图片信息 go func() { defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) // 设置超时,2秒 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // 设置超时,2秒 defer cancel() cursor, err := global.MongoDB.Database(global.CONFIG.MongoDB.DB).Collection("comment_images").Find(ctx, bson.M{"comment_id": bson.M{"$in": commentIds}}) @@ -563,7 +584,7 @@ func (CommentAPI) ReplyList(c *gin.Context) { } query, u := gplus.NewQuery[model.ScaCommentReply]() page := gplus.NewPage[model.ScaCommentReply](replyListRequest.Page, replyListRequest.Size) - query.Eq(&u.TopicId, replyListRequest.TopicId).Eq(&u.ReplyId, replyListRequest.CommentId).Eq(&u.CommentType, enum.REPLY).OrderByDesc(&u.CommentOrder).OrderByDesc(&u.Likes).OrderByDesc(&u.CreatedTime) + query.Eq(&u.TopicId, replyListRequest.TopicId).Eq(&u.ReplyId, replyListRequest.CommentId).Eq(&u.CommentType, enum.REPLY).OrderByDesc(&u.Likes).OrderByAsc(&u.CreatedTime) page, pageDB := gplus.SelectPage(page, query) if pageDB.Error != nil { global.LOG.Errorln(pageDB.Error) diff --git a/api/comment_api/dto/request_dto.go b/api/comment_api/dto/request_dto.go index dba1bb7..c2ad2a4 100644 --- a/api/comment_api/dto/request_dto.go +++ b/api/comment_api/dto/request_dto.go @@ -6,6 +6,8 @@ type CommentRequest struct { UserID string `json:"user_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"` Author string `json:"author" binding:"required"` + Key string `json:"key" binding:"required"` + Point []int64 `json:"point" binding:"required"` } type ReplyCommentRequest struct { Content string `json:"content" binding:"required"` @@ -15,6 +17,8 @@ type ReplyCommentRequest struct { ReplyId int64 `json:"reply_id" binding:"required"` ReplyUser string `json:"reply_user" binding:"required"` Author string `json:"author" binding:"required"` + Key string `json:"key" binding:"required"` + Point []int64 `json:"point" binding:"required"` } type ReplyReplyRequest struct { @@ -26,6 +30,8 @@ type ReplyReplyRequest struct { ReplyId int64 `json:"reply_id" binding:"required"` ReplyUser string `json:"reply_user" binding:"required"` Author string `json:"author" binding:"required"` + Key string `json:"key" binding:"required"` + Point []int64 `json:"point" binding:"required"` } type CommentListRequest struct { @@ -33,6 +39,7 @@ type CommentListRequest struct { TopicId string `json:"topic_id" binding:"required"` Page int `json:"page" default:"1"` Size int `json:"size" default:"5"` + IsHot bool `json:"is_hot" default:"true"` } type ReplyListRequest struct { UserID string `json:"user_id" binding:"required"` diff --git a/api/oauth_api/oauth.go b/api/oauth_api/oauth.go index 0f19757..42435a0 100644 --- a/api/oauth_api/oauth.go +++ b/api/oauth_api/oauth.go @@ -11,6 +11,7 @@ import ( "schisandra-cloud-album/api/user_api/dto" "schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/redis" + "schisandra-cloud-album/common/result" "schisandra-cloud-album/global" "schisandra-cloud-album/model" "schisandra-cloud-album/service" @@ -39,10 +40,11 @@ var script = ` ` func HandleLoginResponse(c *gin.Context, uid string) { - res, data := HandelUserLogin(uid) + res, data := HandelUserLogin(uid, c) if !res { return } + tokenData, err := json.Marshal(data) if err != nil { global.LOG.Error(err) @@ -55,7 +57,7 @@ func HandleLoginResponse(c *gin.Context, uid string) { } // HandelUserLogin 处理用户登录 -func HandelUserLogin(userId string) (bool, map[string]interface{}) { +func HandelUserLogin(userId string, c *gin.Context) (bool, result.Response) { // 使用goroutine生成accessToken accessTokenChan := make(chan string) errChan := make(chan error) @@ -86,7 +88,7 @@ func HandelUserLogin(userId string) (bool, map[string]interface{}) { case accessToken = <-accessTokenChan: case err = <-errChan: global.LOG.Error(err) - return false, nil + return false, result.Response{} } select { case refreshToken = <-refreshTokenChan: @@ -99,7 +101,10 @@ func HandelUserLogin(userId string) (bool, map[string]interface{}) { ExpiresAt: expiresAt, UID: &userId, } - + wrong := utils.SetSession(c, "user", data) + if wrong != nil { + return false, result.Response{} + } // 使用goroutine将数据存入redis redisErrChan := make(chan error) go func() { @@ -115,13 +120,13 @@ func HandelUserLogin(userId string) (bool, map[string]interface{}) { redisErr := <-redisErrChan if redisErr != nil { global.LOG.Error(redisErr) - return false, nil + return false, result.Response{} } - responseData := map[string]interface{}{ - "code": 0, - "message": "success", - "data": data, - "success": true, + responseData := result.Response{ + Data: data, + Message: "success", + Code: 200, + Success: true, } return true, responseData } diff --git a/api/oauth_api/wechat_api.go b/api/oauth_api/wechat_api.go index 39c0df4..bfbc1b6 100644 --- a/api/oauth_api/wechat_api.go +++ b/api/oauth_api/wechat_api.go @@ -1,6 +1,7 @@ package oauth_api import ( + "encoding/gob" "encoding/json" "errors" "github.com/ArtisanCloud/PowerLibs/v3/http/helper" @@ -47,7 +48,7 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) { return "error" } key := strings.TrimPrefix(msg.EventKey, "qrscene_") - res := wechatLoginHandler(msg.FromUserName, key) + res := wechatLoginHandler(msg.FromUserName, key, c) if !res { return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed")) } @@ -69,7 +70,7 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) { println(err.Error()) return "error" } - res := wechatLoginHandler(msg.FromUserName, msg.EventKey) + res := wechatLoginHandler(msg.FromUserName, msg.EventKey, c) if !res { return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed")) } @@ -165,7 +166,7 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) { } // wechatLoginHandler 微信登录处理 -func wechatLoginHandler(openId string, clientId string) bool { +func wechatLoginHandler(openId string, clientId string, c *gin.Context) bool { if openId == "" { return false } @@ -257,7 +258,7 @@ func wechatLoginHandler(openId string, clientId string) bool { // 异步处理用户登录 resChan := make(chan bool, 1) go func() { - res := handelUserLogin(*addUser.UID, clientId) + res := handelUserLogin(*addUser.UID, clientId, c) resChan <- res }() @@ -271,7 +272,7 @@ func wechatLoginHandler(openId string, clientId string) bool { tx.Commit() return true } else { - res := handelUserLogin(*authUserSocial.UserID, clientId) + res := handelUserLogin(*authUserSocial.UserID, clientId, c) if !res { return false } @@ -280,7 +281,7 @@ func wechatLoginHandler(openId string, clientId string) bool { } // handelUserLogin 处理用户登录 -func handelUserLogin(userId string, clientId string) bool { +func handelUserLogin(userId string, clientId string, c *gin.Context) bool { resultChan := make(chan bool, 1) go func() { @@ -312,6 +313,12 @@ func handelUserLogin(userId string, clientId string) bool { resultChan <- false return } + gob.Register(dto.ResponseData{}) + wrong := utils.SetSession(c, "user", data) + if wrong != nil { + resultChan <- false + return + } // gws方式发送消息 err = websocket_api.Handler.SendMessageToClient(clientId, tokenData) if err != nil { diff --git a/api/user_api/user_api.go b/api/user_api/user_api.go index 21e2cd9..dd36e05 100644 --- a/api/user_api/user_api.go +++ b/api/user_api/user_api.go @@ -1,6 +1,7 @@ package user_api import ( + "encoding/gob" "errors" ginI18n "github.com/gin-contrib/i18n" "github.com/gin-gonic/gin" @@ -332,6 +333,7 @@ func handelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c) return } + gob.Register(dto.ResponseData{}) err = utils.SetSession(c, "user", data) if err != nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c) diff --git a/common/constant/redis_key.go b/common/constant/redis_key.go index 624ec41..f139c2c 100644 --- a/common/constant/redis_key.go +++ b/common/constant/redis_key.go @@ -1,6 +1,7 @@ package constant const ( + // 登录相关的redis key UserLoginSmsRedisKey = "user:sms:" UserLoginTokenRedisKey = "user:token:" UserLoginCaptchaRedisKey = "user:captcha:" @@ -8,3 +9,8 @@ const ( UserLoginQrcodeRedisKey = "user:qrcode:" UserSessionRedisKey = "user:session:" ) + +// 登录之后 +const ( + CommentSubmitCaptchaRedisKey = "comment:submit:captcha:" +) diff --git a/core/captcha.go b/core/captcha.go index 005067b..b9c05fa 100644 --- a/core/captcha.go +++ b/core/captcha.go @@ -16,6 +16,7 @@ import ( func InitCaptcha() { initRotateCaptcha() + initsSlideCaptcha() } // initTextCaptcha 初始化点选验证码 diff --git a/i18n/language/en.toml b/i18n/language/en.toml index 007d2e2..9e22b73 100644 --- a/i18n/language/en.toml +++ b/i18n/language/en.toml @@ -1,4 +1,3 @@ -welcome = "hello" NotFoundUser = "User not found" DeletedSuccess = "deleted successfully!" DeletedFailed = "delete failed!" @@ -73,4 +72,5 @@ ImageSaveError = "image save error!" CommentLikeSuccess = "comment like success!" CommentLikeFailed = "comment like failed!" CommentDislikeSuccess = "comment dislike success!" -CommentDislikeFailed = "comment dislike failed!" \ No newline at end of file +CommentDislikeFailed = "comment dislike failed!" +CaptchaVerifyError = "captcha error!" \ No newline at end of file diff --git a/i18n/language/zh.toml b/i18n/language/zh.toml index dd5a0ae..dc0ec1f 100644 --- a/i18n/language/zh.toml +++ b/i18n/language/zh.toml @@ -1,4 +1,3 @@ -welcome = "欢迎" NotFoundUser = "未找到用户!" DeletedSuccess = "删除成功!" DeletedFailed = "删除失败!" @@ -73,4 +72,5 @@ ImageSaveError = "图片保存错误!" CommentLikeSuccess = "评论点赞成功!" CommentLikeFailed = "评论点赞失败!" CommentDislikeSuccess = "评论取消点赞成功!" -CommentDislikeFailed = "评论取消点赞失败!" \ No newline at end of file +CommentDislikeFailed = "评论取消点赞失败!" +CaptchaVerifyError = "验证失败!" diff --git a/router/modules/captcha_router.go b/router/modules/captcha_router.go index 18754dd..c97c12c 100644 --- a/router/modules/captcha_router.go +++ b/router/modules/captcha_router.go @@ -9,6 +9,13 @@ var captchaApi = api.Api.CaptchaApi func CaptchaRouter(router *gin.RouterGroup) { group := router.Group("/captcha") + group.GET("/rotate/get", captchaApi.GenerateRotateCaptcha) group.POST("/rotate/check", captchaApi.CheckRotateData) } + +// CaptchaRouterAuth 需要鉴权的路由 +func CaptchaRouterAuth(router *gin.RouterGroup) { + group := router.Group("/captcha") + group.GET("/slide/generate", captchaApi.GenerateSlideBasicCaptData) +} diff --git a/router/router.go b/router/router.go index 8062ba7..1e0b03c 100644 --- a/router/router.go +++ b/router/router.go @@ -52,10 +52,11 @@ func InitRouter() *gin.Engine { middleware.CasbinMiddleware(), ) { - modules.UserRouterAuth(authGroup) // 注册鉴权路由 - modules.RoleRouter(authGroup) // 注册角色路由 - modules.PermissionRouter(authGroup) // 注册权限路由 - modules.CommentRouter(authGroup) // 注册评论路由 + modules.UserRouterAuth(authGroup) // 注册鉴权路由 + modules.RoleRouter(authGroup) // 注册角色路由 + modules.PermissionRouter(authGroup) // 注册权限路由 + modules.CommentRouter(authGroup) // 注册评论路由 + modules.CaptchaRouterAuth(authGroup) // 注册验证码路由 } return router diff --git a/utils/check_captcha.go b/utils/check_captcha.go new file mode 100644 index 0000000..8efbbf9 --- /dev/null +++ b/utils/check_captcha.go @@ -0,0 +1,66 @@ +package utils + +import ( + "encoding/json" + "fmt" + "github.com/wenlng/go-captcha/v2/rotate" + "github.com/wenlng/go-captcha/v2/slide" + "schisandra-cloud-album/common/constant" + "schisandra-cloud-album/common/redis" + "schisandra-cloud-album/global" + "strconv" +) + +// CheckSlideData 校验滑动验证码 +func CheckSlideData(point []int64, key string) bool { + if point == nil || key == "" { + return false + } + cacheDataByte, err := redis.Get(constant.CommentSubmitCaptchaRedisKey + key).Bytes() + if len(cacheDataByte) == 0 || err != nil { + return false + } + var dct *slide.Block + if err = json.Unmarshal(cacheDataByte, &dct); err != nil { + return false + } + + chkRet := false + if 2 == len(point) { + sx, _ := strconv.ParseFloat(fmt.Sprintf("%v", point[0]), 64) + sy, _ := strconv.ParseFloat(fmt.Sprintf("%v", point[1]), 64) + chkRet = slide.CheckPoint(int64(sx), int64(sy), int64(dct.X), int64(dct.Y), 4) + } + if chkRet { + return true + } + return false +} + +// CheckRotateData 校验旋转验证码 +func CheckRotateData(angle string, key string) bool { + if angle == "" || key == "" { + return false + } + cacheDataByte, err := redis.Get(constant.UserLoginCaptchaRedisKey + key).Bytes() + if err != nil || len(cacheDataByte) == 0 { + global.LOG.Error(err) + return false + } + var dct *rotate.Block + if err = json.Unmarshal(cacheDataByte, &dct); err != nil { + global.LOG.Error(err) + return false + } + sAngle, err := strconv.ParseFloat(fmt.Sprintf("%v", angle), 64) + if err != nil { + global.LOG.Error(err) + return false + } + chkRet := rotate.CheckAngle(int64(sAngle), int64(dct.Angle), 2) + if chkRet { + return true + } + return false + +} diff --git a/utils/session.go b/utils/session.go index dd967e7..f5aac09 100644 --- a/utils/session.go +++ b/utils/session.go @@ -1,20 +1,25 @@ package utils import ( - "encoding/gob" "encoding/json" "github.com/gin-gonic/gin" + "schisandra-cloud-album/api/user_api/dto" "schisandra-cloud-album/global" ) +// SetSession sets session data with key and data func SetSession(c *gin.Context, key string, data interface{}) error { session, err := global.Session.Get(c.Request, key) if err != nil { global.LOG.Error("SetSession failed: ", err) return err } - gob.Register(data) - session.Values[key] = data + jsonData, err := json.Marshal(data) + if err != nil { + global.LOG.Error("SetSession failed: ", err) + return err + } + session.Values[key] = jsonData err = session.Save(c.Request, c.Writer) if err != nil { global.LOG.Error("SetSession failed: ", err) @@ -23,26 +28,28 @@ func SetSession(c *gin.Context, key string, data interface{}) error { return nil } -func GetSession(c *gin.Context, key string) interface{} { +// GetSession gets session data with key +func GetSession(c *gin.Context, key string) dto.ResponseData { session, err := global.Session.Get(c.Request, key) if err != nil { global.LOG.Error("GetSession failed: ", err) - return nil + return dto.ResponseData{} } jsonData, ok := session.Values[key] if !ok { global.LOG.Error("GetSession failed: ", "key not found") - return nil + return dto.ResponseData{} } - var data interface{} + var data dto.ResponseData err = json.Unmarshal(jsonData.([]byte), &data) if err != nil { global.LOG.Error("GetSession failed: ", err) - return nil + return dto.ResponseData{} } return data } +// DelSession deletes session data with key func DelSession(c *gin.Context, key string) { session, err := global.Session.Get(c.Request, key) if err != nil {