// KNN 图像分类器 - 基于MobileNet特征标签 class KNNImageClassifier { constructor() { this.mobilenet = null; this.knnClassifier = null; this.classNames = []; this.webcamStream = null; this.isPredicting = false; this.currentCaptureClass = -1; this.imagenetClasses = null; // 低通滤波器状态 this.filteredConfidences = {}; this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑 // 距离阈值设置 this.useDistanceThreshold = true; this.distanceThreshold = 0.5; // 默认距离阈值(归一化后的特征) this.adaptiveThreshold = null; // 自适应阈值 this.init(); } async init() { this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); try { // 加载 MobileNet 模型 this.mobilenet = await mobilenet.load({ version: 2, alpha: 1.0 }); // 创建 KNN 分类器 this.knnClassifier = knnClassifier.create(); // 加载 ImageNet 类别名称 await this.loadImageNetClasses(); this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); this.setupEventListeners(); } catch (error) { this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); } } async loadImageNetClasses() { // ImageNet 前10个类别名称(简化版) this.imagenetClasses = [ 'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead', 'electric_ray', 'stingray', 'cock', 'hen', 'ostrich' ]; } setupEventListeners() { // 文件上传监听 ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { document.getElementById(id).addEventListener('change', (e) => { this.handleImageUpload(e, index); }); }); // 按钮监听 document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN()); document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel()); document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel()); } handleImageUpload(event, classIndex) { const files = event.target.files; const countElement = document.getElementById(`class${classIndex + 1}Count`); const previewContainer = document.getElementById(`class${classIndex + 1}Preview`); countElement.textContent = `${files.length} 张图片`; // 清空之前的预览 previewContainer.innerHTML = ''; // 添加图片预览 Array.from(files).forEach(file => { const reader = new FileReader(); reader.onload = (e) => { const img = document.createElement('img'); img.src = e.target.result; img.className = 'preview-img'; previewContainer.appendChild(img); }; reader.readAsDataURL(file); }); } // 从图像提取 MobileNet 标签和权重 async extractImageNetTags(img) { try { // 获取 MobileNet 的预测(1000个类别的概率) const predictions = await this.mobilenet.classify(img); // 获取用于KNN的特征(使用嵌入层获得更好的特征表示) const rawEmbeddings = this.mobilenet.infer(img, true); // true = 使用嵌入层,获取1280维特征 // L2归一化特征向量(重要:使距离计算更稳定) const embeddings = tf.tidy(() => { const norm = tf.norm(rawEmbeddings); const normalized = tf.div(rawEmbeddings, norm); rawEmbeddings.dispose(); // 清理原始嵌入 return normalized; }); // 获取用于显示的logits(1000个类别) const logits = this.mobilenet.infer(img, false); // false = 获取原始1000维输出 // 获取前10个最高概率的标签 const topK = await this.getTopKTags(logits, 10); // 清理logits(只用于显示) logits.dispose(); return { logits: embeddings, // 使用1280维嵌入特征用于KNN predictions: predictions, // 前3个预测 topTags: topK // 前10个标签和权重 }; } catch (error) { console.error('特征提取失败:', error); throw error; } } // 获取Top-K标签 async getTopKTags(logits, k = 10) { const values = await logits.data(); const valuesAndIndices = []; for (let i = 0; i < values.length; i++) { valuesAndIndices.push({ value: values[i], index: i }); } valuesAndIndices.sort((a, b) => b.value - a.value); const topkValues = new Float32Array(k); const topkIndices = new Int32Array(k); for (let i = 0; i < k; i++) { topkValues[i] = valuesAndIndices[i].value; topkIndices[i] = valuesAndIndices[i].index; } const topTags = []; for (let i = 0; i < k; i++) { topTags.push({ className: this.imagenetClasses[i] || `class_${topkIndices[i]}`, probability: this.softmax(topkValues)[i], logit: topkValues[i] }); } return topTags; } // Softmax 函数 softmax(arr) { const maxLogit = Math.max(...arr); const scores = arr.map(l => Math.exp(l - maxLogit)); const sum = scores.reduce((a, b) => a + b); return scores.map(s => s / sum); } // 训练 KNN 模型 async trainKNN() { const classes = []; const imageFiles = []; // 收集所有类别和图片 for (let i = 1; i <= 3; i++) { const className = document.getElementById(`class${i}Name`).value.trim(); const files = document.getElementById(`class${i}Images`).files; if (className && files && files.length > 0) { classes.push(className); imageFiles.push(files); console.log(`类别 ${i}: "${className}" - ${files.length} 张图片`); } } console.log('收集到的类别:', classes); console.log('类别数量:', classes.length); // 支持单品类检测(One-Class Classification) if (classes.length < 1) { this.showStatus('dataStatus', 'error', '请至少添加一个类别的图片!'); return; } // 如果只有一个类别,提示用户这是单品类检测模式 if (classes.length === 1) { console.log('📍 单品类检测模式:只检测 "' + classes[0] + '",其他都视为背景/未知'); } this.classNames = classes; this.filteredConfidences = {}; // 重置滤波器状态 this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...'); // 清空现有的KNN分类器 this.knnClassifier.clearAllClasses(); let totalProcessed = 0; let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0); // 处理每个类别的图片 for (let classIndex = 0; classIndex < classes.length; classIndex++) { const files = imageFiles[classIndex]; console.log(`处理类别 ${classes[classIndex]}...`); for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { try { const img = await this.loadImage(files[fileIndex]); // 提取特征标签 const features = await this.extractImageNetTags(img); // 添加到KNN分类器 // 使用完整的logits作为特征向量 this.knnClassifier.addExample(features.logits, classIndex); totalProcessed++; const progress = Math.round((totalProcessed / totalImages) * 100); this.showStatus('dataStatus', 'info', `处理中... ${totalProcessed}/${totalImages} (${progress}%)`); // 显示提取的标签 if (fileIndex === 0) { this.displayTopTags(features.topTags); } // 清理 img.remove(); features.logits.dispose(); } catch (error) { console.error('处理图片失败:', error); } } } // 更新模型信息 document.getElementById('totalSamples').textContent = totalProcessed; // 根据类别数量显示不同的消息 let statusMessage; if (classes.length === 1) { statusMessage = `单品类检测模型训练完成!将只检测 "${classes[0]}",共 ${totalProcessed} 个样本`; } else { statusMessage = `KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`; } this.showStatus('dataStatus', 'success', statusMessage); console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别'); if (classes.length === 1) { console.log('📍 单品类检测模式已启用,将基于距离阈值判断是否为:', classes[0]); // 计算自适应阈值 await this.calculateAdaptiveThreshold(); } } // 计算自适应阈值(基于训练数据的内部距离) async calculateAdaptiveThreshold() { if (this.knnClassifier.getNumClasses() !== 1) return; console.log('计算自适应阈值...'); const dataset = this.knnClassifier.getClassifierDataset(); if (!dataset || !dataset[0]) return; const trainData = await dataset[0].data(); const numSamples = dataset[0].shape[0]; const featureDim = dataset[0].shape[1]; // 计算训练样本之间的平均距离 let totalDistance = 0; let count = 0; for (let i = 0; i < Math.min(numSamples, 20); i++) { // 限制计算量 for (let j = i + 1; j < Math.min(numSamples, 20); j++) { let distance = 0; for (let k = 0; k < featureDim; k++) { const diff = trainData[i * featureDim + k] - trainData[j * featureDim + k]; distance += diff * diff; } distance = Math.sqrt(distance); totalDistance += distance; count++; } } if (count > 0) { const avgInternalDistance = totalDistance / count; // 自适应阈值设为内部平均距离的1.3-1.5倍(归一化后距离较小) this.adaptiveThreshold = avgInternalDistance * 1.3; console.log(`内部平均距离: ${avgInternalDistance.toFixed(2)}`); console.log(`建议自适应阈值: ${this.adaptiveThreshold.toFixed(2)}`); // 更新UI显示建议阈值 const thresholdInput = document.getElementById('distanceThreshold'); const thresholdDisplay = document.getElementById('distanceThresholdDisplay'); if (thresholdInput && thresholdDisplay) { thresholdInput.value = this.adaptiveThreshold.toFixed(1); thresholdDisplay.textContent = this.adaptiveThreshold.toFixed(1); } this.showStatus('dataStatus', 'info', `自适应阈值已计算: ${this.adaptiveThreshold.toFixed(1)} (基于训练数据内部距离)`); } } // 显示提取的标签 displayTopTags(tags) { const container = document.getElementById('tagsList'); let html = ''; tags.slice(0, 5).forEach(tag => { html += ` ${tag.className} ${(tag.probability * 100).toFixed(1)}% `; }); container.innerHTML = html; } // 加载图片 async loadImage(file) { return new Promise((resolve, reject) => { const reader = new FileReader(); reader.onload = (e) => { const img = new Image(); img.onload = () => resolve(img); img.onerror = reject; img.src = e.target.result; }; reader.onerror = reject; reader.readAsDataURL(file); }); } // 清空数据集 clearDataset() { this.knnClassifier.clearAllClasses(); this.classNames = []; this.filteredConfidences = {}; // 重置滤波器状态 console.log('数据集已清空,滤波器状态已重置'); for (let i = 1; i <= 3; i++) { document.getElementById(`class${i}Images`).value = ''; document.getElementById(`class${i}Count`).textContent = '0 张图片'; document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览 } document.getElementById('totalSamples').textContent = '0'; document.getElementById('tagsList').innerHTML = '等待数据...'; document.getElementById('predictions').innerHTML = '等待预测...'; this.showStatus('dataStatus', 'info', '数据集已清空'); } // 启动摄像头 async startWebcam() { if (this.knnClassifier.getNumClasses() === 0) { this.showStatus('predictionStatus', 'error', '请先训练模型!'); return; } const video = document.getElementById('webcam'); try { const stream = await navigator.mediaDevices.getUserMedia({ video: { facingMode: 'user' }, audio: false }); video.srcObject = stream; this.webcamStream = stream; document.getElementById('startWebcamBtn').disabled = true; document.getElementById('stopWebcamBtn').disabled = false; // 等待视频加载 video.addEventListener('loadeddata', () => { this.isPredicting = true; this.predictLoop(); }); this.showStatus('predictionStatus', 'success', '摄像头已启动'); } catch (error) { this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); } } // 停止摄像头 stopWebcam() { if (this.webcamStream) { this.webcamStream.getTracks().forEach(track => track.stop()); this.webcamStream = null; } this.isPredicting = false; this.filteredConfidences = {}; // 重置滤波器状态 const video = document.getElementById('webcam'); video.srcObject = null; document.getElementById('startWebcamBtn').disabled = false; document.getElementById('stopWebcamBtn').disabled = true; this.showStatus('predictionStatus', 'info', '摄像头已停止'); } // 预测循环 async predictLoop() { if (!this.isPredicting) return; const video = document.getElementById('webcam'); if (video.readyState === 4) { try { console.log('开始预测,类别数量:', this.classNames.length); // 提取特征 const features = await this.extractImageNetTags(video); // 使用原始KNN进行预测(包含距离信息) const k = parseInt(document.getElementById('kValue').value); const predictionWithDistance = await this.predictWithDistance(features.logits, k); console.log('预测结果:', predictionWithDistance); // 检查是否为未知类(基于距离阈值) let finalPrediction; // 单品类模式特殊处理 if (predictionWithDistance.isSingleClass) { // 直接使用 predictWithDistance 返回的结果,它已经处理了阈值判断 finalPrediction = { label: predictionWithDistance.label, confidences: predictionWithDistance.confidences, isUnknown: predictionWithDistance.label === -1, minDistance: predictionWithDistance.minDistance, isSingleClass: true }; } else { // 多类别模式:直接使用KNN预测结果,不使用距离阈值 finalPrediction = { label: predictionWithDistance.label, confidences: predictionWithDistance.confidences, isUnknown: false, minDistance: predictionWithDistance.minDistance, isSingleClass: false }; } // 应用低通滤波器 const smoothedPrediction = this.applyLowPassFilter(finalPrediction); // 显示预测结果 this.displayPrediction(smoothedPrediction); // 显示提取的标签 this.displayTopTags(features.topTags); // 清理张量 features.logits.dispose(); } catch (error) { console.error('预测错误:', error); } } // 继续预测循环 requestAnimationFrame(() => this.predictLoop()); } // 使用距离信息进行预测 async predictWithDistance(logits, k) { // 如果没有训练数据,返回空结果 if (this.knnClassifier.getNumClasses() === 0) { return { label: -1, confidences: {}, minDistance: Infinity, isSingleClass: false }; } const numClasses = this.knnClassifier.getNumClasses(); // 单品类检测模式 - 使用实际距离计算 if (numClasses === 1) { console.log('单品类检测模式 - 计算实际距离'); // 获取训练数据 const dataset = this.knnClassifier.getClassifierDataset(); if (!dataset || !dataset[0]) { return { label: -1, confidences: { 0: 0 }, minDistance: Infinity, isSingleClass: true }; } // 计算输入样本与所有训练样本的欧氏距离 const inputData = await logits.data(); const trainData = await dataset[0].data(); const numSamples = dataset[0].shape[0]; const featureDim = dataset[0].shape[1]; console.log(`输入特征维度: ${inputData.length}, 训练数据维度: ${featureDim}, 样本数: ${numSamples}`); // 确保维度匹配 if (inputData.length !== featureDim) { console.error(`维度不匹配!输入: ${inputData.length}, 训练: ${featureDim}`); return { label: -1, confidences: { 0: 0 }, minDistance: Infinity, isSingleClass: true }; } let minDistance = Infinity; const distances = []; // 计算与每个训练样本的距离 for (let i = 0; i < numSamples; i++) { let distance = 0; for (let j = 0; j < featureDim; j++) { const diff = inputData[j] - trainData[i * featureDim + j]; distance += diff * diff; } distance = Math.sqrt(distance); distances.push(distance); if (distance < minDistance) { minDistance = distance; } } console.log(`计算了 ${distances.length} 个距离,最小距离: ${minDistance.toFixed(2)}`); console.log(`前5个距离: ${distances.slice(0, 5).map(d => d.toFixed(2)).join(', ')}`); // 获取K个最近邻的平均距离 distances.sort((a, b) => a - b); const kNearest = distances.slice(0, Math.min(k, distances.length)); const avgDistance = kNearest.reduce((sum, d) => sum + d, 0) / kNearest.length; // 从UI获取距离阈值 const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '15.0'); // 基于距离阈值判断是否属于该类 const belongsToClass = avgDistance <= threshold; // 二值化置信度:在阈值内100%,超出阈值0% const confidence = belongsToClass ? 1.0 : 0; console.log(`单品类预测 - 平均距离: ${avgDistance.toFixed(2)}, 阈值: ${threshold}, 属于类别: ${belongsToClass}, 置信度: ${confidence.toFixed(3)}`); return { label: belongsToClass ? 0 : -1, confidences: { 0: confidence }, minDistance: avgDistance, isSingleClass: true }; } // 多品类模式:使用KNN分类器预测 try { const prediction = await this.knnClassifier.predictClass(logits, k); console.log('多品类预测结果:', prediction); console.log('预测标签:', prediction.label); console.log('置信度:', prediction.confidences); // 确保confidences存在且格式正确 let confidences = {}; if (prediction.confidences) { // 检查是否是对象格式 if (typeof prediction.confidences === 'object') { confidences = prediction.confidences; } } // 如果confidences为空,手动计算 if (Object.keys(confidences).length === 0) { console.warn('置信度为空,使用默认值'); for (let i = 0; i < this.classNames.length; i++) { confidences[i] = i === prediction.label ? 1.0 : 0; } } // 计算实际距离(可选) let minDistance = 0.5; // 默认距离 return { label: prediction.label, confidences: confidences, minDistance: minDistance, isSingleClass: false }; } catch (error) { console.error('预测错误:', error); return { label: -1, confidences: {}, minDistance: Infinity, isSingleClass: false }; } } // 应用低通滤波器到置信度 applyLowPassFilter(prediction) { // 获取滤波器系数 const alpha = parseFloat(document.getElementById('filterAlpha').value); // 保留原始的特殊字段 const isSingleClass = prediction.isSingleClass || false; const minDistance = prediction.minDistance; const isUnknown = prediction.isUnknown; // 初始化滤波状态(如果是第一次) if (Object.keys(this.filteredConfidences).length === 0) { for (let i = 0; i < this.classNames.length; i++) { this.filteredConfidences[i] = prediction.confidences[i] || 0; } return { label: prediction.label, confidences: {...this.filteredConfidences}, isSingleClass: isSingleClass, minDistance: minDistance, isUnknown: isUnknown }; } // 单品类模式下,使用二值化输出,不应用滤波 if (isSingleClass) { // 获取距离阈值 const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5'); // 二值化判断:在阈值内为1,超出为0 const inThreshold = minDistance <= threshold; const confidence = inThreshold ? 1.0 : 0; // 直接更新,不滤波 this.filteredConfidences[0] = confidence; return { label: inThreshold ? 0 : -1, confidences: { 0: confidence }, isSingleClass: isSingleClass, minDistance: minDistance, isUnknown: !inThreshold }; } // 应用指数移动平均(EMA)低通滤波 const newConfidences = {}; for (let i = 0; i < this.classNames.length; i++) { const currentValue = prediction.confidences[i] || 0; const previousValue = this.filteredConfidences[i] || 0; // EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1] this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue; newConfidences[i] = this.filteredConfidences[i]; } // 归一化确保总和为1 let sum = 0; Object.values(newConfidences).forEach(v => sum += v); if (sum > 0) { Object.keys(newConfidences).forEach(key => { newConfidences[key] = newConfidences[key] / sum; }); } // 找到最高置信度的类别 let maxConfidence = 0; let bestLabel = 0; Object.keys(newConfidences).forEach(key => { if (newConfidences[key] > maxConfidence) { maxConfidence = newConfidences[key]; bestLabel = parseInt(key); } }); return { label: bestLabel, confidences: newConfidences, isSingleClass: isSingleClass, minDistance: minDistance, isUnknown: isUnknown }; } // 显示预测结果 displayPrediction(prediction) { const container = document.getElementById('predictions'); let html = ''; // 单品类模式特殊处理 if (this.classNames.length === 1) { const className = this.classNames[0]; const confidence = prediction.confidences[0] || 0; const percentage = (confidence * 100).toFixed(1); const isDetected = prediction.label === 0; // 是否检测到该类 // 获取距离阈值 const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5'); const distance = prediction.minDistance || 0; // 显示单品类检测结果(二值化显示) html = `