import onnxruntime import numpy as np import cv2 from typing import Tuple, List, Union from .base_onnx import BaseONNX class Counting(BaseONNX): UPPER_BOUND = 2560 MULTIPLE_OF = 32 def __init__(self, model_path): super().__init__(model_path) def preprocess_image(self, img: cv2.UMat, is_rgb: bool = True): """ 预处理图像,包括颜色转换、缩放和标准化 Args: img: 输入图像,BGR或RGB格式 is_rgb: 是否已经是RGB格式,默认为True Returns: 预处理后的图像张量,形状为(1, 3, H, W) """ if not is_rgb: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) else: img = img img_copy = img.copy() # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 转换为 float32 类型 img = img.astype(np.float32) # 除以 255.0 img /= 255.0 # 减去均值 img -= np.array([0.485, 0.456, 0.406]) # 除以标准差 img /= np.array([0.229, 0.224, 0.225]) # 检查图像大小是否超过上限 origin_h, origin_w = img.shape[:2] max_size = max(origin_h, origin_w) if max_size > self.UPPER_BOUND: scale = self.UPPER_BOUND / max_size img = cv2.resize(img, None, fx=scale, fy=scale) h, w = img.shape[:2] # 确保图像尺寸是32的倍数 new_h = (h // self.MULTIPLE_OF) * self.MULTIPLE_OF new_w = (w // self.MULTIPLE_OF) * self.MULTIPLE_OF if h != new_h or w != new_w: img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) img_copy = cv2.resize(img_copy, (new_w, new_h), interpolation=cv2.INTER_LINEAR) # 调整维度顺序 (H,W,C) -> (C,H,W) img = np.transpose(img, (2, 0, 1)) # 添加 batch 维度 img = np.expand_dims(img, axis=0) return img, img_copy def run_inference(self, image: np.ndarray) -> any: """ Run inference on the image. Args: image (np.ndarray): The image to run inference on. Returns: tuple: A tuple containing the detection results and labels. """ # 运行推理 result = self.session.run(None, {self.input_name: image}) return result def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[float], List[List[float]],]: """ Predict the detection results of the image. Args: image (cv2.UMat, str): The image to predict. Returns: """ if isinstance(image, str): img_bgr = cv2.imread(image) is_rgb = False else: img_bgr = image.copy() processed_image, _ = self.preprocess_image(img_bgr, is_rgb) scores, points = self.run_inference(processed_image) return scores, points def draw_pred(self, image: cv2.UMat, scores: List[float], points: List[List[float]]) -> cv2.UMat: marked_img = np.array(image.copy()) for point, score in zip(points, scores): # 确保点坐标在合理范围内 x, y = int(point[0]), int(point[1]) if 0 <= x < marked_img.shape[1] and 0 <= y < marked_img.shape[0]: cv2.circle(marked_img, (x, y), 5, (255, 0, 0), -1) return marked_img