401 lines
14 KiB
Python
401 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
实时手部检测Web服务器
|
||
支持WebSocket通信,实时视频流处理和机械臂控制
|
||
"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import json
|
||
import logging
|
||
import os
|
||
import time
|
||
from threading import Thread
|
||
from typing import Dict, Optional, Any
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from flask import Flask, render_template, request
|
||
from flask_socketio import SocketIO, emit
|
||
from PIL import Image
|
||
import io
|
||
|
||
from hand_detection_3d import load_mediapipe_model, process_frame_3d
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class HandDetectionWebServer:
|
||
"""实时手部检测Web服务器"""
|
||
|
||
def __init__(self, host='0.0.0.0', port=5000):
|
||
self.host = host
|
||
self.port = port
|
||
|
||
# Flask应用和SocketIO
|
||
# 设置模板和静态文件路径
|
||
import os
|
||
template_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
|
||
self.app = Flask(__name__, template_folder=template_dir)
|
||
self.app.config['SECRET_KEY'] = 'hand_detection_secret_key'
|
||
self.socketio = SocketIO(self.app, cors_allowed_origins="*", async_mode='threading')
|
||
|
||
# MediaPipe模型
|
||
self.hands_model = load_mediapipe_model()
|
||
|
||
# 状态管理
|
||
self.clients = {} # 连接的客户端
|
||
self.current_frame = None
|
||
self.previous_hand_data = None
|
||
self.detection_results = {
|
||
'x_angle': 90,
|
||
'y_angle': 90,
|
||
'z_angle': 90,
|
||
'grip': 0,
|
||
'action': 'none',
|
||
'speed': 5,
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
# 性能监控
|
||
self.fps_counter = 0
|
||
self.last_fps_time = time.time()
|
||
self.current_fps = 0
|
||
|
||
# 配置路由和事件处理
|
||
self._setup_routes()
|
||
self._setup_socket_events()
|
||
|
||
def _setup_routes(self):
|
||
"""设置HTTP路由"""
|
||
|
||
@self.app.route('/')
|
||
def index():
|
||
"""主页面"""
|
||
return render_template('index.html')
|
||
|
||
@self.app.route('/api/status')
|
||
def api_status():
|
||
"""获取系统状态"""
|
||
return {
|
||
'status': 'running',
|
||
'fps': self.current_fps,
|
||
'clients': len(self.clients),
|
||
'detection_results': self.detection_results
|
||
}
|
||
|
||
def _setup_socket_events(self):
|
||
"""设置WebSocket事件处理"""
|
||
|
||
@self.socketio.on('connect')
|
||
def handle_connect():
|
||
"""客户端连接"""
|
||
client_id = request.sid
|
||
self.clients[client_id] = {
|
||
'connected_at': time.time(),
|
||
'type': 'unknown',
|
||
'last_ping': time.time()
|
||
}
|
||
logger.info(f"客户端 {client_id} 已连接")
|
||
|
||
# 发送欢迎消息和当前状态
|
||
emit('status', {
|
||
'message': '连接成功',
|
||
'client_id': client_id,
|
||
'current_results': self.detection_results
|
||
})
|
||
|
||
@self.socketio.on('disconnect')
|
||
def handle_disconnect():
|
||
"""客户端断开连接"""
|
||
client_id = request.sid
|
||
if client_id in self.clients:
|
||
del self.clients[client_id]
|
||
logger.info(f"客户端 {client_id} 已断开连接")
|
||
|
||
@self.socketio.on('register_client')
|
||
def handle_register_client(data):
|
||
"""注册客户端类型"""
|
||
client_id = request.sid
|
||
client_type = data.get('type', 'unknown')
|
||
|
||
if client_id in self.clients:
|
||
self.clients[client_id]['type'] = client_type
|
||
logger.info(f"客户端 {client_id} 注册为: {client_type}")
|
||
|
||
emit('registration_success', {
|
||
'client_id': client_id,
|
||
'type': client_type
|
||
})
|
||
|
||
@self.socketio.on('video_frame')
|
||
def handle_video_frame(data):
|
||
"""处理视频帧"""
|
||
try:
|
||
# 解码base64图像
|
||
frame_data = base64.b64decode(data['frame'])
|
||
frame = self._decode_frame(frame_data)
|
||
|
||
if frame is not None:
|
||
# 处理帧并检测手部
|
||
control_signal, hand_data = process_frame_3d(
|
||
frame, self.hands_model, self.previous_hand_data
|
||
)
|
||
|
||
# 更新检测结果
|
||
self.detection_results = control_signal
|
||
self.detection_results['timestamp'] = time.time()
|
||
self.previous_hand_data = hand_data
|
||
|
||
# 编码处理后的帧
|
||
processed_frame_data = self._encode_frame(frame)
|
||
|
||
# 发送结果给web预览客户端
|
||
self.socketio.emit('detection_results', {
|
||
'control_signal': control_signal,
|
||
'processed_frame': processed_frame_data,
|
||
'fps': self.current_fps
|
||
}, room=None)
|
||
|
||
# 发送控制信号给机械臂客户端
|
||
self._send_to_robot_clients(control_signal)
|
||
|
||
# 更新FPS
|
||
self._update_fps()
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理视频帧时出错: {e}")
|
||
emit('error', {'message': str(e)})
|
||
|
||
@self.socketio.on('ping')
|
||
def handle_ping():
|
||
"""处理ping请求"""
|
||
client_id = request.sid
|
||
if client_id in self.clients:
|
||
self.clients[client_id]['last_ping'] = time.time()
|
||
emit('pong', {'timestamp': time.time()})
|
||
|
||
@self.socketio.on('get_detection_results')
|
||
def handle_get_detection_results():
|
||
"""获取最新的检测结果"""
|
||
emit('detection_results', {
|
||
'control_signal': self.detection_results,
|
||
'fps': self.current_fps
|
||
})
|
||
|
||
@self.socketio.on('start_local_test')
|
||
def handle_start_local_test(data=None):
|
||
"""处理开始本地测试请求"""
|
||
try:
|
||
# 如果提供了视频路径,使用指定的视频
|
||
if data and 'video_path' in data:
|
||
test_video = data['video_path']
|
||
if not os.path.exists(test_video):
|
||
emit('test_error', {
|
||
'message': f'视频文件不存在: {test_video}'
|
||
})
|
||
return
|
||
else:
|
||
# 检查是否有默认测试视频
|
||
test_videos = [
|
||
'data/videos/test_basic.mp4',
|
||
'data/videos/test_gesture.mp4'
|
||
]
|
||
|
||
# 找到第一个存在的测试视频
|
||
test_video = None
|
||
for video_path in test_videos:
|
||
if os.path.exists(video_path):
|
||
test_video = video_path
|
||
break
|
||
|
||
if not test_video:
|
||
# 没有找到测试视频,提供帮助信息
|
||
emit('test_error', {
|
||
'message': '未找到测试视频文件',
|
||
'help': '请先生成测试视频:python create_test_video.py'
|
||
})
|
||
return
|
||
|
||
logger.info(f"开始本地测试,使用视频: {test_video}")
|
||
self.start_local_video_test(test_video)
|
||
|
||
emit('test_started', {
|
||
'message': f'本地测试已开始,使用视频: {os.path.basename(test_video)}',
|
||
'video_path': test_video
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"启动本地测试时出错: {e}")
|
||
emit('test_error', {
|
||
'message': f'启动本地测试失败: {str(e)}'
|
||
})
|
||
|
||
@self.socketio.on('get_video_list')
|
||
def handle_get_video_list():
|
||
"""获取可用的视频文件列表"""
|
||
try:
|
||
video_dirs = ['data/videos', 'videos', '.']
|
||
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
|
||
videos = []
|
||
|
||
for video_dir in video_dirs:
|
||
if os.path.exists(video_dir):
|
||
for file in os.listdir(video_dir):
|
||
if any(file.lower().endswith(ext) for ext in video_extensions):
|
||
file_path = os.path.join(video_dir, file)
|
||
try:
|
||
file_size = os.path.getsize(file_path)
|
||
size_mb = file_size / (1024 * 1024)
|
||
videos.append({
|
||
'path': file_path,
|
||
'name': file,
|
||
'size': f'{size_mb:.1f}MB'
|
||
})
|
||
except OSError:
|
||
continue
|
||
|
||
# 按文件名排序
|
||
videos.sort(key=lambda x: x['name'])
|
||
|
||
emit('video_list', {
|
||
'videos': videos
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取视频列表时出错: {e}")
|
||
emit('video_list', {
|
||
'videos': []
|
||
})
|
||
|
||
def _decode_frame(self, frame_data: bytes) -> Optional[np.ndarray]:
|
||
"""解码图像帧"""
|
||
try:
|
||
# 使用PIL解码
|
||
image = Image.open(io.BytesIO(frame_data))
|
||
frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||
return frame
|
||
except Exception as e:
|
||
logger.error(f"解码帧时出错: {e}")
|
||
return None
|
||
|
||
def _encode_frame(self, frame: np.ndarray) -> str:
|
||
"""编码图像帧为base64"""
|
||
try:
|
||
# 转换为RGB格式
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
image = Image.fromarray(frame_rgb)
|
||
|
||
# 编码为JPEG
|
||
buffer = io.BytesIO()
|
||
image.save(buffer, format='JPEG', quality=80)
|
||
frame_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||
|
||
return f"data:image/jpeg;base64,{frame_data}"
|
||
except Exception as e:
|
||
logger.error(f"编码帧时出错: {e}")
|
||
return ""
|
||
|
||
def _send_to_robot_clients(self, control_signal: Dict[str, Any]):
|
||
"""发送控制信号给机械臂客户端"""
|
||
robot_clients = [
|
||
client_id for client_id, info in self.clients.items()
|
||
if info.get('type') == 'robot'
|
||
]
|
||
|
||
if robot_clients:
|
||
for client_id in robot_clients:
|
||
self.socketio.emit('robot_control', control_signal, room=client_id)
|
||
|
||
def _update_fps(self):
|
||
"""更新FPS计数"""
|
||
self.fps_counter += 1
|
||
current_time = time.time()
|
||
|
||
if current_time - self.last_fps_time >= 1.0: # 每秒更新一次
|
||
self.current_fps = self.fps_counter
|
||
self.fps_counter = 0
|
||
self.last_fps_time = current_time
|
||
|
||
def start_local_video_test(self, video_path: str):
|
||
"""启动本地视频测试"""
|
||
def video_test_thread():
|
||
cap = cv2.VideoCapture(video_path)
|
||
|
||
while cap.isOpened():
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
# 处理帧
|
||
control_signal, hand_data = process_frame_3d(
|
||
frame, self.hands_model, self.previous_hand_data
|
||
)
|
||
|
||
# 更新状态
|
||
self.detection_results = control_signal
|
||
self.detection_results['timestamp'] = time.time()
|
||
self.previous_hand_data = hand_data
|
||
|
||
# 编码帧
|
||
processed_frame_data = self._encode_frame(frame)
|
||
|
||
# 广播结果
|
||
self.socketio.emit('detection_results', {
|
||
'control_signal': control_signal,
|
||
'processed_frame': processed_frame_data,
|
||
'fps': self.current_fps
|
||
})
|
||
|
||
# 发送给机械臂
|
||
self._send_to_robot_clients(control_signal)
|
||
|
||
# 更新FPS
|
||
self._update_fps()
|
||
|
||
# 控制帧率
|
||
time.sleep(1/30) # 30 FPS
|
||
|
||
cap.release()
|
||
|
||
thread = Thread(target=video_test_thread)
|
||
thread.daemon = True
|
||
thread.start()
|
||
logger.info(f"本地视频测试已启动: {video_path}")
|
||
|
||
def run(self, debug=False):
|
||
"""启动Web服务器"""
|
||
logger.info(f"启动手部检测Web服务器 http://{self.host}:{self.port}")
|
||
self.socketio.run(
|
||
self.app,
|
||
host=self.host,
|
||
port=self.port,
|
||
debug=debug,
|
||
allow_unsafe_werkzeug=True
|
||
)
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='实时手部检测Web服务器')
|
||
parser.add_argument('--host', default='0.0.0.0', help='服务器地址')
|
||
parser.add_argument('--port', type=int, default=5000, help='端口号')
|
||
parser.add_argument('--debug', action='store_true', help='调试模式')
|
||
parser.add_argument('--test-video', help='本地测试视频路径')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 创建服务器实例
|
||
server = HandDetectionWebServer(host=args.host, port=args.port)
|
||
|
||
# 如果指定了测试视频,启动本地视频测试
|
||
if args.test_video:
|
||
server.start_local_video_test(args.test_video)
|
||
|
||
# 启动服务器
|
||
server.run(debug=args.debug)
|
||
|
||
if __name__ == '__main__':
|
||
main() |