update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,7 @@ if torch.cuda.is_available():
|
|
32 |
setup_eval_logging()
|
33 |
OUTPUT_DIR = Path("./output/gradio")
|
34 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
35 |
-
NUM_SAMPLE=
|
36 |
snapshot_download(repo_id="google/flan-t5-large")
|
37 |
a=AutoModel.from_pretrained('bert-base-uncased')
|
38 |
b=AutoModel.from_pretrained('roberta-base')
|
@@ -207,6 +207,8 @@ def generate_audio_gradio(
|
|
207 |
scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
|
208 |
audio_embed,
|
209 |
dim=-1)
|
|
|
|
|
210 |
audio=audios[torch.argmax(scores).item()].float().cpu()
|
211 |
safe_prompt = (
|
212 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
|
|
32 |
setup_eval_logging()
|
33 |
OUTPUT_DIR = Path("./output/gradio")
|
34 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
35 |
+
NUM_SAMPLE=8
|
36 |
snapshot_download(repo_id="google/flan-t5-large")
|
37 |
a=AutoModel.from_pretrained('bert-base-uncased')
|
38 |
b=AutoModel.from_pretrained('roberta-base')
|
|
|
207 |
scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
|
208 |
audio_embed,
|
209 |
dim=-1)
|
210 |
+
log.info(scores)
|
211 |
+
log.info(torch.argmax(scores).item())
|
212 |
audio=audios[torch.argmax(scores).item()].float().cpu()
|
213 |
safe_prompt = (
|
214 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|