File size: 8,377 Bytes
5a48378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a88469f
c917731
5a48378
 
 
 
 
 
 
 
 
 
9300754
 
c917731
 
 
 
a88469f
a457b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a49b272
5a48378
 
 
 
 
 
c917731
 
 
 
 
 
 
 
679d35b
5a48378
35a83b0
679d35b
f1df3c1
a457b7a
5a48378
 
 
 
6922a91
a88469f
a457b7a
 
5a48378
a88469f
5a48378
 
 
 
c917731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a48378
c917731
 
5a48378
 
a88469f
5a48378
 
 
c917731
 
 
 
 
5a48378
 
c917731
5a48378
 
 
9300754
c917731
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr

from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora

# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"    
lora_base_path = "./models"

# Environment variable for API token (set this in your Hugging Face space settings)
API_TOKEN = os.environ.get("HF_TOKEN")

pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")

def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

# Define the Gradio interface with token verification
@spaces.GPU()
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, api_token=""):
    # Check if API token is required and valid
    if API_TOKEN and api_token != API_TOKEN:
        return "ERROR: Invalid API token. Please provide a valid token to generate images."
    
    try:
        # Ensure height and width are divisible by 8
        height = int(height)
        width = int(width)
        
        if height % 8 != 0 or width % 8 != 0:
            # Adjust to nearest multiple of 8
            height = (height // 8) * 8
            width = (width // 8) * 8
            print(f"Dimensions adjusted to be divisible by 8: {height}x{width}")
        
        # Set the control type
        if control_type == "Ghibli":
            lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
        set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
        
        # Process the image
        spatial_imgs = [spatial_img] if spatial_img else []
        image = pipe(
            prompt,
            height=height,
            width=width,
            guidance_scale=3.5,
            num_inference_steps=25,
            max_sequence_length=512,
            generator=torch.Generator("cpu").manual_seed(seed), 
            subject_images=[],
            spatial_images=spatial_imgs,
            cond_size=512,
        ).images[0]
        clear_cache(pipe.transformer)
        return image
    except Exception as e:
        error_message = f"Error during generation: {str(e)}"
        print(error_message)
        return f"ERROR: {error_message}"

# Define the Gradio interface components
control_types = ["Ghibli"]

# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
    
    # Only show token field if API token is required
    if API_TOKEN:
        gr.Markdown("⚠️ **AUTHENTICATION REQUIRED**: Please enter your API token to use this service.")
        api_token = gr.Textbox(label="API Token", type="password", value="")
    else:
        api_token = gr.Textbox(visible=False, value="")  # Hidden field with empty value
    
    gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
    gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
    
    gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
    gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
    gr.Markdown("**NOTE**: Both height and width must be divisible by 8. Values will be automatically adjusted if needed.")

    with gr.Tab("Ghibli Condition Generation"):
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
                spatial_img = gr.Image(label="Ghibli Image", type="pil")
                height = gr.Slider(minimum=256, maximum=1024, step=8, label="Height", value=768)
                width = gr.Slider(minimum=256, maximum=1024, step=8, label="Width", value=768)
                seed = gr.Number(label="Seed", value=42)
                control_type = gr.Dropdown(choices=control_types, label="Control Type", value="Ghibli")
                single_generate_btn = gr.Button("Generate Image")
            with gr.Column():
                single_output_image = gr.Image(label="Generated Image")

        # Set up examples (with token automatically added if present)
        example_inputs = [prompt, spatial_img, height, width, seed, control_type]
        if API_TOKEN:
            # Add token to examples for convenience
            example_data = [
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", API_TOKEN],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", API_TOKEN],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli", API_TOKEN],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli", API_TOKEN],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli", API_TOKEN],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli", API_TOKEN],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli", API_TOKEN],
            ]
            example_inputs.append(api_token)
        else:
            # Use examples without token
            example_data = [
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli"],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli"],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli"],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli"],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli"],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli"],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli"],
                ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli"],
            ]

        gr.Examples(
            examples=example_data,
            inputs=example_inputs,
            outputs=single_output_image,
            fn=single_condition_generate_image,
            cache_examples=False,
            label="Single Condition Examples"
        )

    # Link the buttons to the functions with API token included
    inputs = [prompt, spatial_img, height, width, seed, control_type]
    if API_TOKEN:
        inputs.append(api_token)
    
    single_generate_btn.click(
        single_condition_generate_image,
        inputs=inputs,
        outputs=single_output_image
    )

# Launch the Gradio app
demo.queue().launch()