Spaces:
Sleeping
Sleeping
import onnxruntime | |
import numpy as np | |
import cv2 | |
from typing import Tuple, List, Union | |
from .base_onnx import BaseONNX | |
def center_crop(image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray: | |
""" | |
Center crop the image to the target size. | |
Args: | |
image (np.ndarray): The input image. | |
target_size (Tuple[int, int]): The desired output size (height, width). | |
Returns: | |
np.ndarray: The cropped image. | |
""" | |
h, w, _ = image.shape | |
target_w, target_h = target_size | |
center_x = w // 2 | |
center_y = h // 2 | |
start_x = int(center_x - target_w // 2) | |
start_y = int(center_y - target_h // 2) | |
cropped_image = image[start_y:start_y + target_h, start_x:start_x + target_w] | |
return cropped_image | |
dict_cate_names = { | |
'point': '.', | |
'other': 'x', | |
'red_king': 'K', | |
'red_advisor': 'A', | |
'red_bishop': 'B', | |
'red_knight': 'N', | |
'red_rook': 'R', | |
'red_cannon': 'C', | |
'red_pawn': 'P', | |
'black_king': 'k', | |
'black_advisor': 'a', | |
'black_bishop': 'b', | |
'black_knight': 'n', | |
'black_rook': 'r', | |
'black_cannon': 'c', | |
'black_pawn': 'p', | |
} | |
class FULL_CLASSIFIER_ONNX(BaseONNX): | |
label_2_short = dict_cate_names | |
classes_labels = list(dict_cate_names.keys()) | |
def __init__(self, | |
model_path, | |
# 输入图片大小 | |
input_size=(280, 315), # (w, h) | |
# 图片裁剪大小 | |
crop_size=(400, 450), # (w, h) | |
): | |
super().__init__(model_path, input_size) | |
self.crop_size = crop_size | |
def preprocess_image(self, img_bgr: cv2.UMat, is_rgb: bool = True): | |
if not is_rgb: | |
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
else: | |
img_rgb = img_bgr | |
if img_rgb.shape[:2] != self.crop_size: | |
# 调整图片大小 执行 center crop | |
img_rgb = center_crop(img_rgb, self.crop_size) # dst_size = (w, h) | |
# resize 到 input_size | |
img_rgb = cv2.resize(img_rgb, self.input_size) | |
# normalize mean and std | |
img = (img_rgb - np.array([ 123.675, 116.28, 103.53])) / np.array([58.395, 57.12, 57.375]) | |
img = img.astype(np.float32) | |
# 转换为浮点型并归一化 | |
# img = img.astype(np.float32) / 255.0 | |
# 调整维度顺序 (H,W,C) -> (C,H,W) | |
img = np.transpose(img, (2, 0, 1)) | |
# 添加 batch 维度 | |
img = np.expand_dims(img, axis=0) | |
return img | |
def run_inference(self, image: np.ndarray) -> np.ndarray: | |
""" | |
Run inference on the image. | |
Args: | |
image (np.ndarray): The image to run inference on. | |
Returns: | |
tuple: A tuple containing the detection results and labels. | |
""" | |
# 运行推理 | |
outputs, = self.session.run(None, {self.input_name: image}) | |
return outputs | |
def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[List[str]], List[List[str]], List[List[float]], str]: | |
""" | |
Predict the detection results of the image. | |
Args: | |
image (cv2.UMat, str): The image to predict. | |
Returns: | |
""" | |
if isinstance(image, str): | |
img_bgr = cv2.imread(image) | |
is_rgb = False | |
else: | |
img_bgr = image.copy() | |
image = self.preprocess_image(img_bgr, is_rgb) | |
labels = self.run_inference(image) | |
# 校验 labels 的 shape | |
assert labels.shape[1:] == (90, 16) | |
# shape (90, 16) | |
first_batch_labels = labels[0] | |
# 获取置信度最高的标签 | |
# list[int] | |
label_indexes = np.argmax(first_batch_labels, axis=-1).tolist() | |
# 将标签索引转换为标签 | |
# list[str] | |
label_names = [self.classes_labels[index] for index in label_indexes] | |
# list[str] | |
label_short = [self.label_2_short[name] for name in label_names] | |
# 获取置信度, 根据 first_batch_labels 和 label_indexes | |
confidence = first_batch_labels[np.arange(first_batch_labels.shape[0]), label_indexes] | |
label_names_10x9 = [label_names[i*9:(i+1)*9] for i in range(10)] | |
label_short_10x9 = [label_short[i*9:(i+1)*9] for i in range(10)] | |
confidence_10x9 = [confidence[i*9:(i+1)*9] for i in range(10)] | |
layout_str = "\n".join(["".join(row) for row in label_short_10x9]) | |
return label_names_10x9, label_short_10x9, confidence_10x9, layout_str | |
def draw_pred(self, image: cv2.UMat, label_index: int, label_name: str, label_short: str, confidence: float) -> cv2.UMat: | |
# 在图像上绘制预测结果 | |
cv2.putText(image, f"{label_short} {confidence:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
return image | |
def draw_pred_with_result(self, image: cv2.UMat, results: List[Tuple[int, str, str, float]], cells_xyxy: np.ndarray, is_rgb: bool = True) -> cv2.UMat: | |
assert len(results) == cells_xyxy.shape[0] | |
if not is_rgb: | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
for i, (label_index, label_name, label_short, confidence) in enumerate(results): | |
# 确保坐标是整数类型 | |
x1, y1, x2, y2 = map(int, cells_xyxy[i]) | |
if label_name.startswith('red'): | |
color = (180, 105, 255) # 粉红色 | |
elif label_name.startswith('black'): | |
color = (0, 100, 50) # 黑色 | |
else: | |
color = (0, 0, 255) # 蓝色 | |
if confidence < 0.5: | |
# yellow | |
color = (255, 255, 0) | |
# confidence:.2f 仅保留两位小数 移除 | |
label_str = f"{label_short} {confidence:.2f}" if confidence < 0.9 else f"{label_short}" | |
cv2.putText(image, label_str, (x1 + 8, y2 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1) | |
return image | |