Spaces:
Sleeping
Sleeping
main
Browse files- gradio_app.py +22 -17
gradio_app.py
CHANGED
|
@@ -31,25 +31,29 @@ model_paths = {
|
|
| 31 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 32 |
'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp8_e4m3fn.safetensors",
|
| 33 |
'LORA_REPO': "showlab/makeanything",
|
| 34 |
-
'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors"
|
|
|
|
| 35 |
},
|
| 36 |
'LEGO': {
|
| 37 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 38 |
'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp8_e4m3fn.safetensors",
|
| 39 |
'LORA_REPO': "showlab/makeanything",
|
| 40 |
-
'LORA_FILE': "recraft/recraft_9f_lego.safetensors"
|
|
|
|
| 41 |
},
|
| 42 |
'Sketch': {
|
| 43 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 44 |
'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp8_e4m3fn.safetensors",
|
| 45 |
'LORA_REPO': "showlab/makeanything",
|
| 46 |
-
'LORA_FILE': "recraft/recraft_9f_sketch.safetensors"
|
|
|
|
| 47 |
},
|
| 48 |
'Portrait': {
|
| 49 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 50 |
'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp8_e4m3fn.safetensors",
|
| 51 |
'LORA_REPO': "showlab/makeanything",
|
| 52 |
-
'LORA_FILE': "recraft/recraft_9f_portrait.safetensors"
|
|
|
|
| 53 |
}
|
| 54 |
}
|
| 55 |
|
|
@@ -92,14 +96,15 @@ def load_target_model(selected_model):
|
|
| 92 |
|
| 93 |
logger.info("Loading models...")
|
| 94 |
try:
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
| 103 |
|
| 104 |
# Load LoRA weights
|
| 105 |
multiplier = 1.0
|
|
@@ -148,12 +153,15 @@ class ResizeWithPadding:
|
|
| 148 |
|
| 149 |
# The function to generate image from a prompt and conditional image
|
| 150 |
@spaces.GPU(duration=180)
|
| 151 |
-
def infer(prompt, sample_image,
|
| 152 |
global model, clip_l, t5xxl, ae, lora_model
|
| 153 |
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
| 154 |
logger.error("Models not loaded. Please load the models first.")
|
| 155 |
return None
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
logger.info(f"Started generating image with prompt: {prompt}")
|
| 158 |
|
| 159 |
lora_model.to("cuda")
|
|
@@ -288,9 +296,6 @@ with gr.Blocks() as demo:
|
|
| 288 |
# File upload for image
|
| 289 |
sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
|
| 290 |
|
| 291 |
-
# Frame number selection
|
| 292 |
-
frame_num = gr.Radio([4, 9], label="Select Frame Number", value=9)
|
| 293 |
-
|
| 294 |
# Seed
|
| 295 |
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
|
| 296 |
|
|
@@ -310,7 +315,7 @@ with gr.Blocks() as demo:
|
|
| 310 |
load_button.click(fn=load_target_model, inputs=[recraft_model], outputs=[status_box])
|
| 311 |
|
| 312 |
# Run Button
|
| 313 |
-
run_button.click(fn=infer, inputs=[prompt, sample_image,
|
| 314 |
|
| 315 |
# Launch the Gradio app
|
| 316 |
demo.launch()
|
|
|
|
| 31 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 32 |
'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp8_e4m3fn.safetensors",
|
| 33 |
'LORA_REPO': "showlab/makeanything",
|
| 34 |
+
'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
|
| 35 |
+
"Frame": 4
|
| 36 |
},
|
| 37 |
'LEGO': {
|
| 38 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 39 |
'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp8_e4m3fn.safetensors",
|
| 40 |
'LORA_REPO': "showlab/makeanything",
|
| 41 |
+
'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
|
| 42 |
+
"Frame": 9
|
| 43 |
},
|
| 44 |
'Sketch': {
|
| 45 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 46 |
'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp8_e4m3fn.safetensors",
|
| 47 |
'LORA_REPO': "showlab/makeanything",
|
| 48 |
+
'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
|
| 49 |
+
"Frame": 9
|
| 50 |
},
|
| 51 |
'Portrait': {
|
| 52 |
'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
| 53 |
'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp8_e4m3fn.safetensors",
|
| 54 |
'LORA_REPO': "showlab/makeanything",
|
| 55 |
+
'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
|
| 56 |
+
"Frame": 9
|
| 57 |
}
|
| 58 |
}
|
| 59 |
|
|
|
|
| 96 |
|
| 97 |
logger.info("Loading models...")
|
| 98 |
try:
|
| 99 |
+
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
| 100 |
+
_, model = flux_utils.load_flow_model(
|
| 101 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
| 102 |
+
)
|
| 103 |
+
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
| 104 |
+
clip_l.eval()
|
| 105 |
+
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
| 106 |
+
t5xxl.eval()
|
| 107 |
+
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
| 108 |
|
| 109 |
# Load LoRA weights
|
| 110 |
multiplier = 1.0
|
|
|
|
| 153 |
|
| 154 |
# The function to generate image from a prompt and conditional image
|
| 155 |
@spaces.GPU(duration=180)
|
| 156 |
+
def infer(prompt, sample_image, recraft_model, seed=0):
|
| 157 |
global model, clip_l, t5xxl, ae, lora_model
|
| 158 |
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
| 159 |
logger.error("Models not loaded. Please load the models first.")
|
| 160 |
return None
|
| 161 |
|
| 162 |
+
model_path = model_paths[selected_model]
|
| 163 |
+
frame_num = model_path['Frame']
|
| 164 |
+
|
| 165 |
logger.info(f"Started generating image with prompt: {prompt}")
|
| 166 |
|
| 167 |
lora_model.to("cuda")
|
|
|
|
| 296 |
# File upload for image
|
| 297 |
sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
|
| 298 |
|
|
|
|
|
|
|
|
|
|
| 299 |
# Seed
|
| 300 |
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
|
| 301 |
|
|
|
|
| 315 |
load_button.click(fn=load_target_model, inputs=[recraft_model], outputs=[status_box])
|
| 316 |
|
| 317 |
# Run Button
|
| 318 |
+
run_button.click(fn=infer, inputs=[prompt, sample_image, recraft_model, seed], outputs=[result_image])
|
| 319 |
|
| 320 |
# Launch the Gradio app
|
| 321 |
demo.launch()
|