msekoyan commited on
Commit
647d57e
·
2 Parent(s): 4d083e5 41fce2e

Merge branch 'main' of https://huggingface.co/spaces/nvidia/canary-1b-v2

Browse files
Files changed (1) hide show
  1. app.py +79 -38
app.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  import os
11
  import gradio.themes as gr_themes
12
  import csv
 
13
  from supported_languages import SUPPORTED_LANGS_MAP
14
  from datetime import timedelta, datetime
15
 
@@ -25,6 +26,34 @@ AVAILABLE_TGT_LANGS = list(SUPPORTED_LANGS_MAP.keys())
25
  DEFAULT_TGT_LANG = "English"
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def start_session(request: gr.Request):
29
  session_hash = request.session_hash
30
  session_dir = Path(f'/tmp/{session_hash}')
@@ -104,16 +133,21 @@ def get_audio_segment(audio_path, start_second, end_second):
104
  def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_lang):
105
  if not audio_path:
106
  gr.Error("No audio file path provided for transcription.", duration=None)
107
- # Return an update to hide the button
108
- return [], [], None, gr.DownloadButton(visible=False)
109
 
110
  vis_data = [["N/A", "N/A", "Processing failed"]]
111
  raw_times_data = [[0.0, 0.0]]
112
  processed_audio_path = None
113
  csv_file_path = None
 
114
  original_path_name = Path(audio_path).name
115
  audio_name = Path(audio_path).stem
116
 
 
 
 
 
117
  try:
118
  try:
119
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
@@ -121,8 +155,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
121
  print('Audio loaded successfully')
122
  except Exception as load_e:
123
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
124
- # Return an update to hide the button
125
- return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
126
 
127
  resampled = False
128
  mono = False
@@ -134,8 +167,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
134
  resampled = True
135
  except Exception as resample_e:
136
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
137
- # Return an update to hide the button
138
- return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
139
 
140
  if audio.channels == 2:
141
  try:
@@ -143,12 +175,10 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
143
  mono = True
144
  except Exception as mono_e:
145
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
146
- # Return an update to hide the button
147
- return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
148
  elif audio.channels > 2:
149
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
150
- # Return an update to hide the button
151
- return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
152
 
153
  if resampled or mono:
154
  try:
@@ -160,68 +190,77 @@ def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_l
160
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
161
  if processed_audio_path and os.path.exists(processed_audio_path):
162
  os.remove(processed_audio_path)
163
- # Return an update to hide the button
164
- return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
165
  else:
166
  transcribe_path = audio_path
167
  info_path_name = original_path_name
168
 
169
  try:
170
  model.to(device)
171
- gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
 
 
 
 
 
172
  output = model.transcribe([transcribe_path], timestamps=True, source_lang=SUPPORTED_LANGS_MAP[source_lang], target_lang=SUPPORTED_LANGS_MAP[target_lang])
173
 
174
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
175
- gr.Error("Transcription failed or produced unexpected output format.", duration=None)
176
- # Return an update to hide the button
177
- return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
178
 
179
  segment_timestamps = output[0].timestamp['segment']
180
  csv_headers = ["Start (s)", "End (s)", "Segment"]
181
  vis_data = [[f"{sec_to_hrs(ts['start'])}", f"{sec_to_hrs(ts['end'])}", ts['segment']] for ts in segment_timestamps]
182
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
183
 
184
- # Default button update (hidden) in case CSV writing fails
185
- button_update = gr.DownloadButton(visible=False)
186
  try:
187
- csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
188
  writer = csv.writer(open(csv_file_path, 'w'))
189
  writer.writerow(csv_headers)
190
  writer.writerows(vis_data)
191
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
192
- # If CSV is saved, create update to show button with path
193
- button_update = gr.DownloadButton(value=csv_file_path, visible=True)
194
  except Exception as csv_e:
195
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
196
  print(f"Error writing CSV: {csv_e}")
197
- # csv_file_path remains None, button_update remains hidden
198
 
199
- gr.Info("Transcription complete.", duration=2)
200
- # Return the data and the button update dictionary
201
- return vis_data, raw_times_data, audio_path, button_update
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  except torch.cuda.OutOfMemoryError as e:
204
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
205
  print(f"CUDA OutOfMemoryError: {e}")
206
  gr.Error(error_msg, duration=None)
207
- # Return an update to hide the button
208
- return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
209
 
210
  except FileNotFoundError:
211
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
212
  print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
213
  gr.Error(error_msg, duration=None)
214
- # Return an update to hide the button
215
- return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
216
 
217
  except Exception as e:
218
- error_msg = f"Transcription failed: {e}"
219
- print(f"Error during transcription processing: {e}")
220
  gr.Error(error_msg, duration=None)
221
  vis_data = [["Error", "Error", error_msg]]
222
  raw_times_data = [[0.0, 0.0]]
223
- # Return an update to hide the button
224
- return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
225
  finally:
226
  try:
227
  if 'model' in locals() and hasattr(model, 'cpu'):
@@ -364,14 +403,16 @@ with gr.Blocks(theme=nvidia_theme) as demo:
364
  gr.Markdown("---")
365
  gr.HTML("<h3 style='text-align: center'>Ready to dive in? Click on the text to jump to the part you need!</h3>")
366
 
367
- # Define the DownloadButton *before* the DataFrame
368
- download_btn = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
 
 
369
 
370
  vis_timestamps_df = gr.DataFrame(
371
  headers=["Start (s)", "End (s)", "Segment"],
372
  datatype=["number", "number", "str"],
373
  wrap=True,
374
- label="Transcription Segments"
375
  )
376
 
377
  # selected_segment_player was defined after download_btn previously, keep it after df for layout
@@ -392,14 +433,14 @@ with gr.Blocks(theme=nvidia_theme) as demo:
392
  mic_transcribe_btn.click(
393
  fn=get_transcripts_and_raw_times,
394
  inputs=[mic_input, session_dir, source_lang_dropdown, target_lang_dropdown],
395
- outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
396
  api_name="transcribe_mic"
397
  )
398
 
399
  file_transcribe_btn.click(
400
  fn=get_transcripts_and_raw_times,
401
  inputs=[file_input, session_dir, source_lang_dropdown, target_lang_dropdown],
402
- outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
403
  api_name="transcribe_file"
404
  )
405
 
 
10
  import os
11
  import gradio.themes as gr_themes
12
  import csv
13
+ import datetime
14
  from supported_languages import SUPPORTED_LANGS_MAP
15
  from datetime import timedelta, datetime
16
 
 
26
  DEFAULT_TGT_LANG = "English"
27
 
28
 
29
+ def format_srt_time(seconds: float) -> str:
30
+ """Converts seconds to SRT time format HH:MM:SS,mmm using datetime.timedelta"""
31
+ sanitized_total_seconds = max(0.0, seconds)
32
+ delta = datetime.timedelta(seconds=sanitized_total_seconds)
33
+ total_int_seconds = int(delta.total_seconds())
34
+
35
+ hours = total_int_seconds // 3600
36
+ remainder_seconds_after_hours = total_int_seconds % 3600
37
+ minutes = remainder_seconds_after_hours // 60
38
+ seconds_part = remainder_seconds_after_hours % 60
39
+ milliseconds = delta.microseconds // 1000
40
+
41
+ return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"
42
+
43
+ def generate_srt_content(segment_timestamps: list) -> str:
44
+ """Generates SRT formatted string from segment timestamps."""
45
+ srt_content = []
46
+ for i, ts in enumerate(segment_timestamps):
47
+ start_time = format_srt_time(ts['start'])
48
+ end_time = format_srt_time(ts['end'])
49
+ text = ts['segment']
50
+ srt_content.append(str(i + 1))
51
+ srt_content.append(f"{start_time} --> {end_time}")
52
+ srt_content.append(text)
53
+ srt_content.append("")
54
+ return "\n".join(srt_content)
55
+
56
+
57
  def start_session(request: gr.Request):
58
  session_hash = request.session_hash
59
  session_dir = Path(f'/tmp/{session_hash}')
 
133
  def get_transcripts_and_raw_times(audio_path, session_dir, source_lang, target_lang):
134
  if not audio_path:
135
  gr.Error("No audio file path provided for transcription.", duration=None)
136
+ # Return an update to hide the buttons
137
+ return [], [], None, gr.DownloadButton(label="Download Transcript (CSV)", visible=False), gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
138
 
139
  vis_data = [["N/A", "N/A", "Processing failed"]]
140
  raw_times_data = [[0.0, 0.0]]
141
  processed_audio_path = None
142
  csv_file_path = None
143
+ srt_file_path = None
144
  original_path_name = Path(audio_path).name
145
  audio_name = Path(audio_path).stem
146
 
147
+ # Initialize button states
148
+ csv_button_update = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
149
+ srt_button_update = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
150
+
151
  try:
152
  try:
153
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
 
155
  print('Audio loaded successfully')
156
  except Exception as load_e:
157
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
158
+ return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
159
 
160
  resampled = False
161
  mono = False
 
167
  resampled = True
168
  except Exception as resample_e:
169
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
170
+ return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
171
 
172
  if audio.channels == 2:
173
  try:
 
175
  mono = True
176
  except Exception as mono_e:
177
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
178
+ return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
179
  elif audio.channels > 2:
180
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
181
+ return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
182
 
183
  if resampled or mono:
184
  try:
 
190
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
191
  if processed_audio_path and os.path.exists(processed_audio_path):
192
  os.remove(processed_audio_path)
193
+ return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
194
  else:
195
  transcribe_path = audio_path
196
  info_path_name = original_path_name
197
 
198
  try:
199
  model.to(device)
200
+ if source_lang == target_lang:
201
+ task = "Transcribing"
202
+ else:
203
+ task = "Translating"
204
+ gr.Info(f"{task} {info_path_name} from {source_lang} to {target_lang}", duration=2)
205
+
206
  output = model.transcribe([transcribe_path], timestamps=True, source_lang=SUPPORTED_LANGS_MAP[source_lang], target_lang=SUPPORTED_LANGS_MAP[target_lang])
207
 
208
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
209
+ gr.Error("Prediction failed or produced unexpected output format.", duration=None)
210
+ return [["Error", "Error", "Prediction Format Issue"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
211
 
212
  segment_timestamps = output[0].timestamp['segment']
213
  csv_headers = ["Start (s)", "End (s)", "Segment"]
214
  vis_data = [[f"{sec_to_hrs(ts['start'])}", f"{sec_to_hrs(ts['end'])}", ts['segment']] for ts in segment_timestamps]
215
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
216
 
217
+ # CSV file generation
 
218
  try:
219
+ csv_file_path = Path(session_dir, f"{task}_{audio_name}_{source_lang}_{target_lang}.csv")
220
  writer = csv.writer(open(csv_file_path, 'w'))
221
  writer.writerow(csv_headers)
222
  writer.writerows(vis_data)
223
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
224
+ csv_button_update = gr.DownloadButton(value=csv_file_path, visible=True, label="Download Transcript (CSV)")
 
225
  except Exception as csv_e:
226
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
227
  print(f"Error writing CSV: {csv_e}")
 
228
 
229
+ # SRT file generation
230
+ if segment_timestamps:
231
+ try:
232
+ srt_content = generate_srt_content(segment_timestamps)
233
+ srt_file_path = Path(session_dir, f"{task}_{audio_name}_{source_lang}_{target_lang}.srt")
234
+ with open(srt_file_path, 'w', encoding='utf-8') as f:
235
+ f.write(srt_content)
236
+ print(f"SRT transcript saved to temporary file: {srt_file_path}")
237
+ srt_button_update = gr.DownloadButton(value=srt_file_path, visible=True, label="Download Transcript (SRT)")
238
+ except Exception as srt_e:
239
+ gr.Warning(f"Failed to create transcript SRT file: {srt_e}", duration=5)
240
+ print(f"Error writing SRT: {srt_e}")
241
+
242
+ gr.Info(f"{task} complete.", duration=2)
243
+ return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
244
 
245
  except torch.cuda.OutOfMemoryError as e:
246
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
247
  print(f"CUDA OutOfMemoryError: {e}")
248
  gr.Error(error_msg, duration=None)
249
+ return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
250
 
251
  except FileNotFoundError:
252
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
253
  print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
254
  gr.Error(error_msg, duration=None)
255
+ return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
256
 
257
  except Exception as e:
258
+ error_msg = f"Prediction failed: {e}"
259
+ print(f"Error during prediction processing: {e}")
260
  gr.Error(error_msg, duration=None)
261
  vis_data = [["Error", "Error", error_msg]]
262
  raw_times_data = [[0.0, 0.0]]
263
+ return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
 
264
  finally:
265
  try:
266
  if 'model' in locals() and hasattr(model, 'cpu'):
 
403
  gr.Markdown("---")
404
  gr.HTML("<h3 style='text-align: center'>Ready to dive in? Click on the text to jump to the part you need!</h3>")
405
 
406
+ # Define the DownloadButtons *before* the DataFrame
407
+ with gr.Row():
408
+ download_btn_csv = gr.DownloadButton(label="Download CSV", visible=False)
409
+ download_btn_srt = gr.DownloadButton(label="Download SRT", visible=False)
410
 
411
  vis_timestamps_df = gr.DataFrame(
412
  headers=["Start (s)", "End (s)", "Segment"],
413
  datatype=["number", "number", "str"],
414
  wrap=True,
415
+ label="Segments"
416
  )
417
 
418
  # selected_segment_player was defined after download_btn previously, keep it after df for layout
 
433
  mic_transcribe_btn.click(
434
  fn=get_transcripts_and_raw_times,
435
  inputs=[mic_input, session_dir, source_lang_dropdown, target_lang_dropdown],
436
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
437
  api_name="transcribe_mic"
438
  )
439
 
440
  file_transcribe_btn.click(
441
  fn=get_transcripts_and_raw_times,
442
  inputs=[file_input, session_dir, source_lang_dropdown, target_lang_dropdown],
443
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
444
  api_name="transcribe_file"
445
  )
446