Spaces:
Running
Running
# tts_utils.py | |
import torch | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer | |
# Updated load_model function in tts_utils.py | |
def load_model(): | |
model = ParlerTTSForConditionalGeneration.from_pretrained( | |
"ai4bharat/indic-parler-tts", | |
torch_dtype=torch.float32 # Force CPU-compatible dtype | |
) | |
# Apply dynamic quantization to Linear layers | |
quantized_model = torch.ao.quantization.quantize_dynamic( | |
model, | |
{torch.nn.Linear}, # Target layer type | |
dtype=torch.qint8 | |
) | |
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-parler-tts") | |
description_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-parler-tts") | |
return quantized_model, tokenizer, description_tokenizer | |
def generate_speech(text, voice_prompt, model, tokenizer, description_tokenizer): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
description_input_ids = description_tokenizer( | |
voice_prompt, | |
return_tensors="pt" | |
).to(device) | |
prompt_input_ids = tokenizer(text, return_tensors="pt").to(device) | |
generation = model.generate( | |
input_ids=description_input_ids.input_ids, | |
attention_mask=description_input_ids.attention_mask, | |
prompt_input_ids=prompt_input_ids.input_ids, | |
prompt_attention_mask=prompt_input_ids.attention_mask, | |
max_new_tokens=1024 | |
) | |
return generation.cpu().numpy().squeeze() |