Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| import os | |
| import torch | |
| from extern.depthcrafter.infer import DepthCrafterDemo | |
| # from extern.video_depth_anything.vdademo import VDADemo | |
| import numpy as np | |
| import torch | |
| from transformers import T5EncoderModel | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from models.crosstransformer3d import CrossTransformer3DModel | |
| from models.autoencoder_magvit import AutoencoderKLCogVideoX | |
| from models.pipeline_trajectorycrafter import TrajCrafter_Pipeline | |
| from models.utils import * | |
| from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler, | |
| DPMSolverMultistepScheduler, | |
| EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, | |
| PNDMScheduler) | |
| from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
| class TrajCrafter: | |
| def __init__(self, opts, gradio=False): | |
| self.funwarp = Warper(device=opts.device) | |
| # self.depth_estimater = VDADemo(pre_train_path=opts.pre_train_path_vda,device=opts.device) | |
| self.depth_estimater = DepthCrafterDemo(unet_path=opts.unet_path,pre_train_path=opts.pre_train_path,cpu_offload=opts.cpu_offload,device=opts.device) | |
| self.caption_processor = AutoProcessor.from_pretrained(opts.blip_path) | |
| self.captioner = Blip2ForConditionalGeneration.from_pretrained(opts.blip_path, torch_dtype=torch.float16).to(opts.device) | |
| self.setup_diffusion(opts) | |
| if gradio: | |
| self.opts=opts | |
| def infer_gradual(self,opts): | |
| frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) | |
| prompt = self.get_caption(opts,frames[opts.video_length//2]) | |
| # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
| depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) | |
| frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
| assert frames.shape[0] == opts.video_length | |
| pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.video_length) | |
| warped_images = [] | |
| masks = [] | |
| for i in tqdm(range(opts.video_length)): | |
| warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, opts.mask,twice=False) | |
| warped_images.append(warped_frame2) | |
| masks.append(mask2) | |
| cond_video = (torch.cat(warped_images)+1.)/2. | |
| cond_masks = torch.cat(masks) | |
| frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') | |
| save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) | |
| save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) | |
| save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) | |
| frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
| frames_ref = frames[:,:,:10,:,:] | |
| cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
| cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
| generator = torch.Generator(device=opts.device).manual_seed(opts.seed) | |
| del self.depth_estimater | |
| del self.caption_processor | |
| del self.captioner | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| sample = self.pipeline( | |
| prompt, | |
| num_frames = opts.video_length, | |
| negative_prompt = opts.negative_prompt, | |
| height = opts.sample_size[0], | |
| width = opts.sample_size[1], | |
| generator = generator, | |
| guidance_scale = opts.diffusion_guidance_scale, | |
| num_inference_steps = opts.diffusion_inference_steps, | |
| video = cond_video, | |
| mask_video = cond_masks, | |
| reference = frames_ref, | |
| ).videos | |
| save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) | |
| viz = True | |
| if viz: | |
| tensor_left = frames[0].to(opts.device) | |
| tensor_right = sample[0].to(opts.device) | |
| interval = torch.ones(3, 49, 384, 30).to(opts.device) | |
| result = torch.cat((tensor_left, interval, tensor_right), dim=3) | |
| result_reverse = torch.flip(result, dims=[1]) | |
| final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
| save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2) | |
| def infer_direct(self,opts): | |
| opts.cut = 20 | |
| frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) | |
| prompt = self.get_caption(opts,frames[opts.video_length//2]) | |
| # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
| depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) | |
| frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
| assert frames.shape[0] == opts.video_length | |
| pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.cut) | |
| warped_images = [] | |
| masks = [] | |
| for i in tqdm(range(opts.video_length)): | |
| if i < opts.cut: | |
| warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[0:1], None, depths[0:1], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False) | |
| warped_images.append(warped_frame2) | |
| masks.append(mask2) | |
| else: | |
| warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i-opts.cut:i-opts.cut+1], None, depths[i-opts.cut:i-opts.cut+1], pose_s[0:1], pose_t[-1:], K[0:1], None, opts.mask,twice=False) | |
| warped_images.append(warped_frame2) | |
| masks.append(mask2) | |
| cond_video = (torch.cat(warped_images)+1.)/2. | |
| cond_masks = torch.cat(masks) | |
| frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') | |
| save_video((frames[:opts.video_length-opts.cut].permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) | |
| save_video(cond_video[opts.cut:].permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) | |
| save_video(cond_masks[opts.cut:].repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) | |
| frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
| frames_ref = frames[:,:,:10,:,:] | |
| cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
| cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
| generator = torch.Generator(device=opts.device).manual_seed(opts.seed) | |
| del self.depth_estimater | |
| del self.caption_processor | |
| del self.captioner | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| sample = self.pipeline( | |
| prompt, | |
| num_frames = opts.video_length, | |
| negative_prompt = opts.negative_prompt, | |
| height = opts.sample_size[0], | |
| width = opts.sample_size[1], | |
| generator = generator, | |
| guidance_scale = opts.diffusion_guidance_scale, | |
| num_inference_steps = opts.diffusion_inference_steps, | |
| video = cond_video, | |
| mask_video = cond_masks, | |
| reference = frames_ref, | |
| ).videos | |
| save_video(sample[0].permute(1,2,3,0)[opts.cut:], os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) | |
| viz = True | |
| if viz: | |
| tensor_left = frames[0][:,:opts.video_length-opts.cut,...].to(opts.device) | |
| tensor_right = sample[0][:,opts.cut:,...].to(opts.device) | |
| interval = torch.ones(3, opts.video_length-opts.cut, 384, 30).to(opts.device) | |
| result = torch.cat((tensor_left, interval, tensor_right), dim=3) | |
| result_reverse = torch.flip(result, dims=[1]) | |
| final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
| save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2) | |
| def infer_bullet(self,opts): | |
| frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) | |
| prompt = self.get_caption(opts,frames[opts.video_length//2]) | |
| # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
| depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) | |
| frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
| assert frames.shape[0] == opts.video_length | |
| pose_s, pose_t, K = self.get_poses(opts,depths, num_frames = opts.video_length) | |
| warped_images = [] | |
| masks = [] | |
| for i in tqdm(range(opts.video_length)): | |
| warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[-1:], None, depths[-1:], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False) | |
| warped_images.append(warped_frame2) | |
| masks.append(mask2) | |
| cond_video = (torch.cat(warped_images)+1.)/2. | |
| cond_masks = torch.cat(masks) | |
| frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') | |
| save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) | |
| save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) | |
| save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) | |
| frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
| frames_ref = frames[:,:,-10:,:,:] | |
| cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
| cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
| generator = torch.Generator(device=opts.device).manual_seed(opts.seed) | |
| del self.depth_estimater | |
| del self.caption_processor | |
| del self.captioner | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| sample = self.pipeline( | |
| prompt, | |
| num_frames = opts.video_length, | |
| negative_prompt = opts.negative_prompt, | |
| height = opts.sample_size[0], | |
| width = opts.sample_size[1], | |
| generator = generator, | |
| guidance_scale = opts.diffusion_guidance_scale, | |
| num_inference_steps = opts.diffusion_inference_steps, | |
| video = cond_video, | |
| mask_video = cond_masks, | |
| reference = frames_ref, | |
| ).videos | |
| save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) | |
| viz = True | |
| if viz: | |
| tensor_left = frames[0].to(opts.device) | |
| tensor_left_full = torch.cat([tensor_left,tensor_left[:,-1:,:,:].repeat(1,48,1,1)],dim=1) | |
| tensor_right = sample[0].to(opts.device) | |
| tensor_right_full = torch.cat([tensor_left,tensor_right[:,1:,:,:]],dim=1) | |
| interval = torch.ones(3, 49*2-1, 384, 30).to(opts.device) | |
| result = torch.cat((tensor_left_full, interval, tensor_right_full), dim=3) | |
| result_reverse = torch.flip(result, dims=[1]) | |
| final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
| save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*4) | |
| def get_caption(self,opts,image): | |
| image_array = (image * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(image_array) | |
| inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(opts.device, torch.float16) | |
| generated_ids = self.captioner.generate(**inputs) | |
| generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| return generated_text + opts.refine_prompt | |
| def get_poses(self,opts,depths,num_frames): | |
| radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*opts.radius_scale | |
| radius = min(radius, 5) | |
| cx = 512. #depths.shape[-1]//2 | |
| cy = 288. #depths.shape[-2]//2 | |
| f = 500 #500. | |
| K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(opts.device) | |
| c2w_init = torch.tensor([[-1., 0., 0., 0.], | |
| [ 0., 1., 0., 0.], | |
| [ 0., 0., -1., 0.], | |
| [ 0., 0., 0., 1.]]).to(opts.device).unsqueeze(0) | |
| if opts.camera == 'target': | |
| dtheta, dphi, dr, dx, dy = opts.target_pose | |
| poses = generate_traj_specified(c2w_init, dtheta, dphi, dr*radius, dx, dy, num_frames, opts.device) | |
| elif opts.camera =='traj': | |
| with open(opts.traj_txt, 'r') as file: | |
| lines = file.readlines() | |
| theta = [float(i) for i in lines[0].split()] | |
| phi = [float(i) for i in lines[1].split()] | |
| r = [float(i)*radius for i in lines[2].split()] | |
| poses = generate_traj_txt(c2w_init, phi, theta, r, num_frames, opts.device) | |
| poses[:,2, 3] = poses[:,2, 3] + radius | |
| pose_s = poses[opts.anchor_idx:opts.anchor_idx+1].repeat(num_frames,1,1) | |
| pose_t = poses | |
| return pose_s, pose_t, K | |
| # def setup_diffusion(self,opts): | |
| # # transformer = CrossTransformer3DModel.from_pretrained_cus(opts.transformer_path).to(opts.weight_dtype) | |
| # transformer = CrossTransformer3DModel.from_pretrained(opts.transformer_path).to(opts.weight_dtype) | |
| # # transformer = transformer.to(opts.weight_dtype) | |
| # vae = AutoencoderKLCogVideoX.from_pretrained( | |
| # opts.model_name, | |
| # subfolder="vae" | |
| # ).to(opts.weight_dtype) | |
| # text_encoder = T5EncoderModel.from_pretrained( | |
| # opts.model_name, subfolder="text_encoder", torch_dtype=opts.weight_dtype | |
| # ) | |
| # # Get Scheduler | |
| # Choosen_Scheduler = { | |
| # "Euler": EulerDiscreteScheduler, | |
| # "Euler A": EulerAncestralDiscreteScheduler, | |
| # "DPM++": DPMSolverMultistepScheduler, | |
| # "PNDM": PNDMScheduler, | |
| # "DDIM_Cog": CogVideoXDDIMScheduler, | |
| # "DDIM_Origin": DDIMScheduler, | |
| # }[opts.sampler_name] | |
| # scheduler = Choosen_Scheduler.from_pretrained( | |
| # opts.model_name, | |
| # subfolder="scheduler" | |
| # ) | |
| # self.pipeline = TrajCrafter_Pipeline.from_pretrained( | |
| # opts.model_name, | |
| # vae=vae, | |
| # text_encoder=text_encoder, | |
| # transformer=transformer, | |
| # scheduler=scheduler, | |
| # torch_dtype=opts.weight_dtype | |
| # ) | |
| # if opts.low_gpu_memory_mode: | |
| # self.pipeline.enable_sequential_cpu_offload() | |
| # else: | |
| # self.pipeline.enable_model_cpu_offload() | |
| def setup_diffusion(self, opts): | |
| import torch | |
| # 1) 选择设备 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # (可选)在 CPU 跑时避免 efficient attention 报错;在 CUDA 上也无害 | |
| try: | |
| torch.backends.cuda.enable_flash_sdp(False) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(True) | |
| except Exception: | |
| pass | |
| # 2) 加载/放置子模块到 device + dtype | |
| # 注意:原代码只 .to(dtype),未指定 device;这里补齐 | |
| transformer = CrossTransformer3DModel.from_pretrained(opts.transformer_path) | |
| transformer = transformer.to(device=device, dtype=opts.weight_dtype) | |
| vae = AutoencoderKLCogVideoX.from_pretrained( | |
| opts.model_name, | |
| subfolder="vae", | |
| # 仅指定 dtype;后面统一 .to(device) | |
| # 某些 from_pretrained 不支持 device 形参 | |
| ).to(dtype=opts.weight_dtype).to(device) | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| opts.model_name, | |
| subfolder="text_encoder", | |
| torch_dtype=opts.weight_dtype, | |
| ).to(device) | |
| # 3) 调度器照旧 | |
| Choosen_Scheduler = { | |
| "Euler": EulerDiscreteScheduler, | |
| "Euler A": EulerAncestralDiscreteScheduler, | |
| "DPM++": DPMSolverMultistepScheduler, | |
| "PNDM": PNDMScheduler, | |
| "DDIM_Cog": CogVideoXDDIMScheduler, | |
| "DDIM_Origin": DDIMScheduler, | |
| }[opts.sampler_name] | |
| scheduler = Choosen_Scheduler.from_pretrained( | |
| opts.model_name, | |
| subfolder="scheduler" | |
| ) | |
| # 4) 组装 pipeline,并确保在正确 device/dtype | |
| self.pipeline = TrajCrafter_Pipeline.from_pretrained( | |
| opts.model_name, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| torch_dtype=opts.weight_dtype, | |
| ) | |
| # Offload 策略: | |
| # - 如果你机器有足够显存,建议直接 to(device) 获得最稳的行为 | |
| # - 如果显存紧张,再启用 offload(需要 accelerate 支持) | |
| if opts.low_gpu_memory_mode: | |
| # 这两种 offload 会在计算时把块迁移到 GPU,空闲时回收;加速略低但更省显存 | |
| # 二选一:根据你之前的使用习惯保留其一 | |
| # self.pipeline.enable_sequential_cpu_offload() | |
| self.pipeline.enable_model_cpu_offload() | |
| else: | |
| self.pipeline.to(device) | |
| def run_gradio(self,input_video, stride, radius_scale, pose, steps, seed): | |
| frames = read_video_frames(input_video, self.opts.video_length, stride,self.opts.max_res) | |
| prompt = self.get_caption(self.opts,frames[self.opts.video_length//2]) | |
| # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
| depths= self.depth_estimater.infer(frames, self.opts.near, self.opts.far, self.opts.depth_inference_steps, self.opts.depth_guidance_scale, window_size=self.opts.window_size, overlap=self.opts.overlap).to(self.opts.device) | |
| frames = torch.from_numpy(frames).permute(0,3,1,2).to(self.opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
| num_frames = frames.shape[0] | |
| assert num_frames == self.opts.video_length | |
| radius_scale = float(radius_scale) | |
| radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*radius_scale | |
| radius = min(radius, 5) | |
| cx = 512. #depths.shape[-1]//2 | |
| cy = 288. #depths.shape[-2]//2 | |
| f = 500 #500. | |
| K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(self.opts.device) | |
| c2w_init = torch.tensor([[-1., 0., 0., 0.], | |
| [ 0., 1., 0., 0.], | |
| [ 0., 0., -1., 0.], | |
| [ 0., 0., 0., 1.]]).to(self.opts.device).unsqueeze(0) | |
| # import pdb | |
| # pdb.set_trace() | |
| theta,phi,r,x,y = [float(i) for i in pose.split(';')] | |
| # theta,phi,r,x,y = [float(i) for i in theta.split()],[float(i) for i in phi.split()],[float(i) for i in r.split()],[float(i) for i in x.split()],[float(i) for i in y.split()] | |
| # target mode | |
| poses = generate_traj_specified(c2w_init, theta, phi, r*radius, x, y, num_frames, self.opts.device) | |
| poses[:,2, 3] = poses[:,2, 3] + radius | |
| pose_s = poses[self.opts.anchor_idx:self.opts.anchor_idx+1].repeat(num_frames,1,1) | |
| pose_t = poses | |
| warped_images = [] | |
| masks = [] | |
| for i in tqdm(range(self.opts.video_length)): | |
| warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, self.opts.mask,twice=False) | |
| warped_images.append(warped_frame2) | |
| masks.append(mask2) | |
| cond_video = (torch.cat(warped_images)+1.)/2. | |
| cond_masks = torch.cat(masks) | |
| frames = F.interpolate(frames, size=self.opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_video = F.interpolate(cond_video, size=self.opts.sample_size, mode='bilinear', align_corners=False) | |
| cond_masks = F.interpolate(cond_masks, size=self.opts.sample_size, mode='nearest') | |
| save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(self.opts.save_dir,'input.mp4'),fps=self.opts.fps) | |
| save_video(cond_video.permute(0,2,3,1), os.path.join(self.opts.save_dir,'render.mp4'),fps=self.opts.fps) | |
| save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(self.opts.save_dir,'mask.mp4'),fps=self.opts.fps) | |
| frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
| frames_ref = frames[:,:,:10,:,:] | |
| cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
| cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
| generator = torch.Generator(device=self.opts.device).manual_seed(seed) | |
| # del self.depth_estimater | |
| # del self.caption_processor | |
| # del self.captioner | |
| # gc.collect() | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| sample = self.pipeline( | |
| prompt, | |
| num_frames = self.opts.video_length, | |
| negative_prompt = self.opts.negative_prompt, | |
| height = self.opts.sample_size[0], | |
| width = self.opts.sample_size[1], | |
| generator = generator, | |
| guidance_scale = self.opts.diffusion_guidance_scale, | |
| num_inference_steps = steps, | |
| video = cond_video, | |
| mask_video = cond_masks, | |
| reference = frames_ref, | |
| ).videos | |
| save_video(sample[0].permute(1,2,3,0), os.path.join(self.opts.save_dir,'gen.mp4'), fps=self.opts.fps) | |
| viz = True | |
| if viz: | |
| tensor_left = frames[0].to(self.opts.device) | |
| tensor_right = sample[0].to(self.opts.device) | |
| interval = torch.ones(3, 49, 384, 30).to(self.opts.device) | |
| result = torch.cat((tensor_left, interval, tensor_right), dim=3) | |
| result_reverse = torch.flip(result, dims=[1]) | |
| final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
| save_video(final_result.permute(1,2,3,0), os.path.join(self.opts.save_dir,'viz.mp4'), fps=self.opts.fps*2) | |
| return os.path.join(self.opts.save_dir,'viz.mp4') |