File size: 5,906 Bytes
56a2ae1
 
 
 
 
 
 
1f437b8
 
56a2ae1
 
 
 
8bb9c95
 
 
 
 
 
 
675ceb3
56a2ae1
 
 
 
 
333b9ce
806fa90
 
56a2ae1
92bd220
 
 
 
 
 
 
 
56a2ae1
 
 
 
 
 
 
 
 
a442a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56a2ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00bb3aa
56a2ae1
 
 
 
 
 
 
 
1f437b8
56a2ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f437b8
10afe14
56a2ae1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr

from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
from huggingface_hub import login
token = os.getenv("hugface_token")
if token:
    login(token=token)
    print("Login successful!")
else:
    print("hugface_token not found in environment variables.")

# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"    
lora_base_path = "./models"


# pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
# try:
pipe.to("cuda")
# 在初始化模型后立即清理GPU缓存和启用注意力切片
# torch.cuda.empty_cache()  # 清理GPU缓存
# pipe.enable_attention_slicing()  # 启用注意力切片以减少内存使用
# except torch.cuda.OutOfMemoryError:
#     print("CUDA out of memory. Switching to CPU.")
#     pipe.to("cpu")

def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
    # Set the control type
    if control_type == "Ghibli":
        lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
        set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
        
        # Process the image
        spatial_imgs = [spatial_img] if spatial_img else []
        image = pipe(
            prompt,
            height=int(height),
            width=int(width),
            guidance_scale=3.5,
            num_inference_steps=25,
            max_sequence_length=512,
            generator=torch.Generator("cpu").manual_seed(seed), 
            subject_images=[],
            spatial_images=spatial_imgs,
            cond_size=512,
        ).images[0]
        clear_cache(pipe.transformer)
        return image

# Define the Gradio interface components
control_types = ["Ghibli"]

# Example data
single_examples = [
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli"],
]


# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("# Ghibli Studio Control Image Generation")
    gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
    gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
    
    gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
    # gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")

    with gr.Tab("Ghibli Condition Generation"):
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
                spatial_img = gr.Image(label="Ghibli Image", type="pil")  # 上传图像文件
                height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
                width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
                seed = gr.Number(label="Seed", value=42)
                control_type = gr.Dropdown(choices=control_types, label="Control Type")
                single_generate_btn = gr.Button("Generate Image")
            with gr.Column():
                single_output_image = gr.Image(label="Generated Image")

        # Add examples for Single Condition Generation
        gr.Examples(
            examples=single_examples,
            inputs=[prompt, spatial_img, height, width, seed, control_type],
            outputs=single_output_image,
            fn=single_condition_generate_image,
            cache_examples=False,  # 缓存示例结果以加快加载速度
            label="Single Condition Examples"
        )

    # Link the buttons to the functions
    single_generate_btn.click(
        single_condition_generate_image,
        inputs=[prompt, spatial_img, height, width, seed, control_type],
        outputs=single_output_image
    )


# Launch the Gradio app
demo.queue().launch()