Fix crash at startup if TensorFlow is not supported (#8984)
* Lazy loading tensorflow * CHANGELOG * CHANGELOG * Check CPU flags * .
This commit is contained in:
parent
1557d0afb8
commit
660781afd9
|
@ -9,6 +9,13 @@
|
||||||
You should also include the user name that made the change.
|
You should also include the user name that made the change.
|
||||||
-->
|
-->
|
||||||
|
|
||||||
|
## 12.x.x (unreleased)
|
||||||
|
|
||||||
|
### Improvements
|
||||||
|
|
||||||
|
### Bugfixes
|
||||||
|
- Server: Fix crash at startup if TensorFlow is not supported @mei23
|
||||||
|
|
||||||
## 12.112.3 (2022/07/09)
|
## 12.112.3 (2022/07/09)
|
||||||
|
|
||||||
### Improvements
|
### Improvements
|
||||||
|
|
|
@ -2,19 +2,34 @@ import * as fs from 'node:fs';
|
||||||
import { fileURLToPath } from 'node:url';
|
import { fileURLToPath } from 'node:url';
|
||||||
import { dirname } from 'node:path';
|
import { dirname } from 'node:path';
|
||||||
import * as nsfw from 'nsfwjs';
|
import * as nsfw from 'nsfwjs';
|
||||||
import * as tf from '@tensorflow/tfjs-node';
|
import si from 'systeminformation';
|
||||||
|
|
||||||
const _filename = fileURLToPath(import.meta.url);
|
const _filename = fileURLToPath(import.meta.url);
|
||||||
const _dirname = dirname(_filename);
|
const _dirname = dirname(_filename);
|
||||||
|
|
||||||
|
const REQUIRED_CPU_FLAGS = ['avx2', 'fma'];
|
||||||
|
let isSupportedCpu: undefined | boolean = undefined;
|
||||||
|
|
||||||
let model: nsfw.NSFWJS;
|
let model: nsfw.NSFWJS;
|
||||||
|
|
||||||
export async function detectSensitive(path: string): Promise<nsfw.predictionType[] | null> {
|
export async function detectSensitive(path: string): Promise<nsfw.predictionType[] | null> {
|
||||||
try {
|
try {
|
||||||
|
if (isSupportedCpu === undefined) {
|
||||||
|
const cpuFlags = await getCpuFlags();
|
||||||
|
isSupportedCpu = REQUIRED_CPU_FLAGS.every(required => cpuFlags.includes(required));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isSupportedCpu) {
|
||||||
|
console.error('This CPU cannot use TensorFlow.');
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tf = await import('@tensorflow/tfjs-node');
|
||||||
|
|
||||||
if (model == null) model = await nsfw.load(`file://${_dirname}/../../nsfw-model/`, { size: 299 });
|
if (model == null) model = await nsfw.load(`file://${_dirname}/../../nsfw-model/`, { size: 299 });
|
||||||
|
|
||||||
const buffer = await fs.promises.readFile(path);
|
const buffer = await fs.promises.readFile(path);
|
||||||
const image = await tf.node.decodeImage(buffer, 3) as tf.Tensor3D;
|
const image = await tf.node.decodeImage(buffer, 3) as any;
|
||||||
try {
|
try {
|
||||||
const predictions = await model.classify(image);
|
const predictions = await model.classify(image);
|
||||||
return predictions;
|
return predictions;
|
||||||
|
@ -26,3 +41,8 @@ export async function detectSensitive(path: string): Promise<nsfw.predictionType
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function getCpuFlags(): Promise<string[]> {
|
||||||
|
const str = await si.cpuFlags();
|
||||||
|
return str.split(/\s+/);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue