[CF]新增poseClassifier.js实现模型主要功能
This commit is contained in:
parent
35d18deb43
commit
6feb9b5d21
@ -42,6 +42,7 @@ css样式中描绘关节骨架与实际对比偏小,现在使用的方法是
|
|||||||
|
|
||||||
```
|
```
|
||||||
/
|
/
|
||||||
|
├── poseClassifier.js # 核心逻辑脚本,包含所有功能实现
|
||||||
├── index.html # 应用主页面,包含所有UI元素
|
├── index.html # 应用主页面,包含所有UI元素
|
||||||
├── style.css # 页面样式文件
|
├── style.css # 页面样式文件
|
||||||
└── script.js # 核心逻辑脚本,包含所有功能实现
|
└── script.js # 核心逻辑脚本,包含所有功能实现
|
||||||
|
243
姿态分类/poseClassifier.js
Normal file
243
姿态分类/poseClassifier.js
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
/**
|
||||||
|
* =============================================================================
|
||||||
|
* 姿态识别与分类核心功能模块
|
||||||
|
* =============================================================================
|
||||||
|
* 功能列表:
|
||||||
|
* - 实时姿态检测 (MoveNet)
|
||||||
|
* - KNN 分类器训练
|
||||||
|
* - 实时姿态预测
|
||||||
|
* - 坐标数据处理
|
||||||
|
* =============================================================================
|
||||||
|
*/
|
||||||
|
|
||||||
|
'use strict';
|
||||||
|
|
||||||
|
// --- 全局变量和常量 ---
|
||||||
|
let detector, classifier;
|
||||||
|
|
||||||
|
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
|
||||||
|
const appState = {
|
||||||
|
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
|
||||||
|
nextClassId: 0 // 用于生成唯一的 classId
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- 主要功能函数 ---
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 初始化姿态检测器和分类器
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
async function init() {
|
||||||
|
try {
|
||||||
|
classifier = knnClassifier.create();
|
||||||
|
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
|
||||||
|
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
|
||||||
|
return Promise.resolve();
|
||||||
|
} catch (error) {
|
||||||
|
return Promise.reject(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 姿态类别管理功能模块
|
||||||
|
*/
|
||||||
|
const ClassManager = {
|
||||||
|
/**
|
||||||
|
* 添加一个新的姿态类别
|
||||||
|
* @param {string} className - 类别名称
|
||||||
|
* @returns {number} 新增类别的ID
|
||||||
|
*/
|
||||||
|
addClass(className) {
|
||||||
|
const classId = appState.nextClassId;
|
||||||
|
appState.classMap[classId] = className || `Class ${classId + 1}`;
|
||||||
|
appState.nextClassId++;
|
||||||
|
return classId;
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重命名一个姿态类别
|
||||||
|
* @param {number} classId - 类别ID
|
||||||
|
* @param {string} newName - 新的类别名称
|
||||||
|
*/
|
||||||
|
renameClass(classId, newName) {
|
||||||
|
if (appState.classMap.hasOwnProperty(classId)) {
|
||||||
|
appState.classMap[classId] = newName;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除一个指定的姿态类别
|
||||||
|
* @param {number} classId - 要删除的类别的ID
|
||||||
|
*/
|
||||||
|
deleteClass(classId) {
|
||||||
|
// 从状态和分类器中移除
|
||||||
|
delete appState.classMap[classId];
|
||||||
|
classifier.clearClass(classId);
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有类别信息
|
||||||
|
* @returns {Object} 类别映射表
|
||||||
|
*/
|
||||||
|
getClassMap() {
|
||||||
|
return appState.classMap;
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取类别数量
|
||||||
|
* @returns {number} 类别数量
|
||||||
|
*/
|
||||||
|
getNumClasses() {
|
||||||
|
return classifier.getNumClasses();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 样本训练与预测功能模块
|
||||||
|
*/
|
||||||
|
const Trainer = {
|
||||||
|
/**
|
||||||
|
* 采集一个姿态样本并添加到KNN分类器
|
||||||
|
* @param {ImageData|HTMLVideoElement} source - 图像数据源
|
||||||
|
* @param {number} classId 类别的ID
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
async addExample(source, classId) {
|
||||||
|
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
|
||||||
|
if (poses && poses.length > 0) {
|
||||||
|
const poseTensor = flattenPose(poses[0]);
|
||||||
|
classifier.addExample(poseTensor, classId);
|
||||||
|
poseTensor.dispose();
|
||||||
|
return Promise.resolve();
|
||||||
|
} else {
|
||||||
|
return Promise.reject(new Error('未检测到姿态'));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 对输入的姿态进行预测
|
||||||
|
* @param {ImageData|HTMLVideoElement} source - 图像数据源
|
||||||
|
* @param {number} k - KNN中的K值,默认为3
|
||||||
|
* @returns {Promise<Object>} 预测结果,包含标签和置信度
|
||||||
|
*/
|
||||||
|
async predict(source, k = 3) {
|
||||||
|
if (classifier.getNumClasses() === 0) {
|
||||||
|
return Promise.reject(new Error("分类器中没有样本"));
|
||||||
|
}
|
||||||
|
|
||||||
|
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
|
||||||
|
if (poses && poses.length > 0) {
|
||||||
|
const poseTensor = flattenPose(poses[0]);
|
||||||
|
const result = await classifier.predictClass(poseTensor, k);
|
||||||
|
poseTensor.dispose();
|
||||||
|
|
||||||
|
// 动态获取类别名称
|
||||||
|
const predictedClassName = appState.classMap[result.label] || '未知类别';
|
||||||
|
|
||||||
|
return Promise.resolve({
|
||||||
|
classId: result.label,
|
||||||
|
className: predictedClassName,
|
||||||
|
confidence: result.confidences[result.label],
|
||||||
|
confidences: result.confidences
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
return Promise.reject(new Error("未检测到姿态"));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取各类别样本数量
|
||||||
|
* @returns {Object} 各类别样本数量统计
|
||||||
|
*/
|
||||||
|
getExampleCounts() {
|
||||||
|
return classifier.getClassifierDataset();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 模型导入导出功能模块
|
||||||
|
*/
|
||||||
|
const ModelManager = {
|
||||||
|
/**
|
||||||
|
* 导出KNN模型为包含类别信息的JSON对象
|
||||||
|
* @returns {Object} 包含classMap和dataset的模型数据
|
||||||
|
*/
|
||||||
|
exportModel() {
|
||||||
|
if (classifier.getNumClasses() === 0) {
|
||||||
|
throw new Error('模型中还没有任何样本,无法导出!');
|
||||||
|
}
|
||||||
|
|
||||||
|
const dataset = classifier.getClassifierDataset();
|
||||||
|
const datasetObj = {};
|
||||||
|
Object.keys(dataset).forEach((key) => {
|
||||||
|
const data = dataset[key];
|
||||||
|
datasetObj[key] = data.arraySync();
|
||||||
|
});
|
||||||
|
|
||||||
|
// 导出格式: 同时保存 classMap 和 dataset
|
||||||
|
return {
|
||||||
|
classMap: appState.classMap,
|
||||||
|
dataset: datasetObj
|
||||||
|
};
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从JSON对象导入KNN模型并恢复类别状态
|
||||||
|
* @param {Object} modelData - 模型数据,包含classMap和dataset
|
||||||
|
*/
|
||||||
|
importModel(modelData) {
|
||||||
|
// 导入格式验证
|
||||||
|
if (!modelData.classMap || !modelData.dataset) {
|
||||||
|
throw new Error("无效的模型文件格式。");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 清理现有状态
|
||||||
|
classifier.clearAllClasses();
|
||||||
|
appState.classMap = {};
|
||||||
|
|
||||||
|
// 2. 加载新状态
|
||||||
|
appState.classMap = modelData.classMap;
|
||||||
|
const classIds = Object.keys(appState.classMap).map(Number);
|
||||||
|
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
|
||||||
|
|
||||||
|
// 3. 加载模型数据
|
||||||
|
const newDataset = {};
|
||||||
|
Object.keys(modelData.dataset).forEach((key) => {
|
||||||
|
newDataset[key] = tf.tensor(modelData.dataset[key]);
|
||||||
|
});
|
||||||
|
classifier.setClassifierDataset(newDataset);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- 辅助函数 ---
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将姿态关键点坐标扁平化为一维张量
|
||||||
|
* @param {Object} pose - 姿态对象
|
||||||
|
* @returns {tf.Tensor} 扁平化的张量
|
||||||
|
*/
|
||||||
|
function flattenPose(pose) {
|
||||||
|
// 假设调用者提供视频元素或图像尺寸
|
||||||
|
// 如果没有提供,则使用默认值
|
||||||
|
const width = 640; // 默认宽度
|
||||||
|
const height = 480; // 默认高度
|
||||||
|
const keypoints = pose.keypoints.map(p => [p.x / width, p.y / height]).flat();
|
||||||
|
return tf.tensor(keypoints);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理资源
|
||||||
|
*/
|
||||||
|
function cleanup() {
|
||||||
|
if (detector) detector.dispose();
|
||||||
|
if (classifier) classifier.clearAllClasses();
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- 导出公共接口 ---
|
||||||
|
window.PoseClassifier = {
|
||||||
|
init,
|
||||||
|
ClassManager,
|
||||||
|
Trainer,
|
||||||
|
ModelManager,
|
||||||
|
cleanup
|
||||||
|
};
|
Loading…
x
Reference in New Issue
Block a user