From 31eabd4e553f9f38affacc5efa6ca16583081aa7 Mon Sep 17 00:00:00 2001 From: landaiqing <3517283258@qq.com> Date: Thu, 17 Oct 2024 23:41:40 +0800 Subject: [PATCH] :sparkles: update websocket --- controller/controller.go | 22 +-- .../oauth_controller/wechat_controller.go | 4 +- .../message_ws_controller.go} | 40 +++--- .../qr_ws_controller/qr_ws_controller.go | 131 ++++++++++++++++++ router/modules/websocket_router.go | 7 +- utils/generate_avatar.go | 73 +++++----- 6 files changed, 213 insertions(+), 64 deletions(-) rename controller/websocket_controller/{gws_controller.go => message_ws_controller/message_ws_controller.go} (82%) create mode 100644 controller/websocket_controller/qr_ws_controller/qr_ws_controller.go diff --git a/controller/controller.go b/controller/controller.go index 57d21cb..bba58f2 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -9,20 +9,22 @@ import ( "schisandra-cloud-album/controller/role_controller" "schisandra-cloud-album/controller/sms_controller" "schisandra-cloud-album/controller/user_controller" - "schisandra-cloud-album/controller/websocket_controller" + "schisandra-cloud-album/controller/websocket_controller/message_ws_controller" + "schisandra-cloud-album/controller/websocket_controller/qr_ws_controller" ) // Controllers 统一导出的控制器接口 type Controllers struct { - UserController user_controller.UserController - CaptchaController captcha_controller.CaptchaController - SmsController sms_controller.SmsController - OAuthController oauth_controller.OAuthController - WebsocketController websocket_controller.WebsocketController - RoleController role_controller.RoleController - PermissionController permission_controller.PermissionController - ClientController client_controller.ClientController - CommonController comment_controller.CommentController + UserController user_controller.UserController + CaptchaController captcha_controller.CaptchaController + SmsController sms_controller.SmsController + OAuthController oauth_controller.OAuthController + QrWebsocketController qr_ws_controller.QrWebsocketController + MessageWebsocketController message_ws_controller.MessageWebsocketController + RoleController role_controller.RoleController + PermissionController permission_controller.PermissionController + ClientController client_controller.ClientController + CommonController comment_controller.CommentController } // Controller new函数实例化,实例化完成后会返回结构体地指针类型 diff --git a/controller/oauth_controller/wechat_controller.go b/controller/oauth_controller/wechat_controller.go index f34bbac..22f0553 100644 --- a/controller/oauth_controller/wechat_controller.go +++ b/controller/oauth_controller/wechat_controller.go @@ -19,7 +19,7 @@ import ( "schisandra-cloud-album/common/randomname" "schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/result" - "schisandra-cloud-album/controller/websocket_controller" + "schisandra-cloud-album/controller/websocket_controller/qr_ws_controller" "schisandra-cloud-album/global" "schisandra-cloud-album/model" "schisandra-cloud-album/utils" @@ -286,7 +286,7 @@ func handelUserLogin(userId string, clientId string, c *gin.Context) bool { return } // gws方式发送消息 - err = websocket_controller.Handler.SendMessageToClient(clientId, tokenData) + err = qr_ws_controller.Handler.SendMessageToClient(clientId, tokenData) if err != nil { global.LOG.Error(err) resultChan <- false diff --git a/controller/websocket_controller/gws_controller.go b/controller/websocket_controller/message_ws_controller/message_ws_controller.go similarity index 82% rename from controller/websocket_controller/gws_controller.go rename to controller/websocket_controller/message_ws_controller/message_ws_controller.go index de2b2a1..b9f083e 100644 --- a/controller/websocket_controller/gws_controller.go +++ b/controller/websocket_controller/message_ws_controller/message_ws_controller.go @@ -1,4 +1,4 @@ -package websocket_controller +package message_ws_controller import ( "context" @@ -9,10 +9,11 @@ import ( "schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/redis" "schisandra-cloud-album/global" + "schisandra-cloud-album/utils" "time" ) -type WebsocketController struct { +type MessageWebsocketController struct { } const ( @@ -25,16 +26,15 @@ type WebSocket struct { sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突 } -var Handler = NewWebSocket() +var MessageHandler = MessageWebSocket() -// NewGWSServer 创建websocket服务 +// MessageWSController 创建websocket服务 // @Summary 创建websocket服务 // @Description 创建websocket服务 // @Tags websocket // @Router /controller/ws/gws [get] -func (WebsocketController) NewGWSServer(c *gin.Context) { - - upgrader := gws.NewUpgrader(Handler, &gws.ServerOption{ +func (MessageWebsocketController) MessageWSController(c *gin.Context) { + upgrader := gws.NewUpgrader(MessageHandler, &gws.ServerOption{ HandshakeTimeout: 5 * time.Second, // 握手超时时间 ReadBufferSize: 1024, // 读缓冲区大小 ParallelEnabled: true, // 开启并行消息处理 @@ -48,11 +48,19 @@ func (WebsocketController) NewGWSServer(c *gin.Context) { if origin != global.CONFIG.System.WebURL() { return false } - var clientId = r.URL.Query().Get("client_id") + var clientId = r.URL.Query().Get("user_id") if clientId == "" { return false } - session.Store("client_id", clientId) + token := r.URL.Query().Get("token") + if token == "" { + return false + } + accessToken, b, err := utils.ParseAccessToken(token) + if err != nil || !b || *accessToken.UserID != clientId { + return false + } + session.Store("user_id", clientId) return true }, }) @@ -73,8 +81,8 @@ func MustLoad[T any](session gws.SessionStorage, key string) (v T) { return } -// NewWebSocket 创建WebSocket实例 -func NewWebSocket() *WebSocket { +// MessageWebSocket 创建WebSocket实例 +func MessageWebSocket() *WebSocket { return &WebSocket{ sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128), } @@ -82,22 +90,22 @@ func NewWebSocket() *WebSocket { // OnOpen 连接建立 func (c *WebSocket) OnOpen(socket *gws.Conn) { - clientId := MustLoad[string](socket.Session(), "client_id") + clientId := MustLoad[string](socket.Session(), "user_id") c.sessions.Store(clientId, socket) // 订阅该用户的频道 go c.subscribeUserChannel(clientId) - fmt.Printf("websocket client %s connected\n", clientId) + //fmt.Printf("websocket client %s connected\n", clientId) } // OnClose 关闭连接 func (c *WebSocket) OnClose(socket *gws.Conn, err error) { - name := MustLoad[string](socket.Session(), "client_id") + name := MustLoad[string](socket.Session(), "user_id") sharding := c.sessions.GetSharding(name) c.sessions.Delete(name) sharding.Lock() defer sharding.Unlock() - global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error()) + //global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error()) } // OnPing 处理客户端的Ping消息 @@ -112,7 +120,7 @@ func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {} // OnMessage 接受消息 func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) { defer message.Close() - clientId := MustLoad[string](socket.Session(), "client_id") + clientId := MustLoad[string](socket.Session(), "user_id") if conn, ok := c.sessions.Load(clientId); ok { _ = conn.WriteMessage(gws.OpcodeText, message.Bytes()) } diff --git a/controller/websocket_controller/qr_ws_controller/qr_ws_controller.go b/controller/websocket_controller/qr_ws_controller/qr_ws_controller.go new file mode 100644 index 0000000..aa7dcd5 --- /dev/null +++ b/controller/websocket_controller/qr_ws_controller/qr_ws_controller.go @@ -0,0 +1,131 @@ +package qr_ws_controller + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/lxzan/gws" + "net/http" + "schisandra-cloud-album/common/constant" + "schisandra-cloud-album/common/redis" + "schisandra-cloud-album/global" + "schisandra-cloud-album/utils" + "time" +) + +type QrWebsocketController struct { +} + +const ( + PingInterval = 5 * time.Second // 客户端心跳间隔 + HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间 +) + +type WebSocket struct { + gws.BuiltinEventHandler + sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突 +} + +var Handler = NewWebSocket() + +// QrWebsocket 创建websocket服务 +// @Summary 创建websocket服务 +// @Description 创建websocket服务 +// @Tags websocket +// @Router /controller/ws/gws [get] +func (QrWebsocketController) QrWebsocket(c *gin.Context) { + + upgrader := gws.NewUpgrader(Handler, &gws.ServerOption{ + HandshakeTimeout: 5 * time.Second, // 握手超时时间 + ReadBufferSize: 1024, // 读缓冲区大小 + ParallelEnabled: true, // 开启并行消息处理 + Recovery: gws.Recovery, // 开启异常恢复 + CheckUtf8Enabled: false, // 关闭UTF8校验 + PermessageDeflate: gws.PermessageDeflate{ + Enabled: true, // 开启压缩 + }, + Authorize: func(r *http.Request, session gws.SessionStorage) bool { + origin := r.Header.Get("Origin") + if origin != global.CONFIG.System.WebURL() { + return false + } + var clientId = r.URL.Query().Get("client_id") + if clientId == "" { + return false + } + ip := utils.GetClientIP(c) + exists := redis.Get(constant.UserLoginClientRedisKey + ip).Val() + if clientId != exists { + return false + } + session.Store("client_id", clientId) + return true + }, + }) + socket, err := upgrader.Upgrade(c.Writer, c.Request) + if err != nil { + return + } + go func() { + socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC + }() +} + +// MustLoad 从session中加载数据 +func MustLoad[T any](session gws.SessionStorage, key string) (v T) { + if value, exist := session.Load(key); exist { + v = value.(T) + } + return +} + +// NewWebSocket 创建WebSocket实例 +func NewWebSocket() *WebSocket { + return &WebSocket{ + sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128), + } +} + +// OnOpen 连接建立 +func (c *WebSocket) OnOpen(socket *gws.Conn) { + clientId := MustLoad[string](socket.Session(), "client_id") + c.sessions.Store(clientId, socket) + //fmt.Printf("websocket client %s connected\n", clientId) +} + +// OnClose 关闭连接 +func (c *WebSocket) OnClose(socket *gws.Conn, err error) { + name := MustLoad[string](socket.Session(), "client_id") + sharding := c.sessions.GetSharding(name) + c.sessions.Delete(name) + sharding.Lock() + defer sharding.Unlock() + + //global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error()) +} + +// OnPing 处理客户端的Ping消息 +func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) { + _ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout)) + _ = socket.WritePong(payload) +} + +// OnPong 处理客户端的Pong消息 +func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {} + +// OnMessage 接受消息 +func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) { + defer message.Close() + clientId := MustLoad[string](socket.Session(), "client_id") + if conn, ok := c.sessions.Load(clientId); ok { + _ = conn.WriteMessage(gws.OpcodeText, message.Bytes()) + } +} + +// SendMessageToClient 向指定客户端发送消息 +func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error { + conn, ok := c.sessions.Load(clientId) + if ok { + return conn.WriteMessage(gws.OpcodeText, message) + } + return fmt.Errorf("client %s not found", clientId) +} diff --git a/router/modules/websocket_router.go b/router/modules/websocket_router.go index fefaf9e..e062d28 100644 --- a/router/modules/websocket_router.go +++ b/router/modules/websocket_router.go @@ -5,12 +5,13 @@ import ( "schisandra-cloud-album/controller" ) -var websocketAPI = controller.Controller.WebsocketController +var qrWebsocketAPI = controller.Controller.QrWebsocketController +var messageWebsocketAPI = controller.Controller.MessageWebsocketController func WebsocketRouter(router *gin.RouterGroup) { group := router.Group("/ws") { - group.GET("/gws", websocketAPI.NewGWSServer) + group.GET("/qr_ws", qrWebsocketAPI.QrWebsocket) + group.GET("/message_ws", messageWebsocketAPI.MessageWSController) } - } diff --git a/utils/generate_avatar.go b/utils/generate_avatar.go index 51269ab..1c7010b 100644 --- a/utils/generate_avatar.go +++ b/utils/generate_avatar.go @@ -1,7 +1,10 @@ package utils import ( + "encoding/base64" + "io" "net/http" + "schisandra-cloud-album/global" "time" ) @@ -15,37 +18,41 @@ func GenerateAvatar(userId string) (baseImg string) { path := "https://api.multiavatar.com/" + userId + ".png" - //// 创建请求 - //request, err := http.NewRequest("GET", path, nil) - //if err != nil { - // return "", errors.New("image request error") - //} - // - //// 发送请求并获取响应 - //respImg, err := client.Do(request) - //if err != nil { - // return "", errors.New("failed to fetch image") - //} - //defer func(Body io.ReadCloser) { - // err := Body.Close() - // if err != nil { - // return - // } - //}(respImg.Body) - // - //// 读取图片数据 - //imgByte, err := io.ReadAll(respImg.Body) - //if err != nil { - // return "", errors.New("failed to read image data") - //} - // - //// 判断文件类型,生成一个前缀 - //mimeType := http.DetectContentType(imgByte) - //switch mimeType { - //case "image/png": - // baseImg = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgByte) - //default: - // return "", errors.New("unsupported image type") - //} - return path + // 创建请求 + request, err := http.NewRequest("GET", path, nil) + if err != nil { + global.LOG.Error(err) + return "" + } + + // 发送请求并获取响应 + respImg, err := client.Do(request) + if err != nil { + global.LOG.Error(err) + return "" + } + defer func(Body io.ReadCloser) { + err = Body.Close() + if err != nil { + global.LOG.Error(err) + return + } + }(respImg.Body) + + // 读取图片数据 + imgByte, err := io.ReadAll(respImg.Body) + if err != nil { + global.LOG.Error(err) + return "" + } + + // 判断文件类型,生成一个前缀 + mimeType := http.DetectContentType(imgByte) + switch mimeType { + case "image/png": + baseImg = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgByte) + default: + return "" + } + return baseImg }