yolo12138 commited on
Commit
fbf373d
·
1 Parent(s): 20d8e4e

Reset git-lfs tracking

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ coverage
2
+ .DS_Store
3
+ __pycache__
4
+ *.pyc
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from onnx.counting import Counting
4
+
5
+
6
+ counting = Counting("onnx/apgcc.onnx")
7
+
8
+
9
+ def filter_with_threshold(scores, points, threshold):
10
+ filtered_scores = []
11
+ filtered_points = []
12
+ for score, point in zip(scores, points):
13
+ if score > threshold:
14
+ filtered_scores.append(score)
15
+ filtered_points.append(point)
16
+ return filtered_scores, filtered_points
17
+
18
+
19
+ def pred(img, threshold):
20
+ # 计算处理时间
21
+ start_at = time.time()
22
+
23
+ scores, points = counting.pred(img, is_rgb=True)
24
+
25
+ scores, points = filter_with_threshold(scores, points, threshold)
26
+
27
+ draw = counting.draw_pred(img, scores, points)
28
+
29
+ elapsed_time = time.time() - start_at
30
+ use_time = f"use: {elapsed_time:.3f}s"
31
+
32
+
33
+ total = len(points)
34
+
35
+ return draw, total, use_time
36
+
37
+
38
+ model_description = """
39
+ # APGCC People Counting
40
+
41
+ APGCC (Adaptive Perspective Guidance for Crowd Counting)
42
+
43
+ ### based on
44
+
45
+ - [APGCC](https://github.com/AaronCIH/APGCC)
46
+ """
47
+
48
+ demo = gr.Interface(
49
+ description=model_description,
50
+ fn=pred,
51
+ inputs=["image",
52
+ gr.Slider(0, 1, 0.5, label="Threshold")],
53
+ outputs=[
54
+ "image",
55
+ gr.Number(label="Count"),
56
+ gr.Textbox(label="useTime"),
57
+ ],
58
+ examples=[
59
+ ["examples/crowd-001.jpg", 0.5],
60
+ ["examples/crowd-002.jpg", 0.5],
61
+ ["examples/image.png", 0.5],
62
+ ["examples/image2.png", 0.5],
63
+ ["examples/few-001.png", 0.5],
64
+ ])
65
+
66
+ demo.launch()
examples/crowd-001.jpg ADDED

Git LFS Details

  • SHA256: fffb92cbe6928641dd501724e52f8679df5c434ab16984c81d8833e3fd36f2b7
  • Pointer size: 131 Bytes
  • Size of remote file: 410 kB
examples/crowd-002.jpg ADDED

Git LFS Details

  • SHA256: ae88bdb130de7b1b498a24fd22be972a33c25ddcc308cba5dacd5d726a0981f6
  • Pointer size: 131 Bytes
  • Size of remote file: 760 kB
examples/few-001.png ADDED

Git LFS Details

  • SHA256: 6893e80042edbda26baaa3319bd46b4d22b229ba2e7f8cc266cf38d2ffc35255
  • Pointer size: 131 Bytes
  • Size of remote file: 299 kB
examples/image.png ADDED

Git LFS Details

  • SHA256: afa8c18919f5f8237d9837fdbc67dcfd42d4db94380d10a01d80ee3960b75129
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
examples/image2.png ADDED

Git LFS Details

  • SHA256: c1ae4f5270afdadca3622fe091d3d6fcab0bf568ddfb630c62ba32d3b20831c5
  • Pointer size: 132 Bytes
  • Size of remote file: 4.25 MB
onnx/apgcc.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85904b6137f2fd7f94fe6acc1d996f84d5a76dfeff904d10336584e5a4db68eb
3
+ size 71659567
onnx/base_onnx.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ import cv2
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Tuple, Union, List
6
+
7
+ class BaseONNX(ABC):
8
+ def __init__(self, model_path: str):
9
+ """初始化ONNX模型基类
10
+
11
+ Args:
12
+ model_path (str): ONNX模型路径
13
+ input_size (tuple): 模型输入尺寸 (width, height)
14
+ """
15
+ self.session = onnxruntime.InferenceSession(model_path)
16
+ self.input_name = self.session.get_inputs()[0].name
17
+
18
+ def load_image(self, image: Union[cv2.UMat, str]) -> cv2.UMat:
19
+ """加载图像
20
+
21
+ Args:
22
+ image (Union[cv2.UMat, str]): 图像路径或cv2图像对象
23
+
24
+ Returns:
25
+ cv2.UMat: 加载的图像
26
+ """
27
+ if isinstance(image, str):
28
+ return cv2.imread(image)
29
+ return image.copy()
30
+
31
+ @abstractmethod
32
+ def preprocess_image(self, img_bgr: cv2.UMat, *args, **kwargs) -> np.ndarray:
33
+ """图像预处理抽象方法
34
+
35
+ Args:
36
+ img_bgr (cv2.UMat): BGR格式的输入图像
37
+
38
+ Returns:
39
+ np.ndarray: 预处理后的图像
40
+ """
41
+ pass
42
+
43
+ @abstractmethod
44
+ def run_inference(self, image: np.ndarray) -> Any:
45
+ """运行推理的抽象方法
46
+
47
+ Args:
48
+ image (np.ndarray): 预处理后的输入图像
49
+
50
+ Returns:
51
+ Any: 模型输出结果
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ def pred(self, image: Union[cv2.UMat, str], *args, **kwargs) -> Any:
57
+ """预测的抽象方法
58
+
59
+ Args:
60
+ image (Union[cv2.UMat, str]): 输入图像或图像路径
61
+
62
+ Returns:
63
+ Any: 预测结果
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ def draw_pred(self, img: cv2.UMat, *args, **kwargs) -> cv2.UMat:
69
+ """绘制预测结果的抽象方法
70
+
71
+ Args:
72
+ img (cv2.UMat): 要绘制的图像
73
+
74
+ Returns:
75
+ cv2.UMat: 绘制结果后的图像
76
+ """
77
+ pass
78
+
79
+
80
+ def check_images_list(self, images: List[Union[cv2.UMat, str, np.ndarray]]):
81
+ """
82
+ 检查图像列表是否有效
83
+ """
84
+ for image in images:
85
+ if not isinstance(image, cv2.UMat) and not isinstance(image, str) and not isinstance(image, np.ndarray):
86
+ raise ValueError("The images must be a list of cv2.UMat or str or np.ndarray.")
87
+
onnx/counting.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ import cv2
4
+ from typing import Tuple, List, Union
5
+ from .base_onnx import BaseONNX
6
+
7
+ class Counting(BaseONNX):
8
+ UPPER_BOUND = 2560
9
+ MULTIPLE_OF = 32
10
+
11
+ def __init__(self, model_path):
12
+ super().__init__(model_path)
13
+
14
+
15
+ def preprocess_image(self, img: cv2.UMat, is_rgb: bool = True):
16
+ """
17
+ 预处理图像,包括颜色转换、缩放和标准化
18
+
19
+ Args:
20
+ img: 输入图像,BGR或RGB格式
21
+ is_rgb: 是否已经是RGB格式,默认为True
22
+
23
+ Returns:
24
+ 预处理后的图像张量,形状为(1, 3, H, W)
25
+ """
26
+ if not is_rgb:
27
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28
+ else:
29
+ img = img
30
+
31
+ # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
+ # 转换为 float32 类型
33
+ img = img.astype(np.float32)
34
+ # 除以 255.0
35
+ img /= 255.0
36
+ # 减去均值
37
+ img -= np.array([0.485, 0.456, 0.406])
38
+ # 除以标准差
39
+ img /= np.array([0.229, 0.224, 0.225])
40
+
41
+
42
+ # 检查图像大小是否超过上限
43
+ origin_h, origin_w = img.shape[:2]
44
+ max_size = max(origin_h, origin_w)
45
+
46
+ if max_size > self.UPPER_BOUND:
47
+ scale = self.UPPER_BOUND / max_size
48
+ img = cv2.resize(img, None, fx=scale, fy=scale)
49
+
50
+
51
+ h, w = img.shape[:2]
52
+
53
+ # 确保图像尺寸是32的倍数
54
+ new_h = (h // self.MULTIPLE_OF) * self.MULTIPLE_OF
55
+ new_w = (w // self.MULTIPLE_OF) * self.MULTIPLE_OF
56
+
57
+ if h != new_h or w != new_w:
58
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
59
+
60
+ # 调整维度顺序 (H,W,C) -> (C,H,W)
61
+ img = np.transpose(img, (2, 0, 1))
62
+
63
+ # 添加 batch 维度
64
+ img = np.expand_dims(img, axis=0)
65
+
66
+ return img
67
+
68
+ def run_inference(self, image: np.ndarray) -> any:
69
+ """
70
+ Run inference on the image.
71
+
72
+ Args:
73
+ image (np.ndarray): The image to run inference on.
74
+
75
+ Returns:
76
+ tuple: A tuple containing the detection results and labels.
77
+ """
78
+
79
+ # 运行推理
80
+ result = self.session.run(None, {self.input_name: image})
81
+
82
+ return result
83
+
84
+ def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[float], List[List[float]],]:
85
+ """
86
+ Predict the detection results of the image.
87
+
88
+ Args:
89
+ image (cv2.UMat, str): The image to predict.
90
+
91
+ Returns:
92
+
93
+ """
94
+ if isinstance(image, str):
95
+ img_bgr = cv2.imread(image)
96
+ is_rgb = False
97
+ else:
98
+ img_bgr = image.copy()
99
+
100
+ processed_image = self.preprocess_image(img_bgr, is_rgb)
101
+
102
+ scores, points = self.run_inference(processed_image)
103
+
104
+ return scores, points
105
+
106
+
107
+ def draw_pred(self, image: cv2.UMat, scores: List[float], points: List[List[float]]) -> cv2.UMat:
108
+
109
+ marked_img = np.array(image.copy())
110
+
111
+ for point, score in zip(points, scores):
112
+ # 确保点坐标在合理范围内
113
+ x, y = int(point[0]), int(point[1])
114
+ if 0 <= x < marked_img.shape[1] and 0 <= y < marked_img.shape[0]:
115
+ cv2.circle(marked_img, (x, y), 5, (255, 0, 0), -1)
116
+
117
+ return marked_img
118
+
119
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ opencv-python
2
+ onnxruntime