244 lines
7.1 KiB
JavaScript
244 lines
7.1 KiB
JavaScript
/**
|
||
* =============================================================================
|
||
* 姿态识别与分类核心功能模块
|
||
* =============================================================================
|
||
* 功能列表:
|
||
* - 实时姿态检测 (MoveNet)
|
||
* - KNN 分类器训练
|
||
* - 实时姿态预测
|
||
* - 坐标数据处理
|
||
* =============================================================================
|
||
*/
|
||
|
||
'use strict';
|
||
|
||
// --- 全局变量和常量 ---
|
||
let detector, classifier;
|
||
|
||
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
|
||
const appState = {
|
||
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
|
||
nextClassId: 0 // 用于生成唯一的 classId
|
||
};
|
||
|
||
// --- 主要功能函数 ---
|
||
|
||
/**
|
||
* 初始化姿态检测器和分类器
|
||
* @returns {Promise<void>}
|
||
*/
|
||
async function init() {
|
||
try {
|
||
classifier = knnClassifier.create();
|
||
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
|
||
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
|
||
return Promise.resolve();
|
||
} catch (error) {
|
||
return Promise.reject(error);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 姿态类别管理功能模块
|
||
*/
|
||
const ClassManager = {
|
||
/**
|
||
* 添加一个新的姿态类别
|
||
* @param {string} className - 类别名称
|
||
* @returns {number} 新增类别的ID
|
||
*/
|
||
addClass(className) {
|
||
const classId = appState.nextClassId;
|
||
appState.classMap[classId] = className || `Class ${classId + 1}`;
|
||
appState.nextClassId++;
|
||
return classId;
|
||
},
|
||
|
||
/**
|
||
* 重命名一个姿态类别
|
||
* @param {number} classId - 类别ID
|
||
* @param {string} newName - 新的类别名称
|
||
*/
|
||
renameClass(classId, newName) {
|
||
if (appState.classMap.hasOwnProperty(classId)) {
|
||
appState.classMap[classId] = newName;
|
||
}
|
||
},
|
||
|
||
/**
|
||
* 删除一个指定的姿态类别
|
||
* @param {number} classId - 要删除的类别的ID
|
||
*/
|
||
deleteClass(classId) {
|
||
// 从状态和分类器中移除
|
||
delete appState.classMap[classId];
|
||
classifier.clearClass(classId);
|
||
},
|
||
|
||
/**
|
||
* 获取所有类别信息
|
||
* @returns {Object} 类别映射表
|
||
*/
|
||
getClassMap() {
|
||
return appState.classMap;
|
||
},
|
||
|
||
/**
|
||
* 获取类别数量
|
||
* @returns {number} 类别数量
|
||
*/
|
||
getNumClasses() {
|
||
return classifier.getNumClasses();
|
||
}
|
||
};
|
||
|
||
/**
|
||
* 样本训练与预测功能模块
|
||
*/
|
||
const Trainer = {
|
||
/**
|
||
* 采集一个姿态样本并添加到KNN分类器
|
||
* @param {ImageData|HTMLVideoElement} source - 图像数据源
|
||
* @param {number} classId 类别的ID
|
||
* @returns {Promise<void>}
|
||
*/
|
||
async addExample(source, classId) {
|
||
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
|
||
if (poses && poses.length > 0) {
|
||
const poseTensor = flattenPose(poses[0]);
|
||
classifier.addExample(poseTensor, classId);
|
||
poseTensor.dispose();
|
||
return Promise.resolve();
|
||
} else {
|
||
return Promise.reject(new Error('未检测到姿态'));
|
||
}
|
||
},
|
||
|
||
/**
|
||
* 对输入的姿态进行预测
|
||
* @param {ImageData|HTMLVideoElement} source - 图像数据源
|
||
* @param {number} k - KNN中的K值,默认为3
|
||
* @returns {Promise<Object>} 预测结果,包含标签和置信度
|
||
*/
|
||
async predict(source, k = 3) {
|
||
if (classifier.getNumClasses() === 0) {
|
||
return Promise.reject(new Error("分类器中没有样本"));
|
||
}
|
||
|
||
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
|
||
if (poses && poses.length > 0) {
|
||
const poseTensor = flattenPose(poses[0]);
|
||
const result = await classifier.predictClass(poseTensor, k);
|
||
poseTensor.dispose();
|
||
|
||
// 动态获取类别名称
|
||
const predictedClassName = appState.classMap[result.label] || '未知类别';
|
||
|
||
return Promise.resolve({
|
||
classId: result.label,
|
||
className: predictedClassName,
|
||
confidence: result.confidences[result.label],
|
||
confidences: result.confidences
|
||
});
|
||
} else {
|
||
return Promise.reject(new Error("未检测到姿态"));
|
||
}
|
||
},
|
||
|
||
/**
|
||
* 获取各类别样本数量
|
||
* @returns {Object} 各类别样本数量统计
|
||
*/
|
||
getExampleCounts() {
|
||
return classifier.getClassifierDataset();
|
||
}
|
||
};
|
||
|
||
/**
|
||
* 模型导入导出功能模块
|
||
*/
|
||
const ModelManager = {
|
||
/**
|
||
* 导出KNN模型为包含类别信息的JSON对象
|
||
* @returns {Object} 包含classMap和dataset的模型数据
|
||
*/
|
||
exportModel() {
|
||
if (classifier.getNumClasses() === 0) {
|
||
throw new Error('模型中还没有任何样本,无法导出!');
|
||
}
|
||
|
||
const dataset = classifier.getClassifierDataset();
|
||
const datasetObj = {};
|
||
Object.keys(dataset).forEach((key) => {
|
||
const data = dataset[key];
|
||
datasetObj[key] = data.arraySync();
|
||
});
|
||
|
||
// 导出格式: 同时保存 classMap 和 dataset
|
||
return {
|
||
classMap: appState.classMap,
|
||
dataset: datasetObj
|
||
};
|
||
},
|
||
|
||
/**
|
||
* 从JSON对象导入KNN模型并恢复类别状态
|
||
* @param {Object} modelData - 模型数据,包含classMap和dataset
|
||
*/
|
||
importModel(modelData) {
|
||
// 导入格式验证
|
||
if (!modelData.classMap || !modelData.dataset) {
|
||
throw new Error("无效的模型文件格式。");
|
||
}
|
||
|
||
// 1. 清理现有状态
|
||
classifier.clearAllClasses();
|
||
appState.classMap = {};
|
||
|
||
// 2. 加载新状态
|
||
appState.classMap = modelData.classMap;
|
||
const classIds = Object.keys(appState.classMap).map(Number);
|
||
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
|
||
|
||
// 3. 加载模型数据
|
||
const newDataset = {};
|
||
Object.keys(modelData.dataset).forEach((key) => {
|
||
newDataset[key] = tf.tensor(modelData.dataset[key]);
|
||
});
|
||
classifier.setClassifierDataset(newDataset);
|
||
}
|
||
};
|
||
|
||
// --- 辅助函数 ---
|
||
|
||
/**
|
||
* 将姿态关键点坐标扁平化为一维张量
|
||
* @param {Object} pose - 姿态对象
|
||
* @returns {tf.Tensor} 扁平化的张量
|
||
*/
|
||
function flattenPose(pose) {
|
||
// 假设调用者提供视频元素或图像尺寸
|
||
// 如果没有提供,则使用默认值
|
||
const width = 640; // 默认宽度
|
||
const height = 480; // 默认高度
|
||
const keypoints = pose.keypoints.map(p => [p.x / width, p.y / height]).flat();
|
||
return tf.tensor(keypoints);
|
||
}
|
||
|
||
/**
|
||
* 清理资源
|
||
*/
|
||
function cleanup() {
|
||
// 还原所有状态
|
||
if (detector) detector.dispose(); // 清理检测器
|
||
if (classifier) classifier.clearAllClasses(); // 清理分类器
|
||
}
|
||
|
||
// --- 导出公共接口 ---
|
||
window.PoseClassifier = {
|
||
init,
|
||
ClassManager,
|
||
Trainer,
|
||
ModelManager,
|
||
cleanup
|
||
}; |