377 lines
10 KiB
Go
377 lines
10 KiB
Go
package aiservicelogic
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/Kagami/go-face"
|
|
"github.com/ccpwcn/kgo"
|
|
"github.com/minio/minio-go/v7"
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
"image"
|
|
"image/jpeg"
|
|
_ "image/png"
|
|
"path"
|
|
"schisandra-album-cloud-microservices/app/aisvc/model/mysql/model"
|
|
"schisandra-album-cloud-microservices/app/aisvc/rpc/internal/svc"
|
|
"schisandra-album-cloud-microservices/app/aisvc/rpc/pb"
|
|
"schisandra-album-cloud-microservices/common/constant"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type FaceRecognitionLogic struct {
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
logx.Logger
|
|
directoryCache sync.Map
|
|
wg sync.WaitGroup
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func NewFaceRecognitionLogic(ctx context.Context, svcCtx *svc.ServiceContext) *FaceRecognitionLogic {
|
|
return &FaceRecognitionLogic{
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
Logger: logx.WithContext(ctx),
|
|
directoryCache: sync.Map{},
|
|
wg: sync.WaitGroup{},
|
|
mu: sync.Mutex{},
|
|
}
|
|
}
|
|
|
|
// FaceRecognition 人脸识别
|
|
func (l *FaceRecognitionLogic) FaceRecognition(in *pb.FaceRecognitionRequest) (*pb.FaceRecognitionResponse, error) {
|
|
toJPEG, err := l.ConvertImageToJPEG(in.GetFace())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if toJPEG == nil {
|
|
return nil, nil
|
|
}
|
|
// 提取人脸特征
|
|
faceFeatures, err := l.svcCtx.FaceRecognizer.RecognizeSingle(toJPEG)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if faceFeatures == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
hashKey := constant.FaceVectorPrefix + in.GetUserId()
|
|
// 从 Redis 加载人脸数据
|
|
samples, ids, err := l.loadFacesFromRedisHash(hashKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query Redis: %v", err)
|
|
}
|
|
// 如果缓存中没有数据,则查询数据库
|
|
if len(samples) == 0 {
|
|
samples, ids, err = l.loadExistingFaces(in.GetUserId())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 如果数据库也没有数据,直接保存当前人脸
|
|
if len(samples) == 0 || len(ids) == 0 {
|
|
return l.saveNewFace(in, faceFeatures, hashKey)
|
|
}
|
|
|
|
// 将数据写入 Redis
|
|
err = l.cacheFacesToRedisHash(hashKey, samples, ids)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to cache faces to Redis: %v", err)
|
|
}
|
|
}
|
|
|
|
// 设置人脸特征
|
|
l.svcCtx.FaceRecognizer.SetSamples(samples, ids)
|
|
|
|
// 人脸分类
|
|
classify := l.svcCtx.FaceRecognizer.ClassifyThreshold(faceFeatures.Descriptor, 0.3)
|
|
if classify > 0 {
|
|
return &pb.FaceRecognitionResponse{
|
|
FaceId: int64(classify),
|
|
}, nil
|
|
}
|
|
|
|
// 如果未找到匹配的人脸,则保存为新样本
|
|
return l.saveNewFace(in, faceFeatures, hashKey)
|
|
}
|
|
|
|
func (l *FaceRecognitionLogic) ConvertImageToJPEG(imageData []byte) ([]byte, error) {
|
|
|
|
// 解码图片
|
|
img, format, err := image.Decode(bytes.NewReader(imageData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode image: %v", err)
|
|
}
|
|
|
|
// 如果已经是 JPEG 格式,则直接返回原数据
|
|
if format == "jpeg" {
|
|
return imageData, nil
|
|
}
|
|
|
|
// 如果是 PNG 格式,则转换为 JPEG
|
|
var jpegBuffer bytes.Buffer
|
|
err = jpeg.Encode(&jpegBuffer, img, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode image to JPEG: %v", err)
|
|
}
|
|
|
|
return jpegBuffer.Bytes(), nil
|
|
}
|
|
|
|
// 保存新的人脸样本到数据库和 Redis
|
|
func (l *FaceRecognitionLogic) saveNewFace(in *pb.FaceRecognitionRequest, faceFeatures *face.Face, hashKey string) (*pb.FaceRecognitionResponse, error) {
|
|
// 人脸有效性判断 (大小必须大于50)
|
|
if !l.isFaceValid(faceFeatures.Rectangle) {
|
|
return nil, nil
|
|
}
|
|
|
|
// 保存人脸图片到本地
|
|
faceImagePath, err := l.saveCroppedFaceToLocal(in.GetFace(), faceFeatures.Rectangle, in.GetUserId())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 保存到数据库
|
|
storageFace, err := l.saveFaceToDatabase(in.GetUserId(), faceFeatures.Descriptor, faceImagePath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 将新增数据写入 Redis
|
|
err = l.appendFaceToRedisHash(hashKey, storageFace.ID, faceFeatures.Descriptor)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to append face to Redis: %v", err)
|
|
}
|
|
|
|
return &pb.FaceRecognitionResponse{
|
|
FaceId: storageFace.ID,
|
|
}, nil
|
|
}
|
|
|
|
// 加载数据库中的已有人脸
|
|
func (l *FaceRecognitionLogic) loadExistingFaces(userId string) ([]face.Descriptor, []int32, error) {
|
|
if userId == "" {
|
|
return nil, nil, fmt.Errorf("user ID is required")
|
|
}
|
|
storageFace := l.svcCtx.DB.ScaStorageFace
|
|
existingFaces, err := storageFace.
|
|
Select(storageFace.FaceVector, storageFace.ID).
|
|
Where(storageFace.UserID.Eq(userId)).
|
|
Find()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if len(existingFaces) == 0 {
|
|
return nil, nil, nil
|
|
}
|
|
|
|
var samples []face.Descriptor
|
|
var ids []int32
|
|
// 使用并发处理每个数据
|
|
for _, existingFace := range existingFaces {
|
|
l.wg.Add(1)
|
|
go func(faceData *model.ScaStorageFace) {
|
|
defer l.wg.Done()
|
|
|
|
var descriptor face.Descriptor
|
|
if err = json.Unmarshal([]byte(faceData.FaceVector), &descriptor); err != nil {
|
|
l.Errorf("failed to unmarshal face vector: %v", err)
|
|
return
|
|
}
|
|
// 使用锁来保证并发访问时对切片的安全操作
|
|
l.mu.Lock()
|
|
samples = append(samples, descriptor)
|
|
ids = append(ids, int32(faceData.ID))
|
|
l.mu.Unlock()
|
|
}(existingFace)
|
|
}
|
|
l.wg.Wait()
|
|
return samples, ids, nil
|
|
}
|
|
|
|
const (
|
|
minFaceWidth = 50 // 最小允许的人脸宽度
|
|
minFaceHeight = 50 // 最小允许的人脸高度
|
|
FaceShow = 0 // 人脸展示状态
|
|
FaceHide = 1
|
|
)
|
|
|
|
// 判断人脸是否有效
|
|
func (l *FaceRecognitionLogic) isFaceValid(rect image.Rectangle) bool {
|
|
width := rect.Dx()
|
|
height := rect.Dy()
|
|
return width >= minFaceWidth && height >= minFaceHeight
|
|
}
|
|
|
|
// 保存人脸特征和路径到数据库
|
|
func (l *FaceRecognitionLogic) saveFaceToDatabase(userId string, descriptor face.Descriptor, faceImagePath string) (*model.ScaStorageFace, error) {
|
|
jsonBytes, err := json.Marshal(descriptor)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
storageFace := model.ScaStorageFace{
|
|
FaceVector: string(jsonBytes),
|
|
FaceImagePath: faceImagePath,
|
|
UserID: userId,
|
|
FaceShow: FaceShow,
|
|
}
|
|
err = l.svcCtx.DB.ScaStorageFace.Create(&storageFace)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &storageFace, nil
|
|
}
|
|
|
|
func (l *FaceRecognitionLogic) saveCroppedFaceToLocal(faceImage []byte, rect image.Rectangle, userID string) (string, error) {
|
|
objectKey := path.Join(
|
|
userID,
|
|
time.Now().Format("2006/01"), // 按年/月划分目录
|
|
fmt.Sprintf("%s_%s.jpg", time.Now().Format("20060102150405"), kgo.SimpleUuid()),
|
|
)
|
|
|
|
// 解码图像
|
|
img, _, err := image.Decode(bytes.NewReader(faceImage))
|
|
if err != nil {
|
|
return "", fmt.Errorf("image decode failed: %w", err)
|
|
}
|
|
|
|
// 获取图像边界
|
|
imgBounds := img.Bounds()
|
|
// 增加边距(比如 20 像素)
|
|
margin := 20
|
|
extendedRect := image.Rect(
|
|
max(rect.Min.X-margin, imgBounds.Min.X), // 确保不超出左边界
|
|
max(rect.Min.Y-margin, imgBounds.Min.Y), // 确保不超出上边界
|
|
min(rect.Max.X+margin, imgBounds.Max.X), // 确保不超出右边界
|
|
min(rect.Max.Y+margin, imgBounds.Max.Y), // 确保不超出下边界
|
|
)
|
|
// 裁剪图像
|
|
croppedImage := img.(interface {
|
|
SubImage(r image.Rectangle) image.Image
|
|
}).SubImage(extendedRect)
|
|
|
|
// 将图像编码为JPEG字节流
|
|
var buf bytes.Buffer
|
|
if err = jpeg.Encode(&buf, croppedImage, nil); err != nil {
|
|
return "", fmt.Errorf("failed to encode image to JPEG: %w", err)
|
|
}
|
|
exists, err := l.svcCtx.MinioClient.BucketExists(l.ctx, constant.FaceBucketName)
|
|
if err != nil || !exists {
|
|
err = l.svcCtx.MinioClient.MakeBucket(l.ctx, constant.FaceBucketName, minio.MakeBucketOptions{Region: "us-east-1", ObjectLocking: true})
|
|
if err != nil {
|
|
logx.Errorf("Failed to create MinIO bucket: %v", err)
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
// 上传到MinIO
|
|
_, err = l.svcCtx.MinioClient.PutObject(
|
|
l.ctx,
|
|
constant.FaceBucketName,
|
|
objectKey,
|
|
bytes.NewReader(buf.Bytes()),
|
|
int64(buf.Len()),
|
|
minio.PutObjectOptions{
|
|
ContentType: "image/jpeg",
|
|
},
|
|
)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to upload image to MinIO: %w", err)
|
|
}
|
|
return objectKey, nil
|
|
}
|
|
|
|
// 从 Redis 的 Hash 中加载人脸数据
|
|
func (l *FaceRecognitionLogic) loadFacesFromRedisHash(hashKey string) ([]face.Descriptor, []int32, error) {
|
|
// 从 Redis 获取 Hash 的所有字段和值
|
|
data, err := l.svcCtx.RedisClient.HGetAll(l.ctx, hashKey).Result()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
var samples []face.Descriptor
|
|
var ids []int32
|
|
for idStr, descriptorStr := range data {
|
|
var descriptor face.Descriptor
|
|
if err = json.Unmarshal([]byte(descriptorStr), &descriptor); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// 转换 ID 为 int32
|
|
id, err := parseInt32(idStr)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
samples = append(samples, descriptor)
|
|
ids = append(ids, id)
|
|
}
|
|
return samples, ids, nil
|
|
}
|
|
|
|
// 将人脸数据写入 Redis 的 Hash
|
|
func (l *FaceRecognitionLogic) cacheFacesToRedisHash(hashKey string, samples []face.Descriptor, ids []int32) error {
|
|
// 开启事务
|
|
pipe := l.svcCtx.RedisClient.Pipeline()
|
|
|
|
for i := range samples {
|
|
descriptorData, err := json.Marshal(samples[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 使用 HSET 设置 Hash 字段和值
|
|
pipe.HSet(l.ctx, hashKey, fmt.Sprintf("%d", ids[i]), descriptorData)
|
|
}
|
|
|
|
// 设置缓存过期时间
|
|
pipe.Expire(l.ctx, hashKey, 3600*time.Second)
|
|
|
|
_, err := pipe.Exec(l.ctx)
|
|
return err
|
|
}
|
|
|
|
// 将新增的人脸数据追加到 Redis 的 Hash
|
|
func (l *FaceRecognitionLogic) appendFaceToRedisHash(hashKey string, id int64, descriptor face.Descriptor) error {
|
|
descriptorData, err := json.Marshal(descriptor)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 追加数据到 Hash
|
|
err = l.svcCtx.RedisClient.HSet(l.ctx, hashKey, fmt.Sprintf("%d", id), descriptorData).Err()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// 检查是否已设置过期时间
|
|
ttl, err := l.svcCtx.RedisClient.TTL(l.ctx, hashKey).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 如果未设置过期时间或已经过期,设置固定过期时间
|
|
if ttl < 0 {
|
|
err = l.svcCtx.RedisClient.Expire(l.ctx, hashKey, 3600*time.Second).Err()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// 辅助函数:字符串转换为 int32
|
|
func parseInt32(s string) (int32, error) {
|
|
var i int64
|
|
var err error
|
|
if i, err = strconv.ParseInt(s, 10, 32); err != nil {
|
|
return 0, err
|
|
}
|
|
return int32(i), nil
|
|
}
|