From 8f998b19154aeb80a89d9e55b67e0d238a0fe7aa Mon Sep 17 00:00:00 2001 From: 51hhh Date: Mon, 11 Aug 2025 17:44:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 69 ++ 三层神经网络/custom-classifier.html | 532 ++++++++++++ 三层神经网络/custom-classifier.js | 1174 +++++++++++++++++++++++++++ 原版KNN/knn-classifier.html | 551 +++++++++++++ 原版KNN/knn-classifier.js | 621 ++++++++++++++ 完善KNN/knn-classifier.html | 561 +++++++++++++ 完善KNN/knn-classifier.js | 985 ++++++++++++++++++++++ 随机森林/decision-tree.js | 394 +++++++++ 随机森林/rf-classifier.html | 542 +++++++++++++ 随机森林/rf-classifier.js | 472 +++++++++++ 10 files changed, 5901 insertions(+) create mode 100644 README.md create mode 100644 三层神经网络/custom-classifier.html create mode 100644 三层神经网络/custom-classifier.js create mode 100644 原版KNN/knn-classifier.html create mode 100644 原版KNN/knn-classifier.js create mode 100644 完善KNN/knn-classifier.html create mode 100644 完善KNN/knn-classifier.js create mode 100644 随机森林/decision-tree.js create mode 100644 随机森林/rf-classifier.html create mode 100644 随机森林/rf-classifier.js diff --git a/README.md b/README.md new file mode 100644 index 0000000..806b66a --- /dev/null +++ b/README.md @@ -0,0 +1,69 @@ +# 图像分类器项目 + +本项目包含多种基于MobileNet特征提取的图像分类器实现,包括三层神经网络、随机森林和KNN算法。 + +## 模型对比 + +| 模型 | 算法类型 | 特点 | 适用场景 | +|------|---------|------|---------| +| 三层神经网络 | 深度学习 | 自定义全连接网络,支持训练监控、早停、正则化 | 需要高精度、可解释性强的场景 | +| 随机森林 | 集成学习 | 多决策树投票,参数直观可调 | 中等规模数据集,需要模型解释性 | +| KNN (原版) | 实例学习 | 简单实现,预测结果平滑处理 | 快速原型开发,小规模数据 | +| KNN (完善版) | 实例学习 | 增强阈值控制,支持单类别检测 | 异常检测、单类别分类场景 | + +## 详细说明 + +### 三层神经网络分类器 +- 使用MobileNet进行特征提取 +- 自定义三层全连接网络进行分类 +- 功能特点: + - 训练过程可视化(损失/准确率曲线) + - 支持早停、正则化等技巧 + - 模型保存/加载功能 + - 温度缩放调整预测置信度 + +### 随机森林分类器 +- 使用MobileNet特征作为输入 +- 构建多个决策树进行集成分类 +- 可调参数: + - 决策树数量(默认10棵) + - 训练集子集比例(默认70%) +- 特点: + - 训练速度快 + - 提供ImageNet标签显示功能 + +### KNN分类器(原版) +- 基于MobileNet特征的K最近邻算法 +- 特点: + - 实现简单 + - 低通滤波器平滑预测结果 + - 支持模型保存/加载 + +### KNN分类器(完善版) +在原版基础上增强: +- 距离阈值控制 +- 自适应阈值计算 +- 改进的单类别检测 +- 更详细的训练反馈 + +## 使用指南 + +1. 选择分类器类型 +2. 上传各类别训练图片 +3. 调整模型参数(如适用) +4. 点击"训练模型"按钮 +5. 使用摄像头或上传图片进行预测 + +## 技术实现 + +所有分类器均基于以下技术: +- 特征提取:MobileNet (TensorFlow.js) +- 前端框架:纯JavaScript实现 +- 数据存储:浏览器本地存储(IndexedDB) + +## 开发建议 + +- 对于高精度需求:使用三层神经网络 +- 对于可解释性需求:使用随机森林 +- 对于快速原型开发:使用KNN +- 对于异常检测:使用完善版KNN diff --git a/三层神经网络/custom-classifier.html b/三层神经网络/custom-classifier.html new file mode 100644 index 0000000..aceabc6 --- /dev/null +++ b/三层神经网络/custom-classifier.html @@ -0,0 +1,532 @@ + + + + + + 自定义图像分类器 - TensorFlow.js + + + + + +
+

🤖 自定义图像分类器

+ +
+ +
+

📸 数据采集

+ +
+

1 第一类

+ + + + 0 张图片 +
+
+ +
+

2 第二类

+ + + + 0 张图片 +
+
+ +
+

3 第三类(可选)

+ + + + 0 张图片 +
+
+ +
+ + +
+ +
+
+ + +
+

🎯 模型训练

+ + +
+

⚙️ 超参数设置

+ +
+
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+
+ +
+ + +
+ + +
+
+ +
+ +
+ +
+ + +
+
+ +
+ + +
+ +
+ + + +
+ + +
+
+ + +
+

📹 实时预测

+ +
+ + + + +
+ +
+ +
+ +
+ +
+
+
+ + + + \ No newline at end of file diff --git a/三层神经网络/custom-classifier.js b/三层神经网络/custom-classifier.js new file mode 100644 index 0000000..27678da --- /dev/null +++ b/三层神经网络/custom-classifier.js @@ -0,0 +1,1174 @@ +// 自定义图像分类器应用 +class CustomImageClassifier { + constructor() { + this.mobilenet = null; + this.model = null; + this.classNames = []; + this.trainingDataInputs = []; + this.trainingDataOutputs = []; + this.isTraining = false; + this.webcamStream = null; + this.isPredicting = false; + this.lossHistory = []; + this.accuracyHistory = []; + this.valLossHistory = []; + this.valAccuracyHistory = []; + this.dataMean = null; // 存储训练数据的均值 + this.dataStd = null; // 存储训练数据的标准差 + this.temperature = 1.0; // 温度缩放参数 + this.bestValLoss = Infinity; + this.patienceCounter = 0; + this.bestWeights = null; + + this.init(); + } + + async init() { + this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); + + try { + // 加载预训练的 MobileNet 模型 + this.mobilenet = await mobilenet.load({ + version: 2, + alpha: 1.0 + }); + + this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); + this.setupEventListeners(); + } catch (error) { + this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); + } + } + + setupEventListeners() { + // 文件上传监听 + ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { + document.getElementById(id).addEventListener('change', (e) => { + this.handleImageUpload(e, index + 1); + }); + }); + + // 按钮监听 + document.getElementById('addDataBtn').addEventListener('click', () => this.addToDataset()); + document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); + document.getElementById('trainBtn').addEventListener('click', () => this.trainModel()); + document.getElementById('stopBtn').addEventListener('click', () => this.stopTraining()); + document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); + document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); + document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel()); + document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel()); + } + + handleImageUpload(event, classIndex) { + const files = event.target.files; + const previewContainer = document.getElementById(`class${classIndex}Preview`); + const countElement = document.getElementById(`class${classIndex}Count`); + + previewContainer.innerHTML = ''; + + Array.from(files).forEach(file => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = document.createElement('img'); + img.src = e.target.result; + img.className = 'preview-img'; + previewContainer.appendChild(img); + }; + reader.readAsDataURL(file); + }); + + countElement.textContent = `${files.length} 张图片`; + } + + async addToDataset() { + const classes = []; + const imagesByClass = []; + + // 收集所有类别和图片 + for (let i = 1; i <= 3; i++) { + const className = document.getElementById(`class${i}Name`).value.trim(); + const files = document.getElementById(`class${i}Images`).files; + + if (className && files.length > 0) { + classes.push(className); + imagesByClass.push(files); + } + } + + if (classes.length < 2) { + this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!'); + return; + } + + // 检查是否已有数据,如果有,询问是否追加 + if (this.trainingDataInputs.length > 0) { + // 清空旧数据 + this.clearDataset(); + } + + this.classNames = classes; + + // 计算总图片数量和每个类别的数量 + let totalToProcess = 0; + const classImageCounts = []; + imagesByClass.forEach((files, idx) => { + totalToProcess += files.length; + classImageCounts.push({ + className: classes[idx], + count: files.length + }); + }); + + // 检查数据平衡性 + const minImages = Math.min(...classImageCounts.map(c => c.count)); + const maxImages = Math.max(...classImageCounts.map(c => c.count)); + + if (maxImages > minImages * 3) { + console.warn('警告:数据集不平衡!'); + classImageCounts.forEach(c => { + console.log(`${c.className}: ${c.count} 张图片`); + }); + this.showStatus('dataStatus', 'warning', + `警告:数据集不平衡!建议每个类别的图片数量相近。`); + await new Promise(resolve => setTimeout(resolve, 2000)); // 显示警告2秒 + } + + this.showStatus('dataStatus', 'info', `正在处理 ${totalToProcess} 张图片...`); + + let totalImages = 0; + let processedImages = 0; + const classSampleCounts = new Array(classes.length).fill(0); + + // 处理每个类别的图片 + for (let classIndex = 0; classIndex < classes.length; classIndex++) { + const files = imagesByClass[classIndex]; + + for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { + try { + console.log(`处理图片 ${fileIndex + 1}/${files.length} - ${classes[classIndex]}`); + const img = await this.loadImage(files[fileIndex]); + + // 确保图片加载成功 + if (!img || !img.width || !img.height) { + console.error('图片加载失败或无效:', files[fileIndex].name); + processedImages++; + continue; + } + + // 为每张图片提取特征(原始版本) + const features = await this.extractFeatures(img, false); + this.trainingDataInputs.push(features); + this.trainingDataOutputs.push(classIndex); + + // 如果启用数据增强,生成额外的增强样本 + if (document.getElementById('dataAugmentationCheck') && + document.getElementById('dataAugmentationCheck').checked) { + try { + // 生成1个增强版本 + const augFeatures = await this.extractFeatures(img, true); + this.trainingDataInputs.push(augFeatures); + this.trainingDataOutputs.push(classIndex); + totalImages++; // 增强样本也计入总数 + classSampleCounts[classIndex]++; + } catch (augError) { + console.warn('数据增强失败,跳过增强样本:', augError); + } + } + + totalImages++; + processedImages++; + classSampleCounts[classIndex]++; + + // 更新处理进度 + const progressPercent = Math.round((processedImages / totalToProcess) * 100); + this.showStatus('dataStatus', 'info', + `处理中... ${processedImages}/${totalToProcess} (${progressPercent}%) - ${classes[classIndex]}: ${classSampleCounts[classIndex]} 张`); + + // 只有当img是HTMLImageElement时才调用remove + if (img && img.remove) { + img.remove(); + } + } catch (error) { + console.error(`处理图片失败 (${files[fileIndex].name}):`, error); + processedImages++; + } + } + } + + // 显示最终统计 + let statsMessage = `成功添加 ${totalImages} 张图片到数据集!\n`; + classes.forEach((className, idx) => { + statsMessage += `${className}: ${classSampleCounts[idx]} 张\n`; + }); + + this.showStatus('dataStatus', 'success', statsMessage); + + // 打印详细统计到控制台 + console.log('数据集统计:'); + console.log('总样本数:', totalImages); + classes.forEach((className, idx) => { + console.log(`${className}: ${classSampleCounts[idx]} 张 (${(classSampleCounts[idx]/totalImages*100).toFixed(1)}%)`); + }); + } + + async loadImage(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = new Image(); + img.onload = () => resolve(img); + img.onerror = reject; + img.src = e.target.result; + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); + } + + async extractFeatures(img, augment = false) { + try { + // 如果启用数据增强(仅在训练时) + if (augment && document.getElementById('dataAugmentationCheck') && + document.getElementById('dataAugmentationCheck').checked) { + // 随机增强数据 + const augmented = await this.augmentImage(img); + const activation = this.mobilenet.infer(augmented, 'conv_pw_13_relu'); + + // 如果augmented是canvas元素,不需要dispose + if (augmented !== img && augmented.tagName === 'CANVAS') { + // Canvas元素不需要dispose + } + + const shape = activation.shape; + if (shape.length === 4 && shape[0] === 1) { + return activation.squeeze([0]); + } + return activation; + } + + // 正常提取特征 + const activation = this.mobilenet.infer(img, 'conv_pw_13_relu'); + + // 检查张量形状 + const shape = activation.shape; + + // 如果是4D张量 [batch, height, width, channels],去掉批次维度 + if (shape.length === 4 && shape[0] === 1) { + return activation.squeeze([0]); // 返回3D张量 + } + + // 如果已经是3D张量,直接返回 + return activation; + } catch (error) { + console.error('特征提取失败:', error); + throw error; + } + } + + async augmentImage(img) { + try { + // 简单的数据增强:随机噪声 + const imgTensor = tf.browser.fromPixels(img); + + // 添加少量随机噪声 + const noise = tf.randomNormal(imgTensor.shape, 0, 5); // 减少噪声强度 + const augmented = imgTensor.add(noise).clipByValue(0, 255); + + // 转换回图像格式 + const canvas = document.createElement('canvas'); + canvas.width = img.width || img.videoWidth || 224; + canvas.height = img.height || img.videoHeight || 224; + + // 归一化后转换为图像 + const normalized = augmented.div(255); + await tf.browser.toPixels(normalized, canvas); + + // 清理张量 + imgTensor.dispose(); + noise.dispose(); + augmented.dispose(); + normalized.dispose(); + + return canvas; + } catch (error) { + console.error('数据增强失败:', error); + // 如果增强失败,返回原图 + return img; + } + } + + clearDataset() { + this.trainingDataInputs.forEach(tensor => tensor.dispose()); + this.trainingDataInputs = []; + this.trainingDataOutputs = []; + this.classNames = []; + + // 清空标准化参数 + if (this.dataMean) { + this.dataMean.dispose(); + this.dataMean = null; + } + if (this.dataStd) { + this.dataStd.dispose(); + this.dataStd = null; + } + + // 清空文件输入和预览 + for (let i = 1; i <= 3; i++) { + document.getElementById(`class${i}Images`).value = ''; + document.getElementById(`class${i}Preview`).innerHTML = ''; + document.getElementById(`class${i}Count`).textContent = '0 张图片'; + } + + this.showStatus('dataStatus', 'info', '数据集已清空'); + } + + createCustomModel() { + const numClasses = this.classNames.length; + + // 获取超参数 + const dropoutRate = parseFloat(document.getElementById('dropoutRate').value); + const l2Reg = Math.pow(10, parseFloat(document.getElementById('l2Regularization').value)); + const learningRate = Math.pow(10, parseFloat(document.getElementById('learningRate').value)); + + // 获取输入形状 - 现在应该是3D张量 + const inputShape = this.trainingDataInputs[0].shape; + console.log('Input shape:', inputShape); + console.log('Hyperparameters:', { learningRate, dropoutRate, l2Reg }); + + // 创建更深的模型以提高区分能力 + this.model = tf.sequential({ + layers: [ + tf.layers.flatten({ + inputShape: inputShape + }), + // 第一层:更多神经元 + tf.layers.dense({ + units: 128, // 减少神经元数量,降低模型复杂度 + activation: 'relu', + kernelInitializer: 'heNormal', + kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }), + biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 }) + }), + tf.layers.batchNormalization(), + tf.layers.dropout({ rate: dropoutRate, seed: 42 }), // 添加seed使dropout更一致 + + // 第二层 + tf.layers.dense({ + units: 64, // 减少神经元数量 + activation: 'relu', + kernelInitializer: 'heNormal', + kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }), + biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 }) + }), + tf.layers.batchNormalization(), + tf.layers.dropout({ rate: dropoutRate, seed: 43 }), + + // 第三层 + tf.layers.dense({ + units: 32, // 进一步减少神经元 + activation: 'relu', + kernelInitializer: 'heNormal', + kernelRegularizer: tf.regularizers.l2({ l2: l2Reg }), + biasRegularizer: tf.regularizers.l2({ l2: l2Reg * 0.5 }) + }), + tf.layers.dropout({ rate: dropoutRate * 0.7, seed: 44 }), + + // 输出层 + tf.layers.dense({ + units: numClasses, + activation: 'softmax', + kernelInitializer: 'glorotNormal' + }) + ] + }); + + // 使用更合适的优化器和学习率 + this.model.compile({ + optimizer: tf.train.adam(learningRate), + loss: 'categoricalCrossentropy', + metrics: ['accuracy'] + }); + + console.log('Model created successfully'); + this.model.summary(); + } + + async trainModel() { + if (this.trainingDataInputs.length === 0) { + this.showStatus('trainingStatus', 'error', '请先添加训练数据!'); + return; + } + + // 检查数据平衡性 + const classCounts = {}; + this.trainingDataOutputs.forEach(label => { + classCounts[label] = (classCounts[label] || 0) + 1; + }); + + console.log('每个类别的样本数量:', classCounts); + console.log('类别名称:', this.classNames); + + // 计算类别权重以平衡数据 + const totalSamples = this.trainingDataOutputs.length; + const numClasses = this.classNames.length; + const classWeights = {}; + + for (let i = 0; i < numClasses; i++) { + const classCount = classCounts[i] || 1; + // 使用反比例权重:样本少的类别权重更高 + classWeights[i] = totalSamples / (numClasses * classCount); + } + + console.log('类别权重:', classWeights); + + this.showStatus('trainingStatus', 'info', '准备训练模型...'); + + // 创建模型 + this.createCustomModel(); + + // 准备训练数据 + let xs = tf.stack(this.trainingDataInputs); + const ys = tf.oneHot(tf.tensor1d(this.trainingDataOutputs, 'int32'), this.classNames.length); + + // 手动分割训练集和验证集(而不是使用validationSplit) + const totalDataCount = xs.shape[0]; // 改名避免冲突 + const indices = tf.util.createShuffledIndices(totalDataCount); + const splitIdx = Math.floor(totalDataCount * 0.8); + + // 确保indices是TypedArray + const trainIndices = Array.from(indices).slice(0, splitIdx); + const valIndices = Array.from(indices).slice(splitIdx); + + // 将indices转换为Tensor + const trainIndicesTensor = tf.tensor1d(trainIndices, 'int32'); + const valIndicesTensor = tf.tensor1d(valIndices, 'int32'); + + // 使用gather来获取对应的数据 + const xsTrain = tf.gather(xs, trainIndicesTensor); + const ysTrain = tf.gather(ys, trainIndicesTensor); + const xsVal = tf.gather(xs, valIndicesTensor); + const ysVal = tf.gather(ys, valIndicesTensor); + + // 清理索引张量 + trainIndicesTensor.dispose(); + valIndicesTensor.dispose(); + + // 只在训练集上计算标准化参数 + this.dataMean = xsTrain.mean([0], true); + this.dataStd = xsTrain.sub(this.dataMean).square().mean([0], true).sqrt(); + + // 应用标准化到训练集和验证集 + const xsTrainNorm = xsTrain.sub(this.dataMean).div(this.dataStd.add(1e-7)); + const xsValNorm = xsVal.sub(this.dataMean).div(this.dataStd.add(1e-7)); + + // 清理临时张量 + xsTrain.dispose(); + xsVal.dispose(); + xs.dispose(); + ys.dispose(); + + // 详细的数据集分析 + console.log('\n📊 === 数据集分析 ==='); + console.log(`训练集: ${trainIndices.length} 样本`); + console.log(`验证集: ${valIndices.length} 样本`); + console.log(`训练/验证比例: ${(trainIndices.length / totalDataCount * 100).toFixed(1)}% / ${(valIndices.length / totalDataCount * 100).toFixed(1)}%`); + + // 分析每个类别在训练集和验证集中的分布 + const trainLabels = trainIndices.map(i => this.trainingDataOutputs[i]); + const valLabels = valIndices.map(i => this.trainingDataOutputs[i]); + + console.log('\n类别分布:'); + for (let i = 0; i < this.classNames.length; i++) { + const trainCount = trainLabels.filter(l => l === i).length; + const valCount = valLabels.filter(l => l === i).length; + console.log(` ${this.classNames[i]}:`); + console.log(` - 训练集: ${trainCount} (${(trainCount/trainIndices.length*100).toFixed(1)}%)`); + console.log(` - 验证集: ${valCount} (${(valCount/valIndices.length*100).toFixed(1)}%)`); + } + + // 警告检查 + if (valIndices.length < 5) { + console.warn('⚠️ 验证集太小!建议至少5个样本'); + } + + if (valIndices.length < totalDataCount * 0.1) { + console.warn('⚠️ 验证集比例太小!建议至少10%的数据用于验证'); + } + + console.log('===================\n'); + + console.log('数据形状:', xs.shape); + console.log('均值形状:', this.dataMean.shape); + console.log('标准差形状:', this.dataStd.shape); + + // 打印数据集统计信息 + const uniqueLabels = [...new Set(this.trainingDataOutputs)]; + console.log('训练数据统计:'); + uniqueLabels.forEach(label => { + const count = this.trainingDataOutputs.filter(l => l === label).length; + console.log(` ${this.classNames[label]}: ${count} 张图片`); + }); + + // 显示进度条 + document.getElementById('trainingProgress').classList.remove('hidden'); + document.getElementById('trainBtn').disabled = true; + document.getElementById('stopBtn').disabled = false; + + this.isTraining = true; + this.lossHistory = []; + this.accuracyHistory = []; + this.valLossHistory = []; + this.valAccuracyHistory = []; + this.bestValLoss = Infinity; + this.patienceCounter = 0; + this.bestWeights = null; + + // 从界面获取超参数 + const epochs = parseInt(document.getElementById('epochs').value); + const batchSize = parseInt(document.getElementById('batchSize').value); + const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked; + const patience = parseInt(document.getElementById('patience').value); + this.temperature = parseFloat(document.getElementById('temperature').value); + + console.log('Training parameters:', { epochs, batchSize, enableEarlyStopping, patience, temperature: this.temperature }); + + try { + const startTime = Date.now(); + + await this.model.fit(xsTrainNorm, ysTrain, { + epochs, + batchSize, + validationData: [xsValNorm, ysVal], + shuffle: true, // 确保数据被打乱 + classWeight: classWeights, // 应用类别权重 + callbacks: { + onEpochEnd: (epoch, logs) => { + const progress = ((epoch + 1) / epochs) * 100; + document.getElementById('progressFill').style.width = `${progress}%`; + document.getElementById('progressFill').textContent = `${Math.round(progress)}%`; + + this.lossHistory.push(logs.loss); + this.accuracyHistory.push(logs.acc); + this.valLossHistory.push(logs.val_loss); + this.valAccuracyHistory.push(logs.val_acc); + + // 早停逻辑 + const enableEarlyStopping = document.getElementById('earlyStoppingCheck').checked; + const patience = parseInt(document.getElementById('patience').value); + + if (enableEarlyStopping) { + const improvement = this.bestValLoss - logs.val_loss; + const improvementPercent = (improvement / this.bestValLoss) * 100; + + // 详细的早停分析 + console.log(`\n=== Epoch ${epoch + 1} 早停分析 ===`); + console.log(`当前验证损失: ${logs.val_loss.toFixed(4)}`); + console.log(`最佳验证损失: ${this.bestValLoss.toFixed(4)}`); + console.log(`改善: ${improvement.toFixed(4)} (${improvementPercent.toFixed(2)}%)`); + console.log(`耐心计数器: ${this.patienceCounter}/${patience}`); + + // 只有显著改善才重置计数器(改善超过0.1%) + if (improvement > 0.001 && improvementPercent > 0.1) { + this.bestValLoss = logs.val_loss; + this.patienceCounter = 0; + // 保存最佳权重 + if (this.bestWeights) { + this.bestWeights.forEach(w => w.dispose()); + } + this.bestWeights = this.model.getWeights().map(w => w.clone()); + console.log(`✅ 新的最佳验证损失! 重置耐心计数器`); + } else { + this.patienceCounter++; + console.log(`⚠️ 没有显著改善,耐心计数器增加`); + + // 分析为什么没有改善 + if (logs.val_loss > logs.loss * 1.5) { + console.log(`📊 可能过拟合:验证损失 (${logs.val_loss.toFixed(4)}) >> 训练损失 (${logs.loss.toFixed(4)})`); + } + + if (logs.val_acc < logs.acc - 0.1) { + console.log(`📊 验证准确率低:验证 ${(logs.val_acc * 100).toFixed(1)}% vs 训练 ${(logs.acc * 100).toFixed(1)}%`); + } + + if (this.patienceCounter >= patience) { + console.log(`\n🛑 早停触发!原因分析:`); + console.log(`- 连续 ${patience} 个 epochs 没有改善`); + console.log(`- 最终训练损失: ${logs.loss.toFixed(4)}`); + console.log(`- 最终验证损失: ${logs.val_loss.toFixed(4)}`); + console.log(`- 最终训练准确率: ${(logs.acc * 100).toFixed(1)}%`); + console.log(`- 最终验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`); + + // 提供建议 + if (logs.val_loss > logs.loss * 1.5) { + console.log(`\n💡 建议:模型过拟合`); + console.log(` - 增加 Dropout 率`); + console.log(` - 增加 L2 正则化`); + console.log(` - 减少模型复杂度`); + console.log(` - 增加更多训练数据`); + } else if (logs.val_acc < 0.5) { + console.log(`\n💡 建议:模型欠拟合`); + console.log(` - 增加训练轮数`); + console.log(` - 增大学习率`); + console.log(` - 增加模型复杂度`); + } + + this.isTraining = false; + this.model.stopTraining = true; + } + } + console.log(`======================\n`); + } + + // 计算剩余时间 + const elapsedTime = Date.now() - startTime; + const timePerEpoch = elapsedTime / (epoch + 1); + const remainingTime = timePerEpoch * (epochs - epoch - 1); + const remainingSeconds = Math.round(remainingTime / 1000); + + // 检测可能的过拟合 + const overfitGap = logs.acc - logs.val_acc; + if (overfitGap > 0.2) { + console.warn(`⚠️ 可能过拟合!训练准确率: ${(logs.acc * 100).toFixed(1)}%, 验证准确率: ${(logs.val_acc * 100).toFixed(1)}%`); + } + + // 如果训练准确率达到100%,发出警告 + if (logs.acc >= 0.99) { + console.warn('⚠️ 训练准确率达到100%,模型可能过拟合!考虑:'); + console.warn(' 1. 增加更多训练数据'); + console.warn(' 2. 增大Dropout率'); + console.warn(' 3. 增大L2正则化'); + console.warn(' 4. 减少模型复杂度'); + } + + this.updateMetrics({ + epoch: epoch + 1, + totalEpochs: epochs, + loss: logs.loss, + accuracy: logs.acc, + valLoss: logs.val_loss, + valAccuracy: logs.val_acc, + remainingTime: remainingSeconds + }); + + this.plotLossChart(); + + // 更新训练状态 + const statusMsg = `训练中 - Epoch ${epoch + 1}/${epochs} | ` + + `准确率: ${(logs.acc * 100).toFixed(1)}% | ` + + `剩余时间: ${this.formatTime(remainingSeconds)}`; + this.showStatus('trainingStatus', 'info', statusMsg); + + if (!this.isTraining) { + this.model.stopTraining = true; + } + }, + onBatchEnd: (batch, logs) => { + // 可选:显示批次进度 + if (batch % 5 === 0) { + console.log(`Batch ${batch}: loss = ${logs.loss.toFixed(4)}`); + } + } + } + }); + + // 如果启用了早停并且保存了最佳权重,恢复它们 + if (this.bestWeights && document.getElementById('earlyStoppingCheck').checked) { + console.log('恢复最佳模型权重...'); + this.model.setWeights(this.bestWeights); + } + + this.showStatus('trainingStatus', 'success', '模型训练完成!'); + } catch (error) { + this.showStatus('trainingStatus', 'error', `训练失败: ${error.message}`); + } finally { + xsTrainNorm.dispose(); + ysTrain.dispose(); + xsValNorm.dispose(); + ysVal.dispose(); + + document.getElementById('trainBtn').disabled = false; + document.getElementById('stopBtn').disabled = true; + this.isTraining = false; + } + } + + stopTraining() { + this.isTraining = false; + this.showStatus('trainingStatus', 'info', '正在停止训练...'); + } + + formatTime(seconds) { + if (seconds < 60) { + return `${seconds} 秒`; + } else if (seconds < 3600) { + const minutes = Math.floor(seconds / 60); + const secs = seconds % 60; + return `${minutes} 分 ${secs} 秒`; + } else { + const hours = Math.floor(seconds / 3600); + const minutes = Math.floor((seconds % 3600) / 60); + return `${hours} 小时 ${minutes} 分`; + } + } + + updateMetrics(metrics) { + const container = document.getElementById('metricsContainer'); + container.innerHTML = ` +
+
Epoch
+
${metrics.epoch}/${metrics.totalEpochs}
+
+
+
损失
+
${metrics.loss.toFixed(4)}
+
+
+
准确率
+
${(metrics.accuracy * 100).toFixed(1)}%
+
+
+
验证损失
+
${metrics.valLoss.toFixed(4)}
+
+
+
验证准确率
+
${(metrics.valAccuracy * 100).toFixed(1)}%
+
+ ${metrics.remainingTime !== undefined ? ` +
+
剩余时间
+
${this.formatTime(metrics.remainingTime)}
+
+ ` : ''} + `; + } + + plotLossChart() { + const canvas = document.getElementById('lossChart'); + const ctx = canvas.getContext('2d'); + + // 设置canvas尺寸 + canvas.width = canvas.offsetWidth; + canvas.height = 300; + + // 清除画布 + ctx.clearRect(0, 0, canvas.width, canvas.height); + + if (this.lossHistory.length === 0) return; + + const padding = 40; + const chartWidth = canvas.width - 2 * padding; + const chartHeight = canvas.height - 2 * padding; + + // 找到最大值 + const maxLoss = Math.max(...this.lossHistory); + + // 绘制背景网格 + ctx.strokeStyle = '#e0e0e0'; + ctx.lineWidth = 1; + for (let i = 0; i <= 5; i++) { + const y = padding + (i * chartHeight / 5); + ctx.beginPath(); + ctx.moveTo(padding, y); + ctx.lineTo(canvas.width - padding, y); + ctx.stroke(); + } + + // 绘制坐标轴 + ctx.strokeStyle = '#333'; + ctx.lineWidth = 2; + ctx.beginPath(); + ctx.moveTo(padding, padding); + ctx.lineTo(padding, canvas.height - padding); + ctx.lineTo(canvas.width - padding, canvas.height - padding); + ctx.stroke(); + + // 绘制损失曲线 + ctx.strokeStyle = '#f56565'; + ctx.lineWidth = 2; + ctx.beginPath(); + for (let i = 0; i < this.lossHistory.length; i++) { + const x = padding + (i / (this.lossHistory.length - 1)) * chartWidth; + const y = canvas.height - padding - (this.lossHistory[i] / maxLoss) * chartHeight * 0.9; + + if (i === 0) { + ctx.moveTo(x, y); + } else { + ctx.lineTo(x, y); + } + } + ctx.stroke(); + + // 绘制准确率曲线 + ctx.strokeStyle = '#48bb78'; + ctx.lineWidth = 2; + ctx.beginPath(); + for (let i = 0; i < this.accuracyHistory.length; i++) { + const x = padding + (i / (this.accuracyHistory.length - 1)) * chartWidth; + const y = canvas.height - padding - this.accuracyHistory[i] * chartHeight * 0.9; + + if (i === 0) { + ctx.moveTo(x, y); + } else { + ctx.lineTo(x, y); + } + } + ctx.stroke(); + + // 添加图例 + ctx.fillStyle = '#f56565'; + ctx.fillRect(canvas.width - 120, 20, 15, 15); + ctx.fillStyle = '#333'; + ctx.font = '12px Arial'; + ctx.fillText('损失', canvas.width - 100, 32); + + ctx.fillStyle = '#48bb78'; + ctx.fillRect(canvas.width - 120, 40, 15, 15); + ctx.fillStyle = '#333'; + ctx.fillText('准确率', canvas.width - 100, 52); + } + + async startWebcam() { + if (!this.model) { + this.showStatus('predictionStatus', 'error', '请先训练模型!'); + return; + } + + const video = document.getElementById('webcam'); + + try { + const stream = await navigator.mediaDevices.getUserMedia({ + video: { facingMode: 'user' }, + audio: false + }); + + video.srcObject = stream; + this.webcamStream = stream; + + document.getElementById('startWebcamBtn').disabled = true; + document.getElementById('stopWebcamBtn').disabled = false; + + // 等待视频加载 + video.addEventListener('loadeddata', () => { + this.isPredicting = true; + this.predictLoop(); + }); + + this.showStatus('predictionStatus', 'success', '摄像头已启动'); + } catch (error) { + this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); + } + } + + stopWebcam() { + if (this.webcamStream) { + this.webcamStream.getTracks().forEach(track => track.stop()); + this.webcamStream = null; + } + + this.isPredicting = false; + + const video = document.getElementById('webcam'); + video.srcObject = null; + + document.getElementById('startWebcamBtn').disabled = false; + document.getElementById('stopWebcamBtn').disabled = true; + + document.getElementById('confidenceBars').innerHTML = ''; + + this.showStatus('predictionStatus', 'info', '摄像头已停止'); + } + + async predictLoop() { + if (!this.isPredicting) return; + + const video = document.getElementById('webcam'); + + if (video.readyState === 4) { + try { + // 提取特征 - 注意这里返回的是3D张量 + const features = await this.extractFeatures(video); + + // 需要添加批次维度进行预测 [batch, height, width, channels] + let featuresWithBatch = features.expandDims(0); + + // 如果有保存的均值和标准差,进行标准化 + if (this.dataMean && this.dataStd) { + featuresWithBatch = featuresWithBatch.sub(this.dataMean).div(this.dataStd.add(1e-7)); + } + + // 进行预测 + let logits = this.model.predict(featuresWithBatch); + + // 应用温度缩放 + const temperature = parseFloat(document.getElementById('temperature').value) || 1.0; + if (temperature !== 1.0) { + logits = logits.div(temperature); + } + + // 转换为概率 + const prediction = await tf.softmax(logits).data(); + + // 更新置信度显示 + this.updateConfidenceBars(prediction, temperature); + + // 清理张量 + features.dispose(); + featuresWithBatch.dispose(); + logits.dispose(); + } catch (error) { + console.error('预测错误:', error); + } + } + + // 继续预测循环 + requestAnimationFrame(() => this.predictLoop()); + } + + updateConfidenceBars(predictions, temperature = 1.0) { + const container = document.getElementById('confidenceBars'); + + let html = ''; + const classPredictions = []; + + // 计算最大置信度和熵 + const maxConfidence = Math.max(...predictions); + const maxConfidenceIndex = predictions.indexOf(maxConfidence); + const entropy = -predictions.reduce((sum, p) => sum + (p > 0 ? p * Math.log(p) : 0), 0); + const maxEntropy = Math.log(predictions.length); // 最大熵(均匀分布) + const uncertaintyScore = entropy / maxEntropy; // 0-1之间,越高越不确定 + + // 保持原始顺序,不排序 + for (let i = 0; i < predictions.length; i++) { + classPredictions.push({ + className: this.classNames[i], + confidence: predictions[i], + index: i, + isMax: i === maxConfidenceIndex + }); + } + + // 如果不确定性很高(熵接近最大值),显示警告 + let warningMessage = ''; + if (uncertaintyScore > 0.8) { + warningMessage = ` +
+ ⚠️ 不确定性高 - 模型对预测结果不确定 +
可能是未见过的类别或模糊图像 +
+ `; + } else if (maxConfidence < 0.5) { + warningMessage = ` +
+ ❌ 置信度低 - 没有明显匹配的类别 +
这可能不是训练过的类别 +
+ `; + } + + html = warningMessage; + + // 显示温度信息 + if (temperature !== 1.0) { + html += ` +
+ 🌡️ 温度: ${temperature.toFixed(1)} (越高越保守,越低越自信) +
+ `; + } + + // 按原始顺序显示,但标记最高置信度 + classPredictions.forEach((pred) => { + const percentage = (pred.confidence * 100).toFixed(1); + const barWidth = Math.max(percentage, 5); // 最小宽度5%以显示标签 + + // 根据置信度设置颜色 + let barColor = 'linear-gradient(90deg, #667eea, #764ba2)'; + if (pred.confidence < 0.3) { + barColor = 'linear-gradient(90deg, #f56565, #e53e3e)'; + } else if (pred.confidence < 0.6) { + barColor = 'linear-gradient(90deg, #ed8936, #dd6b20)'; + } else if (pred.confidence > 0.8) { + barColor = 'linear-gradient(90deg, #48bb78, #38a169)'; + } + + // 为最高置信度的类别添加特殊样式 + const isWinner = pred.isMax && pred.confidence > 0.5; + const borderStyle = isWinner ? 'border: 2px solid #48bb78; padding: 3px; border-radius: 5px;' : ''; + + html += ` +
+
+ ${pred.className} ${isWinner ? '👑' : ''} + ${percentage}% +
+
+
+ ${percentage > 20 ? percentage + '%' : ''} +
+
+
+ `; + }); + + // 添加不确定性分数 + html += ` +
+ 统计信息:
+ 最高置信度: ${(maxConfidence * 100).toFixed(1)}%
+ 不确定性: ${(uncertaintyScore * 100).toFixed(1)}%
+ 熵: ${entropy.toFixed(3)} / ${maxEntropy.toFixed(3)} +
+ `; + + container.innerHTML = html; + } + + async saveModel() { + if (!this.model) { + this.showStatus('predictionStatus', 'error', '没有可保存的模型'); + return; + } + + try { + // 保存模型和类别名称 + await this.model.save('downloads://custom-image-classifier'); + + // 保存类别信息和标准化参数 + const classInfo = { + classNames: this.classNames, + dataMean: this.dataMean ? await this.dataMean.data() : null, + dataStd: this.dataStd ? await this.dataStd.data() : null, + date: new Date().toISOString() + }; + + const blob = new Blob([JSON.stringify(classInfo, null, 2)], + { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'class-info.json'; + a.click(); + URL.revokeObjectURL(url); + + this.showStatus('predictionStatus', 'success', '模型已保存'); + } catch (error) { + this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`); + } + } + + async loadModel() { + const modelInput = document.createElement('input'); + modelInput.type = 'file'; + modelInput.multiple = true; + modelInput.accept = '.json,.bin'; + + const classInput = document.createElement('input'); + classInput.type = 'file'; + classInput.accept = '.json'; + + modelInput.onchange = async (e) => { + try { + const files = e.target.files; + this.model = await tf.loadLayersModel(tf.io.browserFiles(files)); + + this.showStatus('predictionStatus', 'info', '模型已加载,请选择类别信息文件'); + + // 加载类别信息 + classInput.click(); + } catch (error) { + this.showStatus('predictionStatus', 'error', `加载模型失败: ${error.message}`); + } + }; + + classInput.onchange = async (e) => { + try { + const file = e.target.files[0]; + const text = await file.text(); + const classInfo = JSON.parse(text); + + this.classNames = classInfo.classNames; + + // 恢复标准化参数 + if (classInfo.dataMean && classInfo.dataStd) { + this.dataMean = tf.tensor(classInfo.dataMean); + this.dataStd = tf.tensor(classInfo.dataStd); + console.log('已恢复数据标准化参数'); + } + + this.showStatus('predictionStatus', 'success', + `模型加载成功!类别: ${this.classNames.join(', ')}`); + } catch (error) { + this.showStatus('predictionStatus', 'error', `加载类别信息失败: ${error.message}`); + } + }; + + modelInput.click(); + } + + resetHyperparameters() { + document.getElementById('learningRate').value = -3; + document.getElementById('learningRateValue').textContent = '0.001'; + document.getElementById('epochs').value = 100; + document.getElementById('epochsValue').textContent = '100'; + document.getElementById('dropoutRate').value = 0.3; + document.getElementById('dropoutValue').textContent = '0.3'; + document.getElementById('l2Regularization').value = -2; + document.getElementById('l2Value').textContent = '0.01'; + document.getElementById('batchSize').value = 32; + document.getElementById('batchSizeValue').textContent = '32'; + document.getElementById('temperature').value = 1.0; + document.getElementById('temperatureValue').textContent = '1.0'; + document.getElementById('earlyStoppingCheck').checked = true; + document.getElementById('patience').value = 10; + document.getElementById('patienceValue').textContent = '10'; + + this.showStatus('trainingStatus', 'info', '超参数已重置为默认值'); + } + + showRecommendations() { + const message = ` +💡 超参数设置建议: + +如果过拟合(训练准确率高但验证准确率低): +- 增大 Dropout 率 (0.4-0.6) +- 增大 L2 正则化 (0.01-0.1) +- 减少训练轮数 +- 启用早停 +- 增加更多训练数据 + +如果欠拟合(训练和验证准确率都低): +- 减小 Dropout 率 (0.1-0.2) +- 减小 L2 正则化 (0.0001-0.001) +- 增加训练轮数 +- 增大学习率 (0.001-0.01) + +温度缩放: +- < 1.0: 更自信的预测(可能过度自信) +- = 1.0: 标准预测 +- > 1.0: 更保守的预测(减少过度自信) + `; + alert(message); + } + + showStatus(elementId, type, message) { + const element = document.getElementById(elementId); + + const classMap = { + 'success': 'status-success', + 'error': 'status-error', + 'info': 'status-info', + 'warning': 'status-warning' + }; + + element.className = `status-message ${classMap[type] || 'status-info'}`; + element.textContent = message; + } +} + +// 初始化应用 +let classifier; +document.addEventListener('DOMContentLoaded', () => { + classifier = new CustomImageClassifier(); +}); \ No newline at end of file diff --git a/原版KNN/knn-classifier.html b/原版KNN/knn-classifier.html new file mode 100644 index 0000000..dc7fce6 --- /dev/null +++ b/原版KNN/knn-classifier.html @@ -0,0 +1,551 @@ + + + + + + KNN 图像分类器 - TensorFlow.js + + + + + + +
+

🤖 KNN 图像分类器(基于特征标签)

+ +
+ +
+

📸 数据采集

+ +
+

1 第一类

+ + + + 0 张图片 + +
+
+ +
+

2 第二类

+ + + + 0 张图片 + +
+
+ +
+

3 第三类(可选)

+ + + + 0 张图片 + +
+
+ +
+ + +
+ +
+
+ + +
+

🎯 KNN 模型设置

+ +
+ + + K值越大,预测越保守;K值越小,对局部特征越敏感 +
+ +
+ + + 低通滤波器系数:值越小输出越平滑(0.1-0.3推荐),值越大响应越快 +
+ +
+

📊 特征标签提取预览

+
等待数据...
+
+ +
+

ℹ️ 模型信息

+
+ 预训练模型: + MobileNet v2 +
+
+ 特征维度: + 1000个标签 +
+
+ 分类器类型: + K-最近邻 (KNN) +
+
+ 总样本数: + 0 +
+
+
+
+ + +
+

📹 实时预测

+ +
+ + + + +
+ +
+ +
+ +
+

预测结果

+
等待预测...
+
+ +
+
+
+ + + + \ No newline at end of file diff --git a/原版KNN/knn-classifier.js b/原版KNN/knn-classifier.js new file mode 100644 index 0000000..7d99d8a --- /dev/null +++ b/原版KNN/knn-classifier.js @@ -0,0 +1,621 @@ +// KNN 图像分类器 - 基于MobileNet特征标签 +class KNNImageClassifier { + constructor() { + this.mobilenet = null; + this.knnClassifier = null; + this.classNames = []; + this.webcamStream = null; + this.isPredicting = false; + this.currentCaptureClass = -1; + this.imagenetClasses = null; + + // 低通滤波器状态 + this.filteredConfidences = {}; + this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑 + + this.init(); + } + + async init() { + this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); + + try { + // 加载 MobileNet 模型 + this.mobilenet = await mobilenet.load({ + version: 2, + alpha: 1.0 + }); + + // 创建 KNN 分类器 + this.knnClassifier = knnClassifier.create(); + + // 加载 ImageNet 类别名称 + await this.loadImageNetClasses(); + + this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); + this.setupEventListeners(); + } catch (error) { + this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); + } + } + + async loadImageNetClasses() { + // ImageNet 前10个类别名称(简化版) + this.imagenetClasses = [ + 'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead', + 'electric_ray', 'stingray', 'cock', 'hen', 'ostrich' + ]; + } + + setupEventListeners() { + // 文件上传监听 + ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { + document.getElementById(id).addEventListener('change', (e) => { + this.handleImageUpload(e, index); + }); + }); + + // 按钮监听 + document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN()); + document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); + document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); + document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); + document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel()); + document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel()); + } + + handleImageUpload(event, classIndex) { + const files = event.target.files; + const countElement = document.getElementById(`class${classIndex + 1}Count`); + const previewContainer = document.getElementById(`class${classIndex + 1}Preview`); + + countElement.textContent = `${files.length} 张图片`; + + // 清空之前的预览 + previewContainer.innerHTML = ''; + + // 添加图片预览 + Array.from(files).forEach(file => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = document.createElement('img'); + img.src = e.target.result; + img.className = 'preview-img'; + previewContainer.appendChild(img); + }; + reader.readAsDataURL(file); + }); + } + + // 从图像提取 MobileNet 标签和权重 + async extractImageNetTags(img) { + try { + // 获取 MobileNet 的预测(1000个类别的概率) + const predictions = await this.mobilenet.classify(img); + + // 获取完整的 logits(原始输出) + const logits = this.mobilenet.infer(img, false); // false = 不使用嵌入层,获取原始1000维输出 + + // 获取前10个最高概率的标签 + const topK = await this.getTopKTags(logits, 10); + + return { + logits: logits, // 1000维特征向量 + predictions: predictions, // 前3个预测 + topTags: topK // 前10个标签和权重 + }; + } catch (error) { + console.error('特征提取失败:', error); + throw error; + } + } + + // 获取Top-K标签 + async getTopKTags(logits, k = 10) { + const values = await logits.data(); + const valuesAndIndices = []; + + for (let i = 0; i < values.length; i++) { + valuesAndIndices.push({ value: values[i], index: i }); + } + + valuesAndIndices.sort((a, b) => b.value - a.value); + const topkValues = new Float32Array(k); + const topkIndices = new Int32Array(k); + + for (let i = 0; i < k; i++) { + topkValues[i] = valuesAndIndices[i].value; + topkIndices[i] = valuesAndIndices[i].index; + } + + const topTags = []; + for (let i = 0; i < k; i++) { + topTags.push({ + className: this.imagenetClasses[i] || `class_${topkIndices[i]}`, + probability: this.softmax(topkValues)[i], + logit: topkValues[i] + }); + } + + return topTags; + } + + // Softmax 函数 + softmax(arr) { + const maxLogit = Math.max(...arr); + const scores = arr.map(l => Math.exp(l - maxLogit)); + const sum = scores.reduce((a, b) => a + b); + return scores.map(s => s / sum); + } + + // 训练 KNN 模型 + async trainKNN() { + const classes = []; + const imageFiles = []; + + // 收集所有类别和图片 + for (let i = 1; i <= 3; i++) { + const className = document.getElementById(`class${i}Name`).value.trim(); + const files = document.getElementById(`class${i}Images`).files; + + if (className && files.length > 0) { + classes.push(className); + imageFiles.push(files); + } + } + + if (classes.length < 2) { + this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!'); + return; + } + + this.classNames = classes; + this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...'); + + // 清空现有的KNN分类器 + this.knnClassifier.clearAllClasses(); + + let totalProcessed = 0; + let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0); + + // 处理每个类别的图片 + for (let classIndex = 0; classIndex < classes.length; classIndex++) { + const files = imageFiles[classIndex]; + console.log(`处理类别 ${classes[classIndex]}...`); + + for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { + try { + const img = await this.loadImage(files[fileIndex]); + + // 提取特征标签 + const features = await this.extractImageNetTags(img); + + // 添加到KNN分类器 + // 使用完整的logits作为特征向量 + this.knnClassifier.addExample(features.logits, classIndex); + + totalProcessed++; + const progress = Math.round((totalProcessed / totalImages) * 100); + this.showStatus('dataStatus', 'info', + `处理中... ${totalProcessed}/${totalImages} (${progress}%)`); + + // 显示提取的标签 + if (fileIndex === 0) { + this.displayTopTags(features.topTags); + } + + // 清理 + img.remove(); + features.logits.dispose(); + } catch (error) { + console.error('处理图片失败:', error); + } + } + } + + // 更新模型信息 + document.getElementById('totalSamples').textContent = totalProcessed; + + this.showStatus('dataStatus', 'success', + `KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`); + + console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别'); + } + + // 显示提取的标签 + displayTopTags(tags) { + const container = document.getElementById('tagsList'); + let html = ''; + + tags.slice(0, 5).forEach(tag => { + html += ` + + ${tag.className} + ${(tag.probability * 100).toFixed(1)}% + + `; + }); + + container.innerHTML = html; + } + + // 加载图片 + async loadImage(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = new Image(); + img.onload = () => resolve(img); + img.onerror = reject; + img.src = e.target.result; + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); + } + + // 清空数据集 + clearDataset() { + this.knnClassifier.clearAllClasses(); + this.classNames = []; + this.filteredConfidences = {}; // 重置滤波器状态 + + for (let i = 1; i <= 3; i++) { + document.getElementById(`class${i}Images`).value = ''; + document.getElementById(`class${i}Count`).textContent = '0 张图片'; + document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览 + } + + document.getElementById('totalSamples').textContent = '0'; + document.getElementById('tagsList').innerHTML = '等待数据...'; + document.getElementById('predictions').innerHTML = '等待预测...'; + + this.showStatus('dataStatus', 'info', '数据集已清空'); + } + + // 启动摄像头 + async startWebcam() { + if (this.knnClassifier.getNumClasses() === 0) { + this.showStatus('predictionStatus', 'error', '请先训练模型!'); + return; + } + + const video = document.getElementById('webcam'); + + try { + const stream = await navigator.mediaDevices.getUserMedia({ + video: { facingMode: 'user' }, + audio: false + }); + + video.srcObject = stream; + this.webcamStream = stream; + + document.getElementById('startWebcamBtn').disabled = true; + document.getElementById('stopWebcamBtn').disabled = false; + + // 等待视频加载 + video.addEventListener('loadeddata', () => { + this.isPredicting = true; + this.predictLoop(); + }); + + this.showStatus('predictionStatus', 'success', '摄像头已启动'); + } catch (error) { + this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); + } + } + + // 停止摄像头 + stopWebcam() { + if (this.webcamStream) { + this.webcamStream.getTracks().forEach(track => track.stop()); + this.webcamStream = null; + } + + this.isPredicting = false; + this.filteredConfidences = {}; // 重置滤波器状态 + + const video = document.getElementById('webcam'); + video.srcObject = null; + + document.getElementById('startWebcamBtn').disabled = false; + document.getElementById('stopWebcamBtn').disabled = true; + + this.showStatus('predictionStatus', 'info', '摄像头已停止'); + } + + // 预测循环 + async predictLoop() { + if (!this.isPredicting) return; + + const video = document.getElementById('webcam'); + + if (video.readyState === 4) { + try { + // 提取特征 + const features = await this.extractImageNetTags(video); + + // 使用原始KNN进行预测 + const k = parseInt(document.getElementById('kValue').value); + const prediction = await this.knnClassifier.predictClass(features.logits, k); + + // 应用低通滤波器 + const smoothedPrediction = this.applyLowPassFilter(prediction); + + // 显示预测结果 + this.displayPrediction(smoothedPrediction); + + // 显示提取的标签 + this.displayTopTags(features.topTags); + + // 清理张量 + features.logits.dispose(); + } catch (error) { + console.error('预测错误:', error); + } + } + + // 继续预测循环 + requestAnimationFrame(() => this.predictLoop()); + } + + // 应用低通滤波器到置信度 + applyLowPassFilter(prediction) { + // 获取滤波器系数 + const alpha = parseFloat(document.getElementById('filterAlpha').value); + + // 初始化滤波状态(如果是第一次) + if (Object.keys(this.filteredConfidences).length === 0) { + for (let i = 0; i < this.classNames.length; i++) { + this.filteredConfidences[i] = prediction.confidences[i] || 0; + } + return { + label: prediction.label, + confidences: {...this.filteredConfidences} + }; + } + + // 应用指数移动平均(EMA)低通滤波 + const newConfidences = {}; + for (let i = 0; i < this.classNames.length; i++) { + const currentValue = prediction.confidences[i] || 0; + const previousValue = this.filteredConfidences[i] || 0; + + // EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1] + this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue; + newConfidences[i] = this.filteredConfidences[i]; + } + + // 归一化确保总和为1 + let sum = 0; + Object.values(newConfidences).forEach(v => sum += v); + if (sum > 0) { + Object.keys(newConfidences).forEach(key => { + newConfidences[key] = newConfidences[key] / sum; + }); + } + + // 找到最高置信度的类别 + let maxConfidence = 0; + let bestLabel = 0; + Object.keys(newConfidences).forEach(key => { + if (newConfidences[key] > maxConfidence) { + maxConfidence = newConfidences[key]; + bestLabel = parseInt(key); + } + }); + + return { + label: bestLabel, + confidences: newConfidences + }; + } + + // 显示预测结果 + displayPrediction(prediction) { + const container = document.getElementById('predictions'); + let html = ''; + + // 直接使用滤波后的置信度 + const confidences = prediction.confidences; + const predictedClass = prediction.label; + + // 固定顺序显示(按类别索引) + for (let i = 0; i < this.classNames.length; i++) { + const className = this.classNames[i]; + const confidence = confidences[i] || 0; + const percentage = (confidence * 100).toFixed(1); + const isWinner = i === predictedClass; + + // 根据置信度决定颜色等级 + let barClass = ''; + if (confidence > 0.7) barClass = 'high'; + else if (confidence > 0.4) barClass = 'medium'; + else barClass = 'low'; + + // 如果是获胜类别,使用绿色 + if (isWinner) barClass = 'high'; + + html += ` +
+
+ + ${className} ${isWinner ? '👑' : ''} + + + ${percentage}% + +
+
+
+ ${confidence > 0.15 ? `${percentage}%` : ''} +
+
+
+ `; + } + + container.innerHTML = html; + } + + // 从摄像头捕获样本 + async captureFromWebcam(classIndex) { + if (!this.webcamStream) { + // 临时启动摄像头 + const video = document.getElementById('webcam'); + try { + const stream = await navigator.mediaDevices.getUserMedia({ + video: { facingMode: 'user' }, + audio: false + }); + + video.srcObject = stream; + this.webcamStream = stream; + + // 等待视频加载 + setTimeout(async () => { + await this.addWebcamSample(classIndex); + + // 停止临时摄像头 + this.webcamStream.getTracks().forEach(track => track.stop()); + this.webcamStream = null; + video.srcObject = null; + }, 1000); + } catch (error) { + this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`); + } + } else { + await this.addWebcamSample(classIndex); + } + } + + // 添加摄像头样本 + async addWebcamSample(classIndex) { + const video = document.getElementById('webcam'); + + if (video.readyState === 4) { + try { + // 提取特征 + const features = await this.extractImageNetTags(video); + + // 添加到KNN分类器 + this.knnClassifier.addExample(features.logits, classIndex); + + // 更新计数 + const currentCount = this.knnClassifier.getClassExampleCount(); + const count = currentCount[classIndex] || 0; + document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`; + + // 清理 + features.logits.dispose(); + + this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`); + } catch (error) { + console.error('添加样本失败:', error); + } + } + } + + // 保存模型 + async saveModel() { + if (this.knnClassifier.getNumClasses() === 0) { + this.showStatus('predictionStatus', 'error', '没有可保存的模型'); + return; + } + + try { + // 获取KNN分类器的数据 + const dataset = this.knnClassifier.getClassifierDataset(); + const datasetObj = {}; + + Object.keys(dataset).forEach(key => { + const data = dataset[key].dataSync(); + datasetObj[key] = Array.from(data); + }); + + // 保存为JSON + const modelData = { + dataset: datasetObj, + classNames: this.classNames, + k: document.getElementById('kValue').value, + date: new Date().toISOString() + }; + + const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'knn-model.json'; + a.click(); + URL.revokeObjectURL(url); + + this.showStatus('predictionStatus', 'success', '模型已保存'); + } catch (error) { + this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`); + } + } + + // 加载模型 + async loadModel() { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = '.json'; + + input.onchange = async (e) => { + try { + const file = e.target.files[0]; + const text = await file.text(); + const modelData = JSON.parse(text); + + // 清空现有分类器 + this.knnClassifier.clearAllClasses(); + + // 恢复数据集 + Object.keys(modelData.dataset).forEach(key => { + const tensor = tf.tensor(modelData.dataset[key], [modelData.dataset[key].length / 1024, 1024]); + this.knnClassifier.setClassifierDataset({ [key]: tensor }); + }); + + this.classNames = modelData.classNames; + document.getElementById('kValue').value = modelData.k; + document.getElementById('kValueDisplay').textContent = modelData.k; + + this.showStatus('predictionStatus', 'success', + `模型加载成功!类别: ${this.classNames.join(', ')}`); + } catch (error) { + this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`); + } + }; + + input.click(); + } + + // 显示状态 + showStatus(elementId, type, message) { + const element = document.getElementById(elementId); + + const classMap = { + 'success': 'status-success', + 'error': 'status-error', + 'info': 'status-info' + }; + + element.className = `status-message ${classMap[type]}`; + element.textContent = message; + } +} + +// 全局函数:从摄像头捕获 +function captureFromWebcam(classIndex) { + if (window.classifier) { + window.classifier.captureFromWebcam(classIndex); + } +} + +// 初始化应用 +let classifier; +document.addEventListener('DOMContentLoaded', () => { + classifier = new KNNImageClassifier(); + window.classifier = classifier; +}); \ No newline at end of file diff --git a/完善KNN/knn-classifier.html b/完善KNN/knn-classifier.html new file mode 100644 index 0000000..3b4710c --- /dev/null +++ b/完善KNN/knn-classifier.html @@ -0,0 +1,561 @@ + + + + + + KNN 图像分类器 - TensorFlow.js + + + + + + +
+

🤖 KNN 图像分类器(基于特征标签)

+ +
+ +
+

📸 数据采集

+ +
+

1 第一类

+ + + + 0 张图片 + +
+
+ +
+

2 第二类

+ + + + 0 张图片 + +
+
+ +
+

3 第三类(可选)

+ + + + 0 张图片 + +
+
+ +
+ + +
+ +
+
+ + +
+

🎯 KNN 模型设置

+ +
+ + + K值越大,预测越保守;K值越小,对局部特征越敏感 +
+ +
+ + + 低通滤波器系数:值越小输出越平滑(0.1-0.3推荐),值越大响应越快 +
+ +
+ + + 距离阈值:样本与训练数据的最大距离,超过此值判定为"未知/背景"(单品类检测关键参数) +
+ +
+

📊 特征标签提取预览

+
等待数据...
+
+ +
+

ℹ️ 模型信息

+
+ 预训练模型: + MobileNet v2 +
+
+ 特征维度: + 1280维嵌入向量 +
+
+ 分类器类型: + K-最近邻 (KNN) +
+
+ 总样本数: + 0 +
+
+
+
+ + +
+

📹 实时预测

+ +
+ + + + +
+ +
+ +
+ +
+

预测结果

+
等待预测...
+
+ +
+
+
+ + + + \ No newline at end of file diff --git a/完善KNN/knn-classifier.js b/完善KNN/knn-classifier.js new file mode 100644 index 0000000..db05844 --- /dev/null +++ b/完善KNN/knn-classifier.js @@ -0,0 +1,985 @@ +// KNN 图像分类器 - 基于MobileNet特征标签 +class KNNImageClassifier { + constructor() { + this.mobilenet = null; + this.knnClassifier = null; + this.classNames = []; + this.webcamStream = null; + this.isPredicting = false; + this.currentCaptureClass = -1; + this.imagenetClasses = null; + + // 低通滤波器状态 + this.filteredConfidences = {}; + this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑 + + // 距离阈值设置 + this.useDistanceThreshold = true; + this.distanceThreshold = 0.5; // 默认距离阈值(归一化后的特征) + this.adaptiveThreshold = null; // 自适应阈值 + + this.init(); + } + + async init() { + this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); + + try { + // 加载 MobileNet 模型 + this.mobilenet = await mobilenet.load({ + version: 2, + alpha: 1.0 + }); + + // 创建 KNN 分类器 + this.knnClassifier = knnClassifier.create(); + + // 加载 ImageNet 类别名称 + await this.loadImageNetClasses(); + + this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); + this.setupEventListeners(); + } catch (error) { + this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); + } + } + + async loadImageNetClasses() { + // ImageNet 前10个类别名称(简化版) + this.imagenetClasses = [ + 'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead', + 'electric_ray', 'stingray', 'cock', 'hen', 'ostrich' + ]; + } + + setupEventListeners() { + // 文件上传监听 + ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { + document.getElementById(id).addEventListener('change', (e) => { + this.handleImageUpload(e, index); + }); + }); + + // 按钮监听 + document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN()); + document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); + document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); + document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); + document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel()); + document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel()); + } + + handleImageUpload(event, classIndex) { + const files = event.target.files; + const countElement = document.getElementById(`class${classIndex + 1}Count`); + const previewContainer = document.getElementById(`class${classIndex + 1}Preview`); + + countElement.textContent = `${files.length} 张图片`; + + // 清空之前的预览 + previewContainer.innerHTML = ''; + + // 添加图片预览 + Array.from(files).forEach(file => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = document.createElement('img'); + img.src = e.target.result; + img.className = 'preview-img'; + previewContainer.appendChild(img); + }; + reader.readAsDataURL(file); + }); + } + + // 从图像提取 MobileNet 标签和权重 + async extractImageNetTags(img) { + try { + // 获取 MobileNet 的预测(1000个类别的概率) + const predictions = await this.mobilenet.classify(img); + + // 获取用于KNN的特征(使用嵌入层获得更好的特征表示) + const rawEmbeddings = this.mobilenet.infer(img, true); // true = 使用嵌入层,获取1280维特征 + + // L2归一化特征向量(重要:使距离计算更稳定) + const embeddings = tf.tidy(() => { + const norm = tf.norm(rawEmbeddings); + const normalized = tf.div(rawEmbeddings, norm); + rawEmbeddings.dispose(); // 清理原始嵌入 + return normalized; + }); + + // 获取用于显示的logits(1000个类别) + const logits = this.mobilenet.infer(img, false); // false = 获取原始1000维输出 + + // 获取前10个最高概率的标签 + const topK = await this.getTopKTags(logits, 10); + + // 清理logits(只用于显示) + logits.dispose(); + + return { + logits: embeddings, // 使用1280维嵌入特征用于KNN + predictions: predictions, // 前3个预测 + topTags: topK // 前10个标签和权重 + }; + } catch (error) { + console.error('特征提取失败:', error); + throw error; + } + } + + // 获取Top-K标签 + async getTopKTags(logits, k = 10) { + const values = await logits.data(); + const valuesAndIndices = []; + + for (let i = 0; i < values.length; i++) { + valuesAndIndices.push({ value: values[i], index: i }); + } + + valuesAndIndices.sort((a, b) => b.value - a.value); + const topkValues = new Float32Array(k); + const topkIndices = new Int32Array(k); + + for (let i = 0; i < k; i++) { + topkValues[i] = valuesAndIndices[i].value; + topkIndices[i] = valuesAndIndices[i].index; + } + + const topTags = []; + for (let i = 0; i < k; i++) { + topTags.push({ + className: this.imagenetClasses[i] || `class_${topkIndices[i]}`, + probability: this.softmax(topkValues)[i], + logit: topkValues[i] + }); + } + + return topTags; + } + + // Softmax 函数 + softmax(arr) { + const maxLogit = Math.max(...arr); + const scores = arr.map(l => Math.exp(l - maxLogit)); + const sum = scores.reduce((a, b) => a + b); + return scores.map(s => s / sum); + } + + // 训练 KNN 模型 + async trainKNN() { + const classes = []; + const imageFiles = []; + + // 收集所有类别和图片 + for (let i = 1; i <= 3; i++) { + const className = document.getElementById(`class${i}Name`).value.trim(); + const files = document.getElementById(`class${i}Images`).files; + + if (className && files && files.length > 0) { + classes.push(className); + imageFiles.push(files); + console.log(`类别 ${i}: "${className}" - ${files.length} 张图片`); + } + } + + console.log('收集到的类别:', classes); + console.log('类别数量:', classes.length); + + // 支持单品类检测(One-Class Classification) + if (classes.length < 1) { + this.showStatus('dataStatus', 'error', '请至少添加一个类别的图片!'); + return; + } + + // 如果只有一个类别,提示用户这是单品类检测模式 + if (classes.length === 1) { + console.log('📍 单品类检测模式:只检测 "' + classes[0] + '",其他都视为背景/未知'); + } + + this.classNames = classes; + this.filteredConfidences = {}; // 重置滤波器状态 + this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...'); + + // 清空现有的KNN分类器 + this.knnClassifier.clearAllClasses(); + + let totalProcessed = 0; + let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0); + + // 处理每个类别的图片 + for (let classIndex = 0; classIndex < classes.length; classIndex++) { + const files = imageFiles[classIndex]; + console.log(`处理类别 ${classes[classIndex]}...`); + + for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { + try { + const img = await this.loadImage(files[fileIndex]); + + // 提取特征标签 + const features = await this.extractImageNetTags(img); + + // 添加到KNN分类器 + // 使用完整的logits作为特征向量 + this.knnClassifier.addExample(features.logits, classIndex); + + totalProcessed++; + const progress = Math.round((totalProcessed / totalImages) * 100); + this.showStatus('dataStatus', 'info', + `处理中... ${totalProcessed}/${totalImages} (${progress}%)`); + + // 显示提取的标签 + if (fileIndex === 0) { + this.displayTopTags(features.topTags); + } + + // 清理 + img.remove(); + features.logits.dispose(); + } catch (error) { + console.error('处理图片失败:', error); + } + } + } + + // 更新模型信息 + document.getElementById('totalSamples').textContent = totalProcessed; + + // 根据类别数量显示不同的消息 + let statusMessage; + if (classes.length === 1) { + statusMessage = `单品类检测模型训练完成!将只检测 "${classes[0]}",共 ${totalProcessed} 个样本`; + } else { + statusMessage = `KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`; + } + + this.showStatus('dataStatus', 'success', statusMessage); + + console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别'); + if (classes.length === 1) { + console.log('📍 单品类检测模式已启用,将基于距离阈值判断是否为:', classes[0]); + // 计算自适应阈值 + await this.calculateAdaptiveThreshold(); + } + } + + // 计算自适应阈值(基于训练数据的内部距离) + async calculateAdaptiveThreshold() { + if (this.knnClassifier.getNumClasses() !== 1) return; + + console.log('计算自适应阈值...'); + + const dataset = this.knnClassifier.getClassifierDataset(); + if (!dataset || !dataset[0]) return; + + const trainData = await dataset[0].data(); + const numSamples = dataset[0].shape[0]; + const featureDim = dataset[0].shape[1]; + + // 计算训练样本之间的平均距离 + let totalDistance = 0; + let count = 0; + + for (let i = 0; i < Math.min(numSamples, 20); i++) { // 限制计算量 + for (let j = i + 1; j < Math.min(numSamples, 20); j++) { + let distance = 0; + for (let k = 0; k < featureDim; k++) { + const diff = trainData[i * featureDim + k] - trainData[j * featureDim + k]; + distance += diff * diff; + } + distance = Math.sqrt(distance); + totalDistance += distance; + count++; + } + } + + if (count > 0) { + const avgInternalDistance = totalDistance / count; + // 自适应阈值设为内部平均距离的1.3-1.5倍(归一化后距离较小) + this.adaptiveThreshold = avgInternalDistance * 1.3; + + console.log(`内部平均距离: ${avgInternalDistance.toFixed(2)}`); + console.log(`建议自适应阈值: ${this.adaptiveThreshold.toFixed(2)}`); + + // 更新UI显示建议阈值 + const thresholdInput = document.getElementById('distanceThreshold'); + const thresholdDisplay = document.getElementById('distanceThresholdDisplay'); + if (thresholdInput && thresholdDisplay) { + thresholdInput.value = this.adaptiveThreshold.toFixed(1); + thresholdDisplay.textContent = this.adaptiveThreshold.toFixed(1); + } + + this.showStatus('dataStatus', 'info', + `自适应阈值已计算: ${this.adaptiveThreshold.toFixed(1)} (基于训练数据内部距离)`); + } + } + + // 显示提取的标签 + displayTopTags(tags) { + const container = document.getElementById('tagsList'); + let html = ''; + + tags.slice(0, 5).forEach(tag => { + html += ` + + ${tag.className} + ${(tag.probability * 100).toFixed(1)}% + + `; + }); + + container.innerHTML = html; + } + + // 加载图片 + async loadImage(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = new Image(); + img.onload = () => resolve(img); + img.onerror = reject; + img.src = e.target.result; + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); + } + + // 清空数据集 + clearDataset() { + this.knnClassifier.clearAllClasses(); + this.classNames = []; + this.filteredConfidences = {}; // 重置滤波器状态 + + console.log('数据集已清空,滤波器状态已重置'); + + for (let i = 1; i <= 3; i++) { + document.getElementById(`class${i}Images`).value = ''; + document.getElementById(`class${i}Count`).textContent = '0 张图片'; + document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览 + } + + document.getElementById('totalSamples').textContent = '0'; + document.getElementById('tagsList').innerHTML = '等待数据...'; + document.getElementById('predictions').innerHTML = '等待预测...'; + + this.showStatus('dataStatus', 'info', '数据集已清空'); + } + + // 启动摄像头 + async startWebcam() { + if (this.knnClassifier.getNumClasses() === 0) { + this.showStatus('predictionStatus', 'error', '请先训练模型!'); + return; + } + + const video = document.getElementById('webcam'); + + try { + const stream = await navigator.mediaDevices.getUserMedia({ + video: { facingMode: 'user' }, + audio: false + }); + + video.srcObject = stream; + this.webcamStream = stream; + + document.getElementById('startWebcamBtn').disabled = true; + document.getElementById('stopWebcamBtn').disabled = false; + + // 等待视频加载 + video.addEventListener('loadeddata', () => { + this.isPredicting = true; + this.predictLoop(); + }); + + this.showStatus('predictionStatus', 'success', '摄像头已启动'); + } catch (error) { + this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); + } + } + + // 停止摄像头 + stopWebcam() { + if (this.webcamStream) { + this.webcamStream.getTracks().forEach(track => track.stop()); + this.webcamStream = null; + } + + this.isPredicting = false; + this.filteredConfidences = {}; // 重置滤波器状态 + + const video = document.getElementById('webcam'); + video.srcObject = null; + + document.getElementById('startWebcamBtn').disabled = false; + document.getElementById('stopWebcamBtn').disabled = true; + + this.showStatus('predictionStatus', 'info', '摄像头已停止'); + } + + // 预测循环 + async predictLoop() { + if (!this.isPredicting) return; + + const video = document.getElementById('webcam'); + + if (video.readyState === 4) { + try { + console.log('开始预测,类别数量:', this.classNames.length); + + // 提取特征 + const features = await this.extractImageNetTags(video); + + // 使用原始KNN进行预测(包含距离信息) + const k = parseInt(document.getElementById('kValue').value); + const predictionWithDistance = await this.predictWithDistance(features.logits, k); + + console.log('预测结果:', predictionWithDistance); + + // 检查是否为未知类(基于距离阈值) + let finalPrediction; + + // 单品类模式特殊处理 + if (predictionWithDistance.isSingleClass) { + // 直接使用 predictWithDistance 返回的结果,它已经处理了阈值判断 + finalPrediction = { + label: predictionWithDistance.label, + confidences: predictionWithDistance.confidences, + isUnknown: predictionWithDistance.label === -1, + minDistance: predictionWithDistance.minDistance, + isSingleClass: true + }; + } else { + // 多类别模式:直接使用KNN预测结果,不使用距离阈值 + finalPrediction = { + label: predictionWithDistance.label, + confidences: predictionWithDistance.confidences, + isUnknown: false, + minDistance: predictionWithDistance.minDistance, + isSingleClass: false + }; + } + + // 应用低通滤波器 + const smoothedPrediction = this.applyLowPassFilter(finalPrediction); + + // 显示预测结果 + this.displayPrediction(smoothedPrediction); + + // 显示提取的标签 + this.displayTopTags(features.topTags); + + // 清理张量 + features.logits.dispose(); + } catch (error) { + console.error('预测错误:', error); + } + } + + // 继续预测循环 + requestAnimationFrame(() => this.predictLoop()); + } + + // 使用距离信息进行预测 + async predictWithDistance(logits, k) { + // 如果没有训练数据,返回空结果 + if (this.knnClassifier.getNumClasses() === 0) { + return { + label: -1, + confidences: {}, + minDistance: Infinity, + isSingleClass: false + }; + } + + const numClasses = this.knnClassifier.getNumClasses(); + + // 单品类检测模式 - 使用实际距离计算 + if (numClasses === 1) { + console.log('单品类检测模式 - 计算实际距离'); + + // 获取训练数据 + const dataset = this.knnClassifier.getClassifierDataset(); + if (!dataset || !dataset[0]) { + return { + label: -1, + confidences: { 0: 0 }, + minDistance: Infinity, + isSingleClass: true + }; + } + + // 计算输入样本与所有训练样本的欧氏距离 + const inputData = await logits.data(); + const trainData = await dataset[0].data(); + const numSamples = dataset[0].shape[0]; + const featureDim = dataset[0].shape[1]; + + console.log(`输入特征维度: ${inputData.length}, 训练数据维度: ${featureDim}, 样本数: ${numSamples}`); + + // 确保维度匹配 + if (inputData.length !== featureDim) { + console.error(`维度不匹配!输入: ${inputData.length}, 训练: ${featureDim}`); + return { + label: -1, + confidences: { 0: 0 }, + minDistance: Infinity, + isSingleClass: true + }; + } + + let minDistance = Infinity; + const distances = []; + + // 计算与每个训练样本的距离 + for (let i = 0; i < numSamples; i++) { + let distance = 0; + for (let j = 0; j < featureDim; j++) { + const diff = inputData[j] - trainData[i * featureDim + j]; + distance += diff * diff; + } + distance = Math.sqrt(distance); + distances.push(distance); + if (distance < minDistance) { + minDistance = distance; + } + } + + console.log(`计算了 ${distances.length} 个距离,最小距离: ${minDistance.toFixed(2)}`); + console.log(`前5个距离: ${distances.slice(0, 5).map(d => d.toFixed(2)).join(', ')}`); + + // 获取K个最近邻的平均距离 + distances.sort((a, b) => a - b); + const kNearest = distances.slice(0, Math.min(k, distances.length)); + const avgDistance = kNearest.reduce((sum, d) => sum + d, 0) / kNearest.length; + + // 从UI获取距离阈值 + const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '15.0'); + + // 基于距离阈值判断是否属于该类 + const belongsToClass = avgDistance <= threshold; + + // 二值化置信度:在阈值内100%,超出阈值0% + const confidence = belongsToClass ? 1.0 : 0; + + console.log(`单品类预测 - 平均距离: ${avgDistance.toFixed(2)}, 阈值: ${threshold}, 属于类别: ${belongsToClass}, 置信度: ${confidence.toFixed(3)}`); + + return { + label: belongsToClass ? 0 : -1, + confidences: { 0: confidence }, + minDistance: avgDistance, + isSingleClass: true + }; + } + + // 多品类模式:使用KNN分类器预测 + try { + const prediction = await this.knnClassifier.predictClass(logits, k); + + console.log('多品类预测结果:', prediction); + console.log('预测标签:', prediction.label); + console.log('置信度:', prediction.confidences); + + // 确保confidences存在且格式正确 + let confidences = {}; + if (prediction.confidences) { + // 检查是否是对象格式 + if (typeof prediction.confidences === 'object') { + confidences = prediction.confidences; + } + } + + // 如果confidences为空,手动计算 + if (Object.keys(confidences).length === 0) { + console.warn('置信度为空,使用默认值'); + for (let i = 0; i < this.classNames.length; i++) { + confidences[i] = i === prediction.label ? 1.0 : 0; + } + } + + // 计算实际距离(可选) + let minDistance = 0.5; // 默认距离 + + return { + label: prediction.label, + confidences: confidences, + minDistance: minDistance, + isSingleClass: false + }; + } catch (error) { + console.error('预测错误:', error); + return { + label: -1, + confidences: {}, + minDistance: Infinity, + isSingleClass: false + }; + } + } + + // 应用低通滤波器到置信度 + applyLowPassFilter(prediction) { + // 获取滤波器系数 + const alpha = parseFloat(document.getElementById('filterAlpha').value); + + // 保留原始的特殊字段 + const isSingleClass = prediction.isSingleClass || false; + const minDistance = prediction.minDistance; + const isUnknown = prediction.isUnknown; + + // 初始化滤波状态(如果是第一次) + if (Object.keys(this.filteredConfidences).length === 0) { + for (let i = 0; i < this.classNames.length; i++) { + this.filteredConfidences[i] = prediction.confidences[i] || 0; + } + return { + label: prediction.label, + confidences: {...this.filteredConfidences}, + isSingleClass: isSingleClass, + minDistance: minDistance, + isUnknown: isUnknown + }; + } + + // 单品类模式下,使用二值化输出,不应用滤波 + if (isSingleClass) { + // 获取距离阈值 + const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5'); + + // 二值化判断:在阈值内为1,超出为0 + const inThreshold = minDistance <= threshold; + const confidence = inThreshold ? 1.0 : 0; + + // 直接更新,不滤波 + this.filteredConfidences[0] = confidence; + + return { + label: inThreshold ? 0 : -1, + confidences: { 0: confidence }, + isSingleClass: isSingleClass, + minDistance: minDistance, + isUnknown: !inThreshold + }; + } + + // 应用指数移动平均(EMA)低通滤波 + const newConfidences = {}; + for (let i = 0; i < this.classNames.length; i++) { + const currentValue = prediction.confidences[i] || 0; + const previousValue = this.filteredConfidences[i] || 0; + + // EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1] + this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue; + newConfidences[i] = this.filteredConfidences[i]; + } + + // 归一化确保总和为1 + let sum = 0; + Object.values(newConfidences).forEach(v => sum += v); + if (sum > 0) { + Object.keys(newConfidences).forEach(key => { + newConfidences[key] = newConfidences[key] / sum; + }); + } + + // 找到最高置信度的类别 + let maxConfidence = 0; + let bestLabel = 0; + Object.keys(newConfidences).forEach(key => { + if (newConfidences[key] > maxConfidence) { + maxConfidence = newConfidences[key]; + bestLabel = parseInt(key); + } + }); + + return { + label: bestLabel, + confidences: newConfidences, + isSingleClass: isSingleClass, + minDistance: minDistance, + isUnknown: isUnknown + }; + } + + // 显示预测结果 + displayPrediction(prediction) { + const container = document.getElementById('predictions'); + let html = ''; + + // 单品类模式特殊处理 + if (this.classNames.length === 1) { + const className = this.classNames[0]; + const confidence = prediction.confidences[0] || 0; + const percentage = (confidence * 100).toFixed(1); + const isDetected = prediction.label === 0; // 是否检测到该类 + + // 获取距离阈值 + const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5'); + const distance = prediction.minDistance || 0; + + // 显示单品类检测结果(二值化显示) + html = ` +
+
+ + ${className} ${isDetected ? '✓ 检测到' : '✗ 未检测到'} + + + ${isDetected ? '100%' : '0%'} + +
+
+
+ ${isDetected ? `100%` : ''} +
+
+
+
+ 距离: ${distance.toFixed(2)} | 阈值: ${threshold.toFixed(2)} + + ${distance <= threshold ? '✓ 在阈值范围内' : '✗ 超出阈值范围'} + +
+ `; + + container.innerHTML = html; + return; + } + + // 多品类模式:直接使用滤波后的置信度 + const confidences = prediction.confidences; + const predictedClass = prediction.label; + + // 固定顺序显示(按类别索引) + for (let i = 0; i < this.classNames.length; i++) { + const className = this.classNames[i]; + const confidence = confidences[i] || 0; + const percentage = (confidence * 100).toFixed(1); + const isWinner = i === predictedClass; + + // 根据置信度决定颜色等级 + let barClass = ''; + if (confidence > 0.7) barClass = 'high'; + else if (confidence > 0.4) barClass = 'medium'; + else barClass = 'low'; + + // 如果是获胜类别,使用绿色 + if (isWinner) barClass = 'high'; + + html += ` +
+
+ + ${className} ${isWinner ? '👑' : ''} + + + ${percentage}% + +
+
+
+ ${confidence > 0.15 ? `${percentage}%` : ''} +
+
+
+ `; + } + + container.innerHTML = html; + } + + // 从摄像头捕获样本 + async captureFromWebcam(classIndex) { + if (!this.webcamStream) { + // 临时启动摄像头 + const video = document.getElementById('webcam'); + try { + const stream = await navigator.mediaDevices.getUserMedia({ + video: { facingMode: 'user' }, + audio: false + }); + + video.srcObject = stream; + this.webcamStream = stream; + + // 等待视频加载 + setTimeout(async () => { + await this.addWebcamSample(classIndex); + + // 停止临时摄像头 + this.webcamStream.getTracks().forEach(track => track.stop()); + this.webcamStream = null; + video.srcObject = null; + }, 1000); + } catch (error) { + this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`); + } + } else { + await this.addWebcamSample(classIndex); + } + } + + // 添加摄像头样本 + async addWebcamSample(classIndex) { + const video = document.getElementById('webcam'); + + if (video.readyState === 4) { + try { + // 提取特征 + const features = await this.extractImageNetTags(video); + + // 添加到KNN分类器 + this.knnClassifier.addExample(features.logits, classIndex); + + // 更新计数 + const currentCount = this.knnClassifier.getClassExampleCount(); + const count = currentCount[classIndex] || 0; + document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`; + + // 清理 + features.logits.dispose(); + + this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`); + } catch (error) { + console.error('添加样本失败:', error); + } + } + } + + // 保存模型 + async saveModel() { + if (this.knnClassifier.getNumClasses() === 0) { + this.showStatus('predictionStatus', 'error', '没有可保存的模型'); + return; + } + + try { + // 获取KNN分类器的数据 + const dataset = this.knnClassifier.getClassifierDataset(); + const datasetObj = {}; + + Object.keys(dataset).forEach(key => { + const data = dataset[key].dataSync(); + datasetObj[key] = Array.from(data); + }); + + // 获取特征维度 + let featureDim = 1280; // 默认值 + const firstKey = Object.keys(dataset)[0]; + if (firstKey && dataset[firstKey]) { + featureDim = dataset[firstKey].shape[1]; + } + + // 保存为JSON + const modelData = { + dataset: datasetObj, + classNames: this.classNames, + k: document.getElementById('kValue').value, + featureDim: featureDim, // 保存特征维度 + date: new Date().toISOString() + }; + + const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'knn-model.json'; + a.click(); + URL.revokeObjectURL(url); + + this.showStatus('predictionStatus', 'success', '模型已保存'); + } catch (error) { + this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`); + } + } + + // 加载模型 + async loadModel() { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = '.json'; + + input.onchange = async (e) => { + try { + const file = e.target.files[0]; + const text = await file.text(); + const modelData = JSON.parse(text); + + // 清空现有分类器 + this.knnClassifier.clearAllClasses(); + + // 恢复数据集 + Object.keys(modelData.dataset).forEach(key => { + const data = modelData.dataset[key]; + + // 自动检测特征维度(兼容旧模型) + let featureDim = modelData.featureDim; + if (!featureDim) { + // 尝试常见的维度 + const possibleDims = [1280, 1024, 1000]; + for (const dim of possibleDims) { + if (data.length % dim === 0) { + featureDim = dim; + console.warn(`自动检测到特征维度: ${dim}`); + break; + } + } + } + + if (!featureDim) { + console.error(`无法确定特征维度,数据长度: ${data.length}`); + return; + } + + const numSamples = data.length / featureDim; + console.log(`加载类别 ${key}:${numSamples} 个样本,${featureDim} 维特征`); + + const tensor = tf.tensor(data, [numSamples, featureDim]); + this.knnClassifier.setClassifierDataset({ [key]: tensor }); + }); + + this.classNames = modelData.classNames; + document.getElementById('kValue').value = modelData.k; + document.getElementById('kValueDisplay').textContent = modelData.k; + + this.showStatus('predictionStatus', 'success', + `模型加载成功!类别: ${this.classNames.join(', ')}`); + } catch (error) { + this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`); + } + }; + + input.click(); + } + + // 显示状态 + showStatus(elementId, type, message) { + const element = document.getElementById(elementId); + + const classMap = { + 'success': 'status-success', + 'error': 'status-error', + 'info': 'status-info' + }; + + element.className = `status-message ${classMap[type]}`; + element.textContent = message; + } +} + +// 全局函数:从摄像头捕获 +function captureFromWebcam(classIndex) { + if (window.classifier) { + window.classifier.captureFromWebcam(classIndex); + } +} + +// 初始化应用 +let classifier; +document.addEventListener('DOMContentLoaded', () => { + classifier = new KNNImageClassifier(); + window.classifier = classifier; +}); \ No newline at end of file diff --git a/随机森林/decision-tree.js b/随机森林/decision-tree.js new file mode 100644 index 0000000..2aa8123 --- /dev/null +++ b/随机森林/decision-tree.js @@ -0,0 +1,394 @@ +var dt = (function () { + + /** + * Creates an instance of DecisionTree + * + * @constructor + * @param builder - contains training set and + * some configuration parameters + */ + function DecisionTree(builder) { + this.root = buildDecisionTree({ + trainingSet: builder.trainingSet, + ignoredAttributes: arrayToHashSet(builder.ignoredAttributes), + categoryAttr: builder.categoryAttr || 'category', + minItemsCount: builder.minItemsCount || 1, + entropyThrehold: builder.entropyThrehold || 0.01, + maxTreeDepth: builder.maxTreeDepth || 70 + }); + } + + DecisionTree.prototype.predict = function (item) { + return predict(this.root, item); + } + + /** + * Creates an instance of RandomForest + * with specific number of trees + * + * @constructor + * @param builder - contains training set and some + * configuration parameters for + * building decision trees + */ + function RandomForest(builder, treesNumber) { + this.trees = buildRandomForest(builder, treesNumber); + } + + RandomForest.prototype.predict = function (item) { + return predictRandomForest(this.trees, item); + } + + /** + * Transforming array to object with such attributes + * as elements of array (afterwards it can be used as HashSet) + */ + function arrayToHashSet(array) { + var hashSet = {}; + if (array) { + for(var i in array) { + var attr = array[i]; + hashSet[attr] = true; + } + } + return hashSet; + } + + /** + * Calculating how many objects have the same + * values of specific attribute. + * + * @param items - array of objects + * + * @param attr - variable with name of attribute, + * which embedded in each object + */ + function countUniqueValues(items, attr) { + var counter = {}; + + // detecting different values of attribute + for (var i = items.length - 1; i >= 0; i--) { + // items[i][attr] - value of attribute + counter[items[i][attr]] = 0; + } + + // counting number of occurrences of each of values + // of attribute + for (var i = items.length - 1; i >= 0; i--) { + counter[items[i][attr]] += 1; + } + + return counter; + } + + /** + * Calculating entropy of array of objects + * by specific attribute. + * + * @param items - array of objects + * + * @param attr - variable with name of attribute, + * which embedded in each object + */ + function entropy(items, attr) { + // counting number of occurrences of each of values + // of attribute + var counter = countUniqueValues(items, attr); + + var entropy = 0; + var p; + for (var i in counter) { + p = counter[i] / items.length; + entropy += -p * Math.log(p); + } + + return entropy; + } + + /** + * Splitting array of objects by value of specific attribute, + * using specific predicate and pivot. + * + * Items which matched by predicate will be copied to + * the new array called 'match', and the rest of the items + * will be copied to array with name 'notMatch' + * + * @param items - array of objects + * + * @param attr - variable with name of attribute, + * which embedded in each object + * + * @param predicate - function(x, y) + * which returns 'true' or 'false' + * + * @param pivot - used as the second argument when + * calling predicate function: + * e.g. predicate(item[attr], pivot) + */ + function split(items, attr, predicate, pivot) { + var match = []; + var notMatch = []; + + var item, + attrValue; + + for (var i = items.length - 1; i >= 0; i--) { + item = items[i]; + attrValue = item[attr]; + + if (predicate(attrValue, pivot)) { + match.push(item); + } else { + notMatch.push(item); + } + }; + + return { + match: match, + notMatch: notMatch + }; + } + + /** + * Finding value of specific attribute which is most frequent + * in given array of objects. + * + * @param items - array of objects + * + * @param attr - variable with name of attribute, + * which embedded in each object + */ + function mostFrequentValue(items, attr) { + // counting number of occurrences of each of values + // of attribute + var counter = countUniqueValues(items, attr); + + var mostFrequentCount = 0; + var mostFrequentValue; + + for (var value in counter) { + if (counter[value] > mostFrequentCount) { + mostFrequentCount = counter[value]; + mostFrequentValue = value; + } + }; + + return mostFrequentValue; + } + + var predicates = { + '==': function (a, b) { return a == b }, + '>=': function (a, b) { return a >= b } + }; + + /** + * Function for building decision tree + */ + function buildDecisionTree(builder) { + + var trainingSet = builder.trainingSet; + var minItemsCount = builder.minItemsCount; + var categoryAttr = builder.categoryAttr; + var entropyThrehold = builder.entropyThrehold; + var maxTreeDepth = builder.maxTreeDepth; + var ignoredAttributes = builder.ignoredAttributes; + + if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) { + // restriction by maximal depth of tree + // or size of training set is to small + // so we have to terminate process of building tree + return { + category: mostFrequentValue(trainingSet, categoryAttr) + }; + } + + var initialEntropy = entropy(trainingSet, categoryAttr); + + if (initialEntropy <= entropyThrehold) { + // entropy of training set too small + // (it means that training set is almost homogeneous), + // so we have to terminate process of building tree + return { + category: mostFrequentValue(trainingSet, categoryAttr) + }; + } + + // used as hash-set for avoiding the checking of split by rules + // with the same 'attribute-predicate-pivot' more than once + var alreadyChecked = {}; + + // this variable expected to contain rule, which splits training set + // into subsets with smaller values of entropy (produces informational gain) + var bestSplit = {gain: 0}; + + for (var i = trainingSet.length - 1; i >= 0; i--) { + var item = trainingSet[i]; + + // iterating over all attributes of item + for (var attr in item) { + if ((attr == categoryAttr) || ignoredAttributes[attr]) { + continue; + } + + // let the value of current attribute be the pivot + var pivot = item[attr]; + + // pick the predicate + // depending on the type of the attribute value + var predicateName; + if (typeof pivot == 'number') { + predicateName = '>='; + } else { + // there is no sense to compare non-numeric attributes + // so we will check only equality of such attributes + predicateName = '=='; + } + + var attrPredPivot = attr + predicateName + pivot; + if (alreadyChecked[attrPredPivot]) { + // skip such pairs of 'attribute-predicate-pivot', + // which been already checked + continue; + } + alreadyChecked[attrPredPivot] = true; + + var predicate = predicates[predicateName]; + + // splitting training set by given 'attribute-predicate-value' + var currSplit = split(trainingSet, attr, predicate, pivot); + + // calculating entropy of subsets + var matchEntropy = entropy(currSplit.match, categoryAttr); + var notMatchEntropy = entropy(currSplit.notMatch, categoryAttr); + + // calculating informational gain + var newEntropy = 0; + newEntropy += matchEntropy * currSplit.match.length; + newEntropy += notMatchEntropy * currSplit.notMatch.length; + newEntropy /= trainingSet.length; + var currGain = initialEntropy - newEntropy; + + if (currGain > bestSplit.gain) { + // remember pairs 'attribute-predicate-value' + // which provides informational gain + bestSplit = currSplit; + bestSplit.predicateName = predicateName; + bestSplit.predicate = predicate; + bestSplit.attribute = attr; + bestSplit.pivot = pivot; + bestSplit.gain = currGain; + } + } + } + + if (!bestSplit.gain) { + // can't find optimal split + return { category: mostFrequentValue(trainingSet, categoryAttr) }; + } + + // building subtrees + + builder.maxTreeDepth = maxTreeDepth - 1; + + builder.trainingSet = bestSplit.match; + var matchSubTree = buildDecisionTree(builder); + + builder.trainingSet = bestSplit.notMatch; + var notMatchSubTree = buildDecisionTree(builder); + + return { + attribute: bestSplit.attribute, + predicate: bestSplit.predicate, + predicateName: bestSplit.predicateName, + pivot: bestSplit.pivot, + match: matchSubTree, + notMatch: notMatchSubTree, + matchedCount: bestSplit.match.length, + notMatchedCount: bestSplit.notMatch.length + }; + } + + /** + * Classifying item, using decision tree + */ + function predict(tree, item) { + var attr, + value, + predicate, + pivot; + + // Traversing tree from the root to leaf + while(true) { + + if (tree.category) { + // only leafs contains predicted category + return tree.category; + } + + attr = tree.attribute; + value = item[attr]; + + predicate = tree.predicate; + pivot = tree.pivot; + + // move to one of subtrees + if (predicate(value, pivot)) { + tree = tree.match; + } else { + tree = tree.notMatch; + } + } + } + + /** + * Building array of decision trees + */ + function buildRandomForest(builder, treesNumber) { + var items = builder.trainingSet; + + // creating training sets for each tree + var trainingSets = []; + for (var t = 0; t < treesNumber; t++) { + trainingSets[t] = []; + } + for (var i = items.length - 1; i >= 0 ; i--) { + // assigning items to training sets of each tree + // using 'round-robin' strategy + var correspondingTree = i % treesNumber; + trainingSets[correspondingTree].push(items[i]); + } + + // building decision trees + var forest = []; + for (var t = 0; t < treesNumber; t++) { + builder.trainingSet = trainingSets[t]; + + var tree = new DecisionTree(builder); + forest.push(tree); + } + return forest; + } + + /** + * Each of decision tree classifying item + * ('voting' that item corresponds to some class). + * + * This function returns hash, which contains + * all classifying results, and number of votes + * which were given for each of classifying results + */ + function predictRandomForest(forest, item) { + var result = {}; + for (var i in forest) { + var tree = forest[i]; + var prediction = tree.predict(item); + result[prediction] = result[prediction] ? result[prediction] + 1 : 1; + } + return result; + } + + var exports = {}; + exports.DecisionTree = DecisionTree; + exports.RandomForest = RandomForest; + return exports; +})(); \ No newline at end of file diff --git a/随机森林/rf-classifier.html b/随机森林/rf-classifier.html new file mode 100644 index 0000000..22213ed --- /dev/null +++ b/随机森林/rf-classifier.html @@ -0,0 +1,542 @@ + + + + + + 图像分类器 - TensorFlow.js & decision-tree.js + + + + + + + +
+

🤖 图像分类器 - 随机森林

+ +
+ +
+

📸 数据采集

+
+

1 第一类

+ + + + 0 张图片 + +
+
+ +
+

2 第二类

+ + + + 0 张图片 + +
+
+ +
+

3 第三类(可选)

+ + + + 0 张图片 + +
+
+ +
+ + +
+
+ +
+

模型参数调整

+
+ + + 随机森林中决策树的数量 +
+
+ + + 用于训练每棵树的数据子集占比 (0.1-1.0) +
+ +
+
+

ℹ️ 模型信息

+
+ 预训练模型: + MobileNet v2 +
+
+ 分类器类型: + 随机森林 +
+
+ 总样本数: + 0 +
+
+
+
+ + +
+

📹 实时预测

+ +
+ + +
+ +
+ +
+ +
+

预测结果

+
等待预测...
+
+ +
+
+
+ + + + diff --git a/随机森林/rf-classifier.js b/随机森林/rf-classifier.js new file mode 100644 index 0000000..bc3478f --- /dev/null +++ b/随机森林/rf-classifier.js @@ -0,0 +1,472 @@ +// 图像分类器 - 基于MobileNet特征标签和 decision-tree.js 实现随机森林 + +class ImageClassifier { + constructor() { + this.mobilenet = null; + this.randomForest = []; // 存储多个决策树 + this.classNames = []; + this.webcamStream = null; + this.isPredicting = false; + this.imagenetClasses = null; + this.trainingSet = []; + this.numTrees = 10; // 随机森林中决策树的数量,可调整 + this.subsetSize = 0.7; // 训练集子集大小, 可调整 + + this.init(); + } + + async init() { + this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); + + try { + this.mobilenet = await mobilenet.load({ + version: 2, + alpha: 1.0 + }); + + await this.loadImageNetClasses(); + + this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); + this.setupEventListeners(); + } catch (error) { + this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); + } + } + + async loadImageNetClasses() { + // ImageNet 前10个类别名称(简化版) + this.imagenetClasses = [ + 'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead', + 'electric_ray', 'stingray', 'cock', 'hen', 'ostrich', + 'brambling', 'goldfinch', 'house_finch', 'junco', 'indigo_bunting', + 'robin', 'bulbul', 'jay', 'magpie', 'chickadee' + ]; + } + + setupEventListeners() { + // 文件上传监听 + ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { + document.getElementById(id).addEventListener('change', (e) => { + this.handleImageUpload(e, index); + }); + }); + + // 按钮监听 + document.getElementById('addDataBtn').addEventListener('click', () => this.trainModel()); + document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); + document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); + document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); + + // 参数监听 + document.getElementById('numTrees').addEventListener('input', (e) => { + this.numTrees = parseInt(e.target.value); + }); + + document.getElementById('subsetSize').addEventListener('input', (e) => { + this.subsetSize = parseFloat(e.target.value); + }); + } + + handleImageUpload(event, classIndex) { + const files = event.target.files; + const countElement = document.getElementById(`class${classIndex + 1}Count`); + const previewContainer = document.getElementById(`class${classIndex + 1}Preview`); + + countElement.textContent = `${files.length} 张图片`; + + // 清空之前的预览 + previewContainer.innerHTML = ''; + + // 添加图片预览 + Array.from(files).forEach(file => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = document.createElement('img'); + img.src = e.target.result; + img.className = 'preview-img'; + previewContainer.appendChild(img); + }; + reader.readAsDataURL(file); + }); + } + + async extractImageNetTags(img) { + try { + const predictions = await this.mobilenet.classify(img); + const logits = this.mobilenet.infer(img, false); + return { + logits: logits, // 1000维特征向量 + predictions: predictions, // 前3个预测 + topTags: await this.getTopKTags(logits, 10) // 前10个标签和权重 + }; + } catch (error) { + console.error('特征提取失败:', error); + throw error; + } + } + + async getTopKTags(logits, k = 10) { + const values = await logits.data(); + const valuesAndIndices = []; + + for (let i = 0; i < values.length; i++) { + valuesAndIndices.push({ value: values[i], index: i }); + } + + valuesAndIndices.sort((a, b) => b.value - a.value); + const topkValues = new Float32Array(k); + const topkIndices = new Int32Array(k); + + for (let i = 0; i < k; i++) { + topkValues[i] = valuesAndIndices[i].value; + topkIndices[i] = valuesAndIndices[i].index; + } + + const topTags = []; + for (let i = 0; i < k; i++) { + topTags.push({ + className: this.imagenetClasses[topkIndices[i]] || `class_${topkIndices[i]}`, + probability: this.softmax(topkValues)[i], + logit: topkValues[i] + }); + } + + return topTags; + } + + softmax(arr) { + const maxLogit = Math.max(...arr); + const scores = arr.map(l => Math.exp(l - maxLogit)); + const sum = scores.reduce((a, b) => a + b); + return scores.map(s => s / sum); + } + // 训练随机森林模型 + async trainModel() { + const classes = []; + const imageFiles = []; + + // 收集所有类别和图片 + for (let i = 1; i <= 3; i++) { + const className = document.getElementById(`class${i}Name`).value.trim(); + const files = document.getElementById(`class${i}Images`).files; + + if (className && files.length > 0) { + classes.push(className); + imageFiles.push(files); + } + } + + if (classes.length < 2) { + this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!'); + return; + } + + this.classNames = classes; + this.showStatus('dataStatus', 'info', '正在处理图片并训练模型...'); + + // 准备训练数据 + this.trainingSet = []; + + let totalProcessed = 0; + let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0); + + // 处理每个类别的图片 + for (let classIndex = 0; classIndex < classes.length; classIndex++) { + const files = imageFiles[classIndex]; + for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { + try { + const img = await this.loadImage(files[fileIndex]); + const features = await this.extractImageNetTags(img); + // 将logits从tf.Tensor转换为数组 + const featureVector = await features.logits.data(); + + // 将特征向量添加到训练数据 + const item = {}; + Array.from(featureVector).forEach((value, index) => { + item[`feature_${index}`] = value; + }); + item.category = classes[classIndex]; // 类别名称,而不是索引 + this.trainingSet.push(item); + + totalProcessed++; + const progress = Math.round((totalProcessed / totalImages) * 100); + this.showStatus('dataStatus', 'info', + `处理中... ${totalProcessed}/${totalImages} (${progress}%)`); + + img.remove(); + features.logits.dispose(); + } catch (error) { + console.error('处理图片失败:', error); + } + } + } + + try { + // 构建随机森林 + this.randomForest = []; + for (let i = 0; i < this.numTrees; i++) { + // 创建具有随机子集的决策树 + const trainingSubset = this.createTrainingSubset(this.trainingSet); + + const builder = { + trainingSet: trainingSubset, + categoryAttr: 'category' + }; + const tree = new dt.DecisionTree(builder); + this.randomForest.push(tree); + } + + this.showStatus('dataStatus', 'success', `模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别, ${this.numTrees} 棵树`); + + // 更新模型信息 + document.getElementById('totalSamples').textContent = totalProcessed; + } catch (error) { + console.error('训练失败:', error); + this.showStatus('dataStatus', 'error', `训练失败: ${error.message}`); + } + } + + // 创建训练数据的随机子集(用于随机森林) + createTrainingSubset(trainingSet) { + const subset = []; + const subsetSize = Math.floor(this.subsetSize * trainingSet.length); + + for (let i = 0; i < subsetSize; i++) { + const randomIndex = Math.floor(Math.random() * trainingSet.length); + subset.push(trainingSet[randomIndex]); + } + + return subset; + } + + async loadImage(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (e) => { + const img = new Image(); + img.onload = () => resolve(img); + img.onerror = reject; + img.src = e.target.result; + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); + } + + clearDataset() { + this.randomForest = []; + this.classNames = []; + this.trainingSet = []; + + for (let i = 1; i <= 3; i++) { + document.getElementById(`class${i}Images`).value = ''; + document.getElementById(`class${i}Count`).textContent = '0 张图片'; + document.getElementById(`class${i}Preview`).innerHTML = ''; + } + + document.getElementById('totalSamples').textContent = '0'; + document.getElementById('predictions').innerHTML = '等待预测...'; + + this.showStatus('dataStatus', 'info', '数据集已清空'); + } + + startWebcam() { + if (this.randomForest.length == 0) { + this.showStatus('predictionStatus', 'error', '请先训练模型!'); + return; + } + + const video = document.getElementById('webcam'); + + navigator.mediaDevices.getUserMedia({ + video: { facingMode: 'user' }, + audio: false + }) + .then(stream => { + video.srcObject = stream; + this.webcamStream = stream; + + document.getElementById('startWebcamBtn').disabled = true; + document.getElementById('stopWebcamBtn').disabled = false; + + this.isPredicting = true; + this.predictLoop(); + + this.showStatus('predictionStatus', 'success', '摄像头已启动'); + }) + .catch(error => { + this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); + }); + } + + stopWebcam() { + if (this.webcamStream) { + this.webcamStream.getTracks().forEach(track => track.stop()); + this.webcamStream = null; + } + + this.isPredicting = false; + + const video = document.getElementById('webcam'); + video.srcObject = null; + + document.getElementById('startWebcamBtn').disabled = false; + document.getElementById('stopWebcamBtn').disabled = true; + + this.showStatus('predictionStatus', 'info', '摄像头已停止'); + } + + async predictLoop() { + if (!this.isPredicting) return; + + const video = document.getElementById('webcam'); + + if (video.readyState === 4) { + try { + const features = await this.extractImageNetTags(video); + const featureVector = await features.logits.data(); + + const item = {}; + Array.from(featureVector).forEach((value, index) => { + item[`feature_${index}`] = value; + }); + + const { predictedCategory, probabilities } = this.predictWithRandomForest(item); + + this.displayPrediction(predictedCategory, probabilities); + features.logits.dispose(); + + + } catch (error) { + console.error('预测错误:', error); + } + + } + + requestAnimationFrame(() => this.predictLoop()); + } + + predictWithRandomForest(item) { + const votes = {}; + this.classNames.forEach(className => { + votes[className] = 0; + }); + + this.randomForest.forEach(tree => { + const prediction = tree.predict(item); + votes[prediction] = (votes[prediction] || 0) + 1; + }); + + let predictedCategory = null; + let maxVotes = 0; + for (const category in votes) { + if (votes[category] > maxVotes) { + predictedCategory = category; + maxVotes = votes[category]; + } + } + + const probabilities = {}; + for (const category in votes) { + probabilities[category] = votes[category] / this.numTrees; + } + + return { + predictedCategory: predictedCategory, + probabilities: probabilities + }; + } + + async captureFromWebcam(classIndex) { + if (!this.webcamStream) { + // 临时启动摄像头 + const video = document.getElementById('webcam'); + try { + const stream = await navigator.mediaDevices.getUserMedia({ + video: { facingMode: 'user' }, + audio: false + }); + + video.srcObject = stream; + this.webcamStream = stream; + + // 等待视频加载 + setTimeout(async () => { + await this.addWebcamSample(classIndex); + + // 停止临时摄像头 + this.webcamStream.getTracks().forEach(track => track.stop()); + this.webcamStream = null; + video.srcObject = null; + }, 1000); + } catch (error) { + this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`); + } + } else { + await this.addWebcamSample(classIndex); + } + } + + async addWebcamSample(classIndex) { + const video = document.getElementById('webcam'); + + if (video.readyState === 4) { + try { + const features = await this.extractImageNetTags(video); + + const featureVector = await features.logits.data(); + // 使用logits作为对象的属性 + const item = {}; + + Array.from(featureVector).forEach((value, index) => { + item[`feature_${index}`] = value; + }); + + // 添加类别信息 + const className = document.getElementById(`class${classIndex + 1}Name`).value.trim(); + + item.category = className; + + this.trainingSet.push(item); + this.showStatus('dataStatus', 'success', `从摄像头添加 ${className} 类的样本`); + features.logits.dispose(); + } catch (error) { + console.error('添加样本失败:', error); + this.showStatus('dataStatus', 'error', `添加样本失败: ${error.message}`); + } + } + } + + displayPrediction(category, probabilities) { + const container = document.getElementById('predictions'); + let html = `预测类别:${category}
`; + for (const className in probabilities) { + const probability = (probabilities[className] * 100).toFixed(2); + html += `${className}: ${probability}%
`; + } + container.innerHTML = html; + } + + showStatus(elementId, type, message) { + const element = document.getElementById(elementId); + + const classMap = { + 'success': 'status-success', + 'error': 'status-error', + 'info': 'status-info' + }; + + element.className = `status-message ${classMap[type]}`; + element.textContent = message; + } +} + +// 全局函数:从摄像头采集 +function captureFromWebcam(classIndex) { + if (window.classifier) { + window.classifier.captureFromWebcam(classIndex); + } +} +document.addEventListener('DOMContentLoaded', async () => { + window.classifier = new ImageClassifier(); +});