[CF]新增poseClassifier.js实现模型主要功能

This commit is contained in:
51hhh 2025-08-14 14:46:06 +08:00
parent 35d18deb43
commit 6feb9b5d21
2 changed files with 244 additions and 0 deletions

View File

@ -42,6 +42,7 @@ css样式中描绘关节骨架与实际对比偏小现在使用的方法是
```
/
├── poseClassifier.js # 核心逻辑脚本,包含所有功能实现
├── index.html # 应用主页面包含所有UI元素
├── style.css # 页面样式文件
└── script.js # 核心逻辑脚本,包含所有功能实现

View File

@ -0,0 +1,243 @@
/**
* =============================================================================
* 姿态识别与分类核心功能模块
* =============================================================================
* 功能列表:
* - 实时姿态检测 (MoveNet)
* - KNN 分类器训练
* - 实时姿态预测
* - 坐标数据处理
* =============================================================================
*/
'use strict';
// --- 全局变量和常量 ---
let detector, classifier;
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
const appState = {
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
nextClassId: 0 // 用于生成唯一的 classId
};
// --- 主要功能函数 ---
/**
* 初始化姿态检测器和分类器
* @returns {Promise<void>}
*/
async function init() {
try {
classifier = knnClassifier.create();
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
return Promise.resolve();
} catch (error) {
return Promise.reject(error);
}
}
/**
* 姿态类别管理功能模块
*/
const ClassManager = {
/**
* 添加一个新的姿态类别
* @param {string} className - 类别名称
* @returns {number} 新增类别的ID
*/
addClass(className) {
const classId = appState.nextClassId;
appState.classMap[classId] = className || `Class ${classId + 1}`;
appState.nextClassId++;
return classId;
},
/**
* 重命名一个姿态类别
* @param {number} classId - 类别ID
* @param {string} newName - 新的类别名称
*/
renameClass(classId, newName) {
if (appState.classMap.hasOwnProperty(classId)) {
appState.classMap[classId] = newName;
}
},
/**
* 删除一个指定的姿态类别
* @param {number} classId - 要删除的类别的ID
*/
deleteClass(classId) {
// 从状态和分类器中移除
delete appState.classMap[classId];
classifier.clearClass(classId);
},
/**
* 获取所有类别信息
* @returns {Object} 类别映射表
*/
getClassMap() {
return appState.classMap;
},
/**
* 获取类别数量
* @returns {number} 类别数量
*/
getNumClasses() {
return classifier.getNumClasses();
}
};
/**
* 样本训练与预测功能模块
*/
const Trainer = {
/**
* 采集一个姿态样本并添加到KNN分类器
* @param {ImageData|HTMLVideoElement} source - 图像数据源
* @param {number} classId 类别的ID
* @returns {Promise<void>}
*/
async addExample(source, classId) {
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
if (poses && poses.length > 0) {
const poseTensor = flattenPose(poses[0]);
classifier.addExample(poseTensor, classId);
poseTensor.dispose();
return Promise.resolve();
} else {
return Promise.reject(new Error('未检测到姿态'));
}
},
/**
* 对输入的姿态进行预测
* @param {ImageData|HTMLVideoElement} source - 图像数据源
* @param {number} k - KNN中的K值默认为3
* @returns {Promise<Object>} 预测结果包含标签和置信度
*/
async predict(source, k = 3) {
if (classifier.getNumClasses() === 0) {
return Promise.reject(new Error("分类器中没有样本"));
}
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
if (poses && poses.length > 0) {
const poseTensor = flattenPose(poses[0]);
const result = await classifier.predictClass(poseTensor, k);
poseTensor.dispose();
// 动态获取类别名称
const predictedClassName = appState.classMap[result.label] || '未知类别';
return Promise.resolve({
classId: result.label,
className: predictedClassName,
confidence: result.confidences[result.label],
confidences: result.confidences
});
} else {
return Promise.reject(new Error("未检测到姿态"));
}
},
/**
* 获取各类别样本数量
* @returns {Object} 各类别样本数量统计
*/
getExampleCounts() {
return classifier.getClassifierDataset();
}
};
/**
* 模型导入导出功能模块
*/
const ModelManager = {
/**
* 导出KNN模型为包含类别信息的JSON对象
* @returns {Object} 包含classMap和dataset的模型数据
*/
exportModel() {
if (classifier.getNumClasses() === 0) {
throw new Error('模型中还没有任何样本,无法导出!');
}
const dataset = classifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach((key) => {
const data = dataset[key];
datasetObj[key] = data.arraySync();
});
// 导出格式: 同时保存 classMap 和 dataset
return {
classMap: appState.classMap,
dataset: datasetObj
};
},
/**
* 从JSON对象导入KNN模型并恢复类别状态
* @param {Object} modelData - 模型数据包含classMap和dataset
*/
importModel(modelData) {
// 导入格式验证
if (!modelData.classMap || !modelData.dataset) {
throw new Error("无效的模型文件格式。");
}
// 1. 清理现有状态
classifier.clearAllClasses();
appState.classMap = {};
// 2. 加载新状态
appState.classMap = modelData.classMap;
const classIds = Object.keys(appState.classMap).map(Number);
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
// 3. 加载模型数据
const newDataset = {};
Object.keys(modelData.dataset).forEach((key) => {
newDataset[key] = tf.tensor(modelData.dataset[key]);
});
classifier.setClassifierDataset(newDataset);
}
};
// --- 辅助函数 ---
/**
* 将姿态关键点坐标扁平化为一维张量
* @param {Object} pose - 姿态对象
* @returns {tf.Tensor} 扁平化的张量
*/
function flattenPose(pose) {
// 假设调用者提供视频元素或图像尺寸
// 如果没有提供,则使用默认值
const width = 640; // 默认宽度
const height = 480; // 默认高度
const keypoints = pose.keypoints.map(p => [p.x / width, p.y / height]).flat();
return tf.tensor(keypoints);
}
/**
* 清理资源
*/
function cleanup() {
if (detector) detector.dispose();
if (classifier) classifier.clearAllClasses();
}
// --- 导出公共接口 ---
window.PoseClassifier = {
init,
ClassManager,
Trainer,
ModelManager,
cleanup
};