OmniAvatar / app.py
alexnasa's picture
Update app.py
874b613 verified
raw
history blame
30.1 kB
import spaces
import subprocess
import gradio as gr
import os, sys
from glob import glob
from datetime import datetime
import math
import random
import librosa
import numpy as np
import uuid
import shutil
import importlib, site, sys
from huggingface_hub import hf_hub_download, snapshot_download
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()
def sh(cmd): subprocess.check_call(cmd, shell=True)
flash_attention_wheel = hf_hub_download(
repo_id="alexnasa/flash-attn-3",
repo_type="model",
filename="flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
sh(f"pip install {flash_attention_wheel}")
# tell Python to re-scan site-packages now that the egg-link exists
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
import torch
print(f'torch version:{torch.__version__}')
import torch.nn as nn
from tqdm import tqdm
from functools import partial
from omegaconf import OmegaConf
from argparse import Namespace
# load the one true config you dumped
_args_cfg = OmegaConf.load("args_config.yaml")
args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
from OmniAvatar.utils.args_config import set_global_args
set_global_args(args)
# args = parse_args()
from OmniAvatar.utils.io_utils import load_state_dict
from peft import LoraConfig, inject_adapter_in_model
from OmniAvatar.models.model_manager import ModelManager
from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
from OmniAvatar.wan_video import WanVideoPipeline
from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
import torchvision.transforms as TT
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
import torch.nn.functional as F
from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
def tensor_to_pil(tensor):
"""
Args:
tensor: torch.Tensor with shape like
(1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
values in [-1, 1], on any device.
Returns:
A PIL.Image in RGB mode.
"""
# 1) Remove batch dim if it exists
if tensor.dim() > 3 and tensor.shape[0] == 1:
tensor = tensor[0]
# 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
tensor = tensor.squeeze()
# Now we should have exactly 3 dims: (C, H, W)
if tensor.dim() != 3:
raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
# 3) Move to CPU float32
tensor = tensor.cpu().float()
# 4) Undo normalization from [-1,1] -> [0,1]
tensor = (tensor + 1.0) / 2.0
# 5) Clamp to [0,1]
tensor = torch.clamp(tensor, 0.0, 1.0)
# 6) To NumPy H×W×C in [0,255]
np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
# 7) Build PIL Image
return Image.fromarray(np_img)
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) # 设置当前GPU
torch.cuda.manual_seed_all(seed) # 设置所有GPU
def read_from_file(p):
with open(p, "r") as fin:
for l in fin:
yield l.strip()
def match_size(image_size, h, w):
ratio_ = 9999
size_ = 9999
select_size = None
for image_s in image_size:
ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
size_tmp = abs(max(image_s) - max(w, h))
if ratio_tmp < ratio_:
ratio_ = ratio_tmp
size_ = size_tmp
select_size = image_s
if ratio_ == ratio_tmp:
if size_ == size_tmp:
select_size = image_s
return select_size
def resize_pad(image, ori_size, tgt_size):
h, w = ori_size
scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
scale_h = int(h * scale_ratio)
scale_w = int(w * scale_ratio)
image = transforms.Resize(size=[scale_h, scale_w])(image)
padding_h = tgt_size[0] - scale_h
padding_w = tgt_size[1] - scale_w
pad_top = padding_h // 2
pad_bottom = padding_h - pad_top
pad_left = padding_w // 2
pad_right = padding_w - pad_left
image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
return image
class WanInferencePipeline(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.device = torch.device(f"cuda")
self.dtype = torch.bfloat16
self.pipe = self.load_model()
chained_trainsforms = []
chained_trainsforms.append(TT.ToTensor())
self.transform = TT.Compose(chained_trainsforms)
if self.args.use_audio:
from OmniAvatar.models.wav2vec import Wav2VecModel
self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
self.args.wav2vec_path
)
self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
self.audio_encoder.feature_extractor._freeze_parameters()
def load_model(self):
ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
if self.args.train_architecture == 'lora':
self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
else:
resume_path = ckpt_path
self.step = 0
# Load models
model_manager = ModelManager(device="cuda", infer=True)
model_manager.load_models(
[
self.args.dit_path.split(","),
self.args.vae_path,
self.args.text_encoder_path
],
torch_dtype=self.dtype,
device='cuda',
)
pipe = WanVideoPipeline.from_model_manager(model_manager,
torch_dtype=self.dtype,
device="cuda",
use_usp=False,
infer=True)
if self.args.train_architecture == "lora":
print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
self.add_lora_to_model(
pipe.denoising_model(),
lora_rank=self.args.lora_rank,
lora_alpha=self.args.lora_alpha,
lora_target_modules=self.args.lora_target_modules,
init_lora_weights=self.args.init_lora_weights,
pretrained_lora_path=pretrained_lora_path,
)
print(next(pipe.denoising_model().parameters()).device)
else:
missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
pipe.requires_grad_(False)
pipe.eval()
# pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
return pipe
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
model = inject_adapter_in_model(lora_config, model)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
def get_times(self, prompt,
image_path=None,
audio_path=None,
seq_len=101, # not used while audio_path is not None
height=720,
width=720,
overlap_frame=None,
num_steps=None,
negative_prompt=None,
guidance_scale=None,
audio_scale=None):
overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
num_steps = num_steps if num_steps is not None else self.args.num_steps
negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
if image_path is not None:
from PIL import Image
image = Image.open(image_path).convert("RGB")
image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
_, _, h, w = image.shape
select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
image = resize_pad(image, (h, w), select_size)
image = image * 2.0 - 1.0
image = image[:, :, None]
else:
image = None
select_size = [height, width]
num = self.args.max_tokens * 16 * 16 * 4
den = select_size[0] * select_size[1]
L0 = num // den
diff = (L0 - 1) % 4
L = L0 - diff
if L < 1:
L = 1
T = (L + 3) // 4
if self.args.random_prefix_frames:
fixed_frame = overlap_frame
assert fixed_frame % 4 == 1
else:
fixed_frame = 1
prefix_lat_frame = (3 + fixed_frame) // 4
first_fixed_frame = 1
audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
input_values = np.squeeze(
self.wav_feature_extractor(audio, sampling_rate=16000).input_values
)
input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
if audio_len < L - first_fixed_frame:
audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
seq_len = audio_len
times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
if times * (L-fixed_frame) + fixed_frame < seq_len:
times += 1
return times
@torch.no_grad()
def forward(self, prompt,
image_path=None,
audio_path=None,
seq_len=101, # not used while audio_path is not None
height=720,
width=720,
overlap_frame=None,
num_steps=None,
negative_prompt=None,
guidance_scale=None,
audio_scale=None):
overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
num_steps = num_steps if num_steps is not None else self.args.num_steps
negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
if image_path is not None:
from PIL import Image
image = Image.open(image_path).convert("RGB")
image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
_, _, h, w = image.shape
select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
image = resize_pad(image, (h, w), select_size)
image = image * 2.0 - 1.0
image = image[:, :, None]
else:
image = None
select_size = [height, width]
# L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
# L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
# T = (L + 3) // 4 # latent frames
# step 1: numerator and denominator as ints
num = args.max_tokens * 16 * 16 * 4
den = select_size[0] * select_size[1]
# step 2: integer division
L0 = num // den # exact floor division, no float in sight
# step 3: make it ≡ 1 mod 4
# if L0 % 4 == 1, keep L0;
# otherwise subtract the difference so that (L0 - diff) % 4 == 1,
# but ensure the result stays positive.
diff = (L0 - 1) % 4
L = L0 - diff
if L < 1:
L = 1 # or whatever your minimal frame count is
# step 4: latent frames
T = (L + 3) // 4
if self.args.i2v:
if self.args.random_prefix_frames:
fixed_frame = overlap_frame
assert fixed_frame % 4 == 1
else:
fixed_frame = 1
prefix_lat_frame = (3 + fixed_frame) // 4
first_fixed_frame = 1
else:
fixed_frame = 0
prefix_lat_frame = 0
first_fixed_frame = 0
if audio_path is not None and self.args.use_audio:
audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
input_values = np.squeeze(
self.wav_feature_extractor(audio, sampling_rate=16000).input_values
)
input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
input_values = input_values.unsqueeze(0)
# padding audio
if audio_len < L - first_fixed_frame:
audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
with torch.no_grad():
hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
audio_embeddings = hidden_states.last_hidden_state
for mid_hidden_states in hidden_states.hidden_states:
audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
seq_len = audio_len
audio_embeddings = audio_embeddings.squeeze(0)
audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
else:
audio_embeddings = None
# loop
times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
if times * (L-fixed_frame) + fixed_frame < seq_len:
times += 1
video = []
image_emb = {}
img_lat = None
if self.args.i2v:
self.pipe.load_models_to_device(['vae'])
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
image_cat = img_lat.repeat(1, 1, T, 1, 1)
msk[:, :, 1:] = 1
image_emb["y"] = torch.cat([image_cat, msk], dim=1)
for t in range(times):
print(f"[{t+1}/{times}]")
audio_emb = {}
if t == 0:
overlap = first_fixed_frame
else:
overlap = fixed_frame
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
prefix_overlap = (3 + overlap) // 4
if audio_embeddings is not None:
if t == 0:
audio_tensor = audio_embeddings[
:min(L - overlap, audio_embeddings.shape[0])
]
else:
audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
audio_tensor = audio_embeddings[
audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
]
audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
audio_prefix = audio_tensor[-fixed_frame:]
audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
audio_emb["audio_emb"] = audio_tensor
else:
audio_prefix = None
if image is not None and img_lat is None:
self.pipe.load_models_to_device(['vae'])
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
assert img_lat.shape[2] == prefix_overlap
img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
negative_prompt, num_inference_steps=num_steps,
cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
return_latent=True,
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
torch.cuda.empty_cache()
img_lat = None
image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
if t == 0:
video.append(frames)
else:
video.append(frames[:, overlap:])
video = torch.cat(video, dim=1)
video = video[:, :ori_audio_len + 1]
return video
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
# snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
# snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
# snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
import tempfile
from PIL import Image
set_seed(args.seed)
seq_len = args.seq_len
inferpipe = WanInferencePipeline(args)
def update_generate_button(image_path, audio_path, text, num_steps):
if image_path is None or audio_path is None:
return gr.update(value="⌚ Zero GPU Required: --")
duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s")
def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
audio_chunks = inferpipe.get_times(
prompt=text,
image_path=image_path,
audio_path=audio_path,
seq_len=args.seq_len,
num_steps=num_steps
)
warmup_s = 30
duration_s = (20 * num_steps) + warmup_s
if audio_chunks > 1:
duration_s = (20 * num_steps * audio_chunks) + warmup_s
print(f'for {audio_chunks} times, might take {duration_s}')
return int(duration_s)
def preprocess_img(image_path, session_id = None):
if session_id is None:
session_id = uuid.uuid4().hex
image = Image.open(image_path).convert("RGB")
image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
_, _, h, w = image.shape
select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
image = resize_pad(image, (h, w), select_size)
image = image * 2.0 - 1.0
image = image[:, :, None]
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
img_dir = output_dir + '/image'
os.makedirs(img_dir, exist_ok=True)
input_img_path = os.path.join(img_dir, f"img_input.jpg")
image = tensor_to_pil(image)
image.save(input_img_path)
return input_img_path
@spaces.GPU(duration=get_duration)
def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
if session_id is None:
session_id = uuid.uuid4().hex
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
audio_dir = output_dir + '/audio'
os.makedirs(audio_dir, exist_ok=True)
if args.silence_duration_s > 0:
input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
else:
input_audio_path = audio_path
prompt_dir = output_dir + '/prompt'
os.makedirs(prompt_dir, exist_ok=True)
if args.silence_duration_s > 0:
add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
prompt_path = os.path.join(prompt_dir, f"prompt.txt")
video = inferpipe(
prompt=text,
image_path=image_path,
audio_path=input_audio_path,
seq_len=args.seq_len,
num_steps=num_steps
)
torch.cuda.empty_cache()
add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
video_paths = save_video_as_grid_and_mp4(video,
output_dir,
args.fps,
prompt=text,
prompt_path = prompt_path,
audio_path=tmp2_audio_path if args.use_audio else None,
prefix=f'result')
return video_paths[0]
def cleanup(request: gr.Request):
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
shutil.rmtree(d1, ignore_errors=True)
def start_session(request: gr.Request):
return request.session_hash
css = """
#col-container {
margin: 0 auto;
max-width: 1560px;
}
"""
theme = gr.themes.Ocean()
with gr.Blocks(css=css, theme=theme) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: left;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
</p>
<a href="https://github.com/Omni-Avatar/OmniAvatar" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
</a>
</div>
<div style="text-align: left;">
HF Space by :<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
</a>
<a href="https://huggingface.co/alexnasa">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Reference Image", type="filepath", height=512)
audio_input = gr.Audio(label="Input Audio", type="filepath")
with gr.Column():
output_video = gr.Video(label="Avatar", height=512)
num_steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
text_input = gr.Textbox(label="Video Prompt", lines=6, value="A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
with gr.Column():
examples = gr.Examples(
examples=[
[
"examples/images/male-001.png",
"examples/audios/denial.wav",
"A realistic video of a man speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and he shakes his head in disappointment and tell look stright into the camera , with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
12
],
[
"examples/images/female-001.png",
"examples/audios/script.wav",
"A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
14
],
[
"examples/images/female-002.png",
"examples/audios/nature.wav",
"A realistic video of a woman speaking and sometimes looking directly to the camera and moving her eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on her movements, standing in the woods, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
10
],
],
inputs=[image_input, audio_input, text_input, num_steps],
outputs=[output_video],
fn=infer,
cache_examples=True
)
infer_btn.click(
fn=infer,
inputs=[image_input, audio_input, text_input, num_steps, session_state],
outputs=[output_video]
)
image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input]).then(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
audio_input.upload(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
if __name__ == "__main__":
demo.unload(cleanup)
demo.queue()
demo.launch(ssr_mode=False)