✨ add comment message push
This commit is contained in:
@@ -3,8 +3,12 @@ package comment_api
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/acmestack/gorm-plus/gplus"
|
||||||
"io"
|
"io"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"schisandra-cloud-album/api/comment_api/dto"
|
||||||
|
"schisandra-cloud-album/global"
|
||||||
|
"schisandra-cloud-album/model"
|
||||||
"schisandra-cloud-album/service"
|
"schisandra-cloud-album/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -135,3 +139,88 @@ func getMimeType(data []byte) string {
|
|||||||
|
|
||||||
return "application/octet-stream" // 默认类型
|
return "application/octet-stream" // 默认类型
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 点赞
|
||||||
|
var likeChannel = make(chan dto.CommentLikeRequest, 100)
|
||||||
|
var cancelLikeChannel = make(chan dto.CommentLikeRequest, 100) // 取消点赞
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
go likeConsumer() // 启动消费者
|
||||||
|
go cancelLikeConsumer() // 启动消费者
|
||||||
|
}
|
||||||
|
func likeConsumer() {
|
||||||
|
for likeRequest := range likeChannel {
|
||||||
|
processLike(likeRequest) // 处理点赞
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func cancelLikeConsumer() {
|
||||||
|
for cancelLikeRequest := range cancelLikeChannel {
|
||||||
|
processCancelLike(cancelLikeRequest) // 处理取消点赞
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func processLike(likeRequest dto.CommentLikeRequest) {
|
||||||
|
mx.Lock()
|
||||||
|
defer mx.Unlock()
|
||||||
|
|
||||||
|
likes := model.ScaCommentLikes{
|
||||||
|
CommentId: likeRequest.CommentId,
|
||||||
|
UserId: likeRequest.UserID,
|
||||||
|
TopicId: likeRequest.TopicId,
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := global.DB.Begin()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
res := global.DB.Create(&likes) // 假设这是插入数据库的方法
|
||||||
|
if res.Error != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
global.LOG.Errorln(res.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步更新点赞计数
|
||||||
|
go func() {
|
||||||
|
if err := commentReplyService.UpdateCommentLikesCount(likeRequest.CommentId, likeRequest.TopicId); err != nil {
|
||||||
|
global.LOG.Errorln(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
func processCancelLike(cancelLikeRequest dto.CommentLikeRequest) {
|
||||||
|
mx.Lock()
|
||||||
|
defer mx.Unlock()
|
||||||
|
|
||||||
|
tx := global.DB.Begin()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
query, u := gplus.NewQuery[model.ScaCommentLikes]()
|
||||||
|
query.Eq(&u.CommentId, cancelLikeRequest.CommentId).
|
||||||
|
Eq(&u.UserId, cancelLikeRequest.UserID).
|
||||||
|
Eq(&u.TopicId, cancelLikeRequest.TopicId)
|
||||||
|
|
||||||
|
res := gplus.Delete[model.ScaCommentLikes](query)
|
||||||
|
if res.Error != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return // 返回错误而非打印
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步更新点赞计数
|
||||||
|
go func() {
|
||||||
|
if err := commentReplyService.DecrementCommentLikesCount(cancelLikeRequest.CommentId, cancelLikeRequest.TopicId); err != nil {
|
||||||
|
global.LOG.Errorln(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -763,35 +763,13 @@ func (CommentAPI) CommentLikes(c *gin.Context) {
|
|||||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mx.Lock()
|
|
||||||
defer mx.Unlock()
|
// 将点赞请求发送到 channel 中
|
||||||
likes := model.ScaCommentLikes{
|
likeChannel <- dto.CommentLikeRequest{
|
||||||
CommentId: likeRequest.CommentId,
|
CommentId: likeRequest.CommentId,
|
||||||
UserId: likeRequest.UserID,
|
UserID: likeRequest.UserID,
|
||||||
TopicId: likeRequest.TopicId,
|
TopicId: likeRequest.TopicId,
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := global.DB.Begin()
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
res := gplus.Insert(&likes)
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
global.LOG.Errorln(res.Error)
|
|
||||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeFailed"), c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 更新点赞计数
|
|
||||||
if err = commentReplyService.UpdateCommentLikesCount(likeRequest.CommentId, likeRequest.TopicId); err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
global.LOG.Errorln(err)
|
|
||||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeFailed"), c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeSuccess"), c)
|
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeSuccess"), c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -806,36 +784,16 @@ func (CommentAPI) CommentLikes(c *gin.Context) {
|
|||||||
// @Router /auth/comment/cancel_like [post]
|
// @Router /auth/comment/cancel_like [post]
|
||||||
func (CommentAPI) CancelCommentLikes(c *gin.Context) {
|
func (CommentAPI) CancelCommentLikes(c *gin.Context) {
|
||||||
likeRequest := dto.CommentLikeRequest{}
|
likeRequest := dto.CommentLikeRequest{}
|
||||||
err := c.ShouldBindJSON(&likeRequest)
|
if err := c.ShouldBindJSON(&likeRequest); err != nil {
|
||||||
if err != nil {
|
|
||||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mx.Lock()
|
// 将取消点赞请求发送到 channel
|
||||||
defer mx.Unlock()
|
cancelLikeChannel <- dto.CommentLikeRequest{
|
||||||
tx := global.DB.Begin()
|
CommentId: likeRequest.CommentId,
|
||||||
defer func() {
|
UserID: likeRequest.UserID,
|
||||||
if r := recover(); r != nil {
|
TopicId: likeRequest.TopicId,
|
||||||
tx.Rollback()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
query, u := gplus.NewQuery[model.ScaCommentLikes]()
|
|
||||||
query.Eq(&u.CommentId, likeRequest.CommentId).Eq(&u.UserId, likeRequest.UserID).Eq(&u.TopicId, likeRequest.TopicId)
|
|
||||||
res := gplus.Delete[model.ScaCommentLikes](query)
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
global.LOG.Errorln(res.Error)
|
|
||||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelFailed"), c)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
err = commentReplyService.DecrementCommentLikesCount(likeRequest.CommentId, likeRequest.TopicId)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
global.LOG.Errorln(err)
|
|
||||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelFailed"), c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelSuccess"), c)
|
result.OkWithMessage(ginI18n.MustGetMessage(c, "CommentLikeCancelSuccess"), c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package websocket_api
|
package websocket_api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/lxzan/gws"
|
"github.com/lxzan/gws"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"schisandra-cloud-album/common/constant"
|
||||||
|
"schisandra-cloud-album/common/redis"
|
||||||
"schisandra-cloud-album/global"
|
"schisandra-cloud-album/global"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -14,6 +17,11 @@ const (
|
|||||||
HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
|
HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WebSocket struct {
|
||||||
|
gws.BuiltinEventHandler
|
||||||
|
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
|
||||||
|
}
|
||||||
|
|
||||||
var Handler = NewWebSocket()
|
var Handler = NewWebSocket()
|
||||||
|
|
||||||
// NewGWSServer 创建websocket服务
|
// NewGWSServer 创建websocket服务
|
||||||
@@ -33,6 +41,10 @@ func (WebsocketAPI) NewGWSServer(c *gin.Context) {
|
|||||||
Enabled: true, // 开启压缩
|
Enabled: true, // 开启压缩
|
||||||
},
|
},
|
||||||
Authorize: func(r *http.Request, session gws.SessionStorage) bool {
|
Authorize: func(r *http.Request, session gws.SessionStorage) bool {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin != global.CONFIG.System.WebURL() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
var clientId = r.URL.Query().Get("client_id")
|
var clientId = r.URL.Query().Get("client_id")
|
||||||
if clientId == "" {
|
if clientId == "" {
|
||||||
return false
|
return false
|
||||||
@@ -49,6 +61,8 @@ func (WebsocketAPI) NewGWSServer(c *gin.Context) {
|
|||||||
socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
|
socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MustLoad 从session中加载数据
|
||||||
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
|
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
|
||||||
if value, exist := session.Load(key); exist {
|
if value, exist := session.Load(key); exist {
|
||||||
v = value.(T)
|
v = value.(T)
|
||||||
@@ -56,26 +70,23 @@ func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewWebSocket 创建WebSocket实例
|
||||||
func NewWebSocket() *WebSocket {
|
func NewWebSocket() *WebSocket {
|
||||||
return &WebSocket{
|
return &WebSocket{
|
||||||
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
|
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type WebSocket struct {
|
// OnOpen 连接建立
|
||||||
gws.BuiltinEventHandler
|
|
||||||
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocket) OnOpen(socket *gws.Conn) {
|
func (c *WebSocket) OnOpen(socket *gws.Conn) {
|
||||||
name := MustLoad[string](socket.Session(), "client_id")
|
clientId := MustLoad[string](socket.Session(), "client_id")
|
||||||
if conn, ok := c.sessions.Load(name); ok {
|
c.sessions.Store(clientId, socket)
|
||||||
conn.WriteClose(1000, []byte("connection is replaced"))
|
// 订阅该用户的频道
|
||||||
}
|
go c.subscribeUserChannel(clientId)
|
||||||
c.sessions.Store(name, socket)
|
fmt.Printf("websocket client %s connected\n", clientId)
|
||||||
global.LOG.Printf("%s connected\n", name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnClose 关闭连接
|
||||||
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
|
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
|
||||||
name := MustLoad[string](socket.Session(), "client_id")
|
name := MustLoad[string](socket.Session(), "client_id")
|
||||||
sharding := c.sessions.GetSharding(name)
|
sharding := c.sessions.GetSharding(name)
|
||||||
@@ -86,17 +97,20 @@ func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
|
|||||||
global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
|
global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnPing 处理客户端的Ping消息
|
||||||
func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
|
func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
|
||||||
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
|
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
|
||||||
_ = socket.WritePong(payload)
|
_ = socket.WritePong(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnPong 处理客户端的Pong消息
|
||||||
func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {}
|
func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {}
|
||||||
|
|
||||||
|
// OnMessage 接受消息
|
||||||
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
|
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
|
||||||
defer message.Close()
|
defer message.Close()
|
||||||
name := MustLoad[string](socket.Session(), "client_id")
|
clientId := MustLoad[string](socket.Session(), "client_id")
|
||||||
if conn, ok := c.sessions.Load(name); ok {
|
if conn, ok := c.sessions.Load(clientId); ok {
|
||||||
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
|
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -109,3 +123,60 @@ func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("client %s not found", clientId)
|
return fmt.Errorf("client %s not found", clientId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendMessageToUser 发送消息到指定用户的 Redis 频道
|
||||||
|
func (c *WebSocket) SendMessageToUser(clientId string, message []byte) error {
|
||||||
|
if _, ok := c.sessions.Load(clientId); ok {
|
||||||
|
return redis.Publish(clientId, message).Err()
|
||||||
|
} else {
|
||||||
|
return redis.LPush(constant.CommentOfflineMessageRedisKey+clientId, message).Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订阅用户频道
|
||||||
|
func (c *WebSocket) subscribeUserChannel(clientId string) {
|
||||||
|
conn, ok := c.sessions.Load(clientId)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取离线消息
|
||||||
|
messages, err := redis.LRange(constant.CommentOfflineMessageRedisKey+clientId, 0, -1).Result()
|
||||||
|
if err != nil {
|
||||||
|
global.LOG.Printf("Error loading offline messages for user %s: %v\n", clientId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 逐条发送离线消息
|
||||||
|
for _, msg := range messages {
|
||||||
|
if writeErr := conn.WriteMessage(gws.OpcodeText, []byte(msg)); writeErr != nil {
|
||||||
|
global.LOG.Printf("Error writing offline message to user %s: %v\n", clientId, writeErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空离线消息列表
|
||||||
|
if delErr := redis.Del(constant.CommentOfflineMessageRedisKey + clientId); delErr.Err() != nil {
|
||||||
|
global.LOG.Printf("Error clearing offline messages for user %s: %v\n", clientId, delErr.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
pubsub := redis.Subscribe(clientId)
|
||||||
|
defer func() {
|
||||||
|
if closeErr := pubsub.Close(); closeErr != nil {
|
||||||
|
global.LOG.Printf("Error closing pubsub for user %s: %v\n", clientId, closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, waitErr := pubsub.ReceiveMessage(context.Background())
|
||||||
|
if waitErr != nil {
|
||||||
|
global.LOG.Printf("Error receiving message for user %s: %v\n", clientId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if writeErr := conn.WriteMessage(gws.OpcodeText, []byte(msg.Payload)); writeErr != nil {
|
||||||
|
global.LOG.Printf("Error writing message to user %s: %v\n", clientId, writeErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,5 +12,6 @@ const (
|
|||||||
|
|
||||||
// 登录之后
|
// 登录之后
|
||||||
const (
|
const (
|
||||||
CommentSubmitCaptchaRedisKey = "comment:submit:captcha:"
|
CommentSubmitCaptchaRedisKey = "comment:submit:captcha:"
|
||||||
|
CommentOfflineMessageRedisKey = "comment:offline:message:"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -394,3 +394,16 @@ func ZRem(key string, members ...interface{}) *redis.IntCmd {
|
|||||||
func ZRemRangeByRank(key string, start, stop int64) *redis.IntCmd {
|
func ZRemRangeByRank(key string, start, stop int64) *redis.IntCmd {
|
||||||
return global.REDIS.ZRemRangeByRank(ctx, key, start, stop)
|
return global.REDIS.ZRemRangeByRank(ctx, key, start, stop)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Publish 发布消息到redis
|
||||||
|
// channel是发布的目标信道
|
||||||
|
// payload是要发布的消息内容
|
||||||
|
func Publish(channel string, payload interface{}) *redis.IntCmd {
|
||||||
|
return global.REDIS.Publish(ctx, channel, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe 订阅redis消息
|
||||||
|
// channels是要订阅的信道列表
|
||||||
|
func Subscribe(channels ...string) *redis.PubSub {
|
||||||
|
return global.REDIS.Subscribe(ctx, channels...)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user