Spaces:
Sleeping
Sleeping
import numpy as np | |
from typing import Optional | |
from pathlib import Path | |
from typing import Tuple | |
import time | |
from model_demo.inference.audio import AudioStream | |
from model_demo.inference.landmarks import ( | |
unscale_and_uncenter_head_angles, | |
clean_up_blendshapes, | |
exaggerate_head_wiggle, | |
) | |
from model_demo.inference.constants import ( | |
N_AUDIO_SAMPLES_PER_VIDEO_FRAME, | |
SAMPLE_RATE, | |
HEAD_LANDMARK_DIM, | |
) | |
import onnxruntime as ort | |
from dataclasses import dataclass | |
from typing import Optional, Union | |
class InferencePipeline: | |
""" | |
Pipeline for running WhisperLike model inference on a video file. | |
Added crossfade functionality to smooth transitions between chunks. | |
""" | |
def __init__( | |
self, | |
max_chunk_size: int, | |
crossfade_size: int, | |
batch_size: int, | |
) -> None: | |
""" | |
Initialize streaming inference pipeline. | |
Args: | |
max_chunk_size: Maximum number of frames to process in a single chunk | |
crossfade_size: Number of frames to use for crossfading between chunks | |
batch_size: Batch size for inference | |
device: Device to run on | |
""" | |
self.max_chunk_size = max_chunk_size | |
self.max_audio_input_size = ( | |
self.max_chunk_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME | |
) | |
self.crossfade_size = crossfade_size | |
self.audio_crossfade_size = crossfade_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME | |
self.n_feats = HEAD_LANDMARK_DIM | |
# Maintain state between chunks | |
self.prev_output = np.zeros((batch_size, 0, self.n_feats)) | |
self.audio_buffer = np.zeros((batch_size, 0)) | |
# Crossfade buffer stores the overlapping region from the previous chunk | |
self.crossfade_buffer = None | |
# Pre-compute crossfade weights | |
self.crossfade_weights = np.linspace(0, 1, crossfade_size) | |
self.crossfade_weights = self.crossfade_weights.reshape(1, -1) | |
def apply_crossfade( | |
self, current_chunk: np.ndarray, update_crossfade_buffer: bool | |
) -> np.ndarray: | |
"""Apply crossfade between previous and current chunk predictions.""" | |
if self.crossfade_buffer is not None: | |
# Extract the crossfade region from the current chunk | |
current_fade_region = current_chunk[:, : self.crossfade_size] | |
# Blend the overlapping regions using the pre-computed weights | |
blended_region = np.multiply( | |
self.crossfade_buffer, np.expand_dims((1 - self.crossfade_weights), -1) | |
) + np.multiply( | |
current_fade_region, np.expand_dims(self.crossfade_weights, -1) | |
) | |
# Replace the beginning of the current chunk with the blended region | |
output = current_chunk.copy() | |
output[:, : self.crossfade_size] = blended_region | |
else: | |
output = current_chunk | |
if update_crossfade_buffer: | |
self.crossfade_buffer = current_chunk[:, -self.crossfade_size :].copy() | |
output = output[:, : -self.crossfade_size] | |
return output | |
def model_generate(self, src, max_len, initial_context=None): | |
""" | |
Generate output sequence with optional initial context. | |
Args: | |
src: Source audio features of shape [B, T_a, D], where T_a is the number of | |
audio frames corresponding to max_len video frames | |
max_len: Number of frames to generate | |
initial_context: Optional previous output context (B, J, D), where J is | |
in [1, max_len + 1] | |
Returns: | |
Predicted landmarks [B, max_len - J, D] | |
""" | |
pass | |
def infer_chunk(self, audio: np.ndarray, new_audio_len: int) -> np.ndarray: | |
"""Process a single chunk of audio, using previous context if available.""" | |
n_new_frames = ( | |
new_audio_len // N_AUDIO_SAMPLES_PER_VIDEO_FRAME + self.crossfade_size | |
) | |
n_generation_frames = audio.shape[1] // N_AUDIO_SAMPLES_PER_VIDEO_FRAME | |
n_context_frames = (n_generation_frames - n_new_frames) + 1 | |
if n_context_frames > 0: | |
initial_context = self.prev_output[:, -n_context_frames:] | |
else: | |
initial_context = None | |
# Generate predictions | |
predictions = self.model_generate(audio, n_generation_frames, initial_context) | |
self.prev_output = np.concatenate([self.prev_output, predictions], axis=1)[ | |
:, -self.max_chunk_size : | |
] | |
return predictions | |
def prepare_input_chunk(self, audio: np.ndarray) -> np.ndarray: | |
new_audio_len = audio.shape[1] | |
self.audio_buffer = np.concatenate([self.audio_buffer, audio], axis=1)[ | |
:, -self.max_audio_input_size : | |
] | |
return self.audio_buffer, new_audio_len | |
def process_output_chunk( | |
self, | |
chunk: np.ndarray, | |
update_crossfade_buffer: bool, | |
mouth_exaggeration: float, | |
brow_exaggeration: float, | |
head_wiggle_exaggeration: float, | |
unsquinch_fix: float, | |
eye_contact_fix: float, | |
exaggerate_above: float, | |
symmetrize_eyes: bool, | |
) -> np.ndarray: | |
chunk[..., :52] = clean_up_blendshapes( | |
chunk[..., :52], | |
mouth_exaggeration, | |
brow_exaggeration, | |
clear_neutral=True, | |
unsquinch_fix=unsquinch_fix, | |
eye_contact_fix=eye_contact_fix, | |
exaggerate_above=exaggerate_above, | |
symmetrize_eyes=symmetrize_eyes, | |
) | |
if head_wiggle_exaggeration != 1.0: | |
chunk[..., 52:] = exaggerate_head_wiggle( | |
chunk[..., 52:], head_wiggle_exaggeration | |
) | |
if self.crossfade_size > 0 and chunk.shape[1] > self.crossfade_size: | |
chunk = self.apply_crossfade(chunk, update_crossfade_buffer) | |
return chunk | |
def __call__( | |
self, | |
audio: np.ndarray, | |
audio_stream_can_step: bool, | |
mouth_exaggeration: float, | |
brow_exaggeration: float, | |
head_wiggle_exaggeration: float, | |
unsquinch_fix: float, | |
eye_contact_fix: float, | |
exaggerate_above: float, | |
symmetrize_eyes: bool, | |
) -> np.ndarray: | |
""" | |
Run the model on an audio tensor. | |
Args: | |
audio: Audio tensor of shape (batch_size, n_audio_samples) | |
Returns: | |
np.ndarray: Model predictions | |
""" | |
input_chunk, new_audio_len = self.prepare_input_chunk(audio) | |
output_chunk = self.infer_chunk(input_chunk, new_audio_len) | |
return self.process_output_chunk( | |
output_chunk, | |
update_crossfade_buffer=audio_stream_can_step, | |
mouth_exaggeration=mouth_exaggeration, | |
brow_exaggeration=brow_exaggeration, | |
head_wiggle_exaggeration=head_wiggle_exaggeration, | |
unsquinch_fix=unsquinch_fix, | |
eye_contact_fix=eye_contact_fix, | |
exaggerate_above=exaggerate_above, | |
symmetrize_eyes=symmetrize_eyes, | |
) | |
def reset(self): | |
"""Reset internal state""" | |
self.prev_output = np.zeros_like(self.prev_output) | |
self.audio_buffer = np.zeros_like(self.audio_buffer) | |
self.crossfade_buffer = None | |
def infer_audio_array( | |
self, | |
audio: np.ndarray, | |
min_audio_samples_per_step: int, | |
max_audio_samples_per_step: int, | |
mouth_exaggeration: float = 1.0, | |
brow_exaggeration: float = 1.0, | |
head_wiggle_exaggeration: float = 1.0, | |
unsquinch_fix: float = 0.0, | |
eye_contact_fix: float = 0.0, | |
exaggerate_above: float = 0.0, | |
symmetrize_eyes: bool = False, | |
max_audio_duration: Optional[float] = None, | |
) -> Tuple[np.ndarray, float, float, float]: | |
""" | |
Run the model on an input audio or video file under simulated streaming conditions. | |
Args: | |
audio: Numpy array of audio samples | |
min_audio_samples_per_step: Minimum number of audio samples per step | |
max_audio_samples_per_step: Maximum number of audio samples per step | |
max_audio_duration: Maximum duration of audio to process in seconds | |
Returns: | |
Tuple of: | |
- Blendshapes of shape (T, 52) | |
- Head angles of shape (T, 3) | |
- Mean time per step in seconds | |
- Mean real-time factor | |
""" | |
# Reset all buffers | |
self.reset() | |
# Apply duration limit if specified | |
if max_audio_duration is not None: | |
max_audio_duration_frames = int(max_audio_duration * SAMPLE_RATE) | |
audio_len = min(len(audio), max_audio_duration_frames) | |
else: | |
audio_len = len(audio) | |
audio_stream = AudioStream( | |
audio[:audio_len], min_audio_samples_per_step, max_audio_samples_per_step | |
) | |
# Process each chunk | |
outputs = [] | |
step_times = [] | |
audio_durations = [] | |
while audio_stream.can_step: | |
audio_chunk = audio_stream.step() | |
audio_durations.append(audio_chunk.shape[-1] / SAMPLE_RATE) | |
# Process the chunk | |
start_time = time.time() | |
chunk_output = self( | |
np.expand_dims(audio_chunk, 0), | |
audio_stream.can_step, | |
mouth_exaggeration, | |
brow_exaggeration, | |
head_wiggle_exaggeration, | |
unsquinch_fix, | |
eye_contact_fix, | |
exaggerate_above, | |
symmetrize_eyes, | |
) | |
step_times.append(time.time() - start_time) | |
outputs.append(chunk_output) | |
# Concatenate all outputs | |
full_output = np.concatenate(outputs, axis=1) | |
mean_step_time = sum(step_times) / len(step_times) | |
mean_rtf = sum(audio_durations) / sum(step_times) | |
time_to_first_sound = step_times[0] + audio_durations[0] | |
blendshapes = full_output.squeeze(0)[:, :52] | |
head_angles = unscale_and_uncenter_head_angles( | |
full_output.squeeze(0)[:, 52:], bad_frames=[] | |
) | |
return blendshapes, head_angles, mean_step_time, mean_rtf, time_to_first_sound | |
class ONNXModels: | |
hubert_session: ort.InferenceSession | |
encoder_session: ort.InferenceSession | |
decoder_session: ort.InferenceSession | |
class ONNXInferencePipeline(InferencePipeline): | |
""" | |
ONNX version of the inference pipeline. | |
""" | |
def __init__( | |
self, | |
onnx_models: ONNXModels, | |
max_chunk_size: int, | |
crossfade_size: int, | |
batch_size: int, | |
): | |
""" | |
Initialize ONNX inference pipeline. | |
Args: | |
onnx_models: ONNXModels containing hubert and decoder sessions | |
max_chunk_size: Maximum number of frames to process in a single chunk | |
crossfade_size: Number of frames to use for crossfading between chunks | |
batch_size: Batch size for inference | |
device: Device to run inference on | |
""" | |
super().__init__( | |
max_chunk_size, | |
crossfade_size, | |
batch_size, | |
) | |
self.onnx_models = onnx_models | |
def model_generate(self, src, max_len, initial_context=None): | |
""" | |
Generate output sequence using ONNX models. | |
""" | |
# Run HuBERT through ONNX | |
src_np = src.astype(np.float32) | |
hubert_out = self.onnx_models.hubert_session.run( | |
None, {"input_values": src_np} | |
)[0] | |
src = self.onnx_models.encoder_session.run(None, {"src": hubert_out})[0] | |
if initial_context is not None: | |
decoder_in = initial_context.astype(np.float32) | |
else: | |
decoder_in = np.zeros((src.shape[0], 1, HEAD_LANDMARK_DIM)).astype( | |
np.float32 | |
) | |
outputs = [] | |
for i in range(max_len - decoder_in.shape[1] + 1): | |
# Run decoder step through ONNX | |
next_output = self.onnx_models.decoder_session.run( | |
None, | |
{"src": src.astype(np.float32), "decoder_in": decoder_in}, | |
)[0] | |
decoder_in = np.concatenate([decoder_in, next_output], axis=1) | |
outputs.append(next_output) | |
pred_out = np.concatenate(outputs, axis=1) | |
return pred_out | |
def init_pipeline( | |
hubert_onnx_path: Path, | |
encoder_onnx_path: Path, | |
decoder_onnx_path: Path, | |
device: str = "cpu", | |
chunk_size: int = 90, | |
crossfade_size: int = 5, | |
batch_size: int = 1, | |
) -> Union[InferencePipeline, ONNXInferencePipeline]: | |
""" | |
Initialize ONNX inference pipeline based on provided paths. | |
Args: | |
hubert_onnx_path: Path to ONNX HuBERT model | |
decoder_onnx_path: Path to ONNX decoder model | |
chunk_size: Maximum number of frames per chunk | |
crossfade_size: Number of frames for crossfading | |
batch_size: Batch size for inference | |
device: Device to run on | |
Returns: | |
ONNX inference pipeline | |
""" | |
# ONNX pipeline | |
providers = ( | |
["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] | |
) | |
hubert_session = ort.InferenceSession(str(hubert_onnx_path), providers=providers) | |
encoder_session = ort.InferenceSession(str(encoder_onnx_path), providers=providers) | |
decoder_session = ort.InferenceSession(str(decoder_onnx_path), providers=providers) | |
onnx_models = ONNXModels(hubert_session, encoder_session, decoder_session) | |
return ONNXInferencePipeline(onnx_models, chunk_size, crossfade_size, batch_size) | |