264 lines
8.1 KiB
JavaScript
264 lines
8.1 KiB
JavaScript
// 全局变量和模型实例
|
||
let recognizer; // 基础的 SpeechCommands recognizer
|
||
let transferRecognizer; // 用于迁移学习的 recognizer
|
||
const labels = []; // 用户定义的类别标签数组 (包括背景噪音)
|
||
// 将背景噪音定义为第一个类别,其内部名称为 _background_noise_
|
||
const BACKGROUND_NOISE_LABEL = '_background_noise_';
|
||
const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引,不直接用于collectExample
|
||
|
||
let isPredicting = false; // 预测状态标志
|
||
let isRecording = false; // 录音状态标志,防止重复点击
|
||
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
|
||
let isModelTrainedFlag = false; // 手动维护模型训练状态
|
||
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
|
||
|
||
/**
|
||
* 初始化函数 - 加载模型和创建迁移学习模型
|
||
* @returns {Promise<void>}
|
||
*/
|
||
async function init() {
|
||
try {
|
||
recognizer = speechCommands.create(
|
||
'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好
|
||
);
|
||
|
||
await recognizer.ensureModelLoaded();
|
||
transferRecognizer = recognizer.createTransfer('my-custom-model');
|
||
|
||
// 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组
|
||
labels.push(BACKGROUND_NOISE_LABEL);
|
||
|
||
return Promise.resolve();
|
||
} catch (error) {
|
||
return Promise.reject(error);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 批量录制样本的通用函数
|
||
* @param {string} label - 标签名称
|
||
* @param {number} countToRecord - 要录制的样本数量
|
||
* @returns {Promise<void>}
|
||
*/
|
||
async function recordMultipleExamples(label, countToRecord = 5) {
|
||
if (isRecording) {
|
||
return Promise.reject(new Error('正在录制中,请等待当前录音完成'));
|
||
}
|
||
|
||
isRecording = true;
|
||
|
||
for (let i = 0; i < countToRecord; i++) {
|
||
try {
|
||
await transferRecognizer.collectExample(
|
||
label,
|
||
{ amplitudeRequired: true, durationMillis: recordDuration }
|
||
);
|
||
// 在每次录音之间增加短暂延迟,以便更好地分离样本
|
||
if (i < countToRecord - 1) {
|
||
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
|
||
}
|
||
} catch (error) {
|
||
isRecording = false;
|
||
return Promise.reject(error);
|
||
}
|
||
}
|
||
|
||
isRecording = false;
|
||
return Promise.resolve();
|
||
}
|
||
|
||
/**
|
||
* 添加自定义类别
|
||
* @param {string} categoryName - 类别名称
|
||
*/
|
||
function addCustomCategory(categoryName) {
|
||
if (!categoryName) {
|
||
return Promise.reject(new Error('类别名称不能为空'));
|
||
}
|
||
|
||
// 检查是否与现有标签重复
|
||
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
|
||
return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`));
|
||
}
|
||
|
||
// 将标签添加到本地数组
|
||
labels.push(categoryName);
|
||
return Promise.resolve();
|
||
}
|
||
|
||
/**
|
||
* 检查训练就绪状态
|
||
* @returns {boolean} 是否可以开始训练
|
||
*/
|
||
function checkTrainingReadiness() {
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
|
||
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
|
||
|
||
let customCategoriesReady = 0;
|
||
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
|
||
for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音
|
||
const customLabel = labels[i];
|
||
if ((exampleCounts[customLabel] || 0) > 0) {
|
||
customCategoriesReady++;
|
||
}
|
||
}
|
||
|
||
// 必须有背景噪音样本,并且至少一个自定义类别有样本
|
||
return backgroundNoiseReady && customCategoriesReady >= 1;
|
||
}
|
||
|
||
/**
|
||
* 模型训练函数
|
||
* @param {Object} trainingConfig - 训练配置参数
|
||
* @returns {Promise<void>}
|
||
*/
|
||
async function trainModel(trainingConfig = {}) {
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
|
||
let totalExamples = 0;
|
||
let validClasses = 0;
|
||
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
|
||
let allClassesHaveEnoughSamples = true;
|
||
|
||
// 统计所有类别的有效样本数,并检查每个类别是否达到最低要求
|
||
for (const labelName of labels) {
|
||
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
|
||
totalExamples += exampleCounts[labelName];
|
||
validClasses++;
|
||
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||
allClassesHaveEnoughSamples = false;
|
||
}
|
||
}
|
||
}
|
||
|
||
if (validClasses < 2) {
|
||
return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`));
|
||
}
|
||
|
||
if (!allClassesHaveEnoughSamples) {
|
||
return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`));
|
||
}
|
||
|
||
if (totalExamples === 0) {
|
||
return Promise.reject(new Error('没有收集到任何训练样本'));
|
||
}
|
||
|
||
const defaultConfig = {
|
||
epochs: 50,
|
||
batchSize: 16,
|
||
validationSplit: 0.1,
|
||
shuffle: true,
|
||
yieldEvery: 'epoch'
|
||
};
|
||
|
||
const config = Object.assign(defaultConfig, trainingConfig);
|
||
|
||
try {
|
||
await transferRecognizer.train(config);
|
||
isModelTrainedFlag = true;
|
||
return Promise.resolve();
|
||
} catch (error) {
|
||
isModelTrainedFlag = false;
|
||
return Promise.reject(error);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 开始实时预测
|
||
* @param {Function} onPrediction - 预测结果回调函数
|
||
* @param {Object} listenOptions - 监听选项
|
||
* @returns {Promise<Function>} 停止预测的函数
|
||
*/
|
||
async function startPrediction(onPrediction, listenOptions = {}) {
|
||
if (isPredicting) {
|
||
return Promise.reject(new Error('识别已经在进行中'));
|
||
}
|
||
|
||
if (!isModelTrainedFlag) {
|
||
return Promise.reject(new Error('模型尚未训练完成'));
|
||
}
|
||
|
||
isPredicting = true;
|
||
|
||
const defaultOptions = {
|
||
includeEmbedding: true,
|
||
probabilityThreshold: 0.75,
|
||
suppressionTimeMillis: 300,
|
||
overlapFactor: 0.50,
|
||
};
|
||
|
||
const options = Object.assign(defaultOptions, listenOptions);
|
||
|
||
predictionStopFunction = await transferRecognizer.listen(result => {
|
||
if (!isPredicting) return;
|
||
|
||
const classLabels = transferRecognizer.wordLabels();
|
||
const scores = result.scores;
|
||
const maxScore = Math.max(...scores);
|
||
const predictedIndex = scores.indexOf(maxScore);
|
||
|
||
let predictedLabel = classLabels[predictedIndex];
|
||
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
|
||
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
|
||
predictedLabel = '背景噪音';
|
||
}
|
||
|
||
// 调用回调函数返回预测结果
|
||
if (typeof onPrediction === 'function') {
|
||
onPrediction({
|
||
label: predictedLabel,
|
||
score: maxScore,
|
||
scores: scores,
|
||
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label)
|
||
});
|
||
}
|
||
|
||
}, options);
|
||
|
||
return Promise.resolve(predictionStopFunction);
|
||
}
|
||
|
||
/**
|
||
* 停止实时预测
|
||
*/
|
||
function stopPrediction() {
|
||
if (isPredicting) {
|
||
if (typeof predictionStopFunction === 'function') {
|
||
predictionStopFunction();
|
||
predictionStopFunction = null;
|
||
}
|
||
|
||
isPredicting = false;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取各类别样本数量
|
||
* @returns {Object} 各类别样本数量统计
|
||
*/
|
||
function getExampleCounts() {
|
||
return transferRecognizer.countExamples();
|
||
}
|
||
|
||
/**
|
||
* 获取模型是否已训练的状态
|
||
* @returns {boolean} 模型是否已训练
|
||
*/
|
||
function isModelTrained() {
|
||
return isModelTrainedFlag;
|
||
}
|
||
|
||
// 导出公共接口
|
||
window.AudioClassifier = {
|
||
init,
|
||
recordMultipleExamples,
|
||
addCustomCategory,
|
||
checkTrainingReadiness,
|
||
trainModel,
|
||
startPrediction,
|
||
stopPrediction,
|
||
getExampleCounts,
|
||
isModelTrained,
|
||
labels
|
||
}; |