junxiliu commited on
Commit
44f10fc
·
1 Parent(s): ac87bd7

update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -32,7 +32,7 @@ if torch.cuda.is_available():
32
  setup_eval_logging()
33
  OUTPUT_DIR = Path("./output/gradio")
34
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
35
- NUM_SAMPLE=5
36
  snapshot_download(repo_id="google/flan-t5-large")
37
  a=AutoModel.from_pretrained('bert-base-uncased')
38
  b=AutoModel.from_pretrained('roberta-base')
@@ -207,6 +207,8 @@ def generate_audio_gradio(
207
  scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
208
  audio_embed,
209
  dim=-1)
 
 
210
  audio=audios[torch.argmax(scores).item()].float().cpu()
211
  safe_prompt = (
212
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
 
32
  setup_eval_logging()
33
  OUTPUT_DIR = Path("./output/gradio")
34
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
35
+ NUM_SAMPLE=8
36
  snapshot_download(repo_id="google/flan-t5-large")
37
  a=AutoModel.from_pretrained('bert-base-uncased')
38
  b=AutoModel.from_pretrained('roberta-base')
 
207
  scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
208
  audio_embed,
209
  dim=-1)
210
+ log.info(scores)
211
+ log.info(torch.argmax(scores).item())
212
  audio=audios[torch.argmax(scores).item()].float().cpu()
213
  safe_prompt = (
214
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))