mobileNet/姿态分类/poseClassifier.js

243 lines
7.0 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* =============================================================================
* 姿态识别与分类核心功能模块
* =============================================================================
* 功能列表:
* - 实时姿态检测 (MoveNet)
* - KNN 分类器训练
* - 实时姿态预测
* - 坐标数据处理
* =============================================================================
*/
'use strict';
// --- 全局变量和常量 ---
let detector, classifier;
// 📌 核心状态管理: 使用一个对象来管理所有动态状态
const appState = {
classMap: {}, // 存储 classId -> className 的映射, e.g., {0: '姿态 A', 1: '姿态 B'}
nextClassId: 0 // 用于生成唯一的 classId
};
// --- 主要功能函数 ---
/**
* 初始化姿态检测器和分类器
* @returns {Promise<void>}
*/
async function init() {
try {
classifier = knnClassifier.create();
const detectorConfig = { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING };
detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, detectorConfig);
return Promise.resolve();
} catch (error) {
return Promise.reject(error);
}
}
/**
* 姿态类别管理功能模块
*/
const ClassManager = {
/**
* 添加一个新的姿态类别
* @param {string} className - 类别名称
* @returns {number} 新增类别的ID
*/
addClass(className) {
const classId = appState.nextClassId;
appState.classMap[classId] = className || `Class ${classId + 1}`;
appState.nextClassId++;
return classId;
},
/**
* 重命名一个姿态类别
* @param {number} classId - 类别ID
* @param {string} newName - 新的类别名称
*/
renameClass(classId, newName) {
if (appState.classMap.hasOwnProperty(classId)) {
appState.classMap[classId] = newName;
}
},
/**
* 删除一个指定的姿态类别
* @param {number} classId - 要删除的类别的ID
*/
deleteClass(classId) {
// 从状态和分类器中移除
delete appState.classMap[classId];
classifier.clearClass(classId);
},
/**
* 获取所有类别信息
* @returns {Object} 类别映射表
*/
getClassMap() {
return appState.classMap;
},
/**
* 获取类别数量
* @returns {number} 类别数量
*/
getNumClasses() {
return classifier.getNumClasses();
}
};
/**
* 样本训练与预测功能模块
*/
const Trainer = {
/**
* 采集一个姿态样本并添加到KNN分类器
* @param {ImageData|HTMLVideoElement} source - 图像数据源
* @param {number} classId 类别的ID
* @returns {Promise<void>}
*/
async addExample(source, classId) {
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
if (poses && poses.length > 0) {
const poseTensor = flattenPose(poses[0]);
classifier.addExample(poseTensor, classId);
poseTensor.dispose();
return Promise.resolve();
} else {
return Promise.reject(new Error('未检测到姿态'));
}
},
/**
* 对输入的姿态进行预测
* @param {ImageData|HTMLVideoElement} source - 图像数据源
* @param {number} k - KNN中的K值默认为3
* @returns {Promise<Object>} 预测结果,包含标签和置信度
*/
async predict(source, k = 3) {
if (classifier.getNumClasses() === 0) {
return Promise.reject(new Error("分类器中没有样本"));
}
const poses = await detector.estimatePoses(source, { flipHorizontal: true });
if (poses && poses.length > 0) {
const poseTensor = flattenPose(poses[0]);
const result = await classifier.predictClass(poseTensor, k);
poseTensor.dispose();
// 动态获取类别名称
const predictedClassName = appState.classMap[result.label] || '未知类别';
return Promise.resolve({
classId: result.label,
className: predictedClassName,
confidence: result.confidences[result.label],
confidences: result.confidences
});
} else {
return Promise.reject(new Error("未检测到姿态"));
}
},
/**
* 获取各类别样本数量
* @returns {Object} 各类别样本数量统计
*/
getExampleCounts() {
return classifier.getClassifierDataset();
}
};
/**
* 模型导入导出功能模块
*/
const ModelManager = {
/**
* 导出KNN模型为包含类别信息的JSON对象
* @returns {Object} 包含classMap和dataset的模型数据
*/
exportModel() {
if (classifier.getNumClasses() === 0) {
throw new Error('模型中还没有任何样本,无法导出!');
}
const dataset = classifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach((key) => {
const data = dataset[key];
datasetObj[key] = data.arraySync();
});
// 导出格式: 同时保存 classMap 和 dataset
return {
classMap: appState.classMap,
dataset: datasetObj
};
},
/**
* 从JSON对象导入KNN模型并恢复类别状态
* @param {Object} modelData - 模型数据包含classMap和dataset
*/
importModel(modelData) {
// 导入格式验证
if (!modelData.classMap || !modelData.dataset) {
throw new Error("无效的模型文件格式。");
}
// 1. 清理现有状态
classifier.clearAllClasses();
appState.classMap = {};
// 2. 加载新状态
appState.classMap = modelData.classMap;
const classIds = Object.keys(appState.classMap).map(Number);
appState.nextClassId = classIds.length > 0 ? Math.max(...classIds) + 1 : 0;
// 3. 加载模型数据
const newDataset = {};
Object.keys(modelData.dataset).forEach((key) => {
newDataset[key] = tf.tensor(modelData.dataset[key]);
});
classifier.setClassifierDataset(newDataset);
}
};
// --- 辅助函数 ---
/**
* 将姿态关键点坐标扁平化为一维张量
* @param {Object} pose - 姿态对象
* @returns {tf.Tensor} 扁平化的张量
*/
function flattenPose(pose) {
// 假设调用者提供视频元素或图像尺寸
// 如果没有提供,则使用默认值
const width = 640; // 默认宽度
const height = 480; // 默认高度
const keypoints = pose.keypoints.map(p => [p.x / width, p.y / height]).flat();
return tf.tensor(keypoints);
}
/**
* 清理资源
*/
function cleanup() {
if (detector) detector.dispose();
if (classifier) classifier.clearAllClasses();
}
// --- 导出公共接口 ---
window.PoseClassifier = {
init,
ClassManager,
Trainer,
ModelManager,
cleanup
};