diff --git a/api/user_api/user_api.go b/api/user_api/user_api.go index 2b7de9b..8302617 100644 --- a/api/user_api/user_api.go +++ b/api/user_api/user_api.go @@ -38,7 +38,7 @@ func (UserAPI) GetUserList(c *gin.Context) { // @Router /api/auth/user/query_by_username [get] func (UserAPI) QueryUserByUsername(c *gin.Context) { username := c.Query("username") - user, _ := userService.QueryUserByUsername(username) + user := userService.QueryUserByUsername(username) if reflect.DeepEqual(user, model.ScaAuthUser{}) { result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c) return @@ -112,7 +112,7 @@ func (UserAPI) AddUser(c *gin.Context) { return } - username, _ := userService.QueryUserByUsername(addUserRequest.Username) + username := userService.QueryUserByUsername(addUserRequest.Username) if !reflect.DeepEqual(username, model.ScaAuthUser{}) { result.FailWithMessage(ginI18n.MustGetMessage(c, "UsernameExists"), c) return @@ -169,59 +169,29 @@ func (UserAPI) AccountLogin(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "AccountAndPasswordNotEmpty"), c) return } - isPhone := utils.IsPhone(account) - if isPhone { - user := userService.QueryUserByPhone(account) - if reflect.DeepEqual(user, model.ScaAuthUser{}) { - result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneNotRegister"), c) - return - } else { - verify := utils.Verify(*user.Password, password) - if verify { - handelUserLogin(user, accountLoginRequest.AutoLogin, c) - return - } else { - result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c) - return - } - } + + var user model.ScaAuthUser + if utils.IsPhone(account) { + user = userService.QueryUserByPhone(account) + } else if utils.IsEmail(account) { + user = userService.QueryUserByEmail(account) + } else if utils.IsUsername(account) { + user = userService.QueryUserByUsername(account) + } else { + result.FailWithMessage(ginI18n.MustGetMessage(c, "AccountErrorFormat"), c) + return } - isEmail := utils.IsEmail(account) - if isEmail { - user := userService.QueryUserByEmail(account) - if reflect.DeepEqual(user, model.ScaAuthUser{}) { - result.FailWithMessage(ginI18n.MustGetMessage(c, "EmailNotRegister"), c) - return - } else { - verify := utils.Verify(*user.Password, password) - if verify { - handelUserLogin(user, accountLoginRequest.AutoLogin, c) - return - } else { - result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c) - return - } - } + + if reflect.DeepEqual(user, model.ScaAuthUser{}) { + result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c) + return } - isUsername := utils.IsUsername(account) - if isUsername { - user, _ := userService.QueryUserByUsername(account) - if reflect.DeepEqual(user, model.ScaAuthUser{}) { - result.FailWithMessage(ginI18n.MustGetMessage(c, "UsernameNotRegister"), c) - return - } else { - verify := utils.Verify(*user.Password, password) - if verify { - handelUserLogin(user, accountLoginRequest.AutoLogin, c) - return - } else { - result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c) - return - } - } + + if !utils.Verify(*user.Password, password) { + result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c) + return } - result.FailWithMessage(ginI18n.MustGetMessage(c, "AccountErrorFormat"), c) - return + handelUserLogin(user, accountLoginRequest.AutoLogin, c) } // PhoneLogin 手机号登录/注册 @@ -320,8 +290,7 @@ func (UserAPI) PhoneLogin(c *gin.Context) { // @Router /api/token/refresh [post] func (UserAPI) RefreshHandler(c *gin.Context) { request := dto.RefreshTokenRequest{} - err := c.ShouldBindJSON(&request) - if err != nil { + if err := c.ShouldBindJSON(&request); err != nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } @@ -336,32 +305,32 @@ func (UserAPI) RefreshHandler(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) return } - if isUpd { - accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID}) - if err != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) - return - } - - token, err := redis.Get(constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID).Result() - if token == "" || err != nil { - global.LOG.Errorln(err) - result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) - return - } - data := dto.ResponseData{ - AccessToken: accessTokenString, - RefreshToken: refreshToken, - UID: parseRefreshToken.UserID, - } - fail := redis.Set("user:login:token:"+*parseRefreshToken.UserID, data, time.Hour*24*7).Err() - if fail != nil { - result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) - return - } - result.OkWithData(data, c) + if !isUpd { + result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) return } + accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID}) + if err != nil { + result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) + return + } + tokenKey := constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID + token, err := redis.Get(tokenKey).Result() + if token == "" || err != nil { + global.LOG.Errorln(err) + result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) + return + } + data := dto.ResponseData{ + AccessToken: accessTokenString, + RefreshToken: refreshToken, + UID: parseRefreshToken.UserID, + } + if err := redis.Set(tokenKey, data, time.Hour*24*7).Err(); err != nil { + result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c) + return + } + result.OkWithData(data, c) } // handelUserLogin 处理用户登录 diff --git a/middleware/casbin.go b/middleware/casbin.go new file mode 100644 index 0000000..85d7883 --- /dev/null +++ b/middleware/casbin.go @@ -0,0 +1,30 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "schisandra-cloud-album/global" +) + +func CasbinMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + userId, ok := c.Get("userId") + if !ok { + global.LOG.Error("casbin middleware: userId not found") + c.Abort() + return + } + method := c.Request.Method + path := c.Request.URL.Path + ok, err := global.Casbin.Enforce(userId.(string), path, method) + if err != nil { + global.LOG.Error("casbin middleware: ", err) + c.Abort() + return + } + if !ok { + c.Abort() + return + } + c.Next() + } +} diff --git a/router/modules/user_router.go b/router/modules/user_router.go index 91d312e..09ad6bb 100644 --- a/router/modules/user_router.go +++ b/router/modules/user_router.go @@ -17,7 +17,7 @@ func UserRouter(router *gin.RouterGroup) { userGroup.POST("/add", userApi.AddUser) userGroup.POST("/reset_password", userApi.ResetPassword) } - authGroup := router.Group("auth").Use(middleware.JWTAuthMiddleware()) + authGroup := router.Group("auth").Use(middleware.JWTAuthMiddleware()).Use(middleware.CasbinMiddleware()) { authGroup.GET("/user/list", userApi.GetUserList) authGroup.GET("/user/query_by_uuid", userApi.QueryUserByUuid) diff --git a/service/user_service/user_service.go b/service/user_service/user_service.go index eab15c4..1cdcfe0 100644 --- a/service/user_service/user_service.go +++ b/service/user_service/user_service.go @@ -14,13 +14,13 @@ func (UserService) GetUserList() []*model.ScaAuthUser { } // QueryUserByUsername 根据用户名查询用户 -func (UserService) QueryUserByUsername(username string) (model.ScaAuthUser, error) { +func (UserService) QueryUserByUsername(username string) model.ScaAuthUser { authUser := model.ScaAuthUser{} err := global.DB.Where("username = ? and deleted = 0", username).First(&authUser).Error if err != nil { - return model.ScaAuthUser{}, err + return model.ScaAuthUser{} } - return authUser, nil + return authUser } // QueryUserByUuid 根据用户uuid查询用户