diff --git a/音频分类/README.md b/音频分类/README.md index 3c9d69d..2b52004 100644 --- a/音频分类/README.md +++ b/音频分类/README.md @@ -33,6 +33,27 @@ **注:直接打开index.html文件会需要重复授权麦克风权限,请使用live server插件开启本地服务器可以解决** +## 提取主要功能 +audioClassifier.js文件中主要实现了以下功能: + +* 录制背景噪音样本:通过录制音频文件,将其转换成频谱图,并传入模型进行训练。 +* 训练模型:将背景噪音样本和自定义声音样本合并,并训练模型。 +* 实时识别:实时从麦克风输入音频,将其转换成频谱图,并传入模型进行识别。 + +## 目录结构 + +``` +. +├── README.md +├── script.js +├── voice.html +├── audioClassifier.js +├── speech-commands(js文件仓库,不需要关注) +│ └── ... +└── +``` + + ## 音频切片 diff --git a/音频分类/audioClassifier.js b/音频分类/audioClassifier.js new file mode 100644 index 0000000..e716d28 --- /dev/null +++ b/音频分类/audioClassifier.js @@ -0,0 +1,264 @@ +// 全局变量和模型实例 +let recognizer; // 基础的 SpeechCommands recognizer +let transferRecognizer; // 用于迁移学习的 recognizer +const labels = []; // 用户定义的类别标签数组 (包括背景噪音) +// 将背景噪音定义为第一个类别,其内部名称为 _background_noise_ +const BACKGROUND_NOISE_LABEL = '_background_noise_'; +const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引,不直接用于collectExample + +let isPredicting = false; // 预测状态标志 +let isRecording = false; // 录音状态标志,防止重复点击 +const recordDuration = 1000; // 每个样本的录音时长 (毫秒) +let isModelTrainedFlag = false; // 手动维护模型训练状态 +let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数 + +/** + * 初始化函数 - 加载模型和创建迁移学习模型 + * @returns {Promise} + */ +async function init() { + try { + recognizer = speechCommands.create( + 'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好 + ); + + await recognizer.ensureModelLoaded(); + transferRecognizer = recognizer.createTransfer('my-custom-model'); + + // 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组 + labels.push(BACKGROUND_NOISE_LABEL); + + return Promise.resolve(); + } catch (error) { + return Promise.reject(error); + } +} + +/** + * 批量录制样本的通用函数 + * @param {string} label - 标签名称 + * @param {number} countToRecord - 要录制的样本数量 + * @returns {Promise} + */ +async function recordMultipleExamples(label, countToRecord = 5) { + if (isRecording) { + return Promise.reject(new Error('正在录制中,请等待当前录音完成')); + } + + isRecording = true; + + for (let i = 0; i < countToRecord; i++) { + try { + await transferRecognizer.collectExample( + label, + { amplitudeRequired: true, durationMillis: recordDuration } + ); + // 在每次录音之间增加短暂延迟,以便更好地分离样本 + if (i < countToRecord - 1) { + await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5))); + } + } catch (error) { + isRecording = false; + return Promise.reject(error); + } + } + + isRecording = false; + return Promise.resolve(); +} + +/** + * 添加自定义类别 + * @param {string} categoryName - 类别名称 + */ +function addCustomCategory(categoryName) { + if (!categoryName) { + return Promise.reject(new Error('类别名称不能为空')); + } + + // 检查是否与现有标签重复 + if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) { + return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`)); + } + + // 将标签添加到本地数组 + labels.push(categoryName); + return Promise.resolve(); +} + +/** + * 检查训练就绪状态 + * @returns {boolean} 是否可以开始训练 + */ +function checkTrainingReadiness() { + const exampleCounts = transferRecognizer.countExamples(); + + let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0; + + let customCategoriesReady = 0; + // 遍历本地 labels 数组,检查每个自定义类别是否有样本 + for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音 + const customLabel = labels[i]; + if ((exampleCounts[customLabel] || 0) > 0) { + customCategoriesReady++; + } + } + + // 必须有背景噪音样本,并且至少一个自定义类别有样本 + return backgroundNoiseReady && customCategoriesReady >= 1; +} + +/** + * 模型训练函数 + * @param {Object} trainingConfig - 训练配置参数 + * @returns {Promise} + */ +async function trainModel(trainingConfig = {}) { + const exampleCounts = transferRecognizer.countExamples(); + + let totalExamples = 0; + let validClasses = 0; + const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5; + let allClassesHaveEnoughSamples = true; + + // 统计所有类别的有效样本数,并检查每个类别是否达到最低要求 + for (const labelName of labels) { + if (exampleCounts[labelName] && exampleCounts[labelName] > 0) { + totalExamples += exampleCounts[labelName]; + validClasses++; + if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) { + allClassesHaveEnoughSamples = false; + } + } + } + + if (validClasses < 2) { + return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`)); + } + + if (!allClassesHaveEnoughSamples) { + return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`)); + } + + if (totalExamples === 0) { + return Promise.reject(new Error('没有收集到任何训练样本')); + } + + const defaultConfig = { + epochs: 50, + batchSize: 16, + validationSplit: 0.1, + shuffle: true, + yieldEvery: 'epoch' + }; + + const config = Object.assign(defaultConfig, trainingConfig); + + try { + await transferRecognizer.train(config); + isModelTrainedFlag = true; + return Promise.resolve(); + } catch (error) { + isModelTrainedFlag = false; + return Promise.reject(error); + } +} + +/** + * 开始实时预测 + * @param {Function} onPrediction - 预测结果回调函数 + * @param {Object} listenOptions - 监听选项 + * @returns {Promise} 停止预测的函数 + */ +async function startPrediction(onPrediction, listenOptions = {}) { + if (isPredicting) { + return Promise.reject(new Error('识别已经在进行中')); + } + + if (!isModelTrainedFlag) { + return Promise.reject(new Error('模型尚未训练完成')); + } + + isPredicting = true; + + const defaultOptions = { + includeEmbedding: true, + probabilityThreshold: 0.75, + suppressionTimeMillis: 300, + overlapFactor: 0.50, + }; + + const options = Object.assign(defaultOptions, listenOptions); + + predictionStopFunction = await transferRecognizer.listen(result => { + if (!isPredicting) return; + + const classLabels = transferRecognizer.wordLabels(); + const scores = result.scores; + const maxScore = Math.max(...scores); + const predictedIndex = scores.indexOf(maxScore); + + let predictedLabel = classLabels[predictedIndex]; + // 如果预测结果是内部的背景噪音标签,转换成用户友好的显示 + if (predictedLabel === BACKGROUND_NOISE_LABEL) { + predictedLabel = '背景噪音'; + } + + // 调用回调函数返回预测结果 + if (typeof onPrediction === 'function') { + onPrediction({ + label: predictedLabel, + score: maxScore, + scores: scores, + labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) + }); + } + + }, options); + + return Promise.resolve(predictionStopFunction); +} + +/** + * 停止实时预测 + */ +function stopPrediction() { + if (isPredicting) { + if (typeof predictionStopFunction === 'function') { + predictionStopFunction(); + predictionStopFunction = null; + } + + isPredicting = false; + } +} + +/** + * 获取各类别样本数量 + * @returns {Object} 各类别样本数量统计 + */ +function getExampleCounts() { + return transferRecognizer.countExamples(); +} + +/** + * 获取模型是否已训练的状态 + * @returns {boolean} 模型是否已训练 + */ +function isModelTrained() { + return isModelTrainedFlag; +} + +// 导出公共接口 +window.AudioClassifier = { + init, + recordMultipleExamples, + addCustomCategory, + checkTrainingReadiness, + trainModel, + startPrediction, + stopPrediction, + getExampleCounts, + isModelTrained, + labels +}; \ No newline at end of file diff --git a/音频分类/script.js b/音频分类/script.js index 316bb53..5ef92fe 100644 --- a/音频分类/script.js +++ b/音频分类/script.js @@ -63,6 +63,7 @@ async function init() { } // ======================= 批量录制样本的通用函数 ======================= +// recordMultipleExamples传入 label, 样本数量显示元素, 按钮元素, 一次录制的样本数量 async function recordMultipleExamples(label, sampleCountSpanElement, buttonElement, countToRecord = 5) { // 默认一次录制5个样本 if (isRecording) { statusDiv.innerText = '请等待当前录音完成...'; @@ -102,6 +103,7 @@ async function recordMultipleExamples(label, sampleCountSpanElement, buttonEleme } // ======================= 背景噪音样本收集 ======================= +// 按钮点击事件 recordBackgroundNoiseBtn.onclick = async () => { await recordMultipleExamples(BACKGROUND_NOISE_LABEL, backgroundNoiseSampleCountSpan, recordBackgroundNoiseBtn, 5); };