Joseph Pollack commited on
Commit
e83891f
·
unverified ·
1 Parent(s): bc0e217

adds additional components to the interface for reccording

Browse files
Files changed (1) hide show
  1. interface.py +189 -68
interface.py CHANGED
@@ -254,9 +254,9 @@ def start_voxtral_training(
254
  def load_multilingual_phrases(language="en", max_phrases=None, split="train"):
255
  """Load phrases from various multilingual speech datasets.
256
 
257
- Tries multiple datasets in order of preference:
258
- 1. Common Voice (most reliable and up-to-date)
259
- 2. FLEURS (Google's multilingual dataset)
260
  3. Fallback to basic phrases
261
 
262
  Args:
@@ -272,70 +272,97 @@ def load_multilingual_phrases(language="en", max_phrases=None, split="train"):
272
 
273
  # Language code mapping for different datasets
274
  lang_mappings = {
275
- "en": {"common_voice": "en", "fleurs": "en_us"},
276
- "de": {"common_voice": "de", "fleurs": "de_de"},
277
- "fr": {"common_voice": "fr", "fleurs": "fr_fr"},
278
- "es": {"common_voice": "es", "fleurs": "es_419"},
279
- "it": {"common_voice": "it", "fleurs": "it_it"},
280
- "pt": {"common_voice": "pt", "fleurs": "pt_br"},
281
- "pl": {"common_voice": "pl", "fleurs": "pl_pl"},
282
- "nl": {"common_voice": "nl", "fleurs": "nl_nl"},
283
- "ru": {"common_voice": "ru", "fleurs": "ru_ru"},
284
- "ar": {"common_voice": "ar", "fleurs": "ar_eg"},
285
- "zh": {"common_voice": "zh-CN", "fleurs": "zh_cn"},
286
- "ja": {"common_voice": "ja", "fleurs": "ja_jp"},
287
- "ko": {"common_voice": "ko", "fleurs": "ko_kr"},
288
  }
289
 
290
- lang_config = lang_mappings.get(language, {"common_voice": language, "fleurs": f"{language}_{language}"})
291
 
292
- # Try Common Voice first (most reliable)
293
  try:
294
- print(f"Trying Common Voice dataset for language: {language}")
295
- cv_lang = lang_config["common_voice"]
296
- ds = load_dataset("mozilla-foundation/common_voice_11_0", cv_lang, split=split, streaming=True)
297
 
298
  phrases = []
299
  count = 0
 
 
300
  for example in ds:
301
  if max_phrases and count >= max_phrases:
302
  break
303
- text = example.get("sentence", "").strip()
304
- if text and len(text) > 10: # Filter out very short phrases
305
- phrases.append(text)
 
306
  count += 1
307
 
308
  if phrases:
309
- print(f"Successfully loaded {len(phrases)} phrases from Common Voice")
310
  random.shuffle(phrases)
311
  return phrases
312
 
313
  except Exception as e:
314
- print(f"Common Voice failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- # Try FLEURS as backup
317
  try:
318
- print(f"Trying FLEURS dataset for language: {language}")
319
- fleurs_lang = lang_config["fleurs"]
320
- ds = load_dataset("google/fleurs", fleurs_lang, split=split, streaming=True)
321
 
322
  phrases = []
323
  count = 0
324
  for example in ds:
325
  if max_phrases and count >= max_phrases:
326
  break
327
- text = example.get("transcription", "").strip()
328
  if text and len(text) > 10: # Filter out very short phrases
329
  phrases.append(text)
330
  count += 1
331
 
332
  if phrases:
333
- print(f"Successfully loaded {len(phrases)} phrases from FLEURS")
334
  random.shuffle(phrases)
335
  return phrases
336
 
337
  except Exception as e:
338
- print(f"FLEURS failed: {e}")
339
 
340
  # Final fallback to basic phrases
341
  print("All dataset loading attempts failed, using fallback phrases")
@@ -434,17 +461,20 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
434
  # Recording grid with dynamic text readouts
435
  phrase_texts_state = gr.State(ALL_PHRASES)
436
  visible_rows_state = gr.State(10) # Start with 10 visible rows
437
- max_rows = len(ALL_PHRASES) # No cap on total rows
 
 
438
  phrase_markdowns: list[gr.Markdown] = []
439
  rec_components = []
440
 
441
- def create_recording_grid(phrases, visible_count=10):
442
- """Create recording grid components dynamically"""
443
  markdowns = []
444
  recordings = []
445
- for idx, phrase in enumerate(phrases):
446
- visible = idx < visible_count
447
- md = gr.Markdown(f"**{idx+1}. {phrase}**", visible=visible)
 
448
  markdowns.append(md)
449
  comp = gr.Audio(sources="microphone", type="numpy", label=f"Recording {idx+1}", visible=visible)
450
  recordings.append(comp)
@@ -452,44 +482,41 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
452
 
453
  # Initial grid creation
454
  with gr.Column():
455
- phrase_markdowns, rec_components = create_recording_grid(ALL_PHRASES, 10)
456
 
457
  # Add more rows button
458
  add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary")
459
 
460
  def add_more_rows(current_visible, current_phrases):
461
  """Add 10 more rows by making them visible"""
462
- new_visible = min(current_visible + 10, len(current_phrases))
 
 
463
  visibility_updates = []
464
- for i in range(len(current_phrases)):
465
- if i < new_visible:
466
  visibility_updates.append(gr.update(visible=True))
467
  else:
468
  visibility_updates.append(gr.update(visible=False))
 
469
  return [new_visible] + visibility_updates
470
 
471
  def change_language(language):
472
  """Change the language and reload phrases from multilingual datasets"""
473
  new_phrases = load_multilingual_phrases(language, max_phrases=None)
474
  # Reset visible rows to 10
475
- visible_count = min(10, len(new_phrases))
476
 
477
- # Create combined updates for existing components (up to current length)
478
- current_len = len(phrase_markdowns)
479
  combined_updates = []
480
-
481
- # Update existing components
482
- for i in range(current_len):
483
- if i < len(new_phrases):
484
- if i < visible_count:
485
- combined_updates.append(gr.update(value=f"**{i+1}. {new_phrases[i]}**", visible=True))
486
- else:
487
- combined_updates.append(gr.update(visible=False))
488
  else:
489
- combined_updates.append(gr.update(visible=False))
490
 
491
- # If we have more phrases than components, we can't update them via Gradio
492
- # The interface will need to be reloaded for significantly different phrase counts
493
  return [new_phrases, visible_count] + combined_updates
494
 
495
  # Connect language change to phrase reloading
@@ -505,6 +532,56 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
505
  outputs=[visible_rows_state] + phrase_markdowns + rec_components
506
  )
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  # Advanced options accordion
509
  with gr.Accordion("Advanced options", open=False):
510
  base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
@@ -576,22 +653,66 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
576
  vp_btn = gr.Button("Use Multilingual Dataset Sample")
577
 
578
  def _collect_multilingual_sample(lang_code: str, num_samples: int, split: str):
579
- """Load sample from multilingual datasets (Common Voice preferred)"""
580
  from datasets import load_dataset, Audio
581
  import random
582
 
583
- # Language code mapping for Common Voice
584
- cv_lang_map = {
585
  "en": "en", "de": "de", "fr": "fr", "es": "es", "it": "it",
586
  "pl": "pl", "pt": "pt", "nl": "nl", "ru": "ru", "ar": "ar",
587
- "zh": "zh-CN", "ja": "ja", "ko": "ko"
588
  }
589
 
590
- cv_lang = cv_lang_map.get(lang_code, lang_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
 
592
  try:
593
- # Try Common Voice first
594
- ds = load_dataset("mozilla-foundation/common_voice_11_0", cv_lang, split=split, streaming=True)
595
  ds = ds.cast_column("audio", Audio(sampling_rate=16000))
596
 
597
  dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
@@ -605,7 +726,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
605
 
606
  audio = ex.get("audio") or {}
607
  path = audio.get("path")
608
- text = ex.get("sentence", "").strip()
609
 
610
  if path and text and len(text) > 10:
611
  rows.append({"audio_path": path, "text": text})
@@ -618,7 +739,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
618
 
619
  # Build markdown content updates for on-screen prompts
620
  combined_updates = []
621
- for i in range(len(phrase_markdowns)):
622
  t = texts[i] if i < len(texts) else ""
623
  if i < len(texts):
624
  combined_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
@@ -628,7 +749,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
628
  return (str(jsonl_path), texts, *combined_updates)
629
 
630
  except Exception as e:
631
- print(f"Common Voice sample loading failed: {e}")
632
 
633
  # Fallback: generate synthetic samples with text only
634
  print("Using fallback: generating text-only samples")
@@ -642,7 +763,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
642
 
643
  # Build markdown content updates for on-screen prompts
644
  combined_updates = []
645
- for i in range(len(phrase_markdowns)):
646
  t = texts[i] if i < len(texts) else ""
647
  if i < len(texts):
648
  combined_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
 
254
  def load_multilingual_phrases(language="en", max_phrases=None, split="train"):
255
  """Load phrases from various multilingual speech datasets.
256
 
257
+ Uses datasets that work with current library versions:
258
+ 1. ML Commons Speech (modern format)
259
+ 2. Multilingual LibriSpeech (modern format)
260
  3. Fallback to basic phrases
261
 
262
  Args:
 
272
 
273
  # Language code mapping for different datasets
274
  lang_mappings = {
275
+ "en": {"ml_speech": "en", "librispeech": "clean"},
276
+ "de": {"ml_speech": "de", "librispeech": None},
277
+ "fr": {"ml_speech": "fr", "librispeech": None},
278
+ "es": {"ml_speech": "es", "librispeech": None},
279
+ "it": {"ml_speech": "it", "librispeech": None},
280
+ "pt": {"ml_speech": "pt", "librispeech": None},
281
+ "pl": {"ml_speech": "pl", "librispeech": None},
282
+ "nl": {"ml_speech": "nl", "librispeech": None},
283
+ "ru": {"ml_speech": "ru", "librispeech": None},
284
+ "ar": {"ml_speech": "ar", "librispeech": None},
285
+ "zh": {"ml_speech": "zh", "librispeech": None},
286
+ "ja": {"ml_speech": "ja", "librispeech": None},
287
+ "ko": {"ml_speech": "ko", "librispeech": None},
288
  }
289
 
290
+ lang_config = lang_mappings.get(language, {"ml_speech": language, "librispeech": None})
291
 
292
+ # Try ML Commons Speech first (modern format)
293
  try:
294
+ print(f"Trying ML Commons Speech dataset for language: {language}")
295
+ ml_lang = lang_config["ml_speech"]
296
+ ds = load_dataset("mlcommons/ml_spoken_words", f"speech_commands_{ml_lang}", split=split, streaming=True)
297
 
298
  phrases = []
299
  count = 0
300
+ seen_words = set()
301
+
302
  for example in ds:
303
  if max_phrases and count >= max_phrases:
304
  break
305
+ word = example.get("word", "").strip()
306
+ if word and len(word) > 2 and word not in seen_words: # Filter duplicates and short words
307
+ phrases.append(word)
308
+ seen_words.add(word)
309
  count += 1
310
 
311
  if phrases:
312
+ print(f"Successfully loaded {len(phrases)} phrases from ML Commons Speech")
313
  random.shuffle(phrases)
314
  return phrases
315
 
316
  except Exception as e:
317
+ print(f"ML Commons Speech failed: {e}")
318
+
319
+ # Try Multilingual LibriSpeech as backup
320
+ try:
321
+ if lang_config["librispeech"]:
322
+ print(f"Trying Multilingual LibriSpeech dataset for language: {language}")
323
+ librispeech_lang = lang_config["librispeech"]
324
+ ds = load_dataset("facebook/multilingual_librispeech", f"{language}", split=split, streaming=True)
325
+
326
+ phrases = []
327
+ count = 0
328
+ for example in ds:
329
+ if max_phrases and count >= max_phrases:
330
+ break
331
+ text = example.get("text", "").strip()
332
+ if text and len(text) > 10: # Filter out very short phrases
333
+ phrases.append(text)
334
+ count += 1
335
+
336
+ if phrases:
337
+ print(f"Successfully loaded {len(phrases)} phrases from Multilingual LibriSpeech")
338
+ random.shuffle(phrases)
339
+ return phrases
340
+
341
+ except Exception as e:
342
+ print(f"Multilingual LibriSpeech failed: {e}")
343
 
344
+ # Try TED Talk translations (works for many languages)
345
  try:
346
+ print(f"Trying TED Talk translations for language: {language}")
347
+ ds = load_dataset("ted_talks_iwslt", language=[f"{language}_en"], split=split, streaming=True)
 
348
 
349
  phrases = []
350
  count = 0
351
  for example in ds:
352
  if max_phrases and count >= max_phrases:
353
  break
354
+ text = example.get("translation", {}).get(language, "").strip()
355
  if text and len(text) > 10: # Filter out very short phrases
356
  phrases.append(text)
357
  count += 1
358
 
359
  if phrases:
360
+ print(f"Successfully loaded {len(phrases)} phrases from TED Talks")
361
  random.shuffle(phrases)
362
  return phrases
363
 
364
  except Exception as e:
365
+ print(f"TED Talks failed: {e}")
366
 
367
  # Final fallback to basic phrases
368
  print("All dataset loading attempts failed, using fallback phrases")
 
461
  # Recording grid with dynamic text readouts
462
  phrase_texts_state = gr.State(ALL_PHRASES)
463
  visible_rows_state = gr.State(10) # Start with 10 visible rows
464
+ MAX_COMPONENTS = 100 # Fixed maximum number of components
465
+
466
+ # Create fixed number of components upfront
467
  phrase_markdowns: list[gr.Markdown] = []
468
  rec_components = []
469
 
470
+ def create_recording_grid(max_components=MAX_COMPONENTS):
471
+ """Create recording grid components with fixed maximum"""
472
  markdowns = []
473
  recordings = []
474
+ for idx in range(max_components):
475
+ visible = idx < 10 # Only first 10 visible initially
476
+ phrase_text = ALL_PHRASES[idx] if idx < len(ALL_PHRASES) else ""
477
+ md = gr.Markdown(f"**{idx+1}. {phrase_text}**", visible=visible)
478
  markdowns.append(md)
479
  comp = gr.Audio(sources="microphone", type="numpy", label=f"Recording {idx+1}", visible=visible)
480
  recordings.append(comp)
 
482
 
483
  # Initial grid creation
484
  with gr.Column():
485
+ phrase_markdowns, rec_components = create_recording_grid(MAX_COMPONENTS)
486
 
487
  # Add more rows button
488
  add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary")
489
 
490
  def add_more_rows(current_visible, current_phrases):
491
  """Add 10 more rows by making them visible"""
492
+ new_visible = min(current_visible + 10, MAX_COMPONENTS, len(current_phrases))
493
+
494
+ # Create updates for all MAX_COMPONENTS
495
  visibility_updates = []
496
+ for i in range(MAX_COMPONENTS):
497
+ if i < len(current_phrases) and i < new_visible:
498
  visibility_updates.append(gr.update(visible=True))
499
  else:
500
  visibility_updates.append(gr.update(visible=False))
501
+
502
  return [new_visible] + visibility_updates
503
 
504
  def change_language(language):
505
  """Change the language and reload phrases from multilingual datasets"""
506
  new_phrases = load_multilingual_phrases(language, max_phrases=None)
507
  # Reset visible rows to 10
508
+ visible_count = min(10, len(new_phrases), MAX_COMPONENTS)
509
 
510
+ # Create updates for all MAX_COMPONENTS
 
511
  combined_updates = []
512
+ for i in range(MAX_COMPONENTS):
513
+ if i < len(new_phrases) and i < visible_count:
514
+ combined_updates.append(gr.update(value=f"**{i+1}. {new_phrases[i]}**", visible=True))
515
+ elif i < len(new_phrases):
516
+ combined_updates.append(gr.update(value=f"**{i+1}. {new_phrases[i]}**", visible=False))
 
 
 
517
  else:
518
+ combined_updates.append(gr.update(value=f"**{i+1}. **", visible=False))
519
 
 
 
520
  return [new_phrases, visible_count] + combined_updates
521
 
522
  # Connect language change to phrase reloading
 
532
  outputs=[visible_rows_state] + phrase_markdowns + rec_components
533
  )
534
 
535
+ # Recording dataset creation button
536
+ record_dataset_btn = gr.Button("🎙️ Create Dataset from Recordings", variant="primary")
537
+
538
+ def create_recording_dataset(*recordings_and_state):
539
+ """Create dataset from visible recordings and phrases"""
540
+ try:
541
+ import soundfile as sf
542
+
543
+ # Extract recordings and state
544
+ recordings = recordings_and_state[:-1] # All except the last item (phrases)
545
+ phrases = recordings_and_state[-1] # Last item is phrases
546
+
547
+ dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
548
+ wav_dir = dataset_dir / "wavs"
549
+ wav_dir.mkdir(parents=True, exist_ok=True)
550
+
551
+ rows = []
552
+ successful_recordings = 0
553
+
554
+ # Process each recording
555
+ for i, rec in enumerate(recordings):
556
+ if rec is not None and i < len(phrases):
557
+ try:
558
+ sr, data = rec
559
+ out_path = wav_dir / f"recording_{i:04d}.wav"
560
+ sf.write(str(out_path), data, sr)
561
+ rows.append({"audio_path": str(out_path), "text": phrases[i]})
562
+ successful_recordings += 1
563
+ except Exception as e:
564
+ print(f"Error processing recording {i}: {e}")
565
+
566
+ if rows:
567
+ jsonl_path = dataset_dir / "recorded_data.jsonl"
568
+ _write_jsonl(rows, jsonl_path)
569
+ return f"✅ Dataset created successfully! {successful_recordings} recordings saved to {jsonl_path}"
570
+ else:
571
+ return "❌ No recordings found. Please record some audio first."
572
+
573
+ except Exception as e:
574
+ return f"❌ Error creating dataset: {str(e)}"
575
+
576
+ # Status display for dataset creation
577
+ dataset_status = gr.Textbox(label="Dataset Creation Status", interactive=False, visible=True)
578
+
579
+ record_dataset_btn.click(
580
+ create_recording_dataset,
581
+ inputs=rec_components + [phrase_texts_state],
582
+ outputs=[dataset_status]
583
+ )
584
+
585
  # Advanced options accordion
586
  with gr.Accordion("Advanced options", open=False):
587
  base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
 
653
  vp_btn = gr.Button("Use Multilingual Dataset Sample")
654
 
655
  def _collect_multilingual_sample(lang_code: str, num_samples: int, split: str):
656
+ """Load sample from multilingual datasets (ML Commons preferred)"""
657
  from datasets import load_dataset, Audio
658
  import random
659
 
660
+ # Language code mapping for ML Commons Speech
661
+ ml_lang_map = {
662
  "en": "en", "de": "de", "fr": "fr", "es": "es", "it": "it",
663
  "pl": "pl", "pt": "pt", "nl": "nl", "ru": "ru", "ar": "ar",
664
+ "zh": "zh", "ja": "ja", "ko": "ko"
665
  }
666
 
667
+ ml_lang = ml_lang_map.get(lang_code, lang_code)
668
+
669
+ try:
670
+ # Try ML Commons Speech first
671
+ ds = load_dataset("mlcommons/ml_spoken_words", f"speech_commands_{ml_lang}", split=split, streaming=True)
672
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
673
+
674
+ dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
675
+ rows: list[dict] = []
676
+ texts: list[str] = []
677
+
678
+ count = 0
679
+ seen_words = set()
680
+
681
+ for ex in ds:
682
+ if count >= num_samples:
683
+ break
684
+
685
+ audio = ex.get("audio") or {}
686
+ path = audio.get("path")
687
+ word = ex.get("word", "").strip()
688
+
689
+ if path and word and len(word) > 2 and word not in seen_words:
690
+ rows.append({"audio_path": path, "text": word})
691
+ texts.append(str(word))
692
+ seen_words.add(word)
693
+ count += 1
694
+
695
+ if rows:
696
+ jsonl_path = dataset_dir / "data.jsonl"
697
+ _write_jsonl(rows, jsonl_path)
698
+
699
+ # Build markdown content updates for on-screen prompts
700
+ combined_updates = []
701
+ for i in range(MAX_COMPONENTS):
702
+ t = texts[i] if i < len(texts) else ""
703
+ if i < len(texts):
704
+ combined_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
705
+ else:
706
+ combined_updates.append(gr.update(visible=False))
707
+
708
+ return (str(jsonl_path), texts, *combined_updates)
709
+
710
+ except Exception as e:
711
+ print(f"ML Commons Speech sample loading failed: {e}")
712
 
713
+ # Try Multilingual LibriSpeech as backup
714
  try:
715
+ ds = load_dataset("facebook/multilingual_librispeech", f"{lang_code}", split=split, streaming=True)
 
716
  ds = ds.cast_column("audio", Audio(sampling_rate=16000))
717
 
718
  dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
 
726
 
727
  audio = ex.get("audio") or {}
728
  path = audio.get("path")
729
+ text = ex.get("text", "").strip()
730
 
731
  if path and text and len(text) > 10:
732
  rows.append({"audio_path": path, "text": text})
 
739
 
740
  # Build markdown content updates for on-screen prompts
741
  combined_updates = []
742
+ for i in range(MAX_COMPONENTS):
743
  t = texts[i] if i < len(texts) else ""
744
  if i < len(texts):
745
  combined_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
 
749
  return (str(jsonl_path), texts, *combined_updates)
750
 
751
  except Exception as e:
752
+ print(f"Multilingual LibriSpeech failed: {e}")
753
 
754
  # Fallback: generate synthetic samples with text only
755
  print("Using fallback: generating text-only samples")
 
763
 
764
  # Build markdown content updates for on-screen prompts
765
  combined_updates = []
766
+ for i in range(MAX_COMPONENTS):
767
  t = texts[i] if i < len(texts) else ""
768
  if i < len(texts):
769
  combined_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))