206 lines
9.8 KiB
HTML
206 lines
9.8 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="UTF-8">
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||
<title>动态姿态分类器</title>
|
||
<link rel="stylesheet" href="style.css">
|
||
</head>
|
||
<body>
|
||
<header>
|
||
<h1>动态姿态分类器</h1>
|
||
<p>一个使用 TensorFlow.js 和 MoveNet 实现的实时姿态训练与推理工具</p>
|
||
</header>
|
||
|
||
<main id="main-container">
|
||
<div id="video-wrapper">
|
||
<div id="status">正在加载模型,请稍候...</div>
|
||
<div id="video-container">
|
||
<video id="video" width="640" height="480" autoplay muted playsinline></video>
|
||
<canvas id="canvas" width="640" height="480"></canvas>
|
||
</div>
|
||
</div>
|
||
|
||
<div id="controls-panel">
|
||
<h2>控制面板</h2>
|
||
|
||
<!-- ==================== 训练区域重大更新 ==================== -->
|
||
<div class="control-section" id="training-section">
|
||
<h3>第一步: 训练模型</h3>
|
||
<p>点击下方按钮添加姿态分类,为每个姿态采集足够样本。</p>
|
||
|
||
<!-- 📌 新增: 这是动态添加的姿态类别的容器 -->
|
||
<div id="pose-classes-container">
|
||
<!-- JavaScript 将在此处动态生成类别UI -->
|
||
</div>
|
||
|
||
<!-- 📌 新增: 添加新类别的按钮 -->
|
||
<div class="add-class-wrapper">
|
||
<button id="btn-add-class" class="btn-add-class" disabled>+ 增加分类</button>
|
||
</div>
|
||
</div>
|
||
<!-- ========================================================= -->
|
||
|
||
<div class="control-section">
|
||
<h3>模型管理</h3>
|
||
<div class="model-controls">
|
||
<button id="btn-export" class="btn-sample" disabled>导出模型</button>
|
||
<button id="btn-import" class="btn-sample" disabled>导入模型</button>
|
||
<!-- 这个input是隐藏的,由上面的“导入模型”按钮触发 -->
|
||
<input type="file" id="file-importer" accept=".json" style="display: none;">
|
||
</div>
|
||
</div>
|
||
|
||
|
||
<div class="control-section" id="inference-section">
|
||
<h3>第二步: 开始推理</h3>
|
||
<p>训练完成后,点击下方按钮开始实时预测。</p>
|
||
<button id="btn-predict" class="btn-predict" disabled>开始预测</button>
|
||
<div id="result-container">
|
||
<strong>预测结果:</strong>
|
||
<div id="result-text">尚未开始</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</main>
|
||
|
||
<!-- !!!!!! 核心劫持代码:确保在任何 TF.js 库之前加载 !!!!!! -->
|
||
<script>
|
||
(function() {
|
||
// 定义你的镜像服务器的公共前缀,用于存放 MoveNet 模型文件
|
||
// 例如:'https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/movenet-mirror/'
|
||
// 重要:确保你的镜像服务器的目录结构与原始模型文件的路径部分匹配。
|
||
// 举例:
|
||
// 如果原始是 https://tfhub.dev/deepmind/movenet/singlepose-lightning/4/model.json
|
||
// 那么在你的CDN上,你需要部署为:
|
||
// https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/movenet-mirror/tfhub.dev/deepmind/movenet/singlepose-lightning/4/model.json
|
||
//
|
||
// 如果你的CDN就是直接放了 model.json, group1-shard*of*.bin,那么 MIRROR_BASE_URL 将不含后续路径。
|
||
//
|
||
// **** 根据你前一个回复的镜像路径,我们假设你的镜像结构是: ****
|
||
// https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/fetch/snake_game/model.json
|
||
// https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/fetch/snake_game/group1-shard1of2.bin
|
||
// https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/fetch/snake_game/group1-shard2of2.bin
|
||
//
|
||
// 那么,我们需要将匹配的原始URL路径重写为 MIRROR_SPECIFIC_FILENAME_PREFIX
|
||
// 也就是将类似 "https://tfhub.dev/.../model.json..."
|
||
// 替换为 "https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/fetch/snake_game/model.json"
|
||
const MIRROR_SPECIFIC_FILENAME_PREFIX = 'https://goood-space-assets.oss-cn-beijing.aliyuncs.com/public/fetch/snake_game/';
|
||
|
||
// 定义需要被劫持的原始 URL 的域名模式
|
||
const INTERCEPT_DOMAINS = [
|
||
// 'https://tfhub.dev/',
|
||
'https://tfhub.dev/google/tfjs-model/movenet/singlepose/lightning/4/',
|
||
// 如果实际的最终模型文件仍然解析到 storage.googleapis.com,也需要包含
|
||
// 例如:'https://storage.googleapis.com/tfjs-models/'
|
||
// 或者你观察到的实际的最终 Google Storage 域名
|
||
];
|
||
|
||
// 备份原始的 fetch 函数
|
||
const originalFetch = window.fetch;
|
||
|
||
window.fetch = function(input, init) {
|
||
let url = input;
|
||
if (input instanceof Request) {
|
||
url = input.url;
|
||
}
|
||
|
||
let newUrl = url;
|
||
let isIntercepted = false;
|
||
|
||
// 检查 URL 是否以我们关注的域名开头
|
||
for (const domain of INTERCEPT_DOMAINS) {
|
||
if (url.startsWith(domain)) {
|
||
// 尝试从 URL 中提取文件名 (不包含查询参数)
|
||
// 匹配 model.json 或 group1-shardXof2.bin
|
||
const fileNameMatch = url.match(/(model\.json|group1-shard\dof2\.bin)/);
|
||
if (fileNameMatch) {
|
||
const fileName = fileNameMatch[0]; // 获取匹配到的文件名
|
||
newUrl = MIRROR_SPECIFIC_FILENAME_PREFIX + fileName; // 拼接新的镜像 URL
|
||
isIntercepted = true;
|
||
break; // 找到匹配的域名和文件,停止循环
|
||
}
|
||
}
|
||
}
|
||
|
||
if (isIntercepted) {
|
||
console.warn(`[TFJS Fetch Intercepted] Original: ${url}`);
|
||
console.warn(`[TFJS Fetch Intercepted] Redirecting to: ${newUrl}`);
|
||
|
||
if (input instanceof Request) {
|
||
input = new Request(newUrl, {
|
||
method: input.method,
|
||
headers: input.headers,
|
||
body: input.body,
|
||
referrer: input.referrer,
|
||
referrerPolicy: input.referrerPolicy,
|
||
mode: 'cors',
|
||
credentials: input.credentials,
|
||
cache: 'default',
|
||
redirect: 'follow',
|
||
integrity: undefined, // 移除 integrity 属性以避免校验失败
|
||
signal: input.signal,
|
||
});
|
||
} else {
|
||
input = newUrl;
|
||
}
|
||
}
|
||
|
||
return originalFetch(input, init).catch(error => {
|
||
console.error(`[TFJS Fetch Intercepted Error] Failed to load ${url} (redirected to ${newUrl || url || input}):`, error);
|
||
throw error;
|
||
});
|
||
};
|
||
|
||
// -------------------- 劫持 XMLHttpRequest API (备用安全网) --------------------
|
||
// 尽管 TF.js 主要用 fetch,但安全起见保留 XHR 劫持
|
||
const originalXHR = window.XMLHttpRequest;
|
||
window.XMLHttpRequest = function() {
|
||
const xhr = new originalXHR();
|
||
const originalOpen = xhr.open;
|
||
xhr.open = function(method, url, async = true, user = null, password = null) {
|
||
let newUrl = url;
|
||
let isIntercepted = false;
|
||
|
||
for (const domain of INTERCEPT_DOMAINS) {
|
||
if (url.startsWith(domain)) {
|
||
const fileNameMatch = url.match(/(model\.json|group1-shard\dof\d\.bin)/);
|
||
if (fileNameMatch) {
|
||
const fileName = fileNameMatch[0];
|
||
newUrl = MIRROR_SPECIFIC_FILENAME_PREFIX + fileName;
|
||
isIntercepted = true;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
if (isIntercepted) {
|
||
console.warn(`[TFJS XHR Intercepted] Original: ${url}`);
|
||
console.warn(`[TFJS XHR Intercepted] Redirecting to: ${newUrl}`);
|
||
url = newUrl; // 修改传入 open 的 URL
|
||
}
|
||
|
||
return originalOpen.apply(this, arguments);
|
||
};
|
||
|
||
for (const key in originalXHR) {
|
||
if (originalXHR.hasOwnProperty(key)) {
|
||
window.XMLHttpRequest[key] = originalXHR[key];
|
||
}
|
||
}
|
||
return xhr;
|
||
};
|
||
|
||
})();
|
||
</script>
|
||
|
||
<!-- 引入所有依赖库 -->
|
||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.20.0/dist/tf.min.js"></script>
|
||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/pose-detection@2.1.3/dist/pose-detection.min.js"></script>
|
||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier@1.2.2/dist/knn-classifier.min.js"></script>
|
||
|
||
<!-- 引入我们自己的逻辑脚本 -->
|
||
<script src="script.js"></script>
|
||
</body>
|
||
</html>
|