✨ update websocket
This commit is contained in:
@@ -9,20 +9,22 @@ import (
|
|||||||
"schisandra-cloud-album/controller/role_controller"
|
"schisandra-cloud-album/controller/role_controller"
|
||||||
"schisandra-cloud-album/controller/sms_controller"
|
"schisandra-cloud-album/controller/sms_controller"
|
||||||
"schisandra-cloud-album/controller/user_controller"
|
"schisandra-cloud-album/controller/user_controller"
|
||||||
"schisandra-cloud-album/controller/websocket_controller"
|
"schisandra-cloud-album/controller/websocket_controller/message_ws_controller"
|
||||||
|
"schisandra-cloud-album/controller/websocket_controller/qr_ws_controller"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Controllers 统一导出的控制器接口
|
// Controllers 统一导出的控制器接口
|
||||||
type Controllers struct {
|
type Controllers struct {
|
||||||
UserController user_controller.UserController
|
UserController user_controller.UserController
|
||||||
CaptchaController captcha_controller.CaptchaController
|
CaptchaController captcha_controller.CaptchaController
|
||||||
SmsController sms_controller.SmsController
|
SmsController sms_controller.SmsController
|
||||||
OAuthController oauth_controller.OAuthController
|
OAuthController oauth_controller.OAuthController
|
||||||
WebsocketController websocket_controller.WebsocketController
|
QrWebsocketController qr_ws_controller.QrWebsocketController
|
||||||
RoleController role_controller.RoleController
|
MessageWebsocketController message_ws_controller.MessageWebsocketController
|
||||||
PermissionController permission_controller.PermissionController
|
RoleController role_controller.RoleController
|
||||||
ClientController client_controller.ClientController
|
PermissionController permission_controller.PermissionController
|
||||||
CommonController comment_controller.CommentController
|
ClientController client_controller.ClientController
|
||||||
|
CommonController comment_controller.CommentController
|
||||||
}
|
}
|
||||||
|
|
||||||
// Controller new函数实例化,实例化完成后会返回结构体地指针类型
|
// Controller new函数实例化,实例化完成后会返回结构体地指针类型
|
||||||
|
@@ -19,7 +19,7 @@ import (
|
|||||||
"schisandra-cloud-album/common/randomname"
|
"schisandra-cloud-album/common/randomname"
|
||||||
"schisandra-cloud-album/common/redis"
|
"schisandra-cloud-album/common/redis"
|
||||||
"schisandra-cloud-album/common/result"
|
"schisandra-cloud-album/common/result"
|
||||||
"schisandra-cloud-album/controller/websocket_controller"
|
"schisandra-cloud-album/controller/websocket_controller/qr_ws_controller"
|
||||||
"schisandra-cloud-album/global"
|
"schisandra-cloud-album/global"
|
||||||
"schisandra-cloud-album/model"
|
"schisandra-cloud-album/model"
|
||||||
"schisandra-cloud-album/utils"
|
"schisandra-cloud-album/utils"
|
||||||
@@ -286,7 +286,7 @@ func handelUserLogin(userId string, clientId string, c *gin.Context) bool {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// gws方式发送消息
|
// gws方式发送消息
|
||||||
err = websocket_controller.Handler.SendMessageToClient(clientId, tokenData)
|
err = qr_ws_controller.Handler.SendMessageToClient(clientId, tokenData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.LOG.Error(err)
|
global.LOG.Error(err)
|
||||||
resultChan <- false
|
resultChan <- false
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
package websocket_controller
|
package message_ws_controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,10 +9,11 @@ import (
|
|||||||
"schisandra-cloud-album/common/constant"
|
"schisandra-cloud-album/common/constant"
|
||||||
"schisandra-cloud-album/common/redis"
|
"schisandra-cloud-album/common/redis"
|
||||||
"schisandra-cloud-album/global"
|
"schisandra-cloud-album/global"
|
||||||
|
"schisandra-cloud-album/utils"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WebsocketController struct {
|
type MessageWebsocketController struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,16 +26,15 @@ type WebSocket struct {
|
|||||||
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
|
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
|
||||||
}
|
}
|
||||||
|
|
||||||
var Handler = NewWebSocket()
|
var MessageHandler = MessageWebSocket()
|
||||||
|
|
||||||
// NewGWSServer 创建websocket服务
|
// MessageWSController 创建websocket服务
|
||||||
// @Summary 创建websocket服务
|
// @Summary 创建websocket服务
|
||||||
// @Description 创建websocket服务
|
// @Description 创建websocket服务
|
||||||
// @Tags websocket
|
// @Tags websocket
|
||||||
// @Router /controller/ws/gws [get]
|
// @Router /controller/ws/gws [get]
|
||||||
func (WebsocketController) NewGWSServer(c *gin.Context) {
|
func (MessageWebsocketController) MessageWSController(c *gin.Context) {
|
||||||
|
upgrader := gws.NewUpgrader(MessageHandler, &gws.ServerOption{
|
||||||
upgrader := gws.NewUpgrader(Handler, &gws.ServerOption{
|
|
||||||
HandshakeTimeout: 5 * time.Second, // 握手超时时间
|
HandshakeTimeout: 5 * time.Second, // 握手超时时间
|
||||||
ReadBufferSize: 1024, // 读缓冲区大小
|
ReadBufferSize: 1024, // 读缓冲区大小
|
||||||
ParallelEnabled: true, // 开启并行消息处理
|
ParallelEnabled: true, // 开启并行消息处理
|
||||||
@@ -48,11 +48,19 @@ func (WebsocketController) NewGWSServer(c *gin.Context) {
|
|||||||
if origin != global.CONFIG.System.WebURL() {
|
if origin != global.CONFIG.System.WebURL() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
var clientId = r.URL.Query().Get("client_id")
|
var clientId = r.URL.Query().Get("user_id")
|
||||||
if clientId == "" {
|
if clientId == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
session.Store("client_id", clientId)
|
token := r.URL.Query().Get("token")
|
||||||
|
if token == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
accessToken, b, err := utils.ParseAccessToken(token)
|
||||||
|
if err != nil || !b || *accessToken.UserID != clientId {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
session.Store("user_id", clientId)
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -73,8 +81,8 @@ func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWebSocket 创建WebSocket实例
|
// MessageWebSocket 创建WebSocket实例
|
||||||
func NewWebSocket() *WebSocket {
|
func MessageWebSocket() *WebSocket {
|
||||||
return &WebSocket{
|
return &WebSocket{
|
||||||
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
|
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
|
||||||
}
|
}
|
||||||
@@ -82,22 +90,22 @@ func NewWebSocket() *WebSocket {
|
|||||||
|
|
||||||
// OnOpen 连接建立
|
// OnOpen 连接建立
|
||||||
func (c *WebSocket) OnOpen(socket *gws.Conn) {
|
func (c *WebSocket) OnOpen(socket *gws.Conn) {
|
||||||
clientId := MustLoad[string](socket.Session(), "client_id")
|
clientId := MustLoad[string](socket.Session(), "user_id")
|
||||||
c.sessions.Store(clientId, socket)
|
c.sessions.Store(clientId, socket)
|
||||||
// 订阅该用户的频道
|
// 订阅该用户的频道
|
||||||
go c.subscribeUserChannel(clientId)
|
go c.subscribeUserChannel(clientId)
|
||||||
fmt.Printf("websocket client %s connected\n", clientId)
|
//fmt.Printf("websocket client %s connected\n", clientId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnClose 关闭连接
|
// OnClose 关闭连接
|
||||||
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
|
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
|
||||||
name := MustLoad[string](socket.Session(), "client_id")
|
name := MustLoad[string](socket.Session(), "user_id")
|
||||||
sharding := c.sessions.GetSharding(name)
|
sharding := c.sessions.GetSharding(name)
|
||||||
c.sessions.Delete(name)
|
c.sessions.Delete(name)
|
||||||
sharding.Lock()
|
sharding.Lock()
|
||||||
defer sharding.Unlock()
|
defer sharding.Unlock()
|
||||||
|
|
||||||
global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
|
//global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPing 处理客户端的Ping消息
|
// OnPing 处理客户端的Ping消息
|
||||||
@@ -112,7 +120,7 @@ func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {}
|
|||||||
// OnMessage 接受消息
|
// OnMessage 接受消息
|
||||||
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
|
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
|
||||||
defer message.Close()
|
defer message.Close()
|
||||||
clientId := MustLoad[string](socket.Session(), "client_id")
|
clientId := MustLoad[string](socket.Session(), "user_id")
|
||||||
if conn, ok := c.sessions.Load(clientId); ok {
|
if conn, ok := c.sessions.Load(clientId); ok {
|
||||||
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
|
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
|
||||||
}
|
}
|
@@ -0,0 +1,131 @@
|
|||||||
|
package qr_ws_controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/lxzan/gws"
|
||||||
|
"net/http"
|
||||||
|
"schisandra-cloud-album/common/constant"
|
||||||
|
"schisandra-cloud-album/common/redis"
|
||||||
|
"schisandra-cloud-album/global"
|
||||||
|
"schisandra-cloud-album/utils"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type QrWebsocketController struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
PingInterval = 5 * time.Second // 客户端心跳间隔
|
||||||
|
HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
|
||||||
|
)
|
||||||
|
|
||||||
|
type WebSocket struct {
|
||||||
|
gws.BuiltinEventHandler
|
||||||
|
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
|
||||||
|
}
|
||||||
|
|
||||||
|
var Handler = NewWebSocket()
|
||||||
|
|
||||||
|
// QrWebsocket 创建websocket服务
|
||||||
|
// @Summary 创建websocket服务
|
||||||
|
// @Description 创建websocket服务
|
||||||
|
// @Tags websocket
|
||||||
|
// @Router /controller/ws/gws [get]
|
||||||
|
func (QrWebsocketController) QrWebsocket(c *gin.Context) {
|
||||||
|
|
||||||
|
upgrader := gws.NewUpgrader(Handler, &gws.ServerOption{
|
||||||
|
HandshakeTimeout: 5 * time.Second, // 握手超时时间
|
||||||
|
ReadBufferSize: 1024, // 读缓冲区大小
|
||||||
|
ParallelEnabled: true, // 开启并行消息处理
|
||||||
|
Recovery: gws.Recovery, // 开启异常恢复
|
||||||
|
CheckUtf8Enabled: false, // 关闭UTF8校验
|
||||||
|
PermessageDeflate: gws.PermessageDeflate{
|
||||||
|
Enabled: true, // 开启压缩
|
||||||
|
},
|
||||||
|
Authorize: func(r *http.Request, session gws.SessionStorage) bool {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin != global.CONFIG.System.WebURL() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var clientId = r.URL.Query().Get("client_id")
|
||||||
|
if clientId == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ip := utils.GetClientIP(c)
|
||||||
|
exists := redis.Get(constant.UserLoginClientRedisKey + ip).Val()
|
||||||
|
if clientId != exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
session.Store("client_id", clientId)
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
})
|
||||||
|
socket, err := upgrader.Upgrade(c.Writer, c.Request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustLoad 从session中加载数据
|
||||||
|
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
|
||||||
|
if value, exist := session.Load(key); exist {
|
||||||
|
v = value.(T)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebSocket 创建WebSocket实例
|
||||||
|
func NewWebSocket() *WebSocket {
|
||||||
|
return &WebSocket{
|
||||||
|
sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnOpen 连接建立
|
||||||
|
func (c *WebSocket) OnOpen(socket *gws.Conn) {
|
||||||
|
clientId := MustLoad[string](socket.Session(), "client_id")
|
||||||
|
c.sessions.Store(clientId, socket)
|
||||||
|
//fmt.Printf("websocket client %s connected\n", clientId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnClose 关闭连接
|
||||||
|
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
|
||||||
|
name := MustLoad[string](socket.Session(), "client_id")
|
||||||
|
sharding := c.sessions.GetSharding(name)
|
||||||
|
c.sessions.Delete(name)
|
||||||
|
sharding.Lock()
|
||||||
|
defer sharding.Unlock()
|
||||||
|
|
||||||
|
//global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPing 处理客户端的Ping消息
|
||||||
|
func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
|
||||||
|
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
|
||||||
|
_ = socket.WritePong(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPong 处理客户端的Pong消息
|
||||||
|
func (c *WebSocket) OnPong(_ *gws.Conn, _ []byte) {}
|
||||||
|
|
||||||
|
// OnMessage 接受消息
|
||||||
|
func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
|
||||||
|
defer message.Close()
|
||||||
|
clientId := MustLoad[string](socket.Session(), "client_id")
|
||||||
|
if conn, ok := c.sessions.Load(clientId); ok {
|
||||||
|
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessageToClient 向指定客户端发送消息
|
||||||
|
func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
|
||||||
|
conn, ok := c.sessions.Load(clientId)
|
||||||
|
if ok {
|
||||||
|
return conn.WriteMessage(gws.OpcodeText, message)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("client %s not found", clientId)
|
||||||
|
}
|
@@ -5,12 +5,13 @@ import (
|
|||||||
"schisandra-cloud-album/controller"
|
"schisandra-cloud-album/controller"
|
||||||
)
|
)
|
||||||
|
|
||||||
var websocketAPI = controller.Controller.WebsocketController
|
var qrWebsocketAPI = controller.Controller.QrWebsocketController
|
||||||
|
var messageWebsocketAPI = controller.Controller.MessageWebsocketController
|
||||||
|
|
||||||
func WebsocketRouter(router *gin.RouterGroup) {
|
func WebsocketRouter(router *gin.RouterGroup) {
|
||||||
group := router.Group("/ws")
|
group := router.Group("/ws")
|
||||||
{
|
{
|
||||||
group.GET("/gws", websocketAPI.NewGWSServer)
|
group.GET("/qr_ws", qrWebsocketAPI.QrWebsocket)
|
||||||
|
group.GET("/message_ws", messageWebsocketAPI.MessageWSController)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"schisandra-cloud-album/global"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,37 +18,41 @@ func GenerateAvatar(userId string) (baseImg string) {
|
|||||||
|
|
||||||
path := "https://api.multiavatar.com/" + userId + ".png"
|
path := "https://api.multiavatar.com/" + userId + ".png"
|
||||||
|
|
||||||
//// 创建请求
|
// 创建请求
|
||||||
//request, err := http.NewRequest("GET", path, nil)
|
request, err := http.NewRequest("GET", path, nil)
|
||||||
//if err != nil {
|
if err != nil {
|
||||||
// return "", errors.New("image request error")
|
global.LOG.Error(err)
|
||||||
//}
|
return ""
|
||||||
//
|
}
|
||||||
//// 发送请求并获取响应
|
|
||||||
//respImg, err := client.Do(request)
|
// 发送请求并获取响应
|
||||||
//if err != nil {
|
respImg, err := client.Do(request)
|
||||||
// return "", errors.New("failed to fetch image")
|
if err != nil {
|
||||||
//}
|
global.LOG.Error(err)
|
||||||
//defer func(Body io.ReadCloser) {
|
return ""
|
||||||
// err := Body.Close()
|
}
|
||||||
// if err != nil {
|
defer func(Body io.ReadCloser) {
|
||||||
// return
|
err = Body.Close()
|
||||||
// }
|
if err != nil {
|
||||||
//}(respImg.Body)
|
global.LOG.Error(err)
|
||||||
//
|
return
|
||||||
//// 读取图片数据
|
}
|
||||||
//imgByte, err := io.ReadAll(respImg.Body)
|
}(respImg.Body)
|
||||||
//if err != nil {
|
|
||||||
// return "", errors.New("failed to read image data")
|
// 读取图片数据
|
||||||
//}
|
imgByte, err := io.ReadAll(respImg.Body)
|
||||||
//
|
if err != nil {
|
||||||
//// 判断文件类型,生成一个前缀
|
global.LOG.Error(err)
|
||||||
//mimeType := http.DetectContentType(imgByte)
|
return ""
|
||||||
//switch mimeType {
|
}
|
||||||
//case "image/png":
|
|
||||||
// baseImg = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgByte)
|
// 判断文件类型,生成一个前缀
|
||||||
//default:
|
mimeType := http.DetectContentType(imgByte)
|
||||||
// return "", errors.New("unsupported image type")
|
switch mimeType {
|
||||||
//}
|
case "image/png":
|
||||||
return path
|
baseImg = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgByte)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return baseImg
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user