🐛 fixed the issue that third-party login sessions were missing

This commit is contained in:
2024-12-20 01:19:29 +08:00
parent 49831fc4d0
commit 40d073db0f
27 changed files with 556 additions and 308 deletions

View File

@@ -6,6 +6,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
http2 "schisandra-album-cloud-microservices/app/core/api/common/http"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/oauth"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
@@ -21,7 +22,7 @@ func GiteeCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
}
l := oauth.NewGiteeCallbackLogic(r.Context(), svcCtx)
err := l.GiteeCallback(w, r, &req)
data, err := l.GiteeCallback(w, r, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -30,7 +31,7 @@ func GiteeCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
http2.OkHTML(w, data)
}
}
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
http2 "schisandra-album-cloud-microservices/app/core/api/common/http"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/oauth"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
@@ -21,7 +21,7 @@ func GithubCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
}
l := oauth.NewGithubCallbackLogic(r.Context(), svcCtx)
err := l.GithubCallback(w, r, &req)
data, err := l.GithubCallback(w, r, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -30,7 +30,7 @@ func GithubCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
http2.OkHTML(w, data)
}
}
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
http2 "schisandra-album-cloud-microservices/app/core/api/common/http"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/oauth"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
@@ -21,7 +21,7 @@ func QqCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
}
l := oauth.NewQqCallbackLogic(r.Context(), svcCtx)
err := l.QqCallback(w, r, &req)
data, err := l.QqCallback(w, r, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -30,7 +30,7 @@ func QqCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
http2.OkHTML(w, data)
}
}
}

View File

@@ -221,7 +221,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
[]rest.Middleware{serverCtx.SecurityHeadersMiddleware},
[]rest.Route{
{
Method: http.MethodGet,
Method: http.MethodPost,
Path: "/device",
Handler: user.GetUserDeviceHandler(serverCtx),
},

View File

@@ -1,20 +1,26 @@
package user
import (
"github.com/zeromicro/go-zero/core/logx"
"net/http"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/user"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
"schisandra-album-cloud-microservices/app/core/api/internal/types"
)
func GetUserDeviceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.UserDeviceRequest
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
}
l := user.NewGetUserDeviceLogic(r.Context(), svcCtx)
err := l.GetUserDevice(r)
resp, err := l.GetUserDevice(r, w, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -23,7 +29,7 @@ func GetUserDeviceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
httpx.OkJsonCtx(r.Context(), w, resp)
}
}
}

View File

@@ -77,30 +77,33 @@ func NewGiteeCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Git
}
}
func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) error {
func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) (string, error) {
// 获取 token
tokenAuthUrl := l.GetGiteeTokenAuthUrl(req.Code)
token, err := l.GetGiteeToken(tokenAuthUrl)
if err != nil {
return err
return "", err
}
if token == nil {
return nil
return "", errors.New("get gitee token failed")
}
// 获取用户信息
userInfo, err := l.GetGiteeUserInfo(token)
if err != nil {
return err
return "", err
}
if userInfo == nil {
return "", errors.New("get gitee user info failed")
}
var giteeUser GiteeUser
marshal, err := json.Marshal(userInfo)
if err != nil {
return err
return "", err
}
if err = json.Unmarshal(marshal, &giteeUser); err != nil {
return err
return "", err
}
Id := strconv.Itoa(giteeUser.ID)
@@ -109,7 +112,7 @@ func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Reques
userSocial := l.svcCtx.DB.ScaAuthUserSocial
socialUser, err := tx.ScaAuthUserSocial.Where(userSocial.OpenID.Eq(Id), userSocial.Source.Eq(constant.OAuthSourceGitee)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
return "", err
}
if socialUser == nil {
@@ -128,7 +131,7 @@ func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Reques
err = tx.ScaAuthUser.Create(addUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
gitee := constant.OAuthSourceGitee
newSocialUser := &model.ScaAuthUserSocial{
@@ -139,56 +142,56 @@ func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Reques
err = tx.ScaAuthUserSocial.Create(newSocialUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if res, err := l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User); !res || err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx); err != nil {
return err
data, err := HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
} else {
authUser := l.svcCtx.DB.ScaAuthUser
authUserInfo, err := tx.ScaAuthUser.Where(authUser.UID.Eq(socialUser.UserID)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
// HandleOauthLoginResponse 处理登录响应
func HandleOauthLoginResponse(scaAuthUser *model.ScaAuthUser, svcCtx *svc.ServiceContext, r *http.Request, w http.ResponseWriter, ctx context.Context) error {
func HandleOauthLoginResponse(scaAuthUser *model.ScaAuthUser, svcCtx *svc.ServiceContext, r *http.Request, w http.ResponseWriter, ctx context.Context) (string, error) {
data, err := user.HandleUserLogin(scaAuthUser, svcCtx, true, r, w, ctx)
if err != nil {
return err
return "", err
}
responseData := response.SuccessWithData(data)
formattedScript := fmt.Sprintf(Script, responseData, svcCtx.Config.Web.URL)
// 设置响应状态码和内容类型
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
// 写入响应内容
if _, writeErr := w.Write([]byte(formattedScript)); writeErr != nil {
return writeErr
marshalData, err := json.Marshal(responseData)
if err != nil {
return "", err
}
return nil
formattedScript := fmt.Sprintf(Script, marshalData, svcCtx.Config.Web.URL)
return formattedScript, nil
}
// GetGiteeTokenAuthUrl 获取Gitee token

View File

@@ -68,39 +68,39 @@ func NewGithubCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Gi
}
}
func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) error {
func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) (string, error) {
// 获取 token
tokenAuthUrl := l.GetTokenAuthUrl(req.Code)
token, err := l.GetToken(tokenAuthUrl)
if err != nil {
return err
return "", err
}
if token == nil {
return nil
return "", errors.New("get github token failed")
}
// 获取用户信息
userInfo, err := l.GetUserInfo(token)
if err != nil {
return err
return "", err
}
if userInfo == nil {
return nil
return "", errors.New("get github user info failed")
}
// 处理用户信息
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
return err
return "", err
}
var gitHubUser GitHubUser
err = json.Unmarshal(userInfoBytes, &gitHubUser)
if err != nil {
return err
return "", err
}
Id := strconv.Itoa(gitHubUser.ID)
tx := l.svcCtx.DB.Begin()
@@ -108,7 +108,7 @@ func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Requ
userSocial := l.svcCtx.DB.ScaAuthUserSocial
socialUser, err := tx.ScaAuthUserSocial.Where(userSocial.OpenID.Eq(Id), userSocial.Source.Eq(constant.OAuthSourceGithub)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
return "", err
}
if socialUser == nil {
@@ -130,7 +130,7 @@ func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Requ
err = tx.ScaAuthUser.Create(addUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
githubUser := constant.OAuthSourceGithub
newSocialUser := &model.ScaAuthUserSocial{
@@ -141,37 +141,42 @@ func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Requ
err = tx.ScaAuthUserSocial.Create(newSocialUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if res, err := l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User); !res || err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
} else {
authUser := l.svcCtx.DB.ScaAuthUser
authUserInfo, err := tx.ScaAuthUser.Where(authUser.UID.Eq(socialUser.UserID)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
// GetTokenAuthUrl 通过code获取token认证url

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/yitter/idgenerator-go/idgen"
"gorm.io/gorm"
@@ -65,39 +66,42 @@ func NewQqCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *QqCall
}
}
func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) error {
func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) (string, error) {
tokenAuthUrl := l.GetQQTokenAuthUrl(req.Code)
token, err := l.GetQQToken(tokenAuthUrl)
if err != nil {
return err
return "", err
}
if token == nil {
return nil
return "", errors.New("get qq token failed")
}
// 通过 token 获取 openid
authQQme, err := l.GetQQUserOpenID(token)
if err != nil {
return err
return "", err
}
// 通过 token 和 openid 获取用户信息
userInfo, err := l.GetQQUserUserInfo(token, authQQme.OpenID)
if err != nil {
return err
return "", err
}
if userInfo == nil {
return "", errors.New("get qq user info failed")
}
// 处理用户信息
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
return err
return "", err
}
var qqUserInfo QQUserInfo
err = json.Unmarshal(userInfoBytes, &qqUserInfo)
if err != nil {
return err
return "", err
}
tx := l.svcCtx.DB.Begin()
@@ -105,7 +109,7 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
userSocial := l.svcCtx.DB.ScaAuthUserSocial
socialUser, err := tx.ScaAuthUserSocial.Where(userSocial.OpenID.Eq(authQQme.OpenID), userSocial.Source.Eq(constant.OAuthSourceQQ)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
return "", err
}
if socialUser == nil {
@@ -114,9 +118,10 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
uidStr := strconv.FormatInt(uid, 10)
male := constant.Male
avatarUrl := strings.Replace(qqUserInfo.FigureurlQq1, "http://", "https://", 1)
addUser := &model.ScaAuthUser{
UID: uidStr,
Avatar: qqUserInfo.FigureurlQq1,
Avatar: avatarUrl,
Username: authQQme.OpenID,
Nickname: qqUserInfo.Nickname,
Gender: male,
@@ -124,7 +129,7 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
err = tx.ScaAuthUser.Create(addUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
githubUser := constant.OAuthSourceQQ
@@ -136,37 +141,42 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
err = tx.ScaAuthUserSocial.Create(newSocialUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if res, err := l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User); !res || err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
} else {
authUser := l.svcCtx.DB.ScaAuthUser
authUserInfo, err := tx.ScaAuthUser.Where(authUser.UID.Eq(socialUser.UserID)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
// GetQQTokenAuthUrl 通过code获取token认证url

View File

@@ -55,7 +55,7 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
key := strings.TrimPrefix(msg.EventKey, "qrscene_")
err = l.HandlerWechatLogin(msg.FromUserName, key, w, r)
@@ -66,10 +66,10 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
case models.CALLBACK_EVENT_UNSUBSCRIBE:
msg := models.EventUnSubscribe{}
err := event.ReadMessage(&msg)
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
return messages.NewText("ok")
@@ -78,7 +78,7 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
err = l.HandlerWechatLogin(msg.FromUserName, msg.EventKey, w, r)
if err != nil {
@@ -90,10 +90,10 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
case models2.CALLBACK_MSG_TYPE_TEXT:
msg := models.MessageText{}
err := event.ReadMessage(&msg)
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
}
return messages.NewText("ok")
@@ -205,5 +205,4 @@ func (l *WechatCallbackLogic) HandlerWechatLogin(openId string, clientId string,
return err
}
return nil
}

View File

@@ -34,11 +34,7 @@ func (l *RefreshTokenLogic) RefreshToken(r *http.Request) (resp *types.Response,
if err != nil {
return nil, err
}
refreshSessionToken, ok := session.Values["refresh_token"].(string)
if !ok {
return response.ErrorWithCode(403), nil
}
userId, ok := session.Values["uid"].(string)
userId, ok := session.Values["user_id"].(string)
if !ok {
return response.ErrorWithCode(403), nil
}
@@ -51,23 +47,20 @@ func (l *RefreshTokenLogic) RefreshToken(r *http.Request) (resp *types.Response,
if err != nil {
return nil, err
}
if redisTokenData.RefreshToken != refreshSessionToken {
return response.ErrorWithCode(403), nil
}
refreshToken, result := jwt.ParseRefreshToken(l.svcCtx.Config.Auth.AccessSecret, refreshSessionToken)
refreshToken, result := jwt.ParseRefreshToken(l.svcCtx.Config.Auth.AccessSecret, redisTokenData.RefreshToken)
if !result {
return response.ErrorWithCode(403), nil
}
accessToken := jwt.GenerateAccessToken(l.svcCtx.Config.Auth.AccessSecret, jwt.AccessJWTPayload{
UserID: refreshToken.UserID,
Type: constant.JWT_TYPE_ACCESS,
})
if accessToken == "" {
return response.ErrorWithCode(403), nil
}
redisToken := types.RedisToken{
AccessToken: accessToken,
RefreshToken: refreshSessionToken,
RefreshToken: redisTokenData.RefreshToken,
UID: refreshToken.UserID,
}
err = l.svcCtx.RedisClient.Set(l.ctx, constant.UserTokenPrefix+refreshToken.UserID, redisToken, time.Hour*24*7).Err()

View File

@@ -3,6 +3,7 @@ package user
import (
"context"
"errors"
"github.com/rbcervilla/redisstore/v9"
"net/http"
"time"
@@ -80,6 +81,7 @@ func HandleUserLogin(user *model.ScaAuthUser, svcCtx *svc.ServiceContext, autoLo
// 生成jwt token
accessToken := jwt.GenerateAccessToken(svcCtx.Config.Auth.AccessSecret, jwt.AccessJWTPayload{
UserID: user.UID,
Type: constant.JWT_TYPE_ACCESS,
})
var days time.Duration
if autoLogin {
@@ -89,6 +91,7 @@ func HandleUserLogin(user *model.ScaAuthUser, svcCtx *svc.ServiceContext, autoLo
}
refreshToken := jwt.GenerateRefreshToken(svcCtx.Config.Auth.AccessSecret, jwt.RefreshJWTPayload{
UserID: user.UID,
Type: constant.JWT_TYPE_REFRESH,
}, days)
data := types.LoginResponse{
AccessToken: accessToken,
@@ -108,16 +111,24 @@ func HandleUserLogin(user *model.ScaAuthUser, svcCtx *svc.ServiceContext, autoLo
if err != nil {
return nil, err
}
session, err := svcCtx.Session.Get(r, constant.SESSION_KEY)
if err != nil {
return nil, err
}
session.Values["refresh_token"] = refreshToken
session.Values["uid"] = user.UID
err = session.Save(r, w)
err = HandlerSession(r, w, user.UID, svcCtx.Session)
if err != nil {
return nil, err
}
return &data, nil
}
// HandlerSession is a function to set the user_id in the session
func HandlerSession(r *http.Request, w http.ResponseWriter, userID string, redisSession *redisstore.RedisStore) error {
session, err := redisSession.Get(r, constant.SESSION_KEY)
if err != nil {
return err
}
session.Values["user_id"] = userID
err = session.Save(r, w)
if err != nil {
return err
}
return nil
}

View File

@@ -3,18 +3,20 @@ package user
import (
"context"
"errors"
"net/http"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"github.com/mssola/useragent"
"github.com/zeromicro/go-zero/core/logx"
"gorm.io/gorm"
"schisandra-album-cloud-microservices/app/core/api/common/constant"
"net/http"
"schisandra-album-cloud-microservices/app/core/api/common/jwt"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/common/utils"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
"schisandra-album-cloud-microservices/app/core/api/repository/mysql/model"
"schisandra-album-cloud-microservices/app/core/api/repository/mysql/query"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
"schisandra-album-cloud-microservices/app/core/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type GetUserDeviceLogic struct {
@@ -31,20 +33,20 @@ func NewGetUserDeviceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Get
}
}
func (l *GetUserDeviceLogic) GetUserDevice(r *http.Request) error {
session, err := l.svcCtx.Session.Get(r, constant.SESSION_KEY)
if err != nil {
return err
}
uid, ok := session.Values["uid"].(string)
func (l *GetUserDeviceLogic) GetUserDevice(r *http.Request, w http.ResponseWriter, req *types.UserDeviceRequest) (resp *types.Response, err error) {
token, ok := jwt.ParseAccessToken(l.svcCtx.Config.Auth.AccessSecret, req.AccessToken)
if !ok {
return errors.New("user session not found")
return response.Error(), nil
}
if err = GetUserLoginDevice(uid, r, l.svcCtx.Ip2Region, l.svcCtx.DB); err != nil {
return err
err = HandlerSession(r, w, token.UserID, l.svcCtx.Session)
if err != nil {
return nil, err
}
return nil
err = GetUserLoginDevice(token.UserID, r, l.svcCtx.Ip2Region, l.svcCtx.DB)
if err != nil {
return nil, err
}
return response.Success(), nil
}
// GetUserLoginDevice 获取用户登录设备

View File

@@ -2,11 +2,10 @@ package middleware
import (
"net/http"
"schisandra-album-cloud-microservices/app/core/api/common/constant"
"github.com/casbin/casbin/v2"
"github.com/rbcervilla/redisstore/v9"
"schisandra-album-cloud-microservices/app/core/api/common/constant"
)
type CasbinVerifyMiddleware struct {
@@ -28,7 +27,7 @@ func (m *CasbinVerifyMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
userId, ok := session.Values["uid"].(string)
userId, ok := session.Values["user_id"].(string)
if !ok {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return

View File

@@ -118,3 +118,7 @@ type UploadRequest struct {
AccessToken string `json:"access_token"`
UserId string `json:"user_id"`
}
type UserDeviceRequest struct {
AccessToken string `json:"access_token"`
}