mobileNet/音频分类/audioClassifier.js

440 lines
18 KiB
JavaScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// audioClassifier.js (核心功能模块)
// TensorFlow.js 和 Speech Commands 库不需要在这里再次导入,
// 假定在使用此模块的 HTML 页面中已经通过 <script> 标签引入。
// ======================= 模块内部共享变量 =======================
let recognizer; // 基础的 SpeechCommands recognizer 实例
let transferRecognizer; // 用于迁移学习的 recognizer 实例
// 用户定义的类别标签数组 (包括背景噪音)。
// 这是模块的核心状态之一,外部可以通过 getLabels() 获取。
let labels = [];
const BACKGROUND_NOISE_LABEL = '_background_noise_'; // 内部使用的背景噪音标签
let isPredicting = false; // 预测状态标志
let isRecording = false; // 录音状态标志,防止重复点击
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
let isModelTrainedFlag = false; // 手动维护模型训练状态
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
// ======================= 核心功能函数 =======================
/**
* 初始化音频分类器模型。
* 必须在使用任何其他功能之前调用。
* @returns {Promise<void>} resolve 表示成功reject 表示失败。
*/
async function init() {
// 确保每次初始化时重置状态
recognizer = null;
transferRecognizer = null;
labels = [];
isPredicting = false;
isRecording = false;
isModelTrainedFlag = false;
if (predictionStopFunction) {
predictionStopFunction(); // 停止任何正在进行的预测
predictionStopFunction = null;
}
try {
// 创建基础识别器
recognizer = speechCommands.create('BROWSER_FFT');
await recognizer.ensureModelLoaded(); // 确保模型加载完成
// 创建迁移学习识别器
transferRecognizer = recognizer.createTransfer('my-custom-model');
// 初始化时,将背景噪音标签加入到我们的内部 labels 数组
labels.push(BACKGROUND_NOISE_LABEL);
console.log('AudioClassifier: 模型和迁移学习器初始化成功');
return Promise.resolve();
} catch (error) {
console.error('AudioClassifier: 初始化失败:', error);
return Promise.reject(new Error(`模型初始化失败或麦克风无法访问: ${error.message}`));
}
}
/**
* 录制单个或批量音频样本。
* 在录制示例时会暂时禁用 isRecording 标志以防止多重录制。
*
* @param {string} label - 样本所属的类别标签。
* @param {number} countToRecord - 要录制样本的数量。
* @param {Object} [options] - 其他选项。
* @param {Function} [options.onProgress] - (optional) 进度回调函数,参数为 (currentCount, totalCount)。
* @returns {Promise<number>} resolve 时返回该类别当前的总样本数量reject 时返回错误。
*/
async function recordMultipleExamples(label, countToRecord = 1, options = {}) {
if (isRecording) {
return Promise.reject(new Error('核心模块:正在录制中,请等待当前录音完成。'));
}
isRecording = true;
let currentLabelCount = 0;
try {
for (let i = 0; i < countToRecord; i++) {
// console.log(`AudioClassifier: 正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`);
await transferRecognizer.collectExample(label, {
amplitudeRequired: true,
durationMillis: recordDuration
});
const exampleCounts = transferRecognizer.countExamples();
currentLabelCount = exampleCounts[label] || 0;
if (options.onProgress && typeof options.onProgress === 'function') {
options.onProgress(i + 1, countToRecord, currentLabelCount);
}
// 在每次录音之间增加短暂延迟,以便更好地分离样本
if (i < countToRecord - 1) {
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
}
}
console.log(`AudioClassifier: 已为 "${label}" 收集了 ${currentLabelCount} 个样本。`);
return Promise.resolve(currentLabelCount);
} catch (error) {
console.error(`AudioClassifier: 录制 "${label}" 样本失败:`, error);
return Promise.reject(new Error(`录制 "${label}" 样本失败: ${error.message}`));
} finally {
isRecording = false;
}
}
/**
* 添加一个自定义类别标签到模型。
* 标签将添加到内部的 `labels` 数组中。
* @param {string} categoryName - 要添加的类别名称。
* @returns {Promise<void>} resolve 表示成功reject 表示失败(如名称为空或重复)。
*/
function addCustomCategory(categoryName) {
if (!categoryName || categoryName.trim() === '') {
return Promise.reject(new Error('类别名称不能为空。'));
}
const trimmedName = categoryName.trim();
if (labels.some(label => label.toLowerCase() === trimmedName.toLowerCase())) {
return Promise.reject(new Error(`类别 "${trimmedName}" 已经存在。`));
}
labels.push(trimmedName);
console.log(`AudioClassifier: 已添加新类别: "${trimmedName}"`);
return Promise.resolve();
}
/**
* 检查模型是否已准备好进行训练。
* @returns {Object} 包含 `ready` (boolean) 和 `details` (string) 的对象。
*/
function checkTrainingReadiness() {
let exampleCounts = {};
try {
exampleCounts = transferRecognizer.countExamples();
} catch (error) {
// 这是预期内的错误,当没有样本时 countExamples 会抛出。
// 将其捕获并视为空样本集处理。
console.warn("countExamples() 抛出异常 (无样本时正常行为)。", error.message);
}
// 检查是否有任何样本,用于更新导出按钮等逻辑
const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0);
const hasAnyExamples = totalSamples > 0;
const backgroundNoiseCount = exampleCounts[BACKGROUND_NOISE_LABEL] || 0;
const backgroundNoiseReady = backgroundNoiseCount > 0;
let customCategoriesWithSamples = 0;
// 遍历内部 labels 数组,检查每个自定义类别是否有样本
// 从索引 1 开始,因为 0 是背景噪音
for (let i = 1; i < labels.length; i++) {
const customLabel = labels[i];
if ((exampleCounts[customLabel] || 0) > 0) {
customCategoriesWithSamples++;
}
}
// 训练就绪的条件:必须有背景噪音样本,并且至少一个自定义类别有样本
const canTrain = backgroundNoiseReady && customCategoriesWithSamples >= 1;
// 更详细的训练就绪判断,包括样本数量建议
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
let trainingDetails = [];
if (!backgroundNoiseReady) {
trainingDetails.push('需要录制背景噪音样本。');
} else if (backgroundNoiseCount < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
trainingDetails.push(`建议背景噪音至少 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本 (当前: ${backgroundNoiseCount})。`);
}
if (customCategoriesWithSamples === 0) {
trainingDetails.push('需要至少一个自定义类别并录制样本。');
} else {
let insufficientCustomSamples = [];
for (let i = 1; i < labels.length; i++) {
const customLabel = labels[i];
const count = exampleCounts[customLabel] || 0;
if (count > 0 && count < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
insufficientCustomSamples.push(`类别 "${customLabel}" 样本不足 (当前: ${count},建议: ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING})。`);
}
}
if (insufficientCustomSamples.length > 0) {
trainingDetails = trainingDetails.concat(insufficientCustomSamples);
}
}
return {
ready: canTrain && trainingDetails.length === 0, // 只有完全满足条件才认为是 ready
details: trainingDetails.join('\n') || '模型已准备好进行训练。',
hasAnyExamples: hasAnyExamples // 额外返回是否有任何样本,供 UI 判断是否启用导出按钮
};
}
/**
* 训练模型。
* @param {Object} [config] - 训练配置。
* @param {number} [config.epochs=50] - 训练迭代次数。
* @param {number} [config.batchSize=16] - 批大小。
* @param {number} [config.validationSplit=0.1] - 验证集比例。
* @param {boolean} [config.shuffle=true] - 是否打乱数据。
* @param {Function} [config.onEpochEnd] - (可选) 每个 epoch 结束时的回调函数,参数为 (epoch, logs)。
* @returns {Promise<void>} resolve 表示训练完成reject 表示训练失败。
*/
async function trainModel(config = {}) {
// 再次进行训练前检查,防止外部在不满足条件时调用
const readiness = checkTrainingReadiness();
if (!readiness.ready) {
return Promise.reject(new Error(`模型未准备好训练:\n${readiness.details}`));
}
const defaultConfig = {
epochs: 50,
batchSize: 16,
validationSplit: 0.1,
shuffle: true,
// 外部传入的回调函数将在 train 内部被调用
callbacks: {
onEpochEnd: (epoch, logs) => {
if (config.onEpochEnd && typeof config.onEpochEnd === 'function') {
config.onEpochEnd(epoch, logs);
}
}
}
};
// 合并默认配置和用户指定的配置
const trainingConfig = { ...defaultConfig, ...config };
try {
await transferRecognizer.train(trainingConfig);
isModelTrainedFlag = true;
console.log('AudioClassifier: 模型训练完成。');
return Promise.resolve();
} catch (error) {
isModelTrainedFlag = false;
console.error('AudioClassifier: 模型训练失败:', error);
return Promise.reject(new Error(`模型训练失败: ${error.message}. 请确保有足够的样本且类别均衡。`));
}
}
/**
* 开始实时预测。
* @param {Function} onPrediction - 预测结果回调函数,参数为 {label: string, score: number, scores: number[], labels: string[]}。
* @param {Object} [listenOptions] - 监听选项,直接传递给 transferRecognizer.listen()。
* @returns {Promise<Function>} resolve 时返回停止预测的函数reject 时返回错误。
*/
async function startPrediction(onPrediction, listenOptions = {}) {
if (isPredicting) {
return Promise.reject(new Error('核心模块:识别已经在进行中。'));
}
if (!isModelTrainedFlag) {
return Promise.reject(new Error('核心模块:模型尚未训练完成,请先训练模型。'));
}
isPredicting = true;
const defaultOptions = {
includeEmbedding: true,
probabilityThreshold: 0.75, // 预测置信度阈值
suppressionTimeMillis: 300, // 抑制时间,避免连续触发
overlapFactor: 0.50, // 帧重叠因子
};
const options = { ...defaultOptions, ...listenOptions };
try {
predictionStopFunction = await transferRecognizer.listen(result => {
if (!isPredicting) return; // 确保在停止后不再处理结果
const classLabels = transferRecognizer.wordLabels(); // 获取模型内部的标签列表
const scores = result.scores; // 预测分数
const maxScore = Math.max(...scores);
const predictedIndex = scores.indexOf(maxScore);
let predictedLabel = classLabels[predictedIndex];
// 对背景噪音标签进行友好化处理
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
predictedLabel = '背景噪音';
}
// 调用外部传入的回调函数,传递结构化的预测结果
if (typeof onPrediction === 'function') {
onPrediction({
label: predictedLabel,
score: maxScore,
scores: scores, // 原始分数数组
predictedIndex: predictedIndex, // 预测索引
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) // 友好化的标签列表
});
}
}, options);
console.log('AudioClassifier: 开始识别。');
return Promise.resolve(predictionStopFunction); // 返回停止函数给调用者
} catch (error) {
isPredicting = false;
console.error('AudioClassifier: 开始识别失败:', error);
return Promise.reject(new Error(`开始识别失败: ${error.message}`));
}
}
/**
* 停止实时预测。
*/
function stopPrediction() {
if (isPredicting) {
if (typeof predictionStopFunction === 'function') {
predictionStopFunction(); // 调用 SpeechCommands 库返回的停止函数
predictionStopFunction = null;
} else {
console.warn('AudioClassifier: 停止预测函数未定义或不是函数。');
}
isPredicting = false;
console.log('AudioClassifier: 已停止识别。');
}
}
/**
* 获取当前收集到的各类别样本数量。
* @returns {Object} 键为类别名称,值为样本数量。
*/
function getExampleCounts() {
try {
return transferRecognizer.countExamples();
} catch (error) {
// 捕获无样本时的错误,返回空对象
return {};
}
}
/**
* 获取模型是否已训练的状态。
* @returns {boolean} 如果模型已训练则返回 true否则 false。
*/
function isModelTrained() {
return isModelTrainedFlag;
}
/**
* 获取当前模块中所有标签的数组(包括背景噪音和自定义标签)。
* @returns {string[]} 标签数组。
*/
function getLabels() {
return [...labels]; // 返回副本,避免外部直接修改内部状态
}
/**
* 序列化所有收集到的样本数据到 ArrayBuffer。
* @returns {Promise<ArrayBuffer>} 包含样本数据的 ArrayBuffer。
*/
async function serializeExamples() {
try {
const serialized = await transferRecognizer.serializeExamples();
console.log('AudioClassifier: 样本数据序列化成功。');
return Promise.resolve(serialized);
} catch (error) {
console.error('AudioClassifier: 序列化样本失败:', error);
return Promise.reject(new Error(`序列化样本失败: ${error.message}`));
}
}
/**
* 从 ArrayBuffer 导入样本数据。
* 导入后会更新内部的 labels 数组和模型状态。
* @param {ArrayBuffer} dataBuffer - 包含样本数据的 ArrayBuffer。
* @returns {Promise<Object>} resolve 时返回导入后的样本总数统计reject 时返回错误。
*/
async function loadExamples(dataBuffer) {
if (!dataBuffer instanceof ArrayBuffer) {
return Promise.reject(new Error('核心模块:导入数据必须是 ArrayBuffer 类型。'));
}
// 首先清空当前的 transferRecognizer 中的所有样本,再加载新的。
// 这等同于创建一个新的 transferRecognizer 实例,并同步其内部状态。
// SpeechCommands库没有直接的`clearExamples`或`resetExamples`方法,
// 最稳妥的方法是重新创建transferRecognizer并加载。
// 但loadExamples本身就是设计来覆盖和更新现有样本的
// 考虑到性能我们假设直接loadExamples即可。
// 注意loadExamples会清除原有数据。
try {
await transferRecognizer.loadExamples(dataBuffer);
// 成功加载后,需要更新内部的 `labels` 数组,以匹配导入的数据。
// `transferRecognizer.wordLabels()` 返回的是模型内部的实际标签。
const loadedLabels = transferRecognizer.wordLabels();
// 确保 BACKGROUND_NOISE_LABEL 还在第一个位置,并过滤掉 '_version_', '_numExamples_' 等内部标签
labels = loadedLabels.filter(label => label !== '_version_' && label !== '_numExamples_');
if (!labels.includes(BACKGROUND_NOISE_LABEL)) {
// 如果导入的数据中没有背景噪音(不常见,但作为健壮性处理)
labels.unshift(BACKGROUND_NOISE_LABEL);
} else {
// 确保背景噪音在第一位
const bgIndex = labels.indexOf(BACKGROUND_NOISE_LABEL);
if (bgIndex > 0) {
const bgLabel = labels.splice(bgIndex, 1);
labels.unshift(bgLabel[0]);
}
}
// 重置模型的训练状态,因为样本已更改,需要重新训练
isModelTrainedFlag = false;
const exampleCounts = transferRecognizer.countExamples();
console.log('AudioClassifier: 样本数据导入成功。', exampleCounts);
return Promise.resolve(exampleCounts);
} catch (error) {
console.error('AudioClassifier: 加载样本失败:', error);
return Promise.reject(new Error(`加载样本失败: ${error.message}. 请确保文件有效且未损坏。`));
}
}
// ======================= 导出的公共接口 =======================
// 这里使用 CommonJS 风格的 module.exports因为这通常用于 Node.js 环境或 Webpack/Rollup 等打包工具。
// 如果你是在浏览器中直接作为 <script> 标签引入(没有打包工具),
// 并且希望它在全局作用域下可用,可以使用 `window.AudioClassifier = {...}`。
// 在这种情况下,我将提供 `window.AudioClassifier` 的方式。
// 为了 Web 浏览器环境,将核心功能挂载到全局对象上
window.AudioClassifier = {
init, // 初始化模型
recordMultipleExamples, // 录制样本
addCustomCategory, // 添加自定义类别
checkTrainingReadiness, // 检查模型是否可训练
trainModel, // 训练模型
startPrediction, // 开始预测
stopPrediction, // 停止预测
getExampleCounts, // 获取所有类别样本数量
isModelTrained, // 检查模型是否已训练
getLabels, // 获取当前所有标签
serializeExamples, // 导出样本数据
loadExamples // 导入样本数据
};