2025-08-14 14:35:44 +08:00

345 lines
15 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// 全局变量和模型实例
let recognizer; // 基础的 SpeechCommands recognizer
let transferRecognizer; // 用于迁移学习的 recognizer
const labels = []; // 用户定义的类别标签数组 (包括背景噪音)
// 将背景噪音定义为第一个类别,其内部名称为 _background_noise_
const BACKGROUND_NOISE_LABEL = '_background_noise_';
const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引不直接用于collectExample
let isPredicting = false; // 预测状态标志
let isRecording = false; // 录音状态标志,防止重复点击
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
let isModelTrainedFlag = false; // 手动维护模型训练状态
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
// UI 元素引用 (保持不变)
const statusDiv = document.getElementById('status');
const backgroundNoiseSampleCountSpan = document.getElementById('backgroundNoiseSampleCount');
const recordBackgroundNoiseBtn = document.getElementById('recordBackgroundNoiseBtn');
const categoryContainer = document.getElementById('categoryContainer');
const newCategoryNameInput = document.getElementById('newCategoryName');
const addCategoryBtn = document.getElementById('addCategoryBtn');
const trainModelBtn = document.getElementById('trainModelBtn');
const startPredictingBtn = document.getElementById('startPredictingBtn');
const stopPredictingBtn = document.getElementById('stopPredictingBtn');
const predictionResultDiv = document.getElementById('predictionResult');
// ======================= 初始化函数 =======================
async function init() {
statusDiv.innerText = '正在加载 TensorFlow.js 和 Speech Commands 模型...';
try {
recognizer = speechCommands.create(
'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好
);
await recognizer.ensureModelLoaded();
transferRecognizer = recognizer.createTransfer('my-custom-model');
// 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组
labels.push(BACKGROUND_NOISE_LABEL); // 仅用于本地 UI 映射和预测结果查找
statusDiv.innerText = '模型加载成功!你可以开始录制背景噪音和自定义声音样本了。';
recordBackgroundNoiseBtn.disabled = false;
addCategoryBtn.disabled = false;
trainModelBtn.disabled = true;
startPredictingBtn.disabled = true;
stopPredictingBtn.disabled = true;
isModelTrainedFlag = false; // 重置训练状态
} catch (error) {
statusDiv.innerText = `模型加载失败或麦克风无法访问: ${error.message}. 请检查麦克风权限和网络连接。`;
console.error('初始化失败:', error);
// 任何失败都禁用所有控制,直到初始化成功
recordBackgroundNoiseBtn.disabled = true;
addCategoryBtn.disabled = true;
trainModelBtn.disabled = true;
startPredictingBtn.disabled = true;
stopPredictingBtn.disabled = true;
isModelTrainedFlag = false; // 重置训练状态
}
}
// ======================= 批量录制样本的通用函数 =======================
// recordMultipleExamples传入 label, 样本数量显示元素, 按钮元素, 一次录制的样本数量
async function recordMultipleExamples(label, sampleCountSpanElement, buttonElement, countToRecord = 5) { // 默认一次录制5个样本
if (isRecording) {
statusDiv.innerText = '请等待当前录音完成...';
return;
}
isRecording = true;
buttonElement.disabled = true;
buttonElement.innerText = '正在录制...';
for (let i = 0; i < countToRecord; i++) {
statusDiv.innerText = `正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`;
try {
await transferRecognizer.collectExample(
label,
{ amplitudeRequired: true, durationMillis: recordDuration }
);
const exampleCounts = transferRecognizer.countExamples();
sampleCountSpanElement.innerText = exampleCounts[label] || 0;
// 在每次录音之间增加短暂延迟,以便更好地分离样本
if (i < countToRecord - 1) {
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5))); // 至少 200ms 或录音时长的 1/5
}
} catch (error) {
statusDiv.innerText = `录制 "${label}" 样本失败: ${error.message}`;
console.error(`录制 ${label} 样本失败:`, error);
// 如果某个样本录制失败,则停止当前批次的录制
break;
}
}
buttonElement.disabled = false;
buttonElement.innerText = '录制样本';
isRecording = false;
checkTrainingReadiness(); // 录制完成后检查训练就绪状态
statusDiv.innerText = `已为 "${label}" 收集了 ${transferRecognizer.countExamples()[label] || 0} 个样本。`;
}
// ======================= 背景噪音样本收集 =======================
// 按钮点击事件
recordBackgroundNoiseBtn.onclick = async () => {
await recordMultipleExamples(BACKGROUND_NOISE_LABEL, backgroundNoiseSampleCountSpan, recordBackgroundNoiseBtn, 5);
};
// ======================= 自定义类别管理和样本收集 =======================
// 添加新类别到 UI 和逻辑 (用于自定义声音)
function addCustomCategory(categoryName) {
if (!categoryName) {
alert('类别名称不能为空!');
return;
}
// 检查是否与现有标签重复(包括背景噪音,尽管背景噪音不会由用户输入)
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
alert(`类别 "${categoryName}" 已经存在!`);
return;
}
// 将标签添加到本地数组以供 UI 逻辑和后续预测结果查找使用
labels.push(categoryName);
// 创建类别块 UI
const categoryBlock = document.createElement('div');
categoryBlock.className = 'category-block';
// categoryId 此时仅用于生成唯一的 ID不直接传给 collectExample
const categoryId = labels.indexOf(categoryName);
categoryBlock.innerHTML = `
<h3>${categoryName}</h3>
<p>样本数量: <span id="sampleCount-${categoryId}">0</span></p>
<button id="recordBtn-${categoryId}">录制样本</button>
`;
categoryContainer.appendChild(categoryBlock);
// 绑定录音按钮事件
const recordBtn = document.getElementById(`recordBtn-${categoryId}`);
const sampleCountSpan = document.getElementById(`sampleCount-${categoryId}`);
recordBtn.onclick = async () => {
await recordMultipleExamples(categoryName, sampleCountSpan, recordBtn, 5);
};
newCategoryNameInput.value = ''; // 清空输入框
checkTrainingReadiness(); // 添加新类别后检查训练就绪状态
}
// 添加自定义类别按钮点击事件
addCategoryBtn.onclick = () => {
addCustomCategory(newCategoryNameInput.value.trim());
};
// ======================= 检查训练就绪状态 =======================
function checkTrainingReadiness() {
const exampleCounts = transferRecognizer.countExamples();
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
let customCategoriesReady = 0;
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音
const customLabel = labels[i];
if ((exampleCounts[customLabel] || 0) > 0) {
customCategoriesReady++;
}
}
// 必须有背景噪音样本,并且至少一个自定义类别有样本
if (backgroundNoiseReady && customCategoriesReady >= 1) {
trainModelBtn.disabled = false;
} else {
trainModelBtn.disabled = true;
}
}
// ======================= 模型训练 =======================
trainModelBtn.onclick = async () => {
const exampleCounts = transferRecognizer.countExamples(); // 确保这里获取到了最新的样本数量
console.log('--- DEBUG: 训练开始前,各类别样本数量:', exampleCounts);
let totalExamples = 0;
let validClasses = 0;
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
let allClassesHaveEnoughSamples = true;
// 统计所有类别的有效样本数,并检查每个类别是否达到`isTrained`的最低要求
for (const labelName of labels) { // 遍历所有标签(包括背景噪音)
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
totalExamples += exampleCounts[labelName];
validClasses++;
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
allClassesHaveEnoughSamples = false;
}
}
}
// 更明确的样本数量检查提示
if (validClasses < 2) {
alert(`训练需要至少 "背景噪音" (已存在) 和另一个自定义类别 (您需要添加并录制样本)。\n\n当前只有 ${validClasses} 个有效类别。`);
return;
}
if (!allClassesHaveEnoughSamples) {
alert(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。\n(当前某些类别样本不足,请检查!)\n\n建议每个类别多收集一些(例如 5-10 个)以获得更好的模型效果。`);
return;
}
if (totalExamples === 0) { // 额外的安全检查理论上会被上面的validClasses捕捉
alert('没有收集到任何训练样本!请先录制样本。');
return;
}
statusDiv.innerText = '模型训练中...请稍候。';
trainModelBtn.disabled = true;
startPredictingBtn.disabled = true;
stopPredictingBtn.disabled = true;
const trainingConfig = {
epochs: 50,
batchSize: 16,
validationSplit: 0.1,
shuffle: true,
yieldEvery: 'epoch',
callbacks: {
onEpochEnd: (epoch, logs) => {
statusDiv.innerText = `训练 Epoch ${epoch + 1}/${trainingConfig.epochs}, Loss: ${logs.loss ? logs.loss.toFixed(4) : 'N/A'}, Accuracy: ${logs.acc ? logs.acc.toFixed(4) : 'N/A'}`;
}
}
};
try {
await transferRecognizer.train(trainingConfig);
statusDiv.innerText = '模型训练完成!你可以开始识别了。';
predictionResultDiv.innerText = '训练完成,等待识别...';
startPredictingBtn.disabled = false;
// 训练成功后,手动设置状态标志
isModelTrainedFlag = true;
console.log('--- DEBUG: 训练成功完成,此时 transferRecognizer.isTrained 为:', transferRecognizer.isTrained);
} catch (error) {
statusDiv.innerText = `模型训练失败: ${error.message}. 这通常是由于样本数量过少,类别不均,或录音质量问题导致。请确保每个类别至少有 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本,并且多录制一些(例如 5-10 个)!`;
console.error('训练失败:', error);
// 训练失败时重置状态
isModelTrainedFlag = false;
} finally {
trainModelBtn.disabled = false;
}
};
// ======================= 实时预测 =======================
startPredictingBtn.onclick = async () => { // 确保此函数是 async
console.log('--- DEBUG: 点击开始识别时, isModelTrainedFlag 为:', isModelTrainedFlag);
if (isPredicting) {
statusDiv.innerText = '识别已经在进行中...';
return;
}
// 使用自定义标志进行判断
if (!isModelTrainedFlag) {
alert('模型尚未训练完成,请先训练模型!');
return;
}
isPredicting = true;
startPredictingBtn.disabled = true;
stopPredictingBtn.disabled = false;
trainModelBtn.disabled = true;
recordBackgroundNoiseBtn.disabled = true;
addCategoryBtn.disabled = true;
// 禁用所有录制按钮 (确保在预测时不能添加新样本)
document.querySelectorAll('.category-block button').forEach(btn => btn.disabled = true);
statusDiv.innerText = '正在开始识别... 请发出你训练过的声音。';
predictionResultDiv.innerText = '等待识别结果...';
// <<< 核心修正:捕获 transferRecognizer.listen() 返回的停止函数时使用 await
predictionStopFunction = await transferRecognizer.listen(result => { // !!!这里加上了 await
if (!isPredicting) return;
// `transferRecognizer.wordLabels()` 会返回 transferRecognizer 内部按顺序排列的所有标签名称。
// `result.scores` 的索引会与 `transferRecognizer.wordLabels()` 的索引对应。
const classLabels = transferRecognizer.wordLabels();
const scores = result.scores;
const maxScore = Math.max(...scores);
const predictedIndex = scores.indexOf(maxScore);
let predictedLabel = classLabels[predictedIndex]; // 从 transferRecognizer 的内部标签列表中获取
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
predictedLabel = '背景噪音';
}
predictionResultDiv.innerText = `预测结果:${predictedLabel} (置信度: ${(maxScore * 100).toFixed(2)}%)`;
}, {
includeEmbedding: true,
probabilityThreshold: 0.75,
suppressionTimeMillis: 300,
overlapFactor: 0.50,
});
// 可以在这里添加一个调试日志,确认 predictionStopFunction 确实是一个函数
console.log('--- DEBUG: predictionStopFunction 赋值后:', predictionStopFunction);
console.log('--- DEBUG: typeof predictionStopFunction 赋值后:', typeof predictionStopFunction);
};
stopPredictingBtn.onclick = () => {
if (isPredicting) {
// 增加一个额外的类型检查,确保它确实是一个函数
if (typeof predictionStopFunction === 'function') { // 确保是函数才调用
predictionStopFunction(); // 调用停止识别的函数
predictionStopFunction = null; // 清除引用,避免内存泄漏,也防止二次调用
} else {
console.warn('--- WARN: predictionStopFunction 不是一个函数,无法停止监听。');
}
isPredicting = false;
startPredictingBtn.disabled = false;
stopPredictingBtn.disabled = true;
trainModelBtn.disabled = false;
recordBackgroundNoiseBtn.disabled = false;
addCategoryBtn.disabled = false;
// 重新启用所有录制按钮 (只有在不是正在录音状态时才启用)
document.querySelectorAll('.category-block button').forEach(btn => {
if (!isRecording) {
btn.disabled = false;
}
});
statusDiv.innerText = '已停止识别。';
predictionResultDiv.innerText = '停止识别。';
}
};
// ======================= 页面加载时执行 =======================
window.onload = init;