Spaces:
Runtime error
Runtime error
File size: 5,607 Bytes
6a87547 dae6484 6a87547 73ee197 ad0cea8 702754c 6a87547 b664a31 6a87547 b832af5 6a87547 73ee197 7341603 b664a31 6a87547 ad0cea8 b664a31 ad0cea8 b664a31 ad0cea8 b664a31 6a87547 b664a31 6a87547 ad0cea8 7341603 b664a31 6a87547 ad0cea8 4254e9c 7341603 d093812 4254e9c e0e8e31 16c1b5a 4254e9c ad0cea8 4254e9c ad0cea8 6a87547 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import torch
import gradio as gr
from PIL import Image, ImageOps
from huggingface_hub import snapshot_download
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import export_to_video
#import spaces
import uuid
# Constants
MODEL_PATH = "pyramid-flow-model"
MODEL_REPO = "rain1011/pyramid-flow-sd3"
MODEL_VARIANT = "diffusion_transformer_384p"
MODEL_DTYPE = "bf16"
def center_crop(image, target_width, target_height):
width, height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_image = width / height
if aspect_ratio_image > aspect_ratio_target:
# Crop the width (left and right)
new_width = int(height * aspect_ratio_target)
left = (width - new_width) // 2
right = left + new_width
top, bottom = 0, height
else:
# Crop the height (top and bottom)
new_height = int(width / aspect_ratio_target)
top = (height - new_height) // 2
bottom = top + new_height
left, right = 0, width
image = image.crop((left, top, right, bottom))
return image
# Download and load the model
def load_model():
if not os.path.exists(MODEL_PATH):
snapshot_download(MODEL_REPO, local_dir=MODEL_PATH, local_dir_use_symlinks=False, repo_type='model')
model = PyramidDiTForVideoGeneration(
MODEL_PATH,
MODEL_DTYPE,
model_variant=MODEL_VARIANT,
)
model.vae.to("cuda")
model.dit.to("cuda")
model.text_encoder.to("cuda")
model.vae.enable_tiling()
return model
# Global model variable
model = load_model()
# Text-to-video generation function
#@spaces.GPU(duration=120)
def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guidance_scale=5, progress=gr.Progress(track_tqdm=True)):
multiplier = 3
temp = int(duration * multiplier) + 1 # Convert seconds to temp value (assuming 24 FPS)
torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
if(image):
cropped_image = center_crop(image, 640, 384)
resized_image = cropped_image.resize((640, 384))
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=resized_image,
num_inference_steps=[10, 10, 10],
temp=temp,
video_guidance_scale=video_guidance_scale,
output_type="pil",
save_memory=True,
)
else:
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=384,
width=640,
temp=temp,
guidance_scale=guidance_scale,
video_guidance_scale=video_guidance_scale,
output_type="pil",
save_memory=True,
)
output_path = f"{str(uuid.uuid4())}_output_video.mp4"
export_to_video(frames, output_path, fps=24)
return output_path
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Pyramid Flow 384p demo")
gr.Markdown("Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours")
gr.Markdown("[[Paper](https://arxiv.org/pdf/2410.05954)], [[Model](https://huggingface.co/rain1011/pyramid-flow-sd3)], [[Code](https://github.com/jy0205/Pyramid-Flow)] [[Project Page]](https://pyramid-flow.github.io)")
with gr.Row():
with gr.Column():
with gr.Accordion("Image to Video (optional)", open=False):
i2v_image = gr.Image(type="pil", label="Input Image")
t2v_prompt = gr.Textbox(label="Prompt")
with gr.Accordion("Advanced settings", open=False):
t2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=7, step=0.1, label="Guidance Scale")
t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
t2v_generate_btn = gr.Button("Generate Video")
with gr.Column():
t2v_output = gr.Video(label="Generated Video")
gr.Examples(
examples=[
"A futuristic explorer, 30 years old, travels across distant galaxies in a sleek silver space suit, gliding through a glowing nebula. The scene is illuminated by vibrant starbursts and cosmic dust, captured with a futuristic drone in ultra-high-definition, showcasing vibrant purples and blues",
"In a serene winter landscape, a futuristic metropolis hums with life. The camera glides along an icy street as citizens, wrapped in advanced thermal suits, enjoy the wintry scene. Holographic advertisements flicker above snow-covered buildings, while sleek flying vehicles zip overhead. In the background, delicate crystalline structures refract light through the snowflakes."
],
fn=generate_video,
inputs=t2v_prompt,
outputs=t2v_output,
cache_examples="lazy"
)
t2v_generate_btn.click(
generate_video,
inputs=[t2v_prompt, i2v_image, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale],
outputs=t2v_output
)
demo.launch() |