[MF]修改audioClassifier.js (核心功能模块)
This commit is contained in:
parent
5d36541dc5
commit
37dc1c5a76
@ -1,147 +1,213 @@
|
|||||||
// 全局变量和模型实例
|
// audioClassifier.js (核心功能模块)
|
||||||
let recognizer; // 基础的 SpeechCommands recognizer
|
|
||||||
let transferRecognizer; // 用于迁移学习的 recognizer
|
// TensorFlow.js 和 Speech Commands 库不需要在这里再次导入,
|
||||||
const labels = []; // 用户定义的类别标签数组 (包括背景噪音)
|
// 假定在使用此模块的 HTML 页面中已经通过 <script> 标签引入。
|
||||||
// 将背景噪音定义为第一个类别,其内部名称为 _background_noise_
|
|
||||||
const BACKGROUND_NOISE_LABEL = '_background_noise_';
|
// ======================= 模块内部共享变量 =======================
|
||||||
const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引,不直接用于collectExample
|
let recognizer; // 基础的 SpeechCommands recognizer 实例
|
||||||
|
let transferRecognizer; // 用于迁移学习的 recognizer 实例
|
||||||
|
|
||||||
|
// 用户定义的类别标签数组 (包括背景噪音)。
|
||||||
|
// 这是模块的核心状态之一,外部可以通过 getLabels() 获取。
|
||||||
|
let labels = [];
|
||||||
|
const BACKGROUND_NOISE_LABEL = '_background_noise_'; // 内部使用的背景噪音标签
|
||||||
|
|
||||||
let isPredicting = false; // 预测状态标志
|
let isPredicting = false; // 预测状态标志
|
||||||
let isRecording = false; // 录音状态标志,防止重复点击
|
let isRecording = false; // 录音状态标志,防止重复点击
|
||||||
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
|
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
|
||||||
|
|
||||||
let isModelTrainedFlag = false; // 手动维护模型训练状态
|
let isModelTrainedFlag = false; // 手动维护模型训练状态
|
||||||
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
|
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
|
||||||
|
|
||||||
|
|
||||||
|
// ======================= 核心功能函数 =======================
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 初始化函数 - 加载模型和创建迁移学习模型
|
* 初始化音频分类器模型。
|
||||||
* @returns {Promise<void>}
|
* 必须在使用任何其他功能之前调用。
|
||||||
|
* @returns {Promise<void>} resolve 表示成功,reject 表示失败。
|
||||||
*/
|
*/
|
||||||
async function init() {
|
async function init() {
|
||||||
try {
|
// 确保每次初始化时重置状态
|
||||||
recognizer = speechCommands.create(
|
recognizer = null;
|
||||||
'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好
|
transferRecognizer = null;
|
||||||
);
|
labels = [];
|
||||||
|
isPredicting = false;
|
||||||
|
isRecording = false;
|
||||||
|
isModelTrainedFlag = false;
|
||||||
|
if (predictionStopFunction) {
|
||||||
|
predictionStopFunction(); // 停止任何正在进行的预测
|
||||||
|
predictionStopFunction = null;
|
||||||
|
}
|
||||||
|
|
||||||
await recognizer.ensureModelLoaded();
|
try {
|
||||||
|
// 创建基础识别器
|
||||||
|
recognizer = speechCommands.create('BROWSER_FFT');
|
||||||
|
await recognizer.ensureModelLoaded(); // 确保模型加载完成
|
||||||
|
|
||||||
|
// 创建迁移学习识别器
|
||||||
transferRecognizer = recognizer.createTransfer('my-custom-model');
|
transferRecognizer = recognizer.createTransfer('my-custom-model');
|
||||||
|
|
||||||
// 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组
|
// 初始化时,将背景噪音标签加入到我们的内部 labels 数组
|
||||||
labels.push(BACKGROUND_NOISE_LABEL);
|
labels.push(BACKGROUND_NOISE_LABEL);
|
||||||
|
|
||||||
|
console.log('AudioClassifier: 模型和迁移学习器初始化成功');
|
||||||
return Promise.resolve();
|
return Promise.resolve();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return Promise.reject(error);
|
console.error('AudioClassifier: 初始化失败:', error);
|
||||||
|
return Promise.reject(new Error(`模型初始化失败或麦克风无法访问: ${error.message}`));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 批量录制样本的通用函数
|
* 录制单个或批量音频样本。
|
||||||
* @param {string} label - 标签名称
|
* 在录制示例时会暂时禁用 isRecording 标志以防止多重录制。
|
||||||
* @param {number} countToRecord - 要录制的样本数量
|
*
|
||||||
* @returns {Promise<void>}
|
* @param {string} label - 样本所属的类别标签。
|
||||||
|
* @param {number} countToRecord - 要录制样本的数量。
|
||||||
|
* @param {Object} [options] - 其他选项。
|
||||||
|
* @param {Function} [options.onProgress] - (optional) 进度回调函数,参数为 (currentCount, totalCount)。
|
||||||
|
* @returns {Promise<number>} resolve 时返回该类别当前的总样本数量,reject 时返回错误。
|
||||||
*/
|
*/
|
||||||
async function recordMultipleExamples(label, countToRecord = 5) {
|
async function recordMultipleExamples(label, countToRecord = 1, options = {}) {
|
||||||
if (isRecording) {
|
if (isRecording) {
|
||||||
return Promise.reject(new Error('正在录制中,请等待当前录音完成'));
|
return Promise.reject(new Error('核心模块:正在录制中,请等待当前录音完成。'));
|
||||||
}
|
}
|
||||||
|
|
||||||
isRecording = true;
|
isRecording = true;
|
||||||
|
let currentLabelCount = 0;
|
||||||
|
|
||||||
for (let i = 0; i < countToRecord; i++) {
|
|
||||||
try {
|
try {
|
||||||
await transferRecognizer.collectExample(
|
for (let i = 0; i < countToRecord; i++) {
|
||||||
label,
|
// console.log(`AudioClassifier: 正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`);
|
||||||
{ amplitudeRequired: true, durationMillis: recordDuration }
|
await transferRecognizer.collectExample(label, {
|
||||||
);
|
amplitudeRequired: true,
|
||||||
|
durationMillis: recordDuration
|
||||||
|
});
|
||||||
|
|
||||||
|
const exampleCounts = transferRecognizer.countExamples();
|
||||||
|
currentLabelCount = exampleCounts[label] || 0;
|
||||||
|
|
||||||
|
if (options.onProgress && typeof options.onProgress === 'function') {
|
||||||
|
options.onProgress(i + 1, countToRecord, currentLabelCount);
|
||||||
|
}
|
||||||
|
|
||||||
// 在每次录音之间增加短暂延迟,以便更好地分离样本
|
// 在每次录音之间增加短暂延迟,以便更好地分离样本
|
||||||
if (i < countToRecord - 1) {
|
if (i < countToRecord - 1) {
|
||||||
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
|
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
console.log(`AudioClassifier: 已为 "${label}" 收集了 ${currentLabelCount} 个样本。`);
|
||||||
|
return Promise.resolve(currentLabelCount);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
console.error(`AudioClassifier: 录制 "${label}" 样本失败:`, error);
|
||||||
|
return Promise.reject(new Error(`录制 "${label}" 样本失败: ${error.message}`));
|
||||||
|
} finally {
|
||||||
isRecording = false;
|
isRecording = false;
|
||||||
return Promise.reject(error);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
isRecording = false;
|
|
||||||
return Promise.resolve();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 添加自定义类别
|
* 添加一个自定义类别标签到模型。
|
||||||
* @param {string} categoryName - 类别名称
|
* 标签将添加到内部的 `labels` 数组中。
|
||||||
|
* @param {string} categoryName - 要添加的类别名称。
|
||||||
|
* @returns {Promise<void>} resolve 表示成功,reject 表示失败(如名称为空或重复)。
|
||||||
*/
|
*/
|
||||||
function addCustomCategory(categoryName) {
|
function addCustomCategory(categoryName) {
|
||||||
if (!categoryName) {
|
if (!categoryName || categoryName.trim() === '') {
|
||||||
return Promise.reject(new Error('类别名称不能为空'));
|
return Promise.reject(new Error('类别名称不能为空。'));
|
||||||
}
|
}
|
||||||
|
const trimmedName = categoryName.trim();
|
||||||
// 检查是否与现有标签重复
|
if (labels.some(label => label.toLowerCase() === trimmedName.toLowerCase())) {
|
||||||
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
|
return Promise.reject(new Error(`类别 "${trimmedName}" 已经存在。`));
|
||||||
return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`));
|
|
||||||
}
|
}
|
||||||
|
labels.push(trimmedName);
|
||||||
// 将标签添加到本地数组
|
console.log(`AudioClassifier: 已添加新类别: "${trimmedName}"`);
|
||||||
labels.push(categoryName);
|
|
||||||
return Promise.resolve();
|
return Promise.resolve();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 检查训练就绪状态
|
* 检查模型是否已准备好进行训练。
|
||||||
* @returns {boolean} 是否可以开始训练
|
* @returns {Object} 包含 `ready` (boolean) 和 `details` (string) 的对象。
|
||||||
*/
|
*/
|
||||||
function checkTrainingReadiness() {
|
function checkTrainingReadiness() {
|
||||||
const exampleCounts = transferRecognizer.countExamples();
|
let exampleCounts = {};
|
||||||
|
try {
|
||||||
|
exampleCounts = transferRecognizer.countExamples();
|
||||||
|
} catch (error) {
|
||||||
|
// 这是预期内的错误,当没有样本时 countExamples 会抛出。
|
||||||
|
// 将其捕获并视为空样本集处理。
|
||||||
|
console.warn("countExamples() 抛出异常 (无样本时正常行为)。", error.message);
|
||||||
|
}
|
||||||
|
|
||||||
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
|
// 检查是否有任何样本,用于更新导出按钮等逻辑
|
||||||
|
const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0);
|
||||||
|
const hasAnyExamples = totalSamples > 0;
|
||||||
|
|
||||||
let customCategoriesReady = 0;
|
const backgroundNoiseCount = exampleCounts[BACKGROUND_NOISE_LABEL] || 0;
|
||||||
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
|
const backgroundNoiseReady = backgroundNoiseCount > 0;
|
||||||
for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音
|
|
||||||
|
let customCategoriesWithSamples = 0;
|
||||||
|
// 遍历内部 labels 数组,检查每个自定义类别是否有样本
|
||||||
|
// 从索引 1 开始,因为 0 是背景噪音
|
||||||
|
for (let i = 1; i < labels.length; i++) {
|
||||||
const customLabel = labels[i];
|
const customLabel = labels[i];
|
||||||
if ((exampleCounts[customLabel] || 0) > 0) {
|
if ((exampleCounts[customLabel] || 0) > 0) {
|
||||||
customCategoriesReady++;
|
customCategoriesWithSamples++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 必须有背景噪音样本,并且至少一个自定义类别有样本
|
// 训练就绪的条件:必须有背景噪音样本,并且至少一个自定义类别有样本
|
||||||
return backgroundNoiseReady && customCategoriesReady >= 1;
|
const canTrain = backgroundNoiseReady && customCategoriesWithSamples >= 1;
|
||||||
|
|
||||||
|
// 更详细的训练就绪判断,包括样本数量建议
|
||||||
|
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
|
||||||
|
let trainingDetails = [];
|
||||||
|
|
||||||
|
if (!backgroundNoiseReady) {
|
||||||
|
trainingDetails.push('需要录制背景噪音样本。');
|
||||||
|
} else if (backgroundNoiseCount < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||||||
|
trainingDetails.push(`建议背景噪音至少 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本 (当前: ${backgroundNoiseCount})。`);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (customCategoriesWithSamples === 0) {
|
||||||
|
trainingDetails.push('需要至少一个自定义类别并录制样本。');
|
||||||
|
} else {
|
||||||
|
let insufficientCustomSamples = [];
|
||||||
|
for (let i = 1; i < labels.length; i++) {
|
||||||
|
const customLabel = labels[i];
|
||||||
|
const count = exampleCounts[customLabel] || 0;
|
||||||
|
if (count > 0 && count < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||||||
|
insufficientCustomSamples.push(`类别 "${customLabel}" 样本不足 (当前: ${count},建议: ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING})。`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (insufficientCustomSamples.length > 0) {
|
||||||
|
trainingDetails = trainingDetails.concat(insufficientCustomSamples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
ready: canTrain && trainingDetails.length === 0, // 只有完全满足条件才认为是 ready
|
||||||
|
details: trainingDetails.join('\n') || '模型已准备好进行训练。',
|
||||||
|
hasAnyExamples: hasAnyExamples // 额外返回是否有任何样本,供 UI 判断是否启用导出按钮
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 模型训练函数
|
* 训练模型。
|
||||||
* @param {Object} trainingConfig - 训练配置参数
|
* @param {Object} [config] - 训练配置。
|
||||||
* @returns {Promise<void>}
|
* @param {number} [config.epochs=50] - 训练迭代次数。
|
||||||
|
* @param {number} [config.batchSize=16] - 批大小。
|
||||||
|
* @param {number} [config.validationSplit=0.1] - 验证集比例。
|
||||||
|
* @param {boolean} [config.shuffle=true] - 是否打乱数据。
|
||||||
|
* @param {Function} [config.onEpochEnd] - (可选) 每个 epoch 结束时的回调函数,参数为 (epoch, logs)。
|
||||||
|
* @returns {Promise<void>} resolve 表示训练完成,reject 表示训练失败。
|
||||||
*/
|
*/
|
||||||
async function trainModel(trainingConfig = {}) {
|
async function trainModel(config = {}) {
|
||||||
const exampleCounts = transferRecognizer.countExamples();
|
// 再次进行训练前检查,防止外部在不满足条件时调用
|
||||||
|
const readiness = checkTrainingReadiness();
|
||||||
let totalExamples = 0;
|
if (!readiness.ready) {
|
||||||
let validClasses = 0;
|
return Promise.reject(new Error(`模型未准备好训练:\n${readiness.details}`));
|
||||||
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
|
|
||||||
let allClassesHaveEnoughSamples = true;
|
|
||||||
|
|
||||||
// 统计所有类别的有效样本数,并检查每个类别是否达到最低要求
|
|
||||||
for (const labelName of labels) {
|
|
||||||
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
|
|
||||||
totalExamples += exampleCounts[labelName];
|
|
||||||
validClasses++;
|
|
||||||
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
|
||||||
allClassesHaveEnoughSamples = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (validClasses < 2) {
|
|
||||||
return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!allClassesHaveEnoughSamples) {
|
|
||||||
return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalExamples === 0) {
|
|
||||||
return Promise.reject(new Error('没有收集到任何训练样本'));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultConfig = {
|
const defaultConfig = {
|
||||||
@ -149,116 +215,225 @@ async function trainModel(trainingConfig = {}) {
|
|||||||
batchSize: 16,
|
batchSize: 16,
|
||||||
validationSplit: 0.1,
|
validationSplit: 0.1,
|
||||||
shuffle: true,
|
shuffle: true,
|
||||||
yieldEvery: 'epoch'
|
// 外部传入的回调函数将在 train 内部被调用
|
||||||
|
callbacks: {
|
||||||
|
onEpochEnd: (epoch, logs) => {
|
||||||
|
if (config.onEpochEnd && typeof config.onEpochEnd === 'function') {
|
||||||
|
config.onEpochEnd(epoch, logs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const config = Object.assign(defaultConfig, trainingConfig);
|
// 合并默认配置和用户指定的配置
|
||||||
|
const trainingConfig = { ...defaultConfig, ...config };
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await transferRecognizer.train(config);
|
await transferRecognizer.train(trainingConfig);
|
||||||
isModelTrainedFlag = true;
|
isModelTrainedFlag = true;
|
||||||
|
console.log('AudioClassifier: 模型训练完成。');
|
||||||
return Promise.resolve();
|
return Promise.resolve();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
isModelTrainedFlag = false;
|
isModelTrainedFlag = false;
|
||||||
return Promise.reject(error);
|
console.error('AudioClassifier: 模型训练失败:', error);
|
||||||
|
return Promise.reject(new Error(`模型训练失败: ${error.message}. 请确保有足够的样本且类别均衡。`));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 开始实时预测
|
* 开始实时预测。
|
||||||
* @param {Function} onPrediction - 预测结果回调函数
|
* @param {Function} onPrediction - 预测结果回调函数,参数为 {label: string, score: number, scores: number[], labels: string[]}。
|
||||||
* @param {Object} listenOptions - 监听选项
|
* @param {Object} [listenOptions] - 监听选项,直接传递给 transferRecognizer.listen()。
|
||||||
* @returns {Promise<Function>} 停止预测的函数
|
* @returns {Promise<Function>} resolve 时返回停止预测的函数,reject 时返回错误。
|
||||||
*/
|
*/
|
||||||
async function startPrediction(onPrediction, listenOptions = {}) {
|
async function startPrediction(onPrediction, listenOptions = {}) {
|
||||||
if (isPredicting) {
|
if (isPredicting) {
|
||||||
return Promise.reject(new Error('识别已经在进行中'));
|
return Promise.reject(new Error('核心模块:识别已经在进行中。'));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isModelTrainedFlag) {
|
if (!isModelTrainedFlag) {
|
||||||
return Promise.reject(new Error('模型尚未训练完成'));
|
return Promise.reject(new Error('核心模块:模型尚未训练完成,请先训练模型。'));
|
||||||
}
|
}
|
||||||
|
|
||||||
isPredicting = true;
|
isPredicting = true;
|
||||||
|
|
||||||
const defaultOptions = {
|
const defaultOptions = {
|
||||||
includeEmbedding: true,
|
includeEmbedding: true,
|
||||||
probabilityThreshold: 0.75,
|
probabilityThreshold: 0.75, // 预测置信度阈值
|
||||||
suppressionTimeMillis: 300,
|
suppressionTimeMillis: 300, // 抑制时间,避免连续触发
|
||||||
overlapFactor: 0.50,
|
overlapFactor: 0.50, // 帧重叠因子
|
||||||
};
|
};
|
||||||
|
|
||||||
const options = Object.assign(defaultOptions, listenOptions);
|
const options = { ...defaultOptions, ...listenOptions };
|
||||||
|
|
||||||
|
try {
|
||||||
predictionStopFunction = await transferRecognizer.listen(result => {
|
predictionStopFunction = await transferRecognizer.listen(result => {
|
||||||
if (!isPredicting) return;
|
if (!isPredicting) return; // 确保在停止后不再处理结果
|
||||||
|
|
||||||
const classLabels = transferRecognizer.wordLabels();
|
const classLabels = transferRecognizer.wordLabels(); // 获取模型内部的标签列表
|
||||||
const scores = result.scores;
|
const scores = result.scores; // 预测分数
|
||||||
const maxScore = Math.max(...scores);
|
const maxScore = Math.max(...scores);
|
||||||
const predictedIndex = scores.indexOf(maxScore);
|
const predictedIndex = scores.indexOf(maxScore);
|
||||||
|
|
||||||
let predictedLabel = classLabels[predictedIndex];
|
let predictedLabel = classLabels[predictedIndex];
|
||||||
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
|
// 对背景噪音标签进行友好化处理
|
||||||
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
|
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
|
||||||
predictedLabel = '背景噪音';
|
predictedLabel = '背景噪音';
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用回调函数返回预测结果
|
// 调用外部传入的回调函数,传递结构化的预测结果
|
||||||
if (typeof onPrediction === 'function') {
|
if (typeof onPrediction === 'function') {
|
||||||
onPrediction({
|
onPrediction({
|
||||||
label: predictedLabel,
|
label: predictedLabel,
|
||||||
score: maxScore,
|
score: maxScore,
|
||||||
scores: scores,
|
scores: scores, // 原始分数数组
|
||||||
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label)
|
predictedIndex: predictedIndex, // 预测索引
|
||||||
|
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) // 友好化的标签列表
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
}, options);
|
}, options);
|
||||||
|
|
||||||
return Promise.resolve(predictionStopFunction);
|
console.log('AudioClassifier: 开始识别。');
|
||||||
|
return Promise.resolve(predictionStopFunction); // 返回停止函数给调用者
|
||||||
|
} catch (error) {
|
||||||
|
isPredicting = false;
|
||||||
|
console.error('AudioClassifier: 开始识别失败:', error);
|
||||||
|
return Promise.reject(new Error(`开始识别失败: ${error.message}`));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 停止实时预测
|
* 停止实时预测。
|
||||||
*/
|
*/
|
||||||
function stopPrediction() {
|
function stopPrediction() {
|
||||||
if (isPredicting) {
|
if (isPredicting) {
|
||||||
if (typeof predictionStopFunction === 'function') {
|
if (typeof predictionStopFunction === 'function') {
|
||||||
predictionStopFunction();
|
predictionStopFunction(); // 调用 SpeechCommands 库返回的停止函数
|
||||||
predictionStopFunction = null;
|
predictionStopFunction = null;
|
||||||
|
} else {
|
||||||
|
console.warn('AudioClassifier: 停止预测函数未定义或不是函数。');
|
||||||
}
|
}
|
||||||
|
|
||||||
isPredicting = false;
|
isPredicting = false;
|
||||||
|
console.log('AudioClassifier: 已停止识别。');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取各类别样本数量
|
* 获取当前收集到的各类别样本数量。
|
||||||
* @returns {Object} 各类别样本数量统计
|
* @returns {Object} 键为类别名称,值为样本数量。
|
||||||
*/
|
*/
|
||||||
function getExampleCounts() {
|
function getExampleCounts() {
|
||||||
|
try {
|
||||||
return transferRecognizer.countExamples();
|
return transferRecognizer.countExamples();
|
||||||
|
} catch (error) {
|
||||||
|
// 捕获无样本时的错误,返回空对象
|
||||||
|
return {};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取模型是否已训练的状态
|
* 获取模型是否已训练的状态。
|
||||||
* @returns {boolean} 模型是否已训练
|
* @returns {boolean} 如果模型已训练则返回 true,否则 false。
|
||||||
*/
|
*/
|
||||||
function isModelTrained() {
|
function isModelTrained() {
|
||||||
return isModelTrainedFlag;
|
return isModelTrainedFlag;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 导出公共接口
|
/**
|
||||||
|
* 获取当前模块中所有标签的数组(包括背景噪音和自定义标签)。
|
||||||
|
* @returns {string[]} 标签数组。
|
||||||
|
*/
|
||||||
|
function getLabels() {
|
||||||
|
return [...labels]; // 返回副本,避免外部直接修改内部状态
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 序列化所有收集到的样本数据到 ArrayBuffer。
|
||||||
|
* @returns {Promise<ArrayBuffer>} 包含样本数据的 ArrayBuffer。
|
||||||
|
*/
|
||||||
|
async function serializeExamples() {
|
||||||
|
try {
|
||||||
|
const serialized = await transferRecognizer.serializeExamples();
|
||||||
|
console.log('AudioClassifier: 样本数据序列化成功。');
|
||||||
|
return Promise.resolve(serialized);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('AudioClassifier: 序列化样本失败:', error);
|
||||||
|
return Promise.reject(new Error(`序列化样本失败: ${error.message}`));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从 ArrayBuffer 导入样本数据。
|
||||||
|
* 导入后会更新内部的 labels 数组和模型状态。
|
||||||
|
* @param {ArrayBuffer} dataBuffer - 包含样本数据的 ArrayBuffer。
|
||||||
|
* @returns {Promise<Object>} resolve 时返回导入后的样本总数统计,reject 时返回错误。
|
||||||
|
*/
|
||||||
|
async function loadExamples(dataBuffer) {
|
||||||
|
if (!dataBuffer instanceof ArrayBuffer) {
|
||||||
|
return Promise.reject(new Error('核心模块:导入数据必须是 ArrayBuffer 类型。'));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 首先清空当前的 transferRecognizer 中的所有样本,再加载新的。
|
||||||
|
// 这等同于创建一个新的 transferRecognizer 实例,并同步其内部状态。
|
||||||
|
// SpeechCommands库没有直接的`clearExamples`或`resetExamples`方法,
|
||||||
|
// 最稳妥的方法是重新创建transferRecognizer并加载。
|
||||||
|
// 但loadExamples本身就是设计来覆盖和更新现有样本的,
|
||||||
|
// 考虑到性能,我们假设直接loadExamples即可。
|
||||||
|
// 注意:loadExamples会清除原有数据。
|
||||||
|
|
||||||
|
try {
|
||||||
|
await transferRecognizer.loadExamples(dataBuffer);
|
||||||
|
|
||||||
|
// 成功加载后,需要更新内部的 `labels` 数组,以匹配导入的数据。
|
||||||
|
// `transferRecognizer.wordLabels()` 返回的是模型内部的实际标签。
|
||||||
|
const loadedLabels = transferRecognizer.wordLabels();
|
||||||
|
// 确保 BACKGROUND_NOISE_LABEL 还在第一个位置,并过滤掉 '_version_', '_numExamples_' 等内部标签
|
||||||
|
labels = loadedLabels.filter(label => label !== '_version_' && label !== '_numExamples_');
|
||||||
|
|
||||||
|
if (!labels.includes(BACKGROUND_NOISE_LABEL)) {
|
||||||
|
// 如果导入的数据中没有背景噪音(不常见,但作为健壮性处理)
|
||||||
|
labels.unshift(BACKGROUND_NOISE_LABEL);
|
||||||
|
} else {
|
||||||
|
// 确保背景噪音在第一位
|
||||||
|
const bgIndex = labels.indexOf(BACKGROUND_NOISE_LABEL);
|
||||||
|
if (bgIndex > 0) {
|
||||||
|
const bgLabel = labels.splice(bgIndex, 1);
|
||||||
|
labels.unshift(bgLabel[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重置模型的训练状态,因为样本已更改,需要重新训练
|
||||||
|
isModelTrainedFlag = false;
|
||||||
|
|
||||||
|
const exampleCounts = transferRecognizer.countExamples();
|
||||||
|
console.log('AudioClassifier: 样本数据导入成功。', exampleCounts);
|
||||||
|
return Promise.resolve(exampleCounts);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('AudioClassifier: 加载样本失败:', error);
|
||||||
|
return Promise.reject(new Error(`加载样本失败: ${error.message}. 请确保文件有效且未损坏。`));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ======================= 导出的公共接口 =======================
|
||||||
|
// 这里使用 CommonJS 风格的 module.exports,因为这通常用于 Node.js 环境或 Webpack/Rollup 等打包工具。
|
||||||
|
// 如果你是在浏览器中直接作为 <script> 标签引入(没有打包工具),
|
||||||
|
// 并且希望它在全局作用域下可用,可以使用 `window.AudioClassifier = {...}`。
|
||||||
|
// 在这种情况下,我将提供 `window.AudioClassifier` 的方式。
|
||||||
|
|
||||||
|
// 为了 Web 浏览器环境,将核心功能挂载到全局对象上
|
||||||
window.AudioClassifier = {
|
window.AudioClassifier = {
|
||||||
init,
|
init, // 初始化模型
|
||||||
recordMultipleExamples,
|
recordMultipleExamples, // 录制样本
|
||||||
addCustomCategory,
|
addCustomCategory, // 添加自定义类别
|
||||||
checkTrainingReadiness,
|
checkTrainingReadiness, // 检查模型是否可训练
|
||||||
trainModel,
|
trainModel, // 训练模型
|
||||||
startPrediction,
|
startPrediction, // 开始预测
|
||||||
stopPrediction,
|
stopPrediction, // 停止预测
|
||||||
getExampleCounts,
|
getExampleCounts, // 获取所有类别样本数量
|
||||||
isModelTrained,
|
isModelTrained, // 检查模型是否已训练
|
||||||
labels
|
getLabels, // 获取当前所有标签
|
||||||
|
serializeExamples, // 导出样本数据
|
||||||
|
loadExamples // 导入样本数据
|
||||||
};
|
};
|
Loading…
x
Reference in New Issue
Block a user