tts-1.6b-en_fr / inference.py
rumbleFTW's picture
tts
f2c3e4e verified
import argparse
import numpy as np
import torch
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def load_model():
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cpu")
)
return tts_model
def generate_audio(tts_model, text, voice):
entries = tts_model.prepare_script([text], padding_between=1)
voice_path = tts_model.get_voice_path(voice)
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
pcms = []
def _on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0], -1, 1))
all_entries = [entries]
all_condition_attributes = [condition_attributes]
with tts_model.mimi.streaming(len(all_entries)):
tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)
audio = np.concatenate(pcms, axis=-1)
return audio