diff --git a/api/comment_api/comment.go b/api/comment_api/comment.go index 7091e48..4d9f53a 100644 --- a/api/comment_api/comment.go +++ b/api/comment_api/comment.go @@ -3,8 +3,12 @@ package comment_api import ( "encoding/base64" "errors" + "github.com/acmestack/gorm-plus/gplus" "io" "regexp" + "schisandra-cloud-album/api/comment_api/dto" + "schisandra-cloud-album/global" + "schisandra-cloud-album/model" "schisandra-cloud-album/service" "strings" "sync" @@ -135,3 +139,88 @@ func getMimeType(data []byte) string { return "application/octet-stream" // 默认类型 } + +// 点赞 +var likeChannel = make(chan dto.CommentLikeRequest, 100) +var cancelLikeChannel = make(chan dto.CommentLikeRequest, 100) // 取消点赞 + +func init() { + go likeConsumer() // 启动消费者 + go cancelLikeConsumer() // 启动消费者 +} +func likeConsumer() { + for likeRequest := range likeChannel { + processLike(likeRequest) // 处理点赞 + } +} +func cancelLikeConsumer() { + for cancelLikeRequest := range cancelLikeChannel { + processCancelLike(cancelLikeRequest) // 处理取消点赞 + } +} + +func processLike(likeRequest dto.CommentLikeRequest) { + mx.Lock() + defer mx.Unlock() + + likes := model.ScaCommentLikes{ + CommentId: likeRequest.CommentId, + UserId: likeRequest.UserID, + TopicId: likeRequest.TopicId, + } + + tx := global.DB.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + res := global.DB.Create(&likes) // 假设这是插入数据库的方法 + if res.Error != nil { + tx.Rollback() + global.LOG.Errorln(res.Error) + return + } + + // 异步更新点赞计数 + go func() { + if err := commentReplyService.UpdateCommentLikesCount(likeRequest.CommentId, likeRequest.TopicId); err != nil { + global.LOG.Errorln(err) + } + }() + + tx.Commit() +} +func processCancelLike(cancelLikeRequest dto.CommentLikeRequest) { + mx.Lock() + defer mx.Unlock() + + tx := global.DB.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + query, u := gplus.NewQuery[model.ScaCommentLikes]() + query.Eq(&u.CommentId, cancelLikeRequest.CommentId). + Eq(&u.UserId, cancelLikeRequest.UserID). + Eq(&u.TopicId, cancelLikeRequest.TopicId) + + res := gplus.Delete[model.ScaCommentLikes](query) + if res.Error != nil { + tx.Rollback() + return // 返回错误而非打印 + } + + // 异步更新点赞计数 + go func() { + if err := commentReplyService.DecrementCommentLikesCount(cancelLikeRequest.CommentId, cancelLikeRequest.TopicId); err != nil { + global.LOG.Errorln(err) + } + }() + + tx.Commit() + return +} diff --git a/api/comment_api/comment_api.go b/api/comment_api/comment_api.go index 78ac5eb..c98cb99 100644 --- a/api/comment_api/comment_api.go +++ b/api/comment_api/comment_api.go @@ -763,35 +763,13 @@ func (CommentAPI) CommentLikes(c *gin.Context) { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - mx.Lock() - defer mx.Unlock() - likes := model.ScaCommentLikes{ + + // 将点赞请求发送到 channel 中 + likeChannel <- dto.CommentLikeRequest{ CommentId: likeRequest.CommentId, - UserId: likeRequest.UserID, + UserID: likeRequest.UserID, TopicId: likeRequest.TopicId, } - - tx := global.DB.Begin() - defer func() { - if r := recover(); r != nil { - tx.Rollback() - } - }() - res := gplus.Insert(&likes) - if res.Error != nil { - tx.Rollback() - global.LOG.Errorln(res.Error) - result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeFailed"), c) - return - } - // 更新点赞计数 - if err = commentReplyService.UpdateCommentLikesCount(likeRequest.CommentId, likeRequest.TopicId); err != nil { - tx.Rollback() - global.LOG.Errorln(err) - result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeFailed"), c) - return - } - tx.Commit() result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeSuccess"), c) return } @@ -806,36 +784,16 @@ func (CommentAPI) CommentLikes(c *gin.Context) { // @Router /auth/comment/cancel_like [post] func (CommentAPI) CancelCommentLikes(c *gin.Context) { likeRequest := dto.CommentLikeRequest{} - err := c.ShouldBindJSON(&likeRequest) - if err != nil { + if err := c.ShouldBindJSON(&likeRequest); err != nil { result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c) return } - mx.Lock() - defer mx.Unlock() - tx := global.DB.Begin() - defer func() { - if r := recover(); r != nil { - tx.Rollback() - } - }() - query, u := gplus.NewQuery[model.ScaCommentLikes]() - query.Eq(&u.CommentId, likeRequest.CommentId).Eq(&u.UserId, likeRequest.UserID).Eq(&u.TopicId, likeRequest.TopicId) - res := gplus.Delete[model.ScaCommentLikes](query) - if res.Error != nil { - tx.Rollback() - global.LOG.Errorln(res.Error) - result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelFailed"), c) - return + // 将取消点赞请求发送到 channel + cancelLikeChannel <- dto.CommentLikeRequest{ + CommentId: likeRequest.CommentId, + UserID: likeRequest.UserID, + TopicId: likeRequest.TopicId, } - err = commentReplyService.DecrementCommentLikesCount(likeRequest.CommentId, likeRequest.TopicId) - if err != nil { - tx.Rollback() - global.LOG.Errorln(err) - result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelFailed"), c) - return - } - tx.Commit() result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelSuccess"), c) return } diff --git a/api/websocket_api/gws_api.go b/api/websocket_api/gws_api.go index 1187ad0..f8fd69c 100644 --- a/api/websocket_api/gws_api.go +++ b/api/websocket_api/gws_api.go @@ -1,10 +1,13 @@ package websocket_api import ( + "context" "fmt" "github.com/gin-gonic/gin" "github.com/lxzan/gws" "net/http" + "schisandra-cloud-album/common/constant" + "schisandra-cloud-album/common/redis" "schisandra-cloud-album/global" "time" ) @@ -14,6 +17,11 @@ const ( HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间 ) +type WebSocket struct { + gws.BuiltinEventHandler + sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突 +} + var Handler = NewWebSocket() // NewGWSServer 创建websocket服务 @@ -33,6 +41,10 @@ func (WebsocketAPI) NewGWSServer(c *gin.Context) { Enabled: true, // 开启压缩 }, Authorize: func(r *http.Request, session gws.SessionStorage) bool { + origin := r.Header.Get("Origin") + if origin != global.CONFIG.System.WebURL() { + return false + } var clientId = r.URL.Query().Get("client_id") if clientId == "" { return false @@ -49,6 +61,8 @@ func (WebsocketAPI) NewGWSServer(c *gin.Context) { socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC }() } + +// MustLoad 从session中加载数据 func MustLoad[T any](session gws.SessionStorage, key string) (v T) { if value, exist := session.Load(key); exist { v = value.(T) @@ -56,26 +70,23 @@ func MustLoad[T any](session gws.SessionStorage, key string) (v T) { return } +// NewWebSocket 创建WebSocket实例 func NewWebSocket() *WebSocket { return &WebSocket{ sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128), } } -type WebSocket struct { - gws.BuiltinEventHandler - sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突 -} - +// OnOpen 连接建立 func (c *WebSocket) OnOpen(socket *gws.Conn) { - name := MustLoad[string](socket.Session(), "client_id") - if conn, ok := c.sessions.Load(name); ok { - conn.WriteClose(1000, []byte("connection is replaced")) - } - c.sessions.Store(name, socket) - global.LOG.Printf("%s connected\n", name) + clientId := MustLoad[string](socket.Session(), "client_id") + c.sessions.Store(clientId, socket) + // 订阅该用户的频道 + go c.subscribeUserChannel(clientId) + fmt.Printf("websocket client %s connected\n", clientId) } +// OnClose 关闭连接 func (c *WebSocket) OnClose(socket *gws.Conn, err error) { name := MustLoad[string](socket.Session(), "client_id") sharding := c.sessions.GetSharding(name) @@ -86,17 +97,20 @@ func (c *WebSocket) OnClose(socket *gws.Conn, err error) { global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error()) } +// OnPing 处理客户端的Ping消息 func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) { _ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout)) _ = socket.WritePong(payload) } +// OnPong 处理客户端的Pong消息 func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {} +// OnMessage 接受消息 func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) { defer message.Close() - name := MustLoad[string](socket.Session(), "client_id") - if conn, ok := c.sessions.Load(name); ok { + clientId := MustLoad[string](socket.Session(), "client_id") + if conn, ok := c.sessions.Load(clientId); ok { _ = conn.WriteMessage(gws.OpcodeText, message.Bytes()) } } @@ -109,3 +123,60 @@ func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error { } return fmt.Errorf("client %s not found", clientId) } + +// SendMessageToUser 发送消息到指定用户的 Redis 频道 +func (c *WebSocket) SendMessageToUser(clientId string, message []byte) error { + if _, ok := c.sessions.Load(clientId); ok { + return redis.Publish(clientId, message).Err() + } else { + return redis.LPush(constant.CommentOfflineMessageRedisKey+clientId, message).Err() + } +} + +// 订阅用户频道 +func (c *WebSocket) subscribeUserChannel(clientId string) { + conn, ok := c.sessions.Load(clientId) + if !ok { + return + } + + // 获取离线消息 + messages, err := redis.LRange(constant.CommentOfflineMessageRedisKey+clientId, 0, -1).Result() + if err != nil { + global.LOG.Printf("Error loading offline messages for user %s: %v\n", clientId, err) + return + } + + // 逐条发送离线消息 + for _, msg := range messages { + if writeErr := conn.WriteMessage(gws.OpcodeText, []byte(msg)); writeErr != nil { + global.LOG.Printf("Error writing offline message to user %s: %v\n", clientId, writeErr) + return + } + } + + // 清空离线消息列表 + if delErr := redis.Del(constant.CommentOfflineMessageRedisKey + clientId); delErr.Err() != nil { + global.LOG.Printf("Error clearing offline messages for user %s: %v\n", clientId, delErr.Err()) + } + + pubsub := redis.Subscribe(clientId) + defer func() { + if closeErr := pubsub.Close(); closeErr != nil { + global.LOG.Printf("Error closing pubsub for user %s: %v\n", clientId, closeErr) + } + }() + + for { + msg, waitErr := pubsub.ReceiveMessage(context.Background()) + if waitErr != nil { + global.LOG.Printf("Error receiving message for user %s: %v\n", clientId, err) + return + } + + if writeErr := conn.WriteMessage(gws.OpcodeText, []byte(msg.Payload)); writeErr != nil { + global.LOG.Printf("Error writing message to user %s: %v\n", clientId, writeErr) + return + } + } +} diff --git a/common/constant/redis_key.go b/common/constant/redis_key.go index f139c2c..723c48b 100644 --- a/common/constant/redis_key.go +++ b/common/constant/redis_key.go @@ -12,5 +12,6 @@ const ( // 登录之后 const ( - CommentSubmitCaptchaRedisKey = "comment:submit:captcha:" + CommentSubmitCaptchaRedisKey = "comment:submit:captcha:" + CommentOfflineMessageRedisKey = "comment:offline:message:" ) diff --git a/common/redis/redis.go b/common/redis/redis.go index e0ce4cb..fae4d25 100644 --- a/common/redis/redis.go +++ b/common/redis/redis.go @@ -394,3 +394,16 @@ func ZRem(key string, members ...interface{}) *redis.IntCmd { func ZRemRangeByRank(key string, start, stop int64) *redis.IntCmd { return global.REDIS.ZRemRangeByRank(ctx, key, start, stop) } + +// Publish 发布消息到redis +// channel是发布的目标信道 +// payload是要发布的消息内容 +func Publish(channel string, payload interface{}) *redis.IntCmd { + return global.REDIS.Publish(ctx, channel, payload) +} + +// Subscribe 订阅redis消息 +// channels是要订阅的信道列表 +func Subscribe(channels ...string) *redis.PubSub { + return global.REDIS.Subscribe(ctx, channels...) +}