import spaces # import first import random import numpy as np import torch from diffusers import StableDiffusionXLPipeline import gradio as gr from tkg import apply_tkg_noise, ColorSet, COLOR_SET_MAP torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device = "cuda" model_repo_id = "cagliostrolab/animagine-xl-4.0" # Replace to the model you would like to use pipe = StableDiffusionXLPipeline.from_pretrained( "cagliostrolab/animagine-xl-4.0", torch_dtype=torch.bfloat16, custom_pipeline="lpw_stable_diffusion_xl", add_watermarker=False, ) pipe = pipe.to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 @spaces.GPU def infer( prompt: str, negative_prompt: str, seed: int, randomize_seed: bool, width: int, height: int, guidance_scale: float, num_inference_steps: int, tkg_channels: list[int] = [0, 1, 1, 0], chroma_key_shift: float = 0.11, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( ( 1, 4, # 4 channels height // 8, width // 8, ), generator=generator, device=device, dtype=torch.bfloat16, ) tkg_latents = apply_tkg_noise( latents, shift=chroma_key_shift, delta_shift=0.1, std_dev=0.5, factor=8, channels=tkg_channels, ).to(torch.bfloat16) latents = torch.cat( [ tkg_latents, latents, ], dim=0, ) images = pipe( latents=latents, prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, num_images_per_prompt=2, generator=generator, ).images w_tkg, wo_tkg = images return w_tkg, wo_tkg, seed def color_name_to_channels(color_name: str) -> list[int]: if color_name in COLOR_SET_MAP: return COLOR_SET_MAP[color_name].channels else: raise ValueError(f"Unknown color name: {color_name}") def on_generate( prompt: str, negative_prompt: str, seed: int, randomize_seed: bool, width: int, height: int, guidance_scale: float, num_inference_steps: int, color_name: str, chroma_key_shift: float, *args, **kwargs ): tkg_channels = color_name_to_channels(color_name) # TODO: custom channels w_tkg, wo_tkg, seed = infer( prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, tkg_channels=tkg_channels, chroma_key_shift=chroma_key_shift, *args, **kwargs, ) return w_tkg, wo_tkg, seed examples = [ # "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres", "1girl, solo, school uniform, cat ears, full body, looking at viewer, straight-on, chibi, simple background, best quality", "1girl, solo, hand up, waving, long hair, sideways glance, upper body, cropped torso, simple background, best quality", ] with gr.Blocks() as demo: with gr.Column(): gr.Markdown( """ # TKG Chroma-Key with AnimagineXL 4.0 TKG-DM🥚🍚: Training-free Chroma Key Content Generation Diffusion Model - arXiv: https://arxiv.org/abs/2411.15580 - GitHub: https://github.com/ryugo417/TKG-DM """) with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", max_lines=4, placeholder="Enter your prompt", ) color_set = gr.Dropdown( label="Background color", choices=list(COLOR_SET_MAP.keys()), value="green", ) with gr.Accordion("TKG Settings", open=False): chroma_key_shift = gr.Slider( label="Latent mean shift for chroma key", minimum=0.0, maximum=0.2, step=0.005, value=0.11, ) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Textbox( label="Negative prompt", max_lines=4, placeholder="Enter a negative prompt", value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=832, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1152, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=25, ) with gr.Column(): run_button = gr.Button("Generate", variant="primary") with gr.Row(): result_w_tkg = gr.Image(label="with TKG") result_wo_tkg = gr.Image(label="without TKG") gr.Examples(examples=examples, inputs=[prompt]) gr.on( triggers=[run_button.click, prompt.submit], fn=on_generate, inputs=[ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, color_set, chroma_key_shift, ], outputs=[result_w_tkg, result_wo_tkg, seed], ) if __name__ == "__main__": demo.launch()