nithinraok commited on
Commit
41fce2e
·
1 Parent(s): c2d1a8a

SRT and cosmetics

Browse files

Signed-off-by: nithinraok <[email protected]>

Files changed (1) hide show
  1. app.py +79 -38
app.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  import os
11
  import gradio.themes as gr_themes
12
  import csv
 
13
  from supported_languages import SUPPORTED_LANGS_MAP
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -24,6 +25,34 @@ AVAILABLE_TGT_LANGS = list(SUPPORTED_LANGS_MAP.keys())
24
  DEFAULT_TGT_LANG = "English"
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def start_session(request: gr.Request):
28
  session_hash = request.session_hash
29
  session_dir = Path(f'/tmp/{session_hash}')
@@ -100,16 +129,21 @@ def get_audio_segment(audio_path, start_second, end_second):
100
  def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_lang):
101
  if not audio_path:
102
  gr.Error("No audio file path provided for transcription.", duration=None)
103
- # Return an update to hide the button
104
- return [], [], None, gr.DownloadButton(visible=False)
105
 
106
  vis_data = [["N/A", "N/A", "Processing failed"]]
107
  raw_times_data = [[0.0, 0.0]]
108
  processed_audio_path = None
109
  csv_file_path = None
 
110
  original_path_name = Path(audio_path).name
111
  audio_name = Path(audio_path).stem
112
 
 
 
 
 
113
  try:
114
  try:
115
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
@@ -117,8 +151,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
117
  print('Audio loaded successfully')
118
  except Exception as load_e:
119
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
120
- # Return an update to hide the button
121
- return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
122
 
123
  resampled = False
124
  mono = False
@@ -130,8 +163,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
130
  resampled = True
131
  except Exception as resample_e:
132
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
133
- # Return an update to hide the button
134
- return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
135
 
136
  if audio.channels == 2:
137
  try:
@@ -139,12 +171,10 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
139
  mono = True
140
  except Exception as mono_e:
141
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
142
- # Return an update to hide the button
143
- return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
144
  elif audio.channels > 2:
145
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
146
- # Return an update to hide the button
147
- return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
148
 
149
  if resampled or mono:
150
  try:
@@ -156,68 +186,77 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
156
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
157
  if processed_audio_path and os.path.exists(processed_audio_path):
158
  os.remove(processed_audio_path)
159
- # Return an update to hide the button
160
- return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
161
  else:
162
  transcribe_path = audio_path
163
  info_path_name = original_path_name
164
 
165
  try:
166
  model.to(device)
167
- gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
 
 
 
 
 
168
  output = model.transcribe([transcribe_path], timestamps=True, source_lang=SUPPORTED_LANGS_MAP[source_lang], target_lang=SUPPORTED_LANGS_MAP[target_lang])
169
 
170
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
171
- gr.Error("Transcription failed or produced unexpected output format.", duration=None)
172
- # Return an update to hide the button
173
- return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
174
 
175
  segment_timestamps = output[0].timestamp['segment']
176
  csv_headers = ["Start (s)", "End (s)", "Segment"]
177
  vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
178
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
179
 
180
- # Default button update (hidden) in case CSV writing fails
181
- button_update = gr.DownloadButton(visible=False)
182
  try:
183
- csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
184
  writer = csv.writer(open(csv_file_path, 'w'))
185
  writer.writerow(csv_headers)
186
  writer.writerows(vis_data)
187
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
188
- # If CSV is saved, create update to show button with path
189
- button_update = gr.DownloadButton(value=csv_file_path, visible=True)
190
  except Exception as csv_e:
191
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
192
  print(f"Error writing CSV: {csv_e}")
193
- # csv_file_path remains None, button_update remains hidden
194
 
195
- gr.Info("Transcription complete.", duration=2)
196
- # Return the data and the button update dictionary
197
- return vis_data, raw_times_data, audio_path, button_update
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  except torch.cuda.OutOfMemoryError as e:
200
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
201
  print(f"CUDA OutOfMemoryError: {e}")
202
  gr.Error(error_msg, duration=None)
203
- # Return an update to hide the button
204
- return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
205
 
206
  except FileNotFoundError:
207
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
208
  print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
209
  gr.Error(error_msg, duration=None)
210
- # Return an update to hide the button
211
- return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
212
 
213
  except Exception as e:
214
- error_msg = f"Transcription failed: {e}"
215
- print(f"Error during transcription processing: {e}")
216
  gr.Error(error_msg, duration=None)
217
  vis_data = [["Error", "Error", error_msg]]
218
  raw_times_data = [[0.0, 0.0]]
219
- # Return an update to hide the button
220
- return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
221
  finally:
222
  try:
223
  if 'model' in locals() and hasattr(model, 'cpu'):
@@ -354,14 +393,16 @@ with gr.Blocks(theme=nvidia_theme) as demo:
354
  gr.Markdown("---")
355
  gr.HTML("<h3 style='text-align: center'>Ready to dive in? Click on the text to jump to the part you need!</h3>")
356
 
357
- # Define the DownloadButton *before* the DataFrame
358
- download_btn = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
 
 
359
 
360
  vis_timestamps_df = gr.DataFrame(
361
  headers=["Start (s)", "End (s)", "Segment"],
362
  datatype=["number", "number", "str"],
363
  wrap=True,
364
- label="Transcription Segments"
365
  )
366
 
367
  # selected_segment_player was defined after download_btn previously, keep it after df for layout
@@ -382,14 +423,14 @@ with gr.Blocks(theme=nvidia_theme) as demo:
382
  mic_transcribe_btn.click(
383
  fn=get_transcripts_and_raw_times,
384
  inputs=[mic_input, session_dir, source_lang_dropdown, target_lang_dropdown],
385
- outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
386
  api_name="transcribe_mic"
387
  )
388
 
389
  file_transcribe_btn.click(
390
  fn=get_transcripts_and_raw_times,
391
  inputs=[file_input, session_dir, source_lang_dropdown, target_lang_dropdown],
392
- outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
393
  api_name="transcribe_file"
394
  )
395
 
 
10
  import os
11
  import gradio.themes as gr_themes
12
  import csv
13
+ import datetime
14
  from supported_languages import SUPPORTED_LANGS_MAP
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
25
  DEFAULT_TGT_LANG = "English"
26
 
27
 
28
+ def format_srt_time(seconds: float) -> str:
29
+ """Converts seconds to SRT time format HH:MM:SS,mmm using datetime.timedelta"""
30
+ sanitized_total_seconds = max(0.0, seconds)
31
+ delta = datetime.timedelta(seconds=sanitized_total_seconds)
32
+ total_int_seconds = int(delta.total_seconds())
33
+
34
+ hours = total_int_seconds // 3600
35
+ remainder_seconds_after_hours = total_int_seconds % 3600
36
+ minutes = remainder_seconds_after_hours // 60
37
+ seconds_part = remainder_seconds_after_hours % 60
38
+ milliseconds = delta.microseconds // 1000
39
+
40
+ return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"
41
+
42
+ def generate_srt_content(segment_timestamps: list) -> str:
43
+ """Generates SRT formatted string from segment timestamps."""
44
+ srt_content = []
45
+ for i, ts in enumerate(segment_timestamps):
46
+ start_time = format_srt_time(ts['start'])
47
+ end_time = format_srt_time(ts['end'])
48
+ text = ts['segment']
49
+ srt_content.append(str(i + 1))
50
+ srt_content.append(f"{start_time} --> {end_time}")
51
+ srt_content.append(text)
52
+ srt_content.append("")
53
+ return "\n".join(srt_content)
54
+
55
+
56
  def start_session(request: gr.Request):
57
  session_hash = request.session_hash
58
  session_dir = Path(f'/tmp/{session_hash}')
 
129
  def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_lang):
130
  if not audio_path:
131
  gr.Error("No audio file path provided for transcription.", duration=None)
132
+ # Return an update to hide the buttons
133
+ return [], [], None, gr.DownloadButton(label="Download Transcript (CSV)", visible=False), gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
134
 
135
  vis_data = [["N/A", "N/A", "Processing failed"]]
136
  raw_times_data = [[0.0, 0.0]]
137
  processed_audio_path = None
138
  csv_file_path = None
139
+ srt_file_path = None
140
  original_path_name = Path(audio_path).name
141
  audio_name = Path(audio_path).stem
142
 
143
+ # Initialize button states
144
+ csv_button_update = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
145
+ srt_button_update = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
146
+
147
  try:
148
  try:
149
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
 
151
  print('Audio loaded successfully')
152
  except Exception as load_e:
153
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
154
+ return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
155
 
156
  resampled = False
157
  mono = False
 
163
  resampled = True
164
  except Exception as resample_e:
165
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
166
+ return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
167
 
168
  if audio.channels == 2:
169
  try:
 
171
  mono = True
172
  except Exception as mono_e:
173
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
174
+ return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
175
  elif audio.channels > 2:
176
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
177
+ return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
178
 
179
  if resampled or mono:
180
  try:
 
186
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
187
  if processed_audio_path and os.path.exists(processed_audio_path):
188
  os.remove(processed_audio_path)
189
+ return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
190
  else:
191
  transcribe_path = audio_path
192
  info_path_name = original_path_name
193
 
194
  try:
195
  model.to(device)
196
+ if source_lang == target_lang:
197
+ task = "Transcribing"
198
+ else:
199
+ task = "Translating"
200
+ gr.Info(f"{task} {info_path_name} from {source_lang} to {target_lang}", duration=2)
201
+
202
  output = model.transcribe([transcribe_path], timestamps=True, source_lang=SUPPORTED_LANGS_MAP[source_lang], target_lang=SUPPORTED_LANGS_MAP[target_lang])
203
 
204
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
205
+ gr.Error("Prediction failed or produced unexpected output format.", duration=None)
206
+ return [["Error", "Error", "Prediction Format Issue"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
207
 
208
  segment_timestamps = output[0].timestamp['segment']
209
  csv_headers = ["Start (s)", "End (s)", "Segment"]
210
  vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
211
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
212
 
213
+ # CSV file generation
 
214
  try:
215
+ csv_file_path = Path(session_dir, f"{task}_{audio_name}_{source_lang}_{target_lang}.csv")
216
  writer = csv.writer(open(csv_file_path, 'w'))
217
  writer.writerow(csv_headers)
218
  writer.writerows(vis_data)
219
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
220
+ csv_button_update = gr.DownloadButton(value=csv_file_path, visible=True, label="Download Transcript (CSV)")
 
221
  except Exception as csv_e:
222
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
223
  print(f"Error writing CSV: {csv_e}")
 
224
 
225
+ # SRT file generation
226
+ if segment_timestamps:
227
+ try:
228
+ srt_content = generate_srt_content(segment_timestamps)
229
+ srt_file_path = Path(session_dir, f"{task}_{audio_name}_{source_lang}_{target_lang}.srt")
230
+ with open(srt_file_path, 'w', encoding='utf-8') as f:
231
+ f.write(srt_content)
232
+ print(f"SRT transcript saved to temporary file: {srt_file_path}")
233
+ srt_button_update = gr.DownloadButton(value=srt_file_path, visible=True, label="Download Transcript (SRT)")
234
+ except Exception as srt_e:
235
+ gr.Warning(f"Failed to create transcript SRT file: {srt_e}", duration=5)
236
+ print(f"Error writing SRT: {srt_e}")
237
+
238
+ gr.Info(f"{task} complete.", duration=2)
239
+ return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
240
 
241
  except torch.cuda.OutOfMemoryError as e:
242
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
243
  print(f"CUDA OutOfMemoryError: {e}")
244
  gr.Error(error_msg, duration=None)
245
+ return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
246
 
247
  except FileNotFoundError:
248
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
249
  print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
250
  gr.Error(error_msg, duration=None)
251
+ return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
252
 
253
  except Exception as e:
254
+ error_msg = f"Prediction failed: {e}"
255
+ print(f"Error during prediction processing: {e}")
256
  gr.Error(error_msg, duration=None)
257
  vis_data = [["Error", "Error", error_msg]]
258
  raw_times_data = [[0.0, 0.0]]
259
+ return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
 
260
  finally:
261
  try:
262
  if 'model' in locals() and hasattr(model, 'cpu'):
 
393
  gr.Markdown("---")
394
  gr.HTML("<h3 style='text-align: center'>Ready to dive in? Click on the text to jump to the part you need!</h3>")
395
 
396
+ # Define the DownloadButtons *before* the DataFrame
397
+ with gr.Row():
398
+ download_btn_csv = gr.DownloadButton(label="Download CSV", visible=False)
399
+ download_btn_srt = gr.DownloadButton(label="Download SRT", visible=False)
400
 
401
  vis_timestamps_df = gr.DataFrame(
402
  headers=["Start (s)", "End (s)", "Segment"],
403
  datatype=["number", "number", "str"],
404
  wrap=True,
405
+ label="Segments"
406
  )
407
 
408
  # selected_segment_player was defined after download_btn previously, keep it after df for layout
 
423
  mic_transcribe_btn.click(
424
  fn=get_transcripts_and_raw_times,
425
  inputs=[mic_input, session_dir, source_lang_dropdown, target_lang_dropdown],
426
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
427
  api_name="transcribe_mic"
428
  )
429
 
430
  file_transcribe_btn.click(
431
  fn=get_transcripts_and_raw_times,
432
  inputs=[file_input, session_dir, source_lang_dropdown, target_lang_dropdown],
433
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
434
  api_name="transcribe_file"
435
  )
436