TDN-M commited on
Commit
f357c98
·
verified ·
1 Parent(s): ade94f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -11,11 +11,15 @@ def load_sam_model():
11
  checkpoint_path = hf_hub_download(repo_id="facebook/sam-vit-huge", filename="pytorch_model.bin")
12
 
13
  # Load checkpoint với map_location=torch.device('cpu')
 
 
 
14
  model_type = "vit_h"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Sử dụng map_location=torch.device('cpu') để tải mô hình trên CPU
18
- sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
 
19
  sam.to(device=device)
20
  predictor = SamPredictor(sam)
21
  return predictor
 
11
  checkpoint_path = hf_hub_download(repo_id="facebook/sam-vit-huge", filename="pytorch_model.bin")
12
 
13
  # Load checkpoint với map_location=torch.device('cpu')
14
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
15
+
16
+ # Khởi tạo mô hình SAM
17
  model_type = "vit_h"
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ # Truyền checkpoint vào mô hình
21
+ sam = sam_model_registry[model_type]()
22
+ sam.load_state_dict(checkpoint)
23
  sam.to(device=device)
24
  predictor = SamPredictor(sam)
25
  return predictor