import argparse import numpy as np import torch from moshi.models.loaders import CheckpointInfo from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel def load_model(): checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO) tts_model = TTSModel.from_checkpoint_info( checkpoint_info, n_q=32, temp=0.6, device=torch.device("cpu") ) return tts_model def generate_audio(tts_model, text, voice): entries = tts_model.prepare_script([text], padding_between=1) voice_path = tts_model.get_voice_path(voice) condition_attributes = tts_model.make_condition_attributes( [voice_path], cfg_coef=2.0 ) pcms = [] def _on_frame(frame): if (frame != -1).all(): pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcms.append(np.clip(pcm[0, 0], -1, 1)) all_entries = [entries] all_condition_attributes = [condition_attributes] with tts_model.mimi.streaming(len(all_entries)): tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame) audio = np.concatenate(pcms, axis=-1) return audio