minh9972t12 commited on
Commit
52f96c3
·
verified ·
1 Parent(s): 5f69192

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -25
main.py CHANGED
@@ -1,7 +1,5 @@
1
  import io
2
  from typing import List, Dict
3
-
4
- import torch
5
  import uvicorn
6
  import numpy as np
7
  import uuid
@@ -18,18 +16,9 @@ from src.comparison import DamageComparator
18
  from src.visualization import DamageVisualizer
19
  from pathlib import Path
20
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
21
  import gc
22
 
23
-
24
- _initialized = False
25
- def init_globals_once(select_models=2, prefer_onnx=True):
26
- global detector, comparator, visualizer, _initialized
27
- if _initialized:
28
- return
29
- detector = load_detector(select_models, prefer_onnx)
30
- comparator = DamageComparator(config_path=CONFIG_PATHS[select_models])
31
- visualizer = DamageVisualizer(config_path=CONFIG_PATHS[select_models])
32
- _initialized = True
33
  app = FastAPI(
34
  title="Car Damage Detection API",
35
  description="YOLOv11-based car damage detection with DINOv2 ReID (Memory Optimized)",
@@ -105,6 +94,7 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
105
  if not onnx_path.exists():
106
  raise FileNotFoundError(
107
  f"Requested ONNX model index {select_models} not found at {MODEL_PATHS.get(select_models)}")
 
108
  return select_models
109
 
110
  # Normalize to valid PT indices
@@ -114,6 +104,7 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
114
  # PT preferred for 0..4
115
  pt_path = Path(MODEL_PATHS.get(select_models, ""))
116
  if pt_path.exists():
 
117
  return select_models
118
 
119
  # If PT not found and prefer_onnx: fallback to ONNX with optimizations
@@ -121,6 +112,7 @@ def get_optimal_model_index(select_models: int, prefer_onnx: bool = True) -> int
121
  if prefer_onnx and onnx_index is not None:
122
  onnx_path = Path(MODEL_PATHS.get(onnx_index, ""))
123
  if onnx_path.exists():
 
124
  return onnx_index
125
 
126
  # No suitable file found
@@ -180,12 +172,15 @@ def load_detector(select_models: int = 2, prefer_onnx: bool = True):
180
  else:
181
  raise ValueError(f"select_models={select_models} must be 0-11")
182
 
183
- optimization_status = "MAXIMUM OPTIMIZATIONS" if model_type == "ONNX" else "Standard PyTorch"
 
 
184
 
185
  return detector
186
 
187
 
188
  # Initialize default detector with medium model (preferring ONNX for performance)
 
189
  detector = load_detector(2, prefer_onnx=True)
190
  comparator = DamageComparator(config_path=CONFIG_PATHS[2])
191
  visualizer = DamageVisualizer(config_path=CONFIG_PATHS[2])
@@ -530,14 +525,11 @@ async def compare_vehicle_damages(
530
  before_2: UploadFile = File(...),
531
  before_3: UploadFile = File(...),
532
  before_4: UploadFile = File(...),
533
-
534
  # After delivery images (6 positions)
535
  after_1: UploadFile = File(...),
536
  after_2: UploadFile = File(...),
537
  after_3: UploadFile = File(...),
538
  after_4: UploadFile = File(...),
539
-
540
- # Model selection
541
  select_models: int = Form(2),
542
  prefer_onnx: bool = Form(True)
543
  ):
@@ -578,7 +570,7 @@ async def compare_vehicle_damages(
578
  all_after_detections = []
579
 
580
  # Use ThreadPoolExecutor to share memory (avoid OOM)
581
- print(f"Processing {len(before_images)} image pairs...")
582
 
583
  with ThreadPoolExecutor(max_workers=2) as executor: # Limit workers to avoid memory issues
584
  futures = [
@@ -619,6 +611,7 @@ async def compare_vehicle_damages(
619
  position_results.sort(key=lambda x: int(list(x.keys())[0].split('_')[1]))
620
 
621
  # Deduplicate BEFORE damages across all 6 views using DINOv2
 
622
  unique_before = comparator.deduplicate_detections_across_views(
623
  all_before_detections, all_before_images
624
  )
@@ -629,8 +622,8 @@ async def compare_vehicle_damages(
629
  )
630
 
631
  print(
632
- f"Before: {sum(len(d['boxes']) for d in all_before_detections)} detections → {len(unique_before)} unique")
633
- print(f"After: {sum(len(d['boxes']) for d in all_after_detections)} detections → {len(unique_after)} unique")
634
 
635
  # Determine overall case with deduplication
636
  actual_new_damages = max(0, len(unique_after) - len(unique_before))
@@ -717,12 +710,11 @@ async def compare_vehicle_damages(
717
 
718
 
719
  if __name__ == "__main__":
720
- import os
721
- uvicorn.run(
722
  "main:app",
723
  host="0.0.0.0",
724
- port=int(os.environ.get("PORT", 7860)),
725
- workers=2,
726
- reload=True,
727
  log_level="info"
728
- )
 
1
  import io
2
  from typing import List, Dict
 
 
3
  import uvicorn
4
  import numpy as np
5
  import uuid
 
16
  from src.visualization import DamageVisualizer
17
  from pathlib import Path
18
  from concurrent.futures import ThreadPoolExecutor, as_completed
19
+ import torch
20
  import gc
21
 
 
 
 
 
 
 
 
 
 
 
22
  app = FastAPI(
23
  title="Car Damage Detection API",
24
  description="YOLOv11-based car damage detection with DINOv2 ReID (Memory Optimized)",
 
94
  if not onnx_path.exists():
95
  raise FileNotFoundError(
96
  f"Requested ONNX model index {select_models} not found at {MODEL_PATHS.get(select_models)}")
97
+ print(f"🚀 Selected ONNX model with MAXIMUM optimizations: {MODEL_PATHS[select_models]}")
98
  return select_models
99
 
100
  # Normalize to valid PT indices
 
104
  # PT preferred for 0..4
105
  pt_path = Path(MODEL_PATHS.get(select_models, ""))
106
  if pt_path.exists():
107
+ print(f"📦 Selected PyTorch model: {MODEL_PATHS[select_models]}")
108
  return select_models
109
 
110
  # If PT not found and prefer_onnx: fallback to ONNX with optimizations
 
112
  if prefer_onnx and onnx_index is not None:
113
  onnx_path = Path(MODEL_PATHS.get(onnx_index, ""))
114
  if onnx_path.exists():
115
+ print(f"PT not found at {pt_path}, falling back to optimized ONNX {MODEL_PATHS[onnx_index]}")
116
  return onnx_index
117
 
118
  # No suitable file found
 
172
  else:
173
  raise ValueError(f"select_models={select_models} must be 0-11")
174
 
175
+ optimization_status = "🚀 MAXIMUM OPTIMIZATIONS" if model_type == "ONNX" else "📦 Standard PyTorch"
176
+ print(f"Loaded {model_size} model in {model_type} format - {optimization_status}")
177
+ print(f"✅ DINOv2 ReID enabled for damage comparison")
178
 
179
  return detector
180
 
181
 
182
  # Initialize default detector with medium model (preferring ONNX for performance)
183
+ print("🚀 Initializing API with optimized ONNX Runtime and DINOv2 ReID support...")
184
  detector = load_detector(2, prefer_onnx=True)
185
  comparator = DamageComparator(config_path=CONFIG_PATHS[2])
186
  visualizer = DamageVisualizer(config_path=CONFIG_PATHS[2])
 
525
  before_2: UploadFile = File(...),
526
  before_3: UploadFile = File(...),
527
  before_4: UploadFile = File(...),
 
528
  # After delivery images (6 positions)
529
  after_1: UploadFile = File(...),
530
  after_2: UploadFile = File(...),
531
  after_3: UploadFile = File(...),
532
  after_4: UploadFile = File(...),
 
 
533
  select_models: int = Form(2),
534
  prefer_onnx: bool = Form(True)
535
  ):
 
570
  all_after_detections = []
571
 
572
  # Use ThreadPoolExecutor to share memory (avoid OOM)
573
+ print(f"🔄 Processing {len(before_images)} image pairs using ThreadPoolExecutor...")
574
 
575
  with ThreadPoolExecutor(max_workers=2) as executor: # Limit workers to avoid memory issues
576
  futures = [
 
611
  position_results.sort(key=lambda x: int(list(x.keys())[0].split('_')[1]))
612
 
613
  # Deduplicate BEFORE damages across all 6 views using DINOv2
614
+ print("🔍 Deduplicating damages across views using DINOv2...")
615
  unique_before = comparator.deduplicate_detections_across_views(
616
  all_before_detections, all_before_images
617
  )
 
622
  )
623
 
624
  print(
625
+ f"Before: {sum(len(d['boxes']) for d in all_before_detections)} detections → {len(unique_before)} unique")
626
+ print(f"After: {sum(len(d['boxes']) for d in all_after_detections)} detections → {len(unique_after)} unique")
627
 
628
  # Determine overall case with deduplication
629
  actual_new_damages = max(0, len(unique_after) - len(unique_before))
 
710
 
711
 
712
  if __name__ == "__main__":
713
+ import os
714
+ uvicorn.run(
715
  "main:app",
716
  host="0.0.0.0",
717
+ port=int(os.environ.get("PORT", 7860)),
718
+ reload=False,
 
719
  log_level="info"
720
+ )