// 全局变量和模型实例 let recognizer; let transferRecognizer; // labels 现在将根据导入的数据动态重建,但仍需初始化 let labels = []; const BACKGROUND_NOISE_LABEL = '_background_noise_'; let isPredicting = false; let isRecording = false; const recordDuration = 1000; let isModelTrainedFlag = false; let predictionStopFunction = null; // UI 元素引用 (已更新) const statusDiv = document.getElementById('status'); const backgroundNoiseSampleCountSpan = document.getElementById('backgroundNoiseSampleCount'); const recordBackgroundNoiseBtn = document.getElementById('recordBackgroundNoiseBtn'); const categoryContainer = document.getElementById('categoryContainer'); const newCategoryNameInput = document.getElementById('newCategoryName'); const addCategoryBtn = document.getElementById('addCategoryBtn'); const trainModelBtn = document.getElementById('trainModelBtn'); const startPredictingBtn = document.getElementById('startPredictingBtn'); const stopPredictingBtn = document.getElementById('stopPredictingBtn'); const predictionResultDiv = document.getElementById('predictionResult'); // ===== 新增UI元素引用 ===== const exportModelBtn = document.getElementById('exportModelBtn'); const importModelBtn = document.getElementById('importModelBtn'); const importFileInput = document.getElementById('importFileInput'); // ======================= 初始化函数 ======================= async function init() { statusDiv.innerText = '正在加载 TensorFlow.js 和 Speech Commands 模型...'; try { recognizer = speechCommands.create('BROWSER_FFT'); await recognizer.ensureModelLoaded(); transferRecognizer = recognizer.createTransfer('my-custom-model'); // 初始化时清空并设置背景噪音标签 labels = [BACKGROUND_NOISE_LABEL]; statusDiv.innerText = '模型加载成功!你可以开始录制、或导入已有的样本数据。'; recordBackgroundNoiseBtn.disabled = false; addCategoryBtn.disabled = false; importModelBtn.disabled = false; // 允许导入 exportModelBtn.disabled = true; // 尚无数据,默认禁用导出按钮 trainModelBtn.disabled = true; startPredictingBtn.disabled = true; stopPredictingBtn.disabled = true; isModelTrainedFlag = false; // --- 修正之处:移除此处对 checkTrainingReadiness() 的调用 --- // checkTrainingReadiness(); // <--- 移除这一行! } catch (error) { statusDiv.innerText = `模型加载失败或麦克风无法访问: ${error.message}. 请检查麦克风权限和网络连接。`; console.error('初始化失败:', error); // 禁用所有按钮 const buttons = document.querySelectorAll('button'); buttons.forEach(btn => btn.disabled = true); isModelTrainedFlag = false; } } // ======================= 批量录制样本的通用函数 ======================= async function recordMultipleExamples(label, sampleCountSpanElement, buttonElement, countToRecord = 5) { if (isRecording) { statusDiv.innerText = '请等待当前录音完成...'; return; } isRecording = true; buttonElement.disabled = true; buttonElement.innerText = '正在录制...'; for (let i = 0; i < countToRecord; i++) { statusDiv.innerText = `正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`; try { await transferRecognizer.collectExample( label, { amplitudeRequired: true, durationMillis: recordDuration } ); const exampleCounts = transferRecognizer.countExamples(); sampleCountSpanElement.innerText = exampleCounts[label] || 0; if (i < countToRecord - 1) { await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5))); } } catch (error) { statusDiv.innerText = `录制 "${label}" 样本失败: ${error.message}`; console.error(`录制 ${label} 样本失败:`, error); // 如果某个样本录制失败,则停止当前批次的录制 break; } } buttonElement.disabled = false; buttonElement.innerText = '录制样本'; isRecording = false; checkTrainingReadiness(); // 录制完成后检查训练就绪状态 statusDiv.innerText = `已为 "${label}" 收集了 ${transferRecognizer.countExamples()[label] || 0} 个样本。`; } // ============== 背景噪音样本收集 (无修改) ================== recordBackgroundNoiseBtn.onclick = async () => { await recordMultipleExamples(BACKGROUND_NOISE_LABEL, backgroundNoiseSampleCountSpan, recordBackgroundNoiseBtn, 5); }; // ======================= 自定义类别管理 ======================= // 添加新类别(用户手动添加) function addCustomCategory(categoryName) { if (!categoryName) { alert('类别名称不能为空!'); return; } // 检查是否与现有标签重复(包括背景噪音,尽管背景噪音不会由用户输入) if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) { alert(`类别 "${categoryName}" 已经存在!`); return; } // 将标签添加到本地数组以供 UI 逻辑和后续预测结果查找使用 labels.push(categoryName); // 创建UI时样本数量为0 createCategoryUI(categoryName, 0); newCategoryNameInput.value = ''; // 清空输入框 checkTrainingReadiness(); // 添加新类别后检查训练就绪状态 } // 添加自定义类别按钮点击事件 addCategoryBtn.onclick = () => { addCustomCategory(newCategoryNameInput.value.trim()); }; // 创建类别UI的辅助函数(用于手动添加和导入后重建) function createCategoryUI(categoryName, sampleCount) { // categoryId 此时仅用于生成唯一的 ID,不直接传给 collectExample const categoryId = labels.indexOf(categoryName); const categoryBlock = document.createElement('div'); categoryBlock.className = 'category-block'; // 添加一个ID以便后续删除或识别 categoryBlock.id = `category-block-${encodeURIComponent(categoryName)}`; categoryBlock.innerHTML = `

${categoryName}

样本数量: ${sampleCount}

`; categoryContainer.appendChild(categoryBlock); // 绑定录音按钮事件 const recordBtn = document.getElementById(`recordBtn-${categoryId}`); const sampleCountSpan = document.getElementById(`sampleCount-${categoryId}`); recordBtn.onclick = async () => { await recordMultipleExamples(categoryName, sampleCountSpan, recordBtn, 5); }; } // ======================= 状态检查 ======================= function checkTrainingReadiness() { const exampleCounts = transferRecognizer.countExamples(); // 检查是否有任何样本,以决定是否启用“导出”按钮 const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0); exportModelBtn.disabled = totalSamples === 0; let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0; let customCategoriesReady = 0; // 遍历本地 labels 数组,检查每个自定义类别是否有样本 // 从索引 1 开始,因为 0 是背景噪音 for (let i = 1; i < labels.length; i++) { const customLabel = labels[i]; if ((exampleCounts[customLabel] || 0) > 0) { customCategoriesReady++; } } // 必须有背景噪音样本,并且至少一个自定义类别有样本 if (backgroundNoiseReady && customCategoriesReady >= 1) { trainModelBtn.disabled = false; } else { trainModelBtn.disabled = true; } } // ======================= 模型训练 (无修改) ======================= trainModelBtn.onclick = async () => { const exampleCounts = transferRecognizer.countExamples(); console.log('--- DEBUG: 训练开始前,各类别样本数量:', exampleCounts); let totalExamples = 0; let validClasses = 0; const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5; let allClassesHaveEnoughSamples = true; // 统计所有类别的有效样本数,并检查每个类别是否达到`isTrained`的最低要求 for (const labelName of labels) { // 遍历所有标签(包括背景噪音) if (exampleCounts[labelName] && exampleCounts[labelName] > 0) { totalExamples += exampleCounts[labelName]; validClasses++; if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) { allClassesHaveEnoughSamples = false; } } } // 更明确的样本数量检查提示 if (validClasses < 2) { alert(`训练需要至少 "背景噪音" (已存在) 和另一个自定义类别 (您需要添加并录制样本)。\n\n当前只有 ${validClasses} 个有效类别。`); return; } if (!allClassesHaveEnoughSamples) { alert(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。\n(当前某些类别样本不足,请检查!)\n\n建议每个类别多收集一些(例如 5-10 个)以获得更好的模型效果。`); return; } if (totalExamples === 0) { // 额外的安全检查,理论上会被上面的validClasses捕捉 alert('没有收集到任何训练样本!请先录制样本。'); return; } statusDiv.innerText = '模型训练中...请稍候。'; trainModelBtn.disabled = true; startPredictingBtn.disabled = true; stopPredictingBtn.disabled = true; const trainingConfig = { epochs: 50, batchSize: 16, validationSplit: 0.1, shuffle: true, yieldEvery: 'epoch', callbacks: { onEpochEnd: (epoch, logs) => { statusDiv.innerText = `训练 Epoch ${epoch + 1}/${trainingConfig.epochs}, Loss: ${logs.loss ? logs.loss.toFixed(4) : 'N/A'}, Accuracy: ${logs.acc ? logs.acc.toFixed(4) : 'N/A'}`; } } }; try { await transferRecognizer.train(trainingConfig); statusDiv.innerText = '模型训练完成!你可以开始识别了。'; predictionResultDiv.innerText = '训练完成,等待识别...'; startPredictingBtn.disabled = false; // 训练成功后,手动设置状态标志 isModelTrainedFlag = true; console.log('--- DEBUG: 训练成功完成,此时 transferRecognizer.isTrained 为:', transferRecognizer.isTrained); } catch (error) { statusDiv.innerText = `模型训练失败: ${error.message}. 这通常是由于样本数量过少,类别不均,或录音质量问题导致。请确保每个类别至少有 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本,并且多录制一些(例如 5-10 个)!`; console.error('训练失败:', error); // 训练失败时重置状态 isModelTrainedFlag = false; } finally { trainModelBtn.disabled = false; } }; // ======================= 实时预测 (无修改) ======================= startPredictingBtn.onclick = async () => { console.log('--- DEBUG: 点击开始识别时, isModelTrainedFlag 为:', isModelTrainedFlag); if (isPredicting) { statusDiv.innerText = '识别已经在进行中...'; return; } // 使用自定义标志进行判断 if (!isModelTrainedFlag) { alert('模型尚未训练完成,请先训练模型!'); return; } isPredicting = true; startPredictingBtn.disabled = true; stopPredictingBtn.disabled = false; trainModelBtn.disabled = true; recordBackgroundNoiseBtn.disabled = true; addCategoryBtn.disabled = true; // 禁用所有录制按钮 (确保在预测时不能添加新样本) document.querySelectorAll('.category-block button').forEach(btn => btn.disabled = true); statusDiv.innerText = '正在开始识别... 请发出你训练过的声音。'; predictionResultDiv.innerText = '等待识别结果...'; predictionStopFunction = await transferRecognizer.listen(result => { if (!isPredicting) return; // `transferRecognizer.wordLabels()` 会返回 transferRecognizer 内部按顺序排列的所有标签名称。 // `result.scores` 的索引会与 `transferRecognizer.wordLabels()` 的索引对应。 const classLabels = transferRecognizer.wordLabels(); const scores = result.scores; const maxScore = Math.max(...scores); const predictedIndex = scores.indexOf(maxScore); let predictedLabel = classLabels[predictedIndex]; // 从 transferRecognizer 的内部标签列表中获取 // 如果预测结果是内部的背景噪音标签,转换成用户友好的显示 if (predictedLabel === BACKGROUND_NOISE_LABEL) { predictedLabel = '背景噪音'; } predictionResultDiv.innerText = `预测结果:${predictedLabel} (置信度: ${(maxScore * 100).toFixed(2)}%)`; }, { includeEmbedding: true, probabilityThreshold: 0.75, suppressionTimeMillis: 300, overlapFactor: 0.50, }); console.log('--- DEBUG: predictionStopFunction 赋值后:', predictionStopFunction); console.log('--- DEBUG: typeof predictionStopFunction 赋值后:', typeof predictionStopFunction); }; stopPredictingBtn.onclick = () => { if (isPredicting) { if (typeof predictionStopFunction === 'function') { predictionStopFunction(); predictionStopFunction = null; } else { console.warn('--- WARN: predictionStopFunction 不是一个函数,无法停止监听。'); } isPredicting = false; startPredictingBtn.disabled = false; stopPredictingBtn.disabled = true; trainModelBtn.disabled = false; recordBackgroundNoiseBtn.disabled = false; addCategoryBtn.disabled = false; // 重新启用所有录制按钮 (只有在不是正在录音状态时才启用) document.querySelectorAll('.category-block button').forEach(btn => { if (!isRecording) { btn.disabled = false; } }); statusDiv.innerText = '已停止识别。'; predictionResultDiv.innerText = '停止识别。'; } }; // ======================= 新增:模型导出功能 ======================= exportModelBtn.onclick = async () => { try { // 序列化所有收集到的样本数据 const serializedExamples = transferRecognizer.serializeExamples(); // 创建一个 Blob 对象 const blob = new Blob([serializedExamples], { type: 'application/octet-stream' }); // 创建一个下载链接 const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; // 定制文件名,包含日期和时间 const now = new Date(); const filename = `speech_commands_data_${now.getFullYear()}${(now.getMonth()+1).toString().padStart(2, '0')}${now.getDate().toString().padStart(2, '0')}_${now.getHours().toString().padStart(2, '0')}${now.getMinutes().toString().padStart(2, '0')}${now.getSeconds().toString().padStart(2, '0')}.bin`; a.download = filename; // 模拟点击下载 document.body.appendChild(a); a.click(); document.body.removeChild(a); URL.revokeObjectURL(url); // 释放内存 statusDiv.innerText = `数据已成功导出为 "${filename}"。`; } catch (error) { statusDiv.innerText = `导出数据失败: ${error.message}`; console.error('导出数据失败:', error); alert('导出数据失败。请确保您已录制至少一个样本!'); } }; // ======================= 新增:模型导入功能 ======================= importModelBtn.onclick = () => { // 触发隐藏的文件输入框点击事件 importFileInput.click(); }; importFileInput.onchange = async (event) => { const file = event.target.files[0]; if (!file) { statusDiv.innerText = '未选择文件。'; return; } if (!file.name.endsWith('.bin')) { alert('请选择后缀名为 .bin 的文件!'); statusDiv.innerText = '文件格式不正确,请选择 .bin 文件。'; // 清空文件输入,以便用户可以选择其他文件 importFileInput.value = ''; return; } statusDiv.innerText = `正在导入文件 "${file.name}"...`; const reader = new FileReader(); reader.onload = async (e) => { try { const dataBuffer = e.target.result; // 获取文件的 ArrayBuffer 内容 // 清除当前的 transferRecognizer 中的所有样本 // SpeechCommands库中没有直接的clearExamples方法, // 最简单的做法是重新创建一个 transferRecognizer 实例。 // 但更好的做法是先尝试loadExamples,如果需要重置,再做。 // 假设导入是“覆盖”现有样本的。 // TODO: 这里可以考虑增加用户确认是否清除现有样本的提示 // 导入样本。这会自动更新 internal model await transferRecognizer.loadExamples(dataBuffer); // 成功导入后,刷新UI await syncUIWithLoadedData(); statusDiv.innerText = `文件 "${file.name}" 导入成功!`; alert(`已成功导入 ${transferRecognizer.countExamples()._numExamples_ || 0} 个样本!`); } catch (error) { statusDiv.innerText = `导入数据失败: ${error.message}. 确保存储的是有效的模型样本数据。`; console.error('导入数据失败:', error); alert(`导入数据失败。请检查文件是否损坏或格式不正确。\n错误: ${error.message}`); } finally { // 清空文件输入,以便下次选择相同文件也能触发 onchange importFileInput.value = ''; } }; reader.onerror = (error) => { statusDiv.innerText = `读取文件失败: ${error.message}`; console.error('文件读取失败:', error); alert('文件读取失败。'); importFileInput.value = ''; }; reader.readAsArrayBuffer(file); // 以 ArrayBuffer 格式读取文件 }; // ======================= 新增辅助函数:导入后同步UI ======================= async function syncUIWithLoadedData() { // 清空现有除了背景噪音以外的类别块 // 遍历所有子元素,从后向前删除,避免索引问题 while (categoryContainer.firstChild) { categoryContainer.removeChild(categoryContainer.firstChild); } // 重置全局 labels 数组,只保留背景噪音 labels = [BACKGROUND_NOISE_LABEL]; // 获取导入后的样本计数 const exampleCounts = transferRecognizer.countExamples(); console.log('--- DEBUG: 导入后样本数量:', exampleCounts); // 更新背景噪音样本数量 backgroundNoiseSampleCountSpan.innerText = exampleCounts[BACKGROUND_NOISE_LABEL] || 0; // 重新构建自定义类别 UI for (const label of Object.keys(exampleCounts)) { if (label === BACKGROUND_NOISE_LABEL || label === '_version_' || label === '_numExamples_') { continue; // 跳过背景噪音和内部元数据标签 } // 将导入的自定义标签添加到我们的 labels 数组 if (!labels.includes(label)) { labels.push(label); } // 根据导入的数据创建 UI createCategoryUI(label, exampleCounts[label]); } // 重置模型的训练状态 isModelTrainedFlag = false; trainModelBtn.disabled = true; // 训练按钮默认禁用,等待 checkTrainingReadiness 启用 startPredictingBtn.disabled = true; // 预测按钮禁用 stopPredictingBtn.disabled = true; // 停止按钮禁用 // 检查训练就绪状态(现在有样本了,这个调用是安全的) checkTrainingReadiness(); } // ======================= 页面加载时执行 ======================= window.onload = init;