Ticoliro commited on
Commit
bb2b924
·
verified ·
1 Parent(s): c7ff3c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -8
app.py CHANGED
@@ -7,12 +7,16 @@ import re
7
 
8
  from parler_tts import ParlerTTSForConditionalGeneration
9
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
 
 
 
10
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
  repo_id = "parler-tts/parler-tts-mini-expresso"
14
 
15
- model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
 
16
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
17
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
18
 
@@ -66,15 +70,30 @@ def preprocess(text):
66
  return text
67
 
68
 
69
- @spaces.GPU
70
- def gen_tts(text, description):
71
- inputs = tokenizer(description, return_tensors="pt").to(device)
72
- prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
73
 
74
- set_seed(SEED)
75
- generation = model.generate(input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids)
76
- audio_arr = generation.cpu().numpy().squeeze()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return SAMPLE_RATE, audio_arr
79
 
80
 
 
7
 
8
  from parler_tts import ParlerTTSForConditionalGeneration
9
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
10
+ from functools import lru_cache
11
+ from torch.cuda.amp import autocast
12
+ import time
13
 
14
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
 
16
  repo_id = "parler-tts/parler-tts-mini-expresso"
17
 
18
+ model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.float16).to(device)
19
+ model = torch.compile(model) # Adiciona otimização com torch.compile
20
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
21
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
22
 
 
70
  return text
71
 
72
 
73
+ @lru_cache(maxsize=128)
74
+ def cached_tokenizer(text):
75
+ return tokenizer(text, return_tensors="pt").to(device)
 
76
 
 
 
 
77
 
78
+ @spaces.GPU
79
+ def gen_tts(text, description):
80
+ start_time = time.time()
81
+ with torch.no_grad(): # Desativa gradientes
82
+ inputs = cached_tokenizer(description)
83
+ prompt = cached_tokenizer(preprocess(text))
84
+
85
+ set_seed(SEED)
86
+ with autocast(): # Habilita precisão mista
87
+ generation = model.generate(
88
+ input_ids=inputs.input_ids,
89
+ prompt_input_ids=prompt.input_ids,
90
+ max_length=200, # Limita o comprimento máximo da saída
91
+ num_beams=3 # Usa beam search com 3 feixes
92
+ )
93
+ audio_arr = generation.cpu().numpy().squeeze()
94
+
95
+ end_time = time.time()
96
+ print(f"Generation completed in {end_time - start_time:.2f} seconds")
97
  return SAMPLE_RATE, audio_arr
98
 
99