[CF]pose 关键点检测模型
This commit is contained in:
parent
1d2dc0725c
commit
4ebfdaf313
BIN
runs/pose/predict/video.avi
Normal file
BIN
runs/pose/predict/video.avi
Normal file
Binary file not shown.
BIN
runs/pose/predict2/test.avi
Normal file
BIN
runs/pose/predict2/test.avi
Normal file
Binary file not shown.
@ -4,6 +4,7 @@ import os
|
||||
from PIL import Image,ImageEnhance
|
||||
from PIL import Image
|
||||
import os
|
||||
# 将彩色图片转为灰度图片并保存
|
||||
folder_path = 'C:/workspace/le-yolo/data/images/train'
|
||||
for filename in os.listdir(folder_path):
|
||||
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
|
||||
|
@ -3,6 +3,8 @@ from ultralytics import YOLO
|
||||
import os
|
||||
from video import video_to_pic
|
||||
|
||||
# 用训练好的模型来进行模型自动标注
|
||||
|
||||
class Yolov8Detect():
|
||||
def __init__(self, weights):
|
||||
cuda = True if torch.cuda.is_available() else False
|
||||
|
@ -1,6 +1,7 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import os
|
||||
# 截取目标图片 并保存
|
||||
# model = YOLO("yolov8n.pt")
|
||||
model = YOLO(r"C:\workspace\le-yolo\runs\detect\train40\weights\best.pt")
|
||||
target_class = 0
|
||||
|
12
src/pose.py
Normal file
12
src/pose.py
Normal file
@ -0,0 +1,12 @@
|
||||
from ultralytics import YOLO
|
||||
import time
|
||||
# 关键点检测
|
||||
start_time = time.time()
|
||||
model = YOLO('yolov8n-pose.pt')
|
||||
# 预测
|
||||
results = model.predict(source='../res/test.mp4', show=True, save=True)
|
||||
print('预测完成')
|
||||
end_time = time.time()
|
||||
run_time = end_time - start_time
|
||||
min = run_time // 60
|
||||
print('运行时间:', min, '分钟')
|
11
src/pose.yaml
Normal file
11
src/pose.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
path: /home/le/le-yolo/data
|
||||
train: images/train
|
||||
val: images/val
|
||||
test: images/test
|
||||
|
||||
# 关键点坐标 13个关键点,每个关键点有3个坐标[x,y,v]
|
||||
kpt_shape: [13, 3]
|
||||
|
||||
# 翻转索引
|
||||
flip_idx: [5, 4, 3, 2, 1, 0, 8, 9, 6, 7]
|
||||
names: [ 'person' ]
|
59
src/pose_show.py
Normal file
59
src/pose_show.py
Normal file
@ -0,0 +1,59 @@
|
||||
import cv2
|
||||
import cvzone
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO("yolov8n-pose.pt")
|
||||
|
||||
vide_opath = '../res/test.mp4'
|
||||
cap = cv2.VideoCapture(vide_opath)
|
||||
# cap = cv2.VideoCapture(0) # 打开摄像头
|
||||
|
||||
# 连接顺序 (0,1) 0-1
|
||||
# 17个关键点 鼻子 眼睛 耳朵 肩膀 手肘 手腕 胯 膝盖 脚腕
|
||||
connections = [
|
||||
(3, 1), (1, 0), (0, 2), (2, 4), (1, 2), (4, 6), (3, 5),
|
||||
(5, 6), (5, 7), (7, 9),
|
||||
(6, 8), (8, 10),
|
||||
(11, 12), (11, 13), (13, 15),
|
||||
(12, 14), (14, 16),
|
||||
(5, 11), (6, 12)
|
||||
]
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
cap = cv2.VideoCapture("video.MOV")
|
||||
continue
|
||||
frame = cv2.resize(frame, (640, 480))
|
||||
width, height = frame.shape[:2]
|
||||
# 创建一个空白的图像 绘制目标关键点
|
||||
blank_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
results = model(frame)
|
||||
# 绘制关键点
|
||||
for keypoints in results[0].keypoints.data:
|
||||
keypoints = keypoints.cpu().numpy()
|
||||
for i, keypoint in enumerate(keypoints):
|
||||
x, y, confidence = keypoint
|
||||
if confidence > 0.7:
|
||||
cv2.circle(blank_image, (int(x), int(y)), radius=5, color=(0, 250, 50), thickness=1)
|
||||
cv2.putText(blank_image, f'{i}', (int(x), int(y) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255),
|
||||
1, cv2.LINE_AA)
|
||||
# 连线 绘制骨架
|
||||
for part_a, part_b in connections:
|
||||
if keypoints.any():
|
||||
x1, y1, conf1 = keypoints[part_a]
|
||||
x2, y2, conf2 = keypoints[part_b]
|
||||
# 置信度 大于0.5 连线
|
||||
if conf1 > 0.5 and conf2 > 0.5:
|
||||
cv2.line(blank_image, (int(x1), int(y1)), (int(x2), int(y2)), (50, 250, 250), thickness=2)
|
||||
frame = results[0].plot(labels=False, conf=False, boxes=False)
|
||||
# output = cvzone.stackImages([frame, blank_image], cols=2, scale=0.80)
|
||||
# cv2.imshow("并排绘制", output)
|
||||
cv2.imshow("叠加图像", frame)
|
||||
# 按空格提前结束循环
|
||||
if cv2.waitKey(1) & 0xFF == ord(' '):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
@ -3,9 +3,10 @@ import time
|
||||
start_time = time.time()
|
||||
model = YOLO("yolov8n.pt")
|
||||
# model = YOLO(r"C:\workspace\le-yolo\runs\detect\train40\weights\last.pt")
|
||||
model.train(data="data.yaml", epochs=100, batch=8, device=0, imgsz=640, augment = True)
|
||||
# CPU:device='cpu' GPU:device=0
|
||||
model.train(data="data.yaml", epochs=100, batch=8, device=0, imgsz=640, augment = True) # augment=True 数据增强
|
||||
# model.train(data="data.yaml", epochs=100, batch=8, device='cpu', imgsz=640, augment = True, lr = 0.001,wight_decay = 0.0005 )
|
||||
# model.val()
|
||||
# model.val() # 验证
|
||||
print('训练完成')
|
||||
end_time = time.time()
|
||||
run_time = end_time - start_time
|
||||
|
@ -1,4 +1,5 @@
|
||||
import cv2
|
||||
# 视频抽帧 并保存
|
||||
def video_to_pic(vide_opath):
|
||||
# vide_opath = 'C:/workspace/le-yolo/res/6.mp4'
|
||||
video = cv2.VideoCapture(vide_opath)
|
||||
@ -7,8 +8,8 @@ def video_to_pic(vide_opath):
|
||||
ret, frame = video.read()
|
||||
else:
|
||||
ret = False
|
||||
timeF = 10
|
||||
filepath = 'C:/pic/test_'
|
||||
timeF = 10 # 帧数
|
||||
filepath = 'C:/pic/test_' # 保存图片的路径
|
||||
while ret:
|
||||
ret, frame = video.read()
|
||||
if num % timeF == 0:
|
||||
|
BIN
src/yolov8n-pose.pt
Normal file
BIN
src/yolov8n-pose.pt
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user