vk commited on
Commit
b596a1b
·
1 Parent(s): bf7cfc3

hydra dependency removed

Browse files
Files changed (2) hide show
  1. app.py +7 -13
  2. requirements.txt +4 -2
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- from sam2.build_sam import build_sam2
3
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
4
  from PIL import Image
5
  from matplotlib import pyplot as plt
@@ -66,13 +65,10 @@ def get_response(image):
66
  return overlay_masks_on_image(image,masks)
67
 
68
  def download_checkpoint():
69
- os.system('wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt')
70
 
71
 
72
 
73
- def load_model_cfg():
74
- return SAM2AutomaticMaskGenerator(build_sam2(model_cfg, checkpoint, device='cpu'))
75
-
76
 
77
  if __name__ == "__main__":
78
 
@@ -85,15 +81,13 @@ if __name__ == "__main__":
85
  title="Segmenting Microscopic images with Segment Anything",
86
  description="Segmenting Microscopic images with Meta Segment Anything")
87
 
88
-
89
- checkpoint = "sam2.1_hiera_large.pt"
90
- model_cfg = "sam2.1_hiera_l.yaml"
91
-
92
- if not os.path.exists(checkpoint):
93
- print('Downloading checkpoint')
94
  download_checkpoint()
95
- print('Checkpoint Downloaded')
96
 
97
- mask_generator = load_model_cfg()
 
98
 
99
  iface.launch()
 
1
  import torch
 
2
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
3
  from PIL import Image
4
  from matplotlib import pyplot as plt
 
65
  return overlay_masks_on_image(image,masks)
66
 
67
  def download_checkpoint():
68
+ os.system('gdown 1RHSO8lHko3IK3dmABOzFDJuq7wmKVcun')
69
 
70
 
71
 
 
 
 
72
 
73
  if __name__ == "__main__":
74
 
 
81
  title="Segmenting Microscopic images with Segment Anything",
82
  description="Segmenting Microscopic images with Meta Segment Anything")
83
 
84
+ model_path='model.pth'
85
+ if not os.path.exists(model_path):
86
+ print('Downloading model with weights')
 
 
 
87
  download_checkpoint()
88
+ print('Model with weights Downloaded')
89
 
90
+ model = torch.load(model_path, map_location="cpu", weights_only=False)
91
+ mask_generator = SAM2AutomaticMaskGenerator(model)
92
 
93
  iface.launch()
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
- #hydra-core==1.3.2
2
- antlr4-python3-runtime==4.9.3
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv_python
4
+ gdown