AshwinSankar commited on
Commit
8e05620
·
verified ·
1 Parent(s): a2e8c71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -14
app.py CHANGED
@@ -39,8 +39,10 @@ This Gradio demo showcases **IndicSeamlessM4T**, a fine-tuned **SeamlessM4T** mo
39
  """
40
 
41
  hf_token = os.getenv("HF_TOKEN")
 
 
42
 
43
- model = SeamlessM4Tv2ForSpeechToText.from_pretrained("ai4bharat/seamless-m4t-v2-large-stt", torch_dtype=torch.float16, token=hf_token).to("cuda")
44
  processor = SeamlessM4TFeatureExtractor.from_pretrained("ai4bharat/seamless-m4t-v2-large-stt", token=hf_token)
45
  tokenizer = SeamlessM4TTokenizer.from_pretrained("ai4bharat/seamless-m4t-v2-large-stt", token=hf_token)
46
 
@@ -50,17 +52,6 @@ AUDIO_SAMPLE_RATE = 16000.0
50
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
51
  DEFAULT_TARGET_LANGUAGE = "Hindi"
52
 
53
- if torch.cuda.is_available():
54
- device = torch.device("cuda:0")
55
- dtype = torch.float16
56
- else:
57
- device = torch.device("cpu")
58
- dtype = torch.float32
59
-
60
-
61
-
62
-
63
-
64
  def preprocess_audio(input_audio: str) -> None:
65
  arr, org_sr = torchaudio.load(input_audio)
66
  new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
@@ -78,7 +69,7 @@ def run_s2tt(input_audio: str, source_language: str, target_language: str) -> st
78
 
79
  input_audio, orig_freq = torchaudio.load(input_audio)
80
  input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
81
- audio_inputs= processor(input_audio, sampling_rate=16000, return_tensors="pt").to(device="cuda",dtype=torch.float16)
82
 
83
  text_out = model.generate(**audio_inputs, tgt_lang=target_language_code)[0].float().cpu().numpy().squeeze()
84
 
@@ -91,7 +82,7 @@ def run_asr(input_audio: str, target_language: str) -> str:
91
 
92
  input_audio, orig_freq = torchaudio.load(input_audio)
93
  input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
94
- audio_inputs= processor(input_audio, sampling_rate=16000, return_tensors="pt").to(device="cuda",dtype=torch.float16)
95
 
96
  text_out = model.generate(**audio_inputs, tgt_lang=target_language_code)[0].float().cpu().numpy().squeeze()
97
 
 
39
  """
40
 
41
  hf_token = os.getenv("HF_TOKEN")
42
+ device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
43
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
44
 
45
+ model = SeamlessM4Tv2ForSpeechToText.from_pretrained("ai4bharat/seamless-m4t-v2-large-stt", torch_dtype=torch_dtype, token=hf_token).to(device)
46
  processor = SeamlessM4TFeatureExtractor.from_pretrained("ai4bharat/seamless-m4t-v2-large-stt", token=hf_token)
47
  tokenizer = SeamlessM4TTokenizer.from_pretrained("ai4bharat/seamless-m4t-v2-large-stt", token=hf_token)
48
 
 
52
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
53
  DEFAULT_TARGET_LANGUAGE = "Hindi"
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  def preprocess_audio(input_audio: str) -> None:
56
  arr, org_sr = torchaudio.load(input_audio)
57
  new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
 
69
 
70
  input_audio, orig_freq = torchaudio.load(input_audio)
71
  input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
72
+ audio_inputs= processor(input_audio, sampling_rate=16000, return_tensors="pt").to(device="cuda", dtype=torch_dtype)
73
 
74
  text_out = model.generate(**audio_inputs, tgt_lang=target_language_code)[0].float().cpu().numpy().squeeze()
75
 
 
82
 
83
  input_audio, orig_freq = torchaudio.load(input_audio)
84
  input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
85
+ audio_inputs= processor(input_audio, sampling_rate=16000, return_tensors="pt").to(device="cuda", dtype=torch_dtype)
86
 
87
  text_out = model.generate(**audio_inputs, tgt_lang=target_language_code)[0].float().cpu().numpy().squeeze()
88