406 lines
14 KiB
JavaScript
406 lines
14 KiB
JavaScript
/**
|
||
* =============================================================================
|
||
* 动态版 - 姿态识别与模型管理脚本 (v2.0)
|
||
* =============================================================================
|
||
* 功能列表:
|
||
* - 实时姿态检测 (MoveNet)
|
||
* - KNN 分类器训练
|
||
* - 实时姿态预测
|
||
* - 坐标完美对齐 (Canvas与Video重叠)
|
||
* - ✅ 动态添加/删除/重命名姿态类别
|
||
* - ✅ 模型导出为包含类别信息的 JSON 文件
|
||
* - ✅ 从 JSON 文件导入模型并恢复类别状态
|
||
* =============================================================================
|
||
*/
|
||
|
||
'use strict';
|
||
|
||
// --- 全局变量和常量 ---
|
||
const videoElement = document.getElementById('video');
|
||
const canvasElement = document.getElementById('canvas');
|
||
const canvasCtx = canvasElement.getContext('2d');
|
||
const statusElement = document.getElementById('status');
|
||
const resultElement = document.getElementById('result-text');
|
||
|
||
// UI元素
|
||
const poseClassesContainer = document.getElementById('pose-classes-container');
|
||
const addClassButton = document.getElementById('btn-add-class');
|
||
const predictButton = document.getElementById('btn-predict');
|
||
const exportButton = document.getElementById('btn-export');
|
||
const importButton = document.getElementById('btn-import');
|
||
const fileImporter = document.getElementById('file-importer');
|
||
|
||
let detector, classifier, animationFrameId;
|
||
let isPredicting = false;
|
||
|
||
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
|
||
const appState = {
|
||
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
|
||
nextClassId: 0 // 用于生成唯一的 classId
|
||
};
|
||
|
||
// --- 主应用逻辑 ---
|
||
|
||
/**
|
||
* 初始化应用,加载模型并设置摄像头
|
||
*/
|
||
async function init() {
|
||
try {
|
||
classifier = knnClassifier.create();
|
||
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
|
||
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
|
||
|
||
await setupCamera();
|
||
setupEventListeners();
|
||
mainLoop();
|
||
|
||
statusElement.innerText = "模型和摄像头已就绪!";
|
||
enableControls();
|
||
addNewClass(); // 默认创建第一个类别
|
||
|
||
} catch (error) {
|
||
console.error("初始化失败:", error);
|
||
statusElement.innerText = "初始化失败,请检查摄像头权限或刷新。";
|
||
statusElement.style.backgroundColor = '#fce8e6';
|
||
statusElement.style.color = '#d93025';
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 设置和启动用户摄像头
|
||
*/
|
||
async function setupCamera() {
|
||
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
|
||
videoElement.srcObject = stream;
|
||
return new Promise((resolve) => {
|
||
videoElement.onloadedmetadata = () => {
|
||
videoElement.play();
|
||
canvasElement.width = videoElement.videoWidth;
|
||
canvasElement.height = videoElement.videoHeight;
|
||
resolve();
|
||
};
|
||
});
|
||
}
|
||
|
||
/**
|
||
* 为所有交互式元素绑定事件监听器
|
||
*/
|
||
function setupEventListeners() {
|
||
addClassButton.addEventListener('click', addNewClass);
|
||
predictButton.addEventListener('click', togglePrediction);
|
||
exportButton.addEventListener('click', exportModel);
|
||
importButton.addEventListener('click', () => fileImporter.click());
|
||
fileImporter.addEventListener('change', importModel);
|
||
}
|
||
|
||
// --- 动态类别管理 ---
|
||
|
||
/**
|
||
* 动态创建一个新类别的UI元素并添加到页面
|
||
* @param {number} classId - 类别的唯一ID
|
||
* @param {string} className - 类别的名称
|
||
*/
|
||
function createClassUI(classId, className) {
|
||
const poseClassDiv = document.createElement('div');
|
||
poseClassDiv.className = 'pose-class';
|
||
poseClassDiv.dataset.classId = classId;
|
||
|
||
poseClassDiv.innerHTML = `
|
||
<div class="class-info">
|
||
<input type="text" class="class-name-input" value="${className}" data-class-id="${classId}">
|
||
<span class="sample-count">(0 样本)</span>
|
||
</div>
|
||
<div class="class-actions">
|
||
<button class="btn-sample" data-class-id="${classId}">采集样本</button>
|
||
<button class="btn-delete-class" title="删除类别" data-class-id="${classId}">×</button>
|
||
</div>
|
||
`;
|
||
|
||
poseClassesContainer.appendChild(poseClassDiv);
|
||
|
||
// 为新创建的元素绑定事件
|
||
const nameInput = poseClassDiv.querySelector('.class-name-input');
|
||
nameInput.addEventListener('change', (e) => {
|
||
appState.classMap[classId] = e.target.value;
|
||
});
|
||
|
||
const sampleButton = poseClassDiv.querySelector('.btn-sample');
|
||
sampleButton.addEventListener('click', () => addExample(classId));
|
||
|
||
if (isPredicting) sampleButton.disabled = true; // 如果在预测中,禁用新按钮
|
||
|
||
const deleteButton = poseClassDiv.querySelector('.btn-delete-class');
|
||
deleteButton.addEventListener('click', () => deleteClass(classId));
|
||
}
|
||
|
||
/**
|
||
* 添加一个新的姿态类别
|
||
*/
|
||
function addNewClass() {
|
||
const classId = appState.nextClassId;
|
||
const className = `Class ${classId + 1}`;
|
||
appState.classMap[classId] = className;
|
||
appState.nextClassId++;
|
||
createClassUI(classId, className);
|
||
}
|
||
|
||
/**
|
||
* 删除一个指定的姿态类别
|
||
* @param {number} classId - 要删除的类别的ID
|
||
*/
|
||
function deleteClass(classId) {
|
||
if (confirm(`确定要删除类别 "${appState.classMap[classId]}" 吗?所有样本都将丢失。`)) {
|
||
// 从UI中移除
|
||
const elementToRemove = poseClassesContainer.querySelector(`[data-class-id="${classId}"]`);
|
||
if (elementToRemove) elementToRemove.remove();
|
||
|
||
// 从状态和分类器中移除
|
||
delete appState.classMap[classId];
|
||
classifier.clearClass(classId);
|
||
|
||
updateSampleCounts();
|
||
updatePredictionUI(); // 检查是否还有类别可以预测
|
||
checkExportAbility();
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 采集一个姿态样本并添加到KNN分类器
|
||
* @param {number} classId 类别的ID
|
||
*/
|
||
async function addExample(classId) {
|
||
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
|
||
if (poses && poses.length > 0) {
|
||
const poseTensor = flattenPose(poses[0]);
|
||
classifier.addExample(poseTensor, classId);
|
||
poseTensor.dispose();
|
||
|
||
updateSampleCounts();
|
||
checkExportAbility();
|
||
} else {
|
||
console.warn(`为类别 ${appState.classMap[classId]} 采集样本失败,未检测到姿态。`);
|
||
}
|
||
}
|
||
|
||
// --- 模型与预测逻辑 ---
|
||
|
||
/**
|
||
* 开始或停止姿态预测
|
||
*/
|
||
function togglePrediction() {
|
||
if (classifier.getNumClasses() === 0) {
|
||
alert("请先为至少一个姿态采集样本后再开始预测!");
|
||
return;
|
||
}
|
||
isPredicting = !isPredicting;
|
||
updatePredictionUI();
|
||
}
|
||
|
||
/**
|
||
* 应用的主循环
|
||
*/
|
||
async function mainLoop() {
|
||
const poses = await detector.estimatePoses(videoElement, { flipHorizontal: true });
|
||
canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height);
|
||
|
||
if (poses && poses.length > 0) {
|
||
drawPose(poses[0]);
|
||
if (isPredicting && classifier.getNumClasses() > 0) {
|
||
const poseTensor = flattenPose(poses[0]);
|
||
const result = await classifier.predictClass(poseTensor, 3);
|
||
poseTensor.dispose();
|
||
|
||
const confidence = Math.round(result.confidences[result.label] * 100);
|
||
// 📌 动态获取类别名称
|
||
const predictedClassName = appState.classMap[result.label] || '未知类别';
|
||
resultElement.innerText = `姿态: ${predictedClassName} (${confidence}%)`;
|
||
}
|
||
}
|
||
animationFrameId = requestAnimationFrame(mainLoop);
|
||
}
|
||
|
||
// --- 模型管理函数 (已更新以支持动态类别) ---
|
||
|
||
/**
|
||
* 导出KNN模型为包含类别信息的JSON文件
|
||
*/
|
||
function exportModel() {
|
||
if (classifier.getNumClasses() === 0) {
|
||
alert('模型中还没有任何样本,无法导出!');
|
||
return;
|
||
}
|
||
|
||
const dataset = classifier.getClassifierDataset();
|
||
const datasetObj = {};
|
||
Object.keys(dataset).forEach((key) => {
|
||
const data = dataset[key];
|
||
datasetObj[key] = data.arraySync();
|
||
});
|
||
|
||
// 📌 导出格式大更新: 同时保存 classMap 和 dataset
|
||
const modelData = {
|
||
classMap: appState.classMap,
|
||
dataset: datasetObj
|
||
};
|
||
|
||
const jsonStr = JSON.stringify(modelData);
|
||
const blob = new Blob([jsonStr], { type: "application/json" });
|
||
const url = URL.createObjectURL(blob);
|
||
const a = document.createElement('a');
|
||
a.href = url;
|
||
a.download = `pose-knn-model.json`;
|
||
document.body.appendChild(a);
|
||
a.click();
|
||
document.body.removeChild(a);
|
||
URL.revokeObjectURL(url);
|
||
}
|
||
|
||
/**
|
||
* 从JSON文件导入KNN模型并恢复类别状态
|
||
* @param {Event} event
|
||
*/
|
||
function importModel(event) {
|
||
const file = event.target.files[0];
|
||
if (!file) return;
|
||
|
||
const reader = new FileReader();
|
||
reader.onload = (e) => {
|
||
try {
|
||
const modelData = JSON.parse(e.target.result);
|
||
|
||
// 📌 导入格式验证
|
||
if (!modelData.classMap || !modelData.dataset) {
|
||
throw new Error("无效的模型文件格式。");
|
||
}
|
||
|
||
// 1. 清理现有状态
|
||
classifier.clearAllClasses();
|
||
poseClassesContainer.innerHTML = '';
|
||
appState.classMap = {};
|
||
|
||
// 2. 加载新状态
|
||
appState.classMap = modelData.classMap;
|
||
const classIds = Object.keys(appState.classMap).map(Number);
|
||
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
|
||
|
||
// 3. 恢复UI
|
||
classIds.forEach(id => {
|
||
createClassUI(id, appState.classMap[id]);
|
||
});
|
||
|
||
// 4. 加载模型数据
|
||
const newDataset = {};
|
||
Object.keys(modelData.dataset).forEach((key) => {
|
||
newDataset[key] = tf.tensor(modelData.dataset[key]);
|
||
});
|
||
classifier.setClassifierDataset(newDataset);
|
||
|
||
updateSampleCounts();
|
||
checkExportAbility();
|
||
alert('模型导入成功!');
|
||
|
||
} catch (error) {
|
||
console.error("导入模型失败:", error);
|
||
alert(`导入失败!请确保文件是正确的模型JSON文件。\n错误: ${error.message}`);
|
||
} finally {
|
||
fileImporter.value = '';
|
||
}
|
||
};
|
||
reader.readAsText(file);
|
||
}
|
||
|
||
|
||
// --- 辅助和UI更新函数 ---
|
||
|
||
function flattenPose(pose) {
|
||
const keypoints = pose.keypoints.map(p => [p.x / videoElement.videoWidth, p.y / videoElement.videoHeight]).flat();
|
||
return tf.tensor(keypoints);
|
||
}
|
||
|
||
function drawPose(pose) {
|
||
// 绘制关键点和骨骼...
|
||
if (pose.keypoints) {
|
||
// 绘制关键点
|
||
for (const keypoint of pose.keypoints) {
|
||
if (keypoint.score > 0.3) {
|
||
canvasCtx.beginPath();
|
||
canvasCtx.arc(keypoint.x, keypoint.y, 5, 0, 2 * Math.PI);
|
||
canvasCtx.fillStyle = '#1a73e8';
|
||
canvasCtx.fill();
|
||
}
|
||
}
|
||
// 绘制骨骼连接线
|
||
const adjacentPairs = poseDetection.util.getAdjacentPairs(poseDetection.SupportedModels.MoveNet);
|
||
adjacentPairs.forEach(([i, j]) => {
|
||
const kp1 = pose.keypoints[i];
|
||
const kp2 = pose.keypoints[j];
|
||
if (kp1.score > 0.3 && kp2.score > 0.3) {
|
||
canvasCtx.beginPath();
|
||
canvasCtx.moveTo(kp1.x, kp1.y);
|
||
canvasCtx.lineTo(kp2.x, kp2.y);
|
||
canvasCtx.strokeStyle = 'blue';
|
||
canvasCtx.lineWidth = 2;
|
||
canvasCtx.stroke();
|
||
}
|
||
});
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 更新所有类别UI上的样本数量
|
||
*/
|
||
function updateSampleCounts() {
|
||
const dataset = classifier.getClassifierDataset();
|
||
const allClassElements = document.querySelectorAll('.pose-class');
|
||
allClassElements.forEach(el => {
|
||
const classId = parseInt(el.dataset.classId, 10);
|
||
const classInfo = dataset[classId];
|
||
const count = classInfo ? classInfo.shape[0] : 0;
|
||
el.querySelector('.sample-count').innerText = `(${count} 样本)`;
|
||
});
|
||
}
|
||
|
||
/**
|
||
* 根据状态更新UI
|
||
*/
|
||
function updatePredictionUI() {
|
||
const allActionButtons = document.querySelectorAll('.btn-sample, .btn-delete-class, .btn-add-class, #btn-import');
|
||
if (isPredicting) {
|
||
predictButton.innerText = "停止预测";
|
||
predictButton.classList.add('stop');
|
||
resultElement.innerText = "正在分析...";
|
||
allActionButtons.forEach(btn => btn.disabled = true);
|
||
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = true);
|
||
checkExportAbility();
|
||
} else {
|
||
predictButton.innerText = "开始预测";
|
||
predictButton.classList.remove('stop');
|
||
resultElement.innerText = "已停止";
|
||
allActionButtons.forEach(btn => btn.disabled = false);
|
||
document.querySelectorAll('.class-name-input').forEach(input => input.disabled = false);
|
||
checkExportAbility();
|
||
}
|
||
// 只有在有类别且有样本时才能预测
|
||
predictButton.disabled = isPredicting ? false : classifier.getNumClasses() === 0;
|
||
}
|
||
|
||
function enableControls() {
|
||
[predictButton, importButton, exportButton, addClassButton].forEach(btn => btn.disabled = false);
|
||
checkExportAbility();
|
||
}
|
||
|
||
/** 检查是否可以导出模型并更新按钮状态 */
|
||
function checkExportAbility() {
|
||
exportButton.disabled = isPredicting || classifier.getNumClasses() === 0;
|
||
}
|
||
|
||
function cleanup() {
|
||
if (detector) detector.dispose();
|
||
if (classifier) classifier.clearAllClasses();
|
||
if (animationFrameId) cancelAnimationFrame(animationFrameId);
|
||
}
|
||
|
||
// --- 启动应用 ---
|
||
window.onbeforeunload = cleanup;
|
||
init();
|