TractorVision初期移植
This commit is contained in:
84
lib/alg/predict_yolo8_template.py
Normal file
84
lib/alg/predict_yolo8_template.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
# YOLOv8Detector类:用于加载YOLO模型,进行目标检测并在图像上绘制信息
|
||||
class YOLOv8Detector:
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化YOLOv8检测器
|
||||
:param classes: 自定义类别名称列表(暂未使用)
|
||||
"""
|
||||
self.model = YOLO(r"model/yolov8n.pt") # 使用YOLOv8官方预训练模型
|
||||
# self.model = YOLO(r"model/best.pt") # 使用自训练模型
|
||||
|
||||
self.class_names = self.model.names # 获取类别名称
|
||||
self._frame_counter = 0 # 帧计数器(用于调试打印)
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
# 执行YOLO目标检测,并在图像中绘制检测框、中心坐标和尺寸信息(不再过滤类别)
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
def detect_and_draw(self, image):
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# 使用YOLO模型进行检测
|
||||
results = self.model(image, imgsz=640)
|
||||
frame_id = self._frame_counter
|
||||
self._frame_counter += 1
|
||||
|
||||
result_info = [] # 存储每个目标的检测信息
|
||||
# print(f"\n[FRAME {frame_id}] Detection Result:")
|
||||
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
|
||||
# 遍历所有检测框(不再按类别筛选)
|
||||
for i in range(len(boxes)):
|
||||
cls_id = int(boxes.cls[i].item()) # 类别索引
|
||||
cls_name = self.class_names[cls_id] if cls_id < len(self.class_names) else f"class_{cls_id}"
|
||||
score = float(boxes.conf[i].item()) # 置信度
|
||||
|
||||
# (1)boxes.xyxy[i]提取第i个种类边界框坐标[x1, y1, x2, y2];(2)tolist()实现将PyTorch 的张量(Tensor)[x1, y1, x2, y2]转换为 普通的 Python 列表;(3)使用 map() 函数将列表中的 浮点数坐标转换为整数坐标,因为像素坐标用于图像绘制时必须是整数。
|
||||
x1_box, y1_box, x2_box, y2_box = map(int, boxes.xyxy[i].tolist())
|
||||
|
||||
# 计算中心点与宽高
|
||||
x_center = (x1_box + x2_box) / 2
|
||||
y_center = (y1_box + y2_box) / 2
|
||||
bbox_width = x2_box - x1_box
|
||||
bbox_height = y2_box - y1_box
|
||||
|
||||
# 打印调试信息
|
||||
# print(f"[{cls_name}] Score: {score:.2f}")
|
||||
# print(f" Center: ({x_center:.1f}, {y_center:.1f})")
|
||||
# print(f" BBox: Width={bbox_width}px, Height={bbox_height}px")
|
||||
# print(f" 左上角坐标: ({x1_box}px, {y1_box}px)")
|
||||
|
||||
# 绘制检测框(绿色),参数2表示线宽
|
||||
cv2.rectangle(image, (x1_box, y1_box), (x2_box, y2_box), (0, 255, 0), 2)
|
||||
|
||||
# 绘制框右上角顶部标签(类别名+得分)
|
||||
label = f"{cls_name} {score:.2f}"
|
||||
cv2.putText(image, label, (x2_box, y1_box - 60),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 0), 3)
|
||||
# 绘制框右上角顶部标签(类别名+得分)
|
||||
cv2.putText(image, f"L-corner-coor:{x1_box, y1_box}", (x1_box, y1_box - 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 0), 3)
|
||||
|
||||
# 在框底部左下角显示中心坐标和尺寸信息
|
||||
cv2.putText(image, f"Center: ({x_center:.1f}, {y_center:.1f})",
|
||||
(x1_box, y2_box + 70),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 255), 3)
|
||||
cv2.putText(image, f"Width: {bbox_width}px, Height: {bbox_height}px",
|
||||
(x1_box, y2_box + 150),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 2.5, (0, 255, 255), 3)
|
||||
|
||||
# 保存检测信息字典."L-corner-coor":(x1_box, y1_box)左上角坐标点
|
||||
result_dict = {
|
||||
"class": cls_name,
|
||||
"score": score,
|
||||
"center": (x_center, y_center),
|
||||
"bbox": (bbox_width, bbox_height),
|
||||
"L-corner-coor":(x1_box, y1_box)
|
||||
}
|
||||
result_info.append(result_dict)
|
||||
|
||||
return image, result_info
|
Reference in New Issue
Block a user