Files
hirres_tractor_vision/lib/alg/predict_yolo8_template.py
2025-05-30 16:30:37 +08:00

85 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
from ultralytics import YOLO
# YOLOv8Detector类用于加载YOLO模型进行目标检测并在图像上绘制信息
class YOLOv8Detector:
def __init__(self):
"""
初始化YOLOv8检测器
:param classes: 自定义类别名称列表(暂未使用)
"""
self.model = YOLO(r"model/yolov8n.pt") # 使用YOLOv8官方预训练模型
# self.model = YOLO(r"model/best.pt") # 使用自训练模型
self.class_names = self.model.names # 获取类别名称
self._frame_counter = 0 # 帧计数器(用于调试打印)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# 执行YOLO目标检测并在图像中绘制检测框、中心坐标和尺寸信息不再过滤类别
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def detect_and_draw(self, image):
height, width = image.shape[:2]
# 使用YOLO模型进行检测
results = self.model(image, imgsz=640)
frame_id = self._frame_counter
self._frame_counter += 1
result_info = [] # 存储每个目标的检测信息
# print(f"\n[FRAME {frame_id}] Detection Result:")
for result in results:
boxes = result.boxes
# 遍历所有检测框(不再按类别筛选)
for i in range(len(boxes)):
cls_id = int(boxes.cls[i].item()) # 类别索引
cls_name = self.class_names[cls_id] if cls_id < len(self.class_names) else f"class_{cls_id}"
score = float(boxes.conf[i].item()) # 置信度
# 1boxes.xyxy[i]提取第i个种类边界框坐标[x1, y1, x2, y2]2tolist()实现将PyTorch 的张量Tensor[x1, y1, x2, y2]转换为 普通的 Python 列表3使用 map() 函数将列表中的 浮点数坐标转换为整数坐标,因为像素坐标用于图像绘制时必须是整数。
x1_box, y1_box, x2_box, y2_box = map(int, boxes.xyxy[i].tolist())
# 计算中心点与宽高
x_center = (x1_box + x2_box) / 2
y_center = (y1_box + y2_box) / 2
bbox_width = x2_box - x1_box
bbox_height = y2_box - y1_box
# 打印调试信息
# print(f"[{cls_name}] Score: {score:.2f}")
# print(f" Center: ({x_center:.1f}, {y_center:.1f})")
# print(f" BBox: Width={bbox_width}px, Height={bbox_height}px")
# print(f" 左上角坐标: ({x1_box}px, {y1_box}px)")
# 绘制检测框绿色参数2表示线宽
cv2.rectangle(image, (x1_box, y1_box), (x2_box, y2_box), (0, 255, 0), 2)
# 绘制框右上角顶部标签(类别名+得分)
label = f"{cls_name} {score:.2f}"
cv2.putText(image, label, (x2_box, y1_box - 60),
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 0), 3)
# 绘制框右上角顶部标签(类别名+得分)
cv2.putText(image, f"L-corner-coor:{x1_box, y1_box}", (x1_box, y1_box - 20),
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 0), 3)
# 在框底部左下角显示中心坐标和尺寸信息
cv2.putText(image, f"Center: ({x_center:.1f}, {y_center:.1f})",
(x1_box, y2_box + 70),
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 255), 3)
cv2.putText(image, f"Width: {bbox_width}px, Height: {bbox_height}px",
(x1_box, y2_box + 150),
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 255), 3)
# 保存检测信息字典."L-corner-coor":(x1_box, y1_box)左上角坐标点
result_dict = {
"class": cls_name,
"score": score,
"center": (x_center, y_center),
"bbox": (bbox_width, bbox_height),
"L-corner-coor":(x1_box, y1_box)
}
result_info.append(result_dict)
return image, result_info