Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -14,9 +14,7 @@ from torchvision import transforms | |
| 14 | 
             
            from torchvision.transforms import functional as TF
         | 
| 15 | 
             
            from tqdm import trange
         | 
| 16 | 
             
            from transformers import CLIPProcessor, CLIPModel
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            # from diffusion_models import Diffusion  # Swapped Diffusion model for DALL·E 2 based model - REMOVED
         | 
| 19 | 
            -
            from huggingface_hub import hf_hub_url, cached_download
         | 
| 20 | 
             
            import gradio as gr  # 🎨 The magic canvas for AI-powered image generation!
         | 
| 21 | 
             
            import math
         | 
| 22 |  | 
| @@ -130,8 +128,9 @@ def ddpm_sample(model, x, steps, **kwargs): | |
| 130 | 
             
            # NOTE: The HuggingFace URLs you provided might be placeholders.
         | 
| 131 | 
             
            # Make sure these point to the correct model files.
         | 
| 132 | 
             
            try:
         | 
| 133 | 
            -
                 | 
| 134 | 
            -
                 | 
|  | |
| 135 | 
             
            except Exception as e:
         | 
| 136 | 
             
                print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
         | 
| 137 | 
             
                print("Using placeholder models which will not produce good images.")
         | 
| @@ -213,7 +212,7 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method | |
| 213 | 
             
                    target_embeds.append(text_embed)
         | 
| 214 | 
             
                    weights.append(1.0)
         | 
| 215 |  | 
| 216 | 
            -
                #  | 
| 217 | 
             
                # Assign a default weight for image prompts
         | 
| 218 | 
             
                image_prompt_weight = 1.0
         | 
| 219 | 
             
                for image_path in images:
         | 
| @@ -250,7 +249,6 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method | |
| 250 | 
             
                    return v
         | 
| 251 |  | 
| 252 | 
             
                # 🎞️ Run the sampler to generate images
         | 
| 253 | 
            -
                # **FIXED**: Call sampling functions directly without the 'sampling.' prefix
         | 
| 254 | 
             
                def run(x, steps):
         | 
| 255 | 
             
                    if method == 'ddpm':
         | 
| 256 | 
             
                        return ddpm_sample(cfg_model_fn, x, steps)
         | 
| @@ -310,7 +308,6 @@ iface = gr.Interface( | |
| 310 | 
             
                fn=gen_ims,
         | 
| 311 | 
             
                inputs=[
         | 
| 312 | 
             
                    gr.Textbox(label="Text prompt"),
         | 
| 313 | 
            -
                    # **FIXED**: Removed deprecated 'optional=True' argument
         | 
| 314 | 
             
                    gr.Image(label="Image prompt", type='filepath')
         | 
| 315 | 
             
                ],
         | 
| 316 | 
             
                outputs=gr.Image(type="pil", label="Generated Image"),
         | 
|  | |
| 14 | 
             
            from torchvision.transforms import functional as TF
         | 
| 15 | 
             
            from tqdm import trange
         | 
| 16 | 
             
            from transformers import CLIPProcessor, CLIPModel
         | 
| 17 | 
            +
            from huggingface_hub import hf_hub_download # FIXED: Replaced deprecated function
         | 
|  | |
|  | |
| 18 | 
             
            import gradio as gr  # 🎨 The magic canvas for AI-powered image generation!
         | 
| 19 | 
             
            import math
         | 
| 20 |  | 
|  | |
| 128 | 
             
            # NOTE: The HuggingFace URLs you provided might be placeholders.
         | 
| 129 | 
             
            # Make sure these point to the correct model files.
         | 
| 130 | 
             
            try:
         | 
| 131 | 
            +
                # FIXED: Using the new hf_hub_download function with keyword arguments
         | 
| 132 | 
            +
                vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
         | 
| 133 | 
            +
                diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
         | 
| 134 | 
             
            except Exception as e:
         | 
| 135 | 
             
                print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
         | 
| 136 | 
             
                print("Using placeholder models which will not produce good images.")
         | 
|  | |
| 212 | 
             
                    target_embeds.append(text_embed)
         | 
| 213 | 
             
                    weights.append(1.0)
         | 
| 214 |  | 
| 215 | 
            +
                # Correctly process image prompts from Gradio
         | 
| 216 | 
             
                # Assign a default weight for image prompts
         | 
| 217 | 
             
                image_prompt_weight = 1.0
         | 
| 218 | 
             
                for image_path in images:
         | 
|  | |
| 249 | 
             
                    return v
         | 
| 250 |  | 
| 251 | 
             
                # 🎞️ Run the sampler to generate images
         | 
|  | |
| 252 | 
             
                def run(x, steps):
         | 
| 253 | 
             
                    if method == 'ddpm':
         | 
| 254 | 
             
                        return ddpm_sample(cfg_model_fn, x, steps)
         | 
|  | |
| 308 | 
             
                fn=gen_ims,
         | 
| 309 | 
             
                inputs=[
         | 
| 310 | 
             
                    gr.Textbox(label="Text prompt"),
         | 
|  | |
| 311 | 
             
                    gr.Image(label="Image prompt", type='filepath')
         | 
| 312 | 
             
                ],
         | 
| 313 | 
             
                outputs=gr.Image(type="pil", label="Generated Image"),
         | 
