[CF]pose 关键点检测模型

This commit is contained in:
songbingle 2025-06-09 17:07:02 +08:00
parent 1d2dc0725c
commit 4ebfdaf313
11 changed files with 92 additions and 4 deletions

BIN
runs/pose/predict/video.avi Normal file

Binary file not shown.

BIN
runs/pose/predict2/test.avi Normal file

Binary file not shown.

View File

@ -4,6 +4,7 @@ import os
from PIL import Image,ImageEnhance
from PIL import Image
import os
# 将彩色图片转为灰度图片并保存
folder_path = 'C:/workspace/le-yolo/data/images/train'
for filename in os.listdir(folder_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):

View File

@ -3,6 +3,8 @@ from ultralytics import YOLO
import os
from video import video_to_pic
# 用训练好的模型来进行模型自动标注
class Yolov8Detect():
def __init__(self, weights):
cuda = True if torch.cuda.is_available() else False

View File

@ -1,6 +1,7 @@
from ultralytics import YOLO
import cv2
import os
# 截取目标图片 并保存
# model = YOLO("yolov8n.pt")
model = YOLO(r"C:\workspace\le-yolo\runs\detect\train40\weights\best.pt")
target_class = 0

12
src/pose.py Normal file
View File

@ -0,0 +1,12 @@
from ultralytics import YOLO
import time
# 关键点检测
start_time = time.time()
model = YOLO('yolov8n-pose.pt')
# 预测
results = model.predict(source='../res/test.mp4', show=True, save=True)
print('预测完成')
end_time = time.time()
run_time = end_time - start_time
min = run_time // 60
print('运行时间:', min, '分钟')

11
src/pose.yaml Normal file
View File

@ -0,0 +1,11 @@
path: /home/le/le-yolo/data
train: images/train
val: images/val
test: images/test
# 关键点坐标 13个关键点每个关键点有3个坐标[x,y,v]
kpt_shape: [13, 3]
# 翻转索引
flip_idx: [5, 4, 3, 2, 1, 0, 8, 9, 6, 7]
names: [ 'person' ]

59
src/pose_show.py Normal file
View File

@ -0,0 +1,59 @@
import cv2
import cvzone
import numpy as np
from ultralytics import YOLO
model = YOLO("yolov8n-pose.pt")
vide_opath = '../res/test.mp4'
cap = cv2.VideoCapture(vide_opath)
# cap = cv2.VideoCapture(0) # 打开摄像头
# 连接顺序 0,1) 0-1
# 17个关键点 鼻子 眼睛 耳朵 肩膀 手肘 手腕 胯 膝盖 脚腕
connections = [
(3, 1), (1, 0), (0, 2), (2, 4), (1, 2), (4, 6), (3, 5),
(5, 6), (5, 7), (7, 9),
(6, 8), (8, 10),
(11, 12), (11, 13), (13, 15),
(12, 14), (14, 16),
(5, 11), (6, 12)
]
while cap.isOpened():
ret, frame = cap.read()
if not ret:
cap = cv2.VideoCapture("video.MOV")
continue
frame = cv2.resize(frame, (640, 480))
width, height = frame.shape[:2]
# 创建一个空白的图像 绘制目标关键点
blank_image = np.zeros((width, height, 3), dtype=np.uint8)
results = model(frame)
# 绘制关键点
for keypoints in results[0].keypoints.data:
keypoints = keypoints.cpu().numpy()
for i, keypoint in enumerate(keypoints):
x, y, confidence = keypoint
if confidence > 0.7:
cv2.circle(blank_image, (int(x), int(y)), radius=5, color=(0, 250, 50), thickness=1)
cv2.putText(blank_image, f'{i}', (int(x), int(y) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255),
1, cv2.LINE_AA)
# 连线 绘制骨架
for part_a, part_b in connections:
if keypoints.any():
x1, y1, conf1 = keypoints[part_a]
x2, y2, conf2 = keypoints[part_b]
# 置信度 大于0.5 连线
if conf1 > 0.5 and conf2 > 0.5:
cv2.line(blank_image, (int(x1), int(y1)), (int(x2), int(y2)), (50, 250, 250), thickness=2)
frame = results[0].plot(labels=False, conf=False, boxes=False)
# output = cvzone.stackImages([frame, blank_image], cols=2, scale=0.80)
# cv2.imshow("并排绘制", output)
cv2.imshow("叠加图像", frame)
# 按空格提前结束循环
if cv2.waitKey(1) & 0xFF == ord(' '):
break
cap.release()
cv2.destroyAllWindows()

View File

@ -3,9 +3,10 @@ import time
start_time = time.time()
model = YOLO("yolov8n.pt")
# model = YOLO(r"C:\workspace\le-yolo\runs\detect\train40\weights\last.pt")
model.train(data="data.yaml", epochs=100, batch=8, device=0, imgsz=640, augment = True)
# CPUdevice='cpu' GPUdevice=0
model.train(data="data.yaml", epochs=100, batch=8, device=0, imgsz=640, augment = True) # augment=True 数据增强
# model.train(data="data.yaml", epochs=100, batch=8, device='cpu', imgsz=640, augment = True, lr = 0.001,wight_decay = 0.0005 )
# model.val()
# model.val() # 验证
print('训练完成')
end_time = time.time()
run_time = end_time - start_time

View File

@ -1,4 +1,5 @@
import cv2
# 视频抽帧 并保存
def video_to_pic(vide_opath):
# vide_opath = 'C:/workspace/le-yolo/res/6.mp4'
video = cv2.VideoCapture(vide_opath)
@ -7,8 +8,8 @@ def video_to_pic(vide_opath):
ret, frame = video.read()
else:
ret = False
timeF = 10
filepath = 'C:/pic/test_'
timeF = 10 # 帧数
filepath = 'C:/pic/test_' # 保存图片的路径
while ret:
ret, frame = video.read()
if num % timeF == 0:

BIN
src/yolov8n-pose.pt Normal file

Binary file not shown.