1174 lines
49 KiB
JavaScript
1174 lines
49 KiB
JavaScript
// 自定义图像分类器应用
|
||
class CustomImageClassifier {
|
||
constructor() {
|
||
this.mobilenet = null;
|
||
this.model = null;
|
||
this.classNames = [];
|
||
this.trainingDataInputs = [];
|
||
this.trainingDataOutputs = [];
|
||
this.isTraining = false;
|
||
this.webcamStream = null;
|
||
this.isPredicting = false;
|
||
this.lossHistory = [];
|
||
this.accuracyHistory = [];
|
||
this.valLossHistory = [];
|
||
this.valAccuracyHistory = [];
|
||
this.dataMean = null; // 存储训练数据的均值
|
||
this.dataStd = null; // 存储训练数据的标准差
|
||
this.temperature = 1.0; // 温度缩放参数
|
||
this.bestValLoss = Infinity;
|
||
this.patienceCounter = 0;
|
||
this.bestWeights = null;
|
||
|
||
this.init();
|
||
}
|
||
|
||
async init() {
|
||
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
|
||
|
||
try {
|
||
// 加载预训练的 MobileNet 模型
|
||
this.mobilenet = await mobilenet.load({
|
||
version: 2,
|
||
alpha: 1.0
|
||
});
|
||
|
||
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
|
||
this.setupEventListeners();
|
||
} catch (error) {
|
||
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
setupEventListeners() {
|
||
// 文件上传监听
|
||
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
|
||
document.getElementById(id).addEventListener('change', (e) => {
|
||
this.handleImageUpload(e, index + 1);
|
||
});
|
||
});
|
||
|
||
// 按钮监听
|
||
document.getElementById('addDataBtn').addEventListener('click', () => this.addToDataset());
|
||
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
|
||
document.getElementById('trainBtn').addEventListener('click', () => this.trainModel());
|
||
document.getElementById('stopBtn').addEventListener('click', () => this.stopTraining());
|
||
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
|
||
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
|
||
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
|
||
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
|
||
}
|
||
|
||
handleImageUpload(event, classIndex) {
|
||
const files = event.target.files;
|
||
const previewContainer = document.getElementById(`class${classIndex}Preview`);
|
||
const countElement = document.getElementById(`class${classIndex}Count`);
|
||
|
||
previewContainer.innerHTML = '';
|
||
|
||
Array.from(files).forEach(file => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = document.createElement('img');
|
||
img.src = e.target.result;
|
||
img.className = 'preview-img';
|
||
previewContainer.appendChild(img);
|
||
};
|
||
reader.readAsDataURL(file);
|
||
});
|
||
|
||
countElement.textContent = `${files.length} 张图片`;
|
||
}
|
||
|
||
async addToDataset() {
|
||
const classes = [];
|
||
const imagesByClass = [];
|
||
|
||
// 收集所有类别和图片
|
||
for (let i = 1; i <= 3; i++) {
|
||
const className = document.getElementById(`class${i}Name`).value.trim();
|
||
const files = document.getElementById(`class${i}Images`).files;
|
||
|
||
if (className && files.length > 0) {
|
||
classes.push(className);
|
||
imagesByClass.push(files);
|
||
}
|
||
}
|
||
|
||
if (classes.length < 2) {
|
||
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
|
||
return;
|
||
}
|
||
|
||
// 检查是否已有数据,如果有,询问是否追加
|
||
if (this.trainingDataInputs.length > 0) {
|
||
// 清空旧数据
|
||
this.clearDataset();
|
||
}
|
||
|
||
this.classNames = classes;
|
||
|
||
// 计算总图片数量和每个类别的数量
|
||
let totalToProcess = 0;
|
||
const classImageCounts = [];
|
||
imagesByClass.forEach((files, idx) => {
|
||
totalToProcess += files.length;
|
||
classImageCounts.push({
|
||
className: classes[idx],
|
||
count: files.length
|
||
});
|
||
});
|
||
|
||
// 检查数据平衡性
|
||
const minImages = Math.min(...classImageCounts.map(c => c.count));
|
||
const maxImages = Math.max(...classImageCounts.map(c => c.count));
|
||
|
||
if (maxImages > minImages * 3) {
|
||
console.warn('警告:数据集不平衡!');
|
||
classImageCounts.forEach(c => {
|
||
console.log(`${c.className}: ${c.count} 张图片`);
|
||
});
|
||
this.showStatus('dataStatus', 'warning',
|
||
`警告:数据集不平衡!建议每个类别的图片数量相近。`);
|
||
await new Promise(resolve => setTimeout(resolve, 2000)); // 显示警告2秒
|
||
}
|
||
|
||
this.showStatus('dataStatus', 'info', `正在处理 ${totalToProcess} 张图片...`);
|
||
|
||
let totalImages = 0;
|
||
let processedImages = 0;
|
||
const classSampleCounts = new Array(classes.length).fill(0);
|
||
|
||
// 处理每个类别的图片
|
||
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
|
||
const files = imagesByClass[classIndex];
|
||
|
||
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
|
||
try {
|
||
console.log(`处理图片 ${fileIndex + 1}/${files.length} - ${classes[classIndex]}`);
|
||
const img = await this.loadImage(files[fileIndex]);
|
||
|
||
// 确保图片加载成功
|
||
if (!img || !img.width || !img.height) {
|
||
console.error('图片加载失败或无效:', files[fileIndex].name);
|
||
processedImages++;
|
||
continue;
|
||
}
|
||
|
||
// 为每张图片提取特征(原始版本)
|
||
const features = await this.extractFeatures(img, false);
|
||
this.trainingDataInputs.push(features);
|
||
this.trainingDataOutputs.push(classIndex);
|
||
|
||
// 如果启用数据增强,生成额外的增强样本
|
||
if (document.getElementById('dataAugmentationCheck') &&
|
||
document.getElementById('dataAugmentationCheck').checked) {
|
||
try {
|
||
// 生成1个增强版本
|
||
const augFeatures = await this.extractFeatures(img, true);
|
||
this.trainingDataInputs.push(augFeatures);
|
||
this.trainingDataOutputs.push(classIndex);
|
||
totalImages++; // 增强样本也计入总数
|
||
classSampleCounts[classIndex]++;
|
||
} catch (augError) {
|
||
console.warn('数据增强失败,跳过增强样本:', augError);
|
||
}
|
||
}
|
||
|
||
totalImages++;
|
||
processedImages++;
|
||
classSampleCounts[classIndex]++;
|
||
|
||
// 更新处理进度
|
||
const progressPercent = Math.round((processedImages / totalToProcess) * 100);
|
||
this.showStatus('dataStatus', 'info',
|
||
`处理中... ${processedImages}/${totalToProcess} (${progressPercent}%) - ${classes[classIndex]}: ${classSampleCounts[classIndex]} 张`);
|
||
|
||
// 只有当img是HTMLImageElement时才调用remove
|
||
if (img && img.remove) {
|
||
img.remove();
|
||
}
|
||
} catch (error) {
|
||
console.error(`处理图片失败 (${files[fileIndex].name}):`, error);
|
||
processedImages++;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 显示最终统计
|
||
let statsMessage = `成功添加 ${totalImages} 张图片到数据集!\n`;
|
||
classes.forEach((className, idx) => {
|
||
statsMessage += `${className}: ${classSampleCounts[idx]} 张\n`;
|
||
});
|
||
|
||
this.showStatus('dataStatus', 'success', statsMessage);
|
||
|
||
// 打印详细统计到控制台
|
||
console.log('数据集统计:');
|
||
console.log('总样本数:', totalImages);
|
||
classes.forEach((className, idx) => {
|
||
console.log(`${className}: ${classSampleCounts[idx]} 张 (${(classSampleCounts[idx]/totalImages*100).toFixed(1)}%)`);
|
||
});
|
||
}
|
||
|
||
async loadImage(file) {
|
||
return new Promise((resolve, reject) => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = new Image();
|
||
img.onload = () => resolve(img);
|
||
img.onerror = reject;
|
||
img.src = e.target.result;
|
||
};
|
||
reader.onerror = reject;
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
async extractFeatures(img, augment = false) {
|
||
try {
|
||
// 如果启用数据增强(仅在训练时)
|
||
if (augment && document.getElementById('dataAugmentationCheck') &&
|
||
document.getElementById('dataAugmentationCheck').checked) {
|
||
// 随机增强数据
|
||
const augmented = await this.augmentImage(img);
|
||
const activation = this.mobilenet.infer(augmented, 'conv_pw_13_relu');
|
||
|
||
// 如果augmented是canvas元素,不需要dispose
|
||
if (augmented !== img && augmented.tagName === 'CANVAS') {
|
||
// Canvas元素不需要dispose
|
||
}
|
||
|
||
const shape = activation.shape;
|
||
if (shape.length === 4 && shape[0] === 1) {
|
||
return activation.squeeze([0]);
|
||
}
|
||
return activation;
|
||
}
|
||
|
||
// 正常提取特征
|
||
const activation = this.mobilenet.infer(img, 'conv_pw_13_relu');
|
||
|
||
// 检查张量形状
|
||
const shape = activation.shape;
|
||
|
||
// 如果是4D张量 [batch, height, width, channels],去掉批次维度
|
||
if (shape.length === 4 && shape[0] === 1) {
|
||
return activation.squeeze([0]); // 返回3D张量
|
||
}
|
||
|
||
// 如果已经是3D张量,直接返回
|
||
return activation;
|
||
} catch (error) {
|
||
console.error('特征提取失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
async augmentImage(img) {
|
||
try {
|
||
// 简单的数据增强:随机噪声
|
||
const imgTensor = tf.browser.fromPixels(img);
|
||
|
||
// 添加少量随机噪声
|
||
const noise = tf.randomNormal(imgTensor.shape, 0, 5); // 减少噪声强度
|
||
const augmented = imgTensor.add(noise).clipByValue(0, 255);
|
||
|
||
// 转换回图像格式
|
||
const canvas = document.createElement('canvas');
|
||
canvas.width = img.width || img.videoWidth || 224;
|
||
canvas.height = img.height || img.videoHeight || 224;
|
||
|
||
// 归一化后转换为图像
|
||
const normalized = augmented.div(255);
|
||
await tf.browser.toPixels(normalized, canvas);
|
||
|
||
// 清理张量
|
||
imgTensor.dispose();
|
||
noise.dispose();
|
||
augmented.dispose();
|
||
normalized.dispose();
|
||
|
||
return canvas;
|
||
} catch (error) {
|
||
console.error('数据增强失败:', error);
|
||
// 如果增强失败,返回原图
|
||
return img;
|
||
}
|
||
}
|
||
|
||
clearDataset() {
|
||
this.trainingDataInputs.forEach(tensor => tensor.dispose());
|
||
this.trainingDataInputs = [];
|
||
this.trainingDataOutputs = [];
|
||
this.classNames = [];
|
||
|
||
// 清空标准化参数
|
||
if (this.dataMean) {
|
||
this.dataMean.dispose();
|
||
this.dataMean = null;
|
||
}
|
||
if (this.dataStd) {
|
||
this.dataStd.dispose();
|
||
this.dataStd = null;
|
||
}
|
||
|
||
// 清空文件输入和预览
|
||
for (let i = 1; i <= 3; i++) {
|
||
document.getElementById(`class${i}Images`).value = '';
|
||
document.getElementById(`class${i}Preview`).innerHTML = '';
|
||
document.getElementById(`class${i}Count`).textContent = '0 张图片';
|
||
}
|
||
|
||
this.showStatus('dataStatus', 'info', '数据集已清空');
|
||
}
|
||
|
||
createCustomModel() {
|
||
const numClasses = this.classNames.length;
|
||
|
||
// 获取超参数
|
||
const dropoutRate = parseFloat(document.getElementById('dropoutRate').value);
|
||
const l2Reg = Math.pow(10, parseFloat(document.getElementById('l2Regularization').value));
|
||
const learningRate = Math.pow(10, parseFloat(document.getElementById('learningRate').value));
|
||
|
||
// 获取输入形状 - 现在应该是3D张量
|
||
const inputShape = this.trainingDataInputs[0].shape;
|
||
console.log('Input shape:', inputShape);
|
||
console.log('Hyperparameters:', { learningRate, dropoutRate, l2Reg });
|
||
|
||
// 创建更深的模型以提高区分能力
|
||
this.model = tf.sequential({
|
||
layers: [
|
||
tf.layers.flatten({
|
||
inputShape: inputShape
|
||
}),
|
||
// 第一层:更多神经元
|
||
tf.layers.dense({
|
||
units: 128, // 减少神经元数量,降低模型复杂度
|
||
activation: 'relu',
|
||
kernelInitializer: 'heNormal',
|
||
kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }),
|
||
biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 })
|
||
}),
|
||
tf.layers.batchNormalization(),
|
||
tf.layers.dropout({ rate: dropoutRate, seed: 42 }), // 添加seed使dropout更一致
|
||
|
||
// 第二层
|
||
tf.layers.dense({
|
||
units: 64, // 减少神经元数量
|
||
activation: 'relu',
|
||
kernelInitializer: 'heNormal',
|
||
kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }),
|
||
biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 })
|
||
}),
|
||
tf.layers.batchNormalization(),
|
||
tf.layers.dropout({ rate: dropoutRate, seed: 43 }),
|
||
|
||
// 第三层
|
||
tf.layers.dense({
|
||
units: 32, // 进一步减少神经元
|
||
activation: 'relu',
|
||
kernelInitializer: 'heNormal',
|
||
kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }),
|
||
biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 })
|
||
}),
|
||
tf.layers.dropout({ rate: dropoutRate * 0.7, seed: 44 }),
|
||
|
||
// 输出层
|
||
tf.layers.dense({
|
||
units: numClasses,
|
||
activation: 'softmax',
|
||
kernelInitializer: 'glorotNormal'
|
||
})
|
||
]
|
||
});
|
||
|
||
// 使用更合适的优化器和学习率
|
||
this.model.compile({
|
||
optimizer: tf.train.adam(learningRate),
|
||
loss: 'categoricalCrossentropy',
|
||
metrics: ['accuracy']
|
||
});
|
||
|
||
console.log('Model created successfully');
|
||
this.model.summary();
|
||
}
|
||
|
||
async trainModel() {
|
||
if (this.trainingDataInputs.length === 0) {
|
||
this.showStatus('trainingStatus', 'error', '请先添加训练数据!');
|
||
return;
|
||
}
|
||
|
||
// 检查数据平衡性
|
||
const classCounts = {};
|
||
this.trainingDataOutputs.forEach(label => {
|
||
classCounts[label] = (classCounts[label] || 0) + 1;
|
||
});
|
||
|
||
console.log('每个类别的样本数量:', classCounts);
|
||
console.log('类别名称:', this.classNames);
|
||
|
||
// 计算类别权重以平衡数据
|
||
const totalSamples = this.trainingDataOutputs.length;
|
||
const numClasses = this.classNames.length;
|
||
const classWeights = {};
|
||
|
||
for (let i = 0; i < numClasses; i++) {
|
||
const classCount = classCounts[i] || 1;
|
||
// 使用反比例权重:样本少的类别权重更高
|
||
classWeights[i] = totalSamples / (numClasses * classCount);
|
||
}
|
||
|
||
console.log('类别权重:', classWeights);
|
||
|
||
this.showStatus('trainingStatus', 'info', '准备训练模型...');
|
||
|
||
// 创建模型
|
||
this.createCustomModel();
|
||
|
||
// 准备训练数据
|
||
let xs = tf.stack(this.trainingDataInputs);
|
||
const ys = tf.oneHot(tf.tensor1d(this.trainingDataOutputs, 'int32'), this.classNames.length);
|
||
|
||
// 手动分割训练集和验证集(而不是使用validationSplit)
|
||
const totalDataCount = xs.shape[0]; // 改名避免冲突
|
||
const indices = tf.util.createShuffledIndices(totalDataCount);
|
||
const splitIdx = Math.floor(totalDataCount * 0.8);
|
||
|
||
// 确保indices是TypedArray
|
||
const trainIndices = Array.from(indices).slice(0, splitIdx);
|
||
const valIndices = Array.from(indices).slice(splitIdx);
|
||
|
||
// 将indices转换为Tensor
|
||
const trainIndicesTensor = tf.tensor1d(trainIndices, 'int32');
|
||
const valIndicesTensor = tf.tensor1d(valIndices, 'int32');
|
||
|
||
// 使用gather来获取对应的数据
|
||
const xsTrain = tf.gather(xs, trainIndicesTensor);
|
||
const ysTrain = tf.gather(ys, trainIndicesTensor);
|
||
const xsVal = tf.gather(xs, valIndicesTensor);
|
||
const ysVal = tf.gather(ys, valIndicesTensor);
|
||
|
||
// 清理索引张量
|
||
trainIndicesTensor.dispose();
|
||
valIndicesTensor.dispose();
|
||
|
||
// 只在训练集上计算标准化参数
|
||
this.dataMean = xsTrain.mean([0], true);
|
||
this.dataStd = xsTrain.sub(this.dataMean).square().mean([0], true).sqrt();
|
||
|
||
// 应用标准化到训练集和验证集
|
||
const xsTrainNorm = xsTrain.sub(this.dataMean).div(this.dataStd.add(1e-7));
|
||
const xsValNorm = xsVal.sub(this.dataMean).div(this.dataStd.add(1e-7));
|
||
|
||
// 清理临时张量
|
||
xsTrain.dispose();
|
||
xsVal.dispose();
|
||
xs.dispose();
|
||
ys.dispose();
|
||
|
||
// 详细的数据集分析
|
||
console.log('\n📊 === 数据集分析 ===');
|
||
console.log(`训练集: ${trainIndices.length} 样本`);
|
||
console.log(`验证集: ${valIndices.length} 样本`);
|
||
console.log(`训练/验证比例: ${(trainIndices.length / totalDataCount * 100).toFixed(1)}% / ${(valIndices.length / totalDataCount * 100).toFixed(1)}%`);
|
||
|
||
// 分析每个类别在训练集和验证集中的分布
|
||
const trainLabels = trainIndices.map(i => this.trainingDataOutputs[i]);
|
||
const valLabels = valIndices.map(i => this.trainingDataOutputs[i]);
|
||
|
||
console.log('\n类别分布:');
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
const trainCount = trainLabels.filter(l => l === i).length;
|
||
const valCount = valLabels.filter(l => l === i).length;
|
||
console.log(` ${this.classNames[i]}:`);
|
||
console.log(` - 训练集: ${trainCount} (${(trainCount/trainIndices.length*100).toFixed(1)}%)`);
|
||
console.log(` - 验证集: ${valCount} (${(valCount/valIndices.length*100).toFixed(1)}%)`);
|
||
}
|
||
|
||
// 警告检查
|
||
if (valIndices.length < 5) {
|
||
console.warn('⚠️ 验证集太小!建议至少5个样本');
|
||
}
|
||
|
||
if (valIndices.length < totalDataCount * 0.1) {
|
||
console.warn('⚠️ 验证集比例太小!建议至少10%的数据用于验证');
|
||
}
|
||
|
||
console.log('===================\n');
|
||
|
||
console.log('数据形状:', xs.shape);
|
||
console.log('均值形状:', this.dataMean.shape);
|
||
console.log('标准差形状:', this.dataStd.shape);
|
||
|
||
// 打印数据集统计信息
|
||
const uniqueLabels = [...new Set(this.trainingDataOutputs)];
|
||
console.log('训练数据统计:');
|
||
uniqueLabels.forEach(label => {
|
||
const count = this.trainingDataOutputs.filter(l => l === label).length;
|
||
console.log(` ${this.classNames[label]}: ${count} 张图片`);
|
||
});
|
||
|
||
// 显示进度条
|
||
document.getElementById('trainingProgress').classList.remove('hidden');
|
||
document.getElementById('trainBtn').disabled = true;
|
||
document.getElementById('stopBtn').disabled = false;
|
||
|
||
this.isTraining = true;
|
||
this.lossHistory = [];
|
||
this.accuracyHistory = [];
|
||
this.valLossHistory = [];
|
||
this.valAccuracyHistory = [];
|
||
this.bestValLoss = Infinity;
|
||
this.patienceCounter = 0;
|
||
this.bestWeights = null;
|
||
|
||
// 从界面获取超参数
|
||
const epochs = parseInt(document.getElementById('epochs').value);
|
||
const batchSize = parseInt(document.getElementById('batchSize').value);
|
||
const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked;
|
||
const patience = parseInt(document.getElementById('patience').value);
|
||
this.temperature = parseFloat(document.getElementById('temperature').value);
|
||
|
||
console.log('Training parameters:', { epochs, batchSize, enableEarlyStopping, patience, temperature: this.temperature });
|
||
|
||
try {
|
||
const startTime = Date.now();
|
||
|
||
await this.model.fit(xsTrainNorm, ysTrain, {
|
||
epochs,
|
||
batchSize,
|
||
validationData: [xsValNorm, ysVal],
|
||
shuffle: true, // 确保数据被打乱
|
||
classWeight: classWeights, // 应用类别权重
|
||
callbacks: {
|
||
onEpochEnd: (epoch, logs) => {
|
||
const progress = ((epoch + 1) / epochs) * 100;
|
||
document.getElementById('progressFill').style.width = `${progress}%`;
|
||
document.getElementById('progressFill').textContent = `${Math.round(progress)}%`;
|
||
|
||
this.lossHistory.push(logs.loss);
|
||
this.accuracyHistory.push(logs.acc);
|
||
this.valLossHistory.push(logs.val_loss);
|
||
this.valAccuracyHistory.push(logs.val_acc);
|
||
|
||
// 早停逻辑
|
||
const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked;
|
||
const patience = parseInt(document.getElementById('patience').value);
|
||
|
||
if (enableEarlyStopping) {
|
||
const improvement = this.bestValLoss - logs.val_loss;
|
||
const improvementPercent = (improvement / this.bestValLoss) * 100;
|
||
|
||
// 详细的早停分析
|
||
console.log(`\n=== Epoch ${epoch + 1} 早停分析 ===`);
|
||
console.log(`当前验证损失: ${logs.val_loss.toFixed(4)}`);
|
||
console.log(`最佳验证损失: ${this.bestValLoss.toFixed(4)}`);
|
||
console.log(`改善: ${improvement.toFixed(4)} (${improvementPercent.toFixed(2)}%)`);
|
||
console.log(`耐心计数器: ${this.patienceCounter}/${patience}`);
|
||
|
||
// 只有显著改善才重置计数器(改善超过0.1%)
|
||
if (improvement > 0.001 && improvementPercent > 0.1) {
|
||
this.bestValLoss = logs.val_loss;
|
||
this.patienceCounter = 0;
|
||
// 保存最佳权重
|
||
if (this.bestWeights) {
|
||
this.bestWeights.forEach(w => w.dispose());
|
||
}
|
||
this.bestWeights = this.model.getWeights().map(w => w.clone());
|
||
console.log(`✅ 新的最佳验证损失! 重置耐心计数器`);
|
||
} else {
|
||
this.patienceCounter++;
|
||
console.log(`⚠️ 没有显著改善,耐心计数器增加`);
|
||
|
||
// 分析为什么没有改善
|
||
if (logs.val_loss > logs.loss * 1.5) {
|
||
console.log(`📊 可能过拟合:验证损失 (${logs.val_loss.toFixed(4)}) >> 训练损失 (${logs.loss.toFixed(4)})`);
|
||
}
|
||
|
||
if (logs.val_acc < logs.acc - 0.1) {
|
||
console.log(`📊 验证准确率低:验证 ${(logs.val_acc * 100).toFixed(1)}% vs 训练 ${(logs.acc * 100).toFixed(1)}%`);
|
||
}
|
||
|
||
if (this.patienceCounter >= patience) {
|
||
console.log(`\n🛑 早停触发!原因分析:`);
|
||
console.log(`- 连续 ${patience} 个 epochs 没有改善`);
|
||
console.log(`- 最终训练损失: ${logs.loss.toFixed(4)}`);
|
||
console.log(`- 最终验证损失: ${logs.val_loss.toFixed(4)}`);
|
||
console.log(`- 最终训练准确率: ${(logs.acc * 100).toFixed(1)}%`);
|
||
console.log(`- 最终验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`);
|
||
|
||
// 提供建议
|
||
if (logs.val_loss > logs.loss * 1.5) {
|
||
console.log(`\n💡 建议:模型过拟合`);
|
||
console.log(` - 增加 Dropout 率`);
|
||
console.log(` - 增加 L2 正则化`);
|
||
console.log(` - 减少模型复杂度`);
|
||
console.log(` - 增加更多训练数据`);
|
||
} else if (logs.val_acc < 0.5) {
|
||
console.log(`\n💡 建议:模型欠拟合`);
|
||
console.log(` - 增加训练轮数`);
|
||
console.log(` - 增大学习率`);
|
||
console.log(` - 增加模型复杂度`);
|
||
}
|
||
|
||
this.isTraining = false;
|
||
this.model.stopTraining = true;
|
||
}
|
||
}
|
||
console.log(`======================\n`);
|
||
}
|
||
|
||
// 计算剩余时间
|
||
const elapsedTime = Date.now() - startTime;
|
||
const timePerEpoch = elapsedTime / (epoch + 1);
|
||
const remainingTime = timePerEpoch * (epochs - epoch - 1);
|
||
const remainingSeconds = Math.round(remainingTime / 1000);
|
||
|
||
// 检测可能的过拟合
|
||
const overfitGap = logs.acc - logs.val_acc;
|
||
if (overfitGap > 0.2) {
|
||
console.warn(`⚠️ 可能过拟合!训练准确率: ${(logs.acc * 100).toFixed(1)}%, 验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`);
|
||
}
|
||
|
||
// 如果训练准确率达到100%,发出警告
|
||
if (logs.acc >= 0.99) {
|
||
console.warn('⚠️ 训练准确率达到100%,模型可能过拟合!考虑:');
|
||
console.warn(' 1. 增加更多训练数据');
|
||
console.warn(' 2. 增大Dropout率');
|
||
console.warn(' 3. 增大L2正则化');
|
||
console.warn(' 4. 减少模型复杂度');
|
||
}
|
||
|
||
this.updateMetrics({
|
||
epoch: epoch + 1,
|
||
totalEpochs: epochs,
|
||
loss: logs.loss,
|
||
accuracy: logs.acc,
|
||
valLoss: logs.val_loss,
|
||
valAccuracy: logs.val_acc,
|
||
remainingTime: remainingSeconds
|
||
});
|
||
|
||
this.plotLossChart();
|
||
|
||
// 更新训练状态
|
||
const statusMsg = `训练中 - Epoch ${epoch + 1}/${epochs} | ` +
|
||
`准确率: ${(logs.acc * 100).toFixed(1)}% | ` +
|
||
`剩余时间: ${this.formatTime(remainingSeconds)}`;
|
||
this.showStatus('trainingStatus', 'info', statusMsg);
|
||
|
||
if (!this.isTraining) {
|
||
this.model.stopTraining = true;
|
||
}
|
||
},
|
||
onBatchEnd: (batch, logs) => {
|
||
// 可选:显示批次进度
|
||
if (batch % 5 === 0) {
|
||
console.log(`Batch ${batch}: loss = ${logs.loss.toFixed(4)}`);
|
||
}
|
||
}
|
||
}
|
||
});
|
||
|
||
// 如果启用了早停并且保存了最佳权重,恢复它们
|
||
if (this.bestWeights && document.getElementById('earlyStoppingCheck').checked) {
|
||
console.log('恢复最佳模型权重...');
|
||
this.model.setWeights(this.bestWeights);
|
||
}
|
||
|
||
this.showStatus('trainingStatus', 'success', '模型训练完成!');
|
||
} catch (error) {
|
||
this.showStatus('trainingStatus', 'error', `训练失败: ${error.message}`);
|
||
} finally {
|
||
xsTrainNorm.dispose();
|
||
ysTrain.dispose();
|
||
xsValNorm.dispose();
|
||
ysVal.dispose();
|
||
|
||
document.getElementById('trainBtn').disabled = false;
|
||
document.getElementById('stopBtn').disabled = true;
|
||
this.isTraining = false;
|
||
}
|
||
}
|
||
|
||
stopTraining() {
|
||
this.isTraining = false;
|
||
this.showStatus('trainingStatus', 'info', '正在停止训练...');
|
||
}
|
||
|
||
formatTime(seconds) {
|
||
if (seconds < 60) {
|
||
return `${seconds} 秒`;
|
||
} else if (seconds < 3600) {
|
||
const minutes = Math.floor(seconds / 60);
|
||
const secs = seconds % 60;
|
||
return `${minutes} 分 ${secs} 秒`;
|
||
} else {
|
||
const hours = Math.floor(seconds / 3600);
|
||
const minutes = Math.floor((seconds % 3600) / 60);
|
||
return `${hours} 小时 ${minutes} 分`;
|
||
}
|
||
}
|
||
|
||
updateMetrics(metrics) {
|
||
const container = document.getElementById('metricsContainer');
|
||
container.innerHTML = `
|
||
<div class="metric-card">
|
||
<div class="metric-label">Epoch</div>
|
||
<div class="metric-value">${metrics.epoch}/${metrics.totalEpochs}</div>
|
||
</div>
|
||
<div class="metric-card">
|
||
<div class="metric-label">损失</div>
|
||
<div class="metric-value">${metrics.loss.toFixed(4)}</div>
|
||
</div>
|
||
<div class="metric-card">
|
||
<div class="metric-label">准确率</div>
|
||
<div class="metric-value">${(metrics.accuracy * 100).toFixed(1)}%</div>
|
||
</div>
|
||
<div class="metric-card">
|
||
<div class="metric-label">验证损失</div>
|
||
<div class="metric-value">${metrics.valLoss.toFixed(4)}</div>
|
||
</div>
|
||
<div class="metric-card">
|
||
<div class="metric-label">验证准确率</div>
|
||
<div class="metric-value">${(metrics.valAccuracy * 100).toFixed(1)}%</div>
|
||
</div>
|
||
${metrics.remainingTime !== undefined ? `
|
||
<div class="metric-card">
|
||
<div class="metric-label">剩余时间</div>
|
||
<div class="metric-value">${this.formatTime(metrics.remainingTime)}</div>
|
||
</div>
|
||
` : ''}
|
||
`;
|
||
}
|
||
|
||
plotLossChart() {
|
||
const canvas = document.getElementById('lossChart');
|
||
const ctx = canvas.getContext('2d');
|
||
|
||
// 设置canvas尺寸
|
||
canvas.width = canvas.offsetWidth;
|
||
canvas.height = 300;
|
||
|
||
// 清除画布
|
||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||
|
||
if (this.lossHistory.length === 0) return;
|
||
|
||
const padding = 40;
|
||
const chartWidth = canvas.width - 2 * padding;
|
||
const chartHeight = canvas.height - 2 * padding;
|
||
|
||
// 找到最大值
|
||
const maxLoss = Math.max(...this.lossHistory);
|
||
|
||
// 绘制背景网格
|
||
ctx.strokeStyle = '#e0e0e0';
|
||
ctx.lineWidth = 1;
|
||
for (let i = 0; i <= 5; i++) {
|
||
const y = padding + (i * chartHeight / 5);
|
||
ctx.beginPath();
|
||
ctx.moveTo(padding, y);
|
||
ctx.lineTo(canvas.width - padding, y);
|
||
ctx.stroke();
|
||
}
|
||
|
||
// 绘制坐标轴
|
||
ctx.strokeStyle = '#333';
|
||
ctx.lineWidth = 2;
|
||
ctx.beginPath();
|
||
ctx.moveTo(padding, padding);
|
||
ctx.lineTo(padding, canvas.height - padding);
|
||
ctx.lineTo(canvas.width - padding, canvas.height - padding);
|
||
ctx.stroke();
|
||
|
||
// 绘制损失曲线
|
||
ctx.strokeStyle = '#f56565';
|
||
ctx.lineWidth = 2;
|
||
ctx.beginPath();
|
||
for (let i = 0; i < this.lossHistory.length; i++) {
|
||
const x = padding + (i / (this.lossHistory.length - 1)) * chartWidth;
|
||
const y = canvas.height - padding - (this.lossHistory[i] / maxLoss) * chartHeight * 0.9;
|
||
|
||
if (i === 0) {
|
||
ctx.moveTo(x, y);
|
||
} else {
|
||
ctx.lineTo(x, y);
|
||
}
|
||
}
|
||
ctx.stroke();
|
||
|
||
// 绘制准确率曲线
|
||
ctx.strokeStyle = '#48bb78';
|
||
ctx.lineWidth = 2;
|
||
ctx.beginPath();
|
||
for (let i = 0; i < this.accuracyHistory.length; i++) {
|
||
const x = padding + (i / (this.accuracyHistory.length - 1)) * chartWidth;
|
||
const y = canvas.height - padding - this.accuracyHistory[i] * chartHeight * 0.9;
|
||
|
||
if (i === 0) {
|
||
ctx.moveTo(x, y);
|
||
} else {
|
||
ctx.lineTo(x, y);
|
||
}
|
||
}
|
||
ctx.stroke();
|
||
|
||
// 添加图例
|
||
ctx.fillStyle = '#f56565';
|
||
ctx.fillRect(canvas.width - 120, 20, 15, 15);
|
||
ctx.fillStyle = '#333';
|
||
ctx.font = '12px Arial';
|
||
ctx.fillText('损失', canvas.width - 100, 32);
|
||
|
||
ctx.fillStyle = '#48bb78';
|
||
ctx.fillRect(canvas.width - 120, 40, 15, 15);
|
||
ctx.fillStyle = '#333';
|
||
ctx.fillText('准确率', canvas.width - 100, 52);
|
||
}
|
||
|
||
async startWebcam() {
|
||
if (!this.model) {
|
||
this.showStatus('predictionStatus', 'error', '请先训练模型!');
|
||
return;
|
||
}
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({
|
||
video: { facingMode: 'user' },
|
||
audio: false
|
||
});
|
||
|
||
video.srcObject = stream;
|
||
this.webcamStream = stream;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = true;
|
||
document.getElementById('stopWebcamBtn').disabled = false;
|
||
|
||
// 等待视频加载
|
||
video.addEventListener('loadeddata', () => {
|
||
this.isPredicting = true;
|
||
this.predictLoop();
|
||
});
|
||
|
||
this.showStatus('predictionStatus', 'success', '摄像头已启动');
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
stopWebcam() {
|
||
if (this.webcamStream) {
|
||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||
this.webcamStream = null;
|
||
}
|
||
|
||
this.isPredicting = false;
|
||
|
||
const video = document.getElementById('webcam');
|
||
video.srcObject = null;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = false;
|
||
document.getElementById('stopWebcamBtn').disabled = true;
|
||
|
||
document.getElementById('confidenceBars').innerHTML = '';
|
||
|
||
this.showStatus('predictionStatus', 'info', '摄像头已停止');
|
||
}
|
||
|
||
async predictLoop() {
|
||
if (!this.isPredicting) return;
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
if (video.readyState === 4) {
|
||
try {
|
||
// 提取特征 - 注意这里返回的是3D张量
|
||
const features = await this.extractFeatures(video);
|
||
|
||
// 需要添加批次维度进行预测 [batch, height, width, channels]
|
||
let featuresWithBatch = features.expandDims(0);
|
||
|
||
// 如果有保存的均值和标准差,进行标准化
|
||
if (this.dataMean && this.dataStd) {
|
||
featuresWithBatch = featuresWithBatch.sub(this.dataMean).div(this.dataStd.add(1e-7));
|
||
}
|
||
|
||
// 进行预测
|
||
let logits = this.model.predict(featuresWithBatch);
|
||
|
||
// 应用温度缩放
|
||
const temperature = parseFloat(document.getElementById('temperature').value) || 1.0;
|
||
if (temperature !== 1.0) {
|
||
logits = logits.div(temperature);
|
||
}
|
||
|
||
// 转换为概率
|
||
const prediction = await tf.softmax(logits).data();
|
||
|
||
// 更新置信度显示
|
||
this.updateConfidenceBars(prediction, temperature);
|
||
|
||
// 清理张量
|
||
features.dispose();
|
||
featuresWithBatch.dispose();
|
||
logits.dispose();
|
||
} catch (error) {
|
||
console.error('预测错误:', error);
|
||
}
|
||
}
|
||
|
||
// 继续预测循环
|
||
requestAnimationFrame(() => this.predictLoop());
|
||
}
|
||
|
||
updateConfidenceBars(predictions, temperature = 1.0) {
|
||
const container = document.getElementById('confidenceBars');
|
||
|
||
let html = '';
|
||
const classPredictions = [];
|
||
|
||
// 计算最大置信度和熵
|
||
const maxConfidence = Math.max(...predictions);
|
||
const maxConfidenceIndex = predictions.indexOf(maxConfidence);
|
||
const entropy = -predictions.reduce((sum, p) => sum + (p > 0 ? p * Math.log(p) : 0), 0);
|
||
const maxEntropy = Math.log(predictions.length); // 最大熵(均匀分布)
|
||
const uncertaintyScore = entropy / maxEntropy; // 0-1之间,越高越不确定
|
||
|
||
// 保持原始顺序,不排序
|
||
for (let i = 0; i < predictions.length; i++) {
|
||
classPredictions.push({
|
||
className: this.classNames[i],
|
||
confidence: predictions[i],
|
||
index: i,
|
||
isMax: i === maxConfidenceIndex
|
||
});
|
||
}
|
||
|
||
// 如果不确定性很高(熵接近最大值),显示警告
|
||
let warningMessage = '';
|
||
if (uncertaintyScore > 0.8) {
|
||
warningMessage = `
|
||
<div style="background: #fff3cd; color: #856404; padding: 10px; border-radius: 5px; margin-bottom: 10px; border: 1px solid #ffeeba;">
|
||
⚠️ <strong>不确定性高</strong> - 模型对预测结果不确定
|
||
<br><small>可能是未见过的类别或模糊图像</small>
|
||
</div>
|
||
`;
|
||
} else if (maxConfidence < 0.5) {
|
||
warningMessage = `
|
||
<div style="background: #f8d7da; color: #721c24; padding: 10px; border-radius: 5px; margin-bottom: 10px; border: 1px solid #f5c6cb;">
|
||
❌ <strong>置信度低</strong> - 没有明显匹配的类别
|
||
<br><small>这可能不是训练过的类别</small>
|
||
</div>
|
||
`;
|
||
}
|
||
|
||
html = warningMessage;
|
||
|
||
// 显示温度信息
|
||
if (temperature !== 1.0) {
|
||
html += `
|
||
<div style="background: #d1ecf1; color: #0c5460; padding: 8px; border-radius: 5px; margin-bottom: 10px; border: 1px solid #bee5eb; font-size: 12px;">
|
||
🌡️ 温度: ${temperature.toFixed(1)} (越高越保守,越低越自信)
|
||
</div>
|
||
`;
|
||
}
|
||
|
||
// 按原始顺序显示,但标记最高置信度
|
||
classPredictions.forEach((pred) => {
|
||
const percentage = (pred.confidence * 100).toFixed(1);
|
||
const barWidth = Math.max(percentage, 5); // 最小宽度5%以显示标签
|
||
|
||
// 根据置信度设置颜色
|
||
let barColor = 'linear-gradient(90deg, #667eea, #764ba2)';
|
||
if (pred.confidence < 0.3) {
|
||
barColor = 'linear-gradient(90deg, #f56565, #e53e3e)';
|
||
} else if (pred.confidence < 0.6) {
|
||
barColor = 'linear-gradient(90deg, #ed8936, #dd6b20)';
|
||
} else if (pred.confidence > 0.8) {
|
||
barColor = 'linear-gradient(90deg, #48bb78, #38a169)';
|
||
}
|
||
|
||
// 为最高置信度的类别添加特殊样式
|
||
const isWinner = pred.isMax && pred.confidence > 0.5;
|
||
const borderStyle = isWinner ? 'border: 2px solid #48bb78; padding: 3px; border-radius: 5px;' : '';
|
||
|
||
html += `
|
||
<div class="confidence-item" style="${borderStyle}">
|
||
<div class="confidence-label">
|
||
<span>${pred.className} ${isWinner ? '👑' : ''}</span>
|
||
<span style="font-weight: ${isWinner ? 'bold' : 'normal'}">${percentage}%</span>
|
||
</div>
|
||
<div class="confidence-bar">
|
||
<div class="confidence-fill" style="width: ${barWidth}%; background: ${barColor}">
|
||
${percentage > 20 ? percentage + '%' : ''}
|
||
</div>
|
||
</div>
|
||
</div>
|
||
`;
|
||
});
|
||
|
||
// 添加不确定性分数
|
||
html += `
|
||
<div style="margin-top: 15px; padding: 10px; background: #f7fafc; border-radius: 5px; font-size: 12px; color: #4a5568;">
|
||
<strong>统计信息:</strong><br>
|
||
最高置信度: ${(maxConfidence * 100).toFixed(1)}%<br>
|
||
不确定性: ${(uncertaintyScore * 100).toFixed(1)}%<br>
|
||
熵: ${entropy.toFixed(3)} / ${maxEntropy.toFixed(3)}
|
||
</div>
|
||
`;
|
||
|
||
container.innerHTML = html;
|
||
}
|
||
|
||
async saveModel() {
|
||
if (!this.model) {
|
||
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
|
||
return;
|
||
}
|
||
|
||
try {
|
||
// 保存模型和类别名称
|
||
await this.model.save('downloads://custom-image-classifier');
|
||
|
||
// 保存类别信息和标准化参数
|
||
const classInfo = {
|
||
classNames: this.classNames,
|
||
dataMean: this.dataMean ? await this.dataMean.data() : null,
|
||
dataStd: this.dataStd ? await this.dataStd.data() : null,
|
||
date: new Date().toISOString()
|
||
};
|
||
|
||
const blob = new Blob([JSON.stringify(classInfo, null, 2)],
|
||
{ type: 'application/json' });
|
||
const url = URL.createObjectURL(blob);
|
||
const a = document.createElement('a');
|
||
a.href = url;
|
||
a.download = 'class-info.json';
|
||
a.click();
|
||
URL.revokeObjectURL(url);
|
||
|
||
this.showStatus('predictionStatus', 'success', '模型已保存');
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
async loadModel() {
|
||
const modelInput = document.createElement('input');
|
||
modelInput.type = 'file';
|
||
modelInput.multiple = true;
|
||
modelInput.accept = '.json,.bin';
|
||
|
||
const classInput = document.createElement('input');
|
||
classInput.type = 'file';
|
||
classInput.accept = '.json';
|
||
|
||
modelInput.onchange = async (e) => {
|
||
try {
|
||
const files = e.target.files;
|
||
this.model = await tf.loadLayersModel(tf.io.browserFiles(files));
|
||
|
||
this.showStatus('predictionStatus', 'info', '模型已加载,请选择类别信息文件');
|
||
|
||
// 加载类别信息
|
||
classInput.click();
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `加载模型失败: ${error.message}`);
|
||
}
|
||
};
|
||
|
||
classInput.onchange = async (e) => {
|
||
try {
|
||
const file = e.target.files[0];
|
||
const text = await file.text();
|
||
const classInfo = JSON.parse(text);
|
||
|
||
this.classNames = classInfo.classNames;
|
||
|
||
// 恢复标准化参数
|
||
if (classInfo.dataMean && classInfo.dataStd) {
|
||
this.dataMean = tf.tensor(classInfo.dataMean);
|
||
this.dataStd = tf.tensor(classInfo.dataStd);
|
||
console.log('已恢复数据标准化参数');
|
||
}
|
||
|
||
this.showStatus('predictionStatus', 'success',
|
||
`模型加载成功!类别: ${this.classNames.join(', ')}`);
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `加载类别信息失败: ${error.message}`);
|
||
}
|
||
};
|
||
|
||
modelInput.click();
|
||
}
|
||
|
||
resetHyperparameters() {
|
||
document.getElementById('learningRate').value = -3;
|
||
document.getElementById('learningRateValue').textContent = '0.001';
|
||
document.getElementById('epochs').value = 100;
|
||
document.getElementById('epochsValue').textContent = '100';
|
||
document.getElementById('dropoutRate').value = 0.3;
|
||
document.getElementById('dropoutValue').textContent = '0.3';
|
||
document.getElementById('l2Regularization').value = -2;
|
||
document.getElementById('l2Value').textContent = '0.01';
|
||
document.getElementById('batchSize').value = 32;
|
||
document.getElementById('batchSizeValue').textContent = '32';
|
||
document.getElementById('temperature').value = 1.0;
|
||
document.getElementById('temperatureValue').textContent = '1.0';
|
||
document.getElementById('earlyStoppingCheck').checked = true;
|
||
document.getElementById('patience').value = 10;
|
||
document.getElementById('patienceValue').textContent = '10';
|
||
|
||
this.showStatus('trainingStatus', 'info', '超参数已重置为默认值');
|
||
}
|
||
|
||
showRecommendations() {
|
||
const message = `
|
||
💡 超参数设置建议:
|
||
|
||
如果过拟合(训练准确率高但验证准确率低):
|
||
- 增大 Dropout 率 (0.4-0.6)
|
||
- 增大 L2 正则化 (0.01-0.1)
|
||
- 减少训练轮数
|
||
- 启用早停
|
||
- 增加更多训练数据
|
||
|
||
如果欠拟合(训练和验证准确率都低):
|
||
- 减小 Dropout 率 (0.1-0.2)
|
||
- 减小 L2 正则化 (0.0001-0.001)
|
||
- 增加训练轮数
|
||
- 增大学习率 (0.001-0.01)
|
||
|
||
温度缩放:
|
||
- < 1.0: 更自信的预测(可能过度自信)
|
||
- = 1.0: 标准预测
|
||
- > 1.0: 更保守的预测(减少过度自信)
|
||
`;
|
||
alert(message);
|
||
}
|
||
|
||
showStatus(elementId, type, message) {
|
||
const element = document.getElementById(elementId);
|
||
|
||
const classMap = {
|
||
'success': 'status-success',
|
||
'error': 'status-error',
|
||
'info': 'status-info',
|
||
'warning': 'status-warning'
|
||
};
|
||
|
||
element.className = `status-message ${classMap[type] || 'status-info'}`;
|
||
element.textContent = message;
|
||
}
|
||
}
|
||
|
||
// 初始化应用
|
||
let classifier;
|
||
document.addEventListener('DOMContentLoaded', () => {
|
||
classifier = new CustomImageClassifier();
|
||
}); |