Files
schisandra-album-cloud-micr…/app/auth/api/internal/logic/websocket/message_websocket_logic.go
2024-12-24 00:38:41 +08:00

127 lines
3.6 KiB
Go

package websocket
import (
"context"
"fmt"
"net/http"
"schisandra-album-cloud-microservices/common/jwt"
"time"
"github.com/lxzan/gws"
"github.com/zeromicro/go-zero/core/logx"
"schisandra-album-cloud-microservices/app/auth/api/internal/svc"
)
type MessageWebsocketLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewMessageWebsocketLogic(ctx context.Context, svcCtx *svc.ServiceContext) *MessageWebsocketLogic {
return &MessageWebsocketLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
type MessageWebSocket struct {
ctx context.Context
gws.BuiltinEventHandler
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
}
var MessageWebSocketHandler = NewMessageWebSocket()
func (l *MessageWebsocketLogic) MessageWebsocket(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Sec-Websocket-Protocol")
accessToken, res := jwt.ParseAccessToken(l.svcCtx.Config.Auth.AccessSecret, token)
if !res {
return
}
upgrader := gws.NewUpgrader(MessageWebSocketHandler, &gws.ServerOption{
HandshakeTimeout: 5 * time.Second, // 握手超时时间
ReadBufferSize: 1024, // 读缓冲区大小
ParallelEnabled: true, // 开启并行消息处理
Recovery: gws.Recovery, // 开启异常恢复
CheckUtf8Enabled: false, // 关闭UTF8校验
PermessageDeflate: gws.PermessageDeflate{
Enabled: true, // 开启压缩
},
Authorize: func(r *http.Request, session gws.SessionStorage) bool {
var clientId = r.URL.Query().Get("user_id")
if clientId == "" {
return false
}
if accessToken.UserID != clientId {
return false
}
//token := r.URL.Query().Get("token")
//if token == "" {
// return false
//}
//accessToken, res := jwt.ParseAccessToken(l.svcCtx.Config.Auth.AccessSecret, token)
//if !res || accessToken.UserID != clientId {
// return false
//}
session.Store("user_id", clientId)
return true
},
SubProtocols: []string{token},
})
socket, err := upgrader.Upgrade(w, r)
if err != nil {
panic(err)
}
go func() {
socket.ReadLoop()
}()
}
// NewMessageWebSocket 创建WebSocket实例
func NewMessageWebSocket() *MessageWebSocket {
return &MessageWebSocket{
ctx: context.Background(),
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
}
}
// OnOpen 连接建立
func (c *MessageWebSocket) OnOpen(socket *gws.Conn) {
clientId := MustLoad[string](socket.Session(), "user_id")
c.sessions.Store(clientId, socket)
// 订阅该用户的频道
fmt.Printf("websocket client %s connected\n", clientId)
}
// OnClose 关闭连接
func (c *MessageWebSocket) OnClose(socket *gws.Conn, err error) {
name := MustLoad[string](socket.Session(), "user_id")
sharding := c.sessions.GetSharding(name)
c.sessions.Delete(name)
sharding.Lock()
defer sharding.Unlock()
fmt.Printf("websocket closed, name=%s, msg=%s\n", name, err.Error())
}
// OnPing 处理客户端的Ping消息
func (c *MessageWebSocket) OnPing(socket *gws.Conn, payload []byte) {
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
_ = socket.WritePong(payload)
}
// OnPong 处理客户端的Pong消息
func (c *MessageWebSocket) OnPong(_ *gws.Conn, _ []byte) {}
// OnMessage 接受消息
func (c *MessageWebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
defer message.Close()
clientId := MustLoad[string](socket.Session(), "user_id")
if conn, ok := c.sessions.Load(clientId); ok {
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
}
// fmt.Printf("received message from client %s\n", message.Data)
}