|
import os |
|
import torch |
|
import sys |
|
import argparse |
|
import random |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from diffusers.utils import export_to_video |
|
from pyramid_dit import PyramidDiTForVideoGeneration |
|
from trainer_misc import init_distributed_mode, init_sequence_parallel_group |
|
import PIL |
|
from PIL import Image |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False) |
|
parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"]) |
|
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16") |
|
parser.add_argument('--model_path', default='/home/jinyang06/models/pyramid-flow', type=str, help='Set it to the downloaded checkpoint dir') |
|
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str,) |
|
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v']) |
|
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1') |
|
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of gpus used for inference, should be 2 or 4") |
|
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process.") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
|
|
init_distributed_mode(args) |
|
|
|
assert args.world_size == args.sp_group_size, "The sequence parallel size should be DDP world size" |
|
|
|
|
|
init_sequence_parallel_group(args) |
|
|
|
device = torch.device('cuda') |
|
rank = args.rank |
|
model_dtype = args.model_dtype |
|
|
|
model = PyramidDiTForVideoGeneration( |
|
args.model_path, |
|
model_dtype, |
|
model_name=args.model_name, |
|
model_variant=args.variant, |
|
) |
|
|
|
model.vae.to(device) |
|
model.dit.to(device) |
|
model.text_encoder.to(device) |
|
model.vae.enable_tiling() |
|
|
|
if model_dtype == "bf16": |
|
torch_dtype = torch.bfloat16 |
|
elif model_dtype == "fp16": |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
|
|
if args.variant == 'diffusion_transformer_768p': |
|
width = 1280 |
|
height = 768 |
|
else: |
|
assert args.variant == 'diffusion_transformer_384p' |
|
width = 640 |
|
height = 384 |
|
|
|
if args.task == 't2v': |
|
prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors" |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype): |
|
frames = model.generate( |
|
prompt=prompt, |
|
num_inference_steps=[20, 20, 20], |
|
video_num_inference_steps=[10, 10, 10], |
|
height=height, |
|
width=width, |
|
temp=args.temp, |
|
guidance_scale=7.0, |
|
video_guidance_scale=5.0, |
|
output_type="pil", |
|
save_memory=True, |
|
cpu_offloading=False, |
|
inference_multigpu=True, |
|
) |
|
if rank == 0: |
|
export_to_video(frames, "./text_to_video_sample.mp4", fps=24) |
|
|
|
else: |
|
assert args.task == 'i2v' |
|
|
|
image_path = 'assets/the_great_wall.jpg' |
|
image = Image.open(image_path).convert("RGB") |
|
image = image.resize((width, height)) |
|
|
|
prompt = "FPV flying over the Great Wall" |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype): |
|
frames = model.generate_i2v( |
|
prompt=prompt, |
|
input_image=image, |
|
num_inference_steps=[10, 10, 10], |
|
temp=args.temp, |
|
video_guidance_scale=4.0, |
|
output_type="pil", |
|
save_memory=True, |
|
cpu_offloading=False, |
|
inference_multigpu=True, |
|
) |
|
|
|
if rank == 0: |
|
export_to_video(frames, "./image_to_video_sample.mp4", fps=24) |
|
|
|
torch.distributed.barrier() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |