Files
schisandra-cloud-album-front/src/utils/tfjs/anime_classifier.ts
2025-01-13 19:30:29 +08:00

69 lines
2.5 KiB
TypeScript

import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';
async function loadModelFromIndexedDBOrUrl(modelName: string, modelUrl: string) {
let model: tf.GraphModel;
tf.setBackend('webgl');
try {
// 尝试从 IndexedDB 加载模型
model = await tf.loadGraphModel(`indexeddb://${modelName}-model`);
console.log("模型成功从 IndexedDB 加载");
} catch (_error) {
console.log("从 URL 下载模型...");
// 如果 IndexedDB 中没有模型,则从 URL 加载并保存到 IndexedDB
model = await tf.loadGraphModel(modelUrl);
await model.save(`indexeddb://${modelName}-model`);
console.log("模型已从 URL 下载并保存到 IndexedDB");
}
return model;
}
// 封装处理图像和推理的工具函数
export async function loadAnimeClassifierModel() {
const modelName = 'anime_classifier';
const modelUrl = '/tfjs/anime_classifier/model.json';
return await loadModelFromIndexedDBOrUrl(modelName, modelUrl);
}
// 处理图片并进行推理
export async function animePredictImage(imageElement) {
const model: tf.GraphModel = await loadAnimeClassifierModel();
// 将图片转换为张量
const tensor = tf.browser.fromPixels(imageElement, 3).toFloat();
const resized = tf.image.resizeBilinear(tensor, [224, 224]); // 调整图片大小为模型输入大小
const input = resized.expandDims(0); // 增加批次维度
// 进行推理
const prediction: any = model.predict(input);
// 获取预测结果并返回
const resultArray = await prediction.array();
const result = resultArray[0]; // 获取预测结果数组
return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
}
// export async function animePredictImage(width: number, height: number, uint8Array: Uint8Array) {
// const model: tf.GraphModel = await loadModel();
//
// // 将 Uint8Array 转换为 Tensor
// const tensor = tf.tensor3d(uint8Array, [height, width, 3], 'int32').toFloat();
//
// // 调整图片大小为模型输入大小
// const resized = tf.image.resizeBilinear(tensor, [224, 224]);
//
// // 增加批次维度
// const input = resized.expandDims(0);
//
// // 进行推理
// const prediction: any = model.predict(input);
//
// // 获取预测结果并返回
// const resultArray = await prediction.array();
// const result = resultArray[0]; // 获取预测结果数组
// return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
// }