初始化

This commit is contained in:
51hhh 2025-08-11 17:44:57 +08:00
commit 8f998b1915
10 changed files with 5901 additions and 0 deletions

69
README.md Normal file
View File

@ -0,0 +1,69 @@
# 图像分类器项目
本项目包含多种基于MobileNet特征提取的图像分类器实现包括三层神经网络、随机森林和KNN算法。
## 模型对比
| 模型 | 算法类型 | 特点 | 适用场景 |
|------|---------|------|---------|
| 三层神经网络 | 深度学习 | 自定义全连接网络,支持训练监控、早停、正则化 | 需要高精度、可解释性强的场景 |
| 随机森林 | 集成学习 | 多决策树投票,参数直观可调 | 中等规模数据集,需要模型解释性 |
| KNN (原版) | 实例学习 | 简单实现,预测结果平滑处理 | 快速原型开发,小规模数据 |
| KNN (完善版) | 实例学习 | 增强阈值控制,支持单类别检测 | 异常检测、单类别分类场景 |
## 详细说明
### 三层神经网络分类器
- 使用MobileNet进行特征提取
- 自定义三层全连接网络进行分类
- 功能特点:
- 训练过程可视化(损失/准确率曲线)
- 支持早停、正则化等技巧
- 模型保存/加载功能
- 温度缩放调整预测置信度
### 随机森林分类器
- 使用MobileNet特征作为输入
- 构建多个决策树进行集成分类
- 可调参数:
- 决策树数量默认10棵
- 训练集子集比例默认70%
- 特点:
- 训练速度快
- 提供ImageNet标签显示功能
### KNN分类器原版
- 基于MobileNet特征的K最近邻算法
- 特点:
- 实现简单
- 低通滤波器平滑预测结果
- 支持模型保存/加载
### KNN分类器完善版
在原版基础上增强:
- 距离阈值控制
- 自适应阈值计算
- 改进的单类别检测
- 更详细的训练反馈
## 使用指南
1. 选择分类器类型
2. 上传各类别训练图片
3. 调整模型参数(如适用)
4. 点击"训练模型"按钮
5. 使用摄像头或上传图片进行预测
## 技术实现
所有分类器均基于以下技术:
- 特征提取MobileNet (TensorFlow.js)
- 前端框架纯JavaScript实现
- 数据存储浏览器本地存储IndexedDB
## 开发建议
- 对于高精度需求:使用三层神经网络
- 对于可解释性需求:使用随机森林
- 对于快速原型开发使用KNN
- 对于异常检测使用完善版KNN

View File

@ -0,0 +1,532 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>自定义图像分类器 - TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.main-container {
max-width: 1400px;
margin: 0 auto;
}
h1 {
color: white;
text-align: center;
margin-bottom: 30px;
font-size: 2.5em;
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}
.grid-container {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-bottom: 20px;
}
@media (max-width: 768px) {
.grid-container {
grid-template-columns: 1fr;
}
}
.card {
background: white;
border-radius: 15px;
padding: 25px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.class-input {
margin-bottom: 20px;
padding: 15px;
background: #f8f9fa;
border-radius: 10px;
}
.class-input h3 {
color: #555;
margin-bottom: 10px;
display: flex;
align-items: center;
gap: 10px;
}
.class-number {
background: #667eea;
color: white;
width: 25px;
height: 25px;
border-radius: 50%;
display: inline-flex;
align-items: center;
justify-content: center;
font-size: 14px;
}
input[type="text"] {
width: 100%;
padding: 10px;
border: 2px solid #e0e0e0;
border-radius: 5px;
margin-bottom: 10px;
font-size: 16px;
transition: border-color 0.3s;
}
input[type="text"]:focus {
outline: none;
border-color: #667eea;
}
input[type="file"] {
display: none;
}
.file-label {
display: inline-block;
padding: 10px 20px;
background: #667eea;
color: white;
border-radius: 5px;
cursor: pointer;
transition: background 0.3s;
margin-right: 10px;
}
.file-label:hover {
background: #5a67d8;
}
.image-preview {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-top: 10px;
max-height: 150px;
overflow-y: auto;
}
.preview-img {
width: 60px;
height: 60px;
object-fit: cover;
border-radius: 5px;
border: 2px solid #e0e0e0;
}
.btn {
padding: 12px 30px;
border: none;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
transition: all 0.3s;
margin: 5px;
}
.btn-primary {
background: #667eea;
color: white;
}
.btn-primary:hover {
background: #5a67d8;
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
.btn-success {
background: #48bb78;
color: white;
}
.btn-success:hover {
background: #38a169;
}
.btn-danger {
background: #f56565;
color: white;
}
.btn-danger:hover {
background: #e53e3e;
}
.btn:disabled {
background: #cbd5e0;
cursor: not-allowed;
transform: none;
}
#webcam-container {
position: relative;
width: 100%;
max-width: 640px;
margin: 20px auto;
}
#webcam {
width: 100%;
border-radius: 10px;
background: #000;
}
.confidence-bars {
margin-top: 20px;
}
.confidence-item {
margin-bottom: 15px;
}
.confidence-label {
display: flex;
justify-content: space-between;
margin-bottom: 5px;
font-weight: 500;
}
.confidence-bar {
height: 30px;
background: #e0e0e0;
border-radius: 15px;
overflow: hidden;
position: relative;
}
.confidence-fill {
height: 100%;
background: linear-gradient(90deg, #667eea, #764ba2);
border-radius: 15px;
transition: width 0.3s ease;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-weight: bold;
}
.status-message {
padding: 15px;
border-radius: 5px;
margin: 10px 0;
text-align: center;
font-weight: 500;
}
.status-success {
background: #c6f6d5;
color: #22543d;
border: 1px solid #9ae6b4;
}
.status-error {
background: #fed7d7;
color: #742a2a;
border: 1px solid #fc8181;
}
.status-info {
background: #bee3f8;
color: #2c5282;
border: 1px solid #90cdf4;
}
.status-warning {
background: #fef3c7;
color: #92400e;
border: 1px solid #fcd34d;
}
.training-progress {
margin: 20px 0;
}
.progress-bar {
height: 25px;
background: #e0e0e0;
border-radius: 12px;
overflow: hidden;
}
.progress-fill {
height: 100%;
background: linear-gradient(90deg, #48bb78, #38a169);
transition: width 0.3s;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-weight: bold;
}
.metrics {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
gap: 15px;
margin: 20px 0;
}
.metric-card {
background: #f7fafc;
padding: 15px;
border-radius: 8px;
text-align: center;
border: 1px solid #e2e8f0;
}
.metric-label {
color: #718096;
font-size: 12px;
text-transform: uppercase;
margin-bottom: 5px;
}
.metric-value {
color: #2d3748;
font-size: 24px;
font-weight: bold;
}
.samples-count {
display: inline-block;
background: #edf2f7;
padding: 2px 8px;
border-radius: 10px;
font-size: 12px;
color: #4a5568;
margin-left: 5px;
}
.full-width {
grid-column: 1 / -1;
}
.button-group {
display: flex;
gap: 10px;
margin: 20px 0;
flex-wrap: wrap;
}
.hidden {
display: none;
}
#lossChart {
width: 100%;
height: 300px;
margin-top: 20px;
border: 1px solid #e0e0e0;
border-radius: 8px;
}
</style>
</head>
<body>
<div class="main-container">
<h1>🤖 自定义图像分类器</h1>
<div class="grid-container">
<!-- 数据采集卡片 -->
<div class="card">
<h2>📸 数据采集</h2>
<div class="class-input">
<h3><span class="class-number">1</span> 第一类</h3>
<input type="text" id="class1Name" placeholder="输入类别名称(如:人)" value="类别1">
<label class="file-label" for="class1Images">
选择图片
</label>
<input type="file" id="class1Images" multiple accept="image/*">
<span class="samples-count" id="class1Count">0 张图片</span>
<div class="image-preview" id="class1Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">2</span> 第二类</h3>
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
<label class="file-label" for="class2Images">
选择图片
</label>
<input type="file" id="class2Images" multiple accept="image/*">
<span class="samples-count" id="class2Count">0 张图片</span>
<div class="image-preview" id="class2Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">3</span> 第三类(可选)</h3>
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="">
<label class="file-label" for="class3Images">
选择图片
</label>
<input type="file" id="class3Images" multiple accept="image/*">
<span class="samples-count" id="class3Count">0 张图片</span>
<div class="image-preview" id="class3Preview"></div>
</div>
<div class="button-group">
<button class="btn btn-primary" id="addDataBtn">添加到数据集</button>
<button class="btn btn-danger" id="clearDataBtn">清空数据集</button>
</div>
<div id="dataStatus"></div>
</div>
<!-- 训练控制卡片 -->
<div class="card">
<h2>🎯 模型训练</h2>
<!-- 超参数调节 -->
<div class="hyperparameters" style="background: #f8f9fa; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
<h3 style="margin-bottom: 15px; color: #555;">⚙️ 超参数设置</h3>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 15px;">
<div>
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
学习率: <span id="learningRateValue">0.001</span>
</label>
<input type="range" id="learningRate" min="-5" max="-1" step="0.1" value="-3"
style="width: 100%;" oninput="document.getElementById('learningRateValue').textContent = Math.pow(10, this.value).toFixed(5)">
</div>
<div>
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
训练轮数: <span id="epochsValue">100</span>
</label>
<input type="range" id="epochs" min="10" max="200" step="10" value="100"
style="width: 100%;" oninput="document.getElementById('epochsValue').textContent = this.value">
</div>
<div>
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
Dropout率: <span id="dropoutValue">0.3</span>
</label>
<input type="range" id="dropoutRate" min="0" max="0.7" step="0.05" value="0.3"
style="width: 100%;" oninput="document.getElementById('dropoutValue').textContent = this.value">
</div>
<div>
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
L2正则化: <span id="l2Value">0.01</span>
</label>
<input type="range" id="l2Regularization" min="-4" max="-1" step="0.1" value="-2"
style="width: 100%;" oninput="document.getElementById('l2Value').textContent = Math.pow(10, this.value).toFixed(4)">
</div>
<div>
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
批次大小: <span id="batchSizeValue">32</span>
</label>
<input type="range" id="batchSize" min="8" max="128" step="8" value="32"
style="width: 100%;" oninput="document.getElementById('batchSizeValue').textContent = this.value">
</div>
<div>
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
温度缩放: <span id="temperatureValue">1.0</span>
</label>
<input type="range" id="temperature" min="0.5" max="5" step="0.1" value="1.0"
style="width: 100%;" oninput="document.getElementById('temperatureValue').textContent = this.value">
</div>
</div>
<div style="margin-top: 15px;">
<label style="display: flex; align-items: center; gap: 10px; cursor: pointer;">
<input type="checkbox" id="earlyStoppingCheck" checked>
<span style="font-size: 14px; color: #666;">启用早停 (防止过拟合)</span>
</label>
<div id="earlyStoppingOptions" style="margin-top: 10px; padding-left: 25px;">
<label style="display: block; margin-bottom: 5px; font-size: 14px; color: #666;">
耐心值: <span id="patienceValue">10</span> epochs
</label>
<input type="range" id="patience" min="5" max="30" step="5" value="10"
style="width: 200px;" oninput="document.getElementById('patienceValue').textContent = this.value">
</div>
</div>
<div style="margin-top: 15px;">
<label style="display: flex; align-items: center; gap: 10px; cursor: pointer;">
<input type="checkbox" id="dataAugmentationCheck">
<span style="font-size: 14px; color: #666;">启用数据增强 (提高泛化能力)</span>
</label>
</div>
<div style="margin-top: 15px; display: flex; gap: 10px;">
<button class="btn btn-primary" onclick="classifier.resetHyperparameters()" style="padding: 8px 15px; font-size: 14px;">
重置为默认值
</button>
<button class="btn btn-primary" onclick="classifier.showRecommendations()" style="padding: 8px 15px; font-size: 14px;">
查看建议设置
</button>
</div>
</div>
<div class="button-group">
<button class="btn btn-success" id="trainBtn">开始训练</button>
<button class="btn btn-danger" id="stopBtn" disabled>停止训练</button>
</div>
<div id="trainingStatus"></div>
<div class="training-progress hidden" id="trainingProgress">
<div class="progress-bar">
<div class="progress-fill" id="progressFill">0%</div>
</div>
</div>
<div class="metrics" id="metricsContainer"></div>
<canvas id="lossChart"></canvas>
</div>
</div>
<!-- 预测卡片 -->
<div class="card full-width">
<h2>📹 实时预测</h2>
<div class="button-group">
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
<button class="btn btn-success" id="saveModelBtn">保存模型</button>
<button class="btn btn-primary" id="loadModelBtn">加载模型</button>
</div>
<div id="webcam-container">
<video id="webcam" autoplay playsinline muted></video>
</div>
<div class="confidence-bars" id="confidenceBars"></div>
<div id="predictionStatus"></div>
</div>
</div>
<script src="custom-classifier.js"></script>
</body>
</html>

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,551 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>KNN 图像分类器 - TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier@latest"></script>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.main-container {
max-width: 1400px;
margin: 0 auto;
}
h1 {
color: white;
text-align: center;
margin-bottom: 30px;
font-size: 2.5em;
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}
.grid-container {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-bottom: 20px;
}
@media (max-width: 768px) {
.grid-container {
grid-template-columns: 1fr;
}
}
.card {
background: white;
border-radius: 15px;
padding: 25px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.class-input {
margin-bottom: 20px;
padding: 15px;
background: #f8f9fa;
border-radius: 10px;
}
.class-input h3 {
color: #555;
margin-bottom: 10px;
display: flex;
align-items: center;
gap: 10px;
}
.class-number {
background: #667eea;
color: white;
width: 25px;
height: 25px;
border-radius: 50%;
display: inline-flex;
align-items: center;
justify-content: center;
font-size: 14px;
}
input[type="text"] {
width: 100%;
padding: 10px;
border: 2px solid #e0e0e0;
border-radius: 5px;
margin-bottom: 10px;
font-size: 16px;
transition: border-color 0.3s;
}
input[type="text"]:focus {
outline: none;
border-color: #667eea;
}
input[type="file"] {
display: none;
}
.file-label {
display: inline-block;
padding: 10px 20px;
background: #667eea;
color: white;
border-radius: 5px;
cursor: pointer;
transition: background 0.3s;
margin-right: 10px;
}
.file-label:hover {
background: #5a67d8;
}
.btn {
padding: 12px 30px;
border: none;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
transition: all 0.3s;
margin: 5px;
}
.btn-primary {
background: #667eea;
color: white;
}
.btn-primary:hover {
background: #5a67d8;
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
.btn-success {
background: #48bb78;
color: white;
}
.btn-success:hover {
background: #38a169;
}
.btn-danger {
background: #f56565;
color: white;
}
.btn-danger:hover {
background: #e53e3e;
}
.btn:disabled {
background: #cbd5e0;
cursor: not-allowed;
transform: none;
}
#webcam-container {
position: relative;
width: 100%;
max-width: 640px;
margin: 20px auto;
}
#webcam {
width: 100%;
border-radius: 10px;
background: #000;
}
.samples-count {
display: inline-block;
background: #edf2f7;
padding: 2px 8px;
border-radius: 10px;
font-size: 12px;
color: #4a5568;
margin-left: 5px;
}
.image-preview {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-top: 10px;
max-height: 150px;
overflow-y: auto;
}
.preview-img {
width: 60px;
height: 60px;
object-fit: cover;
border-radius: 5px;
border: 2px solid #e0e0e0;
}
.status-message {
padding: 15px;
border-radius: 5px;
margin: 10px 0;
text-align: center;
font-weight: 500;
}
.status-success {
background: #c6f6d5;
color: #22543d;
border: 1px solid #9ae6b4;
}
.status-error {
background: #fed7d7;
color: #742a2a;
border: 1px solid #fc8181;
}
.status-info {
background: #bee3f8;
color: #2c5282;
border: 1px solid #90cdf4;
}
.button-group {
display: flex;
gap: 10px;
margin: 20px 0;
flex-wrap: wrap;
}
.full-width {
grid-column: 1 / -1;
}
.prediction-results {
margin-top: 20px;
padding: 20px;
background: #f7fafc;
border-radius: 10px;
}
.prediction-item {
padding: 15px;
margin: 10px 0;
background: white;
border-radius: 8px;
border-left: 4px solid #667eea;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.prediction-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 8px;
}
.prediction-label {
font-weight: 600;
color: #2d3748;
font-size: 16px;
}
.prediction-confidence {
background: linear-gradient(135deg, #667eea, #764ba2);
color: white;
padding: 4px 12px;
border-radius: 20px;
font-size: 14px;
font-weight: 500;
min-width: 60px;
text-align: center;
}
.confidence-bar-container {
width: 100%;
height: 24px;
background: #e2e8f0;
border-radius: 12px;
overflow: hidden;
position: relative;
}
.confidence-bar {
height: 100%;
background: linear-gradient(90deg, #667eea, #764ba2);
border-radius: 12px;
transition: width 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
min-width: 0;
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
}
.confidence-bar::after {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: linear-gradient(90deg, transparent, rgba(255,255,255,0.3), transparent);
animation: shimmer 2s infinite;
}
@keyframes shimmer {
0% { transform: translateX(-100%); }
100% { transform: translateX(100%); }
}
.confidence-bar.high {
background: linear-gradient(90deg, #48bb78, #38a169);
}
.confidence-bar.medium {
background: linear-gradient(90deg, #ed8936, #dd6b20);
}
.confidence-bar.low {
background: linear-gradient(90deg, #f56565, #e53e3e);
}
.confidence-percentage {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
color: white;
font-weight: 600;
font-size: 12px;
text-shadow: 0 1px 2px rgba(0,0,0,0.2);
z-index: 1;
}
.top-tags {
margin: 20px 0;
padding: 15px;
background: #edf2fe;
border-radius: 10px;
}
.tag-item {
display: inline-block;
background: white;
padding: 5px 12px;
margin: 5px;
border-radius: 15px;
font-size: 14px;
border: 1px solid #cbd5e0;
}
.tag-weight {
color: #667eea;
font-weight: bold;
margin-left: 5px;
}
.k-selector {
margin: 15px 0;
padding: 15px;
background: #f8f9fa;
border-radius: 8px;
}
.k-selector label {
display: block;
margin-bottom: 10px;
color: #555;
font-weight: 500;
}
.k-value-display {
display: inline-block;
background: #667eea;
color: white;
padding: 2px 8px;
border-radius: 5px;
margin-left: 10px;
}
input[type="range"] {
width: 100%;
margin: 10px 0;
}
.model-info {
margin-top: 20px;
padding: 15px;
background: #f0f4f8;
border-radius: 8px;
font-size: 14px;
}
.info-item {
display: flex;
justify-content: space-between;
margin: 5px 0;
}
.info-label {
color: #718096;
}
.info-value {
color: #2d3748;
font-weight: 500;
}
</style>
</head>
<body>
<div class="main-container">
<h1>🤖 KNN 图像分类器(基于特征标签)</h1>
<div class="grid-container">
<!-- 数据采集卡片 -->
<div class="card">
<h2>📸 数据采集</h2>
<div class="class-input">
<h3><span class="class-number">1</span> 第一类</h3>
<input type="text" id="class1Name" placeholder="输入类别名称(如:猫)" value="类别1">
<label class="file-label" for="class1Images">
选择图片
</label>
<input type="file" id="class1Images" multiple accept="image/*">
<span class="samples-count" id="class1Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(0)">从摄像头采集</button>
<div class="image-preview" id="class1Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">2</span> 第二类</h3>
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
<label class="file-label" for="class2Images">
选择图片
</label>
<input type="file" id="class2Images" multiple accept="image/*">
<span class="samples-count" id="class2Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(1)">从摄像头采集</button>
<div class="image-preview" id="class2Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">3</span> 第三类(可选)</h3>
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="">
<label class="file-label" for="class3Images">
选择图片
</label>
<input type="file" id="class3Images" multiple accept="image/*">
<span class="samples-count" id="class3Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(2)">从摄像头采集</button>
<div class="image-preview" id="class3Preview"></div>
</div>
<div class="button-group">
<button class="btn btn-success" id="addDataBtn">训练KNN模型</button>
<button class="btn btn-danger" id="clearDataBtn">清空数据</button>
</div>
<div id="dataStatus"></div>
</div>
<!-- KNN模型信息卡片 -->
<div class="card">
<h2>🎯 KNN 模型设置</h2>
<div class="k-selector">
<label>
K值最近邻数量
<span class="k-value-display" id="kValueDisplay">3</span>
</label>
<input type="range" id="kValue" min="1" max="20" value="3"
oninput="document.getElementById('kValueDisplay').textContent = this.value">
<small style="color: #718096;">K值越大预测越保守K值越小对局部特征越敏感</small>
</div>
<div class="k-selector">
<label>
滤波器系数 (α)
<span class="k-value-display" id="filterAlphaDisplay">0.3</span>
</label>
<input type="range" id="filterAlpha" min="0.05" max="1.0" step="0.05" value="0.3"
oninput="document.getElementById('filterAlphaDisplay').textContent = this.value">
<small style="color: #718096;">低通滤波器系数:值越小输出越平滑(0.1-0.3推荐),值越大响应越快</small>
</div>
<div class="top-tags" id="topTags">
<h3 style="margin-bottom: 10px;">📊 特征标签提取预览</h3>
<div id="tagsList">等待数据...</div>
</div>
<div class="model-info">
<h3 style="margin-bottom: 10px;"> 模型信息</h3>
<div class="info-item">
<span class="info-label">预训练模型:</span>
<span class="info-value">MobileNet v2</span>
</div>
<div class="info-item">
<span class="info-label">特征维度:</span>
<span class="info-value">1000个标签</span>
</div>
<div class="info-item">
<span class="info-label">分类器类型:</span>
<span class="info-value">K-最近邻 (KNN)</span>
</div>
<div class="info-item">
<span class="info-label">总样本数:</span>
<span class="info-value" id="totalSamples">0</span>
</div>
</div>
</div>
</div>
<!-- 预测卡片 -->
<div class="card full-width">
<h2>📹 实时预测</h2>
<div class="button-group">
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
<button class="btn btn-success" id="saveModelBtn">保存模型</button>
<button class="btn btn-primary" id="loadModelBtn">加载模型</button>
</div>
<div id="webcam-container">
<video id="webcam" autoplay playsinline muted></video>
</div>
<div class="prediction-results" id="predictionResults">
<h3>预测结果</h3>
<div id="predictions">等待预测...</div>
</div>
<div id="predictionStatus"></div>
</div>
</div>
<script src="knn-classifier.js"></script>
</body>
</html>

621
原版KNN/knn-classifier.js Normal file
View File

@ -0,0 +1,621 @@
// KNN 图像分类器 - 基于MobileNet特征标签
class KNNImageClassifier {
constructor() {
this.mobilenet = null;
this.knnClassifier = null;
this.classNames = [];
this.webcamStream = null;
this.isPredicting = false;
this.currentCaptureClass = -1;
this.imagenetClasses = null;
// 低通滤波器状态
this.filteredConfidences = {};
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
// 加载 MobileNet 模型
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
// 创建 KNN 分类器
this.knnClassifier = knnClassifier.create();
// 加载 ImageNet 类别名称
await this.loadImageNetClasses();
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
async loadImageNetClasses() {
// ImageNet 前10个类别名称简化版
this.imagenetClasses = [
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
];
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const countElement = document.getElementById(`class${classIndex + 1}Count`);
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
countElement.textContent = `${files.length} 张图片`;
// 清空之前的预览
previewContainer.innerHTML = '';
// 添加图片预览
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
}
// 从图像提取 MobileNet 标签和权重
async extractImageNetTags(img) {
try {
// 获取 MobileNet 的预测1000个类别的概率
const predictions = await this.mobilenet.classify(img);
// 获取完整的 logits原始输出
const logits = this.mobilenet.infer(img, false); // false = 不使用嵌入层获取原始1000维输出
// 获取前10个最高概率的标签
const topK = await this.getTopKTags(logits, 10);
return {
logits: logits, // 1000维特征向量
predictions: predictions, // 前3个预测
topTags: topK // 前10个标签和权重
};
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
// 获取Top-K标签
async getTopKTags(logits, k = 10) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
const topkValues = new Float32Array(k);
const topkIndices = new Int32Array(k);
for (let i = 0; i < k; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topTags = [];
for (let i = 0; i < k; i++) {
topTags.push({
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
probability: this.softmax(topkValues)[i],
logit: topkValues[i]
});
}
return topTags;
}
// Softmax 函数
softmax(arr) {
const maxLogit = Math.max(...arr);
const scores = arr.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b);
return scores.map(s => s / sum);
}
// 训练 KNN 模型
async trainKNN() {
const classes = [];
const imageFiles = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files.length > 0) {
classes.push(className);
imageFiles.push(files);
}
}
if (classes.length < 2) {
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
return;
}
this.classNames = classes;
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
// 清空现有的KNN分类器
this.knnClassifier.clearAllClasses();
let totalProcessed = 0;
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imageFiles[classIndex];
console.log(`处理类别 ${classes[classIndex]}...`);
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
const img = await this.loadImage(files[fileIndex]);
// 提取特征标签
const features = await this.extractImageNetTags(img);
// 添加到KNN分类器
// 使用完整的logits作为特征向量
this.knnClassifier.addExample(features.logits, classIndex);
totalProcessed++;
const progress = Math.round((totalProcessed / totalImages) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
// 显示提取的标签
if (fileIndex === 0) {
this.displayTopTags(features.topTags);
}
// 清理
img.remove();
features.logits.dispose();
} catch (error) {
console.error('处理图片失败:', error);
}
}
}
// 更新模型信息
document.getElementById('totalSamples').textContent = totalProcessed;
this.showStatus('dataStatus', 'success',
`KNN模型训练完成${totalProcessed} 个样本,${classes.length} 个类别`);
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
}
// 显示提取的标签
displayTopTags(tags) {
const container = document.getElementById('tagsList');
let html = '';
tags.slice(0, 5).forEach(tag => {
html += `
<span class="tag-item">
${tag.className}
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
</span>
`;
});
container.innerHTML = html;
}
// 加载图片
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
// 清空数据集
clearDataset() {
this.knnClassifier.clearAllClasses();
this.classNames = [];
this.filteredConfidences = {}; // 重置滤波器状态
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
}
document.getElementById('totalSamples').textContent = '0';
document.getElementById('tagsList').innerHTML = '等待数据...';
document.getElementById('predictions').innerHTML = '等待预测...';
this.showStatus('dataStatus', 'info', '数据集已清空');
}
// 启动摄像头
async startWebcam() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
// 等待视频加载
video.addEventListener('loadeddata', () => {
this.isPredicting = true;
this.predictLoop();
});
this.showStatus('predictionStatus', 'success', '摄像头已启动');
} catch (error) {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
}
}
// 停止摄像头
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
this.filteredConfidences = {}; // 重置滤波器状态
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
// 预测循环
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
// 提取特征
const features = await this.extractImageNetTags(video);
// 使用原始KNN进行预测
const k = parseInt(document.getElementById('kValue').value);
const prediction = await this.knnClassifier.predictClass(features.logits, k);
// 应用低通滤波器
const smoothedPrediction = this.applyLowPassFilter(prediction);
// 显示预测结果
this.displayPrediction(smoothedPrediction);
// 显示提取的标签
this.displayTopTags(features.topTags);
// 清理张量
features.logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
// 继续预测循环
requestAnimationFrame(() => this.predictLoop());
}
// 应用低通滤波器到置信度
applyLowPassFilter(prediction) {
// 获取滤波器系数
const alpha = parseFloat(document.getElementById('filterAlpha').value);
// 初始化滤波状态(如果是第一次)
if (Object.keys(this.filteredConfidences).length === 0) {
for (let i = 0; i < this.classNames.length; i++) {
this.filteredConfidences[i] = prediction.confidences[i] || 0;
}
return {
label: prediction.label,
confidences: {...this.filteredConfidences}
};
}
// 应用指数移动平均EMA低通滤波
const newConfidences = {};
for (let i = 0; i < this.classNames.length; i++) {
const currentValue = prediction.confidences[i] || 0;
const previousValue = this.filteredConfidences[i] || 0;
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
newConfidences[i] = this.filteredConfidences[i];
}
// 归一化确保总和为1
let sum = 0;
Object.values(newConfidences).forEach(v => sum += v);
if (sum > 0) {
Object.keys(newConfidences).forEach(key => {
newConfidences[key] = newConfidences[key] / sum;
});
}
// 找到最高置信度的类别
let maxConfidence = 0;
let bestLabel = 0;
Object.keys(newConfidences).forEach(key => {
if (newConfidences[key] > maxConfidence) {
maxConfidence = newConfidences[key];
bestLabel = parseInt(key);
}
});
return {
label: bestLabel,
confidences: newConfidences
};
}
// 显示预测结果
displayPrediction(prediction) {
const container = document.getElementById('predictions');
let html = '';
// 直接使用滤波后的置信度
const confidences = prediction.confidences;
const predictedClass = prediction.label;
// 固定顺序显示(按类别索引)
for (let i = 0; i < this.classNames.length; i++) {
const className = this.classNames[i];
const confidence = confidences[i] || 0;
const percentage = (confidence * 100).toFixed(1);
const isWinner = i === predictedClass;
// 根据置信度决定颜色等级
let barClass = '';
if (confidence > 0.7) barClass = 'high';
else if (confidence > 0.4) barClass = 'medium';
else barClass = 'low';
// 如果是获胜类别,使用绿色
if (isWinner) barClass = 'high';
html += `
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
<div class="prediction-header">
<span class="prediction-label">
${className} ${isWinner ? '👑' : ''}
</span>
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
${percentage}%
</span>
</div>
<div class="confidence-bar-container">
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
</div>
</div>
</div>
`;
}
container.innerHTML = html;
}
// 从摄像头捕获样本
async captureFromWebcam(classIndex) {
if (!this.webcamStream) {
// 临时启动摄像头
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
// 等待视频加载
setTimeout(async () => {
await this.addWebcamSample(classIndex);
// 停止临时摄像头
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
video.srcObject = null;
}, 1000);
} catch (error) {
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
}
} else {
await this.addWebcamSample(classIndex);
}
}
// 添加摄像头样本
async addWebcamSample(classIndex) {
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
// 提取特征
const features = await this.extractImageNetTags(video);
// 添加到KNN分类器
this.knnClassifier.addExample(features.logits, classIndex);
// 更新计数
const currentCount = this.knnClassifier.getClassExampleCount();
const count = currentCount[classIndex] || 0;
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
// 清理
features.logits.dispose();
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
} catch (error) {
console.error('添加样本失败:', error);
}
}
}
// 保存模型
async saveModel() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
return;
}
try {
// 获取KNN分类器的数据
const dataset = this.knnClassifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach(key => {
const data = dataset[key].dataSync();
datasetObj[key] = Array.from(data);
});
// 保存为JSON
const modelData = {
dataset: datasetObj,
classNames: this.classNames,
k: document.getElementById('kValue').value,
date: new Date().toISOString()
};
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'knn-model.json';
a.click();
URL.revokeObjectURL(url);
this.showStatus('predictionStatus', 'success', '模型已保存');
} catch (error) {
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
}
}
// 加载模型
async loadModel() {
const input = document.createElement('input');
input.type = 'file';
input.accept = '.json';
input.onchange = async (e) => {
try {
const file = e.target.files[0];
const text = await file.text();
const modelData = JSON.parse(text);
// 清空现有分类器
this.knnClassifier.clearAllClasses();
// 恢复数据集
Object.keys(modelData.dataset).forEach(key => {
const tensor = tf.tensor(modelData.dataset[key], [modelData.dataset[key].length / 1024, 1024]);
this.knnClassifier.setClassifierDataset({ [key]: tensor });
});
this.classNames = modelData.classNames;
document.getElementById('kValue').value = modelData.k;
document.getElementById('kValueDisplay').textContent = modelData.k;
this.showStatus('predictionStatus', 'success',
`模型加载成功!类别: ${this.classNames.join(', ')}`);
} catch (error) {
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
}
};
input.click();
}
// 显示状态
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info'
};
element.className = `status-message ${classMap[type]}`;
element.textContent = message;
}
}
// 全局函数:从摄像头捕获
function captureFromWebcam(classIndex) {
if (window.classifier) {
window.classifier.captureFromWebcam(classIndex);
}
}
// 初始化应用
let classifier;
document.addEventListener('DOMContentLoaded', () => {
classifier = new KNNImageClassifier();
window.classifier = classifier;
});

View File

@ -0,0 +1,561 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>KNN 图像分类器 - TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier@latest"></script>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.main-container {
max-width: 1400px;
margin: 0 auto;
}
h1 {
color: white;
text-align: center;
margin-bottom: 30px;
font-size: 2.5em;
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}
.grid-container {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-bottom: 20px;
}
@media (max-width: 768px) {
.grid-container {
grid-template-columns: 1fr;
}
}
.card {
background: white;
border-radius: 15px;
padding: 25px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.class-input {
margin-bottom: 20px;
padding: 15px;
background: #f8f9fa;
border-radius: 10px;
}
.class-input h3 {
color: #555;
margin-bottom: 10px;
display: flex;
align-items: center;
gap: 10px;
}
.class-number {
background: #667eea;
color: white;
width: 25px;
height: 25px;
border-radius: 50%;
display: inline-flex;
align-items: center;
justify-content: center;
font-size: 14px;
}
input[type="text"] {
width: 100%;
padding: 10px;
border: 2px solid #e0e0e0;
border-radius: 5px;
margin-bottom: 10px;
font-size: 16px;
transition: border-color 0.3s;
}
input[type="text"]:focus {
outline: none;
border-color: #667eea;
}
input[type="file"] {
display: none;
}
.file-label {
display: inline-block;
padding: 10px 20px;
background: #667eea;
color: white;
border-radius: 5px;
cursor: pointer;
transition: background 0.3s;
margin-right: 10px;
}
.file-label:hover {
background: #5a67d8;
}
.btn {
padding: 12px 30px;
border: none;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
transition: all 0.3s;
margin: 5px;
}
.btn-primary {
background: #667eea;
color: white;
}
.btn-primary:hover {
background: #5a67d8;
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
.btn-success {
background: #48bb78;
color: white;
}
.btn-success:hover {
background: #38a169;
}
.btn-danger {
background: #f56565;
color: white;
}
.btn-danger:hover {
background: #e53e3e;
}
.btn:disabled {
background: #cbd5e0;
cursor: not-allowed;
transform: none;
}
#webcam-container {
position: relative;
width: 100%;
max-width: 640px;
margin: 20px auto;
}
#webcam {
width: 100%;
border-radius: 10px;
background: #000;
}
.samples-count {
display: inline-block;
background: #edf2f7;
padding: 2px 8px;
border-radius: 10px;
font-size: 12px;
color: #4a5568;
margin-left: 5px;
}
.image-preview {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-top: 10px;
max-height: 150px;
overflow-y: auto;
}
.preview-img {
width: 60px;
height: 60px;
object-fit: cover;
border-radius: 5px;
border: 2px solid #e0e0e0;
}
.status-message {
padding: 15px;
border-radius: 5px;
margin: 10px 0;
text-align: center;
font-weight: 500;
}
.status-success {
background: #c6f6d5;
color: #22543d;
border: 1px solid #9ae6b4;
}
.status-error {
background: #fed7d7;
color: #742a2a;
border: 1px solid #fc8181;
}
.status-info {
background: #bee3f8;
color: #2c5282;
border: 1px solid #90cdf4;
}
.button-group {
display: flex;
gap: 10px;
margin: 20px 0;
flex-wrap: wrap;
}
.full-width {
grid-column: 1 / -1;
}
.prediction-results {
margin-top: 20px;
padding: 20px;
background: #f7fafc;
border-radius: 10px;
}
.prediction-item {
padding: 15px;
margin: 10px 0;
background: white;
border-radius: 8px;
border-left: 4px solid #667eea;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.prediction-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 8px;
}
.prediction-label {
font-weight: 600;
color: #2d3748;
font-size: 16px;
}
.prediction-confidence {
background: linear-gradient(135deg, #667eea, #764ba2);
color: white;
padding: 4px 12px;
border-radius: 20px;
font-size: 14px;
font-weight: 500;
min-width: 60px;
text-align: center;
}
.confidence-bar-container {
width: 100%;
height: 24px;
background: #e2e8f0;
border-radius: 12px;
overflow: hidden;
position: relative;
}
.confidence-bar {
height: 100%;
background: linear-gradient(90deg, #667eea, #764ba2);
border-radius: 12px;
transition: width 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
min-width: 0;
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
}
.confidence-bar::after {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: linear-gradient(90deg, transparent, rgba(255,255,255,0.3), transparent);
animation: shimmer 2s infinite;
}
@keyframes shimmer {
0% { transform: translateX(-100%); }
100% { transform: translateX(100%); }
}
.confidence-bar.high {
background: linear-gradient(90deg, #48bb78, #38a169);
}
.confidence-bar.medium {
background: linear-gradient(90deg, #ed8936, #dd6b20);
}
.confidence-bar.low {
background: linear-gradient(90deg, #f56565, #e53e3e);
}
.confidence-percentage {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
color: white;
font-weight: 600;
font-size: 12px;
text-shadow: 0 1px 2px rgba(0,0,0,0.2);
z-index: 1;
}
.top-tags {
margin: 20px 0;
padding: 15px;
background: #edf2fe;
border-radius: 10px;
}
.tag-item {
display: inline-block;
background: white;
padding: 5px 12px;
margin: 5px;
border-radius: 15px;
font-size: 14px;
border: 1px solid #cbd5e0;
}
.tag-weight {
color: #667eea;
font-weight: bold;
margin-left: 5px;
}
.k-selector {
margin: 15px 0;
padding: 15px;
background: #f8f9fa;
border-radius: 8px;
}
.k-selector label {
display: block;
margin-bottom: 10px;
color: #555;
font-weight: 500;
}
.k-value-display {
display: inline-block;
background: #667eea;
color: white;
padding: 2px 8px;
border-radius: 5px;
margin-left: 10px;
}
input[type="range"] {
width: 100%;
margin: 10px 0;
}
.model-info {
margin-top: 20px;
padding: 15px;
background: #f0f4f8;
border-radius: 8px;
font-size: 14px;
}
.info-item {
display: flex;
justify-content: space-between;
margin: 5px 0;
}
.info-label {
color: #718096;
}
.info-value {
color: #2d3748;
font-weight: 500;
}
</style>
</head>
<body>
<div class="main-container">
<h1>🤖 KNN 图像分类器(基于特征标签)</h1>
<div class="grid-container">
<!-- 数据采集卡片 -->
<div class="card">
<h2>📸 数据采集</h2>
<div class="class-input">
<h3><span class="class-number">1</span> 第一类</h3>
<input type="text" id="class1Name" placeholder="输入类别名称(如:猫)" value="类别1">
<label class="file-label" for="class1Images">
选择图片
</label>
<input type="file" id="class1Images" multiple accept="image/*">
<span class="samples-count" id="class1Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(0)">从摄像头采集</button>
<div class="image-preview" id="class1Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">2</span> 第二类</h3>
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
<label class="file-label" for="class2Images">
选择图片
</label>
<input type="file" id="class2Images" multiple accept="image/*">
<span class="samples-count" id="class2Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(1)">从摄像头采集</button>
<div class="image-preview" id="class2Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">3</span> 第三类(可选)</h3>
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="类别3">
<label class="file-label" for="class3Images">
选择图片
</label>
<input type="file" id="class3Images" multiple accept="image/*">
<span class="samples-count" id="class3Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(2)">从摄像头采集</button>
<div class="image-preview" id="class3Preview"></div>
</div>
<div class="button-group">
<button class="btn btn-success" id="addDataBtn">训练KNN模型</button>
<button class="btn btn-danger" id="clearDataBtn">清空数据</button>
</div>
<div id="dataStatus"></div>
</div>
<!-- KNN模型信息卡片 -->
<div class="card">
<h2>🎯 KNN 模型设置</h2>
<div class="k-selector">
<label>
K值最近邻数量
<span class="k-value-display" id="kValueDisplay">3</span>
</label>
<input type="range" id="kValue" min="1" max="20" value="3"
oninput="document.getElementById('kValueDisplay').textContent = this.value">
<small style="color: #718096;">K值越大预测越保守K值越小对局部特征越敏感</small>
</div>
<div class="k-selector">
<label>
滤波器系数 (α)
<span class="k-value-display" id="filterAlphaDisplay">0.3</span>
</label>
<input type="range" id="filterAlpha" min="0.05" max="1.0" step="0.05" value="0.3"
oninput="document.getElementById('filterAlphaDisplay').textContent = this.value">
<small style="color: #718096;">低通滤波器系数:值越小输出越平滑(0.1-0.3推荐),值越大响应越快</small>
</div>
<div class="k-selector">
<label>
距离阈值 (Distance Threshold)
<span class="k-value-display" id="distanceThresholdDisplay">0.5</span>
</label>
<input type="range" id="distanceThreshold" min="0.1" max="2.0" step="0.05" value="0.5"
oninput="document.getElementById('distanceThresholdDisplay').textContent = this.value">
<small style="color: #718096;">距离阈值:样本与训练数据的最大距离,超过此值判定为"未知/背景"(单品类检测关键参数)</small>
</div>
<div class="top-tags" id="topTags">
<h3 style="margin-bottom: 10px;">📊 特征标签提取预览</h3>
<div id="tagsList">等待数据...</div>
</div>
<div class="model-info">
<h3 style="margin-bottom: 10px;"> 模型信息</h3>
<div class="info-item">
<span class="info-label">预训练模型:</span>
<span class="info-value">MobileNet v2</span>
</div>
<div class="info-item">
<span class="info-label">特征维度:</span>
<span class="info-value">1280维嵌入向量</span>
</div>
<div class="info-item">
<span class="info-label">分类器类型:</span>
<span class="info-value">K-最近邻 (KNN)</span>
</div>
<div class="info-item">
<span class="info-label">总样本数:</span>
<span class="info-value" id="totalSamples">0</span>
</div>
</div>
</div>
</div>
<!-- 预测卡片 -->
<div class="card full-width">
<h2>📹 实时预测</h2>
<div class="button-group">
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
<button class="btn btn-success" id="saveModelBtn">保存模型</button>
<button class="btn btn-primary" id="loadModelBtn">加载模型</button>
</div>
<div id="webcam-container">
<video id="webcam" autoplay playsinline muted></video>
</div>
<div class="prediction-results" id="predictionResults">
<h3>预测结果</h3>
<div id="predictions">等待预测...</div>
</div>
<div id="predictionStatus"></div>
</div>
</div>
<script src="knn-classifier.js"></script>
</body>
</html>

985
完善KNN/knn-classifier.js Normal file
View File

@ -0,0 +1,985 @@
// KNN 图像分类器 - 基于MobileNet特征标签
class KNNImageClassifier {
constructor() {
this.mobilenet = null;
this.knnClassifier = null;
this.classNames = [];
this.webcamStream = null;
this.isPredicting = false;
this.currentCaptureClass = -1;
this.imagenetClasses = null;
// 低通滤波器状态
this.filteredConfidences = {};
this.filterAlpha = 0.3; // 滤波器系数 (0-1),越小越平滑
// 距离阈值设置
this.useDistanceThreshold = true;
this.distanceThreshold = 0.5; // 默认距离阈值(归一化后的特征)
this.adaptiveThreshold = null; // 自适应阈值
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
// 加载 MobileNet 模型
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
// 创建 KNN 分类器
this.knnClassifier = knnClassifier.create();
// 加载 ImageNet 类别名称
await this.loadImageNetClasses();
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
async loadImageNetClasses() {
// ImageNet 前10个类别名称简化版
this.imagenetClasses = [
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
];
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.trainKNN());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
document.getElementById('saveModelBtn').addEventListener('click', () => this.saveModel());
document.getElementById('loadModelBtn').addEventListener('click', () => this.loadModel());
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const countElement = document.getElementById(`class${classIndex + 1}Count`);
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
countElement.textContent = `${files.length} 张图片`;
// 清空之前的预览
previewContainer.innerHTML = '';
// 添加图片预览
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
}
// 从图像提取 MobileNet 标签和权重
async extractImageNetTags(img) {
try {
// 获取 MobileNet 的预测1000个类别的概率
const predictions = await this.mobilenet.classify(img);
// 获取用于KNN的特征使用嵌入层获得更好的特征表示
const rawEmbeddings = this.mobilenet.infer(img, true); // true = 使用嵌入层获取1280维特征
// L2归一化特征向量重要使距离计算更稳定
const embeddings = tf.tidy(() => {
const norm = tf.norm(rawEmbeddings);
const normalized = tf.div(rawEmbeddings, norm);
rawEmbeddings.dispose(); // 清理原始嵌入
return normalized;
});
// 获取用于显示的logits1000个类别
const logits = this.mobilenet.infer(img, false); // false = 获取原始1000维输出
// 获取前10个最高概率的标签
const topK = await this.getTopKTags(logits, 10);
// 清理logits只用于显示
logits.dispose();
return {
logits: embeddings, // 使用1280维嵌入特征用于KNN
predictions: predictions, // 前3个预测
topTags: topK // 前10个标签和权重
};
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
// 获取Top-K标签
async getTopKTags(logits, k = 10) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
const topkValues = new Float32Array(k);
const topkIndices = new Int32Array(k);
for (let i = 0; i < k; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topTags = [];
for (let i = 0; i < k; i++) {
topTags.push({
className: this.imagenetClasses[i] || `class_${topkIndices[i]}`,
probability: this.softmax(topkValues)[i],
logit: topkValues[i]
});
}
return topTags;
}
// Softmax 函数
softmax(arr) {
const maxLogit = Math.max(...arr);
const scores = arr.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b);
return scores.map(s => s / sum);
}
// 训练 KNN 模型
async trainKNN() {
const classes = [];
const imageFiles = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files && files.length > 0) {
classes.push(className);
imageFiles.push(files);
console.log(`类别 ${i}: "${className}" - ${files.length} 张图片`);
}
}
console.log('收集到的类别:', classes);
console.log('类别数量:', classes.length);
// 支持单品类检测One-Class Classification
if (classes.length < 1) {
this.showStatus('dataStatus', 'error', '请至少添加一个类别的图片!');
return;
}
// 如果只有一个类别,提示用户这是单品类检测模式
if (classes.length === 1) {
console.log('📍 单品类检测模式:只检测 "' + classes[0] + '",其他都视为背景/未知');
}
this.classNames = classes;
this.filteredConfidences = {}; // 重置滤波器状态
this.showStatus('dataStatus', 'info', '正在处理图片并训练KNN模型...');
// 清空现有的KNN分类器
this.knnClassifier.clearAllClasses();
let totalProcessed = 0;
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imageFiles[classIndex];
console.log(`处理类别 ${classes[classIndex]}...`);
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
const img = await this.loadImage(files[fileIndex]);
// 提取特征标签
const features = await this.extractImageNetTags(img);
// 添加到KNN分类器
// 使用完整的logits作为特征向量
this.knnClassifier.addExample(features.logits, classIndex);
totalProcessed++;
const progress = Math.round((totalProcessed / totalImages) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
// 显示提取的标签
if (fileIndex === 0) {
this.displayTopTags(features.topTags);
}
// 清理
img.remove();
features.logits.dispose();
} catch (error) {
console.error('处理图片失败:', error);
}
}
}
// 更新模型信息
document.getElementById('totalSamples').textContent = totalProcessed;
// 根据类别数量显示不同的消息
let statusMessage;
if (classes.length === 1) {
statusMessage = `单品类检测模型训练完成!将只检测 "${classes[0]}",共 ${totalProcessed} 个样本`;
} else {
statusMessage = `KNN模型训练完成${totalProcessed} 个样本,${classes.length} 个类别`;
}
this.showStatus('dataStatus', 'success', statusMessage);
console.log('KNN分类器状态:', this.knnClassifier.getNumClasses(), '个类别');
if (classes.length === 1) {
console.log('📍 单品类检测模式已启用,将基于距离阈值判断是否为:', classes[0]);
// 计算自适应阈值
await this.calculateAdaptiveThreshold();
}
}
// 计算自适应阈值(基于训练数据的内部距离)
async calculateAdaptiveThreshold() {
if (this.knnClassifier.getNumClasses() !== 1) return;
console.log('计算自适应阈值...');
const dataset = this.knnClassifier.getClassifierDataset();
if (!dataset || !dataset[0]) return;
const trainData = await dataset[0].data();
const numSamples = dataset[0].shape[0];
const featureDim = dataset[0].shape[1];
// 计算训练样本之间的平均距离
let totalDistance = 0;
let count = 0;
for (let i = 0; i < Math.min(numSamples, 20); i++) { // 限制计算量
for (let j = i + 1; j < Math.min(numSamples, 20); j++) {
let distance = 0;
for (let k = 0; k < featureDim; k++) {
const diff = trainData[i * featureDim + k] - trainData[j * featureDim + k];
distance += diff * diff;
}
distance = Math.sqrt(distance);
totalDistance += distance;
count++;
}
}
if (count > 0) {
const avgInternalDistance = totalDistance / count;
// 自适应阈值设为内部平均距离的1.3-1.5倍(归一化后距离较小)
this.adaptiveThreshold = avgInternalDistance * 1.3;
console.log(`内部平均距离: ${avgInternalDistance.toFixed(2)}`);
console.log(`建议自适应阈值: ${this.adaptiveThreshold.toFixed(2)}`);
// 更新UI显示建议阈值
const thresholdInput = document.getElementById('distanceThreshold');
const thresholdDisplay = document.getElementById('distanceThresholdDisplay');
if (thresholdInput && thresholdDisplay) {
thresholdInput.value = this.adaptiveThreshold.toFixed(1);
thresholdDisplay.textContent = this.adaptiveThreshold.toFixed(1);
}
this.showStatus('dataStatus', 'info',
`自适应阈值已计算: ${this.adaptiveThreshold.toFixed(1)} (基于训练数据内部距离)`);
}
}
// 显示提取的标签
displayTopTags(tags) {
const container = document.getElementById('tagsList');
let html = '';
tags.slice(0, 5).forEach(tag => {
html += `
<span class="tag-item">
${tag.className}
<span class="tag-weight">${(tag.probability * 100).toFixed(1)}%</span>
</span>
`;
});
container.innerHTML = html;
}
// 加载图片
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
// 清空数据集
clearDataset() {
this.knnClassifier.clearAllClasses();
this.classNames = [];
this.filteredConfidences = {}; // 重置滤波器状态
console.log('数据集已清空,滤波器状态已重置');
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
document.getElementById(`class${i}Preview`).innerHTML = ''; // 清空预览
}
document.getElementById('totalSamples').textContent = '0';
document.getElementById('tagsList').innerHTML = '等待数据...';
document.getElementById('predictions').innerHTML = '等待预测...';
this.showStatus('dataStatus', 'info', '数据集已清空');
}
// 启动摄像头
async startWebcam() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
// 等待视频加载
video.addEventListener('loadeddata', () => {
this.isPredicting = true;
this.predictLoop();
});
this.showStatus('predictionStatus', 'success', '摄像头已启动');
} catch (error) {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
}
}
// 停止摄像头
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
this.filteredConfidences = {}; // 重置滤波器状态
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
// 预测循环
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
console.log('开始预测,类别数量:', this.classNames.length);
// 提取特征
const features = await this.extractImageNetTags(video);
// 使用原始KNN进行预测包含距离信息
const k = parseInt(document.getElementById('kValue').value);
const predictionWithDistance = await this.predictWithDistance(features.logits, k);
console.log('预测结果:', predictionWithDistance);
// 检查是否为未知类(基于距离阈值)
let finalPrediction;
// 单品类模式特殊处理
if (predictionWithDistance.isSingleClass) {
// 直接使用 predictWithDistance 返回的结果,它已经处理了阈值判断
finalPrediction = {
label: predictionWithDistance.label,
confidences: predictionWithDistance.confidences,
isUnknown: predictionWithDistance.label === -1,
minDistance: predictionWithDistance.minDistance,
isSingleClass: true
};
} else {
// 多类别模式直接使用KNN预测结果不使用距离阈值
finalPrediction = {
label: predictionWithDistance.label,
confidences: predictionWithDistance.confidences,
isUnknown: false,
minDistance: predictionWithDistance.minDistance,
isSingleClass: false
};
}
// 应用低通滤波器
const smoothedPrediction = this.applyLowPassFilter(finalPrediction);
// 显示预测结果
this.displayPrediction(smoothedPrediction);
// 显示提取的标签
this.displayTopTags(features.topTags);
// 清理张量
features.logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
// 继续预测循环
requestAnimationFrame(() => this.predictLoop());
}
// 使用距离信息进行预测
async predictWithDistance(logits, k) {
// 如果没有训练数据,返回空结果
if (this.knnClassifier.getNumClasses() === 0) {
return {
label: -1,
confidences: {},
minDistance: Infinity,
isSingleClass: false
};
}
const numClasses = this.knnClassifier.getNumClasses();
// 单品类检测模式 - 使用实际距离计算
if (numClasses === 1) {
console.log('单品类检测模式 - 计算实际距离');
// 获取训练数据
const dataset = this.knnClassifier.getClassifierDataset();
if (!dataset || !dataset[0]) {
return {
label: -1,
confidences: { 0: 0 },
minDistance: Infinity,
isSingleClass: true
};
}
// 计算输入样本与所有训练样本的欧氏距离
const inputData = await logits.data();
const trainData = await dataset[0].data();
const numSamples = dataset[0].shape[0];
const featureDim = dataset[0].shape[1];
console.log(`输入特征维度: ${inputData.length}, 训练数据维度: ${featureDim}, 样本数: ${numSamples}`);
// 确保维度匹配
if (inputData.length !== featureDim) {
console.error(`维度不匹配!输入: ${inputData.length}, 训练: ${featureDim}`);
return {
label: -1,
confidences: { 0: 0 },
minDistance: Infinity,
isSingleClass: true
};
}
let minDistance = Infinity;
const distances = [];
// 计算与每个训练样本的距离
for (let i = 0; i < numSamples; i++) {
let distance = 0;
for (let j = 0; j < featureDim; j++) {
const diff = inputData[j] - trainData[i * featureDim + j];
distance += diff * diff;
}
distance = Math.sqrt(distance);
distances.push(distance);
if (distance < minDistance) {
minDistance = distance;
}
}
console.log(`计算了 ${distances.length} 个距离,最小距离: ${minDistance.toFixed(2)}`);
console.log(`前5个距离: ${distances.slice(0, 5).map(d => d.toFixed(2)).join(', ')}`);
// 获取K个最近邻的平均距离
distances.sort((a, b) => a - b);
const kNearest = distances.slice(0, Math.min(k, distances.length));
const avgDistance = kNearest.reduce((sum, d) => sum + d, 0) / kNearest.length;
// 从UI获取距离阈值
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '15.0');
// 基于距离阈值判断是否属于该类
const belongsToClass = avgDistance <= threshold;
// 二值化置信度在阈值内100%超出阈值0%
const confidence = belongsToClass ? 1.0 : 0;
console.log(`单品类预测 - 平均距离: ${avgDistance.toFixed(2)}, 阈值: ${threshold}, 属于类别: ${belongsToClass}, 置信度: ${confidence.toFixed(3)}`);
return {
label: belongsToClass ? 0 : -1,
confidences: { 0: confidence },
minDistance: avgDistance,
isSingleClass: true
};
}
// 多品类模式使用KNN分类器预测
try {
const prediction = await this.knnClassifier.predictClass(logits, k);
console.log('多品类预测结果:', prediction);
console.log('预测标签:', prediction.label);
console.log('置信度:', prediction.confidences);
// 确保confidences存在且格式正确
let confidences = {};
if (prediction.confidences) {
// 检查是否是对象格式
if (typeof prediction.confidences === 'object') {
confidences = prediction.confidences;
}
}
// 如果confidences为空手动计算
if (Object.keys(confidences).length === 0) {
console.warn('置信度为空,使用默认值');
for (let i = 0; i < this.classNames.length; i++) {
confidences[i] = i === prediction.label ? 1.0 : 0;
}
}
// 计算实际距离(可选)
let minDistance = 0.5; // 默认距离
return {
label: prediction.label,
confidences: confidences,
minDistance: minDistance,
isSingleClass: false
};
} catch (error) {
console.error('预测错误:', error);
return {
label: -1,
confidences: {},
minDistance: Infinity,
isSingleClass: false
};
}
}
// 应用低通滤波器到置信度
applyLowPassFilter(prediction) {
// 获取滤波器系数
const alpha = parseFloat(document.getElementById('filterAlpha').value);
// 保留原始的特殊字段
const isSingleClass = prediction.isSingleClass || false;
const minDistance = prediction.minDistance;
const isUnknown = prediction.isUnknown;
// 初始化滤波状态(如果是第一次)
if (Object.keys(this.filteredConfidences).length === 0) {
for (let i = 0; i < this.classNames.length; i++) {
this.filteredConfidences[i] = prediction.confidences[i] || 0;
}
return {
label: prediction.label,
confidences: {...this.filteredConfidences},
isSingleClass: isSingleClass,
minDistance: minDistance,
isUnknown: isUnknown
};
}
// 单品类模式下,使用二值化输出,不应用滤波
if (isSingleClass) {
// 获取距离阈值
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
// 二值化判断在阈值内为1超出为0
const inThreshold = minDistance <= threshold;
const confidence = inThreshold ? 1.0 : 0;
// 直接更新,不滤波
this.filteredConfidences[0] = confidence;
return {
label: inThreshold ? 0 : -1,
confidences: { 0: confidence },
isSingleClass: isSingleClass,
minDistance: minDistance,
isUnknown: !inThreshold
};
}
// 应用指数移动平均EMA低通滤波
const newConfidences = {};
for (let i = 0; i < this.classNames.length; i++) {
const currentValue = prediction.confidences[i] || 0;
const previousValue = this.filteredConfidences[i] || 0;
// EMA公式: y[n] = α * x[n] + (1 - α) * y[n-1]
this.filteredConfidences[i] = alpha * currentValue + (1 - alpha) * previousValue;
newConfidences[i] = this.filteredConfidences[i];
}
// 归一化确保总和为1
let sum = 0;
Object.values(newConfidences).forEach(v => sum += v);
if (sum > 0) {
Object.keys(newConfidences).forEach(key => {
newConfidences[key] = newConfidences[key] / sum;
});
}
// 找到最高置信度的类别
let maxConfidence = 0;
let bestLabel = 0;
Object.keys(newConfidences).forEach(key => {
if (newConfidences[key] > maxConfidence) {
maxConfidence = newConfidences[key];
bestLabel = parseInt(key);
}
});
return {
label: bestLabel,
confidences: newConfidences,
isSingleClass: isSingleClass,
minDistance: minDistance,
isUnknown: isUnknown
};
}
// 显示预测结果
displayPrediction(prediction) {
const container = document.getElementById('predictions');
let html = '';
// 单品类模式特殊处理
if (this.classNames.length === 1) {
const className = this.classNames[0];
const confidence = prediction.confidences[0] || 0;
const percentage = (confidence * 100).toFixed(1);
const isDetected = prediction.label === 0; // 是否检测到该类
// 获取距离阈值
const threshold = parseFloat(document.getElementById('distanceThreshold')?.value || '0.5');
const distance = prediction.minDistance || 0;
// 显示单品类检测结果(二值化显示)
html = `
<div class="prediction-item" style="${isDetected ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : 'border-left-color: #cbd5e0;'}">
<div class="prediction-header">
<span class="prediction-label">
${className} ${isDetected ? '✓ 检测到' : '✗ 未检测到'}
</span>
<span class="prediction-confidence" style="${isDetected ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : 'background: #cbd5e0;'}">
${isDetected ? '100%' : '0%'}
</span>
</div>
<div class="confidence-bar-container">
<div class="confidence-bar ${isDetected ? 'high' : 'low'}" style="width: ${isDetected ? '100' : '0'}%;">
${isDetected ? `<span class="confidence-percentage">100%</span>` : ''}
</div>
</div>
</div>
<div style="margin-top: 10px; padding: 10px; background: #f8f9fa; border-radius: 5px; font-size: 12px; color: #666;">
<strong>距离:</strong> ${distance.toFixed(2)} | <strong>:</strong> ${threshold.toFixed(2)}
<span style="margin-left: 10px; color: ${distance <= threshold ? '#48bb78' : '#f56565'};">
${distance <= threshold ? '✓ 在阈值范围内' : '✗ 超出阈值范围'}
</span>
</div>
`;
container.innerHTML = html;
return;
}
// 多品类模式:直接使用滤波后的置信度
const confidences = prediction.confidences;
const predictedClass = prediction.label;
// 固定顺序显示(按类别索引)
for (let i = 0; i < this.classNames.length; i++) {
const className = this.classNames[i];
const confidence = confidences[i] || 0;
const percentage = (confidence * 100).toFixed(1);
const isWinner = i === predictedClass;
// 根据置信度决定颜色等级
let barClass = '';
if (confidence > 0.7) barClass = 'high';
else if (confidence > 0.4) barClass = 'medium';
else barClass = 'low';
// 如果是获胜类别,使用绿色
if (isWinner) barClass = 'high';
html += `
<div class="prediction-item" style="${isWinner ? 'border-left-color: #48bb78; background: linear-gradient(to right, #f0fff4, white);' : ''}">
<div class="prediction-header">
<span class="prediction-label">
${className} ${isWinner ? '👑' : ''}
</span>
<span class="prediction-confidence" style="${isWinner ? 'background: linear-gradient(135deg, #48bb78, #38a169);' : ''}">
${percentage}%
</span>
</div>
<div class="confidence-bar-container">
<div class="confidence-bar ${barClass}" style="width: ${percentage}%;">
${confidence > 0.15 ? `<span class="confidence-percentage">${percentage}%</span>` : ''}
</div>
</div>
</div>
`;
}
container.innerHTML = html;
}
// 从摄像头捕获样本
async captureFromWebcam(classIndex) {
if (!this.webcamStream) {
// 临时启动摄像头
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
// 等待视频加载
setTimeout(async () => {
await this.addWebcamSample(classIndex);
// 停止临时摄像头
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
video.srcObject = null;
}, 1000);
} catch (error) {
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
}
} else {
await this.addWebcamSample(classIndex);
}
}
// 添加摄像头样本
async addWebcamSample(classIndex) {
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
// 提取特征
const features = await this.extractImageNetTags(video);
// 添加到KNN分类器
this.knnClassifier.addExample(features.logits, classIndex);
// 更新计数
const currentCount = this.knnClassifier.getClassExampleCount();
const count = currentCount[classIndex] || 0;
document.getElementById(`class${classIndex + 1}Count`).textContent = `${count} 个样本`;
// 清理
features.logits.dispose();
this.showStatus('dataStatus', 'success', `已添加样本到类别 ${classIndex + 1}`);
} catch (error) {
console.error('添加样本失败:', error);
}
}
}
// 保存模型
async saveModel() {
if (this.knnClassifier.getNumClasses() === 0) {
this.showStatus('predictionStatus', 'error', '没有可保存的模型');
return;
}
try {
// 获取KNN分类器的数据
const dataset = this.knnClassifier.getClassifierDataset();
const datasetObj = {};
Object.keys(dataset).forEach(key => {
const data = dataset[key].dataSync();
datasetObj[key] = Array.from(data);
});
// 获取特征维度
let featureDim = 1280; // 默认值
const firstKey = Object.keys(dataset)[0];
if (firstKey && dataset[firstKey]) {
featureDim = dataset[firstKey].shape[1];
}
// 保存为JSON
const modelData = {
dataset: datasetObj,
classNames: this.classNames,
k: document.getElementById('kValue').value,
featureDim: featureDim, // 保存特征维度
date: new Date().toISOString()
};
const blob = new Blob([JSON.stringify(modelData)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'knn-model.json';
a.click();
URL.revokeObjectURL(url);
this.showStatus('predictionStatus', 'success', '模型已保存');
} catch (error) {
this.showStatus('predictionStatus', 'error', `保存失败: ${error.message}`);
}
}
// 加载模型
async loadModel() {
const input = document.createElement('input');
input.type = 'file';
input.accept = '.json';
input.onchange = async (e) => {
try {
const file = e.target.files[0];
const text = await file.text();
const modelData = JSON.parse(text);
// 清空现有分类器
this.knnClassifier.clearAllClasses();
// 恢复数据集
Object.keys(modelData.dataset).forEach(key => {
const data = modelData.dataset[key];
// 自动检测特征维度(兼容旧模型)
let featureDim = modelData.featureDim;
if (!featureDim) {
// 尝试常见的维度
const possibleDims = [1280, 1024, 1000];
for (const dim of possibleDims) {
if (data.length % dim === 0) {
featureDim = dim;
console.warn(`自动检测到特征维度: ${dim}`);
break;
}
}
}
if (!featureDim) {
console.error(`无法确定特征维度,数据长度: ${data.length}`);
return;
}
const numSamples = data.length / featureDim;
console.log(`加载类别 ${key}${numSamples} 个样本,${featureDim} 维特征`);
const tensor = tf.tensor(data, [numSamples, featureDim]);
this.knnClassifier.setClassifierDataset({ [key]: tensor });
});
this.classNames = modelData.classNames;
document.getElementById('kValue').value = modelData.k;
document.getElementById('kValueDisplay').textContent = modelData.k;
this.showStatus('predictionStatus', 'success',
`模型加载成功!类别: ${this.classNames.join(', ')}`);
} catch (error) {
this.showStatus('predictionStatus', 'error', `加载失败: ${error.message}`);
}
};
input.click();
}
// 显示状态
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info'
};
element.className = `status-message ${classMap[type]}`;
element.textContent = message;
}
}
// 全局函数:从摄像头捕获
function captureFromWebcam(classIndex) {
if (window.classifier) {
window.classifier.captureFromWebcam(classIndex);
}
}
// 初始化应用
let classifier;
document.addEventListener('DOMContentLoaded', () => {
classifier = new KNNImageClassifier();
window.classifier = classifier;
});

View File

@ -0,0 +1,394 @@
var dt = (function () {
/**
* Creates an instance of DecisionTree
*
* @constructor
* @param builder - contains training set and
* some configuration parameters
*/
function DecisionTree(builder) {
this.root = buildDecisionTree({
trainingSet: builder.trainingSet,
ignoredAttributes: arrayToHashSet(builder.ignoredAttributes),
categoryAttr: builder.categoryAttr || 'category',
minItemsCount: builder.minItemsCount || 1,
entropyThrehold: builder.entropyThrehold || 0.01,
maxTreeDepth: builder.maxTreeDepth || 70
});
}
DecisionTree.prototype.predict = function (item) {
return predict(this.root, item);
}
/**
* Creates an instance of RandomForest
* with specific number of trees
*
* @constructor
* @param builder - contains training set and some
* configuration parameters for
* building decision trees
*/
function RandomForest(builder, treesNumber) {
this.trees = buildRandomForest(builder, treesNumber);
}
RandomForest.prototype.predict = function (item) {
return predictRandomForest(this.trees, item);
}
/**
* Transforming array to object with such attributes
* as elements of array (afterwards it can be used as HashSet)
*/
function arrayToHashSet(array) {
var hashSet = {};
if (array) {
for(var i in array) {
var attr = array[i];
hashSet[attr] = true;
}
}
return hashSet;
}
/**
* Calculating how many objects have the same
* values of specific attribute.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function countUniqueValues(items, attr) {
var counter = {};
// detecting different values of attribute
for (var i = items.length - 1; i >= 0; i--) {
// items[i][attr] - value of attribute
counter[items[i][attr]] = 0;
}
// counting number of occurrences of each of values
// of attribute
for (var i = items.length - 1; i >= 0; i--) {
counter[items[i][attr]] += 1;
}
return counter;
}
/**
* Calculating entropy of array of objects
* by specific attribute.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function entropy(items, attr) {
// counting number of occurrences of each of values
// of attribute
var counter = countUniqueValues(items, attr);
var entropy = 0;
var p;
for (var i in counter) {
p = counter[i] / items.length;
entropy += -p * Math.log(p);
}
return entropy;
}
/**
* Splitting array of objects by value of specific attribute,
* using specific predicate and pivot.
*
* Items which matched by predicate will be copied to
* the new array called 'match', and the rest of the items
* will be copied to array with name 'notMatch'
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*
* @param predicate - function(x, y)
* which returns 'true' or 'false'
*
* @param pivot - used as the second argument when
* calling predicate function:
* e.g. predicate(item[attr], pivot)
*/
function split(items, attr, predicate, pivot) {
var match = [];
var notMatch = [];
var item,
attrValue;
for (var i = items.length - 1; i >= 0; i--) {
item = items[i];
attrValue = item[attr];
if (predicate(attrValue, pivot)) {
match.push(item);
} else {
notMatch.push(item);
}
};
return {
match: match,
notMatch: notMatch
};
}
/**
* Finding value of specific attribute which is most frequent
* in given array of objects.
*
* @param items - array of objects
*
* @param attr - variable with name of attribute,
* which embedded in each object
*/
function mostFrequentValue(items, attr) {
// counting number of occurrences of each of values
// of attribute
var counter = countUniqueValues(items, attr);
var mostFrequentCount = 0;
var mostFrequentValue;
for (var value in counter) {
if (counter[value] > mostFrequentCount) {
mostFrequentCount = counter[value];
mostFrequentValue = value;
}
};
return mostFrequentValue;
}
var predicates = {
'==': function (a, b) { return a == b },
'>=': function (a, b) { return a >= b }
};
/**
* Function for building decision tree
*/
function buildDecisionTree(builder) {
var trainingSet = builder.trainingSet;
var minItemsCount = builder.minItemsCount;
var categoryAttr = builder.categoryAttr;
var entropyThrehold = builder.entropyThrehold;
var maxTreeDepth = builder.maxTreeDepth;
var ignoredAttributes = builder.ignoredAttributes;
if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) {
// restriction by maximal depth of tree
// or size of training set is to small
// so we have to terminate process of building tree
return {
category: mostFrequentValue(trainingSet, categoryAttr)
};
}
var initialEntropy = entropy(trainingSet, categoryAttr);
if (initialEntropy <= entropyThrehold) {
// entropy of training set too small
// (it means that training set is almost homogeneous),
// so we have to terminate process of building tree
return {
category: mostFrequentValue(trainingSet, categoryAttr)
};
}
// used as hash-set for avoiding the checking of split by rules
// with the same 'attribute-predicate-pivot' more than once
var alreadyChecked = {};
// this variable expected to contain rule, which splits training set
// into subsets with smaller values of entropy (produces informational gain)
var bestSplit = {gain: 0};
for (var i = trainingSet.length - 1; i >= 0; i--) {
var item = trainingSet[i];
// iterating over all attributes of item
for (var attr in item) {
if ((attr == categoryAttr) || ignoredAttributes[attr]) {
continue;
}
// let the value of current attribute be the pivot
var pivot = item[attr];
// pick the predicate
// depending on the type of the attribute value
var predicateName;
if (typeof pivot == 'number') {
predicateName = '>=';
} else {
// there is no sense to compare non-numeric attributes
// so we will check only equality of such attributes
predicateName = '==';
}
var attrPredPivot = attr + predicateName + pivot;
if (alreadyChecked[attrPredPivot]) {
// skip such pairs of 'attribute-predicate-pivot',
// which been already checked
continue;
}
alreadyChecked[attrPredPivot] = true;
var predicate = predicates[predicateName];
// splitting training set by given 'attribute-predicate-value'
var currSplit = split(trainingSet, attr, predicate, pivot);
// calculating entropy of subsets
var matchEntropy = entropy(currSplit.match, categoryAttr);
var notMatchEntropy = entropy(currSplit.notMatch, categoryAttr);
// calculating informational gain
var newEntropy = 0;
newEntropy += matchEntropy * currSplit.match.length;
newEntropy += notMatchEntropy * currSplit.notMatch.length;
newEntropy /= trainingSet.length;
var currGain = initialEntropy - newEntropy;
if (currGain > bestSplit.gain) {
// remember pairs 'attribute-predicate-value'
// which provides informational gain
bestSplit = currSplit;
bestSplit.predicateName = predicateName;
bestSplit.predicate = predicate;
bestSplit.attribute = attr;
bestSplit.pivot = pivot;
bestSplit.gain = currGain;
}
}
}
if (!bestSplit.gain) {
// can't find optimal split
return { category: mostFrequentValue(trainingSet, categoryAttr) };
}
// building subtrees
builder.maxTreeDepth = maxTreeDepth - 1;
builder.trainingSet = bestSplit.match;
var matchSubTree = buildDecisionTree(builder);
builder.trainingSet = bestSplit.notMatch;
var notMatchSubTree = buildDecisionTree(builder);
return {
attribute: bestSplit.attribute,
predicate: bestSplit.predicate,
predicateName: bestSplit.predicateName,
pivot: bestSplit.pivot,
match: matchSubTree,
notMatch: notMatchSubTree,
matchedCount: bestSplit.match.length,
notMatchedCount: bestSplit.notMatch.length
};
}
/**
* Classifying item, using decision tree
*/
function predict(tree, item) {
var attr,
value,
predicate,
pivot;
// Traversing tree from the root to leaf
while(true) {
if (tree.category) {
// only leafs contains predicted category
return tree.category;
}
attr = tree.attribute;
value = item[attr];
predicate = tree.predicate;
pivot = tree.pivot;
// move to one of subtrees
if (predicate(value, pivot)) {
tree = tree.match;
} else {
tree = tree.notMatch;
}
}
}
/**
* Building array of decision trees
*/
function buildRandomForest(builder, treesNumber) {
var items = builder.trainingSet;
// creating training sets for each tree
var trainingSets = [];
for (var t = 0; t < treesNumber; t++) {
trainingSets[t] = [];
}
for (var i = items.length - 1; i >= 0 ; i--) {
// assigning items to training sets of each tree
// using 'round-robin' strategy
var correspondingTree = i % treesNumber;
trainingSets[correspondingTree].push(items[i]);
}
// building decision trees
var forest = [];
for (var t = 0; t < treesNumber; t++) {
builder.trainingSet = trainingSets[t];
var tree = new DecisionTree(builder);
forest.push(tree);
}
return forest;
}
/**
* Each of decision tree classifying item
* ('voting' that item corresponds to some class).
*
* This function returns hash, which contains
* all classifying results, and number of votes
* which were given for each of classifying results
*/
function predictRandomForest(forest, item) {
var result = {};
for (var i in forest) {
var tree = forest[i];
var prediction = tree.predict(item);
result[prediction] = result[prediction] ? result[prediction] + 1 : 1;
}
return result;
}
var exports = {};
exports.DecisionTree = DecisionTree;
exports.RandomForest = RandomForest;
return exports;
})();

View File

@ -0,0 +1,542 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>图像分类器 - TensorFlow.js & decision-tree.js</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
<!-- 引入 decision-tree.js -->
<script src="decision-tree.js"></script>
<style>
/* 与之前提供的样式相同...为了简洁,省略 */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.main-container {
max-width: 1400px;
margin: 0 auto;
}
h1 {
color: white;
text-align: center;
margin-bottom: 30px;
font-size: 2.5em;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
}
.grid-container {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-bottom: 20px;
}
@media (max-width: 768px) {
.grid-container {
grid-template-columns: 1fr;
}
}
.card {
background: white;
border-radius: 15px;
padding: 25px;
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.class-input {
margin-bottom: 20px;
padding: 15px;
background: #f8f9fa;
border-radius: 10px;
}
.class-input h3 {
color: #555;
margin-bottom: 10px;
display: flex;
align-items: center;
gap: 10px;
}
.class-number {
background: #667eea;
color: white;
width: 25px;
height: 25px;
border-radius: 50%;
display: inline-flex;
align-items: center;
justify-content: center;
font-size: 14px;
}
input[type="text"] {
width: 100%;
padding: 10px;
border: 2px solid #e0e0e0;
border-radius: 5px;
margin-bottom: 10px;
font-size: 16px;
transition: border-color 0.3s;
}
input[type="text"]:focus {
outline: none;
border-color: #667eea;
}
input[type="file"] {
display: none;
}
.file-label {
display: inline-block;
padding: 10px 20px;
background: #667eea;
color: white;
border-radius: 5px;
cursor: pointer;
transition: background 0.3s;
margin-right: 10px;
}
.file-label:hover {
background: #5a67d8;
}
.btn {
padding: 12px 30px;
border: none;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
transition: all 0.3s;
margin: 5px;
}
.btn-primary {
background: #667eea;
color: white;
}
.btn-primary:hover {
background: #5a67d8;
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
.btn-success {
background: #48bb78;
color: white;
}
.btn-success:hover {
background: #38a169;
}
.btn-danger {
background: #f56565;
color: white;
}
.btn-danger:hover {
background: #e53e3e;
}
.btn:disabled {
background: #cbd5e0;
cursor: not-allowed;
transform: none;
}
#webcam-container {
position: relative;
width: 100%;
max-width: 640px;
margin: 20px auto;
}
#webcam {
width: 100%;
border-radius: 10px;
background: #000;
}
.samples-count {
display: inline-block;
background: #edf2f7;
padding: 2px 8px;
border-radius: 10px;
font-size: 12px;
color: #4a5568;
margin-left: 5px;
}
.image-preview {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-top: 10px;
max-height: 150px;
overflow-y: auto;
}
.preview-img {
width: 60px;
height: 60px;
object-fit: cover;
border-radius: 5px;
border: 2px solid #e0e0e0;
}
.status-message {
padding: 15px;
border-radius: 5px;
margin: 10px 0;
text-align: center;
font-weight: 500;
}
.status-success {
background: #c6f6d5;
color: #22543d;
border: 1px solid #9ae6b4;
}
.status-error {
background: #fed7d7;
color: #742a2a;
border: 1px solid #fc8181;
}
.status-info {
background: #bee3f8;
color: #2c5282;
border: 1px solid #90cdf4;
}
.button-group {
display: flex;
gap: 10px;
margin: 20px 0;
flex-wrap: wrap;
}
.full-width {
grid-column: 1 / -1;
}
.prediction-results {
margin-top: 20px;
padding: 20px;
background: #f7fafc;
border-radius: 10px;
}
.prediction-item {
padding: 15px;
margin: 10px 0;
background: white;
border-radius: 8px;
border-left: 4px solid #667eea;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
.prediction-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 8px;
}
.prediction-label {
font-weight: 600;
color: #2d3748;
font-size: 16px;
}
.prediction-confidence {
background: linear-gradient(135deg, #667eea, #764ba2);
color: white;
padding: 4px 12px;
border-radius: 20px;
font-size: 14px;
font-weight: 500;
min-width: 60px;
text-align: center;
}
.confidence-bar-container {
width: 100%;
height: 24px;
background: #e2e8f0;
border-radius: 12px;
overflow: hidden;
position: relative;
}
.confidence-bar {
height: 100%;
background: linear-gradient(90deg, #667eea, #764ba2);
border-radius: 12px;
transition: width 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
min-width: 0;
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
}
.confidence-bar::after {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.3), transparent);
animation: shimmer 2s infinite;
}
@keyframes shimmer {
0% {
transform: translateX(-100%);
}
100% {
transform: translateX(100%);
}
}
.confidence-bar.high {
background: linear-gradient(90deg, #48bb78, #38a169);
}
.confidence-bar.medium {
background: linear-gradient(90deg, #ed8936, #dd6b20);
}
.confidence-bar.low {
background: linear-gradient(90deg, #f56565, #e53e3e);
}
.confidence-percentage {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
color: white;
font-weight: 600;
font-size: 12px;
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
z-index: 1;
}
.top-tags {
margin: 20px 0;
padding: 15px;
background: #edf2fe;
border-radius: 10px;
}
.tag-item {
display: inline-block;
background: white;
padding: 5px 12px;
margin: 5px;
border-radius: 15px;
font-size: 14px;
border: 1px solid #cbd5e0;
}
.tag-weight {
color: #667eea;
font-weight: bold;
margin-left: 5px;
}
.k-selector {
margin: 15px 0;
padding: 15px;
background: #f8f9fa;
border-radius: 8px;
}
.k-selector label {
display: block;
margin-bottom: 10px;
color: #555;
font-weight: 500;
}
.k-value-display {
display: inline-block;
background: #667eea;
color: white;
padding: 2px 8px;
border-radius: 5px;
margin-left: 10px;
}
input[type="range"] {
width: 100%;
margin: 10px 0;
}
.model-info {
margin-top: 20px;
padding: 15px;
background: #f0f4f8;
border-radius: 8px;
font-size: 14px;
}
.info-item {
display: flex;
justify-content: space-between;
margin: 5px 0;
}
.info-label {
color: #718096;
}
.info-value {
color: #2d3748;
font-weight: 500;
}
</style>
</head>
<body>
<div class="main-container">
<h1>🤖 图像分类器 - 随机森林</h1>
<div class="grid-container">
<!-- 数据采集卡片 -->
<div class="card">
<h2>📸 数据采集</h2>
<div class="class-input">
<h3><span class="class-number">1</span> 第一类</h3>
<input type="text" id="class1Name" placeholder="输入类别名称(如:猫)" value="类别1">
<label class="file-label" for="class1Images">
选择图片
</label>
<input type="file" id="class1Images" multiple accept="image/*">
<span class="samples-count" id="class1Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(0)">从摄像头采集</button>
<div class="image-preview" id="class1Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">2</span> 第二类</h3>
<input type="text" id="class2Name" placeholder="输入类别名称(如:狗)" value="类别2">
<label class="file-label" for="class2Images">
选择图片
</label>
<input type="file" id="class2Images" multiple accept="image/*">
<span class="samples-count" id="class2Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(1)">从摄像头采集</button>
<div class="image-preview" id="class2Preview"></div>
</div>
<div class="class-input">
<h3><span class="class-number">3</span> 第三类(可选)</h3>
<input type="text" id="class3Name" placeholder="输入类别名称(可选)" value="">
<label class="file-label" for="class3Images">
选择图片
</label>
<input type="file" id="class3Images" multiple accept="image/*">
<span class="samples-count" id="class3Count">0 张图片</span>
<button class="btn btn-primary" onclick="captureFromWebcam(2)">从摄像头采集</button>
<div class="image-preview" id="class3Preview"></div>
</div>
<div class="button-group">
<button class="btn btn-success" id="addDataBtn">训练模型</button>
<button class="btn btn-danger" id="clearDataBtn">清空数据</button>
</div>
</div>
<!--参数调整-->
<div class="card">
<h2>模型参数调整</h2>
<div class="k-selector">
<label>
树的数量
<span class="k-value-display" id="numTreesDisplay">10</span>
</label>
<input type="range" id="numTrees" min="5" max="50" value="10"
oninput="document.getElementById('numTreesDisplay').textContent = this.value">
<small style="color: #718096;">随机森林中决策树的数量</small>
</div>
<div class="k-selector">
<label>
子集大小 (比例)
<span class="k-value-display" id="subsetSizeDisplay">0.7</span>
</label>
<input type="range" id="subsetSize" min="0.1" max="1" step="0.1" value="0.7"
oninput="document.getElementById('subsetSizeDisplay').textContent = this.value">
<small style="color: #718096;">用于训练每棵树的数据子集占比 (0.1-1.0)</small>
</div>
<div id="dataStatus"></div>
<div class="model-info">
<h3 style="margin-bottom: 10px;"> 模型信息</h3>
<div class="info-item">
<span class="info-label">预训练模型:</span>
<span class="info-value">MobileNet v2</span>
</div>
<div class="info-item">
<span class="info-label">分类器类型:</span>
<span class="info-value">随机森林</span>
</div>
<div class="info-item">
<span class="info-label">总样本数:</span>
<span class="info-value" id="totalSamples">0</span>
</div>
</div>
</div>
</div>
<!-- 预测卡片 -->
<div class="card full-width">
<h2>📹 实时预测</h2>
<div class="button-group">
<button class="btn btn-primary" id="startWebcamBtn">启动摄像头</button>
<button class="btn btn-danger" id="stopWebcamBtn" disabled>停止摄像头</button>
</div>
<div id="webcam-container">
<video id="webcam" autoplay playsinline muted></video>
</div>
<div class="prediction-results" id="predictionResults">
<h3>预测结果</h3>
<div id="predictions">等待预测...</div>
</div>
<div id="predictionStatus"></div>
</div>
</div>
<script src="rf-classifier.js"></script>
</body>
</html>

View File

@ -0,0 +1,472 @@
// 图像分类器 - 基于MobileNet特征标签和 decision-tree.js 实现随机森林
class ImageClassifier {
constructor() {
this.mobilenet = null;
this.randomForest = []; // 存储多个决策树
this.classNames = [];
this.webcamStream = null;
this.isPredicting = false;
this.imagenetClasses = null;
this.trainingSet = [];
this.numTrees = 10; // 随机森林中决策树的数量,可调整
this.subsetSize = 0.7; // 训练集子集大小, 可调整
this.init();
}
async init() {
this.showStatus('dataStatus', 'info', '正在加载 MobileNet 模型...');
try {
this.mobilenet = await mobilenet.load({
version: 2,
alpha: 1.0
});
await this.loadImageNetClasses();
this.showStatus('dataStatus', 'success', 'MobileNet 模型加载完成!');
this.setupEventListeners();
} catch (error) {
this.showStatus('dataStatus', 'error', `加载失败: ${error.message}`);
}
}
async loadImageNetClasses() {
// ImageNet 前10个类别名称简化版
this.imagenetClasses = [
'tench', 'goldfish', 'shark', 'tiger_shark', 'hammerhead',
'electric_ray', 'stingray', 'cock', 'hen', 'ostrich',
'brambling', 'goldfinch', 'house_finch', 'junco', 'indigo_bunting',
'robin', 'bulbul', 'jay', 'magpie', 'chickadee'
];
}
setupEventListeners() {
// 文件上传监听
['class1Images', 'class2Images', 'class3Images'].forEach((id, index) => {
document.getElementById(id).addEventListener('change', (e) => {
this.handleImageUpload(e, index);
});
});
// 按钮监听
document.getElementById('addDataBtn').addEventListener('click', () => this.trainModel());
document.getElementById('clearDataBtn').addEventListener('click', () => this.clearDataset());
document.getElementById('startWebcamBtn').addEventListener('click', () => this.startWebcam());
document.getElementById('stopWebcamBtn').addEventListener('click', () => this.stopWebcam());
// 参数监听
document.getElementById('numTrees').addEventListener('input', (e) => {
this.numTrees = parseInt(e.target.value);
});
document.getElementById('subsetSize').addEventListener('input', (e) => {
this.subsetSize = parseFloat(e.target.value);
});
}
handleImageUpload(event, classIndex) {
const files = event.target.files;
const countElement = document.getElementById(`class${classIndex + 1}Count`);
const previewContainer = document.getElementById(`class${classIndex + 1}Preview`);
countElement.textContent = `${files.length} 张图片`;
// 清空之前的预览
previewContainer.innerHTML = '';
// 添加图片预览
Array.from(files).forEach(file => {
const reader = new FileReader();
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.className = 'preview-img';
previewContainer.appendChild(img);
};
reader.readAsDataURL(file);
});
}
async extractImageNetTags(img) {
try {
const predictions = await this.mobilenet.classify(img);
const logits = this.mobilenet.infer(img, false);
return {
logits: logits, // 1000维特征向量
predictions: predictions, // 前3个预测
topTags: await this.getTopKTags(logits, 10) // 前10个标签和权重
};
} catch (error) {
console.error('特征提取失败:', error);
throw error;
}
}
async getTopKTags(logits, k = 10) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({ value: values[i], index: i });
}
valuesAndIndices.sort((a, b) => b.value - a.value);
const topkValues = new Float32Array(k);
const topkIndices = new Int32Array(k);
for (let i = 0; i < k; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topTags = [];
for (let i = 0; i < k; i++) {
topTags.push({
className: this.imagenetClasses[topkIndices[i]] || `class_${topkIndices[i]}`,
probability: this.softmax(topkValues)[i],
logit: topkValues[i]
});
}
return topTags;
}
softmax(arr) {
const maxLogit = Math.max(...arr);
const scores = arr.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b);
return scores.map(s => s / sum);
}
// 训练随机森林模型
async trainModel() {
const classes = [];
const imageFiles = [];
// 收集所有类别和图片
for (let i = 1; i <= 3; i++) {
const className = document.getElementById(`class${i}Name`).value.trim();
const files = document.getElementById(`class${i}Images`).files;
if (className && files.length > 0) {
classes.push(className);
imageFiles.push(files);
}
}
if (classes.length < 2) {
this.showStatus('dataStatus', 'error', '请至少添加两个类别的图片!');
return;
}
this.classNames = classes;
this.showStatus('dataStatus', 'info', '正在处理图片并训练模型...');
// 准备训练数据
this.trainingSet = [];
let totalProcessed = 0;
let totalImages = imageFiles.reduce((sum, files) => sum + files.length, 0);
// 处理每个类别的图片
for (let classIndex = 0; classIndex < classes.length; classIndex++) {
const files = imageFiles[classIndex];
for (let fileIndex = 0; fileIndex < files.length; fileIndex++) {
try {
const img = await this.loadImage(files[fileIndex]);
const features = await this.extractImageNetTags(img);
// 将logits从tf.Tensor转换为数组
const featureVector = await features.logits.data();
// 将特征向量添加到训练数据
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
item.category = classes[classIndex]; // 类别名称,而不是索引
this.trainingSet.push(item);
totalProcessed++;
const progress = Math.round((totalProcessed / totalImages) * 100);
this.showStatus('dataStatus', 'info',
`处理中... ${totalProcessed}/${totalImages} (${progress}%)`);
img.remove();
features.logits.dispose();
} catch (error) {
console.error('处理图片失败:', error);
}
}
}
try {
// 构建随机森林
this.randomForest = [];
for (let i = 0; i < this.numTrees; i++) {
// 创建具有随机子集的决策树
const trainingSubset = this.createTrainingSubset(this.trainingSet);
const builder = {
trainingSet: trainingSubset,
categoryAttr: 'category'
};
const tree = new dt.DecisionTree(builder);
this.randomForest.push(tree);
}
this.showStatus('dataStatus', 'success', `模型训练完成!共 ${totalProcessed} 个样本,${classes.length} 个类别, ${this.numTrees} 棵树`);
// 更新模型信息
document.getElementById('totalSamples').textContent = totalProcessed;
} catch (error) {
console.error('训练失败:', error);
this.showStatus('dataStatus', 'error', `训练失败: ${error.message}`);
}
}
// 创建训练数据的随机子集(用于随机森林)
createTrainingSubset(trainingSet) {
const subset = [];
const subsetSize = Math.floor(this.subsetSize * trainingSet.length);
for (let i = 0; i < subsetSize; i++) {
const randomIndex = Math.floor(Math.random() * trainingSet.length);
subset.push(trainingSet[randomIndex]);
}
return subset;
}
async loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const img = new Image();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = e.target.result;
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
clearDataset() {
this.randomForest = [];
this.classNames = [];
this.trainingSet = [];
for (let i = 1; i <= 3; i++) {
document.getElementById(`class${i}Images`).value = '';
document.getElementById(`class${i}Count`).textContent = '0 张图片';
document.getElementById(`class${i}Preview`).innerHTML = '';
}
document.getElementById('totalSamples').textContent = '0';
document.getElementById('predictions').innerHTML = '等待预测...';
this.showStatus('dataStatus', 'info', '数据集已清空');
}
startWebcam() {
if (this.randomForest.length == 0) {
this.showStatus('predictionStatus', 'error', '请先训练模型!');
return;
}
const video = document.getElementById('webcam');
navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
})
.then(stream => {
video.srcObject = stream;
this.webcamStream = stream;
document.getElementById('startWebcamBtn').disabled = true;
document.getElementById('stopWebcamBtn').disabled = false;
this.isPredicting = true;
this.predictLoop();
this.showStatus('predictionStatus', 'success', '摄像头已启动');
})
.catch(error => {
this.showStatus('predictionStatus', 'error', `无法访问摄像头: ${error.message}`);
});
}
stopWebcam() {
if (this.webcamStream) {
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
}
this.isPredicting = false;
const video = document.getElementById('webcam');
video.srcObject = null;
document.getElementById('startWebcamBtn').disabled = false;
document.getElementById('stopWebcamBtn').disabled = true;
this.showStatus('predictionStatus', 'info', '摄像头已停止');
}
async predictLoop() {
if (!this.isPredicting) return;
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
const features = await this.extractImageNetTags(video);
const featureVector = await features.logits.data();
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
const { predictedCategory, probabilities } = this.predictWithRandomForest(item);
this.displayPrediction(predictedCategory, probabilities);
features.logits.dispose();
} catch (error) {
console.error('预测错误:', error);
}
}
requestAnimationFrame(() => this.predictLoop());
}
predictWithRandomForest(item) {
const votes = {};
this.classNames.forEach(className => {
votes[className] = 0;
});
this.randomForest.forEach(tree => {
const prediction = tree.predict(item);
votes[prediction] = (votes[prediction] || 0) + 1;
});
let predictedCategory = null;
let maxVotes = 0;
for (const category in votes) {
if (votes[category] > maxVotes) {
predictedCategory = category;
maxVotes = votes[category];
}
}
const probabilities = {};
for (const category in votes) {
probabilities[category] = votes[category] / this.numTrees;
}
return {
predictedCategory: predictedCategory,
probabilities: probabilities
};
}
async captureFromWebcam(classIndex) {
if (!this.webcamStream) {
// 临时启动摄像头
const video = document.getElementById('webcam');
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { facingMode: 'user' },
audio: false
});
video.srcObject = stream;
this.webcamStream = stream;
// 等待视频加载
setTimeout(async () => {
await this.addWebcamSample(classIndex);
// 停止临时摄像头
this.webcamStream.getTracks().forEach(track => track.stop());
this.webcamStream = null;
video.srcObject = null;
}, 1000);
} catch (error) {
this.showStatus('dataStatus', 'error', `无法访问摄像头: ${error.message}`);
}
} else {
await this.addWebcamSample(classIndex);
}
}
async addWebcamSample(classIndex) {
const video = document.getElementById('webcam');
if (video.readyState === 4) {
try {
const features = await this.extractImageNetTags(video);
const featureVector = await features.logits.data();
// 使用logits作为对象的属性
const item = {};
Array.from(featureVector).forEach((value, index) => {
item[`feature_${index}`] = value;
});
// 添加类别信息
const className = document.getElementById(`class${classIndex + 1}Name`).value.trim();
item.category = className;
this.trainingSet.push(item);
this.showStatus('dataStatus', 'success', `从摄像头添加 ${className} 类的样本`);
features.logits.dispose();
} catch (error) {
console.error('添加样本失败:', error);
this.showStatus('dataStatus', 'error', `添加样本失败: ${error.message}`);
}
}
}
displayPrediction(category, probabilities) {
const container = document.getElementById('predictions');
let html = `预测类别:${category}<br>`;
for (const className in probabilities) {
const probability = (probabilities[className] * 100).toFixed(2);
html += `${className}: ${probability}%<br>`;
}
container.innerHTML = html;
}
showStatus(elementId, type, message) {
const element = document.getElementById(elementId);
const classMap = {
'success': 'status-success',
'error': 'status-error',
'info': 'status-info'
};
element.className = `status-message ${classMap[type]}`;
element.textContent = message;
}
}
// 全局函数:从摄像头采集
function captureFromWebcam(classIndex) {
if (window.classifier) {
window.classifier.captureFromWebcam(classIndex);
}
}
document.addEventListener('DOMContentLoaded', async () => {
window.classifier = new ImageClassifier();
});