621 lines
22 KiB
JavaScript
621 lines
22 KiB
JavaScript
// KNN 图像分类器 - 基于MobileNet特征标签
|
||
class KNNImageClassifier {
|
||
constructor() {
|
||
this.mobilenet = null;
|
||
this.knnClassifier = null;
|
||
this.classNames = [];
|
||
this.webcamStream = null;
|
||
this.isPredicting = false;
|
||
this.currentCaptureClass = -1;
|
||
this.imagenetClasses = null;
|
||
|
||
// 低通滤波器状态
|
||
this.filteredConfidences = {};
|
||
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
|
||
|
||
this.init();
|
||
}
|
||
|
||
async init() {
|
||
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
|
||
|
||
try {
|
||
// 加载 MobileNet 模型
|
||
this.mobilenet = await mobilenet.load({
|
||
version: 2,
|
||
alpha: 1.0
|
||
});
|
||
|
||
// 创建 KNN 分类器
|
||
this.knnClassifier = knnClassifier.create();
|
||
|
||
// 加载 ImageNet 类别名称
|
||
await this.loadImageNetClasses();
|
||
|
||
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
|
||
this.setupEventListeners();
|
||
} catch (error) {
|
||
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
async loadImageNetClasses() {
|
||
// ImageNet 前10个类别名称(简化版)
|
||
this.imagenetClasses = [
|
||
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
|
||
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
|
||
];
|
||
}
|
||
|
||
setupEventListeners() {
|
||
// 文件上传监听
|
||
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
|
||
document.getElementById(id).addEventListener('change', (e) => {
|
||
this.handleImageUpload(e, index);
|
||
});
|
||
});
|
||
|
||
// 按钮监听
|
||
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
|
||
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
|
||
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
|
||
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
|
||
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
|
||
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
|
||
}
|
||
|
||
handleImageUpload(event, classIndex) {
|
||
const files = event.target.files;
|
||
const countElement = document.getElementById(`class${classIndex + 1}Count`);
|
||
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
|
||
|
||
countElement.textContent = `${files.length} 张图片`;
|
||
|
||
// 清空之前的预览
|
||
previewContainer.innerHTML = '';
|
||
|
||
// 添加图片预览
|
||
Array.from(files).forEach(file => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = document.createElement('img');
|
||
img.src = e.target.result;
|
||
img.className = 'preview-img';
|
||
previewContainer.appendChild(img);
|
||
};
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
// 从图像提取 MobileNet 标签和权重
|
||
async extractImageNetTags(img) {
|
||
try {
|
||
// 获取 MobileNet 的预测(1000个类别的概率)
|
||
const predictions = await this.mobilenet.classify(img);
|
||
|
||
// 获取完整的 logits(原始输出)
|
||
const logits = this.mobilenet.infer(img, false); // false = 不使用嵌入层,获取原始1000维输出
|
||
|
||
// 获取前10个最高概率的标签
|
||
const topK = await this.getTopKTags(logits, 10);
|
||
|
||
return {
|
||
logits: logits, // 1000维特征向量
|
||
predictions: predictions, // 前3个预测
|
||
topTags: topK // 前10个标签和权重
|
||
};
|
||
} catch (error) {
|
||
console.error('特征提取失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
// 获取Top-K标签
|
||
async getTopKTags(logits, k = 10) {
|
||
const values = await logits.data();
|
||
const valuesAndIndices = [];
|
||
|
||
for (let i = 0; i < values.length; i++) {
|
||
valuesAndIndices.push({ value: values[i], index: i });
|
||
}
|
||
|
||
valuesAndIndices.sort((a, b) => b.value - a.value);
|
||
const topkValues = new Float32Array(k);
|
||
const topkIndices = new Int32Array(k);
|
||
|
||
for (let i = 0; i < k; i++) {
|
||
topkValues[i] = valuesAndIndices[i].value;
|
||
topkIndices[i] = valuesAndIndices[i].index;
|
||
}
|
||
|
||
const topTags = [];
|
||
for (let i = 0; i < k; i++) {
|
||
topTags.push({
|
||
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
|
||
probability: this.softmax(topkValues)[i],
|
||
logit: topkValues[i]
|
||
});
|
||
}
|
||
|
||
return topTags;
|
||
}
|
||
|
||
// Softmax 函数
|
||
softmax(arr) {
|
||
const maxLogit = Math.max(...arr);
|
||
const scores = arr.map(l => Math.exp(l - maxLogit));
|
||
const sum = scores.reduce((a, b) => a + b);
|
||
return scores.map(s => s / sum);
|
||
}
|
||
|
||
// 训练 KNN 模型
|
||
async trainKNN() {
|
||
const classes = [];
|
||
const imageFiles = [];
|
||
|
||
// 收集所有类别和图片
|
||
for (let i = 1; i <= 3; i++) {
|
||
const className = document.getElementById(`class${i}Name`).value.trim();
|
||
const files = document.getElementById(`class${i}Images`).files;
|
||
|
||
if (className && files.length > 0) {
|
||
classes.push(className);
|
||
imageFiles.push(files);
|
||
}
|
||
}
|
||
|
||
if (classes.length < 2) {
|
||
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
|
||
return;
|
||
}
|
||
|
||
this.classNames = classes;
|
||
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
|
||
|
||
// 清空现有的KNN分类器
|
||
this.knnClassifier.clearAllClasses();
|
||
|
||
let totalProcessed = 0;
|
||
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
|
||
|
||
// 处理每个类别的图片
|
||
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
|
||
const files = imageFiles[classIndex];
|
||
console.log(`处理类别 ${classes[classIndex]}...`);
|
||
|
||
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
|
||
try {
|
||
const img = await this.loadImage(files[fileIndex]);
|
||
|
||
// 提取特征标签
|
||
const features = await this.extractImageNetTags(img);
|
||
|
||
// 添加到KNN分类器
|
||
// 使用完整的logits作为特征向量
|
||
this.knnClassifier.addExample(features.logits, classIndex);
|
||
|
||
totalProcessed++;
|
||
const progress = Math.round((totalProcessed / totalImages) * 100);
|
||
this.showStatus('dataStatus', 'info',
|
||
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
|
||
|
||
// 显示提取的标签
|
||
if (fileIndex === 0) {
|
||
this.displayTopTags(features.topTags);
|
||
}
|
||
|
||
// 清理
|
||
img.remove();
|
||
features.logits.dispose();
|
||
} catch (error) {
|
||
console.error('处理图片失败:', error);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 更新模型信息
|
||
document.getElementById('totalSamples').textContent = totalProcessed;
|
||
|
||
this.showStatus('dataStatus', 'success',
|
||
`KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`);
|
||
|
||
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
|
||
}
|
||
|
||
// 显示提取的标签
|
||
displayTopTags(tags) {
|
||
const container = document.getElementById('tagsList');
|
||
let html = '';
|
||
|
||
tags.slice(0, 5).forEach(tag => {
|
||
html += `
|
||
<span class="tag-item">
|
||
${tag.className}
|
||
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
|
||
</span>
|
||
`;
|
||
});
|
||
|
||
container.innerHTML = html;
|
||
}
|
||
|
||
// 加载图片
|
||
async loadImage(file) {
|
||
return new Promise((resolve, reject) => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = new Image();
|
||
img.onload = () => resolve(img);
|
||
img.onerror = reject;
|
||
img.src = e.target.result;
|
||
};
|
||
reader.onerror = reject;
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
// 清空数据集
|
||
clearDataset() {
|
||
this.knnClassifier.clearAllClasses();
|
||
this.classNames = [];
|
||
this.filteredConfidences = {}; // 重置滤波器状态
|
||
|
||
for (let i = 1; i <= 3; i++) {
|
||
document.getElementById(`class${i}Images`).value = '';
|
||
document.getElementById(`class${i}Count`).textContent = '0 张图片';
|
||
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
|
||
}
|
||
|
||
document.getElementById('totalSamples').textContent = '0';
|
||
document.getElementById('tagsList').innerHTML = '等待数据...';
|
||
document.getElementById('predictions').innerHTML = '等待预测...';
|
||
|
||
this.showStatus('dataStatus', 'info', '数据集已清空');
|
||
}
|
||
|
||
// 启动摄像头
|
||
async startWebcam() {
|
||
if (this.knnClassifier.getNumClasses() === 0) {
|
||
this.showStatus('predictionStatus', 'error', '请先训练模型!');
|
||
return;
|
||
}
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({
|
||
video: { facingMode: 'user' },
|
||
audio: false
|
||
});
|
||
|
||
video.srcObject = stream;
|
||
this.webcamStream = stream;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = true;
|
||
document.getElementById('stopWebcamBtn').disabled = false;
|
||
|
||
// 等待视频加载
|
||
video.addEventListener('loadeddata', () => {
|
||
this.isPredicting = true;
|
||
this.predictLoop();
|
||
});
|
||
|
||
this.showStatus('predictionStatus', 'success', '摄像头已启动');
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
// 停止摄像头
|
||
stopWebcam() {
|
||
if (this.webcamStream) {
|
||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||
this.webcamStream = null;
|
||
}
|
||
|
||
this.isPredicting = false;
|
||
this.filteredConfidences = {}; // 重置滤波器状态
|
||
|
||
const video = document.getElementById('webcam');
|
||
video.srcObject = null;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = false;
|
||
document.getElementById('stopWebcamBtn').disabled = true;
|
||
|
||
this.showStatus('predictionStatus', 'info', '摄像头已停止');
|
||
}
|
||
|
||
// 预测循环
|
||
async predictLoop() {
|
||
if (!this.isPredicting) return;
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
if (video.readyState === 4) {
|
||
try {
|
||
// 提取特征
|
||
const features = await this.extractImageNetTags(video);
|
||
|
||
// 使用原始KNN进行预测
|
||
const k = parseInt(document.getElementById('kValue').value);
|
||
const prediction = await this.knnClassifier.predictClass(features.logits, k);
|
||
|
||
// 应用低通滤波器
|
||
const smoothedPrediction = this.applyLowPassFilter(prediction);
|
||
|
||
// 显示预测结果
|
||
this.displayPrediction(smoothedPrediction);
|
||
|
||
// 显示提取的标签
|
||
this.displayTopTags(features.topTags);
|
||
|
||
// 清理张量
|
||
features.logits.dispose();
|
||
} catch (error) {
|
||
console.error('预测错误:', error);
|
||
}
|
||
}
|
||
|
||
// 继续预测循环
|
||
requestAnimationFrame(() => this.predictLoop());
|
||
}
|
||
|
||
// 应用低通滤波器到置信度
|
||
applyLowPassFilter(prediction) {
|
||
// 获取滤波器系数
|
||
const alpha = parseFloat(document.getElementById('filterAlpha').value);
|
||
|
||
// 初始化滤波状态(如果是第一次)
|
||
if (Object.keys(this.filteredConfidences).length === 0) {
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
this.filteredConfidences[i] = prediction.confidences[i] || 0;
|
||
}
|
||
return {
|
||
label: prediction.label,
|
||
confidences: {...this.filteredConfidences}
|
||
};
|
||
}
|
||
|
||
// 应用指数移动平均(EMA)低通滤波
|
||
const newConfidences = {};
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
const currentValue = prediction.confidences[i] || 0;
|
||
const previousValue = this.filteredConfidences[i] || 0;
|
||
|
||
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
|
||
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
|
||
newConfidences[i] = this.filteredConfidences[i];
|
||
}
|
||
|
||
// 归一化确保总和为1
|
||
let sum = 0;
|
||
Object.values(newConfidences).forEach(v => sum += v);
|
||
if (sum > 0) {
|
||
Object.keys(newConfidences).forEach(key => {
|
||
newConfidences[key] = newConfidences[key] / sum;
|
||
});
|
||
}
|
||
|
||
// 找到最高置信度的类别
|
||
let maxConfidence = 0;
|
||
let bestLabel = 0;
|
||
Object.keys(newConfidences).forEach(key => {
|
||
if (newConfidences[key] > maxConfidence) {
|
||
maxConfidence = newConfidences[key];
|
||
bestLabel = parseInt(key);
|
||
}
|
||
});
|
||
|
||
return {
|
||
label: bestLabel,
|
||
confidences: newConfidences
|
||
};
|
||
}
|
||
|
||
// 显示预测结果
|
||
displayPrediction(prediction) {
|
||
const container = document.getElementById('predictions');
|
||
let html = '';
|
||
|
||
// 直接使用滤波后的置信度
|
||
const confidences = prediction.confidences;
|
||
const predictedClass = prediction.label;
|
||
|
||
// 固定顺序显示(按类别索引)
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
const className = this.classNames[i];
|
||
const confidence = confidences[i] || 0;
|
||
const percentage = (confidence * 100).toFixed(1);
|
||
const isWinner = i === predictedClass;
|
||
|
||
// 根据置信度决定颜色等级
|
||
let barClass = '';
|
||
if (confidence > 0.7) barClass = 'high';
|
||
else if (confidence > 0.4) barClass = 'medium';
|
||
else barClass = 'low';
|
||
|
||
// 如果是获胜类别,使用绿色
|
||
if (isWinner) barClass = 'high';
|
||
|
||
html += `
|
||
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
|
||
<div class="prediction-header">
|
||
<span class="prediction-label">
|
||
${className} ${isWinner ? '👑' : ''}
|
||
</span>
|
||
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
|
||
${percentage}%
|
||
</span>
|
||
</div>
|
||
<div class="confidence-bar-container">
|
||
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
|
||
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
|
||
</div>
|
||
</div>
|
||
</div>
|
||
`;
|
||
}
|
||
|
||
container.innerHTML = html;
|
||
}
|
||
|
||
// 从摄像头捕获样本
|
||
async captureFromWebcam(classIndex) {
|
||
if (!this.webcamStream) {
|
||
// 临时启动摄像头
|
||
const video = document.getElementById('webcam');
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({
|
||
video: { facingMode: 'user' },
|
||
audio: false
|
||
});
|
||
|
||
video.srcObject = stream;
|
||
this.webcamStream = stream;
|
||
|
||
// 等待视频加载
|
||
setTimeout(async () => {
|
||
await this.addWebcamSample(classIndex);
|
||
|
||
// 停止临时摄像头
|
||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||
this.webcamStream = null;
|
||
video.srcObject = null;
|
||
}, 1000);
|
||
} catch (error) {
|
||
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||
}
|
||
} else {
|
||
await this.addWebcamSample(classIndex);
|
||
}
|
||
}
|
||
|
||
// 添加摄像头样本
|
||
async addWebcamSample(classIndex) {
|
||
const video = document.getElementById('webcam');
|
||
|
||
if (video.readyState === 4) {
|
||
try {
|
||
// 提取特征
|
||
const features = await this.extractImageNetTags(video);
|
||
|
||
// 添加到KNN分类器
|
||
this.knnClassifier.addExample(features.logits, classIndex);
|
||
|
||
// 更新计数
|
||
const currentCount = this.knnClassifier.getClassExampleCount();
|
||
const count = currentCount[classIndex] || 0;
|
||
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
|
||
|
||
// 清理
|
||
features.logits.dispose();
|
||
|
||
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
|
||
} catch (error) {
|
||
console.error('添加样本失败:', error);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 保存模型
|
||
async saveModel() {
|
||
if (this.knnClassifier.getNumClasses() === 0) {
|
||
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
|
||
return;
|
||
}
|
||
|
||
try {
|
||
// 获取KNN分类器的数据
|
||
const dataset = this.knnClassifier.getClassifierDataset();
|
||
const datasetObj = {};
|
||
|
||
Object.keys(dataset).forEach(key => {
|
||
const data = dataset[key].dataSync();
|
||
datasetObj[key] = Array.from(data);
|
||
});
|
||
|
||
// 保存为JSON
|
||
const modelData = {
|
||
dataset: datasetObj,
|
||
classNames: this.classNames,
|
||
k: document.getElementById('kValue').value,
|
||
date: new Date().toISOString()
|
||
};
|
||
|
||
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
|
||
const url = URL.createObjectURL(blob);
|
||
const a = document.createElement('a');
|
||
a.href = url;
|
||
a.download = 'knn-model.json';
|
||
a.click();
|
||
URL.revokeObjectURL(url);
|
||
|
||
this.showStatus('predictionStatus', 'success', '模型已保存');
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
// 加载模型
|
||
async loadModel() {
|
||
const input = document.createElement('input');
|
||
input.type = 'file';
|
||
input.accept = '.json';
|
||
|
||
input.onchange = async (e) => {
|
||
try {
|
||
const file = e.target.files[0];
|
||
const text = await file.text();
|
||
const modelData = JSON.parse(text);
|
||
|
||
// 清空现有分类器
|
||
this.knnClassifier.clearAllClasses();
|
||
|
||
// 恢复数据集
|
||
Object.keys(modelData.dataset).forEach(key => {
|
||
const tensor = tf.tensor(modelData.dataset[key], [modelData.dataset[key].length / 1024, 1024]);
|
||
this.knnClassifier.setClassifierDataset({ [key]: tensor });
|
||
});
|
||
|
||
this.classNames = modelData.classNames;
|
||
document.getElementById('kValue').value = modelData.k;
|
||
document.getElementById('kValueDisplay').textContent = modelData.k;
|
||
|
||
this.showStatus('predictionStatus', 'success',
|
||
`模型加载成功!类别: ${this.classNames.join(', ')}`);
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
|
||
}
|
||
};
|
||
|
||
input.click();
|
||
}
|
||
|
||
// 显示状态
|
||
showStatus(elementId, type, message) {
|
||
const element = document.getElementById(elementId);
|
||
|
||
const classMap = {
|
||
'success': 'status-success',
|
||
'error': 'status-error',
|
||
'info': 'status-info'
|
||
};
|
||
|
||
element.className = `status-message ${classMap[type]}`;
|
||
element.textContent = message;
|
||
}
|
||
}
|
||
|
||
// 全局函数:从摄像头捕获
|
||
function captureFromWebcam(classIndex) {
|
||
if (window.classifier) {
|
||
window.classifier.captureFromWebcam(classIndex);
|
||
}
|
||
}
|
||
|
||
// 初始化应用
|
||
let classifier;
|
||
document.addEventListener('DOMContentLoaded', () => {
|
||
classifier = new KNNImageClassifier();
|
||
window.classifier = classifier;
|
||
}); |