435 lines
15 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* =============================================================================
* 动态版 - 手部姿态识别与模型管理脚本 (v3.0)
* 由人体姿态识别修改为手部姿态识别
* =============================================================================
*/
'use strict';
// --- 全局变量和常量 ---
const videoElement = document.getElementById('video');
const canvasElement = document.getElementById('canvas');
const canvasCtx = canvasElement.getContext('2d');
const statusElement = document.getElementById('status');
const resultElement = document.getElementById('result-text');
// UI元素
const poseClassesContainer = document.getElementById('pose-classes-container');
const addClassButton = document.getElementById('btn-add-class');
const predictButton = document.getElementById('btn-predict');
const exportButton = document.getElementById('btn-export');
const importButton = document.getElementById('btn-import');
const fileImporter = document.getElementById('file-importer');
let detector, classifier, animationFrameId;
let isPredicting = false;
const appState = {
classMap: {},
nextClassId: 0
};
// --- 主应用逻辑 ---
/**
* 初始化应用,加载模型并设置摄像头
*/
async function init() {
try {
classifier = knnClassifier.create();
// --- 修改点 1: 加载手部检测模型 ---
const model = handPoseDetection.SupportedModels.MediaPipeHands;
const detectorConfig = {
runtime: 'mediapipe', // 推荐使用 MediaPipe runtime 获得最佳性能
solutionPath: 'https://cdn.jsdelivr.net/npm/@mediapipe/hands' // MediaPipe solution files path
};
detector = await handPoseDetection.createDetector(model, detectorConfig);
await setupCamera();
setupEventListeners();
mainLoop();
statusElement.innerText = "手部模型和摄像头已就绪!";
enableControls();
addNewClass(); // 默认创建第一个类别
} catch (error) {
console.error("初始化失败:", error);
statusElement.innerText = "初始化失败,请检查摄像头权限或刷新。";
statusElement.style.backgroundColor = '#fce8e6';
statusElement.style.color = '#d93025';
}
}
/**
* 设置和启动用户摄像头 (无需修改,与之前一致)
*/
async function setupCamera() {
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
videoElement.srcObject = stream;
return new Promise((resolve) => {
videoElement.onloadedmetadata = () => {
videoElement.play();
// 确保 Canvas 与 Video 宽高一致,并且在这里不需要 Canvas 镜像
canvasElement.width = videoElement.videoWidth;
canvasElement.height = videoElement.videoHeight;
resolve();
};
});
}
/**
* 为所有交互式元素绑定事件监听器 (无需修改,与之前一致)
*/
function setupEventListeners() {
addClassButton.addEventListener('click', addNewClass);
predictButton.addEventListener('click', togglePrediction);
exportButton.addEventListener('click', exportModel);
importButton.addEventListener('click', () => fileImporter.click());
fileImporter.addEventListener('change', importModel);
}
// --- 动态类别管理 (无需修改,与之前一致) ---
/**
* 动态创建一个新类别的UI元素并添加到页面
* @param {number} cId - 类别的唯一ID
* @param {string} cName - 类别的名称
*/
function createClassUI(cId, cName) {
const poseClassDiv = document.createElement('div');
poseClassDiv.className = 'pose-class';
poseClassDiv.dataset.classId = cId;
poseClassDiv.innerHTML = `
<div class="class-info">
<input type="text" class="class-name-input" value="${cName}" data-class-id="${cId}">
<span class="sample-count">(0 样本)</span>
</div>
<div class="class-actions">
<button class="btn-sample" data-class-id="${cId}">采集样本</button>
<button class="btn-delete-class" title="删除类别" data-class-id="${cId}">×</button>
</div>
`;
poseClassesContainer.appendChild(poseClassDiv);
const nameInput = poseClassDiv.querySelector('.class-name-input');
nameInput.addEventListener('change', (e) => {
appState.classMap[cId] = e.target.value;
});
const sampleButton = poseClassDiv.querySelector('.btn-sample');
sampleButton.addEventListener('click', () => addExample(cId));
if (isPredicting) sampleButton.disabled = true;
const deleteButton = poseClassDiv.querySelector('.btn-delete-class');
deleteButton.addEventListener('click', () => deleteClass(cId));
}
/**
* 添加一个新的姿态类别
*/
function addNewClass() {
const classId = appState.nextClassId;
const className = `手势 ${classId + 1}`; // 改为“手势”
appState.classMap[classId] = className;
appState.nextClassId++;
createClassUI(classId, className);
}
/**
* 删除一个指定的姿态类别
* @param {number} classId - 要删除的类别的ID
*/
function deleteClass(classId) {
if (confirm(`确定要删除类别 "${appState.classMap[classId]}" 吗?所有样本都将丢失。`)) {
const elementToRemove = poseClassesContainer.querySelector(`[data-class-id="${classId}"]`);
if (elementToRemove) elementToRemove.remove();
delete appState.classMap[classId];
classifier.clearClass(classId);
updateSampleCounts();
updatePredictionUI();
checkExportAbility();
}
}
/**
* 采集一个姿态样本并添加到KNN分类器
* @param {number} classId 类别的ID
*/
async function addExample(classId) {
// --- 修改点 2: 使用 estimateHands 替代 estimatePoses ---
// flipHorizontal: false 确保模型输出的坐标与原视频方向一致 (非镜像)
const hands = await detector.estimateHands(videoElement, { flipHorizontal: false });
if (hands && hands.length > 0) {
// KNN 分类器通常只处理一个实例,这里我们取检测到的第一只手
const handTensor = flattenHand(hands[0]); // 使用新的 flattenHand
classifier.addExample(handTensor, classId);
handTensor.dispose(); // 释放内存
updateSampleCounts();
checkExportAbility();
} else {
console.warn(`为类别 ${appState.classMap[classId]} 采集样本失败,未检测到手部。`);
}
}
// --- 模型与预测逻辑 ---
/**
* 开始或停止姿态预测
*/
function togglePrediction() {
if (classifier.getNumClasses() === 0) {
alert("请先为至少一个手势采集样本后再开始预测!");
return;
}
isPredicting = !isPredicting;
updatePredictionUI();
}
/**
* 应用的主循环
*/
async function mainLoop() {
// --- 修改点 3: 使用 estimateHands 替代 estimatePoses ---
// flipHorizontal: false 确保模型输出的坐标与原视频方向一致 (非镜像)
const hands = await detector.estimateHands(videoElement, { flipHorizontal: false });
canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height); // 清空画布
if (hands && hands.length > 0) {
// 通常只处理检测到的第一只手,如果有两只手,可以根据需求处理
drawHand(hands[0]); // 使用新的 drawHand
if (isPredicting && classifier.getNumClasses() > 0) {
const handTensor = flattenHand(hands[0]); // 使用新的 flattenHand
const result = await classifier.predictClass(handTensor, 3);
handTensor.dispose();
const confidence = Math.round(result.confidences[result.label] * 100);
const predictedClassName = appState.classMap[result.label] || '未知手势'; // 文案修改
resultElement.innerText = `手势: ${predictedClassName} (${confidence}%)`; // 文案修改
}
}
animationFrameId = requestAnimationFrame(mainLoop);
}
// --- 模型管理函数 (无需修改,与之前一致) ---
/**
* 导出KNN模型为包含类别信息的JSON文件
*/
function exportModel() {
if (classifier.getNumClasses() === 0) {
alert('模型中还没有任何样本,无法导出!');
return;
}
const dataset = classifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach((key) => {
const data = dataset[key];
datasetObj[key] = data.arraySync();
});
const modelData = {
classMap: appState.classMap,
dataset: datasetObj
};
const jsonStr = JSON.stringify(modelData);
const blob = new Blob([jsonStr], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `hand-knn-model.json`; // 文件名改为 hand-knn-model.json
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
/**
* 从JSON文件导入KNN模型并恢复类别状态 (无需修改,与之前一致)
* @param {Event} event
*/
function importModel(event) {
const file = event.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = (e) => {
try {
const modelData = JSON.parse(e.target.result);
if (!modelData.classMap || !modelData.dataset) {
throw new Error("无效的模型文件格式。");
}
classifier.clearAllClasses();
poseClassesContainer.innerHTML = '';
appState.classMap = {};
appState.classMap = modelData.classMap;
const classIds = Object.keys(appState.classMap).map(Number);
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
classIds.forEach(id => {
createClassUI(id, appState.classMap[id]);
});
const newDataset = {};
Object.keys(modelData.dataset).forEach((key) => {
newDataset[key] = tf.tensor(modelData.dataset[key]);
});
classifier.setClassifierDataset(newDataset);
updateSampleCounts();
checkExportAbility();
alert('模型导入成功!');
} catch (error) {
console.error("导入模型失败:", error);
alert(`导入失败请确保文件是正确的模型JSON文件。\n错误: ${error.message}`);
} finally {
fileImporter.value = '';
}
};
reader.readAsText(file);
}
// --- 辅助和UI更新函数 ---
/**
* --- 修改点 4: 展平手部关键点 ---
* 将手部关键点展平为一维张量。
* 考虑到 MediaPipe Hands 模型的关键点总数是21个 (0-20)。
* @param {Object} hand - 单个手部检测结果对象
* @returns {tf.Tensor} - 展平后的关键点坐标张量
*/
function flattenHand(hand) {
// 归一化关键点坐标到 [0, 1] 范围,然后展平
const keypoints = hand.keypoints.map(p => [p.x / videoElement.videoWidth, p.y / videoElement.videoHeight]).flat();
return tf.tensor(keypoints);
}
const HAND_CONNECTIONS = [
[0, 1], [1, 2], [2, 3], [3, 4], // Thumb
[0, 5], [5, 6], [6, 7], [7, 8], // Index finger
[0, 9], [9, 10], [10, 11], [11, 12], // Middle finger
[0, 13], [13, 14], [14, 15], [15, 16], // Ring finger
[0, 17], [17, 18], [18, 19], [19, 20], // Pinky finger
[0, 5], [5, 9], [9, 13], [13, 17], [17, 0] // Palm base connections
];
/**
* --- 修改点 5: 绘制手部骨骼 ---
* 绘制手部关键点和连接线。
* @param {Object} hand - 单个手部检测结果对象
*/
function drawHand(hand) {
if (hand.keypoints) {
const keypoints = hand.keypoints;
// 绘制连接线
canvasCtx.strokeStyle = '#00FFFF'; // 青色
canvasCtx.lineWidth = 2;
for (const connection of HAND_CONNECTIONS) {
const start = keypoints[connection[0]];
const end = keypoints[connection[1]];
// 检查关键点是否存在
if (start && end) {
canvasCtx.beginPath();
canvasCtx.moveTo(start.x, start.y);
canvasCtx.lineTo(end.x, end.y);
canvasCtx.stroke();
}
}
// 绘制关键点
canvasCtx.fillStyle = '#FF0000'; // 红色
for (const keypoint of keypoints) {
if (keypoint) { // 同样检查
canvasCtx.beginPath();
// 关键点半径设置小一点,因为手部关键点比人体姿态更密集
canvasCtx.arc(keypoint.x, keypoint.y, 4, 0, 2 * Math.PI);
canvasCtx.fill();
}
}
}
}
/**
* 更新所有类别UI上的样本数量 (无需修改,与之前一致)
*/
function updateSampleCounts() {
const dataset = classifier.getClassifierDataset();
const allClassElements = document.querySelectorAll('.pose-class');
allClassElements.forEach(el => {
const classId = parseInt(el.dataset.classId, 10);
const classInfo = dataset[classId];
// 确保 classInfo 存在,因为 classifier.clearClass(id) 后dataset[id] 可能会是 undefined
const count = classInfo ? classInfo.shape[0] : 0;
el.querySelector('.sample-count').innerText = `(${count} 样本)`;
});
}
/**
* 根据状态更新UI (少量文案修改)
*/
function updatePredictionUI() {
const allActionButtons = document.querySelectorAll('.btn-sample, .btn-delete-class, .btn-add-class, #btn-import');
if (isPredicting) {
predictButton.innerText = "停止预测";
predictButton.classList.add('stop');
resultElement.innerText = "正在分析手势..."; // 文案修改
allActionButtons.forEach(btn => btn.disabled = true);
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = true);
checkExportAbility();
} else {
predictButton.innerText = "开始预测";
predictButton.classList.remove('stop');
resultElement.innerText = "已停止";
allActionButtons.forEach(btn => btn.disabled = false);
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = false);
checkExportAbility();
}
// 只有在有类别且有样本时才能预测
predictButton.disabled = isPredicting ? false : classifier.getNumClasses() === 0;
}
function enableControls() {
[predictButton, importButton, exportButton, addClassButton].forEach(btn => btn.disabled = false);
checkExportAbility();
}
/** 检查是否可以导出模型并更新按钮状态 */
function checkExportAbility() {
exportButton.disabled = isPredicting || classifier.getNumClasses() === 0;
}
// 释放 TensorFlow.js 相关的内存
function cleanup() {
if (detector) {
// 对于 MediaPipe runtimedetector.dispose() 可能不是必须的,
// 其内部会管理WebGL资源。但为保险起见可以保留。
// 或者更彻底地如果不再需要可以手动清理所有tf.Tensor。
}
if (classifier) classifier.clearAllClasses();
if (animationFrameId) cancelAnimationFrame(animationFrameId);
tf.disposeAll(); // 额外添加,确保所有创建的张量都被释放,防止内存泄露
console.log("Cleanup complete. All TensorFlow.js tensors disposed.");
}
// --- 启动应用 ---
window.onbeforeunload = cleanup; // 页面关闭前清理资源
init();