Files
schisandra-album-cloud-micr…/app/aisvc/rpc/internal/logic/aiservice/face_recognition_logic.go
2025-01-22 10:36:28 +08:00

356 lines
9.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package aiservicelogic
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/Kagami/go-face"
"github.com/ccpwcn/kgo"
"github.com/zeromicro/go-zero/core/logx"
"image"
"image/jpeg"
"os"
"path/filepath"
"schisandra-album-cloud-microservices/app/aisvc/model/mysql/model"
"schisandra-album-cloud-microservices/app/aisvc/rpc/internal/svc"
"schisandra-album-cloud-microservices/app/aisvc/rpc/pb"
"schisandra-album-cloud-microservices/common/constant"
"strconv"
"sync"
"time"
)
type FaceRecognitionLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
directoryCache sync.Map
wg sync.WaitGroup
mu sync.Mutex
}
func NewFaceRecognitionLogic(ctx context.Context, svcCtx *svc.ServiceContext) *FaceRecognitionLogic {
return &FaceRecognitionLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
directoryCache: sync.Map{},
wg: sync.WaitGroup{},
mu: sync.Mutex{},
}
}
// FaceRecognition 人脸识别
func (l *FaceRecognitionLogic) FaceRecognition(in *pb.FaceRecognitionRequest) (*pb.FaceRecognitionResponse, error) {
// 提取人脸特征
faceFeatures, err := l.svcCtx.FaceRecognizer.RecognizeSingle(in.GetFace())
if err != nil {
return nil, err
}
if faceFeatures == nil {
return nil, nil
}
hashKey := fmt.Sprintf("user:%s:faces", in.GetUserId())
// 从 Redis 加载人脸数据
samples, ids, err := l.loadFacesFromRedisHash(hashKey)
if err != nil {
return nil, fmt.Errorf("failed to query Redis: %v", err)
}
// 如果缓存中没有数据,则查询数据库
if len(samples) == 0 {
samples, ids, err = l.loadExistingFaces(in.GetUserId())
if err != nil {
return nil, err
}
// 如果数据库也没有数据,直接保存当前人脸
if len(samples) == 0 || len(ids) == 0 {
return l.saveNewFace(in, faceFeatures, hashKey)
}
// 将数据写入 Redis
err = l.cacheFacesToRedisHash(hashKey, samples, ids)
if err != nil {
return nil, fmt.Errorf("failed to cache faces to Redis: %v", err)
}
}
// 设置人脸特征
l.svcCtx.FaceRecognizer.SetSamples(samples, ids)
// 人脸分类
classify := l.svcCtx.FaceRecognizer.ClassifyThreshold(faceFeatures.Descriptor, 0.6)
if classify >= 0 {
return &pb.FaceRecognitionResponse{
FaceId: int64(ids[classify]),
}, nil
}
// 如果未找到匹配的人脸,则保存为新样本
return l.saveNewFace(in, faceFeatures, hashKey)
}
// 保存新的人脸样本到数据库和 Redis
func (l *FaceRecognitionLogic) saveNewFace(in *pb.FaceRecognitionRequest, faceFeatures *face.Face, hashKey string) (*pb.FaceRecognitionResponse, error) {
// 人脸有效性判断 (大小必须大于50)
if !l.isFaceValid(faceFeatures.Rectangle) {
return nil, nil
}
// 保存人脸图片到本地
faceImagePath, err := l.saveCroppedFaceToLocal(in.GetFace(), faceFeatures.Rectangle, "face_samples", in.GetUserId())
if err != nil {
return nil, err
}
// 保存到数据库
storageFace, err := l.saveFaceToDatabase(in.GetUserId(), faceFeatures.Descriptor, faceImagePath)
if err != nil {
return nil, err
}
// 将新增数据写入 Redis
err = l.appendFaceToRedisHash(hashKey, storageFace.ID, faceFeatures.Descriptor)
if err != nil {
return nil, fmt.Errorf("failed to append face to Redis: %v", err)
}
return &pb.FaceRecognitionResponse{
FaceId: storageFace.ID,
}, nil
}
// 加载数据库中的已有人脸
func (l *FaceRecognitionLogic) loadExistingFaces(userId string) ([]face.Descriptor, []int32, error) {
if userId == "" {
return nil, nil, fmt.Errorf("user ID is required")
}
storageFace := l.svcCtx.DB.ScaStorageFace
existingFaces, err := storageFace.
Select(storageFace.FaceVector, storageFace.ID).
Where(storageFace.UserID.Eq(userId), storageFace.FaceType.Eq(constant.FaceTypeSample)).
Find()
if err != nil {
return nil, nil, err
}
if len(existingFaces) == 0 {
return nil, nil, nil
}
var samples []face.Descriptor
var ids []int32
// 使用并发处理每个数据
for _, existingFace := range existingFaces {
l.wg.Add(1)
go func(faceData *model.ScaStorageFace) {
defer l.wg.Done()
var descriptor face.Descriptor
if err = json.Unmarshal([]byte(faceData.FaceVector), &descriptor); err != nil {
l.Errorf("failed to unmarshal face vector: %v", err)
return
}
// 使用锁来保证并发访问时对切片的安全操作
l.mu.Lock()
samples = append(samples, descriptor)
ids = append(ids, int32(faceData.ID))
l.mu.Unlock()
}(existingFace)
}
l.wg.Wait()
return samples, ids, nil
}
const (
minFaceWidth = 50 // 最小允许的人脸宽度
minFaceHeight = 50 // 最小允许的人脸高度
)
// 判断人脸是否有效
func (l *FaceRecognitionLogic) isFaceValid(rect image.Rectangle) bool {
width := rect.Dx()
height := rect.Dy()
return width >= minFaceWidth && height >= minFaceHeight
}
// 保存人脸特征和路径到数据库
func (l *FaceRecognitionLogic) saveFaceToDatabase(userId string, descriptor face.Descriptor, faceImagePath string) (*model.ScaStorageFace, error) {
jsonBytes, err := json.Marshal(descriptor)
if err != nil {
return nil, err
}
storageFace := model.ScaStorageFace{
FaceVector: string(jsonBytes),
FaceImagePath: faceImagePath,
FaceType: constant.FaceTypeSample,
UserID: userId,
}
err = l.svcCtx.DB.ScaStorageFace.Create(&storageFace)
if err != nil {
return nil, err
}
return &storageFace, nil
}
func (l *FaceRecognitionLogic) saveCroppedFaceToLocal(faceImage []byte, rect image.Rectangle, baseSavePath string, userID string) (string, error) {
// 动态生成用户目录和时间分级目录
subDir := filepath.Join(baseSavePath, userID, time.Now().Format("2006/01")) // 格式:<baseSavePath>/<userID>/YYYY/MM
// 缓存目录检查,避免重复调用 os.MkdirAll
if !l.isDirectoryCached(subDir) {
if err := os.MkdirAll(subDir, os.ModePerm); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}
l.cacheDirectory(subDir) // 缓存已创建的目录路径
}
// 解码图像
img, _, err := image.Decode(bytes.NewReader(faceImage))
if err != nil {
return "", fmt.Errorf("image decode failed: %w", err)
}
// 获取图像边界
imgBounds := img.Bounds()
// 增加边距(比如 20 像素)
margin := 20
extendedRect := image.Rect(
max(rect.Min.X-margin, imgBounds.Min.X), // 确保不超出左边界
max(rect.Min.Y-margin, imgBounds.Min.Y), // 确保不超出上边界
min(rect.Max.X+margin, imgBounds.Max.X), // 确保不超出右边界
min(rect.Max.Y+margin, imgBounds.Max.Y), // 确保不超出下边界
)
// 裁剪图像
croppedImage := img.(interface {
SubImage(r image.Rectangle) image.Image
}).SubImage(extendedRect)
// 生成唯一文件名(时间戳 + UUID
fileName := fmt.Sprintf("%s_%s.jpg", time.Now().Format("20060102_150405"), kgo.SimpleUuid())
outputPath := filepath.Join(subDir, fileName)
// 写入文件
if err = l.writeImageToFile(outputPath, croppedImage); err != nil {
return "", err
}
return outputPath, nil
}
// 判断目录是否已缓存
func (l *FaceRecognitionLogic) isDirectoryCached(dir string) bool {
_, exists := l.directoryCache.Load(dir)
return exists
}
// 缓存目录
func (l *FaceRecognitionLogic) cacheDirectory(dir string) {
l.directoryCache.Store(dir, struct{}{})
}
// 将图像写入文件
func (l *FaceRecognitionLogic) writeImageToFile(path string, img image.Image) error {
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer func(file *os.File) {
_ = file.Close()
}(file)
if err = jpeg.Encode(file, img, nil); err != nil {
return fmt.Errorf("failed to encode and save image: %w", err)
}
return nil
}
// 从 Redis 的 Hash 中加载人脸数据
func (l *FaceRecognitionLogic) loadFacesFromRedisHash(hashKey string) ([]face.Descriptor, []int32, error) {
// 从 Redis 获取 Hash 的所有字段和值
data, err := l.svcCtx.RedisClient.HGetAll(l.ctx, hashKey).Result()
if err != nil {
return nil, nil, err
}
var samples []face.Descriptor
var ids []int32
for idStr, descriptorStr := range data {
var descriptor face.Descriptor
if err = json.Unmarshal([]byte(descriptorStr), &descriptor); err != nil {
return nil, nil, err
}
// 转换 ID 为 int32
id, err := parseInt32(idStr)
if err != nil {
return nil, nil, err
}
samples = append(samples, descriptor)
ids = append(ids, id)
}
return samples, ids, nil
}
// 将人脸数据写入 Redis 的 Hash
func (l *FaceRecognitionLogic) cacheFacesToRedisHash(hashKey string, samples []face.Descriptor, ids []int32) error {
// 开启事务
pipe := l.svcCtx.RedisClient.Pipeline()
for i := range samples {
descriptorData, err := json.Marshal(samples[i])
if err != nil {
return err
}
// 使用 HSET 设置 Hash 字段和值
pipe.HSet(l.ctx, hashKey, fmt.Sprintf("%d", ids[i]), descriptorData)
}
// 设置缓存过期时间
pipe.Expire(l.ctx, hashKey, 3600*time.Second)
_, err := pipe.Exec(l.ctx)
return err
}
// 将新增的人脸数据追加到 Redis 的 Hash
func (l *FaceRecognitionLogic) appendFaceToRedisHash(hashKey string, id int64, descriptor face.Descriptor) error {
descriptorData, err := json.Marshal(descriptor)
if err != nil {
return err
}
// 追加数据到 Hash
err = l.svcCtx.RedisClient.HSet(l.ctx, hashKey, fmt.Sprintf("%d", id), descriptorData).Err()
if err != nil {
return err
}
// 检查是否已设置过期时间
ttl, err := l.svcCtx.RedisClient.TTL(l.ctx, hashKey).Result()
if err != nil {
return err
}
// 如果未设置过期时间或已经过期,设置固定过期时间
if ttl < 0 {
err = l.svcCtx.RedisClient.Expire(l.ctx, hashKey, 3600*time.Second).Err()
if err != nil {
return err
}
}
return nil
}
// 辅助函数:字符串转换为 int32
func parseInt32(s string) (int32, error) {
var i int64
var err error
if i, err = strconv.ParseInt(s, 10, 32); err != nil {
return 0, err
}
return int32(i), nil
}