// 自定义图像分类器应用 class CustomImageClassifier { constructor() { this.mobilenet = null; this.model = null; this.classNames = []; this.trainingDataInputs = []; this.trainingDataOutputs = []; this.isTraining = false; this.webcamStream = null; this.isPredicting = false; this.lossHistory = []; this.accuracyHistory = []; this.valLossHistory = []; this.valAccuracyHistory = []; this.dataMean = null; // 存储训练数据的均值 this.dataStd = null; // 存储训练数据的标准差 this.temperature = 1.0; // 温度缩放参数 this.bestValLoss = Infinity; this.patienceCounter = 0; this.bestWeights = null; this.init(); } async init() { this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); try { // 加载预训练的 MobileNet 模型 this.mobilenet = await mobilenet.load({ version: 2, alpha: 1.0 }); this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); this.setupEventListeners(); } catch (error) { this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); } } setupEventListeners() { // 文件上传监听 ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { document.getElementById(id).addEventListener('change', (e) => { this.handleImageUpload(e, index + 1); }); }); // 按钮监听 document.getElementById('addDataBtn').addEventListener('click', () => this.addToDataset()); document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); document.getElementById('trainBtn').addEventListener('click', () => this.trainModel()); document.getElementById('stopBtn').addEventListener('click', () => this.stopTraining()); document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel()); document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel()); } handleImageUpload(event, classIndex) { const files = event.target.files; const previewContainer = document.getElementById(`class${classIndex}Preview`); const countElement = document.getElementById(`class${classIndex}Count`); previewContainer.innerHTML = ''; Array.from(files).forEach(file => { const reader = new FileReader(); reader.onload = (e) => { const img = document.createElement('img'); img.src = e.target.result; img.className = 'preview-img'; previewContainer.appendChild(img); }; reader.readAsDataURL(file); }); countElement.textContent = `${files.length} 张图片`; } async addToDataset() { const classes = []; const imagesByClass = []; // 收集所有类别和图片 for (let i = 1; i <= 3; i++) { const className = document.getElementById(`class${i}Name`).value.trim(); const files = document.getElementById(`class${i}Images`).files; if (className && files.length > 0) { classes.push(className); imagesByClass.push(files); } } if (classes.length < 2) { this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!'); return; } // 检查是否已有数据,如果有,询问是否追加 if (this.trainingDataInputs.length > 0) { // 清空旧数据 this.clearDataset(); } this.classNames = classes; // 计算总图片数量和每个类别的数量 let totalToProcess = 0; const classImageCounts = []; imagesByClass.forEach((files, idx) => { totalToProcess += files.length; classImageCounts.push({ className: classes[idx], count: files.length }); }); // 检查数据平衡性 const minImages = Math.min(...classImageCounts.map(c => c.count)); const maxImages = Math.max(...classImageCounts.map(c => c.count)); if (maxImages > minImages * 3) { console.warn('警告:数据集不平衡!'); classImageCounts.forEach(c => { console.log(`${c.className}: ${c.count} 张图片`); }); this.showStatus('dataStatus', 'warning', `警告:数据集不平衡!建议每个类别的图片数量相近。`); await new Promise(resolve => setTimeout(resolve, 2000)); // 显示警告2秒 } this.showStatus('dataStatus', 'info', `正在处理 ${totalToProcess} 张图片...`); let totalImages = 0; let processedImages = 0; const classSampleCounts = new Array(classes.length).fill(0); // 处理每个类别的图片 for (let classIndex = 0; classIndex < classes.length; classIndex++) { const files = imagesByClass[classIndex]; for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { try { console.log(`处理图片 ${fileIndex + 1}/${files.length} - ${classes[classIndex]}`); const img = await this.loadImage(files[fileIndex]); // 确保图片加载成功 if (!img || !img.width || !img.height) { console.error('图片加载失败或无效:', files[fileIndex].name); processedImages++; continue; } // 为每张图片提取特征(原始版本) const features = await this.extractFeatures(img, false); this.trainingDataInputs.push(features); this.trainingDataOutputs.push(classIndex); // 如果启用数据增强,生成额外的增强样本 if (document.getElementById('dataAugmentationCheck') && document.getElementById('dataAugmentationCheck').checked) { try { // 生成1个增强版本 const augFeatures = await this.extractFeatures(img, true); this.trainingDataInputs.push(augFeatures); this.trainingDataOutputs.push(classIndex); totalImages++; // 增强样本也计入总数 classSampleCounts[classIndex]++; } catch (augError) { console.warn('数据增强失败,跳过增强样本:', augError); } } totalImages++; processedImages++; classSampleCounts[classIndex]++; // 更新处理进度 const progressPercent = Math.round((processedImages / totalToProcess) * 100); this.showStatus('dataStatus', 'info', `处理中... ${processedImages}/${totalToProcess} (${progressPercent}%) - ${classes[classIndex]}: ${classSampleCounts[classIndex]} 张`); // 只有当img是HTMLImageElement时才调用remove if (img && img.remove) { img.remove(); } } catch (error) { console.error(`处理图片失败 (${files[fileIndex].name}):`, error); processedImages++; } } } // 显示最终统计 let statsMessage = `成功添加 ${totalImages} 张图片到数据集!\n`; classes.forEach((className, idx) => { statsMessage += `${className}: ${classSampleCounts[idx]} 张\n`; }); this.showStatus('dataStatus', 'success', statsMessage); // 打印详细统计到控制台 console.log('数据集统计:'); console.log('总样本数:', totalImages); classes.forEach((className, idx) => { console.log(`${className}: ${classSampleCounts[idx]} 张 (${(classSampleCounts[idx]/totalImages*100).toFixed(1)}%)`); }); } async loadImage(file) { return new Promise((resolve, reject) => { const reader = new FileReader(); reader.onload = (e) => { const img = new Image(); img.onload = () => resolve(img); img.onerror = reject; img.src = e.target.result; }; reader.onerror = reject; reader.readAsDataURL(file); }); } async extractFeatures(img, augment = false) { try { // 如果启用数据增强(仅在训练时) if (augment && document.getElementById('dataAugmentationCheck') && document.getElementById('dataAugmentationCheck').checked) { // 随机增强数据 const augmented = await this.augmentImage(img); const activation = this.mobilenet.infer(augmented, 'conv_pw_13_relu'); // 如果augmented是canvas元素,不需要dispose if (augmented !== img && augmented.tagName === 'CANVAS') { // Canvas元素不需要dispose } const shape = activation.shape; if (shape.length === 4 && shape[0] === 1) { return activation.squeeze([0]); } return activation; } // 正常提取特征 const activation = this.mobilenet.infer(img, 'conv_pw_13_relu'); // 检查张量形状 const shape = activation.shape; // 如果是4D张量 [batch, height, width, channels],去掉批次维度 if (shape.length === 4 && shape[0] === 1) { return activation.squeeze([0]); // 返回3D张量 } // 如果已经是3D张量,直接返回 return activation; } catch (error) { console.error('特征提取失败:', error); throw error; } } async augmentImage(img) { try { // 简单的数据增强:随机噪声 const imgTensor = tf.browser.fromPixels(img); // 添加少量随机噪声 const noise = tf.randomNormal(imgTensor.shape, 0, 5); // 减少噪声强度 const augmented = imgTensor.add(noise).clipByValue(0, 255); // 转换回图像格式 const canvas = document.createElement('canvas'); canvas.width = img.width || img.videoWidth || 224; canvas.height = img.height || img.videoHeight || 224; // 归一化后转换为图像 const normalized = augmented.div(255); await tf.browser.toPixels(normalized, canvas); // 清理张量 imgTensor.dispose(); noise.dispose(); augmented.dispose(); normalized.dispose(); return canvas; } catch (error) { console.error('数据增强失败:', error); // 如果增强失败,返回原图 return img; } } clearDataset() { this.trainingDataInputs.forEach(tensor => tensor.dispose()); this.trainingDataInputs = []; this.trainingDataOutputs = []; this.classNames = []; // 清空标准化参数 if (this.dataMean) { this.dataMean.dispose(); this.dataMean = null; } if (this.dataStd) { this.dataStd.dispose(); this.dataStd = null; } // 清空文件输入和预览 for (let i = 1; i <= 3; i++) { document.getElementById(`class${i}Images`).value = ''; document.getElementById(`class${i}Preview`).innerHTML = ''; document.getElementById(`class${i}Count`).textContent = '0 张图片'; } this.showStatus('dataStatus', 'info', '数据集已清空'); } createCustomModel() { const numClasses = this.classNames.length; // 获取超参数 const dropoutRate = parseFloat(document.getElementById('dropoutRate').value); const l2Reg = Math.pow(10, parseFloat(document.getElementById('l2Regularization').value)); const learningRate = Math.pow(10, parseFloat(document.getElementById('learningRate').value)); // 获取输入形状 - 现在应该是3D张量 const inputShape = this.trainingDataInputs[0].shape; console.log('Input shape:', inputShape); console.log('Hyperparameters:', { learningRate, dropoutRate, l2Reg }); // 创建更深的模型以提高区分能力 this.model = tf.sequential({ layers: [ tf.layers.flatten({ inputShape: inputShape }), // 第一层:更多神经元 tf.layers.dense({ units: 128, // 减少神经元数量,降低模型复杂度 activation: 'relu', kernelInitializer: 'heNormal', kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }), biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 }) }), tf.layers.batchNormalization(), tf.layers.dropout({ rate: dropoutRate, seed: 42 }), // 添加seed使dropout更一致 // 第二层 tf.layers.dense({ units: 64, // 减少神经元数量 activation: 'relu', kernelInitializer: 'heNormal', kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }), biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 }) }), tf.layers.batchNormalization(), tf.layers.dropout({ rate: dropoutRate, seed: 43 }), // 第三层 tf.layers.dense({ units: 32, // 进一步减少神经元 activation: 'relu', kernelInitializer: 'heNormal', kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }), biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 }) }), tf.layers.dropout({ rate: dropoutRate * 0.7, seed: 44 }), // 输出层 tf.layers.dense({ units: numClasses, activation: 'softmax', kernelInitializer: 'glorotNormal' }) ] }); // 使用更合适的优化器和学习率 this.model.compile({ optimizer: tf.train.adam(learningRate), loss: 'categoricalCrossentropy', metrics: ['accuracy'] }); console.log('Model created successfully'); this.model.summary(); } async trainModel() { if (this.trainingDataInputs.length === 0) { this.showStatus('trainingStatus', 'error', '请先添加训练数据!'); return; } // 检查数据平衡性 const classCounts = {}; this.trainingDataOutputs.forEach(label => { classCounts[label] = (classCounts[label] || 0) + 1; }); console.log('每个类别的样本数量:', classCounts); console.log('类别名称:', this.classNames); // 计算类别权重以平衡数据 const totalSamples = this.trainingDataOutputs.length; const numClasses = this.classNames.length; const classWeights = {}; for (let i = 0; i < numClasses; i++) { const classCount = classCounts[i] || 1; // 使用反比例权重:样本少的类别权重更高 classWeights[i] = totalSamples / (numClasses * classCount); } console.log('类别权重:', classWeights); this.showStatus('trainingStatus', 'info', '准备训练模型...'); // 创建模型 this.createCustomModel(); // 准备训练数据 let xs = tf.stack(this.trainingDataInputs); const ys = tf.oneHot(tf.tensor1d(this.trainingDataOutputs, 'int32'), this.classNames.length); // 手动分割训练集和验证集(而不是使用validationSplit) const totalDataCount = xs.shape[0]; // 改名避免冲突 const indices = tf.util.createShuffledIndices(totalDataCount); const splitIdx = Math.floor(totalDataCount * 0.8); // 确保indices是TypedArray const trainIndices = Array.from(indices).slice(0, splitIdx); const valIndices = Array.from(indices).slice(splitIdx); // 将indices转换为Tensor const trainIndicesTensor = tf.tensor1d(trainIndices, 'int32'); const valIndicesTensor = tf.tensor1d(valIndices, 'int32'); // 使用gather来获取对应的数据 const xsTrain = tf.gather(xs, trainIndicesTensor); const ysTrain = tf.gather(ys, trainIndicesTensor); const xsVal = tf.gather(xs, valIndicesTensor); const ysVal = tf.gather(ys, valIndicesTensor); // 清理索引张量 trainIndicesTensor.dispose(); valIndicesTensor.dispose(); // 只在训练集上计算标准化参数 this.dataMean = xsTrain.mean([0], true); this.dataStd = xsTrain.sub(this.dataMean).square().mean([0], true).sqrt(); // 应用标准化到训练集和验证集 const xsTrainNorm = xsTrain.sub(this.dataMean).div(this.dataStd.add(1e-7)); const xsValNorm = xsVal.sub(this.dataMean).div(this.dataStd.add(1e-7)); // 清理临时张量 xsTrain.dispose(); xsVal.dispose(); xs.dispose(); ys.dispose(); // 详细的数据集分析 console.log('\n📊 === 数据集分析 ==='); console.log(`训练集: ${trainIndices.length} 样本`); console.log(`验证集: ${valIndices.length} 样本`); console.log(`训练/验证比例: ${(trainIndices.length / totalDataCount * 100).toFixed(1)}% / ${(valIndices.length / totalDataCount * 100).toFixed(1)}%`); // 分析每个类别在训练集和验证集中的分布 const trainLabels = trainIndices.map(i => this.trainingDataOutputs[i]); const valLabels = valIndices.map(i => this.trainingDataOutputs[i]); console.log('\n类别分布:'); for (let i = 0; i < this.classNames.length; i++) { const trainCount = trainLabels.filter(l => l === i).length; const valCount = valLabels.filter(l => l === i).length; console.log(` ${this.classNames[i]}:`); console.log(` - 训练集: ${trainCount} (${(trainCount/trainIndices.length*100).toFixed(1)}%)`); console.log(` - 验证集: ${valCount} (${(valCount/valIndices.length*100).toFixed(1)}%)`); } // 警告检查 if (valIndices.length < 5) { console.warn('⚠️ 验证集太小!建议至少5个样本'); } if (valIndices.length < totalDataCount * 0.1) { console.warn('⚠️ 验证集比例太小!建议至少10%的数据用于验证'); } console.log('===================\n'); console.log('数据形状:', xs.shape); console.log('均值形状:', this.dataMean.shape); console.log('标准差形状:', this.dataStd.shape); // 打印数据集统计信息 const uniqueLabels = [...new Set(this.trainingDataOutputs)]; console.log('训练数据统计:'); uniqueLabels.forEach(label => { const count = this.trainingDataOutputs.filter(l => l === label).length; console.log(` ${this.classNames[label]}: ${count} 张图片`); }); // 显示进度条 document.getElementById('trainingProgress').classList.remove('hidden'); document.getElementById('trainBtn').disabled = true; document.getElementById('stopBtn').disabled = false; this.isTraining = true; this.lossHistory = []; this.accuracyHistory = []; this.valLossHistory = []; this.valAccuracyHistory = []; this.bestValLoss = Infinity; this.patienceCounter = 0; this.bestWeights = null; // 从界面获取超参数 const epochs = parseInt(document.getElementById('epochs').value); const batchSize = parseInt(document.getElementById('batchSize').value); const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked; const patience = parseInt(document.getElementById('patience').value); this.temperature = parseFloat(document.getElementById('temperature').value); console.log('Training parameters:', { epochs, batchSize, enableEarlyStopping, patience, temperature: this.temperature }); try { const startTime = Date.now(); await this.model.fit(xsTrainNorm, ysTrain, { epochs, batchSize, validationData: [xsValNorm, ysVal], shuffle: true, // 确保数据被打乱 classWeight: classWeights, // 应用类别权重 callbacks: { onEpochEnd: (epoch, logs) => { const progress = ((epoch + 1) / epochs) * 100; document.getElementById('progressFill').style.width = `${progress}%`; document.getElementById('progressFill').textContent = `${Math.round(progress)}%`; this.lossHistory.push(logs.loss); this.accuracyHistory.push(logs.acc); this.valLossHistory.push(logs.val_loss); this.valAccuracyHistory.push(logs.val_acc); // 早停逻辑 const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked; const patience = parseInt(document.getElementById('patience').value); if (enableEarlyStopping) { const improvement = this.bestValLoss - logs.val_loss; const improvementPercent = (improvement / this.bestValLoss) * 100; // 详细的早停分析 console.log(`\n=== Epoch ${epoch + 1} 早停分析 ===`); console.log(`当前验证损失: ${logs.val_loss.toFixed(4)}`); console.log(`最佳验证损失: ${this.bestValLoss.toFixed(4)}`); console.log(`改善: ${improvement.toFixed(4)} (${improvementPercent.toFixed(2)}%)`); console.log(`耐心计数器: ${this.patienceCounter}/${patience}`); // 只有显著改善才重置计数器(改善超过0.1%) if (improvement > 0.001 && improvementPercent > 0.1) { this.bestValLoss = logs.val_loss; this.patienceCounter = 0; // 保存最佳权重 if (this.bestWeights) { this.bestWeights.forEach(w => w.dispose()); } this.bestWeights = this.model.getWeights().map(w => w.clone()); console.log(`✅ 新的最佳验证损失! 重置耐心计数器`); } else { this.patienceCounter++; console.log(`⚠️ 没有显著改善,耐心计数器增加`); // 分析为什么没有改善 if (logs.val_loss > logs.loss * 1.5) { console.log(`📊 可能过拟合:验证损失 (${logs.val_loss.toFixed(4)}) >> 训练损失 (${logs.loss.toFixed(4)})`); } if (logs.val_acc < logs.acc - 0.1) { console.log(`📊 验证准确率低:验证 ${(logs.val_acc * 100).toFixed(1)}% vs 训练 ${(logs.acc * 100).toFixed(1)}%`); } if (this.patienceCounter >= patience) { console.log(`\n🛑 早停触发!原因分析:`); console.log(`- 连续 ${patience} 个 epochs 没有改善`); console.log(`- 最终训练损失: ${logs.loss.toFixed(4)}`); console.log(`- 最终验证损失: ${logs.val_loss.toFixed(4)}`); console.log(`- 最终训练准确率: ${(logs.acc * 100).toFixed(1)}%`); console.log(`- 最终验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`); // 提供建议 if (logs.val_loss > logs.loss * 1.5) { console.log(`\n💡 建议:模型过拟合`); console.log(` - 增加 Dropout 率`); console.log(` - 增加 L2 正则化`); console.log(` - 减少模型复杂度`); console.log(` - 增加更多训练数据`); } else if (logs.val_acc < 0.5) { console.log(`\n💡 建议:模型欠拟合`); console.log(` - 增加训练轮数`); console.log(` - 增大学习率`); console.log(` - 增加模型复杂度`); } this.isTraining = false; this.model.stopTraining = true; } } console.log(`======================\n`); } // 计算剩余时间 const elapsedTime = Date.now() - startTime; const timePerEpoch = elapsedTime / (epoch + 1); const remainingTime = timePerEpoch * (epochs - epoch - 1); const remainingSeconds = Math.round(remainingTime / 1000); // 检测可能的过拟合 const overfitGap = logs.acc - logs.val_acc; if (overfitGap > 0.2) { console.warn(`⚠️ 可能过拟合!训练准确率: ${(logs.acc * 100).toFixed(1)}%, 验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`); } // 如果训练准确率达到100%,发出警告 if (logs.acc >= 0.99) { console.warn('⚠️ 训练准确率达到100%,模型可能过拟合!考虑:'); console.warn(' 1. 增加更多训练数据'); console.warn(' 2. 增大Dropout率'); console.warn(' 3. 增大L2正则化'); console.warn(' 4. 减少模型复杂度'); } this.updateMetrics({ epoch: epoch + 1, totalEpochs: epochs, loss: logs.loss, accuracy: logs.acc, valLoss: logs.val_loss, valAccuracy: logs.val_acc, remainingTime: remainingSeconds }); this.plotLossChart(); // 更新训练状态 const statusMsg = `训练中 - Epoch ${epoch + 1}/${epochs} | ` + `准确率: ${(logs.acc * 100).toFixed(1)}% | ` + `剩余时间: ${this.formatTime(remainingSeconds)}`; this.showStatus('trainingStatus', 'info', statusMsg); if (!this.isTraining) { this.model.stopTraining = true; } }, onBatchEnd: (batch, logs) => { // 可选:显示批次进度 if (batch % 5 === 0) { console.log(`Batch ${batch}: loss = ${logs.loss.toFixed(4)}`); } } } }); // 如果启用了早停并且保存了最佳权重,恢复它们 if (this.bestWeights && document.getElementById('earlyStoppingCheck').checked) { console.log('恢复最佳模型权重...'); this.model.setWeights(this.bestWeights); } this.showStatus('trainingStatus', 'success', '模型训练完成!'); } catch (error) { this.showStatus('trainingStatus', 'error', `训练失败: ${error.message}`); } finally { xsTrainNorm.dispose(); ysTrain.dispose(); xsValNorm.dispose(); ysVal.dispose(); document.getElementById('trainBtn').disabled = false; document.getElementById('stopBtn').disabled = true; this.isTraining = false; } } stopTraining() { this.isTraining = false; this.showStatus('trainingStatus', 'info', '正在停止训练...'); } formatTime(seconds) { if (seconds < 60) { return `${seconds} 秒`; } else if (seconds < 3600) { const minutes = Math.floor(seconds / 60); const secs = seconds % 60; return `${minutes} 分 ${secs} 秒`; } else { const hours = Math.floor(seconds / 3600); const minutes = Math.floor((seconds % 3600) / 60); return `${hours} 小时 ${minutes} 分`; } } updateMetrics(metrics) { const container = document.getElementById('metricsContainer'); container.innerHTML = `
Epoch
${metrics.epoch}/${metrics.totalEpochs}
损失
${metrics.loss.toFixed(4)}
准确率
${(metrics.accuracy * 100).toFixed(1)}%
验证损失
${metrics.valLoss.toFixed(4)}
验证准确率
${(metrics.valAccuracy * 100).toFixed(1)}%
${metrics.remainingTime !== undefined ? `
剩余时间
${this.formatTime(metrics.remainingTime)}
` : ''} `; } plotLossChart() { const canvas = document.getElementById('lossChart'); const ctx = canvas.getContext('2d'); // 设置canvas尺寸 canvas.width = canvas.offsetWidth; canvas.height = 300; // 清除画布 ctx.clearRect(0, 0, canvas.width, canvas.height); if (this.lossHistory.length === 0) return; const padding = 40; const chartWidth = canvas.width - 2 * padding; const chartHeight = canvas.height - 2 * padding; // 找到最大值 const maxLoss = Math.max(...this.lossHistory); // 绘制背景网格 ctx.strokeStyle = '#e0e0e0'; ctx.lineWidth = 1; for (let i = 0; i <= 5; i++) { const y = padding + (i * chartHeight / 5); ctx.beginPath(); ctx.moveTo(padding, y); ctx.lineTo(canvas.width - padding, y); ctx.stroke(); } // 绘制坐标轴 ctx.strokeStyle = '#333'; ctx.lineWidth = 2; ctx.beginPath(); ctx.moveTo(padding, padding); ctx.lineTo(padding, canvas.height - padding); ctx.lineTo(canvas.width - padding, canvas.height - padding); ctx.stroke(); // 绘制损失曲线 ctx.strokeStyle = '#f56565'; ctx.lineWidth = 2; ctx.beginPath(); for (let i = 0; i < this.lossHistory.length; i++) { const x = padding + (i / (this.lossHistory.length - 1)) * chartWidth; const y = canvas.height - padding - (this.lossHistory[i] / maxLoss) * chartHeight * 0.9; if (i === 0) { ctx.moveTo(x, y); } else { ctx.lineTo(x, y); } } ctx.stroke(); // 绘制准确率曲线 ctx.strokeStyle = '#48bb78'; ctx.lineWidth = 2; ctx.beginPath(); for (let i = 0; i < this.accuracyHistory.length; i++) { const x = padding + (i / (this.accuracyHistory.length - 1)) * chartWidth; const y = canvas.height - padding - this.accuracyHistory[i] * chartHeight * 0.9; if (i === 0) { ctx.moveTo(x, y); } else { ctx.lineTo(x, y); } } ctx.stroke(); // 添加图例 ctx.fillStyle = '#f56565'; ctx.fillRect(canvas.width - 120, 20, 15, 15); ctx.fillStyle = '#333'; ctx.font = '12px Arial'; ctx.fillText('损失', canvas.width - 100, 32); ctx.fillStyle = '#48bb78'; ctx.fillRect(canvas.width - 120, 40, 15, 15); ctx.fillStyle = '#333'; ctx.fillText('准确率', canvas.width - 100, 52); } async startWebcam() { if (!this.model) { this.showStatus('predictionStatus', 'error', '请先训练模型!'); return; } const video = document.getElementById('webcam'); try { const stream = await navigator.mediaDevices.getUserMedia({ video: { facingMode: 'user' }, audio: false }); video.srcObject = stream; this.webcamStream = stream; document.getElementById('startWebcamBtn').disabled = true; document.getElementById('stopWebcamBtn').disabled = false; // 等待视频加载 video.addEventListener('loadeddata', () => { this.isPredicting = true; this.predictLoop(); }); this.showStatus('predictionStatus', 'success', '摄像头已启动'); } catch (error) { this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); } } stopWebcam() { if (this.webcamStream) { this.webcamStream.getTracks().forEach(track => track.stop()); this.webcamStream = null; } this.isPredicting = false; const video = document.getElementById('webcam'); video.srcObject = null; document.getElementById('startWebcamBtn').disabled = false; document.getElementById('stopWebcamBtn').disabled = true; document.getElementById('confidenceBars').innerHTML = ''; this.showStatus('predictionStatus', 'info', '摄像头已停止'); } async predictLoop() { if (!this.isPredicting) return; const video = document.getElementById('webcam'); if (video.readyState === 4) { try { // 提取特征 - 注意这里返回的是3D张量 const features = await this.extractFeatures(video); // 需要添加批次维度进行预测 [batch, height, width, channels] let featuresWithBatch = features.expandDims(0); // 如果有保存的均值和标准差,进行标准化 if (this.dataMean && this.dataStd) { featuresWithBatch = featuresWithBatch.sub(this.dataMean).div(this.dataStd.add(1e-7)); } // 进行预测 let logits = this.model.predict(featuresWithBatch); // 应用温度缩放 const temperature = parseFloat(document.getElementById('temperature').value) || 1.0; if (temperature !== 1.0) { logits = logits.div(temperature); } // 转换为概率 const prediction = await tf.softmax(logits).data(); // 更新置信度显示 this.updateConfidenceBars(prediction, temperature); // 清理张量 features.dispose(); featuresWithBatch.dispose(); logits.dispose(); } catch (error) { console.error('预测错误:', error); } } // 继续预测循环 requestAnimationFrame(() => this.predictLoop()); } updateConfidenceBars(predictions, temperature = 1.0) { const container = document.getElementById('confidenceBars'); let html = ''; const classPredictions = []; // 计算最大置信度和熵 const maxConfidence = Math.max(...predictions); const maxConfidenceIndex = predictions.indexOf(maxConfidence); const entropy = -predictions.reduce((sum, p) => sum + (p > 0 ? p * Math.log(p) : 0), 0); const maxEntropy = Math.log(predictions.length); // 最大熵(均匀分布) const uncertaintyScore = entropy / maxEntropy; // 0-1之间,越高越不确定 // 保持原始顺序,不排序 for (let i = 0; i < predictions.length; i++) { classPredictions.push({ className: this.classNames[i], confidence: predictions[i], index: i, isMax: i === maxConfidenceIndex }); } // 如果不确定性很高(熵接近最大值),显示警告 let warningMessage = ''; if (uncertaintyScore > 0.8) { warningMessage = `
⚠️ 不确定性高 - 模型对预测结果不确定
可能是未见过的类别或模糊图像
`; } else if (maxConfidence < 0.5) { warningMessage = `
置信度低 - 没有明显匹配的类别
这可能不是训练过的类别
`; } html = warningMessage; // 显示温度信息 if (temperature !== 1.0) { html += `
🌡️ 温度: ${temperature.toFixed(1)} (越高越保守,越低越自信)
`; } // 按原始顺序显示,但标记最高置信度 classPredictions.forEach((pred) => { const percentage = (pred.confidence * 100).toFixed(1); const barWidth = Math.max(percentage, 5); // 最小宽度5%以显示标签 // 根据置信度设置颜色 let barColor = 'linear-gradient(90deg, #667eea, #764ba2)'; if (pred.confidence < 0.3) { barColor = 'linear-gradient(90deg, #f56565, #e53e3e)'; } else if (pred.confidence < 0.6) { barColor = 'linear-gradient(90deg, #ed8936, #dd6b20)'; } else if (pred.confidence > 0.8) { barColor = 'linear-gradient(90deg, #48bb78, #38a169)'; } // 为最高置信度的类别添加特殊样式 const isWinner = pred.isMax && pred.confidence > 0.5; const borderStyle = isWinner ? 'border: 2px solid #48bb78; padding: 3px; border-radius: 5px;' : ''; html += `
${pred.className} ${isWinner ? '👑' : ''} ${percentage}%
${percentage > 20 ? percentage + '%' : ''}
`; }); // 添加不确定性分数 html += `
统计信息:
最高置信度: ${(maxConfidence * 100).toFixed(1)}%
不确定性: ${(uncertaintyScore * 100).toFixed(1)}%
熵: ${entropy.toFixed(3)} / ${maxEntropy.toFixed(3)}
`; container.innerHTML = html; } async saveModel() { if (!this.model) { this.showStatus('predictionStatus', 'error', '没有可保存的模型'); return; } try { // 保存模型和类别名称 await this.model.save('downloads://custom-image-classifier'); // 保存类别信息和标准化参数 const classInfo = { classNames: this.classNames, dataMean: this.dataMean ? await this.dataMean.data() : null, dataStd: this.dataStd ? await this.dataStd.data() : null, date: new Date().toISOString() }; const blob = new Blob([JSON.stringify(classInfo, null, 2)], { type: 'application/json' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; a.download = 'class-info.json'; a.click(); URL.revokeObjectURL(url); this.showStatus('predictionStatus', 'success', '模型已保存'); } catch (error) { this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`); } } async loadModel() { const modelInput = document.createElement('input'); modelInput.type = 'file'; modelInput.multiple = true; modelInput.accept = '.json,.bin'; const classInput = document.createElement('input'); classInput.type = 'file'; classInput.accept = '.json'; modelInput.onchange = async (e) => { try { const files = e.target.files; this.model = await tf.loadLayersModel(tf.io.browserFiles(files)); this.showStatus('predictionStatus', 'info', '模型已加载,请选择类别信息文件'); // 加载类别信息 classInput.click(); } catch (error) { this.showStatus('predictionStatus', 'error', `加载模型失败: ${error.message}`); } }; classInput.onchange = async (e) => { try { const file = e.target.files[0]; const text = await file.text(); const classInfo = JSON.parse(text); this.classNames = classInfo.classNames; // 恢复标准化参数 if (classInfo.dataMean && classInfo.dataStd) { this.dataMean = tf.tensor(classInfo.dataMean); this.dataStd = tf.tensor(classInfo.dataStd); console.log('已恢复数据标准化参数'); } this.showStatus('predictionStatus', 'success', `模型加载成功!类别: ${this.classNames.join(', ')}`); } catch (error) { this.showStatus('predictionStatus', 'error', `加载类别信息失败: ${error.message}`); } }; modelInput.click(); } resetHyperparameters() { document.getElementById('learningRate').value = -3; document.getElementById('learningRateValue').textContent = '0.001'; document.getElementById('epochs').value = 100; document.getElementById('epochsValue').textContent = '100'; document.getElementById('dropoutRate').value = 0.3; document.getElementById('dropoutValue').textContent = '0.3'; document.getElementById('l2Regularization').value = -2; document.getElementById('l2Value').textContent = '0.01'; document.getElementById('batchSize').value = 32; document.getElementById('batchSizeValue').textContent = '32'; document.getElementById('temperature').value = 1.0; document.getElementById('temperatureValue').textContent = '1.0'; document.getElementById('earlyStoppingCheck').checked = true; document.getElementById('patience').value = 10; document.getElementById('patienceValue').textContent = '10'; this.showStatus('trainingStatus', 'info', '超参数已重置为默认值'); } showRecommendations() { const message = ` 💡 超参数设置建议: 如果过拟合(训练准确率高但验证准确率低): - 增大 Dropout 率 (0.4-0.6) - 增大 L2 正则化 (0.01-0.1) - 减少训练轮数 - 启用早停 - 增加更多训练数据 如果欠拟合(训练和验证准确率都低): - 减小 Dropout 率 (0.1-0.2) - 减小 L2 正则化 (0.0001-0.001) - 增加训练轮数 - 增大学习率 (0.001-0.01) 温度缩放: - < 1.0: 更自信的预测(可能过度自信) - = 1.0: 标准预测 - > 1.0: 更保守的预测(减少过度自信) `; alert(message); } showStatus(elementId, type, message) { const element = document.getElementById(elementId); const classMap = { 'success': 'status-success', 'error': 'status-error', 'info': 'status-info', 'warning': 'status-warning' }; element.className = `status-message ${classMap[type] || 'status-info'}`; element.textContent = message; } } // 初始化应用 let classifier; document.addEventListener('DOMContentLoaded', () => { classifier = new CustomImageClassifier(); });