Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	fix some bugs
Browse files- app/run_app.sh +5 -0
- app/src/brushedit_app.py +53 -42
- app/src/vlm_pipeline.py +24 -18
    	
        app/run_app.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            export PYTHONPATH=.:$PYTHONPATH
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            export CUDA_VISIBLE_DEVICES=0
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            python app/src/brushedit_app.py
         | 
    	
        app/src/brushedit_app.py
    CHANGED
    
    | @@ -337,7 +337,7 @@ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_M | |
| 337 | 
             
            if vlm_processor != "" and vlm_model != "":
         | 
| 338 | 
             
                vlm_model.to(device)
         | 
| 339 | 
             
            else:
         | 
| 340 | 
            -
                gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
         | 
| 341 |  | 
| 342 |  | 
| 343 | 
             
            ## init base model
         | 
| @@ -504,7 +504,7 @@ def random_mask_func(mask, dilation_type='square', dilation_size=20): | |
| 504 | 
             
                    dilated_mask = np.zeros_like(binary_mask, dtype=bool)
         | 
| 505 | 
             
                    dilated_mask[ellipse_mask] = True
         | 
| 506 | 
             
                else:
         | 
| 507 | 
            -
                     | 
| 508 |  | 
| 509 | 
             
                # use binary dilation
         | 
| 510 | 
             
                dilated_mask =  np.uint8(dilated_mask[:,:,np.newaxis]) * 255
         | 
| @@ -637,7 +637,8 @@ def process(input_image, | |
| 637 | 
             
                        image_pil = input_image["background"].convert("RGB")
         | 
| 638 | 
             
                        original_image = np.array(image_pil)
         | 
| 639 | 
             
                if prompt is None or prompt == "":
         | 
| 640 | 
            -
                     | 
|  | |
| 641 |  | 
| 642 | 
             
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 643 | 
             
                input_mask = np.asarray(alpha_mask)
         | 
| @@ -687,17 +688,23 @@ def process(input_image, | |
| 687 | 
             
                        original_mask = input_mask
         | 
| 688 |  | 
| 689 |  | 
| 690 | 
            -
             | 
| 691 | 
             
                if category is not None:
         | 
| 692 | 
            -
                    pass | 
|  | |
|  | |
| 693 | 
             
                else:
         | 
| 694 | 
            -
                     | 
| 695 | 
            -
             | 
|  | |
|  | |
| 696 |  | 
|  | |
| 697 | 
             
                if original_mask is not None:
         | 
| 698 | 
             
                    original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
         | 
| 699 | 
             
                else:
         | 
| 700 | 
            -
                     | 
|  | |
| 701 | 
             
                                                            vlm_processor, 
         | 
| 702 | 
             
                                                            vlm_model, 
         | 
| 703 | 
             
                                                            original_image,
         | 
| @@ -705,30 +712,37 @@ def process(input_image, | |
| 705 | 
             
                                                            prompt,
         | 
| 706 | 
             
                                                            device)
         | 
| 707 |  | 
| 708 | 
            -
             | 
| 709 | 
            -
             | 
| 710 | 
            -
             | 
| 711 | 
            -
             | 
| 712 | 
            -
             | 
| 713 | 
            -
             | 
| 714 | 
            -
             | 
| 715 | 
            -
             | 
| 716 | 
            -
             | 
| 717 | 
            -
             | 
| 718 | 
            -
             | 
|  | |
|  | |
|  | |
| 719 | 
             
                if original_mask.ndim == 2:
         | 
| 720 | 
             
                    original_mask = original_mask[:,:,None]
         | 
| 721 |  | 
| 722 |  | 
| 723 | 
            -
                if len(target_prompt)  | 
| 724 | 
            -
                    prompt_after_apply_instruction =  | 
|  | |
|  | |
|  | |
|  | |
| 725 | 
             
                                                                                vlm_processor, 
         | 
| 726 | 
             
                                                                                vlm_model, 
         | 
| 727 | 
             
                                                                                original_image,
         | 
| 728 | 
             
                                                                                prompt,
         | 
| 729 | 
             
                                                                                device)
         | 
| 730 | 
            -
             | 
| 731 | 
            -
             | 
| 732 |  | 
| 733 | 
             
                generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
         | 
| 734 |  | 
| @@ -758,7 +772,8 @@ def process(input_image, | |
| 758 | 
             
                # image[3].save(f"outputs/image_edit_{uuid}_3.png")
         | 
| 759 | 
             
                # mask_image.save(f"outputs/mask_{uuid}.png")
         | 
| 760 | 
             
                # masked_image.save(f"outputs/masked_image_{uuid}.png")
         | 
| 761 | 
            -
                 | 
|  | |
| 762 |  | 
| 763 |  | 
| 764 | 
             
            def generate_target_prompt(input_image, 
         | 
| @@ -774,7 +789,7 @@ def generate_target_prompt(input_image, | |
| 774 | 
             
                                                                        original_image,
         | 
| 775 | 
             
                                                                        prompt,
         | 
| 776 | 
             
                                                                        device)
         | 
| 777 | 
            -
                return prompt_after_apply_instruction | 
| 778 |  | 
| 779 |  | 
| 780 | 
             
            def process_mask(input_image, 
         | 
| @@ -1415,7 +1430,7 @@ def init_img(base, | |
| 1415 | 
             
                    original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
         | 
| 1416 | 
             
                    return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "", "Custom resolution", False, False, example_change_times
         | 
| 1417 | 
             
                else:
         | 
| 1418 | 
            -
                    return base, original_image, None, "", None, None, None, "", "",  | 
| 1419 |  | 
| 1420 |  | 
| 1421 | 
             
            def reset_func(input_image, 
         | 
| @@ -1423,7 +1438,7 @@ def reset_func(input_image, | |
| 1423 | 
             
                           original_mask, 
         | 
| 1424 | 
             
                           prompt, 
         | 
| 1425 | 
             
                           target_prompt, 
         | 
| 1426 | 
            -
                            | 
| 1427 | 
             
                input_image = None
         | 
| 1428 | 
             
                original_image = None
         | 
| 1429 | 
             
                original_mask = None
         | 
| @@ -1432,10 +1447,9 @@ def reset_func(input_image, | |
| 1432 | 
             
                masked_gallery = []
         | 
| 1433 | 
             
                result_gallery = []
         | 
| 1434 | 
             
                target_prompt = ''
         | 
| 1435 | 
            -
                target_prompt_output = ''
         | 
| 1436 | 
             
                if torch.cuda.is_available():
         | 
| 1437 | 
             
                    torch.cuda.empty_cache()
         | 
| 1438 | 
            -
                return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt,  | 
| 1439 |  | 
| 1440 |  | 
| 1441 | 
             
            def update_example(example_type, 
         | 
| @@ -1458,7 +1472,8 @@ def update_example(example_type, | |
| 1458 | 
             
                original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
         | 
| 1459 | 
             
                aspect_ratio = "Custom resolution"
         | 
| 1460 | 
             
                example_change_times += 1
         | 
| 1461 | 
            -
                return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "",  | 
|  | |
| 1462 |  | 
| 1463 | 
             
            block = gr.Blocks(
         | 
| 1464 | 
             
                    theme=gr.themes.Soft(
         | 
| @@ -1498,6 +1513,8 @@ with block as demo: | |
| 1498 | 
             
                                sources=["upload"],
         | 
| 1499 | 
             
                                )
         | 
| 1500 |  | 
|  | |
|  | |
| 1501 |  | 
| 1502 | 
             
                        vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
         | 
| 1503 | 
             
                        with gr.Group():    
         | 
| @@ -1510,12 +1527,6 @@ with block as demo: | |
| 1510 | 
             
                        aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
         | 
| 1511 | 
             
                        resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
         | 
| 1512 |  | 
| 1513 | 
            -
             | 
| 1514 | 
            -
                        prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
         | 
| 1515 | 
            -
             | 
| 1516 | 
            -
                        run_button = gr.Button("💫 Run")
         | 
| 1517 | 
            -
             | 
| 1518 | 
            -
             | 
| 1519 | 
             
                        with gr.Row():
         | 
| 1520 | 
             
                            mask_button = gr.Button("Generate Mask")
         | 
| 1521 | 
             
                            random_mask_button = gr.Button("Square/Circle Mask ")
         | 
| @@ -1603,7 +1614,7 @@ with block as demo: | |
| 1603 | 
             
                        with gr.Tab(elem_classes="feedback", label="Output"):
         | 
| 1604 | 
             
                            result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
         | 
| 1605 |  | 
| 1606 | 
            -
                        target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
         | 
| 1607 |  | 
| 1608 | 
             
                        reset_button = gr.Button("Reset")
         | 
| 1609 |  | 
| @@ -1634,9 +1645,9 @@ with block as demo: | |
| 1634 | 
             
                input_image.upload(
         | 
| 1635 | 
             
                    init_img,
         | 
| 1636 | 
             
                    [input_image, init_type, prompt, aspect_ratio, example_change_times],
         | 
| 1637 | 
            -
                    [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt,  | 
| 1638 | 
             
                ) 
         | 
| 1639 | 
            -
                example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt,  | 
| 1640 |  | 
| 1641 | 
             
                ## vlm and base model dropdown
         | 
| 1642 | 
             
                vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
         | 
| @@ -1666,7 +1677,7 @@ with block as demo: | |
| 1666 | 
             
                     invert_mask_state]
         | 
| 1667 |  | 
| 1668 | 
             
                ## run brushedit
         | 
| 1669 | 
            -
                run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt,  | 
| 1670 |  | 
| 1671 | 
             
                ## mask func
         | 
| 1672 | 
             
                mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
         | 
| @@ -1681,10 +1692,10 @@ with block as demo: | |
| 1681 | 
             
                move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])    
         | 
| 1682 |  | 
| 1683 | 
             
                ## prompt func
         | 
| 1684 | 
            -
                generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt | 
| 1685 |  | 
| 1686 | 
             
                ## reset func
         | 
| 1687 | 
            -
                reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt | 
| 1688 |  | 
| 1689 |  | 
| 1690 | 
             
            demo.launch()
         | 
|  | |
| 337 | 
             
            if vlm_processor != "" and vlm_model != "":
         | 
| 338 | 
             
                vlm_model.to(device)
         | 
| 339 | 
             
            else:
         | 
| 340 | 
            +
                raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
         | 
| 341 |  | 
| 342 |  | 
| 343 | 
             
            ## init base model
         | 
|  | |
| 504 | 
             
                    dilated_mask = np.zeros_like(binary_mask, dtype=bool)
         | 
| 505 | 
             
                    dilated_mask[ellipse_mask] = True
         | 
| 506 | 
             
                else:
         | 
| 507 | 
            +
                    ValueError("dilation_type must be 'square' or 'ellipse'")
         | 
| 508 |  | 
| 509 | 
             
                # use binary dilation
         | 
| 510 | 
             
                dilated_mask =  np.uint8(dilated_mask[:,:,np.newaxis]) * 255
         | 
|  | |
| 637 | 
             
                        image_pil = input_image["background"].convert("RGB")
         | 
| 638 | 
             
                        original_image = np.array(image_pil)
         | 
| 639 | 
             
                if prompt is None or prompt == "":
         | 
| 640 | 
            +
                    if target_prompt is None or target_prompt == "":
         | 
| 641 | 
            +
                        raise gr.Error("Please input your instructions, e.g., remove the xxx")
         | 
| 642 |  | 
| 643 | 
             
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 644 | 
             
                input_mask = np.asarray(alpha_mask)
         | 
|  | |
| 688 | 
             
                        original_mask = input_mask
         | 
| 689 |  | 
| 690 |  | 
| 691 | 
            +
                ## inpainting directly if target_prompt is not None
         | 
| 692 | 
             
                if category is not None:
         | 
| 693 | 
            +
                    pass
         | 
| 694 | 
            +
                elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
         | 
| 695 | 
            +
                    pass
         | 
| 696 | 
             
                else:
         | 
| 697 | 
            +
                    try:
         | 
| 698 | 
            +
                        category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
         | 
| 699 | 
            +
                    except Exception as e:
         | 
| 700 | 
            +
                        raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
         | 
| 701 |  | 
| 702 | 
            +
             | 
| 703 | 
             
                if original_mask is not None:
         | 
| 704 | 
             
                    original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
         | 
| 705 | 
             
                else:
         | 
| 706 | 
            +
                    try:
         | 
| 707 | 
            +
                        object_wait_for_edit = vlm_response_object_wait_for_edit(
         | 
| 708 | 
             
                                                            vlm_processor, 
         | 
| 709 | 
             
                                                            vlm_model, 
         | 
| 710 | 
             
                                                            original_image,
         | 
|  | |
| 712 | 
             
                                                            prompt,
         | 
| 713 | 
             
                                                            device)
         | 
| 714 |  | 
| 715 | 
            +
                        original_mask = vlm_response_mask(vlm_processor,
         | 
| 716 | 
            +
                                                        vlm_model,
         | 
| 717 | 
            +
                                                        category, 
         | 
| 718 | 
            +
                                                        original_image, 
         | 
| 719 | 
            +
                                                        prompt, 
         | 
| 720 | 
            +
                                                        object_wait_for_edit, 
         | 
| 721 | 
            +
                                                        sam,
         | 
| 722 | 
            +
                                                        sam_predictor,
         | 
| 723 | 
            +
                                                        sam_automask_generator,
         | 
| 724 | 
            +
                                                        groundingdino_model,
         | 
| 725 | 
            +
                                                        device)
         | 
| 726 | 
            +
                    except Exception as e:
         | 
| 727 | 
            +
                        raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
         | 
| 728 | 
            +
             | 
| 729 | 
             
                if original_mask.ndim == 2:
         | 
| 730 | 
             
                    original_mask = original_mask[:,:,None]
         | 
| 731 |  | 
| 732 |  | 
| 733 | 
            +
                if target_prompt is not None and len(target_prompt) >= 1:
         | 
| 734 | 
            +
                    prompt_after_apply_instruction = target_prompt
         | 
| 735 | 
            +
                    
         | 
| 736 | 
            +
                else:
         | 
| 737 | 
            +
                    try:
         | 
| 738 | 
            +
                        prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
         | 
| 739 | 
             
                                                                                vlm_processor, 
         | 
| 740 | 
             
                                                                                vlm_model, 
         | 
| 741 | 
             
                                                                                original_image,
         | 
| 742 | 
             
                                                                                prompt,
         | 
| 743 | 
             
                                                                                device)
         | 
| 744 | 
            +
                    except Exception as e:
         | 
| 745 | 
            +
                        raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
         | 
| 746 |  | 
| 747 | 
             
                generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
         | 
| 748 |  | 
|  | |
| 772 | 
             
                # image[3].save(f"outputs/image_edit_{uuid}_3.png")
         | 
| 773 | 
             
                # mask_image.save(f"outputs/mask_{uuid}.png")
         | 
| 774 | 
             
                # masked_image.save(f"outputs/masked_image_{uuid}.png")
         | 
| 775 | 
            +
                # gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=16)
         | 
| 776 | 
            +
                return image, [mask_image], [masked_image], prompt, '', False
         | 
| 777 |  | 
| 778 |  | 
| 779 | 
             
            def generate_target_prompt(input_image, 
         | 
|  | |
| 789 | 
             
                                                                        original_image,
         | 
| 790 | 
             
                                                                        prompt,
         | 
| 791 | 
             
                                                                        device)
         | 
| 792 | 
            +
                return prompt_after_apply_instruction
         | 
| 793 |  | 
| 794 |  | 
| 795 | 
             
            def process_mask(input_image, 
         | 
|  | |
| 1430 | 
             
                    original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
         | 
| 1431 | 
             
                    return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "", "Custom resolution", False, False, example_change_times
         | 
| 1432 | 
             
                else:
         | 
| 1433 | 
            +
                    return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
         | 
| 1434 |  | 
| 1435 |  | 
| 1436 | 
             
            def reset_func(input_image, 
         | 
|  | |
| 1438 | 
             
                           original_mask, 
         | 
| 1439 | 
             
                           prompt, 
         | 
| 1440 | 
             
                           target_prompt, 
         | 
| 1441 | 
            +
                           ):
         | 
| 1442 | 
             
                input_image = None
         | 
| 1443 | 
             
                original_image = None
         | 
| 1444 | 
             
                original_mask = None
         | 
|  | |
| 1447 | 
             
                masked_gallery = []
         | 
| 1448 | 
             
                result_gallery = []
         | 
| 1449 | 
             
                target_prompt = ''
         | 
|  | |
| 1450 | 
             
                if torch.cuda.is_available():
         | 
| 1451 | 
             
                    torch.cuda.empty_cache()
         | 
| 1452 | 
            +
                return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
         | 
| 1453 |  | 
| 1454 |  | 
| 1455 | 
             
            def update_example(example_type, 
         | 
|  | |
| 1472 | 
             
                original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
         | 
| 1473 | 
             
                aspect_ratio = "Custom resolution"
         | 
| 1474 | 
             
                example_change_times += 1
         | 
| 1475 | 
            +
                return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
         | 
| 1476 | 
            +
             | 
| 1477 |  | 
| 1478 | 
             
            block = gr.Blocks(
         | 
| 1479 | 
             
                    theme=gr.themes.Soft(
         | 
|  | |
| 1513 | 
             
                                sources=["upload"],
         | 
| 1514 | 
             
                                )
         | 
| 1515 |  | 
| 1516 | 
            +
                        prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
         | 
| 1517 | 
            +
                        run_button = gr.Button("💫 Run")
         | 
| 1518 |  | 
| 1519 | 
             
                        vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
         | 
| 1520 | 
             
                        with gr.Group():    
         | 
|  | |
| 1527 | 
             
                        aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
         | 
| 1528 | 
             
                        resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
         | 
| 1529 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1530 | 
             
                        with gr.Row():
         | 
| 1531 | 
             
                            mask_button = gr.Button("Generate Mask")
         | 
| 1532 | 
             
                            random_mask_button = gr.Button("Square/Circle Mask ")
         | 
|  | |
| 1614 | 
             
                        with gr.Tab(elem_classes="feedback", label="Output"):
         | 
| 1615 | 
             
                            result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
         | 
| 1616 |  | 
| 1617 | 
            +
                        # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
         | 
| 1618 |  | 
| 1619 | 
             
                        reset_button = gr.Button("Reset")
         | 
| 1620 |  | 
|  | |
| 1645 | 
             
                input_image.upload(
         | 
| 1646 | 
             
                    init_img,
         | 
| 1647 | 
             
                    [input_image, init_type, prompt, aspect_ratio, example_change_times],
         | 
| 1648 | 
            +
                    [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
         | 
| 1649 | 
             
                ) 
         | 
| 1650 | 
            +
                example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
         | 
| 1651 |  | 
| 1652 | 
             
                ## vlm and base model dropdown
         | 
| 1653 | 
             
                vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
         | 
|  | |
| 1677 | 
             
                     invert_mask_state]
         | 
| 1678 |  | 
| 1679 | 
             
                ## run brushedit
         | 
| 1680 | 
            +
                run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
         | 
| 1681 |  | 
| 1682 | 
             
                ## mask func
         | 
| 1683 | 
             
                mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
         | 
|  | |
| 1692 | 
             
                move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])    
         | 
| 1693 |  | 
| 1694 | 
             
                ## prompt func
         | 
| 1695 | 
            +
                generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
         | 
| 1696 |  | 
| 1697 | 
             
                ## reset func
         | 
| 1698 | 
            +
                reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
         | 
| 1699 |  | 
| 1700 |  | 
| 1701 | 
             
            demo.launch()
         | 
    	
        app/src/vlm_pipeline.py
    CHANGED
    
    | @@ -98,10 +98,12 @@ def vlm_response_editing_type(vlm_processor, | |
| 98 | 
             
                    messages = create_editing_category_messages_qwen2(editing_prompt)
         | 
| 99 | 
             
                    response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
         | 
| 100 |  | 
| 101 | 
            -
                 | 
| 102 | 
            -
                     | 
| 103 | 
            -
                         | 
| 104 | 
            -
             | 
|  | |
|  | |
| 105 |  | 
| 106 |  | 
| 107 | 
             
            ### response object to be edited        
         | 
| @@ -206,17 +208,21 @@ def vlm_response_prompt_after_apply_instruction(vlm_processor, | |
| 206 | 
             
                                                            image, 
         | 
| 207 | 
             
                                                            editing_prompt,
         | 
| 208 | 
             
                                                            device):
         | 
| 209 | 
            -
             | 
| 210 | 
            -
             | 
| 211 | 
            -
                     | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
                     | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
                     | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
             | 
|  | |
|  | |
|  | |
|  | 
|  | |
| 98 | 
             
                    messages = create_editing_category_messages_qwen2(editing_prompt)
         | 
| 99 | 
             
                    response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
         | 
| 100 |  | 
| 101 | 
            +
                try:
         | 
| 102 | 
            +
                    for category_name in ["Addition","Remove","Local","Global","Background"]:
         | 
| 103 | 
            +
                        if category_name.lower() in response_str.lower():
         | 
| 104 | 
            +
                            return category_name
         | 
| 105 | 
            +
                except Exception as e:
         | 
| 106 | 
            +
                    raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
         | 
| 107 |  | 
| 108 |  | 
| 109 | 
             
            ### response object to be edited        
         | 
|  | |
| 208 | 
             
                                                            image, 
         | 
| 209 | 
             
                                                            editing_prompt,
         | 
| 210 | 
             
                                                            device):
         | 
| 211 | 
            +
                                                            
         | 
| 212 | 
            +
                try:
         | 
| 213 | 
            +
                    if isinstance(vlm_model, OpenAI):
         | 
| 214 | 
            +
                        base64_image = encode_image(image)  
         | 
| 215 | 
            +
                        messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
         | 
| 216 | 
            +
                        response_str = run_gpt4o_vl_inference(vlm_model, messages)
         | 
| 217 | 
            +
                    elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
         | 
| 218 | 
            +
                        messages = create_apply_editing_messages_llava(editing_prompt)
         | 
| 219 | 
            +
                        response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
         | 
| 220 | 
            +
                    elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
         | 
| 221 | 
            +
                        base64_image = encode_image(image)  
         | 
| 222 | 
            +
                        messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
         | 
| 223 | 
            +
                        response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
         | 
| 224 | 
            +
                    else:
         | 
| 225 | 
            +
                        raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
         | 
| 226 | 
            +
                except Exception as e:
         | 
| 227 | 
            +
                    raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
         | 
| 228 | 
            +
                return response_str
         | 
