🎨 update

This commit is contained in:
landaiqing
2024-11-05 17:24:11 +08:00
parent a153e0345a
commit 0b22d9800c
16 changed files with 210 additions and 289 deletions

36
common/types/types.go Normal file
View File

@@ -0,0 +1,36 @@
package types
import (
"encoding/json"
)
// ResponseData 返回数据
type ResponseData struct {
AccessToken string `json:"access_token"`
UID *string `json:"uid"`
Username string `json:"username,omitempty"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Status int64 `json:"status"`
}
func (res ResponseData) MarshalBinary() ([]byte, error) {
return json.Marshal(res)
}
func (res ResponseData) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, &res)
}
type RedisToken struct {
AccessToken string `json:"access_token"`
UID string `json:"uid"`
}
func (res RedisToken) MarshalBinary() ([]byte, error) {
return json.Marshal(res)
}
func (res RedisToken) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, &res)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mssola/useragent" "github.com/mssola/useragent"
"schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/enum" "schisandra-cloud-album/common/enum"
"schisandra-cloud-album/common/result" "schisandra-cloud-album/common/result"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
@@ -62,7 +63,8 @@ func (CommentController) CommentSubmit(c *gin.Context) {
browser, _ := ua.Browser() browser, _ := ua.Browser()
operatingSystem := ua.OS() operatingSystem := ua.OS()
isAuthor := 0 isAuthor := 0
if commentRequest.UserID == commentRequest.Author { uid := utils.GetSession(c, constant.SessionKey).UID
if uid == commentRequest.Author {
isAuthor = 1 isAuthor = 1
} }
xssFilterContent := utils.XssFilter(commentRequest.Content) xssFilterContent := utils.XssFilter(commentRequest.Content)
@@ -71,9 +73,10 @@ func (CommentController) CommentSubmit(c *gin.Context) {
return return
} }
commentContent := global.SensitiveManager.Replace(xssFilterContent, '*') commentContent := global.SensitiveManager.Replace(xssFilterContent, '*')
commentReply := model.ScaCommentReply{ commentReply := model.ScaCommentReply{
Content: commentContent, Content: commentContent,
UserId: commentRequest.UserID, UserId: uid,
TopicId: commentRequest.TopicId, TopicId: commentRequest.TopicId,
TopicType: enum.CommentTopicType, TopicType: enum.CommentTopicType,
CommentType: enum.COMMENT, CommentType: enum.COMMENT,
@@ -84,7 +87,7 @@ func (CommentController) CommentSubmit(c *gin.Context) {
OperatingSystem: operatingSystem, OperatingSystem: operatingSystem,
Agent: userAgent, Agent: userAgent,
} }
commentId, response := commentReplyService.SubmitCommentService(&commentReply, commentRequest.TopicId, commentRequest.UserID, commentRequest.Images) commentId, response := commentReplyService.SubmitCommentService(&commentReply, commentRequest.TopicId, uid, commentRequest.Images)
if !response { if !response {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
return return
@@ -92,7 +95,7 @@ func (CommentController) CommentSubmit(c *gin.Context) {
responseData := model.ScaCommentReply{ responseData := model.ScaCommentReply{
Id: commentId, Id: commentId,
Content: commentContent, Content: commentContent,
UserId: commentRequest.UserID, UserId: uid,
TopicId: commentRequest.TopicId, TopicId: commentRequest.TopicId,
Author: isAuthor, Author: isAuthor,
Location: location, Location: location,
@@ -147,7 +150,8 @@ func (CommentController) ReplySubmit(c *gin.Context) {
browser, _ := ua.Browser() browser, _ := ua.Browser()
operatingSystem := ua.OS() operatingSystem := ua.OS()
isAuthor := 0 isAuthor := 0
if replyCommentRequest.UserID == replyCommentRequest.Author { uid := utils.GetSession(c, constant.SessionKey).UID
if uid == replyCommentRequest.Author {
isAuthor = 1 isAuthor = 1
} }
xssFilterContent := utils.XssFilter(replyCommentRequest.Content) xssFilterContent := utils.XssFilter(replyCommentRequest.Content)
@@ -158,7 +162,7 @@ func (CommentController) ReplySubmit(c *gin.Context) {
commentContent := global.SensitiveManager.Replace(xssFilterContent, '*') commentContent := global.SensitiveManager.Replace(xssFilterContent, '*')
commentReply := model.ScaCommentReply{ commentReply := model.ScaCommentReply{
Content: commentContent, Content: commentContent,
UserId: replyCommentRequest.UserID, UserId: uid,
TopicId: replyCommentRequest.TopicId, TopicId: replyCommentRequest.TopicId,
TopicType: enum.CommentTopicType, TopicType: enum.CommentTopicType,
CommentType: enum.REPLY, CommentType: enum.REPLY,
@@ -171,7 +175,7 @@ func (CommentController) ReplySubmit(c *gin.Context) {
OperatingSystem: operatingSystem, OperatingSystem: operatingSystem,
Agent: userAgent, Agent: userAgent,
} }
commentReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyCommentRequest.TopicId, replyCommentRequest.UserID, replyCommentRequest.Images) commentReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyCommentRequest.TopicId, uid, replyCommentRequest.Images)
if !response { if !response {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
return return
@@ -179,7 +183,7 @@ func (CommentController) ReplySubmit(c *gin.Context) {
responseData := model.ScaCommentReply{ responseData := model.ScaCommentReply{
Id: commentReplyId, Id: commentReplyId,
Content: commentContent, Content: commentContent,
UserId: replyCommentRequest.UserID, UserId: uid,
TopicId: replyCommentRequest.TopicId, TopicId: replyCommentRequest.TopicId,
ReplyId: replyCommentRequest.ReplyId, ReplyId: replyCommentRequest.ReplyId,
ReplyUser: replyCommentRequest.ReplyUser, ReplyUser: replyCommentRequest.ReplyUser,
@@ -236,7 +240,8 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) {
browser, _ := ua.Browser() browser, _ := ua.Browser()
operatingSystem := ua.OS() operatingSystem := ua.OS()
isAuthor := 0 isAuthor := 0
if replyReplyRequest.UserID == replyReplyRequest.Author { uid := utils.GetSession(c, constant.SessionKey).UID
if uid == replyReplyRequest.Author {
isAuthor = 1 isAuthor = 1
} }
xssFilterContent := utils.XssFilter(replyReplyRequest.Content) xssFilterContent := utils.XssFilter(replyReplyRequest.Content)
@@ -247,7 +252,7 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) {
commentContent := global.SensitiveManager.Replace(xssFilterContent, '*') commentContent := global.SensitiveManager.Replace(xssFilterContent, '*')
commentReply := model.ScaCommentReply{ commentReply := model.ScaCommentReply{
Content: commentContent, Content: commentContent,
UserId: replyReplyRequest.UserID, UserId: uid,
TopicId: replyReplyRequest.TopicId, TopicId: replyReplyRequest.TopicId,
TopicType: enum.CommentTopicType, TopicType: enum.CommentTopicType,
CommentType: enum.REPLY, CommentType: enum.REPLY,
@@ -261,7 +266,7 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) {
OperatingSystem: operatingSystem, OperatingSystem: operatingSystem,
Agent: userAgent, Agent: userAgent,
} }
commentReplyReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyReplyRequest.TopicId, replyReplyRequest.UserID, replyReplyRequest.Images) commentReplyReplyId, response := commentReplyService.SubmitCommentService(&commentReply, replyReplyRequest.TopicId, uid, replyReplyRequest.Images)
if !response { if !response {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
return return
@@ -269,7 +274,7 @@ func (CommentController) ReplyReplySubmit(c *gin.Context) {
responseData := model.ScaCommentReply{ responseData := model.ScaCommentReply{
Id: commentReplyReplyId, Id: commentReplyReplyId,
Content: commentContent, Content: commentContent,
UserId: replyReplyRequest.UserID, UserId: uid,
TopicId: replyReplyRequest.TopicId, TopicId: replyReplyRequest.TopicId,
ReplyTo: replyReplyRequest.ReplyTo, ReplyTo: replyReplyRequest.ReplyTo,
ReplyId: replyReplyRequest.ReplyId, ReplyId: replyReplyRequest.ReplyId,
@@ -299,7 +304,8 @@ func (CommentController) CommentList(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return return
} }
response := commentReplyService.GetCommentListService(commentListRequest.UserID, commentListRequest.TopicId, commentListRequest.Page, commentListRequest.Size, commentListRequest.IsHot) uid := utils.GetSession(c, constant.SessionKey).UID
response := commentReplyService.GetCommentListService(uid, commentListRequest.TopicId, commentListRequest.Page, commentListRequest.Size, commentListRequest.IsHot)
result.OkWithData(response, c) result.OkWithData(response, c)
return return
} }
@@ -319,7 +325,8 @@ func (CommentController) ReplyList(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return return
} }
response := commentReplyService.GetCommentReplyListService(replyListRequest.UserID, replyListRequest.TopicId, replyListRequest.CommentId, replyListRequest.Page, replyListRequest.Size) uid := utils.GetSession(c, constant.SessionKey).UID
response := commentReplyService.GetCommentReplyListService(uid, replyListRequest.TopicId, replyListRequest.CommentId, replyListRequest.Page, replyListRequest.Size)
result.OkWithData(response, c) result.OkWithData(response, c)
return return
} }
@@ -339,7 +346,8 @@ func (CommentController) CommentLikes(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return return
} }
res := commentReplyService.CommentLikeService(likeRequest.UserID, likeRequest.CommentId, likeRequest.TopicId) uid := utils.GetSession(c, constant.SessionKey).UID
res := commentReplyService.CommentLikeService(uid, likeRequest.CommentId, likeRequest.TopicId)
if !res { if !res {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeFailed"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeFailed"), c)
return return
@@ -362,7 +370,8 @@ func (CommentController) CancelCommentLikes(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return return
} }
res := commentReplyService.CommentDislikeService(cancelLikeRequest.UserID, cancelLikeRequest.CommentId, cancelLikeRequest.TopicId) uid := utils.GetSession(c, constant.SessionKey).UID
res := commentReplyService.CommentDislikeService(uid, cancelLikeRequest.CommentId, cancelLikeRequest.TopicId)
if !res { if !res {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentDislikeFailed"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentDislikeFailed"), c)
return return

View File

@@ -4,43 +4,43 @@ package comment_controller
type CommentRequest struct { type CommentRequest struct {
Content string `json:"content" binding:"required"` Content string `json:"content" binding:"required"`
Images []string `json:"images"` Images []string `json:"images"`
UserID string `json:"user_id" binding:"required"` // UserID string `json:"user_id" binding:"required"`
TopicId string `json:"topic_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"`
Author string `json:"author" binding:"required"` Author string `json:"author" binding:"required"`
Key string `json:"key" binding:"required"` Key string `json:"key" binding:"required"`
Point []int64 `json:"point" binding:"required"` Point []int64 `json:"point" binding:"required"`
} }
// ReplyCommentRequest 回复评论请求参数 // ReplyCommentRequest 回复评论请求参数
type ReplyCommentRequest struct { type ReplyCommentRequest struct {
Content string `json:"content" binding:"required"` Content string `json:"content" binding:"required"`
Images []string `json:"images"` Images []string `json:"images"`
UserID string `json:"user_id" binding:"required"` // UserID string `json:"user_id" binding:"required"`
TopicId string `json:"topic_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"`
ReplyId int64 `json:"reply_id" binding:"required"` ReplyId int64 `json:"reply_id" binding:"required"`
ReplyUser string `json:"reply_user" binding:"required"` ReplyUser string `json:"reply_user" binding:"required"`
Author string `json:"author" binding:"required"` Author string `json:"author" binding:"required"`
Key string `json:"key" binding:"required"` Key string `json:"key" binding:"required"`
Point []int64 `json:"point" binding:"required"` Point []int64 `json:"point" binding:"required"`
} }
// ReplyReplyRequest 回复回复请求参数 // ReplyReplyRequest 回复回复请求参数
type ReplyReplyRequest struct { type ReplyReplyRequest struct {
Content string `json:"content" binding:"required"` Content string `json:"content" binding:"required"`
Images []string `json:"images"` Images []string `json:"images"`
UserID string `json:"user_id" binding:"required"` // UserID string `json:"user_id" binding:"required"`
TopicId string `json:"topic_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"`
ReplyTo int64 `json:"reply_to" binding:"required"` ReplyTo int64 `json:"reply_to" binding:"required"`
ReplyId int64 `json:"reply_id" binding:"required"` ReplyId int64 `json:"reply_id" binding:"required"`
ReplyUser string `json:"reply_user" binding:"required"` ReplyUser string `json:"reply_user" binding:"required"`
Author string `json:"author" binding:"required"` Author string `json:"author" binding:"required"`
Key string `json:"key" binding:"required"` Key string `json:"key" binding:"required"`
Point []int64 `json:"point" binding:"required"` Point []int64 `json:"point" binding:"required"`
} }
// CommentListRequest 评论列表请求参数 // CommentListRequest 评论列表请求参数
type CommentListRequest struct { type CommentListRequest struct {
UserID string `json:"user_id" binding:"required"` // UserID string `json:"user_id" binding:"required"`
TopicId string `json:"topic_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"`
Page int `json:"page" default:"1"` Page int `json:"page" default:"1"`
Size int `json:"size" default:"5"` Size int `json:"size" default:"5"`
@@ -49,7 +49,7 @@ type CommentListRequest struct {
// ReplyListRequest 回复列表请求参数 // ReplyListRequest 回复列表请求参数
type ReplyListRequest struct { type ReplyListRequest struct {
UserID string `json:"user_id" binding:"required"` // UserID string `json:"user_id" binding:"required"`
TopicId string `json:"topic_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"`
CommentId int64 `json:"comment_id" binding:"required"` CommentId int64 `json:"comment_id" binding:"required"`
Page int `json:"page" default:"1"` Page int `json:"page" default:"1"`
@@ -60,5 +60,5 @@ type ReplyListRequest struct {
type CommentLikeRequest struct { type CommentLikeRequest struct {
TopicId string `json:"topic_id" binding:"required"` TopicId string `json:"topic_id" binding:"required"`
CommentId int64 `json:"comment_id" binding:"required"` CommentId int64 `json:"comment_id" binding:"required"`
UserID string `json:"user_id" binding:"required"` // UserID string `json:"user_id" binding:"required"`
} }

View File

@@ -3,16 +3,19 @@ package oauth_controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/redis"
"schisandra-cloud-album/common/result" "schisandra-cloud-album/common/result"
"schisandra-cloud-album/common/types"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
"schisandra-cloud-album/service/impl" "schisandra-cloud-album/service/impl"
"schisandra-cloud-album/utils" "schisandra-cloud-album/utils"
"sync"
"time"
) )
type OAuthController struct{} type OAuthController struct{}
@@ -36,7 +39,6 @@ func HandleLoginResponse(c *gin.Context, uid string) {
user := userService.QueryUserByUuidService(&uid) user := userService.QueryUserByUuidService(&uid)
var accessToken, refreshToken string var accessToken, refreshToken string
var expiresAt int64
var err error var err error
var wg sync.WaitGroup var wg sync.WaitGroup
var accessTokenErr error var accessTokenErr error
@@ -52,7 +54,7 @@ func HandleLoginResponse(c *gin.Context, uid string) {
// 使用goroutine生成refreshToken // 使用goroutine生成refreshToken
go func() { go func() {
defer wg.Done() // 完成时减少计数器 defer wg.Done() // 完成时减少计数器
refreshToken, expiresAt = utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &uid}, time.Hour*24*7) refreshToken = utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &uid}, time.Hour*24*7)
}() }()
// 等待两个协程完成 // 等待两个协程完成
@@ -64,29 +66,28 @@ func HandleLoginResponse(c *gin.Context, uid string) {
return return
} }
data := ResponseData{ data := types.ResponseData{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, UID: &uid,
ExpiresAt: expiresAt, Username: user.Username,
UID: &uid, Nickname: user.Nickname,
UserInfo: UserInfo{ Avatar: user.Avatar,
Username: user.Username, Status: user.Status,
Nickname: user.Nickname,
Avatar: user.Avatar,
Email: user.Email,
Phone: user.Phone,
Gender: user.Gender,
Status: user.Status,
CreateAt: *user.CreatedTime,
},
} }
// 设置session
if err = utils.SetSession(c, constant.SessionKey, data); err != nil { sessionData := utils.SessionData{
RefreshToken: refreshToken,
UID: uid,
}
if err = utils.SetSession(c, constant.SessionKey, sessionData); err != nil {
return return
} }
redisTokenData := types.RedisToken{
AccessToken: accessToken,
UID: uid,
}
// 将数据存入redis // 将数据存入redis
if err = redis.Set(constant.UserLoginTokenRedisKey+uid, data, time.Hour*24*7).Err(); err != nil { if err = redis.Set(constant.UserLoginTokenRedisKey+uid, redisTokenData, time.Hour*24*7).Err(); err != nil {
global.LOG.Error(err) global.LOG.Error(err)
return return
} }

View File

@@ -1,33 +0,0 @@
package oauth_controller
import (
"encoding/json"
"time"
)
// ResponseData 返回数据
type ResponseData struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt int64 `json:"expires_at"`
UID *string `json:"uid"`
UserInfo UserInfo `json:"user_info"`
}
type UserInfo struct {
Username string `json:"username,omitempty"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Phone string `json:"phone,omitempty"`
Email string `json:"email,omitempty"`
Gender string `json:"gender"`
Status int64 `json:"status"`
CreateAt time.Time `json:"create_at"`
}
func (res ResponseData) MarshalBinary() ([]byte, error) {
return json.Marshal(res)
}
func (res ResponseData) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, &res)
}

View File

@@ -1,9 +1,12 @@
package oauth_controller package oauth_controller
import ( import (
"encoding/gob"
"encoding/json" "encoding/json"
"errors" "errors"
"strconv"
"strings"
"time"
"github.com/ArtisanCloud/PowerLibs/v3/http/helper" "github.com/ArtisanCloud/PowerLibs/v3/http/helper"
"github.com/ArtisanCloud/PowerWeChat/v3/src/basicService/qrCode/response" "github.com/ArtisanCloud/PowerWeChat/v3/src/basicService/qrCode/response"
"github.com/ArtisanCloud/PowerWeChat/v3/src/kernel/contract" "github.com/ArtisanCloud/PowerWeChat/v3/src/kernel/contract"
@@ -14,18 +17,17 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/yitter/idgenerator-go/idgen" "github.com/yitter/idgenerator-go/idgen"
"gorm.io/gorm" "gorm.io/gorm"
"schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/enum" "schisandra-cloud-album/common/enum"
"schisandra-cloud-album/common/randomname" "schisandra-cloud-album/common/randomname"
"schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/redis"
"schisandra-cloud-album/common/result" "schisandra-cloud-album/common/result"
"schisandra-cloud-album/common/types"
"schisandra-cloud-album/controller/websocket_controller/qr_ws_controller" "schisandra-cloud-album/controller/websocket_controller/qr_ws_controller"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
"schisandra-cloud-album/model" "schisandra-cloud-album/model"
"schisandra-cloud-album/utils" "schisandra-cloud-album/utils"
"strconv"
"strings"
"time"
) )
// CallbackNotify 微信回调 // CallbackNotify 微信回调
@@ -246,24 +248,20 @@ func handelUserLogin(userId string, clientId string, c *gin.Context) bool {
resultChan <- false resultChan <- false
return return
} }
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7) refreshToken := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7)
data := ResponseData{ data := types.ResponseData{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, UID: &userId,
ExpiresAt: expiresAt, Username: user.Username,
UID: &userId, Nickname: user.Nickname,
UserInfo: UserInfo{ Avatar: user.Avatar,
Username: user.Username, Status: user.Status,
Nickname: user.Nickname,
Avatar: user.Avatar,
Gender: user.Gender,
Phone: user.Phone,
Email: user.Email,
CreateAt: *user.CreatedTime,
Status: user.Status,
},
} }
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err() redisTokenData := types.RedisToken{
AccessToken: accessToken,
UID: userId,
}
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, redisTokenData, time.Hour*24*7).Err()
if fail != nil { if fail != nil {
resultChan <- false resultChan <- false
return return
@@ -279,8 +277,11 @@ func handelUserLogin(userId string, clientId string, c *gin.Context) bool {
resultChan <- false resultChan <- false
return return
} }
gob.Register(ResponseData{}) sessionData := utils.SessionData{
wrong := utils.SetSession(c, constant.SessionKey, data) RefreshToken: refreshToken,
UID: userId,
}
wrong := utils.SetSession(c, constant.SessionKey, sessionData)
if wrong != nil { if wrong != nil {
resultChan <- false resultChan <- false
return return

View File

@@ -1,10 +1,5 @@
package user_controller package user_controller
// RefreshTokenRequest 刷新token请求
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
// PhoneLoginRequest 手机号登录请求 // PhoneLoginRequest 手机号登录请求
type PhoneLoginRequest struct { type PhoneLoginRequest struct {
Phone string `json:"phone" binding:"required"` Phone string `json:"phone" binding:"required"`

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"reflect" "reflect"
"strconv" "strconv"
"sync"
ginI18n "github.com/gin-contrib/i18n" ginI18n "github.com/gin-contrib/i18n"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -25,7 +24,6 @@ import (
type UserController struct{} type UserController struct{}
var mu sync.Mutex
var userService = impl.UserServiceImpl{} var userService = impl.UserServiceImpl{}
var userDeviceService = impl.UserDeviceServiceImpl{} var userDeviceService = impl.UserDeviceServiceImpl{}
@@ -244,12 +242,8 @@ func (UserController) PhoneLogin(c *gin.Context) {
// @Success 200 {string} json // @Success 200 {string} json
// @Router /controller/token/refresh [post] // @Router /controller/token/refresh [post]
func (UserController) RefreshHandler(c *gin.Context) { func (UserController) RefreshHandler(c *gin.Context) {
request := RefreshTokenRequest{} session := utils.GetSession(c, constant.SessionKey)
if err := c.ShouldBindJSON(&request); err != nil { data, res := userService.RefreshTokenService(c, session.RefreshToken)
global.LOG.Error(err)
return
}
data, res := userService.RefreshTokenService(request.RefreshToken)
if !res { if !res {
result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "LoginExpired"), c) result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "LoginExpired"), c)
return return
@@ -343,7 +337,7 @@ func (UserController) ResetPassword(c *gin.Context) {
// @Success 200 {string} json // @Success 200 {string} json
// @Router /controller/auth/user/logout [post] // @Router /controller/auth/user/logout [post]
func (UserController) Logout(c *gin.Context) { func (UserController) Logout(c *gin.Context) {
userId := c.Query("user_id") userId := utils.GetSession(c, constant.SessionKey).UID
if userId == "" { if userId == "" {
global.LOG.Errorln("userId is empty") global.LOG.Errorln("userId is empty")
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
@@ -371,7 +365,7 @@ func (UserController) Logout(c *gin.Context) {
// GetUserLoginDevice 获取用户登录设备 // GetUserLoginDevice 获取用户登录设备
func (UserController) GetUserLoginDevice(c *gin.Context) { func (UserController) GetUserLoginDevice(c *gin.Context) {
userId := c.Query("user_id") userId := utils.GetSession(c, constant.SessionKey).UID
if userId == "" { if userId == "" {
return return
} }

View File

@@ -2,10 +2,12 @@ package core
import ( import (
"context" "context"
"net/http"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/rbcervilla/redisstore/v9" "github.com/rbcervilla/redisstore/v9"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"net/http"
"schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/constant"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
) )
@@ -20,7 +22,7 @@ func InitSession(client *redis.Client) {
store.KeyPrefix(constant.UserSessionRedisKey) store.KeyPrefix(constant.UserSessionRedisKey)
store.Options(sessions.Options{ store.Options(sessions.Options{
Path: "/", Path: "/",
//Domain: global.CONFIG.System.Web, // Domain: global.CONFIG.System.Web,
MaxAge: 86400 * 7, MaxAge: 86400 * 7,
HttpOnly: true, HttpOnly: true,
Secure: true, Secure: true,

View File

@@ -10,7 +10,7 @@ import (
func CasbinMiddleware() gin.HandlerFunc { func CasbinMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
userIdAny, exists := c.Get("userId") userIdAny, exists := c.Get("user_id")
if !exists { if !exists {
global.LOG.Error("casbin middleware: userId not found") global.LOG.Error("casbin middleware: userId not found")
result.FailWithMessage(ginI18n.MustGetMessage(c, "PermissionDenied"), c) result.FailWithMessage(ginI18n.MustGetMessage(c, "PermissionDenied"), c)

View File

@@ -10,17 +10,11 @@ import (
"schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/redis"
"schisandra-cloud-album/common/result" "schisandra-cloud-album/common/result"
"schisandra-cloud-album/common/types"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
"schisandra-cloud-album/utils" "schisandra-cloud-album/utils"
) )
type TokenData struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt int64 `json:"expires_at"`
UID *string `json:"uid"`
}
func JWTAuthMiddleware() gin.HandlerFunc { func JWTAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 默认Token放在请求头Authorization的Bearer中并以空格隔开 // 默认Token放在请求头Authorization的Bearer中并以空格隔开
@@ -51,7 +45,7 @@ func JWTAuthMiddleware() gin.HandlerFunc {
c.Abort() c.Abort()
return return
} }
tokenResult := TokenData{} tokenResult := types.RedisToken{}
err = json.Unmarshal([]byte(token), &tokenResult) err = json.Unmarshal([]byte(token), &tokenResult)
if err != nil { if err != nil {
result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c) result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c)
@@ -63,7 +57,13 @@ func JWTAuthMiddleware() gin.HandlerFunc {
c.Abort() c.Abort()
return return
} }
c.Set("userId", parseToken.UserID) uid := utils.GetSession(c, constant.SessionKey).UID
if uid != *parseToken.UserID {
result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c)
c.Abort()
return
}
c.Set("user_id", parseToken.UserID)
global.DB.Set("user_id", parseToken.UserID) // 全局变量中设置用户ID global.DB.Set("user_id", parseToken.UserID) // 全局变量中设置用户ID
c.Next() c.Next()
} }

View File

@@ -1,41 +0,0 @@
package middleware
import (
ginI18n "github.com/gin-contrib/i18n"
"github.com/gin-gonic/gin"
"schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/result"
"schisandra-cloud-album/utils"
)
// SessionCheckMiddleware session检查中间件
func SessionCheckMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
session := utils.GetSession(c, constant.SessionKey)
if session == nil {
result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c)
c.Abort()
return
}
userIdAny, exists := c.Get("userId")
if !exists {
result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c)
c.Abort()
return
}
userId, ok := userIdAny.(*string)
if !ok {
result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c)
c.Abort()
return
}
if *userId != *session.UID {
result.FailWithCodeAndMessage(403, ginI18n.MustGetMessage(c, "AuthVerifyExpired"), c)
c.Abort()
return
}
c.Next()
}
}

View File

@@ -52,7 +52,6 @@ func InitRouter() *gin.Engine {
middleware.SecurityHeaders(), middleware.SecurityHeaders(),
middleware.JWTAuthMiddleware(), middleware.JWTAuthMiddleware(),
middleware.CasbinMiddleware(), middleware.CasbinMiddleware(),
middleware.SessionCheckMiddleware(),
middleware.VerifySignature(), middleware.VerifySignature(),
) )
{ {

View File

@@ -1,20 +1,21 @@
package impl package impl
import ( import (
"encoding/gob"
"encoding/json"
"errors" "errors"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mssola/useragent" "github.com/mssola/useragent"
"gorm.io/gorm" "gorm.io/gorm"
"schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/redis"
"schisandra-cloud-album/common/types"
"schisandra-cloud-album/dao/impl" "schisandra-cloud-album/dao/impl"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
"schisandra-cloud-album/model" "schisandra-cloud-album/model"
"schisandra-cloud-album/utils" "schisandra-cloud-album/utils"
"sync"
"time"
) )
var userDao = impl.UserDaoImpl{} var userDao = impl.UserDaoImpl{}
@@ -23,33 +24,6 @@ type UserServiceImpl struct{}
var mu = &sync.Mutex{} var mu = &sync.Mutex{}
// ResponseData 返回数据
type ResponseData struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt int64 `json:"expires_at"`
UID *string `json:"uid"`
UserInfo UserInfo `json:"user_info"`
}
type UserInfo struct {
Username string `json:"username,omitempty"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Phone string `json:"phone,omitempty"`
Email string `json:"email,omitempty"`
Gender string `json:"gender"`
Status int64 `json:"status"`
CreateAt time.Time `json:"create_at"`
}
func (res ResponseData) MarshalBinary() ([]byte, error) {
return json.Marshal(res)
}
func (res ResponseData) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, &res)
}
// GetUserListService 返回用户列表 // GetUserListService 返回用户列表
func (UserServiceImpl) GetUserListService() []*model.ScaAuthUser { func (UserServiceImpl) GetUserListService() []*model.ScaAuthUser {
return userDao.GetUserList() return userDao.GetUserList()
@@ -95,36 +69,34 @@ func (UserServiceImpl) UpdateUserService(phone, encrypt string) error {
} }
// RefreshTokenService 刷新用户token // RefreshTokenService 刷新用户token
func (UserServiceImpl) RefreshTokenService(refreshToken string) (*ResponseData, bool) { func (UserServiceImpl) RefreshTokenService(c *gin.Context, refreshToken string) (string, bool) {
parseRefreshToken, isUpd, err := utils.ParseRefreshToken(refreshToken) parseRefreshToken, isUpd, err := utils.ParseRefreshToken(refreshToken)
if err != nil || !isUpd { if err != nil || !isUpd {
global.LOG.Errorln(err) global.LOG.Errorln(err)
return nil, false return "", false
} }
accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID}) accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID})
if err != nil { if err != nil {
return nil, false return "", false
} }
tokenKey := constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID tokenKey := constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID
token, err := redis.Get(tokenKey).Result() session := utils.GetSession(c, constant.SessionKey)
if err != nil || token == "" { if session.RefreshToken == "" {
return "", false
}
redisTokenData := types.RedisToken{
AccessToken: accessTokenString,
UID: *parseRefreshToken.UserID,
}
if err = redis.Set(tokenKey, redisTokenData, time.Hour*24*7).Err(); err != nil {
global.LOG.Errorln(err) global.LOG.Errorln(err)
return nil, false return "", false
} }
data := ResponseData{ return accessTokenString, true
AccessToken: accessTokenString,
RefreshToken: refreshToken,
UID: parseRefreshToken.UserID,
}
if err = redis.Set(tokenKey, data, time.Hour*24*7).Err(); err != nil {
global.LOG.Errorln(err)
return nil, false
}
return &data, true
} }
// HandelUserLogin 处理用户登录 // HandelUserLogin 处理用户登录
func (UserServiceImpl) HandelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) (*ResponseData, bool) { func (UserServiceImpl) HandelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) (*types.ResponseData, bool) {
// 检查 user.UID 是否为 nil // 检查 user.UID 是否为 nil
if user.UID == "" { if user.UID == "" {
return nil, false return nil, false
@@ -143,30 +115,28 @@ func (UserServiceImpl) HandelUserLogin(user model.ScaAuthUser, autoLogin bool, c
days = time.Minute * 30 days = time.Minute * 30
} }
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &user.UID}, days) refreshToken := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &user.UID}, days)
data := ResponseData{ data := types.ResponseData{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, UID: &user.UID,
ExpiresAt: expiresAt, Username: user.Username,
UID: &user.UID, Nickname: user.Nickname,
UserInfo: UserInfo{ Avatar: user.Avatar,
Username: user.Username, Status: user.Status,
Nickname: user.Nickname,
Avatar: user.Avatar,
Phone: user.Phone,
Email: user.Email,
Gender: user.Gender,
Status: user.Status,
CreateAt: *user.CreatedTime,
},
} }
redisTokenData := types.RedisToken{
err = redis.Set(constant.UserLoginTokenRedisKey+user.UID, data, days).Err() AccessToken: accessToken,
UID: user.UID,
}
err = redis.Set(constant.UserLoginTokenRedisKey+user.UID, redisTokenData, days).Err()
if err != nil { if err != nil {
return nil, false return nil, false
} }
gob.Register(ResponseData{}) sessionData := utils.SessionData{
err = utils.SetSession(c, constant.SessionKey, data) RefreshToken: refreshToken,
UID: user.UID,
}
err = utils.SetSession(c, constant.SessionKey, sessionData)
if err != nil { if err != nil {
return nil, false return nil, false
} }

View File

@@ -33,7 +33,7 @@ func GenerateAccessToken(payload AccessJWTPayload) (string, error) {
claims := AccessJWTClaims{ claims := AccessJWTClaims{
AccessJWTPayload: payload, AccessJWTPayload: payload,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 30)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 15)),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()),
}, },
@@ -52,7 +52,7 @@ func GenerateAccessToken(payload AccessJWTPayload) (string, error) {
} }
// GenerateRefreshToken generates a JWT token with the given payload, and returns the accessToken and refreshToken // GenerateRefreshToken generates a JWT token with the given payload, and returns the accessToken and refreshToken
func GenerateRefreshToken(payload RefreshJWTPayload, days time.Duration) (string, int64) { func GenerateRefreshToken(payload RefreshJWTPayload, days time.Duration) string {
MySecret = []byte(global.CONFIG.JWT.Secret) MySecret = []byte(global.CONFIG.JWT.Secret)
refreshClaims := RefreshJWTClaims{ refreshClaims := RefreshJWTClaims{
RefreshJWTPayload: payload, RefreshJWTPayload: payload,
@@ -67,14 +67,14 @@ func GenerateRefreshToken(payload RefreshJWTPayload, days time.Duration) (string
refreshTokenString, err := refreshToken.SignedString(MySecret) refreshTokenString, err := refreshToken.SignedString(MySecret)
if err != nil { if err != nil {
global.LOG.Error(err) global.LOG.Error(err)
return "", 0 return ""
} }
// refreshTokenEncrypted, err := aes.AesCtrEncryptHex([]byte(refreshTokenString), []byte(global.CONFIG.Encrypt.Key), []byte(global.CONFIG.Encrypt.IV)) // refreshTokenEncrypted, err := aes.AesCtrEncryptHex([]byte(refreshTokenString), []byte(global.CONFIG.Encrypt.Key), []byte(global.CONFIG.Encrypt.IV))
// if err != nil { // if err != nil {
// fmt.Println(err) // fmt.Println(err)
// return "", 0 // return "", 0
// } // }
return refreshTokenString, refreshClaims.ExpiresAt.Time.Unix() return refreshTokenString
} }
// ParseAccessToken parses a JWT token and returns the payload // ParseAccessToken parses a JWT token and returns the payload

View File

@@ -1,35 +1,23 @@
package utils package utils
import ( import (
"encoding/gob"
"encoding/json" "encoding/json"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
) )
// ResponseData 返回数据 // SessionData 返回数据
type ResponseData struct { type SessionData struct {
AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"`
RefreshToken string `json:"refresh_token"` UID string `json:"uid"`
ExpiresAt int64 `json:"expires_at"`
UID *string `json:"uid"`
UserInfo UserInfo `json:"user_info"`
}
type UserInfo struct {
Username string `json:"username,omitempty"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Phone string `json:"phone,omitempty"`
Email string `json:"email,omitempty"`
Gender string `json:"gender"`
Status int64 `json:"status"`
CreateAt time.Time `json:"create_at"`
} }
// SetSession sets session data with key and data // SetSession sets session data with key and data
func SetSession(c *gin.Context, key string, data interface{}) error { func SetSession(c *gin.Context, key string, data SessionData) error {
gob.Register(SessionData{})
session, err := global.Session.Get(c.Request, key) session, err := global.Session.Get(c.Request, key)
if err != nil { if err != nil {
global.LOG.Error("SetSession failed: ", err) global.LOG.Error("SetSession failed: ", err)
@@ -50,24 +38,24 @@ func SetSession(c *gin.Context, key string, data interface{}) error {
} }
// GetSession gets session data with key // GetSession gets session data with key
func GetSession(c *gin.Context, key string) *ResponseData { func GetSession(c *gin.Context, key string) SessionData {
session, err := global.Session.Get(c.Request, key) session, err := global.Session.Get(c.Request, key)
if err != nil { if err != nil {
global.LOG.Error("GetSession failed: ", err) global.LOG.Error("GetSession failed: ", err)
return nil return SessionData{}
} }
jsonData, ok := session.Values[key] jsonData, ok := session.Values[key]
if !ok { if !ok {
global.LOG.Error("GetSession failed: ", "key not found") global.LOG.Error("GetSession failed: ", "key not found")
return nil return SessionData{}
} }
data := ResponseData{} data := SessionData{}
err = json.Unmarshal(jsonData.([]byte), &data) err = json.Unmarshal(jsonData.([]byte), &data)
if err != nil { if err != nil {
global.LOG.Error("GetSession failed: ", err) global.LOG.Error("GetSession failed: ", err)
return nil return SessionData{}
} }
return &data return data
} }
// DelSession deletes session data with key // DelSession deletes session data with key