add casbin permission verification middleware

This commit is contained in:
landaiqing
2024-08-24 18:04:13 +08:00
parent 9330935822
commit 974a96b6e0
4 changed files with 81 additions and 82 deletions

View File

@@ -38,7 +38,7 @@ func (UserAPI) GetUserList(c *gin.Context) {
// @Router /api/auth/user/query_by_username [get]
func (UserAPI) QueryUserByUsername(c *gin.Context) {
username := c.Query("username")
user, _ := userService.QueryUserByUsername(username)
user := userService.QueryUserByUsername(username)
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c)
return
@@ -112,7 +112,7 @@ func (UserAPI) AddUser(c *gin.Context) {
return
}
username, _ := userService.QueryUserByUsername(addUserRequest.Username)
username := userService.QueryUserByUsername(addUserRequest.Username)
if !reflect.DeepEqual(username, model.ScaAuthUser{}) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "UsernameExists"), c)
return
@@ -169,59 +169,29 @@ func (UserAPI) AccountLogin(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "AccountAndPasswordNotEmpty"), c)
return
}
isPhone := utils.IsPhone(account)
if isPhone {
user := userService.QueryUserByPhone(account)
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneNotRegister"), c)
return
} else {
verify := utils.Verify(*user.Password, password)
if verify {
handelUserLogin(user, accountLoginRequest.AutoLogin, c)
return
} else {
result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c)
return
}
}
var user model.ScaAuthUser
if utils.IsPhone(account) {
user = userService.QueryUserByPhone(account)
} else if utils.IsEmail(account) {
user = userService.QueryUserByEmail(account)
} else if utils.IsUsername(account) {
user = userService.QueryUserByUsername(account)
} else {
result.FailWithMessage(ginI18n.MustGetMessage(c, "AccountErrorFormat"), c)
return
}
isEmail := utils.IsEmail(account)
if isEmail {
user := userService.QueryUserByEmail(account)
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "EmailNotRegister"), c)
return
} else {
verify := utils.Verify(*user.Password, password)
if verify {
handelUserLogin(user, accountLoginRequest.AutoLogin, c)
return
} else {
result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c)
return
}
}
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c)
return
}
isUsername := utils.IsUsername(account)
if isUsername {
user, _ := userService.QueryUserByUsername(account)
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "UsernameNotRegister"), c)
return
} else {
verify := utils.Verify(*user.Password, password)
if verify {
handelUserLogin(user, accountLoginRequest.AutoLogin, c)
return
} else {
result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c)
return
}
}
if !utils.Verify(*user.Password, password) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c)
return
}
result.FailWithMessage(ginI18n.MustGetMessage(c, "AccountErrorFormat"), c)
return
handelUserLogin(user, accountLoginRequest.AutoLogin, c)
}
// PhoneLogin 手机号登录/注册
@@ -320,8 +290,7 @@ func (UserAPI) PhoneLogin(c *gin.Context) {
// @Router /api/token/refresh [post]
func (UserAPI) RefreshHandler(c *gin.Context) {
request := dto.RefreshTokenRequest{}
err := c.ShouldBindJSON(&request)
if err != nil {
if err := c.ShouldBindJSON(&request); err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
@@ -336,32 +305,32 @@ func (UserAPI) RefreshHandler(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
if isUpd {
accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID})
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
token, err := redis.Get(constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID).Result()
if token == "" || err != nil {
global.LOG.Errorln(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
data := dto.ResponseData{
AccessToken: accessTokenString,
RefreshToken: refreshToken,
UID: parseRefreshToken.UserID,
}
fail := redis.Set("user:login:token:"+*parseRefreshToken.UserID, data, time.Hour*24*7).Err()
if fail != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
result.OkWithData(data, c)
if !isUpd {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID})
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
tokenKey := constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID
token, err := redis.Get(tokenKey).Result()
if token == "" || err != nil {
global.LOG.Errorln(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
data := dto.ResponseData{
AccessToken: accessTokenString,
RefreshToken: refreshToken,
UID: parseRefreshToken.UserID,
}
if err := redis.Set(tokenKey, data, time.Hour*24*7).Err(); err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
result.OkWithData(data, c)
}
// handelUserLogin 处理用户登录

30
middleware/casbin.go Normal file
View File

@@ -0,0 +1,30 @@
package middleware
import (
"github.com/gin-gonic/gin"
"schisandra-cloud-album/global"
)
func CasbinMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
userId, ok := c.Get("userId")
if !ok {
global.LOG.Error("casbin middleware: userId not found")
c.Abort()
return
}
method := c.Request.Method
path := c.Request.URL.Path
ok, err := global.Casbin.Enforce(userId.(string), path, method)
if err != nil {
global.LOG.Error("casbin middleware: ", err)
c.Abort()
return
}
if !ok {
c.Abort()
return
}
c.Next()
}
}

View File

@@ -17,7 +17,7 @@ func UserRouter(router *gin.RouterGroup) {
userGroup.POST("/add", userApi.AddUser)
userGroup.POST("/reset_password", userApi.ResetPassword)
}
authGroup := router.Group("auth").Use(middleware.JWTAuthMiddleware())
authGroup := router.Group("auth").Use(middleware.JWTAuthMiddleware()).Use(middleware.CasbinMiddleware())
{
authGroup.GET("/user/list", userApi.GetUserList)
authGroup.GET("/user/query_by_uuid", userApi.QueryUserByUuid)

View File

@@ -14,13 +14,13 @@ func (UserService) GetUserList() []*model.ScaAuthUser {
}
// QueryUserByUsername 根据用户名查询用户
func (UserService) QueryUserByUsername(username string) (model.ScaAuthUser, error) {
func (UserService) QueryUserByUsername(username string) model.ScaAuthUser {
authUser := model.ScaAuthUser{}
err := global.DB.Where("username = ? and deleted = 0", username).First(&authUser).Error
if err != nil {
return model.ScaAuthUser{}, err
return model.ScaAuthUser{}
}
return authUser, nil
return authUser
}
// QueryUserByUuid 根据用户uuid查询用户