// 图像分类器 - 基于MobileNet特征标签和 decision-tree.js 实现随机森林
class ImageClassifier {
constructor() {
this.mobilenet = null;
this.randomForest = []; // 存储多个决策树
this.classNames = [];
this.webcamStream = null;
this.isPredicting = false;
this.imagenetClasses = null;
this.trainingSet = [];
this.numTrees = 10; // 随机森林中决策树的数量,可调整
this.subsetSize = 0.7; // 训练集子集大小, 可调整
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
await this.loadImageNetClasses();
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
async loadImageNetClasses() {
// ImageNet 前10个类别名称(简化版)
this.imagenetClasses = [
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich',
'brambling', 'goldfinch', 'house_finch', 'junco', 'indigo_bunting',
'robin', 'bulbul', 'jay', 'magpie', 'chickadee'
];
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.trainModel());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
// 参数监听
document.getElementById('numTrees').addEventListener('input', (e) => {
this.numTrees = parseInt(e.target.value);
});
document.getElementById('subsetSize').addEventListener('input', (e) => {
this.subsetSize = parseFloat(e.target.value);
});
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const countElement = document.getElementById(`class${classIndex + 1}Count`);
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
countElement.textContent = `${files.length} 张图片`;
// 清空之前的预览
previewContainer.innerHTML = '';
// 添加图片预览
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
}
async extractImageNetTags(img) {
try {
const predictions = await this.mobilenet.classify(img);
const logits = this.mobilenet.infer(img, false);
return {
logits: logits, // 1000维特征向量
predictions: predictions, // 前3个预测
topTags: await this.getTopKTags(logits, 10) // 前10个标签和权重
};
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
async getTopKTags(logits, k = 10) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
const topkValues = new Float32Array(k);
const topkIndices = new Int32Array(k);
for (let i = 0; i < k; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topTags = [];
for (let i = 0; i < k; i++) {
topTags.push({
className: this.imagenetClasses[topkIndices[i]] || `class_${topkIndices[i]}`,
probability: this.softmax(topkValues)[i],
logit: topkValues[i]
});
}
return topTags;
}
softmax(arr) {
const maxLogit = Math.max(...arr);
const scores = arr.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b);
return scores.map(s => s / sum);
}
// 训练随机森林模型
async trainModel() {
const classes = [];
const imageFiles = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files.length > 0) {
classes.push(className);
imageFiles.push(files);
}
}
if (classes.length < 2) {
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
return;
}
this.classNames = classes;
this.showStatus('dataStatus', 'info', '正在处理图片并训练模型...');
// 准备训练数据
this.trainingSet = [];
let totalProcessed = 0;
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imageFiles[classIndex];
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
const img = await this.loadImage(files[fileIndex]);
const features = await this.extractImageNetTags(img);
// 将logits从tf.Tensor转换为数组
const featureVector = await features.logits.data();
// 将特征向量添加到训练数据
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
item.category = classes[classIndex]; // 类别名称,而不是索引
this.trainingSet.push(item);
totalProcessed++;
const progress = Math.round((totalProcessed / totalImages) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
img.remove();
features.logits.dispose();
} catch (error) {
console.error('处理图片失败:', error);
}
}
}
try {
// 构建随机森林
this.randomForest = [];
for (let i = 0; i < this.numTrees; i++) {
// 创建具有随机子集的决策树
const trainingSubset = this.createTrainingSubset(this.trainingSet);
const builder = {
trainingSet: trainingSubset,
categoryAttr: 'category'
};
const tree = new dt.DecisionTree(builder);
this.randomForest.push(tree);
}
this.showStatus('dataStatus', 'success', `模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别, ${this.numTrees} 棵树`);
// 更新模型信息
document.getElementById('totalSamples').textContent = totalProcessed;
} catch (error) {
console.error('训练失败:', error);
this.showStatus('dataStatus', 'error', `训练失败: ${error.message}`);
}
}
// 创建训练数据的随机子集(用于随机森林)
createTrainingSubset(trainingSet) {
const subset = [];
const subsetSize = Math.floor(this.subsetSize * trainingSet.length);
for (let i = 0; i < subsetSize; i++) {
const randomIndex = Math.floor(Math.random() * trainingSet.length);
subset.push(trainingSet[randomIndex]);
}
return subset;
}
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
clearDataset() {
this.randomForest = [];
this.classNames = [];
this.trainingSet = [];
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
document.getElementById(`class${i}Preview`).innerHTML = '';
}
document.getElementById('totalSamples').textContent = '0';
document.getElementById('predictions').innerHTML = '等待预测...';
this.showStatus('dataStatus', 'info', '数据集已清空');
}
startWebcam() {
if (this.randomForest.length == 0) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
})
.then(stream => {
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
this.isPredicting = true;
this.predictLoop();
this.showStatus('predictionStatus', 'success', '摄像头已启动');
})
.catch(error => {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
});
}
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
const features = await this.extractImageNetTags(video);
const featureVector = await features.logits.data();
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
const { predictedCategory, probabilities } = this.predictWithRandomForest(item);
this.displayPrediction(predictedCategory, probabilities);
features.logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
requestAnimationFrame(() => this.predictLoop());
}
predictWithRandomForest(item) {
const votes = {};
this.classNames.forEach(className => {
votes[className] = 0;
});
this.randomForest.forEach(tree => {
const prediction = tree.predict(item);
votes[prediction] = (votes[prediction] || 0) + 1;
});
let predictedCategory = null;
let maxVotes = 0;
for (const category in votes) {
if (votes[category] > maxVotes) {
predictedCategory = category;
maxVotes = votes[category];
}
}
const probabilities = {};
for (const category in votes) {
probabilities[category] = votes[category] / this.numTrees;
}
return {
predictedCategory: predictedCategory,
probabilities: probabilities
};
}
async captureFromWebcam(classIndex) {
if (!this.webcamStream) {
// 临时启动摄像头
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
// 等待视频加载
setTimeout(async () => {
await this.addWebcamSample(classIndex);
// 停止临时摄像头
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
video.srcObject = null;
}, 1000);
} catch (error) {
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
}
} else {
await this.addWebcamSample(classIndex);
}
}
async addWebcamSample(classIndex) {
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
const features = await this.extractImageNetTags(video);
const featureVector = await features.logits.data();
// 使用logits作为对象的属性
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
// 添加类别信息
const className = document.getElementById(`class${classIndex + 1}Name`).value.trim();
item.category = className;
this.trainingSet.push(item);
this.showStatus('dataStatus', 'success', `从摄像头添加 ${className} 类的样本`);
features.logits.dispose();
} catch (error) {
console.error('添加样本失败:', error);
this.showStatus('dataStatus', 'error', `添加样本失败: ${error.message}`);
}
}
}
displayPrediction(category, probabilities) {
const container = document.getElementById('predictions');
let html = `预测类别:${category}
`;
for (const className in probabilities) {
const probability = (probabilities[className] * 100).toFixed(2);
html += `${className}: ${probability}%
`;
}
container.innerHTML = html;
}
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info'
};
element.className = `status-message ${classMap[type]}`;
element.textContent = message;
}
}
// 全局函数:从摄像头采集
function captureFromWebcam(classIndex) {
if (window.classifier) {
window.classifier.captureFromWebcam(classIndex);
}
}
document.addEventListener('DOMContentLoaded', async () => {
window.classifier = new ImageClassifier();
});