peopleCounting / counting /counting.py
yolo12138's picture
update preprocess
12c57ee
import onnxruntime
import numpy as np
import cv2
from typing import Tuple, List, Union
from .base_onnx import BaseONNX
class Counting(BaseONNX):
UPPER_BOUND = 2560
MULTIPLE_OF = 32
def __init__(self, model_path):
super().__init__(model_path)
def preprocess_image(self, img: cv2.UMat, is_rgb: bool = True):
"""
预处理图像,包括颜色转换、缩放和标准化
Args:
img: 输入图像,BGR或RGB格式
is_rgb: 是否已经是RGB格式,默认为True
Returns:
预处理后的图像张量,形状为(1, 3, H, W)
"""
if not is_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
else:
img = img
img_copy = img.copy()
# Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# 转换为 float32 类型
img = img.astype(np.float32)
# 除以 255.0
img /= 255.0
# 减去均值
img -= np.array([0.485, 0.456, 0.406])
# 除以标准差
img /= np.array([0.229, 0.224, 0.225])
# 检查图像大小是否超过上限
origin_h, origin_w = img.shape[:2]
max_size = max(origin_h, origin_w)
if max_size > self.UPPER_BOUND:
scale = self.UPPER_BOUND / max_size
img = cv2.resize(img, None, fx=scale, fy=scale)
h, w = img.shape[:2]
# 确保图像尺寸是32的倍数
new_h = (h // self.MULTIPLE_OF) * self.MULTIPLE_OF
new_w = (w // self.MULTIPLE_OF) * self.MULTIPLE_OF
if h != new_h or w != new_w:
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
img_copy = cv2.resize(img_copy, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
# 调整维度顺序 (H,W,C) -> (C,H,W)
img = np.transpose(img, (2, 0, 1))
# 添加 batch 维度
img = np.expand_dims(img, axis=0)
return img, img_copy
def run_inference(self, image: np.ndarray) -> any:
"""
Run inference on the image.
Args:
image (np.ndarray): The image to run inference on.
Returns:
tuple: A tuple containing the detection results and labels.
"""
# 运行推理
result = self.session.run(None, {self.input_name: image})
return result
def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[float], List[List[float]],]:
"""
Predict the detection results of the image.
Args:
image (cv2.UMat, str): The image to predict.
Returns:
"""
if isinstance(image, str):
img_bgr = cv2.imread(image)
is_rgb = False
else:
img_bgr = image.copy()
processed_image, _ = self.preprocess_image(img_bgr, is_rgb)
scores, points = self.run_inference(processed_image)
return scores, points
def draw_pred(self, image: cv2.UMat, scores: List[float], points: List[List[float]]) -> cv2.UMat:
marked_img = np.array(image.copy())
for point, score in zip(points, scores):
# 确保点坐标在合理范围内
x, y = int(point[0]), int(point[1])
if 0 <= x < marked_img.shape[1] and 0 <= y < marked_img.shape[0]:
cv2.circle(marked_img, (x, y), 5, (255, 0, 0), -1)
return marked_img