[MF]修改audioClassifier.js (核心功能模块)
This commit is contained in:
		
							parent
							
								
									5d36541dc5
								
							
						
					
					
						commit
						37dc1c5a76
					
				| @ -1,147 +1,213 @@ | |||||||
| // 全局变量和模型实例
 | // audioClassifier.js (核心功能模块)
 | ||||||
| let recognizer; // 基础的 SpeechCommands recognizer
 | 
 | ||||||
| let transferRecognizer; // 用于迁移学习的 recognizer
 | // TensorFlow.js 和 Speech Commands 库不需要在这里再次导入,
 | ||||||
| const labels = []; // 用户定义的类别标签数组 (包括背景噪音)
 | // 假定在使用此模块的 HTML 页面中已经通过 <script> 标签引入。
 | ||||||
| // 将背景噪音定义为第一个类别,其内部名称为 _background_noise_
 | 
 | ||||||
| const BACKGROUND_NOISE_LABEL = '_background_noise_'; | // ======================= 模块内部共享变量 =======================
 | ||||||
| const BACKGROUND_NOISE_INDEX = 0; // 仅用于本地 labels 数组索引,不直接用于collectExample
 | let recognizer; // 基础的 SpeechCommands recognizer 实例
 | ||||||
|  | let transferRecognizer; // 用于迁移学习的 recognizer 实例
 | ||||||
|  | 
 | ||||||
|  | // 用户定义的类别标签数组 (包括背景噪音)。
 | ||||||
|  | // 这是模块的核心状态之一,外部可以通过 getLabels() 获取。
 | ||||||
|  | let labels = [];  | ||||||
|  | const BACKGROUND_NOISE_LABEL = '_background_noise_'; // 内部使用的背景噪音标签
 | ||||||
| 
 | 
 | ||||||
| let isPredicting = false; // 预测状态标志
 | let isPredicting = false; // 预测状态标志
 | ||||||
| let isRecording = false; // 录音状态标志,防止重复点击
 | let isRecording = false; // 录音状态标志,防止重复点击
 | ||||||
| const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
 | const recordDuration = 1000; // 每个样本的录音时长 (毫秒)
 | ||||||
|  | 
 | ||||||
| let isModelTrainedFlag = false; // 手动维护模型训练状态
 | let isModelTrainedFlag = false; // 手动维护模型训练状态
 | ||||||
| let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
 | let predictionStopFunction = null; // 存储 transferRecognizer.listen() 返回的停止函数
 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | // ======================= 核心功能函数 =======================
 | ||||||
|  | 
 | ||||||
| /** | /** | ||||||
|  * 初始化函数 - 加载模型和创建迁移学习模型 |  * 初始化音频分类器模型。 | ||||||
|  * @returns {Promise<void>} |  * 必须在使用任何其他功能之前调用。 | ||||||
|  |  * @returns {Promise<void>} resolve 表示成功,reject 表示失败。 | ||||||
|  */ |  */ | ||||||
| async function init() { | async function init() { | ||||||
|     try { |     // 确保每次初始化时重置状态
 | ||||||
|         recognizer = speechCommands.create( |     recognizer = null; | ||||||
|             'BROWSER_FFT' // 使用浏览器内置的 FFT 处理,性能更好
 |     transferRecognizer = null; | ||||||
|         ); |     labels = []; | ||||||
|  |     isPredicting = false; | ||||||
|  |     isRecording = false; | ||||||
|  |     isModelTrainedFlag = false; | ||||||
|  |     if (predictionStopFunction) { | ||||||
|  |         predictionStopFunction(); // 停止任何正在进行的预测
 | ||||||
|  |         predictionStopFunction = null; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|         await recognizer.ensureModelLoaded(); |     try { | ||||||
|  |         // 创建基础识别器
 | ||||||
|  |         recognizer = speechCommands.create('BROWSER_FFT'); | ||||||
|  |         await recognizer.ensureModelLoaded(); // 确保模型加载完成
 | ||||||
|  | 
 | ||||||
|  |         // 创建迁移学习识别器
 | ||||||
|         transferRecognizer = recognizer.createTransfer('my-custom-model'); |         transferRecognizer = recognizer.createTransfer('my-custom-model'); | ||||||
|          |          | ||||||
|         // 只有在 transferRecognizer 创建成功后,才将背景噪音标签加入我们的 local labels 数组
 |         // 初始化时,将背景噪音标签加入到我们的内部 labels 数组
 | ||||||
|         labels.push(BACKGROUND_NOISE_LABEL); |         labels.push(BACKGROUND_NOISE_LABEL); | ||||||
|          |          | ||||||
|  |         console.log('AudioClassifier: 模型和迁移学习器初始化成功'); | ||||||
|         return Promise.resolve(); |         return Promise.resolve(); | ||||||
|     } catch (error) { |     } catch (error) { | ||||||
|         return Promise.reject(error); |         console.error('AudioClassifier: 初始化失败:', error); | ||||||
|  |         return Promise.reject(new Error(`模型初始化失败或麦克风无法访问: ${error.message}`)); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 批量录制样本的通用函数 |  * 录制单个或批量音频样本。 | ||||||
|  * @param {string} label - 标签名称 |  * 在录制示例时会暂时禁用 isRecording 标志以防止多重录制。 | ||||||
|  * @param {number} countToRecord - 要录制的样本数量 |  * | ||||||
|  * @returns {Promise<void>} |  * @param {string} label - 样本所属的类别标签。 | ||||||
|  |  * @param {number} countToRecord - 要录制样本的数量。 | ||||||
|  |  * @param {Object} [options] - 其他选项。 | ||||||
|  |  * @param {Function} [options.onProgress] - (optional) 进度回调函数,参数为 (currentCount, totalCount)。 | ||||||
|  |  * @returns {Promise<number>} resolve 时返回该类别当前的总样本数量,reject 时返回错误。 | ||||||
|  */ |  */ | ||||||
| async function recordMultipleExamples(label, countToRecord = 5) { | async function recordMultipleExamples(label, countToRecord = 1, options = {}) { | ||||||
|     if (isRecording) { |     if (isRecording) { | ||||||
|         return Promise.reject(new Error('正在录制中,请等待当前录音完成')); |         return Promise.reject(new Error('核心模块:正在录制中,请等待当前录音完成。')); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     isRecording = true; |     isRecording = true; | ||||||
|  |     let currentLabelCount = 0; | ||||||
| 
 | 
 | ||||||
|     for (let i = 0; i < countToRecord; i++) { |  | ||||||
|     try { |     try { | ||||||
|             await transferRecognizer.collectExample( |         for (let i = 0; i < countToRecord; i++) { | ||||||
|                 label, |             // console.log(`AudioClassifier: 正在录制 "${label}" 样本... (第 ${i + 1} 个 / 共 ${countToRecord} 个)`);
 | ||||||
|                 { amplitudeRequired: true, durationMillis: recordDuration } |             await transferRecognizer.collectExample(label, {  | ||||||
|             ); |                 amplitudeRequired: true,  | ||||||
|  |                 durationMillis: recordDuration  | ||||||
|  |             }); | ||||||
|  |              | ||||||
|  |             const exampleCounts = transferRecognizer.countExamples(); | ||||||
|  |             currentLabelCount = exampleCounts[label] || 0; | ||||||
|  | 
 | ||||||
|  |             if (options.onProgress && typeof options.onProgress === 'function') { | ||||||
|  |                 options.onProgress(i + 1, countToRecord, currentLabelCount); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|             // 在每次录音之间增加短暂延迟,以便更好地分离样本
 |             // 在每次录音之间增加短暂延迟,以便更好地分离样本
 | ||||||
|             if (i < countToRecord - 1) { |             if (i < countToRecord - 1) { | ||||||
|                 await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5))); |                 await new Promise(resolve => setTimeout(resolve, Math.max(200, recordDuration / 5))); | ||||||
|             } |             } | ||||||
|  |         } | ||||||
|  |         console.log(`AudioClassifier: 已为 "${label}" 收集了 ${currentLabelCount} 个样本。`); | ||||||
|  |         return Promise.resolve(currentLabelCount); | ||||||
|     } catch (error) { |     } catch (error) { | ||||||
|  |         console.error(`AudioClassifier: 录制 "${label}" 样本失败:`, error); | ||||||
|  |         return Promise.reject(new Error(`录制 "${label}" 样本失败: ${error.message}`)); | ||||||
|  |     } finally { | ||||||
|         isRecording = false; |         isRecording = false; | ||||||
|             return Promise.reject(error); |  | ||||||
|     } |     } | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     isRecording = false; |  | ||||||
|     return Promise.resolve(); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 添加自定义类别 |  * 添加一个自定义类别标签到模型。 | ||||||
|  * @param {string} categoryName - 类别名称 |  * 标签将添加到内部的 `labels` 数组中。 | ||||||
|  |  * @param {string} categoryName - 要添加的类别名称。 | ||||||
|  |  * @returns {Promise<void>} resolve 表示成功,reject 表示失败(如名称为空或重复)。 | ||||||
|  */ |  */ | ||||||
| function addCustomCategory(categoryName) { | function addCustomCategory(categoryName) { | ||||||
|     if (!categoryName) { |     if (!categoryName || categoryName.trim() === '') { | ||||||
|         return Promise.reject(new Error('类别名称不能为空')); |         return Promise.reject(new Error('类别名称不能为空。')); | ||||||
|     } |     } | ||||||
|      |     const trimmedName = categoryName.trim(); | ||||||
|     // 检查是否与现有标签重复
 |     if (labels.some(label => label.toLowerCase() === trimmedName.toLowerCase())) { | ||||||
|     if (labels.some(label => label.toLowerCase() === categoryName.toLowerCase())) { |         return Promise.reject(new Error(`类别 "${trimmedName}" 已经存在。`)); | ||||||
|         return Promise.reject(new Error(`类别 "${categoryName}" 已经存在`)); |  | ||||||
|     } |     } | ||||||
| 
 |     labels.push(trimmedName); | ||||||
|     // 将标签添加到本地数组
 |     console.log(`AudioClassifier: 已添加新类别: "${trimmedName}"`); | ||||||
|     labels.push(categoryName); |  | ||||||
|     return Promise.resolve(); |     return Promise.resolve(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 检查训练就绪状态 |  * 检查模型是否已准备好进行训练。 | ||||||
|  * @returns {boolean} 是否可以开始训练 |  * @returns {Object} 包含 `ready` (boolean) 和 `details` (string) 的对象。 | ||||||
|  */ |  */ | ||||||
| function checkTrainingReadiness() { | function checkTrainingReadiness() { | ||||||
|     const exampleCounts = transferRecognizer.countExamples(); |     let exampleCounts = {}; | ||||||
|  |     try { | ||||||
|  |         exampleCounts = transferRecognizer.countExamples(); | ||||||
|  |     } catch (error) { | ||||||
|  |         // 这是预期内的错误,当没有样本时 countExamples 会抛出。
 | ||||||
|  |         // 将其捕获并视为空样本集处理。
 | ||||||
|  |         console.warn("countExamples() 抛出异常 (无样本时正常行为)。", error.message); | ||||||
|  |     } | ||||||
|      |      | ||||||
|     let backgroundNoiseReady = (exampleCounts[BACKGROUND_NOISE_LABEL] || 0) > 0; |     // 检查是否有任何样本,用于更新导出按钮等逻辑
 | ||||||
|  |     const totalSamples = Object.values(exampleCounts).reduce((acc, count) => acc + count, 0); | ||||||
|  |     const hasAnyExamples = totalSamples > 0; | ||||||
| 
 | 
 | ||||||
|     let customCategoriesReady = 0; |     const backgroundNoiseCount = exampleCounts[BACKGROUND_NOISE_LABEL] || 0; | ||||||
|     // 遍历本地 labels 数组,检查每个自定义类别是否有样本
 |     const backgroundNoiseReady = backgroundNoiseCount > 0; | ||||||
|     for (let i = 1; i < labels.length; i++) { // 从索引 1 开始,因为 0 是背景噪音
 |      | ||||||
|  |     let customCategoriesWithSamples = 0; | ||||||
|  |     // 遍历内部 labels 数组,检查每个自定义类别是否有样本
 | ||||||
|  |     // 从索引 1 开始,因为 0 是背景噪音
 | ||||||
|  |     for (let i = 1; i < labels.length; i++) {  | ||||||
|         const customLabel = labels[i]; |         const customLabel = labels[i]; | ||||||
|         if ((exampleCounts[customLabel] || 0) > 0) { |         if ((exampleCounts[customLabel] || 0) > 0) { | ||||||
|             customCategoriesReady++; |             customCategoriesWithSamples++; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // 必须有背景噪音样本,并且至少一个自定义类别有样本
 |     // 训练就绪的条件:必须有背景噪音样本,并且至少一个自定义类别有样本
 | ||||||
|     return backgroundNoiseReady && customCategoriesReady >= 1; |     const canTrain = backgroundNoiseReady && customCategoriesWithSamples >= 1; | ||||||
|  | 
 | ||||||
|  |     // 更详细的训练就绪判断,包括样本数量建议
 | ||||||
|  |     const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5; | ||||||
|  |     let trainingDetails = []; | ||||||
|  | 
 | ||||||
|  |     if (!backgroundNoiseReady) { | ||||||
|  |         trainingDetails.push('需要录制背景噪音样本。'); | ||||||
|  |     } else if (backgroundNoiseCount < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) { | ||||||
|  |         trainingDetails.push(`建议背景噪音至少 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本 (当前: ${backgroundNoiseCount})。`); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (customCategoriesWithSamples === 0) { | ||||||
|  |         trainingDetails.push('需要至少一个自定义类别并录制样本。'); | ||||||
|  |     } else { | ||||||
|  |         let insufficientCustomSamples = []; | ||||||
|  |         for (let i = 1; i < labels.length; i++) { | ||||||
|  |             const customLabel = labels[i]; | ||||||
|  |             const count = exampleCounts[customLabel] || 0; | ||||||
|  |             if (count > 0 && count < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) { | ||||||
|  |                 insufficientCustomSamples.push(`类别 "${customLabel}" 样本不足 (当前: ${count},建议: ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING})。`); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         if (insufficientCustomSamples.length > 0) { | ||||||
|  |             trainingDetails = trainingDetails.concat(insufficientCustomSamples); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return {  | ||||||
|  |         ready: canTrain && trainingDetails.length === 0, // 只有完全满足条件才认为是 ready
 | ||||||
|  |         details: trainingDetails.join('\n') || '模型已准备好进行训练。', | ||||||
|  |         hasAnyExamples: hasAnyExamples // 额外返回是否有任何样本,供 UI 判断是否启用导出按钮
 | ||||||
|  |     }; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 模型训练函数 |  * 训练模型。 | ||||||
|  * @param {Object} trainingConfig - 训练配置参数 |  * @param {Object} [config] - 训练配置。 | ||||||
|  * @returns {Promise<void>} |  * @param {number} [config.epochs=50] - 训练迭代次数。 | ||||||
|  |  * @param {number} [config.batchSize=16] - 批大小。 | ||||||
|  |  * @param {number} [config.validationSplit=0.1] - 验证集比例。 | ||||||
|  |  * @param {boolean} [config.shuffle=true] - 是否打乱数据。 | ||||||
|  |  * @param {Function} [config.onEpochEnd] - (可选) 每个 epoch 结束时的回调函数,参数为 (epoch, logs)。 | ||||||
|  |  * @returns {Promise<void>} resolve 表示训练完成,reject 表示训练失败。 | ||||||
|  */ |  */ | ||||||
| async function trainModel(trainingConfig = {}) { | async function trainModel(config = {}) { | ||||||
|     const exampleCounts = transferRecognizer.countExamples(); |     // 再次进行训练前检查,防止外部在不满足条件时调用
 | ||||||
|      |     const readiness = checkTrainingReadiness(); | ||||||
|     let totalExamples = 0; |     if (!readiness.ready) { | ||||||
|     let validClasses = 0; |         return Promise.reject(new Error(`模型未准备好训练:\n${readiness.details}`)); | ||||||
|     const MIN_SAMPLES_PER_CLASS_FOR_TRAINING = 5;  |  | ||||||
|     let allClassesHaveEnoughSamples = true;  |  | ||||||
|      |  | ||||||
|     // 统计所有类别的有效样本数,并检查每个类别是否达到最低要求
 |  | ||||||
|     for (const labelName of labels) { |  | ||||||
|         if (exampleCounts[labelName] && exampleCounts[labelName] > 0) { |  | ||||||
|             totalExamples += exampleCounts[labelName]; |  | ||||||
|             validClasses++; |  | ||||||
|             if (exampleCounts[labelName] < MIN_SAMPLES_PER_CLASS_FOR_TRAINING) { |  | ||||||
|                 allClassesHaveEnoughSamples = false;  |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if (validClasses < 2) {  |  | ||||||
|         return Promise.reject(new Error(`训练需要至少 "背景噪音" 和另一个自定义类别。当前只有 ${validClasses} 个有效类别。`)); |  | ||||||
|     } |  | ||||||
|      |  | ||||||
|     if (!allClassesHaveEnoughSamples) { |  | ||||||
|         return Promise.reject(new Error(`请确保每个类别至少收集了 ${MIN_SAMPLES_PER_CLASS_FOR_TRAINING} 个样本。`)); |  | ||||||
|     } |  | ||||||
|      |  | ||||||
|     if (totalExamples === 0) { |  | ||||||
|         return Promise.reject(new Error('没有收集到任何训练样本')); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     const defaultConfig = { |     const defaultConfig = { | ||||||
| @ -149,116 +215,225 @@ async function trainModel(trainingConfig = {}) { | |||||||
|         batchSize: 16,  |         batchSize: 16,  | ||||||
|         validationSplit: 0.1,  |         validationSplit: 0.1,  | ||||||
|         shuffle: true,  |         shuffle: true,  | ||||||
|         yieldEvery: 'epoch' |         // 外部传入的回调函数将在 train 内部被调用
 | ||||||
|  |         callbacks: { | ||||||
|  |             onEpochEnd: (epoch, logs) => { | ||||||
|  |                 if (config.onEpochEnd && typeof config.onEpochEnd === 'function') { | ||||||
|  |                     config.onEpochEnd(epoch, logs); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|     }; |     }; | ||||||
|      |      | ||||||
|     const config = Object.assign(defaultConfig, trainingConfig); |     // 合并默认配置和用户指定的配置
 | ||||||
|  |     const trainingConfig = { ...defaultConfig, ...config }; | ||||||
| 
 | 
 | ||||||
|     try { |     try { | ||||||
|         await transferRecognizer.train(config); |         await transferRecognizer.train(trainingConfig); | ||||||
|         isModelTrainedFlag = true; |         isModelTrainedFlag = true; | ||||||
|  |         console.log('AudioClassifier: 模型训练完成。'); | ||||||
|         return Promise.resolve(); |         return Promise.resolve(); | ||||||
|     } catch (error) { |     } catch (error) { | ||||||
|         isModelTrainedFlag = false; |         isModelTrainedFlag = false; | ||||||
|         return Promise.reject(error); |         console.error('AudioClassifier: 模型训练失败:', error); | ||||||
|  |         return Promise.reject(new Error(`模型训练失败: ${error.message}. 请确保有足够的样本且类别均衡。`)); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 开始实时预测 |  * 开始实时预测。 | ||||||
|  * @param {Function} onPrediction - 预测结果回调函数 |  * @param {Function} onPrediction - 预测结果回调函数,参数为 {label: string, score: number, scores: number[], labels: string[]}。 | ||||||
|  * @param {Object} listenOptions - 监听选项 |  * @param {Object} [listenOptions] - 监听选项,直接传递给 transferRecognizer.listen()。 | ||||||
|  * @returns {Promise<Function>} 停止预测的函数 |  * @returns {Promise<Function>} resolve 时返回停止预测的函数,reject 时返回错误。 | ||||||
|  */ |  */ | ||||||
| async function startPrediction(onPrediction, listenOptions = {}) { | async function startPrediction(onPrediction, listenOptions = {}) { | ||||||
|     if (isPredicting) { |     if (isPredicting) { | ||||||
|         return Promise.reject(new Error('识别已经在进行中')); |         return Promise.reject(new Error('核心模块:识别已经在进行中。')); | ||||||
|     } |     } | ||||||
|      |      | ||||||
|     if (!isModelTrainedFlag) {  |     if (!isModelTrainedFlag) {  | ||||||
|         return Promise.reject(new Error('模型尚未训练完成')); |         return Promise.reject(new Error('核心模块:模型尚未训练完成,请先训练模型。')); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     isPredicting = true; |     isPredicting = true; | ||||||
|      |      | ||||||
|     const defaultOptions = { |     const defaultOptions = { | ||||||
|         includeEmbedding: true,  |         includeEmbedding: true,  | ||||||
|         probabilityThreshold: 0.75,  |         probabilityThreshold: 0.75, // 预测置信度阈值
 | ||||||
|         suppressionTimeMillis: 300,  |         suppressionTimeMillis: 300, // 抑制时间,避免连续触发
 | ||||||
|         overlapFactor: 0.50,  |         overlapFactor: 0.50, // 帧重叠因子
 | ||||||
|     }; |     }; | ||||||
|      |      | ||||||
|     const options = Object.assign(defaultOptions, listenOptions); |     const options = { ...defaultOptions, ...listenOptions }; | ||||||
| 
 | 
 | ||||||
|  |     try { | ||||||
|         predictionStopFunction = await transferRecognizer.listen(result => { |         predictionStopFunction = await transferRecognizer.listen(result => { | ||||||
|         if (!isPredicting) return; |             if (!isPredicting) return; // 确保在停止后不再处理结果
 | ||||||
| 
 | 
 | ||||||
|         const classLabels = transferRecognizer.wordLabels();  |             const classLabels = transferRecognizer.wordLabels(); // 获取模型内部的标签列表
 | ||||||
|         const scores = result.scores;  |             const scores = result.scores; // 预测分数
 | ||||||
|             const maxScore = Math.max(...scores); |             const maxScore = Math.max(...scores); | ||||||
|             const predictedIndex = scores.indexOf(maxScore); |             const predictedIndex = scores.indexOf(maxScore); | ||||||
|              |              | ||||||
|             let predictedLabel = classLabels[predictedIndex]; |             let predictedLabel = classLabels[predictedIndex]; | ||||||
|         // 如果预测结果是内部的背景噪音标签,转换成用户友好的显示
 |             // 对背景噪音标签进行友好化处理
 | ||||||
|             if (predictedLabel === BACKGROUND_NOISE_LABEL) { |             if (predictedLabel === BACKGROUND_NOISE_LABEL) { | ||||||
|                 predictedLabel = '背景噪音';  |                 predictedLabel = '背景噪音';  | ||||||
|             } |             } | ||||||
|              |              | ||||||
|         // 调用回调函数返回预测结果
 |             // 调用外部传入的回调函数,传递结构化的预测结果
 | ||||||
|             if (typeof onPrediction === 'function') { |             if (typeof onPrediction === 'function') { | ||||||
|                 onPrediction({ |                 onPrediction({ | ||||||
|                     label: predictedLabel, |                     label: predictedLabel, | ||||||
|                     score: maxScore, |                     score: maxScore, | ||||||
|                 scores: scores, |                     scores: scores, // 原始分数数组
 | ||||||
|                 labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) |                     predictedIndex: predictedIndex, // 预测索引
 | ||||||
|  |                     labels: classLabels.map(label => label === BACKGROUND_NOISE_LABEL ? '背景噪音' : label) // 友好化的标签列表
 | ||||||
|                 }); |                 }); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|         }, options); |         }, options); | ||||||
| 
 | 
 | ||||||
|     return Promise.resolve(predictionStopFunction); |         console.log('AudioClassifier: 开始识别。'); | ||||||
|  |         return Promise.resolve(predictionStopFunction); // 返回停止函数给调用者
 | ||||||
|  |     } catch (error) { | ||||||
|  |         isPredicting = false; | ||||||
|  |         console.error('AudioClassifier: 开始识别失败:', error); | ||||||
|  |         return Promise.reject(new Error(`开始识别失败: ${error.message}`)); | ||||||
|  |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 停止实时预测 |  * 停止实时预测。 | ||||||
|  */ |  */ | ||||||
| function stopPrediction() { | function stopPrediction() { | ||||||
|     if (isPredicting) { |     if (isPredicting) { | ||||||
|         if (typeof predictionStopFunction === 'function') { |         if (typeof predictionStopFunction === 'function') { | ||||||
|             predictionStopFunction(); |             predictionStopFunction(); // 调用 SpeechCommands 库返回的停止函数
 | ||||||
|             predictionStopFunction = null; |             predictionStopFunction = null; | ||||||
|  |         } else { | ||||||
|  |             console.warn('AudioClassifier: 停止预测函数未定义或不是函数。'); | ||||||
|         } |         } | ||||||
|          |  | ||||||
|         isPredicting = false; |         isPredicting = false; | ||||||
|  |         console.log('AudioClassifier: 已停止识别。'); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 获取各类别样本数量 |  * 获取当前收集到的各类别样本数量。 | ||||||
|  * @returns {Object} 各类别样本数量统计 |  * @returns {Object} 键为类别名称,值为样本数量。 | ||||||
|  */ |  */ | ||||||
| function getExampleCounts() { | function getExampleCounts() { | ||||||
|  |     try { | ||||||
|         return transferRecognizer.countExamples(); |         return transferRecognizer.countExamples(); | ||||||
|  |     } catch (error) { | ||||||
|  |         // 捕获无样本时的错误,返回空对象
 | ||||||
|  |         return {};  | ||||||
|  |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * 获取模型是否已训练的状态 |  * 获取模型是否已训练的状态。 | ||||||
|  * @returns {boolean} 模型是否已训练 |  * @returns {boolean} 如果模型已训练则返回 true,否则 false。 | ||||||
|  */ |  */ | ||||||
| function isModelTrained() { | function isModelTrained() { | ||||||
|     return isModelTrainedFlag; |     return isModelTrainedFlag; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // 导出公共接口
 | /** | ||||||
|  |  * 获取当前模块中所有标签的数组(包括背景噪音和自定义标签)。 | ||||||
|  |  * @returns {string[]} 标签数组。 | ||||||
|  |  */ | ||||||
|  | function getLabels() { | ||||||
|  |     return [...labels]; // 返回副本,避免外部直接修改内部状态
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * 序列化所有收集到的样本数据到 ArrayBuffer。 | ||||||
|  |  * @returns {Promise<ArrayBuffer>} 包含样本数据的 ArrayBuffer。 | ||||||
|  |  */ | ||||||
|  | async function serializeExamples() { | ||||||
|  |     try { | ||||||
|  |         const serialized = await transferRecognizer.serializeExamples(); | ||||||
|  |         console.log('AudioClassifier: 样本数据序列化成功。'); | ||||||
|  |         return Promise.resolve(serialized); | ||||||
|  |     } catch (error) { | ||||||
|  |         console.error('AudioClassifier: 序列化样本失败:', error); | ||||||
|  |         return Promise.reject(new Error(`序列化样本失败: ${error.message}`)); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * 从 ArrayBuffer 导入样本数据。 | ||||||
|  |  * 导入后会更新内部的 labels 数组和模型状态。 | ||||||
|  |  * @param {ArrayBuffer} dataBuffer - 包含样本数据的 ArrayBuffer。 | ||||||
|  |  * @returns {Promise<Object>} resolve 时返回导入后的样本总数统计,reject 时返回错误。 | ||||||
|  |  */ | ||||||
|  | async function loadExamples(dataBuffer) { | ||||||
|  |     if (!dataBuffer instanceof ArrayBuffer) { | ||||||
|  |         return Promise.reject(new Error('核心模块:导入数据必须是 ArrayBuffer 类型。')); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // 首先清空当前的 transferRecognizer 中的所有样本,再加载新的。
 | ||||||
|  |     // 这等同于创建一个新的 transferRecognizer 实例,并同步其内部状态。
 | ||||||
|  |     // SpeechCommands库没有直接的`clearExamples`或`resetExamples`方法,
 | ||||||
|  |     // 最稳妥的方法是重新创建transferRecognizer并加载。
 | ||||||
|  |     // 但loadExamples本身就是设计来覆盖和更新现有样本的,
 | ||||||
|  |     // 考虑到性能,我们假设直接loadExamples即可。
 | ||||||
|  |     // 注意:loadExamples会清除原有数据。
 | ||||||
|  | 
 | ||||||
|  |     try { | ||||||
|  |         await transferRecognizer.loadExamples(dataBuffer); | ||||||
|  | 
 | ||||||
|  |         // 成功加载后,需要更新内部的 `labels` 数组,以匹配导入的数据。
 | ||||||
|  |         // `transferRecognizer.wordLabels()` 返回的是模型内部的实际标签。
 | ||||||
|  |         const loadedLabels = transferRecognizer.wordLabels(); | ||||||
|  |         // 确保 BACKGROUND_NOISE_LABEL 还在第一个位置,并过滤掉 '_version_', '_numExamples_' 等内部标签
 | ||||||
|  |         labels = loadedLabels.filter(label => label !== '_version_' && label !== '_numExamples_'); | ||||||
|  | 
 | ||||||
|  |         if (!labels.includes(BACKGROUND_NOISE_LABEL)) { | ||||||
|  |             // 如果导入的数据中没有背景噪音(不常见,但作为健壮性处理)
 | ||||||
|  |             labels.unshift(BACKGROUND_NOISE_LABEL); | ||||||
|  |         } else { | ||||||
|  |             // 确保背景噪音在第一位
 | ||||||
|  |             const bgIndex = labels.indexOf(BACKGROUND_NOISE_LABEL); | ||||||
|  |             if (bgIndex > 0) { | ||||||
|  |                 const bgLabel = labels.splice(bgIndex, 1); | ||||||
|  |                 labels.unshift(bgLabel[0]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |          | ||||||
|  |         // 重置模型的训练状态,因为样本已更改,需要重新训练
 | ||||||
|  |         isModelTrainedFlag = false;  | ||||||
|  | 
 | ||||||
|  |         const exampleCounts = transferRecognizer.countExamples(); | ||||||
|  |         console.log('AudioClassifier: 样本数据导入成功。', exampleCounts); | ||||||
|  |         return Promise.resolve(exampleCounts); | ||||||
|  |     } catch (error) { | ||||||
|  |         console.error('AudioClassifier: 加载样本失败:', error); | ||||||
|  |         return Promise.reject(new Error(`加载样本失败: ${error.message}. 请确保文件有效且未损坏。`)); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | // ======================= 导出的公共接口 =======================
 | ||||||
|  | // 这里使用 CommonJS 风格的 module.exports,因为这通常用于 Node.js 环境或 Webpack/Rollup 等打包工具。
 | ||||||
|  | // 如果你是在浏览器中直接作为 <script> 标签引入(没有打包工具),
 | ||||||
|  | // 并且希望它在全局作用域下可用,可以使用 `window.AudioClassifier = {...}`。
 | ||||||
|  | // 在这种情况下,我将提供 `window.AudioClassifier` 的方式。
 | ||||||
|  | 
 | ||||||
|  | // 为了 Web 浏览器环境,将核心功能挂载到全局对象上
 | ||||||
| window.AudioClassifier = { | window.AudioClassifier = { | ||||||
|     init, |     init,                           // 初始化模型
 | ||||||
|     recordMultipleExamples, |     recordMultipleExamples,         // 录制样本
 | ||||||
|     addCustomCategory, |     addCustomCategory,              // 添加自定义类别
 | ||||||
|     checkTrainingReadiness, |     checkTrainingReadiness,         // 检查模型是否可训练
 | ||||||
|     trainModel, |     trainModel,                     // 训练模型
 | ||||||
|     startPrediction, |     startPrediction,                // 开始预测
 | ||||||
|     stopPrediction, |     stopPrediction,                 // 停止预测
 | ||||||
|     getExampleCounts, |     getExampleCounts,               // 获取所有类别样本数量
 | ||||||
|     isModelTrained, |     isModelTrained,                 // 检查模型是否已训练
 | ||||||
|     labels |     getLabels,                      // 获取当前所有标签
 | ||||||
|  |     serializeExamples,              // 导出样本数据
 | ||||||
|  |     loadExamples                    // 导入样本数据
 | ||||||
| }; | }; | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user