Spaces:
Runtime error
Runtime error
Update mask_adapter/sam_maskadapter.py
Browse files
mask_adapter/sam_maskadapter.py
CHANGED
|
@@ -113,7 +113,7 @@ class SAMVisualizationDemo(object):
|
|
| 113 |
|
| 114 |
return clip_vis_dense
|
| 115 |
|
| 116 |
-
def run_on_image(self, ori_image, class_names):
|
| 117 |
height, width, _ = ori_image.shape
|
| 118 |
if width > height:
|
| 119 |
new_width = 896
|
|
@@ -140,15 +140,15 @@ class SAMVisualizationDemo(object):
|
|
| 140 |
|
| 141 |
image = image.unsqueeze(0)
|
| 142 |
|
| 143 |
-
if len(class_names) == 1:
|
| 144 |
-
|
| 145 |
-
txts = [f'a photo of {cls_name}' for cls_name in class_names]
|
| 146 |
-
text = open_clip.tokenize(txts)
|
| 147 |
|
| 148 |
|
| 149 |
with torch.no_grad():
|
| 150 |
-
text_features = self.clip_model.encode_text(text)
|
| 151 |
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 152 |
|
| 153 |
features = self.extract_features_convnext(image.float())
|
| 154 |
|
|
|
|
| 113 |
|
| 114 |
return clip_vis_dense
|
| 115 |
|
| 116 |
+
def run_on_image(self, ori_image, class_names, text_features):
|
| 117 |
height, width, _ = ori_image.shape
|
| 118 |
if width > height:
|
| 119 |
new_width = 896
|
|
|
|
| 140 |
|
| 141 |
image = image.unsqueeze(0)
|
| 142 |
|
| 143 |
+
# if len(class_names) == 1:
|
| 144 |
+
# class_names.append('others')
|
| 145 |
+
# txts = [f'a photo of {cls_name}' for cls_name in class_names]
|
| 146 |
+
# text = open_clip.tokenize(txts)
|
| 147 |
|
| 148 |
|
| 149 |
with torch.no_grad():
|
| 150 |
+
# text_features = self.clip_model.encode_text(text)
|
| 151 |
+
# text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 152 |
|
| 153 |
features = self.extract_features_convnext(image.float())
|
| 154 |
|