cherrvak commited on
Commit
2c04fa5
·
1 Parent(s): 753e340

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.vrm filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ venv
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from model_demo import model_demo
3
+ from model_demo.inference.infer import init_pipeline
4
+ import os
5
+ import pydub
6
+ import numpy as np
7
+ from model_demo.inference.constants import BLENDSHAPE_NAMES
8
+
9
+
10
+ def make_downloadable_json(blendshapes, headangles):
11
+ blendshape_dict = {}
12
+ for i, name in enumerate(BLENDSHAPE_NAMES):
13
+ blendshape_dict[name] = blendshapes[:, i].tolist()
14
+ headangle_dict = {}
15
+ for i, name in enumerate(["pitch", "yaw", "roll"]):
16
+ headangle_dict[name] = headangles[:, i].tolist()
17
+ return str({"blendshapes": blendshape_dict, "headangles": headangle_dict})
18
+
19
+
20
+ if "pred_dict" not in st.session_state:
21
+ st.session_state.pred_dict = {}
22
+
23
+ current_dir = os.path.dirname(os.path.abspath(__file__))
24
+ onnx_path = "assets/onnx_models"
25
+ hubert_path = f"{onnx_path}/hubert.onnx"
26
+ encoder_path = f"{onnx_path}/encoder.onnx"
27
+ decoder_path = f"{onnx_path}/decoder.onnx"
28
+ pipeline = init_pipeline(hubert_path, encoder_path, decoder_path)
29
+
30
+ (col1, col2) = st.columns([2, 3])
31
+
32
+ with col1:
33
+ with st.container(border=True):
34
+ audio_tab, control_tab, vrm_tab = st.tabs(["Audio", "Controls", "Upload VRM"])
35
+
36
+ with audio_tab:
37
+ recorded_value = st.audio_input("Record audio")
38
+ st.write("Or")
39
+ uploaded_value = st.file_uploader("Upload audio", type=["wav"])
40
+ audio_value = (
41
+ recorded_value if recorded_value is not None else uploaded_value
42
+ )
43
+
44
+ with control_tab:
45
+ mouth_exaggeration = st.number_input("Lower face exaggeration", value=5.0)
46
+ brow_exaggeration = st.number_input("Upper face exaggeration", value=4.0)
47
+ head_wiggle_exaggeration = st.number_input(
48
+ "Head wiggle exaggeration", value=2.0
49
+ )
50
+ unsquinch_fix = st.number_input(
51
+ "Unsquinch fix",
52
+ value=0.75,
53
+ )
54
+ eye_contact_fix = st.number_input(
55
+ "Eye contact fix",
56
+ value=1.5,
57
+ )
58
+ exaggerate_above = st.number_input(
59
+ "Exaggerate above",
60
+ value=0.01,
61
+ min_value=0.0,
62
+ max_value=1.0,
63
+ step=0.001,
64
+ format="%.3f",
65
+ )
66
+ symmetrize_eyes = st.checkbox("Symmetrize eyes", value=True)
67
+
68
+ with vrm_tab:
69
+ vrm_file = st.file_uploader("Upload VRM file", type=["vrm"])
70
+ if vrm_file:
71
+ # Read the raw bytes from the uploaded file
72
+ vrm_bytes = vrm_file.read()
73
+ # Store the raw bytes in the session state
74
+ st.session_state.pred_dict["vrm_file"] = vrm_bytes
75
+
76
+ submit_button = st.button("Run Inference", disabled=not audio_value)
77
+
78
+ if submit_button and audio_value:
79
+ audio_segment = (
80
+ pydub.AudioSegment.from_file(audio_value).set_frame_rate(16000).set_channels(1)
81
+ )
82
+ audio_array = np.array(audio_segment.get_array_of_samples())
83
+ blendshapes, head_angles, mean_step_time, mean_rtf, time_to_first_sound = (
84
+ pipeline.infer_audio_array(
85
+ np.array(audio_array),
86
+ 32000,
87
+ 48000,
88
+ mouth_exaggeration,
89
+ brow_exaggeration,
90
+ head_wiggle_exaggeration,
91
+ unsquinch_fix,
92
+ eye_contact_fix,
93
+ exaggerate_above,
94
+ symmetrize_eyes,
95
+ )
96
+ )
97
+ st.session_state.pred_dict["blendshapes"] = blendshapes
98
+ st.session_state.pred_dict["head_angles"] = head_angles
99
+ st.session_state.pred_dict["audio_data"] = audio_value.getvalue()
100
+ processing_string = f"Inference complete at {mean_rtf:.2f}x real-time."
101
+
102
+ with col1:
103
+ st.write(processing_string)
104
+ st.download_button(
105
+ label="Download results as JSON",
106
+ data=make_downloadable_json(blendshapes, head_angles),
107
+ file_name="inference_results.json",
108
+ mime="text/json",
109
+ )
110
+
111
+ with col2:
112
+ model_demo(
113
+ blendshapes=st.session_state.pred_dict.get("blendshapes", None),
114
+ headangles=st.session_state.pred_dict.get("head_angles", None),
115
+ audio_data=st.session_state.pred_dict.get("audio_data", None),
116
+ vrm_data=st.session_state.pred_dict.get("vrm_file", None),
117
+ key="model_viewport",
118
+ )
assets/onnx_models/decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ae52135ac99c7e48ec8ca77e96a7ff057b36bcd1379e3763ed25a23339781de
3
+ size 22408905
assets/onnx_models/encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32b26754ff956a46742232d0fb17adb444115ffb2d9cf155051d1e9aca27cf18
3
+ size 5909306
assets/onnx_models/hubert.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6fcb81a8315972f672b9433e85886fda467b9c6d76f1799e5bf01f8c68f915b
3
+ size 377746620
model_demo/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit.components.v1 as components
3
+
4
+ _RELEASE = True
5
+
6
+ if not _RELEASE:
7
+ _component_func = components.declare_component(
8
+ "model_demo",
9
+ url="http://localhost:3001",
10
+ )
11
+ else:
12
+ # When we're distributing a production version of the component, we'll
13
+ # replace the `url` param with `path`, and point it to the component's
14
+ # build directory:
15
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
16
+ build_dir = os.path.join(parent_dir, "frontend/build")
17
+ _component_func = components.declare_component("model_demo", path=build_dir)
18
+
19
+
20
+ def model_demo(
21
+ blendshapes=None, headangles=None, audio_data=None, vrm_data=None, key=None
22
+ ):
23
+ component_value = _component_func(
24
+ blendshapes=blendshapes.tolist() if blendshapes is not None else None,
25
+ headangles=headangles.tolist() if headangles is not None else None,
26
+ audio_data=audio_data.decode("latin1") if audio_data is not None else None,
27
+ vrm_data=[int(b) for b in vrm_data] if vrm_data is not None else None,
28
+ key=key,
29
+ default=0,
30
+ )
31
+ return component_value
model_demo/frontend/build/asset-manifest.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "files": {
3
+ "main.js": "./static/js/main.88082681.js",
4
+ "index.html": "./index.html",
5
+ "main.88082681.js.map": "./static/js/main.88082681.js.map"
6
+ },
7
+ "entrypoints": [
8
+ "static/js/main.88082681.js"
9
+ ]
10
+ }
model_demo/frontend/build/bootstrap.min.css ADDED
The diff for this file is too large to render. See raw diff
 
model_demo/frontend/build/index.html ADDED
@@ -0,0 +1 @@
 
 
1
+ <!doctype html><html lang="en"><head><title>Streamlit Component</title><meta charset="UTF-8"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Streamlit Component"/><link rel="stylesheet" href="bootstrap.min.css"/><script defer="defer" src="./static/js/main.88082681.js"></script></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
model_demo/frontend/build/static/js/main.88082681.js ADDED
The diff for this file is too large to render. See raw diff
 
model_demo/frontend/build/static/js/main.88082681.js.LICENSE.txt ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ object-assign
3
+ (c) Sindre Sorhus
4
+ @license MIT
5
+ */
6
+
7
+ /*!
8
+ * @pixiv/three-vrm v3.3.4
9
+ * VRM file loader for three.js.
10
+ *
11
+ * Copyright (c) 2019-2025 pixiv Inc.
12
+ * @pixiv/three-vrm is distributed under MIT License
13
+ * https://github.com/pixiv/three-vrm/blob/release/LICENSE
14
+ */
15
+
16
+ /**
17
+ * @license
18
+ * Copyright 2010-2024 Three.js Authors
19
+ * SPDX-License-Identifier: MIT
20
+ */
21
+
22
+ /**
23
+ * @license React
24
+ * react-dom.production.min.js
25
+ *
26
+ * Copyright (c) Facebook, Inc. and its affiliates.
27
+ *
28
+ * This source code is licensed under the MIT license found in the
29
+ * LICENSE file in the root directory of this source tree.
30
+ */
31
+
32
+ /**
33
+ * @license React
34
+ * react-jsx-runtime.production.min.js
35
+ *
36
+ * Copyright (c) Facebook, Inc. and its affiliates.
37
+ *
38
+ * This source code is licensed under the MIT license found in the
39
+ * LICENSE file in the root directory of this source tree.
40
+ */
41
+
42
+ /**
43
+ * @license React
44
+ * react-reconciler-constants.production.min.js
45
+ *
46
+ * Copyright (c) Facebook, Inc. and its affiliates.
47
+ *
48
+ * This source code is licensed under the MIT license found in the
49
+ * LICENSE file in the root directory of this source tree.
50
+ */
51
+
52
+ /**
53
+ * @license React
54
+ * react-reconciler.production.min.js
55
+ *
56
+ * Copyright (c) Facebook, Inc. and its affiliates.
57
+ *
58
+ * This source code is licensed under the MIT license found in the
59
+ * LICENSE file in the root directory of this source tree.
60
+ */
61
+
62
+ /**
63
+ * @license React
64
+ * react.production.min.js
65
+ *
66
+ * Copyright (c) Facebook, Inc. and its affiliates.
67
+ *
68
+ * This source code is licensed under the MIT license found in the
69
+ * LICENSE file in the root directory of this source tree.
70
+ */
71
+
72
+ /**
73
+ * @license React
74
+ * scheduler.production.min.js
75
+ *
76
+ * Copyright (c) Facebook, Inc. and its affiliates.
77
+ *
78
+ * This source code is licensed under the MIT license found in the
79
+ * LICENSE file in the root directory of this source tree.
80
+ */
81
+
82
+ /** @license React v16.13.1
83
+ * react-is.production.min.js
84
+ *
85
+ * Copyright (c) Facebook, Inc. and its affiliates.
86
+ *
87
+ * This source code is licensed under the MIT license found in the
88
+ * LICENSE file in the root directory of this source tree.
89
+ */
90
+
91
+ /** @license React v16.14.0
92
+ * react.production.min.js
93
+ *
94
+ * Copyright (c) Facebook, Inc. and its affiliates.
95
+ *
96
+ * This source code is licensed under the MIT license found in the
97
+ * LICENSE file in the root directory of this source tree.
98
+ */
model_demo/frontend/build/static/js/main.88082681.js.map ADDED
The diff for this file is too large to render. See raw diff
 
model_demo/frontend/build/vrm_model/demo.vrm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfb31b3ab6759ff5f4676130827d8b3aa570d0a44eb4f7431cdb56e8790501a6
3
+ size 16932772
model_demo/inference/audio.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class AudioStream:
5
+ """
6
+ Class to mimic streaming audio input.
7
+ """
8
+
9
+ def __init__(
10
+ self, audio: np.ndarray, min_samples_per_step: int, max_samples_per_step: int
11
+ ):
12
+ self.audio = audio
13
+ self.min_samples_per_step = min_samples_per_step
14
+ self.max_samples_per_step = max_samples_per_step
15
+ self.current_idx = 0
16
+ self.can_step = True
17
+
18
+ def step(self) -> np.ndarray:
19
+ if not self.can_step:
20
+ raise StopIteration("End of audio stream")
21
+ start_idx = self.current_idx
22
+ if self.min_samples_per_step == self.max_samples_per_step:
23
+ samples_per_step = self.min_samples_per_step
24
+ else:
25
+ samples_per_step = np.random.randint(
26
+ self.min_samples_per_step, self.max_samples_per_step, (1,)
27
+ ).item()
28
+ end_idx = min(start_idx + samples_per_step, len(self.audio))
29
+ audio_chunk = self.audio[start_idx:end_idx]
30
+ self.current_idx = end_idx
31
+ if end_idx >= len(self.audio):
32
+ self.can_step = False
33
+ return audio_chunk
model_demo/inference/constants.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def exact_div(x, y):
5
+ assert x % y == 0
6
+ return x // y
7
+
8
+
9
+ SAMPLE_RATE = 16000
10
+ N_FFT = 400
11
+ HOP_LENGTH = 160
12
+ CHUNK_LENGTH = 30
13
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
14
+ # 3000 frames in a mel spectrogram input
15
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)
16
+
17
+ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
18
+ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
19
+ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
20
+ TIMESTEP_S = 30 / 1500
21
+
22
+ VIDEO_FPS = 30
23
+ N_AUDIO_SAMPLES_PER_VIDEO_FRAME = SAMPLE_RATE // VIDEO_FPS
24
+ N_VIDEO_FRAMES = CHUNK_LENGTH * VIDEO_FPS # 900 frames in a 30-second video chunk
25
+
26
+
27
+ def mel_frames_from_video_frames(n_video_frames):
28
+ return int(n_video_frames * N_SAMPLES_PER_TOKEN / VIDEO_FPS)
29
+
30
+
31
+ MEL_FILTER_PATH = os.path.join(
32
+ os.path.dirname(__file__), "../../assets", "mel_filters.npz"
33
+ )
34
+ LANDMARKER_PATH = "pretrained_models/mediapipe/face_landmarker_v2_with_blendshapes.task"
35
+
36
+ BLENDSHAPE_NAMES = [
37
+ "_neutral",
38
+ "browDownLeft",
39
+ "browDownRight",
40
+ "browInnerUp",
41
+ "browOuterUpLeft",
42
+ "browOuterUpRight",
43
+ "cheekPuff",
44
+ "cheekSquintLeft",
45
+ "cheekSquintRight",
46
+ "eyeBlinkLeft",
47
+ "eyeBlinkRight",
48
+ "eyeLookDownLeft",
49
+ "eyeLookDownRight",
50
+ "eyeLookInLeft",
51
+ "eyeLookInRight",
52
+ "eyeLookOutLeft",
53
+ "eyeLookOutRight",
54
+ "eyeLookUpLeft",
55
+ "eyeLookUpRight",
56
+ "eyeSquintLeft",
57
+ "eyeSquintRight",
58
+ "eyeWideLeft",
59
+ "eyeWideRight",
60
+ "jawForward",
61
+ "jawLeft",
62
+ "jawOpen",
63
+ "jawRight",
64
+ "mouthClose",
65
+ "mouthDimpleLeft",
66
+ "mouthDimpleRight",
67
+ "mouthFrownLeft",
68
+ "mouthFrownRight",
69
+ "mouthFunnel",
70
+ "mouthLeft",
71
+ "mouthLowerDownLeft",
72
+ "mouthLowerDownRight",
73
+ "mouthPressLeft",
74
+ "mouthPressRight",
75
+ "mouthPucker",
76
+ "mouthRight",
77
+ "mouthRollLower",
78
+ "mouthRollUpper",
79
+ "mouthShrugLower",
80
+ "mouthShrugUpper",
81
+ "mouthSmileLeft",
82
+ "mouthSmileRight",
83
+ "mouthStretchLeft",
84
+ "mouthStretchRight",
85
+ "mouthUpperUpLeft",
86
+ "mouthUpperUpRight",
87
+ "noseSneerLeft",
88
+ "noseSneerRight",
89
+ ]
90
+
91
+ HEAD_ANGLE_NAMES = ["pitch", "yaw", "roll"]
92
+
93
+ HEAD_LANDMARK_DIM = len(BLENDSHAPE_NAMES) + len(HEAD_ANGLE_NAMES)
94
+
95
+
96
+ def get_n_mels(whisper_model_name: str):
97
+ if "v3" in whisper_model_name:
98
+ return 128
99
+ return 80
model_demo/inference/infer.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Optional
3
+ from pathlib import Path
4
+ from typing import Tuple
5
+
6
+ import time
7
+
8
+ from model_demo.inference.audio import AudioStream
9
+ from model_demo.inference.landmarks import (
10
+ unscale_and_uncenter_head_angles,
11
+ clean_up_blendshapes,
12
+ exaggerate_head_wiggle,
13
+ )
14
+ from model_demo.inference.constants import (
15
+ N_AUDIO_SAMPLES_PER_VIDEO_FRAME,
16
+ SAMPLE_RATE,
17
+ HEAD_LANDMARK_DIM,
18
+ )
19
+
20
+ import onnxruntime as ort
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Union
23
+
24
+
25
+ class InferencePipeline:
26
+ """
27
+ Pipeline for running WhisperLike model inference on a video file.
28
+
29
+ Added crossfade functionality to smooth transitions between chunks.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ max_chunk_size: int,
35
+ crossfade_size: int,
36
+ batch_size: int,
37
+ ) -> None:
38
+ """
39
+ Initialize streaming inference pipeline.
40
+
41
+ Args:
42
+ max_chunk_size: Maximum number of frames to process in a single chunk
43
+ crossfade_size: Number of frames to use for crossfading between chunks
44
+ batch_size: Batch size for inference
45
+ device: Device to run on
46
+ """
47
+ self.max_chunk_size = max_chunk_size
48
+ self.max_audio_input_size = (
49
+ self.max_chunk_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME
50
+ )
51
+ self.crossfade_size = crossfade_size
52
+ self.audio_crossfade_size = crossfade_size * N_AUDIO_SAMPLES_PER_VIDEO_FRAME
53
+ self.n_feats = HEAD_LANDMARK_DIM
54
+
55
+ # Maintain state between chunks
56
+ self.prev_output = np.zeros((batch_size, 0, self.n_feats))
57
+ self.audio_buffer = np.zeros((batch_size, 0))
58
+
59
+ # Crossfade buffer stores the overlapping region from the previous chunk
60
+ self.crossfade_buffer = None
61
+
62
+ # Pre-compute crossfade weights
63
+ self.crossfade_weights = np.linspace(0, 1, crossfade_size)
64
+ self.crossfade_weights = self.crossfade_weights.reshape(1, -1)
65
+
66
+ def apply_crossfade(
67
+ self, current_chunk: np.ndarray, update_crossfade_buffer: bool
68
+ ) -> np.ndarray:
69
+ """Apply crossfade between previous and current chunk predictions."""
70
+ if self.crossfade_buffer is not None:
71
+ # Extract the crossfade region from the current chunk
72
+ current_fade_region = current_chunk[:, : self.crossfade_size]
73
+
74
+ # Blend the overlapping regions using the pre-computed weights
75
+ blended_region = np.multiply(
76
+ self.crossfade_buffer, np.expand_dims((1 - self.crossfade_weights), -1)
77
+ ) + np.multiply(
78
+ current_fade_region, np.expand_dims(self.crossfade_weights, -1)
79
+ )
80
+
81
+ # Replace the beginning of the current chunk with the blended region
82
+ output = current_chunk.copy()
83
+ output[:, : self.crossfade_size] = blended_region
84
+ else:
85
+ output = current_chunk
86
+ if update_crossfade_buffer:
87
+ self.crossfade_buffer = current_chunk[:, -self.crossfade_size :].copy()
88
+ output = output[:, : -self.crossfade_size]
89
+ return output
90
+
91
+ def model_generate(self, src, max_len, initial_context=None):
92
+ """
93
+ Generate output sequence with optional initial context.
94
+
95
+ Args:
96
+ src: Source audio features of shape [B, T_a, D], where T_a is the number of
97
+ audio frames corresponding to max_len video frames
98
+ max_len: Number of frames to generate
99
+ initial_context: Optional previous output context (B, J, D), where J is
100
+ in [1, max_len + 1]
101
+
102
+ Returns:
103
+ Predicted landmarks [B, max_len - J, D]
104
+ """
105
+ pass
106
+
107
+ def infer_chunk(self, audio: np.ndarray, new_audio_len: int) -> np.ndarray:
108
+ """Process a single chunk of audio, using previous context if available."""
109
+ n_new_frames = (
110
+ new_audio_len // N_AUDIO_SAMPLES_PER_VIDEO_FRAME + self.crossfade_size
111
+ )
112
+ n_generation_frames = audio.shape[1] // N_AUDIO_SAMPLES_PER_VIDEO_FRAME
113
+ n_context_frames = (n_generation_frames - n_new_frames) + 1
114
+ if n_context_frames > 0:
115
+ initial_context = self.prev_output[:, -n_context_frames:]
116
+ else:
117
+ initial_context = None
118
+ # Generate predictions
119
+ predictions = self.model_generate(audio, n_generation_frames, initial_context)
120
+
121
+ self.prev_output = np.concatenate([self.prev_output, predictions], axis=1)[
122
+ :, -self.max_chunk_size :
123
+ ]
124
+ return predictions
125
+
126
+ def prepare_input_chunk(self, audio: np.ndarray) -> np.ndarray:
127
+ new_audio_len = audio.shape[1]
128
+ self.audio_buffer = np.concatenate([self.audio_buffer, audio], axis=1)[
129
+ :, -self.max_audio_input_size :
130
+ ]
131
+ return self.audio_buffer, new_audio_len
132
+
133
+ def process_output_chunk(
134
+ self,
135
+ chunk: np.ndarray,
136
+ update_crossfade_buffer: bool,
137
+ mouth_exaggeration: float,
138
+ brow_exaggeration: float,
139
+ head_wiggle_exaggeration: float,
140
+ unsquinch_fix: float,
141
+ eye_contact_fix: float,
142
+ exaggerate_above: float,
143
+ symmetrize_eyes: bool,
144
+ ) -> np.ndarray:
145
+ chunk[..., :52] = clean_up_blendshapes(
146
+ chunk[..., :52],
147
+ mouth_exaggeration,
148
+ brow_exaggeration,
149
+ clear_neutral=True,
150
+ unsquinch_fix=unsquinch_fix,
151
+ eye_contact_fix=eye_contact_fix,
152
+ exaggerate_above=exaggerate_above,
153
+ symmetrize_eyes=symmetrize_eyes,
154
+ )
155
+ if head_wiggle_exaggeration != 1.0:
156
+ chunk[..., 52:] = exaggerate_head_wiggle(
157
+ chunk[..., 52:], head_wiggle_exaggeration
158
+ )
159
+ if self.crossfade_size > 0 and chunk.shape[1] > self.crossfade_size:
160
+ chunk = self.apply_crossfade(chunk, update_crossfade_buffer)
161
+ return chunk
162
+
163
+ def __call__(
164
+ self,
165
+ audio: np.ndarray,
166
+ audio_stream_can_step: bool,
167
+ mouth_exaggeration: float,
168
+ brow_exaggeration: float,
169
+ head_wiggle_exaggeration: float,
170
+ unsquinch_fix: float,
171
+ eye_contact_fix: float,
172
+ exaggerate_above: float,
173
+ symmetrize_eyes: bool,
174
+ ) -> np.ndarray:
175
+ """
176
+ Run the model on an audio tensor.
177
+
178
+ Args:
179
+ audio: Audio tensor of shape (batch_size, n_audio_samples)
180
+
181
+ Returns:
182
+ np.ndarray: Model predictions
183
+ """
184
+ input_chunk, new_audio_len = self.prepare_input_chunk(audio)
185
+ output_chunk = self.infer_chunk(input_chunk, new_audio_len)
186
+ return self.process_output_chunk(
187
+ output_chunk,
188
+ update_crossfade_buffer=audio_stream_can_step,
189
+ mouth_exaggeration=mouth_exaggeration,
190
+ brow_exaggeration=brow_exaggeration,
191
+ head_wiggle_exaggeration=head_wiggle_exaggeration,
192
+ unsquinch_fix=unsquinch_fix,
193
+ eye_contact_fix=eye_contact_fix,
194
+ exaggerate_above=exaggerate_above,
195
+ symmetrize_eyes=symmetrize_eyes,
196
+ )
197
+
198
+ def reset(self):
199
+ """Reset internal state"""
200
+ self.prev_output = np.zeros_like(self.prev_output)
201
+ self.audio_buffer = np.zeros_like(self.audio_buffer)
202
+ self.crossfade_buffer = None
203
+
204
+ def infer_audio_array(
205
+ self,
206
+ audio: np.ndarray,
207
+ min_audio_samples_per_step: int,
208
+ max_audio_samples_per_step: int,
209
+ mouth_exaggeration: float = 1.0,
210
+ brow_exaggeration: float = 1.0,
211
+ head_wiggle_exaggeration: float = 1.0,
212
+ unsquinch_fix: float = 0.0,
213
+ eye_contact_fix: float = 0.0,
214
+ exaggerate_above: float = 0.0,
215
+ symmetrize_eyes: bool = False,
216
+ max_audio_duration: Optional[float] = None,
217
+ ) -> Tuple[np.ndarray, float, float, float]:
218
+ """
219
+ Run the model on an input audio or video file under simulated streaming conditions.
220
+
221
+ Args:
222
+ audio: Numpy array of audio samples
223
+ min_audio_samples_per_step: Minimum number of audio samples per step
224
+ max_audio_samples_per_step: Maximum number of audio samples per step
225
+ max_audio_duration: Maximum duration of audio to process in seconds
226
+
227
+ Returns:
228
+ Tuple of:
229
+ - Blendshapes of shape (T, 52)
230
+ - Head angles of shape (T, 3)
231
+ - Mean time per step in seconds
232
+ - Mean real-time factor
233
+ """
234
+ # Reset all buffers
235
+ self.reset()
236
+ # Apply duration limit if specified
237
+ if max_audio_duration is not None:
238
+ max_audio_duration_frames = int(max_audio_duration * SAMPLE_RATE)
239
+ audio_len = min(len(audio), max_audio_duration_frames)
240
+ else:
241
+ audio_len = len(audio)
242
+
243
+ audio_stream = AudioStream(
244
+ audio[:audio_len], min_audio_samples_per_step, max_audio_samples_per_step
245
+ )
246
+
247
+ # Process each chunk
248
+ outputs = []
249
+ step_times = []
250
+ audio_durations = []
251
+ while audio_stream.can_step:
252
+ audio_chunk = audio_stream.step()
253
+ audio_durations.append(audio_chunk.shape[-1] / SAMPLE_RATE)
254
+ # Process the chunk
255
+ start_time = time.time()
256
+ chunk_output = self(
257
+ np.expand_dims(audio_chunk, 0),
258
+ audio_stream.can_step,
259
+ mouth_exaggeration,
260
+ brow_exaggeration,
261
+ head_wiggle_exaggeration,
262
+ unsquinch_fix,
263
+ eye_contact_fix,
264
+ exaggerate_above,
265
+ symmetrize_eyes,
266
+ )
267
+ step_times.append(time.time() - start_time)
268
+ outputs.append(chunk_output)
269
+
270
+ # Concatenate all outputs
271
+ full_output = np.concatenate(outputs, axis=1)
272
+ mean_step_time = sum(step_times) / len(step_times)
273
+ mean_rtf = sum(audio_durations) / sum(step_times)
274
+ time_to_first_sound = step_times[0] + audio_durations[0]
275
+
276
+ blendshapes = full_output.squeeze(0)[:, :52]
277
+ head_angles = unscale_and_uncenter_head_angles(
278
+ full_output.squeeze(0)[:, 52:], bad_frames=[]
279
+ )
280
+
281
+ return blendshapes, head_angles, mean_step_time, mean_rtf, time_to_first_sound
282
+
283
+
284
+ @dataclass
285
+ class ONNXModels:
286
+ hubert_session: ort.InferenceSession
287
+ encoder_session: ort.InferenceSession
288
+ decoder_session: ort.InferenceSession
289
+
290
+
291
+ class ONNXInferencePipeline(InferencePipeline):
292
+ """
293
+ ONNX version of the inference pipeline.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ onnx_models: ONNXModels,
299
+ max_chunk_size: int,
300
+ crossfade_size: int,
301
+ batch_size: int,
302
+ ):
303
+ """
304
+ Initialize ONNX inference pipeline.
305
+
306
+ Args:
307
+ onnx_models: ONNXModels containing hubert and decoder sessions
308
+ max_chunk_size: Maximum number of frames to process in a single chunk
309
+ crossfade_size: Number of frames to use for crossfading between chunks
310
+ batch_size: Batch size for inference
311
+ device: Device to run inference on
312
+ """
313
+ super().__init__(
314
+ max_chunk_size,
315
+ crossfade_size,
316
+ batch_size,
317
+ )
318
+ self.onnx_models = onnx_models
319
+
320
+ def model_generate(self, src, max_len, initial_context=None):
321
+ """
322
+ Generate output sequence using ONNX models.
323
+ """
324
+ # Run HuBERT through ONNX
325
+ src_np = src.astype(np.float32)
326
+ hubert_out = self.onnx_models.hubert_session.run(
327
+ None, {"input_values": src_np}
328
+ )[0]
329
+ src = self.onnx_models.encoder_session.run(None, {"src": hubert_out})[0]
330
+
331
+ if initial_context is not None:
332
+ decoder_in = initial_context.astype(np.float32)
333
+ else:
334
+ decoder_in = np.zeros((src.shape[0], 1, HEAD_LANDMARK_DIM)).astype(
335
+ np.float32
336
+ )
337
+
338
+ outputs = []
339
+ for i in range(max_len - decoder_in.shape[1] + 1):
340
+ # Run decoder step through ONNX
341
+ next_output = self.onnx_models.decoder_session.run(
342
+ None,
343
+ {"src": src.astype(np.float32), "decoder_in": decoder_in},
344
+ )[0]
345
+
346
+ decoder_in = np.concatenate([decoder_in, next_output], axis=1)
347
+ outputs.append(next_output)
348
+
349
+ pred_out = np.concatenate(outputs, axis=1)
350
+ return pred_out
351
+
352
+
353
+ def init_pipeline(
354
+ hubert_onnx_path: Path,
355
+ encoder_onnx_path: Path,
356
+ decoder_onnx_path: Path,
357
+ device: str = "cpu",
358
+ chunk_size: int = 90,
359
+ crossfade_size: int = 5,
360
+ batch_size: int = 1,
361
+ ) -> Union[InferencePipeline, ONNXInferencePipeline]:
362
+ """
363
+ Initialize ONNX inference pipeline based on provided paths.
364
+
365
+ Args:
366
+ hubert_onnx_path: Path to ONNX HuBERT model
367
+ decoder_onnx_path: Path to ONNX decoder model
368
+ chunk_size: Maximum number of frames per chunk
369
+ crossfade_size: Number of frames for crossfading
370
+ batch_size: Batch size for inference
371
+ device: Device to run on
372
+
373
+ Returns:
374
+ ONNX inference pipeline
375
+ """
376
+ # ONNX pipeline
377
+ providers = (
378
+ ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
379
+ )
380
+
381
+ hubert_session = ort.InferenceSession(str(hubert_onnx_path), providers=providers)
382
+ encoder_session = ort.InferenceSession(str(encoder_onnx_path), providers=providers)
383
+ decoder_session = ort.InferenceSession(str(decoder_onnx_path), providers=providers)
384
+
385
+ onnx_models = ONNXModels(hubert_session, encoder_session, decoder_session)
386
+ return ONNXInferencePipeline(onnx_models, chunk_size, crossfade_size, batch_size)
model_demo/inference/landmarks.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List, Union
2
+ import numpy as np
3
+ from math import e, pi
4
+
5
+ from model_demo.inference.constants import BLENDSHAPE_NAMES
6
+
7
+
8
+ def clean_up_blendshapes(
9
+ blendshapes: np.ndarray,
10
+ mouth_exaggeration: float,
11
+ brow_exaggeration: float,
12
+ unsquinch_fix: float,
13
+ eye_contact_fix: float,
14
+ clear_neutral: bool = False,
15
+ exaggerate_above: float = 0,
16
+ symmetrize_eyes: bool = False,
17
+ ) -> np.ndarray:
18
+ """
19
+ Exaggerate blendshapes by a given factor.
20
+
21
+ Args:
22
+ blendshapes: Blendshape coefficients of shape (B, T, D) or (T, D)
23
+ exaggeration_factor: Factor to exaggerate the blendshapes by
24
+ unsquinch_fix: Factor to reduce eye squint and blink blendshapes by in range [0, 1]
25
+ eye_contact_fix: Factor to reduce eye look blendshapes by in range [0, 1]
26
+ clear_neutral: Whether to clear the neutral expression blendshape (set to 0)
27
+ mouth_only: Whether to exaggerate only the mouth blendshapes
28
+ exaggerate_above: Landmarks below this value will be exaggerated up, below down
29
+
30
+ Returns:
31
+ Exaggerated blendshape coefficients of shape (B, T, D) or (T, D)
32
+ """
33
+
34
+ def modify_blendshapes(
35
+ blendshapes: np.ndarray, target_substrings: List[str], factor: float
36
+ ) -> np.ndarray:
37
+ if factor != 1:
38
+ for i, shape in enumerate(BLENDSHAPE_NAMES):
39
+ if any(substring in shape for substring in target_substrings):
40
+ blendshapes_offset = blendshapes[..., i] - exaggerate_above
41
+ blendshapes[..., i] = blendshapes_offset * factor + exaggerate_above
42
+ blendshapes = np.clip(blendshapes, 0.0, 1.0)
43
+ return blendshapes
44
+
45
+ if clear_neutral:
46
+ blendshapes[..., 0] = 0
47
+
48
+ modify_blendshapes(blendshapes, ["mouth", "jaw", "cheek"], mouth_exaggeration)
49
+ modify_blendshapes(blendshapes, ["brow", "noseSneer", "eye"], brow_exaggeration)
50
+ if unsquinch_fix > 0:
51
+ eye_idx = [
52
+ i
53
+ for i, name in enumerate(BLENDSHAPE_NAMES)
54
+ if "eyeSquint" in name or "eyeBlink" in name
55
+ ]
56
+ for idx in eye_idx:
57
+ blendshapes[..., idx] -= unsquinch_fix
58
+ if eye_contact_fix > 0:
59
+ eye_idx = [i for i, name in enumerate(BLENDSHAPE_NAMES) if "eyeLook" in name]
60
+ for idx in eye_idx:
61
+ blendshapes[..., idx] -= eye_contact_fix
62
+ if symmetrize_eyes:
63
+ # average between eyeBlinkLeft and eyeBlinkRight
64
+ eye_blink_left_index = BLENDSHAPE_NAMES.index("eyeBlinkLeft")
65
+ eye_blink_right_index = BLENDSHAPE_NAMES.index("eyeBlinkRight")
66
+ avg_val = (
67
+ blendshapes[..., eye_blink_left_index]
68
+ + blendshapes[..., eye_blink_right_index]
69
+ ) / 2
70
+ blendshapes[..., eye_blink_left_index] = avg_val
71
+ blendshapes[..., eye_blink_right_index] = avg_val
72
+
73
+ blendshapes = np.clip(blendshapes, 0.0, 1.0)
74
+
75
+ return blendshapes
76
+
77
+
78
+ def exaggerate_head_wiggle(
79
+ head_angles: np.ndarray[np.float32], exaggeration_factor: float
80
+ ) -> np.ndarray[np.float32]:
81
+ """
82
+ Exaggerate head angles by a given factor.
83
+
84
+ Args:
85
+ head_angles: Sequence of pitch, yaw, roll values of shape (temporal_dim, 3)
86
+ exaggeration_factor: Factor to exaggerate the head angles by
87
+
88
+ Returns:
89
+ Exaggerated head angles of shape (temporal_dim, 3)
90
+ """
91
+ return head_angles * exaggeration_factor
92
+
93
+
94
+ def unscale_and_uncenter_head_angles(
95
+ head_angles: np.ndarray[np.float32],
96
+ mean_pos: Optional[np.ndarray[np.float32]] = None,
97
+ bad_frames: List[int] = [],
98
+ ) -> np.ndarray[np.float32]:
99
+ """
100
+ Rescale head angles in range [-1, 1] to [-pi, pi] and uncenter them.
101
+
102
+ Args:
103
+ head_angles: Sequence of pitch, yaw, roll values of shape (temporal_dim, 3)
104
+ mean_pos: Mean position to offset the head angles of shape (3,)
105
+ bad_frames: List of indices of frames where face detection failed
106
+
107
+ Returns:
108
+ Array of unscaled and uncentered head angles of shape (temporal_dim, 3)
109
+ """
110
+ if mean_pos is None:
111
+ mean_pos = np.zeros(3).astype(np.float32)
112
+ good_frames = [i for i in range(head_angles.shape[0]) if i not in bad_frames]
113
+ head_angles[good_frames] = head_angles[good_frames] + mean_pos
114
+ head_angles[good_frames] = head_angles[good_frames] * pi
115
+ return head_angles
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pydub
2
+ numpy
3
+ onnxruntime