pszemraj commited on
Commit
89e989d
·
verified ·
1 Parent(s): 5e283ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -29
app.py CHANGED
@@ -42,34 +42,25 @@ def cleanup_memory():
42
 
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
45
- torch.cuda.synchronize()
46
 
47
 
48
  def load_model(name):
49
- """Load model with proper memory management and dtype handling"""
50
  global processor, model
51
-
52
- # Clean up existing model
53
  cleanup_memory()
54
-
55
  model_id = MODEL_MAP[name]
56
-
57
- # Load processor
58
  processor = AutoImageProcessor.from_pretrained(model_id)
59
-
60
- model = (
61
- AutoModel.from_pretrained(
62
- model_id,
63
- torch_dtype="auto",
64
- )
65
- .to(DEVICE)
66
- .eval()
67
- )
68
-
69
- # Get model info
70
  param_count = sum(p.numel() for p in model.parameters()) / 1e9
71
-
72
- return f"Loaded: {name} | {param_count:.1f}B params | {DEVICE.upper()}"
73
 
74
 
75
  # Initialize default model
@@ -79,22 +70,31 @@ load_model(DEFAULT_NAME)
79
  @spaces.GPU(duration=60)
80
  def _extract_grid(img):
81
  """Extract feature grid from image"""
 
 
82
  with torch.inference_mode():
83
- pv = processor(images=img, return_tensors="pt").pixel_values
84
-
85
- if DEVICE == "cuda":
86
- pv = pv.to(DEVICE)
87
-
 
 
88
  out = model(pixel_values=pv)
89
  last = out.last_hidden_state[0].to(torch.float32)
90
-
 
91
  num_reg = getattr(model.config, "num_register_tokens", 0)
92
  p = model.config.patch_size
93
  _, _, Ht, Wt = pv.shape
94
  gh, gw = Ht // p, Wt // p
95
-
96
- feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu()
97
-
 
 
 
 
98
  return feats, gh, gw
99
 
100
 
 
42
 
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
45
+ # torch.cuda.synchronize()
46
 
47
 
48
  def load_model(name):
49
+ """Load model with CORRECT dtype"""
50
  global processor, model
51
+
 
52
  cleanup_memory()
 
53
  model_id = MODEL_MAP[name]
54
+
 
55
  processor = AutoImageProcessor.from_pretrained(model_id)
56
+
57
+ model = AutoModel.from_pretrained(
58
+ model_id,
59
+ torch_dtype="auto",
60
+ ).eval()
61
+
 
 
 
 
 
62
  param_count = sum(p.numel() for p in model.parameters()) / 1e9
63
+ return f"Loaded: {name} | {param_count:.1f}B params | Ready"
 
64
 
65
 
66
  # Initialize default model
 
70
  @spaces.GPU(duration=60)
71
  def _extract_grid(img):
72
  """Extract feature grid from image"""
73
+ global model
74
+
75
  with torch.inference_mode():
76
+ # Move model to GPU for this call
77
+ model = model.to('cuda')
78
+
79
+ # Process image and move to GPU
80
+ pv = processor(images=img, return_tensors="pt").pixel_values.to(model.device)
81
+
82
+ # Run inference
83
  out = model(pixel_values=pv)
84
  last = out.last_hidden_state[0].to(torch.float32)
85
+
86
+ # Extract features
87
  num_reg = getattr(model.config, "num_register_tokens", 0)
88
  p = model.config.patch_size
89
  _, _, Ht, Wt = pv.shape
90
  gh, gw = Ht // p, Wt // p
91
+
92
+ feats = last[1 + num_reg:, :].reshape(gh, gw, -1).cpu()
93
+
94
+ # Move model back to CPU before function exits
95
+ model = model.cpu()
96
+ torch.cuda.empty_cache()
97
+
98
  return feats, gh, gw
99
 
100