olaaoamgo commited on
Commit
2e75204
·
verified ·
1 Parent(s): a0df9d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -59
app.py CHANGED
@@ -9,12 +9,12 @@ from src_inference.pipeline import FluxPipeline
9
  from src_inference.lora_helper import set_single_lora
10
 
11
  BASE_PATH = "black-forest-labs/FLUX.1-dev"
12
- LOCAL_LORA_DIR = "./LoRAs"
13
- CUSTOM_LORA_DIR = "./Custom_LoRAs"
14
  os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
15
  os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
16
 
17
- # ------------------ DEVICE SETUP (✅ supports CPU-only Spaces) ------------------ #
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
20
  print(f"🚀 Running on device: {device}")
@@ -37,32 +37,16 @@ pipe = FluxPipeline.from_pretrained(
37
  set_single_lora(pipe.transformer, omni_consistency_path,
38
  lora_weights=[1], cond_size=512)
39
 
40
- # ------------------ Download LoRA Styles ------------------ #
41
- def download_all_loras():
42
- lora_names = [
43
- "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
44
- "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line",
45
- "Macaron", "Oil_Painting", "Origami", "Paper_Cutting",
46
- "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty",
47
- "Snoopy", "Van_Gogh", "Vector"
48
- ]
49
- for name in lora_names:
50
- hf_hub_download(
51
- repo_id="showlab/OmniConsistency",
52
- filename=f"LoRAs/{name}_rank128_bf16.safetensors",
53
- local_dir=LOCAL_LORA_DIR,
54
- )
55
- download_all_loras()
56
-
57
  def clear_cache(transformer):
58
  for _, attn_processor in transformer.attn_processors.items():
59
  attn_processor.bank_kv.clear()
60
 
61
- # ------------------ Generation Function ------------------ #
62
- @spaces.GPU() # Will fallback silently if GPU not available
63
  def generate_image(
64
- lora_name,
65
- custom_repo_id,
66
  prompt,
67
  uploaded_image,
68
  width, height,
@@ -73,6 +57,7 @@ def generate_image(
73
  width, height = int(width), int(height)
74
  generator = torch.Generator("cpu").manual_seed(seed)
75
 
 
76
  if custom_repo_id and custom_repo_id.strip():
77
  repo_id = custom_repo_id.strip()
78
  try:
@@ -91,8 +76,12 @@ def generate_image(
91
  except Exception as e:
92
  raise gr.Error(f"Load custom LoRA failed: {e}")
93
  else:
94
- lora_path = os.path.join(
95
- f"{LOCAL_LORA_DIR}/LoRAs", f"{lora_name}_rank128_bf16.safetensors"
 
 
 
 
96
  )
97
 
98
  pipe.unload_lora_weights()
@@ -104,13 +93,14 @@ def generate_image(
104
  except Exception as e:
105
  raise gr.Error(f"Load LoRA failed: {e}")
106
 
107
- spatial_image = [uploaded_image.convert("RGB")]
108
  subject_images = []
 
109
  start = time.time()
110
  out_img = pipe(
111
  prompt,
112
  height=(height // 8) * 8,
113
- width=(width // 8) * 8,
114
  guidance_scale=guidance_scale,
115
  num_inference_steps=num_inference_steps,
116
  max_sequence_length=512,
@@ -124,7 +114,7 @@ def generate_image(
124
  clear_cache(pipe.transformer)
125
  return uploaded_image, out_img
126
 
127
- # ------------------ UI Interface ------------------ #
128
  def create_interface():
129
  demo_lora_names = [
130
  "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
@@ -135,22 +125,22 @@ def create_interface():
135
  ]
136
 
137
  def update_trigger_word(lora_name, prompt):
138
- for name in demo_lora_names:
139
- trigger = " ".join(name.split("_")) + " style,"
140
- prompt = prompt.replace(trigger, "")
141
- new_trigger = " ".join(lora_name.split("_"))+ " style,"
142
- return new_trigger + prompt
143
 
144
  examples = [
145
- ["3D_Chibi", "", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
146
  Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
147
- ["Clay_Toy", "", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.",
148
  Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
149
- ["American_Cartoon", "", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' — capturing both tension and humor.",
150
  Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
151
- ["Origami", "", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful moment—capturing the excitement of meeting a football legend.",
152
  Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
153
- ["Vector", "", "Vector style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.",
154
  Image.open("./test_imgs/04.png"), 512, 1024, 3.5, 24, 42]
155
  ]
156
 
@@ -169,33 +159,32 @@ def create_interface():
169
 
170
  with gr.Row():
171
  with gr.Column(scale=1):
172
- image_input = gr.Image(type="pil", label="Upload Image")
173
- prompt_box = gr.Textbox(label="Prompt",
174
- value="3D Chibi style,",
175
- info="Remember to include the necessary trigger words if you're using a custom LoRA."
 
176
  )
177
- lora_dropdown = gr.Dropdown(
178
  demo_lora_names, label="Select built-in LoRA")
179
- custom_repo_box = gr.Textbox(
180
  label="Enter Custom LoRA",
181
- placeholder="LoRA Hugging Face path (e.g., 'username/repo_name')",
182
- info="If you want to use a custom LoRA, enter its Hugging Face repo ID here and built-in LoRA will be Overridden. Leave empty to use built-in LoRAs. [Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)"
183
  )
184
- gen_btn = gr.Button("Generate")
 
185
  with gr.Column(scale=1):
186
  output_image = gr.ImageSlider(label="Generated Image")
 
187
  with gr.Accordion("Advanced Options", open=False):
188
- height_box = gr.Textbox(value="1024", label="Height")
189
- width_box = gr.Textbox(value="1024", label="Width")
190
- guidance_slider = gr.Slider(
191
- 0.1, 20, value=3.5, step=0.1, label="Guidance Scale")
192
- steps_slider = gr.Slider(
193
- 1, 50, value=25, step=1, label="Inference Steps")
194
- seed_slider = gr.Slider(
195
- 1, 2_147_483_647, value=42, step=1, label="Seed")
196
-
197
- lora_dropdown.select(fn=update_trigger_word, inputs=[lora_dropdown,prompt_box],
198
- outputs=prompt_box)
199
 
200
  gr.Examples(
201
  examples=examples,
@@ -213,8 +202,10 @@ def create_interface():
213
  width_box, height_box, guidance_slider, steps_slider, seed_slider],
214
  outputs=output_image
215
  )
 
216
  return demo
217
 
 
218
  if __name__ == "__main__":
219
  demo = create_interface()
220
  demo.launch(ssr_mode=False)
 
9
  from src_inference.lora_helper import set_single_lora
10
 
11
  BASE_PATH = "black-forest-labs/FLUX.1-dev"
12
+ LOCAL_LORA_DIR = "./LoRAs"
13
+ CUSTOM_LORA_DIR = "./Custom_LoRAs"
14
  os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
15
  os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
16
 
17
+ # ------------------ DEVICE SETUP ------------------ #
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
20
  print(f"🚀 Running on device: {device}")
 
37
  set_single_lora(pipe.transformer, omni_consistency_path,
38
  lora_weights=[1], cond_size=512)
39
 
40
+ # ------------------ Util ------------------ #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def clear_cache(transformer):
42
  for _, attn_processor in transformer.attn_processors.items():
43
  attn_processor.bank_kv.clear()
44
 
45
+ # ------------------ Generation ------------------ #
46
+ @spaces.GPU()
47
  def generate_image(
48
+ lora_name,
49
+ custom_repo_id,
50
  prompt,
51
  uploaded_image,
52
  width, height,
 
57
  width, height = int(width), int(height)
58
  generator = torch.Generator("cpu").manual_seed(seed)
59
 
60
+ # Custom LoRA path
61
  if custom_repo_id and custom_repo_id.strip():
62
  repo_id = custom_repo_id.strip()
63
  try:
 
76
  except Exception as e:
77
  raise gr.Error(f"Load custom LoRA failed: {e}")
78
  else:
79
+ # Built-in LoRA: download only the one selected
80
+ lora_filename = f"LoRAs/{lora_name}_rank128_bf16.safetensors"
81
+ lora_path = hf_hub_download(
82
+ repo_id="showlab/OmniConsistency",
83
+ filename=lora_filename,
84
+ local_dir=LOCAL_LORA_DIR
85
  )
86
 
87
  pipe.unload_lora_weights()
 
93
  except Exception as e:
94
  raise gr.Error(f"Load LoRA failed: {e}")
95
 
96
+ spatial_image = [uploaded_image.convert("RGB")]
97
  subject_images = []
98
+
99
  start = time.time()
100
  out_img = pipe(
101
  prompt,
102
  height=(height // 8) * 8,
103
+ width=(width // 8) * 8,
104
  guidance_scale=guidance_scale,
105
  num_inference_steps=num_inference_steps,
106
  max_sequence_length=512,
 
114
  clear_cache(pipe.transformer)
115
  return uploaded_image, out_img
116
 
117
+ # ------------------ Gradio UI ------------------ #
118
  def create_interface():
119
  demo_lora_names = [
120
  "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
 
125
  ]
126
 
127
  def update_trigger_word(lora_name, prompt):
128
+ for name in demo_lora_names:
129
+ trigger = " ".join(name.split("_")) + " style,"
130
+ prompt = prompt.replace(trigger, "")
131
+ new_trigger = " ".join(lora_name.split("_")) + " style,"
132
+ return new_trigger + prompt
133
 
134
  examples = [
135
+ ["3D_Chibi", "", "3D Chibi style, Two smiling colleagues high-five at a whiteboard filled with technical notes.",
136
  Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
137
+ ["Clay_Toy", "", "Clay Toy style, A holiday-themed OpenAI team photo full of smiles and warmth.",
138
  Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
139
+ ["American_Cartoon", "", "American Cartoon style, A dramatic subtitle moment from a classic film.",
140
  Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
141
+ ["Origami", "", "Origami style, A Portugal football fan posing with Cristiano Ronaldo.",
142
  Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
143
+ ["Vector", "", "Vector style, The distracted boyfriend meme reimagined.",
144
  Image.open("./test_imgs/04.png"), 512, 1024, 3.5, 24, 42]
145
  ]
146
 
 
159
 
160
  with gr.Row():
161
  with gr.Column(scale=1):
162
+ image_input = gr.Image(type="pil", label="Upload Image")
163
+ prompt_box = gr.Textbox(
164
+ label="Prompt",
165
+ value="3D Chibi style,",
166
+ info="Include a style like 'Ghibli style,' in your prompt for better results."
167
  )
168
+ lora_dropdown = gr.Dropdown(
169
  demo_lora_names, label="Select built-in LoRA")
170
+ custom_repo_box = gr.Textbox(
171
  label="Enter Custom LoRA",
172
+ placeholder="e.g. username/repo_name",
173
+ info="Overrides built-in LoRA if provided."
174
  )
175
+ gen_btn = gr.Button("Generate")
176
+
177
  with gr.Column(scale=1):
178
  output_image = gr.ImageSlider(label="Generated Image")
179
+
180
  with gr.Accordion("Advanced Options", open=False):
181
+ height_box = gr.Textbox(value="1024", label="Height")
182
+ width_box = gr.Textbox(value="1024", label="Width")
183
+ guidance_slider = gr.Slider(0.1, 20, value=3.5, step=0.1, label="Guidance Scale")
184
+ steps_slider = gr.Slider(1, 50, value=25, step=1, label="Inference Steps")
185
+ seed_slider = gr.Slider(1, 2_147_483_647, value=42, step=1, label="Seed")
186
+
187
+ lora_dropdown.select(fn=update_trigger_word, inputs=[lora_dropdown, prompt_box], outputs=prompt_box)
 
 
 
 
188
 
189
  gr.Examples(
190
  examples=examples,
 
202
  width_box, height_box, guidance_slider, steps_slider, seed_slider],
203
  outputs=output_image
204
  )
205
+
206
  return demo
207
 
208
+ # ------------------ Run ------------------ #
209
  if __name__ == "__main__":
210
  demo = create_interface()
211
  demo.launch(ssr_mode=False)