mobileNet/废弃/三层神经网络/custom-classifier.js

1174 lines
49 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// 自定义图像分类器应用
class CustomImageClassifier {
constructor() {
this.mobilenet = null;
this.model = null;
this.classNames = [];
this.trainingDataInputs = [];
this.trainingDataOutputs = [];
this.isTraining = false;
this.webcamStream = null;
this.isPredicting = false;
this.lossHistory = [];
this.accuracyHistory = [];
this.valLossHistory = [];
this.valAccuracyHistory = [];
this.dataMean = null; // 存储训练数据的均值
this.dataStd = null; // 存储训练数据的标准差
this.temperature = 1.0; // 温度缩放参数
this.bestValLoss = Infinity;
this.patienceCounter = 0;
this.bestWeights = null;
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
// 加载预训练的 MobileNet 模型
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index + 1);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.addToDataset());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('trainBtn').addEventListener('click', () => this.trainModel());
document.getElementById('stopBtn').addEventListener('click', () => this.stopTraining());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const previewContainer = document.getElementById(`class${classIndex}Preview`);
const countElement = document.getElementById(`class${classIndex}Count`);
previewContainer.innerHTML = '';
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
countElement.textContent = `${files.length} 张图片`;
}
async addToDataset() {
const classes = [];
const imagesByClass = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files.length > 0) {
classes.push(className);
imagesByClass.push(files);
}
}
if (classes.length < 2) {
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
return;
}
// 检查是否已有数据,如果有,询问是否追加
if (this.trainingDataInputs.length > 0) {
// 清空旧数据
this.clearDataset();
}
this.classNames = classes;
// 计算总图片数量和每个类别的数量
let totalToProcess = 0;
const classImageCounts = [];
imagesByClass.forEach((files, idx) => {
totalToProcess += files.length;
classImageCounts.push({
className: classes[idx],
count: files.length
});
});
// 检查数据平衡性
const minImages = Math.min(...classImageCounts.map(c => c.count));
const maxImages = Math.max(...classImageCounts.map(c => c.count));
if (maxImages > minImages * 3) {
console.warn('警告:数据集不平衡!');
classImageCounts.forEach(c => {
console.log(`${c.className}: ${c.count} 张图片`);
});
this.showStatus('dataStatus', 'warning',
`警告:数据集不平衡!建议每个类别的图片数量相近。`);
await new Promise(resolve => setTimeout(resolve, 2000)); // 显示警告2秒
}
this.showStatus('dataStatus', 'info', `正在处理 ${totalToProcess} 张图片...`);
let totalImages = 0;
let processedImages = 0;
const classSampleCounts = new Array(classes.length).fill(0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imagesByClass[classIndex];
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
console.log(`处理图片 ${fileIndex + 1}/${files.length} - ${classes[classIndex]}`);
const img = await this.loadImage(files[fileIndex]);
// 确保图片加载成功
if (!img || !img.width || !img.height) {
console.error('图片加载失败或无效:', files[fileIndex].name);
processedImages++;
continue;
}
// 为每张图片提取特征(原始版本)
const features = await this.extractFeatures(img, false);
this.trainingDataInputs.push(features);
this.trainingDataOutputs.push(classIndex);
// 如果启用数据增强,生成额外的增强样本
if (document.getElementById('dataAugmentationCheck') &&
document.getElementById('dataAugmentationCheck').checked) {
try {
// 生成1个增强版本
const augFeatures = await this.extractFeatures(img, true);
this.trainingDataInputs.push(augFeatures);
this.trainingDataOutputs.push(classIndex);
totalImages++; // 增强样本也计入总数
classSampleCounts[classIndex]++;
} catch (augError) {
console.warn('数据增强失败,跳过增强样本:', augError);
}
}
totalImages++;
processedImages++;
classSampleCounts[classIndex]++;
// 更新处理进度
const progressPercent = Math.round((processedImages / totalToProcess) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${processedImages}/${totalToProcess} (${progressPercent}%) - ${classes[classIndex]}: ${classSampleCounts[classIndex]}`);
// 只有当img是HTMLImageElement时才调用remove
if (img && img.remove) {
img.remove();
}
} catch (error) {
console.error(`处理图片失败 (${files[fileIndex].name}):`, error);
processedImages++;
}
}
}
// 显示最终统计
let statsMessage = `成功添加 ${totalImages} 张图片到数据集!\n`;
classes.forEach((className, idx) => {
statsMessage += `${className}: ${classSampleCounts[idx]}\n`;
});
this.showStatus('dataStatus', 'success', statsMessage);
// 打印详细统计到控制台
console.log('数据集统计:');
console.log('总样本数:', totalImages);
classes.forEach((className, idx) => {
console.log(`${className}: ${classSampleCounts[idx]} 张 (${(classSampleCounts[idx]/totalImages*100).toFixed(1)}%)`);
});
}
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
async extractFeatures(img, augment = false) {
try {
// 如果启用数据增强(仅在训练时)
if (augment && document.getElementById('dataAugmentationCheck') &&
document.getElementById('dataAugmentationCheck').checked) {
// 随机增强数据
const augmented = await this.augmentImage(img);
const activation = this.mobilenet.infer(augmented, 'conv_pw_13_relu');
// 如果augmented是canvas元素不需要dispose
if (augmented !== img && augmented.tagName === 'CANVAS') {
// Canvas元素不需要dispose
}
const shape = activation.shape;
if (shape.length === 4 && shape[0] === 1) {
return activation.squeeze([0]);
}
return activation;
}
// 正常提取特征
const activation = this.mobilenet.infer(img, 'conv_pw_13_relu');
// 检查张量形状
const shape = activation.shape;
// 如果是4D张量 [batch, height, width, channels],去掉批次维度
if (shape.length === 4 && shape[0] === 1) {
return activation.squeeze([0]); // 返回3D张量
}
// 如果已经是3D张量直接返回
return activation;
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
async augmentImage(img) {
try {
// 简单的数据增强:随机噪声
const imgTensor = tf.browser.fromPixels(img);
// 添加少量随机噪声
const noise = tf.randomNormal(imgTensor.shape, 0, 5); // 减少噪声强度
const augmented = imgTensor.add(noise).clipByValue(0, 255);
// 转换回图像格式
const canvas = document.createElement('canvas');
canvas.width = img.width || img.videoWidth || 224;
canvas.height = img.height || img.videoHeight || 224;
// 归一化后转换为图像
const normalized = augmented.div(255);
await tf.browser.toPixels(normalized, canvas);
// 清理张量
imgTensor.dispose();
noise.dispose();
augmented.dispose();
normalized.dispose();
return canvas;
} catch (error) {
console.error('数据增强失败:', error);
// 如果增强失败,返回原图
return img;
}
}
clearDataset() {
this.trainingDataInputs.forEach(tensor => tensor.dispose());
this.trainingDataInputs = [];
this.trainingDataOutputs = [];
this.classNames = [];
// 清空标准化参数
if (this.dataMean) {
this.dataMean.dispose();
this.dataMean = null;
}
if (this.dataStd) {
this.dataStd.dispose();
this.dataStd = null;
}
// 清空文件输入和预览
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Preview`).innerHTML = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
}
this.showStatus('dataStatus', 'info', '数据集已清空');
}
createCustomModel() {
const numClasses = this.classNames.length;
// 获取超参数
const dropoutRate = parseFloat(document.getElementById('dropoutRate').value);
const l2Reg = Math.pow(10, parseFloat(document.getElementById('l2Regularization').value));
const learningRate = Math.pow(10, parseFloat(document.getElementById('learningRate').value));
// 获取输入形状 - 现在应该是3D张量
const inputShape = this.trainingDataInputs[0].shape;
console.log('Input shape:', inputShape);
console.log('Hyperparameters:', { learningRate, dropoutRate, l2Reg });
// 创建更深的模型以提高区分能力
this.model = tf.sequential({
layers: [
tf.layers.flatten({
inputShape: inputShape
}),
// 第一层:更多神经元
tf.layers.dense({
units: 128, // 减少神经元数量,降低模型复杂度
activation: 'relu',
kernelInitializer: 'heNormal',
kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }),
biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 })
}),
tf.layers.batchNormalization(),
tf.layers.dropout({ rate: dropoutRate, seed: 42 }), // 添加seed使dropout更一致
// 第二层
tf.layers.dense({
units: 64, // 减少神经元数量
activation: 'relu',
kernelInitializer: 'heNormal',
kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }),
biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 })
}),
tf.layers.batchNormalization(),
tf.layers.dropout({ rate: dropoutRate, seed: 43 }),
// 第三层
tf.layers.dense({
units: 32, // 进一步减少神经元
activation: 'relu',
kernelInitializer: 'heNormal',
kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }),
biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 })
}),
tf.layers.dropout({ rate: dropoutRate * 0.7, seed: 44 }),
// 输出层
tf.layers.dense({
units: numClasses,
activation: 'softmax',
kernelInitializer: 'glorotNormal'
})
]
});
// 使用更合适的优化器和学习率
this.model.compile({
optimizer: tf.train.adam(learningRate),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
console.log('Model created successfully');
this.model.summary();
}
async trainModel() {
if (this.trainingDataInputs.length === 0) {
this.showStatus('trainingStatus', 'error', '请先添加训练数据!');
return;
}
// 检查数据平衡性
const classCounts = {};
this.trainingDataOutputs.forEach(label => {
classCounts[label] = (classCounts[label] || 0) + 1;
});
console.log('每个类别的样本数量:', classCounts);
console.log('类别名称:', this.classNames);
// 计算类别权重以平衡数据
const totalSamples = this.trainingDataOutputs.length;
const numClasses = this.classNames.length;
const classWeights = {};
for (let i = 0; i < numClasses; i++) {
const classCount = classCounts[i] || 1;
// 使用反比例权重:样本少的类别权重更高
classWeights[i] = totalSamples / (numClasses * classCount);
}
console.log('类别权重:', classWeights);
this.showStatus('trainingStatus', 'info', '准备训练模型...');
// 创建模型
this.createCustomModel();
// 准备训练数据
let xs = tf.stack(this.trainingDataInputs);
const ys = tf.oneHot(tf.tensor1d(this.trainingDataOutputs, 'int32'), this.classNames.length);
// 手动分割训练集和验证集而不是使用validationSplit
const totalDataCount = xs.shape[0]; // 改名避免冲突
const indices = tf.util.createShuffledIndices(totalDataCount);
const splitIdx = Math.floor(totalDataCount * 0.8);
// 确保indices是TypedArray
const trainIndices = Array.from(indices).slice(0, splitIdx);
const valIndices = Array.from(indices).slice(splitIdx);
// 将indices转换为Tensor
const trainIndicesTensor = tf.tensor1d(trainIndices, 'int32');
const valIndicesTensor = tf.tensor1d(valIndices, 'int32');
// 使用gather来获取对应的数据
const xsTrain = tf.gather(xs, trainIndicesTensor);
const ysTrain = tf.gather(ys, trainIndicesTensor);
const xsVal = tf.gather(xs, valIndicesTensor);
const ysVal = tf.gather(ys, valIndicesTensor);
// 清理索引张量
trainIndicesTensor.dispose();
valIndicesTensor.dispose();
// 只在训练集上计算标准化参数
this.dataMean = xsTrain.mean([0], true);
this.dataStd = xsTrain.sub(this.dataMean).square().mean([0], true).sqrt();
// 应用标准化到训练集和验证集
const xsTrainNorm = xsTrain.sub(this.dataMean).div(this.dataStd.add(1e-7));
const xsValNorm = xsVal.sub(this.dataMean).div(this.dataStd.add(1e-7));
// 清理临时张量
xsTrain.dispose();
xsVal.dispose();
xs.dispose();
ys.dispose();
// 详细的数据集分析
console.log('\n📊 === 数据集分析 ===');
console.log(`训练集: ${trainIndices.length} 样本`);
console.log(`验证集: ${valIndices.length} 样本`);
console.log(`训练/验证比例: ${(trainIndices.length / totalDataCount * 100).toFixed(1)}% / ${(valIndices.length / totalDataCount * 100).toFixed(1)}%`);
// 分析每个类别在训练集和验证集中的分布
const trainLabels = trainIndices.map(i => this.trainingDataOutputs[i]);
const valLabels = valIndices.map(i => this.trainingDataOutputs[i]);
console.log('\n类别分布:');
for (let i = 0; i < this.classNames.length; i++) {
const trainCount = trainLabels.filter(l => l === i).length;
const valCount = valLabels.filter(l => l === i).length;
console.log(` ${this.classNames[i]}:`);
console.log(` - 训练集: ${trainCount} (${(trainCount/trainIndices.length*100).toFixed(1)}%)`);
console.log(` - 验证集: ${valCount} (${(valCount/valIndices.length*100).toFixed(1)}%)`);
}
// 警告检查
if (valIndices.length < 5) {
console.warn('⚠️ 验证集太小建议至少5个样本');
}
if (valIndices.length < totalDataCount * 0.1) {
console.warn('⚠️ 验证集比例太小建议至少10%的数据用于验证');
}
console.log('===================\n');
console.log('数据形状:', xs.shape);
console.log('均值形状:', this.dataMean.shape);
console.log('标准差形状:', this.dataStd.shape);
// 打印数据集统计信息
const uniqueLabels = [...new Set(this.trainingDataOutputs)];
console.log('训练数据统计:');
uniqueLabels.forEach(label => {
const count = this.trainingDataOutputs.filter(l => l === label).length;
console.log(` ${this.classNames[label]}: ${count} 张图片`);
});
// 显示进度条
document.getElementById('trainingProgress').classList.remove('hidden');
document.getElementById('trainBtn').disabled = true;
document.getElementById('stopBtn').disabled = false;
this.isTraining = true;
this.lossHistory = [];
this.accuracyHistory = [];
this.valLossHistory = [];
this.valAccuracyHistory = [];
this.bestValLoss = Infinity;
this.patienceCounter = 0;
this.bestWeights = null;
// 从界面获取超参数
const epochs = parseInt(document.getElementById('epochs').value);
const batchSize = parseInt(document.getElementById('batchSize').value);
const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked;
const patience = parseInt(document.getElementById('patience').value);
this.temperature = parseFloat(document.getElementById('temperature').value);
console.log('Training parameters:', { epochs, batchSize, enableEarlyStopping, patience, temperature: this.temperature });
try {
const startTime = Date.now();
await this.model.fit(xsTrainNorm, ysTrain, {
epochs,
batchSize,
validationData: [xsValNorm, ysVal],
shuffle: true, // 确保数据被打乱
classWeight: classWeights, // 应用类别权重
callbacks: {
onEpochEnd: (epoch, logs) => {
const progress = ((epoch + 1) / epochs) * 100;
document.getElementById('progressFill').style.width = `${progress}%`;
document.getElementById('progressFill').textContent = `${Math.round(progress)}%`;
this.lossHistory.push(logs.loss);
this.accuracyHistory.push(logs.acc);
this.valLossHistory.push(logs.val_loss);
this.valAccuracyHistory.push(logs.val_acc);
// 早停逻辑
const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked;
const patience = parseInt(document.getElementById('patience').value);
if (enableEarlyStopping) {
const improvement = this.bestValLoss - logs.val_loss;
const improvementPercent = (improvement / this.bestValLoss) * 100;
// 详细的早停分析
console.log(`\n=== Epoch ${epoch + 1} 早停分析 ===`);
console.log(`当前验证损失: ${logs.val_loss.toFixed(4)}`);
console.log(`最佳验证损失: ${this.bestValLoss.toFixed(4)}`);
console.log(`改善: ${improvement.toFixed(4)} (${improvementPercent.toFixed(2)}%)`);
console.log(`耐心计数器: ${this.patienceCounter}/${patience}`);
// 只有显著改善才重置计数器改善超过0.1%
if (improvement > 0.001 && improvementPercent > 0.1) {
this.bestValLoss = logs.val_loss;
this.patienceCounter = 0;
// 保存最佳权重
if (this.bestWeights) {
this.bestWeights.forEach(w => w.dispose());
}
this.bestWeights = this.model.getWeights().map(w => w.clone());
console.log(`✅ 新的最佳验证损失! 重置耐心计数器`);
} else {
this.patienceCounter++;
console.log(`⚠️ 没有显著改善,耐心计数器增加`);
// 分析为什么没有改善
if (logs.val_loss > logs.loss * 1.5) {
console.log(`📊 可能过拟合:验证损失 (${logs.val_loss.toFixed(4)}) >> 训练损失 (${logs.loss.toFixed(4)})`);
}
if (logs.val_acc < logs.acc - 0.1) {
console.log(`📊 验证准确率低:验证 ${(logs.val_acc * 100).toFixed(1)}% vs 训练 ${(logs.acc * 100).toFixed(1)}%`);
}
if (this.patienceCounter >= patience) {
console.log(`\n🛑 早停触发!原因分析:`);
console.log(`- 连续 ${patience} 个 epochs 没有改善`);
console.log(`- 最终训练损失: ${logs.loss.toFixed(4)}`);
console.log(`- 最终验证损失: ${logs.val_loss.toFixed(4)}`);
console.log(`- 最终训练准确率: ${(logs.acc * 100).toFixed(1)}%`);
console.log(`- 最终验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`);
// 提供建议
if (logs.val_loss > logs.loss * 1.5) {
console.log(`\n💡 建议:模型过拟合`);
console.log(` - 增加 Dropout 率`);
console.log(` - 增加 L2 正则化`);
console.log(` - 减少模型复杂度`);
console.log(` - 增加更多训练数据`);
} else if (logs.val_acc < 0.5) {
console.log(`\n💡 建议:模型欠拟合`);
console.log(` - 增加训练轮数`);
console.log(` - 增大学习率`);
console.log(` - 增加模型复杂度`);
}
this.isTraining = false;
this.model.stopTraining = true;
}
}
console.log(`======================\n`);
}
// 计算剩余时间
const elapsedTime = Date.now() - startTime;
const timePerEpoch = elapsedTime / (epoch + 1);
const remainingTime = timePerEpoch * (epochs - epoch - 1);
const remainingSeconds = Math.round(remainingTime / 1000);
// 检测可能的过拟合
const overfitGap = logs.acc - logs.val_acc;
if (overfitGap > 0.2) {
console.warn(`⚠️ 可能过拟合!训练准确率: ${(logs.acc * 100).toFixed(1)}%, 验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`);
}
// 如果训练准确率达到100%,发出警告
if (logs.acc >= 0.99) {
console.warn('⚠️ 训练准确率达到100%,模型可能过拟合!考虑:');
console.warn(' 1. 增加更多训练数据');
console.warn(' 2. 增大Dropout率');
console.warn(' 3. 增大L2正则化');
console.warn(' 4. 减少模型复杂度');
}
this.updateMetrics({
epoch: epoch + 1,
totalEpochs: epochs,
loss: logs.loss,
accuracy: logs.acc,
valLoss: logs.val_loss,
valAccuracy: logs.val_acc,
remainingTime: remainingSeconds
});
this.plotLossChart();
// 更新训练状态
const statusMsg = `训练中 - Epoch ${epoch + 1}/${epochs} | ` +
`准确率: ${(logs.acc * 100).toFixed(1)}% | ` +
`剩余时间: ${this.formatTime(remainingSeconds)}`;
this.showStatus('trainingStatus', 'info', statusMsg);
if (!this.isTraining) {
this.model.stopTraining = true;
}
},
onBatchEnd: (batch, logs) => {
// 可选:显示批次进度
if (batch % 5 === 0) {
console.log(`Batch ${batch}: loss = ${logs.loss.toFixed(4)}`);
}
}
}
});
// 如果启用了早停并且保存了最佳权重,恢复它们
if (this.bestWeights && document.getElementById('earlyStoppingCheck').checked) {
console.log('恢复最佳模型权重...');
this.model.setWeights(this.bestWeights);
}
this.showStatus('trainingStatus', 'success', '模型训练完成!');
} catch (error) {
this.showStatus('trainingStatus', 'error', `训练失败: ${error.message}`);
} finally {
xsTrainNorm.dispose();
ysTrain.dispose();
xsValNorm.dispose();
ysVal.dispose();
document.getElementById('trainBtn').disabled = false;
document.getElementById('stopBtn').disabled = true;
this.isTraining = false;
}
}
stopTraining() {
this.isTraining = false;
this.showStatus('trainingStatus', 'info', '正在停止训练...');
}
formatTime(seconds) {
if (seconds < 60) {
return `${seconds}`;
} else if (seconds < 3600) {
const minutes = Math.floor(seconds / 60);
const secs = seconds % 60;
return `${minutes}${secs}`;
} else {
const hours = Math.floor(seconds / 3600);
const minutes = Math.floor((seconds % 3600) / 60);
return `${hours} 小时 ${minutes}`;
}
}
updateMetrics(metrics) {
const container = document.getElementById('metricsContainer');
container.innerHTML = `
<div class="metric-card">
<div class="metric-label">Epoch</div>
<div class="metric-value">${metrics.epoch}/${metrics.totalEpochs}</div>
</div>
<div class="metric-card">
<div class="metric-label">损失</div>
<div class="metric-value">${metrics.loss.toFixed(4)}</div>
</div>
<div class="metric-card">
<div class="metric-label">准确率</div>
<div class="metric-value">${(metrics.accuracy * 100).toFixed(1)}%</div>
</div>
<div class="metric-card">
<div class="metric-label">验证损失</div>
<div class="metric-value">${metrics.valLoss.toFixed(4)}</div>
</div>
<div class="metric-card">
<div class="metric-label">验证准确率</div>
<div class="metric-value">${(metrics.valAccuracy * 100).toFixed(1)}%</div>
</div>
${metrics.remainingTime !== undefined ? `
<div class="metric-card">
<div class="metric-label">剩余时间</div>
<div class="metric-value">${this.formatTime(metrics.remainingTime)}</div>
</div>
` : ''}
`;
}
plotLossChart() {
const canvas = document.getElementById('lossChart');
const ctx = canvas.getContext('2d');
// 设置canvas尺寸
canvas.width = canvas.offsetWidth;
canvas.height = 300;
// 清除画布
ctx.clearRect(0, 0, canvas.width, canvas.height);
if (this.lossHistory.length === 0) return;
const padding = 40;
const chartWidth = canvas.width - 2 * padding;
const chartHeight = canvas.height - 2 * padding;
// 找到最大值
const maxLoss = Math.max(...this.lossHistory);
// 绘制背景网格
ctx.strokeStyle = '#e0e0e0';
ctx.lineWidth = 1;
for (let i = 0; i <= 5; i++) {
const y = padding + (i * chartHeight / 5);
ctx.beginPath();
ctx.moveTo(padding, y);
ctx.lineTo(canvas.width - padding, y);
ctx.stroke();
}
// 绘制坐标轴
ctx.strokeStyle = '#333';
ctx.lineWidth = 2;
ctx.beginPath();
ctx.moveTo(padding, padding);
ctx.lineTo(padding, canvas.height - padding);
ctx.lineTo(canvas.width - padding, canvas.height - padding);
ctx.stroke();
// 绘制损失曲线
ctx.strokeStyle = '#f56565';
ctx.lineWidth = 2;
ctx.beginPath();
for (let i = 0; i < this.lossHistory.length; i++) {
const x = padding + (i / (this.lossHistory.length - 1)) * chartWidth;
const y = canvas.height - padding - (this.lossHistory[i] / maxLoss) * chartHeight * 0.9;
if (i === 0) {
ctx.moveTo(x, y);
} else {
ctx.lineTo(x, y);
}
}
ctx.stroke();
// 绘制准确率曲线
ctx.strokeStyle = '#48bb78';
ctx.lineWidth = 2;
ctx.beginPath();
for (let i = 0; i < this.accuracyHistory.length; i++) {
const x = padding + (i / (this.accuracyHistory.length - 1)) * chartWidth;
const y = canvas.height - padding - this.accuracyHistory[i] * chartHeight * 0.9;
if (i === 0) {
ctx.moveTo(x, y);
} else {
ctx.lineTo(x, y);
}
}
ctx.stroke();
// 添加图例
ctx.fillStyle = '#f56565';
ctx.fillRect(canvas.width - 120, 20, 15, 15);
ctx.fillStyle = '#333';
ctx.font = '12px Arial';
ctx.fillText('损失', canvas.width - 100, 32);
ctx.fillStyle = '#48bb78';
ctx.fillRect(canvas.width - 120, 40, 15, 15);
ctx.fillStyle = '#333';
ctx.fillText('准确率', canvas.width - 100, 52);
}
async startWebcam() {
if (!this.model) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
// 等待视频加载
video.addEventListener('loadeddata', () => {
this.isPredicting = true;
this.predictLoop();
});
this.showStatus('predictionStatus', 'success', '摄像头已启动');
} catch (error) {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
}
}
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
document.getElementById('confidenceBars').innerHTML = '';
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
// 提取特征 - 注意这里返回的是3D张量
const features = await this.extractFeatures(video);
// 需要添加批次维度进行预测 [batch, height, width, channels]
let featuresWithBatch = features.expandDims(0);
// 如果有保存的均值和标准差,进行标准化
if (this.dataMean && this.dataStd) {
featuresWithBatch = featuresWithBatch.sub(this.dataMean).div(this.dataStd.add(1e-7));
}
// 进行预测
let logits = this.model.predict(featuresWithBatch);
// 应用温度缩放
const temperature = parseFloat(document.getElementById('temperature').value) || 1.0;
if (temperature !== 1.0) {
logits = logits.div(temperature);
}
// 转换为概率
const prediction = await tf.softmax(logits).data();
// 更新置信度显示
this.updateConfidenceBars(prediction, temperature);
// 清理张量
features.dispose();
featuresWithBatch.dispose();
logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
// 继续预测循环
requestAnimationFrame(() => this.predictLoop());
}
updateConfidenceBars(predictions, temperature = 1.0) {
const container = document.getElementById('confidenceBars');
let html = '';
const classPredictions = [];
// 计算最大置信度和熵
const maxConfidence = Math.max(...predictions);
const maxConfidenceIndex = predictions.indexOf(maxConfidence);
const entropy = -predictions.reduce((sum, p) => sum + (p > 0 ? p * Math.log(p) : 0), 0);
const maxEntropy = Math.log(predictions.length); // 最大熵(均匀分布)
const uncertaintyScore = entropy / maxEntropy; // 0-1之间越高越不确定
// 保持原始顺序,不排序
for (let i = 0; i < predictions.length; i++) {
classPredictions.push({
className: this.classNames[i],
confidence: predictions[i],
index: i,
isMax: i === maxConfidenceIndex
});
}
// 如果不确定性很高(熵接近最大值),显示警告
let warningMessage = '';
if (uncertaintyScore > 0.8) {
warningMessage = `
<div style="background: #fff3cd; color: #856404; padding: 10px; border-radius: 5px; margin-bottom: 10px; border: 1px solid #ffeeba;">
⚠️ <strong>不确定性高</strong> - 模型对预测结果不确定
<br><small>可能是未见过的类别或模糊图像</small>
</div>
`;
} else if (maxConfidence < 0.5) {
warningMessage = `
<div style="background: #f8d7da; color: #721c24; padding: 10px; border-radius: 5px; margin-bottom: 10px; border: 1px solid #f5c6cb;">
❌ <strong>置信度低</strong> - 没有明显匹配的类别
<br><small>这可能不是训练过的类别</small>
</div>
`;
}
html = warningMessage;
// 显示温度信息
if (temperature !== 1.0) {
html += `
<div style="background: #d1ecf1; color: #0c5460; padding: 8px; border-radius: 5px; margin-bottom: 10px; border: 1px solid #bee5eb; font-size: 12px;">
🌡️ 温度: ${temperature.toFixed(1)} (越高越保守,越低越自信)
</div>
`;
}
// 按原始顺序显示,但标记最高置信度
classPredictions.forEach((pred) => {
const percentage = (pred.confidence * 100).toFixed(1);
const barWidth = Math.max(percentage, 5); // 最小宽度5%以显示标签
// 根据置信度设置颜色
let barColor = 'linear-gradient(90deg, #667eea, #764ba2)';
if (pred.confidence < 0.3) {
barColor = 'linear-gradient(90deg, #f56565, #e53e3e)';
} else if (pred.confidence < 0.6) {
barColor = 'linear-gradient(90deg, #ed8936, #dd6b20)';
} else if (pred.confidence > 0.8) {
barColor = 'linear-gradient(90deg, #48bb78, #38a169)';
}
// 为最高置信度的类别添加特殊样式
const isWinner = pred.isMax && pred.confidence > 0.5;
const borderStyle = isWinner ? 'border: 2px solid #48bb78; padding: 3px; border-radius: 5px;' : '';
html += `
<div class="confidence-item" style="${borderStyle}">
<div class="confidence-label">
<span>${pred.className} ${isWinner ? '👑' : ''}</span>
<span style="font-weight: ${isWinner ? 'bold' : 'normal'}">${percentage}%</span>
</div>
<div class="confidence-bar">
<div class="confidence-fill" style="width: ${barWidth}%; background: ${barColor}">
${percentage > 20 ? percentage + '%' : ''}
</div>
</div>
</div>
`;
});
// 添加不确定性分数
html += `
<div style="margin-top: 15px; padding: 10px; background: #f7fafc; border-radius: 5px; font-size: 12px; color: #4a5568;">
<strong>统计信息:</strong><br>
最高置信度: ${(maxConfidence * 100).toFixed(1)}%<br>
不确定性: ${(uncertaintyScore * 100).toFixed(1)}%<br>
熵: ${entropy.toFixed(3)} / ${maxEntropy.toFixed(3)}
</div>
`;
container.innerHTML = html;
}
async saveModel() {
if (!this.model) {
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
return;
}
try {
// 保存模型和类别名称
await this.model.save('downloads://custom-image-classifier');
// 保存类别信息和标准化参数
const classInfo = {
classNames: this.classNames,
dataMean: this.dataMean ? await this.dataMean.data() : null,
dataStd: this.dataStd ? await this.dataStd.data() : null,
date: new Date().toISOString()
};
const blob = new Blob([JSON.stringify(classInfo, null, 2)],
{ type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'class-info.json';
a.click();
URL.revokeObjectURL(url);
this.showStatus('predictionStatus', 'success', '模型已保存');
} catch (error) {
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
}
}
async loadModel() {
const modelInput = document.createElement('input');
modelInput.type = 'file';
modelInput.multiple = true;
modelInput.accept = '.json,.bin';
const classInput = document.createElement('input');
classInput.type = 'file';
classInput.accept = '.json';
modelInput.onchange = async (e) => {
try {
const files = e.target.files;
this.model = await tf.loadLayersModel(tf.io.browserFiles(files));
this.showStatus('predictionStatus', 'info', '模型已加载,请选择类别信息文件');
// 加载类别信息
classInput.click();
} catch (error) {
this.showStatus('predictionStatus', 'error', `加载模型失败: ${error.message}`);
}
};
classInput.onchange = async (e) => {
try {
const file = e.target.files[0];
const text = await file.text();
const classInfo = JSON.parse(text);
this.classNames = classInfo.classNames;
// 恢复标准化参数
if (classInfo.dataMean && classInfo.dataStd) {
this.dataMean = tf.tensor(classInfo.dataMean);
this.dataStd = tf.tensor(classInfo.dataStd);
console.log('已恢复数据标准化参数');
}
this.showStatus('predictionStatus', 'success',
`模型加载成功!类别: ${this.classNames.join(', ')}`);
} catch (error) {
this.showStatus('predictionStatus', 'error', `加载类别信息失败: ${error.message}`);
}
};
modelInput.click();
}
resetHyperparameters() {
document.getElementById('learningRate').value = -3;
document.getElementById('learningRateValue').textContent = '0.001';
document.getElementById('epochs').value = 100;
document.getElementById('epochsValue').textContent = '100';
document.getElementById('dropoutRate').value = 0.3;
document.getElementById('dropoutValue').textContent = '0.3';
document.getElementById('l2Regularization').value = -2;
document.getElementById('l2Value').textContent = '0.01';
document.getElementById('batchSize').value = 32;
document.getElementById('batchSizeValue').textContent = '32';
document.getElementById('temperature').value = 1.0;
document.getElementById('temperatureValue').textContent = '1.0';
document.getElementById('earlyStoppingCheck').checked = true;
document.getElementById('patience').value = 10;
document.getElementById('patienceValue').textContent = '10';
this.showStatus('trainingStatus', 'info', '超参数已重置为默认值');
}
showRecommendations() {
const message = `
💡 超参数设置建议:
如果过拟合(训练准确率高但验证准确率低):
- 增大 Dropout 率 (0.4-0.6)
- 增大 L2 正则化 (0.01-0.1)
- 减少训练轮数
- 启用早停
- 增加更多训练数据
如果欠拟合(训练和验证准确率都低):
- 减小 Dropout 率 (0.1-0.2)
- 减小 L2 正则化 (0.0001-0.001)
- 增加训练轮数
- 增大学习率 (0.001-0.01)
温度缩放:
- < 1.0: 更自信的预测(可能过度自信)
- = 1.0: 标准预测
- > 1.0: 更保守的预测(减少过度自信)
`;
alert(message);
}
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info',
'warning': 'status-warning'
};
element.className = `status-message ${classMap[type] || 'status-info'}`;
element.textContent = message;
}
}
// 初始化应用
let classifier;
document.addEventListener('DOMContentLoaded', () => {
classifier = new CustomImageClassifier();
});