update qq/gitee/github oauth2 login

This commit is contained in:
landaiqing
2024-08-24 16:31:40 +08:00
parent 014abca8f8
commit 9330935822
45 changed files with 1243 additions and 642 deletions

View File

@@ -3,6 +3,8 @@ package api
import (
"schisandra-cloud-album/api/captcha_api"
"schisandra-cloud-album/api/oauth_api"
"schisandra-cloud-album/api/permission_api"
"schisandra-cloud-album/api/role_api"
"schisandra-cloud-album/api/sms_api"
"schisandra-cloud-album/api/user_api"
"schisandra-cloud-album/api/websocket_api"
@@ -10,11 +12,13 @@ import (
// Apis 统一导出的api
type Apis struct {
UserApi user_api.UserAPI
CaptchaApi captcha_api.CaptchaAPI
SmsApi sms_api.SmsAPI
OAuthApi oauth_api.OAuthAPI
WebsocketApi websocket_api.WebsocketAPI
UserApi user_api.UserAPI
CaptchaApi captcha_api.CaptchaAPI
SmsApi sms_api.SmsAPI
OAuthApi oauth_api.OAuthAPI
WebsocketApi websocket_api.WebsocketAPI
RoleApi role_api.RoleAPI
PermissionApi permission_api.PermissionAPI
}
// Api new函数实例化实例化完成后会返回结构体地指针类型

View File

@@ -142,100 +142,112 @@ func (OAuthAPI) GiteeCallback(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
// 通过 code, 获取 token
var tokenAuthUrl = GetGiteeTokenAuthUrl(code)
var token *Token
if token, err = GetGiteeToken(tokenAuthUrl); err != nil {
global.LOG.Error(err)
return
}
// 通过token获取用户信息
var userInfo map[string]interface{}
if userInfo, err = GetGiteeUserInfo(token); err != nil {
global.LOG.Error(err)
return
}
// 异步获取 token
var tokenChan = make(chan *Token)
var errChan = make(chan error)
go func() {
var tokenAuthUrl = GetGiteeTokenAuthUrl(code)
token, err := GetGiteeToken(tokenAuthUrl)
if err != nil {
errChan <- err
return
}
tokenChan <- token
}()
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
global.LOG.Error(err)
return
}
var giteeUser GiteeUser
err = json.Unmarshal(userInfoBytes, &giteeUser)
if err != nil {
global.LOG.Error(err)
return
}
// 异步获取用户信息
var userInfoChan = make(chan map[string]interface{})
go func() {
token := <-tokenChan
if token == nil {
errChan <- errors.New("failed to get token")
return
}
userInfo, err := GetGiteeUserInfo(token)
if err != nil {
errChan <- err
return
}
userInfoChan <- userInfo
}()
Id := strconv.Itoa(giteeUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGitee)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &giteeUser.Login,
Nickname: &giteeUser.Name,
Avatar: &giteeUser.AvatarURL,
Blog: &giteeUser.Blog,
Email: &giteeUser.Email,
}
addUser, err := userService.AddUser(user)
// 等待结果
select {
case err = <-errChan:
global.LOG.Error(err)
return
case userInfo := <-userInfoChan:
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
global.LOG.Error(err)
return
}
gitee := enum.OAuthSourceGitee
userSocial = model.ScaAuthUserSocial{
UserID: &addUser.ID,
UUID: &Id,
Source: &gitee,
}
err = userSocialService.AddUserSocial(userSocial)
var giteeUser GiteeUser
err = json.Unmarshal(userInfoBytes, &giteeUser)
if err != nil {
global.LOG.Error(err)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
Id := strconv.Itoa(giteeUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGitee)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
db := global.DB
tx := db.Begin() // 开始事务
if tx.Error != nil {
global.LOG.Error(tx.Error)
return
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &giteeUser.Login,
Nickname: &giteeUser.Name,
Avatar: &giteeUser.AvatarURL,
Blog: &giteeUser.Blog,
Email: &giteeUser.Email,
}
addUser, err := userService.AddUser(user)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
gitee := enum.OAuthSourceGitee
userSocial = model.ScaAuthUserSocial{
UserID: &uidStr,
UUID: &Id,
Source: &gitee,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
if err := tx.Commit().Error; err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
HandleLoginResponse(c, *addUser.UID)
} else {
HandleLoginResponse(c, *userSocial.UserID)
}
err = userRoleService.AddUserRole(userRole)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(addUser)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
} else {
user, err := userService.QueryUserById(userSocial.UserID)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(user)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
return
}
return
}

View File

@@ -148,99 +148,116 @@ func (OAuthAPI) Callback(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
// 通过 code, 获取 token
var tokenAuthUrl = GetTokenAuthUrl(code)
var token *Token
if token, err = GetToken(tokenAuthUrl); err != nil {
global.LOG.Error(err)
return
}
// 通过token获取用户信息
var userInfo map[string]interface{}
if userInfo, err = GetUserInfo(token); err != nil {
// 使用channel来接收异步操作的结果
tokenChan := make(chan *Token)
userInfoChan := make(chan map[string]interface{})
errChan := make(chan error)
// 异步获取token
go func() {
var tokenAuthUrl = GetTokenAuthUrl(code)
token, err := GetToken(tokenAuthUrl)
if err != nil {
errChan <- err
return
}
tokenChan <- token
}()
// 异步获取用户信息
go func() {
token := <-tokenChan
if token == nil {
return
}
userInfo, err := GetUserInfo(token)
if err != nil {
errChan <- err
return
}
userInfoChan <- userInfo
}()
select {
case err = <-errChan:
global.LOG.Error(err)
return
}
//json 转 struct
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
global.LOG.Error(err)
case userInfo := <-userInfoChan:
if userInfo == nil {
global.LOG.Error(<-errChan)
return
}
// 继续处理用户信息
userInfoBytes, err := json.Marshal(<-userInfoChan)
if err != nil {
global.LOG.Error(err)
return
}
var gitHubUser GitHubUser
err = json.Unmarshal(userInfoBytes, &gitHubUser)
if err != nil {
global.LOG.Error(err)
return
}
Id := strconv.Itoa(gitHubUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGithub)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
db := global.DB
tx := db.Begin() // 开始事务
if tx.Error != nil {
global.LOG.Error(tx.Error)
return
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &gitHubUser.Login,
Nickname: &gitHubUser.Name,
Avatar: &gitHubUser.AvatarURL,
Blog: &gitHubUser.Blog,
Email: &gitHubUser.Email,
}
addUser, err := userService.AddUser(user)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
github := enum.OAuthSourceGithub
userSocial = model.ScaAuthUserSocial{
UserID: &uidStr,
UUID: &Id,
Source: &github,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
if err := tx.Commit().Error; err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
HandleLoginResponse(c, *addUser.UID)
} else {
HandleLoginResponse(c, *userSocial.UserID)
}
return
}
var gitHubUser GitHubUser
err = json.Unmarshal(userInfoBytes, &gitHubUser)
if err != nil {
global.LOG.Error(err)
return
}
Id := strconv.Itoa(gitHubUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGithub)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &gitHubUser.Login,
Nickname: &gitHubUser.Name,
Avatar: &gitHubUser.AvatarURL,
Blog: &gitHubUser.Blog,
Email: &gitHubUser.Email,
}
addUser, err := userService.AddUser(user)
if err != nil {
global.LOG.Error(err)
return
}
github := enum.OAuthSourceGithub
userSocial = model.ScaAuthUserSocial{
UserID: &addUser.ID,
UUID: &Id,
Source: &github,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
global.LOG.Error(err)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
}
err = userRoleService.AddUserRole(userRole)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(addUser)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
} else {
user, err := userService.QueryUserById(userSocial.UserID)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(user)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
}
return
}

View File

@@ -2,10 +2,13 @@ package oauth_api
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"schisandra-cloud-album/api/user_api/dto"
"schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/redis"
"schisandra-cloud-album/model"
"schisandra-cloud-album/global"
"schisandra-cloud-album/service"
"schisandra-cloud-album/utils"
"time"
@@ -14,11 +17,7 @@ import (
type OAuthAPI struct{}
var userService = service.Service.UserService
var userRoleService = service.Service.UserRoleService
var userSocialService = service.Service.UserSocialService
var rolePermissionService = service.Service.RolePermissionService
var permissionServiceService = service.Service.PermissionService
var roleService = service.Service.RoleService
type Token struct {
AccessToken string `json:"access_token"`
@@ -31,50 +30,82 @@ var script = `
</script>
`
func HandleLoginResponse(c *gin.Context, uid string) {
res, data := HandelUserLogin(uid)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
return
}
// HandelUserLogin 处理用户登录
func HandelUserLogin(user model.ScaAuthUser) (bool, map[string]interface{}) {
ids, err := userRoleService.GetUserRoleIdsByUserId(user.ID)
if err != nil {
func HandelUserLogin(userId string) (bool, map[string]interface{}) {
// 使用goroutine生成accessToken
accessTokenChan := make(chan string)
errChan := make(chan error)
go func() {
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: &userId})
if err != nil {
errChan <- err
return
}
accessTokenChan <- accessToken
}()
// 使用goroutine生成refreshToken
refreshTokenChan := make(chan string)
expiresAtChan := make(chan int64)
go func() {
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7)
refreshTokenChan <- refreshToken
expiresAtChan <- expiresAt
}()
// 等待accessToken和refreshToken生成完成
var accessToken string
var refreshToken string
var expiresAt int64
var err error
select {
case accessToken = <-accessTokenChan:
case err = <-errChan:
global.LOG.Error(err)
return false, nil
}
permissionIds := rolePermissionService.QueryPermissionIdsByRoleId(ids)
permissions, err := permissionServiceService.GetPermissionsByIds(permissionIds)
if err != nil {
return false, nil
select {
case refreshToken = <-refreshTokenChan:
case expiresAt = <-expiresAtChan:
}
serializedPermissions, err := json.Marshal(permissions)
if err != nil {
return false, nil
}
wrong := redis.Set(constant.UserAuthPermissionRedisKey+*user.UID, serializedPermissions, 0).Err()
if wrong != nil {
return false, nil
}
roleList, err := roleService.GetRoleListByIds(ids)
if err != nil {
return false, nil
}
serializedRoleList, err := json.Marshal(roleList)
if err != nil {
return false, nil
}
er := redis.Set(constant.UserAuthRoleRedisKey+*user.UID, serializedRoleList, 0).Err()
if er != nil {
return false, nil
}
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID, RoleID: ids})
if err != nil {
return false, nil
}
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID, RoleID: ids}, time.Hour*24*7)
data := dto.ResponseData{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: expiresAt,
UID: user.UID,
UID: &userId,
}
fail := redis.Set(constant.UserLoginTokenRedisKey+*user.UID, data, time.Hour*24*7).Err()
if fail != nil {
// 使用goroutine将数据存入redis
redisErrChan := make(chan error)
go func() {
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err()
if fail != nil {
redisErrChan <- fail
return
}
redisErrChan <- nil
}()
// 等待redis操作完成
redisErr := <-redisErrChan
if redisErr != nil {
global.LOG.Error(redisErr)
return false, nil
}
responseData := map[string]interface{}{

View File

@@ -173,23 +173,61 @@ func (OAuthAPI) QQCallback(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
// 通过 code, 获取 token
var tokenAuthUrl = GetQQTokenAuthUrl(code)
tokenChan := make(chan *QQToken)
errChan := make(chan error)
go func() {
token, err := GetQQToken(tokenAuthUrl)
if err != nil {
errChan <- err
return
}
tokenChan <- token
}()
var token *QQToken
if token, err = GetQQToken(tokenAuthUrl); err != nil {
select {
case token = <-tokenChan:
case err = <-errChan:
global.LOG.Error(err)
return
}
// 通过 token获取 openid
authQQme, err := GetQQUserOpenID(token)
if err != nil {
openIDChan := make(chan *AuthQQme)
errChan = make(chan error)
go func() {
authQQme, err := GetQQUserOpenID(token)
if err != nil {
errChan <- err
return
}
openIDChan <- authQQme
}()
var authQQme *AuthQQme
select {
case authQQme = <-openIDChan:
case err = <-errChan:
global.LOG.Error(err)
return
}
// 通过token获取用户信息
userInfoChan := make(chan map[string]interface{})
errChan = make(chan error)
go func() {
userInfo, err := GetQQUserUserInfo(token, authQQme.OpenID)
if err != nil {
errChan <- err
return
}
userInfoChan <- userInfo
}()
var userInfo map[string]interface{}
if userInfo, err = GetQQUserUserInfo(token, authQQme.OpenID); err != nil {
select {
case userInfo = <-userInfoChan:
case err = <-errChan:
global.LOG.Error(err)
return
}
@@ -205,8 +243,20 @@ func (OAuthAPI) QQCallback(c *gin.Context) {
global.LOG.Error(err)
return
}
userSocial, err := userSocialService.QueryUserSocialByOpenID(authQQme.OpenID, enum.OAuthSourceQQ)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
db := global.DB
tx := db.Begin() // 开始事务
if tx.Error != nil {
global.LOG.Error(tx.Error)
return
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
@@ -221,59 +271,36 @@ func (OAuthAPI) QQCallback(c *gin.Context) {
}
addUser, err := userService.AddUser(user)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
qq := enum.OAuthSourceQQ
userSocial = model.ScaAuthUserSocial{
UserID: &addUser.ID,
UserID: &uidStr,
OpenID: &authQQme.OpenID,
Source: &qq,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
}
err = userRoleService.AddUserRole(userRole)
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(addUser)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
if err := tx.Commit().Error; err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
HandleLoginResponse(c, *addUser.UID)
return
} else {
user, err := userService.QueryUserById(userSocial.UserID)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(user)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
return
HandleLoginResponse(c, *userSocial.UserID)
}
}

View File

@@ -26,9 +26,12 @@ import (
"schisandra-cloud-album/utils"
"strconv"
"strings"
"sync"
"time"
)
var mu sync.Mutex
// GenerateClientId 生成客户端ID
// @Summary 生成客户端ID
// @Description 生成客户端ID
@@ -44,6 +47,9 @@ func (OAuthAPI) GenerateClientId(c *gin.Context) {
if ip == "" {
ip = c.ClientIP()
}
// 加锁
mu.Lock()
defer mu.Unlock()
// 从Redis获取客户端ID
clientId := redis.Get(constant.UserLoginClientRedisKey + ip).Val()
@@ -70,10 +76,7 @@ func (OAuthAPI) GenerateClientId(c *gin.Context) {
// @Router /api/oauth/callback_notify [POST]
func (OAuthAPI) CallbackNotify(c *gin.Context) {
rs, err := global.Wechat.Server.Notify(c.Request, func(event contract.EventInterface) interface{} {
fmt.Dump("event", event)
switch event.GetMsgType() {
case models2.CALLBACK_MSG_TYPE_EVENT:
switch event.GetEvent() {
case models.CALLBACK_EVENT_SUBSCRIBE:
@@ -122,7 +125,6 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) {
println(err.Error())
return "error"
}
fmt.Dump(msg)
}
return messages.NewText("ok")
@@ -177,6 +179,7 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
data := response.ResponseQRCodeCreate{}
err := json.Unmarshal([]byte(qrcode), &data)
if err != nil {
global.LOG.Error(err)
return
}
result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c)
@@ -184,17 +187,20 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
}
data, err := global.Wechat.QRCode.Temporary(c.Request.Context(), clientId, 30*24*3600)
if err != nil {
global.LOG.Error(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
return
}
serializedData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
return
}
wrong := redis.Set(constant.UserLoginQrcodeRedisKey+ip+":"+clientId, serializedData, time.Hour*24*30).Err()
if wrong != nil {
global.LOG.Error(wrong)
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
return
}
@@ -208,45 +214,98 @@ func wechatLoginHandler(openId string, clientId string) bool {
}
authUserSocial, err := userSocialService.QueryUserSocialByOpenID(openId, enum.OAuthSourceWechat)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
tx := global.DB.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
createUser := model.ScaAuthUser{
UID: &uidStr,
Username: &openId,
}
addUser, err := userService.AddUser(createUser)
if err != nil {
// 异步添加用户
addUserChan := make(chan *model.ScaAuthUser, 1)
errChan := make(chan error, 1)
go func() {
addUser, err := userService.AddUser(createUser)
if err != nil {
errChan <- err
return
}
addUserChan <- &addUser
}()
var addUser *model.ScaAuthUser
select {
case addUser = <-addUserChan:
case err := <-errChan:
tx.Rollback()
global.LOG.Error(err)
return false
}
wechat := enum.OAuthSourceWechat
userSocial := model.ScaAuthUserSocial{
UserID: &addUser.ID,
UserID: &uidStr,
OpenID: &openId,
Source: &wechat,
}
wrong := userSocialService.AddUserSocial(userSocial)
if wrong != nil {
return false
// 异步添加用户社交信息
wrongChan := make(chan error, 1)
go func() {
wrong := userSocialService.AddUserSocial(userSocial)
wrongChan <- wrong
}()
select {
case wrong := <-wrongChan:
if wrong != nil {
tx.Rollback()
global.LOG.Error(wrong)
return false
}
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
// 异步添加角色
roleErrChan := make(chan error, 1)
go func() {
_, err := global.Casbin.AddRoleForUser(uidStr, enum.User)
roleErrChan <- err
}()
select {
case err := <-roleErrChan:
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return false
}
}
e := userRoleService.AddUserRole(userRole)
if e != nil {
return false
}
res := handelUserLogin(addUser, clientId)
if !res {
return false
// 异步处理用户登录
resChan := make(chan bool, 1)
go func() {
res := handelUserLogin(*addUser.UID, clientId)
resChan <- res
}()
select {
case res := <-resChan:
if !res {
tx.Rollback()
return false
}
}
tx.Commit()
return true
} else {
user, err := userService.QueryUserById(authUserSocial.UserID)
if err != nil {
return false
}
res := handelUserLogin(user, clientId)
res := handelUserLogin(*authUserSocial.UserID, clientId)
if !res {
return false
}
@@ -255,70 +314,47 @@ func wechatLoginHandler(openId string, clientId string) bool {
}
// handelUserLogin 处理用户登录
func handelUserLogin(user model.ScaAuthUser, clientId string) bool {
ids, err := userRoleService.GetUserRoleIdsByUserId(user.ID)
if err != nil {
return false
}
permissionIds := rolePermissionService.QueryPermissionIdsByRoleId(ids)
permissions, err := permissionServiceService.GetPermissionsByIds(permissionIds)
if err != nil {
return false
}
serializedPermissions, err := json.Marshal(permissions)
if err != nil {
return false
}
wrong := redis.Set(constant.UserAuthPermissionRedisKey+*user.UID, serializedPermissions, 0).Err()
if wrong != nil {
return false
}
roleList, err := roleService.GetRoleListByIds(ids)
if err != nil {
return false
}
serializedRoleList, err := json.Marshal(roleList)
if err != nil {
return false
}
er := redis.Set(constant.UserAuthRoleRedisKey+*user.UID, serializedRoleList, 0).Err()
if er != nil {
return false
}
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID, RoleID: ids})
if err != nil {
return false
}
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID, RoleID: ids}, time.Hour*24*7)
data := dto.ResponseData{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: expiresAt,
UID: user.UID,
}
fail := redis.Set(constant.UserLoginTokenRedisKey+*user.UID, data, time.Hour*24*7).Err()
if fail != nil {
return false
}
responseData := map[string]interface{}{
"code": 0,
"message": "success",
"data": data,
"success": true,
}
tokenData, err := json.Marshal(responseData)
if err != nil {
return false
}
// gws方式发送消息
err = websocket_api.Handler.SendMessageToClient(clientId, tokenData)
if err != nil {
return false
}
// gorilla websocket方式发送消息
//res := websocket_api.SendMessageData(clientId, responseData)
//if !res {
// return false
//}
return true
func handelUserLogin(userId string, clientId string) bool {
resultChan := make(chan bool, 1)
go func() {
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: &userId})
if err != nil {
resultChan <- false
return
}
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7)
data := dto.ResponseData{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: expiresAt,
UID: &userId,
}
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err()
if fail != nil {
resultChan <- false
return
}
responseData := map[string]interface{}{
"code": 0,
"message": "success",
"data": data,
"success": true,
}
tokenData, err := json.Marshal(responseData)
if err != nil {
resultChan <- false
return
}
// gws方式发送消息
err = websocket_api.Handler.SendMessageToClient(clientId, tokenData)
if err != nil {
global.LOG.Error(err)
resultChan <- false
return
}
resultChan <- true
}()
return <-resultChan
}

View File

@@ -0,0 +1,11 @@
package dto
import "schisandra-cloud-album/model"
type AddPermissionRequestDto struct {
Permissions []model.ScaAuthPermission `json:"permissions"`
}
type AddPermissionToRoleRequestDto struct {
RoleKey string `json:"role_key"`
Permissions []model.ScaAuthPermission `json:"permissions"`
}

View File

@@ -0,0 +1,3 @@
package permission_api
type PermissionAPI struct{}

View File

@@ -0,0 +1,36 @@
package permission_api
import (
ginI18n "github.com/gin-contrib/i18n"
"github.com/gin-gonic/gin"
"schisandra-cloud-album/api/permission_api/dto"
"schisandra-cloud-album/common/result"
"schisandra-cloud-album/global"
"schisandra-cloud-album/service"
)
var permissionService = service.Service.PermissionService
// AddPermissions 批量添加权限
// @Summary 批量添加权限
// @Description 批量添加权限
// @Tags 权限管理
// @Accept json
// @Produce json
// @Param permissions body dto.AddPermissionRequestDto true "权限列表"
// @Router /api/auth/permission/add [post]
func (PermissionAPI) AddPermissions(c *gin.Context) {
addPermissionRequestDto := dto.AddPermissionRequestDto{}
err := c.ShouldBindJSON(&addPermissionRequestDto)
if err != nil {
return
}
err = permissionService.CreatePermissions(addPermissionRequestDto.Permissions)
if err != nil {
global.LOG.Error(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "CreatedFailed"), c)
return
}
result.OkWithMessage(ginI18n.MustGetMessage(c, "CreatedSuccess"), c)
return
}

View File

@@ -0,0 +1,11 @@
package dto
type RoleRequestDto struct {
RoleName string `json:"role_name"`
RoleKey string `json:"role_key"`
}
type AddRoleToUserRequestDto struct {
Uid string `json:"uid"`
RoleKey string `json:"role_key"`
}

View File

@@ -1 +1,64 @@
package role_api
import (
ginI18n "github.com/gin-contrib/i18n"
"github.com/gin-gonic/gin"
"schisandra-cloud-album/api/role_api/dto"
"schisandra-cloud-album/common/result"
"schisandra-cloud-album/global"
"schisandra-cloud-album/model"
"schisandra-cloud-album/service"
)
var roleService = service.Service.RoleService
// CreateRole 创建角色
// @Summary 创建角色
// @Description 创建角色
// @Tags 角色
// @Accept json
// @Produce json
// @Param roleRequestDto body dto.RoleRequestDto true "角色信息"
// @Router /api/auth/role/create [post]
func (RoleAPI) CreateRole(c *gin.Context) {
roleRequestDto := dto.RoleRequestDto{}
err := c.ShouldBindJSON(&roleRequestDto)
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CreatedFailed"), c)
return
}
role := model.ScaAuthRole{
RoleName: roleRequestDto.RoleName,
RoleKey: roleRequestDto.RoleKey,
}
err = roleService.AddRole(role)
if err != nil {
global.LOG.Error(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "CreatedFailed"), c)
return
}
result.OkWithMessage(ginI18n.MustGetMessage(c, "CreatedSuccess"), c)
}
// AddRoleToUser 给指定用户添加角色
// @Summary 给指定用户添加角色
// @Description 给指定用户添加角色
// @Tags 角色
// @Accept json
// @Produce json
// @Param addRoleToUserRequestDto body dto.AddRoleToUserRequestDto true "给指定用户添加角色"
// @Router /api/auth/role/add_role_to_user [post]
func (RoleAPI) AddRoleToUser(c *gin.Context) {
addRoleToUserRequestDto := dto.AddRoleToUserRequestDto{}
err := c.ShouldBindJSON(&addRoleToUserRequestDto)
if err != nil {
global.LOG.Error(err)
return
}
user, err := global.Casbin.AddRoleForUser(addRoleToUserRequestDto.Uid, addRoleToUserRequestDto.RoleKey)
if err != nil {
global.LOG.Error(err)
return
}
result.OkWithData(user, c)
}

View File

@@ -1,7 +1,6 @@
package user_api
import (
"encoding/json"
ginI18n "github.com/gin-contrib/i18n"
"github.com/gin-gonic/gin"
"github.com/yitter/idgenerator-go/idgen"
@@ -20,10 +19,6 @@ import (
)
var userService = service.Service.UserService
var userRoleService = service.Service.UserRoleService
var rolePermissionService = service.Service.RolePermissionService
var permissionServiceService = service.Service.PermissionService
var roleService = service.Service.RoleService
// GetUserList
// @Summary 获取所有用户列表
@@ -146,12 +141,8 @@ func (UserAPI) AddUser(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "AddUserError"), c)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
}
e := userRoleService.AddUserRole(userRole)
if e != nil {
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "AddUserRoleError"), c)
return
}
@@ -258,11 +249,30 @@ func (UserAPI) PhoneLogin(c *gin.Context) {
return
}
user := userService.QueryUserByPhone(phone)
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
// 未注册
// 异步查询用户信息
userChan := make(chan *model.ScaAuthUser)
go func() {
user := userService.QueryUserByPhone(phone)
userChan <- &user
}()
// 异步获取验证码
codeChan := make(chan string)
go func() {
code := redis.Get(constant.UserLoginSmsRedisKey + phone)
if code == nil {
codeChan <- ""
} else {
codeChan <- code.Val()
}
}()
user := <-userChan
code := <-codeChan
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
// 未注册
if code == "" {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c)
return
} else {
@@ -277,35 +287,29 @@ func (UserAPI) PhoneLogin(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "RegisterUserError"), c)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
}
e := userRoleService.AddUserRole(userRole)
if e != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "RegisterUserError"), c)
return
}
handelUserLogin(addUser, request.AutoLogin, c)
return
}
} else {
code := redis.Get(constant.UserLoginSmsRedisKey + phone)
if code == nil {
if code == "" {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c)
return
} else {
if captcha != code.Val() {
if captcha != code {
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaError"), c)
return
} else {
handelUserLogin(user, request.AutoLogin, c)
handelUserLogin(*user, request.AutoLogin, c)
return
}
}
}
}
// RefreshHandler 刷新token
@@ -333,7 +337,7 @@ func (UserAPI) RefreshHandler(c *gin.Context) {
return
}
if isUpd {
accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID, RoleID: parseRefreshToken.RoleID})
accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID})
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
@@ -362,42 +366,7 @@ func (UserAPI) RefreshHandler(c *gin.Context) {
// handelUserLogin 处理用户登录
func handelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) {
ids, err := userRoleService.GetUserRoleIdsByUserId(user.ID)
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
return
}
permissionIds := rolePermissionService.QueryPermissionIdsByRoleId(ids)
permissions, err := permissionServiceService.GetPermissionsByIds(permissionIds)
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
return
}
serializedPermissions, err := json.Marshal(permissions)
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
return
}
wrong := redis.Set(constant.UserAuthPermissionRedisKey+*user.UID, serializedPermissions, 0).Err()
if wrong != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
return
}
roleList, err := roleService.GetRoleListByIds(ids)
if err != nil {
return
}
serializedRoleList, err := json.Marshal(roleList)
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
return
}
er := redis.Set(constant.UserAuthRoleRedisKey+*user.UID, serializedRoleList, 0).Err()
if er != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
return
}
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID, RoleID: ids})
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID})
if err != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
return
@@ -408,7 +377,7 @@ func handelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) {
} else {
days = time.Hour * 24 * 1
}
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID, RoleID: ids}, days)
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID}, days)
data := dto.ResponseData{
AccessToken: accessToken,
RefreshToken: refreshToken,