use websocket

This commit is contained in:
landaiqing
2024-08-17 20:01:03 +08:00
parent e8fbff7e7f
commit 57964d39af
11 changed files with 280 additions and 25 deletions

View File

@@ -5,14 +5,16 @@ import (
"schisandra-cloud-album/api/oauth_api"
"schisandra-cloud-album/api/sms_api"
"schisandra-cloud-album/api/user_api"
"schisandra-cloud-album/api/websocket_api"
)
// Apis 统一导出的api
type Apis struct {
UserApi user_api.UserAPI
CaptchaApi captcha_api.CaptchaAPI
SmsApi sms_api.SmsAPI
OAuthApi oauth_api.OAuthAPI
UserApi user_api.UserAPI
CaptchaApi captcha_api.CaptchaAPI
SmsApi sms_api.SmsAPI
OAuthApi oauth_api.OAuthAPI
WebsocketApi websocket_api.WebsocketAPI
}
// Api new函数实例化实例化完成后会返回结构体地指针类型

View File

@@ -16,6 +16,7 @@ import (
"github.com/yitter/idgenerator-go/idgen"
"gorm.io/gorm"
"schisandra-cloud-album/api/user_api/dto"
"schisandra-cloud-album/api/websocket_api"
"schisandra-cloud-album/common/constant"
"schisandra-cloud-album/common/enum"
"schisandra-cloud-album/common/redis"
@@ -44,6 +45,11 @@ var roleService = service.Service.RoleService
// @Router /api/oauth/generate_client_id [get]
func (OAuthAPI) GenerateClientId(c *gin.Context) {
ip := c.ClientIP()
clientId := redis.Get(constant.UserLoginClientRedisKey + ip).Val()
if clientId != "" {
result.OkWithData(clientId, c)
return
}
v1 := uuid.NewV1()
redis.Set(constant.UserLoginClientRedisKey+ip, v1.String(), 0)
result.OkWithData(v1.String(), c)
@@ -74,9 +80,9 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) {
key := strings.TrimPrefix(msg.EventKey, "qrscene_")
res := wechatLoginHandler(msg.FromUserName, key)
if !res {
return messages.NewText("登录失败")
return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed"))
}
return messages.NewText("登录成功")
return messages.NewText(ginI18n.MustGetMessage(c, "LoginSuccess"))
case models.CALLBACK_EVENT_UNSUBSCRIBE:
msg := models.EventUnSubscribe{}
@@ -97,9 +103,9 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) {
}
res := wechatLoginHandler(msg.FromUserName, msg.EventKey)
if !res {
return messages.NewText("登录失败")
return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed"))
}
return messages.NewText("登录成功")
return messages.NewText(ginI18n.MustGetMessage(c, "LoginSuccess"))
}
@@ -147,11 +153,12 @@ func (OAuthAPI) CallbackVerify(c *gin.Context) {
// @Router /api/oauth/get_temp_qrcode [get]
func (OAuthAPI) GetTempQrCode(c *gin.Context) {
clientId := c.Query("client_id")
ip := c.ClientIP()
if clientId == "" {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
qrcode := redis.Get(constant.UserLoginQrcodeRedisKey + clientId).Val()
qrcode := redis.Get(constant.UserLoginQrcodeRedisKey + ip + ":" + clientId).Val()
if qrcode != "" {
data := response.ResponseQRCodeCreate{}
@@ -172,7 +179,7 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
return
}
wrong := redis.Set(constant.UserLoginQrcodeRedisKey+clientId, serializedData, time.Hour*24*30).Err()
wrong := redis.Set(constant.UserLoginQrcodeRedisKey+ip+":"+clientId, serializedData, time.Hour*24*30).Err()
if wrong != nil {
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
@@ -181,6 +188,7 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c)
}
// wechatLoginHandler 微信登录处理
func wechatLoginHandler(openId string, clientId string) bool {
if openId == "" {
return false
@@ -277,5 +285,9 @@ func handelUserLogin(user model.ScaAuthUser, clientId string) bool {
if fail != nil || w != nil {
return false
}
res := websocket_api.SendMessageData(clientId, data)
if !res {
return false
}
return true
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
ginI18n "github.com/gin-contrib/i18n"
"github.com/gin-gonic/gin"
"github.com/wumansgy/goEncrypt/aes"
"github.com/yitter/idgenerator-go/idgen"
"reflect"
"schisandra-cloud-album/api/user_api/dto"
@@ -328,13 +327,7 @@ func (UserAPI) RefreshHandler(c *gin.Context) {
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
return
}
plaintext, err := aes.AesCtrDecryptByHex(refreshToken, []byte(global.CONFIG.Encrypt.Key), []byte(global.CONFIG.Encrypt.IV))
if err != nil {
global.LOG.Error(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
return
}
parseRefreshToken, isUpd, err := utils.ParseToken(string(plaintext))
parseRefreshToken, isUpd, err := utils.ParseToken(refreshToken)
if err != nil {
global.LOG.Errorln(err)
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)

View File

@@ -0,0 +1,4 @@
package websocket_api
type WebsocketAPI struct {
}

View File

@@ -0,0 +1,196 @@
package websocket_api
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"log"
"net/http"
"schisandra-cloud-album/global"
"sync"
"time"
)
var (
// 消息通道
msg = make(map[string]chan interface{})
// websocket客户端链接池
client = make(map[string]*websocket.Conn)
// 互斥锁,防止程序对统一资源同时进行读写
mux sync.Mutex
)
// NewSocketClient 建立websocket长链接接口处理函数
func (WebsocketAPI) NewSocketClient(context *gin.Context) {
id := context.Query("client_id")
global.LOG.Println(id + "websocket链接")
// 升级为websocket长链接
WsHandler(context.Writer, context.Request, id)
}
// DeleteClient api:/deleteClient接口处理函数
func (WebsocketAPI) DeleteClient(context *gin.Context) {
id := context.Query("client_id")
// 关闭websocket链接
conn, exist := getClient(id)
if exist {
err := conn.Close()
if err != nil {
return
}
deleteClient(id)
} else {
context.JSON(http.StatusOK, gin.H{
"mesg": "未找到该客户端",
})
}
// 关闭其消息通道
_, exist = getMsgChannel(id)
if exist {
deletemsgChannel(id)
}
}
// SendMessageData 发送消息接口处理函数
func SendMessageData(clientId string, data interface{}) bool {
m, exist := getMsgChannel(clientId)
if !exist {
log.Println("未找到该客户端的消息通道")
return false
}
// 向消息通道发送消息
select {
case m <- data:
global.LOG.Println("发送消息给客户端:" + clientId)
return true
default:
global.LOG.Println("消息通道已满,消息发送失败")
}
return false
}
// websocket Upgrader
var wsupgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 5 * time.Second,
// 取消ws跨域校验
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// WsHandler 处理ws请求
func WsHandler(w http.ResponseWriter, r *http.Request, id string) {
var conn *websocket.Conn
var err error
var exist bool
// 创建一个定时器用于服务端心跳
//pingTicker := time.NewTicker(time.Second * 10)
conn, err = wsupgrader.Upgrade(w, r, nil)
if err != nil {
global.LOG.Println(err)
return
}
// 把与客户端的链接添加到客户端链接池中
addClient(id, conn)
// 获取该客户端的消息通道
m, exist := getMsgChannel(id)
if !exist {
m = make(chan interface{})
addMsgChannel(id, m)
}
// 设置客户端关闭ws链接回调函数
conn.SetCloseHandler(func(code int, text string) error {
deleteClient(id)
return nil
})
defer conn.Close()
for {
_, _, err := conn.ReadMessage()
if err != nil {
fmt.Println(err)
return
}
select {
case content, _ := <-m:
// 从消息通道接收消息,然后推送给前端
err = conn.WriteJSON(content)
if err != nil {
global.LOG.Error(err)
err := conn.Close()
if err != nil {
return
}
deleteClient(id)
break
}
//case <-pingTicker.C:
// // 服务端心跳:每20秒ping一次客户端查看其是否在线
// err := conn.SetWriteDeadline(time.Now().Add(time.Second * 20))
// if err != nil {
// return
// }
// err = conn.WriteMessage(websocket.PongMessage, []byte("pong"))
// if err != nil {
// log.Println("send pong err:", err)
// err := conn.Close()
// if err != nil {
// return
// }
// deleteClient(id)
// return
// }
}
}
}
// 将客户端添加到客户端链接池
func addClient(id string, conn *websocket.Conn) {
mux.Lock()
client[id] = conn
mux.Unlock()
}
// 获取指定客户端链接
func getClient(id string) (conn *websocket.Conn, exist bool) {
mux.Lock()
conn, exist = client[id]
mux.Unlock()
return
}
// 删除客户端链接
func deleteClient(id string) {
mux.Lock()
delete(client, id)
log.Println(id + "websocket退出")
mux.Unlock()
}
// 添加用户消息通道
func addMsgChannel(id string, m chan interface{}) {
mux.Lock()
msg[id] = m
mux.Unlock()
}
// 获取指定用户消息通道
func getMsgChannel(id string) (m chan interface{}, exist bool) {
mux.Lock()
defer mux.Unlock()
m, exist = msg[id]
return
}
// 删除指定消息通道
func deletemsgChannel(id string) {
mux.Lock()
if m, ok := msg[id]; ok {
close(m)
delete(msg, id)
}
mux.Unlock()
}