avans06 commited on
Commit
7339550
·
1 Parent(s): d80a2c7

feat: Add detailed progress bars for single and batch jobs

Browse files
Files changed (1) hide show
  1. app.py +51 -9
app.py CHANGED
@@ -1158,17 +1158,28 @@ def _transcribe_stem(audio_path: str, base_name: str, temp_dir: str, params: App
1158
 
1159
 
1160
  # --- The core processing engine for a single file ---
1161
- def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppParameters):
1162
  """
1163
  This is the main processing engine. It takes a file path and a dictionary of all settings,
1164
  and performs the full pipeline: load, separate, transcribe, render, re-merge.
1165
  It is UI-agnostic and returns file paths and data, not Gradio updates.
 
1166
  """
 
 
 
 
 
1167
  # --- Start timer for this specific file ---
1168
  file_start_time = reqtime.time()
1169
 
1170
  filename = os.path.basename(input_file_path)
1171
  base_name = os.path.splitext(filename)[0]
 
 
 
 
 
1172
  print(f"\n{'='*20} Starting Pipeline for: {filename} {'='*20}")
1173
 
1174
  # --- Use the provided timestamp for unique filenames ---
@@ -1179,7 +1190,9 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1179
  other_part_sr = None
1180
 
1181
  # --- Step 1: Check file type and transcribe if necessary ---
1182
- if filename.lower().endswith(('.mid', '.midi', '.kar')):
 
 
1183
  print("MIDI file detected. Skipping transcription. Proceeding directly to rendering.")
1184
  midi_path_for_rendering = input_file_path
1185
  else:
@@ -1187,6 +1200,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1187
  os.makedirs(temp_dir, exist_ok=True)
1188
 
1189
  # --- Audio Loading ---
 
1190
  print("Audio file detected. Starting pre-processing...")
1191
  # --- Robust audio loading with ffmpeg fallback ---
1192
  try:
@@ -1196,6 +1210,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1196
  audio_tensor, native_sample_rate = torchaudio.load(input_file_path)
1197
  print("Torchaudio loading successful.")
1198
  except Exception as e:
 
1199
  print(f"Torchaudio failed: {e}. Attempting fallback with ffmpeg...")
1200
  try:
1201
  # Define a path for the temporary converted file
@@ -1224,9 +1239,12 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1224
  # --- Standard Workflow: Transcribe the original full audio ---
1225
  audio_to_transcribe_path = os.path.join(temp_dir, f"{timestamped_base_name}_original.flac")
1226
  torchaudio.save(audio_to_transcribe_path, audio_tensor, native_sample_rate)
 
 
1227
  midi_path_for_rendering = _transcribe_stem(audio_to_transcribe_path, f"{timestamped_base_name}_original", temp_dir, params)
1228
  else:
1229
  # --- Vocal Separation Workflow ---
 
1230
  # Convert to a common format (stereo, float32) that demucs expects
1231
  audio_tensor = convert_audio(audio_tensor, native_sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
1232
 
@@ -1282,18 +1300,22 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1282
  # --- Main Branching Logic: Transcribe one or both stems ---
1283
  if not params.transcribe_both_stems:
1284
  print(f"Transcribing primary target only: {os.path.basename(primary_target_path)}")
 
1285
  midi_path_for_rendering = _transcribe_stem(primary_target_path, os.path.splitext(os.path.basename(primary_target_path))[0], temp_dir, params)
1286
  else:
1287
  print("Transcribing BOTH stems and merging the MIDI results.")
1288
 
1289
  # Transcribe the primary target
 
1290
  midi_path_primary = _transcribe_stem(primary_target_path, os.path.splitext(os.path.basename(primary_target_path))[0], temp_dir, params)
1291
 
1292
  # Transcribe the other part
 
1293
  midi_path_other = _transcribe_stem(other_part_path, os.path.splitext(os.path.basename(other_part_path))[0], temp_dir, params)
1294
 
1295
  # Merge the two resulting MIDI files
1296
  if midi_path_primary and midi_path_other:
 
1297
  final_merged_midi_path = os.path.join(temp_dir, f"{base_name}_full_transcription.mid")
1298
  print(f"Merging transcribed MIDI files into {os.path.basename(final_merged_midi_path)}")
1299
 
@@ -1319,10 +1341,13 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1319
  return None
1320
 
1321
  # --- Step 2: Render the FINAL MIDI file with selected options ---
 
 
1322
 
1323
  # --- Auto-Recommendation Logic ---
1324
  # If the user selected the auto-recommend option, override the parameters
1325
  if params.s8bit_preset_selector == "Auto-Recommend (Analyze MIDI)":
 
1326
  print("Auto-Recommendation is enabled. Analyzing MIDI features...")
1327
  try:
1328
  midi_to_analyze = pretty_midi.PrettyMIDI(midi_path_for_rendering)
@@ -1337,13 +1362,16 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1337
  except Exception as e:
1338
  print(f"Could not auto-recommend parameters for {filename}: {e}.")
1339
 
 
1340
  print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
1341
 
1342
  # Call the rendering function, Pass dictionaries directly to Render_MIDI
1343
  results_tuple = Render_MIDI(input_midi_path=midi_path_for_rendering, params=params)
1344
 
1345
  # --- Vocal Re-merging Logic ---
 
1346
  if params.separate_vocals and params.remerge_vocals and not params.transcribe_both_stems and other_part_tensor is not None:
 
1347
  print(f"Re-merging the non-transcribed part with newly rendered music...")
1348
 
1349
  # 1. Unpack the original rendered audio from the results
@@ -1387,6 +1415,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1387
  print("Re-merging complete.")
1388
 
1389
  # --- Save final audio and return path ---
 
1390
  final_srate, final_audio_data = results_tuple[4]
1391
  final_midi_path_from_render = results_tuple[3] # Get the path of the processed MIDI
1392
 
@@ -1421,6 +1450,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1421
  "plot": results_tuple[5],
1422
  "description": results_tuple[6]
1423
  }
 
1424
  # Return both the results and the final state of the parameters object
1425
  return results, params
1426
 
@@ -1430,10 +1460,10 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1430
  # =================================================================================================
1431
 
1432
  # --- Thin wrapper for batch processing ---
1433
- def batch_process_files(input_files, progress=gr.Progress(), *args):
1434
  """
1435
- Gradio wrapper for batch processing. It packs all UI values into an AppParameters object.
1436
- It iterates through files, calls the core pipeline, and collects the output file paths.
1437
  """
1438
 
1439
  if not input_files:
@@ -1458,10 +1488,21 @@ def batch_process_files(input_files, progress=gr.Progress(), *args):
1458
  for i, file_obj in enumerate(input_files):
1459
  # The input from gr.File is a tempfile object, we need its path
1460
  input_path = file_obj.name
 
 
 
 
 
 
 
 
 
 
 
1461
  progress(i / total_files, desc=f"Processing {os.path.basename(input_path)} ({i+1}/{total_files})")
1462
 
1463
  # --- Pass the batch_timestamp to the pipeline ---
1464
- results, _ = run_single_file_pipeline(input_path, batch_timestamp, params)
1465
 
1466
  if results:
1467
  if results.get("final_audio_path"):
@@ -1482,12 +1523,13 @@ def batch_process_files(input_files, progress=gr.Progress(), *args):
1482
 
1483
 
1484
  # --- The original function is now a thin wrapper for the single file UI ---
1485
- def process_and_render_file(input_file, *args):
1486
  """
1487
  Gradio wrapper for the single file processing UI. Packs UI values into an AppParameters object.
1488
  Calls the core pipeline and formats the output for all UI components.
1489
  Main function to handle file processing. It determines the file type and calls the
1490
  appropriate functions for transcription and/or rendering based on user selections.
 
1491
  """
1492
  if input_file is None:
1493
  # Return a list of updates to clear all output fields and UI controls
@@ -1503,8 +1545,8 @@ def process_and_render_file(input_file, *args):
1503
  # The first value in *args is s8bit_preset_selector, the rest match the keys
1504
  params = AppParameters(input_file=input_file, **dict(zip(ALL_PARAM_KEYS, args)))
1505
 
1506
- # Run the core pipeline, Pass the timestamp to the pipeline
1507
- results, final_params = run_single_file_pipeline(input_file, single_file_timestamp, params)
1508
 
1509
  if results is None:
1510
  raise gr.Error("File processing failed. Check console for details.")
 
1158
 
1159
 
1160
  # --- The core processing engine for a single file ---
1161
+ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppParameters, progress: gr.Progress = None):
1162
  """
1163
  This is the main processing engine. It takes a file path and a dictionary of all settings,
1164
  and performs the full pipeline: load, separate, transcribe, render, re-merge.
1165
  It is UI-agnostic and returns file paths and data, not Gradio updates.
1166
+ It now accepts a Gradio Progress object to report granular progress.
1167
  """
1168
+ # Helper function to safely update progress
1169
+ def update_progress(fraction, desc):
1170
+ if progress:
1171
+ progress(fraction, desc=desc)
1172
+
1173
  # --- Start timer for this specific file ---
1174
  file_start_time = reqtime.time()
1175
 
1176
  filename = os.path.basename(input_file_path)
1177
  base_name = os.path.splitext(filename)[0]
1178
+
1179
+ # --- Determine file type to select the correct progress timeline ---
1180
+ is_midi_input = filename.lower().endswith(('.mid', '.midi', '.kar'))
1181
+
1182
+ update_progress(0, f"Starting: {filename}")
1183
  print(f"\n{'='*20} Starting Pipeline for: {filename} {'='*20}")
1184
 
1185
  # --- Use the provided timestamp for unique filenames ---
 
1190
  other_part_sr = None
1191
 
1192
  # --- Step 1: Check file type and transcribe if necessary ---
1193
+ if is_midi_input:
1194
+ # For MIDI files, we start at 0% and directly proceed to the rendering steps.
1195
+ update_progress(0, "MIDI file detected, skipping transcription...")
1196
  print("MIDI file detected. Skipping transcription. Proceeding directly to rendering.")
1197
  midi_path_for_rendering = input_file_path
1198
  else:
 
1200
  os.makedirs(temp_dir, exist_ok=True)
1201
 
1202
  # --- Audio Loading ---
1203
+ update_progress(0.1, "Audio file detected, loading...")
1204
  print("Audio file detected. Starting pre-processing...")
1205
  # --- Robust audio loading with ffmpeg fallback ---
1206
  try:
 
1210
  audio_tensor, native_sample_rate = torchaudio.load(input_file_path)
1211
  print("Torchaudio loading successful.")
1212
  except Exception as e:
1213
+ update_progress(0.15, "Torchaudio failed, trying ffmpeg...")
1214
  print(f"Torchaudio failed: {e}. Attempting fallback with ffmpeg...")
1215
  try:
1216
  # Define a path for the temporary converted file
 
1239
  # --- Standard Workflow: Transcribe the original full audio ---
1240
  audio_to_transcribe_path = os.path.join(temp_dir, f"{timestamped_base_name}_original.flac")
1241
  torchaudio.save(audio_to_transcribe_path, audio_tensor, native_sample_rate)
1242
+
1243
+ update_progress(0.2, "Transcribing audio to MIDI...")
1244
  midi_path_for_rendering = _transcribe_stem(audio_to_transcribe_path, f"{timestamped_base_name}_original", temp_dir, params)
1245
  else:
1246
  # --- Vocal Separation Workflow ---
1247
+ update_progress(0.2, "Separating vocals with Demucs...")
1248
  # Convert to a common format (stereo, float32) that demucs expects
1249
  audio_tensor = convert_audio(audio_tensor, native_sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
1250
 
 
1300
  # --- Main Branching Logic: Transcribe one or both stems ---
1301
  if not params.transcribe_both_stems:
1302
  print(f"Transcribing primary target only: {os.path.basename(primary_target_path)}")
1303
+ update_progress(0.4, f"Transcribing primary target: {os.path.basename(primary_target_path)}")
1304
  midi_path_for_rendering = _transcribe_stem(primary_target_path, os.path.splitext(os.path.basename(primary_target_path))[0], temp_dir, params)
1305
  else:
1306
  print("Transcribing BOTH stems and merging the MIDI results.")
1307
 
1308
  # Transcribe the primary target
1309
+ update_progress(0.4, "Transcribing primary stem...")
1310
  midi_path_primary = _transcribe_stem(primary_target_path, os.path.splitext(os.path.basename(primary_target_path))[0], temp_dir, params)
1311
 
1312
  # Transcribe the other part
1313
+ update_progress(0.5, "Transcribing second stem...")
1314
  midi_path_other = _transcribe_stem(other_part_path, os.path.splitext(os.path.basename(other_part_path))[0], temp_dir, params)
1315
 
1316
  # Merge the two resulting MIDI files
1317
  if midi_path_primary and midi_path_other:
1318
+ update_progress(0.55, "Merging transcribed MIDIs...")
1319
  final_merged_midi_path = os.path.join(temp_dir, f"{base_name}_full_transcription.mid")
1320
  print(f"Merging transcribed MIDI files into {os.path.basename(final_merged_midi_path)}")
1321
 
 
1341
  return None
1342
 
1343
  # --- Step 2: Render the FINAL MIDI file with selected options ---
1344
+ # The progress values are now conditional based on the input file type.
1345
+ update_progress(0.1 if is_midi_input else 0.6, "Applying MIDI transformations...")
1346
 
1347
  # --- Auto-Recommendation Logic ---
1348
  # If the user selected the auto-recommend option, override the parameters
1349
  if params.s8bit_preset_selector == "Auto-Recommend (Analyze MIDI)":
1350
+ update_progress(0.15 if is_midi_input else 0.65, "Auto-recommending 8-bit parameters...")
1351
  print("Auto-Recommendation is enabled. Analyzing MIDI features...")
1352
  try:
1353
  midi_to_analyze = pretty_midi.PrettyMIDI(midi_path_for_rendering)
 
1362
  except Exception as e:
1363
  print(f"Could not auto-recommend parameters for {filename}: {e}.")
1364
 
1365
+ update_progress(0.2 if is_midi_input else 0.7, "Rendering MIDI to audio...")
1366
  print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
1367
 
1368
  # Call the rendering function, Pass dictionaries directly to Render_MIDI
1369
  results_tuple = Render_MIDI(input_midi_path=midi_path_for_rendering, params=params)
1370
 
1371
  # --- Vocal Re-merging Logic ---
1372
+ # Vocal Re-merging only happens for audio files, so its progress value doesn't need to be conditional.
1373
  if params.separate_vocals and params.remerge_vocals and not params.transcribe_both_stems and other_part_tensor is not None:
1374
+ update_progress(0.8, "Re-merging rendered audio with vocals...")
1375
  print(f"Re-merging the non-transcribed part with newly rendered music...")
1376
 
1377
  # 1. Unpack the original rendered audio from the results
 
1415
  print("Re-merging complete.")
1416
 
1417
  # --- Save final audio and return path ---
1418
+ update_progress(0.9, "Saving final files...")
1419
  final_srate, final_audio_data = results_tuple[4]
1420
  final_midi_path_from_render = results_tuple[3] # Get the path of the processed MIDI
1421
 
 
1450
  "plot": results_tuple[5],
1451
  "description": results_tuple[6]
1452
  }
1453
+ update_progress(1.0, "Done!")
1454
  # Return both the results and the final state of the parameters object
1455
  return results, params
1456
 
 
1460
  # =================================================================================================
1461
 
1462
  # --- Thin wrapper for batch processing ---
1463
+ def batch_process_files(input_files, progress=gr.Progress(track_tqdm=True), *args):
1464
  """
1465
+ Gradio wrapper for batch processing. It iterates through files, calls the core pipeline,
1466
+ and collects the output file paths. It now provides detailed, nested progress updates.
1467
  """
1468
 
1469
  if not input_files:
 
1488
  for i, file_obj in enumerate(input_files):
1489
  # The input from gr.File is a tempfile object, we need its path
1490
  input_path = file_obj.name
1491
+ filename = os.path.basename(input_path)
1492
+
1493
+ # --- Nested Progress Logic ---
1494
+ # Define a local function to scale the sub-progress of the pipeline
1495
+ # into the correct slot of the main batch progress bar.
1496
+ def batch_progress_updater(local_fraction, desc):
1497
+ # Calculate the overall progress based on which file we are on (i)
1498
+ # and the progress within that file (local_fraction).
1499
+ progress_per_file = 1 / total_files
1500
+ overall_fraction = (i / total_files) + (local_fraction * progress_per_file)
1501
+ progress(overall_fraction, desc=f"({i+1}/{total_files}) {filename}: {desc}")
1502
  progress(i / total_files, desc=f"Processing {os.path.basename(input_path)} ({i+1}/{total_files})")
1503
 
1504
  # --- Pass the batch_timestamp to the pipeline ---
1505
+ results, _ = run_single_file_pipeline(input_path, batch_timestamp, params, progress=batch_progress_updater)
1506
 
1507
  if results:
1508
  if results.get("final_audio_path"):
 
1523
 
1524
 
1525
  # --- The original function is now a thin wrapper for the single file UI ---
1526
+ def process_and_render_file(input_file, *args, progress=gr.Progress()):
1527
  """
1528
  Gradio wrapper for the single file processing UI. Packs UI values into an AppParameters object.
1529
  Calls the core pipeline and formats the output for all UI components.
1530
  Main function to handle file processing. It determines the file type and calls the
1531
  appropriate functions for transcription and/or rendering based on user selections.
1532
+ Now includes a progress bar.
1533
  """
1534
  if input_file is None:
1535
  # Return a list of updates to clear all output fields and UI controls
 
1545
  # The first value in *args is s8bit_preset_selector, the rest match the keys
1546
  params = AppParameters(input_file=input_file, **dict(zip(ALL_PARAM_KEYS, args)))
1547
 
1548
+ # Run the core pipeline, passing the timestamp and progress to the pipeline
1549
+ results, final_params = run_single_file_pipeline(input_file, single_file_timestamp, params, progress=progress)
1550
 
1551
  if results is None:
1552
  raise gr.Error("File processing failed. Check console for details.")