import random import spaces import torch import gradio as gr from modeling.dmm_pipeline import StableDiffusionDMMPipeline from huggingface_hub import snapshot_download ckpt_path = "ckpt" snapshot_download(repo_id="MCG-NJU/DMM", local_dir=ckpt_path) pipe = StableDiffusionDMMPipeline.from_pretrained( ckpt_path, torch_dtype=torch.float16, use_safetensors=True ) pipe.to("cuda") @spaces.GPU def generate(prompt: str, negative_prompt: str, model_id: int, seed: int = 1234, height: int = 512, width: int = 512, all: bool = True): if all: outputs = [] for i in range(pipe.unet.get_num_models()): output = pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=25, guidance_scale=7, model_id=i, generator=torch.Generator().manual_seed(seed), ).images[0] outputs.append(output) return outputs else: output = pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=25, guidance_scale=7, model_id=int(model_id), generator=torch.Generator().manual_seed(seed), ).images[0] return [output,] candidates = [ "0. [JuggernautReborn] realistic", "1. [MajicmixRealisticV7] realistic, Asia portrait", "2. [EpicRealismV5] realistic", "3. [RealisticVisionV5] realistic", "4. [MajicmixFantasyV3] animation", "5. [MinimalismV2] illustration", "6. [RealCartoon3dV17] cartoon 3d", "7. [AWPaintingV1.4] animation", ] def main(): with gr.Blocks() as demo: gr.Markdown( """ # DMM Demo The checkpoint is https://huggingface.co/MCG-NJU/DMM. """ ) with gr.Row(): with gr.Column(): with gr.Column(): model_id = gr.Dropdown(candidates, label="Model Index", type="index") all_check = gr.Checkbox(label="All (ignore the selection above)") prompt = gr.Textbox("portrait photo of a girl, long golden hair, flowers, best quality", label="Prompt") negative_prompt = gr.Textbox("worst quality,low quality,normal quality,lowres,watermark,nsfw", label="Negative Prompt") with gr.Row(): seed = gr.Number(0, label="Seed", precision=0, scale=3) update_seed_btn = gr.Button("🎲", scale=1) with gr.Row(): height = gr.Number(768, step=8, label="Height (suggest 512~768)") width = gr.Number(512, step=8, label="Width") submit_btn = gr.Button("Submit", variant="primary") output = gr.Gallery(label="images") submit_btn.click(generate, inputs=[prompt, negative_prompt, model_id, seed, height, width, all_check], outputs=[output]) update_seed_btn.click(lambda: random.randint(0, 1000000), outputs=[seed]) demo.launch() if __name__ == "__main__": main()