Spaces:
Runtime error
Runtime error
liuyizhang
commited on
Commit
·
dfba81f
1
Parent(s):
c4d99b7
update app.py
Browse files
app.py
CHANGED
|
@@ -38,6 +38,7 @@ import cv2
|
|
| 38 |
import numpy as np
|
| 39 |
import matplotlib.pyplot as plt
|
| 40 |
|
|
|
|
| 41 |
sam_enable = True
|
| 42 |
inpainting_enable = True
|
| 43 |
ram_enable = True
|
|
@@ -103,16 +104,10 @@ sam_predictor = None
|
|
| 103 |
sam_mask_generator = None
|
| 104 |
sd_model = None
|
| 105 |
lama_cleaner_model= None
|
| 106 |
-
lama_cleaner_model_device = device
|
| 107 |
ram_model = None
|
| 108 |
kosmos_model = None
|
| 109 |
kosmos_processor = None
|
| 110 |
|
| 111 |
-
def get_sam_vit_h_4b8939():
|
| 112 |
-
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
| 113 |
-
logger.info(f"get sam_vit_h_4b8939.pth...")
|
| 114 |
-
result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
|
| 115 |
-
print(f'wget sam_vit_h_4b8939.pth result = {result}')
|
| 116 |
|
| 117 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
| 118 |
args = SLConfig.fromfile(model_config_path)
|
|
@@ -282,24 +277,31 @@ def set_device():
|
|
| 282 |
device = 'cpu'
|
| 283 |
print(f'device={device}')
|
| 284 |
|
| 285 |
-
def load_groundingdino_model():
|
| 286 |
# initialize groundingdino model
|
| 287 |
-
global groundingdino_model
|
| 288 |
logger.info(f"initialize groundingdino model...")
|
| 289 |
-
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device='cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
-
def load_sam_model():
|
| 292 |
# initialize SAM
|
| 293 |
-
global sam_model, sam_predictor, sam_mask_generator, sam_device
|
|
|
|
| 294 |
logger.info(f"initialize SAM model...")
|
| 295 |
sam_device = device
|
| 296 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
| 297 |
sam_predictor = SamPredictor(sam_model)
|
| 298 |
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
| 299 |
|
| 300 |
-
def load_sd_model():
|
| 301 |
# initialize stable-diffusion-inpainting
|
| 302 |
-
global sd_model
|
| 303 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
| 304 |
sd_model = None
|
| 305 |
if os.environ.get('IS_MY_DEBUG') is None:
|
|
@@ -311,14 +313,14 @@ def load_sd_model():
|
|
| 311 |
)
|
| 312 |
sd_model = sd_model.to(device)
|
| 313 |
|
| 314 |
-
def load_lama_cleaner_model():
|
| 315 |
# initialize lama_cleaner
|
| 316 |
-
global lama_cleaner_model
|
| 317 |
logger.info(f"initialize lama_cleaner...")
|
| 318 |
|
| 319 |
lama_cleaner_model = ModelManager(
|
| 320 |
name='lama',
|
| 321 |
-
device=
|
| 322 |
)
|
| 323 |
|
| 324 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
|
@@ -390,7 +392,7 @@ class Ram_Predictor(RamPredictor):
|
|
| 390 |
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
| 391 |
self.model.train()
|
| 392 |
|
| 393 |
-
def load_ram_model():
|
| 394 |
# load ram model
|
| 395 |
global ram_model
|
| 396 |
if os.environ.get('IS_MY_DEBUG') is not None:
|
|
@@ -830,20 +832,20 @@ if __name__ == "__main__":
|
|
| 830 |
if kosmos_enable:
|
| 831 |
kosmos_model, kosmos_processor = load_kosmos_model(device)
|
| 832 |
|
| 833 |
-
|
|
|
|
| 834 |
|
| 835 |
if sam_enable:
|
| 836 |
-
|
| 837 |
-
load_sam_model()
|
| 838 |
|
| 839 |
if inpainting_enable:
|
| 840 |
-
load_sd_model()
|
| 841 |
|
| 842 |
if lama_cleaner_enable:
|
| 843 |
-
load_lama_cleaner_model()
|
| 844 |
|
| 845 |
if ram_enable:
|
| 846 |
-
load_ram_model()
|
| 847 |
|
| 848 |
if os.environ.get('IS_MY_DEBUG') is None:
|
| 849 |
os.system("pip list")
|
|
@@ -865,7 +867,7 @@ if __name__ == "__main__":
|
|
| 865 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 866 |
value=mask_source_segment, label="Mask from",
|
| 867 |
visible=False)
|
| 868 |
-
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each
|
| 869 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
| 870 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
| 871 |
|
|
@@ -946,6 +948,7 @@ if __name__ == "__main__":
|
|
| 946 |
<a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
| 947 |
gr.Markdown(DESCRIPTION)
|
| 948 |
|
|
|
|
| 949 |
computer_info()
|
| 950 |
block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
|
| 951 |
|
|
|
|
| 38 |
import numpy as np
|
| 39 |
import matplotlib.pyplot as plt
|
| 40 |
|
| 41 |
+
groundingdino_enable = True
|
| 42 |
sam_enable = True
|
| 43 |
inpainting_enable = True
|
| 44 |
ram_enable = True
|
|
|
|
| 104 |
sam_mask_generator = None
|
| 105 |
sd_model = None
|
| 106 |
lama_cleaner_model= None
|
|
|
|
| 107 |
ram_model = None
|
| 108 |
kosmos_model = None
|
| 109 |
kosmos_processor = None
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
| 113 |
args = SLConfig.fromfile(model_config_path)
|
|
|
|
| 277 |
device = 'cpu'
|
| 278 |
print(f'device={device}')
|
| 279 |
|
| 280 |
+
def load_groundingdino_model(device):
|
| 281 |
# initialize groundingdino model
|
|
|
|
| 282 |
logger.info(f"initialize groundingdino model...")
|
| 283 |
+
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
|
| 284 |
+
return groundingdino_model
|
| 285 |
+
|
| 286 |
+
def get_sam_vit_h_4b8939():
|
| 287 |
+
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
| 288 |
+
logger.info(f"get sam_vit_h_4b8939.pth...")
|
| 289 |
+
result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
|
| 290 |
+
print(f'wget sam_vit_h_4b8939.pth result = {result}')
|
| 291 |
|
| 292 |
+
def load_sam_model(device):
|
| 293 |
# initialize SAM
|
| 294 |
+
global sam_model, sam_predictor, sam_mask_generator, sam_device
|
| 295 |
+
get_sam_vit_h_4b8939()
|
| 296 |
logger.info(f"initialize SAM model...")
|
| 297 |
sam_device = device
|
| 298 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
| 299 |
sam_predictor = SamPredictor(sam_model)
|
| 300 |
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
| 301 |
|
| 302 |
+
def load_sd_model(device):
|
| 303 |
# initialize stable-diffusion-inpainting
|
| 304 |
+
global sd_model
|
| 305 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
| 306 |
sd_model = None
|
| 307 |
if os.environ.get('IS_MY_DEBUG') is None:
|
|
|
|
| 313 |
)
|
| 314 |
sd_model = sd_model.to(device)
|
| 315 |
|
| 316 |
+
def load_lama_cleaner_model(device):
|
| 317 |
# initialize lama_cleaner
|
| 318 |
+
global lama_cleaner_model
|
| 319 |
logger.info(f"initialize lama_cleaner...")
|
| 320 |
|
| 321 |
lama_cleaner_model = ModelManager(
|
| 322 |
name='lama',
|
| 323 |
+
device=device,
|
| 324 |
)
|
| 325 |
|
| 326 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
|
|
|
| 392 |
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
| 393 |
self.model.train()
|
| 394 |
|
| 395 |
+
def load_ram_model(device):
|
| 396 |
# load ram model
|
| 397 |
global ram_model
|
| 398 |
if os.environ.get('IS_MY_DEBUG') is not None:
|
|
|
|
| 832 |
if kosmos_enable:
|
| 833 |
kosmos_model, kosmos_processor = load_kosmos_model(device)
|
| 834 |
|
| 835 |
+
if groundingdino_enable:
|
| 836 |
+
groundingdino_model = load_groundingdino_model('cpu')
|
| 837 |
|
| 838 |
if sam_enable:
|
| 839 |
+
load_sam_model(device)
|
|
|
|
| 840 |
|
| 841 |
if inpainting_enable:
|
| 842 |
+
load_sd_model(device)
|
| 843 |
|
| 844 |
if lama_cleaner_enable:
|
| 845 |
+
load_lama_cleaner_model(device)
|
| 846 |
|
| 847 |
if ram_enable:
|
| 848 |
+
load_ram_model(device)
|
| 849 |
|
| 850 |
if os.environ.get('IS_MY_DEBUG') is None:
|
| 851 |
os.system("pip list")
|
|
|
|
| 867 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 868 |
value=mask_source_segment, label="Mask from",
|
| 869 |
visible=False)
|
| 870 |
+
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
|
| 871 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
| 872 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
| 873 |
|
|
|
|
| 948 |
<a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
| 949 |
gr.Markdown(DESCRIPTION)
|
| 950 |
|
| 951 |
+
print(f'device={device}')
|
| 952 |
computer_info()
|
| 953 |
block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
|
| 954 |
|