[MF]修改audioClassifier.js (核心功能模块)

This commit is contained in:
51hhh 2025-08-14 15:42:05 +08:00
parent 5d36541dc5
commit 37dc1c5a76

View File

@ -1,147 +1,213 @@
// 全局变量和模型实例
let recognizer; // 基础的 SpeechCommands recognizer
let transferRecognizer; // 用于迁移学习的 recognizer
const labels = []; // 用户定义的类别标签数组 (包括背景噪音)
// 将背景噪音定义为第一个类别,其内部名称为 _background_noise_
const BACKGROUND_NOISE_LABEL = '_background_noise_';
const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引不直接用于collectExample
// audioClassifier.js (核心功能模块)
// TensorFlow.js 和 Speech Commands 库不需要在这里再次导入,
// 假定在使用此模块的 HTML 页面中已经通过 <script> 标签引入。
// ======================= 模块内部共享变量 =======================
let recognizer; // 基础的 SpeechCommands recognizer 实例
let transferRecognizer; // 用于迁移学习的 recognizer 实例
// 用户定义的类别标签数组 (包括背景噪音)。
// 这是模块的核心状态之一,外部可以通过 getLabels() 获取。
let labels = [];
const BACKGROUND_NOISE_LABEL = '_background_noise_'; // 内部使用的背景噪音标签
let isPredicting = false; // 预测状态标志
let isRecording = false; // 录音状态标志,防止重复点击
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
let isModelTrainedFlag = false; // 手动维护模型训练状态
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
// ======================= 核心功能函数 =======================
/**
* 初始化函数 - 加载模型和创建迁移学习模型
* @returns {Promise<void>}
* 初始化音频分类器模型
* 必须在使用任何其他功能之前调用
* @returns {Promise<void>} resolve 表示成功reject 表示失败
*/
async function init() {
try {
recognizer = speechCommands.create(
'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好
);
// 确保每次初始化时重置状态
recognizer = null;
transferRecognizer = null;
labels = [];
isPredicting = false;
isRecording = false;
isModelTrainedFlag = false;
if (predictionStopFunction) {
predictionStopFunction(); // 停止任何正在进行的预测
predictionStopFunction = null;
}
await recognizer.ensureModelLoaded();
try {
// 创建基础识别器
recognizer = speechCommands.create('BROWSER_FFT');
await recognizer.ensureModelLoaded(); // 确保模型加载完成
// 创建迁移学习识别器
transferRecognizer = recognizer.createTransfer('my-custom-model');
// 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组
// 初始化时,将背景噪音标签加入到我们的内部 labels 数组
labels.push(BACKGROUND_NOISE_LABEL);
console.log('AudioClassifier: 模型和迁移学习器初始化成功');
return Promise.resolve();
} catch (error) {
return Promise.reject(error);
console.error('AudioClassifier: 初始化失败:', error);
return Promise.reject(new Error(`模型初始化失败或麦克风无法访问: ${error.message}`));
}
}
/**
* 批量录制样本的通用函数
* @param {string} label - 标签名称
* @param {number} countToRecord - 要录制的样本数量
* @returns {Promise<void>}
* 录制单个或批量音频样本
* 在录制示例时会暂时禁用 isRecording 标志以防止多重录制
*
* @param {string} label - 样本所属的类别标签
* @param {number} countToRecord - 要录制样本的数量
* @param {Object} [options] - 其他选项
* @param {Function} [options.onProgress] - (optional) 进度回调函数参数为 (currentCount, totalCount)
* @returns {Promise<number>} resolve 时返回该类别当前的总样本数量reject 时返回错误
*/
async function recordMultipleExamples(label, countToRecord = 5) {
async function recordMultipleExamples(label, countToRecord = 1, options = {}) {
if (isRecording) {
return Promise.reject(new Error('正在录制中,请等待当前录音完成'));
return Promise.reject(new Error('核心模块:正在录制中,请等待当前录音完成'));
}
isRecording = true;
for (let i = 0; i < countToRecord; i++) {
try {
await transferRecognizer.collectExample(
label,
{ amplitudeRequired: true, durationMillis: recordDuration }
);
let currentLabelCount = 0;
try {
for (let i = 0; i < countToRecord; i++) {
// console.log(`AudioClassifier: 正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`);
await transferRecognizer.collectExample(label, {
amplitudeRequired: true,
durationMillis: recordDuration
});
const exampleCounts = transferRecognizer.countExamples();
currentLabelCount = exampleCounts[label] || 0;
if (options.onProgress && typeof options.onProgress === 'function') {
options.onProgress(i + 1, countToRecord, currentLabelCount);
}
// 在每次录音之间增加短暂延迟,以便更好地分离样本
if (i < countToRecord - 1) {
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
}
} catch (error) {
isRecording = false;
return Promise.reject(error);
}
console.log(`AudioClassifier: 已为 "${label}" 收集了 ${currentLabelCount} 个样本。`);
return Promise.resolve(currentLabelCount);
} catch (error) {
console.error(`AudioClassifier: 录制 "${label}" 样本失败:`, error);
return Promise.reject(new Error(`录制 "${label}" 样本失败: ${error.message}`));
} finally {
isRecording = false;
}
isRecording = false;
return Promise.resolve();
}
/**
* 添加自定义类别
* @param {string} categoryName - 类别名称
* 添加一个自定义类别标签到模型
* 标签将添加到内部的 `labels` 数组中
* @param {string} categoryName - 要添加的类别名称
* @returns {Promise<void>} resolve 表示成功reject 表示失败如名称为空或重复
*/
function addCustomCategory(categoryName) {
if (!categoryName) {
return Promise.reject(new Error('类别名称不能为空'));
if (!categoryName || categoryName.trim() === '') {
return Promise.reject(new Error('类别名称不能为空'));
}
// 检查是否与现有标签重复
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`));
const trimmedName = categoryName.trim();
if (labels.some(label => label.toLowerCase() === trimmedName.toLowerCase())) {
return Promise.reject(new Error(`类别 "${trimmedName}" 已经存在。`));
}
// 将标签添加到本地数组
labels.push(categoryName);
labels.push(trimmedName);
console.log(`AudioClassifier: 已添加新类别: "${trimmedName}"`);
return Promise.resolve();
}
/**
* 检查训练就绪状态
* @returns {boolean} 是否可以开始训练
* 检查模型是否已准备好进行训练
* @returns {Object} 包含 `ready` (boolean) `details` (string) 的对象
*/
function checkTrainingReadiness() {
const exampleCounts = transferRecognizer.countExamples();
let exampleCounts = {};
try {
exampleCounts = transferRecognizer.countExamples();
} catch (error) {
// 这是预期内的错误,当没有样本时 countExamples 会抛出。
// 将其捕获并视为空样本集处理。
console.warn("countExamples() 抛出异常 (无样本时正常行为)。", error.message);
}
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
// 检查是否有任何样本,用于更新导出按钮等逻辑
const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0);
const hasAnyExamples = totalSamples > 0;
const backgroundNoiseCount = exampleCounts[BACKGROUND_NOISE_LABEL] || 0;
const backgroundNoiseReady = backgroundNoiseCount > 0;
let customCategoriesReady = 0;
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音
let customCategoriesWithSamples = 0;
// 遍历内部 labels 数组,检查每个自定义类别是否有样本
// 从索引 1 开始,因为 0 是背景噪音
for (let i = 1; i < labels.length; i++) {
const customLabel = labels[i];
if ((exampleCounts[customLabel] || 0) > 0) {
customCategoriesReady++;
customCategoriesWithSamples++;
}
}
// 必须有背景噪音样本,并且至少一个自定义类别有样本
return backgroundNoiseReady && customCategoriesReady >= 1;
// 训练就绪的条件:必须有背景噪音样本,并且至少一个自定义类别有样本
const canTrain = backgroundNoiseReady && customCategoriesWithSamples >= 1;
// 更详细的训练就绪判断,包括样本数量建议
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
let trainingDetails = [];
if (!backgroundNoiseReady) {
trainingDetails.push('需要录制背景噪音样本。');
} else if (backgroundNoiseCount < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
trainingDetails.push(`建议背景噪音至少 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本 (当前: ${backgroundNoiseCount})。`);
}
if (customCategoriesWithSamples === 0) {
trainingDetails.push('需要至少一个自定义类别并录制样本。');
} else {
let insufficientCustomSamples = [];
for (let i = 1; i < labels.length; i++) {
const customLabel = labels[i];
const count = exampleCounts[customLabel] || 0;
if (count > 0 && count < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
insufficientCustomSamples.push(`类别 "${customLabel}" 样本不足 (当前: ${count},建议: ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING})。`);
}
}
if (insufficientCustomSamples.length > 0) {
trainingDetails = trainingDetails.concat(insufficientCustomSamples);
}
}
return {
ready: canTrain && trainingDetails.length === 0, // 只有完全满足条件才认为是 ready
details: trainingDetails.join('\n') || '模型已准备好进行训练。',
hasAnyExamples: hasAnyExamples // 额外返回是否有任何样本,供 UI 判断是否启用导出按钮
};
}
/**
* 模型训练函数
* @param {Object} trainingConfig - 训练配置参数
* @returns {Promise<void>}
* 训练模型
* @param {Object} [config] - 训练配置
* @param {number} [config.epochs=50] - 训练迭代次数
* @param {number} [config.batchSize=16] - 批大小
* @param {number} [config.validationSplit=0.1] - 验证集比例
* @param {boolean} [config.shuffle=true] - 是否打乱数据
* @param {Function} [config.onEpochEnd] - (可选) 每个 epoch 结束时的回调函数参数为 (epoch, logs)
* @returns {Promise<void>} resolve 表示训练完成reject 表示训练失败
*/
async function trainModel(trainingConfig = {}) {
const exampleCounts = transferRecognizer.countExamples();
let totalExamples = 0;
let validClasses = 0;
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
let allClassesHaveEnoughSamples = true;
// 统计所有类别的有效样本数,并检查每个类别是否达到最低要求
for (const labelName of labels) {
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
totalExamples += exampleCounts[labelName];
validClasses++;
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
allClassesHaveEnoughSamples = false;
}
}
}
if (validClasses < 2) {
return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`));
}
if (!allClassesHaveEnoughSamples) {
return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`));
}
if (totalExamples === 0) {
return Promise.reject(new Error('没有收集到任何训练样本'));
async function trainModel(config = {}) {
// 再次进行训练前检查,防止外部在不满足条件时调用
const readiness = checkTrainingReadiness();
if (!readiness.ready) {
return Promise.reject(new Error(`模型未准备好训练:\n${readiness.details}`));
}
const defaultConfig = {
@ -149,116 +215,225 @@ async function trainModel(trainingConfig = {}) {
batchSize: 16,
validationSplit: 0.1,
shuffle: true,
yieldEvery: 'epoch'
// 外部传入的回调函数将在 train 内部被调用
callbacks: {
onEpochEnd: (epoch, logs) => {
if (config.onEpochEnd && typeof config.onEpochEnd === 'function') {
config.onEpochEnd(epoch, logs);
}
}
}
};
const config = Object.assign(defaultConfig, trainingConfig);
// 合并默认配置和用户指定的配置
const trainingConfig = { ...defaultConfig, ...config };
try {
await transferRecognizer.train(config);
await transferRecognizer.train(trainingConfig);
isModelTrainedFlag = true;
console.log('AudioClassifier: 模型训练完成。');
return Promise.resolve();
} catch (error) {
isModelTrainedFlag = false;
return Promise.reject(error);
console.error('AudioClassifier: 模型训练失败:', error);
return Promise.reject(new Error(`模型训练失败: ${error.message}. 请确保有足够的样本且类别均衡。`));
}
}
/**
* 开始实时预测
* @param {Function} onPrediction - 预测结果回调函数
* @param {Object} listenOptions - 监听选项
* @returns {Promise<Function>} 停止预测的函数
* 开始实时预测
* @param {Function} onPrediction - 预测结果回调函数参数为 {label: string, score: number, scores: number[], labels: string[]}
* @param {Object} [listenOptions] - 监听选项直接传递给 transferRecognizer.listen()
* @returns {Promise<Function>} resolve 时返回停止预测的函数reject 时返回错误
*/
async function startPrediction(onPrediction, listenOptions = {}) {
if (isPredicting) {
return Promise.reject(new Error('识别已经在进行中'));
return Promise.reject(new Error('核心模块:识别已经在进行中'));
}
if (!isModelTrainedFlag) {
return Promise.reject(new Error('模型尚未训练完成'));
return Promise.reject(new Error('核心模块:模型尚未训练完成,请先训练模型。'));
}
isPredicting = true;
const defaultOptions = {
includeEmbedding: true,
probabilityThreshold: 0.75,
suppressionTimeMillis: 300,
overlapFactor: 0.50,
probabilityThreshold: 0.75, // 预测置信度阈值
suppressionTimeMillis: 300, // 抑制时间,避免连续触发
overlapFactor: 0.50, // 帧重叠因子
};
const options = Object.assign(defaultOptions, listenOptions);
const options = { ...defaultOptions, ...listenOptions };
predictionStopFunction = await transferRecognizer.listen(result => {
if (!isPredicting) return;
try {
predictionStopFunction = await transferRecognizer.listen(result => {
if (!isPredicting) return; // 确保在停止后不再处理结果
const classLabels = transferRecognizer.wordLabels();
const scores = result.scores;
const maxScore = Math.max(...scores);
const predictedIndex = scores.indexOf(maxScore);
let predictedLabel = classLabels[predictedIndex];
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
predictedLabel = '背景噪音';
}
// 调用回调函数返回预测结果
if (typeof onPrediction === 'function') {
onPrediction({
label: predictedLabel,
score: maxScore,
scores: scores,
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label)
});
}
const classLabels = transferRecognizer.wordLabels(); // 获取模型内部的标签列表
const scores = result.scores; // 预测分数
const maxScore = Math.max(...scores);
const predictedIndex = scores.indexOf(maxScore);
let predictedLabel = classLabels[predictedIndex];
// 对背景噪音标签进行友好化处理
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
predictedLabel = '背景噪音';
}
// 调用外部传入的回调函数,传递结构化的预测结果
if (typeof onPrediction === 'function') {
onPrediction({
label: predictedLabel,
score: maxScore,
scores: scores, // 原始分数数组
predictedIndex: predictedIndex, // 预测索引
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) // 友好化的标签列表
});
}
}, options);
}, options);
return Promise.resolve(predictionStopFunction);
}
/**
* 停止实时预测
*/
function stopPrediction() {
if (isPredicting) {
if (typeof predictionStopFunction === 'function') {
predictionStopFunction();
predictionStopFunction = null;
}
console.log('AudioClassifier: 开始识别。');
return Promise.resolve(predictionStopFunction); // 返回停止函数给调用者
} catch (error) {
isPredicting = false;
console.error('AudioClassifier: 开始识别失败:', error);
return Promise.reject(new Error(`开始识别失败: ${error.message}`));
}
}
/**
* 获取各类别样本数量
* @returns {Object} 各类别样本数量统计
* 停止实时预测
*/
function getExampleCounts() {
return transferRecognizer.countExamples();
function stopPrediction() {
if (isPredicting) {
if (typeof predictionStopFunction === 'function') {
predictionStopFunction(); // 调用 SpeechCommands 库返回的停止函数
predictionStopFunction = null;
} else {
console.warn('AudioClassifier: 停止预测函数未定义或不是函数。');
}
isPredicting = false;
console.log('AudioClassifier: 已停止识别。');
}
}
/**
* 获取模型是否已训练的状态
* @returns {boolean} 模型是否已训练
* 获取当前收集到的各类别样本数量
* @returns {Object} 键为类别名称值为样本数量
*/
function getExampleCounts() {
try {
return transferRecognizer.countExamples();
} catch (error) {
// 捕获无样本时的错误,返回空对象
return {};
}
}
/**
* 获取模型是否已训练的状态
* @returns {boolean} 如果模型已训练则返回 true否则 false
*/
function isModelTrained() {
return isModelTrainedFlag;
}
// 导出公共接口
/**
* 获取当前模块中所有标签的数组包括背景噪音和自定义标签
* @returns {string[]} 标签数组
*/
function getLabels() {
return [...labels]; // 返回副本,避免外部直接修改内部状态
}
/**
* 序列化所有收集到的样本数据到 ArrayBuffer
* @returns {Promise<ArrayBuffer>} 包含样本数据的 ArrayBuffer
*/
async function serializeExamples() {
try {
const serialized = await transferRecognizer.serializeExamples();
console.log('AudioClassifier: 样本数据序列化成功。');
return Promise.resolve(serialized);
} catch (error) {
console.error('AudioClassifier: 序列化样本失败:', error);
return Promise.reject(new Error(`序列化样本失败: ${error.message}`));
}
}
/**
* ArrayBuffer 导入样本数据
* 导入后会更新内部的 labels 数组和模型状态
* @param {ArrayBuffer} dataBuffer - 包含样本数据的 ArrayBuffer
* @returns {Promise<Object>} resolve 时返回导入后的样本总数统计reject 时返回错误
*/
async function loadExamples(dataBuffer) {
if (!dataBuffer instanceof ArrayBuffer) {
return Promise.reject(new Error('核心模块:导入数据必须是 ArrayBuffer 类型。'));
}
// 首先清空当前的 transferRecognizer 中的所有样本,再加载新的。
// 这等同于创建一个新的 transferRecognizer 实例,并同步其内部状态。
// SpeechCommands库没有直接的`clearExamples`或`resetExamples`方法,
// 最稳妥的方法是重新创建transferRecognizer并加载。
// 但loadExamples本身就是设计来覆盖和更新现有样本的
// 考虑到性能我们假设直接loadExamples即可。
// 注意loadExamples会清除原有数据。
try {
await transferRecognizer.loadExamples(dataBuffer);
// 成功加载后,需要更新内部的 `labels` 数组,以匹配导入的数据。
// `transferRecognizer.wordLabels()` 返回的是模型内部的实际标签。
const loadedLabels = transferRecognizer.wordLabels();
// 确保 BACKGROUND_NOISE_LABEL 还在第一个位置,并过滤掉 '_version_', '_numExamples_' 等内部标签
labels = loadedLabels.filter(label => label !== '_version_' && label !== '_numExamples_');
if (!labels.includes(BACKGROUND_NOISE_LABEL)) {
// 如果导入的数据中没有背景噪音(不常见,但作为健壮性处理)
labels.unshift(BACKGROUND_NOISE_LABEL);
} else {
// 确保背景噪音在第一位
const bgIndex = labels.indexOf(BACKGROUND_NOISE_LABEL);
if (bgIndex > 0) {
const bgLabel = labels.splice(bgIndex, 1);
labels.unshift(bgLabel[0]);
}
}
// 重置模型的训练状态,因为样本已更改,需要重新训练
isModelTrainedFlag = false;
const exampleCounts = transferRecognizer.countExamples();
console.log('AudioClassifier: 样本数据导入成功。', exampleCounts);
return Promise.resolve(exampleCounts);
} catch (error) {
console.error('AudioClassifier: 加载样本失败:', error);
return Promise.reject(new Error(`加载样本失败: ${error.message}. 请确保文件有效且未损坏。`));
}
}
// ======================= 导出的公共接口 =======================
// 这里使用 CommonJS 风格的 module.exports因为这通常用于 Node.js 环境或 Webpack/Rollup 等打包工具。
// 如果你是在浏览器中直接作为 <script> 标签引入(没有打包工具),
// 并且希望它在全局作用域下可用,可以使用 `window.AudioClassifier = {...}`。
// 在这种情况下,我将提供 `window.AudioClassifier` 的方式。
// 为了 Web 浏览器环境,将核心功能挂载到全局对象上
window.AudioClassifier = {
init,
recordMultipleExamples,
addCustomCategory,
checkTrainingReadiness,
trainModel,
startPrediction,
stopPrediction,
getExampleCounts,
isModelTrained,
labels
};
init, // 初始化模型
recordMultipleExamples, // 录制样本
addCustomCategory, // 添加自定义类别
checkTrainingReadiness, // 检查模型是否可训练
trainModel, // 训练模型
startPrediction, // 开始预测
stopPrediction, // 停止预测
getExampleCounts, // 获取所有类别样本数量
isModelTrained, // 检查模型是否已训练
getLabels, // 获取当前所有标签
serializeExamples, // 导出样本数据
loadExamples // 导入样本数据
};