Spaces:
Running
Running
Commit
Β·
8950d34
1
Parent(s):
c534101
Update main.py
Browse files
main.py
CHANGED
@@ -43,13 +43,12 @@ MODEL_PATHS = {
|
|
43 |
1: "models_small_version_2/best.pt", # Small v2 PT
|
44 |
2: "models_medium/best.pt", # Medium v1 PT
|
45 |
3: "models_medium_version_2/best.pt", # Medium v2 PT
|
46 |
-
4: "models_large/best.pt", # Large PT (no ONNX for large)
|
47 |
|
48 |
# ONNX models (optimized with v1.19 + opset 21)
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
}
|
54 |
|
55 |
# Config paths - ONNX uses same config as PT version
|
@@ -58,19 +57,18 @@ CONFIG_PATHS = {
|
|
58 |
1: "config_version2.yaml", # Small v2 PT
|
59 |
2: "config.yaml", # Medium v1 PT
|
60 |
3: "config_version2.yaml", # Medium v2 PT
|
61 |
-
4: "config.yaml", #
|
62 |
-
5: "
|
63 |
-
6: "
|
64 |
-
7: "
|
65 |
-
8: "config_version2.yaml" # Medium v2 ONNX
|
66 |
}
|
67 |
|
68 |
# Mapping from PT index to ONNX index
|
69 |
PT_TO_ONNX_MAPPING = {
|
70 |
-
0:
|
71 |
-
1:
|
72 |
-
2:
|
73 |
-
3:
|
74 |
4: None # Large has no ONNX
|
75 |
}
|
76 |
|
@@ -80,7 +78,7 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
|
|
80 |
Enhanced model selection with performance optimization info
|
81 |
"""
|
82 |
# If user explicitly selects ONNX index (5..8) => use that ONNX with optimizations
|
83 |
-
if select_models in (5, 6, 7
|
84 |
onnx_path = Path(MODEL_PATHS.get(select_models, ""))
|
85 |
if not onnx_path.exists():
|
86 |
raise FileNotFoundError(
|
@@ -89,7 +87,7 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
|
|
89 |
return select_models
|
90 |
|
91 |
# Normalize to valid PT indices
|
92 |
-
if select_models not in (0, 1, 2, 3
|
93 |
select_models = 2 # default to medium v1
|
94 |
|
95 |
# PT preferred for 0..4
|
@@ -112,12 +110,11 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
|
|
112 |
|
113 |
def load_detector(select_models: int = 2, prefer_onnx: bool = True):
|
114 |
"""
|
115 |
-
|
116 |
|
117 |
Args:
|
118 |
select_models: Model selection
|
119 |
-
|
120 |
-
- 5-8: ONNX models (with maximum optimizations)
|
121 |
prefer_onnx: Whether to prefer ONNX format for fallback
|
122 |
"""
|
123 |
global detector, comparator, visualizer
|
@@ -147,14 +144,14 @@ def load_detector(select_models: int = 2, prefer_onnx: bool = True):
|
|
147 |
# Log model info with optimization status
|
148 |
model_type = "ONNX" if MODEL_PATHS[actual_model_index].endswith('.onnx') else "PyTorch"
|
149 |
model_labels = [
|
150 |
-
"Small v1", "Small v2", "Medium v1", "Medium v2",
|
151 |
"Small v1 ONNX", "Small v2 ONNX", "Medium v1 ONNX", "Medium v2 ONNX"
|
152 |
]
|
153 |
|
154 |
if 0 <= select_models < len(model_labels):
|
155 |
model_size = model_labels[select_models]
|
156 |
else:
|
157 |
-
raise ValueError(f"select_models={select_models} must be 0-
|
158 |
|
159 |
# Enhanced logging for optimization status
|
160 |
optimization_status = "π MAXIMUM OPTIMIZATIONS" if model_type == "ONNX" else "π¦ Standard PyTorch"
|
@@ -308,7 +305,7 @@ async def detect_single_image(
|
|
308 |
"""
|
309 |
try:
|
310 |
# Validate select_models
|
311 |
-
if select_models not in list(range(0,
|
312 |
raise HTTPException(status_code=400,
|
313 |
detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
|
314 |
|
@@ -499,7 +496,7 @@ async def compare_vehicle_damages(
|
|
499 |
"""
|
500 |
try:
|
501 |
# Validate select_models
|
502 |
-
if select_models not in list(range(0,
|
503 |
raise HTTPException(status_code=400,
|
504 |
detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
|
505 |
|
|
|
43 |
1: "models_small_version_2/best.pt", # Small v2 PT
|
44 |
2: "models_medium/best.pt", # Medium v1 PT
|
45 |
3: "models_medium_version_2/best.pt", # Medium v2 PT
|
|
|
46 |
|
47 |
# ONNX models (optimized with v1.19 + opset 21)
|
48 |
+
4: "models_small/best.onnx", # Small v1 ONNX
|
49 |
+
5: "models_small_version_2/best.onnx", # Small v2 ONNX
|
50 |
+
6: "models_medium/best.onnx", # Medium v1 ONNX
|
51 |
+
7: "models_medium_version_2/best.onnx" # Medium v2 ONNX
|
52 |
}
|
53 |
|
54 |
# Config paths - ONNX uses same config as PT version
|
|
|
57 |
1: "config_version2.yaml", # Small v2 PT
|
58 |
2: "config.yaml", # Medium v1 PT
|
59 |
3: "config_version2.yaml", # Medium v2 PT
|
60 |
+
4: "config.yaml", # Small v1 ONNX
|
61 |
+
5: "config_version2.yaml", # Small v2 ONNX
|
62 |
+
6: "config.yaml", # Medium v1 ONNX
|
63 |
+
7: "config_version2.yaml" # Medium v2 ONNX
|
|
|
64 |
}
|
65 |
|
66 |
# Mapping from PT index to ONNX index
|
67 |
PT_TO_ONNX_MAPPING = {
|
68 |
+
0: 4, # Small v1 -> ONNX
|
69 |
+
1: 5, # Small v2 -> ONNX
|
70 |
+
2: 6, # Medium v1 -> ONNX
|
71 |
+
3: 7, # Medium v2 -> ONNX
|
72 |
4: None # Large has no ONNX
|
73 |
}
|
74 |
|
|
|
78 |
Enhanced model selection with performance optimization info
|
79 |
"""
|
80 |
# If user explicitly selects ONNX index (5..8) => use that ONNX with optimizations
|
81 |
+
if select_models in (4, 5, 6, 7):
|
82 |
onnx_path = Path(MODEL_PATHS.get(select_models, ""))
|
83 |
if not onnx_path.exists():
|
84 |
raise FileNotFoundError(
|
|
|
87 |
return select_models
|
88 |
|
89 |
# Normalize to valid PT indices
|
90 |
+
if select_models not in (0, 1, 2, 3):
|
91 |
select_models = 2 # default to medium v1
|
92 |
|
93 |
# PT preferred for 0..4
|
|
|
110 |
|
111 |
def load_detector(select_models: int = 2, prefer_onnx: bool = True):
|
112 |
"""
|
113 |
+
|
114 |
|
115 |
Args:
|
116 |
select_models: Model selection
|
117 |
+
|
|
|
118 |
prefer_onnx: Whether to prefer ONNX format for fallback
|
119 |
"""
|
120 |
global detector, comparator, visualizer
|
|
|
144 |
# Log model info with optimization status
|
145 |
model_type = "ONNX" if MODEL_PATHS[actual_model_index].endswith('.onnx') else "PyTorch"
|
146 |
model_labels = [
|
147 |
+
"Small v1", "Small v2", "Medium v1", "Medium v2",
|
148 |
"Small v1 ONNX", "Small v2 ONNX", "Medium v1 ONNX", "Medium v2 ONNX"
|
149 |
]
|
150 |
|
151 |
if 0 <= select_models < len(model_labels):
|
152 |
model_size = model_labels[select_models]
|
153 |
else:
|
154 |
+
raise ValueError(f"select_models={select_models} must be 0-7")
|
155 |
|
156 |
# Enhanced logging for optimization status
|
157 |
optimization_status = "π MAXIMUM OPTIMIZATIONS" if model_type == "ONNX" else "π¦ Standard PyTorch"
|
|
|
305 |
"""
|
306 |
try:
|
307 |
# Validate select_models
|
308 |
+
if select_models not in list(range(0, 8)):
|
309 |
raise HTTPException(status_code=400,
|
310 |
detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
|
311 |
|
|
|
496 |
"""
|
497 |
try:
|
498 |
# Validate select_models
|
499 |
+
if select_models not in list(range(0, 8)):
|
500 |
raise HTTPException(status_code=400,
|
501 |
detail="select_models must be 0-8 (0-4=PyTorch, 5-8=ONNX optimized)")
|
502 |
|