animate / app.py
ahmdliaqat's picture
animate
7611753
raw
history blame contribute delete
5.01 kB
import gradio as gr
import torch
import os
import spaces
import uuid
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
# Constants
bases = {
"Cartoon": "frankjoshua/toonyou_beta6",
"Realistic": "emilianJR/epiCRealism",
"3d": "Lykon/DreamShaper",
"Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None
# CPU configuration
device = "cpu"
dtype = torch.float32
# Initialize pipeline for CPU
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
# Safety checkers
from transformers import CLIPFeatureExtractor
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
# Function
def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
global step_loaded
global base_loaded
global motion_loaded
print(prompt, base, step)
try:
if step_loaded != step:
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
step_loaded = step
if base_loaded != base:
pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
base_loaded = base
if motion_loaded != motion:
pipe.unload_lora_weights()
if motion != "":
pipe.load_lora_weights(motion, adapter_name="motion")
pipe.set_adapters(["motion"], [0.7])
motion_loaded = motion
progress((0, step))
def progress_callback(i, t, z):
progress((i+1, step))
output = pipe(
prompt=prompt,
guidance_scale=1.2,
num_inference_steps=step,
callback=progress_callback,
callback_steps=1
)
name = str(uuid.uuid4()).replace("-", "")
path = f"/tmp/{name}.mp4"
export_to_video(output.frames[0], path, fps=10)
return path
except Exception as e:
print(f"Error during generation: {str(e)}")
return None
# Gradio Interface
with gr.Blocks(css="style.css") as demo:
gr.HTML(
"<h1><center>Textual Imagination : A Text To Video Synthesis</center></h1>"
)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label='Prompt'
)
with gr.Row():
select_base = gr.Dropdown(
label='Base model',
choices=[
"Cartoon",
"Realistic",
"3d",
"Anime",
],
value=base_loaded,
interactive=True
)
select_motion = gr.Dropdown(
label='Motion',
choices=[
("Default", ""),
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
],
value="guoyww/animatediff-motion-lora-zoom-in",
interactive=True
)
select_step = gr.Dropdown(
label='Inference steps',
choices=[
('1-Step', 1),
('2-Step', 2),
('4-Step', 4),
('8-Step', 8),
],
value=4,
interactive=True
)
submit = gr.Button(
scale=1,
variant='primary'
)
video = gr.Video(
label='AnimateDiff-Lightning',
autoplay=True,
height=512,
width=512,
elem_id="video_output"
)
gr.on(triggers=[
submit.click,
prompt.submit
],
fn = generate_image,
inputs = [prompt, select_base, select_motion, select_step],
outputs = [video],
api_name = "instant_video",
queue = False
)
demo.queue().launch()