406 lines
14 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* =============================================================================
* 动态版 - 姿态识别与模型管理脚本 (v2.0)
* =============================================================================
* 功能列表:
* - 实时姿态检测 (MoveNet)
* - KNN 分类器训练
* - 实时姿态预测
* - 坐标完美对齐 (Canvas与Video重叠)
* - ✅ 动态添加/删除/重命名姿态类别
* - ✅ 模型导出为包含类别信息的 JSON 文件
* - ✅ 从 JSON 文件导入模型并恢复类别状态
* =============================================================================
*/
'use strict';
// --- 全局变量和常量 ---
const videoElement = document.getElementById('video');
const canvasElement = document.getElementById('canvas');
const canvasCtx = canvasElement.getContext('2d');
const statusElement = document.getElementById('status');
const resultElement = document.getElementById('result-text');
// UI元素
const poseClassesContainer = document.getElementById('pose-classes-container');
const addClassButton = document.getElementById('btn-add-class');
const predictButton = document.getElementById('btn-predict');
const exportButton = document.getElementById('btn-export');
const importButton = document.getElementById('btn-import');
const fileImporter = document.getElementById('file-importer');
let detector, classifier, animationFrameId;
let isPredicting = false;
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
const appState = {
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
nextClassId: 0 // 用于生成唯一的 classId
};
// --- 主应用逻辑 ---
/**
* 初始化应用,加载模型并设置摄像头
*/
async function init() {
try {
classifier = knnClassifier.create();
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
await setupCamera();
setupEventListeners();
mainLoop();
statusElement.innerText = "模型和摄像头已就绪!";
enableControls();
addNewClass(); // 默认创建第一个类别
} catch (error) {
console.error("初始化失败:", error);
statusElement.innerText = "初始化失败,请检查摄像头权限或刷新。";
statusElement.style.backgroundColor = '#fce8e6';
statusElement.style.color = '#d93025';
}
}
/**
* 设置和启动用户摄像头
*/
async function setupCamera() {
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
videoElement.srcObject = stream;
return new Promise((resolve) => {
videoElement.onloadedmetadata = () => {
videoElement.play();
canvasElement.width = videoElement.videoWidth;
canvasElement.height = videoElement.videoHeight;
resolve();
};
});
}
/**
* 为所有交互式元素绑定事件监听器
*/
function setupEventListeners() {
addClassButton.addEventListener('click', addNewClass);
predictButton.addEventListener('click', togglePrediction);
exportButton.addEventListener('click', exportModel);
importButton.addEventListener('click', () => fileImporter.click());
fileImporter.addEventListener('change', importModel);
}
// --- 动态类别管理 ---
/**
* 动态创建一个新类别的UI元素并添加到页面
* @param {number} classId - 类别的唯一ID
* @param {string} className - 类别的名称
*/
function createClassUI(classId, className) {
const poseClassDiv = document.createElement('div');
poseClassDiv.className = 'pose-class';
poseClassDiv.dataset.classId = classId;
poseClassDiv.innerHTML = `
<div class="class-info">
<input type="text" class="class-name-input" value="${className}" data-class-id="${classId}">
<span class="sample-count">(0 样本)</span>
</div>
<div class="class-actions">
<button class="btn-sample" data-class-id="${classId}">采集样本</button>
<button class="btn-delete-class" title="删除类别" data-class-id="${classId}">×</button>
</div>
`;
poseClassesContainer.appendChild(poseClassDiv);
// 为新创建的元素绑定事件
const nameInput = poseClassDiv.querySelector('.class-name-input');
nameInput.addEventListener('change', (e) => {
appState.classMap[classId] = e.target.value;
});
const sampleButton = poseClassDiv.querySelector('.btn-sample');
sampleButton.addEventListener('click', () => addExample(classId));
if (isPredicting) sampleButton.disabled = true; // 如果在预测中,禁用新按钮
const deleteButton = poseClassDiv.querySelector('.btn-delete-class');
deleteButton.addEventListener('click', () => deleteClass(classId));
}
/**
* 添加一个新的姿态类别
*/
function addNewClass() {
const classId = appState.nextClassId;
const className = `Class ${classId + 1}`;
appState.classMap[classId] = className;
appState.nextClassId++;
createClassUI(classId, className);
}
/**
* 删除一个指定的姿态类别
* @param {number} classId - 要删除的类别的ID
*/
function deleteClass(classId) {
if (confirm(`确定要删除类别 "${appState.classMap[classId]}" 吗?所有样本都将丢失。`)) {
// 从UI中移除
const elementToRemove = poseClassesContainer.querySelector(`[data-class-id="${classId}"]`);
if (elementToRemove) elementToRemove.remove();
// 从状态和分类器中移除
delete appState.classMap[classId];
classifier.clearClass(classId);
updateSampleCounts();
updatePredictionUI(); // 检查是否还有类别可以预测
checkExportAbility();
}
}
/**
* 采集一个姿态样本并添加到KNN分类器
* @param {number} classId 类别的ID
*/
async function addExample(classId) {
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
if (poses && poses.length > 0) {
const poseTensor = flattenPose(poses[0]);
classifier.addExample(poseTensor, classId);
poseTensor.dispose();
updateSampleCounts();
checkExportAbility();
} else {
console.warn(`为类别 ${appState.classMap[classId]} 采集样本失败,未检测到姿态。`);
}
}
// --- 模型与预测逻辑 ---
/**
* 开始或停止姿态预测
*/
function togglePrediction() {
if (classifier.getNumClasses() === 0) {
alert("请先为至少一个姿态采集样本后再开始预测!");
return;
}
isPredicting = !isPredicting;
updatePredictionUI();
}
/**
* 应用的主循环
*/
async function mainLoop() {
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height);
if (poses && poses.length > 0) {
drawPose(poses[0]);
if (isPredicting && classifier.getNumClasses() > 0) {
const poseTensor = flattenPose(poses[0]);
const result = await classifier.predictClass(poseTensor, 3);
poseTensor.dispose();
const confidence = Math.round(result.confidences[result.label] * 100);
// 📌 动态获取类别名称
const predictedClassName = appState.classMap[result.label] || '未知类别';
resultElement.innerText = `姿态: ${predictedClassName} (${confidence}%)`;
}
}
animationFrameId = requestAnimationFrame(mainLoop);
}
// --- 模型管理函数 (已更新以支持动态类别) ---
/**
* 导出KNN模型为包含类别信息的JSON文件
*/
function exportModel() {
if (classifier.getNumClasses() === 0) {
alert('模型中还没有任何样本,无法导出!');
return;
}
const dataset = classifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach((key) => {
const data = dataset[key];
datasetObj[key] = data.arraySync();
});
// 📌 导出格式大更新: 同时保存 classMap 和 dataset
const modelData = {
classMap: appState.classMap,
dataset: datasetObj
};
const jsonStr = JSON.stringify(modelData);
const blob = new Blob([jsonStr], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `pose-knn-model.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
/**
* 从JSON文件导入KNN模型并恢复类别状态
* @param {Event} event
*/
function importModel(event) {
const file = event.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = (e) => {
try {
const modelData = JSON.parse(e.target.result);
// 📌 导入格式验证
if (!modelData.classMap || !modelData.dataset) {
throw new Error("无效的模型文件格式。");
}
// 1. 清理现有状态
classifier.clearAllClasses();
poseClassesContainer.innerHTML = '';
appState.classMap = {};
// 2. 加载新状态
appState.classMap = modelData.classMap;
const classIds = Object.keys(appState.classMap).map(Number);
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
// 3. 恢复UI
classIds.forEach(id => {
createClassUI(id, appState.classMap[id]);
});
// 4. 加载模型数据
const newDataset = {};
Object.keys(modelData.dataset).forEach((key) => {
newDataset[key] = tf.tensor(modelData.dataset[key]);
});
classifier.setClassifierDataset(newDataset);
updateSampleCounts();
checkExportAbility();
alert('模型导入成功!');
} catch (error) {
console.error("导入模型失败:", error);
alert(`导入失败请确保文件是正确的模型JSON文件。\n错误: ${error.message}`);
} finally {
fileImporter.value = '';
}
};
reader.readAsText(file);
}
// --- 辅助和UI更新函数 ---
function flattenPose(pose) {
const keypoints = pose.keypoints.map(p => [p.x / videoElement.videoWidth, p.y / videoElement.videoHeight]).flat();
return tf.tensor(keypoints);
}
function drawPose(pose) {
// 绘制关键点和骨骼...
if (pose.keypoints) {
// 绘制关键点
for (const keypoint of pose.keypoints) {
if (keypoint.score > 0.3) {
canvasCtx.beginPath();
canvasCtx.arc(keypoint.x, keypoint.y, 5, 0, 2 * Math.PI);
canvasCtx.fillStyle = '#1a73e8';
canvasCtx.fill();
}
}
// 绘制骨骼连接线
const adjacentPairs = poseDetection.util.getAdjacentPairs(poseDetection.SupportedModels.MoveNet);
adjacentPairs.forEach(([i, j]) => {
const kp1 = pose.keypoints[i];
const kp2 = pose.keypoints[j];
if (kp1.score > 0.3 && kp2.score > 0.3) {
canvasCtx.beginPath();
canvasCtx.moveTo(kp1.x, kp1.y);
canvasCtx.lineTo(kp2.x, kp2.y);
canvasCtx.strokeStyle = 'blue';
canvasCtx.lineWidth = 2;
canvasCtx.stroke();
}
});
}
}
/**
* 更新所有类别UI上的样本数量
*/
function updateSampleCounts() {
const dataset = classifier.getClassifierDataset();
const allClassElements = document.querySelectorAll('.pose-class');
allClassElements.forEach(el => {
const classId = parseInt(el.dataset.classId, 10);
const classInfo = dataset[classId];
const count = classInfo ? classInfo.shape[0] : 0;
el.querySelector('.sample-count').innerText = `(${count} 样本)`;
});
}
/**
* 根据状态更新UI
*/
function updatePredictionUI() {
const allActionButtons = document.querySelectorAll('.btn-sample, .btn-delete-class, .btn-add-class, #btn-import');
if (isPredicting) {
predictButton.innerText = "停止预测";
predictButton.classList.add('stop');
resultElement.innerText = "正在分析...";
allActionButtons.forEach(btn => btn.disabled = true);
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = true);
checkExportAbility();
} else {
predictButton.innerText = "开始预测";
predictButton.classList.remove('stop');
resultElement.innerText = "已停止";
allActionButtons.forEach(btn => btn.disabled = false);
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = false);
checkExportAbility();
}
// 只有在有类别且有样本时才能预测
predictButton.disabled = isPredicting ? false : classifier.getNumClasses() === 0;
}
function enableControls() {
[predictButton, importButton, exportButton, addClassButton].forEach(btn => btn.disabled = false);
checkExportAbility();
}
/** 检查是否可以导出模型并更新按钮状态 */
function checkExportAbility() {
exportButton.disabled = isPredicting || classifier.getNumClasses() === 0;
}
function cleanup() {
if (detector) detector.dispose();
if (classifier) classifier.clearAllClasses();
if (animationFrameId) cancelAnimationFrame(animationFrameId);
}
// --- 启动应用 ---
window.onbeforeunload = cleanup;
init();