le-yolo/src/labels.py
2025-06-06 14:35:16 +08:00

43 lines
1.7 KiB
Python

import torch
from ultralytics import YOLO
import os
class Yolov8Detect():
def __init__(self, weights):
cuda = True if torch.cuda.is_available() else False
self.device = torch.device('cuda:0' if cuda else 'cpu')
self.detect_model = YOLO(weights)
self.detect_model.to(self.device)
def inferences(self, inputs):
results = self.detect_model(inputs)
for result in results:
label_text = []
boxes = result.boxes
for box in boxes:
cat_num = int(box.cls.cpu())
label_text.append([cat_num, box.xywhn.cpu().numpy()])
# 指定新的保存文件夹
save_folder = r'C:/workspace/le-yolo/data/labels/aut/'
filename = os.path.splitext(os.path.basename(inputs))[0] + '.txt'
save_path = os.path.join(save_folder, filename)
txt_construct(save_path, label_text=label_text)
def txt_construct(save_path, label_text):
with open(save_path, 'w') as file:
file.truncate()
for label in label_text:
with open(save_path, 'a') as txt_file:
label_ = label[0]
size = label[1][0].tolist()
size_string = ' '.join(map(str, size))
result = f'{label_} {size_string}'
print('result', result)
txt_file.write(str(result))
txt_file.write('\n')
if __name__ == '__main__':
model_path = r'C:\workspace\le-yolo\runs\detect\train38\weights\best.pt'
model = Yolov8Detect(model_path)
import glob
image_path = glob.glob('../data/images/test/*.jpg')
for img_path in image_path[:]:
model.inferences(img_path)