mobileNet/废弃/随机森林/rf-classifier.js

473 lines
16 KiB
JavaScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// 图像分类器 - 基于MobileNet特征标签和 decision-tree.js 实现随机森林
class ImageClassifier {
constructor() {
this.mobilenet = null;
this.randomForest = []; // 存储多个决策树
this.classNames = [];
this.webcamStream = null;
this.isPredicting = false;
this.imagenetClasses = null;
this.trainingSet = [];
this.numTrees = 10; // 随机森林中决策树的数量,可调整
this.subsetSize = 0.7; // 训练集子集大小, 可调整
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
await this.loadImageNetClasses();
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
async loadImageNetClasses() {
// ImageNet 前10个类别名称简化版
this.imagenetClasses = [
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich',
'brambling', 'goldfinch', 'house_finch', 'junco', 'indigo_bunting',
'robin', 'bulbul', 'jay', 'magpie', 'chickadee'
];
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.trainModel());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
// 参数监听
document.getElementById('numTrees').addEventListener('input', (e) => {
this.numTrees = parseInt(e.target.value);
});
document.getElementById('subsetSize').addEventListener('input', (e) => {
this.subsetSize = parseFloat(e.target.value);
});
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const countElement = document.getElementById(`class${classIndex + 1}Count`);
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
countElement.textContent = `${files.length} 张图片`;
// 清空之前的预览
previewContainer.innerHTML = '';
// 添加图片预览
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
}
async extractImageNetTags(img) {
try {
const predictions = await this.mobilenet.classify(img);
const logits = this.mobilenet.infer(img, false);
return {
logits: logits, // 1000维特征向量
predictions: predictions, // 前3个预测
topTags: await this.getTopKTags(logits, 10) // 前10个标签和权重
};
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
async getTopKTags(logits, k = 10) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
const topkValues = new Float32Array(k);
const topkIndices = new Int32Array(k);
for (let i = 0; i < k; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topTags = [];
for (let i = 0; i < k; i++) {
topTags.push({
className: this.imagenetClasses[topkIndices[i]] || `class_${topkIndices[i]}`,
probability: this.softmax(topkValues)[i],
logit: topkValues[i]
});
}
return topTags;
}
softmax(arr) {
const maxLogit = Math.max(...arr);
const scores = arr.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b);
return scores.map(s => s / sum);
}
// 训练随机森林模型
async trainModel() {
const classes = [];
const imageFiles = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files.length > 0) {
classes.push(className);
imageFiles.push(files);
}
}
if (classes.length < 2) {
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
return;
}
this.classNames = classes;
this.showStatus('dataStatus', 'info', '正在处理图片并训练模型...');
// 准备训练数据
this.trainingSet = [];
let totalProcessed = 0;
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imageFiles[classIndex];
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
const img = await this.loadImage(files[fileIndex]);
const features = await this.extractImageNetTags(img);
// 将logits从tf.Tensor转换为数组
const featureVector = await features.logits.data();
// 将特征向量添加到训练数据
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
item.category = classes[classIndex]; // 类别名称,而不是索引
this.trainingSet.push(item);
totalProcessed++;
const progress = Math.round((totalProcessed / totalImages) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
img.remove();
features.logits.dispose();
} catch (error) {
console.error('处理图片失败:', error);
}
}
}
try {
// 构建随机森林
this.randomForest = [];
for (let i = 0; i < this.numTrees; i++) {
// 创建具有随机子集的决策树
const trainingSubset = this.createTrainingSubset(this.trainingSet);
const builder = {
trainingSet: trainingSubset,
categoryAttr: 'category'
};
const tree = new dt.DecisionTree(builder);
this.randomForest.push(tree);
}
this.showStatus('dataStatus', 'success', `模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别, ${this.numTrees} 棵树`);
// 更新模型信息
document.getElementById('totalSamples').textContent = totalProcessed;
} catch (error) {
console.error('训练失败:', error);
this.showStatus('dataStatus', 'error', `训练失败: ${error.message}`);
}
}
// 创建训练数据的随机子集(用于随机森林)
createTrainingSubset(trainingSet) {
const subset = [];
const subsetSize = Math.floor(this.subsetSize * trainingSet.length);
for (let i = 0; i < subsetSize; i++) {
const randomIndex = Math.floor(Math.random() * trainingSet.length);
subset.push(trainingSet[randomIndex]);
}
return subset;
}
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
clearDataset() {
this.randomForest = [];
this.classNames = [];
this.trainingSet = [];
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
document.getElementById(`class${i}Preview`).innerHTML = '';
}
document.getElementById('totalSamples').textContent = '0';
document.getElementById('predictions').innerHTML = '等待预测...';
this.showStatus('dataStatus', 'info', '数据集已清空');
}
startWebcam() {
if (this.randomForest.length == 0) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
})
.then(stream => {
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
this.isPredicting = true;
this.predictLoop();
this.showStatus('predictionStatus', 'success', '摄像头已启动');
})
.catch(error => {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
});
}
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
const features = await this.extractImageNetTags(video);
const featureVector = await features.logits.data();
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
const { predictedCategory, probabilities } = this.predictWithRandomForest(item);
this.displayPrediction(predictedCategory, probabilities);
features.logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
requestAnimationFrame(() => this.predictLoop());
}
predictWithRandomForest(item) {
const votes = {};
this.classNames.forEach(className => {
votes[className] = 0;
});
this.randomForest.forEach(tree => {
const prediction = tree.predict(item);
votes[prediction] = (votes[prediction] || 0) + 1;
});
let predictedCategory = null;
let maxVotes = 0;
for (const category in votes) {
if (votes[category] > maxVotes) {
predictedCategory = category;
maxVotes = votes[category];
}
}
const probabilities = {};
for (const category in votes) {
probabilities[category] = votes[category] / this.numTrees;
}
return {
predictedCategory: predictedCategory,
probabilities: probabilities
};
}
async captureFromWebcam(classIndex) {
if (!this.webcamStream) {
// 临时启动摄像头
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
// 等待视频加载
setTimeout(async () => {
await this.addWebcamSample(classIndex);
// 停止临时摄像头
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
video.srcObject = null;
}, 1000);
} catch (error) {
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
}
} else {
await this.addWebcamSample(classIndex);
}
}
async addWebcamSample(classIndex) {
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
const features = await this.extractImageNetTags(video);
const featureVector = await features.logits.data();
// 使用logits作为对象的属性
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
// 添加类别信息
const className = document.getElementById(`class${classIndex + 1}Name`).value.trim();
item.category = className;
this.trainingSet.push(item);
this.showStatus('dataStatus', 'success', `从摄像头添加 ${className} 类的样本`);
features.logits.dispose();
} catch (error) {
console.error('添加样本失败:', error);
this.showStatus('dataStatus', 'error', `添加样本失败: ${error.message}`);
}
}
}
displayPrediction(category, probabilities) {
const container = document.getElementById('predictions');
let html = `预测类别:${category}<br>`;
for (const className in probabilities) {
const probability = (probabilities[className] * 100).toFixed(2);
html += `${className}: ${probability}%<br>`;
}
container.innerHTML = html;
}
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info'
};
element.className = `status-message ${classMap[type]}`;
element.textContent = message;
}
}
// 全局函数:从摄像头采集
function captureFromWebcam(classIndex) {
if (window.classifier) {
window.classifier.captureFromWebcam(classIndex);
}
}
document.addEventListener('DOMContentLoaded', async () => {
window.classifier = new ImageClassifier();
});