🎨 update project structure
This commit is contained in:
3
controller/captcha_controller/captcha.go
Normal file
3
controller/captcha_controller/captcha.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package captcha_controller
|
||||
|
||||
type CaptchaController struct{}
|
372
controller/captcha_controller/captcha_controller.go
Normal file
372
controller/captcha_controller/captcha_controller.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package captcha_controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/wenlng/go-captcha-assets/helper"
|
||||
"github.com/wenlng/go-captcha/v2/click"
|
||||
"github.com/wenlng/go-captcha/v2/rotate"
|
||||
"github.com/wenlng/go-captcha/v2/slide"
|
||||
"log"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GenerateRotateCaptcha 生成旋转验证码
|
||||
// @Summary 生成旋转验证码
|
||||
// @Description 生成旋转验证码
|
||||
// @Tags 旋转验证码
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/rotate/get [get]
|
||||
func (CaptchaController) GenerateRotateCaptcha(c *gin.Context) {
|
||||
captchaData, err := global.RotateCaptcha.Generate()
|
||||
if err != nil {
|
||||
global.LOG.Fatalln(err)
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
blockData := captchaData.GetData()
|
||||
if blockData == nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
masterImageBase64 := captchaData.GetMasterImage().ToBase64()
|
||||
thumbImageBase64 := captchaData.GetThumbImage().ToBase64()
|
||||
dotsByte, err := json.Marshal(blockData)
|
||||
if err != nil {
|
||||
global.LOG.Fatalln(err)
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
key := helper.StringToMD5(string(dotsByte))
|
||||
err = redis.Set(constant.UserLoginCaptchaRedisKey+key, dotsByte, time.Minute).Err()
|
||||
if err != nil {
|
||||
global.LOG.Fatalln(err)
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
result.OkWithData(map[string]interface{}{
|
||||
"key": key,
|
||||
"image": masterImageBase64,
|
||||
"thumb": thumbImageBase64,
|
||||
}, c)
|
||||
}
|
||||
|
||||
// CheckRotateData 验证旋转验证码
|
||||
// @Summary 验证旋转验证码
|
||||
// @Description 验证旋转验证码
|
||||
// @Tags 旋转验证码
|
||||
// @Param angle query string true "验证码角度"
|
||||
// @Param key query string true "验证码key"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/rotate/check [post]
|
||||
func (CaptchaController) CheckRotateData(c *gin.Context) {
|
||||
var rotateRequest RotateCaptchaRequest
|
||||
if err := c.ShouldBindJSON(&rotateRequest); err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheDataByte, err := redis.Get(constant.UserLoginCaptchaRedisKey + rotateRequest.Key).Bytes()
|
||||
if err != nil || len(cacheDataByte) == 0 {
|
||||
result.FailWithCodeAndMessage(1011, ginI18n.MustGetMessage(c, "CaptchaExpired"), c)
|
||||
return
|
||||
}
|
||||
|
||||
var dct *rotate.Block
|
||||
if err := json.Unmarshal(cacheDataByte, &dct); err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
sAngle, err := strconv.ParseFloat(fmt.Sprintf("%v", rotateRequest.Angle), 64)
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
chkRet := rotate.CheckAngle(int64(sAngle), int64(dct.Angle), 2)
|
||||
if chkRet {
|
||||
result.OkWithMessage("success", c)
|
||||
return
|
||||
}
|
||||
|
||||
result.FailWithMessage("fail", c)
|
||||
}
|
||||
|
||||
// GenerateBasicTextCaptcha 生成基础文字验证码
|
||||
// @Summary 生成基础文字验证码
|
||||
// @Description 生成基础文字验证码
|
||||
// @Tags 基础文字验证码
|
||||
// @Param type query string true "验证码类型"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/text/get [get]
|
||||
func (CaptchaController) GenerateBasicTextCaptcha(c *gin.Context) {
|
||||
var capt click.Captcha
|
||||
if c.Query("type") == "light" {
|
||||
capt = global.LightTextCaptcha
|
||||
} else {
|
||||
capt = global.TextCaptcha
|
||||
}
|
||||
captData, err := capt.Generate()
|
||||
if err != nil {
|
||||
global.LOG.Fatalln(err)
|
||||
}
|
||||
dotData := captData.GetData()
|
||||
if dotData == nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
var masterImageBase64, thumbImageBase64 string
|
||||
masterImageBase64 = captData.GetMasterImage().ToBase64()
|
||||
thumbImageBase64 = captData.GetThumbImage().ToBase64()
|
||||
|
||||
dotsByte, err := json.Marshal(dotData)
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
key := helper.StringToMD5(string(dotsByte))
|
||||
err = redis.Set("user:login:client:"+key, dotsByte, time.Minute).Err()
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
bt := map[string]interface{}{
|
||||
"key": key,
|
||||
"image": masterImageBase64,
|
||||
"thumb": thumbImageBase64,
|
||||
}
|
||||
result.OkWithData(bt, c)
|
||||
}
|
||||
|
||||
// CheckClickData 验证基础文字验证码
|
||||
// @Summary 验证基础文字验证码
|
||||
// @Description 验证基础文字验证码
|
||||
// @Tags 基础文字验证码
|
||||
// @Param captcha query string true "验证码"
|
||||
// @Param key query string true "验证码key"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/text/check [get]
|
||||
func (CaptchaController) CheckClickData(c *gin.Context) {
|
||||
dots := c.Query("dots")
|
||||
key := c.Query("key")
|
||||
if dots == "" || key == "" {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
cacheDataByte, err := redis.Get("user:login:client:" + key).Bytes()
|
||||
if len(cacheDataByte) == 0 || err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
src := strings.Split(dots, ",")
|
||||
|
||||
var dct map[int]*click.Dot
|
||||
if err := json.Unmarshal(cacheDataByte, &dct); err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
chkRet := false
|
||||
if (len(dct) * 2) == len(src) {
|
||||
for i := 0; i < len(dct); i++ {
|
||||
dot := dct[i]
|
||||
j := i * 2
|
||||
k := i*2 + 1
|
||||
sx, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[j]), 64)
|
||||
sy, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[k]), 64)
|
||||
|
||||
chkRet = click.CheckPoint(int64(sx), int64(sy), int64(dot.X), int64(dot.Y), int64(dot.Width), int64(dot.Height), 0)
|
||||
if !chkRet {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if chkRet {
|
||||
result.OkWithMessage("success", c)
|
||||
return
|
||||
}
|
||||
result.FailWithMessage("fail", c)
|
||||
}
|
||||
|
||||
// GenerateClickShapeCaptcha 生成点击形状验证码
|
||||
// @Summary 生成点击形状验证码
|
||||
// @Description 生成点击形状验证码
|
||||
// @Tags 点击形状验证码
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/shape/get [get]
|
||||
func (CaptchaController) GenerateClickShapeCaptcha(c *gin.Context) {
|
||||
captData, err := global.ClickShapeCaptcha.Generate()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
dotData := captData.GetData()
|
||||
if dotData == nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
var masterImageBase64, thumbImageBase64 string
|
||||
masterImageBase64 = captData.GetMasterImage().ToBase64()
|
||||
thumbImageBase64 = captData.GetThumbImage().ToBase64()
|
||||
|
||||
dotsByte, err := json.Marshal(dotData)
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
key := helper.StringToMD5(string(dotsByte))
|
||||
err = redis.Set(key, dotsByte, time.Minute).Err()
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
bt := map[string]interface{}{
|
||||
"key": key,
|
||||
"image": masterImageBase64,
|
||||
"thumb": thumbImageBase64,
|
||||
}
|
||||
result.OkWithData(bt, c)
|
||||
}
|
||||
|
||||
// GenerateSlideBasicCaptData 滑块基础验证码
|
||||
// @Summary 滑块基础验证码
|
||||
// @Description 滑块基础验证码
|
||||
// @Tags 滑块基础验证码
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/slide/generate [get]
|
||||
func (CaptchaController) GenerateSlideBasicCaptData(c *gin.Context) {
|
||||
captData, err := global.SlideCaptcha.Generate()
|
||||
if err != nil {
|
||||
global.LOG.Fatalln(err)
|
||||
}
|
||||
blockData := captData.GetData()
|
||||
if blockData == nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
var masterImageBase64, tileImageBase64 string
|
||||
masterImageBase64 = captData.GetMasterImage().ToBase64()
|
||||
|
||||
tileImageBase64 = captData.GetTileImage().ToBase64()
|
||||
|
||||
dotsByte, err := json.Marshal(blockData)
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
key := helper.StringToMD5(string(dotsByte))
|
||||
err = redis.Set(constant.CommentSubmitCaptchaRedisKey+key, dotsByte, time.Minute).Err()
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
bt := map[string]interface{}{
|
||||
"key": key,
|
||||
"image": masterImageBase64,
|
||||
"thumb": tileImageBase64,
|
||||
"thumb_width": blockData.Width,
|
||||
"thumb_height": blockData.Height,
|
||||
"thumb_x": blockData.TileX,
|
||||
"thumb_y": blockData.TileY,
|
||||
}
|
||||
result.OkWithData(bt, c)
|
||||
}
|
||||
|
||||
// GenerateSlideRegionCaptData 生成滑动区域形状验证码
|
||||
// @Summary 生成滑动区域形状验证码
|
||||
// @Description 生成滑动区域形状验证码
|
||||
// @Tags 生成滑动区域形状验证码
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/shape/slide/region/get [get]
|
||||
func (CaptchaController) GenerateSlideRegionCaptData(c *gin.Context) {
|
||||
captData, err := global.SlideRegionCaptcha.Generate()
|
||||
if err != nil {
|
||||
global.LOG.Fatalln(err)
|
||||
}
|
||||
|
||||
blockData := captData.GetData()
|
||||
if blockData == nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
var masterImageBase64, tileImageBase64 string
|
||||
masterImageBase64 = captData.GetMasterImage().ToBase64()
|
||||
tileImageBase64 = captData.GetTileImage().ToBase64()
|
||||
|
||||
blockByte, err := json.Marshal(blockData)
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
key := helper.StringToMD5(string(blockByte))
|
||||
err = redis.Set(key, blockByte, time.Minute).Err()
|
||||
if err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
bt := map[string]interface{}{
|
||||
"code": 0,
|
||||
"key": key,
|
||||
"image": masterImageBase64,
|
||||
"tile": tileImageBase64,
|
||||
"tile_width": blockData.Width,
|
||||
"tile_height": blockData.Height,
|
||||
"tile_x": blockData.TileX,
|
||||
"tile_y": blockData.TileY,
|
||||
}
|
||||
result.OkWithData(bt, c)
|
||||
}
|
||||
|
||||
// CheckSlideData 验证滑动验证码
|
||||
// @Summary 验证滑动验证码
|
||||
// @Description 验证滑动验证码
|
||||
// @Tags 验证滑动验证码
|
||||
// @Param point query string true "点击坐标"
|
||||
// @Param key query string true "验证码key"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/captcha/shape/slide/check [get]
|
||||
func (CaptchaController) CheckSlideData(c *gin.Context) {
|
||||
point := c.Query("point")
|
||||
key := c.Query("key")
|
||||
if point == "" || key == "" {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheDataByte, err := redis.Get(key).Bytes()
|
||||
if len(cacheDataByte) == 0 || err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
src := strings.Split(point, ",")
|
||||
|
||||
var dct *slide.Block
|
||||
if err := json.Unmarshal(cacheDataByte, &dct); err != nil {
|
||||
result.FailWithNull(c)
|
||||
return
|
||||
}
|
||||
|
||||
chkRet := false
|
||||
if 2 == len(src) {
|
||||
sx, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[0]), 64)
|
||||
sy, _ := strconv.ParseFloat(fmt.Sprintf("%v", src[1]), 64)
|
||||
chkRet = slide.CheckPoint(int64(sx), int64(sy), int64(dct.X), int64(dct.Y), 4)
|
||||
}
|
||||
|
||||
if chkRet {
|
||||
result.OkWithMessage("success", c)
|
||||
return
|
||||
}
|
||||
result.FailWithMessage("fail", c)
|
||||
}
|
6
controller/captcha_controller/request_param.go
Normal file
6
controller/captcha_controller/request_param.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package captcha_controller
|
||||
|
||||
type RotateCaptchaRequest struct {
|
||||
Angle int `json:"angle" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
}
|
7
controller/client_controller/client.go
Normal file
7
controller/client_controller/client.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package client_controller
|
||||
|
||||
import "sync"
|
||||
|
||||
type ClientController struct{}
|
||||
|
||||
var mu sync.Mutex
|
46
controller/client_controller/client_controller.go
Normal file
46
controller/client_controller/client_controller.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package client_controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GenerateClientId 生成客户端ID
|
||||
// @Summary 生成客户端ID
|
||||
// @Description 生成客户端ID
|
||||
// @Tags 微信公众号
|
||||
// @Produce json
|
||||
// @Router /controller/oauth/generate_client_id [get]
|
||||
func (ClientController) GenerateClientId(c *gin.Context) {
|
||||
// 获取客户端IP
|
||||
ip := utils.GetClientIP(c)
|
||||
// 加锁
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// 从Redis获取客户端ID
|
||||
clientId := redis.Get(constant.UserLoginClientRedisKey + ip).Val()
|
||||
if clientId != "" {
|
||||
result.OkWithData(clientId, c)
|
||||
return
|
||||
}
|
||||
// 生成新的客户端ID
|
||||
uid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
err = redis.Set(constant.UserLoginClientRedisKey+ip, uid.String(), time.Hour*24*7).Err()
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
result.OkWithData(uid.String(), c)
|
||||
return
|
||||
}
|
57
controller/comment_controller/comment.go
Normal file
57
controller/comment_controller/comment.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package comment_controller
|
||||
|
||||
import (
|
||||
"schisandra-cloud-album/service/impl"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CommentController struct{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mx sync.Mutex
|
||||
var commentReplyService = impl.CommentReplyServiceImpl{}
|
||||
|
||||
// CommentImages 评论图片
|
||||
type CommentImages struct {
|
||||
TopicId string `json:"topic_id" bson:"topic_id" required:"true"`
|
||||
CommentId int64 `json:"comment_id" bson:"comment_id" required:"true"`
|
||||
UserId string `json:"user_id" bson:"user_id" required:"true"`
|
||||
Images [][]byte `json:"images" bson:"images" required:"true"`
|
||||
CreatedAt string `json:"created_at" bson:"created_at" required:"true"`
|
||||
}
|
||||
|
||||
// CommentContent 评论内容
|
||||
type CommentContent struct {
|
||||
NickName string `json:"nickname"`
|
||||
Avatar string `json:"avatar"`
|
||||
Level int `json:"level,omitempty"`
|
||||
Id int64 `json:"id"`
|
||||
UserId string `json:"user_id"`
|
||||
TopicId string `json:"topic_id"`
|
||||
Content string `json:"content"`
|
||||
ReplyTo int64 `json:"reply_to,omitempty"`
|
||||
ReplyId int64 `json:"reply_id,omitempty"`
|
||||
ReplyUser string `json:"reply_user,omitempty"`
|
||||
ReplyUsername string `json:"reply_username,omitempty"`
|
||||
Author int `json:"author"`
|
||||
Likes int64 `json:"likes"`
|
||||
ReplyCount int64 `json:"reply_count"`
|
||||
CreatedTime time.Time `json:"created_time"`
|
||||
Location string `json:"location"`
|
||||
Browser string `json:"browser"`
|
||||
OperatingSystem string `json:"operating_system"`
|
||||
IsLiked bool `json:"is_liked" default:"false"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
// CommentResponse 评论返回值
|
||||
type CommentResponse struct {
|
||||
Size int `json:"size"`
|
||||
Total int64 `json:"total"`
|
||||
Current int `json:"current"`
|
||||
Comments []CommentContent `json:"comments"`
|
||||
}
|
||||
|
||||
var likeChannel = make(chan CommentLikeRequest, 1000)
|
||||
var cancelLikeChannel = make(chan CommentLikeRequest, 1000)
|
853
controller/comment_controller/comment_controller.go
Normal file
853
controller/comment_controller/comment_controller.go
Normal file
@@ -0,0 +1,853 @@
|
||||
package comment_controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/acmestack/gorm-plus/gplus"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mssola/useragent"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"schisandra-cloud-album/common/enum"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"schisandra-cloud-album/mq"
|
||||
"schisandra-cloud-album/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CommentSubmit 提交评论
|
||||
// @Summary 提交评论
|
||||
// @Description 提交评论
|
||||
// @Tags 评论
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param comment_request body CommentRequest true "评论请求"
|
||||
// @Router /auth/comment/submit [post]
|
||||
func (CommentController) CommentSubmit(c *gin.Context) {
|
||||
commentRequest := CommentRequest{}
|
||||
if err := c.ShouldBindJSON(&commentRequest); err != nil {
|
||||
return
|
||||
}
|
||||
// 验证校验
|
||||
res := utils.CheckSlideData(commentRequest.Point, commentRequest.Key)
|
||||
if !res {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
if len(commentRequest.Images) > 3 {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "TooManyImages"), c)
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if userAgent == "" {
|
||||
global.LOG.Errorln("user-agent is empty")
|
||||
return
|
||||
}
|
||||
ua := useragent.New(userAgent)
|
||||
|
||||
ip := utils.GetClientIP(c)
|
||||
location, err := global.IP2Location.SearchByStr(ip)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
location = utils.RemoveZeroAndAdjust(location)
|
||||
|
||||
browser, _ := ua.Browser()
|
||||
operatingSystem := ua.OS()
|
||||
isAuthor := 0
|
||||
if commentRequest.UserID == commentRequest.Author {
|
||||
isAuthor = 1
|
||||
}
|
||||
|
||||
tx := global.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
commentReply := model.ScaCommentReply{
|
||||
Content: commentRequest.Content,
|
||||
UserId: commentRequest.UserID,
|
||||
TopicId: commentRequest.TopicId,
|
||||
TopicType: enum.CommentTopicType,
|
||||
CommentType: enum.COMMENT,
|
||||
Author: isAuthor,
|
||||
CommentIp: ip,
|
||||
Location: location,
|
||||
Browser: browser,
|
||||
OperatingSystem: operatingSystem,
|
||||
Agent: userAgent,
|
||||
}
|
||||
// 使用 goroutine 进行异步评论保存
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
errCh <- commentReplyService.CreateCommentReplyService(&commentReply)
|
||||
}()
|
||||
|
||||
// 等待评论回复的创建
|
||||
if err = <-errCh; err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// 处理图片异步上传
|
||||
if len(commentRequest.Images) > 0 {
|
||||
imagesDataCh := make(chan [][]byte)
|
||||
go func() {
|
||||
imagesData, err := processImages(commentRequest.Images)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
imagesDataCh <- nil // 发送失败信号
|
||||
return
|
||||
}
|
||||
imagesDataCh <- imagesData // 发送处理成功的数据
|
||||
}()
|
||||
|
||||
imagesData := <-imagesDataCh
|
||||
if imagesData == nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
commentImages := CommentImages{
|
||||
TopicId: commentRequest.TopicId,
|
||||
CommentId: commentReply.Id,
|
||||
UserId: commentRequest.UserID,
|
||||
Images: imagesData,
|
||||
CreatedAt: time.Now().Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
|
||||
// 使用 goroutine 进行异步图片保存
|
||||
go func() {
|
||||
if _, err = global.MongoDB.Database(global.CONFIG.MongoDB.DB).Collection("comment_images").InsertOne(context.Background(), commentImages); err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
tx.Commit()
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitSuccess"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// ReplySubmit 提交回复
|
||||
// @Summary 提交回复
|
||||
// @Description 提交回复
|
||||
// @Tags 评论
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param reply_comment_request body ReplyCommentRequest true "回复评论请求"
|
||||
// @Router /auth/reply/submit [post]
|
||||
func (CommentController) ReplySubmit(c *gin.Context) {
|
||||
replyCommentRequest := ReplyCommentRequest{}
|
||||
if err := c.ShouldBindJSON(&replyCommentRequest); err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
// 验证校验
|
||||
res := utils.CheckSlideData(replyCommentRequest.Point, replyCommentRequest.Key)
|
||||
if !res {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c)
|
||||
return
|
||||
}
|
||||
if len(replyCommentRequest.Images) > 3 {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "TooManyImages"), c)
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if userAgent == "" {
|
||||
global.LOG.Errorln("user-agent is empty")
|
||||
return
|
||||
}
|
||||
|
||||
ua := useragent.New(userAgent)
|
||||
ip := utils.GetClientIP(c)
|
||||
location, err := global.IP2Location.SearchByStr(ip)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
location = utils.RemoveZeroAndAdjust(location)
|
||||
|
||||
browser, _ := ua.Browser()
|
||||
operatingSystem := ua.OS()
|
||||
isAuthor := 0
|
||||
if replyCommentRequest.UserID == replyCommentRequest.Author {
|
||||
isAuthor = 1
|
||||
}
|
||||
|
||||
tx := global.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
commentReply := model.ScaCommentReply{
|
||||
Content: replyCommentRequest.Content,
|
||||
UserId: replyCommentRequest.UserID,
|
||||
TopicId: replyCommentRequest.TopicId,
|
||||
TopicType: enum.CommentTopicType,
|
||||
CommentType: enum.REPLY,
|
||||
ReplyId: replyCommentRequest.ReplyId,
|
||||
ReplyUser: replyCommentRequest.ReplyUser,
|
||||
Author: isAuthor,
|
||||
CommentIp: ip,
|
||||
Location: location,
|
||||
Browser: browser,
|
||||
OperatingSystem: operatingSystem,
|
||||
Agent: userAgent,
|
||||
}
|
||||
// 使用 goroutine 进行异步评论保存
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
|
||||
errCh <- commentReplyService.CreateCommentReplyService(&commentReply)
|
||||
}()
|
||||
go func() {
|
||||
|
||||
errCh <- commentReplyService.UpdateCommentReplyCountService(replyCommentRequest.ReplyId)
|
||||
}()
|
||||
// 等待评论回复的创建
|
||||
if err = <-errCh; err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// 处理图片异步上传
|
||||
if len(replyCommentRequest.Images) > 0 {
|
||||
imagesDataCh := make(chan [][]byte)
|
||||
go func() {
|
||||
imagesData, err := processImages(replyCommentRequest.Images)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
imagesDataCh <- nil // 发送失败信号
|
||||
return
|
||||
}
|
||||
imagesDataCh <- imagesData // 发送处理成功的数据
|
||||
}()
|
||||
|
||||
imagesData := <-imagesDataCh
|
||||
if imagesData == nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
commentImages := CommentImages{
|
||||
TopicId: replyCommentRequest.TopicId,
|
||||
CommentId: commentReply.Id,
|
||||
UserId: replyCommentRequest.UserID,
|
||||
Images: imagesData,
|
||||
CreatedAt: time.Now().Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
|
||||
// 使用 goroutine 进行异步图片保存
|
||||
go func() {
|
||||
if _, err = global.MongoDB.Database(global.CONFIG.MongoDB.DB).Collection("comment_images").InsertOne(context.Background(), commentImages); err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
tx.Commit()
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitSuccess"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// ReplyReplySubmit 提交回复的回复
|
||||
// @Summary 提交回复的回复
|
||||
// @Description 提交回复的回复
|
||||
// @Tags 评论
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param reply_reply_request body ReplyReplyRequest true "回复回复请求"
|
||||
// @Router /auth/reply/reply/submit [post]
|
||||
func (CommentController) ReplyReplySubmit(c *gin.Context) {
|
||||
replyReplyRequest := ReplyReplyRequest{}
|
||||
if err := c.ShouldBindJSON(&replyReplyRequest); err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
// 验证校验
|
||||
res := utils.CheckSlideData(replyReplyRequest.Point, replyReplyRequest.Key)
|
||||
if !res {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c)
|
||||
return
|
||||
}
|
||||
if len(replyReplyRequest.Images) > 3 {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "TooManyImages"), c)
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if userAgent == "" {
|
||||
global.LOG.Errorln("user-agent is empty")
|
||||
return
|
||||
}
|
||||
|
||||
ua := useragent.New(userAgent)
|
||||
ip := utils.GetClientIP(c)
|
||||
location, err := global.IP2Location.SearchByStr(ip)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
location = utils.RemoveZeroAndAdjust(location)
|
||||
|
||||
browser, _ := ua.Browser()
|
||||
operatingSystem := ua.OS()
|
||||
isAuthor := 0
|
||||
if replyReplyRequest.UserID == replyReplyRequest.Author {
|
||||
isAuthor = 1
|
||||
}
|
||||
|
||||
tx := global.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
commentReply := model.ScaCommentReply{
|
||||
Content: replyReplyRequest.Content,
|
||||
UserId: replyReplyRequest.UserID,
|
||||
TopicId: replyReplyRequest.TopicId,
|
||||
TopicType: enum.CommentTopicType,
|
||||
CommentType: enum.REPLY,
|
||||
ReplyTo: replyReplyRequest.ReplyTo,
|
||||
ReplyId: replyReplyRequest.ReplyId,
|
||||
ReplyUser: replyReplyRequest.ReplyUser,
|
||||
Author: isAuthor,
|
||||
CommentIp: ip,
|
||||
Location: location,
|
||||
Browser: browser,
|
||||
OperatingSystem: operatingSystem,
|
||||
Agent: userAgent,
|
||||
}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
errCh <- commentReplyService.CreateCommentReplyService(&commentReply)
|
||||
}()
|
||||
go func() {
|
||||
errCh <- commentReplyService.UpdateCommentReplyCountService(replyReplyRequest.ReplyId)
|
||||
}()
|
||||
|
||||
if err = <-errCh; err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
if len(replyReplyRequest.Images) > 0 {
|
||||
imagesDataCh := make(chan [][]byte)
|
||||
go func() {
|
||||
imagesData, err := processImages(replyReplyRequest.Images)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
imagesDataCh <- nil
|
||||
return
|
||||
}
|
||||
imagesDataCh <- imagesData
|
||||
}()
|
||||
|
||||
imagesData := <-imagesDataCh
|
||||
if imagesData == nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitFailed"), c)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
commentImages := CommentImages{
|
||||
TopicId: replyReplyRequest.TopicId,
|
||||
CommentId: commentReply.Id,
|
||||
UserId: replyReplyRequest.UserID,
|
||||
Images: imagesData,
|
||||
CreatedAt: time.Now().Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
|
||||
// 处理图片保存
|
||||
go func() {
|
||||
if _, err = global.MongoDB.Database(global.CONFIG.MongoDB.DB).Collection("comment_images").InsertOne(context.Background(), commentImages); err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
tx.Commit()
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentSubmitSuccess"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// CommentList 获取评论列表
|
||||
// @Summary 获取评论列表
|
||||
// @Description 获取评论列表
|
||||
// @Tags 评论
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param comment_list_request body CommentListRequest true "评论列表请求"
|
||||
// @Router /auth/comment/list [post]
|
||||
func (CommentController) CommentList(c *gin.Context) {
|
||||
commentListRequest := CommentListRequest{}
|
||||
err := c.ShouldBindJSON(&commentListRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
// 查询评论列表
|
||||
query, u := gplus.NewQuery[model.ScaCommentReply]()
|
||||
page := gplus.NewPage[model.ScaCommentReply](commentListRequest.Page, commentListRequest.Size)
|
||||
if commentListRequest.IsHot {
|
||||
query.OrderByDesc(&u.CommentOrder).OrderByDesc(&u.Likes).OrderByDesc(&u.ReplyCount)
|
||||
} else {
|
||||
query.OrderByDesc(&u.CommentOrder).OrderByDesc(&u.CreatedTime)
|
||||
}
|
||||
query.Eq(&u.TopicId, commentListRequest.TopicId).Eq(&u.CommentType, enum.COMMENT)
|
||||
page, pageDB := gplus.SelectPage(page, query)
|
||||
if pageDB.Error != nil {
|
||||
global.LOG.Errorln(pageDB.Error)
|
||||
return
|
||||
}
|
||||
if len(page.Records) == 0 {
|
||||
result.OkWithData(CommentResponse{Comments: []CommentContent{}, Size: page.Size, Current: page.Current, Total: page.Total}, c)
|
||||
return
|
||||
}
|
||||
|
||||
userIds := make([]string, 0, len(page.Records))
|
||||
commentIds := make([]int64, 0, len(page.Records))
|
||||
for _, comment := range page.Records {
|
||||
userIds = append(userIds, comment.UserId)
|
||||
commentIds = append(commentIds, comment.Id)
|
||||
}
|
||||
|
||||
// 结果存储
|
||||
userInfoMap := make(map[string]model.ScaAuthUser)
|
||||
likeMap := make(map[int64]bool)
|
||||
commentImagesMap := make(map[int64]CommentImages)
|
||||
|
||||
// 使用 WaitGroup 等待协程完成
|
||||
wg.Add(3)
|
||||
|
||||
// 查询评论用户信息
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
queryUser, n := gplus.NewQuery[model.ScaAuthUser]()
|
||||
queryUser.Select(&n.UID, &n.Avatar, &n.Nickname).In(&n.UID, userIds)
|
||||
userInfos, userInfosDB := gplus.SelectList(queryUser)
|
||||
if userInfosDB.Error != nil {
|
||||
global.LOG.Errorln(userInfosDB.Error)
|
||||
return
|
||||
}
|
||||
for _, userInfo := range userInfos {
|
||||
userInfoMap[*userInfo.UID] = *userInfo
|
||||
}
|
||||
}()
|
||||
|
||||
// 查询评论点赞状态
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if len(page.Records) > 0 {
|
||||
queryLike, l := gplus.NewQuery[model.ScaCommentLikes]()
|
||||
queryLike.Eq(&l.TopicId, commentListRequest.TopicId).Eq(&l.UserId, commentListRequest.UserID).In(&l.CommentId, commentIds)
|
||||
likes, likesDB := gplus.SelectList(queryLike)
|
||||
if likesDB.Error != nil {
|
||||
global.LOG.Errorln(likesDB.Error)
|
||||
return
|
||||
}
|
||||
for _, like := range likes {
|
||||
likeMap[like.CommentId] = true
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 查询评论图片信息
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // 设置超时,2秒
|
||||
defer cancel()
|
||||
|
||||
cursor, err := global.MongoDB.Database(global.CONFIG.MongoDB.DB).Collection("comment_images").Find(ctx, bson.M{"comment_id": bson.M{"$in": commentIds}})
|
||||
if err != nil {
|
||||
global.LOG.Errorf("Failed to get images for comments: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(cursor *mongo.Cursor, ctx context.Context) {
|
||||
err := cursor.Close(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}(cursor, ctx)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
var commentImages CommentImages
|
||||
if err = cursor.Decode(&commentImages); err != nil {
|
||||
global.LOG.Errorf("Failed to decode comment images: %v", err)
|
||||
continue
|
||||
}
|
||||
commentImagesMap[commentImages.CommentId] = commentImages
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待所有查询完成
|
||||
wg.Wait()
|
||||
commentChannel := make(chan CommentContent, len(page.Records))
|
||||
|
||||
for _, comment := range page.Records {
|
||||
wg.Add(1)
|
||||
go func(comment model.ScaCommentReply) {
|
||||
defer wg.Done()
|
||||
// 将图片转换为base64
|
||||
var imagesBase64 []string
|
||||
if imgData, ok := commentImagesMap[comment.Id]; ok {
|
||||
// 将图片转换为base64
|
||||
for _, img := range imgData.Images {
|
||||
mimeType := getMimeType(img)
|
||||
base64Img := base64.StdEncoding.EncodeToString(img)
|
||||
base64WithPrefix := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Img)
|
||||
imagesBase64 = append(imagesBase64, base64WithPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
userInfo, exist := userInfoMap[comment.UserId]
|
||||
if !exist {
|
||||
global.LOG.Errorf("Failed to get user info for comment: %s", comment.UserId)
|
||||
return
|
||||
}
|
||||
commentContent := CommentContent{
|
||||
Avatar: *userInfo.Avatar,
|
||||
NickName: *userInfo.Nickname,
|
||||
Id: comment.Id,
|
||||
UserId: comment.UserId,
|
||||
TopicId: comment.TopicId,
|
||||
Content: comment.Content,
|
||||
ReplyCount: comment.ReplyCount,
|
||||
Likes: comment.Likes,
|
||||
CreatedTime: comment.CreatedTime,
|
||||
Author: comment.Author,
|
||||
Location: comment.Location,
|
||||
Browser: comment.Browser,
|
||||
OperatingSystem: comment.OperatingSystem,
|
||||
Images: imagesBase64,
|
||||
IsLiked: likeMap[comment.Id],
|
||||
}
|
||||
commentChannel <- commentContent
|
||||
}(*comment)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(commentChannel)
|
||||
}()
|
||||
|
||||
var commentsWithImages []CommentContent
|
||||
for commentContent := range commentChannel {
|
||||
commentsWithImages = append(commentsWithImages, commentContent)
|
||||
}
|
||||
|
||||
response := CommentResponse{
|
||||
Comments: commentsWithImages,
|
||||
Size: page.Size,
|
||||
Current: page.Current,
|
||||
Total: page.Total,
|
||||
}
|
||||
result.OkWithData(response, c)
|
||||
return
|
||||
}
|
||||
|
||||
// ReplyList 获取回复列表
|
||||
// @Summary 获取回复列表
|
||||
// @Description 获取回复列表
|
||||
// @Tags 评论
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param reply_list_request body ReplyListRequest true "回复列表请求"
|
||||
// @Router /auth/reply/list [post]
|
||||
func (CommentController) ReplyList(c *gin.Context) {
|
||||
replyListRequest := ReplyListRequest{}
|
||||
err := c.ShouldBindJSON(&replyListRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
query, u := gplus.NewQuery[model.ScaCommentReply]()
|
||||
page := gplus.NewPage[model.ScaCommentReply](replyListRequest.Page, replyListRequest.Size)
|
||||
query.Eq(&u.TopicId, replyListRequest.TopicId).Eq(&u.ReplyId, replyListRequest.CommentId).Eq(&u.CommentType, enum.REPLY).OrderByDesc(&u.Likes).OrderByAsc(&u.CreatedTime)
|
||||
page, pageDB := gplus.SelectPage(page, query)
|
||||
if pageDB.Error != nil {
|
||||
global.LOG.Errorln(pageDB.Error)
|
||||
return
|
||||
}
|
||||
if len(page.Records) == 0 {
|
||||
result.OkWithData(CommentResponse{Comments: []CommentContent{}, Size: page.Size, Current: page.Current, Total: page.Total}, c)
|
||||
return
|
||||
}
|
||||
|
||||
userIdsSet := make(map[string]struct{}) // 使用集合去重用户 ID
|
||||
commentIds := make([]int64, 0, len(page.Records))
|
||||
// 收集用户 ID 和评论 ID
|
||||
for _, comment := range page.Records {
|
||||
userIdsSet[comment.UserId] = struct{}{} // 去重
|
||||
commentIds = append(commentIds, comment.Id)
|
||||
if comment.ReplyUser != "" {
|
||||
userIdsSet[comment.ReplyUser] = struct{}{} // 去重
|
||||
}
|
||||
}
|
||||
// 将用户 ID 转换为切片
|
||||
userIds := make([]string, 0, len(userIdsSet))
|
||||
for userId := range userIdsSet {
|
||||
userIds = append(userIds, userId)
|
||||
}
|
||||
|
||||
likeMap := make(map[int64]bool, len(page.Records))
|
||||
commentImagesMap := make(map[int64]CommentImages)
|
||||
userInfoMap := make(map[string]model.ScaAuthUser, len(userIds))
|
||||
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// 查询评论用户信息
|
||||
queryUser, n := gplus.NewQuery[model.ScaAuthUser]()
|
||||
queryUser.Select(&n.UID, &n.Avatar, &n.Nickname).In(&n.UID, userIds)
|
||||
userInfos, userInfosDB := gplus.SelectList(queryUser)
|
||||
if userInfosDB.Error != nil {
|
||||
global.LOG.Errorln(userInfosDB.Error)
|
||||
return
|
||||
}
|
||||
for _, userInfo := range userInfos {
|
||||
userInfoMap[*userInfo.UID] = *userInfo
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// 查询评论点赞状态
|
||||
|
||||
if len(page.Records) > 0 {
|
||||
queryLike, l := gplus.NewQuery[model.ScaCommentLikes]()
|
||||
queryLike.Eq(&l.TopicId, replyListRequest.TopicId).Eq(&l.UserId, replyListRequest.UserID).In(&l.CommentId, commentIds)
|
||||
likes, likesDB := gplus.SelectList(queryLike)
|
||||
if likesDB.Error != nil {
|
||||
global.LOG.Errorln(likesDB.Error)
|
||||
return
|
||||
}
|
||||
for _, like := range likes {
|
||||
likeMap[like.CommentId] = true
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) // 设置超时,2秒
|
||||
defer cancel()
|
||||
cursor, err := global.MongoDB.Database(global.CONFIG.MongoDB.DB).Collection("comment_images").Find(ctx, bson.M{"comment_id": bson.M{"$in": commentIds}})
|
||||
if err != nil {
|
||||
global.LOG.Errorf("Failed to get images for comments: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(cursor *mongo.Cursor, ctx context.Context) {
|
||||
warn := cursor.Close(ctx)
|
||||
if warn != nil {
|
||||
return
|
||||
}
|
||||
}(cursor, ctx)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
var commentImages CommentImages
|
||||
if e := cursor.Decode(&commentImages); e != nil {
|
||||
global.LOG.Errorf("Failed to decode comment images: %v", e)
|
||||
continue
|
||||
}
|
||||
commentImagesMap[commentImages.CommentId] = commentImages
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
replyChannel := make(chan CommentContent, len(page.Records)) // 使用通道传递回复内容
|
||||
|
||||
for _, reply := range page.Records {
|
||||
wg.Add(1)
|
||||
go func(reply model.ScaCommentReply) {
|
||||
defer wg.Done()
|
||||
|
||||
var imagesBase64 []string
|
||||
if imgData, ok := commentImagesMap[reply.Id]; ok {
|
||||
// 将图片转换为base64
|
||||
for _, img := range imgData.Images {
|
||||
mimeType := getMimeType(img)
|
||||
base64Img := base64.StdEncoding.EncodeToString(img)
|
||||
base64WithPrefix := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Img)
|
||||
imagesBase64 = append(imagesBase64, base64WithPrefix)
|
||||
}
|
||||
}
|
||||
userInfo, exist := userInfoMap[reply.UserId]
|
||||
if !exist {
|
||||
global.LOG.Errorf("Failed to get user info for comment: %s", reply.UserId)
|
||||
return
|
||||
}
|
||||
replyUserInfo, e := userInfoMap[reply.ReplyUser]
|
||||
if !e {
|
||||
global.LOG.Errorf("Failed to get reply user info for comment: %s", reply.ReplyUser)
|
||||
return
|
||||
}
|
||||
commentContent := CommentContent{
|
||||
Avatar: *userInfo.Avatar,
|
||||
NickName: *userInfo.Nickname,
|
||||
Id: reply.Id,
|
||||
UserId: reply.UserId,
|
||||
TopicId: reply.TopicId,
|
||||
Content: reply.Content,
|
||||
ReplyUsername: *replyUserInfo.Nickname,
|
||||
ReplyCount: reply.ReplyCount,
|
||||
Likes: reply.Likes,
|
||||
CreatedTime: reply.CreatedTime,
|
||||
ReplyUser: reply.ReplyUser,
|
||||
ReplyId: reply.ReplyId,
|
||||
ReplyTo: reply.ReplyTo,
|
||||
Author: reply.Author,
|
||||
Location: reply.Location,
|
||||
Browser: reply.Browser,
|
||||
OperatingSystem: reply.OperatingSystem,
|
||||
Images: imagesBase64,
|
||||
IsLiked: likeMap[reply.Id],
|
||||
}
|
||||
replyChannel <- commentContent // 发送到通道
|
||||
}(*reply)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(replyChannel) // 关闭通道
|
||||
}()
|
||||
|
||||
var repliesWithImages []CommentContent
|
||||
for replyContent := range replyChannel {
|
||||
repliesWithImages = append(repliesWithImages, replyContent) // 从通道获取回复内容
|
||||
}
|
||||
|
||||
response := CommentResponse{
|
||||
Comments: repliesWithImages,
|
||||
Size: page.Size,
|
||||
Current: page.Current,
|
||||
Total: page.Total,
|
||||
}
|
||||
result.OkWithData(response, c)
|
||||
return
|
||||
}
|
||||
|
||||
// CommentLikes 点赞评论
|
||||
// @Summary 点赞评论
|
||||
// @Description 点赞评论
|
||||
// @Tags 评论
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param comment_like_request body CommentLikeRequest true "点赞请求"
|
||||
// @Router /auth/comment/like [post]
|
||||
func (CommentController) CommentLikes(c *gin.Context) {
|
||||
likeRequest := CommentLikeRequest{}
|
||||
err := c.ShouldBindJSON(&likeRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
|
||||
likes := model.ScaCommentLikes{
|
||||
CommentId: likeRequest.CommentId,
|
||||
UserId: likeRequest.UserID,
|
||||
TopicId: likeRequest.TopicId,
|
||||
}
|
||||
|
||||
tx := global.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
res := global.DB.Create(&likes) // 假设这是插入数据库的方法
|
||||
if res.Error != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Errorln(res.Error)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步更新点赞计数
|
||||
go func() {
|
||||
if err = commentReplyService.UpdateCommentLikesCountService(likeRequest.CommentId, likeRequest.TopicId); err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
}
|
||||
}()
|
||||
marshal, err := json.Marshal(likes)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
mq.CommentLikeProducer(marshal)
|
||||
|
||||
tx.Commit()
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeSuccess"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// CancelCommentLikes 取消点赞评论
|
||||
// @Summary 取消点赞评论
|
||||
// @Description 取消点赞评论
|
||||
// @Tags 评论
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param comment_like_request body CommentLikeRequest true "取消点赞请求"
|
||||
// @Router /auth/comment/cancel_like [post]
|
||||
func (CommentController) CancelCommentLikes(c *gin.Context) {
|
||||
cancelLikeRequest := CommentLikeRequest{}
|
||||
if err := c.ShouldBindJSON(&cancelLikeRequest); err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
mx.Lock()
|
||||
defer mx.Unlock()
|
||||
|
||||
tx := global.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
query, u := gplus.NewQuery[model.ScaCommentLikes]()
|
||||
query.Eq(&u.CommentId, cancelLikeRequest.CommentId).
|
||||
Eq(&u.UserId, cancelLikeRequest.UserID).
|
||||
Eq(&u.TopicId, cancelLikeRequest.TopicId)
|
||||
|
||||
res := gplus.Delete[model.ScaCommentLikes](query)
|
||||
if res.Error != nil {
|
||||
tx.Rollback()
|
||||
return // 返回错误而非打印
|
||||
}
|
||||
|
||||
// 异步更新点赞计数
|
||||
go func() {
|
||||
if err := commentReplyService.DecrementCommentLikesCountService(cancelLikeRequest.CommentId, cancelLikeRequest.TopicId); err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
}
|
||||
}()
|
||||
tx.Commit()
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelSuccess"), c)
|
||||
return
|
||||
}
|
87
controller/comment_controller/handler.go
Normal file
87
controller/comment_controller/handler.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package comment_controller
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// base64ToBytes 将base64字符串转换为字节数组
|
||||
func base64ToBytes(base64Str string) ([]byte, error) {
|
||||
reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(base64Str))
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to decode base64 string")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// processImages 处理图片,将 base64 字符串转换为字节数组
|
||||
func processImages(images []string) ([][]byte, error) {
|
||||
var imagesData [][]byte
|
||||
dataChan := make(chan []byte, len(images)) // 创建一个带缓冲的 channel
|
||||
re := regexp.MustCompile(`^data:image/\w+;base64,`)
|
||||
|
||||
for _, img := range images {
|
||||
wg.Add(1) // 增加 WaitGroup 的计数
|
||||
go func(img string) {
|
||||
defer wg.Done() // 函数结束时减少计数
|
||||
|
||||
imgWithoutPrefix := re.ReplaceAllString(img, "")
|
||||
data, err := base64ToBytes(imgWithoutPrefix)
|
||||
if err != nil {
|
||||
return // 出错时直接返回
|
||||
}
|
||||
dataChan <- data // 将结果发送到 channel
|
||||
}(img)
|
||||
}
|
||||
|
||||
wg.Wait() // 等待所有 goroutine 完成
|
||||
close(dataChan) // 关闭 channel
|
||||
|
||||
for data := range dataChan { // 收集所有结果
|
||||
imagesData = append(imagesData, data)
|
||||
}
|
||||
|
||||
return imagesData, nil
|
||||
}
|
||||
|
||||
// getMimeType 获取 MIME 类型
|
||||
func getMimeType(data []byte) string {
|
||||
if len(data) < 4 {
|
||||
return "application/octet-stream" // 默认类型
|
||||
}
|
||||
|
||||
// 判断 JPEG
|
||||
if data[0] == 0xFF && data[1] == 0xD8 {
|
||||
return "image/jpeg"
|
||||
}
|
||||
|
||||
// 判断 PNG
|
||||
if len(data) >= 8 && data[0] == 0x89 && data[1] == 0x50 && data[2] == 0x4E && data[3] == 0x47 &&
|
||||
data[4] == 0x0D && data[5] == 0x0A && data[6] == 0x1A && data[7] == 0x0A {
|
||||
return "image/png"
|
||||
}
|
||||
|
||||
// 判断 GIF
|
||||
if len(data) >= 6 && data[0] == 'G' && data[1] == 'I' && data[2] == 'F' {
|
||||
return "image/gif"
|
||||
}
|
||||
// 判断 WEBP
|
||||
if len(data) >= 12 && data[0] == 0x52 && data[1] == 0x49 && data[2] == 0x46 && data[3] == 0x46 &&
|
||||
data[8] == 0x57 && data[9] == 0x45 && data[10] == 0x42 && data[11] == 0x50 {
|
||||
return "image/webp"
|
||||
}
|
||||
// 判断svg
|
||||
if len(data) >= 4 && data[0] == '<' && data[1] == '?' && data[2] == 'x' && data[3] == 'm' {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
// 判断JPG
|
||||
if len(data) >= 3 && data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF {
|
||||
return "image/jpeg"
|
||||
}
|
||||
|
||||
return "application/octet-stream" // 默认类型
|
||||
}
|
55
controller/comment_controller/request_param.go
Normal file
55
controller/comment_controller/request_param.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package comment_controller
|
||||
|
||||
type CommentRequest struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
Images []string `json:"images"`
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
TopicId string `json:"topic_id" binding:"required"`
|
||||
Author string `json:"author" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
Point []int64 `json:"point" binding:"required"`
|
||||
}
|
||||
type ReplyCommentRequest struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
Images []string `json:"images"`
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
TopicId string `json:"topic_id" binding:"required"`
|
||||
ReplyId int64 `json:"reply_id" binding:"required"`
|
||||
ReplyUser string `json:"reply_user" binding:"required"`
|
||||
Author string `json:"author" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
Point []int64 `json:"point" binding:"required"`
|
||||
}
|
||||
|
||||
type ReplyReplyRequest struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
Images []string `json:"images"`
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
TopicId string `json:"topic_id" binding:"required"`
|
||||
ReplyTo int64 `json:"reply_to" binding:"required"`
|
||||
ReplyId int64 `json:"reply_id" binding:"required"`
|
||||
ReplyUser string `json:"reply_user" binding:"required"`
|
||||
Author string `json:"author" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
Point []int64 `json:"point" binding:"required"`
|
||||
}
|
||||
|
||||
type CommentListRequest struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
TopicId string `json:"topic_id" binding:"required"`
|
||||
Page int `json:"page" default:"1"`
|
||||
Size int `json:"size" default:"5"`
|
||||
IsHot bool `json:"is_hot" default:"true"`
|
||||
}
|
||||
type ReplyListRequest struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
TopicId string `json:"topic_id" binding:"required"`
|
||||
CommentId int64 `json:"comment_id" binding:"required"`
|
||||
Page int `json:"page" default:"1"`
|
||||
Size int `json:"size" default:"5"`
|
||||
}
|
||||
type CommentLikeRequest struct {
|
||||
TopicId string `json:"topic_id" binding:"required"`
|
||||
CommentId int64 `json:"comment_id" binding:"required"`
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
}
|
29
controller/controller.go
Normal file
29
controller/controller.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"schisandra-cloud-album/controller/captcha_controller"
|
||||
"schisandra-cloud-album/controller/client_controller"
|
||||
"schisandra-cloud-album/controller/comment_controller"
|
||||
"schisandra-cloud-album/controller/oauth_controller"
|
||||
"schisandra-cloud-album/controller/permission_controller"
|
||||
"schisandra-cloud-album/controller/role_controller"
|
||||
"schisandra-cloud-album/controller/sms_controller"
|
||||
"schisandra-cloud-album/controller/user_controller"
|
||||
"schisandra-cloud-album/controller/websocket_controller"
|
||||
)
|
||||
|
||||
// Controllers 统一导出的控制器接口
|
||||
type Controllers struct {
|
||||
UserController user_controller.UserController
|
||||
CaptchaController captcha_controller.CaptchaController
|
||||
SmsController sms_controller.SmsController
|
||||
OAuthController oauth_controller.OAuthController
|
||||
WebsocketController websocket_controller.WebsocketController
|
||||
RoleController role_controller.RoleController
|
||||
PermissionController permission_controller.PermissionController
|
||||
ClientController client_controller.ClientController
|
||||
CommonController comment_controller.CommentController
|
||||
}
|
||||
|
||||
// Controller new函数实例化,实例化完成后会返回结构体地指针类型
|
||||
var Controller = new(Controllers)
|
254
controller/oauth_controller/gitee_controller.go
Normal file
254
controller/oauth_controller/gitee_controller.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package oauth_controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/yitter/idgenerator-go/idgen"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"schisandra-cloud-album/common/enum"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GiteeUser struct {
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Bio string `json:"bio"`
|
||||
Blog string `json:"blog"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Email string `json:"email"`
|
||||
EventsURL string `json:"events_url"`
|
||||
Followers int `json:"followers"`
|
||||
FollowersURL string `json:"followers_url"`
|
||||
Following int `json:"following"`
|
||||
FollowingURL string `json:"following_url"`
|
||||
GistsURL string `json:"gists_url"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
ID int `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
OrganizationsURL string `json:"organizations_url"`
|
||||
PublicGists int `json:"public_gists"`
|
||||
PublicRepos int `json:"public_repos"`
|
||||
ReceivedEventsURL string `json:"received_events_url"`
|
||||
Remark string `json:"remark"`
|
||||
ReposURL string `json:"repos_url"`
|
||||
Stared int `json:"stared"`
|
||||
StarredURL string `json:"starred_url"`
|
||||
SubscriptionsURL string `json:"subscriptions_url"`
|
||||
Type string `json:"type"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
URL string `json:"url"`
|
||||
Watched int `json:"watched"`
|
||||
Weibo interface{} `json:"weibo"`
|
||||
}
|
||||
|
||||
// GetGiteeRedirectUrl 获取Gitee登录地址
|
||||
// @Summary 获取Gitee登录地址
|
||||
// @Description 获取Gitee登录地址
|
||||
// @Tags Gitee OAuth
|
||||
// @Produce json
|
||||
// @Success 200 {string} string "登录地址"
|
||||
// @Router /controller/oauth/gitee/get_url [get]
|
||||
func (OAuthController) GetGiteeRedirectUrl(c *gin.Context) {
|
||||
clientID := global.CONFIG.OAuth.Gitee.ClientID
|
||||
redirectURI := global.CONFIG.OAuth.Gitee.RedirectURI
|
||||
url := "https://gitee.com/oauth/authorize?client_id=" + clientID + "&redirect_uri=" + redirectURI + "&response_type=code"
|
||||
result.OkWithData(url, c)
|
||||
return
|
||||
}
|
||||
|
||||
// GetGiteeTokenAuthUrl 获取Gitee token
|
||||
func GetGiteeTokenAuthUrl(code string) string {
|
||||
clientId := global.CONFIG.OAuth.Gitee.ClientID
|
||||
clientSecret := global.CONFIG.OAuth.Gitee.ClientSecret
|
||||
redirectURI := global.CONFIG.OAuth.Gitee.RedirectURI
|
||||
return fmt.Sprintf(
|
||||
"https://gitee.com/oauth/token?grant_type=authorization_code&code=%s&client_id=%s&redirect_uri=%s&client_secret=%s",
|
||||
code, clientId, redirectURI, clientSecret,
|
||||
)
|
||||
}
|
||||
|
||||
// GetGiteeToken 获取 token
|
||||
func GetGiteeToken(url string) (*Token, error) {
|
||||
|
||||
// 形成请求
|
||||
var req *http.Request
|
||||
var err error
|
||||
if req, err = http.NewRequest(http.MethodPost, url, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("accept", "application/json")
|
||||
|
||||
// 发送请求并获得响应
|
||||
var httpClient = http.Client{}
|
||||
var res *http.Response
|
||||
if res, err = httpClient.Do(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将响应体解析为 token,并返回
|
||||
var token Token
|
||||
if err = json.NewDecoder(res.Body).Decode(&token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetGiteeUserInfo 获取用户信息
|
||||
func GetGiteeUserInfo(token *Token) (map[string]interface{}, error) {
|
||||
|
||||
// 形成请求
|
||||
var userInfoUrl = "https://gitee.com/api/v5/user" // github用户信息获取接口
|
||||
var req *http.Request
|
||||
var err error
|
||||
if req, err = http.NewRequest(http.MethodGet, userInfoUrl, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("accept", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken))
|
||||
// 发送请求并获取响应
|
||||
var client = http.Client{}
|
||||
var res *http.Response
|
||||
if res, err = client.Do(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将响应的数据写入 userInfo 中,并返回
|
||||
var userInfo = make(map[string]interface{})
|
||||
if err = json.NewDecoder(res.Body).Decode(&userInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// GiteeCallback 处理Gitee回调
|
||||
// @Summary 处理Gitee回调
|
||||
// @Description 处理Gitee回调
|
||||
// @Tags Gitee OAuth
|
||||
// @Produce json
|
||||
// @Router /controller/oauth/gitee/callback [get]
|
||||
func (OAuthController) GiteeCallback(c *gin.Context) {
|
||||
var err error
|
||||
// 获取 code
|
||||
var code = c.Query("code")
|
||||
if code == "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步获取 token
|
||||
var tokenChan = make(chan *Token)
|
||||
var errChan = make(chan error)
|
||||
go func() {
|
||||
var tokenAuthUrl = GetGiteeTokenAuthUrl(code)
|
||||
token, err := GetGiteeToken(tokenAuthUrl)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
tokenChan <- token
|
||||
}()
|
||||
|
||||
// 异步获取用户信息
|
||||
var userInfoChan = make(chan map[string]interface{})
|
||||
go func() {
|
||||
token := <-tokenChan
|
||||
if token == nil {
|
||||
errChan <- errors.New("failed to get token")
|
||||
return
|
||||
}
|
||||
userInfo, err := GetGiteeUserInfo(token)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
userInfoChan <- userInfo
|
||||
}()
|
||||
|
||||
// 等待结果
|
||||
select {
|
||||
case err = <-errChan:
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
case userInfo := <-userInfoChan:
|
||||
userInfoBytes, err := json.Marshal(userInfo)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
var giteeUser GiteeUser
|
||||
err = json.Unmarshal(userInfoBytes, &giteeUser)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
Id := strconv.Itoa(giteeUser.ID)
|
||||
userSocial, err := userSocialService.QueryUserSocialByOpenIDService(Id, enum.OAuthSourceGitee)
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
db := global.DB
|
||||
tx := db.Begin() // 开始事务
|
||||
if tx.Error != nil {
|
||||
global.LOG.Error(tx.Error)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
// 第一次登录,创建用户
|
||||
uid := idgen.NextId()
|
||||
uidStr := strconv.FormatInt(uid, 10)
|
||||
user := model.ScaAuthUser{
|
||||
UID: &uidStr,
|
||||
Username: &giteeUser.Login,
|
||||
Nickname: &giteeUser.Name,
|
||||
Avatar: &giteeUser.AvatarURL,
|
||||
Blog: &giteeUser.Blog,
|
||||
Email: &giteeUser.Email,
|
||||
Gender: &enum.Male,
|
||||
}
|
||||
addUser, err := userService.AddUserService(user)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
gitee := enum.OAuthSourceGitee
|
||||
userSocial = model.ScaAuthUserSocial{
|
||||
UserID: &uidStr,
|
||||
OpenID: &Id,
|
||||
Source: &gitee,
|
||||
}
|
||||
err = userSocialService.AddUserSocialService(userSocial)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
HandleLoginResponse(c, *addUser.UID)
|
||||
} else {
|
||||
HandleLoginResponse(c, *userSocial.UserID)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
264
controller/oauth_controller/github_controller.go
Normal file
264
controller/oauth_controller/github_controller.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package oauth_controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/yitter/idgenerator-go/idgen"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"schisandra-cloud-album/common/enum"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type GitHubUser struct {
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Bio interface{} `json:"bio"`
|
||||
Blog string `json:"blog"`
|
||||
Company interface{} `json:"company"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Email string `json:"email"`
|
||||
EventsURL string `json:"events_url"`
|
||||
Followers int `json:"followers"`
|
||||
FollowersURL string `json:"followers_url"`
|
||||
Following int `json:"following"`
|
||||
FollowingURL string `json:"following_url"`
|
||||
GistsURL string `json:"gists_url"`
|
||||
GravatarID string `json:"gravatar_id"`
|
||||
Hireable interface{} `json:"hireable"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
ID int `json:"id"`
|
||||
Location interface{} `json:"location"`
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
NodeID string `json:"node_id"`
|
||||
NotificationEmail interface{} `json:"notification_email"`
|
||||
OrganizationsURL string `json:"organizations_url"`
|
||||
PublicGists int `json:"public_gists"`
|
||||
PublicRepos int `json:"public_repos"`
|
||||
ReceivedEventsURL string `json:"received_events_url"`
|
||||
ReposURL string `json:"repos_url"`
|
||||
SiteAdmin bool `json:"site_admin"`
|
||||
StarredURL string `json:"starred_url"`
|
||||
SubscriptionsURL string `json:"subscriptions_url"`
|
||||
TwitterUsername interface{} `json:"twitter_username"`
|
||||
Type string `json:"type"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// GetRedirectUrl 获取github登录url
|
||||
// @Summary 获取github登录url
|
||||
// @Description 获取github登录url
|
||||
// @Tags Github OAuth
|
||||
// @Produce json
|
||||
// @Success 200 {string} string "登录url"
|
||||
// @Router /controller/oauth/github/get_url [get]
|
||||
func (OAuthController) GetRedirectUrl(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
clientId := global.CONFIG.OAuth.Github.ClientID
|
||||
redirectUrl := global.CONFIG.OAuth.Github.RedirectURI
|
||||
url := "https://github.com/login/oauth/authorize?client_id=" + clientId + "&redirect_uri=" + redirectUrl + "&state=" + state
|
||||
result.OkWithData(url, c)
|
||||
return
|
||||
}
|
||||
|
||||
// GetTokenAuthUrl 通过code获取token认证url
|
||||
func GetTokenAuthUrl(code string) string {
|
||||
clientId := global.CONFIG.OAuth.Github.ClientID
|
||||
clientSecret := global.CONFIG.OAuth.Github.ClientSecret
|
||||
return fmt.Sprintf(
|
||||
"https://github.com/login/oauth/access_token?client_id=%s&client_secret=%s&code=%s",
|
||||
clientId, clientSecret, code,
|
||||
)
|
||||
}
|
||||
|
||||
// GetToken 获取 token
|
||||
func GetToken(url string) (*Token, error) {
|
||||
|
||||
// 形成请求
|
||||
var req *http.Request
|
||||
var err error
|
||||
if req, err = http.NewRequest(http.MethodGet, url, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("accept", "application/json")
|
||||
|
||||
// 发送请求并获得响应
|
||||
var httpClient = http.Client{}
|
||||
var res *http.Response
|
||||
if res, err = httpClient.Do(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将响应体解析为 token,并返回
|
||||
var token Token
|
||||
if err = json.NewDecoder(res.Body).Decode(&token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
func GetUserInfo(token *Token) (map[string]interface{}, error) {
|
||||
|
||||
// 形成请求
|
||||
var userInfoUrl = "https://api.github.com/user" // github用户信息获取接口
|
||||
var req *http.Request
|
||||
var err error
|
||||
if req, err = http.NewRequest(http.MethodGet, userInfoUrl, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("accept", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken))
|
||||
|
||||
// 发送请求并获取响应
|
||||
var client = http.Client{}
|
||||
var res *http.Response
|
||||
if res, err = client.Do(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将响应的数据写入 userInfo 中,并返回
|
||||
var userInfo = make(map[string]interface{})
|
||||
if err = json.NewDecoder(res.Body).Decode(&userInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// Callback 登录回调函数
|
||||
// @Summary 登录回调函数
|
||||
// @Description 登录回调函数
|
||||
// @Tags Github OAuth
|
||||
// @Produce json
|
||||
// @Param code query string true "code"
|
||||
// @Success 200 {string} string "登录成功"
|
||||
// @Router /controller/oauth/github/callback [get]
|
||||
func (OAuthController) Callback(c *gin.Context) {
|
||||
var err error
|
||||
// 获取 code
|
||||
var code = c.Query("code")
|
||||
if code == "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用channel来接收异步操作的结果
|
||||
tokenChan := make(chan *Token)
|
||||
userInfoChan := make(chan map[string]interface{})
|
||||
errChan := make(chan error)
|
||||
|
||||
// 异步获取token
|
||||
go func() {
|
||||
var tokenAuthUrl = GetTokenAuthUrl(code)
|
||||
token, err := GetToken(tokenAuthUrl)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
tokenChan <- token
|
||||
}()
|
||||
|
||||
// 异步获取用户信息
|
||||
go func() {
|
||||
token := <-tokenChan
|
||||
if token == nil {
|
||||
return
|
||||
}
|
||||
userInfo, err := GetUserInfo(token)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
userInfoChan <- userInfo
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-errChan:
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
case userInfo := <-userInfoChan:
|
||||
if userInfo == nil {
|
||||
global.LOG.Error(<-errChan)
|
||||
return
|
||||
}
|
||||
// 继续处理用户信息
|
||||
userInfoBytes, err := json.Marshal(<-userInfoChan)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
var gitHubUser GitHubUser
|
||||
err = json.Unmarshal(userInfoBytes, &gitHubUser)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
Id := strconv.Itoa(gitHubUser.ID)
|
||||
userSocial, err := userSocialService.QueryUserSocialByOpenIDService(Id, enum.OAuthSourceGithub)
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
db := global.DB
|
||||
tx := db.Begin() // 开始事务
|
||||
if tx.Error != nil {
|
||||
global.LOG.Error(tx.Error)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
// 第一次登录,创建用户
|
||||
uid := idgen.NextId()
|
||||
uidStr := strconv.FormatInt(uid, 10)
|
||||
user := model.ScaAuthUser{
|
||||
UID: &uidStr,
|
||||
Username: &gitHubUser.Login,
|
||||
Nickname: &gitHubUser.Name,
|
||||
Avatar: &gitHubUser.AvatarURL,
|
||||
Blog: &gitHubUser.Blog,
|
||||
Email: &gitHubUser.Email,
|
||||
Gender: &enum.Male,
|
||||
}
|
||||
addUser, err := userService.AddUserService(user)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
github := enum.OAuthSourceGithub
|
||||
userSocial = model.ScaAuthUserSocial{
|
||||
UserID: &uidStr,
|
||||
OpenID: &Id,
|
||||
Source: &github,
|
||||
}
|
||||
err = userSocialService.AddUserSocialService(userSocial)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
HandleLoginResponse(c, *addUser.UID)
|
||||
} else {
|
||||
HandleLoginResponse(c, *userSocial.UserID)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
193
controller/oauth_controller/oauth.go
Normal file
193
controller/oauth_controller/oauth.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package oauth_controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mssola/useragent"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"schisandra-cloud-album/service/impl"
|
||||
"schisandra-cloud-album/utils"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
type OAuthController struct{}
|
||||
|
||||
var userSocialService = impl.UserSocialServiceImpl{}
|
||||
var userService = impl.UserServiceImpl{}
|
||||
var userDeviceService = impl.UserDeviceServiceImpl{}
|
||||
|
||||
type Token struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
var script = `
|
||||
<script>
|
||||
window.opener.postMessage('%s', '%s');
|
||||
window.close();
|
||||
</script>
|
||||
`
|
||||
|
||||
func HandleLoginResponse(c *gin.Context, uid string) {
|
||||
res, data := HandelUserLogin(uid, c)
|
||||
if !res {
|
||||
return
|
||||
}
|
||||
|
||||
tokenData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
formattedScript := fmt.Sprintf(script, tokenData, global.CONFIG.System.WebURL())
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(formattedScript))
|
||||
return
|
||||
}
|
||||
|
||||
// HandelUserLogin 处理用户登录
|
||||
func HandelUserLogin(userId string, c *gin.Context) (bool, result.Response) {
|
||||
// 使用goroutine生成accessToken
|
||||
accessTokenChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: &userId})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
accessTokenChan <- accessToken
|
||||
}()
|
||||
|
||||
// 使用goroutine生成refreshToken
|
||||
refreshTokenChan := make(chan string)
|
||||
expiresAtChan := make(chan int64)
|
||||
go func() {
|
||||
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7)
|
||||
refreshTokenChan <- refreshToken
|
||||
expiresAtChan <- expiresAt
|
||||
}()
|
||||
|
||||
// 等待accessToken和refreshToken生成完成
|
||||
var accessToken string
|
||||
var refreshToken string
|
||||
var expiresAt int64
|
||||
var err error
|
||||
select {
|
||||
case accessToken = <-accessTokenChan:
|
||||
case err = <-errChan:
|
||||
global.LOG.Error(err)
|
||||
return false, result.Response{}
|
||||
}
|
||||
select {
|
||||
case refreshToken = <-refreshTokenChan:
|
||||
case expiresAt = <-expiresAtChan:
|
||||
}
|
||||
|
||||
data := ResponseData{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
UID: &userId,
|
||||
}
|
||||
wrong := utils.SetSession(c, constant.SessionKey, data)
|
||||
if wrong != nil {
|
||||
return false, result.Response{}
|
||||
}
|
||||
// 使用goroutine将数据存入redis
|
||||
redisErrChan := make(chan error)
|
||||
go func() {
|
||||
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err()
|
||||
if fail != nil {
|
||||
redisErrChan <- fail
|
||||
return
|
||||
}
|
||||
redisErrChan <- nil
|
||||
}()
|
||||
|
||||
// 等待redis操作完成
|
||||
redisErr := <-redisErrChan
|
||||
if redisErr != nil {
|
||||
global.LOG.Error(redisErr)
|
||||
return false, result.Response{}
|
||||
}
|
||||
responseData := result.Response{
|
||||
Data: data,
|
||||
Message: "success",
|
||||
Code: 200,
|
||||
Success: true,
|
||||
}
|
||||
return true, responseData
|
||||
}
|
||||
|
||||
// GetUserLoginDevice 获取用户登录设备
|
||||
func (OAuthController) GetUserLoginDevice(c *gin.Context) {
|
||||
userId := c.Query("user_id")
|
||||
if userId == "" {
|
||||
return
|
||||
}
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if userAgent == "" {
|
||||
global.LOG.Errorln("user-agent is empty")
|
||||
return
|
||||
}
|
||||
ua := useragent.New(userAgent)
|
||||
|
||||
ip := utils.GetClientIP(c)
|
||||
location, err := global.IP2Location.SearchByStr(ip)
|
||||
location = utils.RemoveZeroAndAdjust(location)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
isBot := ua.Bot()
|
||||
browser, browserVersion := ua.Browser()
|
||||
os := ua.OS()
|
||||
mobile := ua.Mobile()
|
||||
mozilla := ua.Mozilla()
|
||||
platform := ua.Platform()
|
||||
engine, engineVersion := ua.Engine()
|
||||
device := model.ScaAuthUserDevice{
|
||||
UserID: &userId,
|
||||
IP: &ip,
|
||||
Location: &location,
|
||||
Agent: userAgent,
|
||||
Browser: &browser,
|
||||
BrowserVersion: &browserVersion,
|
||||
OperatingSystem: &os,
|
||||
Mobile: &mobile,
|
||||
Bot: &isBot,
|
||||
Mozilla: &mozilla,
|
||||
Platform: &platform,
|
||||
EngineName: &engine,
|
||||
EngineVersion: &engineVersion,
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
userDevice, err := userDeviceService.GetUserDeviceByUIDIPAgentService(userId, ip, userAgent)
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
err = userDeviceService.AddUserDeviceService(&device)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
return
|
||||
} else {
|
||||
err := userDeviceService.UpdateUserDeviceService(userDevice.ID, &device)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
304
controller/oauth_controller/qq_controller.go
Normal file
304
controller/oauth_controller/qq_controller.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package oauth_controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/yitter/idgenerator-go/idgen"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"schisandra-cloud-album/common/enum"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type AuthQQme struct {
|
||||
ClientID string `json:"client_id"`
|
||||
OpenID string `json:"openid"`
|
||||
}
|
||||
type QQToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpireIn string `json:"expire_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type QQUserInfo struct {
|
||||
City string `json:"city"`
|
||||
Figureurl string `json:"figureurl"`
|
||||
Figureurl1 string `json:"figureurl_1"`
|
||||
Figureurl2 string `json:"figureurl_2"`
|
||||
FigureurlQq string `json:"figureurl_qq"`
|
||||
FigureurlQq1 string `json:"figureurl_qq_1"`
|
||||
FigureurlQq2 string `json:"figureurl_qq_2"`
|
||||
Gender string `json:"gender"`
|
||||
GenderType int `json:"gender_type"`
|
||||
IsLost int `json:"is_lost"`
|
||||
IsYellowVip string `json:"is_yellow_vip"`
|
||||
IsYellowYearVip string `json:"is_yellow_year_vip"`
|
||||
Level string `json:"level"`
|
||||
Msg string `json:"msg"`
|
||||
Nickname string `json:"nickname"`
|
||||
Province string `json:"province"`
|
||||
Ret int `json:"ret"`
|
||||
Vip string `json:"vip"`
|
||||
Year string `json:"year"`
|
||||
YellowVipLevel string `json:"yellow_vip_level"`
|
||||
}
|
||||
|
||||
// GetQQRedirectUrl 获取登录地址
|
||||
// @Summary 获取QQ登录地址
|
||||
// @Description 获取QQ登录地址
|
||||
// @Tags QQ OAuth
|
||||
// @Produce json
|
||||
// @Success 200 {string} string "登录地址"
|
||||
// @Router /controller/oauth/qq/get_url [get]
|
||||
func (OAuthController) GetQQRedirectUrl(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
clientId := global.CONFIG.OAuth.QQ.ClientID
|
||||
redirectURI := global.CONFIG.OAuth.QQ.RedirectURI
|
||||
url := "https://graph.qq.com/oauth2.0/authorize?response_type=code&client_id=" + clientId + "&redirect_uri=" + redirectURI + "&state=" + state
|
||||
result.OkWithData(url, c)
|
||||
return
|
||||
}
|
||||
|
||||
// GetQQTokenAuthUrl 通过code获取token认证url
|
||||
func GetQQTokenAuthUrl(code string) string {
|
||||
clientId := global.CONFIG.OAuth.QQ.ClientID
|
||||
clientSecret := global.CONFIG.OAuth.QQ.ClientSecret
|
||||
redirectURI := global.CONFIG.OAuth.QQ.RedirectURI
|
||||
return fmt.Sprintf(
|
||||
"https://graph.qq.com/oauth2.0/token?grant_type=authorization_code&client_id=%s&client_secret=%s&code=%s&redirect_uri=%s&fmt=json",
|
||||
clientId, clientSecret, code, redirectURI,
|
||||
)
|
||||
}
|
||||
|
||||
// GetQQToken 获取 token
|
||||
func GetQQToken(url string) (*QQToken, error) {
|
||||
|
||||
// 形成请求
|
||||
var req *http.Request
|
||||
var err error
|
||||
if req, err = http.NewRequest(http.MethodGet, url, nil); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 发送请求并获得响应
|
||||
var httpClient = http.Client{}
|
||||
var res *http.Response
|
||||
if res, err = httpClient.Do(req); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
//将响应体解析为 token,并返回
|
||||
var token QQToken
|
||||
if err = json.NewDecoder(res.Body).Decode(&token); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetQQUserOpenID 获取用户 openid
|
||||
func GetQQUserOpenID(token *QQToken) (*AuthQQme, error) {
|
||||
|
||||
// 形成请求
|
||||
var userInfoUrl = "https://graph.qq.com/oauth2.0/me?access_token=" + token.AccessToken + "&fmt=json"
|
||||
var req *http.Request
|
||||
var err error
|
||||
if req, err = http.NewRequest(http.MethodGet, userInfoUrl, nil); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
// 发送请求并获取响应
|
||||
var client = http.Client{}
|
||||
var res *http.Response
|
||||
if res, err = client.Do(req); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将响应体解析为 AuthQQme,并返回
|
||||
var authQQme AuthQQme
|
||||
if err = json.NewDecoder(res.Body).Decode(&authQQme); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
return &authQQme, nil
|
||||
}
|
||||
|
||||
// GetQQUserUserInfo 获取用户信息
|
||||
func GetQQUserUserInfo(token *QQToken, openId string) (map[string]interface{}, error) {
|
||||
|
||||
clientId := global.CONFIG.OAuth.QQ.ClientID
|
||||
// 形成请求
|
||||
var userInfoUrl = "https://graph.qq.com/user/get_user_info?access_token=" + token.AccessToken + "&oauth_consumer_key=" + clientId + "&openid=" + openId
|
||||
var req *http.Request
|
||||
var err error
|
||||
if req, err = http.NewRequest(http.MethodGet, userInfoUrl, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 发送请求并获取响应
|
||||
var client = http.Client{}
|
||||
var res *http.Response
|
||||
if res, err = client.Do(req); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将响应的数据写入 userInfo 中,并返回
|
||||
var userInfo = make(map[string]interface{})
|
||||
if err = json.NewDecoder(res.Body).Decode(&userInfo); err != nil {
|
||||
global.LOG.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// QQCallback QQ登录回调
|
||||
// @Summary QQ登录回调
|
||||
// @Description QQ登录回调
|
||||
// @Tags QQ OAuth
|
||||
// @Produce json
|
||||
// @Router /controller/oauth/qq/callback [get]
|
||||
func (OAuthController) QQCallback(c *gin.Context) {
|
||||
var err error
|
||||
// 获取 code
|
||||
var code = c.Query("code")
|
||||
if code == "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 通过 code, 获取 token
|
||||
var tokenAuthUrl = GetQQTokenAuthUrl(code)
|
||||
tokenChan := make(chan *QQToken)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
token, err := GetQQToken(tokenAuthUrl)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
tokenChan <- token
|
||||
}()
|
||||
var token *QQToken
|
||||
select {
|
||||
case token = <-tokenChan:
|
||||
case err = <-errChan:
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 通过 token,获取 openid
|
||||
openIDChan := make(chan *AuthQQme)
|
||||
errChan = make(chan error)
|
||||
go func() {
|
||||
authQQme, err := GetQQUserOpenID(token)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
openIDChan <- authQQme
|
||||
}()
|
||||
var authQQme *AuthQQme
|
||||
select {
|
||||
case authQQme = <-openIDChan:
|
||||
case err = <-errChan:
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 通过token,获取用户信息
|
||||
userInfoChan := make(chan map[string]interface{})
|
||||
errChan = make(chan error)
|
||||
go func() {
|
||||
userInfo, err := GetQQUserUserInfo(token, authQQme.OpenID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
userInfoChan <- userInfo
|
||||
}()
|
||||
var userInfo map[string]interface{}
|
||||
select {
|
||||
case userInfo = <-userInfoChan:
|
||||
case err = <-errChan:
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
userInfoBytes, err := json.Marshal(userInfo)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
var qqUserInfo QQUserInfo
|
||||
err = json.Unmarshal(userInfoBytes, &qqUserInfo)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
userSocial, err := userSocialService.QueryUserSocialByOpenIDService(authQQme.OpenID, enum.OAuthSourceQQ)
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
db := global.DB
|
||||
tx := db.Begin() // 开始事务
|
||||
if tx.Error != nil {
|
||||
global.LOG.Error(tx.Error)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
// 第一次登录,创建用户
|
||||
uid := idgen.NextId()
|
||||
uidStr := strconv.FormatInt(uid, 10)
|
||||
user := model.ScaAuthUser{
|
||||
UID: &uidStr,
|
||||
Username: &authQQme.OpenID,
|
||||
Nickname: &qqUserInfo.Nickname,
|
||||
Avatar: &qqUserInfo.FigureurlQq1,
|
||||
Gender: &qqUserInfo.Gender,
|
||||
}
|
||||
addUser, err := userService.AddUserService(user)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
qq := enum.OAuthSourceQQ
|
||||
userSocial = model.ScaAuthUserSocial{
|
||||
UserID: &uidStr,
|
||||
OpenID: &authQQme.OpenID,
|
||||
Source: &qq,
|
||||
}
|
||||
err = userSocialService.AddUserSocialService(userSocial)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
HandleLoginResponse(c, *addUser.UID)
|
||||
return
|
||||
} else {
|
||||
HandleLoginResponse(c, *userSocial.UserID)
|
||||
}
|
||||
}
|
19
controller/oauth_controller/request_param.go
Normal file
19
controller/oauth_controller/request_param.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package oauth_controller
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ResponseData 返回数据
|
||||
type ResponseData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
UID *string `json:"uid"`
|
||||
}
|
||||
|
||||
func (res ResponseData) MarshalBinary() ([]byte, error) {
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
||||
func (res ResponseData) UnmarshalBinary(data []byte) error {
|
||||
return json.Unmarshal(data, &res)
|
||||
}
|
332
controller/oauth_controller/wechat_controller.go
Normal file
332
controller/oauth_controller/wechat_controller.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package oauth_controller
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/ArtisanCloud/PowerLibs/v3/http/helper"
|
||||
"github.com/ArtisanCloud/PowerWeChat/v3/src/basicService/qrCode/response"
|
||||
"github.com/ArtisanCloud/PowerWeChat/v3/src/kernel/contract"
|
||||
"github.com/ArtisanCloud/PowerWeChat/v3/src/kernel/messages"
|
||||
models2 "github.com/ArtisanCloud/PowerWeChat/v3/src/kernel/models"
|
||||
"github.com/ArtisanCloud/PowerWeChat/v3/src/officialAccount/server/handlers/models"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/yitter/idgenerator-go/idgen"
|
||||
"gorm.io/gorm"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/enum"
|
||||
"schisandra-cloud-album/common/randomname"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/controller/websocket_controller"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"schisandra-cloud-album/utils"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CallbackNotify 微信回调
|
||||
// @Summary 微信回调
|
||||
// @Tags 微信公众号
|
||||
// @Description 微信回调
|
||||
// @Produce json
|
||||
// @Router /controller/oauth/callback_notify [POST]
|
||||
func (OAuthController) CallbackNotify(c *gin.Context) {
|
||||
rs, err := global.Wechat.Server.Notify(c.Request, func(event contract.EventInterface) interface{} {
|
||||
switch event.GetMsgType() {
|
||||
case models2.CALLBACK_MSG_TYPE_EVENT:
|
||||
switch event.GetEvent() {
|
||||
case models.CALLBACK_EVENT_SUBSCRIBE:
|
||||
msg := models.EventSubscribe{}
|
||||
err := event.ReadMessage(&msg)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return "error"
|
||||
}
|
||||
key := strings.TrimPrefix(msg.EventKey, "qrscene_")
|
||||
res := wechatLoginHandler(msg.FromUserName, key, c)
|
||||
if !res {
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed"))
|
||||
}
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginSuccess"))
|
||||
|
||||
case models.CALLBACK_EVENT_UNSUBSCRIBE:
|
||||
msg := models.EventUnSubscribe{}
|
||||
err := event.ReadMessage(&msg)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return "error"
|
||||
}
|
||||
return messages.NewText("ok")
|
||||
|
||||
case models.CALLBACK_EVENT_SCAN:
|
||||
msg := models.EventScan{}
|
||||
err := event.ReadMessage(&msg)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return "error"
|
||||
}
|
||||
res := wechatLoginHandler(msg.FromUserName, msg.EventKey, c)
|
||||
if !res {
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed"))
|
||||
}
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginSuccess"))
|
||||
|
||||
}
|
||||
|
||||
case models2.CALLBACK_MSG_TYPE_TEXT:
|
||||
msg := models.MessageText{}
|
||||
err := event.ReadMessage(&msg)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return "error"
|
||||
}
|
||||
}
|
||||
return messages.NewText("ok")
|
||||
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = helper.HttpResponseSend(rs, c.Writer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// CallbackVerify 微信回调验证
|
||||
// @Summary 微信回调验证
|
||||
// @Tags 微信公众号
|
||||
// @Description 微信回调验证
|
||||
// @Produce json
|
||||
// @Router /controller/oauth/callback_verify [get]
|
||||
func (OAuthController) CallbackVerify(c *gin.Context) {
|
||||
rs, err := global.Wechat.Server.VerifyURL(c.Request)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = helper.HttpResponseSend(rs, c.Writer)
|
||||
}
|
||||
|
||||
// GetTempQrCode 获取临时二维码
|
||||
// @Summary 获取临时二维码
|
||||
// @Tags 微信公众号
|
||||
// @Description 获取临时二维码
|
||||
// @Produce json
|
||||
// @Param client_id query string true "客户端ID"
|
||||
// @Router /controller/oauth/get_temp_qrcode [get]
|
||||
func (OAuthController) GetTempQrCode(c *gin.Context) {
|
||||
clientId := c.Query("client_id")
|
||||
if clientId == "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
ip := utils.GetClientIP(c) // 使用工具函数获取客户端IP
|
||||
key := constant.UserLoginQrcodeRedisKey + ip
|
||||
|
||||
// 从Redis获取二维码数据
|
||||
qrcode := redis.Get(key).Val()
|
||||
if qrcode != "" {
|
||||
data := new(response.ResponseQRCodeCreate)
|
||||
if err := json.Unmarshal([]byte(qrcode), data); err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成临时二维码
|
||||
data, err := global.Wechat.QRCode.Temporary(c.Request.Context(), clientId, 7*24*3600)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 序列化数据并存储到Redis
|
||||
serializedData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
|
||||
return
|
||||
}
|
||||
if err := redis.Set(key, serializedData, time.Hour*24*7).Err(); err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
|
||||
return
|
||||
}
|
||||
|
||||
result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c)
|
||||
}
|
||||
|
||||
// wechatLoginHandler 微信登录处理
|
||||
func wechatLoginHandler(openId string, clientId string, c *gin.Context) bool {
|
||||
if openId == "" {
|
||||
return false
|
||||
}
|
||||
authUserSocial, err := userSocialService.QueryUserSocialByOpenIDService(openId, enum.OAuthSourceWechat)
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
tx := global.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
uid := idgen.NextId()
|
||||
uidStr := strconv.FormatInt(uid, 10)
|
||||
avatar, err := utils.GenerateAvatar(uidStr)
|
||||
name := randomname.GenerateName()
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return false
|
||||
}
|
||||
createUser := model.ScaAuthUser{
|
||||
UID: &uidStr,
|
||||
Username: &openId,
|
||||
Avatar: &avatar,
|
||||
Nickname: &name,
|
||||
Gender: &enum.Male,
|
||||
}
|
||||
|
||||
// 异步添加用户
|
||||
addUserChan := make(chan *model.ScaAuthUser, 1)
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
addUser, err := userService.AddUserService(createUser)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
addUserChan <- addUser
|
||||
}()
|
||||
|
||||
var addUser *model.ScaAuthUser
|
||||
select {
|
||||
case addUser = <-addUserChan:
|
||||
case err := <-errChan:
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return false
|
||||
}
|
||||
|
||||
wechat := enum.OAuthSourceWechat
|
||||
userSocial := model.ScaAuthUserSocial{
|
||||
UserID: &uidStr,
|
||||
OpenID: &openId,
|
||||
Source: &wechat,
|
||||
}
|
||||
|
||||
// 异步添加用户社交信息
|
||||
wrongChan := make(chan error, 1)
|
||||
go func() {
|
||||
wrong := userSocialService.AddUserSocialService(userSocial)
|
||||
wrongChan <- wrong
|
||||
}()
|
||||
|
||||
select {
|
||||
case wrong := <-wrongChan:
|
||||
if wrong != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(wrong)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 异步添加角色
|
||||
roleErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := global.Casbin.AddRoleForUser(uidStr, enum.User)
|
||||
roleErrChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-roleErrChan:
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
global.LOG.Error(err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 异步处理用户登录
|
||||
resChan := make(chan bool, 1)
|
||||
go func() {
|
||||
res := handelUserLogin(*addUser.UID, clientId, c)
|
||||
resChan <- res
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-resChan:
|
||||
if !res {
|
||||
tx.Rollback()
|
||||
return false
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
return true
|
||||
} else {
|
||||
res := handelUserLogin(*authUserSocial.UserID, clientId, c)
|
||||
if !res {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// handelUserLogin 处理用户登录
|
||||
func handelUserLogin(userId string, clientId string, c *gin.Context) bool {
|
||||
resultChan := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: &userId})
|
||||
if err != nil {
|
||||
resultChan <- false
|
||||
return
|
||||
}
|
||||
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: &userId}, time.Hour*24*7)
|
||||
data := ResponseData{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
UID: &userId,
|
||||
}
|
||||
fail := redis.Set(constant.UserLoginTokenRedisKey+userId, data, time.Hour*24*7).Err()
|
||||
if fail != nil {
|
||||
resultChan <- false
|
||||
return
|
||||
}
|
||||
responseData := result.Response{
|
||||
Data: data,
|
||||
Message: "success",
|
||||
Code: 200,
|
||||
Success: true,
|
||||
}
|
||||
tokenData, err := json.Marshal(responseData)
|
||||
if err != nil {
|
||||
resultChan <- false
|
||||
return
|
||||
}
|
||||
gob.Register(ResponseData{})
|
||||
wrong := utils.SetSession(c, constant.SessionKey, data)
|
||||
if wrong != nil {
|
||||
resultChan <- false
|
||||
return
|
||||
}
|
||||
// gws方式发送消息
|
||||
err = websocket_controller.Handler.SendMessageToClient(clientId, tokenData)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
resultChan <- false
|
||||
return
|
||||
}
|
||||
resultChan <- true
|
||||
}()
|
||||
|
||||
return <-resultChan
|
||||
}
|
3
controller/permission_controller/permission.go
Normal file
3
controller/permission_controller/permission.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package permission_controller
|
||||
|
||||
type PermissionController struct{}
|
86
controller/permission_controller/permission_controller.go
Normal file
86
controller/permission_controller/permission_controller.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package permission_controller
|
||||
|
||||
import (
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/service/impl"
|
||||
)
|
||||
|
||||
var permissionService = impl.PermissionServiceImpl{}
|
||||
|
||||
// AddPermissions 批量添加权限
|
||||
// @Summary 批量添加权限
|
||||
// @Description 批量添加权限
|
||||
// @Tags 权限管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param permissions body AddPermissionRequest true "权限列表"
|
||||
// @Router /controller/auth/permission/add [post]
|
||||
func (PermissionController) AddPermissions(c *gin.Context) {
|
||||
addPermissionRequest := AddPermissionRequest{}
|
||||
err := c.ShouldBind(&addPermissionRequest.Permissions)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
err = permissionService.CreatePermissionsService(addPermissionRequest.Permissions)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CreatedFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CreatedSuccess"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// AssignPermissionsToRole 给指定角色分配权限
|
||||
// @Summary 给指定角色分配权限
|
||||
// @Description 给指定角色分配权限
|
||||
// @Tags 权限管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param permissions body AddPermissionToRoleRequest true "权限列表"
|
||||
// @Router /controller/auth/permission/assign [post]
|
||||
func (PermissionController) AssignPermissionsToRole(c *gin.Context) {
|
||||
permissionToRoleRequest := AddPermissionToRoleRequest{}
|
||||
|
||||
err := c.ShouldBind(&permissionToRoleRequest)
|
||||
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "AssignFailed"), c)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := global.Casbin.AddPolicy(permissionToRoleRequest.RoleKey, permissionToRoleRequest.Permission, permissionToRoleRequest.Method)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "AssignFailed"), c)
|
||||
return
|
||||
}
|
||||
if policy == false {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "AssignFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "AssignSuccess"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// GetUserPermissions 获取用户角色权限
|
||||
func (PermissionController) GetUserPermissions(c *gin.Context) {
|
||||
getPermissionRequest := GetPermissionRequest{}
|
||||
err := c.ShouldBindJSON(&getPermissionRequest)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
data, err := global.Casbin.GetImplicitRolesForUser(getPermissionRequest.UserId)
|
||||
if err != nil {
|
||||
result.FailWithMessage("Get user permissions failed", c)
|
||||
return
|
||||
}
|
||||
result.OkWithData(data, c)
|
||||
return
|
||||
}
|
18
controller/permission_controller/request_param.go
Normal file
18
controller/permission_controller/request_param.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package permission_controller
|
||||
|
||||
import "schisandra-cloud-album/model"
|
||||
|
||||
// AddPermissionToRoleRequest 添加权限请求
|
||||
type AddPermissionRequest struct {
|
||||
Permissions []model.ScaAuthPermission `form:"permissions[]" json:"permissions"`
|
||||
}
|
||||
|
||||
// AddPermissionToRoleRequest 添加权限到角色请求
|
||||
type AddPermissionToRoleRequest struct {
|
||||
RoleKey string `json:"role_key"`
|
||||
Permission string `json:"permission"`
|
||||
Method string `json:"method"`
|
||||
}
|
||||
type GetPermissionRequest struct {
|
||||
UserId string `json:"user_id" binding:"required"`
|
||||
}
|
11
controller/role_controller/request_param.go
Normal file
11
controller/role_controller/request_param.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package role_controller
|
||||
|
||||
type RoleRequest struct {
|
||||
RoleName string `json:"role_name" binding:"required"`
|
||||
RoleKey string `json:"role_key" binding:"required"`
|
||||
}
|
||||
|
||||
type AddRoleToUserRequest struct {
|
||||
Uid string `json:"uid" binding:"required"`
|
||||
RoleKey string `json:"role_key" binding:"required"`
|
||||
}
|
3
controller/role_controller/role.go
Normal file
3
controller/role_controller/role.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package role_controller
|
||||
|
||||
type RoleController struct{}
|
63
controller/role_controller/role_controller.go
Normal file
63
controller/role_controller/role_controller.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package role_controller
|
||||
|
||||
import (
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"schisandra-cloud-album/service/impl"
|
||||
)
|
||||
|
||||
var roleService = impl.RoleServiceImpl{}
|
||||
|
||||
// CreateRole 创建角色
|
||||
// @Summary 创建角色
|
||||
// @Description 创建角色
|
||||
// @Tags 角色
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param roleRequestDto body RoleRequest true "角色信息"
|
||||
// @Router /controller/auth/role/create [post]
|
||||
func (RoleController) CreateRole(c *gin.Context) {
|
||||
roleRequest := RoleRequest{}
|
||||
err := c.ShouldBindJSON(&roleRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CreatedFailed"), c)
|
||||
return
|
||||
}
|
||||
role := model.ScaAuthRole{
|
||||
RoleName: roleRequest.RoleName,
|
||||
RoleKey: roleRequest.RoleKey,
|
||||
}
|
||||
err = roleService.AddRoleService(role)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CreatedFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CreatedSuccess"), c)
|
||||
}
|
||||
|
||||
// AddRoleToUser 给指定用户添加角色
|
||||
// @Summary 给指定用户添加角色
|
||||
// @Description 给指定用户添加角色
|
||||
// @Tags 角色
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param addRoleToUserRequestDto body AddRoleToUserRequest true "给指定用户添加角色"
|
||||
// @Router /controller/auth/role/add_role_to_user [post]
|
||||
func (RoleController) AddRoleToUser(c *gin.Context) {
|
||||
addRoleToUserRequest := AddRoleToUserRequest{}
|
||||
err := c.ShouldBindJSON(&addRoleToUserRequest)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
user, err := global.Casbin.AddRoleForUser(addRoleToUserRequest.Uid, addRoleToUserRequest.RoleKey)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
return
|
||||
}
|
||||
result.OkWithData(user, c)
|
||||
}
|
7
controller/sms_controller/request_param.go
Normal file
7
controller/sms_controller/request_param.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package sms_controller
|
||||
|
||||
type SmsRequest struct {
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
Angle int64 `json:"angle" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
}
|
3
controller/sms_controller/sms.go
Normal file
3
controller/sms_controller/sms.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package sms_controller
|
||||
|
||||
type SmsController struct{}
|
162
controller/sms_controller/sms_controller.go
Normal file
162
controller/sms_controller/sms_controller.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package sms_controller
|
||||
|
||||
import (
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
gosms "github.com/pkg6/go-sms"
|
||||
"github.com/pkg6/go-sms/gateways"
|
||||
"github.com/pkg6/go-sms/gateways/aliyun"
|
||||
"github.com/pkg6/go-sms/gateways/smsbao"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SendMessageByAli 发送短信验证码
|
||||
// @Summary 发送短信验证码
|
||||
// @Description 发送短信验证码
|
||||
// @Tags 短信验证码
|
||||
// @Produce json
|
||||
// @Param phone query string true "手机号"
|
||||
// @Router /controller/sms/ali/send [get]
|
||||
func (SmsController) SendMessageByAli(c *gin.Context) {
|
||||
smsRequest := SmsRequest{}
|
||||
err := c.ShouldBindJSON(&smsRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendFailed"), c)
|
||||
return
|
||||
}
|
||||
checkRotateData := utils.CheckRotateData(smsRequest.Angle, smsRequest.Key)
|
||||
if !checkRotateData {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c)
|
||||
return
|
||||
}
|
||||
isPhone := utils.IsPhone(smsRequest.Phone)
|
||||
if !isPhone {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneErrorFormat"), c)
|
||||
return
|
||||
}
|
||||
val := redis.Get(constant.UserLoginSmsRedisKey + smsRequest.Phone).Val()
|
||||
if val != "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaTooOften"), c)
|
||||
return
|
||||
}
|
||||
sms := gosms.NewParser(gateways.Gateways{
|
||||
ALiYun: aliyun.ALiYun{
|
||||
Host: global.CONFIG.SMS.Ali.Host,
|
||||
AccessKeyId: global.CONFIG.SMS.Ali.AccessKeyID,
|
||||
AccessKeySecret: global.CONFIG.SMS.Ali.AccessKeySecret,
|
||||
},
|
||||
})
|
||||
code := utils.GenValidateCode(6)
|
||||
wrong := redis.Set(constant.UserLoginSmsRedisKey+smsRequest.Phone, code, time.Minute).Err()
|
||||
if wrong != nil {
|
||||
global.LOG.Error(wrong)
|
||||
return
|
||||
}
|
||||
_, err = sms.Send(smsRequest.Phone, gosms.MapStringAny{
|
||||
"content": "您的验证码是:****。请不要把验证码泄露给其他人。",
|
||||
"template": global.CONFIG.SMS.Ali.TemplateID,
|
||||
//"signName": global.CONFIG.SMS.Ali.Signature,
|
||||
"data": gosms.MapStrings{
|
||||
"code": code,
|
||||
},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendSuccess"), c)
|
||||
|
||||
}
|
||||
|
||||
// SendMessageBySmsBao 短信宝发送短信验证码
|
||||
// @Summary 短信宝发送短信验证码
|
||||
// @Description 短信宝发送短信验证码
|
||||
// @Tags 短信验证码
|
||||
// @Produce json
|
||||
// @Param phone query string true "手机号"
|
||||
// @Router /controller/sms/smsbao/send [post]
|
||||
func (SmsController) SendMessageBySmsBao(c *gin.Context) {
|
||||
smsRequest := SmsRequest{}
|
||||
err := c.ShouldBindJSON(&smsRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendFailed"), c)
|
||||
return
|
||||
}
|
||||
checkRotateData := utils.CheckRotateData(smsRequest.Angle, smsRequest.Key)
|
||||
if !checkRotateData {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c)
|
||||
return
|
||||
}
|
||||
isPhone := utils.IsPhone(smsRequest.Phone)
|
||||
if !isPhone {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneErrorFormat"), c)
|
||||
return
|
||||
}
|
||||
val := redis.Get(constant.UserLoginSmsRedisKey + smsRequest.Phone).Val()
|
||||
if val != "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaTooOften"), c)
|
||||
return
|
||||
}
|
||||
sms := gosms.NewParser(gateways.Gateways{
|
||||
SmsBao: smsbao.SmsBao{
|
||||
User: global.CONFIG.SMS.SmsBao.User,
|
||||
Password: global.CONFIG.SMS.SmsBao.Password,
|
||||
},
|
||||
})
|
||||
code := utils.GenValidateCode(6)
|
||||
wrong := redis.Set(constant.UserLoginSmsRedisKey+smsRequest.Phone, code, time.Minute).Err()
|
||||
if wrong != nil {
|
||||
global.LOG.Error(wrong)
|
||||
return
|
||||
}
|
||||
_, err = sms.Send(smsRequest.Phone, gosms.MapStringAny{
|
||||
"content": "您的验证码是:" + code + "。请不要把验证码泄露给其他人。",
|
||||
}, nil)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendSuccess"), c)
|
||||
}
|
||||
|
||||
// SendMessageTest 发送测试短信验证码
|
||||
// @Summary 发送测试短信验证码
|
||||
// @Description 发送测试短信验证码
|
||||
// @Tags 短信验证码
|
||||
// @Produce json
|
||||
// @Param phone query string true "手机号"
|
||||
// @Router /controller/sms/test/send [post]
|
||||
func (SmsController) SendMessageTest(c *gin.Context) {
|
||||
smsRequest := SmsRequest{}
|
||||
err := c.ShouldBindJSON(&smsRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendFailed"), c)
|
||||
return
|
||||
}
|
||||
checkRotateData := utils.CheckRotateData(smsRequest.Angle, smsRequest.Key)
|
||||
if !checkRotateData {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c)
|
||||
return
|
||||
}
|
||||
isPhone := utils.IsPhone(smsRequest.Phone)
|
||||
if !isPhone {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneError"), c)
|
||||
return
|
||||
}
|
||||
code := utils.GenValidateCode(6)
|
||||
err = redis.Set(constant.UserLoginSmsRedisKey+smsRequest.Phone, code, time.Minute).Err()
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CaptchaSendSuccess"), c)
|
||||
|
||||
}
|
134
controller/user_controller/handler.go
Normal file
134
controller/user_controller/handler.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package user_controller
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mssola/useragent"
|
||||
"gorm.io/gorm"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"schisandra-cloud-album/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
// getUserLoginDevice 获取用户登录设备
|
||||
func getUserLoginDevice(user model.ScaAuthUser, c *gin.Context) bool {
|
||||
|
||||
// 检查user.UID是否为空
|
||||
if user.UID == nil {
|
||||
global.LOG.Errorln("user.UID is nil")
|
||||
return false
|
||||
}
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if userAgent == "" {
|
||||
global.LOG.Errorln("user-agent is empty")
|
||||
return false
|
||||
}
|
||||
ua := useragent.New(userAgent)
|
||||
|
||||
ip := utils.GetClientIP(c)
|
||||
location, err := global.IP2Location.SearchByStr(ip)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return false
|
||||
}
|
||||
location = utils.RemoveZeroAndAdjust(location)
|
||||
|
||||
isBot := ua.Bot()
|
||||
browser, browserVersion := ua.Browser()
|
||||
os := ua.OS()
|
||||
mobile := ua.Mobile()
|
||||
mozilla := ua.Mozilla()
|
||||
platform := ua.Platform()
|
||||
engine, engineVersion := ua.Engine()
|
||||
|
||||
device := model.ScaAuthUserDevice{
|
||||
UserID: user.UID,
|
||||
IP: &ip,
|
||||
Location: &location,
|
||||
Agent: userAgent,
|
||||
Browser: &browser,
|
||||
BrowserVersion: &browserVersion,
|
||||
OperatingSystem: &os,
|
||||
Mobile: &mobile,
|
||||
Bot: &isBot,
|
||||
Mozilla: &mozilla,
|
||||
Platform: &platform,
|
||||
EngineName: &engine,
|
||||
EngineVersion: &engineVersion,
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
userDevice, err := userDeviceService.GetUserDeviceByUIDIPAgentService(*user.UID, ip, userAgent)
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
err = userDeviceService.AddUserDeviceService(&device)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return false
|
||||
}
|
||||
} else if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return false
|
||||
} else {
|
||||
err := userDeviceService.UpdateUserDeviceService(userDevice.ID, &device)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handelUserLogin 处理用户登录
|
||||
func handelUserLogin(user model.ScaAuthUser, autoLogin bool, c *gin.Context) {
|
||||
// 检查 user.UID 是否为 nil
|
||||
if user.UID == nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
if !getUserLoginDevice(user, c) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
|
||||
return
|
||||
}
|
||||
accessToken, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: user.UID})
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
|
||||
return
|
||||
}
|
||||
|
||||
var days time.Duration
|
||||
if autoLogin {
|
||||
days = 7 * 24 * time.Hour
|
||||
} else {
|
||||
days = time.Minute * 30
|
||||
}
|
||||
|
||||
refreshToken, expiresAt := utils.GenerateRefreshToken(utils.RefreshJWTPayload{UserID: user.UID}, days)
|
||||
data := ResponseData{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
UID: user.UID,
|
||||
}
|
||||
|
||||
err = redis.Set(constant.UserLoginTokenRedisKey+*user.UID, data, days).Err()
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
|
||||
return
|
||||
}
|
||||
gob.Register(ResponseData{})
|
||||
err = utils.SetSession(c, constant.SessionKey, data)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithData(data, c)
|
||||
}
|
55
controller/user_controller/request_param.go
Normal file
55
controller/user_controller/request_param.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package user_controller
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// RefreshTokenRequest 刷新token请求
|
||||
type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
}
|
||||
|
||||
// PhoneLoginRequest 手机号登录请求
|
||||
type PhoneLoginRequest struct {
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
Captcha string `json:"captcha" binding:"required"`
|
||||
AutoLogin bool `json:"auto_login" binding:"required"`
|
||||
}
|
||||
|
||||
// AccountLoginRequest 账号登录请求
|
||||
type AccountLoginRequest struct {
|
||||
Account string `json:"account" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
AutoLogin bool `json:"auto_login" binding:"required"`
|
||||
Angle int64 `json:"angle" binding:"required"`
|
||||
Key string `json:"key" binding:"required"`
|
||||
}
|
||||
|
||||
// AddUserRequest 新增用户请求
|
||||
type AddUserRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
Captcha string `json:"captcha" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Repassword string `json:"repassword" binding:"required"`
|
||||
}
|
||||
|
||||
// ResponseData 返回数据
|
||||
type ResponseData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
UID *string `json:"uid"`
|
||||
}
|
||||
|
||||
func (res ResponseData) MarshalBinary() ([]byte, error) {
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
||||
func (res ResponseData) UnmarshalBinary(data []byte) error {
|
||||
return json.Unmarshal(data, &res)
|
||||
}
|
12
controller/user_controller/user.go
Normal file
12
controller/user_controller/user.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package user_controller
|
||||
|
||||
import (
|
||||
"schisandra-cloud-album/service/impl"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type UserController struct{}
|
||||
|
||||
var mu sync.Mutex
|
||||
var userService = impl.UserServiceImpl{}
|
||||
var userDeviceService = impl.UserDeviceServiceImpl{}
|
407
controller/user_controller/user_controller.go
Normal file
407
controller/user_controller/user_controller.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package user_controller
|
||||
|
||||
import (
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/yitter/idgenerator-go/idgen"
|
||||
"gorm.io/gorm"
|
||||
"reflect"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/enum"
|
||||
"schisandra-cloud-album/common/randomname"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/common/result"
|
||||
"schisandra-cloud-album/global"
|
||||
"schisandra-cloud-album/model"
|
||||
"schisandra-cloud-album/utils"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetUserList
|
||||
// @Summary 获取所有用户列表
|
||||
// @Tags 用户模块
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/auth/user/List [get]
|
||||
func (UserController) GetUserList(c *gin.Context) {
|
||||
userList := userService.GetUserListService()
|
||||
result.OkWithData(userList, c)
|
||||
}
|
||||
|
||||
// QueryUserByUsername
|
||||
// @Summary 根据用户名查询用户
|
||||
// @Tags 用户模块
|
||||
// @Param username query string true "用户名"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/auth/user/query_by_username [get]
|
||||
func (UserController) QueryUserByUsername(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
user := userService.QueryUserByUsernameService(username)
|
||||
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithData(user, c)
|
||||
}
|
||||
|
||||
// QueryUserByUuid
|
||||
// @Summary 根据uuid查询用户
|
||||
// @Tags 用户模块
|
||||
// @Param uid query string true "用户uid"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/auth/user/query_by_uid [get]
|
||||
func (UserController) QueryUserByUuid(c *gin.Context) {
|
||||
uid := c.Query("uid")
|
||||
user := userService.QueryUserByUuidService(&uid)
|
||||
if user.ID == 0 {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithData(user, c)
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
// @Summary 删除用户
|
||||
// @Tags 用户模块
|
||||
// @Param uid query string true "用户uid"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/auth/user/delete [delete]
|
||||
func (UserController) DeleteUser(c *gin.Context) {
|
||||
uid := c.Query("uid")
|
||||
err := userService.DeleteUserService(uid)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "DeletedFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "DeletedSuccess"), c)
|
||||
}
|
||||
|
||||
// QueryUserByPhone 根据手机号查询用户
|
||||
// @Summary 根据手机号查询用户
|
||||
// @Tags 用户模块
|
||||
// @Param phone query string true "手机号"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/auth/user/query_by_phone [get]
|
||||
func (UserController) QueryUserByPhone(c *gin.Context) {
|
||||
phone := c.Query("phone")
|
||||
user := userService.QueryUserByPhoneService(phone)
|
||||
if user.ID == 0 {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithData(user, c)
|
||||
}
|
||||
|
||||
// AccountLogin 账号登录
|
||||
// @Summary 账号登录
|
||||
// @Tags 用户模块
|
||||
// @Param user body AccountLoginRequest true "用户信息"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/user/login [post]
|
||||
func (UserController) AccountLogin(c *gin.Context) {
|
||||
accountLoginRequest := AccountLoginRequest{}
|
||||
err := c.ShouldBindJSON(&accountLoginRequest)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
rotateData := utils.CheckRotateData(accountLoginRequest.Angle, accountLoginRequest.Key)
|
||||
if !rotateData {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaVerifyError"), c)
|
||||
return
|
||||
}
|
||||
account := accountLoginRequest.Account
|
||||
password := accountLoginRequest.Password
|
||||
|
||||
var user model.ScaAuthUser
|
||||
if utils.IsPhone(account) {
|
||||
user = userService.QueryUserByPhoneService(account)
|
||||
} else if utils.IsEmail(account) {
|
||||
user = userService.QueryUserByEmailService(account)
|
||||
} else if utils.IsUsername(account) {
|
||||
user = userService.QueryUserByUsernameService(account)
|
||||
} else {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "AccountErrorFormat"), c)
|
||||
return
|
||||
}
|
||||
|
||||
if user.ID == 0 {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "NotFoundUser"), c)
|
||||
return
|
||||
}
|
||||
|
||||
if !utils.Verify(*user.Password, password) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c)
|
||||
return
|
||||
}
|
||||
handelUserLogin(user, accountLoginRequest.AutoLogin, c)
|
||||
}
|
||||
|
||||
// PhoneLogin 手机号登录/注册
|
||||
// @Summary 手机号登录/注册
|
||||
// @Tags 用户模块
|
||||
// @Param user body PhoneLoginRequest true "用户信息"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/user/phone_login [post]
|
||||
func (UserController) PhoneLogin(c *gin.Context) {
|
||||
request := PhoneLoginRequest{}
|
||||
err := c.ShouldBind(&request)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
phone := request.Phone
|
||||
captcha := request.Captcha
|
||||
autoLogin := request.AutoLogin
|
||||
if !utils.IsPhone(phone) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneErrorFormat"), c)
|
||||
return
|
||||
}
|
||||
|
||||
userChan := make(chan model.ScaAuthUser)
|
||||
go func() {
|
||||
user := userService.QueryUserByPhoneService(phone)
|
||||
userChan <- user
|
||||
}()
|
||||
|
||||
user := <-userChan
|
||||
close(userChan)
|
||||
|
||||
if user.ID == 0 {
|
||||
// 未注册
|
||||
codeChan := make(chan *string)
|
||||
go func() {
|
||||
code := redis.Get(constant.UserLoginSmsRedisKey + phone).Val()
|
||||
codeChan <- &code
|
||||
}()
|
||||
|
||||
code := <-codeChan
|
||||
close(codeChan)
|
||||
|
||||
if code == nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c)
|
||||
return
|
||||
}
|
||||
|
||||
uid := idgen.NextId()
|
||||
uidStr := strconv.FormatInt(uid, 10)
|
||||
|
||||
avatar, err := utils.GenerateAvatar(uidStr)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
return
|
||||
}
|
||||
name := randomname.GenerateName()
|
||||
createUser := model.ScaAuthUser{
|
||||
UID: &uidStr,
|
||||
Phone: &phone,
|
||||
Avatar: &avatar,
|
||||
Nickname: &name,
|
||||
Gender: &enum.Male,
|
||||
}
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
err := global.DB.Transaction(func(tx *gorm.DB) error {
|
||||
addUser, err := userService.AddUserService(createUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = global.Casbin.AddRoleForUser(uidStr, enum.User)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handelUserLogin(*addUser, autoLogin, c)
|
||||
return nil
|
||||
})
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
err = <-errChan
|
||||
close(errChan)
|
||||
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "RegisterUserError"), c)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
codeChan := make(chan string)
|
||||
go func() {
|
||||
code := redis.Get(constant.UserLoginSmsRedisKey + phone).Val()
|
||||
codeChan <- code
|
||||
}()
|
||||
|
||||
code := <-codeChan
|
||||
close(codeChan)
|
||||
|
||||
if code == "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c)
|
||||
return
|
||||
}
|
||||
if captcha != code {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaError"), c)
|
||||
return
|
||||
}
|
||||
handelUserLogin(user, autoLogin, c)
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshHandler 刷新token
|
||||
// @Summary 刷新token
|
||||
// @Tags 用户模块
|
||||
// @Param refresh_token query string true "刷新token"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/token/refresh [post]
|
||||
func (UserController) RefreshHandler(c *gin.Context) {
|
||||
request := RefreshTokenRequest{}
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
refreshToken := request.RefreshToken
|
||||
parseRefreshToken, isUpd, err := utils.ParseRefreshToken(refreshToken)
|
||||
if err != nil || !isUpd {
|
||||
global.LOG.Errorln(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
|
||||
return
|
||||
}
|
||||
accessTokenString, err := utils.GenerateAccessToken(utils.AccessJWTPayload{UserID: parseRefreshToken.UserID})
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
|
||||
return
|
||||
}
|
||||
tokenKey := constant.UserLoginTokenRedisKey + *parseRefreshToken.UserID
|
||||
token, err := redis.Get(tokenKey).Result()
|
||||
if err != nil || token == "" {
|
||||
global.LOG.Errorln(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
|
||||
return
|
||||
}
|
||||
data := ResponseData{
|
||||
AccessToken: accessTokenString,
|
||||
RefreshToken: refreshToken,
|
||||
UID: parseRefreshToken.UserID,
|
||||
}
|
||||
if err := redis.Set(tokenKey, data, time.Hour*24*7).Err(); err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithData(data, c)
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// @Summary 重置密码
|
||||
// @Tags 用户模块
|
||||
// @Param user body ResetPasswordRequest true "用户信息"
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/user/reset_password [post]
|
||||
func (UserController) ResetPassword(c *gin.Context) {
|
||||
var resetPasswordRequest ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&resetPasswordRequest); err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
phone := resetPasswordRequest.Phone
|
||||
captcha := resetPasswordRequest.Captcha
|
||||
password := resetPasswordRequest.Password
|
||||
repassword := resetPasswordRequest.Repassword
|
||||
if !utils.IsPhone(phone) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
if password != repassword {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordNotSame"), c)
|
||||
return
|
||||
}
|
||||
|
||||
if !utils.IsPassword(password) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PasswordError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用事务确保验证码检查和密码更新的原子性
|
||||
tx := global.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := tx.Error; err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "DatabaseError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
code := redis.Get(constant.UserLoginSmsRedisKey + phone).Val()
|
||||
if code == "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaExpired"), c)
|
||||
return
|
||||
}
|
||||
|
||||
if captcha != code {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CaptchaError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证码检查通过后立即删除或标记为已使用
|
||||
if err := redis.Del(constant.UserLoginSmsRedisKey + phone).Err(); err != nil {
|
||||
tx.Rollback()
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
user := userService.QueryUserByPhoneService(phone)
|
||||
if reflect.DeepEqual(user, model.ScaAuthUser{}) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "PhoneNotRegister"), c)
|
||||
return
|
||||
}
|
||||
|
||||
encrypt, err := utils.Encrypt(password)
|
||||
if err != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordError")+": "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := userService.UpdateUserService(phone, encrypt); err != nil {
|
||||
tx.Rollback()
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "ResetPasswordSuccess"), c)
|
||||
}
|
||||
|
||||
// Logout 退出登录
|
||||
// @Summary 退出登录
|
||||
// @Tags 用户模块
|
||||
// @Success 200 {string} json
|
||||
// @Router /controller/auth/user/logout [post]
|
||||
func (UserController) Logout(c *gin.Context) {
|
||||
userId := c.Query("user_id")
|
||||
if userId == "" {
|
||||
global.LOG.Errorln("userId is empty")
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
|
||||
tokenKey := constant.UserLoginTokenRedisKey + userId
|
||||
del := redis.Del(tokenKey)
|
||||
|
||||
if del.Err() != nil {
|
||||
global.LOG.Errorln(del.Err())
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LogoutFailed"), c)
|
||||
return
|
||||
}
|
||||
ip := utils.GetClientIP(c)
|
||||
key := constant.UserLoginClientRedisKey + ip
|
||||
del = redis.Del(key)
|
||||
if del.Err() != nil {
|
||||
global.LOG.Errorln(del.Err())
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LogoutFailed"), c)
|
||||
return
|
||||
}
|
||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "LogoutSuccess"), c)
|
||||
}
|
182
controller/websocket_controller/gws_controller.go
Normal file
182
controller/websocket_controller/gws_controller.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package websocket_controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lxzan/gws"
|
||||
"net/http"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
"schisandra-cloud-album/global"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
PingInterval = 5 * time.Second // 客户端心跳间隔
|
||||
HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
|
||||
)
|
||||
|
||||
type WebSocket struct {
|
||||
gws.BuiltinEventHandler
|
||||
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
|
||||
}
|
||||
|
||||
var Handler = NewWebSocket()
|
||||
|
||||
// NewGWSServer 创建websocket服务
|
||||
// @Summary 创建websocket服务
|
||||
// @Description 创建websocket服务
|
||||
// @Tags websocket
|
||||
// @Router /controller/ws/gws [get]
|
||||
func (WebsocketController) NewGWSServer(c *gin.Context) {
|
||||
|
||||
upgrader := gws.NewUpgrader(Handler, &gws.ServerOption{
|
||||
HandshakeTimeout: 5 * time.Second, // 握手超时时间
|
||||
ReadBufferSize: 1024, // 读缓冲区大小
|
||||
ParallelEnabled: true, // 开启并行消息处理
|
||||
Recovery: gws.Recovery, // 开启异常恢复
|
||||
CheckUtf8Enabled: false, // 关闭UTF8校验
|
||||
PermessageDeflate: gws.PermessageDeflate{
|
||||
Enabled: true, // 开启压缩
|
||||
},
|
||||
Authorize: func(r *http.Request, session gws.SessionStorage) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != global.CONFIG.System.WebURL() {
|
||||
return false
|
||||
}
|
||||
var clientId = r.URL.Query().Get("client_id")
|
||||
if clientId == "" {
|
||||
return false
|
||||
}
|
||||
session.Store("client_id", clientId)
|
||||
return true
|
||||
},
|
||||
})
|
||||
socket, err := upgrader.Upgrade(c.Writer, c.Request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
|
||||
}()
|
||||
}
|
||||
|
||||
// MustLoad 从session中加载数据
|
||||
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
|
||||
if value, exist := session.Load(key); exist {
|
||||
v = value.(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NewWebSocket 创建WebSocket实例
|
||||
func NewWebSocket() *WebSocket {
|
||||
return &WebSocket{
|
||||
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
|
||||
}
|
||||
}
|
||||
|
||||
// OnOpen 连接建立
|
||||
func (c *WebSocket) OnOpen(socket *gws.Conn) {
|
||||
clientId := MustLoad[string](socket.Session(), "client_id")
|
||||
c.sessions.Store(clientId, socket)
|
||||
// 订阅该用户的频道
|
||||
go c.subscribeUserChannel(clientId)
|
||||
fmt.Printf("websocket client %s connected\n", clientId)
|
||||
}
|
||||
|
||||
// OnClose 关闭连接
|
||||
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
|
||||
name := MustLoad[string](socket.Session(), "client_id")
|
||||
sharding := c.sessions.GetSharding(name)
|
||||
c.sessions.Delete(name)
|
||||
sharding.Lock()
|
||||
defer sharding.Unlock()
|
||||
|
||||
global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
|
||||
}
|
||||
|
||||
// OnPing 处理客户端的Ping消息
|
||||
func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
|
||||
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
|
||||
_ = socket.WritePong(payload)
|
||||
}
|
||||
|
||||
// OnPong 处理客户端的Pong消息
|
||||
func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {}
|
||||
|
||||
// OnMessage 接受消息
|
||||
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
|
||||
defer message.Close()
|
||||
clientId := MustLoad[string](socket.Session(), "client_id")
|
||||
if conn, ok := c.sessions.Load(clientId); ok {
|
||||
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessageToClient 向指定客户端发送消息
|
||||
func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
|
||||
conn, ok := c.sessions.Load(clientId)
|
||||
if ok {
|
||||
return conn.WriteMessage(gws.OpcodeText, message)
|
||||
}
|
||||
return fmt.Errorf("client %s not found", clientId)
|
||||
}
|
||||
|
||||
// SendMessageToUser 发送消息到指定用户的 Redis 频道
|
||||
func (c *WebSocket) SendMessageToUser(clientId string, message []byte) error {
|
||||
if _, ok := c.sessions.Load(clientId); ok {
|
||||
return redis.Publish(clientId, message).Err()
|
||||
} else {
|
||||
return redis.LPush(constant.CommentOfflineMessageRedisKey+clientId, message).Err()
|
||||
}
|
||||
}
|
||||
|
||||
// 订阅用户频道
|
||||
func (c *WebSocket) subscribeUserChannel(clientId string) {
|
||||
conn, ok := c.sessions.Load(clientId)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取离线消息
|
||||
messages, err := redis.LRange(constant.CommentOfflineMessageRedisKey+clientId, 0, -1).Result()
|
||||
if err != nil {
|
||||
global.LOG.Printf("Error loading offline messages for user %s: %v\n", clientId, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 逐条发送离线消息
|
||||
for _, msg := range messages {
|
||||
if writeErr := conn.WriteMessage(gws.OpcodeText, []byte(msg)); writeErr != nil {
|
||||
global.LOG.Printf("Error writing offline message to user %s: %v\n", clientId, writeErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 清空离线消息列表
|
||||
if delErr := redis.Del(constant.CommentOfflineMessageRedisKey + clientId); delErr.Err() != nil {
|
||||
global.LOG.Printf("Error clearing offline messages for user %s: %v\n", clientId, delErr.Err())
|
||||
}
|
||||
|
||||
pubsub := redis.Subscribe(clientId)
|
||||
defer func() {
|
||||
if closeErr := pubsub.Close(); closeErr != nil {
|
||||
global.LOG.Printf("Error closing pubsub for user %s: %v\n", clientId, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
msg, waitErr := pubsub.ReceiveMessage(context.Background())
|
||||
if waitErr != nil {
|
||||
global.LOG.Printf("Error receiving message for user %s: %v\n", clientId, err)
|
||||
return
|
||||
}
|
||||
|
||||
if writeErr := conn.WriteMessage(gws.OpcodeText, []byte(msg.Payload)); writeErr != nil {
|
||||
global.LOG.Printf("Error writing message to user %s: %v\n", clientId, writeErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
4
controller/websocket_controller/websocket.go
Normal file
4
controller/websocket_controller/websocket.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package websocket_controller
|
||||
|
||||
type WebsocketController struct {
|
||||
}
|
Reference in New Issue
Block a user