440 lines
18 KiB
JavaScript
440 lines
18 KiB
JavaScript
// audioClassifier.js (核心功能模块)
|
||
|
||
// TensorFlow.js 和 Speech Commands 库不需要在这里再次导入,
|
||
// 假定在使用此模块的 HTML 页面中已经通过 <script> 标签引入。
|
||
|
||
// ======================= 模块内部共享变量 =======================
|
||
let recognizer; // 基础的 SpeechCommands recognizer 实例
|
||
let transferRecognizer; // 用于迁移学习的 recognizer 实例
|
||
|
||
// 用户定义的类别标签数组 (包括背景噪音)。
|
||
// 这是模块的核心状态之一,外部可以通过 getLabels() 获取。
|
||
let labels = [];
|
||
const BACKGROUND_NOISE_LABEL = '_background_noise_'; // 内部使用的背景噪音标签
|
||
|
||
let isPredicting = false; // 预测状态标志
|
||
let isRecording = false; // 录音状态标志,防止重复点击
|
||
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
|
||
|
||
let isModelTrainedFlag = false; // 手动维护模型训练状态
|
||
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
|
||
|
||
|
||
// ======================= 核心功能函数 =======================
|
||
|
||
/**
|
||
* 初始化音频分类器模型。
|
||
* 必须在使用任何其他功能之前调用。
|
||
* @returns {Promise<void>} resolve 表示成功,reject 表示失败。
|
||
*/
|
||
async function init() {
|
||
// 确保每次初始化时重置状态
|
||
recognizer = null;
|
||
transferRecognizer = null;
|
||
labels = [];
|
||
isPredicting = false;
|
||
isRecording = false;
|
||
isModelTrainedFlag = false;
|
||
if (predictionStopFunction) {
|
||
predictionStopFunction(); // 停止任何正在进行的预测
|
||
predictionStopFunction = null;
|
||
}
|
||
|
||
try {
|
||
// 创建基础识别器
|
||
recognizer = speechCommands.create('BROWSER_FFT');
|
||
await recognizer.ensureModelLoaded(); // 确保模型加载完成
|
||
|
||
// 创建迁移学习识别器
|
||
transferRecognizer = recognizer.createTransfer('my-custom-model');
|
||
|
||
// 初始化时,将背景噪音标签加入到我们的内部 labels 数组
|
||
labels.push(BACKGROUND_NOISE_LABEL);
|
||
|
||
console.log('AudioClassifier: 模型和迁移学习器初始化成功');
|
||
return Promise.resolve();
|
||
} catch (error) {
|
||
console.error('AudioClassifier: 初始化失败:', error);
|
||
return Promise.reject(new Error(`模型初始化失败或麦克风无法访问: ${error.message}`));
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 录制单个或批量音频样本。
|
||
* 在录制示例时会暂时禁用 isRecording 标志以防止多重录制。
|
||
*
|
||
* @param {string} label - 样本所属的类别标签。
|
||
* @param {number} countToRecord - 要录制样本的数量。
|
||
* @param {Object} [options] - 其他选项。
|
||
* @param {Function} [options.onProgress] - (optional) 进度回调函数,参数为 (currentCount, totalCount)。
|
||
* @returns {Promise<number>} resolve 时返回该类别当前的总样本数量,reject 时返回错误。
|
||
*/
|
||
async function recordMultipleExamples(label, countToRecord = 1, options = {}) {
|
||
if (isRecording) {
|
||
return Promise.reject(new Error('核心模块:正在录制中,请等待当前录音完成。'));
|
||
}
|
||
|
||
isRecording = true;
|
||
let currentLabelCount = 0;
|
||
|
||
try {
|
||
for (let i = 0; i < countToRecord; i++) {
|
||
// console.log(`AudioClassifier: 正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`);
|
||
await transferRecognizer.collectExample(label, {
|
||
amplitudeRequired: true,
|
||
durationMillis: recordDuration
|
||
});
|
||
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
currentLabelCount = exampleCounts[label] || 0;
|
||
|
||
if (options.onProgress && typeof options.onProgress === 'function') {
|
||
options.onProgress(i + 1, countToRecord, currentLabelCount);
|
||
}
|
||
|
||
// 在每次录音之间增加短暂延迟,以便更好地分离样本
|
||
if (i < countToRecord - 1) {
|
||
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
|
||
}
|
||
}
|
||
console.log(`AudioClassifier: 已为 "${label}" 收集了 ${currentLabelCount} 个样本。`);
|
||
return Promise.resolve(currentLabelCount);
|
||
} catch (error) {
|
||
console.error(`AudioClassifier: 录制 "${label}" 样本失败:`, error);
|
||
return Promise.reject(new Error(`录制 "${label}" 样本失败: ${error.message}`));
|
||
} finally {
|
||
isRecording = false;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 添加一个自定义类别标签到模型。
|
||
* 标签将添加到内部的 `labels` 数组中。
|
||
* @param {string} categoryName - 要添加的类别名称。
|
||
* @returns {Promise<void>} resolve 表示成功,reject 表示失败(如名称为空或重复)。
|
||
*/
|
||
function addCustomCategory(categoryName) {
|
||
if (!categoryName || categoryName.trim() === '') {
|
||
return Promise.reject(new Error('类别名称不能为空。'));
|
||
}
|
||
const trimmedName = categoryName.trim();
|
||
if (labels.some(label => label.toLowerCase() === trimmedName.toLowerCase())) {
|
||
return Promise.reject(new Error(`类别 "${trimmedName}" 已经存在。`));
|
||
}
|
||
labels.push(trimmedName);
|
||
console.log(`AudioClassifier: 已添加新类别: "${trimmedName}"`);
|
||
return Promise.resolve();
|
||
}
|
||
|
||
/**
|
||
* 检查模型是否已准备好进行训练。
|
||
* @returns {Object} 包含 `ready` (boolean) 和 `details` (string) 的对象。
|
||
*/
|
||
function checkTrainingReadiness() {
|
||
let exampleCounts = {};
|
||
try {
|
||
exampleCounts = transferRecognizer.countExamples();
|
||
} catch (error) {
|
||
// 这是预期内的错误,当没有样本时 countExamples 会抛出。
|
||
// 将其捕获并视为空样本集处理。
|
||
console.warn("countExamples() 抛出异常 (无样本时正常行为)。", error.message);
|
||
}
|
||
|
||
// 检查是否有任何样本,用于更新导出按钮等逻辑
|
||
const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0);
|
||
const hasAnyExamples = totalSamples > 0;
|
||
|
||
const backgroundNoiseCount = exampleCounts[BACKGROUND_NOISE_LABEL] || 0;
|
||
const backgroundNoiseReady = backgroundNoiseCount > 0;
|
||
|
||
let customCategoriesWithSamples = 0;
|
||
// 遍历内部 labels 数组,检查每个自定义类别是否有样本
|
||
// 从索引 1 开始,因为 0 是背景噪音
|
||
for (let i = 1; i < labels.length; i++) {
|
||
const customLabel = labels[i];
|
||
if ((exampleCounts[customLabel] || 0) > 0) {
|
||
customCategoriesWithSamples++;
|
||
}
|
||
}
|
||
|
||
// 训练就绪的条件:必须有背景噪音样本,并且至少一个自定义类别有样本
|
||
const canTrain = backgroundNoiseReady && customCategoriesWithSamples >= 1;
|
||
|
||
// 更详细的训练就绪判断,包括样本数量建议
|
||
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
|
||
let trainingDetails = [];
|
||
|
||
if (!backgroundNoiseReady) {
|
||
trainingDetails.push('需要录制背景噪音样本。');
|
||
} else if (backgroundNoiseCount < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||
trainingDetails.push(`建议背景噪音至少 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本 (当前: ${backgroundNoiseCount})。`);
|
||
}
|
||
|
||
if (customCategoriesWithSamples === 0) {
|
||
trainingDetails.push('需要至少一个自定义类别并录制样本。');
|
||
} else {
|
||
let insufficientCustomSamples = [];
|
||
for (let i = 1; i < labels.length; i++) {
|
||
const customLabel = labels[i];
|
||
const count = exampleCounts[customLabel] || 0;
|
||
if (count > 0 && count < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||
insufficientCustomSamples.push(`类别 "${customLabel}" 样本不足 (当前: ${count},建议: ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING})。`);
|
||
}
|
||
}
|
||
if (insufficientCustomSamples.length > 0) {
|
||
trainingDetails = trainingDetails.concat(insufficientCustomSamples);
|
||
}
|
||
}
|
||
|
||
return {
|
||
ready: canTrain && trainingDetails.length === 0, // 只有完全满足条件才认为是 ready
|
||
details: trainingDetails.join('\n') || '模型已准备好进行训练。',
|
||
hasAnyExamples: hasAnyExamples // 额外返回是否有任何样本,供 UI 判断是否启用导出按钮
|
||
};
|
||
}
|
||
|
||
/**
|
||
* 训练模型。
|
||
* @param {Object} [config] - 训练配置。
|
||
* @param {number} [config.epochs=50] - 训练迭代次数。
|
||
* @param {number} [config.batchSize=16] - 批大小。
|
||
* @param {number} [config.validationSplit=0.1] - 验证集比例。
|
||
* @param {boolean} [config.shuffle=true] - 是否打乱数据。
|
||
* @param {Function} [config.onEpochEnd] - (可选) 每个 epoch 结束时的回调函数,参数为 (epoch, logs)。
|
||
* @returns {Promise<void>} resolve 表示训练完成,reject 表示训练失败。
|
||
*/
|
||
async function trainModel(config = {}) {
|
||
// 再次进行训练前检查,防止外部在不满足条件时调用
|
||
const readiness = checkTrainingReadiness();
|
||
if (!readiness.ready) {
|
||
return Promise.reject(new Error(`模型未准备好训练:\n${readiness.details}`));
|
||
}
|
||
|
||
const defaultConfig = {
|
||
epochs: 50,
|
||
batchSize: 16,
|
||
validationSplit: 0.1,
|
||
shuffle: true,
|
||
// 外部传入的回调函数将在 train 内部被调用
|
||
callbacks: {
|
||
onEpochEnd: (epoch, logs) => {
|
||
if (config.onEpochEnd && typeof config.onEpochEnd === 'function') {
|
||
config.onEpochEnd(epoch, logs);
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
// 合并默认配置和用户指定的配置
|
||
const trainingConfig = { ...defaultConfig, ...config };
|
||
|
||
try {
|
||
await transferRecognizer.train(trainingConfig);
|
||
isModelTrainedFlag = true;
|
||
console.log('AudioClassifier: 模型训练完成。');
|
||
return Promise.resolve();
|
||
} catch (error) {
|
||
isModelTrainedFlag = false;
|
||
console.error('AudioClassifier: 模型训练失败:', error);
|
||
return Promise.reject(new Error(`模型训练失败: ${error.message}. 请确保有足够的样本且类别均衡。`));
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 开始实时预测。
|
||
* @param {Function} onPrediction - 预测结果回调函数,参数为 {label: string, score: number, scores: number[], labels: string[]}。
|
||
* @param {Object} [listenOptions] - 监听选项,直接传递给 transferRecognizer.listen()。
|
||
* @returns {Promise<Function>} resolve 时返回停止预测的函数,reject 时返回错误。
|
||
*/
|
||
async function startPrediction(onPrediction, listenOptions = {}) {
|
||
if (isPredicting) {
|
||
return Promise.reject(new Error('核心模块:识别已经在进行中。'));
|
||
}
|
||
|
||
if (!isModelTrainedFlag) {
|
||
return Promise.reject(new Error('核心模块:模型尚未训练完成,请先训练模型。'));
|
||
}
|
||
|
||
isPredicting = true;
|
||
|
||
const defaultOptions = {
|
||
includeEmbedding: true,
|
||
probabilityThreshold: 0.75, // 预测置信度阈值
|
||
suppressionTimeMillis: 300, // 抑制时间,避免连续触发
|
||
overlapFactor: 0.50, // 帧重叠因子
|
||
};
|
||
|
||
const options = { ...defaultOptions, ...listenOptions };
|
||
|
||
try {
|
||
predictionStopFunction = await transferRecognizer.listen(result => {
|
||
if (!isPredicting) return; // 确保在停止后不再处理结果
|
||
|
||
const classLabels = transferRecognizer.wordLabels(); // 获取模型内部的标签列表
|
||
const scores = result.scores; // 预测分数
|
||
const maxScore = Math.max(...scores);
|
||
const predictedIndex = scores.indexOf(maxScore);
|
||
|
||
let predictedLabel = classLabels[predictedIndex];
|
||
// 对背景噪音标签进行友好化处理
|
||
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
|
||
predictedLabel = '背景噪音';
|
||
}
|
||
|
||
// 调用外部传入的回调函数,传递结构化的预测结果
|
||
if (typeof onPrediction === 'function') {
|
||
onPrediction({
|
||
label: predictedLabel,
|
||
score: maxScore,
|
||
scores: scores, // 原始分数数组
|
||
predictedIndex: predictedIndex, // 预测索引
|
||
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) // 友好化的标签列表
|
||
});
|
||
}
|
||
|
||
}, options);
|
||
|
||
console.log('AudioClassifier: 开始识别。');
|
||
return Promise.resolve(predictionStopFunction); // 返回停止函数给调用者
|
||
} catch (error) {
|
||
isPredicting = false;
|
||
console.error('AudioClassifier: 开始识别失败:', error);
|
||
return Promise.reject(new Error(`开始识别失败: ${error.message}`));
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 停止实时预测。
|
||
*/
|
||
function stopPrediction() {
|
||
if (isPredicting) {
|
||
if (typeof predictionStopFunction === 'function') {
|
||
predictionStopFunction(); // 调用 SpeechCommands 库返回的停止函数
|
||
predictionStopFunction = null;
|
||
} else {
|
||
console.warn('AudioClassifier: 停止预测函数未定义或不是函数。');
|
||
}
|
||
isPredicting = false;
|
||
console.log('AudioClassifier: 已停止识别。');
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取当前收集到的各类别样本数量。
|
||
* @returns {Object} 键为类别名称,值为样本数量。
|
||
*/
|
||
function getExampleCounts() {
|
||
try {
|
||
return transferRecognizer.countExamples();
|
||
} catch (error) {
|
||
// 捕获无样本时的错误,返回空对象
|
||
return {};
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取模型是否已训练的状态。
|
||
* @returns {boolean} 如果模型已训练则返回 true,否则 false。
|
||
*/
|
||
function isModelTrained() {
|
||
return isModelTrainedFlag;
|
||
}
|
||
|
||
/**
|
||
* 获取当前模块中所有标签的数组(包括背景噪音和自定义标签)。
|
||
* @returns {string[]} 标签数组。
|
||
*/
|
||
function getLabels() {
|
||
return [...labels]; // 返回副本,避免外部直接修改内部状态
|
||
}
|
||
|
||
/**
|
||
* 序列化所有收集到的样本数据到 ArrayBuffer。
|
||
* @returns {Promise<ArrayBuffer>} 包含样本数据的 ArrayBuffer。
|
||
*/
|
||
async function serializeExamples() {
|
||
try {
|
||
const serialized = await transferRecognizer.serializeExamples();
|
||
console.log('AudioClassifier: 样本数据序列化成功。');
|
||
return Promise.resolve(serialized);
|
||
} catch (error) {
|
||
console.error('AudioClassifier: 序列化样本失败:', error);
|
||
return Promise.reject(new Error(`序列化样本失败: ${error.message}`));
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 从 ArrayBuffer 导入样本数据。
|
||
* 导入后会更新内部的 labels 数组和模型状态。
|
||
* @param {ArrayBuffer} dataBuffer - 包含样本数据的 ArrayBuffer。
|
||
* @returns {Promise<Object>} resolve 时返回导入后的样本总数统计,reject 时返回错误。
|
||
*/
|
||
async function loadExamples(dataBuffer) {
|
||
if (!dataBuffer instanceof ArrayBuffer) {
|
||
return Promise.reject(new Error('核心模块:导入数据必须是 ArrayBuffer 类型。'));
|
||
}
|
||
|
||
// 首先清空当前的 transferRecognizer 中的所有样本,再加载新的。
|
||
// 这等同于创建一个新的 transferRecognizer 实例,并同步其内部状态。
|
||
// SpeechCommands库没有直接的`clearExamples`或`resetExamples`方法,
|
||
// 最稳妥的方法是重新创建transferRecognizer并加载。
|
||
// 但loadExamples本身就是设计来覆盖和更新现有样本的,
|
||
// 考虑到性能,我们假设直接loadExamples即可。
|
||
// 注意:loadExamples会清除原有数据。
|
||
|
||
try {
|
||
await transferRecognizer.loadExamples(dataBuffer);
|
||
|
||
// 成功加载后,需要更新内部的 `labels` 数组,以匹配导入的数据。
|
||
// `transferRecognizer.wordLabels()` 返回的是模型内部的实际标签。
|
||
const loadedLabels = transferRecognizer.wordLabels();
|
||
// 确保 BACKGROUND_NOISE_LABEL 还在第一个位置,并过滤掉 '_version_', '_numExamples_' 等内部标签
|
||
labels = loadedLabels.filter(label => label !== '_version_' && label !== '_numExamples_');
|
||
|
||
if (!labels.includes(BACKGROUND_NOISE_LABEL)) {
|
||
// 如果导入的数据中没有背景噪音(不常见,但作为健壮性处理)
|
||
labels.unshift(BACKGROUND_NOISE_LABEL);
|
||
} else {
|
||
// 确保背景噪音在第一位
|
||
const bgIndex = labels.indexOf(BACKGROUND_NOISE_LABEL);
|
||
if (bgIndex > 0) {
|
||
const bgLabel = labels.splice(bgIndex, 1);
|
||
labels.unshift(bgLabel[0]);
|
||
}
|
||
}
|
||
|
||
// 重置模型的训练状态,因为样本已更改,需要重新训练
|
||
isModelTrainedFlag = false;
|
||
|
||
const exampleCounts = transferRecognizer.countExamples();
|
||
console.log('AudioClassifier: 样本数据导入成功。', exampleCounts);
|
||
return Promise.resolve(exampleCounts);
|
||
} catch (error) {
|
||
console.error('AudioClassifier: 加载样本失败:', error);
|
||
return Promise.reject(new Error(`加载样本失败: ${error.message}. 请确保文件有效且未损坏。`));
|
||
}
|
||
}
|
||
|
||
|
||
// ======================= 导出的公共接口 =======================
|
||
// 这里使用 CommonJS 风格的 module.exports,因为这通常用于 Node.js 环境或 Webpack/Rollup 等打包工具。
|
||
// 如果你是在浏览器中直接作为 <script> 标签引入(没有打包工具),
|
||
// 并且希望它在全局作用域下可用,可以使用 `window.AudioClassifier = {...}`。
|
||
// 在这种情况下,我将提供 `window.AudioClassifier` 的方式。
|
||
|
||
// 为了 Web 浏览器环境,将核心功能挂载到全局对象上
|
||
window.AudioClassifier = {
|
||
init, // 初始化模型
|
||
recordMultipleExamples, // 录制样本
|
||
addCustomCategory, // 添加自定义类别
|
||
checkTrainingReadiness, // 检查模型是否可训练
|
||
trainModel, // 训练模型
|
||
startPrediction, // 开始预测
|
||
stopPrediction, // 停止预测
|
||
getExampleCounts, // 获取所有类别样本数量
|
||
isModelTrained, // 检查模型是否已训练
|
||
getLabels, // 获取当前所有标签
|
||
serializeExamples, // 导出样本数据
|
||
loadExamples // 导入样本数据
|
||
};
|