🐛 fix session bug

This commit is contained in:
landaiqing
2024-11-17 20:02:59 +08:00
parent 34c4690f80
commit 78a162a19a
72 changed files with 1304 additions and 453 deletions

View File

@@ -36,7 +36,7 @@ func NewAccountLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Acco
func (l *AccountLoginLogic) AccountLogin(w http.ResponseWriter, r *http.Request, req *types.AccountLoginRequest) (resp *types.Response, err error) {
verifyResult := verify.VerifyRotateCaptcha(l.ctx, l.svcCtx.RedisClient, req.Angle, req.Key)
if !verifyResult {
return response.ErrorWithI18n(l.ctx, "captcha.verificationFailure", "验证失败!"), nil
return response.ErrorWithI18n(l.ctx, "captcha.verificationFailure"), nil
}
var user *ent.ScaAuthUser
var query *ent.ScaAuthUserQuery
@@ -49,23 +49,27 @@ func (l *AccountLoginLogic) AccountLogin(w http.ResponseWriter, r *http.Request,
case utils.IsUsername(req.Account):
query = l.svcCtx.MySQLClient.ScaAuthUser.Query().Where(scaauthuser.UsernameEQ(req.Account), scaauthuser.DeletedEQ(0))
default:
return response.ErrorWithI18n(l.ctx, "login.invalidAccount", "无效账号!"), nil
return response.ErrorWithI18n(l.ctx, "login.invalidAccount"), nil
}
user, err = query.First(l.ctx)
if err != nil {
if ent.IsNotFound(err) {
return response.ErrorWithI18n(l.ctx, "login.userNotRegistered", "用户未注册!"), nil
return response.ErrorWithI18n(l.ctx, "login.userNotRegistered"), nil
}
return nil, err
}
if !utils.Verify(user.Password, req.Password) {
return response.ErrorWithI18n(l.ctx, "login.invalidPassword", "密码错误!"), nil
return response.ErrorWithI18n(l.ctx, "login.invalidPassword"), nil
}
data, result := HandleUserLogin(user, l.svcCtx, req.AutoLogin, r, w, l.ctx)
if !result {
return response.ErrorWithI18n(l.ctx, "login.loginFailed", "登录失败!"), nil
return response.ErrorWithI18n(l.ctx, "login.loginFailed"), nil
}
// 记录用户登录设备
if !GetUserLoginDevice(user.UID, r, l.svcCtx.Ip2Region, l.svcCtx.MySQLClient, l.ctx) {
return response.ErrorWithI18n(l.ctx, "login.loginFailed"), nil
}
return response.SuccessWithData(data), nil
}
@@ -95,32 +99,26 @@ func HandleUserLogin(user *ent.ScaAuthUser, svcCtx *svc.ServiceContext, autoLogi
}
redisToken := types.RedisToken{
AccessToken: accessToken,
UID: user.UID,
AccessToken: accessToken,
RefreshToken: refreshToken,
UID: user.UID,
}
err := svcCtx.RedisClient.Set(ctx, constant.UserTokenPrefix+user.UID, redisToken, days).Err()
if err != nil {
logc.Error(ctx, err)
return nil, false
}
sessionData := types.SessionData{
RefreshToken: refreshToken,
UID: user.UID,
}
session, err := svcCtx.Session.Get(r, constant.SESSION_KEY)
if err != nil {
logc.Error(ctx, err)
return nil, false
}
session.Values[constant.SESSION_KEY] = sessionData
session.Values["refresh_token"] = refreshToken
session.Values["uid"] = user.UID
err = session.Save(r, w)
if err != nil {
return nil, false
}
// 记录用户登录设备
if !GetUserLoginDevice(user.UID, r, svcCtx.Ip2Region, svcCtx.MySQLClient, ctx) {
return nil, false
}
return &data, true
}

View File

@@ -39,7 +39,7 @@ func (l *GetUserDeviceLogic) GetUserDevice(r *http.Request) error {
}
sessionData, ok := session.Values[constant.SESSION_KEY]
if !ok {
return errors.New("User not found or device not found")
return errors.New("user session not found")
}
var data types.SessionData
err = json.Unmarshal(sessionData.([]byte), &data)
@@ -49,7 +49,7 @@ func (l *GetUserDeviceLogic) GetUserDevice(r *http.Request) error {
res := GetUserLoginDevice(data.UID, r, l.svcCtx.Ip2Region, l.svcCtx.MySQLClient, l.ctx)
if !res {
return errors.New("User not found or device not found")
return errors.New("user device not found")
}
return nil
}

View File

@@ -34,19 +34,19 @@ func NewPhoneLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PhoneL
func (l *PhoneLoginLogic) PhoneLogin(r *http.Request, w http.ResponseWriter, req *types.PhoneLoginRequest) (resp *types.Response, err error) {
if !utils.IsPhone(req.Phone) {
return response.ErrorWithI18n(l.ctx, "login.phoneFormatError", "手机号格式错误"), nil
return response.ErrorWithI18n(l.ctx, "login.phoneFormatError"), nil
}
code := l.svcCtx.RedisClient.Get(l.ctx, constant.UserSmsRedisPrefix+req.Phone).Val()
if code == "" {
return response.ErrorWithI18n(l.ctx, "login.captchaExpired", "验证码已过期"), nil
return response.ErrorWithI18n(l.ctx, "login.captchaExpired"), nil
}
if req.Captcha != code {
return response.ErrorWithI18n(l.ctx, "login.captchaError", "验证码错误"), nil
return response.ErrorWithI18n(l.ctx, "login.captchaError"), nil
}
user, err := l.svcCtx.MySQLClient.ScaAuthUser.Query().Where(scaauthuser.Phone(req.Phone), scaauthuser.Deleted(0)).First(l.ctx)
tx, wrong := l.svcCtx.MySQLClient.Tx(l.ctx)
if wrong != nil {
return response.ErrorWithI18n(l.ctx, "login.loginFailed", "登录失败"), err
return response.ErrorWithI18n(l.ctx, "login.loginFailed"), err
}
if ent.IsNotFound(err) {
uid := idgen.NextId()
@@ -64,17 +64,21 @@ func (l *PhoneLoginLogic) PhoneLogin(r *http.Request, w http.ResponseWriter, req
Save(l.ctx)
if fault != nil {
err = tx.Rollback()
return response.ErrorWithI18n(l.ctx, "login.registerError", "注册失败"), err
return response.ErrorWithI18n(l.ctx, "login.registerError"), err
}
_, err = l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User)
if err != nil {
err = tx.Rollback()
return response.ErrorWithI18n(l.ctx, "login.registerError", "注册失败"), err
return response.ErrorWithI18n(l.ctx, "login.registerError"), err
}
data, result := HandleUserLogin(addUser, l.svcCtx, req.AutoLogin, r, w, l.ctx)
if !result {
err = tx.Rollback()
return response.ErrorWithI18n(l.ctx, "login.registerError", "注册失败"), err
return response.ErrorWithI18n(l.ctx, "login.registerError"), err
}
// 记录用户登录设备
if !GetUserLoginDevice(addUser.UID, r, l.svcCtx.Ip2Region, l.svcCtx.MySQLClient, l.ctx) {
return response.ErrorWithI18n(l.ctx, "login.registerError"), nil
}
err = tx.Commit()
if err != nil {
@@ -85,7 +89,11 @@ func (l *PhoneLoginLogic) PhoneLogin(r *http.Request, w http.ResponseWriter, req
data, result := HandleUserLogin(user, l.svcCtx, req.AutoLogin, r, w, l.ctx)
if !result {
err = tx.Rollback()
return response.ErrorWithI18n(l.ctx, "login.loginFailed", "登录失败"), err
return response.ErrorWithI18n(l.ctx, "login.loginFailed"), err
}
// 记录用户登录设备
if !GetUserLoginDevice(user.UID, r, l.svcCtx.Ip2Region, l.svcCtx.MySQLClient, l.ctx) {
return response.ErrorWithI18n(l.ctx, "login.loginFailed"), nil
}
err = tx.Commit()
if err != nil {
@@ -93,6 +101,6 @@ func (l *PhoneLoginLogic) PhoneLogin(r *http.Request, w http.ResponseWriter, req
}
return response.SuccessWithData(data), nil
} else {
return response.ErrorWithI18n(l.ctx, "login.loginFailed", "登录失败"), nil
return response.ErrorWithI18n(l.ctx, "login.loginFailed"), nil
}
}

View File

@@ -34,28 +34,41 @@ func (l *RefreshTokenLogic) RefreshToken(r *http.Request) (resp *types.Response,
if err != nil {
return response.ErrorWithCode(403), err
}
sessionData, ok := session.Values[constant.SESSION_KEY]
refreshSessionToken, ok := session.Values["refresh_token"].(string)
if !ok {
return response.ErrorWithCode(403), err
return response.ErrorWithCode(403), nil
}
data := types.SessionData{}
err = json.Unmarshal(sessionData.([]byte), &data)
userId, ok := session.Values["uid"].(string)
if !ok {
return response.ErrorWithCode(403), nil
}
tokenData := l.svcCtx.RedisClient.Get(l.ctx, constant.UserTokenPrefix+userId).Val()
if tokenData == "" {
return response.ErrorWithCode(403), nil
}
redisTokenData := types.RedisToken{}
err = json.Unmarshal([]byte(tokenData), &redisTokenData)
if err != nil {
return response.ErrorWithCode(403), err
}
refreshToken, result := jwt.ParseRefreshToken(l.svcCtx.Config.Auth.AccessSecret, data.RefreshToken)
if redisTokenData.RefreshToken != refreshSessionToken {
return response.ErrorWithCode(403), nil
}
refreshToken, result := jwt.ParseRefreshToken(l.svcCtx.Config.Auth.AccessSecret, refreshSessionToken)
if !result {
return response.ErrorWithCode(403), err
return response.ErrorWithCode(403), nil
}
accessToken := jwt.GenerateAccessToken(l.svcCtx.Config.Auth.AccessSecret, jwt.AccessJWTPayload{
UserID: refreshToken.UserID,
})
if accessToken == "" {
return response.ErrorWithCode(403), err
return response.ErrorWithCode(403), nil
}
redisToken := types.RedisToken{
AccessToken: accessToken,
UID: refreshToken.UserID,
AccessToken: accessToken,
RefreshToken: refreshSessionToken,
UID: refreshToken.UserID,
}
err = l.svcCtx.RedisClient.Set(l.ctx, constant.UserTokenPrefix+refreshToken.UserID, redisToken, time.Hour*24*7).Err()
if err != nil {

View File

@@ -30,36 +30,36 @@ func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Res
func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (resp *types.Response, err error) {
if !utils.IsPhone(req.Phone) {
return response.ErrorWithI18n(l.ctx, "login.phoneFormatError", "手机号格式错误"), nil
return response.ErrorWithI18n(l.ctx, "login.phoneFormatError"), nil
}
if req.Password != req.Repassword {
return response.ErrorWithI18n(l.ctx, "login.passwordNotMatch", "两次密码输入不一致"), nil
return response.ErrorWithI18n(l.ctx, "login.passwordNotMatch"), nil
}
if !utils.IsPassword(req.Password) {
return response.ErrorWithI18n(l.ctx, "login.passwordFormatError", "密码格式错误"), nil
return response.ErrorWithI18n(l.ctx, "login.passwordFormatError"), nil
}
code := l.svcCtx.RedisClient.Get(l.ctx, constant.UserSmsRedisPrefix+req.Phone).Val()
if code == "" {
return response.ErrorWithI18n(l.ctx, "login.captchaExpired", "验证码已过期"), nil
return response.ErrorWithI18n(l.ctx, "login.captchaExpired"), nil
}
if req.Captcha != code {
return response.ErrorWithI18n(l.ctx, "login.captchaError", "验证码错误"), nil
return response.ErrorWithI18n(l.ctx, "login.captchaError"), nil
}
// 验证码检查通过后立即删除或标记为已使用
if err = l.svcCtx.RedisClient.Del(l.ctx, constant.UserSmsRedisPrefix+req.Phone).Err(); err != nil {
return response.ErrorWithI18n(l.ctx, "login.captchaError", "验证码错误"), nil
return response.ErrorWithI18n(l.ctx, "login.captchaError"), err
}
user, err := l.svcCtx.MySQLClient.ScaAuthUser.Query().Where(scaauthuser.Phone(req.Phone), scaauthuser.Deleted(constant.NotDeleted)).First(l.ctx)
if err != nil && ent.IsNotFound(err) {
return response.ErrorWithI18n(l.ctx, "login.userNotRegistered", "用户未注册"), nil
return response.ErrorWithI18n(l.ctx, "login.userNotRegistered"), err
}
encrypt, err := utils.Encrypt(req.Password)
if err != nil {
return response.ErrorWithI18n(l.ctx, "login.resetPasswordError", "重置密码失败"), nil
return response.ErrorWithI18n(l.ctx, "login.resetPasswordError"), err
}
err = user.Update().SetPassword(encrypt).Exec(l.ctx)
if err != nil {
return response.ErrorWithI18n(l.ctx, "login.resetPasswordError", "重置密码失败"), err
return response.ErrorWithI18n(l.ctx, "login.resetPasswordError"), err
}
return response.Success(), nil
}