// 全局变量和模型实例 let recognizer; // 基础的 SpeechCommands recognizer let transferRecognizer; // 用于迁移学习的 recognizer const labels = []; // 用户定义的类别标签数组 (包括背景噪音) // 将背景噪音定义为第一个类别,其内部名称为 _background_noise_ const BACKGROUND_NOISE_LABEL = '_background_noise_'; const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引,不直接用于collectExample let isPredicting = false; // 预测状态标志 let isRecording = false; // 录音状态标志,防止重复点击 const recordDuration = 1000; // 每个样本的录音时长 (毫秒) let isModelTrainedFlag = false; // 手动维护模型训练状态 let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数 /** * 初始化函数 - 加载模型和创建迁移学习模型 * @returns {Promise} */ async function init() { try { recognizer = speechCommands.create( 'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好 ); await recognizer.ensureModelLoaded(); transferRecognizer = recognizer.createTransfer('my-custom-model'); // 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组 labels.push(BACKGROUND_NOISE_LABEL); return Promise.resolve(); } catch (error) { return Promise.reject(error); } } /** * 批量录制样本的通用函数 * @param {string} label - 标签名称 * @param {number} countToRecord - 要录制的样本数量 * @returns {Promise} */ async function recordMultipleExamples(label, countToRecord = 5) { if (isRecording) { return Promise.reject(new Error('正在录制中,请等待当前录音完成')); } isRecording = true; for (let i = 0; i < countToRecord; i++) { try { await transferRecognizer.collectExample( label, { amplitudeRequired: true, durationMillis: recordDuration } ); // 在每次录音之间增加短暂延迟,以便更好地分离样本 if (i < countToRecord - 1) { await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5))); } } catch (error) { isRecording = false; return Promise.reject(error); } } isRecording = false; return Promise.resolve(); } /** * 添加自定义类别 * @param {string} categoryName - 类别名称 */ function addCustomCategory(categoryName) { if (!categoryName) { return Promise.reject(new Error('类别名称不能为空')); } // 检查是否与现有标签重复 if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) { return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`)); } // 将标签添加到本地数组 labels.push(categoryName); return Promise.resolve(); } /** * 检查训练就绪状态 * @returns {boolean} 是否可以开始训练 */ function checkTrainingReadiness() { const exampleCounts = transferRecognizer.countExamples(); let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0; let customCategoriesReady = 0; // 遍历本地 labels 数组,检查每个自定义类别是否有样本 for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音 const customLabel = labels[i]; if ((exampleCounts[customLabel] || 0) > 0) { customCategoriesReady++; } } // 必须有背景噪音样本,并且至少一个自定义类别有样本 return backgroundNoiseReady && customCategoriesReady >= 1; } /** * 模型训练函数 * @param {Object} trainingConfig - 训练配置参数 * @returns {Promise} */ async function trainModel(trainingConfig = {}) { const exampleCounts = transferRecognizer.countExamples(); let totalExamples = 0; let validClasses = 0; const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5; let allClassesHaveEnoughSamples = true; // 统计所有类别的有效样本数,并检查每个类别是否达到最低要求 for (const labelName of labels) { if (exampleCounts[labelName] && exampleCounts[labelName] > 0) { totalExamples += exampleCounts[labelName]; validClasses++; if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) { allClassesHaveEnoughSamples = false; } } } if (validClasses < 2) { return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`)); } if (!allClassesHaveEnoughSamples) { return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`)); } if (totalExamples === 0) { return Promise.reject(new Error('没有收集到任何训练样本')); } const defaultConfig = { epochs: 50, batchSize: 16, validationSplit: 0.1, shuffle: true, yieldEvery: 'epoch' }; const config = Object.assign(defaultConfig, trainingConfig); try { await transferRecognizer.train(config); isModelTrainedFlag = true; return Promise.resolve(); } catch (error) { isModelTrainedFlag = false; return Promise.reject(error); } } /** * 开始实时预测 * @param {Function} onPrediction - 预测结果回调函数 * @param {Object} listenOptions - 监听选项 * @returns {Promise} 停止预测的函数 */ async function startPrediction(onPrediction, listenOptions = {}) { if (isPredicting) { return Promise.reject(new Error('识别已经在进行中')); } if (!isModelTrainedFlag) { return Promise.reject(new Error('模型尚未训练完成')); } isPredicting = true; const defaultOptions = { includeEmbedding: true, probabilityThreshold: 0.75, suppressionTimeMillis: 300, overlapFactor: 0.50, }; const options = Object.assign(defaultOptions, listenOptions); predictionStopFunction = await transferRecognizer.listen(result => { if (!isPredicting) return; const classLabels = transferRecognizer.wordLabels(); const scores = result.scores; const maxScore = Math.max(...scores); const predictedIndex = scores.indexOf(maxScore); let predictedLabel = classLabels[predictedIndex]; // 如果预测结果是内部的背景噪音标签,转换成用户友好的显示 if (predictedLabel === BACKGROUND_NOISE_LABEL) { predictedLabel = '背景噪音'; } // 调用回调函数返回预测结果 if (typeof onPrediction === 'function') { onPrediction({ label: predictedLabel, score: maxScore, scores: scores, labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) }); } }, options); return Promise.resolve(predictionStopFunction); } /** * 停止实时预测 */ function stopPrediction() { if (isPredicting) { if (typeof predictionStopFunction === 'function') { predictionStopFunction(); predictionStopFunction = null; } isPredicting = false; } } /** * 获取各类别样本数量 * @returns {Object} 各类别样本数量统计 */ function getExampleCounts() { return transferRecognizer.countExamples(); } /** * 获取模型是否已训练的状态 * @returns {boolean} 模型是否已训练 */ function isModelTrained() { return isModelTrainedFlag; } // 导出公共接口 window.AudioClassifier = { init, recordMultipleExamples, addCustomCategory, checkTrainingReadiness, trainModel, startPrediction, stopPrediction, getExampleCounts, isModelTrained, labels };