Joseph Pollack commited on
Commit
b9f51a0
·
unverified ·
1 Parent(s): fb12450

simplifies the interface

Browse files
Files changed (2) hide show
  1. __pycache__/interface.cpython-313.pyc +0 -0
  2. interface.py +38 -133
__pycache__/interface.cpython-313.pyc ADDED
Binary file (30.3 kB). View file
 
interface.py CHANGED
@@ -401,7 +401,8 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
401
  Read the phrases below and record them. Then start fine-tuning.
402
  """)
403
 
404
- jsonl_out = gr.Textbox(label="Dataset JSONL path", interactive=False, visible=True)
 
405
 
406
  # Language selection for NVIDIA Granary phrases
407
  language_selector = gr.Dropdown(
@@ -456,7 +457,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
456
  markdowns = []
457
  recordings = []
458
  for idx in range(max_components):
459
- visible = idx < 10 # Only first 10 visible initially
460
  phrase_text = ALL_PHRASES[idx] if idx < len(ALL_PHRASES) else ""
461
  md = gr.Markdown(f"**{idx+1}. {phrase_text}**", visible=visible)
462
  markdowns.append(md)
@@ -469,7 +470,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
469
  phrase_markdowns, rec_components = create_recording_grid(MAX_COMPONENTS)
470
 
471
  # Add more rows button
472
- add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary")
473
 
474
  def add_more_rows(current_visible, current_phrases):
475
  """Add 10 more rows by making them visible"""
@@ -491,7 +492,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
491
  return [new_visible] + markdown_updates + audio_updates
492
 
493
  def change_language(language):
494
- """Change the language and reload phrases from multilingual datasets"""
495
  new_phrases = load_multilingual_phrases(language, max_phrases=None)
496
  # Reset visible rows to 10
497
  visible_count = min(10, len(new_phrases), MAX_COMPONENTS)
@@ -511,15 +512,19 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
511
  markdown_updates.append(gr.update(value=f"**{i+1}. **", visible=False))
512
  audio_updates.append(gr.update(visible=False))
513
 
514
- # Return: [phrases_state, visible_state] + markdown_updates + audio_updates
515
- return [new_phrases, visible_count] + markdown_updates + audio_updates
 
 
 
 
 
 
 
 
516
 
517
- # Connect language change to phrase reloading
518
- language_selector.change(
519
- change_language,
520
- inputs=[language_selector],
521
- outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components
522
- )
523
 
524
  add_rows_btn.click(
525
  add_more_rows,
@@ -528,7 +533,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
528
  )
529
 
530
  # Recording dataset creation button
531
- record_dataset_btn = gr.Button("🎙️ Create Dataset from Recordings", variant="primary")
532
 
533
  def create_recording_dataset(*recordings_and_state):
534
  """Create dataset from visible recordings and phrases"""
@@ -569,7 +574,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
569
  return f"❌ Error creating dataset: {str(e)}"
570
 
571
  # Status display for dataset creation
572
- dataset_status = gr.Textbox(label="Dataset Creation Status", interactive=False, visible=True)
573
 
574
  record_dataset_btn.click(
575
  create_recording_dataset,
@@ -578,7 +583,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
578
  )
579
 
580
  # Advanced options accordion
581
- with gr.Accordion("Advanced options", open=False):
582
  base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
583
  use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
584
  with gr.Row():
@@ -608,10 +613,11 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
608
  lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
609
  return _save_uploaded_dataset(files or [], lines)
610
 
611
- save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [jsonl_out])
 
612
 
613
  # Save recordings button
614
- save_rec_btn = gr.Button("Save recordings as dataset")
615
 
616
  def _collect_preloaded_recs(*recs_and_texts):
617
  import soundfile as sf
@@ -638,128 +644,17 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
638
  _write_jsonl(rows, jsonl_path)
639
  return str(jsonl_path)
640
 
641
- save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_out])
642
-
643
- # Quick sample from multilingual datasets (Common Voice, etc.)
644
- with gr.Row():
645
- vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "pt", "nl", "ru", "ar", "zh", "ja", "ko", "da", "sv", "fi", "et", "cs", "hr", "bg", "uk", "ro", "hu", "el"], value="en", label="Sample Language")
646
- vp_samples = gr.Number(value=20, precision=0, label="Num samples")
647
- vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
648
- vp_btn = gr.Button("Use Multilingual Dataset Sample")
649
-
650
- def _collect_multilingual_sample(lang_code: str, num_samples: int, split: str):
651
- """Collect sample audio and text from NVIDIA Granary dataset"""
652
- from datasets import load_dataset, Audio
653
- import random
654
-
655
- # Map language code to Granary format
656
- granary_lang_map = {
657
- "en": "en", "de": "de", "fr": "fr", "es": "es", "it": "it",
658
- "pl": "pl", "pt": "pt", "nl": "nl", "ru": "ru", "ar": "ar",
659
- "zh": "zh", "ja": "ja", "ko": "ko", "da": "da", "sv": "sv",
660
- "no": "no", "fi": "fi", "et": "et", "lv": "lv", "lt": "lt",
661
- "sl": "sl", "sk": "sk", "cs": "cs", "hr": "hr", "bg": "bg",
662
- "uk": "uk", "ro": "ro", "hu": "hu", "el": "el", "mt": "mt"
663
- }
664
-
665
- granary_lang = granary_lang_map.get(lang_code, "en")
666
-
667
- try:
668
- print(f"Collecting {num_samples} samples from NVIDIA Granary dataset for language: {lang_code}")
669
-
670
- # Load Granary dataset with ASR split
671
- ds = load_dataset("nvidia/Granary", granary_lang, split="asr", streaming=True)
672
-
673
- dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
674
- rows = []
675
- texts = []
676
- count = 0
677
-
678
- # Sample from the dataset
679
- for example in ds:
680
- if count >= num_samples:
681
- break
682
-
683
- text = example.get("text", "").strip()
684
- audio_path = example.get("audio_filepath", "")
685
-
686
- # Filter for quality samples
687
- if (text and
688
- len(text) > 10 and
689
- len(text) < 200 and
690
- audio_path): # Must have audio file
691
-
692
- rows.append({
693
- "audio_path": audio_path,
694
- "text": text
695
- })
696
- texts.append(text)
697
- count += 1
698
-
699
- if rows:
700
- jsonl_path = dataset_dir / "data.jsonl"
701
- _write_jsonl(rows, jsonl_path)
702
-
703
- print(f"Successfully collected {len(rows)} samples from Granary dataset")
704
-
705
- # Build markdown and audio content updates for on-screen prompts
706
- markdown_updates = []
707
- audio_updates = []
708
- for i in range(MAX_COMPONENTS):
709
- t = texts[i] if i < len(texts) else ""
710
- if i < len(texts):
711
- markdown_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
712
- audio_updates.append(gr.update(visible=True))
713
- else:
714
- markdown_updates.append(gr.update(visible=False))
715
- audio_updates.append(gr.update(visible=False))
716
-
717
- combined_updates = markdown_updates + audio_updates
718
-
719
- return (str(jsonl_path), texts, *combined_updates)
720
-
721
- except Exception as e:
722
- print(f"Granary sample collection failed for {lang_code}: {e}")
723
-
724
- # Fallback: generate text-only samples if Granary fails
725
- print(f"Using fallback: generating text-only samples for {lang_code}")
726
- phrases = load_multilingual_phrases(lang_code, max_phrases=num_samples)
727
- texts = phrases[:num_samples]
728
-
729
- dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
730
- rows = [{"audio_path": "", "text": text} for text in texts]
731
- jsonl_path = dataset_dir / "data.jsonl"
732
- _write_jsonl(rows, jsonl_path)
733
 
734
- # Build markdown and audio content updates for on-screen prompts
735
- markdown_updates = []
736
- audio_updates = []
737
- for i in range(MAX_COMPONENTS):
738
- t = texts[i] if i < len(texts) else ""
739
- if i < len(texts):
740
- markdown_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
741
- audio_updates.append(gr.update(visible=True))
742
- else:
743
- markdown_updates.append(gr.update(visible=False))
744
- audio_updates.append(gr.update(visible=False))
745
-
746
- combined_updates = markdown_updates + audio_updates
747
-
748
- return (str(jsonl_path), texts, *combined_updates)
749
-
750
- vp_btn.click(
751
- _collect_multilingual_sample,
752
- [vp_lang, vp_samples, vp_split],
753
- [jsonl_out, phrase_texts_state] + phrase_markdowns,
754
- )
755
 
756
- start_btn = gr.Button("Start Fine-tuning")
757
- logs_box = gr.Textbox(label="Logs", lines=20)
758
 
759
  start_btn.click(
760
  start_voxtral_training,
761
  inputs=[
762
- use_lora, base_model, repo_short, jsonl_out, train_count, eval_count,
763
  batch_size, grad_accum, learning_rate, epochs,
764
  lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
765
  push_to_hub, deploy_demo,
@@ -767,6 +662,16 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
767
  outputs=[logs_box],
768
  )
769
 
 
 
 
 
 
 
 
 
 
 
770
 
771
  if __name__ == "__main__":
772
  server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
 
401
  Read the phrases below and record them. Then start fine-tuning.
402
  """)
403
 
404
+ # Hidden state to track dataset JSONL path
405
+ jsonl_path_state = gr.State("")
406
 
407
  # Language selection for NVIDIA Granary phrases
408
  language_selector = gr.Dropdown(
 
457
  markdowns = []
458
  recordings = []
459
  for idx in range(max_components):
460
+ visible = False # Initially hidden - will be revealed when language is selected
461
  phrase_text = ALL_PHRASES[idx] if idx < len(ALL_PHRASES) else ""
462
  md = gr.Markdown(f"**{idx+1}. {phrase_text}**", visible=visible)
463
  markdowns.append(md)
 
470
  phrase_markdowns, rec_components = create_recording_grid(MAX_COMPONENTS)
471
 
472
  # Add more rows button
473
+ add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary", visible=False)
474
 
475
  def add_more_rows(current_visible, current_phrases):
476
  """Add 10 more rows by making them visible"""
 
492
  return [new_visible] + markdown_updates + audio_updates
493
 
494
  def change_language(language):
495
+ """Change the language and reload phrases from multilingual datasets, reveal interface"""
496
  new_phrases = load_multilingual_phrases(language, max_phrases=None)
497
  # Reset visible rows to 10
498
  visible_count = min(10, len(new_phrases), MAX_COMPONENTS)
 
512
  markdown_updates.append(gr.update(value=f"**{i+1}. **", visible=False))
513
  audio_updates.append(gr.update(visible=False))
514
 
515
+ # Reveal all interface elements when language is selected
516
+ reveal_updates = [
517
+ gr.update(visible=True), # add_rows_btn
518
+ gr.update(visible=True), # record_dataset_btn
519
+ gr.update(visible=True), # dataset_status
520
+ gr.update(visible=True), # advanced_accordion
521
+ gr.update(visible=True), # save_rec_btn
522
+ gr.update(visible=True), # start_btn
523
+ gr.update(visible=True), # logs_box
524
+ ]
525
 
526
+ # Return: [phrases_state, visible_state] + markdown_updates + audio_updates + reveal_updates
527
+ return [new_phrases, visible_count] + markdown_updates + audio_updates + reveal_updates
 
 
 
 
528
 
529
  add_rows_btn.click(
530
  add_more_rows,
 
533
  )
534
 
535
  # Recording dataset creation button
536
+ record_dataset_btn = gr.Button("🎙️ Create Dataset from Recordings", variant="primary", visible=False)
537
 
538
  def create_recording_dataset(*recordings_and_state):
539
  """Create dataset from visible recordings and phrases"""
 
574
  return f"❌ Error creating dataset: {str(e)}"
575
 
576
  # Status display for dataset creation
577
+ dataset_status = gr.Textbox(label="Dataset Creation Status", interactive=False, visible=False)
578
 
579
  record_dataset_btn.click(
580
  create_recording_dataset,
 
583
  )
584
 
585
  # Advanced options accordion
586
+ with gr.Accordion("Advanced options", open=False, visible=False) as advanced_accordion:
587
  base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
588
  use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
589
  with gr.Row():
 
613
  lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
614
  return _save_uploaded_dataset(files or [], lines)
615
 
616
+ # Removed - no longer needed since jsonl_out was removed
617
+ # save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [])
618
 
619
  # Save recordings button
620
+ save_rec_btn = gr.Button("Save recordings as dataset", visible=False)
621
 
622
  def _collect_preloaded_recs(*recs_and_texts):
623
  import soundfile as sf
 
644
  _write_jsonl(rows, jsonl_path)
645
  return str(jsonl_path)
646
 
647
+ save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_path_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
+ # Removed multilingual dataset sample section - phrases are now loaded automatically when language is selected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
+ start_btn = gr.Button("Start Fine-tuning", visible=False)
652
+ logs_box = gr.Textbox(label="Logs", lines=20, visible=False)
653
 
654
  start_btn.click(
655
  start_voxtral_training,
656
  inputs=[
657
+ use_lora, base_model, repo_short, jsonl_path_state, train_count, eval_count,
658
  batch_size, grad_accum, learning_rate, epochs,
659
  lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
660
  push_to_hub, deploy_demo,
 
662
  outputs=[logs_box],
663
  )
664
 
665
+ # Connect language change to phrase reloading and interface reveal (placed after all components are defined)
666
+ language_selector.change(
667
+ change_language,
668
+ inputs=[language_selector],
669
+ outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components + [
670
+ add_rows_btn, record_dataset_btn, dataset_status, advanced_accordion,
671
+ save_rec_btn, start_btn, logs_box
672
+ ]
673
+ )
674
+
675
 
676
  if __name__ == "__main__":
677
  server_port = int(os.environ.get("INTERFACE_PORT", "7860"))