Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		liuyizhang
		
	commited on
		
		
					Commit 
							
							·
						
						c419c35
	
1
								Parent(s):
							
							e5f7fa3
								
update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -116,18 +116,16 @@ def load_image(image_path): 
     | 
|
| 116 | 
         
             
                image, _ = transform(image_pil, None)  # 3, h, w
         
     | 
| 117 | 
         
             
                return image_pil, image
         
     | 
| 118 | 
         | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
             
            def load_model(model_config_path, model_checkpoint_path, device):
         
     | 
| 121 | 
         
             
                args = SLConfig.fromfile(model_config_path)
         
     | 
| 122 | 
         
             
                args.device = device
         
     | 
| 123 | 
         
             
                model = build_model(args)
         
     | 
| 124 | 
         
            -
                checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
         
     | 
| 125 | 
         
             
                load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
         
     | 
| 126 | 
         
             
                print(load_res)
         
     | 
| 127 | 
         
             
                _ = model.eval()
         
     | 
| 128 | 
         
             
                return model
         
     | 
| 129 | 
         | 
| 130 | 
         
            -
             
     | 
| 131 | 
         
             
            def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
         
     | 
| 132 | 
         
             
                caption = caption.lower()
         
     | 
| 133 | 
         
             
                caption = caption.strip()
         
     | 
| 
         @@ -172,14 +170,12 @@ def show_mask(mask, ax, random_color=False): 
     | 
|
| 172 | 
         
             
                mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
         
     | 
| 173 | 
         
             
                ax.imshow(mask_image)
         
     | 
| 174 | 
         | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
             
            def show_box(box, ax, label):
         
     | 
| 177 | 
         
             
                x0, y0 = box[0], box[1]
         
     | 
| 178 | 
         
             
                w, h = box[2] - box[0], box[3] - box[1]
         
     | 
| 179 | 
         
             
                ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 
         
     | 
| 180 | 
         
             
                ax.text(x0, y0, label)
         
     | 
| 181 | 
         | 
| 182 | 
         
            -
             
     | 
| 183 | 
         
             
            config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
         
     | 
| 184 | 
         
             
            ckpt_repo_id = "ShilongLiu/GroundingDINO"
         
     | 
| 185 | 
         
             
            ckpt_filenmae = "groundingdino_swint_ogc.pth"
         
     | 
| 
         @@ -189,6 +185,19 @@ device = "cuda" 
     | 
|
| 189 | 
         | 
| 190 | 
         
             
            device = get_device()
         
     | 
| 191 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 192 | 
         
             
            def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
         
     | 
| 193 | 
         
             
                assert text_prompt, 'text_prompt is not found!'
         
     | 
| 194 | 
         | 
| 
         @@ -196,24 +205,20 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr 
     | 
|
| 196 | 
         
             
                os.makedirs(output_dir, exist_ok=True)
         
     | 
| 197 | 
         
             
                # load image
         
     | 
| 198 | 
         
             
                image_pil, image = load_image(image_path.convert("RGB"))
         
     | 
| 199 | 
         
            -
                # load model
         
     | 
| 200 | 
         
            -
                model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
         
     | 
| 201 | 
         | 
| 202 | 
         
             
                # visualize raw image
         
     | 
| 203 | 
         
             
                image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
         
     | 
| 204 | 
         | 
| 205 | 
         
             
                # run grounding dino model
         
     | 
| 206 | 
         
             
                boxes_filt, pred_phrases = get_grounding_output(
         
     | 
| 207 | 
         
            -
                     
     | 
| 208 | 
         
             
                )
         
     | 
| 209 | 
         | 
| 210 | 
         
             
                size = image_pil.size
         
     | 
| 211 | 
         | 
| 212 | 
         
             
                if task_type == 'segment' or task_type == 'inpainting':
         
     | 
| 213 | 
         
            -
                    # initialize SAM
         
     | 
| 214 | 
         
            -
                    predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
         
     | 
| 215 | 
         
             
                    image = np.array(image_path)
         
     | 
| 216 | 
         
            -
                     
     | 
| 217 | 
         | 
| 218 | 
         
             
                    H, W = size[1], size[0]
         
     | 
| 219 | 
         
             
                    for i in range(boxes_filt.size(0)):
         
     | 
| 
         @@ -222,9 +227,9 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr 
     | 
|
| 222 | 
         
             
                        boxes_filt[i][2:] += boxes_filt[i][:2]
         
     | 
| 223 | 
         | 
| 224 | 
         
             
                    boxes_filt = boxes_filt.cpu()
         
     | 
| 225 | 
         
            -
                    transformed_boxes =  
     | 
| 226 | 
         | 
| 227 | 
         
            -
                    masks, _, _ =  
     | 
| 228 | 
         
             
                        point_coords = None,
         
     | 
| 229 | 
         
             
                        point_labels = None,
         
     | 
| 230 | 
         
             
                        boxes = transformed_boxes,
         
     | 
| 
         @@ -266,14 +271,8 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr 
     | 
|
| 266 | 
         
             
                    mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
         
     | 
| 267 | 
         
             
                    mask_pil = Image.fromarray(mask)
         
     | 
| 268 | 
         
             
                    image_pil = Image.fromarray(image)
         
     | 
| 269 | 
         
            -
                    
         
     | 
| 270 | 
         
            -
                    pipe = StableDiffusionInpaintPipeline.from_pretrained(
         
     | 
| 271 | 
         
            -
                            "runwayml/stable-diffusion-inpainting", 
         
     | 
| 272 | 
         
            -
                            # torch_dtype=torch.float16
         
     | 
| 273 | 
         
            -
                    )
         
     | 
| 274 | 
         
            -
                    pipe = pipe.to(device)
         
     | 
| 275 | 
         | 
| 276 | 
         
            -
                    image =  
     | 
| 277 | 
         
             
                    image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
         
     | 
| 278 | 
         
             
                    image.save(image_path)
         
     | 
| 279 | 
         
             
                    image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
         
     | 
| 
         | 
|
| 116 | 
         
             
                image, _ = transform(image_pil, None)  # 3, h, w
         
     | 
| 117 | 
         
             
                return image_pil, image
         
     | 
| 118 | 
         | 
| 
         | 
|
| 119 | 
         
             
            def load_model(model_config_path, model_checkpoint_path, device):
         
     | 
| 120 | 
         
             
                args = SLConfig.fromfile(model_config_path)
         
     | 
| 121 | 
         
             
                args.device = device
         
     | 
| 122 | 
         
             
                model = build_model(args)
         
     | 
| 123 | 
         
            +
                checkpoint = torch.load(model_checkpoint_path, map_location=device) #"cpu")
         
     | 
| 124 | 
         
             
                load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
         
     | 
| 125 | 
         
             
                print(load_res)
         
     | 
| 126 | 
         
             
                _ = model.eval()
         
     | 
| 127 | 
         
             
                return model
         
     | 
| 128 | 
         | 
| 
         | 
|
| 129 | 
         
             
            def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
         
     | 
| 130 | 
         
             
                caption = caption.lower()
         
     | 
| 131 | 
         
             
                caption = caption.strip()
         
     | 
| 
         | 
|
| 170 | 
         
             
                mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
         
     | 
| 171 | 
         
             
                ax.imshow(mask_image)
         
     | 
| 172 | 
         | 
| 
         | 
|
| 173 | 
         
             
            def show_box(box, ax, label):
         
     | 
| 174 | 
         
             
                x0, y0 = box[0], box[1]
         
     | 
| 175 | 
         
             
                w, h = box[2] - box[0], box[3] - box[1]
         
     | 
| 176 | 
         
             
                ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 
         
     | 
| 177 | 
         
             
                ax.text(x0, y0, label)
         
     | 
| 178 | 
         | 
| 
         | 
|
| 179 | 
         
             
            config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
         
     | 
| 180 | 
         
             
            ckpt_repo_id = "ShilongLiu/GroundingDINO"
         
     | 
| 181 | 
         
             
            ckpt_filenmae = "groundingdino_swint_ogc.pth"
         
     | 
| 
         | 
|
| 185 | 
         | 
| 186 | 
         
             
            device = get_device()
         
     | 
| 187 | 
         | 
| 188 | 
         
            +
            # initialize groundingdino model
         
     | 
| 189 | 
         
            +
            groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            # initialize SAM
         
     | 
| 192 | 
         
            +
            sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            # initialize stable-diffusion-inpainting
         
     | 
| 195 | 
         
            +
            sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
         
     | 
| 196 | 
         
            +
                    "runwayml/stable-diffusion-inpainting", 
         
     | 
| 197 | 
         
            +
                    # torch_dtype=torch.float16
         
     | 
| 198 | 
         
            +
            )
         
     | 
| 199 | 
         
            +
            sd_pipe = sd_pipe.to(device)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
             
            def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
         
     | 
| 202 | 
         
             
                assert text_prompt, 'text_prompt is not found!'
         
     | 
| 203 | 
         | 
| 
         | 
|
| 205 | 
         
             
                os.makedirs(output_dir, exist_ok=True)
         
     | 
| 206 | 
         
             
                # load image
         
     | 
| 207 | 
         
             
                image_pil, image = load_image(image_path.convert("RGB"))
         
     | 
| 
         | 
|
| 
         | 
|
| 208 | 
         | 
| 209 | 
         
             
                # visualize raw image
         
     | 
| 210 | 
         
             
                image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
         
     | 
| 211 | 
         | 
| 212 | 
         
             
                # run grounding dino model
         
     | 
| 213 | 
         
             
                boxes_filt, pred_phrases = get_grounding_output(
         
     | 
| 214 | 
         
            +
                    groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=device
         
     | 
| 215 | 
         
             
                )
         
     | 
| 216 | 
         | 
| 217 | 
         
             
                size = image_pil.size
         
     | 
| 218 | 
         | 
| 219 | 
         
             
                if task_type == 'segment' or task_type == 'inpainting':
         
     | 
| 
         | 
|
| 
         | 
|
| 220 | 
         
             
                    image = np.array(image_path)
         
     | 
| 221 | 
         
            +
                    sam_predictor.set_image(image)
         
     | 
| 222 | 
         | 
| 223 | 
         
             
                    H, W = size[1], size[0]
         
     | 
| 224 | 
         
             
                    for i in range(boxes_filt.size(0)):
         
     | 
| 
         | 
|
| 227 | 
         
             
                        boxes_filt[i][2:] += boxes_filt[i][:2]
         
     | 
| 228 | 
         | 
| 229 | 
         
             
                    boxes_filt = boxes_filt.cpu()
         
     | 
| 230 | 
         
            +
                    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
         
     | 
| 231 | 
         | 
| 232 | 
         
            +
                    masks, _, _ = sam_predictor.predict_torch(
         
     | 
| 233 | 
         
             
                        point_coords = None,
         
     | 
| 234 | 
         
             
                        point_labels = None,
         
     | 
| 235 | 
         
             
                        boxes = transformed_boxes,
         
     | 
| 
         | 
|
| 271 | 
         
             
                    mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
         
     | 
| 272 | 
         
             
                    mask_pil = Image.fromarray(mask)
         
     | 
| 273 | 
         
             
                    image_pil = Image.fromarray(image)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 274 | 
         | 
| 275 | 
         
            +
                    image = sd_pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
         
     | 
| 276 | 
         
             
                    image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
         
     | 
| 277 | 
         
             
                    image.save(image_path)
         
     | 
| 278 | 
         
             
                    image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
         
     |