493 lines
20 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// 全局变量和模型实例
let recognizer;
let transferRecognizer;
// labels 现在将根据导入的数据动态重建,但仍需初始化
let labels = [];
const BACKGROUND_NOISE_LABEL = '_background_noise_';
let isPredicting = false;
let isRecording = false;
const recordDuration = 1000;
let isModelTrainedFlag = false;
let predictionStopFunction = null;
// UI 元素引用 (已更新)
const statusDiv = document.getElementById('status');
const backgroundNoiseSampleCountSpan = document.getElementById('backgroundNoiseSampleCount');
const recordBackgroundNoiseBtn = document.getElementById('recordBackgroundNoiseBtn');
const categoryContainer = document.getElementById('categoryContainer');
const newCategoryNameInput = document.getElementById('newCategoryName');
const addCategoryBtn = document.getElementById('addCategoryBtn');
const trainModelBtn = document.getElementById('trainModelBtn');
const startPredictingBtn = document.getElementById('startPredictingBtn');
const stopPredictingBtn = document.getElementById('stopPredictingBtn');
const predictionResultDiv = document.getElementById('predictionResult');
// ===== 新增UI元素引用 =====
const exportModelBtn = document.getElementById('exportModelBtn');
const importModelBtn = document.getElementById('importModelBtn');
const importFileInput = document.getElementById('importFileInput');
// ======================= 初始化函数 =======================
async function init() {
statusDiv.innerText = '正在加载 TensorFlow.js 和 Speech Commands 模型...';
try {
recognizer = speechCommands.create('BROWSER_FFT');
await recognizer.ensureModelLoaded();
transferRecognizer = recognizer.createTransfer('my-custom-model');
// 初始化时清空并设置背景噪音标签
labels = [BACKGROUND_NOISE_LABEL];
statusDiv.innerText = '模型加载成功!你可以开始录制、或导入已有的样本数据。';
recordBackgroundNoiseBtn.disabled = false;
addCategoryBtn.disabled = false;
importModelBtn.disabled = false; // 允许导入
exportModelBtn.disabled = true; // 尚无数据,默认禁用导出按钮
trainModelBtn.disabled = true;
startPredictingBtn.disabled = true;
stopPredictingBtn.disabled = true;
isModelTrainedFlag = false;
// --- 修正之处:移除此处对 checkTrainingReadiness() 的调用 ---
// checkTrainingReadiness(); // <--- 移除这一行!
} catch (error) {
statusDiv.innerText = `模型加载失败或麦克风无法访问: ${error.message}. 请检查麦克风权限和网络连接。`;
console.error('初始化失败:', error);
// 禁用所有按钮
const buttons = document.querySelectorAll('button');
buttons.forEach(btn => btn.disabled = true);
isModelTrainedFlag = false;
}
}
// ======================= 批量录制样本的通用函数 =======================
async function recordMultipleExamples(label, sampleCountSpanElement, buttonElement, countToRecord = 5) {
if (isRecording) {
statusDiv.innerText = '请等待当前录音完成...';
return;
}
isRecording = true;
buttonElement.disabled = true;
buttonElement.innerText = '正在录制...';
for (let i = 0; i < countToRecord; i++) {
statusDiv.innerText = `正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`;
try {
await transferRecognizer.collectExample(
label,
{ amplitudeRequired: true, durationMillis: recordDuration }
);
const exampleCounts = transferRecognizer.countExamples();
sampleCountSpanElement.innerText = exampleCounts[label] || 0;
if (i < countToRecord - 1) {
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
}
} catch (error) {
statusDiv.innerText = `录制 "${label}" 样本失败: ${error.message}`;
console.error(`录制 ${label} 样本失败:`, error);
// 如果某个样本录制失败,则停止当前批次的录制
break;
}
}
buttonElement.disabled = false;
buttonElement.innerText = '录制样本';
isRecording = false;
checkTrainingReadiness(); // 录制完成后检查训练就绪状态
statusDiv.innerText = `已为 "${label}" 收集了 ${transferRecognizer.countExamples()[label] || 0} 个样本。`;
}
// ============== 背景噪音样本收集 (无修改) ==================
recordBackgroundNoiseBtn.onclick = async () => {
await recordMultipleExamples(BACKGROUND_NOISE_LABEL, backgroundNoiseSampleCountSpan, recordBackgroundNoiseBtn, 5);
};
// ======================= 自定义类别管理 =======================
// 添加新类别(用户手动添加)
function addCustomCategory(categoryName) {
if (!categoryName) {
alert('类别名称不能为空!');
return;
}
// 检查是否与现有标签重复(包括背景噪音,尽管背景噪音不会由用户输入)
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
alert(`类别 "${categoryName}" 已经存在!`);
return;
}
// 将标签添加到本地数组以供 UI 逻辑和后续预测结果查找使用
labels.push(categoryName);
// 创建UI时样本数量为0
createCategoryUI(categoryName, 0);
newCategoryNameInput.value = ''; // 清空输入框
checkTrainingReadiness(); // 添加新类别后检查训练就绪状态
}
// 添加自定义类别按钮点击事件
addCategoryBtn.onclick = () => {
addCustomCategory(newCategoryNameInput.value.trim());
};
// 创建类别UI的辅助函数用于手动添加和导入后重建
function createCategoryUI(categoryName, sampleCount) {
// categoryId 此时仅用于生成唯一的 ID不直接传给 collectExample
const categoryId = labels.indexOf(categoryName);
const categoryBlock = document.createElement('div');
categoryBlock.className = 'category-block';
// 添加一个ID以便后续删除或识别
categoryBlock.id = `category-block-${encodeURIComponent(categoryName)}`;
categoryBlock.innerHTML = `
<h3>${categoryName}</h3>
<p>样本数量: <span id="sampleCount-${categoryId}">${sampleCount}</span></p>
<button id="recordBtn-${categoryId}">录制样本</button>
`;
categoryContainer.appendChild(categoryBlock);
// 绑定录音按钮事件
const recordBtn = document.getElementById(`recordBtn-${categoryId}`);
const sampleCountSpan = document.getElementById(`sampleCount-${categoryId}`);
recordBtn.onclick = async () => {
await recordMultipleExamples(categoryName, sampleCountSpan, recordBtn, 5);
};
}
// ======================= 状态检查 =======================
function checkTrainingReadiness() {
const exampleCounts = transferRecognizer.countExamples();
// 检查是否有任何样本,以决定是否启用“导出”按钮
const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0);
exportModelBtn.disabled = totalSamples === 0;
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
let customCategoriesReady = 0;
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
// 从索引 1 开始,因为 0 是背景噪音
for (let i = 1; i < labels.length; i++) {
const customLabel = labels[i];
if ((exampleCounts[customLabel] || 0) > 0) {
customCategoriesReady++;
}
}
// 必须有背景噪音样本,并且至少一个自定义类别有样本
if (backgroundNoiseReady && customCategoriesReady >= 1) {
trainModelBtn.disabled = false;
} else {
trainModelBtn.disabled = true;
}
}
// ======================= 模型训练 (无修改) =======================
trainModelBtn.onclick = async () => {
const exampleCounts = transferRecognizer.countExamples();
console.log('--- DEBUG: 训练开始前,各类别样本数量:', exampleCounts);
let totalExamples = 0;
let validClasses = 0;
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
let allClassesHaveEnoughSamples = true;
// 统计所有类别的有效样本数,并检查每个类别是否达到`isTrained`的最低要求
for (const labelName of labels) { // 遍历所有标签(包括背景噪音)
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
totalExamples += exampleCounts[labelName];
validClasses++;
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
allClassesHaveEnoughSamples = false;
}
}
}
// 更明确的样本数量检查提示
if (validClasses < 2) {
alert(`训练需要至少 "背景噪音" (已存在) 和另一个自定义类别 (您需要添加并录制样本)。\n\n当前只有 ${validClasses} 个有效类别。`);
return;
}
if (!allClassesHaveEnoughSamples) {
alert(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。\n(当前某些类别样本不足,请检查!)\n\n建议每个类别多收集一些(例如 5-10 个)以获得更好的模型效果。`);
return;
}
if (totalExamples === 0) { // 额外的安全检查理论上会被上面的validClasses捕捉
alert('没有收集到任何训练样本!请先录制样本。');
return;
}
statusDiv.innerText = '模型训练中...请稍候。';
trainModelBtn.disabled = true;
startPredictingBtn.disabled = true;
stopPredictingBtn.disabled = true;
const trainingConfig = {
epochs: 50,
batchSize: 16,
validationSplit: 0.1,
shuffle: true,
yieldEvery: 'epoch',
callbacks: {
onEpochEnd: (epoch, logs) => {
statusDiv.innerText = `训练 Epoch ${epoch + 1}/${trainingConfig.epochs}, Loss: ${logs.loss ? logs.loss.toFixed(4) : 'N/A'}, Accuracy: ${logs.acc ? logs.acc.toFixed(4) : 'N/A'}`;
}
}
};
try {
await transferRecognizer.train(trainingConfig);
statusDiv.innerText = '模型训练完成!你可以开始识别了。';
predictionResultDiv.innerText = '训练完成,等待识别...';
startPredictingBtn.disabled = false;
// 训练成功后,手动设置状态标志
isModelTrainedFlag = true;
console.log('--- DEBUG: 训练成功完成,此时 transferRecognizer.isTrained 为:', transferRecognizer.isTrained);
} catch (error) {
statusDiv.innerText = `模型训练失败: ${error.message}. 这通常是由于样本数量过少,类别不均,或录音质量问题导致。请确保每个类别至少有 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本,并且多录制一些(例如 5-10 个)!`;
console.error('训练失败:', error);
// 训练失败时重置状态
isModelTrainedFlag = false;
} finally {
trainModelBtn.disabled = false;
}
};
// ======================= 实时预测 (无修改) =======================
startPredictingBtn.onclick = async () => {
console.log('--- DEBUG: 点击开始识别时, isModelTrainedFlag 为:', isModelTrainedFlag);
if (isPredicting) {
statusDiv.innerText = '识别已经在进行中...';
return;
}
// 使用自定义标志进行判断
if (!isModelTrainedFlag) {
alert('模型尚未训练完成,请先训练模型!');
return;
}
isPredicting = true;
startPredictingBtn.disabled = true;
stopPredictingBtn.disabled = false;
trainModelBtn.disabled = true;
recordBackgroundNoiseBtn.disabled = true;
addCategoryBtn.disabled = true;
// 禁用所有录制按钮 (确保在预测时不能添加新样本)
document.querySelectorAll('.category-block button').forEach(btn => btn.disabled = true);
statusDiv.innerText = '正在开始识别... 请发出你训练过的声音。';
predictionResultDiv.innerText = '等待识别结果...';
predictionStopFunction = await transferRecognizer.listen(result => {
if (!isPredicting) return;
// `transferRecognizer.wordLabels()` 会返回 transferRecognizer 内部按顺序排列的所有标签名称。
// `result.scores` 的索引会与 `transferRecognizer.wordLabels()` 的索引对应。
const classLabels = transferRecognizer.wordLabels();
const scores = result.scores;
const maxScore = Math.max(...scores);
const predictedIndex = scores.indexOf(maxScore);
let predictedLabel = classLabels[predictedIndex]; // 从 transferRecognizer 的内部标签列表中获取
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
predictedLabel = '背景噪音';
}
predictionResultDiv.innerText = `预测结果:${predictedLabel} (置信度: ${(maxScore * 100).toFixed(2)}%)`;
}, {
includeEmbedding: true,
probabilityThreshold: 0.75,
suppressionTimeMillis: 300,
overlapFactor: 0.50,
});
console.log('--- DEBUG: predictionStopFunction 赋值后:', predictionStopFunction);
console.log('--- DEBUG: typeof predictionStopFunction 赋值后:', typeof predictionStopFunction);
};
stopPredictingBtn.onclick = () => {
if (isPredicting) {
if (typeof predictionStopFunction === 'function') {
predictionStopFunction();
predictionStopFunction = null;
} else {
console.warn('--- WARN: predictionStopFunction 不是一个函数,无法停止监听。');
}
isPredicting = false;
startPredictingBtn.disabled = false;
stopPredictingBtn.disabled = true;
trainModelBtn.disabled = false;
recordBackgroundNoiseBtn.disabled = false;
addCategoryBtn.disabled = false;
// 重新启用所有录制按钮 (只有在不是正在录音状态时才启用)
document.querySelectorAll('.category-block button').forEach(btn => {
if (!isRecording) {
btn.disabled = false;
}
});
statusDiv.innerText = '已停止识别。';
predictionResultDiv.innerText = '停止识别。';
}
};
// ======================= 新增:模型导出功能 =======================
exportModelBtn.onclick = async () => {
try {
// 序列化所有收集到的样本数据
const serializedExamples = transferRecognizer.serializeExamples();
// 创建一个 Blob 对象
const blob = new Blob([serializedExamples], { type: 'application/octet-stream' });
// 创建一个下载链接
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
// 定制文件名,包含日期和时间
const now = new Date();
const filename = `speech_commands_data_${now.getFullYear()}${(now.getMonth()+1).toString().padStart(2, '0')}${now.getDate().toString().padStart(2, '0')}_${now.getHours().toString().padStart(2, '0')}${now.getMinutes().toString().padStart(2, '0')}${now.getSeconds().toString().padStart(2, '0')}.bin`;
a.download = filename;
// 模拟点击下载
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url); // 释放内存
statusDiv.innerText = `数据已成功导出为 "${filename}"。`;
} catch (error) {
statusDiv.innerText = `导出数据失败: ${error.message}`;
console.error('导出数据失败:', error);
alert('导出数据失败。请确保您已录制至少一个样本!');
}
};
// ======================= 新增:模型导入功能 =======================
importModelBtn.onclick = () => {
// 触发隐藏的文件输入框点击事件
importFileInput.click();
};
importFileInput.onchange = async (event) => {
const file = event.target.files[0];
if (!file) {
statusDiv.innerText = '未选择文件。';
return;
}
if (!file.name.endsWith('.bin')) {
alert('请选择后缀名为 .bin 的文件!');
statusDiv.innerText = '文件格式不正确,请选择 .bin 文件。';
// 清空文件输入,以便用户可以选择其他文件
importFileInput.value = '';
return;
}
statusDiv.innerText = `正在导入文件 "${file.name}"...`;
const reader = new FileReader();
reader.onload = async (e) => {
try {
const dataBuffer = e.target.result; // 获取文件的 ArrayBuffer 内容
// 清除当前的 transferRecognizer 中的所有样本
// SpeechCommands库中没有直接的clearExamples方法
// 最简单的做法是重新创建一个 transferRecognizer 实例。
// 但更好的做法是先尝试loadExamples如果需要重置再做。
// 假设导入是“覆盖”现有样本的。
// TODO: 这里可以考虑增加用户确认是否清除现有样本的提示
// 导入样本。这会自动更新 internal model
await transferRecognizer.loadExamples(dataBuffer);
// 成功导入后刷新UI
await syncUIWithLoadedData();
statusDiv.innerText = `文件 "${file.name}" 导入成功!`;
alert(`已成功导入 ${transferRecognizer.countExamples()._numExamples_ || 0} 个样本!`);
} catch (error) {
statusDiv.innerText = `导入数据失败: ${error.message}. 确保存储的是有效的模型样本数据。`;
console.error('导入数据失败:', error);
alert(`导入数据失败。请检查文件是否损坏或格式不正确。\n错误: ${error.message}`);
} finally {
// 清空文件输入,以便下次选择相同文件也能触发 onchange
importFileInput.value = '';
}
};
reader.onerror = (error) => {
statusDiv.innerText = `读取文件失败: ${error.message}`;
console.error('文件读取失败:', error);
alert('文件读取失败。');
importFileInput.value = '';
};
reader.readAsArrayBuffer(file); // 以 ArrayBuffer 格式读取文件
};
// ======================= 新增辅助函数导入后同步UI =======================
async function syncUIWithLoadedData() {
// 清空现有除了背景噪音以外的类别块
// 遍历所有子元素,从后向前删除,避免索引问题
while (categoryContainer.firstChild) {
categoryContainer.removeChild(categoryContainer.firstChild);
}
// 重置全局 labels 数组,只保留背景噪音
labels = [BACKGROUND_NOISE_LABEL];
// 获取导入后的样本计数
const exampleCounts = transferRecognizer.countExamples();
console.log('--- DEBUG: 导入后样本数量:', exampleCounts);
// 更新背景噪音样本数量
backgroundNoiseSampleCountSpan.innerText = exampleCounts[BACKGROUND_NOISE_LABEL] || 0;
// 重新构建自定义类别 UI
for (const label of Object.keys(exampleCounts)) {
if (label === BACKGROUND_NOISE_LABEL || label === '_version_' || label === '_numExamples_') {
continue; // 跳过背景噪音和内部元数据标签
}
// 将导入的自定义标签添加到我们的 labels 数组
if (!labels.includes(label)) {
labels.push(label);
}
// 根据导入的数据创建 UI
createCategoryUI(label, exampleCounts[label]);
}
// 重置模型的训练状态
isModelTrainedFlag = false;
trainModelBtn.disabled = true; // 训练按钮默认禁用,等待 checkTrainingReadiness 启用
startPredictingBtn.disabled = true; // 预测按钮禁用
stopPredictingBtn.disabled = true; // 停止按钮禁用
// 检查训练就绪状态(现在有样本了,这个调用是安全的)
checkTrainingReadiness();
}
// ======================= 页面加载时执行 =======================
window.onload = init;