artificialguybr's picture
Update app.py
1ec1f0d verified
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from tqdm import tqdm
import gc
from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
LORA_CONFIG = {
"None": {
"repo_id": None,
"filename": None,
"type": "edit",
"method": "none",
"prompt_template": "{prompt}",
"description": "Use the base Qwen-Image-Edit model without any LoRA.",
},
"InStyle (Style Transfer)": {
"repo_id": "peteromallet/Qwen-Image-Edit-InStyle",
"filename": "InStyle-0.5.safetensors",
"type": "style",
"method": "manual_fuse",
"prompt_template": "Make an image in this style of {prompt}",
"description": "Transfers the style from a reference image to a new image described by the prompt.",
},
"InScene (In-Scene Editing)": {
"repo_id": "flymy-ai/qwen-image-edit-inscene-lora",
"filename": "flymy_qwen_image_edit_inscene_lora.safetensors",
"type": "edit",
"method": "standard",
"prompt_template": "{prompt}",
"description": "Improves in-scene editing, object positioning, and camera perspective changes.",
},
"Face Segmentation": {
"repo_id": "TsienDragon/qwen-image-edit-lora-face-segmentation",
"filename": "pytorch_lora_weights.safetensors",
"type": "edit",
"method": "standard",
"prompt_template": "change the face to face segmentation mask",
"description": "Transforms a facial image into a precise segmentation mask.",
},
"Object Remover": {
"repo_id": "valiantcat/Qwen-Image-Edit-Remover-General-LoRA",
"filename": "qwen-edit-remover.safetensors",
"type": "edit",
"method": "standard",
"prompt_template": "Remove {prompt}",
"description": "Removes objects from an image while maintaining background consistency.",
},
}
print("Initializing model...")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = QwenImageEditPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit",
torch_dtype=dtype
).to(device)
pipe.transformer.__class__ = QwenImageTransformer2DModel
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
original_transformer_state_dict = pipe.transformer.state_dict()
print("Base model loaded and ready.")
def fuse_lora_manual(transformer, lora_state_dict, alpha=1.0):
key_mapping = {}
for key in lora_state_dict.keys():
base_key = key.replace('diffusion_model.', '').rsplit('.lora_', 1)[0]
if base_key not in key_mapping:
key_mapping[base_key] = {}
if 'lora_A' in key:
key_mapping[base_key]['down'] = lora_state_dict[key]
elif 'lora_B' in key:
key_mapping[base_key]['up'] = lora_state_dict[key]
for name, module in tqdm(transformer.named_modules(), desc="Fusing layers"):
if name in key_mapping and isinstance(module, torch.nn.Linear):
lora_weights = key_mapping[name]
if 'down' in lora_weights and 'up' in lora_weights:
device = module.weight.device
dtype = module.weight.dtype
lora_down = lora_weights['down'].to(device, dtype=dtype)
lora_up = lora_weights['up'].to(device, dtype=dtype)
merged_delta = lora_up @ lora_down
module.weight.data += alpha * merged_delta
return transformer
def load_and_fuse_lora(lora_name):
"""Carrega uma LoRA, funde-a ao modelo e retorna o pipeline modificado."""
config = LORA_CONFIG[lora_name]
print("Resetting transformer to original state...")
pipe.transformer.load_state_dict(original_transformer_state_dict)
if config["method"] == "none":
print("No LoRA selected. Using base model.")
return
print(f"Loading LoRA: {lora_name}")
lora_path = hf_hub_download(repo_id=config["repo_id"], filename=config["filename"])
if config["method"] == "standard":
print("Using standard loading method...")
pipe.load_lora_weights(lora_path)
print("Fusing LoRA into the model...")
pipe.fuse_lora()
elif config["method"] == "manual_fuse":
print("Using manual fusion method...")
lora_state_dict = load_file(lora_path)
pipe.transformer = fuse_lora_manual(pipe.transformer, lora_state_dict)
gc.collect()
torch.cuda.empty_cache()
print(f"LoRA '{lora_name}' is now active.")
@spaces.GPU(duration=60)
def infer(
lora_name,
input_image,
style_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if not lora_name:
raise gr.Error("Please select a LoRA model.")
config = LORA_CONFIG[lora_name]
if config["type"] == "style":
if style_image is None:
raise gr.Error("Style Transfer LoRA requires a Style Reference Image.")
image_for_pipeline = style_image
else: # 'edit'
if input_image is None:
raise gr.Error("This LoRA requires an Input Image.")
image_for_pipeline = input_image
if not prompt and config["prompt_template"] != "change the face to face segmentation mask":
raise gr.Error("A text prompt is required for this LoRA.")
load_and_fuse_lora(lora_name)
final_prompt = config["prompt_template"].format(prompt=prompt)
if randomize_seed:
seed = random.randint(0, np.iinfo(np.int32).max)
generator = torch.Generator(device=device).manual_seed(int(seed))
print("--- Running Inference ---")
print(f"LoRA: {lora_name}")
print(f"Prompt: {final_prompt}")
print(f"Seed: {seed}, Steps: {num_inference_steps}, CFG: {true_guidance_scale}")
with torch.inference_mode():
result_image = pipe(
image=image_for_pipeline,
prompt=final_prompt,
negative_prompt=" ",
num_inference_steps=int(num_inference_steps),
generator=generator,
true_cfg_scale=true_guidance_scale,
).images[0]
pipe.unfuse_lora()
gc.collect()
torch.cuda.empty_cache()
return result_image, seed
def on_lora_change(lora_name):
config = LORA_CONFIG[lora_name]
is_style_lora = config["type"] == "style"
return {
lora_description: gr.Markdown(visible=True, value=f"**Description:** {config['description']}"),
input_image_box: gr.Image(visible=not is_style_lora),
style_image_box: gr.Image(visible=is_style_lora),
prompt_box: gr.Textbox(visible=(config["prompt_template"] != "change the face to face segmentation mask"))
}
with gr.Blocks(css="#col-container { margin: 0 auto; max-width: 1024px; }") as demo:
with gr.Column(elem_id="col-container"):
gr.HTML('<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_edit_logo.png" alt="Qwen-Image Logo" style="width: 400px; margin: 0 auto; display: block;">')
gr.Markdown("<h2 style='text-align: center;'>Qwen-Image-Edit Multi-LoRA Playground</h2>")
with gr.Row():
with gr.Column(scale=1):
lora_selector = gr.Dropdown(
label="Select LoRA Model",
choices=list(LORA_CONFIG.keys()),
value="InStyle (Style Transfer)"
)
lora_description = gr.Markdown(visible=False)
input_image_box = gr.Image(label="Input Image", type="pil", visible=False)
style_image_box = gr.Image(label="Style Reference Image", type="pil", visible=True)
prompt_box = gr.Textbox(label="Prompt", placeholder="Describe the content or object to remove...")
run_button = gr.Button("Generate!", variant="primary")
with gr.Column(scale=1):
result_image = gr.Image(label="Result", type="pil")
used_seed = gr.Number(label="Used Seed", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=42)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
cfg_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, step=0.1, value=4.0)
steps_slider = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=25)
lora_selector.change(
fn=on_lora_change,
inputs=lora_selector,
outputs=[lora_description, input_image_box, style_image_box, prompt_box]
)
demo.load(
fn=on_lora_change,
inputs=lora_selector,
outputs=[lora_description, input_image_box, style_image_box, prompt_box]
)
run_button.click(
fn=infer,
inputs=[
lora_selector,
input_image_box, style_image_box,
prompt_box,
seed_slider, randomize_seed_checkbox,
cfg_slider, steps_slider
],
outputs=[result_image, used_seed]
)
if __name__ == "__main__":
demo.launch()