import streamlit as st from model_demo import model_demo from model_demo.inference.infer import init_pipeline import os import pydub import numpy as np from model_demo.inference.constants import BLENDSHAPE_NAMES def make_downloadable_json(blendshapes, headangles): blendshape_dict = {} for i, name in enumerate(BLENDSHAPE_NAMES): blendshape_dict[name] = blendshapes[:, i].tolist() headangle_dict = {} for i, name in enumerate(["pitch", "yaw", "roll"]): headangle_dict[name] = headangles[:, i].tolist() return str({"blendshapes": blendshape_dict, "headangles": headangle_dict}) if "pred_dict" not in st.session_state: st.session_state.pred_dict = {} current_dir = os.path.dirname(os.path.abspath(__file__)) onnx_path = "assets/onnx_models" hubert_path = f"{onnx_path}/hubert.onnx" encoder_path = f"{onnx_path}/encoder.onnx" decoder_path = f"{onnx_path}/decoder.onnx" pipeline = init_pipeline(hubert_path, encoder_path, decoder_path) (col1, col2) = st.columns([2, 3]) with col1: with st.container(border=True): audio_tab, control_tab, vrm_tab = st.tabs(["Audio", "Controls", "Upload VRM"]) with audio_tab: recorded_value = st.audio_input("Record audio") st.write("Or") uploaded_value = st.file_uploader("Upload audio", type=["wav"]) audio_value = ( recorded_value if recorded_value is not None else uploaded_value ) with control_tab: mouth_exaggeration = st.number_input("Lower face exaggeration", value=5.0) brow_exaggeration = st.number_input("Upper face exaggeration", value=4.0) head_wiggle_exaggeration = st.number_input( "Head wiggle exaggeration", value=2.0 ) unsquinch_fix = st.number_input( "Unsquinch fix", value=0.75, ) eye_contact_fix = st.number_input( "Eye contact fix", value=1.5, ) exaggerate_above = st.number_input( "Exaggerate above", value=0.01, min_value=0.0, max_value=1.0, step=0.001, format="%.3f", ) symmetrize_eyes = st.checkbox("Symmetrize eyes", value=True) with vrm_tab: vrm_file = st.file_uploader("Upload VRM file", type=["vrm"]) if vrm_file: # Read the raw bytes from the uploaded file vrm_bytes = vrm_file.read() # Store the raw bytes in the session state st.session_state.pred_dict["vrm_file"] = vrm_bytes submit_button = st.button("Run Inference", disabled=not audio_value) if submit_button and audio_value: audio_segment = ( pydub.AudioSegment.from_file(audio_value).set_frame_rate(16000).set_channels(1) ) audio_array = np.array(audio_segment.get_array_of_samples()) blendshapes, head_angles, mean_step_time, mean_rtf, time_to_first_sound = ( pipeline.infer_audio_array( np.array(audio_array), 32000, 48000, mouth_exaggeration, brow_exaggeration, head_wiggle_exaggeration, unsquinch_fix, eye_contact_fix, exaggerate_above, symmetrize_eyes, ) ) st.session_state.pred_dict["blendshapes"] = blendshapes st.session_state.pred_dict["head_angles"] = head_angles st.session_state.pred_dict["audio_data"] = audio_value.getvalue() processing_string = f"Inference complete at {mean_rtf:.2f}x real-time." with col1: st.write(processing_string) st.download_button( label="Download results as JSON", data=make_downloadable_json(blendshapes, head_angles), file_name="inference_results.json", mime="text/json", ) with col2: model_demo( blendshapes=st.session_state.pred_dict.get("blendshapes", None), headangles=st.session_state.pred_dict.get("head_angles", None), audio_data=st.session_state.pred_dict.get("audio_data", None), vrm_data=st.session_state.pred_dict.get("vrm_file", None), key="model_viewport", )