[MF]修改audioClassifier.js (核心功能模块)
This commit is contained in:
parent
5d36541dc5
commit
37dc1c5a76
@ -1,147 +1,213 @@
|
||||
// 全局变量和模型实例
|
||||
let recognizer; // 基础的 SpeechCommands recognizer
|
||||
let transferRecognizer; // 用于迁移学习的 recognizer
|
||||
const labels = []; // 用户定义的类别标签数组 (包括背景噪音)
|
||||
// 将背景噪音定义为第一个类别,其内部名称为 _background_noise_
|
||||
const BACKGROUND_NOISE_LABEL = '_background_noise_';
|
||||
const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引,不直接用于collectExample
|
||||
// audioClassifier.js (核心功能模块)
|
||||
|
||||
// TensorFlow.js 和 Speech Commands 库不需要在这里再次导入,
|
||||
// 假定在使用此模块的 HTML 页面中已经通过 <script> 标签引入。
|
||||
|
||||
// ======================= 模块内部共享变量 =======================
|
||||
let recognizer; // 基础的 SpeechCommands recognizer 实例
|
||||
let transferRecognizer; // 用于迁移学习的 recognizer 实例
|
||||
|
||||
// 用户定义的类别标签数组 (包括背景噪音)。
|
||||
// 这是模块的核心状态之一,外部可以通过 getLabels() 获取。
|
||||
let labels = [];
|
||||
const BACKGROUND_NOISE_LABEL = '_background_noise_'; // 内部使用的背景噪音标签
|
||||
|
||||
let isPredicting = false; // 预测状态标志
|
||||
let isRecording = false; // 录音状态标志,防止重复点击
|
||||
const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
|
||||
|
||||
let isModelTrainedFlag = false; // 手动维护模型训练状态
|
||||
let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
|
||||
|
||||
|
||||
// ======================= 核心功能函数 =======================
|
||||
|
||||
/**
|
||||
* 初始化函数 - 加载模型和创建迁移学习模型
|
||||
* @returns {Promise<void>}
|
||||
* 初始化音频分类器模型。
|
||||
* 必须在使用任何其他功能之前调用。
|
||||
* @returns {Promise<void>} resolve 表示成功,reject 表示失败。
|
||||
*/
|
||||
async function init() {
|
||||
try {
|
||||
recognizer = speechCommands.create(
|
||||
'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好
|
||||
);
|
||||
// 确保每次初始化时重置状态
|
||||
recognizer = null;
|
||||
transferRecognizer = null;
|
||||
labels = [];
|
||||
isPredicting = false;
|
||||
isRecording = false;
|
||||
isModelTrainedFlag = false;
|
||||
if (predictionStopFunction) {
|
||||
predictionStopFunction(); // 停止任何正在进行的预测
|
||||
predictionStopFunction = null;
|
||||
}
|
||||
|
||||
await recognizer.ensureModelLoaded();
|
||||
try {
|
||||
// 创建基础识别器
|
||||
recognizer = speechCommands.create('BROWSER_FFT');
|
||||
await recognizer.ensureModelLoaded(); // 确保模型加载完成
|
||||
|
||||
// 创建迁移学习识别器
|
||||
transferRecognizer = recognizer.createTransfer('my-custom-model');
|
||||
|
||||
// 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组
|
||||
// 初始化时,将背景噪音标签加入到我们的内部 labels 数组
|
||||
labels.push(BACKGROUND_NOISE_LABEL);
|
||||
|
||||
console.log('AudioClassifier: 模型和迁移学习器初始化成功');
|
||||
return Promise.resolve();
|
||||
} catch (error) {
|
||||
return Promise.reject(error);
|
||||
console.error('AudioClassifier: 初始化失败:', error);
|
||||
return Promise.reject(new Error(`模型初始化失败或麦克风无法访问: ${error.message}`));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量录制样本的通用函数
|
||||
* @param {string} label - 标签名称
|
||||
* @param {number} countToRecord - 要录制的样本数量
|
||||
* @returns {Promise<void>}
|
||||
* 录制单个或批量音频样本。
|
||||
* 在录制示例时会暂时禁用 isRecording 标志以防止多重录制。
|
||||
*
|
||||
* @param {string} label - 样本所属的类别标签。
|
||||
* @param {number} countToRecord - 要录制样本的数量。
|
||||
* @param {Object} [options] - 其他选项。
|
||||
* @param {Function} [options.onProgress] - (optional) 进度回调函数,参数为 (currentCount, totalCount)。
|
||||
* @returns {Promise<number>} resolve 时返回该类别当前的总样本数量,reject 时返回错误。
|
||||
*/
|
||||
async function recordMultipleExamples(label, countToRecord = 5) {
|
||||
async function recordMultipleExamples(label, countToRecord = 1, options = {}) {
|
||||
if (isRecording) {
|
||||
return Promise.reject(new Error('正在录制中,请等待当前录音完成'));
|
||||
return Promise.reject(new Error('核心模块:正在录制中,请等待当前录音完成。'));
|
||||
}
|
||||
|
||||
isRecording = true;
|
||||
|
||||
for (let i = 0; i < countToRecord; i++) {
|
||||
try {
|
||||
await transferRecognizer.collectExample(
|
||||
label,
|
||||
{ amplitudeRequired: true, durationMillis: recordDuration }
|
||||
);
|
||||
let currentLabelCount = 0;
|
||||
|
||||
try {
|
||||
for (let i = 0; i < countToRecord; i++) {
|
||||
// console.log(`AudioClassifier: 正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`);
|
||||
await transferRecognizer.collectExample(label, {
|
||||
amplitudeRequired: true,
|
||||
durationMillis: recordDuration
|
||||
});
|
||||
|
||||
const exampleCounts = transferRecognizer.countExamples();
|
||||
currentLabelCount = exampleCounts[label] || 0;
|
||||
|
||||
if (options.onProgress && typeof options.onProgress === 'function') {
|
||||
options.onProgress(i + 1, countToRecord, currentLabelCount);
|
||||
}
|
||||
|
||||
// 在每次录音之间增加短暂延迟,以便更好地分离样本
|
||||
if (i < countToRecord - 1) {
|
||||
await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5)));
|
||||
}
|
||||
} catch (error) {
|
||||
isRecording = false;
|
||||
return Promise.reject(error);
|
||||
}
|
||||
console.log(`AudioClassifier: 已为 "${label}" 收集了 ${currentLabelCount} 个样本。`);
|
||||
return Promise.resolve(currentLabelCount);
|
||||
} catch (error) {
|
||||
console.error(`AudioClassifier: 录制 "${label}" 样本失败:`, error);
|
||||
return Promise.reject(new Error(`录制 "${label}" 样本失败: ${error.message}`));
|
||||
} finally {
|
||||
isRecording = false;
|
||||
}
|
||||
|
||||
isRecording = false;
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加自定义类别
|
||||
* @param {string} categoryName - 类别名称
|
||||
* 添加一个自定义类别标签到模型。
|
||||
* 标签将添加到内部的 `labels` 数组中。
|
||||
* @param {string} categoryName - 要添加的类别名称。
|
||||
* @returns {Promise<void>} resolve 表示成功,reject 表示失败(如名称为空或重复)。
|
||||
*/
|
||||
function addCustomCategory(categoryName) {
|
||||
if (!categoryName) {
|
||||
return Promise.reject(new Error('类别名称不能为空'));
|
||||
if (!categoryName || categoryName.trim() === '') {
|
||||
return Promise.reject(new Error('类别名称不能为空。'));
|
||||
}
|
||||
|
||||
// 检查是否与现有标签重复
|
||||
if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) {
|
||||
return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`));
|
||||
const trimmedName = categoryName.trim();
|
||||
if (labels.some(label => label.toLowerCase() === trimmedName.toLowerCase())) {
|
||||
return Promise.reject(new Error(`类别 "${trimmedName}" 已经存在。`));
|
||||
}
|
||||
|
||||
// 将标签添加到本地数组
|
||||
labels.push(categoryName);
|
||||
labels.push(trimmedName);
|
||||
console.log(`AudioClassifier: 已添加新类别: "${trimmedName}"`);
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查训练就绪状态
|
||||
* @returns {boolean} 是否可以开始训练
|
||||
* 检查模型是否已准备好进行训练。
|
||||
* @returns {Object} 包含 `ready` (boolean) 和 `details` (string) 的对象。
|
||||
*/
|
||||
function checkTrainingReadiness() {
|
||||
const exampleCounts = transferRecognizer.countExamples();
|
||||
let exampleCounts = {};
|
||||
try {
|
||||
exampleCounts = transferRecognizer.countExamples();
|
||||
} catch (error) {
|
||||
// 这是预期内的错误,当没有样本时 countExamples 会抛出。
|
||||
// 将其捕获并视为空样本集处理。
|
||||
console.warn("countExamples() 抛出异常 (无样本时正常行为)。", error.message);
|
||||
}
|
||||
|
||||
let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0;
|
||||
// 检查是否有任何样本,用于更新导出按钮等逻辑
|
||||
const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0);
|
||||
const hasAnyExamples = totalSamples > 0;
|
||||
|
||||
const backgroundNoiseCount = exampleCounts[BACKGROUND_NOISE_LABEL] || 0;
|
||||
const backgroundNoiseReady = backgroundNoiseCount > 0;
|
||||
|
||||
let customCategoriesReady = 0;
|
||||
// 遍历本地 labels 数组,检查每个自定义类别是否有样本
|
||||
for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音
|
||||
let customCategoriesWithSamples = 0;
|
||||
// 遍历内部 labels 数组,检查每个自定义类别是否有样本
|
||||
// 从索引 1 开始,因为 0 是背景噪音
|
||||
for (let i = 1; i < labels.length; i++) {
|
||||
const customLabel = labels[i];
|
||||
if ((exampleCounts[customLabel] || 0) > 0) {
|
||||
customCategoriesReady++;
|
||||
customCategoriesWithSamples++;
|
||||
}
|
||||
}
|
||||
|
||||
// 必须有背景噪音样本,并且至少一个自定义类别有样本
|
||||
return backgroundNoiseReady && customCategoriesReady >= 1;
|
||||
// 训练就绪的条件:必须有背景噪音样本,并且至少一个自定义类别有样本
|
||||
const canTrain = backgroundNoiseReady && customCategoriesWithSamples >= 1;
|
||||
|
||||
// 更详细的训练就绪判断,包括样本数量建议
|
||||
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
|
||||
let trainingDetails = [];
|
||||
|
||||
if (!backgroundNoiseReady) {
|
||||
trainingDetails.push('需要录制背景噪音样本。');
|
||||
} else if (backgroundNoiseCount < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||||
trainingDetails.push(`建议背景噪音至少 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本 (当前: ${backgroundNoiseCount})。`);
|
||||
}
|
||||
|
||||
if (customCategoriesWithSamples === 0) {
|
||||
trainingDetails.push('需要至少一个自定义类别并录制样本。');
|
||||
} else {
|
||||
let insufficientCustomSamples = [];
|
||||
for (let i = 1; i < labels.length; i++) {
|
||||
const customLabel = labels[i];
|
||||
const count = exampleCounts[customLabel] || 0;
|
||||
if (count > 0 && count < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||||
insufficientCustomSamples.push(`类别 "${customLabel}" 样本不足 (当前: ${count},建议: ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING})。`);
|
||||
}
|
||||
}
|
||||
if (insufficientCustomSamples.length > 0) {
|
||||
trainingDetails = trainingDetails.concat(insufficientCustomSamples);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
ready: canTrain && trainingDetails.length === 0, // 只有完全满足条件才认为是 ready
|
||||
details: trainingDetails.join('\n') || '模型已准备好进行训练。',
|
||||
hasAnyExamples: hasAnyExamples // 额外返回是否有任何样本,供 UI 判断是否启用导出按钮
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 模型训练函数
|
||||
* @param {Object} trainingConfig - 训练配置参数
|
||||
* @returns {Promise<void>}
|
||||
* 训练模型。
|
||||
* @param {Object} [config] - 训练配置。
|
||||
* @param {number} [config.epochs=50] - 训练迭代次数。
|
||||
* @param {number} [config.batchSize=16] - 批大小。
|
||||
* @param {number} [config.validationSplit=0.1] - 验证集比例。
|
||||
* @param {boolean} [config.shuffle=true] - 是否打乱数据。
|
||||
* @param {Function} [config.onEpochEnd] - (可选) 每个 epoch 结束时的回调函数,参数为 (epoch, logs)。
|
||||
* @returns {Promise<void>} resolve 表示训练完成,reject 表示训练失败。
|
||||
*/
|
||||
async function trainModel(trainingConfig = {}) {
|
||||
const exampleCounts = transferRecognizer.countExamples();
|
||||
|
||||
let totalExamples = 0;
|
||||
let validClasses = 0;
|
||||
const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;
|
||||
let allClassesHaveEnoughSamples = true;
|
||||
|
||||
// 统计所有类别的有效样本数,并检查每个类别是否达到最低要求
|
||||
for (const labelName of labels) {
|
||||
if (exampleCounts[labelName] && exampleCounts[labelName] > 0) {
|
||||
totalExamples += exampleCounts[labelName];
|
||||
validClasses++;
|
||||
if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) {
|
||||
allClassesHaveEnoughSamples = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (validClasses < 2) {
|
||||
return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`));
|
||||
}
|
||||
|
||||
if (!allClassesHaveEnoughSamples) {
|
||||
return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`));
|
||||
}
|
||||
|
||||
if (totalExamples === 0) {
|
||||
return Promise.reject(new Error('没有收集到任何训练样本'));
|
||||
async function trainModel(config = {}) {
|
||||
// 再次进行训练前检查,防止外部在不满足条件时调用
|
||||
const readiness = checkTrainingReadiness();
|
||||
if (!readiness.ready) {
|
||||
return Promise.reject(new Error(`模型未准备好训练:\n${readiness.details}`));
|
||||
}
|
||||
|
||||
const defaultConfig = {
|
||||
@ -149,116 +215,225 @@ async function trainModel(trainingConfig = {}) {
|
||||
batchSize: 16,
|
||||
validationSplit: 0.1,
|
||||
shuffle: true,
|
||||
yieldEvery: 'epoch'
|
||||
// 外部传入的回调函数将在 train 内部被调用
|
||||
callbacks: {
|
||||
onEpochEnd: (epoch, logs) => {
|
||||
if (config.onEpochEnd && typeof config.onEpochEnd === 'function') {
|
||||
config.onEpochEnd(epoch, logs);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const config = Object.assign(defaultConfig, trainingConfig);
|
||||
// 合并默认配置和用户指定的配置
|
||||
const trainingConfig = { ...defaultConfig, ...config };
|
||||
|
||||
try {
|
||||
await transferRecognizer.train(config);
|
||||
await transferRecognizer.train(trainingConfig);
|
||||
isModelTrainedFlag = true;
|
||||
console.log('AudioClassifier: 模型训练完成。');
|
||||
return Promise.resolve();
|
||||
} catch (error) {
|
||||
isModelTrainedFlag = false;
|
||||
return Promise.reject(error);
|
||||
console.error('AudioClassifier: 模型训练失败:', error);
|
||||
return Promise.reject(new Error(`模型训练失败: ${error.message}. 请确保有足够的样本且类别均衡。`));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 开始实时预测
|
||||
* @param {Function} onPrediction - 预测结果回调函数
|
||||
* @param {Object} listenOptions - 监听选项
|
||||
* @returns {Promise<Function>} 停止预测的函数
|
||||
* 开始实时预测。
|
||||
* @param {Function} onPrediction - 预测结果回调函数,参数为 {label: string, score: number, scores: number[], labels: string[]}。
|
||||
* @param {Object} [listenOptions] - 监听选项,直接传递给 transferRecognizer.listen()。
|
||||
* @returns {Promise<Function>} resolve 时返回停止预测的函数,reject 时返回错误。
|
||||
*/
|
||||
async function startPrediction(onPrediction, listenOptions = {}) {
|
||||
if (isPredicting) {
|
||||
return Promise.reject(new Error('识别已经在进行中'));
|
||||
return Promise.reject(new Error('核心模块:识别已经在进行中。'));
|
||||
}
|
||||
|
||||
if (!isModelTrainedFlag) {
|
||||
return Promise.reject(new Error('模型尚未训练完成'));
|
||||
return Promise.reject(new Error('核心模块:模型尚未训练完成,请先训练模型。'));
|
||||
}
|
||||
|
||||
isPredicting = true;
|
||||
|
||||
const defaultOptions = {
|
||||
includeEmbedding: true,
|
||||
probabilityThreshold: 0.75,
|
||||
suppressionTimeMillis: 300,
|
||||
overlapFactor: 0.50,
|
||||
probabilityThreshold: 0.75, // 预测置信度阈值
|
||||
suppressionTimeMillis: 300, // 抑制时间,避免连续触发
|
||||
overlapFactor: 0.50, // 帧重叠因子
|
||||
};
|
||||
|
||||
const options = Object.assign(defaultOptions, listenOptions);
|
||||
const options = { ...defaultOptions, ...listenOptions };
|
||||
|
||||
predictionStopFunction = await transferRecognizer.listen(result => {
|
||||
if (!isPredicting) return;
|
||||
try {
|
||||
predictionStopFunction = await transferRecognizer.listen(result => {
|
||||
if (!isPredicting) return; // 确保在停止后不再处理结果
|
||||
|
||||
const classLabels = transferRecognizer.wordLabels();
|
||||
const scores = result.scores;
|
||||
const maxScore = Math.max(...scores);
|
||||
const predictedIndex = scores.indexOf(maxScore);
|
||||
|
||||
let predictedLabel = classLabels[predictedIndex];
|
||||
// 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
|
||||
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
|
||||
predictedLabel = '背景噪音';
|
||||
}
|
||||
|
||||
// 调用回调函数返回预测结果
|
||||
if (typeof onPrediction === 'function') {
|
||||
onPrediction({
|
||||
label: predictedLabel,
|
||||
score: maxScore,
|
||||
scores: scores,
|
||||
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label)
|
||||
});
|
||||
}
|
||||
const classLabels = transferRecognizer.wordLabels(); // 获取模型内部的标签列表
|
||||
const scores = result.scores; // 预测分数
|
||||
const maxScore = Math.max(...scores);
|
||||
const predictedIndex = scores.indexOf(maxScore);
|
||||
|
||||
let predictedLabel = classLabels[predictedIndex];
|
||||
// 对背景噪音标签进行友好化处理
|
||||
if (predictedLabel === BACKGROUND_NOISE_LABEL) {
|
||||
predictedLabel = '背景噪音';
|
||||
}
|
||||
|
||||
// 调用外部传入的回调函数,传递结构化的预测结果
|
||||
if (typeof onPrediction === 'function') {
|
||||
onPrediction({
|
||||
label: predictedLabel,
|
||||
score: maxScore,
|
||||
scores: scores, // 原始分数数组
|
||||
predictedIndex: predictedIndex, // 预测索引
|
||||
labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) // 友好化的标签列表
|
||||
});
|
||||
}
|
||||
|
||||
}, options);
|
||||
}, options);
|
||||
|
||||
return Promise.resolve(predictionStopFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* 停止实时预测
|
||||
*/
|
||||
function stopPrediction() {
|
||||
if (isPredicting) {
|
||||
if (typeof predictionStopFunction === 'function') {
|
||||
predictionStopFunction();
|
||||
predictionStopFunction = null;
|
||||
}
|
||||
|
||||
console.log('AudioClassifier: 开始识别。');
|
||||
return Promise.resolve(predictionStopFunction); // 返回停止函数给调用者
|
||||
} catch (error) {
|
||||
isPredicting = false;
|
||||
console.error('AudioClassifier: 开始识别失败:', error);
|
||||
return Promise.reject(new Error(`开始识别失败: ${error.message}`));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取各类别样本数量
|
||||
* @returns {Object} 各类别样本数量统计
|
||||
* 停止实时预测。
|
||||
*/
|
||||
function getExampleCounts() {
|
||||
return transferRecognizer.countExamples();
|
||||
function stopPrediction() {
|
||||
if (isPredicting) {
|
||||
if (typeof predictionStopFunction === 'function') {
|
||||
predictionStopFunction(); // 调用 SpeechCommands 库返回的停止函数
|
||||
predictionStopFunction = null;
|
||||
} else {
|
||||
console.warn('AudioClassifier: 停止预测函数未定义或不是函数。');
|
||||
}
|
||||
isPredicting = false;
|
||||
console.log('AudioClassifier: 已停止识别。');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型是否已训练的状态
|
||||
* @returns {boolean} 模型是否已训练
|
||||
* 获取当前收集到的各类别样本数量。
|
||||
* @returns {Object} 键为类别名称,值为样本数量。
|
||||
*/
|
||||
function getExampleCounts() {
|
||||
try {
|
||||
return transferRecognizer.countExamples();
|
||||
} catch (error) {
|
||||
// 捕获无样本时的错误,返回空对象
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型是否已训练的状态。
|
||||
* @returns {boolean} 如果模型已训练则返回 true,否则 false。
|
||||
*/
|
||||
function isModelTrained() {
|
||||
return isModelTrainedFlag;
|
||||
}
|
||||
|
||||
// 导出公共接口
|
||||
/**
|
||||
* 获取当前模块中所有标签的数组(包括背景噪音和自定义标签)。
|
||||
* @returns {string[]} 标签数组。
|
||||
*/
|
||||
function getLabels() {
|
||||
return [...labels]; // 返回副本,避免外部直接修改内部状态
|
||||
}
|
||||
|
||||
/**
|
||||
* 序列化所有收集到的样本数据到 ArrayBuffer。
|
||||
* @returns {Promise<ArrayBuffer>} 包含样本数据的 ArrayBuffer。
|
||||
*/
|
||||
async function serializeExamples() {
|
||||
try {
|
||||
const serialized = await transferRecognizer.serializeExamples();
|
||||
console.log('AudioClassifier: 样本数据序列化成功。');
|
||||
return Promise.resolve(serialized);
|
||||
} catch (error) {
|
||||
console.error('AudioClassifier: 序列化样本失败:', error);
|
||||
return Promise.reject(new Error(`序列化样本失败: ${error.message}`));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 ArrayBuffer 导入样本数据。
|
||||
* 导入后会更新内部的 labels 数组和模型状态。
|
||||
* @param {ArrayBuffer} dataBuffer - 包含样本数据的 ArrayBuffer。
|
||||
* @returns {Promise<Object>} resolve 时返回导入后的样本总数统计,reject 时返回错误。
|
||||
*/
|
||||
async function loadExamples(dataBuffer) {
|
||||
if (!dataBuffer instanceof ArrayBuffer) {
|
||||
return Promise.reject(new Error('核心模块:导入数据必须是 ArrayBuffer 类型。'));
|
||||
}
|
||||
|
||||
// 首先清空当前的 transferRecognizer 中的所有样本,再加载新的。
|
||||
// 这等同于创建一个新的 transferRecognizer 实例,并同步其内部状态。
|
||||
// SpeechCommands库没有直接的`clearExamples`或`resetExamples`方法,
|
||||
// 最稳妥的方法是重新创建transferRecognizer并加载。
|
||||
// 但loadExamples本身就是设计来覆盖和更新现有样本的,
|
||||
// 考虑到性能,我们假设直接loadExamples即可。
|
||||
// 注意:loadExamples会清除原有数据。
|
||||
|
||||
try {
|
||||
await transferRecognizer.loadExamples(dataBuffer);
|
||||
|
||||
// 成功加载后,需要更新内部的 `labels` 数组,以匹配导入的数据。
|
||||
// `transferRecognizer.wordLabels()` 返回的是模型内部的实际标签。
|
||||
const loadedLabels = transferRecognizer.wordLabels();
|
||||
// 确保 BACKGROUND_NOISE_LABEL 还在第一个位置,并过滤掉 '_version_', '_numExamples_' 等内部标签
|
||||
labels = loadedLabels.filter(label => label !== '_version_' && label !== '_numExamples_');
|
||||
|
||||
if (!labels.includes(BACKGROUND_NOISE_LABEL)) {
|
||||
// 如果导入的数据中没有背景噪音(不常见,但作为健壮性处理)
|
||||
labels.unshift(BACKGROUND_NOISE_LABEL);
|
||||
} else {
|
||||
// 确保背景噪音在第一位
|
||||
const bgIndex = labels.indexOf(BACKGROUND_NOISE_LABEL);
|
||||
if (bgIndex > 0) {
|
||||
const bgLabel = labels.splice(bgIndex, 1);
|
||||
labels.unshift(bgLabel[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// 重置模型的训练状态,因为样本已更改,需要重新训练
|
||||
isModelTrainedFlag = false;
|
||||
|
||||
const exampleCounts = transferRecognizer.countExamples();
|
||||
console.log('AudioClassifier: 样本数据导入成功。', exampleCounts);
|
||||
return Promise.resolve(exampleCounts);
|
||||
} catch (error) {
|
||||
console.error('AudioClassifier: 加载样本失败:', error);
|
||||
return Promise.reject(new Error(`加载样本失败: ${error.message}. 请确保文件有效且未损坏。`));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ======================= 导出的公共接口 =======================
|
||||
// 这里使用 CommonJS 风格的 module.exports,因为这通常用于 Node.js 环境或 Webpack/Rollup 等打包工具。
|
||||
// 如果你是在浏览器中直接作为 <script> 标签引入(没有打包工具),
|
||||
// 并且希望它在全局作用域下可用,可以使用 `window.AudioClassifier = {...}`。
|
||||
// 在这种情况下,我将提供 `window.AudioClassifier` 的方式。
|
||||
|
||||
// 为了 Web 浏览器环境,将核心功能挂载到全局对象上
|
||||
window.AudioClassifier = {
|
||||
init,
|
||||
recordMultipleExamples,
|
||||
addCustomCategory,
|
||||
checkTrainingReadiness,
|
||||
trainModel,
|
||||
startPrediction,
|
||||
stopPrediction,
|
||||
getExampleCounts,
|
||||
isModelTrained,
|
||||
labels
|
||||
};
|
||||
init, // 初始化模型
|
||||
recordMultipleExamples, // 录制样本
|
||||
addCustomCategory, // 添加自定义类别
|
||||
checkTrainingReadiness, // 检查模型是否可训练
|
||||
trainModel, // 训练模型
|
||||
startPrediction, // 开始预测
|
||||
stopPrediction, // 停止预测
|
||||
getExampleCounts, // 获取所有类别样本数量
|
||||
isModelTrained, // 检查模型是否已训练
|
||||
getLabels, // 获取当前所有标签
|
||||
serializeExamples, // 导出样本数据
|
||||
loadExamples // 导入样本数据
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user