diff --git a/姿态分类/README.md b/姿态分类/README.md index 673ab93..91b8509 100644 --- a/姿态分类/README.md +++ b/姿态分类/README.md @@ -42,6 +42,7 @@ css样式中描绘关节骨架与实际对比偏小,现在使用的方法是 ``` / +├── poseClassifier.js # 核心逻辑脚本,包含所有功能实现 ├── index.html # 应用主页面,包含所有UI元素 ├── style.css # 页面样式文件 └── script.js # 核心逻辑脚本,包含所有功能实现 diff --git a/姿态分类/poseClassifier.js b/姿态分类/poseClassifier.js new file mode 100644 index 0000000..78b5fde --- /dev/null +++ b/姿态分类/poseClassifier.js @@ -0,0 +1,243 @@ +/** + * ============================================================================= + * 姿态识别与分类核心功能模块 + * ============================================================================= + * 功能列表: + * - 实时姿态检测 (MoveNet) + * - KNN 分类器训练 + * - 实时姿态预测 + * - 坐标数据处理 + * ============================================================================= + */ + +'use strict'; + +// --- 全局变量和常量 --- +let detector, classifier; + +// 📌 核心状态管理: 使用一个对象来管理所有动态状态 +const appState = { + classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'} + nextClassId: 0 // 用于生成唯一的 classId +}; + +// --- 主要功能函数 --- + +/** + * 初始化姿态检测器和分类器 + * @returns {Promise} + */ +async function init() { + try { + classifier = knnClassifier.create(); + const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }; + detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig); + return Promise.resolve(); + } catch (error) { + return Promise.reject(error); + } +} + +/** + * 姿态类别管理功能模块 + */ +const ClassManager = { + /** + * 添加一个新的姿态类别 + * @param {string} className - 类别名称 + * @returns {number} 新增类别的ID + */ + addClass(className) { + const classId = appState.nextClassId; + appState.classMap[classId] = className || `Class ${classId + 1}`; + appState.nextClassId++; + return classId; + }, + + /** + * 重命名一个姿态类别 + * @param {number} classId - 类别ID + * @param {string} newName - 新的类别名称 + */ + renameClass(classId, newName) { + if (appState.classMap.hasOwnProperty(classId)) { + appState.classMap[classId] = newName; + } + }, + + /** + * 删除一个指定的姿态类别 + * @param {number} classId - 要删除的类别的ID + */ + deleteClass(classId) { + // 从状态和分类器中移除 + delete appState.classMap[classId]; + classifier.clearClass(classId); + }, + + /** + * 获取所有类别信息 + * @returns {Object} 类别映射表 + */ + getClassMap() { + return appState.classMap; + }, + + /** + * 获取类别数量 + * @returns {number} 类别数量 + */ + getNumClasses() { + return classifier.getNumClasses(); + } +}; + +/** + * 样本训练与预测功能模块 + */ +const Trainer = { + /** + * 采集一个姿态样本并添加到KNN分类器 + * @param {ImageData|HTMLVideoElement} source - 图像数据源 + * @param {number} classId 类别的ID + * @returns {Promise} + */ + async addExample(source, classId) { + const poses = await detector.estimatePoses(source, { flipHorizontal: true }); + if (poses && poses.length > 0) { + const poseTensor = flattenPose(poses[0]); + classifier.addExample(poseTensor, classId); + poseTensor.dispose(); + return Promise.resolve(); + } else { + return Promise.reject(new Error('未检测到姿态')); + } + }, + + /** + * 对输入的姿态进行预测 + * @param {ImageData|HTMLVideoElement} source - 图像数据源 + * @param {number} k - KNN中的K值,默认为3 + * @returns {Promise} 预测结果,包含标签和置信度 + */ + async predict(source, k = 3) { + if (classifier.getNumClasses() === 0) { + return Promise.reject(new Error("分类器中没有样本")); + } + + const poses = await detector.estimatePoses(source, { flipHorizontal: true }); + if (poses && poses.length > 0) { + const poseTensor = flattenPose(poses[0]); + const result = await classifier.predictClass(poseTensor, k); + poseTensor.dispose(); + + // 动态获取类别名称 + const predictedClassName = appState.classMap[result.label] || '未知类别'; + + return Promise.resolve({ + classId: result.label, + className: predictedClassName, + confidence: result.confidences[result.label], + confidences: result.confidences + }); + } else { + return Promise.reject(new Error("未检测到姿态")); + } + }, + + /** + * 获取各类别样本数量 + * @returns {Object} 各类别样本数量统计 + */ + getExampleCounts() { + return classifier.getClassifierDataset(); + } +}; + +/** + * 模型导入导出功能模块 + */ +const ModelManager = { + /** + * 导出KNN模型为包含类别信息的JSON对象 + * @returns {Object} 包含classMap和dataset的模型数据 + */ + exportModel() { + if (classifier.getNumClasses() === 0) { + throw new Error('模型中还没有任何样本,无法导出!'); + } + + const dataset = classifier.getClassifierDataset(); + const datasetObj = {}; + Object.keys(dataset).forEach((key) => { + const data = dataset[key]; + datasetObj[key] = data.arraySync(); + }); + + // 导出格式: 同时保存 classMap 和 dataset + return { + classMap: appState.classMap, + dataset: datasetObj + }; + }, + + /** + * 从JSON对象导入KNN模型并恢复类别状态 + * @param {Object} modelData - 模型数据,包含classMap和dataset + */ + importModel(modelData) { + // 导入格式验证 + if (!modelData.classMap || !modelData.dataset) { + throw new Error("无效的模型文件格式。"); + } + + // 1. 清理现有状态 + classifier.clearAllClasses(); + appState.classMap = {}; + + // 2. 加载新状态 + appState.classMap = modelData.classMap; + const classIds = Object.keys(appState.classMap).map(Number); + appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0; + + // 3. 加载模型数据 + const newDataset = {}; + Object.keys(modelData.dataset).forEach((key) => { + newDataset[key] = tf.tensor(modelData.dataset[key]); + }); + classifier.setClassifierDataset(newDataset); + } +}; + +// --- 辅助函数 --- + +/** + * 将姿态关键点坐标扁平化为一维张量 + * @param {Object} pose - 姿态对象 + * @returns {tf.Tensor} 扁平化的张量 + */ +function flattenPose(pose) { + // 假设调用者提供视频元素或图像尺寸 + // 如果没有提供,则使用默认值 + const width = 640; // 默认宽度 + const height = 480; // 默认高度 + const keypoints = pose.keypoints.map(p => [p.x / width, p.y / height]).flat(); + return tf.tensor(keypoints); +} + +/** + * 清理资源 + */ +function cleanup() { + if (detector) detector.dispose(); + if (classifier) classifier.clearAllClasses(); +} + +// --- 导出公共接口 --- +window.PoseClassifier = { + init, + ClassManager, + Trainer, + ModelManager, + cleanup +}; \ No newline at end of file