/**
* =============================================================================
* 动态版 - 姿态识别与模型管理脚本 (v2.1)
* - 新增自动采集样本功能
* =============================================================================
* 功能列表:
* - 实时姿态检测 (MoveNet)
* - KNN 分类器训练
* - 实时姿态预测
* - 坐标完美对齐 (Canvas与Video重叠)
* - 动态添加/删除/重命名姿态类别
* - 模型导出为包含类别信息的 JSON 文件
* - 从 JSON 文件导入模型并恢复类别状态
* - ✅ 新增:自动采集10次样本,间隔0.3秒
* =============================================================================
*/
'use strict';
// --- 全局变量和常量 ---
const videoElement = document.getElementById('video');
const canvasElement = document.getElementById('canvas');
const canvasCtx = canvasElement.getContext('2d');
const statusElement = document.getElementById('status');
const resultElement = document.getElementById('result-text');
// UI元素
const poseClassesContainer = document.getElementById('pose-classes-container');
const addClassButton = document.getElementById('btn-add-class');
const predictButton = document.getElementById('btn-predict');
const exportButton = document.getElementById('btn-export');
const importButton = document.getElementById('btn-import');
const fileImporter = document.getElementById('file-importer');
let detector, classifier, animationFrameId;
let isPredicting = false;
let isAutoCollecting = false; // 新增:标记是否正在进行自动采集
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
const appState = {
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
nextClassId: 0 // 用于生成唯一的 classId
};
// --- 主应用逻辑 ---
/**
* 初始化应用,加载模型并设置摄像头
*/
async function init() {
try {
classifier = knnClassifier.create();
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
await setupCamera();
setupEventListeners();
mainLoop();
statusElement.innerText = "模型和摄像头已就绪!";
enableControls();
addNewClass(); // 默认创建第一个类别
} catch (error) {
console.error("初始化失败:", error);
statusElement.innerText = "初始化失败,请检查摄像头权限或刷新。";
statusElement.style.backgroundColor = '#fce8e6';
statusElement.style.color = '#d93025';
}
}
/**
* 设置和启动用户摄像头
*/
async function setupCamera() {
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
videoElement.srcObject = stream;
return new Promise((resolve) => {
videoElement.onloadedmetadata = () => {
videoElement.play();
canvasElement.width = videoElement.videoWidth;
canvasElement.height = videoElement.videoHeight;
resolve();
};
});
}
/**
* 为所有交互式元素绑定事件监听器
*/
function setupEventListeners() {
addClassButton.addEventListener('click', addNewClass);
predictButton.addEventListener('click', togglePrediction);
exportButton.addEventListener('click', exportModel);
importButton.addEventListener('click', () => fileImporter.click());
fileImporter.addEventListener('change', importModel);
}
// --- 动态类别管理 ---
/**
* 动态创建一个新类别的UI元素并添加到页面
* @param {number} classId - 类别的唯一ID
* @param {string} className - 类别的名称
*/
function createClassUI(classId, className) {
const poseClassDiv = document.createElement('div');
poseClassDiv.className = 'pose-class';
poseClassDiv.dataset.classId = classId;
// 📌 修改这里:添加 btn-auto-sample 按钮
poseClassDiv.innerHTML = `
(0 样本)
`;
poseClassesContainer.appendChild(poseClassDiv);
// 为新创建的元素绑定事件
const nameInput = poseClassDiv.querySelector('.class-name-input');
nameInput.addEventListener('change', (e) => {
appState.classMap[classId] = e.target.value;
});
const autoSampleButton = poseClassDiv.querySelector('.btn-auto-sample'); // 新增
autoSampleButton.addEventListener('click', () => toggleAutoCollection(classId, autoSampleButton)); // 新增
const sampleButton = poseClassDiv.querySelector('.btn-sample');
sampleButton.addEventListener('click', () => addExample(classId));
// 初始化时根据预测状态禁用按钮
if (isPredicting) {
sampleButton.disabled = true;
autoSampleButton.disabled = true; // 新增
}
const deleteButton = poseClassDiv.querySelector('.btn-delete-class');
deleteButton.addEventListener('click', () => deleteClass(classId));
}
/**
* 添加一个新的姿态类别
*/
function addNewClass() {
const classId = appState.nextClassId;
const className = `Class ${classId + 1}`;
appState.classMap[classId] = className;
appState.nextClassId++;
createClassUI(classId, className);
}
/**
* 删除一个指定的姿态类别
* @param {number} classId - 要删除的类别的ID
*/
function deleteClass(classId) {
if (confirm(`确定要删除类别 "${appState.classMap[classId]}" 吗?所有样本都将丢失。`)) {
// 从UI中移除
const elementToRemove = poseClassesContainer.querySelector(`[data-class-id="${classId}"]`);
if (elementToRemove) elementToRemove.remove();
// 从状态和分类器中移除
delete appState.classMap[classId];
classifier.clearClass(classId);
updateSampleCounts();
updatePredictionUI(); // 检查是否还有类别可以预测
checkExportAbility();
}
}
/**
* 采集一个姿态样本并添加到KNN分类器
* @param {number} classId 类别的ID
*/
async function addExample(classId) {
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
if (poses && poses.length > 0) {
const poseTensor = flattenPose(poses[0]);
classifier.addExample(poseTensor, classId);
poseTensor.dispose();
updateSampleCounts();
checkExportAbility();
console.log(`为类别 ${appState.classMap[classId]} 采集1个样本。`);
return true; // 表示采集成功
} else {
console.warn(`为类别 ${appState.classMap[classId]} 采集样本失败,未检测到姿态。`);
return false; // 表示采集失败
}
}
// --- 新增:自动采集逻辑 ---
let autoCollectionIntervalId = null; // 用于存储 setInterval ID
let autoCollectionCount = 0; // 计数器
const AUTO_COLLECTION_TOTAL = 10; // 总共采集次数
const AUTO_COLLECTION_INTERVAL = 300; // 间隔时间 0.3 秒
async function toggleAutoCollection(classId, buttonElement) {
if (isAutoCollecting) {
// 如果正在自动采集,则停止
stopAutoCollection(buttonElement);
} else {
// 否则,开始自动采集
startAutoCollection(classId, buttonElement);
}
}
async function startAutoCollection(classId, buttonElement) {
isAutoCollecting = true;
autoCollectionCount = 0;
// 禁用其他采集和预测按钮
predictButton.disabled = true;
exportButton.disabled = true;
importButton.disabled = true;
addClassButton.disabled = true;
document.querySelectorAll('.btn-sample, .btn-auto-sample, .btn-delete-class, .class-name-input').forEach(btn => {
if (btn !== buttonElement) { // 不禁用当前自动采集按钮
btn.disabled = true;
}
if (btn.classList.contains('class-name-input')) btn.disabled = true;
});
buttonElement.innerText = `停止采集 (0/${AUTO_COLLECTION_TOTAL})`;
buttonElement.classList.add('stop'); // 添加停止样式
const performCollection = async () => {
if (autoCollectionCount < AUTO_COLLECTION_TOTAL) {
const success = await addExample(classId); // 调用手动采集功能
if (success) {
autoCollectionCount++;
}
buttonElement.innerText = `停止采集 (${autoCollectionCount}/${AUTO_COLLECTION_TOTAL})`;
} else {
stopAutoCollection(buttonElement);
alert(`类别 "${appState.classMap[classId]}" 自动采集完成!`);
}
};
// 立即执行一次,然后设置定时器
await performCollection();
if (autoCollectionCount < AUTO_COLLECTION_TOTAL) {
autoCollectionIntervalId = setInterval(performCollection, AUTO_COLLECTION_INTERVAL);
}
}
function stopAutoCollection(buttonElement) {
clearInterval(autoCollectionIntervalId);
autoCollectionIntervalId = null;
isAutoCollecting = false;
buttonElement.innerText = '自动采集';
buttonElement.classList.remove('stop'); // 移除停止样式
// 重新启用按钮(根据应用状态)
updatePredictionUI(); // 根据预测状态重新启用/禁用相关按钮
enableControls(); // 重新启用添加类别、导出、导入按钮
}
// --- 模型与预测逻辑 ---
/**
* 开始或停止姿态预测
*/
function togglePrediction() {
if (classifier.getNumClasses() === 0) {
alert("请先为至少一个姿态采集样本后再开始预测!");
return;
}
isPredicting = !isPredicting;
updatePredictionUI();
}
/**
* 应用的主循环
*/
async function mainLoop() {
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height);
if (poses && poses.length > 0) {
drawPose(poses[0]);
// 只有当不在自动采集状态时才进行预测
if (isPredicting && classifier.getNumClasses() > 0 && !isAutoCollecting) {
const poseTensor = flattenPose(poses[0]);
const result = await classifier.predictClass(poseTensor, 3);
poseTensor.dispose();
const confidence = Math.round(result.confidences[result.label] * 100);
const predictedClassName = appState.classMap[result.label] || '未知类别';
resultElement.innerText = `姿态: ${predictedClassName} (${confidence}%)`;
} else if (isAutoCollecting) {
resultElement.innerText = "自动采集中...";
}
} else {
resultElement.innerText = "未检测到姿态";
}
animationFrameId = requestAnimationFrame(mainLoop);
}
// --- 模型管理函数 (已更新以支持动态类别) ---
/**
* 导出KNN模型为包含类别信息的JSON文件
*/
function exportModel() {
if (classifier.getNumClasses() === 0) {
alert('模型中还没有任何样本,无法导出!');
return;
}
const dataset = classifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach((key) => {
const data = dataset[key];
datasetObj[key] = data.arraySync();
});
const modelData = {
classMap: appState.classMap,
dataset: datasetObj
};
const jsonStr = JSON.stringify(modelData);
const blob = new Blob([jsonStr], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `pose-knn-model.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
/**
* 从JSON文件导入KNN模型并恢复类别状态
* @param {Event} event
*/
function importModel(event) {
const file = event.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = (e) => {
try {
const modelData = JSON.parse(e.target.result);
if (!modelData.classMap || !modelData.dataset) {
throw new Error("无效的模型文件格式。");
}
// 1. 清理现有状态
classifier.clearAllClasses();
poseClassesContainer.innerHTML = '';
appState.classMap = {};
// 2. 加载新状态
appState.classMap = modelData.classMap;
const classIds = Object.keys(appState.classMap).map(Number);
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
// 3. 恢复UI
classIds.forEach(id => {
createClassUI(id, appState.classMap[id]);
});
// 4. 加载模型数据
const newDataset = {};
Object.keys(modelData.dataset).forEach((key) => {
newDataset[key] = tf.tensor(modelData.dataset[key]);
});
classifier.setClassifierDataset(newDataset);
updateSampleCounts();
checkExportAbility();
alert('模型导入成功!');
} catch (error) {
console.error("导入模型失败:", error);
alert(`导入失败!请确保文件是正确的模型JSON文件。\n错误: ${error.message}`);
} finally {
fileImporter.value = '';
}
};
reader.readAsText(file);
}
// --- 辅助和UI更新函数 ---
function flattenPose(pose) {
const keypoints = pose.keypoints.map(p => [p.x / videoElement.videoWidth, p.y / videoElement.videoHeight]).flat();
return tf.tensor(keypoints);
}
function drawPose(pose) {
// 绘制关键点和骨骼...
if (pose.keypoints) {
// 绘制关键点
for (const keypoint of pose.keypoints) {
if (keypoint.score > 0.3) {
canvasCtx.beginPath();
canvasCtx.arc(keypoint.x, keypoint.y, 5, 0, 2 * Math.PI);
canvasCtx.fillStyle = '#1a73e8';
canvasCtx.fill();
}
}
// 绘制骨骼连接线
const adjacentPairs = poseDetection.util.getAdjacentPairs(poseDetection.SupportedModels.MoveNet);
adjacentPairs.forEach(([i, j]) => {
const kp1 = pose.keypoints[i];
const kp2 = pose.keypoints[j];
if (kp1.score > 0.3 && kp2.score > 0.3) {
canvasCtx.beginPath();
canvasCtx.moveTo(kp1.x, kp1.y);
canvasCtx.lineTo(kp2.x, kp2.y);
canvasCtx.strokeStyle = 'blue';
canvasCtx.lineWidth = 2;
canvasCtx.stroke();
}
});
}
}
/**
* 更新所有类别UI上的样本数量
*/
function updateSampleCounts() {
const dataset = classifier.getClassifierDataset();
const allClassElements = document.querySelectorAll('.pose-class');
allClassElements.forEach(el => {
const classId = parseInt(el.dataset.classId, 10);
const classInfo = dataset[classId];
const count = classInfo ? classInfo.shape[0] : 0;
el.querySelector('.sample-count').innerText = `(${count} 样本)`;
});
}
/**
* 根据状态更新UI
*/
function updatePredictionUI() {
// 禁用所有采集按钮(包括手动和自动)和删除按钮
document.querySelectorAll('.btn-sample, .btn-auto-sample, .btn-delete-class').forEach(btn => btn.disabled = isPredicting || isAutoCollecting);
// 禁用添加类别和导入模型的按钮
addClassButton.disabled = isPredicting || isAutoCollecting;
importButton.disabled = isPredicting || isAutoCollecting;
// 禁用类别名称输入框
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = isPredicting || isAutoCollecting);
if (isPredicting) {
predictButton.innerText = "停止预测";
predictButton.classList.add('stop');
resultElement.innerText = "正在分析...";
} else {
predictButton.innerText = "开始预测";
predictButton.classList.remove('stop');
resultElement.innerText = "已停止";
}
// 只有在有类别且有样本时才能预测
predictButton.disabled = isPredicting ? false : classifier.getNumClasses() === 0 || isAutoCollecting;
checkExportAbility();
}
/**
* 通用启用/禁用控件 (在自动采集停止后调用)
*/
function enableControls() {
// 重新评估所有按钮的状态
// 自动采集按钮的状态由其自身管理
predictButton.disabled = classifier.getNumClasses() === 0;
importButton.disabled = false; // 导入按钮总是可以手动启用
addClassButton.disabled = false;
checkExportAbility(); // 重新检查导出按钮
updatePredictionUI(); // 再次调用,确保其他按钮状态正确
}
/** 检查是否可以导出模型并更新按钮状态 */
function checkExportAbility() {
exportButton.disabled = isPredicting || classifier.getNumClasses() === 0 || isAutoCollecting;
}
function cleanup() {
if (detector) detector.dispose();
if (classifier) classifier.clearAllClasses();
if (animationFrameId) cancelAnimationFrame(animationFrameId);
if (autoCollectionIntervalId) clearInterval(autoCollectionIntervalId); // 清理自动采集定时器
}
// --- 启动应用 ---
window.onbeforeunload = cleanup;
init();