初始化
This commit is contained in:
commit
8f998b1915
69
README.md
Normal file
69
README.md
Normal file
@ -0,0 +1,69 @@
|
||||
# 图像分类器项目
|
||||
|
||||
本项目包含多种基于MobileNet特征提取的图像分类器实现,包括三层神经网络、随机森林和KNN算法。
|
||||
|
||||
## 模型对比
|
||||
|
||||
| 模型 | 算法类型 | 特点 | 适用场景 |
|
||||
|------|---------|------|---------|
|
||||
| 三层神经网络 | 深度学习 | 自定义全连接网络,支持训练监控、早停、正则化 | 需要高精度、可解释性强的场景 |
|
||||
| 随机森林 | 集成学习 | 多决策树投票,参数直观可调 | 中等规模数据集,需要模型解释性 |
|
||||
| KNN (原版) | 实例学习 | 简单实现,预测结果平滑处理 | 快速原型开发,小规模数据 |
|
||||
| KNN (完善版) | 实例学习 | 增强阈值控制,支持单类别检测 | 异常检测、单类别分类场景 |
|
||||
|
||||
## 详细说明
|
||||
|
||||
### 三层神经网络分类器
|
||||
- 使用MobileNet进行特征提取
|
||||
- 自定义三层全连接网络进行分类
|
||||
- 功能特点:
|
||||
- 训练过程可视化(损失/准确率曲线)
|
||||
- 支持早停、正则化等技巧
|
||||
- 模型保存/加载功能
|
||||
- 温度缩放调整预测置信度
|
||||
|
||||
### 随机森林分类器
|
||||
- 使用MobileNet特征作为输入
|
||||
- 构建多个决策树进行集成分类
|
||||
- 可调参数:
|
||||
- 决策树数量(默认10棵)
|
||||
- 训练集子集比例(默认70%)
|
||||
- 特点:
|
||||
- 训练速度快
|
||||
- 提供ImageNet标签显示功能
|
||||
|
||||
### KNN分类器(原版)
|
||||
- 基于MobileNet特征的K最近邻算法
|
||||
- 特点:
|
||||
- 实现简单
|
||||
- 低通滤波器平滑预测结果
|
||||
- 支持模型保存/加载
|
||||
|
||||
### KNN分类器(完善版)
|
||||
在原版基础上增强:
|
||||
- 距离阈值控制
|
||||
- 自适应阈值计算
|
||||
- 改进的单类别检测
|
||||
- 更详细的训练反馈
|
||||
|
||||
## 使用指南
|
||||
|
||||
1. 选择分类器类型
|
||||
2. 上传各类别训练图片
|
||||
3. 调整模型参数(如适用)
|
||||
4. 点击"训练模型"按钮
|
||||
5. 使用摄像头或上传图片进行预测
|
||||
|
||||
## 技术实现
|
||||
|
||||
所有分类器均基于以下技术:
|
||||
- 特征提取:MobileNet (TensorFlow.js)
|
||||
- 前端框架:纯JavaScript实现
|
||||
- 数据存储:浏览器本地存储(IndexedDB)
|
||||
|
||||
## 开发建议
|
||||
|
||||
- 对于高精度需求:使用三层神经网络
|
||||
- 对于可解释性需求:使用随机森林
|
||||
- 对于快速原型开发:使用KNN
|
||||
- 对于异常检测:使用完善版KNN
|
532
三层神经网络/custom-classifier.html
Normal file
532
三层神经网络/custom-classifier.html
Normal file
@ -0,0 +1,532 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>自定义图像分类器 - TensorFlow.js</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.main-container {
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: white;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
font-size: 2.5em;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.grid-container {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.grid-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 15px;
|
||||
padding: 25px;
|
||||
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.card h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
border-bottom: 2px solid #667eea;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.class-input {
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.class-input h3 {
|
||||
color: #555;
|
||||
margin-bottom: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.class-number {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
border-radius: 50%;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 5px;
|
||||
margin-bottom: 10px;
|
||||
font-size: 16px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
input[type="file"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.file-label {
|
||||
display: inline-block;
|
||||
padding: 10px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: background 0.3s;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.file-label:hover {
|
||||
background: #5a67d8;
|
||||
}
|
||||
|
||||
.image-preview {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
margin-top: 10px;
|
||||
max-height: 150px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.preview-img {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
object-fit: cover;
|
||||
border-radius: 5px;
|
||||
border: 2px solid #e0e0e0;
|
||||
}
|
||||
|
||||
.btn {
|
||||
padding: 12px 30px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
font-size: 16px;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
margin: 5px;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background: #5a67d8;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn-success {
|
||||
background: #48bb78;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-success:hover {
|
||||
background: #38a169;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background: #f56565;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover {
|
||||
background: #e53e3e;
|
||||
}
|
||||
|
||||
.btn:disabled {
|
||||
background: #cbd5e0;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
#webcam-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
max-width: 640px;
|
||||
margin: 20px auto;
|
||||
}
|
||||
|
||||
#webcam {
|
||||
width: 100%;
|
||||
border-radius: 10px;
|
||||
background: #000;
|
||||
}
|
||||
|
||||
.confidence-bars {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.confidence-item {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.confidence-label {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 5px;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.confidence-bar {
|
||||
height: 30px;
|
||||
background: #e0e0e0;
|
||||
border-radius: 15px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.confidence-fill {
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #667eea, #764ba2);
|
||||
border-radius: 15px;
|
||||
transition: width 0.3s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: white;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 5px;
|
||||
margin: 10px 0;
|
||||
text-align: center;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.status-success {
|
||||
background: #c6f6d5;
|
||||
color: #22543d;
|
||||
border: 1px solid #9ae6b4;
|
||||
}
|
||||
|
||||
.status-error {
|
||||
background: #fed7d7;
|
||||
color: #742a2a;
|
||||
border: 1px solid #fc8181;
|
||||
}
|
||||
|
||||
.status-info {
|
||||
background: #bee3f8;
|
||||
color: #2c5282;
|
||||
border: 1px solid #90cdf4;
|
||||
}
|
||||
|
||||
.status-warning {
|
||||
background: #fef3c7;
|
||||
color: #92400e;
|
||||
border: 1px solid #fcd34d;
|
||||
}
|
||||
|
||||
.training-progress {
|
||||
margin: 20px 0;
|
||||
}
|
||||
|
||||
.progress-bar {
|
||||
height: 25px;
|
||||
background: #e0e0e0;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.progress-fill {
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #48bb78, #38a169);
|
||||
transition: width 0.3s;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: white;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.metrics {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
|
||||
gap: 15px;
|
||||
margin: 20px 0;
|
||||
}
|
||||
|
||||
.metric-card {
|
||||
background: #f7fafc;
|
||||
padding: 15px;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
color: #718096;
|
||||
font-size: 12px;
|
||||
text-transform: uppercase;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
color: #2d3748;
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.samples-count {
|
||||
display: inline-block;
|
||||
background: #edf2f7;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
font-size: 12px;
|
||||
color: #4a5568;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.full-width {
|
||||
grid-column: 1 / -1;
|
||||
}
|
||||
|
||||
.button-group {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin: 20px 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
#lossChart {
|
||||
width: 100%;
|
||||
height: 300px;
|
||||
margin-top: 20px;
|
||||
border: 1px solid #e0e0e0;
|
||||
border-radius: 8px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="main-container">
|
||||
<h1>🤖 自定义图像分类器</h1>
|
||||
|
||||
<div class="grid-container">
|
||||
<!-- 数据采集卡片 -->
|
||||
<div class="card">
|
||||
<h2>📸 数据采集</h2>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">1</span> 第一类</h3>
|
||||
<input type="text" id="class1Name" placeholder="输入类别名称(如:人)" value="类别1">
|
||||
<label class="file-label" for="class1Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class1Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class1Count">0 张图片</span>
|
||||
<div class="image-preview" id="class1Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">2</span> 第二类</h3>
|
||||
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
|
||||
<label class="file-label" for="class2Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class2Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class2Count">0 张图片</span>
|
||||
<div class="image-preview" id="class2Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">3</span> 第三类(可选)</h3>
|
||||
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="">
|
||||
<label class="file-label" for="class3Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class3Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class3Count">0 张图片</span>
|
||||
<div class="image-preview" id="class3Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-primary" id="addDataBtn">添加到数据集</button>
|
||||
<button class="btn btn-danger" id="clearDataBtn">清空数据集</button>
|
||||
</div>
|
||||
|
||||
<div id="dataStatus"></div>
|
||||
</div>
|
||||
|
||||
<!-- 训练控制卡片 -->
|
||||
<div class="card">
|
||||
<h2>🎯 模型训练</h2>
|
||||
|
||||
<!-- 超参数调节 -->
|
||||
<div class="hyperparameters" style="background: #f8f9fa; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
|
||||
<h3 style="margin-bottom: 15px; color: #555;">⚙️ 超参数设置</h3>
|
||||
|
||||
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 15px;">
|
||||
<div>
|
||||
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
|
||||
学习率: <span id="learningRateValue">0.001</span>
|
||||
</label>
|
||||
<input type="range" id="learningRate" min="-5" max="-1" step="0.1" value="-3"
|
||||
style="width: 100%;" oninput="document.getElementById('learningRateValue').textContent = Math.pow(10, this.value).toFixed(5)">
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
|
||||
训练轮数: <span id="epochsValue">100</span>
|
||||
</label>
|
||||
<input type="range" id="epochs" min="10" max="200" step="10" value="100"
|
||||
style="width: 100%;" oninput="document.getElementById('epochsValue').textContent = this.value">
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
|
||||
Dropout率: <span id="dropoutValue">0.3</span>
|
||||
</label>
|
||||
<input type="range" id="dropoutRate" min="0" max="0.7" step="0.05" value="0.3"
|
||||
style="width: 100%;" oninput="document.getElementById('dropoutValue').textContent = this.value">
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
|
||||
L2正则化: <span id="l2Value">0.01</span>
|
||||
</label>
|
||||
<input type="range" id="l2Regularization" min="-4" max="-1" step="0.1" value="-2"
|
||||
style="width: 100%;" oninput="document.getElementById('l2Value').textContent = Math.pow(10, this.value).toFixed(4)">
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
|
||||
批次大小: <span id="batchSizeValue">32</span>
|
||||
</label>
|
||||
<input type="range" id="batchSize" min="8" max="128" step="8" value="32"
|
||||
style="width: 100%;" oninput="document.getElementById('batchSizeValue').textContent = this.value">
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
|
||||
温度缩放: <span id="temperatureValue">1.0</span>
|
||||
</label>
|
||||
<input type="range" id="temperature" min="0.5" max="5" step="0.1" value="1.0"
|
||||
style="width: 100%;" oninput="document.getElementById('temperatureValue').textContent = this.value">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="margin-top: 15px;">
|
||||
<label style="display: flex; align-items: center; gap: 10px; cursor: pointer;">
|
||||
<input type="checkbox" id="earlyStoppingCheck" checked>
|
||||
<span style="font-size: 14px; color: #666;">启用早停 (防止过拟合)</span>
|
||||
</label>
|
||||
|
||||
<div id="earlyStoppingOptions" style="margin-top: 10px; padding-left: 25px;">
|
||||
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
|
||||
耐心值: <span id="patienceValue">10</span> epochs
|
||||
</label>
|
||||
<input type="range" id="patience" min="5" max="30" step="5" value="10"
|
||||
style="width: 200px;" oninput="document.getElementById('patienceValue').textContent = this.value">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="margin-top: 15px;">
|
||||
<label style="display: flex; align-items: center; gap: 10px; cursor: pointer;">
|
||||
<input type="checkbox" id="dataAugmentationCheck">
|
||||
<span style="font-size: 14px; color: #666;">启用数据增强 (提高泛化能力)</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div style="margin-top: 15px; display: flex; gap: 10px;">
|
||||
<button class="btn btn-primary" onclick="classifier.resetHyperparameters()" style="padding: 8px 15px; font-size: 14px;">
|
||||
重置为默认值
|
||||
</button>
|
||||
<button class="btn btn-primary" onclick="classifier.showRecommendations()" style="padding: 8px 15px; font-size: 14px;">
|
||||
查看建议设置
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-success" id="trainBtn">开始训练</button>
|
||||
<button class="btn btn-danger" id="stopBtn" disabled>停止训练</button>
|
||||
</div>
|
||||
|
||||
<div id="trainingStatus"></div>
|
||||
|
||||
<div class="training-progress hidden" id="trainingProgress">
|
||||
<div class="progress-bar">
|
||||
<div class="progress-fill" id="progressFill">0%</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="metrics" id="metricsContainer"></div>
|
||||
|
||||
<canvas id="lossChart"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 预测卡片 -->
|
||||
<div class="card full-width">
|
||||
<h2>📹 实时预测</h2>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
|
||||
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
|
||||
<button class="btn btn-success" id="saveModelBtn">保存模型</button>
|
||||
<button class="btn btn-primary" id="loadModelBtn">加载模型</button>
|
||||
</div>
|
||||
|
||||
<div id="webcam-container">
|
||||
<video id="webcam" autoplay playsinline muted></video>
|
||||
</div>
|
||||
|
||||
<div class="confidence-bars" id="confidenceBars"></div>
|
||||
|
||||
<div id="predictionStatus"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="custom-classifier.js"></script>
|
||||
</body>
|
||||
</html>
|
1174
三层神经网络/custom-classifier.js
Normal file
1174
三层神经网络/custom-classifier.js
Normal file
File diff suppressed because it is too large
Load Diff
551
原版KNN/knn-classifier.html
Normal file
551
原版KNN/knn-classifier.html
Normal file
@ -0,0 +1,551 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>KNN 图像分类器 - TensorFlow.js</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier@latest"></script>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.main-container {
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: white;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
font-size: 2.5em;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.grid-container {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.grid-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 15px;
|
||||
padding: 25px;
|
||||
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.card h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
border-bottom: 2px solid #667eea;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.class-input {
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.class-input h3 {
|
||||
color: #555;
|
||||
margin-bottom: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.class-number {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
border-radius: 50%;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 5px;
|
||||
margin-bottom: 10px;
|
||||
font-size: 16px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
input[type="file"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.file-label {
|
||||
display: inline-block;
|
||||
padding: 10px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: background 0.3s;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.file-label:hover {
|
||||
background: #5a67d8;
|
||||
}
|
||||
|
||||
.btn {
|
||||
padding: 12px 30px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
font-size: 16px;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
margin: 5px;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background: #5a67d8;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn-success {
|
||||
background: #48bb78;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-success:hover {
|
||||
background: #38a169;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background: #f56565;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover {
|
||||
background: #e53e3e;
|
||||
}
|
||||
|
||||
.btn:disabled {
|
||||
background: #cbd5e0;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
#webcam-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
max-width: 640px;
|
||||
margin: 20px auto;
|
||||
}
|
||||
|
||||
#webcam {
|
||||
width: 100%;
|
||||
border-radius: 10px;
|
||||
background: #000;
|
||||
}
|
||||
|
||||
.samples-count {
|
||||
display: inline-block;
|
||||
background: #edf2f7;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
font-size: 12px;
|
||||
color: #4a5568;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.image-preview {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
margin-top: 10px;
|
||||
max-height: 150px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.preview-img {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
object-fit: cover;
|
||||
border-radius: 5px;
|
||||
border: 2px solid #e0e0e0;
|
||||
}
|
||||
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 5px;
|
||||
margin: 10px 0;
|
||||
text-align: center;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.status-success {
|
||||
background: #c6f6d5;
|
||||
color: #22543d;
|
||||
border: 1px solid #9ae6b4;
|
||||
}
|
||||
|
||||
.status-error {
|
||||
background: #fed7d7;
|
||||
color: #742a2a;
|
||||
border: 1px solid #fc8181;
|
||||
}
|
||||
|
||||
.status-info {
|
||||
background: #bee3f8;
|
||||
color: #2c5282;
|
||||
border: 1px solid #90cdf4;
|
||||
}
|
||||
|
||||
.button-group {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin: 20px 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.full-width {
|
||||
grid-column: 1 / -1;
|
||||
}
|
||||
|
||||
.prediction-results {
|
||||
margin-top: 20px;
|
||||
padding: 20px;
|
||||
background: #f7fafc;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.prediction-item {
|
||||
padding: 15px;
|
||||
margin: 10px 0;
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
border-left: 4px solid #667eea;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
|
||||
}
|
||||
|
||||
.prediction-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.prediction-label {
|
||||
font-weight: 600;
|
||||
color: #2d3748;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.prediction-confidence {
|
||||
background: linear-gradient(135deg, #667eea, #764ba2);
|
||||
color: white;
|
||||
padding: 4px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
min-width: 60px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.confidence-bar-container {
|
||||
width: 100%;
|
||||
height: 24px;
|
||||
background: #e2e8f0;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.confidence-bar {
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #667eea, #764ba2);
|
||||
border-radius: 12px;
|
||||
transition: width 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
position: relative;
|
||||
min-width: 0;
|
||||
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
|
||||
}
|
||||
|
||||
.confidence-bar::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: linear-gradient(90deg, transparent, rgba(255,255,255,0.3), transparent);
|
||||
animation: shimmer 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes shimmer {
|
||||
0% { transform: translateX(-100%); }
|
||||
100% { transform: translateX(100%); }
|
||||
}
|
||||
|
||||
.confidence-bar.high {
|
||||
background: linear-gradient(90deg, #48bb78, #38a169);
|
||||
}
|
||||
|
||||
.confidence-bar.medium {
|
||||
background: linear-gradient(90deg, #ed8936, #dd6b20);
|
||||
}
|
||||
|
||||
.confidence-bar.low {
|
||||
background: linear-gradient(90deg, #f56565, #e53e3e);
|
||||
}
|
||||
|
||||
.confidence-percentage {
|
||||
position: absolute;
|
||||
left: 50%;
|
||||
top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
font-size: 12px;
|
||||
text-shadow: 0 1px 2px rgba(0,0,0,0.2);
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.top-tags {
|
||||
margin: 20px 0;
|
||||
padding: 15px;
|
||||
background: #edf2fe;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.tag-item {
|
||||
display: inline-block;
|
||||
background: white;
|
||||
padding: 5px 12px;
|
||||
margin: 5px;
|
||||
border-radius: 15px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #cbd5e0;
|
||||
}
|
||||
|
||||
.tag-weight {
|
||||
color: #667eea;
|
||||
font-weight: bold;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.k-selector {
|
||||
margin: 15px 0;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.k-selector label {
|
||||
display: block;
|
||||
margin-bottom: 10px;
|
||||
color: #555;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.k-value-display {
|
||||
display: inline-block;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
padding: 2px 8px;
|
||||
border-radius: 5px;
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
input[type="range"] {
|
||||
width: 100%;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.model-info {
|
||||
margin-top: 20px;
|
||||
padding: 15px;
|
||||
background: #f0f4f8;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.info-item {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin: 5px 0;
|
||||
}
|
||||
|
||||
.info-label {
|
||||
color: #718096;
|
||||
}
|
||||
|
||||
.info-value {
|
||||
color: #2d3748;
|
||||
font-weight: 500;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="main-container">
|
||||
<h1>🤖 KNN 图像分类器(基于特征标签)</h1>
|
||||
|
||||
<div class="grid-container">
|
||||
<!-- 数据采集卡片 -->
|
||||
<div class="card">
|
||||
<h2>📸 数据采集</h2>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">1</span> 第一类</h3>
|
||||
<input type="text" id="class1Name" placeholder="输入类别名称(如:猫)" value="类别1">
|
||||
<label class="file-label" for="class1Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class1Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class1Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(0)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class1Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">2</span> 第二类</h3>
|
||||
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
|
||||
<label class="file-label" for="class2Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class2Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class2Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(1)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class2Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">3</span> 第三类(可选)</h3>
|
||||
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="">
|
||||
<label class="file-label" for="class3Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class3Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class3Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(2)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class3Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-success" id="addDataBtn">训练KNN模型</button>
|
||||
<button class="btn btn-danger" id="clearDataBtn">清空数据</button>
|
||||
</div>
|
||||
|
||||
<div id="dataStatus"></div>
|
||||
</div>
|
||||
|
||||
<!-- KNN模型信息卡片 -->
|
||||
<div class="card">
|
||||
<h2>🎯 KNN 模型设置</h2>
|
||||
|
||||
<div class="k-selector">
|
||||
<label>
|
||||
K值(最近邻数量)
|
||||
<span class="k-value-display" id="kValueDisplay">3</span>
|
||||
</label>
|
||||
<input type="range" id="kValue" min="1" max="20" value="3"
|
||||
oninput="document.getElementById('kValueDisplay').textContent = this.value">
|
||||
<small style="color: #718096;">K值越大,预测越保守;K值越小,对局部特征越敏感</small>
|
||||
</div>
|
||||
|
||||
<div class="k-selector">
|
||||
<label>
|
||||
滤波器系数 (α)
|
||||
<span class="k-value-display" id="filterAlphaDisplay">0.3</span>
|
||||
</label>
|
||||
<input type="range" id="filterAlpha" min="0.05" max="1.0" step="0.05" value="0.3"
|
||||
oninput="document.getElementById('filterAlphaDisplay').textContent = this.value">
|
||||
<small style="color: #718096;">低通滤波器系数:值越小输出越平滑(0.1-0.3推荐),值越大响应越快</small>
|
||||
</div>
|
||||
|
||||
<div class="top-tags" id="topTags">
|
||||
<h3 style="margin-bottom: 10px;">📊 特征标签提取预览</h3>
|
||||
<div id="tagsList">等待数据...</div>
|
||||
</div>
|
||||
|
||||
<div class="model-info">
|
||||
<h3 style="margin-bottom: 10px;">ℹ️ 模型信息</h3>
|
||||
<div class="info-item">
|
||||
<span class="info-label">预训练模型:</span>
|
||||
<span class="info-value">MobileNet v2</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">特征维度:</span>
|
||||
<span class="info-value">1000个标签</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">分类器类型:</span>
|
||||
<span class="info-value">K-最近邻 (KNN)</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">总样本数:</span>
|
||||
<span class="info-value" id="totalSamples">0</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 预测卡片 -->
|
||||
<div class="card full-width">
|
||||
<h2>📹 实时预测</h2>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
|
||||
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
|
||||
<button class="btn btn-success" id="saveModelBtn">保存模型</button>
|
||||
<button class="btn btn-primary" id="loadModelBtn">加载模型</button>
|
||||
</div>
|
||||
|
||||
<div id="webcam-container">
|
||||
<video id="webcam" autoplay playsinline muted></video>
|
||||
</div>
|
||||
|
||||
<div class="prediction-results" id="predictionResults">
|
||||
<h3>预测结果</h3>
|
||||
<div id="predictions">等待预测...</div>
|
||||
</div>
|
||||
|
||||
<div id="predictionStatus"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="knn-classifier.js"></script>
|
||||
</body>
|
||||
</html>
|
621
原版KNN/knn-classifier.js
Normal file
621
原版KNN/knn-classifier.js
Normal file
@ -0,0 +1,621 @@
|
||||
// KNN 图像分类器 - 基于MobileNet特征标签
|
||||
class KNNImageClassifier {
|
||||
constructor() {
|
||||
this.mobilenet = null;
|
||||
this.knnClassifier = null;
|
||||
this.classNames = [];
|
||||
this.webcamStream = null;
|
||||
this.isPredicting = false;
|
||||
this.currentCaptureClass = -1;
|
||||
this.imagenetClasses = null;
|
||||
|
||||
// 低通滤波器状态
|
||||
this.filteredConfidences = {};
|
||||
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
|
||||
|
||||
this.init();
|
||||
}
|
||||
|
||||
async init() {
|
||||
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
|
||||
|
||||
try {
|
||||
// 加载 MobileNet 模型
|
||||
this.mobilenet = await mobilenet.load({
|
||||
version: 2,
|
||||
alpha: 1.0
|
||||
});
|
||||
|
||||
// 创建 KNN 分类器
|
||||
this.knnClassifier = knnClassifier.create();
|
||||
|
||||
// 加载 ImageNet 类别名称
|
||||
await this.loadImageNetClasses();
|
||||
|
||||
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
|
||||
this.setupEventListeners();
|
||||
} catch (error) {
|
||||
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
async loadImageNetClasses() {
|
||||
// ImageNet 前10个类别名称(简化版)
|
||||
this.imagenetClasses = [
|
||||
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
|
||||
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
|
||||
];
|
||||
}
|
||||
|
||||
setupEventListeners() {
|
||||
// 文件上传监听
|
||||
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
|
||||
document.getElementById(id).addEventListener('change', (e) => {
|
||||
this.handleImageUpload(e, index);
|
||||
});
|
||||
});
|
||||
|
||||
// 按钮监听
|
||||
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
|
||||
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
|
||||
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
|
||||
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
|
||||
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
|
||||
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
|
||||
}
|
||||
|
||||
handleImageUpload(event, classIndex) {
|
||||
const files = event.target.files;
|
||||
const countElement = document.getElementById(`class${classIndex + 1}Count`);
|
||||
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
|
||||
|
||||
countElement.textContent = `${files.length} 张图片`;
|
||||
|
||||
// 清空之前的预览
|
||||
previewContainer.innerHTML = '';
|
||||
|
||||
// 添加图片预览
|
||||
Array.from(files).forEach(file => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
const img = document.createElement('img');
|
||||
img.src = e.target.result;
|
||||
img.className = 'preview-img';
|
||||
previewContainer.appendChild(img);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
// 从图像提取 MobileNet 标签和权重
|
||||
async extractImageNetTags(img) {
|
||||
try {
|
||||
// 获取 MobileNet 的预测(1000个类别的概率)
|
||||
const predictions = await this.mobilenet.classify(img);
|
||||
|
||||
// 获取完整的 logits(原始输出)
|
||||
const logits = this.mobilenet.infer(img, false); // false = 不使用嵌入层,获取原始1000维输出
|
||||
|
||||
// 获取前10个最高概率的标签
|
||||
const topK = await this.getTopKTags(logits, 10);
|
||||
|
||||
return {
|
||||
logits: logits, // 1000维特征向量
|
||||
predictions: predictions, // 前3个预测
|
||||
topTags: topK // 前10个标签和权重
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('特征提取失败:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 获取Top-K标签
|
||||
async getTopKTags(logits, k = 10) {
|
||||
const values = await logits.data();
|
||||
const valuesAndIndices = [];
|
||||
|
||||
for (let i = 0; i < values.length; i++) {
|
||||
valuesAndIndices.push({ value: values[i], index: i });
|
||||
}
|
||||
|
||||
valuesAndIndices.sort((a, b) => b.value - a.value);
|
||||
const topkValues = new Float32Array(k);
|
||||
const topkIndices = new Int32Array(k);
|
||||
|
||||
for (let i = 0; i < k; i++) {
|
||||
topkValues[i] = valuesAndIndices[i].value;
|
||||
topkIndices[i] = valuesAndIndices[i].index;
|
||||
}
|
||||
|
||||
const topTags = [];
|
||||
for (let i = 0; i < k; i++) {
|
||||
topTags.push({
|
||||
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
|
||||
probability: this.softmax(topkValues)[i],
|
||||
logit: topkValues[i]
|
||||
});
|
||||
}
|
||||
|
||||
return topTags;
|
||||
}
|
||||
|
||||
// Softmax 函数
|
||||
softmax(arr) {
|
||||
const maxLogit = Math.max(...arr);
|
||||
const scores = arr.map(l => Math.exp(l - maxLogit));
|
||||
const sum = scores.reduce((a, b) => a + b);
|
||||
return scores.map(s => s / sum);
|
||||
}
|
||||
|
||||
// 训练 KNN 模型
|
||||
async trainKNN() {
|
||||
const classes = [];
|
||||
const imageFiles = [];
|
||||
|
||||
// 收集所有类别和图片
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
const className = document.getElementById(`class${i}Name`).value.trim();
|
||||
const files = document.getElementById(`class${i}Images`).files;
|
||||
|
||||
if (className && files.length > 0) {
|
||||
classes.push(className);
|
||||
imageFiles.push(files);
|
||||
}
|
||||
}
|
||||
|
||||
if (classes.length < 2) {
|
||||
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
|
||||
return;
|
||||
}
|
||||
|
||||
this.classNames = classes;
|
||||
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
|
||||
|
||||
// 清空现有的KNN分类器
|
||||
this.knnClassifier.clearAllClasses();
|
||||
|
||||
let totalProcessed = 0;
|
||||
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
|
||||
|
||||
// 处理每个类别的图片
|
||||
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
|
||||
const files = imageFiles[classIndex];
|
||||
console.log(`处理类别 ${classes[classIndex]}...`);
|
||||
|
||||
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
|
||||
try {
|
||||
const img = await this.loadImage(files[fileIndex]);
|
||||
|
||||
// 提取特征标签
|
||||
const features = await this.extractImageNetTags(img);
|
||||
|
||||
// 添加到KNN分类器
|
||||
// 使用完整的logits作为特征向量
|
||||
this.knnClassifier.addExample(features.logits, classIndex);
|
||||
|
||||
totalProcessed++;
|
||||
const progress = Math.round((totalProcessed / totalImages) * 100);
|
||||
this.showStatus('dataStatus', 'info',
|
||||
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
|
||||
|
||||
// 显示提取的标签
|
||||
if (fileIndex === 0) {
|
||||
this.displayTopTags(features.topTags);
|
||||
}
|
||||
|
||||
// 清理
|
||||
img.remove();
|
||||
features.logits.dispose();
|
||||
} catch (error) {
|
||||
console.error('处理图片失败:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新模型信息
|
||||
document.getElementById('totalSamples').textContent = totalProcessed;
|
||||
|
||||
this.showStatus('dataStatus', 'success',
|
||||
`KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`);
|
||||
|
||||
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
|
||||
}
|
||||
|
||||
// 显示提取的标签
|
||||
displayTopTags(tags) {
|
||||
const container = document.getElementById('tagsList');
|
||||
let html = '';
|
||||
|
||||
tags.slice(0, 5).forEach(tag => {
|
||||
html += `
|
||||
<span class="tag-item">
|
||||
${tag.className}
|
||||
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
|
||||
</span>
|
||||
`;
|
||||
});
|
||||
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
// 加载图片
|
||||
async loadImage(file) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
const img = new Image();
|
||||
img.onload = () => resolve(img);
|
||||
img.onerror = reject;
|
||||
img.src = e.target.result;
|
||||
};
|
||||
reader.onerror = reject;
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
// 清空数据集
|
||||
clearDataset() {
|
||||
this.knnClassifier.clearAllClasses();
|
||||
this.classNames = [];
|
||||
this.filteredConfidences = {}; // 重置滤波器状态
|
||||
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
document.getElementById(`class${i}Images`).value = '';
|
||||
document.getElementById(`class${i}Count`).textContent = '0 张图片';
|
||||
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
|
||||
}
|
||||
|
||||
document.getElementById('totalSamples').textContent = '0';
|
||||
document.getElementById('tagsList').innerHTML = '等待数据...';
|
||||
document.getElementById('predictions').innerHTML = '等待预测...';
|
||||
|
||||
this.showStatus('dataStatus', 'info', '数据集已清空');
|
||||
}
|
||||
|
||||
// 启动摄像头
|
||||
async startWebcam() {
|
||||
if (this.knnClassifier.getNumClasses() === 0) {
|
||||
this.showStatus('predictionStatus', 'error', '请先训练模型!');
|
||||
return;
|
||||
}
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
video: { facingMode: 'user' },
|
||||
audio: false
|
||||
});
|
||||
|
||||
video.srcObject = stream;
|
||||
this.webcamStream = stream;
|
||||
|
||||
document.getElementById('startWebcamBtn').disabled = true;
|
||||
document.getElementById('stopWebcamBtn').disabled = false;
|
||||
|
||||
// 等待视频加载
|
||||
video.addEventListener('loadeddata', () => {
|
||||
this.isPredicting = true;
|
||||
this.predictLoop();
|
||||
});
|
||||
|
||||
this.showStatus('predictionStatus', 'success', '摄像头已启动');
|
||||
} catch (error) {
|
||||
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 停止摄像头
|
||||
stopWebcam() {
|
||||
if (this.webcamStream) {
|
||||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||||
this.webcamStream = null;
|
||||
}
|
||||
|
||||
this.isPredicting = false;
|
||||
this.filteredConfidences = {}; // 重置滤波器状态
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
video.srcObject = null;
|
||||
|
||||
document.getElementById('startWebcamBtn').disabled = false;
|
||||
document.getElementById('stopWebcamBtn').disabled = true;
|
||||
|
||||
this.showStatus('predictionStatus', 'info', '摄像头已停止');
|
||||
}
|
||||
|
||||
// 预测循环
|
||||
async predictLoop() {
|
||||
if (!this.isPredicting) return;
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
if (video.readyState === 4) {
|
||||
try {
|
||||
// 提取特征
|
||||
const features = await this.extractImageNetTags(video);
|
||||
|
||||
// 使用原始KNN进行预测
|
||||
const k = parseInt(document.getElementById('kValue').value);
|
||||
const prediction = await this.knnClassifier.predictClass(features.logits, k);
|
||||
|
||||
// 应用低通滤波器
|
||||
const smoothedPrediction = this.applyLowPassFilter(prediction);
|
||||
|
||||
// 显示预测结果
|
||||
this.displayPrediction(smoothedPrediction);
|
||||
|
||||
// 显示提取的标签
|
||||
this.displayTopTags(features.topTags);
|
||||
|
||||
// 清理张量
|
||||
features.logits.dispose();
|
||||
} catch (error) {
|
||||
console.error('预测错误:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// 继续预测循环
|
||||
requestAnimationFrame(() => this.predictLoop());
|
||||
}
|
||||
|
||||
// 应用低通滤波器到置信度
|
||||
applyLowPassFilter(prediction) {
|
||||
// 获取滤波器系数
|
||||
const alpha = parseFloat(document.getElementById('filterAlpha').value);
|
||||
|
||||
// 初始化滤波状态(如果是第一次)
|
||||
if (Object.keys(this.filteredConfidences).length === 0) {
|
||||
for (let i = 0; i < this.classNames.length; i++) {
|
||||
this.filteredConfidences[i] = prediction.confidences[i] || 0;
|
||||
}
|
||||
return {
|
||||
label: prediction.label,
|
||||
confidences: {...this.filteredConfidences}
|
||||
};
|
||||
}
|
||||
|
||||
// 应用指数移动平均(EMA)低通滤波
|
||||
const newConfidences = {};
|
||||
for (let i = 0; i < this.classNames.length; i++) {
|
||||
const currentValue = prediction.confidences[i] || 0;
|
||||
const previousValue = this.filteredConfidences[i] || 0;
|
||||
|
||||
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
|
||||
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
|
||||
newConfidences[i] = this.filteredConfidences[i];
|
||||
}
|
||||
|
||||
// 归一化确保总和为1
|
||||
let sum = 0;
|
||||
Object.values(newConfidences).forEach(v => sum += v);
|
||||
if (sum > 0) {
|
||||
Object.keys(newConfidences).forEach(key => {
|
||||
newConfidences[key] = newConfidences[key] / sum;
|
||||
});
|
||||
}
|
||||
|
||||
// 找到最高置信度的类别
|
||||
let maxConfidence = 0;
|
||||
let bestLabel = 0;
|
||||
Object.keys(newConfidences).forEach(key => {
|
||||
if (newConfidences[key] > maxConfidence) {
|
||||
maxConfidence = newConfidences[key];
|
||||
bestLabel = parseInt(key);
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
label: bestLabel,
|
||||
confidences: newConfidences
|
||||
};
|
||||
}
|
||||
|
||||
// 显示预测结果
|
||||
displayPrediction(prediction) {
|
||||
const container = document.getElementById('predictions');
|
||||
let html = '';
|
||||
|
||||
// 直接使用滤波后的置信度
|
||||
const confidences = prediction.confidences;
|
||||
const predictedClass = prediction.label;
|
||||
|
||||
// 固定顺序显示(按类别索引)
|
||||
for (let i = 0; i < this.classNames.length; i++) {
|
||||
const className = this.classNames[i];
|
||||
const confidence = confidences[i] || 0;
|
||||
const percentage = (confidence * 100).toFixed(1);
|
||||
const isWinner = i === predictedClass;
|
||||
|
||||
// 根据置信度决定颜色等级
|
||||
let barClass = '';
|
||||
if (confidence > 0.7) barClass = 'high';
|
||||
else if (confidence > 0.4) barClass = 'medium';
|
||||
else barClass = 'low';
|
||||
|
||||
// 如果是获胜类别,使用绿色
|
||||
if (isWinner) barClass = 'high';
|
||||
|
||||
html += `
|
||||
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
|
||||
<div class="prediction-header">
|
||||
<span class="prediction-label">
|
||||
${className} ${isWinner ? '👑' : ''}
|
||||
</span>
|
||||
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
|
||||
${percentage}%
|
||||
</span>
|
||||
</div>
|
||||
<div class="confidence-bar-container">
|
||||
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
|
||||
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
// 从摄像头捕获样本
|
||||
async captureFromWebcam(classIndex) {
|
||||
if (!this.webcamStream) {
|
||||
// 临时启动摄像头
|
||||
const video = document.getElementById('webcam');
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
video: { facingMode: 'user' },
|
||||
audio: false
|
||||
});
|
||||
|
||||
video.srcObject = stream;
|
||||
this.webcamStream = stream;
|
||||
|
||||
// 等待视频加载
|
||||
setTimeout(async () => {
|
||||
await this.addWebcamSample(classIndex);
|
||||
|
||||
// 停止临时摄像头
|
||||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||||
this.webcamStream = null;
|
||||
video.srcObject = null;
|
||||
}, 1000);
|
||||
} catch (error) {
|
||||
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||||
}
|
||||
} else {
|
||||
await this.addWebcamSample(classIndex);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加摄像头样本
|
||||
async addWebcamSample(classIndex) {
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
if (video.readyState === 4) {
|
||||
try {
|
||||
// 提取特征
|
||||
const features = await this.extractImageNetTags(video);
|
||||
|
||||
// 添加到KNN分类器
|
||||
this.knnClassifier.addExample(features.logits, classIndex);
|
||||
|
||||
// 更新计数
|
||||
const currentCount = this.knnClassifier.getClassExampleCount();
|
||||
const count = currentCount[classIndex] || 0;
|
||||
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
|
||||
|
||||
// 清理
|
||||
features.logits.dispose();
|
||||
|
||||
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
|
||||
} catch (error) {
|
||||
console.error('添加样本失败:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存模型
|
||||
async saveModel() {
|
||||
if (this.knnClassifier.getNumClasses() === 0) {
|
||||
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取KNN分类器的数据
|
||||
const dataset = this.knnClassifier.getClassifierDataset();
|
||||
const datasetObj = {};
|
||||
|
||||
Object.keys(dataset).forEach(key => {
|
||||
const data = dataset[key].dataSync();
|
||||
datasetObj[key] = Array.from(data);
|
||||
});
|
||||
|
||||
// 保存为JSON
|
||||
const modelData = {
|
||||
dataset: datasetObj,
|
||||
classNames: this.classNames,
|
||||
k: document.getElementById('kValue').value,
|
||||
date: new Date().toISOString()
|
||||
};
|
||||
|
||||
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = 'knn-model.json';
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
|
||||
this.showStatus('predictionStatus', 'success', '模型已保存');
|
||||
} catch (error) {
|
||||
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 加载模型
|
||||
async loadModel() {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
input.accept = '.json';
|
||||
|
||||
input.onchange = async (e) => {
|
||||
try {
|
||||
const file = e.target.files[0];
|
||||
const text = await file.text();
|
||||
const modelData = JSON.parse(text);
|
||||
|
||||
// 清空现有分类器
|
||||
this.knnClassifier.clearAllClasses();
|
||||
|
||||
// 恢复数据集
|
||||
Object.keys(modelData.dataset).forEach(key => {
|
||||
const tensor = tf.tensor(modelData.dataset[key], [modelData.dataset[key].length / 1024, 1024]);
|
||||
this.knnClassifier.setClassifierDataset({ [key]: tensor });
|
||||
});
|
||||
|
||||
this.classNames = modelData.classNames;
|
||||
document.getElementById('kValue').value = modelData.k;
|
||||
document.getElementById('kValueDisplay').textContent = modelData.k;
|
||||
|
||||
this.showStatus('predictionStatus', 'success',
|
||||
`模型加载成功!类别: ${this.classNames.join(', ')}`);
|
||||
} catch (error) {
|
||||
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
input.click();
|
||||
}
|
||||
|
||||
// 显示状态
|
||||
showStatus(elementId, type, message) {
|
||||
const element = document.getElementById(elementId);
|
||||
|
||||
const classMap = {
|
||||
'success': 'status-success',
|
||||
'error': 'status-error',
|
||||
'info': 'status-info'
|
||||
};
|
||||
|
||||
element.className = `status-message ${classMap[type]}`;
|
||||
element.textContent = message;
|
||||
}
|
||||
}
|
||||
|
||||
// 全局函数:从摄像头捕获
|
||||
function captureFromWebcam(classIndex) {
|
||||
if (window.classifier) {
|
||||
window.classifier.captureFromWebcam(classIndex);
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化应用
|
||||
let classifier;
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
classifier = new KNNImageClassifier();
|
||||
window.classifier = classifier;
|
||||
});
|
561
完善KNN/knn-classifier.html
Normal file
561
完善KNN/knn-classifier.html
Normal file
@ -0,0 +1,561 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>KNN 图像分类器 - TensorFlow.js</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier@latest"></script>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.main-container {
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: white;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
font-size: 2.5em;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.grid-container {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.grid-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 15px;
|
||||
padding: 25px;
|
||||
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.card h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
border-bottom: 2px solid #667eea;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.class-input {
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.class-input h3 {
|
||||
color: #555;
|
||||
margin-bottom: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.class-number {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
border-radius: 50%;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 5px;
|
||||
margin-bottom: 10px;
|
||||
font-size: 16px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
input[type="file"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.file-label {
|
||||
display: inline-block;
|
||||
padding: 10px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: background 0.3s;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.file-label:hover {
|
||||
background: #5a67d8;
|
||||
}
|
||||
|
||||
.btn {
|
||||
padding: 12px 30px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
font-size: 16px;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
margin: 5px;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background: #5a67d8;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn-success {
|
||||
background: #48bb78;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-success:hover {
|
||||
background: #38a169;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background: #f56565;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover {
|
||||
background: #e53e3e;
|
||||
}
|
||||
|
||||
.btn:disabled {
|
||||
background: #cbd5e0;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
#webcam-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
max-width: 640px;
|
||||
margin: 20px auto;
|
||||
}
|
||||
|
||||
#webcam {
|
||||
width: 100%;
|
||||
border-radius: 10px;
|
||||
background: #000;
|
||||
}
|
||||
|
||||
.samples-count {
|
||||
display: inline-block;
|
||||
background: #edf2f7;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
font-size: 12px;
|
||||
color: #4a5568;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.image-preview {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
margin-top: 10px;
|
||||
max-height: 150px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.preview-img {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
object-fit: cover;
|
||||
border-radius: 5px;
|
||||
border: 2px solid #e0e0e0;
|
||||
}
|
||||
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 5px;
|
||||
margin: 10px 0;
|
||||
text-align: center;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.status-success {
|
||||
background: #c6f6d5;
|
||||
color: #22543d;
|
||||
border: 1px solid #9ae6b4;
|
||||
}
|
||||
|
||||
.status-error {
|
||||
background: #fed7d7;
|
||||
color: #742a2a;
|
||||
border: 1px solid #fc8181;
|
||||
}
|
||||
|
||||
.status-info {
|
||||
background: #bee3f8;
|
||||
color: #2c5282;
|
||||
border: 1px solid #90cdf4;
|
||||
}
|
||||
|
||||
.button-group {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin: 20px 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.full-width {
|
||||
grid-column: 1 / -1;
|
||||
}
|
||||
|
||||
.prediction-results {
|
||||
margin-top: 20px;
|
||||
padding: 20px;
|
||||
background: #f7fafc;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.prediction-item {
|
||||
padding: 15px;
|
||||
margin: 10px 0;
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
border-left: 4px solid #667eea;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
|
||||
}
|
||||
|
||||
.prediction-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.prediction-label {
|
||||
font-weight: 600;
|
||||
color: #2d3748;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.prediction-confidence {
|
||||
background: linear-gradient(135deg, #667eea, #764ba2);
|
||||
color: white;
|
||||
padding: 4px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
min-width: 60px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.confidence-bar-container {
|
||||
width: 100%;
|
||||
height: 24px;
|
||||
background: #e2e8f0;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.confidence-bar {
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #667eea, #764ba2);
|
||||
border-radius: 12px;
|
||||
transition: width 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
position: relative;
|
||||
min-width: 0;
|
||||
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
|
||||
}
|
||||
|
||||
.confidence-bar::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: linear-gradient(90deg, transparent, rgba(255,255,255,0.3), transparent);
|
||||
animation: shimmer 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes shimmer {
|
||||
0% { transform: translateX(-100%); }
|
||||
100% { transform: translateX(100%); }
|
||||
}
|
||||
|
||||
.confidence-bar.high {
|
||||
background: linear-gradient(90deg, #48bb78, #38a169);
|
||||
}
|
||||
|
||||
.confidence-bar.medium {
|
||||
background: linear-gradient(90deg, #ed8936, #dd6b20);
|
||||
}
|
||||
|
||||
.confidence-bar.low {
|
||||
background: linear-gradient(90deg, #f56565, #e53e3e);
|
||||
}
|
||||
|
||||
.confidence-percentage {
|
||||
position: absolute;
|
||||
left: 50%;
|
||||
top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
font-size: 12px;
|
||||
text-shadow: 0 1px 2px rgba(0,0,0,0.2);
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.top-tags {
|
||||
margin: 20px 0;
|
||||
padding: 15px;
|
||||
background: #edf2fe;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.tag-item {
|
||||
display: inline-block;
|
||||
background: white;
|
||||
padding: 5px 12px;
|
||||
margin: 5px;
|
||||
border-radius: 15px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #cbd5e0;
|
||||
}
|
||||
|
||||
.tag-weight {
|
||||
color: #667eea;
|
||||
font-weight: bold;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.k-selector {
|
||||
margin: 15px 0;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.k-selector label {
|
||||
display: block;
|
||||
margin-bottom: 10px;
|
||||
color: #555;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.k-value-display {
|
||||
display: inline-block;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
padding: 2px 8px;
|
||||
border-radius: 5px;
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
input[type="range"] {
|
||||
width: 100%;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.model-info {
|
||||
margin-top: 20px;
|
||||
padding: 15px;
|
||||
background: #f0f4f8;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.info-item {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin: 5px 0;
|
||||
}
|
||||
|
||||
.info-label {
|
||||
color: #718096;
|
||||
}
|
||||
|
||||
.info-value {
|
||||
color: #2d3748;
|
||||
font-weight: 500;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="main-container">
|
||||
<h1>🤖 KNN 图像分类器(基于特征标签)</h1>
|
||||
|
||||
<div class="grid-container">
|
||||
<!-- 数据采集卡片 -->
|
||||
<div class="card">
|
||||
<h2>📸 数据采集</h2>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">1</span> 第一类</h3>
|
||||
<input type="text" id="class1Name" placeholder="输入类别名称(如:猫)" value="类别1">
|
||||
<label class="file-label" for="class1Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class1Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class1Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(0)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class1Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">2</span> 第二类</h3>
|
||||
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
|
||||
<label class="file-label" for="class2Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class2Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class2Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(1)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class2Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">3</span> 第三类(可选)</h3>
|
||||
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="类别3">
|
||||
<label class="file-label" for="class3Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class3Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class3Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(2)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class3Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-success" id="addDataBtn">训练KNN模型</button>
|
||||
<button class="btn btn-danger" id="clearDataBtn">清空数据</button>
|
||||
</div>
|
||||
|
||||
<div id="dataStatus"></div>
|
||||
</div>
|
||||
|
||||
<!-- KNN模型信息卡片 -->
|
||||
<div class="card">
|
||||
<h2>🎯 KNN 模型设置</h2>
|
||||
|
||||
<div class="k-selector">
|
||||
<label>
|
||||
K值(最近邻数量)
|
||||
<span class="k-value-display" id="kValueDisplay">3</span>
|
||||
</label>
|
||||
<input type="range" id="kValue" min="1" max="20" value="3"
|
||||
oninput="document.getElementById('kValueDisplay').textContent = this.value">
|
||||
<small style="color: #718096;">K值越大,预测越保守;K值越小,对局部特征越敏感</small>
|
||||
</div>
|
||||
|
||||
<div class="k-selector">
|
||||
<label>
|
||||
滤波器系数 (α)
|
||||
<span class="k-value-display" id="filterAlphaDisplay">0.3</span>
|
||||
</label>
|
||||
<input type="range" id="filterAlpha" min="0.05" max="1.0" step="0.05" value="0.3"
|
||||
oninput="document.getElementById('filterAlphaDisplay').textContent = this.value">
|
||||
<small style="color: #718096;">低通滤波器系数:值越小输出越平滑(0.1-0.3推荐),值越大响应越快</small>
|
||||
</div>
|
||||
|
||||
<div class="k-selector">
|
||||
<label>
|
||||
距离阈值 (Distance Threshold)
|
||||
<span class="k-value-display" id="distanceThresholdDisplay">0.5</span>
|
||||
</label>
|
||||
<input type="range" id="distanceThreshold" min="0.1" max="2.0" step="0.05" value="0.5"
|
||||
oninput="document.getElementById('distanceThresholdDisplay').textContent = this.value">
|
||||
<small style="color: #718096;">距离阈值:样本与训练数据的最大距离,超过此值判定为"未知/背景"(单品类检测关键参数)</small>
|
||||
</div>
|
||||
|
||||
<div class="top-tags" id="topTags">
|
||||
<h3 style="margin-bottom: 10px;">📊 特征标签提取预览</h3>
|
||||
<div id="tagsList">等待数据...</div>
|
||||
</div>
|
||||
|
||||
<div class="model-info">
|
||||
<h3 style="margin-bottom: 10px;">ℹ️ 模型信息</h3>
|
||||
<div class="info-item">
|
||||
<span class="info-label">预训练模型:</span>
|
||||
<span class="info-value">MobileNet v2</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">特征维度:</span>
|
||||
<span class="info-value">1280维嵌入向量</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">分类器类型:</span>
|
||||
<span class="info-value">K-最近邻 (KNN)</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">总样本数:</span>
|
||||
<span class="info-value" id="totalSamples">0</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 预测卡片 -->
|
||||
<div class="card full-width">
|
||||
<h2>📹 实时预测</h2>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
|
||||
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
|
||||
<button class="btn btn-success" id="saveModelBtn">保存模型</button>
|
||||
<button class="btn btn-primary" id="loadModelBtn">加载模型</button>
|
||||
</div>
|
||||
|
||||
<div id="webcam-container">
|
||||
<video id="webcam" autoplay playsinline muted></video>
|
||||
</div>
|
||||
|
||||
<div class="prediction-results" id="predictionResults">
|
||||
<h3>预测结果</h3>
|
||||
<div id="predictions">等待预测...</div>
|
||||
</div>
|
||||
|
||||
<div id="predictionStatus"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="knn-classifier.js"></script>
|
||||
</body>
|
||||
</html>
|
985
完善KNN/knn-classifier.js
Normal file
985
完善KNN/knn-classifier.js
Normal file
@ -0,0 +1,985 @@
|
||||
// KNN 图像分类器 - 基于MobileNet特征标签
|
||||
class KNNImageClassifier {
|
||||
constructor() {
|
||||
this.mobilenet = null;
|
||||
this.knnClassifier = null;
|
||||
this.classNames = [];
|
||||
this.webcamStream = null;
|
||||
this.isPredicting = false;
|
||||
this.currentCaptureClass = -1;
|
||||
this.imagenetClasses = null;
|
||||
|
||||
// 低通滤波器状态
|
||||
this.filteredConfidences = {};
|
||||
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
|
||||
|
||||
// 距离阈值设置
|
||||
this.useDistanceThreshold = true;
|
||||
this.distanceThreshold = 0.5; // 默认距离阈值(归一化后的特征)
|
||||
this.adaptiveThreshold = null; // 自适应阈值
|
||||
|
||||
this.init();
|
||||
}
|
||||
|
||||
async init() {
|
||||
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
|
||||
|
||||
try {
|
||||
// 加载 MobileNet 模型
|
||||
this.mobilenet = await mobilenet.load({
|
||||
version: 2,
|
||||
alpha: 1.0
|
||||
});
|
||||
|
||||
// 创建 KNN 分类器
|
||||
this.knnClassifier = knnClassifier.create();
|
||||
|
||||
// 加载 ImageNet 类别名称
|
||||
await this.loadImageNetClasses();
|
||||
|
||||
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
|
||||
this.setupEventListeners();
|
||||
} catch (error) {
|
||||
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
async loadImageNetClasses() {
|
||||
// ImageNet 前10个类别名称(简化版)
|
||||
this.imagenetClasses = [
|
||||
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
|
||||
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
|
||||
];
|
||||
}
|
||||
|
||||
setupEventListeners() {
|
||||
// 文件上传监听
|
||||
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
|
||||
document.getElementById(id).addEventListener('change', (e) => {
|
||||
this.handleImageUpload(e, index);
|
||||
});
|
||||
});
|
||||
|
||||
// 按钮监听
|
||||
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
|
||||
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
|
||||
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
|
||||
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
|
||||
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
|
||||
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
|
||||
}
|
||||
|
||||
handleImageUpload(event, classIndex) {
|
||||
const files = event.target.files;
|
||||
const countElement = document.getElementById(`class${classIndex + 1}Count`);
|
||||
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
|
||||
|
||||
countElement.textContent = `${files.length} 张图片`;
|
||||
|
||||
// 清空之前的预览
|
||||
previewContainer.innerHTML = '';
|
||||
|
||||
// 添加图片预览
|
||||
Array.from(files).forEach(file => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
const img = document.createElement('img');
|
||||
img.src = e.target.result;
|
||||
img.className = 'preview-img';
|
||||
previewContainer.appendChild(img);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
// 从图像提取 MobileNet 标签和权重
|
||||
async extractImageNetTags(img) {
|
||||
try {
|
||||
// 获取 MobileNet 的预测(1000个类别的概率)
|
||||
const predictions = await this.mobilenet.classify(img);
|
||||
|
||||
// 获取用于KNN的特征(使用嵌入层获得更好的特征表示)
|
||||
const rawEmbeddings = this.mobilenet.infer(img, true); // true = 使用嵌入层,获取1280维特征
|
||||
|
||||
// L2归一化特征向量(重要:使距离计算更稳定)
|
||||
const embeddings = tf.tidy(() => {
|
||||
const norm = tf.norm(rawEmbeddings);
|
||||
const normalized = tf.div(rawEmbeddings, norm);
|
||||
rawEmbeddings.dispose(); // 清理原始嵌入
|
||||
return normalized;
|
||||
});
|
||||
|
||||
// 获取用于显示的logits(1000个类别)
|
||||
const logits = this.mobilenet.infer(img, false); // false = 获取原始1000维输出
|
||||
|
||||
// 获取前10个最高概率的标签
|
||||
const topK = await this.getTopKTags(logits, 10);
|
||||
|
||||
// 清理logits(只用于显示)
|
||||
logits.dispose();
|
||||
|
||||
return {
|
||||
logits: embeddings, // 使用1280维嵌入特征用于KNN
|
||||
predictions: predictions, // 前3个预测
|
||||
topTags: topK // 前10个标签和权重
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('特征提取失败:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 获取Top-K标签
|
||||
async getTopKTags(logits, k = 10) {
|
||||
const values = await logits.data();
|
||||
const valuesAndIndices = [];
|
||||
|
||||
for (let i = 0; i < values.length; i++) {
|
||||
valuesAndIndices.push({ value: values[i], index: i });
|
||||
}
|
||||
|
||||
valuesAndIndices.sort((a, b) => b.value - a.value);
|
||||
const topkValues = new Float32Array(k);
|
||||
const topkIndices = new Int32Array(k);
|
||||
|
||||
for (let i = 0; i < k; i++) {
|
||||
topkValues[i] = valuesAndIndices[i].value;
|
||||
topkIndices[i] = valuesAndIndices[i].index;
|
||||
}
|
||||
|
||||
const topTags = [];
|
||||
for (let i = 0; i < k; i++) {
|
||||
topTags.push({
|
||||
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
|
||||
probability: this.softmax(topkValues)[i],
|
||||
logit: topkValues[i]
|
||||
});
|
||||
}
|
||||
|
||||
return topTags;
|
||||
}
|
||||
|
||||
// Softmax 函数
|
||||
softmax(arr) {
|
||||
const maxLogit = Math.max(...arr);
|
||||
const scores = arr.map(l => Math.exp(l - maxLogit));
|
||||
const sum = scores.reduce((a, b) => a + b);
|
||||
return scores.map(s => s / sum);
|
||||
}
|
||||
|
||||
// 训练 KNN 模型
|
||||
async trainKNN() {
|
||||
const classes = [];
|
||||
const imageFiles = [];
|
||||
|
||||
// 收集所有类别和图片
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
const className = document.getElementById(`class${i}Name`).value.trim();
|
||||
const files = document.getElementById(`class${i}Images`).files;
|
||||
|
||||
if (className && files && files.length > 0) {
|
||||
classes.push(className);
|
||||
imageFiles.push(files);
|
||||
console.log(`类别 ${i}: "${className}" - ${files.length} 张图片`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log('收集到的类别:', classes);
|
||||
console.log('类别数量:', classes.length);
|
||||
|
||||
// 支持单品类检测(One-Class Classification)
|
||||
if (classes.length < 1) {
|
||||
this.showStatus('dataStatus', 'error', '请至少添加一个类别的图片!');
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果只有一个类别,提示用户这是单品类检测模式
|
||||
if (classes.length === 1) {
|
||||
console.log('📍 单品类检测模式:只检测 "' + classes[0] + '",其他都视为背景/未知');
|
||||
}
|
||||
|
||||
this.classNames = classes;
|
||||
this.filteredConfidences = {}; // 重置滤波器状态
|
||||
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
|
||||
|
||||
// 清空现有的KNN分类器
|
||||
this.knnClassifier.clearAllClasses();
|
||||
|
||||
let totalProcessed = 0;
|
||||
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
|
||||
|
||||
// 处理每个类别的图片
|
||||
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
|
||||
const files = imageFiles[classIndex];
|
||||
console.log(`处理类别 ${classes[classIndex]}...`);
|
||||
|
||||
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
|
||||
try {
|
||||
const img = await this.loadImage(files[fileIndex]);
|
||||
|
||||
// 提取特征标签
|
||||
const features = await this.extractImageNetTags(img);
|
||||
|
||||
// 添加到KNN分类器
|
||||
// 使用完整的logits作为特征向量
|
||||
this.knnClassifier.addExample(features.logits, classIndex);
|
||||
|
||||
totalProcessed++;
|
||||
const progress = Math.round((totalProcessed / totalImages) * 100);
|
||||
this.showStatus('dataStatus', 'info',
|
||||
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
|
||||
|
||||
// 显示提取的标签
|
||||
if (fileIndex === 0) {
|
||||
this.displayTopTags(features.topTags);
|
||||
}
|
||||
|
||||
// 清理
|
||||
img.remove();
|
||||
features.logits.dispose();
|
||||
} catch (error) {
|
||||
console.error('处理图片失败:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新模型信息
|
||||
document.getElementById('totalSamples').textContent = totalProcessed;
|
||||
|
||||
// 根据类别数量显示不同的消息
|
||||
let statusMessage;
|
||||
if (classes.length === 1) {
|
||||
statusMessage = `单品类检测模型训练完成!将只检测 "${classes[0]}",共 ${totalProcessed} 个样本`;
|
||||
} else {
|
||||
statusMessage = `KNN模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别`;
|
||||
}
|
||||
|
||||
this.showStatus('dataStatus', 'success', statusMessage);
|
||||
|
||||
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
|
||||
if (classes.length === 1) {
|
||||
console.log('📍 单品类检测模式已启用,将基于距离阈值判断是否为:', classes[0]);
|
||||
// 计算自适应阈值
|
||||
await this.calculateAdaptiveThreshold();
|
||||
}
|
||||
}
|
||||
|
||||
// 计算自适应阈值(基于训练数据的内部距离)
|
||||
async calculateAdaptiveThreshold() {
|
||||
if (this.knnClassifier.getNumClasses() !== 1) return;
|
||||
|
||||
console.log('计算自适应阈值...');
|
||||
|
||||
const dataset = this.knnClassifier.getClassifierDataset();
|
||||
if (!dataset || !dataset[0]) return;
|
||||
|
||||
const trainData = await dataset[0].data();
|
||||
const numSamples = dataset[0].shape[0];
|
||||
const featureDim = dataset[0].shape[1];
|
||||
|
||||
// 计算训练样本之间的平均距离
|
||||
let totalDistance = 0;
|
||||
let count = 0;
|
||||
|
||||
for (let i = 0; i < Math.min(numSamples, 20); i++) { // 限制计算量
|
||||
for (let j = i + 1; j < Math.min(numSamples, 20); j++) {
|
||||
let distance = 0;
|
||||
for (let k = 0; k < featureDim; k++) {
|
||||
const diff = trainData[i * featureDim + k] - trainData[j * featureDim + k];
|
||||
distance += diff * diff;
|
||||
}
|
||||
distance = Math.sqrt(distance);
|
||||
totalDistance += distance;
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
const avgInternalDistance = totalDistance / count;
|
||||
// 自适应阈值设为内部平均距离的1.3-1.5倍(归一化后距离较小)
|
||||
this.adaptiveThreshold = avgInternalDistance * 1.3;
|
||||
|
||||
console.log(`内部平均距离: ${avgInternalDistance.toFixed(2)}`);
|
||||
console.log(`建议自适应阈值: ${this.adaptiveThreshold.toFixed(2)}`);
|
||||
|
||||
// 更新UI显示建议阈值
|
||||
const thresholdInput = document.getElementById('distanceThreshold');
|
||||
const thresholdDisplay = document.getElementById('distanceThresholdDisplay');
|
||||
if (thresholdInput && thresholdDisplay) {
|
||||
thresholdInput.value = this.adaptiveThreshold.toFixed(1);
|
||||
thresholdDisplay.textContent = this.adaptiveThreshold.toFixed(1);
|
||||
}
|
||||
|
||||
this.showStatus('dataStatus', 'info',
|
||||
`自适应阈值已计算: ${this.adaptiveThreshold.toFixed(1)} (基于训练数据内部距离)`);
|
||||
}
|
||||
}
|
||||
|
||||
// 显示提取的标签
|
||||
displayTopTags(tags) {
|
||||
const container = document.getElementById('tagsList');
|
||||
let html = '';
|
||||
|
||||
tags.slice(0, 5).forEach(tag => {
|
||||
html += `
|
||||
<span class="tag-item">
|
||||
${tag.className}
|
||||
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
|
||||
</span>
|
||||
`;
|
||||
});
|
||||
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
// 加载图片
|
||||
async loadImage(file) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
const img = new Image();
|
||||
img.onload = () => resolve(img);
|
||||
img.onerror = reject;
|
||||
img.src = e.target.result;
|
||||
};
|
||||
reader.onerror = reject;
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
// 清空数据集
|
||||
clearDataset() {
|
||||
this.knnClassifier.clearAllClasses();
|
||||
this.classNames = [];
|
||||
this.filteredConfidences = {}; // 重置滤波器状态
|
||||
|
||||
console.log('数据集已清空,滤波器状态已重置');
|
||||
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
document.getElementById(`class${i}Images`).value = '';
|
||||
document.getElementById(`class${i}Count`).textContent = '0 张图片';
|
||||
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
|
||||
}
|
||||
|
||||
document.getElementById('totalSamples').textContent = '0';
|
||||
document.getElementById('tagsList').innerHTML = '等待数据...';
|
||||
document.getElementById('predictions').innerHTML = '等待预测...';
|
||||
|
||||
this.showStatus('dataStatus', 'info', '数据集已清空');
|
||||
}
|
||||
|
||||
// 启动摄像头
|
||||
async startWebcam() {
|
||||
if (this.knnClassifier.getNumClasses() === 0) {
|
||||
this.showStatus('predictionStatus', 'error', '请先训练模型!');
|
||||
return;
|
||||
}
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
video: { facingMode: 'user' },
|
||||
audio: false
|
||||
});
|
||||
|
||||
video.srcObject = stream;
|
||||
this.webcamStream = stream;
|
||||
|
||||
document.getElementById('startWebcamBtn').disabled = true;
|
||||
document.getElementById('stopWebcamBtn').disabled = false;
|
||||
|
||||
// 等待视频加载
|
||||
video.addEventListener('loadeddata', () => {
|
||||
this.isPredicting = true;
|
||||
this.predictLoop();
|
||||
});
|
||||
|
||||
this.showStatus('predictionStatus', 'success', '摄像头已启动');
|
||||
} catch (error) {
|
||||
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 停止摄像头
|
||||
stopWebcam() {
|
||||
if (this.webcamStream) {
|
||||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||||
this.webcamStream = null;
|
||||
}
|
||||
|
||||
this.isPredicting = false;
|
||||
this.filteredConfidences = {}; // 重置滤波器状态
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
video.srcObject = null;
|
||||
|
||||
document.getElementById('startWebcamBtn').disabled = false;
|
||||
document.getElementById('stopWebcamBtn').disabled = true;
|
||||
|
||||
this.showStatus('predictionStatus', 'info', '摄像头已停止');
|
||||
}
|
||||
|
||||
// 预测循环
|
||||
async predictLoop() {
|
||||
if (!this.isPredicting) return;
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
if (video.readyState === 4) {
|
||||
try {
|
||||
console.log('开始预测,类别数量:', this.classNames.length);
|
||||
|
||||
// 提取特征
|
||||
const features = await this.extractImageNetTags(video);
|
||||
|
||||
// 使用原始KNN进行预测(包含距离信息)
|
||||
const k = parseInt(document.getElementById('kValue').value);
|
||||
const predictionWithDistance = await this.predictWithDistance(features.logits, k);
|
||||
|
||||
console.log('预测结果:', predictionWithDistance);
|
||||
|
||||
// 检查是否为未知类(基于距离阈值)
|
||||
let finalPrediction;
|
||||
|
||||
// 单品类模式特殊处理
|
||||
if (predictionWithDistance.isSingleClass) {
|
||||
// 直接使用 predictWithDistance 返回的结果,它已经处理了阈值判断
|
||||
finalPrediction = {
|
||||
label: predictionWithDistance.label,
|
||||
confidences: predictionWithDistance.confidences,
|
||||
isUnknown: predictionWithDistance.label === -1,
|
||||
minDistance: predictionWithDistance.minDistance,
|
||||
isSingleClass: true
|
||||
};
|
||||
} else {
|
||||
// 多类别模式:直接使用KNN预测结果,不使用距离阈值
|
||||
finalPrediction = {
|
||||
label: predictionWithDistance.label,
|
||||
confidences: predictionWithDistance.confidences,
|
||||
isUnknown: false,
|
||||
minDistance: predictionWithDistance.minDistance,
|
||||
isSingleClass: false
|
||||
};
|
||||
}
|
||||
|
||||
// 应用低通滤波器
|
||||
const smoothedPrediction = this.applyLowPassFilter(finalPrediction);
|
||||
|
||||
// 显示预测结果
|
||||
this.displayPrediction(smoothedPrediction);
|
||||
|
||||
// 显示提取的标签
|
||||
this.displayTopTags(features.topTags);
|
||||
|
||||
// 清理张量
|
||||
features.logits.dispose();
|
||||
} catch (error) {
|
||||
console.error('预测错误:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// 继续预测循环
|
||||
requestAnimationFrame(() => this.predictLoop());
|
||||
}
|
||||
|
||||
// 使用距离信息进行预测
|
||||
async predictWithDistance(logits, k) {
|
||||
// 如果没有训练数据,返回空结果
|
||||
if (this.knnClassifier.getNumClasses() === 0) {
|
||||
return {
|
||||
label: -1,
|
||||
confidences: {},
|
||||
minDistance: Infinity,
|
||||
isSingleClass: false
|
||||
};
|
||||
}
|
||||
|
||||
const numClasses = this.knnClassifier.getNumClasses();
|
||||
|
||||
// 单品类检测模式 - 使用实际距离计算
|
||||
if (numClasses === 1) {
|
||||
console.log('单品类检测模式 - 计算实际距离');
|
||||
|
||||
// 获取训练数据
|
||||
const dataset = this.knnClassifier.getClassifierDataset();
|
||||
if (!dataset || !dataset[0]) {
|
||||
return {
|
||||
label: -1,
|
||||
confidences: { 0: 0 },
|
||||
minDistance: Infinity,
|
||||
isSingleClass: true
|
||||
};
|
||||
}
|
||||
|
||||
// 计算输入样本与所有训练样本的欧氏距离
|
||||
const inputData = await logits.data();
|
||||
const trainData = await dataset[0].data();
|
||||
const numSamples = dataset[0].shape[0];
|
||||
const featureDim = dataset[0].shape[1];
|
||||
|
||||
console.log(`输入特征维度: ${inputData.length}, 训练数据维度: ${featureDim}, 样本数: ${numSamples}`);
|
||||
|
||||
// 确保维度匹配
|
||||
if (inputData.length !== featureDim) {
|
||||
console.error(`维度不匹配!输入: ${inputData.length}, 训练: ${featureDim}`);
|
||||
return {
|
||||
label: -1,
|
||||
confidences: { 0: 0 },
|
||||
minDistance: Infinity,
|
||||
isSingleClass: true
|
||||
};
|
||||
}
|
||||
|
||||
let minDistance = Infinity;
|
||||
const distances = [];
|
||||
|
||||
// 计算与每个训练样本的距离
|
||||
for (let i = 0; i < numSamples; i++) {
|
||||
let distance = 0;
|
||||
for (let j = 0; j < featureDim; j++) {
|
||||
const diff = inputData[j] - trainData[i * featureDim + j];
|
||||
distance += diff * diff;
|
||||
}
|
||||
distance = Math.sqrt(distance);
|
||||
distances.push(distance);
|
||||
if (distance < minDistance) {
|
||||
minDistance = distance;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`计算了 ${distances.length} 个距离,最小距离: ${minDistance.toFixed(2)}`);
|
||||
console.log(`前5个距离: ${distances.slice(0, 5).map(d => d.toFixed(2)).join(', ')}`);
|
||||
|
||||
// 获取K个最近邻的平均距离
|
||||
distances.sort((a, b) => a - b);
|
||||
const kNearest = distances.slice(0, Math.min(k, distances.length));
|
||||
const avgDistance = kNearest.reduce((sum, d) => sum + d, 0) / kNearest.length;
|
||||
|
||||
// 从UI获取距离阈值
|
||||
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '15.0');
|
||||
|
||||
// 基于距离阈值判断是否属于该类
|
||||
const belongsToClass = avgDistance <= threshold;
|
||||
|
||||
// 二值化置信度:在阈值内100%,超出阈值0%
|
||||
const confidence = belongsToClass ? 1.0 : 0;
|
||||
|
||||
console.log(`单品类预测 - 平均距离: ${avgDistance.toFixed(2)}, 阈值: ${threshold}, 属于类别: ${belongsToClass}, 置信度: ${confidence.toFixed(3)}`);
|
||||
|
||||
return {
|
||||
label: belongsToClass ? 0 : -1,
|
||||
confidences: { 0: confidence },
|
||||
minDistance: avgDistance,
|
||||
isSingleClass: true
|
||||
};
|
||||
}
|
||||
|
||||
// 多品类模式:使用KNN分类器预测
|
||||
try {
|
||||
const prediction = await this.knnClassifier.predictClass(logits, k);
|
||||
|
||||
console.log('多品类预测结果:', prediction);
|
||||
console.log('预测标签:', prediction.label);
|
||||
console.log('置信度:', prediction.confidences);
|
||||
|
||||
// 确保confidences存在且格式正确
|
||||
let confidences = {};
|
||||
if (prediction.confidences) {
|
||||
// 检查是否是对象格式
|
||||
if (typeof prediction.confidences === 'object') {
|
||||
confidences = prediction.confidences;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果confidences为空,手动计算
|
||||
if (Object.keys(confidences).length === 0) {
|
||||
console.warn('置信度为空,使用默认值');
|
||||
for (let i = 0; i < this.classNames.length; i++) {
|
||||
confidences[i] = i === prediction.label ? 1.0 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
// 计算实际距离(可选)
|
||||
let minDistance = 0.5; // 默认距离
|
||||
|
||||
return {
|
||||
label: prediction.label,
|
||||
confidences: confidences,
|
||||
minDistance: minDistance,
|
||||
isSingleClass: false
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('预测错误:', error);
|
||||
return {
|
||||
label: -1,
|
||||
confidences: {},
|
||||
minDistance: Infinity,
|
||||
isSingleClass: false
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 应用低通滤波器到置信度
|
||||
applyLowPassFilter(prediction) {
|
||||
// 获取滤波器系数
|
||||
const alpha = parseFloat(document.getElementById('filterAlpha').value);
|
||||
|
||||
// 保留原始的特殊字段
|
||||
const isSingleClass = prediction.isSingleClass || false;
|
||||
const minDistance = prediction.minDistance;
|
||||
const isUnknown = prediction.isUnknown;
|
||||
|
||||
// 初始化滤波状态(如果是第一次)
|
||||
if (Object.keys(this.filteredConfidences).length === 0) {
|
||||
for (let i = 0; i < this.classNames.length; i++) {
|
||||
this.filteredConfidences[i] = prediction.confidences[i] || 0;
|
||||
}
|
||||
return {
|
||||
label: prediction.label,
|
||||
confidences: {...this.filteredConfidences},
|
||||
isSingleClass: isSingleClass,
|
||||
minDistance: minDistance,
|
||||
isUnknown: isUnknown
|
||||
};
|
||||
}
|
||||
|
||||
// 单品类模式下,使用二值化输出,不应用滤波
|
||||
if (isSingleClass) {
|
||||
// 获取距离阈值
|
||||
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
|
||||
|
||||
// 二值化判断:在阈值内为1,超出为0
|
||||
const inThreshold = minDistance <= threshold;
|
||||
const confidence = inThreshold ? 1.0 : 0;
|
||||
|
||||
// 直接更新,不滤波
|
||||
this.filteredConfidences[0] = confidence;
|
||||
|
||||
return {
|
||||
label: inThreshold ? 0 : -1,
|
||||
confidences: { 0: confidence },
|
||||
isSingleClass: isSingleClass,
|
||||
minDistance: minDistance,
|
||||
isUnknown: !inThreshold
|
||||
};
|
||||
}
|
||||
|
||||
// 应用指数移动平均(EMA)低通滤波
|
||||
const newConfidences = {};
|
||||
for (let i = 0; i < this.classNames.length; i++) {
|
||||
const currentValue = prediction.confidences[i] || 0;
|
||||
const previousValue = this.filteredConfidences[i] || 0;
|
||||
|
||||
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
|
||||
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
|
||||
newConfidences[i] = this.filteredConfidences[i];
|
||||
}
|
||||
|
||||
// 归一化确保总和为1
|
||||
let sum = 0;
|
||||
Object.values(newConfidences).forEach(v => sum += v);
|
||||
if (sum > 0) {
|
||||
Object.keys(newConfidences).forEach(key => {
|
||||
newConfidences[key] = newConfidences[key] / sum;
|
||||
});
|
||||
}
|
||||
|
||||
// 找到最高置信度的类别
|
||||
let maxConfidence = 0;
|
||||
let bestLabel = 0;
|
||||
Object.keys(newConfidences).forEach(key => {
|
||||
if (newConfidences[key] > maxConfidence) {
|
||||
maxConfidence = newConfidences[key];
|
||||
bestLabel = parseInt(key);
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
label: bestLabel,
|
||||
confidences: newConfidences,
|
||||
isSingleClass: isSingleClass,
|
||||
minDistance: minDistance,
|
||||
isUnknown: isUnknown
|
||||
};
|
||||
}
|
||||
|
||||
// 显示预测结果
|
||||
displayPrediction(prediction) {
|
||||
const container = document.getElementById('predictions');
|
||||
let html = '';
|
||||
|
||||
// 单品类模式特殊处理
|
||||
if (this.classNames.length === 1) {
|
||||
const className = this.classNames[0];
|
||||
const confidence = prediction.confidences[0] || 0;
|
||||
const percentage = (confidence * 100).toFixed(1);
|
||||
const isDetected = prediction.label === 0; // 是否检测到该类
|
||||
|
||||
// 获取距离阈值
|
||||
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
|
||||
const distance = prediction.minDistance || 0;
|
||||
|
||||
// 显示单品类检测结果(二值化显示)
|
||||
html = `
|
||||
<div class="prediction-item" style="${isDetected ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : 'border-left-color: #cbd5e0;'}">
|
||||
<div class="prediction-header">
|
||||
<span class="prediction-label">
|
||||
${className} ${isDetected ? '✓ 检测到' : '✗ 未检测到'}
|
||||
</span>
|
||||
<span class="prediction-confidence" style="${isDetected ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : 'background: #cbd5e0;'}">
|
||||
${isDetected ? '100%' : '0%'}
|
||||
</span>
|
||||
</div>
|
||||
<div class="confidence-bar-container">
|
||||
<div class="confidence-bar ${isDetected ? 'high' : 'low'}" style="width: ${isDetected ? '100' : '0'}%;">
|
||||
${isDetected ? `<span class="confidence-percentage">100%</span>` : ''}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div style="margin-top: 10px; padding: 10px; background: #f8f9fa; border-radius: 5px; font-size: 12px; color: #666;">
|
||||
<strong>距离:</strong> ${distance.toFixed(2)} | <strong>阈值:</strong> ${threshold.toFixed(2)}
|
||||
<span style="margin-left: 10px; color: ${distance <= threshold ? '#48bb78' : '#f56565'};">
|
||||
${distance <= threshold ? '✓ 在阈值范围内' : '✗ 超出阈值范围'}
|
||||
</span>
|
||||
</div>
|
||||
`;
|
||||
|
||||
container.innerHTML = html;
|
||||
return;
|
||||
}
|
||||
|
||||
// 多品类模式:直接使用滤波后的置信度
|
||||
const confidences = prediction.confidences;
|
||||
const predictedClass = prediction.label;
|
||||
|
||||
// 固定顺序显示(按类别索引)
|
||||
for (let i = 0; i < this.classNames.length; i++) {
|
||||
const className = this.classNames[i];
|
||||
const confidence = confidences[i] || 0;
|
||||
const percentage = (confidence * 100).toFixed(1);
|
||||
const isWinner = i === predictedClass;
|
||||
|
||||
// 根据置信度决定颜色等级
|
||||
let barClass = '';
|
||||
if (confidence > 0.7) barClass = 'high';
|
||||
else if (confidence > 0.4) barClass = 'medium';
|
||||
else barClass = 'low';
|
||||
|
||||
// 如果是获胜类别,使用绿色
|
||||
if (isWinner) barClass = 'high';
|
||||
|
||||
html += `
|
||||
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
|
||||
<div class="prediction-header">
|
||||
<span class="prediction-label">
|
||||
${className} ${isWinner ? '👑' : ''}
|
||||
</span>
|
||||
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
|
||||
${percentage}%
|
||||
</span>
|
||||
</div>
|
||||
<div class="confidence-bar-container">
|
||||
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
|
||||
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
// 从摄像头捕获样本
|
||||
async captureFromWebcam(classIndex) {
|
||||
if (!this.webcamStream) {
|
||||
// 临时启动摄像头
|
||||
const video = document.getElementById('webcam');
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
video: { facingMode: 'user' },
|
||||
audio: false
|
||||
});
|
||||
|
||||
video.srcObject = stream;
|
||||
this.webcamStream = stream;
|
||||
|
||||
// 等待视频加载
|
||||
setTimeout(async () => {
|
||||
await this.addWebcamSample(classIndex);
|
||||
|
||||
// 停止临时摄像头
|
||||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||||
this.webcamStream = null;
|
||||
video.srcObject = null;
|
||||
}, 1000);
|
||||
} catch (error) {
|
||||
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||||
}
|
||||
} else {
|
||||
await this.addWebcamSample(classIndex);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加摄像头样本
|
||||
async addWebcamSample(classIndex) {
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
if (video.readyState === 4) {
|
||||
try {
|
||||
// 提取特征
|
||||
const features = await this.extractImageNetTags(video);
|
||||
|
||||
// 添加到KNN分类器
|
||||
this.knnClassifier.addExample(features.logits, classIndex);
|
||||
|
||||
// 更新计数
|
||||
const currentCount = this.knnClassifier.getClassExampleCount();
|
||||
const count = currentCount[classIndex] || 0;
|
||||
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
|
||||
|
||||
// 清理
|
||||
features.logits.dispose();
|
||||
|
||||
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
|
||||
} catch (error) {
|
||||
console.error('添加样本失败:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存模型
|
||||
async saveModel() {
|
||||
if (this.knnClassifier.getNumClasses() === 0) {
|
||||
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取KNN分类器的数据
|
||||
const dataset = this.knnClassifier.getClassifierDataset();
|
||||
const datasetObj = {};
|
||||
|
||||
Object.keys(dataset).forEach(key => {
|
||||
const data = dataset[key].dataSync();
|
||||
datasetObj[key] = Array.from(data);
|
||||
});
|
||||
|
||||
// 获取特征维度
|
||||
let featureDim = 1280; // 默认值
|
||||
const firstKey = Object.keys(dataset)[0];
|
||||
if (firstKey && dataset[firstKey]) {
|
||||
featureDim = dataset[firstKey].shape[1];
|
||||
}
|
||||
|
||||
// 保存为JSON
|
||||
const modelData = {
|
||||
dataset: datasetObj,
|
||||
classNames: this.classNames,
|
||||
k: document.getElementById('kValue').value,
|
||||
featureDim: featureDim, // 保存特征维度
|
||||
date: new Date().toISOString()
|
||||
};
|
||||
|
||||
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = 'knn-model.json';
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
|
||||
this.showStatus('predictionStatus', 'success', '模型已保存');
|
||||
} catch (error) {
|
||||
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 加载模型
|
||||
async loadModel() {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
input.accept = '.json';
|
||||
|
||||
input.onchange = async (e) => {
|
||||
try {
|
||||
const file = e.target.files[0];
|
||||
const text = await file.text();
|
||||
const modelData = JSON.parse(text);
|
||||
|
||||
// 清空现有分类器
|
||||
this.knnClassifier.clearAllClasses();
|
||||
|
||||
// 恢复数据集
|
||||
Object.keys(modelData.dataset).forEach(key => {
|
||||
const data = modelData.dataset[key];
|
||||
|
||||
// 自动检测特征维度(兼容旧模型)
|
||||
let featureDim = modelData.featureDim;
|
||||
if (!featureDim) {
|
||||
// 尝试常见的维度
|
||||
const possibleDims = [1280, 1024, 1000];
|
||||
for (const dim of possibleDims) {
|
||||
if (data.length % dim === 0) {
|
||||
featureDim = dim;
|
||||
console.warn(`自动检测到特征维度: ${dim}`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!featureDim) {
|
||||
console.error(`无法确定特征维度,数据长度: ${data.length}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const numSamples = data.length / featureDim;
|
||||
console.log(`加载类别 ${key}:${numSamples} 个样本,${featureDim} 维特征`);
|
||||
|
||||
const tensor = tf.tensor(data, [numSamples, featureDim]);
|
||||
this.knnClassifier.setClassifierDataset({ [key]: tensor });
|
||||
});
|
||||
|
||||
this.classNames = modelData.classNames;
|
||||
document.getElementById('kValue').value = modelData.k;
|
||||
document.getElementById('kValueDisplay').textContent = modelData.k;
|
||||
|
||||
this.showStatus('predictionStatus', 'success',
|
||||
`模型加载成功!类别: ${this.classNames.join(', ')}`);
|
||||
} catch (error) {
|
||||
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
input.click();
|
||||
}
|
||||
|
||||
// 显示状态
|
||||
showStatus(elementId, type, message) {
|
||||
const element = document.getElementById(elementId);
|
||||
|
||||
const classMap = {
|
||||
'success': 'status-success',
|
||||
'error': 'status-error',
|
||||
'info': 'status-info'
|
||||
};
|
||||
|
||||
element.className = `status-message ${classMap[type]}`;
|
||||
element.textContent = message;
|
||||
}
|
||||
}
|
||||
|
||||
// 全局函数:从摄像头捕获
|
||||
function captureFromWebcam(classIndex) {
|
||||
if (window.classifier) {
|
||||
window.classifier.captureFromWebcam(classIndex);
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化应用
|
||||
let classifier;
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
classifier = new KNNImageClassifier();
|
||||
window.classifier = classifier;
|
||||
});
|
394
随机森林/decision-tree.js
Normal file
394
随机森林/decision-tree.js
Normal file
@ -0,0 +1,394 @@
|
||||
var dt = (function () {
|
||||
|
||||
/**
|
||||
* Creates an instance of DecisionTree
|
||||
*
|
||||
* @constructor
|
||||
* @param builder - contains training set and
|
||||
* some configuration parameters
|
||||
*/
|
||||
function DecisionTree(builder) {
|
||||
this.root = buildDecisionTree({
|
||||
trainingSet: builder.trainingSet,
|
||||
ignoredAttributes: arrayToHashSet(builder.ignoredAttributes),
|
||||
categoryAttr: builder.categoryAttr || 'category',
|
||||
minItemsCount: builder.minItemsCount || 1,
|
||||
entropyThrehold: builder.entropyThrehold || 0.01,
|
||||
maxTreeDepth: builder.maxTreeDepth || 70
|
||||
});
|
||||
}
|
||||
|
||||
DecisionTree.prototype.predict = function (item) {
|
||||
return predict(this.root, item);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of RandomForest
|
||||
* with specific number of trees
|
||||
*
|
||||
* @constructor
|
||||
* @param builder - contains training set and some
|
||||
* configuration parameters for
|
||||
* building decision trees
|
||||
*/
|
||||
function RandomForest(builder, treesNumber) {
|
||||
this.trees = buildRandomForest(builder, treesNumber);
|
||||
}
|
||||
|
||||
RandomForest.prototype.predict = function (item) {
|
||||
return predictRandomForest(this.trees, item);
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforming array to object with such attributes
|
||||
* as elements of array (afterwards it can be used as HashSet)
|
||||
*/
|
||||
function arrayToHashSet(array) {
|
||||
var hashSet = {};
|
||||
if (array) {
|
||||
for(var i in array) {
|
||||
var attr = array[i];
|
||||
hashSet[attr] = true;
|
||||
}
|
||||
}
|
||||
return hashSet;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculating how many objects have the same
|
||||
* values of specific attribute.
|
||||
*
|
||||
* @param items - array of objects
|
||||
*
|
||||
* @param attr - variable with name of attribute,
|
||||
* which embedded in each object
|
||||
*/
|
||||
function countUniqueValues(items, attr) {
|
||||
var counter = {};
|
||||
|
||||
// detecting different values of attribute
|
||||
for (var i = items.length - 1; i >= 0; i--) {
|
||||
// items[i][attr] - value of attribute
|
||||
counter[items[i][attr]] = 0;
|
||||
}
|
||||
|
||||
// counting number of occurrences of each of values
|
||||
// of attribute
|
||||
for (var i = items.length - 1; i >= 0; i--) {
|
||||
counter[items[i][attr]] += 1;
|
||||
}
|
||||
|
||||
return counter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculating entropy of array of objects
|
||||
* by specific attribute.
|
||||
*
|
||||
* @param items - array of objects
|
||||
*
|
||||
* @param attr - variable with name of attribute,
|
||||
* which embedded in each object
|
||||
*/
|
||||
function entropy(items, attr) {
|
||||
// counting number of occurrences of each of values
|
||||
// of attribute
|
||||
var counter = countUniqueValues(items, attr);
|
||||
|
||||
var entropy = 0;
|
||||
var p;
|
||||
for (var i in counter) {
|
||||
p = counter[i] / items.length;
|
||||
entropy += -p * Math.log(p);
|
||||
}
|
||||
|
||||
return entropy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Splitting array of objects by value of specific attribute,
|
||||
* using specific predicate and pivot.
|
||||
*
|
||||
* Items which matched by predicate will be copied to
|
||||
* the new array called 'match', and the rest of the items
|
||||
* will be copied to array with name 'notMatch'
|
||||
*
|
||||
* @param items - array of objects
|
||||
*
|
||||
* @param attr - variable with name of attribute,
|
||||
* which embedded in each object
|
||||
*
|
||||
* @param predicate - function(x, y)
|
||||
* which returns 'true' or 'false'
|
||||
*
|
||||
* @param pivot - used as the second argument when
|
||||
* calling predicate function:
|
||||
* e.g. predicate(item[attr], pivot)
|
||||
*/
|
||||
function split(items, attr, predicate, pivot) {
|
||||
var match = [];
|
||||
var notMatch = [];
|
||||
|
||||
var item,
|
||||
attrValue;
|
||||
|
||||
for (var i = items.length - 1; i >= 0; i--) {
|
||||
item = items[i];
|
||||
attrValue = item[attr];
|
||||
|
||||
if (predicate(attrValue, pivot)) {
|
||||
match.push(item);
|
||||
} else {
|
||||
notMatch.push(item);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
match: match,
|
||||
notMatch: notMatch
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Finding value of specific attribute which is most frequent
|
||||
* in given array of objects.
|
||||
*
|
||||
* @param items - array of objects
|
||||
*
|
||||
* @param attr - variable with name of attribute,
|
||||
* which embedded in each object
|
||||
*/
|
||||
function mostFrequentValue(items, attr) {
|
||||
// counting number of occurrences of each of values
|
||||
// of attribute
|
||||
var counter = countUniqueValues(items, attr);
|
||||
|
||||
var mostFrequentCount = 0;
|
||||
var mostFrequentValue;
|
||||
|
||||
for (var value in counter) {
|
||||
if (counter[value] > mostFrequentCount) {
|
||||
mostFrequentCount = counter[value];
|
||||
mostFrequentValue = value;
|
||||
}
|
||||
};
|
||||
|
||||
return mostFrequentValue;
|
||||
}
|
||||
|
||||
var predicates = {
|
||||
'==': function (a, b) { return a == b },
|
||||
'>=': function (a, b) { return a >= b }
|
||||
};
|
||||
|
||||
/**
|
||||
* Function for building decision tree
|
||||
*/
|
||||
function buildDecisionTree(builder) {
|
||||
|
||||
var trainingSet = builder.trainingSet;
|
||||
var minItemsCount = builder.minItemsCount;
|
||||
var categoryAttr = builder.categoryAttr;
|
||||
var entropyThrehold = builder.entropyThrehold;
|
||||
var maxTreeDepth = builder.maxTreeDepth;
|
||||
var ignoredAttributes = builder.ignoredAttributes;
|
||||
|
||||
if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) {
|
||||
// restriction by maximal depth of tree
|
||||
// or size of training set is to small
|
||||
// so we have to terminate process of building tree
|
||||
return {
|
||||
category: mostFrequentValue(trainingSet, categoryAttr)
|
||||
};
|
||||
}
|
||||
|
||||
var initialEntropy = entropy(trainingSet, categoryAttr);
|
||||
|
||||
if (initialEntropy <= entropyThrehold) {
|
||||
// entropy of training set too small
|
||||
// (it means that training set is almost homogeneous),
|
||||
// so we have to terminate process of building tree
|
||||
return {
|
||||
category: mostFrequentValue(trainingSet, categoryAttr)
|
||||
};
|
||||
}
|
||||
|
||||
// used as hash-set for avoiding the checking of split by rules
|
||||
// with the same 'attribute-predicate-pivot' more than once
|
||||
var alreadyChecked = {};
|
||||
|
||||
// this variable expected to contain rule, which splits training set
|
||||
// into subsets with smaller values of entropy (produces informational gain)
|
||||
var bestSplit = {gain: 0};
|
||||
|
||||
for (var i = trainingSet.length - 1; i >= 0; i--) {
|
||||
var item = trainingSet[i];
|
||||
|
||||
// iterating over all attributes of item
|
||||
for (var attr in item) {
|
||||
if ((attr == categoryAttr) || ignoredAttributes[attr]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// let the value of current attribute be the pivot
|
||||
var pivot = item[attr];
|
||||
|
||||
// pick the predicate
|
||||
// depending on the type of the attribute value
|
||||
var predicateName;
|
||||
if (typeof pivot == 'number') {
|
||||
predicateName = '>=';
|
||||
} else {
|
||||
// there is no sense to compare non-numeric attributes
|
||||
// so we will check only equality of such attributes
|
||||
predicateName = '==';
|
||||
}
|
||||
|
||||
var attrPredPivot = attr + predicateName + pivot;
|
||||
if (alreadyChecked[attrPredPivot]) {
|
||||
// skip such pairs of 'attribute-predicate-pivot',
|
||||
// which been already checked
|
||||
continue;
|
||||
}
|
||||
alreadyChecked[attrPredPivot] = true;
|
||||
|
||||
var predicate = predicates[predicateName];
|
||||
|
||||
// splitting training set by given 'attribute-predicate-value'
|
||||
var currSplit = split(trainingSet, attr, predicate, pivot);
|
||||
|
||||
// calculating entropy of subsets
|
||||
var matchEntropy = entropy(currSplit.match, categoryAttr);
|
||||
var notMatchEntropy = entropy(currSplit.notMatch, categoryAttr);
|
||||
|
||||
// calculating informational gain
|
||||
var newEntropy = 0;
|
||||
newEntropy += matchEntropy * currSplit.match.length;
|
||||
newEntropy += notMatchEntropy * currSplit.notMatch.length;
|
||||
newEntropy /= trainingSet.length;
|
||||
var currGain = initialEntropy - newEntropy;
|
||||
|
||||
if (currGain > bestSplit.gain) {
|
||||
// remember pairs 'attribute-predicate-value'
|
||||
// which provides informational gain
|
||||
bestSplit = currSplit;
|
||||
bestSplit.predicateName = predicateName;
|
||||
bestSplit.predicate = predicate;
|
||||
bestSplit.attribute = attr;
|
||||
bestSplit.pivot = pivot;
|
||||
bestSplit.gain = currGain;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!bestSplit.gain) {
|
||||
// can't find optimal split
|
||||
return { category: mostFrequentValue(trainingSet, categoryAttr) };
|
||||
}
|
||||
|
||||
// building subtrees
|
||||
|
||||
builder.maxTreeDepth = maxTreeDepth - 1;
|
||||
|
||||
builder.trainingSet = bestSplit.match;
|
||||
var matchSubTree = buildDecisionTree(builder);
|
||||
|
||||
builder.trainingSet = bestSplit.notMatch;
|
||||
var notMatchSubTree = buildDecisionTree(builder);
|
||||
|
||||
return {
|
||||
attribute: bestSplit.attribute,
|
||||
predicate: bestSplit.predicate,
|
||||
predicateName: bestSplit.predicateName,
|
||||
pivot: bestSplit.pivot,
|
||||
match: matchSubTree,
|
||||
notMatch: notMatchSubTree,
|
||||
matchedCount: bestSplit.match.length,
|
||||
notMatchedCount: bestSplit.notMatch.length
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Classifying item, using decision tree
|
||||
*/
|
||||
function predict(tree, item) {
|
||||
var attr,
|
||||
value,
|
||||
predicate,
|
||||
pivot;
|
||||
|
||||
// Traversing tree from the root to leaf
|
||||
while(true) {
|
||||
|
||||
if (tree.category) {
|
||||
// only leafs contains predicted category
|
||||
return tree.category;
|
||||
}
|
||||
|
||||
attr = tree.attribute;
|
||||
value = item[attr];
|
||||
|
||||
predicate = tree.predicate;
|
||||
pivot = tree.pivot;
|
||||
|
||||
// move to one of subtrees
|
||||
if (predicate(value, pivot)) {
|
||||
tree = tree.match;
|
||||
} else {
|
||||
tree = tree.notMatch;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Building array of decision trees
|
||||
*/
|
||||
function buildRandomForest(builder, treesNumber) {
|
||||
var items = builder.trainingSet;
|
||||
|
||||
// creating training sets for each tree
|
||||
var trainingSets = [];
|
||||
for (var t = 0; t < treesNumber; t++) {
|
||||
trainingSets[t] = [];
|
||||
}
|
||||
for (var i = items.length - 1; i >= 0 ; i--) {
|
||||
// assigning items to training sets of each tree
|
||||
// using 'round-robin' strategy
|
||||
var correspondingTree = i % treesNumber;
|
||||
trainingSets[correspondingTree].push(items[i]);
|
||||
}
|
||||
|
||||
// building decision trees
|
||||
var forest = [];
|
||||
for (var t = 0; t < treesNumber; t++) {
|
||||
builder.trainingSet = trainingSets[t];
|
||||
|
||||
var tree = new DecisionTree(builder);
|
||||
forest.push(tree);
|
||||
}
|
||||
return forest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Each of decision tree classifying item
|
||||
* ('voting' that item corresponds to some class).
|
||||
*
|
||||
* This function returns hash, which contains
|
||||
* all classifying results, and number of votes
|
||||
* which were given for each of classifying results
|
||||
*/
|
||||
function predictRandomForest(forest, item) {
|
||||
var result = {};
|
||||
for (var i in forest) {
|
||||
var tree = forest[i];
|
||||
var prediction = tree.predict(item);
|
||||
result[prediction] = result[prediction] ? result[prediction] + 1 : 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
var exports = {};
|
||||
exports.DecisionTree = DecisionTree;
|
||||
exports.RandomForest = RandomForest;
|
||||
return exports;
|
||||
})();
|
542
随机森林/rf-classifier.html
Normal file
542
随机森林/rf-classifier.html
Normal file
@ -0,0 +1,542 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>图像分类器 - TensorFlow.js & decision-tree.js</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
|
||||
<!-- 引入 decision-tree.js -->
|
||||
<script src="decision-tree.js"></script>
|
||||
<style>
|
||||
/* 与之前提供的样式相同...为了简洁,省略 */
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.main-container {
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: white;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
font-size: 2.5em;
|
||||
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.grid-container {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.grid-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 15px;
|
||||
padding: 25px;
|
||||
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.card h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
border-bottom: 2px solid #667eea;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.class-input {
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.class-input h3 {
|
||||
color: #555;
|
||||
margin-bottom: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.class-number {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
border-radius: 50%;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 5px;
|
||||
margin-bottom: 10px;
|
||||
font-size: 16px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
input[type="file"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.file-label {
|
||||
display: inline-block;
|
||||
padding: 10px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: background 0.3s;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.file-label:hover {
|
||||
background: #5a67d8;
|
||||
}
|
||||
|
||||
.btn {
|
||||
padding: 12px 30px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
font-size: 16px;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
margin: 5px;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background: #5a67d8;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn-success {
|
||||
background: #48bb78;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-success:hover {
|
||||
background: #38a169;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background: #f56565;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover {
|
||||
background: #e53e3e;
|
||||
}
|
||||
|
||||
.btn:disabled {
|
||||
background: #cbd5e0;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
#webcam-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
max-width: 640px;
|
||||
margin: 20px auto;
|
||||
}
|
||||
|
||||
#webcam {
|
||||
width: 100%;
|
||||
border-radius: 10px;
|
||||
background: #000;
|
||||
}
|
||||
|
||||
.samples-count {
|
||||
display: inline-block;
|
||||
background: #edf2f7;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
font-size: 12px;
|
||||
color: #4a5568;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.image-preview {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
margin-top: 10px;
|
||||
max-height: 150px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.preview-img {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
object-fit: cover;
|
||||
border-radius: 5px;
|
||||
border: 2px solid #e0e0e0;
|
||||
}
|
||||
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 5px;
|
||||
margin: 10px 0;
|
||||
text-align: center;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.status-success {
|
||||
background: #c6f6d5;
|
||||
color: #22543d;
|
||||
border: 1px solid #9ae6b4;
|
||||
}
|
||||
|
||||
.status-error {
|
||||
background: #fed7d7;
|
||||
color: #742a2a;
|
||||
border: 1px solid #fc8181;
|
||||
}
|
||||
|
||||
.status-info {
|
||||
background: #bee3f8;
|
||||
color: #2c5282;
|
||||
border: 1px solid #90cdf4;
|
||||
}
|
||||
|
||||
.button-group {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin: 20px 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.full-width {
|
||||
grid-column: 1 / -1;
|
||||
}
|
||||
|
||||
.prediction-results {
|
||||
margin-top: 20px;
|
||||
padding: 20px;
|
||||
background: #f7fafc;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.prediction-item {
|
||||
padding: 15px;
|
||||
margin: 10px 0;
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
border-left: 4px solid #667eea;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.prediction-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.prediction-label {
|
||||
font-weight: 600;
|
||||
color: #2d3748;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.prediction-confidence {
|
||||
background: linear-gradient(135deg, #667eea, #764ba2);
|
||||
color: white;
|
||||
padding: 4px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
min-width: 60px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.confidence-bar-container {
|
||||
width: 100%;
|
||||
height: 24px;
|
||||
background: #e2e8f0;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.confidence-bar {
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #667eea, #764ba2);
|
||||
border-radius: 12px;
|
||||
transition: width 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
position: relative;
|
||||
min-width: 0;
|
||||
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
|
||||
}
|
||||
|
||||
.confidence-bar::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.3), transparent);
|
||||
animation: shimmer 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes shimmer {
|
||||
0% {
|
||||
transform: translateX(-100%);
|
||||
}
|
||||
|
||||
100% {
|
||||
transform: translateX(100%);
|
||||
}
|
||||
}
|
||||
|
||||
.confidence-bar.high {
|
||||
background: linear-gradient(90deg, #48bb78, #38a169);
|
||||
}
|
||||
|
||||
.confidence-bar.medium {
|
||||
background: linear-gradient(90deg, #ed8936, #dd6b20);
|
||||
}
|
||||
|
||||
.confidence-bar.low {
|
||||
background: linear-gradient(90deg, #f56565, #e53e3e);
|
||||
}
|
||||
|
||||
.confidence-percentage {
|
||||
position: absolute;
|
||||
left: 50%;
|
||||
top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
font-size: 12px;
|
||||
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.top-tags {
|
||||
margin: 20px 0;
|
||||
padding: 15px;
|
||||
background: #edf2fe;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.tag-item {
|
||||
display: inline-block;
|
||||
background: white;
|
||||
padding: 5px 12px;
|
||||
margin: 5px;
|
||||
border-radius: 15px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #cbd5e0;
|
||||
}
|
||||
|
||||
.tag-weight {
|
||||
color: #667eea;
|
||||
font-weight: bold;
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.k-selector {
|
||||
margin: 15px 0;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.k-selector label {
|
||||
display: block;
|
||||
margin-bottom: 10px;
|
||||
color: #555;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.k-value-display {
|
||||
display: inline-block;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
padding: 2px 8px;
|
||||
border-radius: 5px;
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
input[type="range"] {
|
||||
width: 100%;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.model-info {
|
||||
margin-top: 20px;
|
||||
padding: 15px;
|
||||
background: #f0f4f8;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.info-item {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin: 5px 0;
|
||||
}
|
||||
|
||||
.info-label {
|
||||
color: #718096;
|
||||
}
|
||||
|
||||
.info-value {
|
||||
color: #2d3748;
|
||||
font-weight: 500;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="main-container">
|
||||
<h1>🤖 图像分类器 - 随机森林</h1>
|
||||
|
||||
<div class="grid-container">
|
||||
<!-- 数据采集卡片 -->
|
||||
<div class="card">
|
||||
<h2>📸 数据采集</h2>
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">1</span> 第一类</h3>
|
||||
<input type="text" id="class1Name" placeholder="输入类别名称(如:猫)" value="类别1">
|
||||
<label class="file-label" for="class1Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class1Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class1Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(0)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class1Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">2</span> 第二类</h3>
|
||||
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
|
||||
<label class="file-label" for="class2Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class2Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class2Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(1)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class2Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="class-input">
|
||||
<h3><span class="class-number">3</span> 第三类(可选)</h3>
|
||||
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="">
|
||||
<label class="file-label" for="class3Images">
|
||||
选择图片
|
||||
</label>
|
||||
<input type="file" id="class3Images" multiple accept="image/*">
|
||||
<span class="samples-count" id="class3Count">0 张图片</span>
|
||||
<button class="btn btn-primary" onclick="captureFromWebcam(2)">从摄像头采集</button>
|
||||
<div class="image-preview" id="class3Preview"></div>
|
||||
</div>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-success" id="addDataBtn">训练模型</button>
|
||||
<button class="btn btn-danger" id="clearDataBtn">清空数据</button>
|
||||
</div>
|
||||
</div>
|
||||
<!--参数调整-->
|
||||
<div class="card">
|
||||
<h2>模型参数调整</h2>
|
||||
<div class="k-selector">
|
||||
<label>
|
||||
树的数量
|
||||
<span class="k-value-display" id="numTreesDisplay">10</span>
|
||||
</label>
|
||||
<input type="range" id="numTrees" min="5" max="50" value="10"
|
||||
oninput="document.getElementById('numTreesDisplay').textContent = this.value">
|
||||
<small style="color: #718096;">随机森林中决策树的数量</small>
|
||||
</div>
|
||||
<div class="k-selector">
|
||||
<label>
|
||||
子集大小 (比例)
|
||||
<span class="k-value-display" id="subsetSizeDisplay">0.7</span>
|
||||
</label>
|
||||
<input type="range" id="subsetSize" min="0.1" max="1" step="0.1" value="0.7"
|
||||
oninput="document.getElementById('subsetSizeDisplay').textContent = this.value">
|
||||
<small style="color: #718096;">用于训练每棵树的数据子集占比 (0.1-1.0)</small>
|
||||
</div>
|
||||
|
||||
<div id="dataStatus"></div>
|
||||
<div class="model-info">
|
||||
<h3 style="margin-bottom: 10px;">ℹ️ 模型信息</h3>
|
||||
<div class="info-item">
|
||||
<span class="info-label">预训练模型:</span>
|
||||
<span class="info-value">MobileNet v2</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">分类器类型:</span>
|
||||
<span class="info-value">随机森林</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<span class="info-label">总样本数:</span>
|
||||
<span class="info-value" id="totalSamples">0</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 预测卡片 -->
|
||||
<div class="card full-width">
|
||||
<h2>📹 实时预测</h2>
|
||||
|
||||
<div class="button-group">
|
||||
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
|
||||
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
|
||||
</div>
|
||||
|
||||
<div id="webcam-container">
|
||||
<video id="webcam" autoplay playsinline muted></video>
|
||||
</div>
|
||||
|
||||
<div class="prediction-results" id="predictionResults">
|
||||
<h3>预测结果</h3>
|
||||
<div id="predictions">等待预测...</div>
|
||||
</div>
|
||||
|
||||
<div id="predictionStatus"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="rf-classifier.js"></script>
|
||||
</body>
|
||||
</html>
|
472
随机森林/rf-classifier.js
Normal file
472
随机森林/rf-classifier.js
Normal file
@ -0,0 +1,472 @@
|
||||
// 图像分类器 - 基于MobileNet特征标签和 decision-tree.js 实现随机森林
|
||||
|
||||
class ImageClassifier {
|
||||
constructor() {
|
||||
this.mobilenet = null;
|
||||
this.randomForest = []; // 存储多个决策树
|
||||
this.classNames = [];
|
||||
this.webcamStream = null;
|
||||
this.isPredicting = false;
|
||||
this.imagenetClasses = null;
|
||||
this.trainingSet = [];
|
||||
this.numTrees = 10; // 随机森林中决策树的数量,可调整
|
||||
this.subsetSize = 0.7; // 训练集子集大小, 可调整
|
||||
|
||||
this.init();
|
||||
}
|
||||
|
||||
async init() {
|
||||
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
|
||||
|
||||
try {
|
||||
this.mobilenet = await mobilenet.load({
|
||||
version: 2,
|
||||
alpha: 1.0
|
||||
});
|
||||
|
||||
await this.loadImageNetClasses();
|
||||
|
||||
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
|
||||
this.setupEventListeners();
|
||||
} catch (error) {
|
||||
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
async loadImageNetClasses() {
|
||||
// ImageNet 前10个类别名称(简化版)
|
||||
this.imagenetClasses = [
|
||||
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
|
||||
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich',
|
||||
'brambling', 'goldfinch', 'house_finch', 'junco', 'indigo_bunting',
|
||||
'robin', 'bulbul', 'jay', 'magpie', 'chickadee'
|
||||
];
|
||||
}
|
||||
|
||||
setupEventListeners() {
|
||||
// 文件上传监听
|
||||
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
|
||||
document.getElementById(id).addEventListener('change', (e) => {
|
||||
this.handleImageUpload(e, index);
|
||||
});
|
||||
});
|
||||
|
||||
// 按钮监听
|
||||
document.getElementById('addDataBtn').addEventListener('click', () => this.trainModel());
|
||||
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
|
||||
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
|
||||
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
|
||||
|
||||
// 参数监听
|
||||
document.getElementById('numTrees').addEventListener('input', (e) => {
|
||||
this.numTrees = parseInt(e.target.value);
|
||||
});
|
||||
|
||||
document.getElementById('subsetSize').addEventListener('input', (e) => {
|
||||
this.subsetSize = parseFloat(e.target.value);
|
||||
});
|
||||
}
|
||||
|
||||
handleImageUpload(event, classIndex) {
|
||||
const files = event.target.files;
|
||||
const countElement = document.getElementById(`class${classIndex + 1}Count`);
|
||||
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
|
||||
|
||||
countElement.textContent = `${files.length} 张图片`;
|
||||
|
||||
// 清空之前的预览
|
||||
previewContainer.innerHTML = '';
|
||||
|
||||
// 添加图片预览
|
||||
Array.from(files).forEach(file => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
const img = document.createElement('img');
|
||||
img.src = e.target.result;
|
||||
img.className = 'preview-img';
|
||||
previewContainer.appendChild(img);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
async extractImageNetTags(img) {
|
||||
try {
|
||||
const predictions = await this.mobilenet.classify(img);
|
||||
const logits = this.mobilenet.infer(img, false);
|
||||
return {
|
||||
logits: logits, // 1000维特征向量
|
||||
predictions: predictions, // 前3个预测
|
||||
topTags: await this.getTopKTags(logits, 10) // 前10个标签和权重
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('特征提取失败:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async getTopKTags(logits, k = 10) {
|
||||
const values = await logits.data();
|
||||
const valuesAndIndices = [];
|
||||
|
||||
for (let i = 0; i < values.length; i++) {
|
||||
valuesAndIndices.push({ value: values[i], index: i });
|
||||
}
|
||||
|
||||
valuesAndIndices.sort((a, b) => b.value - a.value);
|
||||
const topkValues = new Float32Array(k);
|
||||
const topkIndices = new Int32Array(k);
|
||||
|
||||
for (let i = 0; i < k; i++) {
|
||||
topkValues[i] = valuesAndIndices[i].value;
|
||||
topkIndices[i] = valuesAndIndices[i].index;
|
||||
}
|
||||
|
||||
const topTags = [];
|
||||
for (let i = 0; i < k; i++) {
|
||||
topTags.push({
|
||||
className: this.imagenetClasses[topkIndices[i]] || `class_${topkIndices[i]}`,
|
||||
probability: this.softmax(topkValues)[i],
|
||||
logit: topkValues[i]
|
||||
});
|
||||
}
|
||||
|
||||
return topTags;
|
||||
}
|
||||
|
||||
softmax(arr) {
|
||||
const maxLogit = Math.max(...arr);
|
||||
const scores = arr.map(l => Math.exp(l - maxLogit));
|
||||
const sum = scores.reduce((a, b) => a + b);
|
||||
return scores.map(s => s / sum);
|
||||
}
|
||||
// 训练随机森林模型
|
||||
async trainModel() {
|
||||
const classes = [];
|
||||
const imageFiles = [];
|
||||
|
||||
// 收集所有类别和图片
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
const className = document.getElementById(`class${i}Name`).value.trim();
|
||||
const files = document.getElementById(`class${i}Images`).files;
|
||||
|
||||
if (className && files.length > 0) {
|
||||
classes.push(className);
|
||||
imageFiles.push(files);
|
||||
}
|
||||
}
|
||||
|
||||
if (classes.length < 2) {
|
||||
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
|
||||
return;
|
||||
}
|
||||
|
||||
this.classNames = classes;
|
||||
this.showStatus('dataStatus', 'info', '正在处理图片并训练模型...');
|
||||
|
||||
// 准备训练数据
|
||||
this.trainingSet = [];
|
||||
|
||||
let totalProcessed = 0;
|
||||
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
|
||||
|
||||
// 处理每个类别的图片
|
||||
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
|
||||
const files = imageFiles[classIndex];
|
||||
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
|
||||
try {
|
||||
const img = await this.loadImage(files[fileIndex]);
|
||||
const features = await this.extractImageNetTags(img);
|
||||
// 将logits从tf.Tensor转换为数组
|
||||
const featureVector = await features.logits.data();
|
||||
|
||||
// 将特征向量添加到训练数据
|
||||
const item = {};
|
||||
Array.from(featureVector).forEach((value, index) => {
|
||||
item[`feature_${index}`] = value;
|
||||
});
|
||||
item.category = classes[classIndex]; // 类别名称,而不是索引
|
||||
this.trainingSet.push(item);
|
||||
|
||||
totalProcessed++;
|
||||
const progress = Math.round((totalProcessed / totalImages) * 100);
|
||||
this.showStatus('dataStatus', 'info',
|
||||
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
|
||||
|
||||
img.remove();
|
||||
features.logits.dispose();
|
||||
} catch (error) {
|
||||
console.error('处理图片失败:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// 构建随机森林
|
||||
this.randomForest = [];
|
||||
for (let i = 0; i < this.numTrees; i++) {
|
||||
// 创建具有随机子集的决策树
|
||||
const trainingSubset = this.createTrainingSubset(this.trainingSet);
|
||||
|
||||
const builder = {
|
||||
trainingSet: trainingSubset,
|
||||
categoryAttr: 'category'
|
||||
};
|
||||
const tree = new dt.DecisionTree(builder);
|
||||
this.randomForest.push(tree);
|
||||
}
|
||||
|
||||
this.showStatus('dataStatus', 'success', `模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别, ${this.numTrees} 棵树`);
|
||||
|
||||
// 更新模型信息
|
||||
document.getElementById('totalSamples').textContent = totalProcessed;
|
||||
} catch (error) {
|
||||
console.error('训练失败:', error);
|
||||
this.showStatus('dataStatus', 'error', `训练失败: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 创建训练数据的随机子集(用于随机森林)
|
||||
createTrainingSubset(trainingSet) {
|
||||
const subset = [];
|
||||
const subsetSize = Math.floor(this.subsetSize * trainingSet.length);
|
||||
|
||||
for (let i = 0; i < subsetSize; i++) {
|
||||
const randomIndex = Math.floor(Math.random() * trainingSet.length);
|
||||
subset.push(trainingSet[randomIndex]);
|
||||
}
|
||||
|
||||
return subset;
|
||||
}
|
||||
|
||||
async loadImage(file) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
const img = new Image();
|
||||
img.onload = () => resolve(img);
|
||||
img.onerror = reject;
|
||||
img.src = e.target.result;
|
||||
};
|
||||
reader.onerror = reject;
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
clearDataset() {
|
||||
this.randomForest = [];
|
||||
this.classNames = [];
|
||||
this.trainingSet = [];
|
||||
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
document.getElementById(`class${i}Images`).value = '';
|
||||
document.getElementById(`class${i}Count`).textContent = '0 张图片';
|
||||
document.getElementById(`class${i}Preview`).innerHTML = '';
|
||||
}
|
||||
|
||||
document.getElementById('totalSamples').textContent = '0';
|
||||
document.getElementById('predictions').innerHTML = '等待预测...';
|
||||
|
||||
this.showStatus('dataStatus', 'info', '数据集已清空');
|
||||
}
|
||||
|
||||
startWebcam() {
|
||||
if (this.randomForest.length == 0) {
|
||||
this.showStatus('predictionStatus', 'error', '请先训练模型!');
|
||||
return;
|
||||
}
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
navigator.mediaDevices.getUserMedia({
|
||||
video: { facingMode: 'user' },
|
||||
audio: false
|
||||
})
|
||||
.then(stream => {
|
||||
video.srcObject = stream;
|
||||
this.webcamStream = stream;
|
||||
|
||||
document.getElementById('startWebcamBtn').disabled = true;
|
||||
document.getElementById('stopWebcamBtn').disabled = false;
|
||||
|
||||
this.isPredicting = true;
|
||||
this.predictLoop();
|
||||
|
||||
this.showStatus('predictionStatus', 'success', '摄像头已启动');
|
||||
})
|
||||
.catch(error => {
|
||||
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||||
});
|
||||
}
|
||||
|
||||
stopWebcam() {
|
||||
if (this.webcamStream) {
|
||||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||||
this.webcamStream = null;
|
||||
}
|
||||
|
||||
this.isPredicting = false;
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
video.srcObject = null;
|
||||
|
||||
document.getElementById('startWebcamBtn').disabled = false;
|
||||
document.getElementById('stopWebcamBtn').disabled = true;
|
||||
|
||||
this.showStatus('predictionStatus', 'info', '摄像头已停止');
|
||||
}
|
||||
|
||||
async predictLoop() {
|
||||
if (!this.isPredicting) return;
|
||||
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
if (video.readyState === 4) {
|
||||
try {
|
||||
const features = await this.extractImageNetTags(video);
|
||||
const featureVector = await features.logits.data();
|
||||
|
||||
const item = {};
|
||||
Array.from(featureVector).forEach((value, index) => {
|
||||
item[`feature_${index}`] = value;
|
||||
});
|
||||
|
||||
const { predictedCategory, probabilities } = this.predictWithRandomForest(item);
|
||||
|
||||
this.displayPrediction(predictedCategory, probabilities);
|
||||
features.logits.dispose();
|
||||
|
||||
|
||||
} catch (error) {
|
||||
console.error('预测错误:', error);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
requestAnimationFrame(() => this.predictLoop());
|
||||
}
|
||||
|
||||
predictWithRandomForest(item) {
|
||||
const votes = {};
|
||||
this.classNames.forEach(className => {
|
||||
votes[className] = 0;
|
||||
});
|
||||
|
||||
this.randomForest.forEach(tree => {
|
||||
const prediction = tree.predict(item);
|
||||
votes[prediction] = (votes[prediction] || 0) + 1;
|
||||
});
|
||||
|
||||
let predictedCategory = null;
|
||||
let maxVotes = 0;
|
||||
for (const category in votes) {
|
||||
if (votes[category] > maxVotes) {
|
||||
predictedCategory = category;
|
||||
maxVotes = votes[category];
|
||||
}
|
||||
}
|
||||
|
||||
const probabilities = {};
|
||||
for (const category in votes) {
|
||||
probabilities[category] = votes[category] / this.numTrees;
|
||||
}
|
||||
|
||||
return {
|
||||
predictedCategory: predictedCategory,
|
||||
probabilities: probabilities
|
||||
};
|
||||
}
|
||||
|
||||
async captureFromWebcam(classIndex) {
|
||||
if (!this.webcamStream) {
|
||||
// 临时启动摄像头
|
||||
const video = document.getElementById('webcam');
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
video: { facingMode: 'user' },
|
||||
audio: false
|
||||
});
|
||||
|
||||
video.srcObject = stream;
|
||||
this.webcamStream = stream;
|
||||
|
||||
// 等待视频加载
|
||||
setTimeout(async () => {
|
||||
await this.addWebcamSample(classIndex);
|
||||
|
||||
// 停止临时摄像头
|
||||
this.webcamStream.getTracks().forEach(track => track.stop());
|
||||
this.webcamStream = null;
|
||||
video.srcObject = null;
|
||||
}, 1000);
|
||||
} catch (error) {
|
||||
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
|
||||
}
|
||||
} else {
|
||||
await this.addWebcamSample(classIndex);
|
||||
}
|
||||
}
|
||||
|
||||
async addWebcamSample(classIndex) {
|
||||
const video = document.getElementById('webcam');
|
||||
|
||||
if (video.readyState === 4) {
|
||||
try {
|
||||
const features = await this.extractImageNetTags(video);
|
||||
|
||||
const featureVector = await features.logits.data();
|
||||
// 使用logits作为对象的属性
|
||||
const item = {};
|
||||
|
||||
Array.from(featureVector).forEach((value, index) => {
|
||||
item[`feature_${index}`] = value;
|
||||
});
|
||||
|
||||
// 添加类别信息
|
||||
const className = document.getElementById(`class${classIndex + 1}Name`).value.trim();
|
||||
|
||||
item.category = className;
|
||||
|
||||
this.trainingSet.push(item);
|
||||
this.showStatus('dataStatus', 'success', `从摄像头添加 ${className} 类的样本`);
|
||||
features.logits.dispose();
|
||||
} catch (error) {
|
||||
console.error('添加样本失败:', error);
|
||||
this.showStatus('dataStatus', 'error', `添加样本失败: ${error.message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
displayPrediction(category, probabilities) {
|
||||
const container = document.getElementById('predictions');
|
||||
let html = `预测类别:${category}<br>`;
|
||||
for (const className in probabilities) {
|
||||
const probability = (probabilities[className] * 100).toFixed(2);
|
||||
html += `${className}: ${probability}%<br>`;
|
||||
}
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
showStatus(elementId, type, message) {
|
||||
const element = document.getElementById(elementId);
|
||||
|
||||
const classMap = {
|
||||
'success': 'status-success',
|
||||
'error': 'status-error',
|
||||
'info': 'status-info'
|
||||
};
|
||||
|
||||
element.className = `status-message ${classMap[type]}`;
|
||||
element.textContent = message;
|
||||
}
|
||||
}
|
||||
|
||||
// 全局函数:从摄像头采集
|
||||
function captureFromWebcam(classIndex) {
|
||||
if (window.classifier) {
|
||||
window.classifier.captureFromWebcam(classIndex);
|
||||
}
|
||||
}
|
||||
document.addEventListener('DOMContentLoaded', async () => {
|
||||
window.classifier = new ImageClassifier();
|
||||
});
|
Loading…
x
Reference in New Issue
Block a user