Update custom_interface_app.py
Browse files- custom_interface_app.py +32 -5
custom_interface_app.py
CHANGED
@@ -41,7 +41,7 @@ class ASR(Pretrained):
|
|
41 |
# Forward encoder + decoder
|
42 |
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
43 |
tokens = tokens.to(device)
|
44 |
-
enc_out, logits, _ = self.mods.whisper(wavs, tokens)
|
45 |
log_probs = self.hparams.log_softmax(logits)
|
46 |
|
47 |
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
@@ -100,6 +100,30 @@ class ASR(Pretrained):
|
|
100 |
seq.append(token)
|
101 |
output = []
|
102 |
return seq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
|
105 |
def classify_file_w2v2(self, path, device):
|
@@ -150,7 +174,6 @@ class ASR(Pretrained):
|
|
150 |
|
151 |
def classify_file_whisper_mkd(self, path, device):
|
152 |
# Load the audio file
|
153 |
-
# path = "long_sample.wav"
|
154 |
waveform, sr = librosa.load(path, sr=16000)
|
155 |
|
156 |
# Get audio length in seconds
|
@@ -178,7 +201,8 @@ class ASR(Pretrained):
|
|
178 |
|
179 |
# Fake a batch for the segment
|
180 |
batch = segment_tensor.unsqueeze(0).to(device)
|
181 |
-
|
|
|
182 |
|
183 |
# Pass the segment through the ASR model
|
184 |
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
@@ -186,13 +210,14 @@ class ASR(Pretrained):
|
|
186 |
else:
|
187 |
waveform = torch.tensor(waveform).to(device)
|
188 |
waveform = waveform.to(device)
|
189 |
-
# Fake a batch:
|
190 |
batch = waveform.unsqueeze(0)
|
191 |
-
|
|
|
192 |
outputs = self.encode_batch_whisper(device, batch, rel_length)
|
193 |
yield outputs
|
194 |
|
195 |
|
|
|
196 |
def classify_file_whisper(self, path, pipe, device):
|
197 |
waveform, sr = librosa.load(path, sr=16000)
|
198 |
transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
|
@@ -228,6 +253,7 @@ class ASR(Pretrained):
|
|
228 |
|
229 |
# Pass the segment through the ASR model
|
230 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
|
|
231 |
outputs = model(**inputs).logits
|
232 |
ids = torch.argmax(outputs, dim=-1)[0]
|
233 |
segment_output = processor.decode(ids)
|
@@ -235,6 +261,7 @@ class ASR(Pretrained):
|
|
235 |
else:
|
236 |
waveform = torch.tensor(waveform).to(device)
|
237 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
|
|
238 |
outputs = model(**inputs).logits
|
239 |
ids = torch.argmax(outputs, dim=-1)[0]
|
240 |
transcription = processor.decode(ids)
|
|
|
41 |
# Forward encoder + decoder
|
42 |
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
43 |
tokens = tokens.to(device)
|
44 |
+
enc_out, logits, _ = self.mods.whisper(wavs.detach(), tokens.detach())
|
45 |
log_probs = self.hparams.log_softmax(logits)
|
46 |
|
47 |
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
|
|
100 |
seq.append(token)
|
101 |
output = []
|
102 |
return seq
|
103 |
+
|
104 |
+
|
105 |
+
def increase_volume(self, waveform, threshold_db=-25):
|
106 |
+
# Measure loudness using RMS
|
107 |
+
loudness_vector = librosa.feature.rms(y=waveform)
|
108 |
+
average_loudness = np.mean(loudness_vector)
|
109 |
+
average_loudness_db = librosa.amplitude_to_db(average_loudness)
|
110 |
+
|
111 |
+
print(f"Average Loudness: {average_loudness_db} dB")
|
112 |
+
|
113 |
+
# Check if loudness is below threshold and apply gain if needed
|
114 |
+
if average_loudness_db < threshold_db:
|
115 |
+
# Calculate gain needed
|
116 |
+
gain_db = threshold_db - average_loudness_db
|
117 |
+
gain = librosa.db_to_amplitude(gain_db) # Convert dB to amplitude factor
|
118 |
+
|
119 |
+
# Apply gain to the audio signal
|
120 |
+
waveform = waveform * gain
|
121 |
+
loudness_vector = librosa.feature.rms(y=waveform)
|
122 |
+
average_loudness = np.mean(loudness_vector)
|
123 |
+
average_loudness_db = librosa.amplitude_to_db(average_loudness)
|
124 |
+
|
125 |
+
print(f"Average Loudness: {average_loudness_db} dB")
|
126 |
+
return waveform
|
127 |
|
128 |
|
129 |
def classify_file_w2v2(self, path, device):
|
|
|
174 |
|
175 |
def classify_file_whisper_mkd(self, path, device):
|
176 |
# Load the audio file
|
|
|
177 |
waveform, sr = librosa.load(path, sr=16000)
|
178 |
|
179 |
# Get audio length in seconds
|
|
|
201 |
|
202 |
# Fake a batch for the segment
|
203 |
batch = segment_tensor.unsqueeze(0).to(device)
|
204 |
+
batch = batch.to(torch.float16)
|
205 |
+
rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
|
206 |
|
207 |
# Pass the segment through the ASR model
|
208 |
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
|
|
210 |
else:
|
211 |
waveform = torch.tensor(waveform).to(device)
|
212 |
waveform = waveform.to(device)
|
|
|
213 |
batch = waveform.unsqueeze(0)
|
214 |
+
batch = batch.to(torch.float16)
|
215 |
+
rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
|
216 |
outputs = self.encode_batch_whisper(device, batch, rel_length)
|
217 |
yield outputs
|
218 |
|
219 |
|
220 |
+
|
221 |
def classify_file_whisper(self, path, pipe, device):
|
222 |
waveform, sr = librosa.load(path, sr=16000)
|
223 |
transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
|
|
|
253 |
|
254 |
# Pass the segment through the ASR model
|
255 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
256 |
+
inputs['input_values'] = inputs['input_values'].to(torch.float16)
|
257 |
outputs = model(**inputs).logits
|
258 |
ids = torch.argmax(outputs, dim=-1)[0]
|
259 |
segment_output = processor.decode(ids)
|
|
|
261 |
else:
|
262 |
waveform = torch.tensor(waveform).to(device)
|
263 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
264 |
+
inputs['input_values'] = inputs['input_values'].to(torch.float16)
|
265 |
outputs = model(**inputs).logits
|
266 |
ids = torch.argmax(outputs, dim=-1)[0]
|
267 |
transcription = processor.decode(ids)
|