✨ complete model integration
This commit is contained in:
19
src/utils/imageUtils/imageElementToUint8Array.ts
Normal file
19
src/utils/imageUtils/imageElementToUint8Array.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
|
||||
export async function imageToUint8Array(imageElement) {
|
||||
// 创建一个 TensorFlow.js 图像张量
|
||||
const tensor = tf.browser.fromPixels(imageElement, 3).toFloat();
|
||||
|
||||
// 获取图像的宽度和高度
|
||||
const width = imageElement.width;
|
||||
const height = imageElement.height;
|
||||
|
||||
// 将张量转为 Uint8Array(RGB 格式,值范围从 0 到 255)
|
||||
const uint8Array = await tensor.data();
|
||||
|
||||
return {
|
||||
width,
|
||||
height,
|
||||
uint8Array
|
||||
};
|
||||
}
|
10
src/utils/imageUtils/imgToBase64.ts
Normal file
10
src/utils/imageUtils/imgToBase64.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
export function imageToBase64(img) {
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = img.width;
|
||||
canvas.height = img.height;
|
||||
const ctx = canvas.getContext('2d');
|
||||
if (ctx) {
|
||||
ctx.drawImage(img, 0, 0);
|
||||
}
|
||||
return canvas.toDataURL('image/png'); // 或者 'image/jpeg'
|
||||
}
|
@@ -2,21 +2,23 @@ import * as nsfwjs from "nsfwjs";
|
||||
import {NSFWJS} from "nsfwjs";
|
||||
import * as tf from "@tensorflow/tfjs";
|
||||
|
||||
/**
|
||||
* Initializes the NSFWJS model and returns it.
|
||||
*/
|
||||
let isInit: boolean = false;
|
||||
|
||||
const initNSFWJs = async (): Promise<NSFWJS> => {
|
||||
tf.enableProdMode();
|
||||
if (!isInit) {
|
||||
const initialLoad: nsfwjs.NSFWJS = await nsfwjs.load("/nsfw/mobilenet_v2_mid/", {
|
||||
size: 224,
|
||||
type: "graph"
|
||||
});
|
||||
await initialLoad.model.save("indexeddb://nsfwjs-model");
|
||||
isInit = true;
|
||||
let nsfwModelCache: NSFWJS | null = null; // 缓存模型实例
|
||||
// 如果模型已经加载,则直接返回缓存
|
||||
try {
|
||||
// 首先尝试从 IndexedDB 加载模型
|
||||
nsfwModelCache = await nsfwjs.load("indexeddb://nsfwjs-model", {size: 224, type: "graph"});
|
||||
console.log("NSFWJS 模型成功从 IndexedDB 加载");
|
||||
} catch (_error) {
|
||||
console.warn("IndexedDB 中未找到模型,正在从网络加载...");
|
||||
// 如果 IndexedDB 加载失败,从 URL 加载模型并保存到 IndexedDB
|
||||
nsfwModelCache = await nsfwjs.load("/nsfw/mobilenet_v2_mid/", {size: 224, type: "graph"});
|
||||
await nsfwModelCache.model.save("indexeddb://nsfwjs-model");
|
||||
console.log("NSFWJS 模型已从网络加载并保存到 IndexedDB");
|
||||
}
|
||||
return await nsfwjs.load("indexeddb://nsfwjs-model", {size: 224, type: "graph"});
|
||||
|
||||
return nsfwModelCache;
|
||||
};
|
||||
/**
|
||||
* Predicts the NSFW score of an image using the NSFWJS model.
|
||||
|
16
src/utils/tfjs/BBT_face_similarity.ts
Normal file
16
src/utils/tfjs/BBT_face_similarity.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import * as faceapi from '@vladmandic/face-api/dist/face-api.esm-nobundle.js';
|
||||
|
||||
export async function loadModel() {
|
||||
const modelsPath = `/tfjs/face_api/model`;
|
||||
// 面部识别模型
|
||||
await faceapi.nets.faceRecognitionNet.load(modelsPath);
|
||||
}
|
||||
|
||||
export async function faceSimilarity(img1: HTMLImageElement, img2: HTMLImageElement) {
|
||||
const descriptor1 = await faceapi.computeFaceDescriptor(img1);
|
||||
const descriptor2 = await faceapi.computeFaceDescriptor(img2);
|
||||
if (descriptor1 instanceof Float32Array && descriptor2 instanceof Float32Array) {
|
||||
return faceapi.euclideanDistance(descriptor1, descriptor2).toFixed(2);
|
||||
}
|
||||
return -1;
|
||||
}
|
@@ -1,23 +1,68 @@
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
import '@tensorflow/tfjs-backend-webgl';
|
||||
|
||||
// 封装处理图像和推理的工具函数
|
||||
export async function loadModel(modelPath) {
|
||||
const model = await tf.loadGraphModel(modelPath);
|
||||
console.log('Model Loaded');
|
||||
async function loadModelFromIndexedDBOrUrl(modelName: string, modelUrl: string) {
|
||||
let model: tf.GraphModel;
|
||||
tf.setBackend('webgl');
|
||||
try {
|
||||
// 尝试从 IndexedDB 加载模型
|
||||
model = await tf.loadGraphModel(`indexeddb://${modelName}-model`);
|
||||
console.log("模型成功从 IndexedDB 加载");
|
||||
} catch (_error) {
|
||||
console.log("从 URL 下载模型...");
|
||||
// 如果 IndexedDB 中没有模型,则从 URL 加载并保存到 IndexedDB
|
||||
model = await tf.loadGraphModel(modelUrl);
|
||||
await model.save(`indexeddb://${modelName}-model`);
|
||||
console.log("模型已从 URL 下载并保存到 IndexedDB");
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
// 封装处理图像和推理的工具函数
|
||||
export async function loadAnimeClassifierModel() {
|
||||
const modelName = 'anime_classifier';
|
||||
const modelUrl = '/tfjs/anime_classifier/model.json';
|
||||
return await loadModelFromIndexedDBOrUrl(modelName, modelUrl);
|
||||
}
|
||||
|
||||
// 处理图片并进行推理
|
||||
export async function predictImage(model, imageElement) {
|
||||
export async function animePredictImage(imageElement) {
|
||||
|
||||
const model: tf.GraphModel = await loadAnimeClassifierModel();
|
||||
// 将图片转换为张量
|
||||
const tensor = tf.browser.fromPixels(imageElement).toFloat();
|
||||
const tensor = tf.browser.fromPixels(imageElement, 3).toFloat();
|
||||
const resized = tf.image.resizeBilinear(tensor, [224, 224]); // 调整图片大小为模型输入大小
|
||||
const input = resized.expandDims(0); // 增加批次维度
|
||||
|
||||
// 进行推理
|
||||
const prediction = model.predict(input);
|
||||
const prediction: any = model.predict(input);
|
||||
|
||||
// 获取预测结果并返回
|
||||
const resultArray = await prediction.array();
|
||||
return resultArray[0]; // 返回第一项的预测结果
|
||||
const result = resultArray[0]; // 获取预测结果数组
|
||||
return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
|
||||
}
|
||||
|
||||
// export async function animePredictImage(width: number, height: number, uint8Array: Uint8Array) {
|
||||
// const model: tf.GraphModel = await loadModel();
|
||||
//
|
||||
// // 将 Uint8Array 转换为 Tensor
|
||||
// const tensor = tf.tensor3d(uint8Array, [height, width, 3], 'int32').toFloat();
|
||||
//
|
||||
// // 调整图片大小为模型输入大小
|
||||
// const resized = tf.image.resizeBilinear(tensor, [224, 224]);
|
||||
//
|
||||
// // 增加批次维度
|
||||
// const input = resized.expandDims(0);
|
||||
//
|
||||
// // 进行推理
|
||||
// const prediction: any = model.predict(input);
|
||||
//
|
||||
// // 获取预测结果并返回
|
||||
// const resultArray = await prediction.array();
|
||||
// const result = resultArray[0]; // 获取预测结果数组
|
||||
// return result.indexOf(1) === 0 ? 'Anime' : 'Neutral';
|
||||
// }
|
||||
|
||||
|
||||
|
||||
|
65
src/utils/tfjs/anime_classifier_pro.ts
Normal file
65
src/utils/tfjs/anime_classifier_pro.ts
Normal file
@@ -0,0 +1,65 @@
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
import '@tensorflow/tfjs-backend-webgl';
|
||||
|
||||
async function loadModelFromIndexedDBOrUrl(modelName: string, modelUrl: string) {
|
||||
let model: tf.LayersModel;
|
||||
tf.setBackend('webgl');
|
||||
try {
|
||||
// 尝试从 IndexedDB 加载模型
|
||||
model = await tf.loadLayersModel(`indexeddb://${modelName}-model`);
|
||||
console.log("模型成功从 IndexedDB 加载");
|
||||
} catch (_error) {
|
||||
console.log("从 URL 下载模型...");
|
||||
// 如果 IndexedDB 中没有模型,则从 URL 加载并保存到 IndexedDB
|
||||
model = await tf.loadLayersModel(modelUrl);
|
||||
await model.save(`indexeddb://${modelName}-model`);
|
||||
console.log("模型已从 URL 下载并保存到 IndexedDB");
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
// 封装处理图像和推理的工具函数
|
||||
export async function loadAnimeClassifierProModel() {
|
||||
const modelName = 'anime_classifier2';
|
||||
const modelUrl = '/tfjs/anime_classifier2/model.json';
|
||||
return await loadModelFromIndexedDBOrUrl(modelName, modelUrl);
|
||||
}
|
||||
|
||||
// 处理图片并进行推理
|
||||
export async function animePredictImagePro(imageElement) {
|
||||
|
||||
const model: any = await loadAnimeClassifierProModel();
|
||||
// 将图片转换为张量
|
||||
const tensor = tf.browser.fromPixels(imageElement).toFloat();
|
||||
const imageResized = tf.image.resizeBilinear(tensor, [224, 224]);
|
||||
const imageReshaped = imageResized.reshape([1, 224, 224, 3]);
|
||||
const imageNormalized = imageReshaped.div(255);
|
||||
|
||||
// 进行推理
|
||||
const prediction: any = model.predict(imageNormalized);
|
||||
|
||||
|
||||
const predictedClass = tf.argMax(prediction, 1).dataSync()[0];
|
||||
// const predictedClassConfidence = await prediction.dataSync()[predictedClass].toFixed(2);
|
||||
// console.log(`预测结果: ${predictedClassName}(${predictedClassConfidence})`);
|
||||
return ['Anime', 'Furry', 'Neutral'][predictedClass];
|
||||
}
|
||||
|
||||
// export async function animePredictImagePro(width: number, height: number, uint8Array: Uint8Array) {
|
||||
//
|
||||
// const model: any = await loadModel();
|
||||
// // 将图片转换为张量
|
||||
// const tensor = tf.tensor3d(uint8Array, [height, width, 3], 'int32').toFloat();
|
||||
// const imageResized = tf.image.resizeBilinear(tensor, [224, 224]);
|
||||
// const imageReshaped = imageResized.reshape([1, 224, 224, 3]);
|
||||
// const imageNormalized = imageReshaped.div(255);
|
||||
//
|
||||
// // 进行推理
|
||||
// const prediction: any = model.predict(imageNormalized);
|
||||
//
|
||||
//
|
||||
// const predictedClass = tf.argMax(prediction, 1).dataSync()[0];
|
||||
// // const predictedClassConfidence = await prediction.dataSync()[predictedClass].toFixed(2);
|
||||
// // console.log(`预测结果: ${predictedClassName}(${predictedClassConfidence})`);
|
||||
// return ['Anime', 'Furry', 'Neutral'][predictedClass];
|
||||
// }
|
42
src/utils/tfjs/face_detection.ts
Normal file
42
src/utils/tfjs/face_detection.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import '@mediapipe/face_detection';
|
||||
import '@tensorflow/tfjs-core';
|
||||
// Register WebGL backend.
|
||||
import '@tensorflow/tfjs-backend-webgl';
|
||||
import * as faceDetection from '@tensorflow-models/face-detection';
|
||||
import * as faceLandmarksDetection from '@tensorflow-models/face-landmarks-detection';
|
||||
import '@mediapipe/face_mesh';
|
||||
|
||||
/**
|
||||
* 检测人脸
|
||||
* @param image
|
||||
*/
|
||||
export async function detectFaces(image: HTMLImageElement) {
|
||||
const model = faceDetection.SupportedModels.MediaPipeFaceDetector;
|
||||
const detectorConfig: any = {
|
||||
runtime: 'tfjs',
|
||||
maxFaces: 1,
|
||||
modelType: 'short', //'short'|'full'
|
||||
};
|
||||
const detector = await faceDetection.createDetector(model, detectorConfig);
|
||||
const estimationConfig = {flipHorizontal: false};
|
||||
const faces = await detector.estimateFaces(image, estimationConfig);
|
||||
return faces;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检测人脸特征点
|
||||
* @param image
|
||||
*/
|
||||
export async function detectionFaceLandmarks(image: HTMLImageElement) {
|
||||
const model = faceLandmarksDetection.SupportedModels.MediaPipeFaceMesh;
|
||||
const detectorConfig: any = {
|
||||
runtime: 'tfjs',
|
||||
maxFaces: 1,
|
||||
refineLandmarks: false,
|
||||
};
|
||||
const detector = await faceLandmarksDetection.createDetector(model, detectorConfig);
|
||||
const estimationConfig = {flipHorizontal: false};
|
||||
const faces = await detector.estimateFaces(image, estimationConfig);
|
||||
return faces;
|
||||
|
||||
}
|
28
src/utils/tfjs/face_extraction.ts
Normal file
28
src/utils/tfjs/face_extraction.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import * as faceapi from '@vladmandic/face-api/dist/face-api.esm-nobundle.js';
|
||||
import '@tensorflow/tfjs-backend-webgl';
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
|
||||
export async function loadFaceExtractorModel() {
|
||||
tf.setBackend('webgl'); // set webgl backend
|
||||
// 模型文件访问路径
|
||||
const modelsPath = `/tfjs/face_api/model/ssd_mobilenetv1_model-weights_manifest.json`;
|
||||
// 模型参数-ssdMobilenetv1
|
||||
await faceapi.nets.ssdMobilenetv1.load(modelsPath);
|
||||
return new faceapi.SsdMobilenetv1Options({
|
||||
minConfidence: 0.5, // 0 ~ 1
|
||||
maxResults: 50, // 0 ~ 100
|
||||
});
|
||||
}
|
||||
|
||||
export async function fnDetectFace(img: HTMLImageElement) {
|
||||
const options = await loadFaceExtractorModel();
|
||||
const detections = await faceapi.detectSingleFace(
|
||||
img,
|
||||
options,
|
||||
);
|
||||
if (!detections) {
|
||||
return null;
|
||||
}
|
||||
const faceImages = await faceapi.extractFaces(img, [detections]);
|
||||
return faceImages[0].toDataURL('image/png');
|
||||
}
|
54
src/utils/tfjs/landscape_recognition.ts
Normal file
54
src/utils/tfjs/landscape_recognition.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
import '@tensorflow/tfjs-backend-webgl';
|
||||
|
||||
export async function loadLandscapeRecognitionModel() {
|
||||
const modelName = 'landscape_recognition';
|
||||
const modelUrl = '/tfjs/landscape_recognition/model.json';
|
||||
let model: tf.LayersModel;
|
||||
tf.setBackend('webgl');
|
||||
try {
|
||||
// 尝试从 IndexedDB 加载模型
|
||||
model = await tf.loadLayersModel(`indexeddb://${modelName}-model`);
|
||||
console.log("模型成功从 IndexedDB 加载");
|
||||
} catch (_error) {
|
||||
console.log("从 URL 下载模型...");
|
||||
// 如果 IndexedDB 中没有模型,则从 URL 加载模型
|
||||
model = await tf.loadLayersModel(modelUrl);
|
||||
await model.save(`indexeddb://${modelName}-model`);
|
||||
console.log("模型已从 URL 下载并保存到 IndexedDB");
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
export const predictLandscape = async (imgElement) => {
|
||||
if (!imgElement) return;
|
||||
const model = await loadLandscapeRecognitionModel();
|
||||
const img = tf.cast(tf.browser.fromPixels(imgElement), 'float32').resizeBilinear([150, 150]);
|
||||
|
||||
const offset = tf.scalar(127.5);
|
||||
const normalized = img.sub(offset).div(offset);
|
||||
const batched = normalized.reshape([1, 150, 150, 3]);
|
||||
|
||||
const results: any = model.predict(batched);
|
||||
return getCategory(results.dataSync().indexOf(results.max().dataSync()[0]));
|
||||
};
|
||||
|
||||
const getCategory = (index: number) => {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return "building";
|
||||
case 1:
|
||||
return "forest";
|
||||
case 2:
|
||||
return "glacier";
|
||||
case 3:
|
||||
return "mountain";
|
||||
case 4:
|
||||
return "sea";
|
||||
case 5:
|
||||
return "street";
|
||||
default:
|
||||
return "none";
|
||||
}
|
||||
};
|
@@ -1,6 +1,7 @@
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
import * as mobilenet from '@tensorflow-models/mobilenet';
|
||||
import * as cocoSsd from '@tensorflow-models/coco-ssd';
|
||||
import '@tensorflow/tfjs-backend-webgl';
|
||||
|
||||
// 确保 TensorFlow.js 已准备好并设置后端
|
||||
async function initializeTensorFlow(backend = "webgl") {
|
||||
@@ -49,36 +50,43 @@ export async function loadMobileNet(image) {
|
||||
return await model.classify(image, 3);
|
||||
}
|
||||
|
||||
// 加载 COCO SSD 模型的工具函数
|
||||
export async function loadCocoSsd(image) {
|
||||
// 工具函数:加载或缓存模型
|
||||
export async function loadCocoSsdModel() {
|
||||
const modelName = "cocoSsd-model";
|
||||
const modelUrl = '/tfjs/mobilenet/ssd-mobilenet-v2-tfjs-default-v1/model.json';
|
||||
|
||||
// 初始化 TensorFlow.js
|
||||
if (!(await initializeTensorFlow())) {
|
||||
return;
|
||||
}
|
||||
|
||||
let model;
|
||||
|
||||
try {
|
||||
// 尝试从 IndexedDB 加载模型
|
||||
model = await cocoSsd.load({
|
||||
base: 'mobilenet_v2',
|
||||
modelUrl: `indexeddb://${modelName}`,
|
||||
});
|
||||
console.log("COCO SSD model loaded from IndexedDB successfully");
|
||||
console.log(`${modelName} loaded from IndexedDB successfully`);
|
||||
} catch (_error) {
|
||||
console.log("Downloading COCO SSD model...");
|
||||
console.log(`Downloading ${modelName}...`);
|
||||
// 如果 IndexedDB 中没有模型则从 URL 加载并保存到 IndexedDB
|
||||
model = await cocoSsd.load({
|
||||
base: 'mobilenet_v2',
|
||||
modelUrl: modelUrl,
|
||||
});
|
||||
const Model = await tf.loadGraphModel(modelUrl);
|
||||
await Model.save(`indexeddb://${modelName}`);
|
||||
console.log("COCO SSD model downloaded and saved to IndexedDB");
|
||||
const graphModel = await tf.loadGraphModel(modelUrl);
|
||||
await graphModel.save(`indexeddb://${modelName}`);
|
||||
console.log(`${modelName} downloaded and saved to IndexedDB`);
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
// 加载 COCO SSD 模型的工具函数
|
||||
// 使用提取的加载模型工具函数
|
||||
export async function cocoSsdPredict(image) {
|
||||
// 初始化 TensorFlow.js
|
||||
tf.setBackend('webgl');
|
||||
if (!(await initializeTensorFlow())) {
|
||||
return [];
|
||||
}
|
||||
// 加载模型
|
||||
const model = await loadCocoSsdModel();
|
||||
// 使用模型进行检测
|
||||
return await model.detect(image);
|
||||
}
|
||||
|
Reference in New Issue
Block a user