2025-08-20 16:41:21 +08:00

503 lines
18 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* =============================================================================
* 动态版 - 姿态识别与模型管理脚本 (v2.1)
* - 新增自动采集样本功能
* =============================================================================
* 功能列表:
* - 实时姿态检测 (MoveNet)
* - KNN 分类器训练
* - 实时姿态预测
* - 坐标完美对齐 (Canvas与Video重叠)
* - 动态添加/删除/重命名姿态类别
* - 模型导出为包含类别信息的 JSON 文件
* - 从 JSON 文件导入模型并恢复类别状态
* - ✅ 新增自动采集10次样本间隔0.3秒
* =============================================================================
*/
'use strict';
// --- 全局变量和常量 ---
const videoElement = document.getElementById('video');
const canvasElement = document.getElementById('canvas');
const canvasCtx = canvasElement.getContext('2d');
const statusElement = document.getElementById('status');
const resultElement = document.getElementById('result-text');
// UI元素
const poseClassesContainer = document.getElementById('pose-classes-container');
const addClassButton = document.getElementById('btn-add-class');
const predictButton = document.getElementById('btn-predict');
const exportButton = document.getElementById('btn-export');
const importButton = document.getElementById('btn-import');
const fileImporter = document.getElementById('file-importer');
let detector, classifier, animationFrameId;
let isPredicting = false;
let isAutoCollecting = false; // 新增:标记是否正在进行自动采集
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
const appState = {
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
nextClassId: 0 // 用于生成唯一的 classId
};
// --- 主应用逻辑 ---
/**
* 初始化应用,加载模型并设置摄像头
*/
async function init() {
try {
classifier = knnClassifier.create();
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
await setupCamera();
setupEventListeners();
mainLoop();
statusElement.innerText = "模型和摄像头已就绪!";
enableControls();
addNewClass(); // 默认创建第一个类别
} catch (error) {
console.error("初始化失败:", error);
statusElement.innerText = "初始化失败,请检查摄像头权限或刷新。";
statusElement.style.backgroundColor = '#fce8e6';
statusElement.style.color = '#d93025';
}
}
/**
* 设置和启动用户摄像头
*/
async function setupCamera() {
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
videoElement.srcObject = stream;
return new Promise((resolve) => {
videoElement.onloadedmetadata = () => {
videoElement.play();
canvasElement.width = videoElement.videoWidth;
canvasElement.height = videoElement.videoHeight;
resolve();
};
});
}
/**
* 为所有交互式元素绑定事件监听器
*/
function setupEventListeners() {
addClassButton.addEventListener('click', addNewClass);
predictButton.addEventListener('click', togglePrediction);
exportButton.addEventListener('click', exportModel);
importButton.addEventListener('click', () => fileImporter.click());
fileImporter.addEventListener('change', importModel);
}
// --- 动态类别管理 ---
/**
* 动态创建一个新类别的UI元素并添加到页面
* @param {number} classId - 类别的唯一ID
* @param {string} className - 类别的名称
*/
function createClassUI(classId, className) {
const poseClassDiv = document.createElement('div');
poseClassDiv.className = 'pose-class';
poseClassDiv.dataset.classId = classId;
// 📌 修改这里:添加 btn-auto-sample 按钮
poseClassDiv.innerHTML = `
<div class="class-info">
<input type="text" class="class-name-input" value="${className}" data-class-id="${classId}">
<span class="sample-count">(0 样本)</span>
</div>
<div class="class-actions">
<button class="btn-auto-sample" data-class-id="${classId}">自动采集</button>
<button class="btn-sample" data-class-id="${classId}">采集样本</button>
<button class="btn-delete-class" title="删除类别" data-class-id="${classId}">×</button>
</div>
`;
poseClassesContainer.appendChild(poseClassDiv);
// 为新创建的元素绑定事件
const nameInput = poseClassDiv.querySelector('.class-name-input');
nameInput.addEventListener('change', (e) => {
appState.classMap[classId] = e.target.value;
});
const autoSampleButton = poseClassDiv.querySelector('.btn-auto-sample'); // 新增
autoSampleButton.addEventListener('click', () => toggleAutoCollection(classId, autoSampleButton)); // 新增
const sampleButton = poseClassDiv.querySelector('.btn-sample');
sampleButton.addEventListener('click', () => addExample(classId));
// 初始化时根据预测状态禁用按钮
if (isPredicting) {
sampleButton.disabled = true;
autoSampleButton.disabled = true; // 新增
}
const deleteButton = poseClassDiv.querySelector('.btn-delete-class');
deleteButton.addEventListener('click', () => deleteClass(classId));
}
/**
* 添加一个新的姿态类别
*/
function addNewClass() {
const classId = appState.nextClassId;
const className = `Class ${classId + 1}`;
appState.classMap[classId] = className;
appState.nextClassId++;
createClassUI(classId, className);
}
/**
* 删除一个指定的姿态类别
* @param {number} classId - 要删除的类别的ID
*/
function deleteClass(classId) {
if (confirm(`确定要删除类别 "${appState.classMap[classId]}" 吗?所有样本都将丢失。`)) {
// 从UI中移除
const elementToRemove = poseClassesContainer.querySelector(`[data-class-id="${classId}"]`);
if (elementToRemove) elementToRemove.remove();
// 从状态和分类器中移除
delete appState.classMap[classId];
classifier.clearClass(classId);
updateSampleCounts();
updatePredictionUI(); // 检查是否还有类别可以预测
checkExportAbility();
}
}
/**
* 采集一个姿态样本并添加到KNN分类器
* @param {number} classId 类别的ID
*/
async function addExample(classId) {
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
if (poses && poses.length > 0) {
const poseTensor = flattenPose(poses[0]);
classifier.addExample(poseTensor, classId);
poseTensor.dispose();
updateSampleCounts();
checkExportAbility();
console.log(`为类别 ${appState.classMap[classId]} 采集1个样本。`);
return true; // 表示采集成功
} else {
console.warn(`为类别 ${appState.classMap[classId]} 采集样本失败,未检测到姿态。`);
return false; // 表示采集失败
}
}
// --- 新增:自动采集逻辑 ---
let autoCollectionIntervalId = null; // 用于存储 setInterval ID
let autoCollectionCount = 0; // 计数器
const AUTO_COLLECTION_TOTAL = 10; // 总共采集次数
const AUTO_COLLECTION_INTERVAL = 300; // 间隔时间 0.3 秒
async function toggleAutoCollection(classId, buttonElement) {
if (isAutoCollecting) {
// 如果正在自动采集,则停止
stopAutoCollection(buttonElement);
} else {
// 否则,开始自动采集
startAutoCollection(classId, buttonElement);
}
}
async function startAutoCollection(classId, buttonElement) {
isAutoCollecting = true;
autoCollectionCount = 0;
// 禁用其他采集和预测按钮
predictButton.disabled = true;
exportButton.disabled = true;
importButton.disabled = true;
addClassButton.disabled = true;
document.querySelectorAll('.btn-sample, .btn-auto-sample, .btn-delete-class, .class-name-input').forEach(btn => {
if (btn !== buttonElement) { // 不禁用当前自动采集按钮
btn.disabled = true;
}
if (btn.classList.contains('class-name-input')) btn.disabled = true;
});
buttonElement.innerText = `停止采集 (0/${AUTO_COLLECTION_TOTAL})`;
buttonElement.classList.add('stop'); // 添加停止样式
const performCollection = async () => {
if (autoCollectionCount < AUTO_COLLECTION_TOTAL) {
const success = await addExample(classId); // 调用手动采集功能
if (success) {
autoCollectionCount++;
}
buttonElement.innerText = `停止采集 (${autoCollectionCount}/${AUTO_COLLECTION_TOTAL})`;
} else {
stopAutoCollection(buttonElement);
alert(`类别 "${appState.classMap[classId]}" 自动采集完成!`);
}
};
// 立即执行一次,然后设置定时器
await performCollection();
if (autoCollectionCount < AUTO_COLLECTION_TOTAL) {
autoCollectionIntervalId = setInterval(performCollection, AUTO_COLLECTION_INTERVAL);
}
}
function stopAutoCollection(buttonElement) {
clearInterval(autoCollectionIntervalId);
autoCollectionIntervalId = null;
isAutoCollecting = false;
buttonElement.innerText = '自动采集';
buttonElement.classList.remove('stop'); // 移除停止样式
// 重新启用按钮(根据应用状态)
updatePredictionUI(); // 根据预测状态重新启用/禁用相关按钮
enableControls(); // 重新启用添加类别、导出、导入按钮
}
// --- 模型与预测逻辑 ---
/**
* 开始或停止姿态预测
*/
function togglePrediction() {
if (classifier.getNumClasses() === 0) {
alert("请先为至少一个姿态采集样本后再开始预测!");
return;
}
isPredicting = !isPredicting;
updatePredictionUI();
}
/**
* 应用的主循环
*/
async function mainLoop() {
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height);
if (poses && poses.length > 0) {
drawPose(poses[0]);
// 只有当不在自动采集状态时才进行预测
if (isPredicting && classifier.getNumClasses() > 0 && !isAutoCollecting) {
const poseTensor = flattenPose(poses[0]);
const result = await classifier.predictClass(poseTensor, 3);
poseTensor.dispose();
const confidence = Math.round(result.confidences[result.label] * 100);
const predictedClassName = appState.classMap[result.label] || '未知类别';
resultElement.innerText = `姿态: ${predictedClassName} (${confidence}%)`;
} else if (isAutoCollecting) {
resultElement.innerText = "自动采集中...";
}
} else {
resultElement.innerText = "未检测到姿态";
}
animationFrameId = requestAnimationFrame(mainLoop);
}
// --- 模型管理函数 (已更新以支持动态类别) ---
/**
* 导出KNN模型为包含类别信息的JSON文件
*/
function exportModel() {
if (classifier.getNumClasses() === 0) {
alert('模型中还没有任何样本,无法导出!');
return;
}
const dataset = classifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach((key) => {
const data = dataset[key];
datasetObj[key] = data.arraySync();
});
const modelData = {
classMap: appState.classMap,
dataset: datasetObj
};
const jsonStr = JSON.stringify(modelData);
const blob = new Blob([jsonStr], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `pose-knn-model.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
/**
* 从JSON文件导入KNN模型并恢复类别状态
* @param {Event} event
*/
function importModel(event) {
const file = event.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = (e) => {
try {
const modelData = JSON.parse(e.target.result);
if (!modelData.classMap || !modelData.dataset) {
throw new Error("无效的模型文件格式。");
}
// 1. 清理现有状态
classifier.clearAllClasses();
poseClassesContainer.innerHTML = '';
appState.classMap = {};
// 2. 加载新状态
appState.classMap = modelData.classMap;
const classIds = Object.keys(appState.classMap).map(Number);
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
// 3. 恢复UI
classIds.forEach(id => {
createClassUI(id, appState.classMap[id]);
});
// 4. 加载模型数据
const newDataset = {};
Object.keys(modelData.dataset).forEach((key) => {
newDataset[key] = tf.tensor(modelData.dataset[key]);
});
classifier.setClassifierDataset(newDataset);
updateSampleCounts();
checkExportAbility();
alert('模型导入成功!');
} catch (error) {
console.error("导入模型失败:", error);
alert(`导入失败请确保文件是正确的模型JSON文件。\n错误: ${error.message}`);
} finally {
fileImporter.value = '';
}
};
reader.readAsText(file);
}
// --- 辅助和UI更新函数 ---
function flattenPose(pose) {
const keypoints = pose.keypoints.map(p => [p.x / videoElement.videoWidth, p.y / videoElement.videoHeight]).flat();
return tf.tensor(keypoints);
}
function drawPose(pose) {
// 绘制关键点和骨骼...
if (pose.keypoints) {
// 绘制关键点
for (const keypoint of pose.keypoints) {
if (keypoint.score > 0.3) {
canvasCtx.beginPath();
canvasCtx.arc(keypoint.x, keypoint.y, 5, 0, 2 * Math.PI);
canvasCtx.fillStyle = '#1a73e8';
canvasCtx.fill();
}
}
// 绘制骨骼连接线
const adjacentPairs = poseDetection.util.getAdjacentPairs(poseDetection.SupportedModels.MoveNet);
adjacentPairs.forEach(([i, j]) => {
const kp1 = pose.keypoints[i];
const kp2 = pose.keypoints[j];
if (kp1.score > 0.3 && kp2.score > 0.3) {
canvasCtx.beginPath();
canvasCtx.moveTo(kp1.x, kp1.y);
canvasCtx.lineTo(kp2.x, kp2.y);
canvasCtx.strokeStyle = 'blue';
canvasCtx.lineWidth = 2;
canvasCtx.stroke();
}
});
}
}
/**
* 更新所有类别UI上的样本数量
*/
function updateSampleCounts() {
const dataset = classifier.getClassifierDataset();
const allClassElements = document.querySelectorAll('.pose-class');
allClassElements.forEach(el => {
const classId = parseInt(el.dataset.classId, 10);
const classInfo = dataset[classId];
const count = classInfo ? classInfo.shape[0] : 0;
el.querySelector('.sample-count').innerText = `(${count} 样本)`;
});
}
/**
* 根据状态更新UI
*/
function updatePredictionUI() {
// 禁用所有采集按钮(包括手动和自动)和删除按钮
document.querySelectorAll('.btn-sample, .btn-auto-sample, .btn-delete-class').forEach(btn => btn.disabled = isPredicting || isAutoCollecting);
// 禁用添加类别和导入模型的按钮
addClassButton.disabled = isPredicting || isAutoCollecting;
importButton.disabled = isPredicting || isAutoCollecting;
// 禁用类别名称输入框
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = isPredicting || isAutoCollecting);
if (isPredicting) {
predictButton.innerText = "停止预测";
predictButton.classList.add('stop');
resultElement.innerText = "正在分析...";
} else {
predictButton.innerText = "开始预测";
predictButton.classList.remove('stop');
resultElement.innerText = "已停止";
}
// 只有在有类别且有样本时才能预测
predictButton.disabled = isPredicting ? false : classifier.getNumClasses() === 0 || isAutoCollecting;
checkExportAbility();
}
/**
* 通用启用/禁用控件 (在自动采集停止后调用)
*/
function enableControls() {
// 重新评估所有按钮的状态
// 自动采集按钮的状态由其自身管理
predictButton.disabled = classifier.getNumClasses() === 0;
importButton.disabled = false; // 导入按钮总是可以手动启用
addClassButton.disabled = false;
checkExportAbility(); // 重新检查导出按钮
updatePredictionUI(); // 再次调用,确保其他按钮状态正确
}
/** 检查是否可以导出模型并更新按钮状态 */
function checkExportAbility() {
exportButton.disabled = isPredicting || classifier.getNumClasses() === 0 || isAutoCollecting;
}
function cleanup() {
if (detector) detector.dispose();
if (classifier) classifier.clearAllClasses();
if (animationFrameId) cancelAnimationFrame(animationFrameId);
if (autoCollectionIntervalId) clearInterval(autoCollectionIntervalId); // 清理自动采集定时器
}
// --- 启动应用 ---
window.onbeforeunload = cleanup;
init();