update websocket

This commit is contained in:
landaiqing
2024-10-17 23:41:40 +08:00
parent b5d88a7ccd
commit 31eabd4e55
6 changed files with 213 additions and 64 deletions

View File

@@ -9,20 +9,22 @@ import (
"schisandra-cloud-album/controller/role_controller" "schisandra-cloud-album/controller/role_controller"
"schisandra-cloud-album/controller/sms_controller" "schisandra-cloud-album/controller/sms_controller"
"schisandra-cloud-album/controller/user_controller" "schisandra-cloud-album/controller/user_controller"
"schisandra-cloud-album/controller/websocket_controller" "schisandra-cloud-album/controller/websocket_controller/message_ws_controller"
"schisandra-cloud-album/controller/websocket_controller/qr_ws_controller"
) )
// Controllers 统一导出的控制器接口 // Controllers 统一导出的控制器接口
type Controllers struct { type Controllers struct {
UserController user_controller.UserController UserController user_controller.UserController
CaptchaController captcha_controller.CaptchaController CaptchaController captcha_controller.CaptchaController
SmsController sms_controller.SmsController SmsController sms_controller.SmsController
OAuthController oauth_controller.OAuthController OAuthController oauth_controller.OAuthController
WebsocketController websocket_controller.WebsocketController QrWebsocketController qr_ws_controller.QrWebsocketController
RoleController role_controller.RoleController MessageWebsocketController message_ws_controller.MessageWebsocketController
PermissionController permission_controller.PermissionController RoleController role_controller.RoleController
ClientController client_controller.ClientController PermissionController permission_controller.PermissionController
CommonController comment_controller.CommentController ClientController client_controller.ClientController
CommonController comment_controller.CommentController
} }
// Controller new函数实例化实例化完成后会返回结构体地指针类型 // Controller new函数实例化实例化完成后会返回结构体地指针类型

View File

@@ -19,7 +19,7 @@ import (
"schisandra-cloud-album/common/randomname" "schisandra-cloud-album/common/randomname"
"schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/redis"
"schisandra-cloud-album/common/result" "schisandra-cloud-album/common/result"
"schisandra-cloud-album/controller/websocket_controller" "schisandra-cloud-album/controller/websocket_controller/qr_ws_controller"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
"schisandra-cloud-album/model" "schisandra-cloud-album/model"
"schisandra-cloud-album/utils" "schisandra-cloud-album/utils"
@@ -286,7 +286,7 @@ func handelUserLogin(userId string, clientId string, c *gin.Context) bool {
return return
} }
// gws方式发送消息 // gws方式发送消息
err = websocket_controller.Handler.SendMessageToClient(clientId, tokenData) err = qr_ws_controller.Handler.SendMessageToClient(clientId, tokenData)
if err != nil { if err != nil {
global.LOG.Error(err) global.LOG.Error(err)
resultChan <- false resultChan <- false

View File

@@ -1,4 +1,4 @@
package websocket_controller package message_ws_controller
import ( import (
"context" "context"
@@ -9,10 +9,11 @@ import (
"schisandra-cloud-album/common/constant" "schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/redis" "schisandra-cloud-album/common/redis"
"schisandra-cloud-album/global" "schisandra-cloud-album/global"
"schisandra-cloud-album/utils"
"time" "time"
) )
type WebsocketController struct { type MessageWebsocketController struct {
} }
const ( const (
@@ -25,16 +26,15 @@ type WebSocket struct {
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突 sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
} }
var Handler = NewWebSocket() var MessageHandler = MessageWebSocket()
// NewGWSServer 创建websocket服务 // MessageWSController 创建websocket服务
// @Summary 创建websocket服务 // @Summary 创建websocket服务
// @Description 创建websocket服务 // @Description 创建websocket服务
// @Tags websocket // @Tags websocket
// @Router /controller/ws/gws [get] // @Router /controller/ws/gws [get]
func (WebsocketController) NewGWSServer(c *gin.Context) { func (MessageWebsocketController) MessageWSController(c *gin.Context) {
upgrader := gws.NewUpgrader(MessageHandler, &gws.ServerOption{
upgrader := gws.NewUpgrader(Handler, &gws.ServerOption{
HandshakeTimeout: 5 * time.Second, // 握手超时时间 HandshakeTimeout: 5 * time.Second, // 握手超时时间
ReadBufferSize: 1024, // 读缓冲区大小 ReadBufferSize: 1024, // 读缓冲区大小
ParallelEnabled: true, // 开启并行消息处理 ParallelEnabled: true, // 开启并行消息处理
@@ -48,11 +48,19 @@ func (WebsocketController) NewGWSServer(c *gin.Context) {
if origin != global.CONFIG.System.WebURL() { if origin != global.CONFIG.System.WebURL() {
return false return false
} }
var clientId = r.URL.Query().Get("client_id") var clientId = r.URL.Query().Get("user_id")
if clientId == "" { if clientId == "" {
return false return false
} }
session.Store("client_id", clientId) token := r.URL.Query().Get("token")
if token == "" {
return false
}
accessToken, b, err := utils.ParseAccessToken(token)
if err != nil || !b || *accessToken.UserID != clientId {
return false
}
session.Store("user_id", clientId)
return true return true
}, },
}) })
@@ -73,8 +81,8 @@ func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
return return
} }
// NewWebSocket 创建WebSocket实例 // MessageWebSocket 创建WebSocket实例
func NewWebSocket() *WebSocket { func MessageWebSocket() *WebSocket {
return &WebSocket{ return &WebSocket{
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128), sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
} }
@@ -82,22 +90,22 @@ func NewWebSocket() *WebSocket {
// OnOpen 连接建立 // OnOpen 连接建立
func (c *WebSocket) OnOpen(socket *gws.Conn) { func (c *WebSocket) OnOpen(socket *gws.Conn) {
clientId := MustLoad[string](socket.Session(), "client_id") clientId := MustLoad[string](socket.Session(), "user_id")
c.sessions.Store(clientId, socket) c.sessions.Store(clientId, socket)
// 订阅该用户的频道 // 订阅该用户的频道
go c.subscribeUserChannel(clientId) go c.subscribeUserChannel(clientId)
fmt.Printf("websocket client %s connected\n", clientId) //fmt.Printf("websocket client %s connected\n", clientId)
} }
// OnClose 关闭连接 // OnClose 关闭连接
func (c *WebSocket) OnClose(socket *gws.Conn, err error) { func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
name := MustLoad[string](socket.Session(), "client_id") name := MustLoad[string](socket.Session(), "user_id")
sharding := c.sessions.GetSharding(name) sharding := c.sessions.GetSharding(name)
c.sessions.Delete(name) c.sessions.Delete(name)
sharding.Lock() sharding.Lock()
defer sharding.Unlock() defer sharding.Unlock()
global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error()) //global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
} }
// OnPing 处理客户端的Ping消息 // OnPing 处理客户端的Ping消息
@@ -112,7 +120,7 @@ func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {}
// OnMessage 接受消息 // OnMessage 接受消息
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) { func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
defer message.Close() defer message.Close()
clientId := MustLoad[string](socket.Session(), "client_id") clientId := MustLoad[string](socket.Session(), "user_id")
if conn, ok := c.sessions.Load(clientId); ok { if conn, ok := c.sessions.Load(clientId); ok {
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes()) _ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
} }

View File

@@ -0,0 +1,131 @@
package qr_ws_controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/lxzan/gws"
"net/http"
"schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/redis"
"schisandra-cloud-album/global"
"schisandra-cloud-album/utils"
"time"
)
type QrWebsocketController struct {
}
const (
PingInterval = 5 * time.Second // 客户端心跳间隔
HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
)
type WebSocket struct {
gws.BuiltinEventHandler
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
}
var Handler = NewWebSocket()
// QrWebsocket 创建websocket服务
// @Summary 创建websocket服务
// @Description 创建websocket服务
// @Tags websocket
// @Router /controller/ws/gws [get]
func (QrWebsocketController) QrWebsocket(c *gin.Context) {
upgrader := gws.NewUpgrader(Handler, &gws.ServerOption{
HandshakeTimeout: 5 * time.Second, // 握手超时时间
ReadBufferSize: 1024, // 读缓冲区大小
ParallelEnabled: true, // 开启并行消息处理
Recovery: gws.Recovery, // 开启异常恢复
CheckUtf8Enabled: false, // 关闭UTF8校验
PermessageDeflate: gws.PermessageDeflate{
Enabled: true, // 开启压缩
},
Authorize: func(r *http.Request, session gws.SessionStorage) bool {
origin := r.Header.Get("Origin")
if origin != global.CONFIG.System.WebURL() {
return false
}
var clientId = r.URL.Query().Get("client_id")
if clientId == "" {
return false
}
ip := utils.GetClientIP(c)
exists := redis.Get(constant.UserLoginClientRedisKey + ip).Val()
if clientId != exists {
return false
}
session.Store("client_id", clientId)
return true
},
})
socket, err := upgrader.Upgrade(c.Writer, c.Request)
if err != nil {
return
}
go func() {
socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
}()
}
// MustLoad 从session中加载数据
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
if value, exist := session.Load(key); exist {
v = value.(T)
}
return
}
// NewWebSocket 创建WebSocket实例
func NewWebSocket() *WebSocket {
return &WebSocket{
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
}
}
// OnOpen 连接建立
func (c *WebSocket) OnOpen(socket *gws.Conn) {
clientId := MustLoad[string](socket.Session(), "client_id")
c.sessions.Store(clientId, socket)
//fmt.Printf("websocket client %s connected\n", clientId)
}
// OnClose 关闭连接
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
name := MustLoad[string](socket.Session(), "client_id")
sharding := c.sessions.GetSharding(name)
c.sessions.Delete(name)
sharding.Lock()
defer sharding.Unlock()
//global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
}
// OnPing 处理客户端的Ping消息
func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
_ = socket.WritePong(payload)
}
// OnPong 处理客户端的Pong消息
func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {}
// OnMessage 接受消息
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
defer message.Close()
clientId := MustLoad[string](socket.Session(), "client_id")
if conn, ok := c.sessions.Load(clientId); ok {
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
}
}
// SendMessageToClient 向指定客户端发送消息
func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
conn, ok := c.sessions.Load(clientId)
if ok {
return conn.WriteMessage(gws.OpcodeText, message)
}
return fmt.Errorf("client %s not found", clientId)
}

View File

@@ -5,12 +5,13 @@ import (
"schisandra-cloud-album/controller" "schisandra-cloud-album/controller"
) )
var websocketAPI = controller.Controller.WebsocketController var qrWebsocketAPI = controller.Controller.QrWebsocketController
var messageWebsocketAPI = controller.Controller.MessageWebsocketController
func WebsocketRouter(router *gin.RouterGroup) { func WebsocketRouter(router *gin.RouterGroup) {
group := router.Group("/ws") group := router.Group("/ws")
{ {
group.GET("/gws", websocketAPI.NewGWSServer) group.GET("/qr_ws", qrWebsocketAPI.QrWebsocket)
group.GET("/message_ws", messageWebsocketAPI.MessageWSController)
} }
} }

View File

@@ -1,7 +1,10 @@
package utils package utils
import ( import (
"encoding/base64"
"io"
"net/http" "net/http"
"schisandra-cloud-album/global"
"time" "time"
) )
@@ -15,37 +18,41 @@ func GenerateAvatar(userId string) (baseImg string) {
path := "https://api.multiavatar.com/" + userId + ".png" path := "https://api.multiavatar.com/" + userId + ".png"
//// 创建请求 // 创建请求
//request, err := http.NewRequest("GET", path, nil) request, err := http.NewRequest("GET", path, nil)
//if err != nil { if err != nil {
// return "", errors.New("image request error") global.LOG.Error(err)
//} return ""
// }
//// 发送请求并获取响应
//respImg, err := client.Do(request) // 发送请求并获取响应
//if err != nil { respImg, err := client.Do(request)
// return "", errors.New("failed to fetch image") if err != nil {
//} global.LOG.Error(err)
//defer func(Body io.ReadCloser) { return ""
// err := Body.Close() }
// if err != nil { defer func(Body io.ReadCloser) {
// return err = Body.Close()
// } if err != nil {
//}(respImg.Body) global.LOG.Error(err)
// return
//// 读取图片数据 }
//imgByte, err := io.ReadAll(respImg.Body) }(respImg.Body)
//if err != nil {
// return "", errors.New("failed to read image data") // 读取图片数据
//} imgByte, err := io.ReadAll(respImg.Body)
// if err != nil {
//// 判断文件类型,生成一个前缀 global.LOG.Error(err)
//mimeType := http.DetectContentType(imgByte) return ""
//switch mimeType { }
//case "image/png":
// baseImg = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgByte) // 判断文件类型,生成一个前缀
//default: mimeType := http.DetectContentType(imgByte)
// return "", errors.New("unsupported image type") switch mimeType {
//} case "image/png":
return path baseImg = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgByte)
default:
return ""
}
return baseImg
} }