import numpy as np from typing import Optional from pathlib import Path from typing import Tuple import time from model_demo.inference.audio import AudioStream from model_demo.inference.landmarks import ( unscale_and_uncenter_head_angles, clean_up_blendshapes, exaggerate_head_wiggle, ) from model_demo.inference.constants import ( N_AUDIO_SAMPLES_PER_VIDEO_FRAME, SAMPLE_RATE, HEAD_LANDMARK_DIM, ) import onnxruntime as ort from dataclasses import dataclass from typing import Optional, Union class InferencePipeline: """ Pipeline for running WhisperLike model inference on a video file. Added crossfade functionality to smooth transitions between chunks. """ def __init__( self, max_chunk_size: int, crossfade_size: int, batch_size: int, ) -> None: """ Initialize streaming inference pipeline. Args: max_chunk_size: Maximum number of frames to process in a single chunk crossfade_size: Number of frames to use for crossfading between chunks batch_size: Batch size for inference device: Device to run on """ self.max_chunk_size = max_chunk_size self.max_audio_input_size = ( self.max_chunk_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME ) self.crossfade_size = crossfade_size self.audio_crossfade_size = crossfade_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME self.n_feats = HEAD_LANDMARK_DIM # Maintain state between chunks self.prev_output = np.zeros((batch_size, 0, self.n_feats)) self.audio_buffer = np.zeros((batch_size, 0)) # Crossfade buffer stores the overlapping region from the previous chunk self.crossfade_buffer = None # Pre-compute crossfade weights self.crossfade_weights = np.linspace(0, 1, crossfade_size) self.crossfade_weights = self.crossfade_weights.reshape(1, -1) def apply_crossfade( self, current_chunk: np.ndarray, update_crossfade_buffer: bool ) -> np.ndarray: """Apply crossfade between previous and current chunk predictions.""" if self.crossfade_buffer is not None: # Extract the crossfade region from the current chunk current_fade_region = current_chunk[:, : self.crossfade_size] # Blend the overlapping regions using the pre-computed weights blended_region = np.multiply( self.crossfade_buffer, np.expand_dims((1 - self.crossfade_weights), -1) ) + np.multiply( current_fade_region, np.expand_dims(self.crossfade_weights, -1) ) # Replace the beginning of the current chunk with the blended region output = current_chunk.copy() output[:, : self.crossfade_size] = blended_region else: output = current_chunk if update_crossfade_buffer: self.crossfade_buffer = current_chunk[:, -self.crossfade_size :].copy() output = output[:, : -self.crossfade_size] return output def model_generate(self, src, max_len, initial_context=None): """ Generate output sequence with optional initial context. Args: src: Source audio features of shape [B, T_a, D], where T_a is the number of audio frames corresponding to max_len video frames max_len: Number of frames to generate initial_context: Optional previous output context (B, J, D), where J is in [1, max_len + 1] Returns: Predicted landmarks [B, max_len - J, D] """ pass def infer_chunk(self, audio: np.ndarray, new_audio_len: int) -> np.ndarray: """Process a single chunk of audio, using previous context if available.""" n_new_frames = ( new_audio_len // N_AUDIO_SAMPLES_PER_VIDEO_FRAME + self.crossfade_size ) n_generation_frames = audio.shape[1] // N_AUDIO_SAMPLES_PER_VIDEO_FRAME n_context_frames = (n_generation_frames - n_new_frames) + 1 if n_context_frames > 0: initial_context = self.prev_output[:, -n_context_frames:] else: initial_context = None # Generate predictions predictions = self.model_generate(audio, n_generation_frames, initial_context) self.prev_output = np.concatenate([self.prev_output, predictions], axis=1)[ :, -self.max_chunk_size : ] return predictions def prepare_input_chunk(self, audio: np.ndarray) -> np.ndarray: new_audio_len = audio.shape[1] self.audio_buffer = np.concatenate([self.audio_buffer, audio], axis=1)[ :, -self.max_audio_input_size : ] return self.audio_buffer, new_audio_len def process_output_chunk( self, chunk: np.ndarray, update_crossfade_buffer: bool, mouth_exaggeration: float, brow_exaggeration: float, head_wiggle_exaggeration: float, unsquinch_fix: float, eye_contact_fix: float, exaggerate_above: float, symmetrize_eyes: bool, ) -> np.ndarray: chunk[..., :52] = clean_up_blendshapes( chunk[..., :52], mouth_exaggeration, brow_exaggeration, clear_neutral=True, unsquinch_fix=unsquinch_fix, eye_contact_fix=eye_contact_fix, exaggerate_above=exaggerate_above, symmetrize_eyes=symmetrize_eyes, ) if head_wiggle_exaggeration != 1.0: chunk[..., 52:] = exaggerate_head_wiggle( chunk[..., 52:], head_wiggle_exaggeration ) if self.crossfade_size > 0 and chunk.shape[1] > self.crossfade_size: chunk = self.apply_crossfade(chunk, update_crossfade_buffer) return chunk def __call__( self, audio: np.ndarray, audio_stream_can_step: bool, mouth_exaggeration: float, brow_exaggeration: float, head_wiggle_exaggeration: float, unsquinch_fix: float, eye_contact_fix: float, exaggerate_above: float, symmetrize_eyes: bool, ) -> np.ndarray: """ Run the model on an audio tensor. Args: audio: Audio tensor of shape (batch_size, n_audio_samples) Returns: np.ndarray: Model predictions """ input_chunk, new_audio_len = self.prepare_input_chunk(audio) output_chunk = self.infer_chunk(input_chunk, new_audio_len) return self.process_output_chunk( output_chunk, update_crossfade_buffer=audio_stream_can_step, mouth_exaggeration=mouth_exaggeration, brow_exaggeration=brow_exaggeration, head_wiggle_exaggeration=head_wiggle_exaggeration, unsquinch_fix=unsquinch_fix, eye_contact_fix=eye_contact_fix, exaggerate_above=exaggerate_above, symmetrize_eyes=symmetrize_eyes, ) def reset(self): """Reset internal state""" self.prev_output = np.zeros_like(self.prev_output) self.audio_buffer = np.zeros_like(self.audio_buffer) self.crossfade_buffer = None def infer_audio_array( self, audio: np.ndarray, min_audio_samples_per_step: int, max_audio_samples_per_step: int, mouth_exaggeration: float = 1.0, brow_exaggeration: float = 1.0, head_wiggle_exaggeration: float = 1.0, unsquinch_fix: float = 0.0, eye_contact_fix: float = 0.0, exaggerate_above: float = 0.0, symmetrize_eyes: bool = False, max_audio_duration: Optional[float] = None, ) -> Tuple[np.ndarray, float, float, float]: """ Run the model on an input audio or video file under simulated streaming conditions. Args: audio: Numpy array of audio samples min_audio_samples_per_step: Minimum number of audio samples per step max_audio_samples_per_step: Maximum number of audio samples per step max_audio_duration: Maximum duration of audio to process in seconds Returns: Tuple of: - Blendshapes of shape (T, 52) - Head angles of shape (T, 3) - Mean time per step in seconds - Mean real-time factor """ # Reset all buffers self.reset() # Apply duration limit if specified if max_audio_duration is not None: max_audio_duration_frames = int(max_audio_duration * SAMPLE_RATE) audio_len = min(len(audio), max_audio_duration_frames) else: audio_len = len(audio) audio_stream = AudioStream( audio[:audio_len], min_audio_samples_per_step, max_audio_samples_per_step ) # Process each chunk outputs = [] step_times = [] audio_durations = [] while audio_stream.can_step: audio_chunk = audio_stream.step() audio_durations.append(audio_chunk.shape[-1] / SAMPLE_RATE) # Process the chunk start_time = time.time() chunk_output = self( np.expand_dims(audio_chunk, 0), audio_stream.can_step, mouth_exaggeration, brow_exaggeration, head_wiggle_exaggeration, unsquinch_fix, eye_contact_fix, exaggerate_above, symmetrize_eyes, ) step_times.append(time.time() - start_time) outputs.append(chunk_output) # Concatenate all outputs full_output = np.concatenate(outputs, axis=1) mean_step_time = sum(step_times) / len(step_times) mean_rtf = sum(audio_durations) / sum(step_times) time_to_first_sound = step_times[0] + audio_durations[0] blendshapes = full_output.squeeze(0)[:, :52] head_angles = unscale_and_uncenter_head_angles( full_output.squeeze(0)[:, 52:], bad_frames=[] ) return blendshapes, head_angles, mean_step_time, mean_rtf, time_to_first_sound @dataclass class ONNXModels: hubert_session: ort.InferenceSession encoder_session: ort.InferenceSession decoder_session: ort.InferenceSession class ONNXInferencePipeline(InferencePipeline): """ ONNX version of the inference pipeline. """ def __init__( self, onnx_models: ONNXModels, max_chunk_size: int, crossfade_size: int, batch_size: int, ): """ Initialize ONNX inference pipeline. Args: onnx_models: ONNXModels containing hubert and decoder sessions max_chunk_size: Maximum number of frames to process in a single chunk crossfade_size: Number of frames to use for crossfading between chunks batch_size: Batch size for inference device: Device to run inference on """ super().__init__( max_chunk_size, crossfade_size, batch_size, ) self.onnx_models = onnx_models def model_generate(self, src, max_len, initial_context=None): """ Generate output sequence using ONNX models. """ # Run HuBERT through ONNX src_np = src.astype(np.float32) hubert_out = self.onnx_models.hubert_session.run( None, {"input_values": src_np} )[0] src = self.onnx_models.encoder_session.run(None, {"src": hubert_out})[0] if initial_context is not None: decoder_in = initial_context.astype(np.float32) else: decoder_in = np.zeros((src.shape[0], 1, HEAD_LANDMARK_DIM)).astype( np.float32 ) outputs = [] for i in range(max_len - decoder_in.shape[1] + 1): # Run decoder step through ONNX next_output = self.onnx_models.decoder_session.run( None, {"src": src.astype(np.float32), "decoder_in": decoder_in}, )[0] decoder_in = np.concatenate([decoder_in, next_output], axis=1) outputs.append(next_output) pred_out = np.concatenate(outputs, axis=1) return pred_out def init_pipeline( hubert_onnx_path: Path, encoder_onnx_path: Path, decoder_onnx_path: Path, device: str = "cpu", chunk_size: int = 90, crossfade_size: int = 5, batch_size: int = 1, ) -> Union[InferencePipeline, ONNXInferencePipeline]: """ Initialize ONNX inference pipeline based on provided paths. Args: hubert_onnx_path: Path to ONNX HuBERT model decoder_onnx_path: Path to ONNX decoder model chunk_size: Maximum number of frames per chunk crossfade_size: Number of frames for crossfading batch_size: Batch size for inference device: Device to run on Returns: ONNX inference pipeline """ # ONNX pipeline providers = ( ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] ) hubert_session = ort.InferenceSession(str(hubert_onnx_path), providers=providers) encoder_session = ort.InferenceSession(str(encoder_onnx_path), providers=providers) decoder_session = ort.InferenceSession(str(decoder_onnx_path), providers=providers) onnx_models = ONNXModels(hubert_session, encoder_session, decoder_session) return ONNXInferencePipeline(onnx_models, chunk_size, crossfade_size, batch_size)