🐳 add docker file
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
package aiservicelogic
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"schisandra-album-cloud-microservices/app/aisvc/rpc/internal/svc"
|
||||
"schisandra-album-cloud-microservices/app/aisvc/rpc/pb"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
type CaffeClassificationLogic struct {
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
logx.Logger
|
||||
}
|
||||
|
||||
func NewCaffeClassificationLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CaffeClassificationLogic {
|
||||
return &CaffeClassificationLogic{
|
||||
ctx: ctx,
|
||||
svcCtx: svcCtx,
|
||||
Logger: logx.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
// CaffeClassification
|
||||
func (l *CaffeClassificationLogic) CaffeClassification(in *pb.CaffeClassificationRequest) (*pb.CaffeClassificationResponse, error) {
|
||||
// todo: add your logic here and delete this line
|
||||
|
||||
return &pb.CaffeClassificationResponse{}, nil
|
||||
}
|
@@ -43,8 +43,12 @@ func NewFaceRecognitionLogic(ctx context.Context, svcCtx *svc.ServiceContext) *F
|
||||
|
||||
// FaceRecognition 人脸识别
|
||||
func (l *FaceRecognitionLogic) FaceRecognition(in *pb.FaceRecognitionRequest) (*pb.FaceRecognitionResponse, error) {
|
||||
toJPEG, err := l.ConvertImageToJPEG(in.GetFace())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 提取人脸特征
|
||||
faceFeatures, err := l.svcCtx.FaceRecognizer.RecognizeSingle(in.GetFace())
|
||||
faceFeatures, err := l.svcCtx.FaceRecognizer.RecognizeSingle(toJPEG)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,7 +86,7 @@ func (l *FaceRecognitionLogic) FaceRecognition(in *pb.FaceRecognitionRequest) (*
|
||||
|
||||
// 人脸分类
|
||||
classify := l.svcCtx.FaceRecognizer.ClassifyThreshold(faceFeatures.Descriptor, 0.6)
|
||||
if classify >= 0 {
|
||||
if classify >= 0 && classify < len(ids) {
|
||||
return &pb.FaceRecognitionResponse{
|
||||
FaceId: int64(ids[classify]),
|
||||
}, nil
|
||||
@@ -92,6 +96,26 @@ func (l *FaceRecognitionLogic) FaceRecognition(in *pb.FaceRecognitionRequest) (*
|
||||
return l.saveNewFace(in, faceFeatures, hashKey)
|
||||
}
|
||||
|
||||
// ConvertImageToJPEG 将非 JPEG 格式的图片字节数据转换为 JPEG
|
||||
func (l *FaceRecognitionLogic) ConvertImageToJPEG(imageData []byte) ([]byte, error) {
|
||||
// 使用 image.Decode 解码图像数据
|
||||
img, _, err := image.Decode(bytes.NewReader(imageData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode image: %v", err)
|
||||
}
|
||||
|
||||
// 创建一个缓冲区来存储 JPEG 格式的数据
|
||||
var jpegBuffer bytes.Buffer
|
||||
|
||||
// 将图片编码为 JPEG 格式
|
||||
err = jpeg.Encode(&jpegBuffer, img, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode image to JPEG: %v", err)
|
||||
}
|
||||
|
||||
return jpegBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// 保存新的人脸样本到数据库和 Redis
|
||||
func (l *FaceRecognitionLogic) saveNewFace(in *pb.FaceRecognitionRequest, faceFeatures *face.Face, hashKey string) (*pb.FaceRecognitionResponse, error) {
|
||||
// 人脸有效性判断 (大小必须大于50)
|
||||
|
@@ -0,0 +1,79 @@
|
||||
package aiservicelogic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gocv.io/x/gocv"
|
||||
"image"
|
||||
"schisandra-album-cloud-microservices/app/aisvc/rpc/internal/svc"
|
||||
"schisandra-album-cloud-microservices/app/aisvc/rpc/pb"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
type TfClassificationLogic struct {
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
logx.Logger
|
||||
}
|
||||
|
||||
func NewTfClassificationLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TfClassificationLogic {
|
||||
return &TfClassificationLogic{
|
||||
ctx: ctx,
|
||||
svcCtx: svcCtx,
|
||||
Logger: logx.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
// TfClassification is a server endpoint to classify an image using TensorFlow.
|
||||
func (l *TfClassificationLogic) TfClassification(in *pb.TfClassificationRequest) (*pb.TfClassificationResponse, error) {
|
||||
className, source, err := l.ClassifyImage(in.GetImage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pb.TfClassificationResponse{
|
||||
Score: source,
|
||||
ClassName: className,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ClassifyImage 从字节数据分类图像,返回分类标签和最大概率值
|
||||
func (l *TfClassificationLogic) ClassifyImage(imageBytes []byte) (string, float32, error) {
|
||||
|
||||
// 解码字节数据为图像
|
||||
img, err := gocv.IMDecode(imageBytes, gocv.IMReadColor)
|
||||
if err != nil || img.Empty() {
|
||||
return "", 0, fmt.Errorf("failed to decode image: %v", err)
|
||||
}
|
||||
defer func(img *gocv.Mat) {
|
||||
_ = img.Close()
|
||||
}(&img)
|
||||
|
||||
// 将图像 Mat 转换为 224x224 blob,以便分类器分析
|
||||
blob := gocv.BlobFromImage(img, 1.0, image.Pt(224, 224), gocv.NewScalar(0, 0, 0, 0), true, false)
|
||||
|
||||
// 将 blob 输入分类器
|
||||
l.svcCtx.TfNet.SetInput(blob, "input")
|
||||
|
||||
// 运行网络的正向传递
|
||||
prob := l.svcCtx.TfNet.Forward("softmax2")
|
||||
|
||||
// 将结果重塑为 1x1000 矩阵
|
||||
probMat := prob.Reshape(1, 1)
|
||||
|
||||
// 确定最可能的分类
|
||||
_, maxVal, _, maxLoc := gocv.MinMaxLoc(probMat)
|
||||
|
||||
// 获取分类描述
|
||||
desc := ""
|
||||
if maxLoc.X < 1000 {
|
||||
desc = l.svcCtx.TfDesc[maxLoc.X]
|
||||
}
|
||||
|
||||
// 清理资源
|
||||
_ = blob.Close()
|
||||
_ = prob.Close()
|
||||
_ = probMat.Close()
|
||||
|
||||
return desc, maxVal, nil
|
||||
}
|
@@ -28,3 +28,15 @@ func (s *AiServiceServer) FaceRecognition(ctx context.Context, in *pb.FaceRecogn
|
||||
l := aiservicelogic.NewFaceRecognitionLogic(ctx, s.svcCtx)
|
||||
return l.FaceRecognition(in)
|
||||
}
|
||||
|
||||
// TfClassification
|
||||
func (s *AiServiceServer) TfClassification(ctx context.Context, in *pb.TfClassificationRequest) (*pb.TfClassificationResponse, error) {
|
||||
l := aiservicelogic.NewTfClassificationLogic(ctx, s.svcCtx)
|
||||
return l.TfClassification(in)
|
||||
}
|
||||
|
||||
// CaffeClassification
|
||||
func (s *AiServiceServer) CaffeClassification(ctx context.Context, in *pb.CaffeClassificationRequest) (*pb.CaffeClassificationResponse, error) {
|
||||
l := aiservicelogic.NewCaffeClassificationLogic(ctx, s.svcCtx)
|
||||
return l.CaffeClassification(in)
|
||||
}
|
||||
|
@@ -3,11 +3,14 @@ package svc
|
||||
import (
|
||||
"github.com/Kagami/go-face"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gocv.io/x/gocv"
|
||||
"schisandra-album-cloud-microservices/app/aisvc/model/mysql"
|
||||
"schisandra-album-cloud-microservices/app/aisvc/model/mysql/query"
|
||||
"schisandra-album-cloud-microservices/app/aisvc/rpc/internal/config"
|
||||
"schisandra-album-cloud-microservices/common/caffe_classifier"
|
||||
"schisandra-album-cloud-microservices/common/face_recognizer"
|
||||
"schisandra-album-cloud-microservices/common/redisx"
|
||||
"schisandra-album-cloud-microservices/common/tf_classifier"
|
||||
)
|
||||
|
||||
type ServiceContext struct {
|
||||
@@ -15,15 +18,25 @@ type ServiceContext struct {
|
||||
FaceRecognizer *face.Recognizer
|
||||
DB *query.Query
|
||||
RedisClient *redis.Client
|
||||
TfNet *gocv.Net
|
||||
TfDesc []string
|
||||
CaffeNet *gocv.Net
|
||||
CaffeDesc []string
|
||||
}
|
||||
|
||||
func NewServiceContext(c config.Config) *ServiceContext {
|
||||
redisClient := redisx.NewRedis(c.RedisConf.Host, c.RedisConf.Pass, c.RedisConf.DB)
|
||||
_, queryDB := mysql.NewMySQL(c.Mysql.DataSource, c.Mysql.MaxOpenConn, c.Mysql.MaxIdleConn, redisClient)
|
||||
tfClassifier, tfDesc := tf_classifier.NewTFClassifier()
|
||||
caffeClassifier, caffeDesc := caffe_classifier.NewCaffeClassifier()
|
||||
return &ServiceContext{
|
||||
Config: c,
|
||||
FaceRecognizer: face_recognizer.NewFaceRecognition(),
|
||||
DB: queryDB,
|
||||
RedisClient: redisClient,
|
||||
TfNet: tfClassifier,
|
||||
TfDesc: tfDesc,
|
||||
CaffeNet: caffeClassifier,
|
||||
CaffeDesc: caffeDesc,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user