cherrvak's picture
initial commit
2c04fa5
import numpy as np
from typing import Optional
from pathlib import Path
from typing import Tuple
import time
from model_demo.inference.audio import AudioStream
from model_demo.inference.landmarks import (
unscale_and_uncenter_head_angles,
clean_up_blendshapes,
exaggerate_head_wiggle,
)
from model_demo.inference.constants import (
N_AUDIO_SAMPLES_PER_VIDEO_FRAME,
SAMPLE_RATE,
HEAD_LANDMARK_DIM,
)
import onnxruntime as ort
from dataclasses import dataclass
from typing import Optional, Union
class InferencePipeline:
"""
Pipeline for running WhisperLike model inference on a video file.
Added crossfade functionality to smooth transitions between chunks.
"""
def __init__(
self,
max_chunk_size: int,
crossfade_size: int,
batch_size: int,
) -> None:
"""
Initialize streaming inference pipeline.
Args:
max_chunk_size: Maximum number of frames to process in a single chunk
crossfade_size: Number of frames to use for crossfading between chunks
batch_size: Batch size for inference
device: Device to run on
"""
self.max_chunk_size = max_chunk_size
self.max_audio_input_size = (
self.max_chunk_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME
)
self.crossfade_size = crossfade_size
self.audio_crossfade_size = crossfade_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME
self.n_feats = HEAD_LANDMARK_DIM
# Maintain state between chunks
self.prev_output = np.zeros((batch_size, 0, self.n_feats))
self.audio_buffer = np.zeros((batch_size, 0))
# Crossfade buffer stores the overlapping region from the previous chunk
self.crossfade_buffer = None
# Pre-compute crossfade weights
self.crossfade_weights = np.linspace(0, 1, crossfade_size)
self.crossfade_weights = self.crossfade_weights.reshape(1, -1)
def apply_crossfade(
self, current_chunk: np.ndarray, update_crossfade_buffer: bool
) -> np.ndarray:
"""Apply crossfade between previous and current chunk predictions."""
if self.crossfade_buffer is not None:
# Extract the crossfade region from the current chunk
current_fade_region = current_chunk[:, : self.crossfade_size]
# Blend the overlapping regions using the pre-computed weights
blended_region = np.multiply(
self.crossfade_buffer, np.expand_dims((1 - self.crossfade_weights), -1)
) + np.multiply(
current_fade_region, np.expand_dims(self.crossfade_weights, -1)
)
# Replace the beginning of the current chunk with the blended region
output = current_chunk.copy()
output[:, : self.crossfade_size] = blended_region
else:
output = current_chunk
if update_crossfade_buffer:
self.crossfade_buffer = current_chunk[:, -self.crossfade_size :].copy()
output = output[:, : -self.crossfade_size]
return output
def model_generate(self, src, max_len, initial_context=None):
"""
Generate output sequence with optional initial context.
Args:
src: Source audio features of shape [B, T_a, D], where T_a is the number of
audio frames corresponding to max_len video frames
max_len: Number of frames to generate
initial_context: Optional previous output context (B, J, D), where J is
in [1, max_len + 1]
Returns:
Predicted landmarks [B, max_len - J, D]
"""
pass
def infer_chunk(self, audio: np.ndarray, new_audio_len: int) -> np.ndarray:
"""Process a single chunk of audio, using previous context if available."""
n_new_frames = (
new_audio_len // N_AUDIO_SAMPLES_PER_VIDEO_FRAME + self.crossfade_size
)
n_generation_frames = audio.shape[1] // N_AUDIO_SAMPLES_PER_VIDEO_FRAME
n_context_frames = (n_generation_frames - n_new_frames) + 1
if n_context_frames > 0:
initial_context = self.prev_output[:, -n_context_frames:]
else:
initial_context = None
# Generate predictions
predictions = self.model_generate(audio, n_generation_frames, initial_context)
self.prev_output = np.concatenate([self.prev_output, predictions], axis=1)[
:, -self.max_chunk_size :
]
return predictions
def prepare_input_chunk(self, audio: np.ndarray) -> np.ndarray:
new_audio_len = audio.shape[1]
self.audio_buffer = np.concatenate([self.audio_buffer, audio], axis=1)[
:, -self.max_audio_input_size :
]
return self.audio_buffer, new_audio_len
def process_output_chunk(
self,
chunk: np.ndarray,
update_crossfade_buffer: bool,
mouth_exaggeration: float,
brow_exaggeration: float,
head_wiggle_exaggeration: float,
unsquinch_fix: float,
eye_contact_fix: float,
exaggerate_above: float,
symmetrize_eyes: bool,
) -> np.ndarray:
chunk[..., :52] = clean_up_blendshapes(
chunk[..., :52],
mouth_exaggeration,
brow_exaggeration,
clear_neutral=True,
unsquinch_fix=unsquinch_fix,
eye_contact_fix=eye_contact_fix,
exaggerate_above=exaggerate_above,
symmetrize_eyes=symmetrize_eyes,
)
if head_wiggle_exaggeration != 1.0:
chunk[..., 52:] = exaggerate_head_wiggle(
chunk[..., 52:], head_wiggle_exaggeration
)
if self.crossfade_size > 0 and chunk.shape[1] > self.crossfade_size:
chunk = self.apply_crossfade(chunk, update_crossfade_buffer)
return chunk
def __call__(
self,
audio: np.ndarray,
audio_stream_can_step: bool,
mouth_exaggeration: float,
brow_exaggeration: float,
head_wiggle_exaggeration: float,
unsquinch_fix: float,
eye_contact_fix: float,
exaggerate_above: float,
symmetrize_eyes: bool,
) -> np.ndarray:
"""
Run the model on an audio tensor.
Args:
audio: Audio tensor of shape (batch_size, n_audio_samples)
Returns:
np.ndarray: Model predictions
"""
input_chunk, new_audio_len = self.prepare_input_chunk(audio)
output_chunk = self.infer_chunk(input_chunk, new_audio_len)
return self.process_output_chunk(
output_chunk,
update_crossfade_buffer=audio_stream_can_step,
mouth_exaggeration=mouth_exaggeration,
brow_exaggeration=brow_exaggeration,
head_wiggle_exaggeration=head_wiggle_exaggeration,
unsquinch_fix=unsquinch_fix,
eye_contact_fix=eye_contact_fix,
exaggerate_above=exaggerate_above,
symmetrize_eyes=symmetrize_eyes,
)
def reset(self):
"""Reset internal state"""
self.prev_output = np.zeros_like(self.prev_output)
self.audio_buffer = np.zeros_like(self.audio_buffer)
self.crossfade_buffer = None
def infer_audio_array(
self,
audio: np.ndarray,
min_audio_samples_per_step: int,
max_audio_samples_per_step: int,
mouth_exaggeration: float = 1.0,
brow_exaggeration: float = 1.0,
head_wiggle_exaggeration: float = 1.0,
unsquinch_fix: float = 0.0,
eye_contact_fix: float = 0.0,
exaggerate_above: float = 0.0,
symmetrize_eyes: bool = False,
max_audio_duration: Optional[float] = None,
) -> Tuple[np.ndarray, float, float, float]:
"""
Run the model on an input audio or video file under simulated streaming conditions.
Args:
audio: Numpy array of audio samples
min_audio_samples_per_step: Minimum number of audio samples per step
max_audio_samples_per_step: Maximum number of audio samples per step
max_audio_duration: Maximum duration of audio to process in seconds
Returns:
Tuple of:
- Blendshapes of shape (T, 52)
- Head angles of shape (T, 3)
- Mean time per step in seconds
- Mean real-time factor
"""
# Reset all buffers
self.reset()
# Apply duration limit if specified
if max_audio_duration is not None:
max_audio_duration_frames = int(max_audio_duration * SAMPLE_RATE)
audio_len = min(len(audio), max_audio_duration_frames)
else:
audio_len = len(audio)
audio_stream = AudioStream(
audio[:audio_len], min_audio_samples_per_step, max_audio_samples_per_step
)
# Process each chunk
outputs = []
step_times = []
audio_durations = []
while audio_stream.can_step:
audio_chunk = audio_stream.step()
audio_durations.append(audio_chunk.shape[-1] / SAMPLE_RATE)
# Process the chunk
start_time = time.time()
chunk_output = self(
np.expand_dims(audio_chunk, 0),
audio_stream.can_step,
mouth_exaggeration,
brow_exaggeration,
head_wiggle_exaggeration,
unsquinch_fix,
eye_contact_fix,
exaggerate_above,
symmetrize_eyes,
)
step_times.append(time.time() - start_time)
outputs.append(chunk_output)
# Concatenate all outputs
full_output = np.concatenate(outputs, axis=1)
mean_step_time = sum(step_times) / len(step_times)
mean_rtf = sum(audio_durations) / sum(step_times)
time_to_first_sound = step_times[0] + audio_durations[0]
blendshapes = full_output.squeeze(0)[:, :52]
head_angles = unscale_and_uncenter_head_angles(
full_output.squeeze(0)[:, 52:], bad_frames=[]
)
return blendshapes, head_angles, mean_step_time, mean_rtf, time_to_first_sound
@dataclass
class ONNXModels:
hubert_session: ort.InferenceSession
encoder_session: ort.InferenceSession
decoder_session: ort.InferenceSession
class ONNXInferencePipeline(InferencePipeline):
"""
ONNX version of the inference pipeline.
"""
def __init__(
self,
onnx_models: ONNXModels,
max_chunk_size: int,
crossfade_size: int,
batch_size: int,
):
"""
Initialize ONNX inference pipeline.
Args:
onnx_models: ONNXModels containing hubert and decoder sessions
max_chunk_size: Maximum number of frames to process in a single chunk
crossfade_size: Number of frames to use for crossfading between chunks
batch_size: Batch size for inference
device: Device to run inference on
"""
super().__init__(
max_chunk_size,
crossfade_size,
batch_size,
)
self.onnx_models = onnx_models
def model_generate(self, src, max_len, initial_context=None):
"""
Generate output sequence using ONNX models.
"""
# Run HuBERT through ONNX
src_np = src.astype(np.float32)
hubert_out = self.onnx_models.hubert_session.run(
None, {"input_values": src_np}
)[0]
src = self.onnx_models.encoder_session.run(None, {"src": hubert_out})[0]
if initial_context is not None:
decoder_in = initial_context.astype(np.float32)
else:
decoder_in = np.zeros((src.shape[0], 1, HEAD_LANDMARK_DIM)).astype(
np.float32
)
outputs = []
for i in range(max_len - decoder_in.shape[1] + 1):
# Run decoder step through ONNX
next_output = self.onnx_models.decoder_session.run(
None,
{"src": src.astype(np.float32), "decoder_in": decoder_in},
)[0]
decoder_in = np.concatenate([decoder_in, next_output], axis=1)
outputs.append(next_output)
pred_out = np.concatenate(outputs, axis=1)
return pred_out
def init_pipeline(
hubert_onnx_path: Path,
encoder_onnx_path: Path,
decoder_onnx_path: Path,
device: str = "cpu",
chunk_size: int = 90,
crossfade_size: int = 5,
batch_size: int = 1,
) -> Union[InferencePipeline, ONNXInferencePipeline]:
"""
Initialize ONNX inference pipeline based on provided paths.
Args:
hubert_onnx_path: Path to ONNX HuBERT model
decoder_onnx_path: Path to ONNX decoder model
chunk_size: Maximum number of frames per chunk
crossfade_size: Number of frames for crossfading
batch_size: Batch size for inference
device: Device to run on
Returns:
ONNX inference pipeline
"""
# ONNX pipeline
providers = (
["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
)
hubert_session = ort.InferenceSession(str(hubert_onnx_path), providers=providers)
encoder_session = ort.InferenceSession(str(encoder_onnx_path), providers=providers)
decoder_session = ort.InferenceSession(str(decoder_onnx_path), providers=providers)
onnx_models = ONNXModels(hubert_session, encoder_session, decoder_session)
return ONNXInferencePipeline(onnx_models, chunk_size, crossfade_size, batch_size)