🐛 fixed the issue that third-party login sessions were missing

This commit is contained in:
2024-12-20 01:19:29 +08:00
parent 49831fc4d0
commit 40d073db0f
27 changed files with 556 additions and 308 deletions

View File

@@ -0,0 +1,6 @@
package constant
const (
JWT_TYPE_ACCESS string = "access"
JWT_TYPE_REFRESH string = "refresh"
)

View File

@@ -0,0 +1,19 @@
package errors
import "fmt"
// CodeMsg is a struct that contains a code and a message.
// It implements the error interface.
type CodeMsg struct {
Code int
Msg string
}
func (c *CodeMsg) Error() string {
return fmt.Sprintf("code: %d, msg: %s", c.Code, c.Msg)
}
// New creates a new CodeMsg.
func New(code int, msg string) error {
return &CodeMsg{Code: code, Msg: msg}
}

View File

@@ -0,0 +1,85 @@
package http
import (
"context"
"encoding/xml"
"net/http"
"schisandra-album-cloud-microservices/app/core/api/common/errors"
"github.com/zeromicro/go-zero/rest/httpx"
"google.golang.org/grpc/status"
)
// BaseResponse is the base response struct.
type BaseResponse[T any] struct {
// Code represents the business code, not the http status code.
Code int `json:"code" xml:"code"`
// Msg represents the business message, if Code = BusinessCodeOK,
// and Msg is empty, then the Msg will be set to BusinessMsgOk.
Msg string `json:"msg" xml:"msg"`
// Data represents the business data.
Data T `json:"data,omitempty" xml:"data,omitempty"`
}
type baseXmlResponse[T any] struct {
XMLName xml.Name `xml:"xml"`
Version string `xml:"version,attr"`
Encoding string `xml:"encoding,attr"`
BaseResponse[T]
}
// JsonBaseResponse writes v into w with http.StatusOK.
func JsonBaseResponse(w http.ResponseWriter, v any) {
httpx.OkJson(w, wrapBaseResponse(v))
}
// JsonBaseResponseCtx writes v into w with http.StatusOK.
func JsonBaseResponseCtx(ctx context.Context, w http.ResponseWriter, v any) {
httpx.OkJsonCtx(ctx, w, wrapBaseResponse(v))
}
// XmlBaseResponse writes v into w with http.StatusOK.
func XmlBaseResponse(w http.ResponseWriter, v any) {
OkXml(w, wrapXmlBaseResponse(v))
}
// XmlBaseResponseCtx writes v into w with http.StatusOK.
func XmlBaseResponseCtx(ctx context.Context, w http.ResponseWriter, v any) {
OkXmlCtx(ctx, w, wrapXmlBaseResponse(v))
}
func wrapXmlBaseResponse(v any) baseXmlResponse[any] {
base := wrapBaseResponse(v)
return baseXmlResponse[any]{
Version: xmlVersion,
Encoding: xmlEncoding,
BaseResponse: base,
}
}
func wrapBaseResponse(v any) BaseResponse[any] {
var resp BaseResponse[any]
switch data := v.(type) {
case *errors.CodeMsg:
resp.Code = data.Code
resp.Msg = data.Msg
case errors.CodeMsg:
resp.Code = data.Code
resp.Msg = data.Msg
case *status.Status:
resp.Code = int(data.Code())
resp.Msg = data.Message()
case interface{ GRPCStatus() *status.Status }:
resp.Code = int(data.GRPCStatus().Code())
resp.Msg = data.GRPCStatus().Message()
case error:
resp.Code = BusinessCodeError
resp.Msg = data.Error()
default:
resp.Code = BusinessCodeOK
resp.Msg = BusinessMsgOk
resp.Data = v
}
return resp
}

View File

@@ -0,0 +1,100 @@
package http
import (
"context"
"encoding/xml"
"fmt"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
"net/http"
)
// OkXml writes v into w with 200 OK.
func OkXml(w http.ResponseWriter, v any) {
WriteXml(w, http.StatusOK, v)
}
// OkXmlCtx writes v into w with 200 OK.
func OkXmlCtx(ctx context.Context, w http.ResponseWriter, v any) {
WriteXmlCtx(ctx, w, http.StatusOK, v)
}
// WriteXml writes v as xml string into w with code.
func WriteXml(w http.ResponseWriter, code int, v any) {
if err := doWriteXml(w, code, v); err != nil {
logx.Error(err)
}
}
// WriteXmlCtx writes v as xml string into w with code.
func WriteXmlCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
if err := doWriteXml(w, code, v); err != nil {
logx.WithContext(ctx).Error(err)
}
}
func doWriteXml(w http.ResponseWriter, code int, v any) error {
bs, err := xml.Marshal(v)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return fmt.Errorf("marshal xml failed, error: %w", err)
}
w.Header().Set(httpx.ContentType, XmlContentType)
w.WriteHeader(code)
if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here.
if err != http.ErrHandlerTimeout {
return fmt.Errorf("write response failed, error: %w", err)
}
} else if n < len(bs) {
return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
return nil
}
// OkHTML writes v into w with 200 OK.
func OkHTML(w http.ResponseWriter, v string) {
WriteHTML(w, http.StatusOK, v)
}
// OkHTMLCtx writes v into w with 200 OK.
func OkHTMLCtx(ctx context.Context, w http.ResponseWriter, v string) {
WriteHTMLCtx(ctx, w, http.StatusOK, v)
}
// WriteHTML writes v as HTML string into w with code.
func WriteHTML(w http.ResponseWriter, code int, v string) {
if err := doWriteHTML(w, code, v); err != nil {
logx.Error(err)
}
}
// WriteHTMLCtx writes v as HTML string into w with code.
func WriteHTMLCtx(ctx context.Context, w http.ResponseWriter, code int, v string) {
if err := doWriteHTML(w, code, v); err != nil {
logx.WithContext(ctx).Error(err)
}
}
func doWriteHTML(w http.ResponseWriter, code int, v string) error {
w.Header().Set(httpx.ContentType, HTMLContentType)
w.WriteHeader(code)
bs := []byte(v)
if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here.
if err != http.ErrHandlerTimeout {
return fmt.Errorf("write response failed, error: %w", err)
}
} else if n < len(bs) {
return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
return nil
}

View File

@@ -0,0 +1,19 @@
package http
const (
xmlVersion = "1.0"
xmlEncoding = "UTF-8"
// BusinessCodeOK represents the business code for success.
BusinessCodeOK = 0
// BusinessMsgOk represents the business message for success.
BusinessMsgOk = "ok"
// BusinessCodeError represents the business code for error.
BusinessCodeError = -1
// XmlContentType represents the content type for xml.
XmlContentType = "application/xml"
// HTMLContentType represents the content type for html.
HTMLContentType = "text/html;charset=utf-8"
)

View File

@@ -8,7 +8,7 @@ import (
type AccessJWTPayload struct {
UserID string `json:"user_id"`
Type string `json:"type" default:"access"`
Type string `json:"type"`
}
type AccessJWTClaims struct {
AccessJWTPayload
@@ -19,7 +19,7 @@ func GenerateAccessToken(secret string, payload AccessJWTPayload) string {
claims := AccessJWTClaims{
AccessJWTPayload: payload,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 15)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 30)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},

View File

@@ -8,7 +8,7 @@ import (
type RefreshJWTPayload struct {
UserID string `json:"user_id"`
Type string `json:"type" default:"refresh"`
Type string `json:"type"`
}
type RefreshJWTClaims struct {
RefreshJWTPayload

View File

@@ -31,6 +31,9 @@ type (
Password string `json:"password"`
Repassword string `json:"repassword"`
}
UserDeviceRequest {
AccessToken string `json:"access_token"`
}
// 登录响应参数
LoginResponse {
AccessToken string `json:"access_token"`
@@ -168,7 +171,7 @@ service core {
post /reset/password (ResetPasswordRequest) returns (Response)
@handler getUserDevice
get /device
post /device (UserDeviceRequest) returns (Response)
}
@server (
@@ -238,13 +241,13 @@ service core {
get /qq/url (OAuthRequest) returns (Response)
@handler giteeCallback
get /gitee/callback (OAuthCallbackRequest)
get /gitee/callback (OAuthCallbackRequest) returns (string)
@handler githubCallback
get /github/callback (OAuthCallbackRequest)
get /github/callback (OAuthCallbackRequest) returns (string)
@handler qqCallback
get /qq/callback (OAuthCallbackRequest)
get /qq/callback (OAuthCallbackRequest) returns (string)
@handler wechatCallback
get /wechat/callback

View File

@@ -29,7 +29,7 @@ Log:
# 日期格式化
TimeFormat:
# 日志在文件输出模式下,日志输出路径
Path: logs
Path: logs/system
# 日志输出级别 debug,info,error,severe
Level: debug
# 日志长度限制,打印单个日志的时候会对日志进行裁剪,只有对 content 进行裁剪
@@ -37,7 +37,7 @@ Log:
# 是否压缩日志
Compress: true
# 是否开启 stat 日志go-zero 版本大于等于1.5.0才支持
Stat: true
Stat: false
# 日志保留天数,只有在文件模式才会生效
KeepDays: 7
# 堆栈打印冷却时间
@@ -51,7 +51,7 @@ Log:
# 文件名日期格式
FileTimeFormat:
Web:
URL: https://www.landaiqing.cn
URL: http://localhost:5173/
# 启用中间件
Middlewares:
# 访问日志中间件

View File

@@ -6,6 +6,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
http2 "schisandra-album-cloud-microservices/app/core/api/common/http"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/oauth"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
@@ -21,7 +22,7 @@ func GiteeCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
}
l := oauth.NewGiteeCallbackLogic(r.Context(), svcCtx)
err := l.GiteeCallback(w, r, &req)
data, err := l.GiteeCallback(w, r, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -30,7 +31,7 @@ func GiteeCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
http2.OkHTML(w, data)
}
}
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
http2 "schisandra-album-cloud-microservices/app/core/api/common/http"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/oauth"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
@@ -21,7 +21,7 @@ func GithubCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
}
l := oauth.NewGithubCallbackLogic(r.Context(), svcCtx)
err := l.GithubCallback(w, r, &req)
data, err := l.GithubCallback(w, r, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -30,7 +30,7 @@ func GithubCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
http2.OkHTML(w, data)
}
}
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
http2 "schisandra-album-cloud-microservices/app/core/api/common/http"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/oauth"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
@@ -21,7 +21,7 @@ func QqCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
}
l := oauth.NewQqCallbackLogic(r.Context(), svcCtx)
err := l.QqCallback(w, r, &req)
data, err := l.QqCallback(w, r, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -30,7 +30,7 @@ func QqCallbackHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
http2.OkHTML(w, data)
}
}
}

View File

@@ -221,7 +221,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
[]rest.Middleware{serverCtx.SecurityHeadersMiddleware},
[]rest.Route{
{
Method: http.MethodGet,
Method: http.MethodPost,
Path: "/device",
Handler: user.GetUserDeviceHandler(serverCtx),
},

View File

@@ -1,20 +1,26 @@
package user
import (
"github.com/zeromicro/go-zero/core/logx"
"net/http"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpx"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/internal/logic/user"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
"schisandra-album-cloud-microservices/app/core/api/internal/types"
)
func GetUserDeviceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.UserDeviceRequest
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
}
l := user.NewGetUserDeviceLogic(r.Context(), svcCtx)
err := l.GetUserDevice(r)
resp, err := l.GetUserDevice(r, w, &req)
if err != nil {
logx.Error(err)
httpx.WriteJsonCtx(
@@ -23,7 +29,7 @@ func GetUserDeviceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
http.StatusInternalServerError,
response.ErrorWithI18n(r.Context(), "system.error"))
} else {
httpx.Ok(w)
httpx.OkJsonCtx(r.Context(), w, resp)
}
}
}

View File

@@ -77,30 +77,33 @@ func NewGiteeCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Git
}
}
func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) error {
func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) (string, error) {
// 获取 token
tokenAuthUrl := l.GetGiteeTokenAuthUrl(req.Code)
token, err := l.GetGiteeToken(tokenAuthUrl)
if err != nil {
return err
return "", err
}
if token == nil {
return nil
return "", errors.New("get gitee token failed")
}
// 获取用户信息
userInfo, err := l.GetGiteeUserInfo(token)
if err != nil {
return err
return "", err
}
if userInfo == nil {
return "", errors.New("get gitee user info failed")
}
var giteeUser GiteeUser
marshal, err := json.Marshal(userInfo)
if err != nil {
return err
return "", err
}
if err = json.Unmarshal(marshal, &giteeUser); err != nil {
return err
return "", err
}
Id := strconv.Itoa(giteeUser.ID)
@@ -109,7 +112,7 @@ func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Reques
userSocial := l.svcCtx.DB.ScaAuthUserSocial
socialUser, err := tx.ScaAuthUserSocial.Where(userSocial.OpenID.Eq(Id), userSocial.Source.Eq(constant.OAuthSourceGitee)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
return "", err
}
if socialUser == nil {
@@ -128,7 +131,7 @@ func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Reques
err = tx.ScaAuthUser.Create(addUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
gitee := constant.OAuthSourceGitee
newSocialUser := &model.ScaAuthUserSocial{
@@ -139,56 +142,56 @@ func (l *GiteeCallbackLogic) GiteeCallback(w http.ResponseWriter, r *http.Reques
err = tx.ScaAuthUserSocial.Create(newSocialUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if res, err := l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User); !res || err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx); err != nil {
return err
data, err := HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
} else {
authUser := l.svcCtx.DB.ScaAuthUser
authUserInfo, err := tx.ScaAuthUser.Where(authUser.UID.Eq(socialUser.UserID)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
// HandleOauthLoginResponse 处理登录响应
func HandleOauthLoginResponse(scaAuthUser *model.ScaAuthUser, svcCtx *svc.ServiceContext, r *http.Request, w http.ResponseWriter, ctx context.Context) error {
func HandleOauthLoginResponse(scaAuthUser *model.ScaAuthUser, svcCtx *svc.ServiceContext, r *http.Request, w http.ResponseWriter, ctx context.Context) (string, error) {
data, err := user.HandleUserLogin(scaAuthUser, svcCtx, true, r, w, ctx)
if err != nil {
return err
return "", err
}
responseData := response.SuccessWithData(data)
formattedScript := fmt.Sprintf(Script, responseData, svcCtx.Config.Web.URL)
// 设置响应状态码和内容类型
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
// 写入响应内容
if _, writeErr := w.Write([]byte(formattedScript)); writeErr != nil {
return writeErr
marshalData, err := json.Marshal(responseData)
if err != nil {
return "", err
}
return nil
formattedScript := fmt.Sprintf(Script, marshalData, svcCtx.Config.Web.URL)
return formattedScript, nil
}
// GetGiteeTokenAuthUrl 获取Gitee token

View File

@@ -68,39 +68,39 @@ func NewGithubCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Gi
}
}
func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) error {
func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) (string, error) {
// 获取 token
tokenAuthUrl := l.GetTokenAuthUrl(req.Code)
token, err := l.GetToken(tokenAuthUrl)
if err != nil {
return err
return "", err
}
if token == nil {
return nil
return "", errors.New("get github token failed")
}
// 获取用户信息
userInfo, err := l.GetUserInfo(token)
if err != nil {
return err
return "", err
}
if userInfo == nil {
return nil
return "", errors.New("get github user info failed")
}
// 处理用户信息
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
return err
return "", err
}
var gitHubUser GitHubUser
err = json.Unmarshal(userInfoBytes, &gitHubUser)
if err != nil {
return err
return "", err
}
Id := strconv.Itoa(gitHubUser.ID)
tx := l.svcCtx.DB.Begin()
@@ -108,7 +108,7 @@ func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Requ
userSocial := l.svcCtx.DB.ScaAuthUserSocial
socialUser, err := tx.ScaAuthUserSocial.Where(userSocial.OpenID.Eq(Id), userSocial.Source.Eq(constant.OAuthSourceGithub)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
return "", err
}
if socialUser == nil {
@@ -130,7 +130,7 @@ func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Requ
err = tx.ScaAuthUser.Create(addUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
githubUser := constant.OAuthSourceGithub
newSocialUser := &model.ScaAuthUserSocial{
@@ -141,37 +141,42 @@ func (l *GithubCallbackLogic) GithubCallback(w http.ResponseWriter, r *http.Requ
err = tx.ScaAuthUserSocial.Create(newSocialUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if res, err := l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User); !res || err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
} else {
authUser := l.svcCtx.DB.ScaAuthUser
authUserInfo, err := tx.ScaAuthUser.Where(authUser.UID.Eq(socialUser.UserID)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
// GetTokenAuthUrl 通过code获取token认证url

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/yitter/idgenerator-go/idgen"
"gorm.io/gorm"
@@ -65,39 +66,42 @@ func NewQqCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *QqCall
}
}
func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) error {
func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req *types.OAuthCallbackRequest) (string, error) {
tokenAuthUrl := l.GetQQTokenAuthUrl(req.Code)
token, err := l.GetQQToken(tokenAuthUrl)
if err != nil {
return err
return "", err
}
if token == nil {
return nil
return "", errors.New("get qq token failed")
}
// 通过 token 获取 openid
authQQme, err := l.GetQQUserOpenID(token)
if err != nil {
return err
return "", err
}
// 通过 token 和 openid 获取用户信息
userInfo, err := l.GetQQUserUserInfo(token, authQQme.OpenID)
if err != nil {
return err
return "", err
}
if userInfo == nil {
return "", errors.New("get qq user info failed")
}
// 处理用户信息
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
return err
return "", err
}
var qqUserInfo QQUserInfo
err = json.Unmarshal(userInfoBytes, &qqUserInfo)
if err != nil {
return err
return "", err
}
tx := l.svcCtx.DB.Begin()
@@ -105,7 +109,7 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
userSocial := l.svcCtx.DB.ScaAuthUserSocial
socialUser, err := tx.ScaAuthUserSocial.Where(userSocial.OpenID.Eq(authQQme.OpenID), userSocial.Source.Eq(constant.OAuthSourceQQ)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
return "", err
}
if socialUser == nil {
@@ -114,9 +118,10 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
uidStr := strconv.FormatInt(uid, 10)
male := constant.Male
avatarUrl := strings.Replace(qqUserInfo.FigureurlQq1, "http://", "https://", 1)
addUser := &model.ScaAuthUser{
UID: uidStr,
Avatar: qqUserInfo.FigureurlQq1,
Avatar: avatarUrl,
Username: authQQme.OpenID,
Nickname: qqUserInfo.Nickname,
Gender: male,
@@ -124,7 +129,7 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
err = tx.ScaAuthUser.Create(addUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
githubUser := constant.OAuthSourceQQ
@@ -136,37 +141,42 @@ func (l *QqCallbackLogic) QqCallback(w http.ResponseWriter, r *http.Request, req
err = tx.ScaAuthUserSocial.Create(newSocialUser)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if res, err := l.svcCtx.CasbinEnforcer.AddRoleForUser(uidStr, constant.User); !res || err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(addUser, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
} else {
authUser := l.svcCtx.DB.ScaAuthUser
authUserInfo, err := tx.ScaAuthUser.Where(authUser.UID.Eq(socialUser.UserID)).First()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
_ = tx.Rollback()
return err
return "", err
}
if err = HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx); err != nil {
data, err := HandleOauthLoginResponse(authUserInfo, l.svcCtx, r, w, l.ctx)
if err != nil {
_ = tx.Rollback()
return err
return "", err
}
if err = tx.Commit(); err != nil {
return "", err
}
return data, nil
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
// GetQQTokenAuthUrl 通过code获取token认证url

View File

@@ -55,7 +55,7 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
key := strings.TrimPrefix(msg.EventKey, "qrscene_")
err = l.HandlerWechatLogin(msg.FromUserName, key, w, r)
@@ -66,10 +66,10 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
case models.CALLBACK_EVENT_UNSUBSCRIBE:
msg := models.EventUnSubscribe{}
err := event.ReadMessage(&msg)
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
return messages.NewText("ok")
@@ -78,7 +78,7 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
err = l.HandlerWechatLogin(msg.FromUserName, msg.EventKey, w, r)
if err != nil {
@@ -90,10 +90,10 @@ func (l *WechatCallbackLogic) WechatCallback(w http.ResponseWriter, r *http.Requ
case models2.CALLBACK_MSG_TYPE_TEXT:
msg := models.MessageText{}
err := event.ReadMessage(&msg)
err = event.ReadMessage(&msg)
if err != nil {
println(err.Error())
return "error"
return err
}
}
return messages.NewText("ok")
@@ -205,5 +205,4 @@ func (l *WechatCallbackLogic) HandlerWechatLogin(openId string, clientId string,
return err
}
return nil
}

View File

@@ -34,11 +34,7 @@ func (l *RefreshTokenLogic) RefreshToken(r *http.Request) (resp *types.Response,
if err != nil {
return nil, err
}
refreshSessionToken, ok := session.Values["refresh_token"].(string)
if !ok {
return response.ErrorWithCode(403), nil
}
userId, ok := session.Values["uid"].(string)
userId, ok := session.Values["user_id"].(string)
if !ok {
return response.ErrorWithCode(403), nil
}
@@ -51,23 +47,20 @@ func (l *RefreshTokenLogic) RefreshToken(r *http.Request) (resp *types.Response,
if err != nil {
return nil, err
}
if redisTokenData.RefreshToken != refreshSessionToken {
return response.ErrorWithCode(403), nil
}
refreshToken, result := jwt.ParseRefreshToken(l.svcCtx.Config.Auth.AccessSecret, refreshSessionToken)
refreshToken, result := jwt.ParseRefreshToken(l.svcCtx.Config.Auth.AccessSecret, redisTokenData.RefreshToken)
if !result {
return response.ErrorWithCode(403), nil
}
accessToken := jwt.GenerateAccessToken(l.svcCtx.Config.Auth.AccessSecret, jwt.AccessJWTPayload{
UserID: refreshToken.UserID,
Type: constant.JWT_TYPE_ACCESS,
})
if accessToken == "" {
return response.ErrorWithCode(403), nil
}
redisToken := types.RedisToken{
AccessToken: accessToken,
RefreshToken: refreshSessionToken,
RefreshToken: redisTokenData.RefreshToken,
UID: refreshToken.UserID,
}
err = l.svcCtx.RedisClient.Set(l.ctx, constant.UserTokenPrefix+refreshToken.UserID, redisToken, time.Hour*24*7).Err()

View File

@@ -3,6 +3,7 @@ package user
import (
"context"
"errors"
"github.com/rbcervilla/redisstore/v9"
"net/http"
"time"
@@ -80,6 +81,7 @@ func HandleUserLogin(user *model.ScaAuthUser, svcCtx *svc.ServiceContext, autoLo
// 生成jwt token
accessToken := jwt.GenerateAccessToken(svcCtx.Config.Auth.AccessSecret, jwt.AccessJWTPayload{
UserID: user.UID,
Type: constant.JWT_TYPE_ACCESS,
})
var days time.Duration
if autoLogin {
@@ -89,6 +91,7 @@ func HandleUserLogin(user *model.ScaAuthUser, svcCtx *svc.ServiceContext, autoLo
}
refreshToken := jwt.GenerateRefreshToken(svcCtx.Config.Auth.AccessSecret, jwt.RefreshJWTPayload{
UserID: user.UID,
Type: constant.JWT_TYPE_REFRESH,
}, days)
data := types.LoginResponse{
AccessToken: accessToken,
@@ -108,16 +111,24 @@ func HandleUserLogin(user *model.ScaAuthUser, svcCtx *svc.ServiceContext, autoLo
if err != nil {
return nil, err
}
session, err := svcCtx.Session.Get(r, constant.SESSION_KEY)
if err != nil {
return nil, err
}
session.Values["refresh_token"] = refreshToken
session.Values["uid"] = user.UID
err = session.Save(r, w)
err = HandlerSession(r, w, user.UID, svcCtx.Session)
if err != nil {
return nil, err
}
return &data, nil
}
// HandlerSession is a function to set the user_id in the session
func HandlerSession(r *http.Request, w http.ResponseWriter, userID string, redisSession *redisstore.RedisStore) error {
session, err := redisSession.Get(r, constant.SESSION_KEY)
if err != nil {
return err
}
session.Values["user_id"] = userID
err = session.Save(r, w)
if err != nil {
return err
}
return nil
}

View File

@@ -3,18 +3,20 @@ package user
import (
"context"
"errors"
"net/http"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"github.com/mssola/useragent"
"github.com/zeromicro/go-zero/core/logx"
"gorm.io/gorm"
"schisandra-album-cloud-microservices/app/core/api/common/constant"
"net/http"
"schisandra-album-cloud-microservices/app/core/api/common/jwt"
"schisandra-album-cloud-microservices/app/core/api/common/response"
"schisandra-album-cloud-microservices/app/core/api/common/utils"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
"schisandra-album-cloud-microservices/app/core/api/repository/mysql/model"
"schisandra-album-cloud-microservices/app/core/api/repository/mysql/query"
"schisandra-album-cloud-microservices/app/core/api/internal/svc"
"schisandra-album-cloud-microservices/app/core/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type GetUserDeviceLogic struct {
@@ -31,20 +33,20 @@ func NewGetUserDeviceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Get
}
}
func (l *GetUserDeviceLogic) GetUserDevice(r *http.Request) error {
session, err := l.svcCtx.Session.Get(r, constant.SESSION_KEY)
if err != nil {
return err
}
uid, ok := session.Values["uid"].(string)
func (l *GetUserDeviceLogic) GetUserDevice(r *http.Request, w http.ResponseWriter, req *types.UserDeviceRequest) (resp *types.Response, err error) {
token, ok := jwt.ParseAccessToken(l.svcCtx.Config.Auth.AccessSecret, req.AccessToken)
if !ok {
return errors.New("user session not found")
return response.Error(), nil
}
if err = GetUserLoginDevice(uid, r, l.svcCtx.Ip2Region, l.svcCtx.DB); err != nil {
return err
err = HandlerSession(r, w, token.UserID, l.svcCtx.Session)
if err != nil {
return nil, err
}
return nil
err = GetUserLoginDevice(token.UserID, r, l.svcCtx.Ip2Region, l.svcCtx.DB)
if err != nil {
return nil, err
}
return response.Success(), nil
}
// GetUserLoginDevice 获取用户登录设备

View File

@@ -2,11 +2,10 @@ package middleware
import (
"net/http"
"schisandra-album-cloud-microservices/app/core/api/common/constant"
"github.com/casbin/casbin/v2"
"github.com/rbcervilla/redisstore/v9"
"schisandra-album-cloud-microservices/app/core/api/common/constant"
)
type CasbinVerifyMiddleware struct {
@@ -28,7 +27,7 @@ func (m *CasbinVerifyMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
userId, ok := session.Values["uid"].(string)
userId, ok := session.Values["user_id"].(string)
if !ok {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return

View File

@@ -118,3 +118,7 @@ type UploadRequest struct {
AccessToken string `json:"access_token"`
UserId string `json:"user_id"`
}
type UserDeviceRequest struct {
AccessToken string `json:"access_token"`
}

View File

@@ -19,7 +19,7 @@ type ScaAuthUser struct {
Username string `gorm:"column:username;type:varchar(32);comment:用户名" json:"username"` // 用户名
Nickname string `gorm:"column:nickname;type:varchar(32);comment:昵称" json:"nickname"` // 昵称
Email string `gorm:"column:email;type:varchar(32);comment:邮箱" json:"email"` // 邮箱
Phone string `gorm:"column:phone;type:varchar(32);uniqueIndex:phone,priority:1;comment:电话" json:"phone"` // 电话
Phone string `gorm:"column:phone;type:varchar(32);comment:电话" json:"phone"` // 电话
Password string `gorm:"column:password;type:varchar(64);comment:密码" json:"password"` // 密码
Gender int64 `gorm:"column:gender;type:tinyint;comment:性别" json:"gender"` // 性别
Avatar string `gorm:"column:avatar;type:longtext;comment:头像" json:"avatar"` // 头像

View File

@@ -16,11 +16,12 @@ func NewWechatPublic(appId, appSecret, token, aesKey, addr, pass string, db int)
AESKey: aesKey,
Log: officialAccount.Log{
Level: "error",
Stdout: true,
File: "/logs/wechat/wechat_official.log",
Stdout: false,
},
ResponseType: os.Getenv("response_type"),
HttpDebug: true,
Debug: true,
HttpDebug: false,
Debug: false,
Cache: kernel.NewRedisClient(&kernel.UniversalOptions{
Addrs: []string{addr},
Password: pass,