// 图像分类器 - 基于MobileNet特征标签和 decision-tree.js 实现随机森林 class ImageClassifier { constructor() { this.mobilenet = null; this.randomForest = []; // 存储多个决策树 this.classNames = []; this.webcamStream = null; this.isPredicting = false; this.imagenetClasses = null; this.trainingSet = []; this.numTrees = 10; // 随机森林中决策树的数量,可调整 this.subsetSize = 0.7; // 训练集子集大小, 可调整 this.init(); } async init() { this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...'); try { this.mobilenet = await mobilenet.load({ version: 2, alpha: 1.0 }); await this.loadImageNetClasses(); this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!'); this.setupEventListeners(); } catch (error) { this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`); } } async loadImageNetClasses() { // ImageNet 前10个类别名称(简化版) this.imagenetClasses = [ 'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead', 'electric_ray', 'stingray', 'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house_finch', 'junco', 'indigo_bunting', 'robin', 'bulbul', 'jay', 'magpie', 'chickadee' ]; } setupEventListeners() { // 文件上传监听 ['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => { document.getElementById(id).addEventListener('change', (e) => { this.handleImageUpload(e, index); }); }); // 按钮监听 document.getElementById('addDataBtn').addEventListener('click', () => this.trainModel()); document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset()); document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam()); document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam()); // 参数监听 document.getElementById('numTrees').addEventListener('input', (e) => { this.numTrees = parseInt(e.target.value); }); document.getElementById('subsetSize').addEventListener('input', (e) => { this.subsetSize = parseFloat(e.target.value); }); } handleImageUpload(event, classIndex) { const files = event.target.files; const countElement = document.getElementById(`class${classIndex + 1}Count`); const previewContainer = document.getElementById(`class${classIndex + 1}Preview`); countElement.textContent = `${files.length} 张图片`; // 清空之前的预览 previewContainer.innerHTML = ''; // 添加图片预览 Array.from(files).forEach(file => { const reader = new FileReader(); reader.onload = (e) => { const img = document.createElement('img'); img.src = e.target.result; img.className = 'preview-img'; previewContainer.appendChild(img); }; reader.readAsDataURL(file); }); } async extractImageNetTags(img) { try { const predictions = await this.mobilenet.classify(img); const logits = this.mobilenet.infer(img, false); return { logits: logits, // 1000维特征向量 predictions: predictions, // 前3个预测 topTags: await this.getTopKTags(logits, 10) // 前10个标签和权重 }; } catch (error) { console.error('特征提取失败:', error); throw error; } } async getTopKTags(logits, k = 10) { const values = await logits.data(); const valuesAndIndices = []; for (let i = 0; i < values.length; i++) { valuesAndIndices.push({ value: values[i], index: i }); } valuesAndIndices.sort((a, b) => b.value - a.value); const topkValues = new Float32Array(k); const topkIndices = new Int32Array(k); for (let i = 0; i < k; i++) { topkValues[i] = valuesAndIndices[i].value; topkIndices[i] = valuesAndIndices[i].index; } const topTags = []; for (let i = 0; i < k; i++) { topTags.push({ className: this.imagenetClasses[topkIndices[i]] || `class_${topkIndices[i]}`, probability: this.softmax(topkValues)[i], logit: topkValues[i] }); } return topTags; } softmax(arr) { const maxLogit = Math.max(...arr); const scores = arr.map(l => Math.exp(l - maxLogit)); const sum = scores.reduce((a, b) => a + b); return scores.map(s => s / sum); } // 训练随机森林模型 async trainModel() { const classes = []; const imageFiles = []; // 收集所有类别和图片 for (let i = 1; i <= 3; i++) { const className = document.getElementById(`class${i}Name`).value.trim(); const files = document.getElementById(`class${i}Images`).files; if (className && files.length > 0) { classes.push(className); imageFiles.push(files); } } if (classes.length < 2) { this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!'); return; } this.classNames = classes; this.showStatus('dataStatus', 'info', '正在处理图片并训练模型...'); // 准备训练数据 this.trainingSet = []; let totalProcessed = 0; let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0); // 处理每个类别的图片 for (let classIndex = 0; classIndex < classes.length; classIndex++) { const files = imageFiles[classIndex]; for (let fileIndex = 0; fileIndex < files.length; fileIndex++) { try { const img = await this.loadImage(files[fileIndex]); const features = await this.extractImageNetTags(img); // 将logits从tf.Tensor转换为数组 const featureVector = await features.logits.data(); // 将特征向量添加到训练数据 const item = {}; Array.from(featureVector).forEach((value, index) => { item[`feature_${index}`] = value; }); item.category = classes[classIndex]; // 类别名称,而不是索引 this.trainingSet.push(item); totalProcessed++; const progress = Math.round((totalProcessed / totalImages) * 100); this.showStatus('dataStatus', 'info', `处理中... ${totalProcessed}/${totalImages} (${progress}%)`); img.remove(); features.logits.dispose(); } catch (error) { console.error('处理图片失败:', error); } } } try { // 构建随机森林 this.randomForest = []; for (let i = 0; i < this.numTrees; i++) { // 创建具有随机子集的决策树 const trainingSubset = this.createTrainingSubset(this.trainingSet); const builder = { trainingSet: trainingSubset, categoryAttr: 'category' }; const tree = new dt.DecisionTree(builder); this.randomForest.push(tree); } this.showStatus('dataStatus', 'success', `模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别, ${this.numTrees} 棵树`); // 更新模型信息 document.getElementById('totalSamples').textContent = totalProcessed; } catch (error) { console.error('训练失败:', error); this.showStatus('dataStatus', 'error', `训练失败: ${error.message}`); } } // 创建训练数据的随机子集(用于随机森林) createTrainingSubset(trainingSet) { const subset = []; const subsetSize = Math.floor(this.subsetSize * trainingSet.length); for (let i = 0; i < subsetSize; i++) { const randomIndex = Math.floor(Math.random() * trainingSet.length); subset.push(trainingSet[randomIndex]); } return subset; } async loadImage(file) { return new Promise((resolve, reject) => { const reader = new FileReader(); reader.onload = (e) => { const img = new Image(); img.onload = () => resolve(img); img.onerror = reject; img.src = e.target.result; }; reader.onerror = reject; reader.readAsDataURL(file); }); } clearDataset() { this.randomForest = []; this.classNames = []; this.trainingSet = []; for (let i = 1; i <= 3; i++) { document.getElementById(`class${i}Images`).value = ''; document.getElementById(`class${i}Count`).textContent = '0 张图片'; document.getElementById(`class${i}Preview`).innerHTML = ''; } document.getElementById('totalSamples').textContent = '0'; document.getElementById('predictions').innerHTML = '等待预测...'; this.showStatus('dataStatus', 'info', '数据集已清空'); } startWebcam() { if (this.randomForest.length == 0) { this.showStatus('predictionStatus', 'error', '请先训练模型!'); return; } const video = document.getElementById('webcam'); navigator.mediaDevices.getUserMedia({ video: { facingMode: 'user' }, audio: false }) .then(stream => { video.srcObject = stream; this.webcamStream = stream; document.getElementById('startWebcamBtn').disabled = true; document.getElementById('stopWebcamBtn').disabled = false; this.isPredicting = true; this.predictLoop(); this.showStatus('predictionStatus', 'success', '摄像头已启动'); }) .catch(error => { this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`); }); } stopWebcam() { if (this.webcamStream) { this.webcamStream.getTracks().forEach(track => track.stop()); this.webcamStream = null; } this.isPredicting = false; const video = document.getElementById('webcam'); video.srcObject = null; document.getElementById('startWebcamBtn').disabled = false; document.getElementById('stopWebcamBtn').disabled = true; this.showStatus('predictionStatus', 'info', '摄像头已停止'); } async predictLoop() { if (!this.isPredicting) return; const video = document.getElementById('webcam'); if (video.readyState === 4) { try { const features = await this.extractImageNetTags(video); const featureVector = await features.logits.data(); const item = {}; Array.from(featureVector).forEach((value, index) => { item[`feature_${index}`] = value; }); const { predictedCategory, probabilities } = this.predictWithRandomForest(item); this.displayPrediction(predictedCategory, probabilities); features.logits.dispose(); } catch (error) { console.error('预测错误:', error); } } requestAnimationFrame(() => this.predictLoop()); } predictWithRandomForest(item) { const votes = {}; this.classNames.forEach(className => { votes[className] = 0; }); this.randomForest.forEach(tree => { const prediction = tree.predict(item); votes[prediction] = (votes[prediction] || 0) + 1; }); let predictedCategory = null; let maxVotes = 0; for (const category in votes) { if (votes[category] > maxVotes) { predictedCategory = category; maxVotes = votes[category]; } } const probabilities = {}; for (const category in votes) { probabilities[category] = votes[category] / this.numTrees; } return { predictedCategory: predictedCategory, probabilities: probabilities }; } async captureFromWebcam(classIndex) { if (!this.webcamStream) { // 临时启动摄像头 const video = document.getElementById('webcam'); try { const stream = await navigator.mediaDevices.getUserMedia({ video: { facingMode: 'user' }, audio: false }); video.srcObject = stream; this.webcamStream = stream; // 等待视频加载 setTimeout(async () => { await this.addWebcamSample(classIndex); // 停止临时摄像头 this.webcamStream.getTracks().forEach(track => track.stop()); this.webcamStream = null; video.srcObject = null; }, 1000); } catch (error) { this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`); } } else { await this.addWebcamSample(classIndex); } } async addWebcamSample(classIndex) { const video = document.getElementById('webcam'); if (video.readyState === 4) { try { const features = await this.extractImageNetTags(video); const featureVector = await features.logits.data(); // 使用logits作为对象的属性 const item = {}; Array.from(featureVector).forEach((value, index) => { item[`feature_${index}`] = value; }); // 添加类别信息 const className = document.getElementById(`class${classIndex + 1}Name`).value.trim(); item.category = className; this.trainingSet.push(item); this.showStatus('dataStatus', 'success', `从摄像头添加 ${className} 类的样本`); features.logits.dispose(); } catch (error) { console.error('添加样本失败:', error); this.showStatus('dataStatus', 'error', `添加样本失败: ${error.message}`); } } } displayPrediction(category, probabilities) { const container = document.getElementById('predictions'); let html = `预测类别:${category}
`; for (const className in probabilities) { const probability = (probabilities[className] * 100).toFixed(2); html += `${className}: ${probability}%
`; } container.innerHTML = html; } showStatus(elementId, type, message) { const element = document.getElementById(elementId); const classMap = { 'success': 'status-success', 'error': 'status-error', 'info': 'status-info' }; element.className = `status-message ${classMap[type]}`; element.textContent = message; } } // 全局函数:从摄像头采集 function captureFromWebcam(classIndex) { if (window.classifier) { window.classifier.captureFromWebcam(classIndex); } } document.addEventListener('DOMContentLoaded', async () => { window.classifier = new ImageClassifier(); });