Spaces:
Running
Running
import onnxruntime | |
import numpy as np | |
import cv2 | |
from typing import Tuple, List, Union | |
from .base_onnx import BaseONNX | |
class Counting(BaseONNX): | |
UPPER_BOUND = 2560 | |
MULTIPLE_OF = 32 | |
def __init__(self, model_path): | |
super().__init__(model_path) | |
def preprocess_image(self, img: cv2.UMat, is_rgb: bool = True): | |
""" | |
预处理图像,包括颜色转换、缩放和标准化 | |
Args: | |
img: 输入图像,BGR或RGB格式 | |
is_rgb: 是否已经是RGB格式,默认为True | |
Returns: | |
预处理后的图像张量,形状为(1, 3, H, W) | |
""" | |
if not is_rgb: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
else: | |
img = img | |
img_copy = img.copy() | |
# Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
# 转换为 float32 类型 | |
img = img.astype(np.float32) | |
# 除以 255.0 | |
img /= 255.0 | |
# 减去均值 | |
img -= np.array([0.485, 0.456, 0.406]) | |
# 除以标准差 | |
img /= np.array([0.229, 0.224, 0.225]) | |
# 检查图像大小是否超过上限 | |
origin_h, origin_w = img.shape[:2] | |
max_size = max(origin_h, origin_w) | |
if max_size > self.UPPER_BOUND: | |
scale = self.UPPER_BOUND / max_size | |
img = cv2.resize(img, None, fx=scale, fy=scale) | |
h, w = img.shape[:2] | |
# 确保图像尺寸是32的倍数 | |
new_h = (h // self.MULTIPLE_OF) * self.MULTIPLE_OF | |
new_w = (w // self.MULTIPLE_OF) * self.MULTIPLE_OF | |
if h != new_h or w != new_w: | |
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
img_copy = cv2.resize(img_copy, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
# 调整维度顺序 (H,W,C) -> (C,H,W) | |
img = np.transpose(img, (2, 0, 1)) | |
# 添加 batch 维度 | |
img = np.expand_dims(img, axis=0) | |
return img, img_copy | |
def run_inference(self, image: np.ndarray) -> any: | |
""" | |
Run inference on the image. | |
Args: | |
image (np.ndarray): The image to run inference on. | |
Returns: | |
tuple: A tuple containing the detection results and labels. | |
""" | |
# 运行推理 | |
result = self.session.run(None, {self.input_name: image}) | |
return result | |
def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[float], List[List[float]],]: | |
""" | |
Predict the detection results of the image. | |
Args: | |
image (cv2.UMat, str): The image to predict. | |
Returns: | |
""" | |
if isinstance(image, str): | |
img_bgr = cv2.imread(image) | |
is_rgb = False | |
else: | |
img_bgr = image.copy() | |
processed_image, _ = self.preprocess_image(img_bgr, is_rgb) | |
scores, points = self.run_inference(processed_image) | |
return scores, points | |
def draw_pred(self, image: cv2.UMat, scores: List[float], points: List[List[float]]) -> cv2.UMat: | |
marked_img = np.array(image.copy()) | |
for point, score in zip(points, scores): | |
# 确保点坐标在合理范围内 | |
x, y = int(point[0]), int(point[1]) | |
if 0 <= x < marked_img.shape[1] and 0 <= y < marked_img.shape[0]: | |
cv2.circle(marked_img, (x, y), 5, (255, 0, 0), -1) | |
return marked_img | |