chriswu25 commited on
Commit
3f1893b
·
verified ·
1 Parent(s): 04e3a10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -109
app.py CHANGED
@@ -1,110 +1,106 @@
1
- import spaces
2
- import os
3
- import json
4
- import time
5
- import torch
6
- from PIL import Image
7
- from tqdm import tqdm
8
- import gradio as gr
9
-
10
- from safetensors.torch import save_file
11
- from src.pipeline import FluxPipeline
12
- from src.transformer_flux import FluxTransformer2DModel
13
- from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
-
15
- # Initialize the image processor
16
- base_path = "black-forest-labs/FLUX.1-dev"
17
- lora_base_path = "./models"
18
-
19
-
20
- pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
21
- transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
22
- pipe.transformer = transformer
23
- pipe.to("cuda")
24
-
25
- def clear_cache(transformer):
26
- for name, attn_processor in transformer.attn_processors.items():
27
- attn_processor.bank_kv.clear()
28
-
29
- # Define the Gradio interface
30
- @spaces.GPU()
31
- def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
32
- # Set the control type
33
- if control_type == "Ghibli":
34
- lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
35
- set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
36
-
37
- # Process the image
38
- spatial_imgs = [spatial_img] if spatial_img else []
39
- image = pipe(
40
- prompt,
41
- height=int(height),
42
- width=int(width),
43
- guidance_scale=3.5,
44
- num_inference_steps=25,
45
- max_sequence_length=512,
46
- generator=torch.Generator("cpu").manual_seed(seed),
47
- subject_images=[],
48
- spatial_images=spatial_imgs,
49
- cond_size=512,
50
- ).images[0]
51
- clear_cache(pipe.transformer)
52
- return image
53
-
54
- # Define the Gradio interface components
55
- control_types = ["Ghibli"]
56
-
57
- # Example data
58
- single_examples = [
59
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli"],
60
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli"],
61
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli"],
62
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli"],
63
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli"],
64
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli"],
65
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli"],
66
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli"],
67
- ]
68
-
69
-
70
- # Create the Gradio Blocks interface
71
- with gr.Blocks() as demo:
72
- gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
73
- gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
74
- gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
75
-
76
- gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
77
- gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
78
-
79
- with gr.Tab("Ghibli Condition Generation"):
80
- with gr.Row():
81
- with gr.Column():
82
- prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
83
- spatial_img = gr.Image(label="Ghibli Image", type="pil") # 上传图像文件
84
- height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
85
- width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
86
- seed = gr.Number(label="Seed", value=42)
87
- control_type = gr.Dropdown(choices=control_types, label="Control Type")
88
- single_generate_btn = gr.Button("Generate Image")
89
- with gr.Column():
90
- single_output_image = gr.Image(label="Generated Image")
91
-
92
- # Add examples for Single Condition Generation
93
- gr.Examples(
94
- examples=single_examples,
95
- inputs=[prompt, spatial_img, height, width, seed, control_type],
96
- outputs=single_output_image,
97
- fn=single_condition_generate_image,
98
- cache_examples=False, # 缓存示例结果��加快加载速度
99
- label="Single Condition Examples"
100
- )
101
-
102
- # Link the buttons to the functions
103
- single_generate_btn.click(
104
- single_condition_generate_image,
105
- inputs=[prompt, spatial_img, height, width, seed, control_type],
106
- outputs=single_output_image
107
- )
108
-
109
- # Launch the Gradio app
110
  demo.queue().launch()
 
1
+ import spaces
2
+ import os
3
+ import json
4
+ import time
5
+ import torch
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import gradio as gr
9
+
10
+ from safetensors.torch import save_file
11
+ from src.pipeline import FluxPipeline
12
+ from src.transformer_flux import FluxTransformer2DModel
13
+ from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
+
15
+ # Initialize the image processor
16
+ base_path = "black-forest-labs/FLUX.1-dev"
17
+ lora_base_path = "./models"
18
+
19
+ pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
20
+ transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
21
+ pipe.transformer = transformer
22
+ # 移除 pipe.to("cuda"),默认使用CPU
23
+
24
+ def clear_cache(transformer):
25
+ for name, attn_processor in transformer.attn_processors.items():
26
+ attn_processor.bank_kv.clear()
27
+
28
+ # Define the Gradio interface
29
+ @spaces.GPU() # 改为 @spaces.CPU() 或直接移除,因为免费层没有GPU
30
+ def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
31
+ # Set the control type
32
+ if control_type == "Ghibli":
33
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
34
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
35
+
36
+ # Process the image
37
+ spatial_imgs = [spatial_img] if spatial_img else []
38
+ image = pipe(
39
+ prompt,
40
+ height=int(height),
41
+ width=int(width),
42
+ guidance_scale=3.5,
43
+ num_inference_steps=15, # 减少步数以适应CPU
44
+ max_sequence_length=512,
45
+ generator=torch.Generator("cpu").manual_seed(seed),
46
+ subject_images=[],
47
+ spatial_images=spatial_imgs,
48
+ cond_size=512,
49
+ ).images[0]
50
+ clear_cache(pipe.transformer)
51
+ return image
52
+
53
+ # Define the Gradio interface components
54
+ control_types = ["Ghibli"]
55
+
56
+ # Example data (调整分辨率以适应CPU)
57
+ single_examples = [
58
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 512, 512, 5, "Ghibli"],
59
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 512, 512, 42, "Ghibli"],
60
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 512, 512, 1, "Ghibli"],
61
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 512, 512, 1, "Ghibli"],
62
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 512, 512, 1, "Ghibli"],
63
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 512, 512, 1, "Ghibli"],
64
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 512, 512, 1, "Ghibli"],
65
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 512, 512, 1, "Ghibli"],
66
+ ]
67
+
68
+ # Create the Gradio Blocks interface
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
71
+ gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
72
+ gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Running on CPU due to free tier limitations; expect slower performance and lower resolution.)")
73
+
74
+ gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
75
+ gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
76
+
77
+ with gr.Tab("Ghibli Condition Generation"):
78
+ with gr.Row():
79
+ with gr.Column():
80
+ prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
81
+ spatial_img = gr.Image(label="Ghibli Image", type="pil")
82
+ height = gr.Slider(minimum=256, maximum=512, step=64, label="Height", value=512) # 限制最大分辨率
83
+ width = gr.Slider(minimum=256, maximum=512, step=64, label="Width", value=512) # 限制最大分辨率
84
+ seed = gr.Number(label="Seed", value=42)
85
+ control_type = gr.Dropdown(choices=control_types, label="Control Type")
86
+ single_generate_btn = gr.Button("Generate Image")
87
+ with gr.Column():
88
+ single_output_image = gr.Image(label="Generated Image")
89
+
90
+ gr.Examples(
91
+ examples=single_examples,
92
+ inputs=[prompt, spatial_img, height, width, seed, control_type],
93
+ outputs=single_output_image,
94
+ fn=single_condition_generate_image,
95
+ cache_examples=False,
96
+ label="Single Condition Examples"
97
+ )
98
+
99
+ single_generate_btn.click(
100
+ single_condition_generate_image,
101
+ inputs=[prompt, spatial_img, height, width, seed, control_type],
102
+ outputs=single_output_image
103
+ )
104
+
105
+ # Launch the Gradio app
 
 
 
 
106
  demo.queue().launch()