File size: 1,802 Bytes
683afc3
 
8d2ed6a
4fbc46c
c1497a6
4fbc46c
c1497a6
4fbc46c
c1497a6
683afc3
c1497a6
ca7110d
c1497a6
92fa744
c1497a6
8d2ed6a
 
683afc3
8d2ed6a
 
ca7110d
c1497a6
 
683afc3
8d2ed6a
 
c1497a6
 
 
 
8d2ed6a
683afc3
8d2ed6a
 
 
 
 
 
683afc3
c1497a6
8d2ed6a
 
 
683afc3
8d2ed6a
 
683afc3
8d2ed6a
 
 
683afc3
 
8d2ed6a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import os
from huggingface_hub import login

# Log in with your Hugging Face token (assumed stored in HF_TOKEN)
token = os.getenv("HF_TOKEN")
login(token=token)

# Model IDs for the base Stable Diffusion model and ControlNet variant
model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
controlnet_id = "lllyasviel/control_v11p_sd15_inpaint"  # Make sure this ControlNet is compatible

# Load ControlNet model and pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
    model_id,
    controlnet=controlnet,
    torch_dtype=torch.float32
)
pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline


# Define the Gradio interface function
def generate_image(prompt, reference_image):
    # Ensure the reference image is in the correct format
    reference_image = reference_image.convert("RGB").resize((512, 512))

    # Generate the image with ControlNet
    generated_image = pipeline(
        prompt=prompt,
        image=reference_image,
        controlnet_conditioning_scale=1.0,
        guidance_scale=7.5,
        num_inference_steps=50
    ).images[0]
    return generated_image


# Set up Gradio interface
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(type="pil", label="Reference Image (Style)")
    ],
    outputs="image",
    title="Image Generation with Reference-Only Style Transfer",
    description="Generate an image based on a text prompt and style reference image using Stable Diffusion 3.5 with ControlNet (reference-only mode)."
)

# Launch the Gradio interface
interface.launch()