Spaces:
Running
on
Zero
Running
on
Zero
import subprocess | |
import os, sys | |
from glob import glob | |
from datetime import datetime | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
import math | |
import random | |
import librosa | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from functools import partial | |
from omegaconf import OmegaConf | |
from argparse import Namespace | |
# # load the one true config you dumped | |
# _args_cfg = OmegaConf.load("demo_out/config/args_config.yaml") | |
# args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True)) | |
# from OmniAvatar.utils.args_config import set_global_args | |
# set_global_args(args) | |
from OmniAvatar.utils.args_config import parse_args | |
args = parse_args() | |
from OmniAvatar.utils.io_utils import load_state_dict | |
from peft import LoraConfig, inject_adapter_in_model | |
from OmniAvatar.models.model_manager import ModelManager | |
from OmniAvatar.wan_video import WanVideoPipeline | |
from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4 | |
import torchvision.transforms as TT | |
from transformers import Wav2Vec2FeatureExtractor | |
import torchvision.transforms as transforms | |
import torch.nn.functional as F | |
from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg | |
from huggingface_hub import hf_hub_download | |
def set_seed(seed: int = 42): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) # 设置当前GPU | |
torch.cuda.manual_seed_all(seed) # 设置所有GPU | |
def read_from_file(p): | |
with open(p, "r") as fin: | |
for l in fin: | |
yield l.strip() | |
def match_size(image_size, h, w): | |
ratio_ = 9999 | |
size_ = 9999 | |
select_size = None | |
for image_s in image_size: | |
ratio_tmp = abs(image_s[0] / image_s[1] - h / w) | |
size_tmp = abs(max(image_s) - max(w, h)) | |
if ratio_tmp < ratio_: | |
ratio_ = ratio_tmp | |
size_ = size_tmp | |
select_size = image_s | |
if ratio_ == ratio_tmp: | |
if size_ == size_tmp: | |
select_size = image_s | |
return select_size | |
def resize_pad(image, ori_size, tgt_size): | |
h, w = ori_size | |
scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w) | |
scale_h = int(h * scale_ratio) | |
scale_w = int(w * scale_ratio) | |
image = transforms.Resize(size=[scale_h, scale_w])(image) | |
padding_h = tgt_size[0] - scale_h | |
padding_w = tgt_size[1] - scale_w | |
pad_top = padding_h // 2 | |
pad_bottom = padding_h - pad_top | |
pad_left = padding_w // 2 | |
pad_right = padding_w - pad_left | |
image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0) | |
return image | |
class WanInferencePipeline(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.device = torch.device(f"cuda") | |
if self.args.dtype=='bf16': | |
self.dtype = torch.bfloat16 | |
elif self.args.dtype=='fp16': | |
self.dtype = torch.float16 | |
else: | |
self.dtype = torch.float32 | |
self.pipe = self.load_model() | |
if self.args.i2v: | |
chained_trainsforms = [] | |
chained_trainsforms.append(TT.ToTensor()) | |
self.transform = TT.Compose(chained_trainsforms) | |
if self.args.use_audio: | |
from OmniAvatar.models.wav2vec import Wav2VecModel | |
self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( | |
self.args.wav2vec_path | |
) | |
self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device) | |
self.audio_encoder.feature_extractor._freeze_parameters() | |
def load_model(self): | |
torch.cuda.set_device(0) | |
ckpt_path = f'{self.args.exp_path}/pytorch_model.pt' | |
assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}" | |
if self.args.train_architecture == 'lora': | |
self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path | |
else: | |
resume_path = ckpt_path | |
self.step = 0 | |
# Load models | |
model_manager = ModelManager(device="cpu", infer=True) | |
model_manager.load_models( | |
[ | |
self.args.dit_path.split(","), | |
self.args.text_encoder_path, | |
self.args.vae_path | |
], | |
torch_dtype=self.dtype, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. | |
device='cpu', | |
) | |
LORA_REPO_ID = "Kijai/WanVideo_comfy" | |
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" | |
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) | |
model_manager.load_lora(causvid_path, lora_alpha=1.0) | |
pipe = WanVideoPipeline.from_model_manager(model_manager, | |
torch_dtype=self.dtype, | |
device=f"cuda", | |
use_usp=True if self.args.sp_size > 1 else False, | |
infer=True) | |
if self.args.train_architecture == "lora": | |
print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}') | |
self.add_lora_to_model( | |
pipe.denoising_model(), | |
lora_rank=self.args.lora_rank, | |
lora_alpha=self.args.lora_alpha, | |
lora_target_modules=self.args.lora_target_modules, | |
init_lora_weights=self.args.init_lora_weights, | |
pretrained_lora_path=pretrained_lora_path, | |
) | |
else: | |
missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True) | |
print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys") | |
pipe.requires_grad_(False) | |
pipe.eval() | |
pipe.enable_vram_management(num_persistent_param_in_dit=self.args.num_persistent_param_in_dit) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. | |
return pipe | |
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None): | |
# Add LoRA to UNet | |
self.lora_alpha = lora_alpha | |
if init_lora_weights == "kaiming": | |
init_lora_weights = True | |
lora_config = LoraConfig( | |
r=lora_rank, | |
lora_alpha=lora_alpha, | |
init_lora_weights=init_lora_weights, | |
target_modules=lora_target_modules.split(","), | |
) | |
model = inject_adapter_in_model(lora_config, model) | |
# Lora pretrained lora weights | |
if pretrained_lora_path is not None: | |
state_dict = load_state_dict(pretrained_lora_path) | |
if state_dict_converter is not None: | |
state_dict = state_dict_converter(state_dict) | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
all_keys = [i for i, _ in model.named_parameters()] | |
num_updated_keys = len(all_keys) - len(missing_keys) | |
num_unexpected_keys = len(unexpected_keys) | |
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") | |
def forward(self, prompt, | |
image_path=None, | |
audio_path=None, | |
seq_len=101, # not used while audio_path is not None | |
height=720, | |
width=720, | |
overlap_frame=None, | |
num_steps=None, | |
negative_prompt=None, | |
guidance_scale=None, | |
audio_scale=None): | |
overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame | |
num_steps = num_steps if num_steps is not None else self.args.num_steps | |
negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt | |
guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale | |
audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale | |
if image_path is not None: | |
from PIL import Image | |
image = Image.open(image_path).convert("RGB") | |
image = self.transform(image).unsqueeze(0).to(self.device) | |
_, _, h, w = image.shape | |
select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w) | |
image = resize_pad(image, (h, w), select_size) | |
image = image * 2.0 - 1.0 | |
image = image[:, :, None] | |
else: | |
image = None | |
select_size = [height, width] | |
L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1]) | |
L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames | |
T = (L + 3) // 4 # latent frames | |
if self.args.i2v: | |
if self.args.random_prefix_frames: | |
fixed_frame = overlap_frame | |
assert fixed_frame % 4 == 1 | |
else: | |
fixed_frame = 1 | |
prefix_lat_frame = (3 + fixed_frame) // 4 | |
first_fixed_frame = 1 | |
else: | |
fixed_frame = 0 | |
prefix_lat_frame = 0 | |
first_fixed_frame = 0 | |
if audio_path is not None and self.args.use_audio: | |
audio, sr = librosa.load(audio_path, sr=self.args.sample_rate) | |
input_values = np.squeeze( | |
self.wav_feature_extractor(audio, sampling_rate=16000).input_values | |
) | |
input_values = torch.from_numpy(input_values).float().to(device=self.device) | |
ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps) | |
input_values = input_values.unsqueeze(0) | |
# padding audio | |
if audio_len < L - first_fixed_frame: | |
audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame)) | |
elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0: | |
audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame)) | |
input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0) | |
with torch.no_grad(): | |
hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True) | |
audio_embeddings = hidden_states.last_hidden_state | |
for mid_hidden_states in hidden_states.hidden_states: | |
audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1) | |
seq_len = audio_len | |
audio_embeddings = audio_embeddings.squeeze(0) | |
audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame]) | |
else: | |
audio_embeddings = None | |
# loop | |
times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1 | |
if times * (L-fixed_frame) + fixed_frame < seq_len: | |
times += 1 | |
video = [] | |
image_emb = {} | |
img_lat = None | |
if self.args.i2v: | |
self.pipe.load_models_to_device(['vae']) | |
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device) | |
msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1]) | |
image_cat = img_lat.repeat(1, 1, T, 1, 1) | |
msk[:, :, 1:] = 1 | |
image_emb["y"] = torch.cat([image_cat, msk], dim=1) | |
for t in range(times): | |
print(f"[{t+1}/{times}]") | |
audio_emb = {} | |
if t == 0: | |
overlap = first_fixed_frame | |
else: | |
overlap = fixed_frame | |
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap | |
prefix_overlap = (3 + overlap) // 4 | |
if audio_embeddings is not None: | |
if t == 0: | |
audio_tensor = audio_embeddings[ | |
:min(L - overlap, audio_embeddings.shape[0]) | |
] | |
else: | |
audio_start = L - first_fixed_frame + (t - 1) * (L - overlap) | |
audio_tensor = audio_embeddings[ | |
audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0]) | |
] | |
audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0) | |
audio_prefix = audio_tensor[-fixed_frame:] | |
audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype) | |
audio_emb["audio_emb"] = audio_tensor | |
else: | |
audio_prefix = None | |
if image is not None and img_lat is None: | |
self.pipe.load_models_to_device(['vae']) | |
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device) | |
assert img_lat.shape[2] == prefix_overlap | |
img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1))], dim=2) | |
frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb, | |
negative_prompt, num_inference_steps=num_steps, | |
cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale, | |
return_latent=True, | |
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B") | |
img_lat = None | |
image = (frames[:, -fixed_frame:].clip(0, 1) * 2 - 1).permute(0, 2, 1, 3, 4).contiguous() | |
if t == 0: | |
video.append(frames) | |
else: | |
video.append(frames[:, overlap:]) | |
video = torch.cat(video, dim=1) | |
video = video[:, :ori_audio_len + 1] | |
return video | |
def main(): | |
# os.makedirs("demo_out/config", exist_ok=True) | |
# OmegaConf.save(config=OmegaConf.create(vars(args)), | |
# f="demo_out/config/args_config.yaml") | |
# print("Saved merged args to demo_out/config/args_config.yaml") | |
set_seed(args.seed) | |
# laod data | |
data_iter = read_from_file(args.input_file) | |
exp_name = os.path.basename(args.exp_path) | |
seq_len = args.seq_len | |
# Text-to-video | |
inferpipe = WanInferencePipeline(args) | |
output_dir = f'demo_out' | |
idx = 0 | |
text = "A realistic video of a man speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence." | |
image_path = "examples/images/0000.jpeg" | |
audio_path = "examples/audios/0000.MP3" | |
audio_dir = output_dir + '/audio' | |
os.makedirs(audio_dir, exist_ok=True) | |
if args.silence_duration_s > 0: | |
input_audio_path = os.path.join(audio_dir, f"audio_input_{idx:03d}.wav") | |
else: | |
input_audio_path = audio_path | |
prompt_dir = output_dir + '/prompt' | |
os.makedirs(prompt_dir, exist_ok=True) | |
if args.silence_duration_s > 0: | |
add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s) | |
video = inferpipe( | |
prompt=text, | |
image_path=image_path, | |
audio_path=input_audio_path, | |
seq_len=seq_len | |
) | |
tmp2_audio_path = os.path.join(audio_dir, f"audio_out_{idx:03d}.wav") # 因为第一帧是参考帧,因此需要往前1/25秒 | |
prompt_path = os.path.join(prompt_dir, f"prompt_{idx:03d}.txt") | |
add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s) | |
save_video_as_grid_and_mp4(video, | |
output_dir, | |
args.fps, | |
prompt=text, | |
prompt_path = prompt_path, | |
audio_path=tmp2_audio_path if args.use_audio else None, | |
prefix=f'result_{idx:03d}') | |
class NoPrint: | |
def write(self, x): | |
pass | |
def flush(self): | |
pass | |
if __name__ == '__main__': | |
if not args.debug: | |
if args.local_rank != 0: # 屏蔽除0外的输出 | |
sys.stdout = NoPrint() | |
main() |