Spaces:
Running
Running
Commit
·
4a98225
1
Parent(s):
a3450c9
Update src/comparison.py
Browse files- src/comparison.py +30 -22
src/comparison.py
CHANGED
@@ -17,6 +17,32 @@ except ImportError:
|
|
17 |
print(" Install with: pip install transformers")
|
18 |
CLIP_AVAILABLE = False
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
class DamageComparator:
|
22 |
"""Enhanced damage comparator with view-invariant re-identification"""
|
@@ -32,32 +58,14 @@ class DamageComparator:
|
|
32 |
# Device selection
|
33 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
|
35 |
-
#
|
36 |
-
self.clip_model =
|
37 |
-
self.clip_processor = None
|
38 |
-
if CLIP_AVAILABLE:
|
39 |
-
self._init_clip_model()
|
40 |
|
41 |
# ReID thresholds
|
42 |
-
self.reid_similarity_threshold = 0.6
|
43 |
-
self.feature_cache = {}
|
44 |
-
|
45 |
-
def _init_clip_model(self):
|
46 |
-
"""Initialize CLIP model for view-invariant features"""
|
47 |
-
try:
|
48 |
-
model_name = "openai/clip-vit-base-patch32" # Smaller model for speed
|
49 |
-
self.clip_model = CLIPModel.from_pretrained(model_name).to(self.device)
|
50 |
-
self.clip_processor = CLIPProcessor.from_pretrained(model_name)
|
51 |
-
self.clip_model.eval()
|
52 |
|
53 |
-
# Freeze parameters
|
54 |
-
for param in self.clip_model.parameters():
|
55 |
-
param.requires_grad = False
|
56 |
|
57 |
-
print(f"✓ CLIP model loaded for ReID: {model_name}")
|
58 |
-
except Exception as e:
|
59 |
-
print(f"⚠ CLIP loading failed: {e}. Using fallback features.")
|
60 |
-
self.clip_model = None
|
61 |
|
62 |
def calculate_iou(self, box1: List[int], box2: List[int]) -> float:
|
63 |
"""Calculate Intersection over Union between two boxes"""
|
|
|
17 |
print(" Install with: pip install transformers")
|
18 |
CLIP_AVAILABLE = False
|
19 |
|
20 |
+
_GLOBAL_CLIP_MODEL = None
|
21 |
+
_GLOBAL_CLIP_PROCESSOR = None
|
22 |
+
|
23 |
+
|
24 |
+
def get_clip_model():
|
25 |
+
"""Get or initialize global CLIP model"""
|
26 |
+
global _GLOBAL_CLIP_MODEL, _GLOBAL_CLIP_PROCESSOR
|
27 |
+
|
28 |
+
if _GLOBAL_CLIP_MODEL is None and CLIP_AVAILABLE:
|
29 |
+
try:
|
30 |
+
model_name = "openai/clip-vit-base-patch32"
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
_GLOBAL_CLIP_MODEL = CLIPModel.from_pretrained(model_name).to(device)
|
33 |
+
_GLOBAL_CLIP_PROCESSOR = CLIPProcessor.from_pretrained(model_name)
|
34 |
+
_GLOBAL_CLIP_MODEL.eval()
|
35 |
+
|
36 |
+
for param in _GLOBAL_CLIP_MODEL.parameters():
|
37 |
+
param.requires_grad = False
|
38 |
+
|
39 |
+
print(f"✓ CLIP model loaded for ReID: {model_name}")
|
40 |
+
except Exception as e:
|
41 |
+
print(f"⚠ CLIP loading failed: {e}. Using fallback features.")
|
42 |
+
_GLOBAL_CLIP_MODEL = None
|
43 |
+
_GLOBAL_CLIP_PROCESSOR = None
|
44 |
+
|
45 |
+
return _GLOBAL_CLIP_MODEL, _GLOBAL_CLIP_PROCESSOR
|
46 |
|
47 |
class DamageComparator:
|
48 |
"""Enhanced damage comparator with view-invariant re-identification"""
|
|
|
58 |
# Device selection
|
59 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
|
61 |
+
# Get global CLIP model instead of creating new one
|
62 |
+
self.clip_model, self.clip_processor = get_clip_model()
|
|
|
|
|
|
|
63 |
|
64 |
# ReID thresholds
|
65 |
+
self.reid_similarity_threshold = 0.6
|
66 |
+
self.feature_cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
|
|
|
|
|
|
68 |
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def calculate_iou(self, box1: List[int], box2: List[int]) -> float:
|
71 |
"""Calculate Intersection over Union between two boxes"""
|