/** * ============================================================================= * 姿态识别与分类核心功能模块 * ============================================================================= * 功能列表: * - 实时姿态检测 (MoveNet) * - KNN 分类器训练 * - 实时姿态预测 * - 坐标数据处理 * ============================================================================= */ 'use strict'; // --- 全局变量和常量 --- let detector, classifier; // 📌 核心状态管理: 使用一个对象来管理所有动态状态 const appState = { classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'} nextClassId: 0 // 用于生成唯一的 classId }; // --- 主要功能函数 --- /** * 初始化姿态检测器和分类器 * @returns {Promise} */ async function init() { try { classifier = knnClassifier.create(); const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }; detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig); return Promise.resolve(); } catch (error) { return Promise.reject(error); } } /** * 姿态类别管理功能模块 */ const ClassManager = { /** * 添加一个新的姿态类别 * @param {string} className - 类别名称 * @returns {number} 新增类别的ID */ addClass(className) { const classId = appState.nextClassId; appState.classMap[classId] = className || `Class ${classId + 1}`; appState.nextClassId++; return classId; }, /** * 重命名一个姿态类别 * @param {number} classId - 类别ID * @param {string} newName - 新的类别名称 */ renameClass(classId, newName) { if (appState.classMap.hasOwnProperty(classId)) { appState.classMap[classId] = newName; } }, /** * 删除一个指定的姿态类别 * @param {number} classId - 要删除的类别的ID */ deleteClass(classId) { // 从状态和分类器中移除 delete appState.classMap[classId]; classifier.clearClass(classId); }, /** * 获取所有类别信息 * @returns {Object} 类别映射表 */ getClassMap() { return appState.classMap; }, /** * 获取类别数量 * @returns {number} 类别数量 */ getNumClasses() { return classifier.getNumClasses(); } }; /** * 样本训练与预测功能模块 */ const Trainer = { /** * 采集一个姿态样本并添加到KNN分类器 * @param {ImageData|HTMLVideoElement} source - 图像数据源 * @param {number} classId 类别的ID * @returns {Promise} */ async addExample(source, classId) { const poses = await detector.estimatePoses(source, { flipHorizontal: true }); if (poses && poses.length > 0) { const poseTensor = flattenPose(poses[0]); classifier.addExample(poseTensor, classId); poseTensor.dispose(); return Promise.resolve(); } else { return Promise.reject(new Error('未检测到姿态')); } }, /** * 对输入的姿态进行预测 * @param {ImageData|HTMLVideoElement} source - 图像数据源 * @param {number} k - KNN中的K值,默认为3 * @returns {Promise} 预测结果,包含标签和置信度 */ async predict(source, k = 3) { if (classifier.getNumClasses() === 0) { return Promise.reject(new Error("分类器中没有样本")); } const poses = await detector.estimatePoses(source, { flipHorizontal: true }); if (poses && poses.length > 0) { const poseTensor = flattenPose(poses[0]); const result = await classifier.predictClass(poseTensor, k); poseTensor.dispose(); // 动态获取类别名称 const predictedClassName = appState.classMap[result.label] || '未知类别'; return Promise.resolve({ classId: result.label, className: predictedClassName, confidence: result.confidences[result.label], confidences: result.confidences }); } else { return Promise.reject(new Error("未检测到姿态")); } }, /** * 获取各类别样本数量 * @returns {Object} 各类别样本数量统计 */ getExampleCounts() { return classifier.getClassifierDataset(); } }; /** * 模型导入导出功能模块 */ const ModelManager = { /** * 导出KNN模型为包含类别信息的JSON对象 * @returns {Object} 包含classMap和dataset的模型数据 */ exportModel() { if (classifier.getNumClasses() === 0) { throw new Error('模型中还没有任何样本,无法导出!'); } const dataset = classifier.getClassifierDataset(); const datasetObj = {}; Object.keys(dataset).forEach((key) => { const data = dataset[key]; datasetObj[key] = data.arraySync(); }); // 导出格式: 同时保存 classMap 和 dataset return { classMap: appState.classMap, dataset: datasetObj }; }, /** * 从JSON对象导入KNN模型并恢复类别状态 * @param {Object} modelData - 模型数据,包含classMap和dataset */ importModel(modelData) { // 导入格式验证 if (!modelData.classMap || !modelData.dataset) { throw new Error("无效的模型文件格式。"); } // 1. 清理现有状态 classifier.clearAllClasses(); appState.classMap = {}; // 2. 加载新状态 appState.classMap = modelData.classMap; const classIds = Object.keys(appState.classMap).map(Number); appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0; // 3. 加载模型数据 const newDataset = {}; Object.keys(modelData.dataset).forEach((key) => { newDataset[key] = tf.tensor(modelData.dataset[key]); }); classifier.setClassifierDataset(newDataset); } }; // --- 辅助函数 --- /** * 将姿态关键点坐标扁平化为一维张量 * @param {Object} pose - 姿态对象 * @returns {tf.Tensor} 扁平化的张量 */ function flattenPose(pose) { // 假设调用者提供视频元素或图像尺寸 // 如果没有提供,则使用默认值 const width = 640; // 默认宽度 const height = 480; // 默认高度 const keypoints = pose.keypoints.map(p => [p.x / width, p.y / height]).flat(); return tf.tensor(keypoints); } /** * 清理资源 */ function cleanup() { // 还原所有状态 if (detector) detector.dispose(); // 清理检测器 if (classifier) classifier.clearAllClasses(); // 清理分类器 } // --- 导出公共接口 --- window.PoseClassifier = { init, ClassManager, Trainer, ModelManager, cleanup };