VoiceStar / inference_commandline.py
mrfakename's picture
Upload 51 files
82bc972 verified
import os
import torch
import torchaudio
import numpy as np
import random
import whisper
import fire
from argparse import Namespace
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from models import voice_star
from inference_tts_utils import inference_one_sample
############################################################
# Utility Functions
############################################################
def seed_everything(seed=1):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def estimate_duration(ref_audio_path, text):
"""
Estimate duration based on seconds per character from the reference audio.
"""
info = torchaudio.info(ref_audio_path)
audio_duration = info.num_frames / info.sample_rate
length_text = max(len(text), 1)
spc = audio_duration / length_text # seconds per character
return len(text) * spc
############################################################
# Main Inference Function
############################################################
def run_inference(
reference_speech="./demo/5895_34622_000026_000002.wav",
target_text="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.",
# Model
model_name="VoiceStar_840M_30s", # or VoiceStar_840M_40s, the later model is trained on maximally 40s long speech
model_root="./pretrained",
# Additional optional
reference_text=None, # if None => run whisper on reference_speech
target_duration=None, # if None => estimate from reference_speech and target_text
# Default hyperparameters from snippet
codec_audio_sr=16000, # do not change
codec_sr=50, # do not change
top_k=10, # try 10, 20, 30, 40
top_p=1, # do not change
min_p=1, # do not change
temperature=1,
silence_tokens=None, # do not change it
kvcache=1, # if OOM, set to 0
multi_trial=None, # do not change it
repeat_prompt=1, # increase this to improve speaker similarity, but it reference speech duration in total adding target duration is longer than maximal training duration, quality may drop
stop_repetition=3, # will not use it
sample_batch_size=1, # do not change
# Others
seed=1,
output_dir="./generated_tts",
# Some snippet-based defaults
cut_off_sec=100, # do not adjust this, we always use the entire reference speech. If you wish to change, also make sure to change the reference_transcript, so that it's only the trasnscript of the speech remained
):
"""
Inference script using Fire.
Example:
python inference_commandline.py \
--reference_speech "./demo/5895_34622_000026_000002.wav" \
--target_text "I cannot believe ... this audio is 10 seconds long." \
--reference_text "(optional) text to use as prefix" \
--target_duration (optional float)
"""
# Seed everything
seed_everything(seed)
# Load model, phn2num, and args
torch.serialization.add_safe_globals([Namespace])
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_fn = os.path.join(model_root, model_name+".pth")
if not os.path.exists(ckpt_fn):
# use wget to download
print(f"[Info] Downloading {model_name} checkpoint...")
os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
args = bundle["args"]
phn2num = bundle["phn2num"]
model = voice_star.VoiceStar(args)
model.load_state_dict(bundle["model"])
model.to(device)
model.eval()
# If reference_text not provided, use whisper large-v3-turbo
if reference_text is None:
print("[Info] No reference_text provided, transcribing reference_speech with Whisper.")
wh_model = whisper.load_model("large-v3-turbo")
result = wh_model.transcribe(reference_speech)
prefix_transcript = result["text"]
print(f"[Info] Whisper transcribed text: {prefix_transcript}")
else:
prefix_transcript = reference_text
# If target_duration not provided, estimate from reference speech + target_text
if target_duration is None:
target_generation_length = estimate_duration(reference_speech, target_text)
print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration.")
else:
target_generation_length = float(target_duration)
# signature from snippet
if args.n_codebooks == 4:
signature = "./pretrained/encodec_6f79c6a8.th"
elif args.n_codebooks == 8:
signature = "./pretrained/encodec_8cb1024_giga.th"
else:
# fallback, just use the 6-f79c6a8
signature = "./pretrained/encodec_6f79c6a8.th"
if silence_tokens is None:
# default from snippet
silence_tokens = []
if multi_trial is None:
# default from snippet
multi_trial = []
delay_pattern_increment = args.n_codebooks + 1 # from snippet
# We can compute prompt_end_frame if we want, from snippet
info = torchaudio.info(reference_speech)
prompt_end_frame = int(cut_off_sec * info.sample_rate)
# Prepare tokenizers
audio_tokenizer = AudioTokenizer(signature=signature)
text_tokenizer = TextTokenizer(backend="espeak")
# decode_config from snippet
decode_config = {
'top_k': top_k,
'top_p': top_p,
'min_p': min_p,
'temperature': temperature,
'stop_repetition': stop_repetition,
'kvcache': kvcache,
'codec_audio_sr': codec_audio_sr,
'codec_sr': codec_sr,
'silence_tokens': silence_tokens,
'sample_batch_size': sample_batch_size
}
# Run inference
print("[Info] Running TTS inference...")
concated_audio, gen_audio = inference_one_sample(
model, args, phn2num, text_tokenizer, audio_tokenizer,
reference_speech, target_text,
device, decode_config,
prompt_end_frame=prompt_end_frame,
target_generation_length=target_generation_length,
delay_pattern_increment=delay_pattern_increment,
prefix_transcript=prefix_transcript,
multi_trial=multi_trial,
repeat_prompt=repeat_prompt,
)
# The model returns a list of waveforms, pick the first
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
# Save the audio (just the generated portion, as the snippet does)
os.makedirs(output_dir, exist_ok=True)
out_filename = "generated.wav"
out_path = os.path.join(output_dir, out_filename)
torchaudio.save(out_path, gen_audio, codec_audio_sr)
print(f"[Success] Generated audio saved to {out_path}")
def main():
fire.Fire(run_inference)
if __name__ == "__main__":
main()