525 lines
22 KiB
JavaScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// script.js
const VIDEO = document.getElementById('webcam');
// const CONNECT_SERIAL_BTN = document.getElementById('connectSerialBtn'); // REMOVED
// const DISCONNECT_SERIAL_BTN = document.getElementById('disconnectSerialBtn'); // REMOVED
const LOAD_MODEL_BTN = document.getElementById('loadModelBtn');
const START_WEBCAM_BTN = document.getElementById('startWebcamBtn');
const STOP_WEBCAM_BTN = document.getElementById('stopWebcamBtn');
const MODEL_STATUS = document.getElementById('modelStatus');
// const SERIAL_STATUS = document.getElementById('serialStatus'); // REMOVED
const PREDICTION_OUTPUT = document.getElementById('prediction');
const WEBCAM_STATUS_DISPLAY = document.getElementById('webcam-status-display');
let mobilenet;
let knnClassifier;
let classNames = [];
let webcamStream = null;
let isPredicting = false;
// REMOVED: Web Serial API variables
// let serialPort = null;
// let serialWriter = null;
// const SERIAL_BAUD_RATE = 9600;
// const SERIAL_SEND_MIN_INTERVAL = 500;
// let lastSerialCommand = '';
// let lastSerialSendTime = 0;
// REMOVED: Serial connection state variables
// let isSerialConnectedState = false;
// let lastSentClassCommand = null;
// REMOVED: Confirmation sending logic variables
// let pendingCommandToSend = null;
// let pendingCommandTimerId = null;
// const CONFIRMATION_DELAY_MS = 100;
// ===================================
// Helper Functions (UI Status)
// ===================================
function showStatus(element, type, message) {
element.className = `status-message status-${type}`;
element.textContent = message;
}
// REMOVED: updateSerialUI function
/*
function updateSerialUI(isConnected) {
CONNECT_SERIAL_BTN.disabled = isConnected;
DISCONNECT_SERIAL_BTN.disabled = !isConnected;
isSerialConnectedState = isConnected;
if (!isConnected) {
showStatus(SERIAL_STATUS, 'info', '串口未连接。点击 "连接串口" 开始。');
}
}
*/
function updateWebcamUI(isRunning) {
START_WEBCAM_BTN.disabled = isRunning;
STOP_WEBCAM_BTN.disabled = !isRunning;
if (isRunning) {
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '摄像头已启动,等待模型预测...');
PREDICTION_OUTPUT.classList.remove('idle', 'error');
} else {
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '摄像头未启动');
PREDICTION_OUTPUT.classList.add('idle');
PREDICTION_OUTPUT.textContent = '等待识别...';
}
}
function updateModelUI(isLoaded) {
LOAD_MODEL_BTN.disabled = false;
START_WEBCAM_BTN.disabled = !isLoaded;
}
// ===================================
// Core Logic: Model & Webcam
// ===================================
async function initModel() {
showStatus(MODEL_STATUS, 'info', '正在加载 MobileNet 模型...');
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '系统初始化中...');
try {
if (!window.tf || !window.mobilenet || !window.knnClassifier) { // Added tf check
showStatus(MODEL_STATUS, 'error', 'TensorFlow.js 核心库或模型库未加载。请检查 HTML 引入。');
console.error('TensorFlow.js 核心库或模型库未加载。');
return;
}
mobilenet = await window.mobilenet.load({ version: 2, alpha: 1.0 });
knnClassifier = window.knnClassifier.create();
showStatus(MODEL_STATUS, 'success', 'MobileNet 模型和 KNN 分类器已加载。');
updateModelUI(false);
const cdnModelBaseUrl = 'https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/models/';
const cdnModelJsonFileName = 'knn-model-juzi.json';
const cdnModelBinFileName = 'knn-model-juzi.bin';
const cdnJsonUrl = `${cdnModelBaseUrl}${cdnModelJsonFileName}`;
const cdnBinUrl = `${cdnModelBaseUrl}${cdnModelBinFileName}`;
// 模型自动加载
// console.log(`尝试从 CDN 加载模型: ${cdnJsonUrl}, ${cdnBinUrl}`);
// showStatus(MODEL_STATUS, 'info', '正在尝试从 CDN 自动加载 KNN 模型...');
// try {
// await loadKNNModel(cdnJsonUrl, cdnBinUrl);
// console.log('CDN 模型自动加载成功。');
// } catch (cdnError) {
// showStatus(MODEL_STATUS, 'warning', `从 CDN 加载 KNN 模型失败: ${cdnError.message}。您可以尝试手动加载。`);
// console.warn('CDN KNN 模型加载失败:', cdnError);
// updateModelUI(false);
// }
} catch (error) {
showStatus(MODEL_STATUS, 'error', `模型加载失败: ${error.message}`);
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '模型加载失败');
console.error('MobileNet/KNN加载失败:', error);
}
}
async function getFeatures(img) {
if (!mobilenet) {
throw new Error("MobileNet model is not loaded.");
}
return tf.tidy(() => {
const embeddings = mobilenet.infer(img, true);
const norm = tf.norm(embeddings);
const normalized = tf.div(embeddings, norm);
return normalized;
});
}
// loadSingleJsonModel 保持不变
async function loadSingleJsonModel(modelData) {
try {
knnClassifier.clearAllClasses();
Object.keys(modelData.dataset).forEach(key => {
const data = modelData.dataset[key];
const featureDim = modelData.featureDim || 1280;
if (data.length % featureDim !== 0) {
throw new Error(`类别 ${key} 的特征数据长度 ${data.length} 与特征维度 ${featureDim} 不匹配!`);
}
const numSamples = data.length / featureDim;
const tensor = tf.tensor(data, [numSamples, featureDim]);
knnClassifier.addExample(tensor, parseInt(key));
tf.dispose(tensor);
});
if (modelData.classList && Array.isArray(modelData.classList)) {
classNames = modelData.classList.map(c => c.name);
} else if (modelData.classNames && Array.isArray(modelData.classNames)) {
classNames = modelData.classNames;
} else {
console.warn('模型JSON中未找到 classList/classNames 字段,使用默认类别名称。');
classNames = Object.keys(modelData.dataset).map(key => `Class ${parseInt(key) + 1}`);
}
showStatus(MODEL_STATUS, 'success', `模型 (单文件JSON格式) 加载成功!类别: ${classNames.join(', ')}`);
updateModelUI(true);
} catch (error) {
showStatus(MODEL_STATUS, 'error', `加载单文件JSON模型失败: ${error.message}`);
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '模型加载失败');
console.error('加载单文件JSON模型失败:', error);
updateModelUI(false);
throw error;
}
}
async function loadKNNModel(jsonUrl = null, binUrl = null) {
if (!knnClassifier) {
showStatus(MODEL_STATUS, 'error', 'KNN 分类器未初始化。请先加载 MobileNet 模型。');
return;
}
let modelData = null;
let binData = null;
let modelName = '未知模型';
try {
if (jsonUrl && binUrl) {
showStatus(MODEL_STATUS, 'info', `正在从 CDN 加载模型配置文件 (${jsonUrl})...`);
const jsonResponse = await fetch(jsonUrl);
if (!jsonResponse.ok) {
throw new Error(`无法从 ${jsonUrl} 加载.json文件: ${jsonResponse.statusText}`);
}
modelData = await jsonResponse.json();
modelName = jsonUrl.split('/').pop();
showStatus(MODEL_STATUS, 'info', `正在从 CDN 加载模型权重 (${binUrl})...`);
const binResponse = await fetch(binUrl);
if (!binResponse.ok) {
throw new Error(`无法从 ${binUrl} 加载.bin文件: ${binResponse.statusText}`);
}
const arrayBuffer = await binResponse.arrayBuffer();
binData = new Float32Array(arrayBuffer);
if (modelData.dataFile && !binUrl.endsWith(modelData.dataFile)) {
console.warn(`CDN 加载警告:.bin URL (${binUrl}) 与 .json 中定义的 dataFile (${modelData.dataFile}) 不匹配。继续加载。`);
}
} else {
// 这是用户手动加载模型文件的分支,需要修改这里
const inputFiles = document.createElement('input');
inputFiles.type = 'file';
// 允许同时选择多个文件
inputFiles.accept = '.json,.bin'; // 限制文件类型
inputFiles.multiple = true; // 关键:允许选择多个文件
showStatus(MODEL_STATUS, 'info', '请选择 KNN 模型配置文件 (.json) 以及其对应的权重文件 (.bin)...');
await new Promise((resolve, reject) => {
inputFiles.onchange = async (e) => {
const files = e.target.files;
if (files.length === 0) {
showStatus(MODEL_STATUS, 'info', '未选择文件。');
updateModelUI(false);
return reject(new Error('No files selected.'));
}
let jsonFile = null;
let binFile = null;
// 遍历所有选择的文件,找到 .json 和 .bin
for (const file of files) {
if (file.name.endsWith('.json')) {
jsonFile = file;
} else if (file.name.endsWith('.bin')) {
binFile = file;
}
}
if (!jsonFile) {
showStatus(MODEL_STATUS, 'error', '请选择一个 .json 模型配置文件。');
updateModelUI(false);
return reject(new Error('No JSON file found.'));
}
showStatus(MODEL_STATUS, 'info', `正在解析 ${jsonFile.name}...`);
modelName = jsonFile.name;
try {
const reader = new FileReader();
const jsonText = await new Promise((res, rej) => {
reader.onload = () => res(reader.result);
reader.onerror = () => rej(reader.error);
reader.readAsText(jsonFile);
});
modelData = JSON.parse(jsonText);
// 兼容旧的单文件JSON格式
if (!modelData.dataFile) {
console.warn('模型JSON文件不包含 "dataFile" 字段尝试以旧的单文件JSON格式加载。');
await loadSingleJsonModel(modelData);
return resolve(); // 单文件模型加载成功后直接返回
}
} catch (error) {
showStatus(MODEL_STATUS, 'error', `解析 .json 文件失败: ${error.message}`);
console.error('解析 .json 失败:', error);
updateModelUI(false);
return reject(error);
}
// 如果是多文件模式,确保也选择了 bin 文件
if (!binFile) {
showStatus(MODEL_STATUS, 'error', `请同时选择与 ${jsonFile.name} 对应的 .bin 权重文件 (${modelData.dataFile})。`);
updateModelUI(false);
return reject(new Error('No BIN file found.'));
}
// 检查 bin 文件名是否匹配
if (binFile.name !== modelData.dataFile) {
showStatus(MODEL_STATUS, 'error', `选择的 .bin 文件名 "${binFile.name}" 与 .json 中定义的 "${modelData.dataFile}" 不匹配!请选择正确的文件。`);
updateModelUI(false);
return reject(new Error('BIN file name mismatch.'));
}
showStatus(MODEL_STATUS, 'info', `正在读取 ${binFile.name} (二进制权重文件)...`);
try {
const reader = new FileReader();
const arrayBuffer = await new Promise((res, rej) => {
reader.onload = () => res(reader.result);
reader.onerror = () => rej(reader.error);
reader.readAsArrayBuffer(binFile);
});
binData = new Float32Array(arrayBuffer);
resolve(); // 所有文件都已读取成功
} catch (error) {
showStatus(MODEL_STATUS, 'error', `读取 .bin 文件失败: ${error.message}`);
console.error('读取 .bin 失败:', error);
updateModelUI(false);
return reject(error);
}
};
inputFiles.click(); // 触发一次文件选择
});
}
if (!modelData) {
return;
}
if (modelData && binData) {
knnClassifier.clearAllClasses();
Object.keys(modelData.dataset).forEach(label => {
const classDataMeta = modelData.dataset[label];
const startFloat32ElementIndex = classDataMeta.start;
const numFloat32Elements = classDataMeta.length;
const featureDim = modelData.featureDim || 1280;
if (startFloat32ElementIndex + numFloat32Elements > binData.length) {
throw new Error(`模型数据错误: 类别 ${label} 的数据超出 .bin 文件范围。`);
}
const classFeatures = binData.subarray(startFloat32ElementIndex, startFloat32ElementIndex + numFloat32Elements);
if (classFeatures.length === 0) {
console.warn(`类别 ${label} 没有找到特征数据,跳过。`);
return;
}
if (classFeatures.length % featureDim !== 0) {
const actualSamples = classFeatures.length / featureDim;
console.error(
`--- 类别: ${label} ---`,
`起始 Float32 元素索引: ${startFloat32ElementIndex}`,
`该类别 Float32 元素数量: ${numFloat32Elements}`,
`ERROR: 特征数据长度 (${classFeatures.length} 个 Float32 元素) 与特征维度 (${featureDim}) 不匹配!` +
`实际样本数计算为 ${actualSamples} (预期为整数)。`,
`请检查您的模型导出逻辑和训练数据的完整性。`
);
throw new Error("模型数据完整性错误:特征数据长度与维度不匹配。");
}
const numSamples = classFeatures.length / featureDim;
for (let i = 0; i < numSamples; i++) {
const startIndex = i * featureDim;
const endIndex = (i + 1) * featureDim;
const sampleFeatures = classFeatures.subarray(startIndex, endIndex);
const sampleTensor = tf.tensor(sampleFeatures, [1, featureDim]);
knnClassifier.addExample(sampleTensor, parseInt(label));
tf.dispose(sampleTensor);
}
});
if (modelData.classList && Array.isArray(modelData.classList)) {
classNames = modelData.classList.map(c => c.name);
} else {
console.warn('模型JSON中未找到 classList 字段或格式不正确,使用默认类别名称。');
classNames = Object.keys(modelData.dataset).map(key => `Class ${parseInt(key) + 1}`);
}
showStatus(MODEL_STATUS, 'success', `KNN 模型 "${modelName}" 加载成功!类别: ${classNames.join(', ')}`);
updateModelUI(true);
} else if (modelData && !binData && !jsonUrl) {
showStatus(MODEL_STATUS, 'error', '未知模型加载状态:仅有 JSON 数据,没有 BIN 数据。');
updateModelUI(false);
}
} catch (error) {
showStatus(MODEL_STATUS, 'error', `加载 KNN 模型失败: ${error.message}`);
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '模型加载失败');
console.error('加载 KNN 模型总失败:', error);
updateModelUI(false);
throw error;
}
}
async function startWebcam() {
if (webcamStream) return;
if (!knnClassifier || knnClassifier.getNumClasses() === 0) {
showStatus(MODEL_STATUS, 'error', '请先加载训练好的模型!');
return;
}
try {
const stream = await navigator.mediaDevices.getUserMedia({ video: { facingMode: 'user' }, audio: false });
VIDEO.srcObject = stream;
webcamStream = stream;
updateWebcamUI(true);
VIDEO.onloadeddata = () => {
// REMOVED: Serial related state resets
isPredicting = true;
predictLoop();
showStatus(WEBCAM_STATUS_DISPLAY, 'success', `摄像头已运行,识别中...`);
PREDICTION_OUTPUT.classList.remove('idle', 'error');
};
} catch (error) {
showStatus(MODEL_STATUS, 'error', `无法访问摄像头: ${error.message}`);
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '无法启动摄像头');
console.error('启动摄像头失败:', error);
updateWebcamUI(false);
}
}
function stopWebcam() {
if (webcamStream) {
webcamStream.getTracks().forEach(track => track.stop());
webcamStream = null;
}
isPredicting = false;
VIDEO.srcObject = null;
updateWebcamUI(false);
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '摄像头已停止');
// REMOVED: Serial related state resets
// If there were any non-serial resource cleanup here, it would be moved.
}
// REMOVED: Serial confirmation logic variables and functions
// predictLoop will be simplified to just predict and update UI.
let currentDetectedClassLabel = '等待识别...'; // For displaying the current prediction.
async function predictLoop() {
if (!isPredicting) return;
if (VIDEO.readyState === 4 && VIDEO.videoWidth > 0 && VIDEO.videoHeight > 0) {
try {
const features = await getFeatures(VIDEO);
const k = 3;
if (!knnClassifier || knnClassifier.getNumClasses() === 0) {
features.dispose();
PREDICTION_OUTPUT.textContent = 'KNN 分类器未就绪或无数据。';
PREDICTION_OUTPUT.classList.add('error');
currentDetectedClassLabel = '模型未就绪';
} else {
const prediction = await knnClassifier.predictClass(features, k);
features.dispose();
if (prediction && prediction.confidences) {
let maxConfidence = 0;
let predictedClassIndex = -1;
const confidencesArray = Object.entries(prediction.confidences).map(([key, value]) => ({ index: parseInt(key), confidence: value }));
confidencesArray.forEach(({ index, confidence }) => {
if (confidence > maxConfidence) {
maxConfidence = confidence;
predictedClassIndex = index;
}
});
const confidenceThreshold = 0.75;
if (predictedClassIndex !== -1 && maxConfidence > confidenceThreshold) {
const className = classNames[predictedClassIndex] || `Class ${predictedClassIndex + 1}`;
const percentage = (maxConfidence * 100).toFixed(1);
PREDICTION_OUTPUT.textContent = `识别为: ${className} (${percentage}%)`;
PREDICTION_OUTPUT.classList.remove('idle', 'error');
currentDetectedClassLabel = className;
// Original logic had commandCandidate = '1', '2', '0' etc.
// Since serial is removed, this part is now purely for UI display.
} else {
PREDICTION_OUTPUT.textContent = `未知或不确定... (最高置信度: ${(maxConfidence * 100).toFixed(1)}%)`;
PREDICTION_OUTPUT.classList.add('idle');
currentDetectedClassLabel = '未知或不确定';
}
} else {
PREDICTION_OUTPUT.textContent = '无法识别。';
PREDICTION_OUTPUT.classList.add('error');
currentDetectedClassLabel = '无法识别';
}
}
} catch (error) {
console.error('预测错误:', error);
PREDICTION_OUTPUT.textContent = `预测错误: ${error.message}`;
PREDICTION_OUTPUT.classList.add('error');
currentDetectedClassLabel = `错误: ${error.message}`;
}
}
requestAnimationFrame(predictLoop);
}
// REMOVED: Web Serial API Logic and Event Listeners
/*
async function checkWebSerialCompatibility() { ... }
async function connectSerial() { ... }
async function disconnectSerial() { ... }
async function sendToSerialPort(command) { ... }
CONNECT_SERIAL_BTN.addEventListener('click', connectSerial);
DISCONNECT_SERIAL_BTN.addEventListener('click', disconnectSerial);
*/
// ===================================
// Event Listeners (Simplified)
// ===================================
LOAD_MODEL_BTN.addEventListener('click', () => loadKNNModel(null, null));
START_WEBCAM_BTN.addEventListener('click', startWebcam);
STOP_WEBCAM_BTN.addEventListener('click', stopWebcam);
// ===================================
// Initialization (Simplified)
// ===================================
document.addEventListener('DOMContentLoaded', () => {
// REMOVED: checkWebSerialCompatibility();
initModel();
});
// Added cleanup for TensorFlow.js on window close/reload
window.onbeforeunload = () => {
if (animationFrameId) {
cancelAnimationFrame(animationFrameId);
}
if (mobilenet) {
// mobilenet.dispose(); // MobileNet is part of TF.js, tf.disposeAll() handles it
}
if (knnClassifier) {
knnClassifier.clearAllClasses();
}
tf.disposeAll();
console.log('Resources cleaned up (tf.disposeAll()).');
};