mobileNet/完善KNN/knn-classifier.js
2025-08-11 17:44:57 +08:00

985 lines
38 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// KNN 图像分类器 - 基于MobileNet特征标签
class KNNImageClassifier {
constructor() {
this.mobilenet = null;
this.knnClassifier = null;
this.classNames = [];
this.webcamStream = null;
this.isPredicting = false;
this.currentCaptureClass = -1;
this.imagenetClasses = null;
// 低通滤波器状态
this.filteredConfidences = {};
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
// 距离阈值设置
this.useDistanceThreshold = true;
this.distanceThreshold = 0.5; // 默认距离阈值(归一化后的特征)
this.adaptiveThreshold = null; // 自适应阈值
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
// 加载 MobileNet 模型
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
// 创建 KNN 分类器
this.knnClassifier = knnClassifier.create();
// 加载 ImageNet 类别名称
await this.loadImageNetClasses();
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
async loadImageNetClasses() {
// ImageNet 前10个类别名称简化版
this.imagenetClasses = [
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
];
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const countElement = document.getElementById(`class${classIndex + 1}Count`);
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
countElement.textContent = `${files.length} 张图片`;
// 清空之前的预览
previewContainer.innerHTML = '';
// 添加图片预览
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
}
// 从图像提取 MobileNet 标签和权重
async extractImageNetTags(img) {
try {
// 获取 MobileNet 的预测1000个类别的概率
const predictions = await this.mobilenet.classify(img);
// 获取用于KNN的特征使用嵌入层获得更好的特征表示
const rawEmbeddings = this.mobilenet.infer(img, true); // true = 使用嵌入层获取1280维特征
// L2归一化特征向量重要使距离计算更稳定
const embeddings = tf.tidy(() => {
const norm = tf.norm(rawEmbeddings);
const normalized = tf.div(rawEmbeddings, norm);
rawEmbeddings.dispose(); // 清理原始嵌入
return normalized;
});
// 获取用于显示的logits1000个类别
const logits = this.mobilenet.infer(img, false); // false = 获取原始1000维输出
// 获取前10个最高概率的标签
const topK = await this.getTopKTags(logits, 10);
// 清理logits只用于显示
logits.dispose();
return {
logits: embeddings, // 使用1280维嵌入特征用于KNN
predictions: predictions, // 前3个预测
topTags: topK // 前10个标签和权重
};
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
// 获取Top-K标签
async getTopKTags(logits, k = 10) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
const topkValues = new Float32Array(k);
const topkIndices = new Int32Array(k);
for (let i = 0; i < k; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topTags = [];
for (let i = 0; i < k; i++) {
topTags.push({
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
probability: this.softmax(topkValues)[i],
logit: topkValues[i]
});
}
return topTags;
}
// Softmax 函数
softmax(arr) {
const maxLogit = Math.max(...arr);
const scores = arr.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b);
return scores.map(s => s / sum);
}
// 训练 KNN 模型
async trainKNN() {
const classes = [];
const imageFiles = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files && files.length > 0) {
classes.push(className);
imageFiles.push(files);
console.log(`类别 ${i}: "${className}" - ${files.length} 张图片`);
}
}
console.log('收集到的类别:', classes);
console.log('类别数量:', classes.length);
// 支持单品类检测One-Class Classification
if (classes.length < 1) {
this.showStatus('dataStatus', 'error', '请至少添加一个类别的图片!');
return;
}
// 如果只有一个类别,提示用户这是单品类检测模式
if (classes.length === 1) {
console.log('📍 单品类检测模式:只检测 "' + classes[0] + '",其他都视为背景/未知');
}
this.classNames = classes;
this.filteredConfidences = {}; // 重置滤波器状态
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
// 清空现有的KNN分类器
this.knnClassifier.clearAllClasses();
let totalProcessed = 0;
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imageFiles[classIndex];
console.log(`处理类别 ${classes[classIndex]}...`);
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
const img = await this.loadImage(files[fileIndex]);
// 提取特征标签
const features = await this.extractImageNetTags(img);
// 添加到KNN分类器
// 使用完整的logits作为特征向量
this.knnClassifier.addExample(features.logits, classIndex);
totalProcessed++;
const progress = Math.round((totalProcessed / totalImages) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
// 显示提取的标签
if (fileIndex === 0) {
this.displayTopTags(features.topTags);
}
// 清理
img.remove();
features.logits.dispose();
} catch (error) {
console.error('处理图片失败:', error);
}
}
}
// 更新模型信息
document.getElementById('totalSamples').textContent = totalProcessed;
// 根据类别数量显示不同的消息
let statusMessage;
if (classes.length === 1) {
statusMessage = `单品类检测模型训练完成!将只检测 "${classes[0]}",共 ${totalProcessed} 个样本`;
} else {
statusMessage = `KNN模型训练完成${totalProcessed} 个样本,${classes.length} 个类别`;
}
this.showStatus('dataStatus', 'success', statusMessage);
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
if (classes.length === 1) {
console.log('📍 单品类检测模式已启用,将基于距离阈值判断是否为:', classes[0]);
// 计算自适应阈值
await this.calculateAdaptiveThreshold();
}
}
// 计算自适应阈值(基于训练数据的内部距离)
async calculateAdaptiveThreshold() {
if (this.knnClassifier.getNumClasses() !== 1) return;
console.log('计算自适应阈值...');
const dataset = this.knnClassifier.getClassifierDataset();
if (!dataset || !dataset[0]) return;
const trainData = await dataset[0].data();
const numSamples = dataset[0].shape[0];
const featureDim = dataset[0].shape[1];
// 计算训练样本之间的平均距离
let totalDistance = 0;
let count = 0;
for (let i = 0; i < Math.min(numSamples, 20); i++) { // 限制计算量
for (let j = i + 1; j < Math.min(numSamples, 20); j++) {
let distance = 0;
for (let k = 0; k < featureDim; k++) {
const diff = trainData[i * featureDim + k] - trainData[j * featureDim + k];
distance += diff * diff;
}
distance = Math.sqrt(distance);
totalDistance += distance;
count++;
}
}
if (count > 0) {
const avgInternalDistance = totalDistance / count;
// 自适应阈值设为内部平均距离的1.3-1.5倍(归一化后距离较小)
this.adaptiveThreshold = avgInternalDistance * 1.3;
console.log(`内部平均距离: ${avgInternalDistance.toFixed(2)}`);
console.log(`建议自适应阈值: ${this.adaptiveThreshold.toFixed(2)}`);
// 更新UI显示建议阈值
const thresholdInput = document.getElementById('distanceThreshold');
const thresholdDisplay = document.getElementById('distanceThresholdDisplay');
if (thresholdInput && thresholdDisplay) {
thresholdInput.value = this.adaptiveThreshold.toFixed(1);
thresholdDisplay.textContent = this.adaptiveThreshold.toFixed(1);
}
this.showStatus('dataStatus', 'info',
`自适应阈值已计算: ${this.adaptiveThreshold.toFixed(1)} (基于训练数据内部距离)`);
}
}
// 显示提取的标签
displayTopTags(tags) {
const container = document.getElementById('tagsList');
let html = '';
tags.slice(0, 5).forEach(tag => {
html += `
<span class="tag-item">
${tag.className}
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
</span>
`;
});
container.innerHTML = html;
}
// 加载图片
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
// 清空数据集
clearDataset() {
this.knnClassifier.clearAllClasses();
this.classNames = [];
this.filteredConfidences = {}; // 重置滤波器状态
console.log('数据集已清空,滤波器状态已重置');
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
}
document.getElementById('totalSamples').textContent = '0';
document.getElementById('tagsList').innerHTML = '等待数据...';
document.getElementById('predictions').innerHTML = '等待预测...';
this.showStatus('dataStatus', 'info', '数据集已清空');
}
// 启动摄像头
async startWebcam() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
// 等待视频加载
video.addEventListener('loadeddata', () => {
this.isPredicting = true;
this.predictLoop();
});
this.showStatus('predictionStatus', 'success', '摄像头已启动');
} catch (error) {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
}
}
// 停止摄像头
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
this.filteredConfidences = {}; // 重置滤波器状态
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
// 预测循环
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
console.log('开始预测,类别数量:', this.classNames.length);
// 提取特征
const features = await this.extractImageNetTags(video);
// 使用原始KNN进行预测包含距离信息
const k = parseInt(document.getElementById('kValue').value);
const predictionWithDistance = await this.predictWithDistance(features.logits, k);
console.log('预测结果:', predictionWithDistance);
// 检查是否为未知类(基于距离阈值)
let finalPrediction;
// 单品类模式特殊处理
if (predictionWithDistance.isSingleClass) {
// 直接使用 predictWithDistance 返回的结果,它已经处理了阈值判断
finalPrediction = {
label: predictionWithDistance.label,
confidences: predictionWithDistance.confidences,
isUnknown: predictionWithDistance.label === -1,
minDistance: predictionWithDistance.minDistance,
isSingleClass: true
};
} else {
// 多类别模式直接使用KNN预测结果不使用距离阈值
finalPrediction = {
label: predictionWithDistance.label,
confidences: predictionWithDistance.confidences,
isUnknown: false,
minDistance: predictionWithDistance.minDistance,
isSingleClass: false
};
}
// 应用低通滤波器
const smoothedPrediction = this.applyLowPassFilter(finalPrediction);
// 显示预测结果
this.displayPrediction(smoothedPrediction);
// 显示提取的标签
this.displayTopTags(features.topTags);
// 清理张量
features.logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
// 继续预测循环
requestAnimationFrame(() => this.predictLoop());
}
// 使用距离信息进行预测
async predictWithDistance(logits, k) {
// 如果没有训练数据,返回空结果
if (this.knnClassifier.getNumClasses() === 0) {
return {
label: -1,
confidences: {},
minDistance: Infinity,
isSingleClass: false
};
}
const numClasses = this.knnClassifier.getNumClasses();
// 单品类检测模式 - 使用实际距离计算
if (numClasses === 1) {
console.log('单品类检测模式 - 计算实际距离');
// 获取训练数据
const dataset = this.knnClassifier.getClassifierDataset();
if (!dataset || !dataset[0]) {
return {
label: -1,
confidences: { 0: 0 },
minDistance: Infinity,
isSingleClass: true
};
}
// 计算输入样本与所有训练样本的欧氏距离
const inputData = await logits.data();
const trainData = await dataset[0].data();
const numSamples = dataset[0].shape[0];
const featureDim = dataset[0].shape[1];
console.log(`输入特征维度: ${inputData.length}, 训练数据维度: ${featureDim}, 样本数: ${numSamples}`);
// 确保维度匹配
if (inputData.length !== featureDim) {
console.error(`维度不匹配!输入: ${inputData.length}, 训练: ${featureDim}`);
return {
label: -1,
confidences: { 0: 0 },
minDistance: Infinity,
isSingleClass: true
};
}
let minDistance = Infinity;
const distances = [];
// 计算与每个训练样本的距离
for (let i = 0; i < numSamples; i++) {
let distance = 0;
for (let j = 0; j < featureDim; j++) {
const diff = inputData[j] - trainData[i * featureDim + j];
distance += diff * diff;
}
distance = Math.sqrt(distance);
distances.push(distance);
if (distance < minDistance) {
minDistance = distance;
}
}
console.log(`计算了 ${distances.length} 个距离,最小距离: ${minDistance.toFixed(2)}`);
console.log(`前5个距离: ${distances.slice(0, 5).map(d => d.toFixed(2)).join(', ')}`);
// 获取K个最近邻的平均距离
distances.sort((a, b) => a - b);
const kNearest = distances.slice(0, Math.min(k, distances.length));
const avgDistance = kNearest.reduce((sum, d) => sum + d, 0) / kNearest.length;
// 从UI获取距离阈值
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '15.0');
// 基于距离阈值判断是否属于该类
const belongsToClass = avgDistance <= threshold;
// 二值化置信度在阈值内100%超出阈值0%
const confidence = belongsToClass ? 1.0 : 0;
console.log(`单品类预测 - 平均距离: ${avgDistance.toFixed(2)}, 阈值: ${threshold}, 属于类别: ${belongsToClass}, 置信度: ${confidence.toFixed(3)}`);
return {
label: belongsToClass ? 0 : -1,
confidences: { 0: confidence },
minDistance: avgDistance,
isSingleClass: true
};
}
// 多品类模式使用KNN分类器预测
try {
const prediction = await this.knnClassifier.predictClass(logits, k);
console.log('多品类预测结果:', prediction);
console.log('预测标签:', prediction.label);
console.log('置信度:', prediction.confidences);
// 确保confidences存在且格式正确
let confidences = {};
if (prediction.confidences) {
// 检查是否是对象格式
if (typeof prediction.confidences === 'object') {
confidences = prediction.confidences;
}
}
// 如果confidences为空手动计算
if (Object.keys(confidences).length === 0) {
console.warn('置信度为空,使用默认值');
for (let i = 0; i < this.classNames.length; i++) {
confidences[i] = i === prediction.label ? 1.0 : 0;
}
}
// 计算实际距离(可选)
let minDistance = 0.5; // 默认距离
return {
label: prediction.label,
confidences: confidences,
minDistance: minDistance,
isSingleClass: false
};
} catch (error) {
console.error('预测错误:', error);
return {
label: -1,
confidences: {},
minDistance: Infinity,
isSingleClass: false
};
}
}
// 应用低通滤波器到置信度
applyLowPassFilter(prediction) {
// 获取滤波器系数
const alpha = parseFloat(document.getElementById('filterAlpha').value);
// 保留原始的特殊字段
const isSingleClass = prediction.isSingleClass || false;
const minDistance = prediction.minDistance;
const isUnknown = prediction.isUnknown;
// 初始化滤波状态(如果是第一次)
if (Object.keys(this.filteredConfidences).length === 0) {
for (let i = 0; i < this.classNames.length; i++) {
this.filteredConfidences[i] = prediction.confidences[i] || 0;
}
return {
label: prediction.label,
confidences: {...this.filteredConfidences},
isSingleClass: isSingleClass,
minDistance: minDistance,
isUnknown: isUnknown
};
}
// 单品类模式下,使用二值化输出,不应用滤波
if (isSingleClass) {
// 获取距离阈值
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
// 二值化判断在阈值内为1超出为0
const inThreshold = minDistance <= threshold;
const confidence = inThreshold ? 1.0 : 0;
// 直接更新,不滤波
this.filteredConfidences[0] = confidence;
return {
label: inThreshold ? 0 : -1,
confidences: { 0: confidence },
isSingleClass: isSingleClass,
minDistance: minDistance,
isUnknown: !inThreshold
};
}
// 应用指数移动平均EMA低通滤波
const newConfidences = {};
for (let i = 0; i < this.classNames.length; i++) {
const currentValue = prediction.confidences[i] || 0;
const previousValue = this.filteredConfidences[i] || 0;
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
newConfidences[i] = this.filteredConfidences[i];
}
// 归一化确保总和为1
let sum = 0;
Object.values(newConfidences).forEach(v => sum += v);
if (sum > 0) {
Object.keys(newConfidences).forEach(key => {
newConfidences[key] = newConfidences[key] / sum;
});
}
// 找到最高置信度的类别
let maxConfidence = 0;
let bestLabel = 0;
Object.keys(newConfidences).forEach(key => {
if (newConfidences[key] > maxConfidence) {
maxConfidence = newConfidences[key];
bestLabel = parseInt(key);
}
});
return {
label: bestLabel,
confidences: newConfidences,
isSingleClass: isSingleClass,
minDistance: minDistance,
isUnknown: isUnknown
};
}
// 显示预测结果
displayPrediction(prediction) {
const container = document.getElementById('predictions');
let html = '';
// 单品类模式特殊处理
if (this.classNames.length === 1) {
const className = this.classNames[0];
const confidence = prediction.confidences[0] || 0;
const percentage = (confidence * 100).toFixed(1);
const isDetected = prediction.label === 0; // 是否检测到该类
// 获取距离阈值
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
const distance = prediction.minDistance || 0;
// 显示单品类检测结果(二值化显示)
html = `
<div class="prediction-item" style="${isDetected ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : 'border-left-color: #cbd5e0;'}">
<div class="prediction-header">
<span class="prediction-label">
${className} ${isDetected ? '✓ 检测到' : '✗ 未检测到'}
</span>
<span class="prediction-confidence" style="${isDetected ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : 'background: #cbd5e0;'}">
${isDetected ? '100%' : '0%'}
</span>
</div>
<div class="confidence-bar-container">
<div class="confidence-bar ${isDetected ? 'high' : 'low'}" style="width: ${isDetected ? '100' : '0'}%;">
${isDetected ? `<span class="confidence-percentage">100%</span>` : ''}
</div>
</div>
</div>
<div style="margin-top: 10px; padding: 10px; background: #f8f9fa; border-radius: 5px; font-size: 12px; color: #666;">
<strong>距离:</strong> ${distance.toFixed(2)} | <strong>阈值:</strong> ${threshold.toFixed(2)}
<span style="margin-left: 10px; color: ${distance <= threshold ? '#48bb78' : '#f56565'};">
${distance <= threshold ? '✓ 在阈值范围内' : '✗ 超出阈值范围'}
</span>
</div>
`;
container.innerHTML = html;
return;
}
// 多品类模式:直接使用滤波后的置信度
const confidences = prediction.confidences;
const predictedClass = prediction.label;
// 固定顺序显示(按类别索引)
for (let i = 0; i < this.classNames.length; i++) {
const className = this.classNames[i];
const confidence = confidences[i] || 0;
const percentage = (confidence * 100).toFixed(1);
const isWinner = i === predictedClass;
// 根据置信度决定颜色等级
let barClass = '';
if (confidence > 0.7) barClass = 'high';
else if (confidence > 0.4) barClass = 'medium';
else barClass = 'low';
// 如果是获胜类别,使用绿色
if (isWinner) barClass = 'high';
html += `
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
<div class="prediction-header">
<span class="prediction-label">
${className} ${isWinner ? '👑' : ''}
</span>
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
${percentage}%
</span>
</div>
<div class="confidence-bar-container">
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
</div>
</div>
</div>
`;
}
container.innerHTML = html;
}
// 从摄像头捕获样本
async captureFromWebcam(classIndex) {
if (!this.webcamStream) {
// 临时启动摄像头
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
// 等待视频加载
setTimeout(async () => {
await this.addWebcamSample(classIndex);
// 停止临时摄像头
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
video.srcObject = null;
}, 1000);
} catch (error) {
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
}
} else {
await this.addWebcamSample(classIndex);
}
}
// 添加摄像头样本
async addWebcamSample(classIndex) {
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
// 提取特征
const features = await this.extractImageNetTags(video);
// 添加到KNN分类器
this.knnClassifier.addExample(features.logits, classIndex);
// 更新计数
const currentCount = this.knnClassifier.getClassExampleCount();
const count = currentCount[classIndex] || 0;
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
// 清理
features.logits.dispose();
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
} catch (error) {
console.error('添加样本失败:', error);
}
}
}
// 保存模型
async saveModel() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
return;
}
try {
// 获取KNN分类器的数据
const dataset = this.knnClassifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach(key => {
const data = dataset[key].dataSync();
datasetObj[key] = Array.from(data);
});
// 获取特征维度
let featureDim = 1280; // 默认值
const firstKey = Object.keys(dataset)[0];
if (firstKey && dataset[firstKey]) {
featureDim = dataset[firstKey].shape[1];
}
// 保存为JSON
const modelData = {
dataset: datasetObj,
classNames: this.classNames,
k: document.getElementById('kValue').value,
featureDim: featureDim, // 保存特征维度
date: new Date().toISOString()
};
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'knn-model.json';
a.click();
URL.revokeObjectURL(url);
this.showStatus('predictionStatus', 'success', '模型已保存');
} catch (error) {
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
}
}
// 加载模型
async loadModel() {
const input = document.createElement('input');
input.type = 'file';
input.accept = '.json';
input.onchange = async (e) => {
try {
const file = e.target.files[0];
const text = await file.text();
const modelData = JSON.parse(text);
// 清空现有分类器
this.knnClassifier.clearAllClasses();
// 恢复数据集
Object.keys(modelData.dataset).forEach(key => {
const data = modelData.dataset[key];
// 自动检测特征维度(兼容旧模型)
let featureDim = modelData.featureDim;
if (!featureDim) {
// 尝试常见的维度
const possibleDims = [1280, 1024, 1000];
for (const dim of possibleDims) {
if (data.length % dim === 0) {
featureDim = dim;
console.warn(`自动检测到特征维度: ${dim}`);
break;
}
}
}
if (!featureDim) {
console.error(`无法确定特征维度,数据长度: ${data.length}`);
return;
}
const numSamples = data.length / featureDim;
console.log(`加载类别 ${key}${numSamples} 个样本,${featureDim} 维特征`);
const tensor = tf.tensor(data, [numSamples, featureDim]);
this.knnClassifier.setClassifierDataset({ [key]: tensor });
});
this.classNames = modelData.classNames;
document.getElementById('kValue').value = modelData.k;
document.getElementById('kValueDisplay').textContent = modelData.k;
this.showStatus('predictionStatus', 'success',
`模型加载成功!类别: ${this.classNames.join(', ')}`);
} catch (error) {
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
}
};
input.click();
}
// 显示状态
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info'
};
element.className = `status-message ${classMap[type]}`;
element.textContent = message;
}
}
// 全局函数:从摄像头捕获
function captureFromWebcam(classIndex) {
if (window.classifier) {
window.classifier.captureFromWebcam(classIndex);
}
}
// 初始化应用
let classifier;
document.addEventListener('DOMContentLoaded', () => {
classifier = new KNNImageClassifier();
window.classifier = classifier;
});