|
import asyncio |
|
import functools |
|
import logging |
|
import random |
|
import time |
|
import uuid |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Optional, List, Tuple, Union, AsyncGenerator, Dict, Any |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import librosa |
|
import torch |
|
import numpy as np |
|
import torchaudio |
|
import sounddevice as sd |
|
import io |
|
from torch import nn |
|
from IPython.display import Audio, display |
|
|
|
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt, RequestOutput |
|
from vllm.multimodal import MultiModalDataDict |
|
from vllm.utils import Counter |
|
|
|
from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder |
|
|
|
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder |
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler |
|
|
|
from .xtts2_config import XTTSConfig, XTTSGPTConfig |
|
from .tokenizer import XTTSTokenizerFast |
|
|
|
from ..xtts2_gpt.xtts2_gpt_modeling import LearnedPositionEmbeddings |
|
|
|
|
|
def wav_to_mel_cloning( |
|
wav, |
|
mel_norms_file="../experiments/clips_mel_norms.pth", |
|
mel_norms=None, |
|
device=torch.device("cpu"), |
|
n_fft=4096, |
|
hop_length=1024, |
|
win_length=4096, |
|
power=2, |
|
normalized=False, |
|
sample_rate=22050, |
|
f_min=0, |
|
f_max=8000, |
|
n_mels=80, |
|
): |
|
mel_stft = torchaudio.transforms.MelSpectrogram( |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
power=power, |
|
normalized=normalized, |
|
sample_rate=sample_rate, |
|
f_min=f_min, |
|
f_max=f_max, |
|
n_mels=n_mels, |
|
norm="slaney", |
|
).to(device) |
|
wav = wav.to(device) |
|
mel = mel_stft(wav) |
|
mel = torch.log(torch.clamp(mel, min=1e-5)) |
|
if mel_norms is None: |
|
mel_norms = torch.load(mel_norms_file, map_location=device) |
|
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) |
|
return mel |
|
|
|
|
|
def load_audio(audiopath, sampling_rate): |
|
audio, lsr = torchaudio.load(audiopath) |
|
|
|
|
|
if audio.size(0) != 1: |
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
|
if lsr != sampling_rate: |
|
audio = torchaudio.functional.resample(audio, lsr, sampling_rate) |
|
|
|
|
|
audio.clip_(-1, 1) |
|
return audio |
|
|
|
|
|
@dataclass |
|
class XTTSRequest: |
|
"""Container for XTTS inference request data""" |
|
request_id: str |
|
text: Union[AsyncGenerator[str, None], str] |
|
language: str |
|
speaker_file: str |
|
generate_every_n_chars: Optional[int] = None |
|
temperature: float = 0.75 |
|
top_p: float = 0.85 |
|
top_k: int = 50 |
|
repetition_penalty: float = 5.0 |
|
length_penalty: float = 1.0 |
|
do_sample: bool = True |
|
max_ref_length: int = 60 |
|
gpt_cond_len: int = 30 |
|
gpt_cond_chunk_len: int = 4 |
|
|
|
|
|
import threading |
|
|
|
class HiddenStatesCollector: |
|
def __init__(self): |
|
self.outputs = {} |
|
self.lock = threading.Lock() |
|
|
|
def __call__(self, outputs: Optional[torch.Tensor], request_id: str): |
|
"""Save outputs for a specific request""" |
|
with self.lock: |
|
if request_id not in self.outputs: |
|
self.outputs[request_id] = [] |
|
self.outputs[request_id].append(outputs) |
|
|
|
def get_hidden_states(self, request_id) -> Optional[torch.Tensor]: |
|
with self.lock: |
|
outputs = self.outputs.pop(request_id, None) |
|
if outputs is not None: |
|
outputs = torch.cat(outputs, dim=0) |
|
return outputs |
|
|
|
def bind_to_request(self, request_id: str): |
|
def bound_collector(outputs: Optional[torch.Tensor], _request_id: str = None): |
|
self(outputs, request_id) |
|
return bound_collector |
|
|
|
class ExtendedSamplingParams(SamplingParams, kw_only=True): |
|
"""Extended sampling parameters that allows additional fields while maintaining compatibility with SamplingParams. |
|
|
|
This class inherits from SamplingParams and allows adding new required fields |
|
without conflicting with the base class's optional fields ordering. |
|
""" |
|
hidden_state_collector: HiddenStatesCollector |
|
|
|
|
|
class LogitsRepetitionPenalizer: |
|
"""A logits processor that applies repetition penalty to prevent repetitive text generation.""" |
|
|
|
def __init__(self, repetition_penalty: float): |
|
if repetition_penalty < 0: |
|
raise ValueError("Repetition penalty must be non-negative") |
|
self.repetition_penalty = repetition_penalty |
|
|
|
def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: |
|
"""Apply repetition penalty to the logits based on previous tokens.""" |
|
|
|
if self.repetition_penalty == 1.0 or not token_ids: |
|
return logits |
|
|
|
|
|
repeated_tokens = torch.tensor(token_ids, |
|
device=logits.device, |
|
dtype=torch.long) |
|
|
|
|
|
repeated_logits = logits[repeated_tokens] |
|
|
|
|
|
repeated_logits = torch.where( |
|
repeated_logits > 0, |
|
repeated_logits / self.repetition_penalty, |
|
repeated_logits * self.repetition_penalty |
|
) |
|
|
|
|
|
logits[repeated_tokens] = repeated_logits |
|
|
|
return logits |
|
|
|
|
|
@dataclass |
|
class XTTSOutput: |
|
"""Container for XTTS inference output with integrated audio utilities""" |
|
request_id: str |
|
wav: np.ndarray |
|
sample_rate: int = 24000 |
|
|
|
def to_tensor(self) -> torch.Tensor: |
|
"""Convert numpy array to torch tensor""" |
|
if isinstance(self.wav, np.ndarray): |
|
return torch.from_numpy(self.wav) |
|
return self.wav |
|
|
|
def to_bytes(self, format: str = 'wav', sample_width: int = 2) -> bytes: |
|
"""Convert audio to bytes format. |
|
|
|
Args: |
|
format: Output format ('wav' or 'raw') |
|
sample_width: Bit depth (1, 2, or 4 bytes per sample) |
|
|
|
Returns: |
|
Audio data as bytes |
|
""" |
|
|
|
wav_tensor = self.to_tensor() |
|
|
|
|
|
if wav_tensor.dim() == 1: |
|
wav_tensor = wav_tensor.unsqueeze(0) |
|
|
|
|
|
wav_tensor = torch.clamp(wav_tensor, -1.0, 1.0) |
|
|
|
if format == 'wav': |
|
buffer = io.BytesIO() |
|
torchaudio.save( |
|
buffer, |
|
wav_tensor, |
|
self.sample_rate, |
|
format="wav", |
|
encoding="PCM_S" if sample_width == 2 else "PCM_F", |
|
bits_per_sample=sample_width * 8 |
|
) |
|
return buffer.getvalue() |
|
|
|
elif format == 'raw': |
|
|
|
if sample_width == 2: |
|
wav_tensor = (wav_tensor * 32767).to(torch.int16) |
|
elif sample_width == 4: |
|
wav_tensor = (wav_tensor * 2147483647).to(torch.int32) |
|
else: |
|
wav_tensor = (wav_tensor * 127).to(torch.int8) |
|
return wav_tensor.cpu().numpy().tobytes() |
|
|
|
else: |
|
raise ValueError(f"Unsupported format: {format}") |
|
|
|
def save(self, |
|
filename: Union[str, Path], |
|
sample_rate: Optional[int] = None, |
|
format: Optional[str] = None) -> None: |
|
"""Save audio to file. |
|
|
|
Args: |
|
filename: Output filename |
|
sample_rate: Optional new sample rate for resampling |
|
format: Optional format override (default: inferred from extension) |
|
""" |
|
wav_tensor = self.to_tensor() |
|
if wav_tensor.dim() == 1: |
|
wav_tensor = wav_tensor.unsqueeze(0) |
|
|
|
|
|
if sample_rate and sample_rate != self.sample_rate: |
|
wav_tensor = torchaudio.functional.resample( |
|
wav_tensor, |
|
orig_freq=self.sample_rate, |
|
new_freq=sample_rate |
|
) |
|
else: |
|
sample_rate = self.sample_rate |
|
|
|
torchaudio.save( |
|
filename, |
|
wav_tensor, |
|
sample_rate, |
|
format=format |
|
) |
|
|
|
def resample(self, new_sample_rate: int) -> 'XTTSOutput': |
|
"""Create new XTTSOutput with resampled audio. |
|
|
|
Args: |
|
new_sample_rate: Target sample rate |
|
|
|
Returns: |
|
New XTTSOutput instance with resampled audio |
|
""" |
|
wav_tensor = self.to_tensor() |
|
if wav_tensor.dim() == 1: |
|
wav_tensor = wav_tensor.unsqueeze(0) |
|
|
|
resampled = torchaudio.functional.resample( |
|
wav_tensor, |
|
orig_freq=self.sample_rate, |
|
new_freq=new_sample_rate |
|
) |
|
|
|
return XTTSOutput( |
|
request_id=self.request_id, |
|
wav=resampled.squeeze().numpy(), |
|
sample_rate=new_sample_rate |
|
) |
|
|
|
def get_info(self) -> Tuple[int, int, float]: |
|
"""Get audio information. |
|
|
|
Returns: |
|
Tuple of (number of samples, sample rate, duration in seconds) |
|
""" |
|
n_samples = len(self.wav) |
|
duration = n_samples / self.sample_rate |
|
return n_samples, self.sample_rate, duration |
|
|
|
@classmethod |
|
def from_tensor(cls, request_id: str, tensor: torch.Tensor, sample_rate: int = 24000) -> 'XTTSOutput': |
|
"""Create XTTSOutput from torch tensor. |
|
|
|
Args: |
|
request_id: Request identifier |
|
tensor: Audio tensor |
|
sample_rate: Sample rate of the audio |
|
|
|
Returns: |
|
New XTTSOutput instance |
|
""" |
|
return cls( |
|
request_id=request_id, |
|
wav=tensor.squeeze().cpu().numpy(), |
|
sample_rate=sample_rate |
|
) |
|
|
|
@classmethod |
|
def from_file(cls, request_id: str, filename: Union[str, Path]) -> 'XTTSOutput': |
|
"""Create XTTSOutput from audio file. |
|
|
|
Args: |
|
request_id: Request identifier |
|
filename: Path to audio file |
|
|
|
Returns: |
|
New XTTSOutput instance |
|
""" |
|
wav_tensor, sample_rate = torchaudio.load(filename) |
|
return cls.from_tensor(request_id, wav_tensor, sample_rate) |
|
|
|
def play(self) -> None: |
|
"""Play the audio through the default sound device. |
|
For use in regular Python scripts/applications.""" |
|
|
|
if isinstance(self.wav, torch.Tensor): |
|
audio_data = self.wav.cpu().numpy() |
|
else: |
|
audio_data = self.wav |
|
|
|
|
|
if audio_data.dtype != np.float32: |
|
audio_data = audio_data.astype(np.float32) |
|
audio_data = np.clip(audio_data, -1.0, 1.0) |
|
|
|
|
|
sd.play(audio_data, self.sample_rate) |
|
sd.wait() |
|
|
|
def display(self) -> Optional[Audio]: |
|
"""Display audio player in Jupyter notebook. |
|
Returns Audio widget if in notebook, None otherwise.""" |
|
try: |
|
|
|
audio_bytes = self.to_bytes(format='wav') |
|
|
|
|
|
audio_widget = Audio(audio_bytes, rate=self.sample_rate, autoplay=False) |
|
display(audio_widget) |
|
return audio_widget |
|
except Exception as e: |
|
print(f"Could not display audio widget: {str(e)}") |
|
print("Try using .play() method instead") |
|
return None |
|
|
|
def preview(self) -> None: |
|
"""Smart play method that chooses appropriate playback method.""" |
|
try: |
|
|
|
if self.display() is None: |
|
|
|
self.play() |
|
except Exception as e: |
|
print(f"Error playing audio: {str(e)}") |
|
|
|
|
|
class Xtts(nn.Module): |
|
"""Async XTTS model implementation using VLLM's AsyncEngine.""" |
|
|
|
def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs): |
|
super().__init__() |
|
|
|
self.hifi_config = hifi_config |
|
self.gpt_config = gpt_config |
|
self.mel_bos_token_id = gpt_config.start_audio_token |
|
self.mel_eos_token_id = gpt_config.stop_audio_token |
|
self.tp = tensor_parallel_size |
|
self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt") |
|
self.request_counter = Counter() |
|
self.executor = ThreadPoolExecutor(max_workers=4) |
|
self.hidden_states_collector = HiddenStatesCollector() |
|
|
|
|
|
self.register_buffer("mel_stats", torch.ones(80)) |
|
|
|
|
|
self.conditioning_encoder = ConditioningEncoder( |
|
gpt_config.audio_config.mel_channels, |
|
gpt_config.hidden_size, |
|
num_attn_heads=gpt_config.num_attention_heads |
|
) |
|
|
|
self.text_embedding = nn.Embedding( |
|
gpt_config.number_text_tokens, |
|
gpt_config.hidden_size |
|
) |
|
|
|
self.text_pos_embedding = ( |
|
LearnedPositionEmbeddings( |
|
gpt_config.max_text_tokens + 2, |
|
gpt_config.hidden_size, |
|
supports_pp=False |
|
) |
|
if gpt_config.max_audio_tokens != -1 |
|
else functools.partial(gpt_config.null_position_embeddings, dim=gpt_config.hidden_size) |
|
) |
|
|
|
if gpt_config.use_perceiver_resampler: |
|
self.conditioning_perceiver = PerceiverResampler( |
|
dim=gpt_config.hidden_size, |
|
depth=2, |
|
dim_context=gpt_config.hidden_size, |
|
num_latents=32, |
|
dim_head=64, |
|
heads=8, |
|
ff_mult=4, |
|
use_flash_attn=False, |
|
) |
|
|
|
|
|
self.hifigan_decoder = HifiDecoder( |
|
input_sample_rate=self.hifi_config.input_sample_rate, |
|
output_sample_rate=self.hifi_config.output_sample_rate, |
|
output_hop_length=self.hifi_config.output_hop_length, |
|
ar_mel_length_compression=self.hifi_config.gpt_code_stride_len, |
|
decoder_input_dim=self.hifi_config.decoder_input_dim, |
|
d_vector_dim=self.hifi_config.d_vector_dim, |
|
cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer, |
|
) |
|
|
|
|
|
self.text_head = nn.Linear(gpt_config.hidden_size, gpt_config.number_text_tokens, bias=True) |
|
self.final_norm = nn.LayerNorm(gpt_config.hidden_size, eps=1e-5, bias=True) |
|
|
|
|
|
self.init_vllm_engine() |
|
|
|
|
|
self.max_concurrency = 10 |
|
self.semaphore = asyncio.BoundedSemaphore(self.max_concurrency) |
|
|
|
def half(self): |
|
|
|
return |
|
|
|
def to(self, *args, **kwargs): |
|
|
|
dtype = kwargs.get('dtype', None) |
|
if dtype == torch.float16 or dtype == torch.bfloat16: |
|
kwargs['dtype'] = torch.float32 |
|
elif len(args) > 0 and (args[0] == torch.float16 or args[0] == torch.bfloat16): |
|
args = list(args) |
|
args[0] = torch.float32 |
|
args = tuple(args) |
|
return super().to(*args, **kwargs) |
|
|
|
@property |
|
def device(self): |
|
"""Get the current device of the model.""" |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def dtype(self): |
|
"""Get the current dtype of the model.""" |
|
return next(self.parameters()).dtype |
|
|
|
@staticmethod |
|
def get_memory_percentage(memory: int) -> float: |
|
"""Get memory percentage.""" |
|
total_memory = torch.cuda.get_device_properties(0).total_memory |
|
reserved_memory = torch.cuda.memory_reserved(0) |
|
allocated_memory = torch.cuda.memory_allocated(0) |
|
available_memory = total_memory - reserved_memory - allocated_memory |
|
return memory / available_memory |
|
|
|
def init_vllm_engine(self): |
|
"""Initialize models with AsyncVLLMEngine.""" |
|
engine_args = AsyncEngineArgs( |
|
model="AstraMindAI/xtts2-gpt", |
|
tensor_parallel_size=self.tp, |
|
dtype="auto", |
|
disable_log_stats=True, |
|
max_model_len=self.gpt_config.max_text_tokens + self.gpt_config.max_audio_tokens, |
|
gpu_memory_utilization=self.get_memory_percentage(3 * 1024 ** 3), |
|
trust_remote_code=True, |
|
enforce_eager=True, |
|
limit_mm_per_prompt={"audio": 1}, |
|
max_num_batched_tokens=7296, |
|
) |
|
|
|
self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: str, |
|
torch_dtype: torch.dtype = torch.float32, |
|
device_map: Optional[str] = "auto", |
|
tensor_parallel_size: int = 1, |
|
**kwargs, |
|
) -> "Xtts": |
|
"""Load pretrained XTTS model from HuggingFace Hub.""" |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
import os |
|
|
|
|
|
if not os.path.exists(pretrained_model_name_or_path): |
|
config_file = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="config.json" |
|
) |
|
with open(config_file, 'r') as f: |
|
config = json.load(f) |
|
|
|
else: |
|
|
|
with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
gpt_config = XTTSGPTConfig(**config['gpt_config']) |
|
hifi_config = XTTSConfig(**config) |
|
|
|
|
|
model = cls( |
|
hifi_config=hifi_config, |
|
gpt_config=gpt_config, |
|
tensor_parallel_size=tensor_parallel_size, |
|
**kwargs |
|
) |
|
|
|
|
|
if not os.path.exists(pretrained_model_name_or_path): |
|
hifigan_weights = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="xtts-v2.safetensors" |
|
) |
|
else: |
|
hifigan_weights = os.path.join(pretrained_model_name_or_path, "xtts-v2.safetensors") |
|
|
|
import safetensors.torch |
|
|
|
|
|
hifigan_state = safetensors.torch.load_file(hifigan_weights) |
|
model.load_state_dict(hifigan_state) |
|
|
|
|
|
model.config = config |
|
|
|
|
|
model = model.to(torch_dtype) |
|
model = model.to('cuda') |
|
|
|
return model |
|
|
|
@staticmethod |
|
def load_audio(audio_path: Union[str, Path], sampling_rate: int = 22050) -> torch.Tensor: |
|
audio, lsr = torchaudio.load(audio_path) |
|
|
|
|
|
if audio.size(0) != 1: |
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
|
if lsr != sampling_rate: |
|
audio = torchaudio.functional.resample(audio, lsr, sampling_rate) |
|
|
|
|
|
audio.clip_(-1, 1) |
|
return audio |
|
|
|
@torch.inference_mode() |
|
def get_speaker_embedding(self, audio, sr): |
|
audio_16k = torchaudio.functional.resample(audio, sr, 16000) |
|
return ( |
|
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True) |
|
.unsqueeze(-1) |
|
.to(self.device) |
|
) |
|
|
|
@torch.inference_mode() |
|
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6): |
|
"""Compute the conditioning latents for the GPT model from the given audio.""" |
|
if sr != 22050: |
|
audio = torchaudio.functional.resample(audio, sr, 22050) |
|
if length > 0: |
|
audio = audio[:, : 22050 * length] |
|
if self.gpt_config.use_perceiver_resampler: |
|
style_embs = [] |
|
for i in range(0, audio.shape[1], 22050 * chunk_length): |
|
audio_chunk = audio[:, i: i + 22050 * chunk_length] |
|
|
|
|
|
if audio_chunk.size(-1) < 22050 * 0.33: |
|
continue |
|
|
|
mel_chunk = wav_to_mel_cloning( |
|
audio_chunk, |
|
mel_norms=self.mel_stats.cpu(), |
|
n_fft=2048, |
|
hop_length=256, |
|
win_length=1024, |
|
power=2, |
|
normalized=False, |
|
sample_rate=22050, |
|
f_min=0, |
|
f_max=8000, |
|
n_mels=80, |
|
) |
|
style_emb = self.get_style_emb(mel_chunk.to(self.device), None) |
|
style_embs.append(style_emb) |
|
|
|
|
|
cond_latent = torch.stack(style_embs).mean(dim=0) |
|
else: |
|
mel = wav_to_mel_cloning( |
|
audio, |
|
mel_norms=self.mel_stats.cpu(), |
|
n_fft=4096, |
|
hop_length=1024, |
|
win_length=4096, |
|
power=2, |
|
normalized=False, |
|
sample_rate=22050, |
|
f_min=0, |
|
f_max=8000, |
|
n_mels=80, |
|
) |
|
cond_latent = self.get_style_emb(mel.to(self.device)) |
|
return cond_latent.transpose(1, 2) |
|
|
|
@torch.inference_mode() |
|
def get_conditioning_latents( |
|
self, |
|
audio_path, |
|
max_ref_length=30, |
|
gpt_cond_len=6, |
|
gpt_cond_chunk_len=6, |
|
librosa_trim_db=None, |
|
sound_norm_refs=False, |
|
load_sr=22050, |
|
): |
|
"""Get the conditioning latents for the GPT model from the given audio.""" |
|
|
|
assert isinstance(audio_path, str) or isinstance(audio_path, list), "audio_path must be a string or a list." |
|
|
|
if not isinstance(audio_path, list): |
|
audio_paths = [audio_path] |
|
else: |
|
audio_paths = audio_path |
|
|
|
speaker_embeddings = [] |
|
audios = [] |
|
for file_path in audio_paths: |
|
audio = load_audio(file_path, load_sr) |
|
audio = audio[:, : load_sr * max_ref_length].to(self.device).to(self.dtype) |
|
if sound_norm_refs: |
|
audio = (audio / torch.abs(audio).max()) * 0.75 |
|
if librosa_trim_db is not None: |
|
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] |
|
|
|
|
|
speaker_embedding = self.get_speaker_embedding(audio, load_sr) |
|
speaker_embeddings.append(speaker_embedding) |
|
|
|
audios.append(audio) |
|
|
|
|
|
full_audio = torch.cat(audios, dim=-1) |
|
gpt_cond_latents = self.get_gpt_cond_latents( |
|
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len |
|
) |
|
|
|
speaker_embedding = torch.stack(speaker_embeddings) |
|
speaker_embedding = speaker_embedding.mean(dim=0) |
|
|
|
return gpt_cond_latents, speaker_embedding |
|
|
|
def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor: |
|
"""Get conditioning embeddings from mel spectrograms.""" |
|
if not return_latent: |
|
if cond_input.ndim == 4: |
|
cond_input = cond_input.squeeze(1) |
|
conds = self.conditioning_encoder(cond_input) |
|
|
|
if hasattr(self, 'conditioning_perceiver'): |
|
conds = self.conditioning_perceiver( |
|
conds.permute(0, 2, 1) |
|
).transpose(1, 2) |
|
else: |
|
conds = cond_input.unsqueeze(1) |
|
return conds |
|
|
|
async def prepare_text_tokens_async(self, text: str, language: str, split_text=False) \ |
|
-> Tuple[List[Union[int, List[int]]], List[torch.Tensor]]: |
|
"""Prepare text tokens for the given text and language.""" |
|
|
|
async def elaborate_tokens(text_tokens: List[int]) -> torch.Tensor: |
|
text_tokens.insert(0, self.tokenizer.bos_token_id) |
|
text_tokens.append(self.tokenizer.eos_token_id) |
|
return torch.tensor(text_tokens).unsqueeze(0).to(self.text_embedding.weight.device) |
|
|
|
async def embed_tokens(text_tokens: Union[torch.Tensor, List[torch.Tensor]]) -> List[torch.Tensor]: |
|
embeds = [] |
|
if isinstance(text_tokens, list): |
|
for list_element in text_tokens: |
|
embeds.append(self.text_embedding(list_element) + self.text_pos_embedding(list_element)) |
|
else: |
|
embeds.append(self.text_embedding(text_tokens) + self.text_pos_embedding(text_tokens)) |
|
return embeds |
|
|
|
fake_tokens_for_audio_generation = [] |
|
if split_text: |
|
text_tokens = self.tokenizer.batch_encode_with_split(text, lang=[language]) |
|
for idx, text_token in enumerate(text_tokens): |
|
text_tokens[idx] = await elaborate_tokens(text_token) |
|
fake_tokens_for_audio_generation.append([1] * len(text_token)) |
|
else: |
|
text_tokens = self.tokenizer.batch_encode(text, lang=[language]) |
|
text_tokens = await elaborate_tokens(text_tokens) |
|
fake_tokens_for_audio_generation = [1] * len(text_tokens) |
|
return fake_tokens_for_audio_generation, await embed_tokens(text_tokens) |
|
|
|
async def prepare_inputs_async(self, text: str, language: str, speaker_file: Union[str, Path], |
|
max_ref_length: int, gpt_cond_len: int, gpt_cond_chunk_len: int, split_text: bool) \ |
|
-> Tuple[List[List[int]], List[torch.Tensor], torch.Tensor]: |
|
"""Prepare input text with conditioning tokens. Return combined conditioning latents""" |
|
|
|
text_tokens, text_embeddings = await self.prepare_text_tokens_async(text, language, split_text) |
|
|
|
|
|
gpt_cond_latent, speaker_embeddings = await self.get_conditioning_latents_async( |
|
speaker_file, |
|
max_ref_length, |
|
gpt_cond_len, |
|
gpt_cond_chunk_len |
|
) |
|
|
|
cond_latents = [] |
|
for text_embedding in text_embeddings: |
|
|
|
cond_latents.append((torch.cat([gpt_cond_latent, text_embedding], dim=1).squeeze(0) |
|
.to(self.llm_engine.engine.model_config.dtype))) |
|
|
|
return text_tokens, cond_latents, speaker_embeddings |
|
|
|
async def get_conditioning_latents_async( |
|
self, |
|
audio_path, |
|
max_ref_length=30, |
|
gpt_cond_len=6, |
|
gpt_cond_chunk_len=6, |
|
librosa_trim_db=None, |
|
sound_norm_refs=False, |
|
load_sr=22050, |
|
): |
|
"""Async version of get_conditioning_latents with concurrency control.""" |
|
async with self.semaphore: |
|
|
|
result = await asyncio.get_event_loop().run_in_executor( |
|
None, |
|
functools.partial(self.get_conditioning_latents, |
|
audio_path, |
|
max_ref_length, |
|
gpt_cond_len, |
|
gpt_cond_chunk_len, |
|
librosa_trim_db, |
|
sound_norm_refs, |
|
load_sr) |
|
) |
|
return result |
|
|
|
async def get_model_logits(self, token_ids: List[int], conditioning: MultiModalDataDict) -> torch.Tensor: |
|
"""Get model logits for a specific request""" |
|
request_id = uuid.uuid4().hex |
|
|
|
|
|
token_ids = [self.mel_bos_token_id] + token_ids + [self.mel_eos_token_id] * 5 |
|
|
|
engine_inputs = TokensPrompt(prompt_token_ids=token_ids) |
|
engine_inputs["multi_modal_data"] = conditioning |
|
|
|
|
|
bound_collector = self.hidden_states_collector.bind_to_request(request_id) |
|
|
|
|
|
sampling_params = ExtendedSamplingParams( |
|
detokenize=False, |
|
max_tokens=1, |
|
hidden_state_collector=bound_collector, |
|
) |
|
|
|
|
|
generator = self.llm_engine.generate( |
|
prompt=engine_inputs, |
|
sampling_params=sampling_params, |
|
request_id=request_id |
|
) |
|
|
|
|
|
try: |
|
async def consume_generator(): |
|
async for _ in generator: |
|
pass |
|
|
|
await asyncio.wait_for(consume_generator(), timeout=300) |
|
except asyncio.TimeoutError: |
|
raise RuntimeError("Timeout while generating logits") |
|
|
|
|
|
hidden_states = self.hidden_states_collector.get_hidden_states(request_id) |
|
|
|
if hidden_states is None: |
|
raise RuntimeError(f"No hidden states collected for request {request_id}") |
|
|
|
return hidden_states[-len(token_ids):, ...].unsqueeze(0).to(self.device).to(self.dtype) |
|
|
|
|
|
async def process_tokens_to_speech( |
|
self, |
|
generators: List[AsyncGenerator[RequestOutput, None]], |
|
speaker_embeddings: torch.Tensor, |
|
multimodal_data: List[torch.Tensor], |
|
chunk_size: int = 20, |
|
) -> AsyncGenerator[XTTSOutput, None]: |
|
""" |
|
Process multiple token generators concurrently and emit results sequentially. |
|
Uses a queue-based approach to handle multiple generators reliably. |
|
""" |
|
|
|
queues = [asyncio.Queue() for _ in generators] |
|
|
|
|
|
tasks = [] |
|
for i, generator in enumerate(generators): |
|
task = asyncio.create_task( |
|
self._process_single_generator( |
|
generator, |
|
queues[i], |
|
speaker_embeddings, |
|
multimodal_data[i], |
|
chunk_size |
|
) |
|
) |
|
tasks.append(task) |
|
|
|
try: |
|
|
|
for i, queue in enumerate(queues): |
|
while True: |
|
result = await queue.get() |
|
if result is None: |
|
|
|
break |
|
else: |
|
yield result |
|
|
|
finally: |
|
|
|
for task in tasks: |
|
if not task.done(): |
|
task.cancel() |
|
await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
async def _process_single_generator( |
|
self, |
|
generator: AsyncGenerator[RequestOutput, None], |
|
queue: asyncio.Queue, |
|
speaker_embeddings: torch.Tensor, |
|
gpt_embed_input: torch.Tensor, |
|
chunk_size: int |
|
) -> None: |
|
"""Process a single generator and put results in its queue.""" |
|
try: |
|
last_decoded_token = 0 |
|
accumulated_tokens = [] |
|
|
|
async for output in generator: |
|
|
|
new_tokens = output.outputs[0].token_ids[last_decoded_token:] |
|
accumulated_tokens.extend(new_tokens) |
|
last_decoded_token = len(accumulated_tokens) |
|
|
|
|
|
if output.finished: |
|
|
|
hidden_states = await self.get_model_logits( |
|
accumulated_tokens, |
|
{ |
|
"audio": { |
|
'embeds': gpt_embed_input, |
|
"is_logits_only_mode": True |
|
} |
|
} |
|
) |
|
|
|
|
|
wav = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, |
|
lambda: self.hifigan_decoder.inference( |
|
hidden_states, |
|
g=speaker_embeddings |
|
).cpu().numpy().squeeze() |
|
) |
|
|
|
|
|
await queue.put(XTTSOutput( |
|
request_id=output.request_id, |
|
wav=wav |
|
)) |
|
|
|
|
|
accumulated_tokens = [] |
|
|
|
if output.finished: |
|
break |
|
|
|
except Exception as e: |
|
logging.error(f"Error in generator processing: {e}") |
|
finally: |
|
|
|
await queue.put(None) |
|
|
|
async def generate_speech_async_from_streaming_source(self, request: XTTSRequest) -> AsyncGenerator[XTTSOutput, None]: |
|
"""Generate speech for streaming source of text, making a streaming source of audio tokens and then decoding |
|
and returning a streaming audio response.""" |
|
assert isinstance(request.text, AsyncGenerator), "Text must be an AsyncGenerator for streaming source." |
|
|
|
gpt_cond_latent, speaker_embeddings = await self.get_conditioning_latents_async( |
|
request.speaker_file, |
|
request.max_ref_length, |
|
request.gpt_cond_len, |
|
request.gpt_cond_chunk_len |
|
) |
|
sampling_params = SamplingParams( |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
detokenize=False, |
|
top_k=request.top_k, |
|
logits_processors=[LogitsRepetitionPenalizer(request.repetition_penalty)], |
|
repetition_penalty=1.0, |
|
max_tokens=self.gpt_config.gpt_max_audio_tokens, |
|
ignore_eos=True, |
|
stop_token_ids=[self.mel_eos_token_id], |
|
) |
|
|
|
accumulated_text = "" |
|
async for text in request.text: |
|
text = text.strip() |
|
accumulated_text += text |
|
|
|
if len(accumulated_text) > request.generate_every_n_chars: |
|
tokens, embeddings = await self.prepare_text_tokens_async(accumulated_text, request.language) |
|
gpt_embed_input = [torch.cat([gpt_cond_latent, embeddings[0]], dim=0)] |
|
|
|
engine_inputs = TokensPrompt(prompt_token_ids=tokens) |
|
if gpt_embed_input is not None: |
|
engine_inputs["multi_modal_data"] = {"audio": {"embeds": gpt_embed_input, "is_logits_only_mode": False}} |
|
token_generator = [self.llm_engine.generate( |
|
prompt=engine_inputs, |
|
sampling_params=sampling_params, |
|
request_id=request.request_id, |
|
)] |
|
|
|
async for output in self.process_tokens_to_speech( |
|
token_generator, |
|
speaker_embeddings, |
|
gpt_embed_input, |
|
chunk_size=50 |
|
): |
|
yield output |
|
|
|
accumulated_text = "" |
|
|
|
async def generate_speech_from_text_async(self, request: XTTSRequest) -> AsyncGenerator[XTTSOutput, None]: |
|
"""Generate speech for a single request asynchronously.""" |
|
|
|
tokens_list, gpt_embed_inputs, speaker_embeddings = await self.prepare_inputs_async( |
|
request.text, |
|
request.language, |
|
request.speaker_file, |
|
request.max_ref_length, |
|
request.gpt_cond_len, |
|
request.gpt_cond_chunk_len, |
|
split_text=True |
|
) |
|
|
|
|
|
generators = [] |
|
for seq_index, sequence in enumerate(tokens_list): |
|
sampling_params = SamplingParams( |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
detokenize=False, |
|
top_k=request.top_k, |
|
logits_processors=[LogitsRepetitionPenalizer(request.repetition_penalty)], |
|
repetition_penalty=1.0, |
|
max_tokens=self.gpt_config.gpt_max_audio_tokens, |
|
ignore_eos=True, |
|
stop_token_ids=[self.mel_eos_token_id], |
|
) |
|
|
|
engine_inputs = TokensPrompt(prompt_token_ids=sequence) |
|
if gpt_embed_inputs is not None: |
|
engine_inputs["multi_modal_data"] = {"audio": {"embeds": gpt_embed_inputs[seq_index], "is_logits_only_mode": False}} |
|
|
|
|
|
token_generator = self.llm_engine.generate( |
|
prompt=engine_inputs, |
|
sampling_params=sampling_params, |
|
request_id=f"{request.request_id}_{seq_index}", |
|
) |
|
generators.append(token_generator) |
|
|
|
|
|
async for output in self.process_tokens_to_speech( |
|
generators, |
|
speaker_embeddings, |
|
gpt_embed_inputs, |
|
chunk_size=50 |
|
): |
|
yield output |
|
|
|
def generate_speech_from_text(self, request: XTTSRequest) -> List[XTTSOutput]: |
|
""" |
|
Synchronous wrapper for generate_speech_from_text_async. |
|
|
|
Args: |
|
request: XTTSRequest object containing generation parameters |
|
|
|
Returns: |
|
List of XTTSOutput containing the generated speech segments |
|
""" |
|
|
|
async def _collect_outputs(): |
|
outputs = [] |
|
async for output in self.generate_speech_from_text_async(request): |
|
outputs.append(output) |
|
return outputs |
|
|
|
|
|
import asyncio |
|
|
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
if loop.is_running(): |
|
|
|
new_loop = asyncio.new_event_loop() |
|
results = new_loop.run_until_complete(_collect_outputs()) |
|
new_loop.close() |
|
else: |
|
results = loop.run_until_complete(_collect_outputs()) |
|
|
|
return results |
|
|