mobileNet/原版KNN/knn-classifier.js
2025-08-11 17:44:57 +08:00

621 lines
22 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// KNN 图像分类器 - 基于MobileNet特征标签
class KNNImageClassifier {
constructor() {
this.mobilenet = null;
this.knnClassifier = null;
this.classNames = [];
this.webcamStream = null;
this.isPredicting = false;
this.currentCaptureClass = -1;
this.imagenetClasses = null;
// 低通滤波器状态
this.filteredConfidences = {};
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
// 加载 MobileNet 模型
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
// 创建 KNN 分类器
this.knnClassifier = knnClassifier.create();
// 加载 ImageNet 类别名称
await this.loadImageNetClasses();
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
async loadImageNetClasses() {
// ImageNet 前10个类别名称简化版
this.imagenetClasses = [
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
];
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const countElement = document.getElementById(`class${classIndex + 1}Count`);
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
countElement.textContent = `${files.length} 张图片`;
// 清空之前的预览
previewContainer.innerHTML = '';
// 添加图片预览
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
}
// 从图像提取 MobileNet 标签和权重
async extractImageNetTags(img) {
try {
// 获取 MobileNet 的预测1000个类别的概率
const predictions = await this.mobilenet.classify(img);
// 获取完整的 logits原始输出
const logits = this.mobilenet.infer(img, false); // false = 不使用嵌入层获取原始1000维输出
// 获取前10个最高概率的标签
const topK = await this.getTopKTags(logits, 10);
return {
logits: logits, // 1000维特征向量
predictions: predictions, // 前3个预测
topTags: topK // 前10个标签和权重
};
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
// 获取Top-K标签
async getTopKTags(logits, k = 10) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
const topkValues = new Float32Array(k);
const topkIndices = new Int32Array(k);
for (let i = 0; i < k; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topTags = [];
for (let i = 0; i < k; i++) {
topTags.push({
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
probability: this.softmax(topkValues)[i],
logit: topkValues[i]
});
}
return topTags;
}
// Softmax 函数
softmax(arr) {
const maxLogit = Math.max(...arr);
const scores = arr.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b);
return scores.map(s => s / sum);
}
// 训练 KNN 模型
async trainKNN() {
const classes = [];
const imageFiles = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files.length > 0) {
classes.push(className);
imageFiles.push(files);
}
}
if (classes.length < 2) {
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
return;
}
this.classNames = classes;
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
// 清空现有的KNN分类器
this.knnClassifier.clearAllClasses();
let totalProcessed = 0;
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imageFiles[classIndex];
console.log(`处理类别 ${classes[classIndex]}...`);
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
const img = await this.loadImage(files[fileIndex]);
// 提取特征标签
const features = await this.extractImageNetTags(img);
// 添加到KNN分类器
// 使用完整的logits作为特征向量
this.knnClassifier.addExample(features.logits, classIndex);
totalProcessed++;
const progress = Math.round((totalProcessed / totalImages) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
// 显示提取的标签
if (fileIndex === 0) {
this.displayTopTags(features.topTags);
}
// 清理
img.remove();
features.logits.dispose();
} catch (error) {
console.error('处理图片失败:', error);
}
}
}
// 更新模型信息
document.getElementById('totalSamples').textContent = totalProcessed;
this.showStatus('dataStatus', 'success',
`KNN模型训练完成${totalProcessed} 个样本,${classes.length} 个类别`);
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
}
// 显示提取的标签
displayTopTags(tags) {
const container = document.getElementById('tagsList');
let html = '';
tags.slice(0, 5).forEach(tag => {
html += `
<span class="tag-item">
${tag.className}
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
</span>
`;
});
container.innerHTML = html;
}
// 加载图片
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
// 清空数据集
clearDataset() {
this.knnClassifier.clearAllClasses();
this.classNames = [];
this.filteredConfidences = {}; // 重置滤波器状态
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
}
document.getElementById('totalSamples').textContent = '0';
document.getElementById('tagsList').innerHTML = '等待数据...';
document.getElementById('predictions').innerHTML = '等待预测...';
this.showStatus('dataStatus', 'info', '数据集已清空');
}
// 启动摄像头
async startWebcam() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
// 等待视频加载
video.addEventListener('loadeddata', () => {
this.isPredicting = true;
this.predictLoop();
});
this.showStatus('predictionStatus', 'success', '摄像头已启动');
} catch (error) {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
}
}
// 停止摄像头
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
this.filteredConfidences = {}; // 重置滤波器状态
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
// 预测循环
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
// 提取特征
const features = await this.extractImageNetTags(video);
// 使用原始KNN进行预测
const k = parseInt(document.getElementById('kValue').value);
const prediction = await this.knnClassifier.predictClass(features.logits, k);
// 应用低通滤波器
const smoothedPrediction = this.applyLowPassFilter(prediction);
// 显示预测结果
this.displayPrediction(smoothedPrediction);
// 显示提取的标签
this.displayTopTags(features.topTags);
// 清理张量
features.logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
// 继续预测循环
requestAnimationFrame(() => this.predictLoop());
}
// 应用低通滤波器到置信度
applyLowPassFilter(prediction) {
// 获取滤波器系数
const alpha = parseFloat(document.getElementById('filterAlpha').value);
// 初始化滤波状态(如果是第一次)
if (Object.keys(this.filteredConfidences).length === 0) {
for (let i = 0; i < this.classNames.length; i++) {
this.filteredConfidences[i] = prediction.confidences[i] || 0;
}
return {
label: prediction.label,
confidences: {...this.filteredConfidences}
};
}
// 应用指数移动平均EMA低通滤波
const newConfidences = {};
for (let i = 0; i < this.classNames.length; i++) {
const currentValue = prediction.confidences[i] || 0;
const previousValue = this.filteredConfidences[i] || 0;
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
newConfidences[i] = this.filteredConfidences[i];
}
// 归一化确保总和为1
let sum = 0;
Object.values(newConfidences).forEach(v => sum += v);
if (sum > 0) {
Object.keys(newConfidences).forEach(key => {
newConfidences[key] = newConfidences[key] / sum;
});
}
// 找到最高置信度的类别
let maxConfidence = 0;
let bestLabel = 0;
Object.keys(newConfidences).forEach(key => {
if (newConfidences[key] > maxConfidence) {
maxConfidence = newConfidences[key];
bestLabel = parseInt(key);
}
});
return {
label: bestLabel,
confidences: newConfidences
};
}
// 显示预测结果
displayPrediction(prediction) {
const container = document.getElementById('predictions');
let html = '';
// 直接使用滤波后的置信度
const confidences = prediction.confidences;
const predictedClass = prediction.label;
// 固定顺序显示(按类别索引)
for (let i = 0; i < this.classNames.length; i++) {
const className = this.classNames[i];
const confidence = confidences[i] || 0;
const percentage = (confidence * 100).toFixed(1);
const isWinner = i === predictedClass;
// 根据置信度决定颜色等级
let barClass = '';
if (confidence > 0.7) barClass = 'high';
else if (confidence > 0.4) barClass = 'medium';
else barClass = 'low';
// 如果是获胜类别,使用绿色
if (isWinner) barClass = 'high';
html += `
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
<div class="prediction-header">
<span class="prediction-label">
${className} ${isWinner ? '👑' : ''}
</span>
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
${percentage}%
</span>
</div>
<div class="confidence-bar-container">
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
</div>
</div>
</div>
`;
}
container.innerHTML = html;
}
// 从摄像头捕获样本
async captureFromWebcam(classIndex) {
if (!this.webcamStream) {
// 临时启动摄像头
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
// 等待视频加载
setTimeout(async () => {
await this.addWebcamSample(classIndex);
// 停止临时摄像头
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
video.srcObject = null;
}, 1000);
} catch (error) {
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
}
} else {
await this.addWebcamSample(classIndex);
}
}
// 添加摄像头样本
async addWebcamSample(classIndex) {
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
// 提取特征
const features = await this.extractImageNetTags(video);
// 添加到KNN分类器
this.knnClassifier.addExample(features.logits, classIndex);
// 更新计数
const currentCount = this.knnClassifier.getClassExampleCount();
const count = currentCount[classIndex] || 0;
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
// 清理
features.logits.dispose();
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
} catch (error) {
console.error('添加样本失败:', error);
}
}
}
// 保存模型
async saveModel() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
return;
}
try {
// 获取KNN分类器的数据
const dataset = this.knnClassifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach(key => {
const data = dataset[key].dataSync();
datasetObj[key] = Array.from(data);
});
// 保存为JSON
const modelData = {
dataset: datasetObj,
classNames: this.classNames,
k: document.getElementById('kValue').value,
date: new Date().toISOString()
};
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'knn-model.json';
a.click();
URL.revokeObjectURL(url);
this.showStatus('predictionStatus', 'success', '模型已保存');
} catch (error) {
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
}
}
// 加载模型
async loadModel() {
const input = document.createElement('input');
input.type = 'file';
input.accept = '.json';
input.onchange = async (e) => {
try {
const file = e.target.files[0];
const text = await file.text();
const modelData = JSON.parse(text);
// 清空现有分类器
this.knnClassifier.clearAllClasses();
// 恢复数据集
Object.keys(modelData.dataset).forEach(key => {
const tensor = tf.tensor(modelData.dataset[key], [modelData.dataset[key].length / 1024, 1024]);
this.knnClassifier.setClassifierDataset({ [key]: tensor });
});
this.classNames = modelData.classNames;
document.getElementById('kValue').value = modelData.k;
document.getElementById('kValueDisplay').textContent = modelData.k;
this.showStatus('predictionStatus', 'success',
`模型加载成功!类别: ${this.classNames.join(', ')}`);
} catch (error) {
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
}
};
input.click();
}
// 显示状态
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info'
};
element.className = `status-message ${classMap[type]}`;
element.textContent = message;
}
}
// 全局函数:从摄像头捕获
function captureFromWebcam(classIndex) {
if (window.classifier) {
window.classifier.captureFromWebcam(classIndex);
}
}
// 初始化应用
let classifier;
document.addEventListener('DOMContentLoaded', () => {
classifier = new KNNImageClassifier();
window.classifier = classifier;
});