update qq/gitee/github oauth2 login

This commit is contained in:
landaiqing
2024-08-24 16:31:40 +08:00
parent 014abca8f8
commit 9330935822
45 changed files with 1243 additions and 642 deletions

View File

@@ -142,100 +142,112 @@ func (OAuthAPI) GiteeCallback(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
// 通过 code, 获取 token
var tokenAuthUrl = GetGiteeTokenAuthUrl(code)
var token *Token
if token, err = GetGiteeToken(tokenAuthUrl); err != nil {
global.LOG.Error(err)
return
}
// 通过token获取用户信息
var userInfo map[string]interface{}
if userInfo, err = GetGiteeUserInfo(token); err != nil {
global.LOG.Error(err)
return
}
// 异步获取 token
var tokenChan = make(chan *Token)
var errChan = make(chan error)
go func() {
var tokenAuthUrl = GetGiteeTokenAuthUrl(code)
token, err := GetGiteeToken(tokenAuthUrl)
if err != nil {
errChan <- err
return
}
tokenChan <- token
}()
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
global.LOG.Error(err)
return
}
var giteeUser GiteeUser
err = json.Unmarshal(userInfoBytes, &giteeUser)
if err != nil {
global.LOG.Error(err)
return
}
// 异步获取用户信息
var userInfoChan = make(chan map[string]interface{})
go func() {
token := <-tokenChan
if token == nil {
errChan <- errors.New("failed to get token")
return
}
userInfo, err := GetGiteeUserInfo(token)
if err != nil {
errChan <- err
return
}
userInfoChan <- userInfo
}()
Id := strconv.Itoa(giteeUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGitee)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &giteeUser.Login,
Nickname: &giteeUser.Name,
Avatar: &giteeUser.AvatarURL,
Blog: &giteeUser.Blog,
Email: &giteeUser.Email,
}
addUser, err := userService.AddUser(user)
// 等待结果
select {
case err = <-errChan:
global.LOG.Error(err)
return
case userInfo := <-userInfoChan:
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
global.LOG.Error(err)
return
}
gitee := enum.OAuthSourceGitee
userSocial = model.ScaAuthUserSocial{
UserID: &addUser.ID,
UUID: &Id,
Source: &gitee,
}
err = userSocialService.AddUserSocial(userSocial)
var giteeUser GiteeUser
err = json.Unmarshal(userInfoBytes, &giteeUser)
if err != nil {
global.LOG.Error(err)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
Id := strconv.Itoa(giteeUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGitee)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
db := global.DB
tx := db.Begin() // 开始事务
if tx.Error != nil {
global.LOG.Error(tx.Error)
return
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &giteeUser.Login,
Nickname: &giteeUser.Name,
Avatar: &giteeUser.AvatarURL,
Blog: &giteeUser.Blog,
Email: &giteeUser.Email,
}
addUser, err := userService.AddUser(user)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
gitee := enum.OAuthSourceGitee
userSocial = model.ScaAuthUserSocial{
UserID: &uidStr,
UUID: &Id,
Source: &gitee,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
if err := tx.Commit().Error; err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
HandleLoginResponse(c, *addUser.UID)
} else {
HandleLoginResponse(c, *userSocial.UserID)
}
err = userRoleService.AddUserRole(userRole)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(addUser)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
} else {
user, err := userService.QueryUserById(userSocial.UserID)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(user)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
return
}
return
}

View File

@@ -148,99 +148,116 @@ func (OAuthAPI) Callback(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
// 通过 code, 获取 token
var tokenAuthUrl = GetTokenAuthUrl(code)
var token *Token
if token, err = GetToken(tokenAuthUrl); err != nil {
global.LOG.Error(err)
return
}
// 通过token获取用户信息
var userInfo map[string]interface{}
if userInfo, err = GetUserInfo(token); err != nil {
// 使用channel来接收异步操作的结果
tokenChan := make(chan *Token)
userInfoChan := make(chan map[string]interface{})
errChan := make(chan error)
// 异步获取token
go func() {
var tokenAuthUrl = GetTokenAuthUrl(code)
token, err := GetToken(tokenAuthUrl)
if err != nil {
errChan <- err
return
}
tokenChan <- token
}()
// 异步获取用户信息
go func() {
token := <-tokenChan
if token == nil {
return
}
userInfo, err := GetUserInfo(token)
if err != nil {
errChan <- err
return
}
userInfoChan <- userInfo
}()
select {
case err = <-errChan:
global.LOG.Error(err)
return
}
//json 转 struct
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
global.LOG.Error(err)
case userInfo := <-userInfoChan:
if userInfo == nil {
global.LOG.Error(<-errChan)
return
}
// 继续处理用户信息
userInfoBytes, err := json.Marshal(<-userInfoChan)
if err != nil {
global.LOG.Error(err)
return
}
var gitHubUser GitHubUser
err = json.Unmarshal(userInfoBytes, &gitHubUser)
if err != nil {
global.LOG.Error(err)
return
}
Id := strconv.Itoa(gitHubUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGithub)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
db := global.DB
tx := db.Begin() // 开始事务
if tx.Error != nil {
global.LOG.Error(tx.Error)
return
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &gitHubUser.Login,
Nickname: &gitHubUser.Name,
Avatar: &gitHubUser.AvatarURL,
Blog: &gitHubUser.Blog,
Email: &gitHubUser.Email,
}
addUser, err := userService.AddUser(user)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
github := enum.OAuthSourceGithub
userSocial = model.ScaAuthUserSocial{
UserID: &uidStr,
UUID: &Id,
Source: &github,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
if err := tx.Commit().Error; err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
HandleLoginResponse(c, *addUser.UID)
} else {
HandleLoginResponse(c, *userSocial.UserID)
}
return
}
var gitHubUser GitHubUser
err = json.Unmarshal(userInfoBytes, &gitHubUser)
if err != nil {
global.LOG.Error(err)
return
}
Id := strconv.Itoa(gitHubUser.ID)
userSocial, err := userSocialService.QueryUserSocialByUUID(Id, enum.OAuthSourceGithub)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
user := model.ScaAuthUser{
UID: &uidStr,
Username: &gitHubUser.Login,
Nickname: &gitHubUser.Name,
Avatar: &gitHubUser.AvatarURL,
Blog: &gitHubUser.Blog,
Email: &gitHubUser.Email,
}
addUser, err := userService.AddUser(user)
if err != nil {
global.LOG.Error(err)
return
}
github := enum.OAuthSourceGithub
userSocial = model.ScaAuthUserSocial{
UserID: &addUser.ID,
UUID: &Id,
Source: &github,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
global.LOG.Error(err)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
}
err = userRoleService.AddUserRole(userRole)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(addUser)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
} else {
user, err := userService.QueryUserById(userSocial.UserID)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(user)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
}
return
}

View File

@@ -2,10 +2,13 @@ package oauth_api
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"schisandra-cloud-album/api/user_api/dto"
"schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/redis"
"schisandra-cloud-album/model"
"schisandra-cloud-album/global"
"schisandra-cloud-album/service"
"schisandra-cloud-album/utils"
"time"
@@ -14,11 +17,7 @@ import (
type OAuthAPI struct{}
var userService = service.Service.UserService
var userRoleService = service.Service.UserRoleService
var userSocialService = service.Service.UserSocialService
var rolePermissionService = service.Service.RolePermissionService
var permissionServiceService = service.Service.PermissionService
var roleService = service.Service.RoleService
type Token struct {
AccessToken string `json:"access_token"`
@@ -31,50 +30,82 @@ var script = `
</script>
`
func HandleLoginResponse(c *gin.Context, uid string) {
res, data := HandelUserLogin(uid)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
return
}
// HandelUserLogin 处理用户登录
func HandelUserLogin(user model.ScaAuthUser) (bool, map[string]interface{}) {
ids, err := userRoleService.GetUserRoleIdsByUserId(user.ID)
if err != nil {
func HandelUserLogin(userId string) (bool, map[string]interface{}) {
// 使用goroutine生成accessToken
accessTokenChan := make(chan string)
errChan := make(chan error)
go func() {
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: &userId})
if err != nil {
errChan <- err
return
}
accessTokenChan <- accessToken
}()
// 使用goroutine生成refreshToken
refreshTokenChan := make(chan string)
expiresAtChan := make(chan int64)
go func() {
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7)
refreshTokenChan <- refreshToken
expiresAtChan <- expiresAt
}()
// 等待accessToken和refreshToken生成完成
var accessToken string
var refreshToken string
var expiresAt int64
var err error
select {
case accessToken = <-accessTokenChan:
case err = <-errChan:
global.LOG.Error(err)
return false, nil
}
permissionIds := rolePermissionService.QueryPermissionIdsByRoleId(ids)
permissions, err := permissionServiceService.GetPermissionsByIds(permissionIds)
if err != nil {
return false, nil
select {
case refreshToken = <-refreshTokenChan:
case expiresAt = <-expiresAtChan:
}
serializedPermissions, err := json.Marshal(permissions)
if err != nil {
return false, nil
}
wrong := redis.Set(constant.UserAuthPermissionRedisKey+*user.UID, serializedPermissions, 0).Err()
if wrong != nil {
return false, nil
}
roleList, err := roleService.GetRoleListByIds(ids)
if err != nil {
return false, nil
}
serializedRoleList, err := json.Marshal(roleList)
if err != nil {
return false, nil
}
er := redis.Set(constant.UserAuthRoleRedisKey+*user.UID, serializedRoleList, 0).Err()
if er != nil {
return false, nil
}
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID, RoleID: ids})
if err != nil {
return false, nil
}
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID, RoleID: ids}, time.Hour*24*7)
data := dto.ResponseData{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: expiresAt,
UID: user.UID,
UID: &userId,
}
fail := redis.Set(constant.UserLoginTokenRedisKey+*user.UID, data, time.Hour*24*7).Err()
if fail != nil {
// 使用goroutine将数据存入redis
redisErrChan := make(chan error)
go func() {
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err()
if fail != nil {
redisErrChan <- fail
return
}
redisErrChan <- nil
}()
// 等待redis操作完成
redisErr := <-redisErrChan
if redisErr != nil {
global.LOG.Error(redisErr)
return false, nil
}
responseData := map[string]interface{}{

View File

@@ -173,23 +173,61 @@ func (OAuthAPI) QQCallback(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
// 通过 code, 获取 token
var tokenAuthUrl = GetQQTokenAuthUrl(code)
tokenChan := make(chan *QQToken)
errChan := make(chan error)
go func() {
token, err := GetQQToken(tokenAuthUrl)
if err != nil {
errChan <- err
return
}
tokenChan <- token
}()
var token *QQToken
if token, err = GetQQToken(tokenAuthUrl); err != nil {
select {
case token = <-tokenChan:
case err = <-errChan:
global.LOG.Error(err)
return
}
// 通过 token获取 openid
authQQme, err := GetQQUserOpenID(token)
if err != nil {
openIDChan := make(chan *AuthQQme)
errChan = make(chan error)
go func() {
authQQme, err := GetQQUserOpenID(token)
if err != nil {
errChan <- err
return
}
openIDChan <- authQQme
}()
var authQQme *AuthQQme
select {
case authQQme = <-openIDChan:
case err = <-errChan:
global.LOG.Error(err)
return
}
// 通过token获取用户信息
userInfoChan := make(chan map[string]interface{})
errChan = make(chan error)
go func() {
userInfo, err := GetQQUserUserInfo(token, authQQme.OpenID)
if err != nil {
errChan <- err
return
}
userInfoChan <- userInfo
}()
var userInfo map[string]interface{}
if userInfo, err = GetQQUserUserInfo(token, authQQme.OpenID); err != nil {
select {
case userInfo = <-userInfoChan:
case err = <-errChan:
global.LOG.Error(err)
return
}
@@ -205,8 +243,20 @@ func (OAuthAPI) QQCallback(c *gin.Context) {
global.LOG.Error(err)
return
}
userSocial, err := userSocialService.QueryUserSocialByOpenID(authQQme.OpenID, enum.OAuthSourceQQ)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
db := global.DB
tx := db.Begin() // 开始事务
if tx.Error != nil {
global.LOG.Error(tx.Error)
return
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 第一次登录,创建用户
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
@@ -221,59 +271,36 @@ func (OAuthAPI) QQCallback(c *gin.Context) {
}
addUser, err := userService.AddUser(user)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
qq := enum.OAuthSourceQQ
userSocial = model.ScaAuthUserSocial{
UserID: &addUser.ID,
UserID: &uidStr,
OpenID: &authQQme.OpenID,
Source: &qq,
}
err = userSocialService.AddUserSocial(userSocial)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
}
err = userRoleService.AddUserRole(userRole)
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(addUser)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
if err := tx.Commit().Error; err != nil {
tx.Rollback()
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
HandleLoginResponse(c, *addUser.UID)
return
} else {
user, err := userService.QueryUserById(userSocial.UserID)
if err != nil {
global.LOG.Error(err)
return
}
res, data := HandelUserLogin(user)
if !res {
return
}
tokenData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
return
}
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.Web)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
return
HandleLoginResponse(c, *userSocial.UserID)
}
}

View File

@@ -26,9 +26,12 @@ import (
"schisandra-cloud-album/utils"
"strconv"
"strings"
"sync"
"time"
)
var mu sync.Mutex
// GenerateClientId 生成客户端ID
// @Summary 生成客户端ID
// @Description 生成客户端ID
@@ -44,6 +47,9 @@ func (OAuthAPI) GenerateClientId(c *gin.Context) {
if ip == "" {
ip = c.ClientIP()
}
// 加锁
mu.Lock()
defer mu.Unlock()
// 从Redis获取客户端ID
clientId := redis.Get(constant.UserLoginClientRedisKey + ip).Val()
@@ -70,10 +76,7 @@ func (OAuthAPI) GenerateClientId(c *gin.Context) {
// @Router /api/oauth/callback_notify [POST]
func (OAuthAPI) CallbackNotify(c *gin.Context) {
rs, err := global.Wechat.Server.Notify(c.Request, func(event contract.EventInterface) interface{} {
fmt.Dump("event", event)
switch event.GetMsgType() {
case models2.CALLBACK_MSG_TYPE_EVENT:
switch event.GetEvent() {
case models.CALLBACK_EVENT_SUBSCRIBE:
@@ -122,7 +125,6 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) {
println(err.Error())
return "error"
}
fmt.Dump(msg)
}
return messages.NewText("ok")
@@ -177,6 +179,7 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
data := response.ResponseQRCodeCreate{}
err := json.Unmarshal([]byte(qrcode), &data)
if err != nil {
global.LOG.Error(err)
return
}
result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c)
@@ -184,17 +187,20 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
}
data, err := global.Wechat.QRCode.Temporary(c.Request.Context(), clientId, 30*24*3600)
if err != nil {
global.LOG.Error(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
return
}
serializedData, err := json.Marshal(data)
if err != nil {
global.LOG.Error(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
return
}
wrong := redis.Set(constant.UserLoginQrcodeRedisKey+ip+":"+clientId, serializedData, time.Hour*24*30).Err()
if wrong != nil {
global.LOG.Error(wrong)
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
return
}
@@ -208,45 +214,98 @@ func wechatLoginHandler(openId string, clientId string) bool {
}
authUserSocial, err := userSocialService.QueryUserSocialByOpenID(openId, enum.OAuthSourceWechat)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
tx := global.DB.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
uid := idgen.NextId()
uidStr := strconv.FormatInt(uid, 10)
createUser := model.ScaAuthUser{
UID: &uidStr,
Username: &openId,
}
addUser, err := userService.AddUser(createUser)
if err != nil {
// 异步添加用户
addUserChan := make(chan *model.ScaAuthUser, 1)
errChan := make(chan error, 1)
go func() {
addUser, err := userService.AddUser(createUser)
if err != nil {
errChan <- err
return
}
addUserChan <- &addUser
}()
var addUser *model.ScaAuthUser
select {
case addUser = <-addUserChan:
case err := <-errChan:
tx.Rollback()
global.LOG.Error(err)
return false
}
wechat := enum.OAuthSourceWechat
userSocial := model.ScaAuthUserSocial{
UserID: &addUser.ID,
UserID: &uidStr,
OpenID: &openId,
Source: &wechat,
}
wrong := userSocialService.AddUserSocial(userSocial)
if wrong != nil {
return false
// 异步添加用户社交信息
wrongChan := make(chan error, 1)
go func() {
wrong := userSocialService.AddUserSocial(userSocial)
wrongChan <- wrong
}()
select {
case wrong := <-wrongChan:
if wrong != nil {
tx.Rollback()
global.LOG.Error(wrong)
return false
}
}
userRole := model.ScaAuthUserRole{
UserID: uidStr,
RoleID: enum.User,
// 异步添加角色
roleErrChan := make(chan error, 1)
go func() {
_, err := global.Casbin.AddRoleForUser(uidStr, enum.User)
roleErrChan <- err
}()
select {
case err := <-roleErrChan:
if err != nil {
tx.Rollback()
global.LOG.Error(err)
return false
}
}
e := userRoleService.AddUserRole(userRole)
if e != nil {
return false
}
res := handelUserLogin(addUser, clientId)
if !res {
return false
// 异步处理用户登录
resChan := make(chan bool, 1)
go func() {
res := handelUserLogin(*addUser.UID, clientId)
resChan <- res
}()
select {
case res := <-resChan:
if !res {
tx.Rollback()
return false
}
}
tx.Commit()
return true
} else {
user, err := userService.QueryUserById(authUserSocial.UserID)
if err != nil {
return false
}
res := handelUserLogin(user, clientId)
res := handelUserLogin(*authUserSocial.UserID, clientId)
if !res {
return false
}
@@ -255,70 +314,47 @@ func wechatLoginHandler(openId string, clientId string) bool {
}
// handelUserLogin 处理用户登录
func handelUserLogin(user model.ScaAuthUser, clientId string) bool {
ids, err := userRoleService.GetUserRoleIdsByUserId(user.ID)
if err != nil {
return false
}
permissionIds := rolePermissionService.QueryPermissionIdsByRoleId(ids)
permissions, err := permissionServiceService.GetPermissionsByIds(permissionIds)
if err != nil {
return false
}
serializedPermissions, err := json.Marshal(permissions)
if err != nil {
return false
}
wrong := redis.Set(constant.UserAuthPermissionRedisKey+*user.UID, serializedPermissions, 0).Err()
if wrong != nil {
return false
}
roleList, err := roleService.GetRoleListByIds(ids)
if err != nil {
return false
}
serializedRoleList, err := json.Marshal(roleList)
if err != nil {
return false
}
er := redis.Set(constant.UserAuthRoleRedisKey+*user.UID, serializedRoleList, 0).Err()
if er != nil {
return false
}
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID, RoleID: ids})
if err != nil {
return false
}
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID, RoleID: ids}, time.Hour*24*7)
data := dto.ResponseData{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: expiresAt,
UID: user.UID,
}
fail := redis.Set(constant.UserLoginTokenRedisKey+*user.UID, data, time.Hour*24*7).Err()
if fail != nil {
return false
}
responseData := map[string]interface{}{
"code": 0,
"message": "success",
"data": data,
"success": true,
}
tokenData, err := json.Marshal(responseData)
if err != nil {
return false
}
// gws方式发送消息
err = websocket_api.Handler.SendMessageToClient(clientId, tokenData)
if err != nil {
return false
}
// gorilla websocket方式发送消息
//res := websocket_api.SendMessageData(clientId, responseData)
//if !res {
// return false
//}
return true
func handelUserLogin(userId string, clientId string) bool {
resultChan := make(chan bool, 1)
go func() {
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: &userId})
if err != nil {
resultChan <- false
return
}
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7)
data := dto.ResponseData{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: expiresAt,
UID: &userId,
}
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err()
if fail != nil {
resultChan <- false
return
}
responseData := map[string]interface{}{
"code": 0,
"message": "success",
"data": data,
"success": true,
}
tokenData, err := json.Marshal(responseData)
if err != nil {
resultChan <- false
return
}
// gws方式发送消息
err = websocket_api.Handler.SendMessageToClient(clientId, tokenData)
if err != nil {
global.LOG.Error(err)
resultChan <- false
return
}
resultChan <- true
}()
return <-resultChan
}