Spaces:
Paused
Paused
| import os | |
| import requests | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import login | |
| from diffusers.utils import load_image | |
| from models.transformer_sd3 import SD3Transformer2DModel | |
| from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline | |
| # ---------------------------- | |
| # Step 1: Download IP Adapter if not exists | |
| # ---------------------------- | |
| url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin" | |
| file_path = "ip-adapter.bin" | |
| if not os.path.exists(file_path): | |
| print("File not found, downloading...") | |
| response = requests.get(url, stream=True) | |
| with open(file_path, "wb") as file: | |
| for chunk in response.iter_content(chunk_size=1024): | |
| if chunk: | |
| file.write(chunk) | |
| print("Download completed!") | |
| # ---------------------------- | |
| # Step 2: Hugging Face Login | |
| # ---------------------------- | |
| token = os.getenv("HF_TOKEN") | |
| if not token: | |
| raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.") | |
| login(token=token) | |
| model_path = 'stabilityai/stable-diffusion-3.5-large' | |
| ip_adapter_path = './ip-adapter.bin' | |
| image_encoder_path = "google/siglip-so400m-patch14-384" | |
| transformer = SD3Transformer2DModel.from_pretrained( | |
| model_path, subfolder="transformer", torch_dtype=torch.bfloat16 | |
| ) | |
| pipe = StableDiffusion3Pipeline.from_pretrained( | |
| model_path, transformer=transformer, torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| pipe.init_ipadapter( | |
| ip_adapter_path=ip_adapter_path, | |
| image_encoder_path=image_encoder_path, | |
| nb_token=64, | |
| ) | |
| # ---------------------------- | |
| # Step 6: Gradio Function | |
| # ---------------------------- | |
| def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale): | |
| ref_img = load_image(ref_img.name).convert('RGB') | |
| # please note that SD3.5 Large is sensitive to highres generation like 1536x1536 | |
| image = pipe( | |
| width=1024, | |
| height=1024, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=24, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator("cuda").manual_seed(42), | |
| clip_image=ref_img, | |
| ipadapter_scale=ipadapter_scale, | |
| ).images[0] | |
| return image | |
| # ---------------------------- | |
| # Step 7: Gradio Interface | |
| # ---------------------------- | |
| prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt") | |
| negative_prompt_box = gr.Textbox(label="Negative Prompt", placeholder="Enter your image generation prompt",value="lowres, low quality, worst quality") | |
| ref_img = gr.File(label="Upload Reference Image") | |
| guidance_slider = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=2, | |
| maximum=16, | |
| value=7, | |
| step=0.5, | |
| info="Controls adherence to the text prompt" | |
| ) | |
| ipadapter_slider = gr.Slider( | |
| label="IP-Adapter Scale", | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.1, | |
| info="Controls influence of the image prompt" | |
| ) | |
| interface = gr.Interface( | |
| fn=gui_generation, | |
| inputs=[prompt_box,negative_prompt_box, ref_img, guidance_slider, ipadapter_slider], | |
| outputs="image", | |
| title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter", | |
| description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter." | |
| ) | |
| # ---------------------------- | |
| # Step 8: Launch Gradio App | |
| # ---------------------------- | |
| interface.launch() | |