add face recognition

This commit is contained in:
2025-01-22 10:36:28 +08:00
parent eab806fb9b
commit c6af9a0461
47 changed files with 3621 additions and 454 deletions

View File

@@ -0,0 +1,355 @@
package aiservicelogic
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/Kagami/go-face"
"github.com/ccpwcn/kgo"
"github.com/zeromicro/go-zero/core/logx"
"image"
"image/jpeg"
"os"
"path/filepath"
"schisandra-album-cloud-microservices/app/aisvc/model/mysql/model"
"schisandra-album-cloud-microservices/app/aisvc/rpc/internal/svc"
"schisandra-album-cloud-microservices/app/aisvc/rpc/pb"
"schisandra-album-cloud-microservices/common/constant"
"strconv"
"sync"
"time"
)
type FaceRecognitionLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
directoryCache sync.Map
wg sync.WaitGroup
mu sync.Mutex
}
func NewFaceRecognitionLogic(ctx context.Context, svcCtx *svc.ServiceContext) *FaceRecognitionLogic {
return &FaceRecognitionLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
directoryCache: sync.Map{},
wg: sync.WaitGroup{},
mu: sync.Mutex{},
}
}
// FaceRecognition 人脸识别
func (l *FaceRecognitionLogic) FaceRecognition(in *pb.FaceRecognitionRequest) (*pb.FaceRecognitionResponse, error) {
// 提取人脸特征
faceFeatures, err := l.svcCtx.FaceRecognizer.RecognizeSingle(in.GetFace())
if err != nil {
return nil, err
}
if faceFeatures == nil {
return nil, nil
}
hashKey := fmt.Sprintf("user:%s:faces", in.GetUserId())
// 从 Redis 加载人脸数据
samples, ids, err := l.loadFacesFromRedisHash(hashKey)
if err != nil {
return nil, fmt.Errorf("failed to query Redis: %v", err)
}
// 如果缓存中没有数据,则查询数据库
if len(samples) == 0 {
samples, ids, err = l.loadExistingFaces(in.GetUserId())
if err != nil {
return nil, err
}
// 如果数据库也没有数据,直接保存当前人脸
if len(samples) == 0 || len(ids) == 0 {
return l.saveNewFace(in, faceFeatures, hashKey)
}
// 将数据写入 Redis
err = l.cacheFacesToRedisHash(hashKey, samples, ids)
if err != nil {
return nil, fmt.Errorf("failed to cache faces to Redis: %v", err)
}
}
// 设置人脸特征
l.svcCtx.FaceRecognizer.SetSamples(samples, ids)
// 人脸分类
classify := l.svcCtx.FaceRecognizer.ClassifyThreshold(faceFeatures.Descriptor, 0.6)
if classify >= 0 {
return &pb.FaceRecognitionResponse{
FaceId: int64(ids[classify]),
}, nil
}
// 如果未找到匹配的人脸,则保存为新样本
return l.saveNewFace(in, faceFeatures, hashKey)
}
// 保存新的人脸样本到数据库和 Redis
func (l *FaceRecognitionLogic) saveNewFace(in *pb.FaceRecognitionRequest, faceFeatures *face.Face, hashKey string) (*pb.FaceRecognitionResponse, error) {
// 人脸有效性判断 (大小必须大于50)
if !l.isFaceValid(faceFeatures.Rectangle) {
return nil, nil
}
// 保存人脸图片到本地
faceImagePath, err := l.saveCroppedFaceToLocal(in.GetFace(), faceFeatures.Rectangle, "face_samples", in.GetUserId())
if err != nil {
return nil, err
}
// 保存到数据库
storageFace, err := l.saveFaceToDatabase(in.GetUserId(), faceFeatures.Descriptor, faceImagePath)
if err != nil {
return nil, err
}
// 将新增数据写入 Redis
err = l.appendFaceToRedisHash(hashKey, storageFace.ID, faceFeatures.Descriptor)
if err != nil {
return nil, fmt.Errorf("failed to append face to Redis: %v", err)
}
return &pb.FaceRecognitionResponse{
FaceId: storageFace.ID,
}, nil
}
// 加载数据库中的已有人脸
func (l *FaceRecognitionLogic) loadExistingFaces(userId string) ([]face.Descriptor, []int32, error) {
if userId == "" {
return nil, nil, fmt.Errorf("user ID is required")
}
storageFace := l.svcCtx.DB.ScaStorageFace
existingFaces, err := storageFace.
Select(storageFace.FaceVector, storageFace.ID).
Where(storageFace.UserID.Eq(userId), storageFace.FaceType.Eq(constant.FaceTypeSample)).
Find()
if err != nil {
return nil, nil, err
}
if len(existingFaces) == 0 {
return nil, nil, nil
}
var samples []face.Descriptor
var ids []int32
// 使用并发处理每个数据
for _, existingFace := range existingFaces {
l.wg.Add(1)
go func(faceData *model.ScaStorageFace) {
defer l.wg.Done()
var descriptor face.Descriptor
if err = json.Unmarshal([]byte(faceData.FaceVector), &descriptor); err != nil {
l.Errorf("failed to unmarshal face vector: %v", err)
return
}
// 使用锁来保证并发访问时对切片的安全操作
l.mu.Lock()
samples = append(samples, descriptor)
ids = append(ids, int32(faceData.ID))
l.mu.Unlock()
}(existingFace)
}
l.wg.Wait()
return samples, ids, nil
}
const (
minFaceWidth = 50 // 最小允许的人脸宽度
minFaceHeight = 50 // 最小允许的人脸高度
)
// 判断人脸是否有效
func (l *FaceRecognitionLogic) isFaceValid(rect image.Rectangle) bool {
width := rect.Dx()
height := rect.Dy()
return width >= minFaceWidth && height >= minFaceHeight
}
// 保存人脸特征和路径到数据库
func (l *FaceRecognitionLogic) saveFaceToDatabase(userId string, descriptor face.Descriptor, faceImagePath string) (*model.ScaStorageFace, error) {
jsonBytes, err := json.Marshal(descriptor)
if err != nil {
return nil, err
}
storageFace := model.ScaStorageFace{
FaceVector: string(jsonBytes),
FaceImagePath: faceImagePath,
FaceType: constant.FaceTypeSample,
UserID: userId,
}
err = l.svcCtx.DB.ScaStorageFace.Create(&storageFace)
if err != nil {
return nil, err
}
return &storageFace, nil
}
func (l *FaceRecognitionLogic) saveCroppedFaceToLocal(faceImage []byte, rect image.Rectangle, baseSavePath string, userID string) (string, error) {
// 动态生成用户目录和时间分级目录
subDir := filepath.Join(baseSavePath, userID, time.Now().Format("2006/01")) // 格式:<baseSavePath>/<userID>/YYYY/MM
// 缓存目录检查,避免重复调用 os.MkdirAll
if !l.isDirectoryCached(subDir) {
if err := os.MkdirAll(subDir, os.ModePerm); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}
l.cacheDirectory(subDir) // 缓存已创建的目录路径
}
// 解码图像
img, _, err := image.Decode(bytes.NewReader(faceImage))
if err != nil {
return "", fmt.Errorf("image decode failed: %w", err)
}
// 获取图像边界
imgBounds := img.Bounds()
// 增加边距(比如 20 像素)
margin := 20
extendedRect := image.Rect(
max(rect.Min.X-margin, imgBounds.Min.X), // 确保不超出左边界
max(rect.Min.Y-margin, imgBounds.Min.Y), // 确保不超出上边界
min(rect.Max.X+margin, imgBounds.Max.X), // 确保不超出右边界
min(rect.Max.Y+margin, imgBounds.Max.Y), // 确保不超出下边界
)
// 裁剪图像
croppedImage := img.(interface {
SubImage(r image.Rectangle) image.Image
}).SubImage(extendedRect)
// 生成唯一文件名(时间戳 + UUID
fileName := fmt.Sprintf("%s_%s.jpg", time.Now().Format("20060102_150405"), kgo.SimpleUuid())
outputPath := filepath.Join(subDir, fileName)
// 写入文件
if err = l.writeImageToFile(outputPath, croppedImage); err != nil {
return "", err
}
return outputPath, nil
}
// 判断目录是否已缓存
func (l *FaceRecognitionLogic) isDirectoryCached(dir string) bool {
_, exists := l.directoryCache.Load(dir)
return exists
}
// 缓存目录
func (l *FaceRecognitionLogic) cacheDirectory(dir string) {
l.directoryCache.Store(dir, struct{}{})
}
// 将图像写入文件
func (l *FaceRecognitionLogic) writeImageToFile(path string, img image.Image) error {
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer func(file *os.File) {
_ = file.Close()
}(file)
if err = jpeg.Encode(file, img, nil); err != nil {
return fmt.Errorf("failed to encode and save image: %w", err)
}
return nil
}
// 从 Redis 的 Hash 中加载人脸数据
func (l *FaceRecognitionLogic) loadFacesFromRedisHash(hashKey string) ([]face.Descriptor, []int32, error) {
// 从 Redis 获取 Hash 的所有字段和值
data, err := l.svcCtx.RedisClient.HGetAll(l.ctx, hashKey).Result()
if err != nil {
return nil, nil, err
}
var samples []face.Descriptor
var ids []int32
for idStr, descriptorStr := range data {
var descriptor face.Descriptor
if err = json.Unmarshal([]byte(descriptorStr), &descriptor); err != nil {
return nil, nil, err
}
// 转换 ID 为 int32
id, err := parseInt32(idStr)
if err != nil {
return nil, nil, err
}
samples = append(samples, descriptor)
ids = append(ids, id)
}
return samples, ids, nil
}
// 将人脸数据写入 Redis 的 Hash
func (l *FaceRecognitionLogic) cacheFacesToRedisHash(hashKey string, samples []face.Descriptor, ids []int32) error {
// 开启事务
pipe := l.svcCtx.RedisClient.Pipeline()
for i := range samples {
descriptorData, err := json.Marshal(samples[i])
if err != nil {
return err
}
// 使用 HSET 设置 Hash 字段和值
pipe.HSet(l.ctx, hashKey, fmt.Sprintf("%d", ids[i]), descriptorData)
}
// 设置缓存过期时间
pipe.Expire(l.ctx, hashKey, 3600*time.Second)
_, err := pipe.Exec(l.ctx)
return err
}
// 将新增的人脸数据追加到 Redis 的 Hash
func (l *FaceRecognitionLogic) appendFaceToRedisHash(hashKey string, id int64, descriptor face.Descriptor) error {
descriptorData, err := json.Marshal(descriptor)
if err != nil {
return err
}
// 追加数据到 Hash
err = l.svcCtx.RedisClient.HSet(l.ctx, hashKey, fmt.Sprintf("%d", id), descriptorData).Err()
if err != nil {
return err
}
// 检查是否已设置过期时间
ttl, err := l.svcCtx.RedisClient.TTL(l.ctx, hashKey).Result()
if err != nil {
return err
}
// 如果未设置过期时间或已经过期,设置固定过期时间
if ttl < 0 {
err = l.svcCtx.RedisClient.Expire(l.ctx, hashKey, 3600*time.Second).Err()
if err != nil {
return err
}
}
return nil
}
// 辅助函数:字符串转换为 int32
func parseInt32(s string) (int32, error) {
var i int64
var err error
if i, err = strconv.ParseInt(s, 10, 32); err != nil {
return 0, err
}
return int32(i), nil
}