✨ use websocket
This commit is contained in:
10
api/api.go
10
api/api.go
@@ -5,14 +5,16 @@ import (
|
||||
"schisandra-cloud-album/api/oauth_api"
|
||||
"schisandra-cloud-album/api/sms_api"
|
||||
"schisandra-cloud-album/api/user_api"
|
||||
"schisandra-cloud-album/api/websocket_api"
|
||||
)
|
||||
|
||||
// Apis 统一导出的api
|
||||
type Apis struct {
|
||||
UserApi user_api.UserAPI
|
||||
CaptchaApi captcha_api.CaptchaAPI
|
||||
SmsApi sms_api.SmsAPI
|
||||
OAuthApi oauth_api.OAuthAPI
|
||||
UserApi user_api.UserAPI
|
||||
CaptchaApi captcha_api.CaptchaAPI
|
||||
SmsApi sms_api.SmsAPI
|
||||
OAuthApi oauth_api.OAuthAPI
|
||||
WebsocketApi websocket_api.WebsocketAPI
|
||||
}
|
||||
|
||||
// Api new函数实例化,实例化完成后会返回结构体地指针类型
|
||||
|
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/yitter/idgenerator-go/idgen"
|
||||
"gorm.io/gorm"
|
||||
"schisandra-cloud-album/api/user_api/dto"
|
||||
"schisandra-cloud-album/api/websocket_api"
|
||||
"schisandra-cloud-album/common/constant"
|
||||
"schisandra-cloud-album/common/enum"
|
||||
"schisandra-cloud-album/common/redis"
|
||||
@@ -44,6 +45,11 @@ var roleService = service.Service.RoleService
|
||||
// @Router /api/oauth/generate_client_id [get]
|
||||
func (OAuthAPI) GenerateClientId(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
clientId := redis.Get(constant.UserLoginClientRedisKey + ip).Val()
|
||||
if clientId != "" {
|
||||
result.OkWithData(clientId, c)
|
||||
return
|
||||
}
|
||||
v1 := uuid.NewV1()
|
||||
redis.Set(constant.UserLoginClientRedisKey+ip, v1.String(), 0)
|
||||
result.OkWithData(v1.String(), c)
|
||||
@@ -74,9 +80,9 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) {
|
||||
key := strings.TrimPrefix(msg.EventKey, "qrscene_")
|
||||
res := wechatLoginHandler(msg.FromUserName, key)
|
||||
if !res {
|
||||
return messages.NewText("登录失败")
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed"))
|
||||
}
|
||||
return messages.NewText("登录成功")
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginSuccess"))
|
||||
|
||||
case models.CALLBACK_EVENT_UNSUBSCRIBE:
|
||||
msg := models.EventUnSubscribe{}
|
||||
@@ -97,9 +103,9 @@ func (OAuthAPI) CallbackNotify(c *gin.Context) {
|
||||
}
|
||||
res := wechatLoginHandler(msg.FromUserName, msg.EventKey)
|
||||
if !res {
|
||||
return messages.NewText("登录失败")
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginFailed"))
|
||||
}
|
||||
return messages.NewText("登录成功")
|
||||
return messages.NewText(ginI18n.MustGetMessage(c, "LoginSuccess"))
|
||||
|
||||
}
|
||||
|
||||
@@ -147,11 +153,12 @@ func (OAuthAPI) CallbackVerify(c *gin.Context) {
|
||||
// @Router /api/oauth/get_temp_qrcode [get]
|
||||
func (OAuthAPI) GetTempQrCode(c *gin.Context) {
|
||||
clientId := c.Query("client_id")
|
||||
ip := c.ClientIP()
|
||||
if clientId == "" {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
qrcode := redis.Get(constant.UserLoginQrcodeRedisKey + clientId).Val()
|
||||
qrcode := redis.Get(constant.UserLoginQrcodeRedisKey + ip + ":" + clientId).Val()
|
||||
|
||||
if qrcode != "" {
|
||||
data := response.ResponseQRCodeCreate{}
|
||||
@@ -172,7 +179,7 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
|
||||
return
|
||||
}
|
||||
wrong := redis.Set(constant.UserLoginQrcodeRedisKey+clientId, serializedData, time.Hour*24*30).Err()
|
||||
wrong := redis.Set(constant.UserLoginQrcodeRedisKey+ip+":"+clientId, serializedData, time.Hour*24*30).Err()
|
||||
|
||||
if wrong != nil {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "QRCodeGetFailed"), c)
|
||||
@@ -181,6 +188,7 @@ func (OAuthAPI) GetTempQrCode(c *gin.Context) {
|
||||
result.OK(ginI18n.MustGetMessage(c, "QRCodeGetSuccess"), data.Url, c)
|
||||
}
|
||||
|
||||
// wechatLoginHandler 微信登录处理
|
||||
func wechatLoginHandler(openId string, clientId string) bool {
|
||||
if openId == "" {
|
||||
return false
|
||||
@@ -277,5 +285,9 @@ func handelUserLogin(user model.ScaAuthUser, clientId string) bool {
|
||||
if fail != nil || w != nil {
|
||||
return false
|
||||
}
|
||||
res := websocket_api.SendMessageData(clientId, data)
|
||||
if !res {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
ginI18n "github.com/gin-contrib/i18n"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/wumansgy/goEncrypt/aes"
|
||||
"github.com/yitter/idgenerator-go/idgen"
|
||||
"reflect"
|
||||
"schisandra-cloud-album/api/user_api/dto"
|
||||
@@ -328,13 +327,7 @@ func (UserAPI) RefreshHandler(c *gin.Context) {
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "ParamsError"), c)
|
||||
return
|
||||
}
|
||||
plaintext, err := aes.AesCtrDecryptByHex(refreshToken, []byte(global.CONFIG.Encrypt.Key), []byte(global.CONFIG.Encrypt.IV))
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
|
||||
return
|
||||
}
|
||||
parseRefreshToken, isUpd, err := utils.ParseToken(string(plaintext))
|
||||
parseRefreshToken, isUpd, err := utils.ParseToken(refreshToken)
|
||||
if err != nil {
|
||||
global.LOG.Errorln(err)
|
||||
result.FailWithMessage(ginI18n.MustGetMessage(c, "LoginExpired"), c)
|
||||
|
4
api/websocket_api/websocket.go
Normal file
4
api/websocket_api/websocket.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package websocket_api
|
||||
|
||||
type WebsocketAPI struct {
|
||||
}
|
196
api/websocket_api/websocket_api.go
Normal file
196
api/websocket_api/websocket_api.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package websocket_api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"log"
|
||||
"net/http"
|
||||
"schisandra-cloud-album/global"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// 消息通道
|
||||
msg = make(map[string]chan interface{})
|
||||
// websocket客户端链接池
|
||||
client = make(map[string]*websocket.Conn)
|
||||
// 互斥锁,防止程序对统一资源同时进行读写
|
||||
mux sync.Mutex
|
||||
)
|
||||
|
||||
// NewSocketClient 建立websocket长链接接口处理函数
|
||||
func (WebsocketAPI) NewSocketClient(context *gin.Context) {
|
||||
id := context.Query("client_id")
|
||||
global.LOG.Println(id + "websocket链接")
|
||||
// 升级为websocket长链接
|
||||
WsHandler(context.Writer, context.Request, id)
|
||||
}
|
||||
|
||||
// DeleteClient api:/deleteClient接口处理函数
|
||||
func (WebsocketAPI) DeleteClient(context *gin.Context) {
|
||||
id := context.Query("client_id")
|
||||
// 关闭websocket链接
|
||||
conn, exist := getClient(id)
|
||||
if exist {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
deleteClient(id)
|
||||
} else {
|
||||
context.JSON(http.StatusOK, gin.H{
|
||||
"mesg": "未找到该客户端",
|
||||
})
|
||||
}
|
||||
// 关闭其消息通道
|
||||
_, exist = getMsgChannel(id)
|
||||
if exist {
|
||||
deletemsgChannel(id)
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessageData 发送消息接口处理函数
|
||||
func SendMessageData(clientId string, data interface{}) bool {
|
||||
m, exist := getMsgChannel(clientId)
|
||||
if !exist {
|
||||
log.Println("未找到该客户端的消息通道")
|
||||
return false
|
||||
}
|
||||
// 向消息通道发送消息
|
||||
select {
|
||||
case m <- data:
|
||||
global.LOG.Println("发送消息给客户端:" + clientId)
|
||||
return true
|
||||
default:
|
||||
global.LOG.Println("消息通道已满,消息发送失败")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// websocket Upgrader
|
||||
var wsupgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
// 取消ws跨域校验
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// WsHandler 处理ws请求
|
||||
func WsHandler(w http.ResponseWriter, r *http.Request, id string) {
|
||||
var conn *websocket.Conn
|
||||
var err error
|
||||
var exist bool
|
||||
// 创建一个定时器用于服务端心跳
|
||||
//pingTicker := time.NewTicker(time.Second * 10)
|
||||
conn, err = wsupgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
global.LOG.Println(err)
|
||||
return
|
||||
}
|
||||
// 把与客户端的链接添加到客户端链接池中
|
||||
addClient(id, conn)
|
||||
|
||||
// 获取该客户端的消息通道
|
||||
m, exist := getMsgChannel(id)
|
||||
if !exist {
|
||||
m = make(chan interface{})
|
||||
addMsgChannel(id, m)
|
||||
}
|
||||
// 设置客户端关闭ws链接回调函数
|
||||
conn.SetCloseHandler(func(code int, text string) error {
|
||||
deleteClient(id)
|
||||
return nil
|
||||
})
|
||||
defer conn.Close()
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case content, _ := <-m:
|
||||
// 从消息通道接收消息,然后推送给前端
|
||||
err = conn.WriteJSON(content)
|
||||
if err != nil {
|
||||
global.LOG.Error(err)
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
deleteClient(id)
|
||||
break
|
||||
}
|
||||
//case <-pingTicker.C:
|
||||
// // 服务端心跳:每20秒ping一次客户端,查看其是否在线
|
||||
// err := conn.SetWriteDeadline(time.Now().Add(time.Second * 20))
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// err = conn.WriteMessage(websocket.PongMessage, []byte("pong"))
|
||||
// if err != nil {
|
||||
// log.Println("send pong err:", err)
|
||||
// err := conn.Close()
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// deleteClient(id)
|
||||
// return
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 将客户端添加到客户端链接池
|
||||
func addClient(id string, conn *websocket.Conn) {
|
||||
mux.Lock()
|
||||
client[id] = conn
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// 获取指定客户端链接
|
||||
func getClient(id string) (conn *websocket.Conn, exist bool) {
|
||||
mux.Lock()
|
||||
conn, exist = client[id]
|
||||
mux.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// 删除客户端链接
|
||||
func deleteClient(id string) {
|
||||
mux.Lock()
|
||||
delete(client, id)
|
||||
log.Println(id + "websocket退出")
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// 添加用户消息通道
|
||||
func addMsgChannel(id string, m chan interface{}) {
|
||||
mux.Lock()
|
||||
msg[id] = m
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// 获取指定用户消息通道
|
||||
func getMsgChannel(id string) (m chan interface{}, exist bool) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
m, exist = msg[id]
|
||||
return
|
||||
}
|
||||
|
||||
// 删除指定消息通道
|
||||
func deletemsgChannel(id string) {
|
||||
mux.Lock()
|
||||
if m, ok := msg[id]; ok {
|
||||
close(m)
|
||||
delete(msg, id)
|
||||
}
|
||||
mux.Unlock()
|
||||
}
|
Reference in New Issue
Block a user