473 lines
16 KiB
JavaScript
473 lines
16 KiB
JavaScript
// 图像分类器 - 基于MobileNet特征标签和 decision-tree.js 实现随机森林
|
||
|
||
class ImageClassifier {
|
||
constructor() {
|
||
this.mobilenet = null;
|
||
this.randomForest = []; // 存储多个决策树
|
||
this.classNames = [];
|
||
this.webcamStream = null;
|
||
this.isPredicting = false;
|
||
this.imagenetClasses = null;
|
||
this.trainingSet = [];
|
||
this.numTrees = 10; // 随机森林中决策树的数量,可调整
|
||
this.subsetSize = 0.7; // 训练集子集大小, 可调整
|
||
|
||
this.init();
|
||
}
|
||
|
||
async init() {
|
||
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
|
||
|
||
try {
|
||
this.mobilenet = await mobilenet.load({
|
||
version: 2,
|
||
alpha: 1.0
|
||
});
|
||
|
||
await this.loadImageNetClasses();
|
||
|
||
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
|
||
this.setupEventListeners();
|
||
} catch (error) {
|
||
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
async loadImageNetClasses() {
|
||
// ImageNet 前10个类别名称(简化版)
|
||
this.imagenetClasses = [
|
||
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
|
||
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich',
|
||
'brambling', 'goldfinch', 'house_finch', 'junco', 'indigo_bunting',
|
||
'robin', 'bulbul', 'jay', 'magpie', 'chickadee'
|
||
];
|
||
}
|
||
|
||
setupEventListeners() {
|
||
// 文件上传监听
|
||
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
|
||
document.getElementById(id).addEventListener('change', (e) => {
|
||
this.handleImageUpload(e, index);
|
||
});
|
||
});
|
||
|
||
// 按钮监听
|
||
document.getElementById('addDataBtn').addEventListener('click', () => this.trainModel());
|
||
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
|
||
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
|
||
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
|
||
|
||
// 参数监听
|
||
document.getElementById('numTrees').addEventListener('input', (e) => {
|
||
this.numTrees = parseInt(e.target.value);
|
||
});
|
||
|
||
document.getElementById('subsetSize').addEventListener('input', (e) => {
|
||
this.subsetSize = parseFloat(e.target.value);
|
||
});
|
||
}
|
||
|
||
handleImageUpload(event, classIndex) {
|
||
const files = event.target.files;
|
||
const countElement = document.getElementById(`class${classIndex + 1}Count`);
|
||
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
|
||
|
||
countElement.textContent = `${files.length} 张图片`;
|
||
|
||
// 清空之前的预览
|
||
previewContainer.innerHTML = '';
|
||
|
||
// 添加图片预览
|
||
Array.from(files).forEach(file => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = document.createElement('img');
|
||
img.src = e.target.result;
|
||
img.className = 'preview-img';
|
||
previewContainer.appendChild(img);
|
||
};
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
async extractImageNetTags(img) {
|
||
try {
|
||
const predictions = await this.mobilenet.classify(img);
|
||
const logits = this.mobilenet.infer(img, false);
|
||
return {
|
||
logits: logits, // 1000维特征向量
|
||
predictions: predictions, // 前3个预测
|
||
topTags: await this.getTopKTags(logits, 10) // 前10个标签和权重
|
||
};
|
||
} catch (error) {
|
||
console.error('特征提取失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
async getTopKTags(logits, k = 10) {
|
||
const values = await logits.data();
|
||
const valuesAndIndices = [];
|
||
|
||
for (let i = 0; i < values.length; i++) {
|
||
valuesAndIndices.push({ value: values[i], index: i });
|
||
}
|
||
|
||
valuesAndIndices.sort((a, b) => b.value - a.value);
|
||
const topkValues = new Float32Array(k);
|
||
const topkIndices = new Int32Array(k);
|
||
|
||
for (let i = 0; i < k; i++) {
|
||
topkValues[i] = valuesAndIndices[i].value;
|
||
topkIndices[i] = valuesAndIndices[i].index;
|
||
}
|
||
|
||
const topTags = [];
|
||
for (let i = 0; i < k; i++) {
|
||
topTags.push({
|
||
className: this.imagenetClasses[topkIndices[i]] || `class_${topkIndices[i]}`,
|
||
probability: this.softmax(topkValues)[i],
|
||
logit: topkValues[i]
|
||
});
|
||
}
|
||
|
||
return topTags;
|
||
}
|
||
|
||
softmax(arr) {
|
||
const maxLogit = Math.max(...arr);
|
||
const scores = arr.map(l => Math.exp(l - maxLogit));
|
||
const sum = scores.reduce((a, b) => a + b);
|
||
return scores.map(s => s / sum);
|
||
}
|
||
// 训练随机森林模型
|
||
async trainModel() {
|
||
const classes = [];
|
||
const imageFiles = [];
|
||
|
||
// 收集所有类别和图片
|
||
for (let i = 1; i <= 3; i++) {
|
||
const className = document.getElementById(`class${i}Name`).value.trim();
|
||
const files = document.getElementById(`class${i}Images`).files;
|
||
|
||
if (className && files.length > 0) {
|
||
classes.push(className);
|
||
imageFiles.push(files);
|
||
}
|
||
}
|
||
|
||
if (classes.length < 2) {
|
||
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
|
||
return;
|
||
}
|
||
|
||
this.classNames = classes;
|
||
this.showStatus('dataStatus', 'info', '正在处理图片并训练模型...');
|
||
|
||
// 准备训练数据
|
||
this.trainingSet = [];
|
||
|
||
let totalProcessed = 0;
|
||
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
|
||
|
||
// 处理每个类别的图片
|
||
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
|
||
const files = imageFiles[classIndex];
|
||
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
|
||
try {
|
||
const img = await this.loadImage(files[fileIndex]);
|
||
const features = await this.extractImageNetTags(img);
|
||
// 将logits从tf.Tensor转换为数组
|
||
const featureVector = await features.logits.data();
|
||
|
||
// 将特征向量添加到训练数据
|
||
const item = {};
|
||
Array.from(featureVector).forEach((value, index) => {
|
||
item[`feature_${index}`] = value;
|
||
});
|
||
item.category = classes[classIndex]; // 类别名称,而不是索引
|
||
this.trainingSet.push(item);
|
||
|
||
totalProcessed++;
|
||
const progress = Math.round((totalProcessed / totalImages) * 100);
|
||
this.showStatus('dataStatus', 'info',
|
||
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
|
||
|
||
img.remove();
|
||
features.logits.dispose();
|
||
} catch (error) {
|
||
console.error('处理图片失败:', error);
|
||
}
|
||
}
|
||
}
|
||
|
||
try {
|
||
// 构建随机森林
|
||
this.randomForest = [];
|
||
for (let i = 0; i < this.numTrees; i++) {
|
||
// 创建具有随机子集的决策树
|
||
const trainingSubset = this.createTrainingSubset(this.trainingSet);
|
||
|
||
const builder = {
|
||
trainingSet: trainingSubset,
|
||
categoryAttr: 'category'
|
||
};
|
||
const tree = new dt.DecisionTree(builder);
|
||
this.randomForest.push(tree);
|
||
}
|
||
|
||
this.showStatus('dataStatus', 'success', `模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别, ${this.numTrees} 棵树`);
|
||
|
||
// 更新模型信息
|
||
document.getElementById('totalSamples').textContent = totalProcessed;
|
||
} catch (error) {
|
||
console.error('训练失败:', error);
|
||
this.showStatus('dataStatus', 'error', `训练失败: ${error.message}`);
|
||
}
|
||
}
|
||
|
||
// 创建训练数据的随机子集(用于随机森林)
|
||
createTrainingSubset(trainingSet) {
|
||
const subset = [];
|
||
const subsetSize = Math.floor(this.subsetSize * trainingSet.length);
|
||
|
||
for (let i = 0; i < subsetSize; i++) {
|
||
const randomIndex = Math.floor(Math.random() * trainingSet.length);
|
||
subset.push(trainingSet[randomIndex]);
|
||
}
|
||
|
||
return subset;
|
||
}
|
||
|
||
async loadImage(file) {
|
||
return new Promise((resolve, reject) => {
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
const img = new Image();
|
||
img.onload = () => resolve(img);
|
||
img.onerror = reject;
|
||
img.src = e.target.result;
|
||
};
|
||
reader.onerror = reject;
|
||
reader.readAsDataURL(file);
|
||
});
|
||
}
|
||
|
||
clearDataset() {
|
||
this.randomForest = [];
|
||
this.classNames = [];
|
||
this.trainingSet = [];
|
||
|
||
for (let i = 1; i <= 3; i++) {
|
||
document.getElementById(`class${i}Images`).value = '';
|
||
document.getElementById(`class${i}Count`).textContent = '0 张图片';
|
||
document.getElementById(`class${i}Preview`).innerHTML = '';
|
||
}
|
||
|
||
document.getElementById('totalSamples').textContent = '0';
|
||
document.getElementById('predictions').innerHTML = '等待预测...';
|
||
|
||
this.showStatus('dataStatus', 'info', '数据集已清空');
|
||
}
|
||
|
||
startWebcam() {
|
||
if (this.randomForest.length == 0) {
|
||
this.showStatus('predictionStatus', 'error', '请先训练模型!');
|
||
return;
|
||
}
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
navigator.mediaDevices.getUserMedia({
|
||
video: { facingMode: 'user' },
|
||
audio: false
|
||
})
|
||
.then(stream => {
|
||
video.srcObject = stream;
|
||
this.webcamStream = stream;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = true;
|
||
document.getElementById('stopWebcamBtn').disabled = false;
|
||
|
||
this.isPredicting = true;
|
||
this.predictLoop();
|
||
|
||
this.showStatus('predictionStatus', 'success', '摄像头已启动');
|
||
})
|
||
.catch(error => {
|
||
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||
});
|
||
}
|
||
|
||
stopWebcam() {
|
||
if (this.webcamStream) {
|
||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||
this.webcamStream = null;
|
||
}
|
||
|
||
this.isPredicting = false;
|
||
|
||
const video = document.getElementById('webcam');
|
||
video.srcObject = null;
|
||
|
||
document.getElementById('startWebcamBtn').disabled = false;
|
||
document.getElementById('stopWebcamBtn').disabled = true;
|
||
|
||
this.showStatus('predictionStatus', 'info', '摄像头已停止');
|
||
}
|
||
|
||
async predictLoop() {
|
||
if (!this.isPredicting) return;
|
||
|
||
const video = document.getElementById('webcam');
|
||
|
||
if (video.readyState === 4) {
|
||
try {
|
||
const features = await this.extractImageNetTags(video);
|
||
const featureVector = await features.logits.data();
|
||
|
||
const item = {};
|
||
Array.from(featureVector).forEach((value, index) => {
|
||
item[`feature_${index}`] = value;
|
||
});
|
||
|
||
const { predictedCategory, probabilities } = this.predictWithRandomForest(item);
|
||
|
||
this.displayPrediction(predictedCategory, probabilities);
|
||
features.logits.dispose();
|
||
|
||
|
||
} catch (error) {
|
||
console.error('预测错误:', error);
|
||
}
|
||
|
||
}
|
||
|
||
requestAnimationFrame(() => this.predictLoop());
|
||
}
|
||
|
||
predictWithRandomForest(item) {
|
||
const votes = {};
|
||
this.classNames.forEach(className => {
|
||
votes[className] = 0;
|
||
});
|
||
|
||
this.randomForest.forEach(tree => {
|
||
const prediction = tree.predict(item);
|
||
votes[prediction] = (votes[prediction] || 0) + 1;
|
||
});
|
||
|
||
let predictedCategory = null;
|
||
let maxVotes = 0;
|
||
for (const category in votes) {
|
||
if (votes[category] > maxVotes) {
|
||
predictedCategory = category;
|
||
maxVotes = votes[category];
|
||
}
|
||
}
|
||
|
||
const probabilities = {};
|
||
for (const category in votes) {
|
||
probabilities[category] = votes[category] / this.numTrees;
|
||
}
|
||
|
||
return {
|
||
predictedCategory: predictedCategory,
|
||
probabilities: probabilities
|
||
};
|
||
}
|
||
|
||
async captureFromWebcam(classIndex) {
|
||
if (!this.webcamStream) {
|
||
// 临时启动摄像头
|
||
const video = document.getElementById('webcam');
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({
|
||
video: { facingMode: 'user' },
|
||
audio: false
|
||
});
|
||
|
||
video.srcObject = stream;
|
||
this.webcamStream = stream;
|
||
|
||
// 等待视频加载
|
||
setTimeout(async () => {
|
||
await this.addWebcamSample(classIndex);
|
||
|
||
// 停止临时摄像头
|
||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||
this.webcamStream = null;
|
||
video.srcObject = null;
|
||
}, 1000);
|
||
} catch (error) {
|
||
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||
}
|
||
} else {
|
||
await this.addWebcamSample(classIndex);
|
||
}
|
||
}
|
||
|
||
async addWebcamSample(classIndex) {
|
||
const video = document.getElementById('webcam');
|
||
|
||
if (video.readyState === 4) {
|
||
try {
|
||
const features = await this.extractImageNetTags(video);
|
||
|
||
const featureVector = await features.logits.data();
|
||
// 使用logits作为对象的属性
|
||
const item = {};
|
||
|
||
Array.from(featureVector).forEach((value, index) => {
|
||
item[`feature_${index}`] = value;
|
||
});
|
||
|
||
// 添加类别信息
|
||
const className = document.getElementById(`class${classIndex + 1}Name`).value.trim();
|
||
|
||
item.category = className;
|
||
|
||
this.trainingSet.push(item);
|
||
this.showStatus('dataStatus', 'success', `从摄像头添加 ${className} 类的样本`);
|
||
features.logits.dispose();
|
||
} catch (error) {
|
||
console.error('添加样本失败:', error);
|
||
this.showStatus('dataStatus', 'error', `添加样本失败: ${error.message}`);
|
||
}
|
||
}
|
||
}
|
||
|
||
displayPrediction(category, probabilities) {
|
||
const container = document.getElementById('predictions');
|
||
let html = `预测类别:${category}<br>`;
|
||
for (const className in probabilities) {
|
||
const probability = (probabilities[className] * 100).toFixed(2);
|
||
html += `${className}: ${probability}%<br>`;
|
||
}
|
||
container.innerHTML = html;
|
||
}
|
||
|
||
showStatus(elementId, type, message) {
|
||
const element = document.getElementById(elementId);
|
||
|
||
const classMap = {
|
||
'success': 'status-success',
|
||
'error': 'status-error',
|
||
'info': 'status-info'
|
||
};
|
||
|
||
element.className = `status-message ${classMap[type]}`;
|
||
element.textContent = message;
|
||
}
|
||
}
|
||
|
||
// 全局函数:从摄像头采集
|
||
function captureFromWebcam(classIndex) {
|
||
if (window.classifier) {
|
||
window.classifier.captureFromWebcam(classIndex);
|
||
}
|
||
}
|
||
document.addEventListener('DOMContentLoaded', async () => {
|
||
window.classifier = new ImageClassifier();
|
||
});
|