|
import gradio as gr |
|
import torch |
|
import spaces |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
from diffusers.pipelines import FluxPipeline |
|
import numpy as np |
|
import requests |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
import torch.multiprocessing as mp |
|
|
|
import argparse |
|
import logging |
|
import math |
|
import os |
|
import re |
|
import random |
|
import shutil |
|
from contextlib import nullcontext |
|
from pathlib import Path |
|
from PIL import Image |
|
import accelerate |
|
import datasets |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor, nn |
|
import torch.utils.checkpoint |
|
import transformers |
|
from accelerate import Accelerator |
|
from accelerate.logging import get_logger |
|
from accelerate.state import AcceleratorState |
|
from accelerate.utils import ProjectConfiguration, set_seed |
|
from huggingface_hub import create_repo, upload_folder |
|
from packaging import version |
|
from tqdm.auto import tqdm |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
from transformers.utils import ContextManagers |
|
from omegaconf import OmegaConf |
|
from copy import deepcopy |
|
import diffusers |
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline |
|
from diffusers.optimization import get_scheduler |
|
from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr |
|
from diffusers.utils import check_min_version, deprecate, make_image_grid |
|
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from diffusers.utils.torch_utils import is_compiled_module |
|
from einops import rearrange |
|
from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack |
|
from src.flux.util import (configs, load_ae, load_clip, |
|
load_flow_model2, load_t5, save_image, tensor_to_pil_image, load_checkpoint) |
|
from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, IPSingleStreamBlockProcessor, ImageProjModel |
|
from src.flux.xflux_pipeline import XFluxSampler |
|
|
|
from image_datasets.dataset import loader, eval_image_pair_loader, image_resize |
|
|
|
from safetensors.torch import load_file |
|
import json |
|
|
|
|
|
|
|
|
|
|
|
def get_models(name: str, device, offload: bool, is_schnell: bool): |
|
t5 = load_t5(device, max_length=256 if is_schnell else 512) |
|
clip = load_clip(device) |
|
clip.requires_grad_(False) |
|
model = load_flow_model2(name, device="cpu") |
|
vae = load_ae(name, device="cpu" if offload else device) |
|
return model, vae, t5, clip |
|
|
|
args = OmegaConf.load("inference_configs/inference.yaml") |
|
is_schnell = args.model_name == "flux-schnell" |
|
set_seed(args.seed) |
|
|
|
device = "cuda" |
|
dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.use_ip: |
|
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj) |
|
elif args.use_spatial_condition: |
|
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding) |
|
else: |
|
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None) |
|
|
|
|
|
|
|
def generate(image, edit_prompt): |
|
print("hello?????????!!!!!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.use_lora: |
|
lora_attn_procs = {} |
|
if args.use_ip: |
|
ip_attn_procs = {} |
|
if args.double_blocks is None: |
|
double_blocks_idx = list(range(19)) |
|
else: |
|
double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")] |
|
|
|
if args.single_blocks is None: |
|
single_blocks_idx = list(range(38)) |
|
elif args.single_blocks is not None: |
|
single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")] |
|
|
|
if args.use_lora: |
|
for name, attn_processor in dit.attn_processors.items(): |
|
match = re.search(r'\.(\d+)\.', name) |
|
if match: |
|
layer_index = int(match.group(1)) |
|
|
|
if name.startswith("double_blocks") and layer_index in double_blocks_idx: |
|
|
|
|
|
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( |
|
dim=3072, rank=args.rank |
|
) |
|
elif name.startswith("single_blocks") and layer_index in single_blocks_idx: |
|
|
|
|
|
lora_attn_procs[name] = SingleStreamBlockLoraProcessor( |
|
dim=3072, rank=args.rank |
|
) |
|
else: |
|
lora_attn_procs[name] = attn_processor |
|
|
|
dit.set_attn_processor(lora_attn_procs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae.requires_grad_(False) |
|
t5.requires_grad_(False) |
|
clip.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = hf_hub_download( |
|
repo_id="Boese0601/ByteMorpher", |
|
filename="dit.safetensors" |
|
) |
|
state_dict = load_file(model_path) |
|
dit.load_state_dict(state_dict) |
|
dit = dit.to(weight_dtype) |
|
dit.eval() |
|
|
|
|
|
test_dataloader = eval_image_pair_loader(**args.data_config) |
|
|
|
|
|
|
|
|
|
dit = accelerator.prepare(dit) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = image_resize(image, 512) |
|
w, h = img.size |
|
new_w = (w // 32) * 32 |
|
new_h = (h // 32) * 32 |
|
img = img.resize((new_w, new_h)) |
|
img = torch.from_numpy((np.array(img) / 127.5) - 1) |
|
img = img.permute(2, 0, 1).unsqueeze(0) |
|
|
|
edit_prompt = edit_prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
result = sampler(prompt=edit_prompt, |
|
width=args.sample_width, |
|
height=args.sample_height, |
|
num_steps=args.sample_steps, |
|
image_prompt=None, |
|
true_gs=args.cfg_scale, |
|
seed=args.seed, |
|
ip_scale=args.ip_scale if args.use_ip else 1.0, |
|
source_image=img if args.use_spatial_condition else None, |
|
) |
|
gen_img = result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return gen_img |
|
|
|
|
|
def get_samples(): |
|
sample_list = [ |
|
{ |
|
"image": "assets/0_camera_zoom/20486354.png", |
|
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", |
|
}, |
|
] |
|
return [ |
|
[ |
|
Image.open(sample["image"]).resize((512, 512)), |
|
sample["edit_prompt"], |
|
] |
|
for sample in sample_list |
|
] |
|
|
|
|
|
header = """ |
|
# ByteMoprh |
|
|
|
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> |
|
<a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> |
|
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> |
|
<a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> |
|
</div> |
|
""" |
|
|
|
|
|
def create_app(): |
|
with gr.Blocks() as app: |
|
gr.Markdown(header, elem_id="header") |
|
with gr.Row(equal_height=False): |
|
with gr.Column(variant="panel", elem_classes="inputPanel"): |
|
original_image = gr.Image( |
|
type="pil", label="Condition Image", width=300, elem_id="input" |
|
) |
|
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") |
|
submit_btn = gr.Button("Run", elem_id="submit_btn") |
|
|
|
with gr.Column(variant="panel", elem_classes="outputPanel"): |
|
output_image = gr.Image(type="pil", elem_id="output") |
|
|
|
with gr.Row(): |
|
examples = gr.Examples( |
|
examples=get_samples(), |
|
inputs=[original_image, edit_prompt], |
|
label="Examples", |
|
) |
|
|
|
submit_btn.click( |
|
fn=generate, |
|
inputs=[original_image, edit_prompt], |
|
outputs=output_image, |
|
) |
|
gr.HTML( |
|
""" |
|
<div style="text-align: center;"> |
|
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>. |
|
</div> |
|
""" |
|
) |
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
print("CUDA available:", torch.cuda.is_available()) |
|
print("CUDA version:", torch.version.cuda) |
|
print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None") |
|
|
|
create_app().launch(debug=False, share=True, ssr_mode=False) |
|
|