69 lines
2.5 KiB
TypeScript
69 lines
2.5 KiB
TypeScript
import * as tf from '@tensorflow/tfjs';
|
|
import '@tensorflow/tfjs-backend-webgl';
|
|
|
|
async function loadModelFromIndexedDBOrUrl(modelName: string, modelUrl: string) {
|
|
let model: tf.GraphModel;
|
|
tf.setBackend('webgl');
|
|
try {
|
|
// 尝试从 IndexedDB 加载模型
|
|
model = await tf.loadGraphModel(`indexeddb://${modelName}-model`);
|
|
console.log("模型成功从 IndexedDB 加载");
|
|
} catch (_error) {
|
|
console.log("从 URL 下载模型...");
|
|
// 如果 IndexedDB 中没有模型,则从 URL 加载并保存到 IndexedDB
|
|
model = await tf.loadGraphModel(modelUrl);
|
|
await model.save(`indexeddb://${modelName}-model`);
|
|
console.log("模型已从 URL 下载并保存到 IndexedDB");
|
|
}
|
|
return model;
|
|
}
|
|
|
|
// 封装处理图像和推理的工具函数
|
|
export async function loadAnimeClassifierModel() {
|
|
const modelName = 'anime_classifier';
|
|
const modelUrl = '/tfjs/anime_classifier/model.json';
|
|
return await loadModelFromIndexedDBOrUrl(modelName, modelUrl);
|
|
}
|
|
|
|
// 处理图片并进行推理
|
|
export async function animePredictImage(imageElement) {
|
|
|
|
const model: tf.GraphModel = await loadAnimeClassifierModel();
|
|
// 将图片转换为张量
|
|
const tensor = tf.browser.fromPixels(imageElement, 3).toFloat();
|
|
const resized = tf.image.resizeBilinear(tensor, [224, 224]); // 调整图片大小为模型输入大小
|
|
const input = resized.expandDims(0); // 增加批次维度
|
|
|
|
// 进行推理
|
|
const prediction: any = model.predict(input);
|
|
|
|
// 获取预测结果并返回
|
|
const resultArray = await prediction.array();
|
|
const result = resultArray[0]; // 获取预测结果数组
|
|
return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
|
|
}
|
|
|
|
// export async function animePredictImage(width: number, height: number, uint8Array: Uint8Array) {
|
|
// const model: tf.GraphModel = await loadModel();
|
|
//
|
|
// // 将 Uint8Array 转换为 Tensor
|
|
// const tensor = tf.tensor3d(uint8Array, [height, width, 3], 'int32').toFloat();
|
|
//
|
|
// // 调整图片大小为模型输入大小
|
|
// const resized = tf.image.resizeBilinear(tensor, [224, 224]);
|
|
//
|
|
// // 增加批次维度
|
|
// const input = resized.expandDims(0);
|
|
//
|
|
// // 进行推理
|
|
// const prediction: any = model.predict(input);
|
|
//
|
|
// // 获取预测结果并返回
|
|
// const resultArray = await prediction.array();
|
|
// const result = resultArray[0]; // 获取预测结果数组
|
|
// return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
|
|
// }
|
|
|
|
|
|
|