Cicici1109 commited on
Commit
8a0dce0
·
verified ·
1 Parent(s): e4b0032

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +28 -79
utils.py CHANGED
@@ -19,6 +19,7 @@ import subprocess
19
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
21
 
 
22
  from src.flux.generate import generate, seed_everything
23
 
24
  try:
@@ -29,78 +30,21 @@ except ImportError:
29
 
30
  import re
31
 
32
- # Global variables
33
  pipe = None
34
  model_dict = {}
35
- _MODEL_INITIALIZED = False
36
- _ADAPTERS_LOADED = False
37
 
38
  def init_flux_pipeline():
39
- """Initialize Flux model, ensuring it runs only once"""
40
- global pipe, _MODEL_INITIALIZED
41
-
42
  if pipe is None:
43
- print("Initializing Flux pipeline...")
44
  token = os.getenv("HF_TOKEN")
45
  if not token:
46
  raise ValueError("HF_TOKEN environment variable not set.")
47
-
48
  pipe = FluxPipeline.from_pretrained(
49
  "black-forest-labs/FLUX.1-schnell",
50
  use_auth_token=token,
51
  torch_dtype=torch.bfloat16
52
  )
53
  pipe = pipe.to("cuda")
54
- _MODEL_INITIALIZED = True
55
- print("Flux pipeline initialized successfully.")
56
-
57
- return pipe
58
-
59
- def load_all_lora_adapters():
60
- """Load all LoRA adapters, ensuring it runs only once"""
61
- global pipe, _ADAPTERS_LOADED
62
-
63
- # Ensure model is initialized
64
- init_flux_pipeline()
65
-
66
- if not _ADAPTERS_LOADED:
67
- print("Loading all LoRA adapters...")
68
-
69
- LORA_ADAPTERS = {
70
- "add": "weights/add.safetensors",
71
- "remove": "weights/remove.safetensors",
72
- "action": "weights/action.safetensors",
73
- "expression": "weights/expression.safetensors",
74
- "addition": "weights/addition.safetensors",
75
- "material": "weights/material.safetensors",
76
- "color": "weights/color.safetensors",
77
- "bg": "weights/bg.safetensors",
78
- "appearance": "weights/appearance.safetensors",
79
- "fusion": "weights/fusion.safetensors",
80
- "overall": "weights/overall.safetensors",
81
- }
82
-
83
- for adapter_name, weight_path in LORA_ADAPTERS.items():
84
- try:
85
- pipe.load_lora_weights(
86
- "Cicici1109/IEAP",
87
- weight_name=weight_path,
88
- adapter_name=adapter_name,
89
- )
90
- print(f"✅ Successfully loaded adapter: {adapter_name}")
91
- except Exception as e:
92
- print(f"❌ Failed to load adapter {adapter_name}: {e}")
93
-
94
- loaded_adapters = list(pipe.lora_adapters.keys())
95
- print(f"Loaded adapters: {loaded_adapters}")
96
-
97
- if loaded_adapters:
98
- pipe.set_adapters(loaded_adapters[0])
99
- print(f"Default adapter set to: {loaded_adapters[0]}")
100
-
101
- _ADAPTERS_LOADED = True
102
-
103
- return pipe
104
 
105
  def get_model(model_path):
106
  global model_dict
@@ -221,55 +165,57 @@ def extract_last_bbox(result):
221
 
222
  @spaces.GPU
223
  def infer_with_DiT(task, image, instruction, category):
224
- # Ensure model and adapters are initialized
225
- load_all_lora_adapters()
226
-
227
  if task == 'RoI Inpainting':
228
  if category == 'Add' or category == 'Replace':
229
- adapter_name = "add"
230
  added = extract_object_with_gpt(instruction)
231
  instruction_dit = f"add {added} on the black region"
232
  elif category == 'Remove' or category == 'Action Change':
233
- adapter_name = "remove"
234
  instruction_dit = f"Fill the hole of the image"
 
235
  condition = Condition("scene", image, position_delta=(0, 0))
236
-
237
  elif task == 'RoI Editing':
238
  image = Image.open(image).convert('RGB').resize((512, 512))
239
  condition = Condition("scene", image, position_delta=(0, -32))
240
  instruction_dit = instruction
241
-
242
  if category == 'Action Change':
243
- adapter_name = "action"
244
  elif category == 'Expression Change':
245
- adapter_name = "expression"
246
  elif category == 'Add':
247
- adapter_name = "addition"
248
  elif category == 'Material Change':
249
- adapter_name = "material"
250
  elif category == 'Color Change':
251
- adapter_name = "color"
252
  elif category == 'Background Change':
253
- adapter_name = "bg"
254
  elif category == 'Appearance Change':
255
- adapter_name = "appearance"
256
-
257
  elif task == 'RoI Compositioning':
258
- adapter_name = "fusion"
259
  condition = Condition("scene", image, position_delta=(0, 0))
260
  instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
261
 
262
  elif task == 'Global Transformation':
263
  image = Image.open(image).convert('RGB').resize((512, 512))
264
  instruction_dit = instruction
265
- adapter_name = "overall"
 
266
  condition = Condition("scene", image, position_delta=(0, -32))
267
  else:
268
  raise ValueError(f"Invalid task: '{task}'")
269
 
270
- # Switch to the specified adapter
271
- print(f"Switching to adapter: {adapter_name}")
272
- pipe.set_adapters(adapter_name)
 
 
 
273
 
274
  result_img = generate(
275
  pipe,
@@ -646,4 +592,7 @@ def layout_change(bbox, instruction):
646
  result = response.choices[0].message.content.strip()
647
 
648
  bbox = extract_last_bbox(result)
649
- return bbox
 
 
 
 
19
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
21
 
22
+
23
  from src.flux.generate import generate, seed_everything
24
 
25
  try:
 
30
 
31
  import re
32
 
 
33
  pipe = None
34
  model_dict = {}
 
 
35
 
36
  def init_flux_pipeline():
37
+ global pipe
 
 
38
  if pipe is None:
 
39
  token = os.getenv("HF_TOKEN")
40
  if not token:
41
  raise ValueError("HF_TOKEN environment variable not set.")
 
42
  pipe = FluxPipeline.from_pretrained(
43
  "black-forest-labs/FLUX.1-schnell",
44
  use_auth_token=token,
45
  torch_dtype=torch.bfloat16
46
  )
47
  pipe = pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def get_model(model_path):
50
  global model_dict
 
165
 
166
  @spaces.GPU
167
  def infer_with_DiT(task, image, instruction, category):
168
+ init_flux_pipeline()
169
+
 
170
  if task == 'RoI Inpainting':
171
  if category == 'Add' or category == 'Replace':
172
+ lora_path = "weights/add.safetensors"
173
  added = extract_object_with_gpt(instruction)
174
  instruction_dit = f"add {added} on the black region"
175
  elif category == 'Remove' or category == 'Action Change':
176
+ lora_path = "weights/remove.safetensors"
177
  instruction_dit = f"Fill the hole of the image"
178
+
179
  condition = Condition("scene", image, position_delta=(0, 0))
 
180
  elif task == 'RoI Editing':
181
  image = Image.open(image).convert('RGB').resize((512, 512))
182
  condition = Condition("scene", image, position_delta=(0, -32))
183
  instruction_dit = instruction
 
184
  if category == 'Action Change':
185
+ lora_path = "weights/action.safetensors"
186
  elif category == 'Expression Change':
187
+ lora_path = "weights/expression.safetensors"
188
  elif category == 'Add':
189
+ lora_path = "weights/addition.safetensors"
190
  elif category == 'Material Change':
191
+ lora_path = "weights/material.safetensors"
192
  elif category == 'Color Change':
193
+ lora_path = "weights/color.safetensors"
194
  elif category == 'Background Change':
195
+ lora_path = "weights/bg.safetensors"
196
  elif category == 'Appearance Change':
197
+ lora_path = "weights/appearance.safetensors"
198
+
199
  elif task == 'RoI Compositioning':
200
+ lora_path = "weights/fusion.safetensors"
201
  condition = Condition("scene", image, position_delta=(0, 0))
202
  instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
203
 
204
  elif task == 'Global Transformation':
205
  image = Image.open(image).convert('RGB').resize((512, 512))
206
  instruction_dit = instruction
207
+ lora_path = "weights/overall.safetensors"
208
+
209
  condition = Condition("scene", image, position_delta=(0, -32))
210
  else:
211
  raise ValueError(f"Invalid task: '{task}'")
212
 
213
+ pipe.unload_lora_weights()
214
+ pipe.load_lora_weights(
215
+ "Cicici1109/IEAP",
216
+ weight_name=lora_path,
217
+ adapter_name="scene",
218
+ )
219
 
220
  result_img = generate(
221
  pipe,
 
592
  result = response.choices[0].message.content.strip()
593
 
594
  bbox = extract_last_bbox(result)
595
+ return bbox
596
+
597
+ if __name__ == "__main__":
598
+ init_flux_pipeline()