[CF]新增poseClassifier.js实现模型主要功能
This commit is contained in:
parent
35d18deb43
commit
6feb9b5d21
@ -42,6 +42,7 @@ css样式中描绘关节骨架与实际对比偏小,现在使用的方法是
|
||||
|
||||
```
|
||||
/
|
||||
├── poseClassifier.js # 核心逻辑脚本,包含所有功能实现
|
||||
├── index.html # 应用主页面,包含所有UI元素
|
||||
├── style.css # 页面样式文件
|
||||
└── script.js # 核心逻辑脚本,包含所有功能实现
|
||||
|
243
姿态分类/poseClassifier.js
Normal file
243
姿态分类/poseClassifier.js
Normal file
@ -0,0 +1,243 @@
|
||||
/**
|
||||
* =============================================================================
|
||||
* 姿态识别与分类核心功能模块
|
||||
* =============================================================================
|
||||
* 功能列表:
|
||||
* - 实时姿态检测 (MoveNet)
|
||||
* - KNN 分类器训练
|
||||
* - 实时姿态预测
|
||||
* - 坐标数据处理
|
||||
* =============================================================================
|
||||
*/
|
||||
|
||||
'use strict';
|
||||
|
||||
// --- 全局变量和常量 ---
|
||||
let detector, classifier;
|
||||
|
||||
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
|
||||
const appState = {
|
||||
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
|
||||
nextClassId: 0 // 用于生成唯一的 classId
|
||||
};
|
||||
|
||||
// --- 主要功能函数 ---
|
||||
|
||||
/**
|
||||
* 初始化姿态检测器和分类器
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function init() {
|
||||
try {
|
||||
classifier = knnClassifier.create();
|
||||
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
|
||||
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
|
||||
return Promise.resolve();
|
||||
} catch (error) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 姿态类别管理功能模块
|
||||
*/
|
||||
const ClassManager = {
|
||||
/**
|
||||
* 添加一个新的姿态类别
|
||||
* @param {string} className - 类别名称
|
||||
* @returns {number} 新增类别的ID
|
||||
*/
|
||||
addClass(className) {
|
||||
const classId = appState.nextClassId;
|
||||
appState.classMap[classId] = className || `Class ${classId + 1}`;
|
||||
appState.nextClassId++;
|
||||
return classId;
|
||||
},
|
||||
|
||||
/**
|
||||
* 重命名一个姿态类别
|
||||
* @param {number} classId - 类别ID
|
||||
* @param {string} newName - 新的类别名称
|
||||
*/
|
||||
renameClass(classId, newName) {
|
||||
if (appState.classMap.hasOwnProperty(classId)) {
|
||||
appState.classMap[classId] = newName;
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除一个指定的姿态类别
|
||||
* @param {number} classId - 要删除的类别的ID
|
||||
*/
|
||||
deleteClass(classId) {
|
||||
// 从状态和分类器中移除
|
||||
delete appState.classMap[classId];
|
||||
classifier.clearClass(classId);
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取所有类别信息
|
||||
* @returns {Object} 类别映射表
|
||||
*/
|
||||
getClassMap() {
|
||||
return appState.classMap;
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取类别数量
|
||||
* @returns {number} 类别数量
|
||||
*/
|
||||
getNumClasses() {
|
||||
return classifier.getNumClasses();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* 样本训练与预测功能模块
|
||||
*/
|
||||
const Trainer = {
|
||||
/**
|
||||
* 采集一个姿态样本并添加到KNN分类器
|
||||
* @param {ImageData|HTMLVideoElement} source - 图像数据源
|
||||
* @param {number} classId 类别的ID
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async addExample(source, classId) {
|
||||
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
|
||||
if (poses && poses.length > 0) {
|
||||
const poseTensor = flattenPose(poses[0]);
|
||||
classifier.addExample(poseTensor, classId);
|
||||
poseTensor.dispose();
|
||||
return Promise.resolve();
|
||||
} else {
|
||||
return Promise.reject(new Error('未检测到姿态'));
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 对输入的姿态进行预测
|
||||
* @param {ImageData|HTMLVideoElement} source - 图像数据源
|
||||
* @param {number} k - KNN中的K值,默认为3
|
||||
* @returns {Promise<Object>} 预测结果,包含标签和置信度
|
||||
*/
|
||||
async predict(source, k = 3) {
|
||||
if (classifier.getNumClasses() === 0) {
|
||||
return Promise.reject(new Error("分类器中没有样本"));
|
||||
}
|
||||
|
||||
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
|
||||
if (poses && poses.length > 0) {
|
||||
const poseTensor = flattenPose(poses[0]);
|
||||
const result = await classifier.predictClass(poseTensor, k);
|
||||
poseTensor.dispose();
|
||||
|
||||
// 动态获取类别名称
|
||||
const predictedClassName = appState.classMap[result.label] || '未知类别';
|
||||
|
||||
return Promise.resolve({
|
||||
classId: result.label,
|
||||
className: predictedClassName,
|
||||
confidence: result.confidences[result.label],
|
||||
confidences: result.confidences
|
||||
});
|
||||
} else {
|
||||
return Promise.reject(new Error("未检测到姿态"));
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取各类别样本数量
|
||||
* @returns {Object} 各类别样本数量统计
|
||||
*/
|
||||
getExampleCounts() {
|
||||
return classifier.getClassifierDataset();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* 模型导入导出功能模块
|
||||
*/
|
||||
const ModelManager = {
|
||||
/**
|
||||
* 导出KNN模型为包含类别信息的JSON对象
|
||||
* @returns {Object} 包含classMap和dataset的模型数据
|
||||
*/
|
||||
exportModel() {
|
||||
if (classifier.getNumClasses() === 0) {
|
||||
throw new Error('模型中还没有任何样本,无法导出!');
|
||||
}
|
||||
|
||||
const dataset = classifier.getClassifierDataset();
|
||||
const datasetObj = {};
|
||||
Object.keys(dataset).forEach((key) => {
|
||||
const data = dataset[key];
|
||||
datasetObj[key] = data.arraySync();
|
||||
});
|
||||
|
||||
// 导出格式: 同时保存 classMap 和 dataset
|
||||
return {
|
||||
classMap: appState.classMap,
|
||||
dataset: datasetObj
|
||||
};
|
||||
},
|
||||
|
||||
/**
|
||||
* 从JSON对象导入KNN模型并恢复类别状态
|
||||
* @param {Object} modelData - 模型数据,包含classMap和dataset
|
||||
*/
|
||||
importModel(modelData) {
|
||||
// 导入格式验证
|
||||
if (!modelData.classMap || !modelData.dataset) {
|
||||
throw new Error("无效的模型文件格式。");
|
||||
}
|
||||
|
||||
// 1. 清理现有状态
|
||||
classifier.clearAllClasses();
|
||||
appState.classMap = {};
|
||||
|
||||
// 2. 加载新状态
|
||||
appState.classMap = modelData.classMap;
|
||||
const classIds = Object.keys(appState.classMap).map(Number);
|
||||
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
|
||||
|
||||
// 3. 加载模型数据
|
||||
const newDataset = {};
|
||||
Object.keys(modelData.dataset).forEach((key) => {
|
||||
newDataset[key] = tf.tensor(modelData.dataset[key]);
|
||||
});
|
||||
classifier.setClassifierDataset(newDataset);
|
||||
}
|
||||
};
|
||||
|
||||
// --- 辅助函数 ---
|
||||
|
||||
/**
|
||||
* 将姿态关键点坐标扁平化为一维张量
|
||||
* @param {Object} pose - 姿态对象
|
||||
* @returns {tf.Tensor} 扁平化的张量
|
||||
*/
|
||||
function flattenPose(pose) {
|
||||
// 假设调用者提供视频元素或图像尺寸
|
||||
// 如果没有提供,则使用默认值
|
||||
const width = 640; // 默认宽度
|
||||
const height = 480; // 默认高度
|
||||
const keypoints = pose.keypoints.map(p => [p.x / width, p.y / height]).flat();
|
||||
return tf.tensor(keypoints);
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理资源
|
||||
*/
|
||||
function cleanup() {
|
||||
if (detector) detector.dispose();
|
||||
if (classifier) classifier.clearAllClasses();
|
||||
}
|
||||
|
||||
// --- 导出公共接口 ---
|
||||
window.PoseClassifier = {
|
||||
init,
|
||||
ClassManager,
|
||||
Trainer,
|
||||
ModelManager,
|
||||
cleanup
|
||||
};
|
Loading…
x
Reference in New Issue
Block a user