/** * ============================================================================= * 动态版 - 姿态识别与模型管理脚本 (v2.1) * - 新增自动采集样本功能 * ============================================================================= * 功能列表: * - 实时姿态检测 (MoveNet) * - KNN 分类器训练 * - 实时姿态预测 * - 坐标完美对齐 (Canvas与Video重叠) * - 动态添加/删除/重命名姿态类别 * - 模型导出为包含类别信息的 JSON 文件 * - 从 JSON 文件导入模型并恢复类别状态 * - ✅ 新增:自动采集10次样本,间隔0.3秒 * ============================================================================= */ 'use strict'; // --- 全局变量和常量 --- const videoElement = document.getElementById('video'); const canvasElement = document.getElementById('canvas'); const canvasCtx = canvasElement.getContext('2d'); const statusElement = document.getElementById('status'); const resultElement = document.getElementById('result-text'); // UI元素 const poseClassesContainer = document.getElementById('pose-classes-container'); const addClassButton = document.getElementById('btn-add-class'); const predictButton = document.getElementById('btn-predict'); const exportButton = document.getElementById('btn-export'); const importButton = document.getElementById('btn-import'); const fileImporter = document.getElementById('file-importer'); let detector, classifier, animationFrameId; let isPredicting = false; let isAutoCollecting = false; // 新增:标记是否正在进行自动采集 // 📌 核心状态管理: 使用一个对象来管理所有动态状态 const appState = { classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'} nextClassId: 0 // 用于生成唯一的 classId }; // --- 主应用逻辑 --- /** * 初始化应用,加载模型并设置摄像头 */ async function init() { try { classifier = knnClassifier.create(); const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }; detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig); await setupCamera(); setupEventListeners(); mainLoop(); statusElement.innerText = "模型和摄像头已就绪!"; enableControls(); addNewClass(); // 默认创建第一个类别 } catch (error) { console.error("初始化失败:", error); statusElement.innerText = "初始化失败,请检查摄像头权限或刷新。"; statusElement.style.backgroundColor = '#fce8e6'; statusElement.style.color = '#d93025'; } } /** * 设置和启动用户摄像头 */ async function setupCamera() { const stream = await navigator.mediaDevices.getUserMedia({ video: true }); videoElement.srcObject = stream; return new Promise((resolve) => { videoElement.onloadedmetadata = () => { videoElement.play(); canvasElement.width = videoElement.videoWidth; canvasElement.height = videoElement.videoHeight; resolve(); }; }); } /** * 为所有交互式元素绑定事件监听器 */ function setupEventListeners() { addClassButton.addEventListener('click', addNewClass); predictButton.addEventListener('click', togglePrediction); exportButton.addEventListener('click', exportModel); importButton.addEventListener('click', () => fileImporter.click()); fileImporter.addEventListener('change', importModel); } // --- 动态类别管理 --- /** * 动态创建一个新类别的UI元素并添加到页面 * @param {number} classId - 类别的唯一ID * @param {string} className - 类别的名称 */ function createClassUI(classId, className) { const poseClassDiv = document.createElement('div'); poseClassDiv.className = 'pose-class'; poseClassDiv.dataset.classId = classId; // 📌 修改这里:添加 btn-auto-sample 按钮 poseClassDiv.innerHTML = `
(0 样本)
`; poseClassesContainer.appendChild(poseClassDiv); // 为新创建的元素绑定事件 const nameInput = poseClassDiv.querySelector('.class-name-input'); nameInput.addEventListener('change', (e) => { appState.classMap[classId] = e.target.value; }); const autoSampleButton = poseClassDiv.querySelector('.btn-auto-sample'); // 新增 autoSampleButton.addEventListener('click', () => toggleAutoCollection(classId, autoSampleButton)); // 新增 const sampleButton = poseClassDiv.querySelector('.btn-sample'); sampleButton.addEventListener('click', () => addExample(classId)); // 初始化时根据预测状态禁用按钮 if (isPredicting) { sampleButton.disabled = true; autoSampleButton.disabled = true; // 新增 } const deleteButton = poseClassDiv.querySelector('.btn-delete-class'); deleteButton.addEventListener('click', () => deleteClass(classId)); } /** * 添加一个新的姿态类别 */ function addNewClass() { const classId = appState.nextClassId; const className = `Class ${classId + 1}`; appState.classMap[classId] = className; appState.nextClassId++; createClassUI(classId, className); } /** * 删除一个指定的姿态类别 * @param {number} classId - 要删除的类别的ID */ function deleteClass(classId) { if (confirm(`确定要删除类别 "${appState.classMap[classId]}" 吗?所有样本都将丢失。`)) { // 从UI中移除 const elementToRemove = poseClassesContainer.querySelector(`[data-class-id="${classId}"]`); if (elementToRemove) elementToRemove.remove(); // 从状态和分类器中移除 delete appState.classMap[classId]; classifier.clearClass(classId); updateSampleCounts(); updatePredictionUI(); // 检查是否还有类别可以预测 checkExportAbility(); } } /** * 采集一个姿态样本并添加到KNN分类器 * @param {number} classId 类别的ID */ async function addExample(classId) { const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true }); if (poses && poses.length > 0) { const poseTensor = flattenPose(poses[0]); classifier.addExample(poseTensor, classId); poseTensor.dispose(); updateSampleCounts(); checkExportAbility(); console.log(`为类别 ${appState.classMap[classId]} 采集1个样本。`); return true; // 表示采集成功 } else { console.warn(`为类别 ${appState.classMap[classId]} 采集样本失败,未检测到姿态。`); return false; // 表示采集失败 } } // --- 新增:自动采集逻辑 --- let autoCollectionIntervalId = null; // 用于存储 setInterval ID let autoCollectionCount = 0; // 计数器 const AUTO_COLLECTION_TOTAL = 10; // 总共采集次数 const AUTO_COLLECTION_INTERVAL = 300; // 间隔时间 0.3 秒 async function toggleAutoCollection(classId, buttonElement) { if (isAutoCollecting) { // 如果正在自动采集,则停止 stopAutoCollection(buttonElement); } else { // 否则,开始自动采集 startAutoCollection(classId, buttonElement); } } async function startAutoCollection(classId, buttonElement) { isAutoCollecting = true; autoCollectionCount = 0; // 禁用其他采集和预测按钮 predictButton.disabled = true; exportButton.disabled = true; importButton.disabled = true; addClassButton.disabled = true; document.querySelectorAll('.btn-sample, .btn-auto-sample, .btn-delete-class, .class-name-input').forEach(btn => { if (btn !== buttonElement) { // 不禁用当前自动采集按钮 btn.disabled = true; } if (btn.classList.contains('class-name-input')) btn.disabled = true; }); buttonElement.innerText = `停止采集 (0/${AUTO_COLLECTION_TOTAL})`; buttonElement.classList.add('stop'); // 添加停止样式 const performCollection = async () => { if (autoCollectionCount < AUTO_COLLECTION_TOTAL) { const success = await addExample(classId); // 调用手动采集功能 if (success) { autoCollectionCount++; } buttonElement.innerText = `停止采集 (${autoCollectionCount}/${AUTO_COLLECTION_TOTAL})`; } else { stopAutoCollection(buttonElement); alert(`类别 "${appState.classMap[classId]}" 自动采集完成!`); } }; // 立即执行一次,然后设置定时器 await performCollection(); if (autoCollectionCount < AUTO_COLLECTION_TOTAL) { autoCollectionIntervalId = setInterval(performCollection, AUTO_COLLECTION_INTERVAL); } } function stopAutoCollection(buttonElement) { clearInterval(autoCollectionIntervalId); autoCollectionIntervalId = null; isAutoCollecting = false; buttonElement.innerText = '自动采集'; buttonElement.classList.remove('stop'); // 移除停止样式 // 重新启用按钮(根据应用状态) updatePredictionUI(); // 根据预测状态重新启用/禁用相关按钮 enableControls(); // 重新启用添加类别、导出、导入按钮 } // --- 模型与预测逻辑 --- /** * 开始或停止姿态预测 */ function togglePrediction() { if (classifier.getNumClasses() === 0) { alert("请先为至少一个姿态采集样本后再开始预测!"); return; } isPredicting = !isPredicting; updatePredictionUI(); } /** * 应用的主循环 */ async function mainLoop() { const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true }); canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height); if (poses && poses.length > 0) { drawPose(poses[0]); // 只有当不在自动采集状态时才进行预测 if (isPredicting && classifier.getNumClasses() > 0 && !isAutoCollecting) { const poseTensor = flattenPose(poses[0]); const result = await classifier.predictClass(poseTensor, 3); poseTensor.dispose(); const confidence = Math.round(result.confidences[result.label] * 100); const predictedClassName = appState.classMap[result.label] || '未知类别'; resultElement.innerText = `姿态: ${predictedClassName} (${confidence}%)`; } else if (isAutoCollecting) { resultElement.innerText = "自动采集中..."; } } else { resultElement.innerText = "未检测到姿态"; } animationFrameId = requestAnimationFrame(mainLoop); } // --- 模型管理函数 (已更新以支持动态类别) --- /** * 导出KNN模型为包含类别信息的JSON文件 */ function exportModel() { if (classifier.getNumClasses() === 0) { alert('模型中还没有任何样本,无法导出!'); return; } const dataset = classifier.getClassifierDataset(); const datasetObj = {}; Object.keys(dataset).forEach((key) => { const data = dataset[key]; datasetObj[key] = data.arraySync(); }); const modelData = { classMap: appState.classMap, dataset: datasetObj }; const jsonStr = JSON.stringify(modelData); const blob = new Blob([jsonStr], { type: "application/json" }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; a.download = `pose-knn-model.json`; document.body.appendChild(a); a.click(); document.body.removeChild(a); URL.revokeObjectURL(url); } /** * 从JSON文件导入KNN模型并恢复类别状态 * @param {Event} event */ function importModel(event) { const file = event.target.files[0]; if (!file) return; const reader = new FileReader(); reader.onload = (e) => { try { const modelData = JSON.parse(e.target.result); if (!modelData.classMap || !modelData.dataset) { throw new Error("无效的模型文件格式。"); } // 1. 清理现有状态 classifier.clearAllClasses(); poseClassesContainer.innerHTML = ''; appState.classMap = {}; // 2. 加载新状态 appState.classMap = modelData.classMap; const classIds = Object.keys(appState.classMap).map(Number); appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0; // 3. 恢复UI classIds.forEach(id => { createClassUI(id, appState.classMap[id]); }); // 4. 加载模型数据 const newDataset = {}; Object.keys(modelData.dataset).forEach((key) => { newDataset[key] = tf.tensor(modelData.dataset[key]); }); classifier.setClassifierDataset(newDataset); updateSampleCounts(); checkExportAbility(); alert('模型导入成功!'); } catch (error) { console.error("导入模型失败:", error); alert(`导入失败!请确保文件是正确的模型JSON文件。\n错误: ${error.message}`); } finally { fileImporter.value = ''; } }; reader.readAsText(file); } // --- 辅助和UI更新函数 --- function flattenPose(pose) { const keypoints = pose.keypoints.map(p => [p.x / videoElement.videoWidth, p.y / videoElement.videoHeight]).flat(); return tf.tensor(keypoints); } function drawPose(pose) { // 绘制关键点和骨骼... if (pose.keypoints) { // 绘制关键点 for (const keypoint of pose.keypoints) { if (keypoint.score > 0.3) { canvasCtx.beginPath(); canvasCtx.arc(keypoint.x, keypoint.y, 5, 0, 2 * Math.PI); canvasCtx.fillStyle = '#1a73e8'; canvasCtx.fill(); } } // 绘制骨骼连接线 const adjacentPairs = poseDetection.util.getAdjacentPairs(poseDetection.SupportedModels.MoveNet); adjacentPairs.forEach(([i, j]) => { const kp1 = pose.keypoints[i]; const kp2 = pose.keypoints[j]; if (kp1.score > 0.3 && kp2.score > 0.3) { canvasCtx.beginPath(); canvasCtx.moveTo(kp1.x, kp1.y); canvasCtx.lineTo(kp2.x, kp2.y); canvasCtx.strokeStyle = 'blue'; canvasCtx.lineWidth = 2; canvasCtx.stroke(); } }); } } /** * 更新所有类别UI上的样本数量 */ function updateSampleCounts() { const dataset = classifier.getClassifierDataset(); const allClassElements = document.querySelectorAll('.pose-class'); allClassElements.forEach(el => { const classId = parseInt(el.dataset.classId, 10); const classInfo = dataset[classId]; const count = classInfo ? classInfo.shape[0] : 0; el.querySelector('.sample-count').innerText = `(${count} 样本)`; }); } /** * 根据状态更新UI */ function updatePredictionUI() { // 禁用所有采集按钮(包括手动和自动)和删除按钮 document.querySelectorAll('.btn-sample, .btn-auto-sample, .btn-delete-class').forEach(btn => btn.disabled = isPredicting || isAutoCollecting); // 禁用添加类别和导入模型的按钮 addClassButton.disabled = isPredicting || isAutoCollecting; importButton.disabled = isPredicting || isAutoCollecting; // 禁用类别名称输入框 document.querySelectorAll('.class-name-input').forEach(input => input.disabled = isPredicting || isAutoCollecting); if (isPredicting) { predictButton.innerText = "停止预测"; predictButton.classList.add('stop'); resultElement.innerText = "正在分析..."; } else { predictButton.innerText = "开始预测"; predictButton.classList.remove('stop'); resultElement.innerText = "已停止"; } // 只有在有类别且有样本时才能预测 predictButton.disabled = isPredicting ? false : classifier.getNumClasses() === 0 || isAutoCollecting; checkExportAbility(); } /** * 通用启用/禁用控件 (在自动采集停止后调用) */ function enableControls() { // 重新评估所有按钮的状态 // 自动采集按钮的状态由其自身管理 predictButton.disabled = classifier.getNumClasses() === 0; importButton.disabled = false; // 导入按钮总是可以手动启用 addClassButton.disabled = false; checkExportAbility(); // 重新检查导出按钮 updatePredictionUI(); // 再次调用,确保其他按钮状态正确 } /** 检查是否可以导出模型并更新按钮状态 */ function checkExportAbility() { exportButton.disabled = isPredicting || classifier.getNumClasses() === 0 || isAutoCollecting; } function cleanup() { if (detector) detector.dispose(); if (classifier) classifier.clearAllClasses(); if (animationFrameId) cancelAnimationFrame(animationFrameId); if (autoCollectionIntervalId) clearInterval(autoCollectionIntervalId); // 清理自动采集定时器 } // --- 启动应用 --- window.onbeforeunload = cleanup; init();