Files
schisandra-album-cloud-micr…/app/aisvc/rpc/internal/logic/aiservice/tf_classification_logic.go
2025-01-23 15:24:59 +08:00

80 lines
2.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package aiservicelogic
import (
"context"
"fmt"
"gocv.io/x/gocv"
"image"
"schisandra-album-cloud-microservices/app/aisvc/rpc/internal/svc"
"schisandra-album-cloud-microservices/app/aisvc/rpc/pb"
"github.com/zeromicro/go-zero/core/logx"
)
type TfClassificationLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewTfClassificationLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TfClassificationLogic {
return &TfClassificationLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// TfClassification is a server endpoint to classify an image using TensorFlow.
func (l *TfClassificationLogic) TfClassification(in *pb.TfClassificationRequest) (*pb.TfClassificationResponse, error) {
className, source, err := l.ClassifyImage(in.GetImage())
if err != nil {
return nil, err
}
return &pb.TfClassificationResponse{
Score: source,
ClassName: className,
}, nil
}
// ClassifyImage 从字节数据分类图像,返回分类标签和最大概率值
func (l *TfClassificationLogic) ClassifyImage(imageBytes []byte) (string, float32, error) {
// 解码字节数据为图像
img, err := gocv.IMDecode(imageBytes, gocv.IMReadColor)
if err != nil || img.Empty() {
return "", 0, fmt.Errorf("failed to decode image: %v", err)
}
defer func(img *gocv.Mat) {
_ = img.Close()
}(&img)
// 将图像 Mat 转换为 224x224 blob以便分类器分析
blob := gocv.BlobFromImage(img, 1.0, image.Pt(224, 224), gocv.NewScalar(0, 0, 0, 0), true, false)
// 将 blob 输入分类器
l.svcCtx.TfNet.SetInput(blob, "input")
// 运行网络的正向传递
prob := l.svcCtx.TfNet.Forward("softmax2")
// 将结果重塑为 1x1000 矩阵
probMat := prob.Reshape(1, 1)
// 确定最可能的分类
_, maxVal, _, maxLoc := gocv.MinMaxLoc(probMat)
// 获取分类描述
desc := ""
if maxLoc.X < 1000 {
desc = l.svcCtx.TfDesc[maxLoc.X]
}
// 清理资源
_ = blob.Close()
_ = prob.Close()
_ = probMat.Close()
return desc, maxVal, nil
}