From 6e47e7251434a931c12d707c0bbffd2bda9d4c53 Mon Sep 17 00:00:00 2001 From: landaiqing <3517283258@qq.com> Date: Sun, 25 Aug 2024 00:06:00 +0800 Subject: [PATCH] :zap: login performance optimization --- api/captcha_api/captcha_api.go | 39 ++++--- api/oauth_api/wechat_api.go | 39 +++---- api/user_api/user_api.go | 191 +++++++++++++++------------------ i18n/language/en.toml | 1 + i18n/language/zh.toml | 1 + router/modules/user_router.go | 1 - utils/get_ip.go | 15 +++ utils/match.go | 7 ++ 8 files changed, 149 insertions(+), 145 deletions(-) create mode 100644 utils/get_ip.go diff --git a/api/captcha_api/captcha_api.go b/api/captcha_api/captcha_api.go index adba39f..7c02e52 100644 --- a/api/captcha_api/captcha_api.go +++ b/api/captcha_api/captcha_api.go @@ -30,32 +30,37 @@ func (CaptchaAPI) GenerateRotateCaptcha(c *gin.Context) { captchaData, err := global.RotateCaptcha.Generate() if err != nil { global.LOG.Fatalln(err) + result.FailWithNull(c) + return } blockData := captchaData.GetData() if blockData == nil { result.FailWithNull(c) return } - var masterImageBase64, thumbImageBase64 string - masterImageBase64 = captchaData.GetMasterImage().ToBase64() - thumbImageBase64 = captchaData.GetThumbImage().ToBase64() + + masterImageBase64 := captchaData.GetMasterImage().ToBase64() + thumbImageBase64 := captchaData.GetThumbImage().ToBase64() dotsByte, err := json.Marshal(blockData) if err != nil { + global.LOG.Fatalln(err) result.FailWithNull(c) return } + key := helper.StringToMD5(string(dotsByte)) err = redis.Set(constant.UserLoginCaptchaRedisKey+key, dotsByte, time.Minute).Err() if err != nil { + global.LOG.Fatalln(err) result.FailWithNull(c) return } - bt := map[string]interface{}{ + + result.OkWithData(map[string]interface{}{ "key": key, "image": masterImageBase64, "thumb": thumbImageBase64, - } - result.OkWithData(bt, c) + }, c) } // CheckRotateData 验证旋转验证码 @@ -67,30 +72,36 @@ func (CaptchaAPI) GenerateRotateCaptcha(c *gin.Context) { // @Success 200 {string} json // @Router /api/captcha/rotate/check [post] func (CaptchaAPI) CheckRotateData(c *gin.Context) { - rotateRequest := dto.RotateCaptchaRequest{} - err := c.ShouldBindJSON(&rotateRequest) - angle := rotateRequest.Angle - key := rotateRequest.Key - if err != nil { + var rotateRequest dto.RotateCaptchaRequest + if err := c.ShouldBindJSON(&rotateRequest); err != nil { result.FailWithNull(c) return } - cacheDataByte, err := redis.Get(constant.UserLoginCaptchaRedisKey + key).Bytes() - if len(cacheDataByte) == 0 || err != nil { + + cacheDataByte, err := redis.Get(constant.UserLoginCaptchaRedisKey + rotateRequest.Key).Bytes() + if err != nil || len(cacheDataByte) == 0 { result.FailWithCodeAndMessage(1011, ginI18n.MustGetMessage(c, "CaptchaExpired"), c) return } + var dct *rotate.Block if err := json.Unmarshal(cacheDataByte, &dct); err != nil { result.FailWithNull(c) return } - sAngle, _ := strconv.ParseFloat(fmt.Sprintf("%v", angle), 64) + + sAngle, err := strconv.ParseFloat(fmt.Sprintf("%v", rotateRequest.Angle), 64) + if err != nil { + result.FailWithNull(c) + return + } + chkRet := rotate.CheckAngle(int64(sAngle), int64(dct.Angle), 2) if chkRet { result.OkWithMessage("success", c) return } + result.FailWithMessage("fail", c) } diff --git a/api/oauth_api/wechat_api.go b/api/oauth_api/wechat_api.go index 0f0dae9..ae9303f 100644 --- a/api/oauth_api/wechat_api.go +++ b/api/oauth_api/wechat_api.go @@ -40,13 +40,7 @@ var mu sync.Mutex // @Router /api/oauth/generate_client_id [get] func (OAuthAPI) GenerateClientId(c *gin.Context) { // 获取客户端IP - ip := c.GetHeader("X-Real-IP") - if ip == "" { - ip = c.GetHeader("X-Forwarded-For") - } - if ip == "" { - ip = c.ClientIP() - } + ip := utils.GetClientIP(c) // 加锁 mu.Lock() defer mu.Unlock() @@ -161,49 +155,48 @@ func (OAuthAPI) CallbackVerify(c *gin.Context) { // @Router /api/oauth/get_temp_qrcode [get] func (OAuthAPI) GetTempQrCode(c *gin.Context) { clientId := c.Query("client_id") - // 获取客户端IP - ip := c.GetHeader("X-Real-IP") - if ip == "" { - ip = c.GetHeader("X-Forwarded-For") - } - if ip == "" { - ip = c.ClientIP() - } if clientId == "" { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - qrcode := redis.Get(constant.UserLoginQrcodeRedisKey + ip + ":" + clientId).Val() + ip := utils.GetClientIP(c) // 使用工具函数获取客户端IP + key := constant.UserLoginQrcodeRedisKey + ip + ":" + clientId + + // 从Redis获取二维码数据 + qrcode := redis.Get(key).Val() if qrcode != "" { - data := response.ResponseQRCodeCreate{} - err := json.Unmarshal([]byte(qrcode), &data) - if err != nil { + data := new(response.ResponseQRCodeCreate) + if err := json.Unmarshal([]byte(qrcode), data); err != nil { global.LOG.Error(err) + result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c) return } result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c) return } + + // 生成临时二维码 data, err := global.Wechat.QRCode.Temporary(c.Request.Context(), clientId, 30*24*3600) if err != nil { global.LOG.Error(err) result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c) return } + + // 序列化数据并存储到Redis serializedData, err := json.Marshal(data) if err != nil { global.LOG.Error(err) result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c) return } - wrong := redis.Set(constant.UserLoginQrcodeRedisKey+ip+":"+clientId, serializedData, time.Hour*24*30).Err() - - if wrong != nil { - global.LOG.Error(wrong) + if err := redis.Set(key, serializedData, time.Hour*24*30).Err(); err != nil { + global.LOG.Error(err) result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c) return } + result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c) } diff --git a/api/user_api/user_api.go b/api/user_api/user_api.go index 8302617..8f3317c 100644 --- a/api/user_api/user_api.go +++ b/api/user_api/user_api.go @@ -98,58 +98,6 @@ func (UserAPI) QueryUserByPhone(c *gin.Context) { result.OkWithData(user, c) } -// AddUser 添加用户 -// @Summary 添加用户 -// @Tags 用户模块 -// @Param user body dto.AddUserRequest true "用户信息" -// @Success 200 {string} json -// @Router /api/user/add [post] -func (UserAPI) AddUser(c *gin.Context) { - addUserRequest := dto.AddUserRequest{} - err := c.ShouldBindJSON(&addUserRequest) - if err != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) - return - } - - username := userService.QueryUserByUsername(addUserRequest.Username) - if !reflect.DeepEqual(username, model.ScaAuthUser{}) { - result.FailWithMessage(ginI18n.MustGetMessage(c, "UsernameExists"), c) - return - } - - phone := userService.QueryUserByPhone(addUserRequest.Phone) - if !reflect.DeepEqual(phone, model.ScaAuthUser{}) { - result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneExists"), c) - return - } - encrypt, err := utils.Encrypt(addUserRequest.Password) - if err != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "AddUserError"), c) - return - } - uid := idgen.NextId() - uidStr := strconv.FormatInt(uid, 10) - user := model.ScaAuthUser{ - UID: &uidStr, - Username: &addUserRequest.Username, - Password: &encrypt, - Phone: &addUserRequest.Phone, - } - _, err = userService.AddUser(user) - if err != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "AddUserError"), c) - return - } - _, err = global.Casbin.AddRoleForUser(uidStr, enum.User) - if err != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "AddUserRoleError"), c) - return - } - result.OkWithMessage(ginI18n.MustGetMessage(c, "AddUserSuccess"), c) - return -} - // AccountLogin 账号登录 // @Summary 账号登录 // @Tags 用户模块 @@ -219,30 +167,11 @@ func (UserAPI) PhoneLogin(c *gin.Context) { return } - // 异步查询用户信息 - userChan := make(chan *model.ScaAuthUser) - go func() { - user := userService.QueryUserByPhone(phone) - userChan <- &user - }() - - // 异步获取验证码 - codeChan := make(chan string) - go func() { - code := redis.Get(constant.UserLoginSmsRedisKey + phone) - if code == nil { - codeChan <- "" - } else { - codeChan <- code.Val() - } - }() - - user := <-userChan - code := <-codeChan - + user := userService.QueryUserByPhone(phone) if reflect.DeepEqual(user, model.ScaAuthUser{}) { // 未注册 - if code == "" { + code := redis.Get(constant.UserLoginSmsRedisKey + phone) + if code == nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c) return } else { @@ -263,23 +192,31 @@ func (UserAPI) PhoneLogin(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "RegisterUserError"), c) return } + err = global.Casbin.SavePolicy() + if err != nil { + result.FailWithMessage(ginI18n.MustGetMessage(c, "RegisterUserError"), c) + return + } handelUserLogin(addUser, request.AutoLogin, c) return } } else { - if code == "" { + code := redis.Get(constant.UserLoginSmsRedisKey + phone) + if code == nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c) return } else { - if captcha != code { + if captcha != code.Val() { result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaError"), c) return } else { - handelUserLogin(*user, request.AutoLogin, c) + handelUserLogin(user, request.AutoLogin, c) return } } + } + } // RefreshHandler 刷新token @@ -300,15 +237,11 @@ func (UserAPI) RefreshHandler(c *gin.Context) { return } parseRefreshToken, isUpd, err := utils.ParseRefreshToken(refreshToken) - if err != nil { + if err != nil || !isUpd { global.LOG.Errorln(err) result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) return } - if !isUpd { - result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) - return - } accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID}) if err != nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) @@ -316,7 +249,7 @@ func (UserAPI) RefreshHandler(c *gin.Context) { } tokenKey := constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID token, err := redis.Get(tokenKey).Result() - if token == "" || err != nil { + if err != nil || token == "" { global.LOG.Errorln(err) result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) return @@ -335,17 +268,24 @@ func (UserAPI) RefreshHandler(c *gin.Context) { // handelUserLogin 处理用户登录 func handelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) { + // 检查 user.UID 是否为 nil + if user.UID == nil { + result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) + return + } accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID}) if err != nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c) return } + var days time.Duration if autoLogin { - days = time.Hour * 24 * 7 + days = 7 * 24 * time.Hour } else { - days = time.Hour * 24 * 1 + days = 24 * time.Hour } + refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID}, days) data := dto.ResponseData{ AccessToken: accessToken, @@ -353,13 +293,14 @@ func handelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) { ExpiresAt: expiresAt, UID: user.UID, } - fail := redis.Set(constant.UserLoginTokenRedisKey+*user.UID, data, time.Hour*24*1).Err() - if fail != nil { + + err = redis.Set(constant.UserLoginTokenRedisKey+*user.UID, data, days).Err() + if err != nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c) return } + result.OkWithData(data, c) - return } // ResetPassword 重置密码 @@ -369,50 +310,86 @@ func handelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) { // @Success 200 {string} json // @Router /api/user/reset_password [post] func (UserAPI) ResetPassword(c *gin.Context) { - resetPasswordRequest := dto.ResetPasswordRequest{} - err := c.ShouldBindJSON(&resetPasswordRequest) - if err != nil { + var resetPasswordRequest dto.ResetPasswordRequest + if err := c.ShouldBindJSON(&resetPasswordRequest); err != nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } + phone := resetPasswordRequest.Phone captcha := resetPasswordRequest.Captcha password := resetPasswordRequest.Password repassword := resetPasswordRequest.Repassword + if phone == "" || captcha == "" || password == "" || repassword == "" { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - isPhone := utils.IsPhone(phone) - if !isPhone { - result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneErrorFormat"), c) + + if !utils.IsPhone(phone) { + result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneError"), c) return } - code := redis.Get(constant.UserLoginSmsRedisKey + phone) - if code == nil { + + if password != repassword { + result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordNotSame"), c) + return + } + + if !utils.IsPassword(password) { + result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c) + return + } + + // 使用事务确保验证码检查和密码更新的原子性 + tx := global.DB.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + if err := tx.Error; err != nil { + result.FailWithMessage(ginI18n.MustGetMessage(c, "DatabaseError"), c) + return + } + + code := redis.Get(constant.UserLoginSmsRedisKey + phone).Val() + if code == "" { result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c) return - } else { - if captcha != code.Val() { - result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaError"), c) - return - } } + + if captcha != code { + result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaError"), c) + return + } + + // 验证码检查通过后立即删除或标记为已使用 + if err := redis.Del(constant.UserLoginSmsRedisKey + phone).Err(); err != nil { + tx.Rollback() + result.FailWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordError"), c) + return + } + user := userService.QueryUserByPhone(phone) if reflect.DeepEqual(user, model.ScaAuthUser{}) { result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneNotRegister"), c) return } + encrypt, err := utils.Encrypt(password) if err != nil { + result.FailWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordError")+": "+err.Error(), c) + return + } + + if err := userService.UpdateUser(phone, encrypt); err != nil { + tx.Rollback() result.FailWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordError"), c) return } - wrong := userService.UpdateUser(phone, encrypt) - if wrong != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordError"), c) - return - } + + tx.Commit() result.OkWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordSuccess"), c) - return } diff --git a/i18n/language/en.toml b/i18n/language/en.toml index 960e799..1e97279 100644 --- a/i18n/language/en.toml +++ b/i18n/language/en.toml @@ -52,3 +52,4 @@ ResetPasswordSuccess = "reset password success!" QRCodeGetFailed = "qr code get failed!" QRCodeGetSuccess = "qr code get successfully!" QRCodeExpired = "qr code expired!" +InternalError = "internal error!" diff --git a/i18n/language/zh.toml b/i18n/language/zh.toml index 7a4e828..bdc8b80 100644 --- a/i18n/language/zh.toml +++ b/i18n/language/zh.toml @@ -52,4 +52,5 @@ ResetPasswordSuccess = "重置密码成功!" QRCodeGetFailed = "获取二维码失败!" QRCodeGetSuccess = "获取二维码成功!" QRCodeExpired = "二维码已过期!" +InternalError = "内部错误!" diff --git a/router/modules/user_router.go b/router/modules/user_router.go index 09ad6bb..44c51bf 100644 --- a/router/modules/user_router.go +++ b/router/modules/user_router.go @@ -14,7 +14,6 @@ func UserRouter(router *gin.RouterGroup) { { userGroup.POST("/login", userApi.AccountLogin) userGroup.POST("/phone_login", userApi.PhoneLogin) - userGroup.POST("/add", userApi.AddUser) userGroup.POST("/reset_password", userApi.ResetPassword) } authGroup := router.Group("auth").Use(middleware.JWTAuthMiddleware()).Use(middleware.CasbinMiddleware()) diff --git a/utils/get_ip.go b/utils/get_ip.go new file mode 100644 index 0000000..685a91c --- /dev/null +++ b/utils/get_ip.go @@ -0,0 +1,15 @@ +package utils + +import "github.com/gin-gonic/gin" + +// GetClientIP 工具函数,获取客户端IP +func GetClientIP(c *gin.Context) string { + ip := c.GetHeader("X-Real-IP") + if ip == "" { + ip = c.GetHeader("X-Forwarded-For") + } + if ip == "" { + ip = c.ClientIP() + } + return ip +} diff --git a/utils/match.go b/utils/match.go index c954b30..a299b1e 100644 --- a/utils/match.go +++ b/utils/match.go @@ -30,3 +30,10 @@ func IsUsername(username string) bool { match, _ := regexp.MatchString(phoneRegex, username) return match } + +// IsPassword 密码的正则表达式 +func IsPassword(password string) bool { + phoneRegex := `^(?=.*[A-Za-z])(?=.*\d)(?=.*[@$!%*#?&])[A-Za-z\d@$!%*#?&]{6,18}$` + match, _ := regexp.MatchString(phoneRegex, password) + return match +}