minh9972t12 commited on
Commit
4a98225
·
1 Parent(s): a3450c9

Update src/comparison.py

Browse files
Files changed (1) hide show
  1. src/comparison.py +30 -22
src/comparison.py CHANGED
@@ -17,6 +17,32 @@ except ImportError:
17
  print(" Install with: pip install transformers")
18
  CLIP_AVAILABLE = False
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class DamageComparator:
22
  """Enhanced damage comparator with view-invariant re-identification"""
@@ -32,32 +58,14 @@ class DamageComparator:
32
  # Device selection
33
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
 
35
- # Initialize CLIP if available
36
- self.clip_model = None
37
- self.clip_processor = None
38
- if CLIP_AVAILABLE:
39
- self._init_clip_model()
40
 
41
  # ReID thresholds
42
- self.reid_similarity_threshold = 0.6 # For matching same damage across views
43
- self.feature_cache = {} # Cache extracted features
44
-
45
- def _init_clip_model(self):
46
- """Initialize CLIP model for view-invariant features"""
47
- try:
48
- model_name = "openai/clip-vit-base-patch32" # Smaller model for speed
49
- self.clip_model = CLIPModel.from_pretrained(model_name).to(self.device)
50
- self.clip_processor = CLIPProcessor.from_pretrained(model_name)
51
- self.clip_model.eval()
52
 
53
- # Freeze parameters
54
- for param in self.clip_model.parameters():
55
- param.requires_grad = False
56
 
57
- print(f"✓ CLIP model loaded for ReID: {model_name}")
58
- except Exception as e:
59
- print(f"⚠ CLIP loading failed: {e}. Using fallback features.")
60
- self.clip_model = None
61
 
62
  def calculate_iou(self, box1: List[int], box2: List[int]) -> float:
63
  """Calculate Intersection over Union between two boxes"""
 
17
  print(" Install with: pip install transformers")
18
  CLIP_AVAILABLE = False
19
 
20
+ _GLOBAL_CLIP_MODEL = None
21
+ _GLOBAL_CLIP_PROCESSOR = None
22
+
23
+
24
+ def get_clip_model():
25
+ """Get or initialize global CLIP model"""
26
+ global _GLOBAL_CLIP_MODEL, _GLOBAL_CLIP_PROCESSOR
27
+
28
+ if _GLOBAL_CLIP_MODEL is None and CLIP_AVAILABLE:
29
+ try:
30
+ model_name = "openai/clip-vit-base-patch32"
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ _GLOBAL_CLIP_MODEL = CLIPModel.from_pretrained(model_name).to(device)
33
+ _GLOBAL_CLIP_PROCESSOR = CLIPProcessor.from_pretrained(model_name)
34
+ _GLOBAL_CLIP_MODEL.eval()
35
+
36
+ for param in _GLOBAL_CLIP_MODEL.parameters():
37
+ param.requires_grad = False
38
+
39
+ print(f"✓ CLIP model loaded for ReID: {model_name}")
40
+ except Exception as e:
41
+ print(f"⚠ CLIP loading failed: {e}. Using fallback features.")
42
+ _GLOBAL_CLIP_MODEL = None
43
+ _GLOBAL_CLIP_PROCESSOR = None
44
+
45
+ return _GLOBAL_CLIP_MODEL, _GLOBAL_CLIP_PROCESSOR
46
 
47
  class DamageComparator:
48
  """Enhanced damage comparator with view-invariant re-identification"""
 
58
  # Device selection
59
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
 
61
+ # Get global CLIP model instead of creating new one
62
+ self.clip_model, self.clip_processor = get_clip_model()
 
 
 
63
 
64
  # ReID thresholds
65
+ self.reid_similarity_threshold = 0.6
66
+ self.feature_cache = {}
 
 
 
 
 
 
 
 
67
 
 
 
 
68
 
 
 
 
 
69
 
70
  def calculate_iou(self, box1: List[int], box2: List[int]) -> float:
71
  """Calculate Intersection over Union between two boxes"""