197 lines
4.2 KiB
Go
197 lines
4.2 KiB
Go
package websocket_api
|
||
|
||
import (
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/gorilla/websocket"
|
||
"log"
|
||
"net/http"
|
||
"schisandra-cloud-album/global"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
var (
|
||
// 消息通道
|
||
msg = make(map[string]chan interface{})
|
||
// websocket客户端链接池
|
||
client = make(map[string]*websocket.Conn)
|
||
// 互斥锁,防止程序对统一资源同时进行读写
|
||
mux sync.Mutex
|
||
)
|
||
|
||
// NewSocketClient 建立websocket长链接接口处理函数
|
||
func (WebsocketAPI) NewSocketClient(context *gin.Context) {
|
||
id := context.Query("client_id")
|
||
global.LOG.Println(id + "websocket链接")
|
||
// 升级为websocket长链接
|
||
WsHandler(context.Writer, context.Request, id)
|
||
}
|
||
|
||
// DeleteClient api:/deleteClient接口处理函数
|
||
func (WebsocketAPI) DeleteClient(context *gin.Context) {
|
||
id := context.Query("client_id")
|
||
// 关闭websocket链接
|
||
conn, exist := getClient(id)
|
||
if exist {
|
||
err := conn.Close()
|
||
if err != nil {
|
||
return
|
||
}
|
||
deleteClient(id)
|
||
} else {
|
||
context.JSON(http.StatusOK, gin.H{
|
||
"mesg": "未找到该客户端",
|
||
})
|
||
}
|
||
// 关闭其消息通道
|
||
_, exist = getMsgChannel(id)
|
||
if exist {
|
||
deletemsgChannel(id)
|
||
}
|
||
}
|
||
|
||
// SendMessageData 发送消息接口处理函数
|
||
func SendMessageData(clientId string, data interface{}) bool {
|
||
m, exist := getMsgChannel(clientId)
|
||
if !exist {
|
||
log.Println("未找到该客户端的消息通道")
|
||
return false
|
||
}
|
||
// 向消息通道发送消息
|
||
select {
|
||
case m <- data:
|
||
global.LOG.Println("发送消息给客户端:" + clientId)
|
||
return true
|
||
default:
|
||
global.LOG.Println("消息通道已满,消息发送失败")
|
||
}
|
||
return false
|
||
}
|
||
|
||
// websocket Upgrader
|
||
var wsupgrader = websocket.Upgrader{
|
||
ReadBufferSize: 1024,
|
||
WriteBufferSize: 1024,
|
||
HandshakeTimeout: 5 * time.Second,
|
||
// 取消ws跨域校验
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true
|
||
},
|
||
}
|
||
|
||
// WsHandler 处理ws请求
|
||
func WsHandler(w http.ResponseWriter, r *http.Request, id string) {
|
||
var conn *websocket.Conn
|
||
var err error
|
||
var exist bool
|
||
// 创建一个定时器用于服务端心跳
|
||
//pingTicker := time.NewTicker(time.Second * 10)
|
||
conn, err = wsupgrader.Upgrade(w, r, nil)
|
||
if err != nil {
|
||
global.LOG.Println(err)
|
||
return
|
||
}
|
||
// 把与客户端的链接添加到客户端链接池中
|
||
addClient(id, conn)
|
||
|
||
// 获取该客户端的消息通道
|
||
m, exist := getMsgChannel(id)
|
||
if !exist {
|
||
m = make(chan interface{})
|
||
addMsgChannel(id, m)
|
||
}
|
||
// 设置客户端关闭ws链接回调函数
|
||
conn.SetCloseHandler(func(code int, text string) error {
|
||
deleteClient(id)
|
||
return nil
|
||
})
|
||
defer conn.Close()
|
||
for {
|
||
_, _, err := conn.ReadMessage()
|
||
if err != nil {
|
||
fmt.Println(err)
|
||
return
|
||
}
|
||
select {
|
||
case content, _ := <-m:
|
||
// 从消息通道接收消息,然后推送给前端
|
||
err = conn.WriteJSON(content)
|
||
if err != nil {
|
||
global.LOG.Error(err)
|
||
err := conn.Close()
|
||
if err != nil {
|
||
return
|
||
}
|
||
deleteClient(id)
|
||
break
|
||
}
|
||
//case <-pingTicker.C:
|
||
// // 服务端心跳:每20秒ping一次客户端,查看其是否在线
|
||
// err := conn.SetWriteDeadline(time.Now().Add(time.Second * 20))
|
||
// if err != nil {
|
||
// return
|
||
// }
|
||
// err = conn.WriteMessage(websocket.PongMessage, []byte("pong"))
|
||
// if err != nil {
|
||
// log.Println("send pong err:", err)
|
||
// err := conn.Close()
|
||
// if err != nil {
|
||
// return
|
||
// }
|
||
// deleteClient(id)
|
||
// return
|
||
// }
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
// 将客户端添加到客户端链接池
|
||
func addClient(id string, conn *websocket.Conn) {
|
||
mux.Lock()
|
||
client[id] = conn
|
||
mux.Unlock()
|
||
}
|
||
|
||
// 获取指定客户端链接
|
||
func getClient(id string) (conn *websocket.Conn, exist bool) {
|
||
mux.Lock()
|
||
conn, exist = client[id]
|
||
mux.Unlock()
|
||
return
|
||
}
|
||
|
||
// 删除客户端链接
|
||
func deleteClient(id string) {
|
||
mux.Lock()
|
||
delete(client, id)
|
||
log.Println(id + "websocket退出")
|
||
mux.Unlock()
|
||
}
|
||
|
||
// 添加用户消息通道
|
||
func addMsgChannel(id string, m chan interface{}) {
|
||
mux.Lock()
|
||
msg[id] = m
|
||
mux.Unlock()
|
||
}
|
||
|
||
// 获取指定用户消息通道
|
||
func getMsgChannel(id string) (m chan interface{}, exist bool) {
|
||
mux.Lock()
|
||
defer mux.Unlock()
|
||
m, exist = msg[id]
|
||
return
|
||
}
|
||
|
||
// 删除指定消息通道
|
||
func deletemsgChannel(id string) {
|
||
mux.Lock()
|
||
if m, ok := msg[id]; ok {
|
||
close(m)
|
||
delete(msg, id)
|
||
}
|
||
mux.Unlock()
|
||
}
|