minh9972t12 commited on
Commit
8950d34
Β·
1 Parent(s): c534101

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -23
main.py CHANGED
@@ -43,13 +43,12 @@ MODEL_PATHS = {
43
  1: "models_small_version_2/best.pt", # Small v2 PT
44
  2: "models_medium/best.pt", # Medium v1 PT
45
  3: "models_medium_version_2/best.pt", # Medium v2 PT
46
- 4: "models_large/best.pt", # Large PT (no ONNX for large)
47
 
48
  # ONNX models (optimized with v1.19 + opset 21)
49
- 5: "models_small/best.onnx", # Small v1 ONNX
50
- 6: "models_small_version_2/best.onnx", # Small v2 ONNX
51
- 7: "models_medium/best.onnx", # Medium v1 ONNX
52
- 8: "models_medium_version_2/best.onnx" # Medium v2 ONNX
53
  }
54
 
55
  # Config paths - ONNX uses same config as PT version
@@ -58,19 +57,18 @@ CONFIG_PATHS = {
58
  1: "config_version2.yaml", # Small v2 PT
59
  2: "config.yaml", # Medium v1 PT
60
  3: "config_version2.yaml", # Medium v2 PT
61
- 4: "config.yaml", # Large PT
62
- 5: "config.yaml", # Small v1 ONNX
63
- 6: "config_version2.yaml", # Small v2 ONNX
64
- 7: "config.yaml", # Medium v1 ONNX
65
- 8: "config_version2.yaml" # Medium v2 ONNX
66
  }
67
 
68
  # Mapping from PT index to ONNX index
69
  PT_TO_ONNX_MAPPING = {
70
- 0: 5, # Small v1 -> ONNX
71
- 1: 6, # Small v2 -> ONNX
72
- 2: 7, # Medium v1 -> ONNX
73
- 3: 8, # Medium v2 -> ONNX
74
  4: None # Large has no ONNX
75
  }
76
 
@@ -80,7 +78,7 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
80
  Enhanced model selection with performance optimization info
81
  """
82
  # If user explicitly selects ONNX index (5..8) => use that ONNX with optimizations
83
- if select_models in (5, 6, 7, 8):
84
  onnx_path = Path(MODEL_PATHS.get(select_models, ""))
85
  if not onnx_path.exists():
86
  raise FileNotFoundError(
@@ -89,7 +87,7 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
89
  return select_models
90
 
91
  # Normalize to valid PT indices
92
- if select_models not in (0, 1, 2, 3, 4):
93
  select_models = 2 # default to medium v1
94
 
95
  # PT preferred for 0..4
@@ -112,12 +110,11 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
112
 
113
  def load_detector(select_models: int = 2, prefer_onnx: bool = True):
114
  """
115
- Load detector with optimized ONNX Runtime v1.19 support
116
 
117
  Args:
118
  select_models: Model selection
119
- - 0-4: PyTorch models (original logic)
120
- - 5-8: ONNX models (with maximum optimizations)
121
  prefer_onnx: Whether to prefer ONNX format for fallback
122
  """
123
  global detector, comparator, visualizer
@@ -147,14 +144,14 @@ def load_detector(select_models: int = 2, prefer_onnx: bool = True):
147
  # Log model info with optimization status
148
  model_type = "ONNX" if MODEL_PATHS[actual_model_index].endswith('.onnx') else "PyTorch"
149
  model_labels = [
150
- "Small v1", "Small v2", "Medium v1", "Medium v2", "Large",
151
  "Small v1 ONNX", "Small v2 ONNX", "Medium v1 ONNX", "Medium v2 ONNX"
152
  ]
153
 
154
  if 0 <= select_models < len(model_labels):
155
  model_size = model_labels[select_models]
156
  else:
157
- raise ValueError(f"select_models={select_models} must be 0-8")
158
 
159
  # Enhanced logging for optimization status
160
  optimization_status = "πŸš€ MAXIMUM OPTIMIZATIONS" if model_type == "ONNX" else "πŸ“¦ Standard PyTorch"
@@ -308,7 +305,7 @@ async def detect_single_image(
308
  """
309
  try:
310
  # Validate select_models
311
- if select_models not in list(range(0, 9)):
312
  raise HTTPException(status_code=400,
313
  detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
314
 
@@ -499,7 +496,7 @@ async def compare_vehicle_damages(
499
  """
500
  try:
501
  # Validate select_models
502
- if select_models not in list(range(0, 9)):
503
  raise HTTPException(status_code=400,
504
  detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
505
 
 
43
  1: "models_small_version_2/best.pt", # Small v2 PT
44
  2: "models_medium/best.pt", # Medium v1 PT
45
  3: "models_medium_version_2/best.pt", # Medium v2 PT
 
46
 
47
  # ONNX models (optimized with v1.19 + opset 21)
48
+ 4: "models_small/best.onnx", # Small v1 ONNX
49
+ 5: "models_small_version_2/best.onnx", # Small v2 ONNX
50
+ 6: "models_medium/best.onnx", # Medium v1 ONNX
51
+ 7: "models_medium_version_2/best.onnx" # Medium v2 ONNX
52
  }
53
 
54
  # Config paths - ONNX uses same config as PT version
 
57
  1: "config_version2.yaml", # Small v2 PT
58
  2: "config.yaml", # Medium v1 PT
59
  3: "config_version2.yaml", # Medium v2 PT
60
+ 4: "config.yaml", # Small v1 ONNX
61
+ 5: "config_version2.yaml", # Small v2 ONNX
62
+ 6: "config.yaml", # Medium v1 ONNX
63
+ 7: "config_version2.yaml" # Medium v2 ONNX
 
64
  }
65
 
66
  # Mapping from PT index to ONNX index
67
  PT_TO_ONNX_MAPPING = {
68
+ 0: 4, # Small v1 -> ONNX
69
+ 1: 5, # Small v2 -> ONNX
70
+ 2: 6, # Medium v1 -> ONNX
71
+ 3: 7, # Medium v2 -> ONNX
72
  4: None # Large has no ONNX
73
  }
74
 
 
78
  Enhanced model selection with performance optimization info
79
  """
80
  # If user explicitly selects ONNX index (5..8) => use that ONNX with optimizations
81
+ if select_models in (4, 5, 6, 7):
82
  onnx_path = Path(MODEL_PATHS.get(select_models, ""))
83
  if not onnx_path.exists():
84
  raise FileNotFoundError(
 
87
  return select_models
88
 
89
  # Normalize to valid PT indices
90
+ if select_models not in (0, 1, 2, 3):
91
  select_models = 2 # default to medium v1
92
 
93
  # PT preferred for 0..4
 
110
 
111
  def load_detector(select_models: int = 2, prefer_onnx: bool = True):
112
  """
113
+
114
 
115
  Args:
116
  select_models: Model selection
117
+
 
118
  prefer_onnx: Whether to prefer ONNX format for fallback
119
  """
120
  global detector, comparator, visualizer
 
144
  # Log model info with optimization status
145
  model_type = "ONNX" if MODEL_PATHS[actual_model_index].endswith('.onnx') else "PyTorch"
146
  model_labels = [
147
+ "Small v1", "Small v2", "Medium v1", "Medium v2",
148
  "Small v1 ONNX", "Small v2 ONNX", "Medium v1 ONNX", "Medium v2 ONNX"
149
  ]
150
 
151
  if 0 <= select_models < len(model_labels):
152
  model_size = model_labels[select_models]
153
  else:
154
+ raise ValueError(f"select_models={select_models} must be 0-7")
155
 
156
  # Enhanced logging for optimization status
157
  optimization_status = "πŸš€ MAXIMUM OPTIMIZATIONS" if model_type == "ONNX" else "πŸ“¦ Standard PyTorch"
 
305
  """
306
  try:
307
  # Validate select_models
308
+ if select_models not in list(range(0, 8)):
309
  raise HTTPException(status_code=400,
310
  detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
311
 
 
496
  """
497
  try:
498
  # Validate select_models
499
+ if select_models not in list(range(0, 8)):
500
  raise HTTPException(status_code=400,
501
  detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
502