Spaces:
Running
on
Zero
Running
on
Zero
Update utils.py
Browse files
utils.py
CHANGED
@@ -19,6 +19,7 @@ import subprocess
|
|
19 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
20 |
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
|
21 |
|
|
|
22 |
from src.flux.generate import generate, seed_everything
|
23 |
|
24 |
try:
|
@@ -29,78 +30,21 @@ except ImportError:
|
|
29 |
|
30 |
import re
|
31 |
|
32 |
-
# Global variables
|
33 |
pipe = None
|
34 |
model_dict = {}
|
35 |
-
_MODEL_INITIALIZED = False
|
36 |
-
_ADAPTERS_LOADED = False
|
37 |
|
38 |
def init_flux_pipeline():
|
39 |
-
|
40 |
-
global pipe, _MODEL_INITIALIZED
|
41 |
-
|
42 |
if pipe is None:
|
43 |
-
print("Initializing Flux pipeline...")
|
44 |
token = os.getenv("HF_TOKEN")
|
45 |
if not token:
|
46 |
raise ValueError("HF_TOKEN environment variable not set.")
|
47 |
-
|
48 |
pipe = FluxPipeline.from_pretrained(
|
49 |
"black-forest-labs/FLUX.1-schnell",
|
50 |
use_auth_token=token,
|
51 |
torch_dtype=torch.bfloat16
|
52 |
)
|
53 |
pipe = pipe.to("cuda")
|
54 |
-
_MODEL_INITIALIZED = True
|
55 |
-
print("Flux pipeline initialized successfully.")
|
56 |
-
|
57 |
-
return pipe
|
58 |
-
|
59 |
-
def load_all_lora_adapters():
|
60 |
-
"""Load all LoRA adapters, ensuring it runs only once"""
|
61 |
-
global pipe, _ADAPTERS_LOADED
|
62 |
-
|
63 |
-
# Ensure model is initialized
|
64 |
-
init_flux_pipeline()
|
65 |
-
|
66 |
-
if not _ADAPTERS_LOADED:
|
67 |
-
print("Loading all LoRA adapters...")
|
68 |
-
|
69 |
-
LORA_ADAPTERS = {
|
70 |
-
"add": "weights/add.safetensors",
|
71 |
-
"remove": "weights/remove.safetensors",
|
72 |
-
"action": "weights/action.safetensors",
|
73 |
-
"expression": "weights/expression.safetensors",
|
74 |
-
"addition": "weights/addition.safetensors",
|
75 |
-
"material": "weights/material.safetensors",
|
76 |
-
"color": "weights/color.safetensors",
|
77 |
-
"bg": "weights/bg.safetensors",
|
78 |
-
"appearance": "weights/appearance.safetensors",
|
79 |
-
"fusion": "weights/fusion.safetensors",
|
80 |
-
"overall": "weights/overall.safetensors",
|
81 |
-
}
|
82 |
-
|
83 |
-
for adapter_name, weight_path in LORA_ADAPTERS.items():
|
84 |
-
try:
|
85 |
-
pipe.load_lora_weights(
|
86 |
-
"Cicici1109/IEAP",
|
87 |
-
weight_name=weight_path,
|
88 |
-
adapter_name=adapter_name,
|
89 |
-
)
|
90 |
-
print(f"✅ Successfully loaded adapter: {adapter_name}")
|
91 |
-
except Exception as e:
|
92 |
-
print(f"❌ Failed to load adapter {adapter_name}: {e}")
|
93 |
-
|
94 |
-
loaded_adapters = list(pipe.lora_adapters.keys())
|
95 |
-
print(f"Loaded adapters: {loaded_adapters}")
|
96 |
-
|
97 |
-
if loaded_adapters:
|
98 |
-
pipe.set_adapters(loaded_adapters[0])
|
99 |
-
print(f"Default adapter set to: {loaded_adapters[0]}")
|
100 |
-
|
101 |
-
_ADAPTERS_LOADED = True
|
102 |
-
|
103 |
-
return pipe
|
104 |
|
105 |
def get_model(model_path):
|
106 |
global model_dict
|
@@ -221,55 +165,57 @@ def extract_last_bbox(result):
|
|
221 |
|
222 |
@spaces.GPU
|
223 |
def infer_with_DiT(task, image, instruction, category):
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
if task == 'RoI Inpainting':
|
228 |
if category == 'Add' or category == 'Replace':
|
229 |
-
|
230 |
added = extract_object_with_gpt(instruction)
|
231 |
instruction_dit = f"add {added} on the black region"
|
232 |
elif category == 'Remove' or category == 'Action Change':
|
233 |
-
|
234 |
instruction_dit = f"Fill the hole of the image"
|
|
|
235 |
condition = Condition("scene", image, position_delta=(0, 0))
|
236 |
-
|
237 |
elif task == 'RoI Editing':
|
238 |
image = Image.open(image).convert('RGB').resize((512, 512))
|
239 |
condition = Condition("scene", image, position_delta=(0, -32))
|
240 |
instruction_dit = instruction
|
241 |
-
|
242 |
if category == 'Action Change':
|
243 |
-
|
244 |
elif category == 'Expression Change':
|
245 |
-
|
246 |
elif category == 'Add':
|
247 |
-
|
248 |
elif category == 'Material Change':
|
249 |
-
|
250 |
elif category == 'Color Change':
|
251 |
-
|
252 |
elif category == 'Background Change':
|
253 |
-
|
254 |
elif category == 'Appearance Change':
|
255 |
-
|
256 |
-
|
257 |
elif task == 'RoI Compositioning':
|
258 |
-
|
259 |
condition = Condition("scene", image, position_delta=(0, 0))
|
260 |
instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
|
261 |
|
262 |
elif task == 'Global Transformation':
|
263 |
image = Image.open(image).convert('RGB').resize((512, 512))
|
264 |
instruction_dit = instruction
|
265 |
-
|
|
|
266 |
condition = Condition("scene", image, position_delta=(0, -32))
|
267 |
else:
|
268 |
raise ValueError(f"Invalid task: '{task}'")
|
269 |
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
273 |
|
274 |
result_img = generate(
|
275 |
pipe,
|
@@ -646,4 +592,7 @@ def layout_change(bbox, instruction):
|
|
646 |
result = response.choices[0].message.content.strip()
|
647 |
|
648 |
bbox = extract_last_bbox(result)
|
649 |
-
return bbox
|
|
|
|
|
|
|
|
19 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
20 |
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
|
21 |
|
22 |
+
|
23 |
from src.flux.generate import generate, seed_everything
|
24 |
|
25 |
try:
|
|
|
30 |
|
31 |
import re
|
32 |
|
|
|
33 |
pipe = None
|
34 |
model_dict = {}
|
|
|
|
|
35 |
|
36 |
def init_flux_pipeline():
|
37 |
+
global pipe
|
|
|
|
|
38 |
if pipe is None:
|
|
|
39 |
token = os.getenv("HF_TOKEN")
|
40 |
if not token:
|
41 |
raise ValueError("HF_TOKEN environment variable not set.")
|
|
|
42 |
pipe = FluxPipeline.from_pretrained(
|
43 |
"black-forest-labs/FLUX.1-schnell",
|
44 |
use_auth_token=token,
|
45 |
torch_dtype=torch.bfloat16
|
46 |
)
|
47 |
pipe = pipe.to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def get_model(model_path):
|
50 |
global model_dict
|
|
|
165 |
|
166 |
@spaces.GPU
|
167 |
def infer_with_DiT(task, image, instruction, category):
|
168 |
+
init_flux_pipeline()
|
169 |
+
|
|
|
170 |
if task == 'RoI Inpainting':
|
171 |
if category == 'Add' or category == 'Replace':
|
172 |
+
lora_path = "weights/add.safetensors"
|
173 |
added = extract_object_with_gpt(instruction)
|
174 |
instruction_dit = f"add {added} on the black region"
|
175 |
elif category == 'Remove' or category == 'Action Change':
|
176 |
+
lora_path = "weights/remove.safetensors"
|
177 |
instruction_dit = f"Fill the hole of the image"
|
178 |
+
|
179 |
condition = Condition("scene", image, position_delta=(0, 0))
|
|
|
180 |
elif task == 'RoI Editing':
|
181 |
image = Image.open(image).convert('RGB').resize((512, 512))
|
182 |
condition = Condition("scene", image, position_delta=(0, -32))
|
183 |
instruction_dit = instruction
|
|
|
184 |
if category == 'Action Change':
|
185 |
+
lora_path = "weights/action.safetensors"
|
186 |
elif category == 'Expression Change':
|
187 |
+
lora_path = "weights/expression.safetensors"
|
188 |
elif category == 'Add':
|
189 |
+
lora_path = "weights/addition.safetensors"
|
190 |
elif category == 'Material Change':
|
191 |
+
lora_path = "weights/material.safetensors"
|
192 |
elif category == 'Color Change':
|
193 |
+
lora_path = "weights/color.safetensors"
|
194 |
elif category == 'Background Change':
|
195 |
+
lora_path = "weights/bg.safetensors"
|
196 |
elif category == 'Appearance Change':
|
197 |
+
lora_path = "weights/appearance.safetensors"
|
198 |
+
|
199 |
elif task == 'RoI Compositioning':
|
200 |
+
lora_path = "weights/fusion.safetensors"
|
201 |
condition = Condition("scene", image, position_delta=(0, 0))
|
202 |
instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
|
203 |
|
204 |
elif task == 'Global Transformation':
|
205 |
image = Image.open(image).convert('RGB').resize((512, 512))
|
206 |
instruction_dit = instruction
|
207 |
+
lora_path = "weights/overall.safetensors"
|
208 |
+
|
209 |
condition = Condition("scene", image, position_delta=(0, -32))
|
210 |
else:
|
211 |
raise ValueError(f"Invalid task: '{task}'")
|
212 |
|
213 |
+
pipe.unload_lora_weights()
|
214 |
+
pipe.load_lora_weights(
|
215 |
+
"Cicici1109/IEAP",
|
216 |
+
weight_name=lora_path,
|
217 |
+
adapter_name="scene",
|
218 |
+
)
|
219 |
|
220 |
result_img = generate(
|
221 |
pipe,
|
|
|
592 |
result = response.choices[0].message.content.strip()
|
593 |
|
594 |
bbox = extract_last_bbox(result)
|
595 |
+
return bbox
|
596 |
+
|
597 |
+
if __name__ == "__main__":
|
598 |
+
init_flux_pipeline()
|