add nsfw

This commit is contained in:
landaiqing
2024-10-02 18:02:36 +08:00
parent 40f2f0b2de
commit af1f57018b
19 changed files with 353 additions and 9 deletions

51
src/utils/nsfw/nsfw.ts Normal file
View File

@@ -0,0 +1,51 @@
import * as nsfwjs from "nsfwjs";
import {NSFWJS} from "nsfwjs";
import * as tf from "@tensorflow/tfjs";
/**
* Initializes the NSFWJS model and returns it.
*/
let isInit: boolean = false;
const initNSFWJs = async (): Promise<NSFWJS> => {
tf.enableProdMode();
if (!isInit) {
const initialLoad = await nsfwjs.load("/nsfw/model/mobilenet_v2_mid/model.json", {size: 224, type: "graph"});
await initialLoad.model.save("indexeddb://nsfwjs-model");
isInit = true;
}
return await nsfwjs.load("indexeddb://nsfwjs-model", {size: 224, type: "graph"});
};
/**
* Predicts the NSFW score of an image using the NSFWJS model.
* @param model
* @param image
*/
const predictNSFW = async (model: NSFWJS, image: tf.Tensor3D | ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement): Promise<boolean> => {
const predictions = await model.classify(image, 5);
console.log(predictions);
// 定义阈值与对应的类别
const thresholds = {
'Porn': 0.5,
'Hentai': 0.3,
'Sexy': 0.5
};
// 使用一个变量来确定是否为色情内容
let isNSFW: boolean = false;
// 遍历预测结果,并检查是否满足阈值
for (const prediction of predictions) {
const className = prediction.className;
const probability = prediction.probability;
// 检查预测类别是否在阈值对象中
if (thresholds[className] !== undefined && probability >= thresholds[className]) {
isNSFW = true;
break; // 早期退出,如果满足任一条件
}
}
return isNSFW; // 返回是否为色情图片
};
export {initNSFWJs, predictNSFW};