Porjaz commited on
Commit
fed23a6
·
verified ·
1 Parent(s): 51c1dfe

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +32 -5
custom_interface_app.py CHANGED
@@ -41,7 +41,7 @@ class ASR(Pretrained):
41
  # Forward encoder + decoder
42
  tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
43
  tokens = tokens.to(device)
44
- enc_out, logits, _ = self.mods.whisper(wavs, tokens)
45
  log_probs = self.hparams.log_softmax(logits)
46
 
47
  hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
@@ -100,6 +100,30 @@ class ASR(Pretrained):
100
  seq.append(token)
101
  output = []
102
  return seq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  def classify_file_w2v2(self, path, device):
@@ -150,7 +174,6 @@ class ASR(Pretrained):
150
 
151
  def classify_file_whisper_mkd(self, path, device):
152
  # Load the audio file
153
- # path = "long_sample.wav"
154
  waveform, sr = librosa.load(path, sr=16000)
155
 
156
  # Get audio length in seconds
@@ -178,7 +201,8 @@ class ASR(Pretrained):
178
 
179
  # Fake a batch for the segment
180
  batch = segment_tensor.unsqueeze(0).to(device)
181
- rel_length = torch.tensor([1.0]).to(device)
 
182
 
183
  # Pass the segment through the ASR model
184
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
@@ -186,13 +210,14 @@ class ASR(Pretrained):
186
  else:
187
  waveform = torch.tensor(waveform).to(device)
188
  waveform = waveform.to(device)
189
- # Fake a batch:
190
  batch = waveform.unsqueeze(0)
191
- rel_length = torch.tensor([1.0]).to(device)
 
192
  outputs = self.encode_batch_whisper(device, batch, rel_length)
193
  yield outputs
194
 
195
 
 
196
  def classify_file_whisper(self, path, pipe, device):
197
  waveform, sr = librosa.load(path, sr=16000)
198
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
@@ -228,6 +253,7 @@ class ASR(Pretrained):
228
 
229
  # Pass the segment through the ASR model
230
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
 
231
  outputs = model(**inputs).logits
232
  ids = torch.argmax(outputs, dim=-1)[0]
233
  segment_output = processor.decode(ids)
@@ -235,6 +261,7 @@ class ASR(Pretrained):
235
  else:
236
  waveform = torch.tensor(waveform).to(device)
237
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
 
238
  outputs = model(**inputs).logits
239
  ids = torch.argmax(outputs, dim=-1)[0]
240
  transcription = processor.decode(ids)
 
41
  # Forward encoder + decoder
42
  tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
43
  tokens = tokens.to(device)
44
+ enc_out, logits, _ = self.mods.whisper(wavs.detach(), tokens.detach())
45
  log_probs = self.hparams.log_softmax(logits)
46
 
47
  hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
100
  seq.append(token)
101
  output = []
102
  return seq
103
+
104
+
105
+ def increase_volume(self, waveform, threshold_db=-25):
106
+ # Measure loudness using RMS
107
+ loudness_vector = librosa.feature.rms(y=waveform)
108
+ average_loudness = np.mean(loudness_vector)
109
+ average_loudness_db = librosa.amplitude_to_db(average_loudness)
110
+
111
+ print(f"Average Loudness: {average_loudness_db} dB")
112
+
113
+ # Check if loudness is below threshold and apply gain if needed
114
+ if average_loudness_db < threshold_db:
115
+ # Calculate gain needed
116
+ gain_db = threshold_db - average_loudness_db
117
+ gain = librosa.db_to_amplitude(gain_db) # Convert dB to amplitude factor
118
+
119
+ # Apply gain to the audio signal
120
+ waveform = waveform * gain
121
+ loudness_vector = librosa.feature.rms(y=waveform)
122
+ average_loudness = np.mean(loudness_vector)
123
+ average_loudness_db = librosa.amplitude_to_db(average_loudness)
124
+
125
+ print(f"Average Loudness: {average_loudness_db} dB")
126
+ return waveform
127
 
128
 
129
  def classify_file_w2v2(self, path, device):
 
174
 
175
  def classify_file_whisper_mkd(self, path, device):
176
  # Load the audio file
 
177
  waveform, sr = librosa.load(path, sr=16000)
178
 
179
  # Get audio length in seconds
 
201
 
202
  # Fake a batch for the segment
203
  batch = segment_tensor.unsqueeze(0).to(device)
204
+ batch = batch.to(torch.float16)
205
+ rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
206
 
207
  # Pass the segment through the ASR model
208
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
 
210
  else:
211
  waveform = torch.tensor(waveform).to(device)
212
  waveform = waveform.to(device)
 
213
  batch = waveform.unsqueeze(0)
214
+ batch = batch.to(torch.float16)
215
+ rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
216
  outputs = self.encode_batch_whisper(device, batch, rel_length)
217
  yield outputs
218
 
219
 
220
+
221
  def classify_file_whisper(self, path, pipe, device):
222
  waveform, sr = librosa.load(path, sr=16000)
223
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
 
253
 
254
  # Pass the segment through the ASR model
255
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
256
+ inputs['input_values'] = inputs['input_values'].to(torch.float16)
257
  outputs = model(**inputs).logits
258
  ids = torch.argmax(outputs, dim=-1)[0]
259
  segment_output = processor.decode(ids)
 
261
  else:
262
  waveform = torch.tensor(waveform).to(device)
263
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
264
+ inputs['input_values'] = inputs['input_values'].to(torch.float16)
265
  outputs = model(**inputs).logits
266
  ids = torch.argmax(outputs, dim=-1)[0]
267
  transcription = processor.decode(ids)