525 lines
22 KiB
JavaScript
525 lines
22 KiB
JavaScript
// script.js
|
||
|
||
const VIDEO = document.getElementById('webcam');
|
||
// const CONNECT_SERIAL_BTN = document.getElementById('connectSerialBtn'); // REMOVED
|
||
// const DISCONNECT_SERIAL_BTN = document.getElementById('disconnectSerialBtn'); // REMOVED
|
||
const LOAD_MODEL_BTN = document.getElementById('loadModelBtn');
|
||
const START_WEBCAM_BTN = document.getElementById('startWebcamBtn');
|
||
const STOP_WEBCAM_BTN = document.getElementById('stopWebcamBtn');
|
||
const MODEL_STATUS = document.getElementById('modelStatus');
|
||
// const SERIAL_STATUS = document.getElementById('serialStatus'); // REMOVED
|
||
const PREDICTION_OUTPUT = document.getElementById('prediction');
|
||
const WEBCAM_STATUS_DISPLAY = document.getElementById('webcam-status-display');
|
||
|
||
let mobilenet;
|
||
let knnClassifier;
|
||
let classNames = [];
|
||
let webcamStream = null;
|
||
let isPredicting = false;
|
||
|
||
// REMOVED: Web Serial API variables
|
||
// let serialPort = null;
|
||
// let serialWriter = null;
|
||
// const SERIAL_BAUD_RATE = 9600;
|
||
// const SERIAL_SEND_MIN_INTERVAL = 500;
|
||
// let lastSerialCommand = '';
|
||
// let lastSerialSendTime = 0;
|
||
|
||
// REMOVED: Serial connection state variables
|
||
// let isSerialConnectedState = false;
|
||
// let lastSentClassCommand = null;
|
||
|
||
// REMOVED: Confirmation sending logic variables
|
||
// let pendingCommandToSend = null;
|
||
// let pendingCommandTimerId = null;
|
||
// const CONFIRMATION_DELAY_MS = 100;
|
||
|
||
|
||
// ===================================
|
||
// Helper Functions (UI Status)
|
||
// ===================================
|
||
function showStatus(element, type, message) {
|
||
element.className = `status-message status-${type}`;
|
||
element.textContent = message;
|
||
}
|
||
|
||
// REMOVED: updateSerialUI function
|
||
/*
|
||
function updateSerialUI(isConnected) {
|
||
CONNECT_SERIAL_BTN.disabled = isConnected;
|
||
DISCONNECT_SERIAL_BTN.disabled = !isConnected;
|
||
isSerialConnectedState = isConnected;
|
||
if (!isConnected) {
|
||
showStatus(SERIAL_STATUS, 'info', '串口未连接。点击 "连接串口" 开始。');
|
||
}
|
||
}
|
||
*/
|
||
|
||
function updateWebcamUI(isRunning) {
|
||
START_WEBCAM_BTN.disabled = isRunning;
|
||
STOP_WEBCAM_BTN.disabled = !isRunning;
|
||
if (isRunning) {
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '摄像头已启动,等待模型预测...');
|
||
PREDICTION_OUTPUT.classList.remove('idle', 'error');
|
||
} else {
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '摄像头未启动');
|
||
PREDICTION_OUTPUT.classList.add('idle');
|
||
PREDICTION_OUTPUT.textContent = '等待识别...';
|
||
}
|
||
}
|
||
|
||
function updateModelUI(isLoaded) {
|
||
LOAD_MODEL_BTN.disabled = false;
|
||
START_WEBCAM_BTN.disabled = !isLoaded;
|
||
}
|
||
|
||
// ===================================
|
||
// Core Logic: Model & Webcam
|
||
// ===================================
|
||
async function initModel() {
|
||
showStatus(MODEL_STATUS, 'info', '正在加载 MobileNet 模型...');
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '系统初始化中...');
|
||
try {
|
||
if (!window.tf || !window.mobilenet || !window.knnClassifier) { // Added tf check
|
||
showStatus(MODEL_STATUS, 'error', 'TensorFlow.js 核心库或模型库未加载。请检查 HTML 引入。');
|
||
console.error('TensorFlow.js 核心库或模型库未加载。');
|
||
return;
|
||
}
|
||
|
||
mobilenet = await window.mobilenet.load({ version: 2, alpha: 1.0 });
|
||
knnClassifier = window.knnClassifier.create();
|
||
showStatus(MODEL_STATUS, 'success', 'MobileNet 模型和 KNN 分类器已加载。');
|
||
updateModelUI(false);
|
||
|
||
const cdnModelBaseUrl = 'https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/models/';
|
||
const cdnModelJsonFileName = 'knn-model-juzi.json';
|
||
const cdnModelBinFileName = 'knn-model-juzi.bin';
|
||
|
||
const cdnJsonUrl = `${cdnModelBaseUrl}${cdnModelJsonFileName}`;
|
||
const cdnBinUrl = `${cdnModelBaseUrl}${cdnModelBinFileName}`;
|
||
|
||
console.log(`尝试从 CDN 加载模型: ${cdnJsonUrl}, ${cdnBinUrl}`);
|
||
showStatus(MODEL_STATUS, 'info', '正在尝试从 CDN 自动加载 KNN 模型...');
|
||
|
||
try {
|
||
await loadKNNModel(cdnJsonUrl, cdnBinUrl);
|
||
console.log('CDN 模型自动加载成功。');
|
||
} catch (cdnError) {
|
||
showStatus(MODEL_STATUS, 'warning', `从 CDN 加载 KNN 模型失败: ${cdnError.message}。您可以尝试手动加载。`);
|
||
console.warn('CDN KNN 模型加载失败:', cdnError);
|
||
updateModelUI(false);
|
||
}
|
||
|
||
} catch (error) {
|
||
showStatus(MODEL_STATUS, 'error', `模型加载失败: ${error.message}`);
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '模型加载失败');
|
||
console.error('MobileNet/KNN加载失败:', error);
|
||
}
|
||
}
|
||
|
||
async function getFeatures(img) {
|
||
if (!mobilenet) {
|
||
throw new Error("MobileNet model is not loaded.");
|
||
}
|
||
return tf.tidy(() => {
|
||
const embeddings = mobilenet.infer(img, true);
|
||
const norm = tf.norm(embeddings);
|
||
const normalized = tf.div(embeddings, norm);
|
||
return normalized;
|
||
});
|
||
}
|
||
|
||
// loadSingleJsonModel 保持不变
|
||
async function loadSingleJsonModel(modelData) {
|
||
try {
|
||
knnClassifier.clearAllClasses();
|
||
Object.keys(modelData.dataset).forEach(key => {
|
||
const data = modelData.dataset[key];
|
||
const featureDim = modelData.featureDim || 1280;
|
||
if (data.length % featureDim !== 0) {
|
||
throw new Error(`类别 ${key} 的特征数据长度 ${data.length} 与特征维度 ${featureDim} 不匹配!`);
|
||
}
|
||
const numSamples = data.length / featureDim;
|
||
const tensor = tf.tensor(data, [numSamples, featureDim]);
|
||
knnClassifier.addExample(tensor, parseInt(key));
|
||
tf.dispose(tensor);
|
||
});
|
||
|
||
if (modelData.classList && Array.isArray(modelData.classList)) {
|
||
classNames = modelData.classList.map(c => c.name);
|
||
} else if (modelData.classNames && Array.isArray(modelData.classNames)) {
|
||
classNames = modelData.classNames;
|
||
} else {
|
||
console.warn('模型JSON中未找到 classList/classNames 字段,使用默认类别名称。');
|
||
classNames = Object.keys(modelData.dataset).map(key => `Class ${parseInt(key) + 1}`);
|
||
}
|
||
|
||
showStatus(MODEL_STATUS, 'success', `模型 (单文件JSON格式) 加载成功!类别: ${classNames.join(', ')}。`);
|
||
updateModelUI(true);
|
||
} catch (error) {
|
||
showStatus(MODEL_STATUS, 'error', `加载单文件JSON模型失败: ${error.message}`);
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '模型加载失败');
|
||
console.error('加载单文件JSON模型失败:', error);
|
||
updateModelUI(false);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
|
||
async function loadKNNModel(jsonUrl = null, binUrl = null) {
|
||
if (!knnClassifier) {
|
||
showStatus(MODEL_STATUS, 'error', 'KNN 分类器未初始化。请先加载 MobileNet 模型。');
|
||
return;
|
||
}
|
||
|
||
let modelData = null;
|
||
let binData = null;
|
||
let modelName = '未知模型';
|
||
|
||
try {
|
||
if (jsonUrl && binUrl) {
|
||
showStatus(MODEL_STATUS, 'info', `正在从 CDN 加载模型配置文件 (${jsonUrl})...`);
|
||
const jsonResponse = await fetch(jsonUrl);
|
||
if (!jsonResponse.ok) {
|
||
throw new Error(`无法从 ${jsonUrl} 加载.json文件: ${jsonResponse.statusText}`);
|
||
}
|
||
modelData = await jsonResponse.json();
|
||
modelName = jsonUrl.split('/').pop();
|
||
|
||
showStatus(MODEL_STATUS, 'info', `正在从 CDN 加载模型权重 (${binUrl})...`);
|
||
const binResponse = await fetch(binUrl);
|
||
if (!binResponse.ok) {
|
||
throw new Error(`无法从 ${binUrl} 加载.bin文件: ${binResponse.statusText}`);
|
||
}
|
||
const arrayBuffer = await binResponse.arrayBuffer();
|
||
binData = new Float32Array(arrayBuffer);
|
||
|
||
if (modelData.dataFile && !binUrl.endsWith(modelData.dataFile)) {
|
||
console.warn(`CDN 加载警告:.bin URL (${binUrl}) 与 .json 中定义的 dataFile (${modelData.dataFile}) 不匹配。继续加载。`);
|
||
}
|
||
|
||
} else {
|
||
// 这是用户手动加载模型文件的分支,需要修改这里
|
||
const inputFiles = document.createElement('input');
|
||
inputFiles.type = 'file';
|
||
// 允许同时选择多个文件
|
||
inputFiles.accept = '.json,.bin'; // 限制文件类型
|
||
inputFiles.multiple = true; // 关键:允许选择多个文件
|
||
|
||
showStatus(MODEL_STATUS, 'info', '请选择 KNN 模型配置文件 (.json) 以及其对应的权重文件 (.bin)...');
|
||
|
||
await new Promise((resolve, reject) => {
|
||
inputFiles.onchange = async (e) => {
|
||
const files = e.target.files;
|
||
if (files.length === 0) {
|
||
showStatus(MODEL_STATUS, 'info', '未选择文件。');
|
||
updateModelUI(false);
|
||
return reject(new Error('No files selected.'));
|
||
}
|
||
|
||
let jsonFile = null;
|
||
let binFile = null;
|
||
|
||
// 遍历所有选择的文件,找到 .json 和 .bin
|
||
for (const file of files) {
|
||
if (file.name.endsWith('.json')) {
|
||
jsonFile = file;
|
||
} else if (file.name.endsWith('.bin')) {
|
||
binFile = file;
|
||
}
|
||
}
|
||
|
||
if (!jsonFile) {
|
||
showStatus(MODEL_STATUS, 'error', '请选择一个 .json 模型配置文件。');
|
||
updateModelUI(false);
|
||
return reject(new Error('No JSON file found.'));
|
||
}
|
||
|
||
showStatus(MODEL_STATUS, 'info', `正在解析 ${jsonFile.name}...`);
|
||
modelName = jsonFile.name;
|
||
|
||
try {
|
||
const reader = new FileReader();
|
||
const jsonText = await new Promise((res, rej) => {
|
||
reader.onload = () => res(reader.result);
|
||
reader.onerror = () => rej(reader.error);
|
||
reader.readAsText(jsonFile);
|
||
});
|
||
modelData = JSON.parse(jsonText);
|
||
|
||
// 兼容旧的单文件JSON格式
|
||
if (!modelData.dataFile) {
|
||
console.warn('模型JSON文件不包含 "dataFile" 字段,尝试以旧的单文件JSON格式加载。');
|
||
await loadSingleJsonModel(modelData);
|
||
return resolve(); // 单文件模型加载成功后直接返回
|
||
}
|
||
|
||
} catch (error) {
|
||
showStatus(MODEL_STATUS, 'error', `解析 .json 文件失败: ${error.message}`);
|
||
console.error('解析 .json 失败:', error);
|
||
updateModelUI(false);
|
||
return reject(error);
|
||
}
|
||
|
||
// 如果是多文件模式,确保也选择了 bin 文件
|
||
if (!binFile) {
|
||
showStatus(MODEL_STATUS, 'error', `请同时选择与 ${jsonFile.name} 对应的 .bin 权重文件 (${modelData.dataFile})。`);
|
||
updateModelUI(false);
|
||
return reject(new Error('No BIN file found.'));
|
||
}
|
||
|
||
// 检查 bin 文件名是否匹配
|
||
if (binFile.name !== modelData.dataFile) {
|
||
showStatus(MODEL_STATUS, 'error', `选择的 .bin 文件名 "${binFile.name}" 与 .json 中定义的 "${modelData.dataFile}" 不匹配!请选择正确的文件。`);
|
||
updateModelUI(false);
|
||
return reject(new Error('BIN file name mismatch.'));
|
||
}
|
||
|
||
showStatus(MODEL_STATUS, 'info', `正在读取 ${binFile.name} (二进制权重文件)...`);
|
||
try {
|
||
const reader = new FileReader();
|
||
const arrayBuffer = await new Promise((res, rej) => {
|
||
reader.onload = () => res(reader.result);
|
||
reader.onerror = () => rej(reader.error);
|
||
reader.readAsArrayBuffer(binFile);
|
||
});
|
||
binData = new Float32Array(arrayBuffer);
|
||
resolve(); // 所有文件都已读取成功
|
||
} catch (error) {
|
||
showStatus(MODEL_STATUS, 'error', `读取 .bin 文件失败: ${error.message}`);
|
||
console.error('读取 .bin 失败:', error);
|
||
updateModelUI(false);
|
||
return reject(error);
|
||
}
|
||
};
|
||
inputFiles.click(); // 触发一次文件选择
|
||
});
|
||
}
|
||
|
||
if (!modelData) {
|
||
return;
|
||
}
|
||
|
||
if (modelData && binData) {
|
||
knnClassifier.clearAllClasses();
|
||
|
||
Object.keys(modelData.dataset).forEach(label => {
|
||
const classDataMeta = modelData.dataset[label];
|
||
const startFloat32ElementIndex = classDataMeta.start;
|
||
const numFloat32Elements = classDataMeta.length;
|
||
|
||
const featureDim = modelData.featureDim || 1280;
|
||
|
||
if (startFloat32ElementIndex + numFloat32Elements > binData.length) {
|
||
throw new Error(`模型数据错误: 类别 ${label} 的数据超出 .bin 文件范围。`);
|
||
}
|
||
|
||
const classFeatures = binData.subarray(startFloat32ElementIndex, startFloat32ElementIndex + numFloat32Elements);
|
||
|
||
if (classFeatures.length === 0) {
|
||
console.warn(`类别 ${label} 没有找到特征数据,跳过。`);
|
||
return;
|
||
}
|
||
|
||
if (classFeatures.length % featureDim !== 0) {
|
||
const actualSamples = classFeatures.length / featureDim;
|
||
console.error(
|
||
`--- 类别: ${label} ---`,
|
||
`起始 Float32 元素索引: ${startFloat32ElementIndex}`,
|
||
`该类别 Float32 元素数量: ${numFloat32Elements}`,
|
||
`ERROR: 特征数据长度 (${classFeatures.length} 个 Float32 元素) 与特征维度 (${featureDim}) 不匹配!` +
|
||
`实际样本数计算为 ${actualSamples} (预期为整数)。`,
|
||
`请检查您的模型导出逻辑和训练数据的完整性。`
|
||
);
|
||
throw new Error("模型数据完整性错误:特征数据长度与维度不匹配。");
|
||
}
|
||
|
||
const numSamples = classFeatures.length / featureDim;
|
||
|
||
for (let i = 0; i < numSamples; i++) {
|
||
const startIndex = i * featureDim;
|
||
const endIndex = (i + 1) * featureDim;
|
||
const sampleFeatures = classFeatures.subarray(startIndex, endIndex);
|
||
|
||
const sampleTensor = tf.tensor(sampleFeatures, [1, featureDim]);
|
||
|
||
knnClassifier.addExample(sampleTensor, parseInt(label));
|
||
tf.dispose(sampleTensor);
|
||
}
|
||
});
|
||
|
||
if (modelData.classList && Array.isArray(modelData.classList)) {
|
||
classNames = modelData.classList.map(c => c.name);
|
||
} else {
|
||
console.warn('模型JSON中未找到 classList 字段或格式不正确,使用默认类别名称。');
|
||
classNames = Object.keys(modelData.dataset).map(key => `Class ${parseInt(key) + 1}`);
|
||
}
|
||
|
||
showStatus(MODEL_STATUS, 'success', `KNN 模型 "${modelName}" 加载成功!类别: ${classNames.join(', ')}。`);
|
||
updateModelUI(true);
|
||
|
||
} else if (modelData && !binData && !jsonUrl) {
|
||
showStatus(MODEL_STATUS, 'error', '未知模型加载状态:仅有 JSON 数据,没有 BIN 数据。');
|
||
updateModelUI(false);
|
||
}
|
||
|
||
} catch (error) {
|
||
showStatus(MODEL_STATUS, 'error', `加载 KNN 模型失败: ${error.message}`);
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '模型加载失败');
|
||
console.error('加载 KNN 模型总失败:', error);
|
||
updateModelUI(false);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
|
||
async function startWebcam() {
|
||
if (webcamStream) return;
|
||
|
||
if (!knnClassifier || knnClassifier.getNumClasses() === 0) {
|
||
showStatus(MODEL_STATUS, 'error', '请先加载训练好的模型!');
|
||
return;
|
||
}
|
||
|
||
try {
|
||
const stream = await navigator.mediaDevices.getUserMedia({ video: { facingMode: 'user' }, audio: false });
|
||
VIDEO.srcObject = stream;
|
||
webcamStream = stream;
|
||
updateWebcamUI(true);
|
||
|
||
VIDEO.onloadeddata = () => {
|
||
// REMOVED: Serial related state resets
|
||
isPredicting = true;
|
||
predictLoop();
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'success', `摄像头已运行,识别中...`);
|
||
PREDICTION_OUTPUT.classList.remove('idle', 'error');
|
||
};
|
||
|
||
} catch (error) {
|
||
showStatus(MODEL_STATUS, 'error', `无法访问摄像头: ${error.message}`);
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'error', '无法启动摄像头');
|
||
console.error('启动摄像头失败:', error);
|
||
updateWebcamUI(false);
|
||
}
|
||
}
|
||
|
||
function stopWebcam() {
|
||
if (webcamStream) {
|
||
webcamStream.getTracks().forEach(track => track.stop());
|
||
webcamStream = null;
|
||
}
|
||
isPredicting = false;
|
||
VIDEO.srcObject = null;
|
||
updateWebcamUI(false);
|
||
showStatus(WEBCAM_STATUS_DISPLAY, 'info', '摄像头已停止');
|
||
|
||
// REMOVED: Serial related state resets
|
||
// If there were any non-serial resource cleanup here, it would be moved.
|
||
}
|
||
|
||
// REMOVED: Serial confirmation logic variables and functions
|
||
// predictLoop will be simplified to just predict and update UI.
|
||
let currentDetectedClassLabel = '等待识别...'; // For displaying the current prediction.
|
||
|
||
async function predictLoop() {
|
||
if (!isPredicting) return;
|
||
|
||
if (VIDEO.readyState === 4 && VIDEO.videoWidth > 0 && VIDEO.videoHeight > 0) {
|
||
try {
|
||
const features = await getFeatures(VIDEO);
|
||
const k = 3;
|
||
|
||
if (!knnClassifier || knnClassifier.getNumClasses() === 0) {
|
||
features.dispose();
|
||
PREDICTION_OUTPUT.textContent = 'KNN 分类器未就绪或无数据。';
|
||
PREDICTION_OUTPUT.classList.add('error');
|
||
currentDetectedClassLabel = '模型未就绪';
|
||
} else {
|
||
const prediction = await knnClassifier.predictClass(features, k);
|
||
features.dispose();
|
||
|
||
if (prediction && prediction.confidences) {
|
||
let maxConfidence = 0;
|
||
let predictedClassIndex = -1;
|
||
|
||
const confidencesArray = Object.entries(prediction.confidences).map(([key, value]) => ({ index: parseInt(key), confidence: value }));
|
||
|
||
confidencesArray.forEach(({ index, confidence }) => {
|
||
if (confidence > maxConfidence) {
|
||
maxConfidence = confidence;
|
||
predictedClassIndex = index;
|
||
}
|
||
});
|
||
|
||
const confidenceThreshold = 0.75;
|
||
if (predictedClassIndex !== -1 && maxConfidence > confidenceThreshold) {
|
||
const className = classNames[predictedClassIndex] || `Class ${predictedClassIndex + 1}`;
|
||
const percentage = (maxConfidence * 100).toFixed(1);
|
||
PREDICTION_OUTPUT.textContent = `识别为: ${className} (${percentage}%)`;
|
||
PREDICTION_OUTPUT.classList.remove('idle', 'error');
|
||
currentDetectedClassLabel = className;
|
||
|
||
// Original logic had commandCandidate = '1', '2', '0' etc.
|
||
// Since serial is removed, this part is now purely for UI display.
|
||
} else {
|
||
PREDICTION_OUTPUT.textContent = `未知或不确定... (最高置信度: ${(maxConfidence * 100).toFixed(1)}%)`;
|
||
PREDICTION_OUTPUT.classList.add('idle');
|
||
currentDetectedClassLabel = '未知或不确定';
|
||
}
|
||
} else {
|
||
PREDICTION_OUTPUT.textContent = '无法识别。';
|
||
PREDICTION_OUTPUT.classList.add('error');
|
||
currentDetectedClassLabel = '无法识别';
|
||
}
|
||
}
|
||
} catch (error) {
|
||
console.error('预测错误:', error);
|
||
PREDICTION_OUTPUT.textContent = `预测错误: ${error.message}`;
|
||
PREDICTION_OUTPUT.classList.add('error');
|
||
currentDetectedClassLabel = `错误: ${error.message}`;
|
||
}
|
||
}
|
||
requestAnimationFrame(predictLoop);
|
||
}
|
||
|
||
|
||
// REMOVED: Web Serial API Logic and Event Listeners
|
||
/*
|
||
async function checkWebSerialCompatibility() { ... }
|
||
async function connectSerial() { ... }
|
||
async function disconnectSerial() { ... }
|
||
async function sendToSerialPort(command) { ... }
|
||
CONNECT_SERIAL_BTN.addEventListener('click', connectSerial);
|
||
DISCONNECT_SERIAL_BTN.addEventListener('click', disconnectSerial);
|
||
*/
|
||
|
||
// ===================================
|
||
// Event Listeners (Simplified)
|
||
// ===================================
|
||
LOAD_MODEL_BTN.addEventListener('click', () => loadKNNModel(null, null));
|
||
START_WEBCAM_BTN.addEventListener('click', startWebcam);
|
||
STOP_WEBCAM_BTN.addEventListener('click', stopWebcam);
|
||
|
||
// ===================================
|
||
// Initialization (Simplified)
|
||
// ===================================
|
||
document.addEventListener('DOMContentLoaded', () => {
|
||
// REMOVED: checkWebSerialCompatibility();
|
||
initModel();
|
||
});
|
||
|
||
// Added cleanup for TensorFlow.js on window close/reload
|
||
window.onbeforeunload = () => {
|
||
if (animationFrameId) {
|
||
cancelAnimationFrame(animationFrameId);
|
||
}
|
||
if (mobilenet) {
|
||
// mobilenet.dispose(); // MobileNet is part of TF.js, tf.disposeAll() handles it
|
||
}
|
||
if (knnClassifier) {
|
||
knnClassifier.clearAllClasses();
|
||
}
|
||
tf.disposeAll();
|
||
console.log('Resources cleaned up (tf.disposeAll()).');
|
||
};
|