// KNN 图像分类器 - 基于MobileNet特征标签 class KNNImageClassifier { constructor() { this.mobilenet = null; this.knnClassifier = null; this.classNames = []; this.webcamStream = null; this.isPredicting = false; this.currentCaptureClass = -1; this.imagenetClasses = null; // 低通滤波器状态 this.filteredConfidences = {}; this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑 this.init(); } async init() { this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); try { // 加载 MobileNet 模型 this.mobilenet = await mobilenet.load({ version: 2, alpha: 1.0 }); // 创建 KNN 分类器 this.knnClassifier = knnClassifier.create(); // 加载 ImageNet 类别名称 await this.loadImageNetClasses(); this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); this.setupEventListeners(); } catch (error) { this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); } } async loadImageNetClasses() { // ImageNet 前10个类别名称(简化版) this.imagenetClasses = [ 'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead', 'electric_ray', 'stingray', 'cock', 'hen', 'ostrich' ]; } setupEventListeners() { // 文件上传监听 ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { document.getElementById(id).addEventListener('change', (e) => { this.handleImageUpload(e, index); }); }); // 按钮监听 document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN()); document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel()); document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel()); } handleImageUpload(event, classIndex) { const files = event.target.files; const countElement = document.getElementById(`class${classIndex + 1}Count`); const previewContainer = document.getElementById(`class${classIndex + 1}Preview`); countElement.textContent = `${files.length} 张图片`; // 清空之前的预览 previewContainer.innerHTML = ''; // 添加图片预览 Array.from(files).forEach(file => { const reader = new FileReader(); reader.onload = (e) => { const img = document.createElement('img'); img.src = e.target.result; img.className = 'preview-img'; previewContainer.appendChild(img); }; reader.readAsDataURL(file); }); } // 从图像提取 MobileNet 标签和权重 async extractImageNetTags(img) { try { // 获取 MobileNet 的预测(1000个类别的概率) const predictions = await this.mobilenet.classify(img); // 获取完整的 logits(原始输出) const logits = this.mobilenet.infer(img, false); // false = 不使用嵌入层,获取原始1000维输出 // 获取前10个最高概率的标签 const topK = await this.getTopKTags(logits, 10); return { logits: logits, // 1000维特征向量 predictions: predictions, // 前3个预测 topTags: topK // 前10个标签和权重 }; } catch (error) { console.error('特征提取失败:', error); throw error; } } // 获取Top-K标签 async getTopKTags(logits, k = 10) { const values = await logits.data(); const valuesAndIndices = []; for (let i = 0; i < values.length; i++) { valuesAndIndices.push({ value: values[i], index: i }); } valuesAndIndices.sort((a, b) => b.value - a.value); const topkValues = new Float32Array(k); const topkIndices = new Int32Array(k); for (let i = 0; i < k; i++) { topkValues[i] = valuesAndIndices[i].value; topkIndices[i] = valuesAndIndices[i].index; } const topTags = []; for (let i = 0; i < k; i++) { topTags.push({ className: this.imagenetClasses[i] || `class_${topkIndices[i]}`, probability: this.softmax(topkValues)[i], logit: topkValues[i] }); } return topTags; } // Softmax 函数 softmax(arr) { const maxLogit = Math.max(...arr); const scores = arr.map(l => Math.exp(l - maxLogit)); const sum = scores.reduce((a, b) => a + b); return scores.map(s => s / sum); } // 训练 KNN 模型 async trainKNN() { const classes = []; const imageFiles = []; // 收集所有类别和图片 for (let i = 1; i <= 3; i++) { const className = document.getElementById(`class${i}Name`).value.trim(); const files = document.getElementById(`class${i}Images`).files; if (className && files.length > 0) { classes.push(className); imageFiles.push(files); } } if (classes.length < 2) { this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!'); return; } this.classNames = classes; this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...'); // 清空现有的KNN分类器 this.knnClassifier.clearAllClasses(); let totalProcessed = 0; let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0); // 处理每个类别的图片 for (let classIndex = 0; classIndex < classes.length; classIndex++) { const files = imageFiles[classIndex]; console.log(`处理类别 ${classes[classIndex]}...`); for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { try { const img = await this.loadImage(files[fileIndex]); // 提取特征标签 const features = await this.extractImageNetTags(img); // 添加到KNN分类器 // 使用完整的logits作为特征向量 this.knnClassifier.addExample(features.logits, classIndex); totalProcessed++; const progress = Math.round((totalProcessed / totalImages) * 100); this.showStatus('dataStatus', 'info', `处理中... ${totalProcessed}/${totalImages} (${progress}%)`); // 显示提取的标签 if (fileIndex === 0) { this.displayTopTags(features.topTags); } // 清理 img.remove(); features.logits.dispose(); } catch (error) { console.error('处理图片失败:', error); } } } // 更新模型信息 document.getElementById('totalSamples').textContent = totalProcessed; this.showStatus('dataStatus', 'success', `KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`); console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别'); } // 显示提取的标签 displayTopTags(tags) { const container = document.getElementById('tagsList'); let html = ''; tags.slice(0, 5).forEach(tag => { html += ` ${tag.className} ${(tag.probability * 100).toFixed(1)}% `; }); container.innerHTML = html; } // 加载图片 async loadImage(file) { return new Promise((resolve, reject) => { const reader = new FileReader(); reader.onload = (e) => { const img = new Image(); img.onload = () => resolve(img); img.onerror = reject; img.src = e.target.result; }; reader.onerror = reject; reader.readAsDataURL(file); }); } // 清空数据集 clearDataset() { this.knnClassifier.clearAllClasses(); this.classNames = []; this.filteredConfidences = {}; // 重置滤波器状态 for (let i = 1; i <= 3; i++) { document.getElementById(`class${i}Images`).value = ''; document.getElementById(`class${i}Count`).textContent = '0 张图片'; document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览 } document.getElementById('totalSamples').textContent = '0'; document.getElementById('tagsList').innerHTML = '等待数据...'; document.getElementById('predictions').innerHTML = '等待预测...'; this.showStatus('dataStatus', 'info', '数据集已清空'); } // 启动摄像头 async startWebcam() { if (this.knnClassifier.getNumClasses() === 0) { this.showStatus('predictionStatus', 'error', '请先训练模型!'); return; } const video = document.getElementById('webcam'); try { const stream = await navigator.mediaDevices.getUserMedia({ video: { facingMode: 'user' }, audio: false }); video.srcObject = stream; this.webcamStream = stream; document.getElementById('startWebcamBtn').disabled = true; document.getElementById('stopWebcamBtn').disabled = false; // 等待视频加载 video.addEventListener('loadeddata', () => { this.isPredicting = true; this.predictLoop(); }); this.showStatus('predictionStatus', 'success', '摄像头已启动'); } catch (error) { this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); } } // 停止摄像头 stopWebcam() { if (this.webcamStream) { this.webcamStream.getTracks().forEach(track => track.stop()); this.webcamStream = null; } this.isPredicting = false; this.filteredConfidences = {}; // 重置滤波器状态 const video = document.getElementById('webcam'); video.srcObject = null; document.getElementById('startWebcamBtn').disabled = false; document.getElementById('stopWebcamBtn').disabled = true; this.showStatus('predictionStatus', 'info', '摄像头已停止'); } // 预测循环 async predictLoop() { if (!this.isPredicting) return; const video = document.getElementById('webcam'); if (video.readyState === 4) { try { // 提取特征 const features = await this.extractImageNetTags(video); // 使用原始KNN进行预测 const k = parseInt(document.getElementById('kValue').value); const prediction = await this.knnClassifier.predictClass(features.logits, k); // 应用低通滤波器 const smoothedPrediction = this.applyLowPassFilter(prediction); // 显示预测结果 this.displayPrediction(smoothedPrediction); // 显示提取的标签 this.displayTopTags(features.topTags); // 清理张量 features.logits.dispose(); } catch (error) { console.error('预测错误:', error); } } // 继续预测循环 requestAnimationFrame(() => this.predictLoop()); } // 应用低通滤波器到置信度 applyLowPassFilter(prediction) { // 获取滤波器系数 const alpha = parseFloat(document.getElementById('filterAlpha').value); // 初始化滤波状态(如果是第一次) if (Object.keys(this.filteredConfidences).length === 0) { for (let i = 0; i < this.classNames.length; i++) { this.filteredConfidences[i] = prediction.confidences[i] || 0; } return { label: prediction.label, confidences: {...this.filteredConfidences} }; } // 应用指数移动平均(EMA)低通滤波 const newConfidences = {}; for (let i = 0; i < this.classNames.length; i++) { const currentValue = prediction.confidences[i] || 0; const previousValue = this.filteredConfidences[i] || 0; // EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1] this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue; newConfidences[i] = this.filteredConfidences[i]; } // 归一化确保总和为1 let sum = 0; Object.values(newConfidences).forEach(v => sum += v); if (sum > 0) { Object.keys(newConfidences).forEach(key => { newConfidences[key] = newConfidences[key] / sum; }); } // 找到最高置信度的类别 let maxConfidence = 0; let bestLabel = 0; Object.keys(newConfidences).forEach(key => { if (newConfidences[key] > maxConfidence) { maxConfidence = newConfidences[key]; bestLabel = parseInt(key); } }); return { label: bestLabel, confidences: newConfidences }; } // 显示预测结果 displayPrediction(prediction) { const container = document.getElementById('predictions'); let html = ''; // 直接使用滤波后的置信度 const confidences = prediction.confidences; const predictedClass = prediction.label; // 固定顺序显示(按类别索引) for (let i = 0; i < this.classNames.length; i++) { const className = this.classNames[i]; const confidence = confidences[i] || 0; const percentage = (confidence * 100).toFixed(1); const isWinner = i === predictedClass; // 根据置信度决定颜色等级 let barClass = ''; if (confidence > 0.7) barClass = 'high'; else if (confidence > 0.4) barClass = 'medium'; else barClass = 'low'; // 如果是获胜类别,使用绿色 if (isWinner) barClass = 'high'; html += `