yolo12138 commited on
Commit
12c57ee
·
1 Parent(s): cd5089e

update preprocess

Browse files
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
  import time
3
- from onnx.counting import Counting
4
 
5
 
6
- counting = Counting("onnx/apgcc.onnx")
7
 
8
 
9
  def filter_with_threshold(scores, points, threshold):
@@ -19,12 +19,14 @@ def filter_with_threshold(scores, points, threshold):
19
  def pred(img, threshold):
20
  # 计算处理时间
21
  start_at = time.time()
 
 
22
 
23
- scores, points = counting.pred(img, is_rgb=True)
24
 
25
  scores, points = filter_with_threshold(scores, points, threshold)
26
 
27
- draw = counting.draw_pred(img, scores, points)
28
 
29
  elapsed_time = time.time() - start_at
30
  use_time = f"use: {elapsed_time:.3f}s"
 
1
  import gradio as gr
2
  import time
3
+ from counting.counting import Counting
4
 
5
 
6
+ counting = Counting("counting/apgcc.onnx")
7
 
8
 
9
  def filter_with_threshold(scores, points, threshold):
 
19
  def pred(img, threshold):
20
  # 计算处理时间
21
  start_at = time.time()
22
+
23
+ processed_image, processed_image_original = counting.preprocess_image(img, True)
24
 
25
+ scores, points = counting.run_inference(processed_image)
26
 
27
  scores, points = filter_with_threshold(scores, points, threshold)
28
 
29
+ draw = counting.draw_pred(processed_image_original, scores, points)
30
 
31
  elapsed_time = time.time() - start_at
32
  use_time = f"use: {elapsed_time:.3f}s"
{onnx → counting}/apgcc.onnx RENAMED
File without changes
{onnx → counting}/base_onnx.py RENAMED
File without changes
{onnx → counting}/counting.py RENAMED
@@ -27,6 +27,8 @@ class Counting(BaseONNX):
27
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28
  else:
29
  img = img
 
 
30
 
31
  # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
  # 转换为 float32 类型
@@ -56,14 +58,14 @@ class Counting(BaseONNX):
56
 
57
  if h != new_h or w != new_w:
58
  img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
59
-
60
  # 调整维度顺序 (H,W,C) -> (C,H,W)
61
  img = np.transpose(img, (2, 0, 1))
62
 
63
  # 添加 batch 维度
64
  img = np.expand_dims(img, axis=0)
65
 
66
- return img
67
 
68
  def run_inference(self, image: np.ndarray) -> any:
69
  """
@@ -97,7 +99,7 @@ class Counting(BaseONNX):
97
  else:
98
  img_bgr = image.copy()
99
 
100
- processed_image = self.preprocess_image(img_bgr, is_rgb)
101
 
102
  scores, points = self.run_inference(processed_image)
103
 
 
27
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28
  else:
29
  img = img
30
+
31
+ img_copy = img.copy()
32
 
33
  # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
34
  # 转换为 float32 类型
 
58
 
59
  if h != new_h or w != new_w:
60
  img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
61
+ img_copy = cv2.resize(img_copy, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
62
  # 调整维度顺序 (H,W,C) -> (C,H,W)
63
  img = np.transpose(img, (2, 0, 1))
64
 
65
  # 添加 batch 维度
66
  img = np.expand_dims(img, axis=0)
67
 
68
+ return img, img_copy
69
 
70
  def run_inference(self, image: np.ndarray) -> any:
71
  """
 
99
  else:
100
  img_bgr = image.copy()
101
 
102
+ processed_image, _ = self.preprocess_image(img_bgr, is_rgb)
103
 
104
  scores, points = self.run_inference(processed_image)
105