985 lines
38 KiB
JavaScript
985 lines
38 KiB
JavaScript
// KNN 图像分类器 - 基于MobileNet特征标签
|
||
class KNNImageClassifier {
|
||
constructor() {
|
||
this.mobilenet = null;
|
||
this.knnClassifier = null;
|
||
this.classNames = [];
|
||
this.webcamStream = null;
|
||
this.isPredicting = false;
|
||
this.currentCaptureClass = -1;
|
||
this.imagenetClasses = null;
|
||
|
||
// 低通滤波器状态
|
||
this.filteredConfidences = {};
|
||
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
|
||
|
||
// 距离阈值设置
|
||
this.useDistanceThreshold = true;
|
||
this.distanceThreshold = 0.5; // 默认距离阈值(归一化后的特征)
|
||
this.adaptiveThreshold = null; // 自适应阈值
|
||
|
||
this.init();
|
||
}
|
||
|
||
async init() {
|
||
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
|
||
|
||
try {
|
||
// 加载 MobileNet 模型
|
||
this.mobilenet = await mobilenet.load({
|
||
version: 2,
|
||
alpha: 1.0
|
||
});
|
||
|
||
// 创建 KNN 分类器
|
||
this.knnClassifier = knnClassifier.create();
|
||
|
||
// 加载 ImageNet 类别名称
|
||
await this.loadImageNetClasses();
|
||
|
||
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
|
||
this.setupEventListeners();
|
||
} catch (error) {
|
||
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
async loadImageNetClasses() {
|
||
// ImageNet 前10个类别名称(简化版)
|
||
this.imagenetClasses = [
|
||
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
|
||
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
|
||
];
|
||
}
|
||
|
||
setupEventListeners() {
|
||
// 文件上传监听
|
||
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
|
||
document.getElementById(id).addEventListener('change', (e) => {
|
||
this.handleImageUpload(e, index);
|
||
});
|
||
});
|
||
|
||
// 按钮监听
|
||
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
|
||
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
|
||
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
|
||
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
|
||
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
|
||
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
|
||
}
|
||
|
||
handleImageUpload(event, classIndex) {
|
||
const files = event.target.files;
|
||
const countElement = document.getElementById(`class${classIndex + 1}Count`);
|
||
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
|
||
|
||
countElement.textContent = `${files.length} 张图片`;
|
||
|
||
// 清空之前的预览
|
||
previewContainer.innerHTML = '';
|
||
|
||
// 添加图片预览
|
||
Array.from(files).forEach(file => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = document.createElement('img');
|
||
img.src = e.target.result;
|
||
img.className = 'preview-img';
|
||
previewContainer.appendChild(img);
|
||
};
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
// 从图像提取 MobileNet 标签和权重
|
||
async extractImageNetTags(img) {
|
||
try {
|
||
// 获取 MobileNet 的预测(1000个类别的概率)
|
||
const predictions = await this.mobilenet.classify(img);
|
||
|
||
// 获取用于KNN的特征(使用嵌入层获得更好的特征表示)
|
||
const rawEmbeddings = this.mobilenet.infer(img, true); // true = 使用嵌入层,获取1280维特征
|
||
|
||
// L2归一化特征向量(重要:使距离计算更稳定)
|
||
const embeddings = tf.tidy(() => {
|
||
const norm = tf.norm(rawEmbeddings);
|
||
const normalized = tf.div(rawEmbeddings, norm);
|
||
rawEmbeddings.dispose(); // 清理原始嵌入
|
||
return normalized;
|
||
});
|
||
|
||
// 获取用于显示的logits(1000个类别)
|
||
const logits = this.mobilenet.infer(img, false); // false = 获取原始1000维输出
|
||
|
||
// 获取前10个最高概率的标签
|
||
const topK = await this.getTopKTags(logits, 10);
|
||
|
||
// 清理logits(只用于显示)
|
||
logits.dispose();
|
||
|
||
return {
|
||
logits: embeddings, // 使用1280维嵌入特征用于KNN
|
||
predictions: predictions, // 前3个预测
|
||
topTags: topK // 前10个标签和权重
|
||
};
|
||
} catch (error) {
|
||
console.error('特征提取失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
// 获取Top-K标签
|
||
async getTopKTags(logits, k = 10) {
|
||
const values = await logits.data();
|
||
const valuesAndIndices = [];
|
||
|
||
for (let i = 0; i < values.length; i++) {
|
||
valuesAndIndices.push({ value: values[i], index: i });
|
||
}
|
||
|
||
valuesAndIndices.sort((a, b) => b.value - a.value);
|
||
const topkValues = new Float32Array(k);
|
||
const topkIndices = new Int32Array(k);
|
||
|
||
for (let i = 0; i < k; i++) {
|
||
topkValues[i] = valuesAndIndices[i].value;
|
||
topkIndices[i] = valuesAndIndices[i].index;
|
||
}
|
||
|
||
const topTags = [];
|
||
for (let i = 0; i < k; i++) {
|
||
topTags.push({
|
||
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
|
||
probability: this.softmax(topkValues)[i],
|
||
logit: topkValues[i]
|
||
});
|
||
}
|
||
|
||
return topTags;
|
||
}
|
||
|
||
// Softmax 函数
|
||
softmax(arr) {
|
||
const maxLogit = Math.max(...arr);
|
||
const scores = arr.map(l => Math.exp(l - maxLogit));
|
||
const sum = scores.reduce((a, b) => a + b);
|
||
return scores.map(s => s / sum);
|
||
}
|
||
|
||
// 训练 KNN 模型
|
||
async trainKNN() {
|
||
const classes = [];
|
||
const imageFiles = [];
|
||
|
||
// 收集所有类别和图片
|
||
for (let i = 1; i <= 3; i++) {
|
||
const className = document.getElementById(`class${i}Name`).value.trim();
|
||
const files = document.getElementById(`class${i}Images`).files;
|
||
|
||
if (className && files && files.length > 0) {
|
||
classes.push(className);
|
||
imageFiles.push(files);
|
||
console.log(`类别 ${i}: "${className}" - ${files.length} 张图片`);
|
||
}
|
||
}
|
||
|
||
console.log('收集到的类别:', classes);
|
||
console.log('类别数量:', classes.length);
|
||
|
||
// 支持单品类检测(One-Class Classification)
|
||
if (classes.length < 1) {
|
||
this.showStatus('dataStatus', 'error', '请至少添加一个类别的图片!');
|
||
return;
|
||
}
|
||
|
||
// 如果只有一个类别,提示用户这是单品类检测模式
|
||
if (classes.length === 1) {
|
||
console.log('📍 单品类检测模式:只检测 "' + classes[0] + '",其他都视为背景/未知');
|
||
}
|
||
|
||
this.classNames = classes;
|
||
this.filteredConfidences = {}; // 重置滤波器状态
|
||
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
|
||
|
||
// 清空现有的KNN分类器
|
||
this.knnClassifier.clearAllClasses();
|
||
|
||
let totalProcessed = 0;
|
||
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
|
||
|
||
// 处理每个类别的图片
|
||
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
|
||
const files = imageFiles[classIndex];
|
||
console.log(`处理类别 ${classes[classIndex]}...`);
|
||
|
||
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
|
||
try {
|
||
const img = await this.loadImage(files[fileIndex]);
|
||
|
||
// 提取特征标签
|
||
const features = await this.extractImageNetTags(img);
|
||
|
||
// 添加到KNN分类器
|
||
// 使用完整的logits作为特征向量
|
||
this.knnClassifier.addExample(features.logits, classIndex);
|
||
|
||
totalProcessed++;
|
||
const progress = Math.round((totalProcessed / totalImages) * 100);
|
||
this.showStatus('dataStatus', 'info',
|
||
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
|
||
|
||
// 显示提取的标签
|
||
if (fileIndex === 0) {
|
||
this.displayTopTags(features.topTags);
|
||
}
|
||
|
||
// 清理
|
||
img.remove();
|
||
features.logits.dispose();
|
||
} catch (error) {
|
||
console.error('处理图片失败:', error);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 更新模型信息
|
||
document.getElementById('totalSamples').textContent = totalProcessed;
|
||
|
||
// 根据类别数量显示不同的消息
|
||
let statusMessage;
|
||
if (classes.length === 1) {
|
||
statusMessage = `单品类检测模型训练完成!将只检测 "${classes[0]}",共 ${totalProcessed} 个样本`;
|
||
} else {
|
||
statusMessage = `KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`;
|
||
}
|
||
|
||
this.showStatus('dataStatus', 'success', statusMessage);
|
||
|
||
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
|
||
if (classes.length === 1) {
|
||
console.log('📍 单品类检测模式已启用,将基于距离阈值判断是否为:', classes[0]);
|
||
// 计算自适应阈值
|
||
await this.calculateAdaptiveThreshold();
|
||
}
|
||
}
|
||
|
||
// 计算自适应阈值(基于训练数据的内部距离)
|
||
async calculateAdaptiveThreshold() {
|
||
if (this.knnClassifier.getNumClasses() !== 1) return;
|
||
|
||
console.log('计算自适应阈值...');
|
||
|
||
const dataset = this.knnClassifier.getClassifierDataset();
|
||
if (!dataset || !dataset[0]) return;
|
||
|
||
const trainData = await dataset[0].data();
|
||
const numSamples = dataset[0].shape[0];
|
||
const featureDim = dataset[0].shape[1];
|
||
|
||
// 计算训练样本之间的平均距离
|
||
let totalDistance = 0;
|
||
let count = 0;
|
||
|
||
for (let i = 0; i < Math.min(numSamples, 20); i++) { // 限制计算量
|
||
for (let j = i + 1; j < Math.min(numSamples, 20); j++) {
|
||
let distance = 0;
|
||
for (let k = 0; k < featureDim; k++) {
|
||
const diff = trainData[i * featureDim + k] - trainData[j * featureDim + k];
|
||
distance += diff * diff;
|
||
}
|
||
distance = Math.sqrt(distance);
|
||
totalDistance += distance;
|
||
count++;
|
||
}
|
||
}
|
||
|
||
if (count > 0) {
|
||
const avgInternalDistance = totalDistance / count;
|
||
// 自适应阈值设为内部平均距离的1.3-1.5倍(归一化后距离较小)
|
||
this.adaptiveThreshold = avgInternalDistance * 1.3;
|
||
|
||
console.log(`内部平均距离: ${avgInternalDistance.toFixed(2)}`);
|
||
console.log(`建议自适应阈值: ${this.adaptiveThreshold.toFixed(2)}`);
|
||
|
||
// 更新UI显示建议阈值
|
||
const thresholdInput = document.getElementById('distanceThreshold');
|
||
const thresholdDisplay = document.getElementById('distanceThresholdDisplay');
|
||
if (thresholdInput && thresholdDisplay) {
|
||
thresholdInput.value = this.adaptiveThreshold.toFixed(1);
|
||
thresholdDisplay.textContent = this.adaptiveThreshold.toFixed(1);
|
||
}
|
||
|
||
this.showStatus('dataStatus', 'info',
|
||
`自适应阈值已计算: ${this.adaptiveThreshold.toFixed(1)} (基于训练数据内部距离)`);
|
||
}
|
||
}
|
||
|
||
// 显示提取的标签
|
||
displayTopTags(tags) {
|
||
const container = document.getElementById('tagsList');
|
||
let html = '';
|
||
|
||
tags.slice(0, 5).forEach(tag => {
|
||
html += `
|
||
<span class="tag-item">
|
||
${tag.className}
|
||
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
|
||
</span>
|
||
`;
|
||
});
|
||
|
||
container.innerHTML = html;
|
||
}
|
||
|
||
// 加载图片
|
||
async loadImage(file) {
|
||
return new Promise((resolve, reject) => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = new Image();
|
||
img.onload = () => resolve(img);
|
||
img.onerror = reject;
|
||
img.src = e.target.result;
|
||
};
|
||
reader.onerror = reject;
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
// 清空数据集
|
||
clearDataset() {
|
||
this.knnClassifier.clearAllClasses();
|
||
this.classNames = [];
|
||
this.filteredConfidences = {}; // 重置滤波器状态
|
||
|
||
console.log('数据集已清空,滤波器状态已重置');
|
||
|
||
for (let i = 1; i <= 3; i++) {
|
||
document.getElementById(`class${i}Images`).value = '';
|
||
document.getElementById(`class${i}Count`).textContent = '0 张图片';
|
||
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
|
||
}
|
||
|
||
document.getElementById('totalSamples').textContent = '0';
|
||
document.getElementById('tagsList').innerHTML = '等待数据...';
|
||
document.getElementById('predictions').innerHTML = '等待预测...';
|
||
|
||
this.showStatus('dataStatus', 'info', '数据集已清空');
|
||
}
|
||
|
||
// 启动摄像头
|
||
async startWebcam() {
|
||
if (this.knnClassifier.getNumClasses() === 0) {
|
||
this.showStatus('predictionStatus', 'error', '请先训练模型!');
|
||
return;
|
||
}
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({
|
||
video: { facingMode: 'user' },
|
||
audio: false
|
||
});
|
||
|
||
video.srcObject = stream;
|
||
this.webcamStream = stream;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = true;
|
||
document.getElementById('stopWebcamBtn').disabled = false;
|
||
|
||
// 等待视频加载
|
||
video.addEventListener('loadeddata', () => {
|
||
this.isPredicting = true;
|
||
this.predictLoop();
|
||
});
|
||
|
||
this.showStatus('predictionStatus', 'success', '摄像头已启动');
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
// 停止摄像头
|
||
stopWebcam() {
|
||
if (this.webcamStream) {
|
||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||
this.webcamStream = null;
|
||
}
|
||
|
||
this.isPredicting = false;
|
||
this.filteredConfidences = {}; // 重置滤波器状态
|
||
|
||
const video = document.getElementById('webcam');
|
||
video.srcObject = null;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = false;
|
||
document.getElementById('stopWebcamBtn').disabled = true;
|
||
|
||
this.showStatus('predictionStatus', 'info', '摄像头已停止');
|
||
}
|
||
|
||
// 预测循环
|
||
async predictLoop() {
|
||
if (!this.isPredicting) return;
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
if (video.readyState === 4) {
|
||
try {
|
||
console.log('开始预测,类别数量:', this.classNames.length);
|
||
|
||
// 提取特征
|
||
const features = await this.extractImageNetTags(video);
|
||
|
||
// 使用原始KNN进行预测(包含距离信息)
|
||
const k = parseInt(document.getElementById('kValue').value);
|
||
const predictionWithDistance = await this.predictWithDistance(features.logits, k);
|
||
|
||
console.log('预测结果:', predictionWithDistance);
|
||
|
||
// 检查是否为未知类(基于距离阈值)
|
||
let finalPrediction;
|
||
|
||
// 单品类模式特殊处理
|
||
if (predictionWithDistance.isSingleClass) {
|
||
// 直接使用 predictWithDistance 返回的结果,它已经处理了阈值判断
|
||
finalPrediction = {
|
||
label: predictionWithDistance.label,
|
||
confidences: predictionWithDistance.confidences,
|
||
isUnknown: predictionWithDistance.label === -1,
|
||
minDistance: predictionWithDistance.minDistance,
|
||
isSingleClass: true
|
||
};
|
||
} else {
|
||
// 多类别模式:直接使用KNN预测结果,不使用距离阈值
|
||
finalPrediction = {
|
||
label: predictionWithDistance.label,
|
||
confidences: predictionWithDistance.confidences,
|
||
isUnknown: false,
|
||
minDistance: predictionWithDistance.minDistance,
|
||
isSingleClass: false
|
||
};
|
||
}
|
||
|
||
// 应用低通滤波器
|
||
const smoothedPrediction = this.applyLowPassFilter(finalPrediction);
|
||
|
||
// 显示预测结果
|
||
this.displayPrediction(smoothedPrediction);
|
||
|
||
// 显示提取的标签
|
||
this.displayTopTags(features.topTags);
|
||
|
||
// 清理张量
|
||
features.logits.dispose();
|
||
} catch (error) {
|
||
console.error('预测错误:', error);
|
||
}
|
||
}
|
||
|
||
// 继续预测循环
|
||
requestAnimationFrame(() => this.predictLoop());
|
||
}
|
||
|
||
// 使用距离信息进行预测
|
||
async predictWithDistance(logits, k) {
|
||
// 如果没有训练数据,返回空结果
|
||
if (this.knnClassifier.getNumClasses() === 0) {
|
||
return {
|
||
label: -1,
|
||
confidences: {},
|
||
minDistance: Infinity,
|
||
isSingleClass: false
|
||
};
|
||
}
|
||
|
||
const numClasses = this.knnClassifier.getNumClasses();
|
||
|
||
// 单品类检测模式 - 使用实际距离计算
|
||
if (numClasses === 1) {
|
||
console.log('单品类检测模式 - 计算实际距离');
|
||
|
||
// 获取训练数据
|
||
const dataset = this.knnClassifier.getClassifierDataset();
|
||
if (!dataset || !dataset[0]) {
|
||
return {
|
||
label: -1,
|
||
confidences: { 0: 0 },
|
||
minDistance: Infinity,
|
||
isSingleClass: true
|
||
};
|
||
}
|
||
|
||
// 计算输入样本与所有训练样本的欧氏距离
|
||
const inputData = await logits.data();
|
||
const trainData = await dataset[0].data();
|
||
const numSamples = dataset[0].shape[0];
|
||
const featureDim = dataset[0].shape[1];
|
||
|
||
console.log(`输入特征维度: ${inputData.length}, 训练数据维度: ${featureDim}, 样本数: ${numSamples}`);
|
||
|
||
// 确保维度匹配
|
||
if (inputData.length !== featureDim) {
|
||
console.error(`维度不匹配!输入: ${inputData.length}, 训练: ${featureDim}`);
|
||
return {
|
||
label: -1,
|
||
confidences: { 0: 0 },
|
||
minDistance: Infinity,
|
||
isSingleClass: true
|
||
};
|
||
}
|
||
|
||
let minDistance = Infinity;
|
||
const distances = [];
|
||
|
||
// 计算与每个训练样本的距离
|
||
for (let i = 0; i < numSamples; i++) {
|
||
let distance = 0;
|
||
for (let j = 0; j < featureDim; j++) {
|
||
const diff = inputData[j] - trainData[i * featureDim + j];
|
||
distance += diff * diff;
|
||
}
|
||
distance = Math.sqrt(distance);
|
||
distances.push(distance);
|
||
if (distance < minDistance) {
|
||
minDistance = distance;
|
||
}
|
||
}
|
||
|
||
console.log(`计算了 ${distances.length} 个距离,最小距离: ${minDistance.toFixed(2)}`);
|
||
console.log(`前5个距离: ${distances.slice(0, 5).map(d => d.toFixed(2)).join(', ')}`);
|
||
|
||
// 获取K个最近邻的平均距离
|
||
distances.sort((a, b) => a - b);
|
||
const kNearest = distances.slice(0, Math.min(k, distances.length));
|
||
const avgDistance = kNearest.reduce((sum, d) => sum + d, 0) / kNearest.length;
|
||
|
||
// 从UI获取距离阈值
|
||
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '15.0');
|
||
|
||
// 基于距离阈值判断是否属于该类
|
||
const belongsToClass = avgDistance <= threshold;
|
||
|
||
// 二值化置信度:在阈值内100%,超出阈值0%
|
||
const confidence = belongsToClass ? 1.0 : 0;
|
||
|
||
console.log(`单品类预测 - 平均距离: ${avgDistance.toFixed(2)}, 阈值: ${threshold}, 属于类别: ${belongsToClass}, 置信度: ${confidence.toFixed(3)}`);
|
||
|
||
return {
|
||
label: belongsToClass ? 0 : -1,
|
||
confidences: { 0: confidence },
|
||
minDistance: avgDistance,
|
||
isSingleClass: true
|
||
};
|
||
}
|
||
|
||
// 多品类模式:使用KNN分类器预测
|
||
try {
|
||
const prediction = await this.knnClassifier.predictClass(logits, k);
|
||
|
||
console.log('多品类预测结果:', prediction);
|
||
console.log('预测标签:', prediction.label);
|
||
console.log('置信度:', prediction.confidences);
|
||
|
||
// 确保confidences存在且格式正确
|
||
let confidences = {};
|
||
if (prediction.confidences) {
|
||
// 检查是否是对象格式
|
||
if (typeof prediction.confidences === 'object') {
|
||
confidences = prediction.confidences;
|
||
}
|
||
}
|
||
|
||
// 如果confidences为空,手动计算
|
||
if (Object.keys(confidences).length === 0) {
|
||
console.warn('置信度为空,使用默认值');
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
confidences[i] = i === prediction.label ? 1.0 : 0;
|
||
}
|
||
}
|
||
|
||
// 计算实际距离(可选)
|
||
let minDistance = 0.5; // 默认距离
|
||
|
||
return {
|
||
label: prediction.label,
|
||
confidences: confidences,
|
||
minDistance: minDistance,
|
||
isSingleClass: false
|
||
};
|
||
} catch (error) {
|
||
console.error('预测错误:', error);
|
||
return {
|
||
label: -1,
|
||
confidences: {},
|
||
minDistance: Infinity,
|
||
isSingleClass: false
|
||
};
|
||
}
|
||
}
|
||
|
||
// 应用低通滤波器到置信度
|
||
applyLowPassFilter(prediction) {
|
||
// 获取滤波器系数
|
||
const alpha = parseFloat(document.getElementById('filterAlpha').value);
|
||
|
||
// 保留原始的特殊字段
|
||
const isSingleClass = prediction.isSingleClass || false;
|
||
const minDistance = prediction.minDistance;
|
||
const isUnknown = prediction.isUnknown;
|
||
|
||
// 初始化滤波状态(如果是第一次)
|
||
if (Object.keys(this.filteredConfidences).length === 0) {
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
this.filteredConfidences[i] = prediction.confidences[i] || 0;
|
||
}
|
||
return {
|
||
label: prediction.label,
|
||
confidences: {...this.filteredConfidences},
|
||
isSingleClass: isSingleClass,
|
||
minDistance: minDistance,
|
||
isUnknown: isUnknown
|
||
};
|
||
}
|
||
|
||
// 单品类模式下,使用二值化输出,不应用滤波
|
||
if (isSingleClass) {
|
||
// 获取距离阈值
|
||
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
|
||
|
||
// 二值化判断:在阈值内为1,超出为0
|
||
const inThreshold = minDistance <= threshold;
|
||
const confidence = inThreshold ? 1.0 : 0;
|
||
|
||
// 直接更新,不滤波
|
||
this.filteredConfidences[0] = confidence;
|
||
|
||
return {
|
||
label: inThreshold ? 0 : -1,
|
||
confidences: { 0: confidence },
|
||
isSingleClass: isSingleClass,
|
||
minDistance: minDistance,
|
||
isUnknown: !inThreshold
|
||
};
|
||
}
|
||
|
||
// 应用指数移动平均(EMA)低通滤波
|
||
const newConfidences = {};
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
const currentValue = prediction.confidences[i] || 0;
|
||
const previousValue = this.filteredConfidences[i] || 0;
|
||
|
||
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
|
||
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
|
||
newConfidences[i] = this.filteredConfidences[i];
|
||
}
|
||
|
||
// 归一化确保总和为1
|
||
let sum = 0;
|
||
Object.values(newConfidences).forEach(v => sum += v);
|
||
if (sum > 0) {
|
||
Object.keys(newConfidences).forEach(key => {
|
||
newConfidences[key] = newConfidences[key] / sum;
|
||
});
|
||
}
|
||
|
||
// 找到最高置信度的类别
|
||
let maxConfidence = 0;
|
||
let bestLabel = 0;
|
||
Object.keys(newConfidences).forEach(key => {
|
||
if (newConfidences[key] > maxConfidence) {
|
||
maxConfidence = newConfidences[key];
|
||
bestLabel = parseInt(key);
|
||
}
|
||
});
|
||
|
||
return {
|
||
label: bestLabel,
|
||
confidences: newConfidences,
|
||
isSingleClass: isSingleClass,
|
||
minDistance: minDistance,
|
||
isUnknown: isUnknown
|
||
};
|
||
}
|
||
|
||
// 显示预测结果
|
||
displayPrediction(prediction) {
|
||
const container = document.getElementById('predictions');
|
||
let html = '';
|
||
|
||
// 单品类模式特殊处理
|
||
if (this.classNames.length === 1) {
|
||
const className = this.classNames[0];
|
||
const confidence = prediction.confidences[0] || 0;
|
||
const percentage = (confidence * 100).toFixed(1);
|
||
const isDetected = prediction.label === 0; // 是否检测到该类
|
||
|
||
// 获取距离阈值
|
||
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
|
||
const distance = prediction.minDistance || 0;
|
||
|
||
// 显示单品类检测结果(二值化显示)
|
||
html = `
|
||
<div class="prediction-item" style="${isDetected ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : 'border-left-color: #cbd5e0;'}">
|
||
<div class="prediction-header">
|
||
<span class="prediction-label">
|
||
${className} ${isDetected ? '✓ 检测到' : '✗ 未检测到'}
|
||
</span>
|
||
<span class="prediction-confidence" style="${isDetected ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : 'background: #cbd5e0;'}">
|
||
${isDetected ? '100%' : '0%'}
|
||
</span>
|
||
</div>
|
||
<div class="confidence-bar-container">
|
||
<div class="confidence-bar ${isDetected ? 'high' : 'low'}" style="width: ${isDetected ? '100' : '0'}%;">
|
||
${isDetected ? `<span class="confidence-percentage">100%</span>` : ''}
|
||
</div>
|
||
</div>
|
||
</div>
|
||
<div style="margin-top: 10px; padding: 10px; background: #f8f9fa; border-radius: 5px; font-size: 12px; color: #666;">
|
||
<strong>距离:</strong> ${distance.toFixed(2)} | <strong>阈值:</strong> ${threshold.toFixed(2)}
|
||
<span style="margin-left: 10px; color: ${distance <= threshold ? '#48bb78' : '#f56565'};">
|
||
${distance <= threshold ? '✓ 在阈值范围内' : '✗ 超出阈值范围'}
|
||
</span>
|
||
</div>
|
||
`;
|
||
|
||
container.innerHTML = html;
|
||
return;
|
||
}
|
||
|
||
// 多品类模式:直接使用滤波后的置信度
|
||
const confidences = prediction.confidences;
|
||
const predictedClass = prediction.label;
|
||
|
||
// 固定顺序显示(按类别索引)
|
||
for (let i = 0; i < this.classNames.length; i++) {
|
||
const className = this.classNames[i];
|
||
const confidence = confidences[i] || 0;
|
||
const percentage = (confidence * 100).toFixed(1);
|
||
const isWinner = i === predictedClass;
|
||
|
||
// 根据置信度决定颜色等级
|
||
let barClass = '';
|
||
if (confidence > 0.7) barClass = 'high';
|
||
else if (confidence > 0.4) barClass = 'medium';
|
||
else barClass = 'low';
|
||
|
||
// 如果是获胜类别,使用绿色
|
||
if (isWinner) barClass = 'high';
|
||
|
||
html += `
|
||
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
|
||
<div class="prediction-header">
|
||
<span class="prediction-label">
|
||
${className} ${isWinner ? '👑' : ''}
|
||
</span>
|
||
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
|
||
${percentage}%
|
||
</span>
|
||
</div>
|
||
<div class="confidence-bar-container">
|
||
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
|
||
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
|
||
</div>
|
||
</div>
|
||
</div>
|
||
`;
|
||
}
|
||
|
||
container.innerHTML = html;
|
||
}
|
||
|
||
// 从摄像头捕获样本
|
||
async captureFromWebcam(classIndex) {
|
||
if (!this.webcamStream) {
|
||
// 临时启动摄像头
|
||
const video = document.getElementById('webcam');
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({
|
||
video: { facingMode: 'user' },
|
||
audio: false
|
||
});
|
||
|
||
video.srcObject = stream;
|
||
this.webcamStream = stream;
|
||
|
||
// 等待视频加载
|
||
setTimeout(async () => {
|
||
await this.addWebcamSample(classIndex);
|
||
|
||
// 停止临时摄像头
|
||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||
this.webcamStream = null;
|
||
video.srcObject = null;
|
||
}, 1000);
|
||
} catch (error) {
|
||
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||
}
|
||
} else {
|
||
await this.addWebcamSample(classIndex);
|
||
}
|
||
}
|
||
|
||
// 添加摄像头样本
|
||
async addWebcamSample(classIndex) {
|
||
const video = document.getElementById('webcam');
|
||
|
||
if (video.readyState === 4) {
|
||
try {
|
||
// 提取特征
|
||
const features = await this.extractImageNetTags(video);
|
||
|
||
// 添加到KNN分类器
|
||
this.knnClassifier.addExample(features.logits, classIndex);
|
||
|
||
// 更新计数
|
||
const currentCount = this.knnClassifier.getClassExampleCount();
|
||
const count = currentCount[classIndex] || 0;
|
||
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
|
||
|
||
// 清理
|
||
features.logits.dispose();
|
||
|
||
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
|
||
} catch (error) {
|
||
console.error('添加样本失败:', error);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 保存模型
|
||
async saveModel() {
|
||
if (this.knnClassifier.getNumClasses() === 0) {
|
||
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
|
||
return;
|
||
}
|
||
|
||
try {
|
||
// 获取KNN分类器的数据
|
||
const dataset = this.knnClassifier.getClassifierDataset();
|
||
const datasetObj = {};
|
||
|
||
Object.keys(dataset).forEach(key => {
|
||
const data = dataset[key].dataSync();
|
||
datasetObj[key] = Array.from(data);
|
||
});
|
||
|
||
// 获取特征维度
|
||
let featureDim = 1280; // 默认值
|
||
const firstKey = Object.keys(dataset)[0];
|
||
if (firstKey && dataset[firstKey]) {
|
||
featureDim = dataset[firstKey].shape[1];
|
||
}
|
||
|
||
// 保存为JSON
|
||
const modelData = {
|
||
dataset: datasetObj,
|
||
classNames: this.classNames,
|
||
k: document.getElementById('kValue').value,
|
||
featureDim: featureDim, // 保存特征维度
|
||
date: new Date().toISOString()
|
||
};
|
||
|
||
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
|
||
const url = URL.createObjectURL(blob);
|
||
const a = document.createElement('a');
|
||
a.href = url;
|
||
a.download = 'knn-model.json';
|
||
a.click();
|
||
URL.revokeObjectURL(url);
|
||
|
||
this.showStatus('predictionStatus', 'success', '模型已保存');
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
// 加载模型
|
||
async loadModel() {
|
||
const input = document.createElement('input');
|
||
input.type = 'file';
|
||
input.accept = '.json';
|
||
|
||
input.onchange = async (e) => {
|
||
try {
|
||
const file = e.target.files[0];
|
||
const text = await file.text();
|
||
const modelData = JSON.parse(text);
|
||
|
||
// 清空现有分类器
|
||
this.knnClassifier.clearAllClasses();
|
||
|
||
// 恢复数据集
|
||
Object.keys(modelData.dataset).forEach(key => {
|
||
const data = modelData.dataset[key];
|
||
|
||
// 自动检测特征维度(兼容旧模型)
|
||
let featureDim = modelData.featureDim;
|
||
if (!featureDim) {
|
||
// 尝试常见的维度
|
||
const possibleDims = [1280, 1024, 1000];
|
||
for (const dim of possibleDims) {
|
||
if (data.length % dim === 0) {
|
||
featureDim = dim;
|
||
console.warn(`自动检测到特征维度: ${dim}`);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
if (!featureDim) {
|
||
console.error(`无法确定特征维度,数据长度: ${data.length}`);
|
||
return;
|
||
}
|
||
|
||
const numSamples = data.length / featureDim;
|
||
console.log(`加载类别 ${key}:${numSamples} 个样本,${featureDim} 维特征`);
|
||
|
||
const tensor = tf.tensor(data, [numSamples, featureDim]);
|
||
this.knnClassifier.setClassifierDataset({ [key]: tensor });
|
||
});
|
||
|
||
this.classNames = modelData.classNames;
|
||
document.getElementById('kValue').value = modelData.k;
|
||
document.getElementById('kValueDisplay').textContent = modelData.k;
|
||
|
||
this.showStatus('predictionStatus', 'success',
|
||
`模型加载成功!类别: ${this.classNames.join(', ')}`);
|
||
} catch (error) {
|
||
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
|
||
}
|
||
};
|
||
|
||
input.click();
|
||
}
|
||
|
||
// 显示状态
|
||
showStatus(elementId, type, message) {
|
||
const element = document.getElementById(elementId);
|
||
|
||
const classMap = {
|
||
'success': 'status-success',
|
||
'error': 'status-error',
|
||
'info': 'status-info'
|
||
};
|
||
|
||
element.className = `status-message ${classMap[type]}`;
|
||
element.textContent = message;
|
||
}
|
||
}
|
||
|
||
// 全局函数:从摄像头捕获
|
||
function captureFromWebcam(classIndex) {
|
||
if (window.classifier) {
|
||
window.classifier.captureFromWebcam(classIndex);
|
||
}
|
||
}
|
||
|
||
// 初始化应用
|
||
let classifier;
|
||
document.addEventListener('DOMContentLoaded', () => {
|
||
classifier = new KNNImageClassifier();
|
||
window.classifier = classifier;
|
||
}); |