TDN-M commited on
Commit
3c568cd
·
verified ·
1 Parent(s): b3b1f2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -1,17 +1,22 @@
1
  import cv2
2
  import numpy as np
3
  import gradio as gr
4
- import torch
5
  from segment_anything import sam_model_registry, SamPredictor
 
 
 
 
 
 
 
6
 
7
- # Load the Segment Anything Model (SAM)
8
- sam_checkpoint = "sam_vit_h_4b8939.pth" # Path to the SAM checkpoint
9
- model_type = "vit_h"
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
13
- sam.to(device=device)
14
- predictor = SamPredictor(sam)
15
 
16
  def generate_mask(image, event: gr.SelectData):
17
  """
 
1
  import cv2
2
  import numpy as np
3
  import gradio as gr
4
+ from huggingface_hub import hf_hub_download
5
  from segment_anything import sam_model_registry, SamPredictor
6
+ import torch
7
+
8
+ # Load the Segment Anything Model (SAM) from Hugging Face
9
+ def load_sam_model():
10
+ checkpoint_path = hf_hub_download(repo_id="facebook/sam-vit-huge", filename="sam_vit_h_4b8939.pth")
11
+ model_type = "vit_h"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
15
+ sam.to(device=device)
16
+ predictor = SamPredictor(sam)
17
+ return predictor
18
 
19
+ predictor = load_sam_model()
 
 
20
 
21
  def generate_mask(image, event: gr.SelectData):
22
  """