complete model integration

This commit is contained in:
2025-01-13 19:30:29 +08:00
parent fed52e66f9
commit 9356c00815
66 changed files with 2035 additions and 821 deletions

View File

@@ -0,0 +1,19 @@
import * as tf from '@tensorflow/tfjs';
export async function imageToUint8Array(imageElement) {
// 创建一个 TensorFlow.js 图像张量
const tensor = tf.browser.fromPixels(imageElement, 3).toFloat();
// 获取图像的宽度和高度
const width = imageElement.width;
const height = imageElement.height;
// 将张量转为 Uint8ArrayRGB 格式,值范围从 0 到 255
const uint8Array = await tensor.data();
return {
width,
height,
uint8Array
};
}

View File

@@ -0,0 +1,10 @@
export function imageToBase64(img) {
const canvas = document.createElement('canvas');
canvas.width = img.width;
canvas.height = img.height;
const ctx = canvas.getContext('2d');
if (ctx) {
ctx.drawImage(img, 0, 0);
}
return canvas.toDataURL('image/png'); // 或者 'image/jpeg'
}

View File

@@ -2,21 +2,23 @@ import * as nsfwjs from "nsfwjs";
import {NSFWJS} from "nsfwjs";
import * as tf from "@tensorflow/tfjs";
/**
* Initializes the NSFWJS model and returns it.
*/
let isInit: boolean = false;
const initNSFWJs = async (): Promise<NSFWJS> => {
tf.enableProdMode();
if (!isInit) {
const initialLoad: nsfwjs.NSFWJS = await nsfwjs.load("/nsfw/mobilenet_v2_mid/", {
size: 224,
type: "graph"
});
await initialLoad.model.save("indexeddb://nsfwjs-model");
isInit = true;
let nsfwModelCache: NSFWJS | null = null; // 缓存模型实例
// 如果模型已经加载,则直接返回缓存
try {
// 首先尝试从 IndexedDB 加载模型
nsfwModelCache = await nsfwjs.load("indexeddb://nsfwjs-model", {size: 224, type: "graph"});
console.log("NSFWJS 模型成功从 IndexedDB 加载");
} catch (_error) {
console.warn("IndexedDB 中未找到模型,正在从网络加载...");
// 如果 IndexedDB 加载失败,从 URL 加载模型并保存到 IndexedDB
nsfwModelCache = await nsfwjs.load("/nsfw/mobilenet_v2_mid/", {size: 224, type: "graph"});
await nsfwModelCache.model.save("indexeddb://nsfwjs-model");
console.log("NSFWJS 模型已从网络加载并保存到 IndexedDB");
}
return await nsfwjs.load("indexeddb://nsfwjs-model", {size: 224, type: "graph"});
return nsfwModelCache;
};
/**
* Predicts the NSFW score of an image using the NSFWJS model.

View File

@@ -0,0 +1,16 @@
import * as faceapi from '@vladmandic/face-api/dist/face-api.esm-nobundle.js';
export async function loadModel() {
const modelsPath = `/tfjs/face_api/model`;
// 面部识别模型
await faceapi.nets.faceRecognitionNet.load(modelsPath);
}
export async function faceSimilarity(img1: HTMLImageElement, img2: HTMLImageElement) {
const descriptor1 = await faceapi.computeFaceDescriptor(img1);
const descriptor2 = await faceapi.computeFaceDescriptor(img2);
if (descriptor1 instanceof Float32Array && descriptor2 instanceof Float32Array) {
return faceapi.euclideanDistance(descriptor1, descriptor2).toFixed(2);
}
return -1;
}

View File

@@ -1,23 +1,68 @@
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';
// 封装处理图像和推理的工具函数
export async function loadModel(modelPath) {
const model = await tf.loadGraphModel(modelPath);
console.log('Model Loaded');
async function loadModelFromIndexedDBOrUrl(modelName: string, modelUrl: string) {
let model: tf.GraphModel;
tf.setBackend('webgl');
try {
// 尝试从 IndexedDB 加载模型
model = await tf.loadGraphModel(`indexeddb://${modelName}-model`);
console.log("模型成功从 IndexedDB 加载");
} catch (_error) {
console.log("从 URL 下载模型...");
// 如果 IndexedDB 中没有模型,则从 URL 加载并保存到 IndexedDB
model = await tf.loadGraphModel(modelUrl);
await model.save(`indexeddb://${modelName}-model`);
console.log("模型已从 URL 下载并保存到 IndexedDB");
}
return model;
}
// 封装处理图像和推理的工具函数
export async function loadAnimeClassifierModel() {
const modelName = 'anime_classifier';
const modelUrl = '/tfjs/anime_classifier/model.json';
return await loadModelFromIndexedDBOrUrl(modelName, modelUrl);
}
// 处理图片并进行推理
export async function predictImage(model, imageElement) {
export async function animePredictImage(imageElement) {
const model: tf.GraphModel = await loadAnimeClassifierModel();
// 将图片转换为张量
const tensor = tf.browser.fromPixels(imageElement).toFloat();
const tensor = tf.browser.fromPixels(imageElement, 3).toFloat();
const resized = tf.image.resizeBilinear(tensor, [224, 224]); // 调整图片大小为模型输入大小
const input = resized.expandDims(0); // 增加批次维度
// 进行推理
const prediction = model.predict(input);
const prediction: any = model.predict(input);
// 获取预测结果并返回
const resultArray = await prediction.array();
return resultArray[0]; // 返回第一项的预测结果
const result = resultArray[0]; // 获取预测结果数组
return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
}
// export async function animePredictImage(width: number, height: number, uint8Array: Uint8Array) {
// const model: tf.GraphModel = await loadModel();
//
// // 将 Uint8Array 转换为 Tensor
// const tensor = tf.tensor3d(uint8Array, [height, width, 3], 'int32').toFloat();
//
// // 调整图片大小为模型输入大小
// const resized = tf.image.resizeBilinear(tensor, [224, 224]);
//
// // 增加批次维度
// const input = resized.expandDims(0);
//
// // 进行推理
// const prediction: any = model.predict(input);
//
// // 获取预测结果并返回
// const resultArray = await prediction.array();
// const result = resultArray[0]; // 获取预测结果数组
// return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
// }

View File

@@ -0,0 +1,65 @@
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';
async function loadModelFromIndexedDBOrUrl(modelName: string, modelUrl: string) {
let model: tf.LayersModel;
tf.setBackend('webgl');
try {
// 尝试从 IndexedDB 加载模型
model = await tf.loadLayersModel(`indexeddb://${modelName}-model`);
console.log("模型成功从 IndexedDB 加载");
} catch (_error) {
console.log("从 URL 下载模型...");
// 如果 IndexedDB 中没有模型,则从 URL 加载并保存到 IndexedDB
model = await tf.loadLayersModel(modelUrl);
await model.save(`indexeddb://${modelName}-model`);
console.log("模型已从 URL 下载并保存到 IndexedDB");
}
return model;
}
// 封装处理图像和推理的工具函数
export async function loadAnimeClassifierProModel() {
const modelName = 'anime_classifier2';
const modelUrl = '/tfjs/anime_classifier2/model.json';
return await loadModelFromIndexedDBOrUrl(modelName, modelUrl);
}
// 处理图片并进行推理
export async function animePredictImagePro(imageElement) {
const model: any = await loadAnimeClassifierProModel();
// 将图片转换为张量
const tensor = tf.browser.fromPixels(imageElement).toFloat();
const imageResized = tf.image.resizeBilinear(tensor, [224, 224]);
const imageReshaped = imageResized.reshape([1, 224, 224, 3]);
const imageNormalized = imageReshaped.div(255);
// 进行推理
const prediction: any = model.predict(imageNormalized);
const predictedClass = tf.argMax(prediction, 1).dataSync()[0];
// const predictedClassConfidence = await prediction.dataSync()[predictedClass].toFixed(2);
// console.log(`预测结果: ${predictedClassName}(${predictedClassConfidence})`);
return ['Anime', 'Furry', 'Neutral'][predictedClass];
}
// export async function animePredictImagePro(width: number, height: number, uint8Array: Uint8Array) {
//
// const model: any = await loadModel();
// // 将图片转换为张量
// const tensor = tf.tensor3d(uint8Array, [height, width, 3], 'int32').toFloat();
// const imageResized = tf.image.resizeBilinear(tensor, [224, 224]);
// const imageReshaped = imageResized.reshape([1, 224, 224, 3]);
// const imageNormalized = imageReshaped.div(255);
//
// // 进行推理
// const prediction: any = model.predict(imageNormalized);
//
//
// const predictedClass = tf.argMax(prediction, 1).dataSync()[0];
// // const predictedClassConfidence = await prediction.dataSync()[predictedClass].toFixed(2);
// // console.log(`预测结果: ${predictedClassName}(${predictedClassConfidence})`);
// return ['Anime', 'Furry', 'Neutral'][predictedClass];
// }

View File

@@ -0,0 +1,42 @@
import '@mediapipe/face_detection';
import '@tensorflow/tfjs-core';
// Register WebGL backend.
import '@tensorflow/tfjs-backend-webgl';
import * as faceDetection from '@tensorflow-models/face-detection';
import * as faceLandmarksDetection from '@tensorflow-models/face-landmarks-detection';
import '@mediapipe/face_mesh';
/**
* 检测人脸
* @param image
*/
export async function detectFaces(image: HTMLImageElement) {
const model = faceDetection.SupportedModels.MediaPipeFaceDetector;
const detectorConfig: any = {
runtime: 'tfjs',
maxFaces: 1,
modelType: 'short', //'short'|'full'
};
const detector = await faceDetection.createDetector(model, detectorConfig);
const estimationConfig = {flipHorizontal: false};
const faces = await detector.estimateFaces(image, estimationConfig);
return faces;
}
/**
* 检测人脸特征点
* @param image
*/
export async function detectionFaceLandmarks(image: HTMLImageElement) {
const model = faceLandmarksDetection.SupportedModels.MediaPipeFaceMesh;
const detectorConfig: any = {
runtime: 'tfjs',
maxFaces: 1,
refineLandmarks: false,
};
const detector = await faceLandmarksDetection.createDetector(model, detectorConfig);
const estimationConfig = {flipHorizontal: false};
const faces = await detector.estimateFaces(image, estimationConfig);
return faces;
}

View File

@@ -0,0 +1,28 @@
import * as faceapi from '@vladmandic/face-api/dist/face-api.esm-nobundle.js';
import '@tensorflow/tfjs-backend-webgl';
import * as tf from '@tensorflow/tfjs';
export async function loadFaceExtractorModel() {
tf.setBackend('webgl'); // set webgl backend
// 模型文件访问路径
const modelsPath = `/tfjs/face_api/model/ssd_mobilenetv1_model-weights_manifest.json`;
// 模型参数-ssdMobilenetv1
await faceapi.nets.ssdMobilenetv1.load(modelsPath);
return new faceapi.SsdMobilenetv1Options({
minConfidence: 0.5, // 0 ~ 1
maxResults: 50, // 0 ~ 100
});
}
export async function fnDetectFace(img: HTMLImageElement) {
const options = await loadFaceExtractorModel();
const detections = await faceapi.detectSingleFace(
img,
options,
);
if (!detections) {
return null;
}
const faceImages = await faceapi.extractFaces(img, [detections]);
return faceImages[0].toDataURL('image/png');
}

View File

@@ -0,0 +1,54 @@
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';
export async function loadLandscapeRecognitionModel() {
const modelName = 'landscape_recognition';
const modelUrl = '/tfjs/landscape_recognition/model.json';
let model: tf.LayersModel;
tf.setBackend('webgl');
try {
// 尝试从 IndexedDB 加载模型
model = await tf.loadLayersModel(`indexeddb://${modelName}-model`);
console.log("模型成功从 IndexedDB 加载");
} catch (_error) {
console.log("从 URL 下载模型...");
// 如果 IndexedDB 中没有模型,则从 URL 加载模型
model = await tf.loadLayersModel(modelUrl);
await model.save(`indexeddb://${modelName}-model`);
console.log("模型已从 URL 下载并保存到 IndexedDB");
}
return model;
}
export const predictLandscape = async (imgElement) => {
if (!imgElement) return;
const model = await loadLandscapeRecognitionModel();
const img = tf.cast(tf.browser.fromPixels(imgElement), 'float32').resizeBilinear([150, 150]);
const offset = tf.scalar(127.5);
const normalized = img.sub(offset).div(offset);
const batched = normalized.reshape([1, 150, 150, 3]);
const results: any = model.predict(batched);
return getCategory(results.dataSync().indexOf(results.max().dataSync()[0]));
};
const getCategory = (index: number) => {
switch (index) {
case 0:
return "building";
case 1:
return "forest";
case 2:
return "glacier";
case 3:
return "mountain";
case 4:
return "sea";
case 5:
return "street";
default:
return "none";
}
};

View File

@@ -1,6 +1,7 @@
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
import * as cocoSsd from '@tensorflow-models/coco-ssd';
import '@tensorflow/tfjs-backend-webgl';
// 确保 TensorFlow.js 已准备好并设置后端
async function initializeTensorFlow(backend = "webgl") {
@@ -49,36 +50,43 @@ export async function loadMobileNet(image) {
return await model.classify(image, 3);
}
// 加载 COCO SSD 模型的工具函数
export async function loadCocoSsd(image) {
// 工具函数:加载或缓存模型
export async function loadCocoSsdModel() {
const modelName = "cocoSsd-model";
const modelUrl = '/tfjs/mobilenet/ssd-mobilenet-v2-tfjs-default-v1/model.json';
// 初始化 TensorFlow.js
if (!(await initializeTensorFlow())) {
return;
}
let model;
try {
// 尝试从 IndexedDB 加载模型
model = await cocoSsd.load({
base: 'mobilenet_v2',
modelUrl: `indexeddb://${modelName}`,
});
console.log("COCO SSD model loaded from IndexedDB successfully");
console.log(`${modelName} loaded from IndexedDB successfully`);
} catch (_error) {
console.log("Downloading COCO SSD model...");
console.log(`Downloading ${modelName}...`);
// 如果 IndexedDB 中没有模型则从 URL 加载并保存到 IndexedDB
model = await cocoSsd.load({
base: 'mobilenet_v2',
modelUrl: modelUrl,
});
const Model = await tf.loadGraphModel(modelUrl);
await Model.save(`indexeddb://${modelName}`);
console.log("COCO SSD model downloaded and saved to IndexedDB");
const graphModel = await tf.loadGraphModel(modelUrl);
await graphModel.save(`indexeddb://${modelName}`);
console.log(`${modelName} downloaded and saved to IndexedDB`);
}
return model;
}
// 加载 COCO SSD 模型的工具函数
// 使用提取的加载模型工具函数
export async function cocoSsdPredict(image) {
// 初始化 TensorFlow.js
tf.setBackend('webgl');
if (!(await initializeTensorFlow())) {
return [];
}
// 加载模型
const model = await loadCocoSsdModel();
// 使用模型进行检测
return await model.detect(image);
}