mobileNet/音频分类/audioClassifier.js
2025-08-14 14:35:44 +08:00

264 lines
8.1 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// 全局变量和模型实例
let recognizer; // 基础的 SpeechCommands recognizer
let transferRecognizer; // 用于迁移学习的 recognizer
const labels = []; // 用户定义的类别标签数组 (包括背景噪音)
// 将背景噪音定义为第一个类别,其内部名称为 _background_noise_
const BACKGROUND_NOISE_LABEL = '_background_noise_';
const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引不直接用于collectExample
let isPredicting = false; // 预测状态标志
let isRecording = false; // 录音状态标志,防止重复点击
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
let isModelTrainedFlag = false; // 手动维护模型训练状态
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
/**
* 初始化函数 - 加载模型和创建迁移学习模型
* @returns {Promise<void>}
*/
async function init() {
try {
recognizer = speechCommands.create(
'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好
);
await recognizer.ensureModelLoaded();
transferRecognizer = recognizer.createTransfer('my-custom-model');
// 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组
labels.push(BACKGROUND_NOISE_LABEL);
return Promise.resolve();
} catch (error) {
return Promise.reject(error);
}
}
/**
* 批量录制样本的通用函数
* @param {string} label - 标签名称
* @param {number} countToRecord - 要录制的样本数量
* @returns {Promise<void>}
*/
async function recordMultipleExamples(label, countToRecord = 5) {
if (isRecording) {
return Promise.reject(new Error('正在录制中,请等待当前录音完成'));
}
isRecording = true;
for (let i = 0; i < countToRecord; i++) {
try {
await transferRecognizer.collectExample(
label,
{ amplitudeRequired: true, durationMillis: recordDuration }
);
// 在每次录音之间增加短暂延迟,以便更好地分离样本
if (i < countToRecord - 1) {
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
}
} catch (error) {
isRecording = false;
return Promise.reject(error);
}
}
isRecording = false;
return Promise.resolve();
}
/**
* 添加自定义类别
* @param {string} categoryName - 类别名称
*/
function addCustomCategory(categoryName) {
if (!categoryName) {
return Promise.reject(new Error('类别名称不能为空'));
}
// 检查是否与现有标签重复
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`));
}
// 将标签添加到本地数组
labels.push(categoryName);
return Promise.resolve();
}
/**
* 检查训练就绪状态
* @returns {boolean} 是否可以开始训练
*/
function checkTrainingReadiness() {
const exampleCounts = transferRecognizer.countExamples();
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
let customCategoriesReady = 0;
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音
const customLabel = labels[i];
if ((exampleCounts[customLabel] || 0) > 0) {
customCategoriesReady++;
}
}
// 必须有背景噪音样本,并且至少一个自定义类别有样本
return backgroundNoiseReady && customCategoriesReady >= 1;
}
/**
* 模型训练函数
* @param {Object} trainingConfig - 训练配置参数
* @returns {Promise<void>}
*/
async function trainModel(trainingConfig = {}) {
const exampleCounts = transferRecognizer.countExamples();
let totalExamples = 0;
let validClasses = 0;
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
let allClassesHaveEnoughSamples = true;
// 统计所有类别的有效样本数,并检查每个类别是否达到最低要求
for (const labelName of labels) {
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
totalExamples += exampleCounts[labelName];
validClasses++;
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
allClassesHaveEnoughSamples = false;
}
}
}
if (validClasses < 2) {
return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`));
}
if (!allClassesHaveEnoughSamples) {
return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`));
}
if (totalExamples === 0) {
return Promise.reject(new Error('没有收集到任何训练样本'));
}
const defaultConfig = {
epochs: 50,
batchSize: 16,
validationSplit: 0.1,
shuffle: true,
yieldEvery: 'epoch'
};
const config = Object.assign(defaultConfig, trainingConfig);
try {
await transferRecognizer.train(config);
isModelTrainedFlag = true;
return Promise.resolve();
} catch (error) {
isModelTrainedFlag = false;
return Promise.reject(error);
}
}
/**
* 开始实时预测
* @param {Function} onPrediction - 预测结果回调函数
* @param {Object} listenOptions - 监听选项
* @returns {Promise<Function>} 停止预测的函数
*/
async function startPrediction(onPrediction, listenOptions = {}) {
if (isPredicting) {
return Promise.reject(new Error('识别已经在进行中'));
}
if (!isModelTrainedFlag) {
return Promise.reject(new Error('模型尚未训练完成'));
}
isPredicting = true;
const defaultOptions = {
includeEmbedding: true,
probabilityThreshold: 0.75,
suppressionTimeMillis: 300,
overlapFactor: 0.50,
};
const options = Object.assign(defaultOptions, listenOptions);
predictionStopFunction = await transferRecognizer.listen(result => {
if (!isPredicting) return;
const classLabels = transferRecognizer.wordLabels();
const scores = result.scores;
const maxScore = Math.max(...scores);
const predictedIndex = scores.indexOf(maxScore);
let predictedLabel = classLabels[predictedIndex];
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
predictedLabel = '背景噪音';
}
// 调用回调函数返回预测结果
if (typeof onPrediction === 'function') {
onPrediction({
label: predictedLabel,
score: maxScore,
scores: scores,
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label)
});
}
}, options);
return Promise.resolve(predictionStopFunction);
}
/**
* 停止实时预测
*/
function stopPrediction() {
if (isPredicting) {
if (typeof predictionStopFunction === 'function') {
predictionStopFunction();
predictionStopFunction = null;
}
isPredicting = false;
}
}
/**
* 获取各类别样本数量
* @returns {Object} 各类别样本数量统计
*/
function getExampleCounts() {
return transferRecognizer.countExamples();
}
/**
* 获取模型是否已训练的状态
* @returns {boolean} 模型是否已训练
*/
function isModelTrained() {
return isModelTrainedFlag;
}
// 导出公共接口
window.AudioClassifier = {
init,
recordMultipleExamples,
addCustomCategory,
checkTrainingReadiness,
trainModel,
startPrediction,
stopPrediction,
getExampleCounts,
isModelTrained,
labels
};