Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
import spaces | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from tqdm import tqdm | |
import gc | |
from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline | |
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel | |
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 | |
LORA_CONFIG = { | |
"None": { | |
"repo_id": None, | |
"filename": None, | |
"type": "edit", | |
"method": "none", | |
"prompt_template": "{prompt}", | |
"description": "Use the base Qwen-Image-Edit model without any LoRA.", | |
}, | |
"InStyle (Style Transfer)": { | |
"repo_id": "peteromallet/Qwen-Image-Edit-InStyle", | |
"filename": "InStyle-0.5.safetensors", | |
"type": "style", | |
"method": "manual_fuse", | |
"prompt_template": "Make an image in this style of {prompt}", | |
"description": "Transfers the style from a reference image to a new image described by the prompt.", | |
}, | |
"InScene (In-Scene Editing)": { | |
"repo_id": "flymy-ai/qwen-image-edit-inscene-lora", | |
"filename": "flymy_qwen_image_edit_inscene_lora.safetensors", | |
"type": "edit", | |
"method": "standard", | |
"prompt_template": "{prompt}", | |
"description": "Improves in-scene editing, object positioning, and camera perspective changes.", | |
}, | |
"Face Segmentation": { | |
"repo_id": "TsienDragon/qwen-image-edit-lora-face-segmentation", | |
"filename": "pytorch_lora_weights.safetensors", | |
"type": "edit", | |
"method": "standard", | |
"prompt_template": "change the face to face segmentation mask", | |
"description": "Transforms a facial image into a precise segmentation mask.", | |
}, | |
"Object Remover": { | |
"repo_id": "valiantcat/Qwen-Image-Edit-Remover-General-LoRA", | |
"filename": "qwen-edit-remover.safetensors", | |
"type": "edit", | |
"method": "standard", | |
"prompt_template": "Remove {prompt}", | |
"description": "Removes objects from an image while maintaining background consistency.", | |
}, | |
} | |
print("Initializing model...") | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = QwenImageEditPipeline.from_pretrained( | |
"Qwen/Qwen-Image-Edit", | |
torch_dtype=dtype | |
).to(device) | |
pipe.transformer.__class__ = QwenImageTransformer2DModel | |
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) | |
original_transformer_state_dict = pipe.transformer.state_dict() | |
print("Base model loaded and ready.") | |
def fuse_lora_manual(transformer, lora_state_dict, alpha=1.0): | |
key_mapping = {} | |
for key in lora_state_dict.keys(): | |
base_key = key.replace('diffusion_model.', '').rsplit('.lora_', 1)[0] | |
if base_key not in key_mapping: | |
key_mapping[base_key] = {} | |
if 'lora_A' in key: | |
key_mapping[base_key]['down'] = lora_state_dict[key] | |
elif 'lora_B' in key: | |
key_mapping[base_key]['up'] = lora_state_dict[key] | |
for name, module in tqdm(transformer.named_modules(), desc="Fusing layers"): | |
if name in key_mapping and isinstance(module, torch.nn.Linear): | |
lora_weights = key_mapping[name] | |
if 'down' in lora_weights and 'up' in lora_weights: | |
device = module.weight.device | |
dtype = module.weight.dtype | |
lora_down = lora_weights['down'].to(device, dtype=dtype) | |
lora_up = lora_weights['up'].to(device, dtype=dtype) | |
merged_delta = lora_up @ lora_down | |
module.weight.data += alpha * merged_delta | |
return transformer | |
def load_and_fuse_lora(lora_name): | |
"""Carrega uma LoRA, funde-a ao modelo e retorna o pipeline modificado.""" | |
config = LORA_CONFIG[lora_name] | |
print("Resetting transformer to original state...") | |
pipe.transformer.load_state_dict(original_transformer_state_dict) | |
if config["method"] == "none": | |
print("No LoRA selected. Using base model.") | |
return | |
print(f"Loading LoRA: {lora_name}") | |
lora_path = hf_hub_download(repo_id=config["repo_id"], filename=config["filename"]) | |
if config["method"] == "standard": | |
print("Using standard loading method...") | |
pipe.load_lora_weights(lora_path) | |
print("Fusing LoRA into the model...") | |
pipe.fuse_lora() | |
elif config["method"] == "manual_fuse": | |
print("Using manual fusion method...") | |
lora_state_dict = load_file(lora_path) | |
pipe.transformer = fuse_lora_manual(pipe.transformer, lora_state_dict) | |
gc.collect() | |
torch.cuda.empty_cache() | |
print(f"LoRA '{lora_name}' is now active.") | |
def infer( | |
lora_name, | |
input_image, | |
style_image, | |
prompt, | |
seed, | |
randomize_seed, | |
true_guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if not lora_name: | |
raise gr.Error("Please select a LoRA model.") | |
config = LORA_CONFIG[lora_name] | |
if config["type"] == "style": | |
if style_image is None: | |
raise gr.Error("Style Transfer LoRA requires a Style Reference Image.") | |
image_for_pipeline = style_image | |
else: # 'edit' | |
if input_image is None: | |
raise gr.Error("This LoRA requires an Input Image.") | |
image_for_pipeline = input_image | |
if not prompt and config["prompt_template"] != "change the face to face segmentation mask": | |
raise gr.Error("A text prompt is required for this LoRA.") | |
load_and_fuse_lora(lora_name) | |
final_prompt = config["prompt_template"].format(prompt=prompt) | |
if randomize_seed: | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
generator = torch.Generator(device=device).manual_seed(int(seed)) | |
print("--- Running Inference ---") | |
print(f"LoRA: {lora_name}") | |
print(f"Prompt: {final_prompt}") | |
print(f"Seed: {seed}, Steps: {num_inference_steps}, CFG: {true_guidance_scale}") | |
with torch.inference_mode(): | |
result_image = pipe( | |
image=image_for_pipeline, | |
prompt=final_prompt, | |
negative_prompt=" ", | |
num_inference_steps=int(num_inference_steps), | |
generator=generator, | |
true_cfg_scale=true_guidance_scale, | |
).images[0] | |
pipe.unfuse_lora() | |
gc.collect() | |
torch.cuda.empty_cache() | |
return result_image, seed | |
def on_lora_change(lora_name): | |
config = LORA_CONFIG[lora_name] | |
is_style_lora = config["type"] == "style" | |
return { | |
lora_description: gr.Markdown(visible=True, value=f"**Description:** {config['description']}"), | |
input_image_box: gr.Image(visible=not is_style_lora), | |
style_image_box: gr.Image(visible=is_style_lora), | |
prompt_box: gr.Textbox(visible=(config["prompt_template"] != "change the face to face segmentation mask")) | |
} | |
with gr.Blocks(css="#col-container { margin: 0 auto; max-width: 1024px; }") as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML('<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_edit_logo.png" alt="Qwen-Image Logo" style="width: 400px; margin: 0 auto; display: block;">') | |
gr.Markdown("<h2 style='text-align: center;'>Qwen-Image-Edit Multi-LoRA Playground</h2>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
lora_selector = gr.Dropdown( | |
label="Select LoRA Model", | |
choices=list(LORA_CONFIG.keys()), | |
value="InStyle (Style Transfer)" | |
) | |
lora_description = gr.Markdown(visible=False) | |
input_image_box = gr.Image(label="Input Image", type="pil", visible=False) | |
style_image_box = gr.Image(label="Style Reference Image", type="pil", visible=True) | |
prompt_box = gr.Textbox(label="Prompt", placeholder="Describe the content or object to remove...") | |
run_button = gr.Button("Generate!", variant="primary") | |
with gr.Column(scale=1): | |
result_image = gr.Image(label="Result", type="pil") | |
used_seed = gr.Number(label="Used Seed", interactive=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=42) | |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) | |
cfg_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, step=0.1, value=4.0) | |
steps_slider = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=25) | |
lora_selector.change( | |
fn=on_lora_change, | |
inputs=lora_selector, | |
outputs=[lora_description, input_image_box, style_image_box, prompt_box] | |
) | |
demo.load( | |
fn=on_lora_change, | |
inputs=lora_selector, | |
outputs=[lora_description, input_image_box, style_image_box, prompt_box] | |
) | |
run_button.click( | |
fn=infer, | |
inputs=[ | |
lora_selector, | |
input_image_box, style_image_box, | |
prompt_box, | |
seed_slider, randomize_seed_checkbox, | |
cfg_slider, steps_slider | |
], | |
outputs=[result_image, used_seed] | |
) | |
if __name__ == "__main__": | |
demo.launch() |