493 lines
20 KiB
JavaScript
493 lines
20 KiB
JavaScript
// 全局变量和模型实例
|
||
let recognizer;
|
||
let transferRecognizer;
|
||
// labels 现在将根据导入的数据动态重建,但仍需初始化
|
||
let labels = [];
|
||
const BACKGROUND_NOISE_LABEL = '_background_noise_';
|
||
|
||
let isPredicting = false;
|
||
let isRecording = false;
|
||
const recordDuration = 1000;
|
||
let isModelTrainedFlag = false;
|
||
let predictionStopFunction = null;
|
||
|
||
// UI 元素引用 (已更新)
|
||
const statusDiv = document.getElementById('status');
|
||
const backgroundNoiseSampleCountSpan = document.getElementById('backgroundNoiseSampleCount');
|
||
const recordBackgroundNoiseBtn = document.getElementById('recordBackgroundNoiseBtn');
|
||
const categoryContainer = document.getElementById('categoryContainer');
|
||
const newCategoryNameInput = document.getElementById('newCategoryName');
|
||
const addCategoryBtn = document.getElementById('addCategoryBtn');
|
||
const trainModelBtn = document.getElementById('trainModelBtn');
|
||
const startPredictingBtn = document.getElementById('startPredictingBtn');
|
||
const stopPredictingBtn = document.getElementById('stopPredictingBtn');
|
||
const predictionResultDiv = document.getElementById('predictionResult');
|
||
|
||
// ===== 新增UI元素引用 =====
|
||
const exportModelBtn = document.getElementById('exportModelBtn');
|
||
const importModelBtn = document.getElementById('importModelBtn');
|
||
const importFileInput = document.getElementById('importFileInput');
|
||
|
||
|
||
// ======================= 初始化函数 =======================
|
||
async function init() {
|
||
statusDiv.innerText = '正在加载 TensorFlow.js 和 Speech Commands 模型...';
|
||
|
||
try {
|
||
recognizer = speechCommands.create('BROWSER_FFT');
|
||
await recognizer.ensureModelLoaded();
|
||
transferRecognizer = recognizer.createTransfer('my-custom-model');
|
||
|
||
// 初始化时清空并设置背景噪音标签
|
||
labels = [BACKGROUND_NOISE_LABEL];
|
||
|
||
statusDiv.innerText = '模型加载成功!你可以开始录制、或导入已有的样本数据。';
|
||
recordBackgroundNoiseBtn.disabled = false;
|
||
addCategoryBtn.disabled = false;
|
||
importModelBtn.disabled = false; // 允许导入
|
||
exportModelBtn.disabled = true; // 尚无数据,默认禁用导出按钮
|
||
|
||
trainModelBtn.disabled = true;
|
||
startPredictingBtn.disabled = true;
|
||
stopPredictingBtn.disabled = true;
|
||
isModelTrainedFlag = false;
|
||
|
||
// --- 修正之处:移除此处对 checkTrainingReadiness() 的调用 ---
|
||
// checkTrainingReadiness(); // <--- 移除这一行!
|
||
|
||
} catch (error) {
|
||
statusDiv.innerText = `模型加载失败或麦克风无法访问: ${error.message}. 请检查麦克风权限和网络连接。`;
|
||
console.error('初始化失败:', error);
|
||
// 禁用所有按钮
|
||
const buttons = document.querySelectorAll('button');
|
||
buttons.forEach(btn => btn.disabled = true);
|
||
isModelTrainedFlag = false;
|
||
}
|
||
}
|
||
|
||
// ======================= 批量录制样本的通用函数 =======================
|
||
async function recordMultipleExamples(label, sampleCountSpanElement, buttonElement, countToRecord = 5) {
|
||
if (isRecording) {
|
||
statusDiv.innerText = '请等待当前录音完成...';
|
||
return;
|
||
}
|
||
|
||
isRecording = true;
|
||
buttonElement.disabled = true;
|
||
buttonElement.innerText = '正在录制...';
|
||
|
||
for (let i = 0; i < countToRecord; i++) {
|
||
statusDiv.innerText = `正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`;
|
||
try {
|
||
await transferRecognizer.collectExample(
|
||
label,
|
||
{ amplitudeRequired: true, durationMillis: recordDuration }
|
||
);
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
sampleCountSpanElement.innerText = exampleCounts[label] || 0;
|
||
if (i < countToRecord - 1) {
|
||
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
|
||
}
|
||
} catch (error) {
|
||
statusDiv.innerText = `录制 "${label}" 样本失败: ${error.message}`;
|
||
console.error(`录制 ${label} 样本失败:`, error);
|
||
// 如果某个样本录制失败,则停止当前批次的录制
|
||
break;
|
||
}
|
||
}
|
||
|
||
buttonElement.disabled = false;
|
||
buttonElement.innerText = '录制样本';
|
||
isRecording = false;
|
||
checkTrainingReadiness(); // 录制完成后检查训练就绪状态
|
||
statusDiv.innerText = `已为 "${label}" 收集了 ${transferRecognizer.countExamples()[label] || 0} 个样本。`;
|
||
}
|
||
|
||
|
||
// ============== 背景噪音样本收集 (无修改) ==================
|
||
recordBackgroundNoiseBtn.onclick = async () => {
|
||
await recordMultipleExamples(BACKGROUND_NOISE_LABEL, backgroundNoiseSampleCountSpan, recordBackgroundNoiseBtn, 5);
|
||
};
|
||
|
||
// ======================= 自定义类别管理 =======================
|
||
|
||
// 添加新类别(用户手动添加)
|
||
function addCustomCategory(categoryName) {
|
||
if (!categoryName) {
|
||
alert('类别名称不能为空!');
|
||
return;
|
||
}
|
||
// 检查是否与现有标签重复(包括背景噪音,尽管背景噪音不会由用户输入)
|
||
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
|
||
alert(`类别 "${categoryName}" 已经存在!`);
|
||
return;
|
||
}
|
||
// 将标签添加到本地数组以供 UI 逻辑和后续预测结果查找使用
|
||
labels.push(categoryName);
|
||
|
||
// 创建UI时样本数量为0
|
||
createCategoryUI(categoryName, 0);
|
||
newCategoryNameInput.value = ''; // 清空输入框
|
||
checkTrainingReadiness(); // 添加新类别后检查训练就绪状态
|
||
}
|
||
|
||
// 添加自定义类别按钮点击事件
|
||
addCategoryBtn.onclick = () => {
|
||
addCustomCategory(newCategoryNameInput.value.trim());
|
||
};
|
||
|
||
// 创建类别UI的辅助函数(用于手动添加和导入后重建)
|
||
function createCategoryUI(categoryName, sampleCount) {
|
||
// categoryId 此时仅用于生成唯一的 ID,不直接传给 collectExample
|
||
const categoryId = labels.indexOf(categoryName);
|
||
|
||
const categoryBlock = document.createElement('div');
|
||
categoryBlock.className = 'category-block';
|
||
// 添加一个ID以便后续删除或识别
|
||
categoryBlock.id = `category-block-${encodeURIComponent(categoryName)}`;
|
||
|
||
categoryBlock.innerHTML = `
|
||
<h3>${categoryName}</h3>
|
||
<p>样本数量: <span id="sampleCount-${categoryId}">${sampleCount}</span></p>
|
||
<button id="recordBtn-${categoryId}">录制样本</button>
|
||
`;
|
||
categoryContainer.appendChild(categoryBlock);
|
||
|
||
// 绑定录音按钮事件
|
||
const recordBtn = document.getElementById(`recordBtn-${categoryId}`);
|
||
const sampleCountSpan = document.getElementById(`sampleCount-${categoryId}`);
|
||
|
||
recordBtn.onclick = async () => {
|
||
await recordMultipleExamples(categoryName, sampleCountSpan, recordBtn, 5);
|
||
};
|
||
}
|
||
|
||
|
||
// ======================= 状态检查 =======================
|
||
function checkTrainingReadiness() {
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
|
||
// 检查是否有任何样本,以决定是否启用“导出”按钮
|
||
const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0);
|
||
exportModelBtn.disabled = totalSamples === 0;
|
||
|
||
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
|
||
|
||
let customCategoriesReady = 0;
|
||
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
|
||
// 从索引 1 开始,因为 0 是背景噪音
|
||
for (let i = 1; i < labels.length; i++) {
|
||
const customLabel = labels[i];
|
||
if ((exampleCounts[customLabel] || 0) > 0) {
|
||
customCategoriesReady++;
|
||
}
|
||
}
|
||
|
||
// 必须有背景噪音样本,并且至少一个自定义类别有样本
|
||
if (backgroundNoiseReady && customCategoriesReady >= 1) {
|
||
trainModelBtn.disabled = false;
|
||
} else {
|
||
trainModelBtn.disabled = true;
|
||
}
|
||
}
|
||
|
||
// ======================= 模型训练 (无修改) =======================
|
||
trainModelBtn.onclick = async () => {
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
console.log('--- DEBUG: 训练开始前,各类别样本数量:', exampleCounts);
|
||
|
||
let totalExamples = 0;
|
||
let validClasses = 0;
|
||
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
|
||
let allClassesHaveEnoughSamples = true;
|
||
|
||
// 统计所有类别的有效样本数,并检查每个类别是否达到`isTrained`的最低要求
|
||
for (const labelName of labels) { // 遍历所有标签(包括背景噪音)
|
||
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
|
||
totalExamples += exampleCounts[labelName];
|
||
validClasses++;
|
||
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||
allClassesHaveEnoughSamples = false;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 更明确的样本数量检查提示
|
||
if (validClasses < 2) {
|
||
alert(`训练需要至少 "背景噪音" (已存在) 和另一个自定义类别 (您需要添加并录制样本)。\n\n当前只有 ${validClasses} 个有效类别。`);
|
||
return;
|
||
}
|
||
if (!allClassesHaveEnoughSamples) {
|
||
alert(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。\n(当前某些类别样本不足,请检查!)\n\n建议每个类别多收集一些(例如 5-10 个)以获得更好的模型效果。`);
|
||
return;
|
||
}
|
||
if (totalExamples === 0) { // 额外的安全检查,理论上会被上面的validClasses捕捉
|
||
alert('没有收集到任何训练样本!请先录制样本。');
|
||
return;
|
||
}
|
||
|
||
|
||
statusDiv.innerText = '模型训练中...请稍候。';
|
||
trainModelBtn.disabled = true;
|
||
startPredictingBtn.disabled = true;
|
||
stopPredictingBtn.disabled = true;
|
||
|
||
const trainingConfig = {
|
||
epochs: 50,
|
||
batchSize: 16,
|
||
validationSplit: 0.1,
|
||
shuffle: true,
|
||
yieldEvery: 'epoch',
|
||
callbacks: {
|
||
onEpochEnd: (epoch, logs) => {
|
||
statusDiv.innerText = `训练 Epoch ${epoch + 1}/${trainingConfig.epochs}, Loss: ${logs.loss ? logs.loss.toFixed(4) : 'N/A'}, Accuracy: ${logs.acc ? logs.acc.toFixed(4) : 'N/A'}`;
|
||
}
|
||
}
|
||
};
|
||
|
||
try {
|
||
await transferRecognizer.train(trainingConfig);
|
||
statusDiv.innerText = '模型训练完成!你可以开始识别了。';
|
||
predictionResultDiv.innerText = '训练完成,等待识别...';
|
||
startPredictingBtn.disabled = false;
|
||
|
||
// 训练成功后,手动设置状态标志
|
||
isModelTrainedFlag = true;
|
||
console.log('--- DEBUG: 训练成功完成,此时 transferRecognizer.isTrained 为:', transferRecognizer.isTrained);
|
||
} catch (error) {
|
||
statusDiv.innerText = `模型训练失败: ${error.message}. 这通常是由于样本数量过少,类别不均,或录音质量问题导致。请确保每个类别至少有 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本,并且多录制一些(例如 5-10 个)!`;
|
||
console.error('训练失败:', error);
|
||
// 训练失败时重置状态
|
||
isModelTrainedFlag = false;
|
||
} finally {
|
||
trainModelBtn.disabled = false;
|
||
}
|
||
};
|
||
|
||
// ======================= 实时预测 (无修改) =======================
|
||
startPredictingBtn.onclick = async () => {
|
||
console.log('--- DEBUG: 点击开始识别时, isModelTrainedFlag 为:', isModelTrainedFlag);
|
||
if (isPredicting) {
|
||
statusDiv.innerText = '识别已经在进行中...';
|
||
return;
|
||
}
|
||
|
||
// 使用自定义标志进行判断
|
||
if (!isModelTrainedFlag) {
|
||
alert('模型尚未训练完成,请先训练模型!');
|
||
return;
|
||
}
|
||
|
||
isPredicting = true;
|
||
startPredictingBtn.disabled = true;
|
||
stopPredictingBtn.disabled = false;
|
||
trainModelBtn.disabled = true;
|
||
recordBackgroundNoiseBtn.disabled = true;
|
||
addCategoryBtn.disabled = true;
|
||
// 禁用所有录制按钮 (确保在预测时不能添加新样本)
|
||
document.querySelectorAll('.category-block button').forEach(btn => btn.disabled = true);
|
||
|
||
statusDiv.innerText = '正在开始识别... 请发出你训练过的声音。';
|
||
predictionResultDiv.innerText = '等待识别结果...';
|
||
|
||
predictionStopFunction = await transferRecognizer.listen(result => {
|
||
if (!isPredicting) return;
|
||
|
||
// `transferRecognizer.wordLabels()` 会返回 transferRecognizer 内部按顺序排列的所有标签名称。
|
||
// `result.scores` 的索引会与 `transferRecognizer.wordLabels()` 的索引对应。
|
||
const classLabels = transferRecognizer.wordLabels();
|
||
|
||
const scores = result.scores;
|
||
const maxScore = Math.max(...scores);
|
||
const predictedIndex = scores.indexOf(maxScore);
|
||
|
||
let predictedLabel = classLabels[predictedIndex]; // 从 transferRecognizer 的内部标签列表中获取
|
||
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
|
||
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
|
||
predictedLabel = '背景噪音';
|
||
}
|
||
|
||
predictionResultDiv.innerText = `预测结果:${predictedLabel} (置信度: ${(maxScore * 100).toFixed(2)}%)`;
|
||
|
||
}, {
|
||
includeEmbedding: true,
|
||
probabilityThreshold: 0.75,
|
||
suppressionTimeMillis: 300,
|
||
overlapFactor: 0.50,
|
||
});
|
||
|
||
console.log('--- DEBUG: predictionStopFunction 赋值后:', predictionStopFunction);
|
||
console.log('--- DEBUG: typeof predictionStopFunction 赋值后:', typeof predictionStopFunction);
|
||
|
||
};
|
||
|
||
stopPredictingBtn.onclick = () => {
|
||
if (isPredicting) {
|
||
if (typeof predictionStopFunction === 'function') {
|
||
predictionStopFunction();
|
||
predictionStopFunction = null;
|
||
} else {
|
||
console.warn('--- WARN: predictionStopFunction 不是一个函数,无法停止监听。');
|
||
}
|
||
|
||
isPredicting = false;
|
||
startPredictingBtn.disabled = false;
|
||
stopPredictingBtn.disabled = true;
|
||
trainModelBtn.disabled = false;
|
||
recordBackgroundNoiseBtn.disabled = false;
|
||
addCategoryBtn.disabled = false;
|
||
|
||
// 重新启用所有录制按钮 (只有在不是正在录音状态时才启用)
|
||
document.querySelectorAll('.category-block button').forEach(btn => {
|
||
if (!isRecording) {
|
||
btn.disabled = false;
|
||
}
|
||
});
|
||
|
||
statusDiv.innerText = '已停止识别。';
|
||
predictionResultDiv.innerText = '停止识别。';
|
||
}
|
||
};
|
||
|
||
|
||
// ======================= 新增:模型导出功能 =======================
|
||
exportModelBtn.onclick = async () => {
|
||
try {
|
||
// 序列化所有收集到的样本数据
|
||
const serializedExamples = transferRecognizer.serializeExamples();
|
||
|
||
// 创建一个 Blob 对象
|
||
const blob = new Blob([serializedExamples], { type: 'application/octet-stream' });
|
||
|
||
// 创建一个下载链接
|
||
const url = URL.createObjectURL(blob);
|
||
const a = document.createElement('a');
|
||
a.href = url;
|
||
// 定制文件名,包含日期和时间
|
||
const now = new Date();
|
||
const filename = `speech_commands_data_${now.getFullYear()}${(now.getMonth()+1).toString().padStart(2, '0')}${now.getDate().toString().padStart(2, '0')}_${now.getHours().toString().padStart(2, '0')}${now.getMinutes().toString().padStart(2, '0')}${now.getSeconds().toString().padStart(2, '0')}.bin`;
|
||
a.download = filename;
|
||
|
||
// 模拟点击下载
|
||
document.body.appendChild(a);
|
||
a.click();
|
||
document.body.removeChild(a);
|
||
URL.revokeObjectURL(url); // 释放内存
|
||
|
||
statusDiv.innerText = `数据已成功导出为 "${filename}"。`;
|
||
} catch (error) {
|
||
statusDiv.innerText = `导出数据失败: ${error.message}`;
|
||
console.error('导出数据失败:', error);
|
||
alert('导出数据失败。请确保您已录制至少一个样本!');
|
||
}
|
||
};
|
||
|
||
// ======================= 新增:模型导入功能 =======================
|
||
importModelBtn.onclick = () => {
|
||
// 触发隐藏的文件输入框点击事件
|
||
importFileInput.click();
|
||
};
|
||
|
||
importFileInput.onchange = async (event) => {
|
||
const file = event.target.files[0];
|
||
if (!file) {
|
||
statusDiv.innerText = '未选择文件。';
|
||
return;
|
||
}
|
||
|
||
if (!file.name.endsWith('.bin')) {
|
||
alert('请选择后缀名为 .bin 的文件!');
|
||
statusDiv.innerText = '文件格式不正确,请选择 .bin 文件。';
|
||
// 清空文件输入,以便用户可以选择其他文件
|
||
importFileInput.value = '';
|
||
return;
|
||
}
|
||
|
||
statusDiv.innerText = `正在导入文件 "${file.name}"...`;
|
||
|
||
const reader = new FileReader();
|
||
reader.onload = async (e) => {
|
||
try {
|
||
const dataBuffer = e.target.result; // 获取文件的 ArrayBuffer 内容
|
||
|
||
// 清除当前的 transferRecognizer 中的所有样本
|
||
// SpeechCommands库中没有直接的clearExamples方法,
|
||
// 最简单的做法是重新创建一个 transferRecognizer 实例。
|
||
// 但更好的做法是先尝试loadExamples,如果需要重置,再做。
|
||
// 假设导入是“覆盖”现有样本的。
|
||
// TODO: 这里可以考虑增加用户确认是否清除现有样本的提示
|
||
|
||
// 导入样本。这会自动更新 internal model
|
||
await transferRecognizer.loadExamples(dataBuffer);
|
||
|
||
// 成功导入后,刷新UI
|
||
await syncUIWithLoadedData();
|
||
|
||
statusDiv.innerText = `文件 "${file.name}" 导入成功!`;
|
||
alert(`已成功导入 ${transferRecognizer.countExamples()._numExamples_ || 0} 个样本!`);
|
||
|
||
} catch (error) {
|
||
statusDiv.innerText = `导入数据失败: ${error.message}. 确保存储的是有效的模型样本数据。`;
|
||
console.error('导入数据失败:', error);
|
||
alert(`导入数据失败。请检查文件是否损坏或格式不正确。\n错误: ${error.message}`);
|
||
} finally {
|
||
// 清空文件输入,以便下次选择相同文件也能触发 onchange
|
||
importFileInput.value = '';
|
||
}
|
||
};
|
||
reader.onerror = (error) => {
|
||
statusDiv.innerText = `读取文件失败: ${error.message}`;
|
||
console.error('文件读取失败:', error);
|
||
alert('文件读取失败。');
|
||
importFileInput.value = '';
|
||
};
|
||
reader.readAsArrayBuffer(file); // 以 ArrayBuffer 格式读取文件
|
||
};
|
||
|
||
// ======================= 新增辅助函数:导入后同步UI =======================
|
||
async function syncUIWithLoadedData() {
|
||
// 清空现有除了背景噪音以外的类别块
|
||
// 遍历所有子元素,从后向前删除,避免索引问题
|
||
while (categoryContainer.firstChild) {
|
||
categoryContainer.removeChild(categoryContainer.firstChild);
|
||
}
|
||
|
||
// 重置全局 labels 数组,只保留背景噪音
|
||
labels = [BACKGROUND_NOISE_LABEL];
|
||
|
||
// 获取导入后的样本计数
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
console.log('--- DEBUG: 导入后样本数量:', exampleCounts);
|
||
|
||
// 更新背景噪音样本数量
|
||
backgroundNoiseSampleCountSpan.innerText = exampleCounts[BACKGROUND_NOISE_LABEL] || 0;
|
||
|
||
// 重新构建自定义类别 UI
|
||
for (const label of Object.keys(exampleCounts)) {
|
||
if (label === BACKGROUND_NOISE_LABEL || label === '_version_' || label === '_numExamples_') {
|
||
continue; // 跳过背景噪音和内部元数据标签
|
||
}
|
||
|
||
// 将导入的自定义标签添加到我们的 labels 数组
|
||
if (!labels.includes(label)) {
|
||
labels.push(label);
|
||
}
|
||
// 根据导入的数据创建 UI
|
||
createCategoryUI(label, exampleCounts[label]);
|
||
}
|
||
|
||
// 重置模型的训练状态
|
||
isModelTrainedFlag = false;
|
||
trainModelBtn.disabled = true; // 训练按钮默认禁用,等待 checkTrainingReadiness 启用
|
||
startPredictingBtn.disabled = true; // 预测按钮禁用
|
||
stopPredictingBtn.disabled = true; // 停止按钮禁用
|
||
|
||
// 检查训练就绪状态(现在有样本了,这个调用是安全的)
|
||
checkTrainingReadiness();
|
||
}
|
||
|
||
|
||
// ======================= 页面加载时执行 =======================
|
||
window.onload = init;
|
||
|