Spaces:
Running
Running
Joseph Pollack
commited on
simplifies the interface
Browse files- __pycache__/interface.cpython-313.pyc +0 -0
- interface.py +38 -133
__pycache__/interface.cpython-313.pyc
ADDED
Binary file (30.3 kB). View file
|
|
interface.py
CHANGED
@@ -401,7 +401,8 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
401 |
Read the phrases below and record them. Then start fine-tuning.
|
402 |
""")
|
403 |
|
404 |
-
|
|
|
405 |
|
406 |
# Language selection for NVIDIA Granary phrases
|
407 |
language_selector = gr.Dropdown(
|
@@ -456,7 +457,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
456 |
markdowns = []
|
457 |
recordings = []
|
458 |
for idx in range(max_components):
|
459 |
-
visible =
|
460 |
phrase_text = ALL_PHRASES[idx] if idx < len(ALL_PHRASES) else ""
|
461 |
md = gr.Markdown(f"**{idx+1}. {phrase_text}**", visible=visible)
|
462 |
markdowns.append(md)
|
@@ -469,7 +470,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
469 |
phrase_markdowns, rec_components = create_recording_grid(MAX_COMPONENTS)
|
470 |
|
471 |
# Add more rows button
|
472 |
-
add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary")
|
473 |
|
474 |
def add_more_rows(current_visible, current_phrases):
|
475 |
"""Add 10 more rows by making them visible"""
|
@@ -491,7 +492,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
491 |
return [new_visible] + markdown_updates + audio_updates
|
492 |
|
493 |
def change_language(language):
|
494 |
-
"""Change the language and reload phrases from multilingual datasets"""
|
495 |
new_phrases = load_multilingual_phrases(language, max_phrases=None)
|
496 |
# Reset visible rows to 10
|
497 |
visible_count = min(10, len(new_phrases), MAX_COMPONENTS)
|
@@ -511,15 +512,19 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
511 |
markdown_updates.append(gr.update(value=f"**{i+1}. **", visible=False))
|
512 |
audio_updates.append(gr.update(visible=False))
|
513 |
|
514 |
-
#
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
-
|
518 |
-
|
519 |
-
change_language,
|
520 |
-
inputs=[language_selector],
|
521 |
-
outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components
|
522 |
-
)
|
523 |
|
524 |
add_rows_btn.click(
|
525 |
add_more_rows,
|
@@ -528,7 +533,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
528 |
)
|
529 |
|
530 |
# Recording dataset creation button
|
531 |
-
record_dataset_btn = gr.Button("🎙️ Create Dataset from Recordings", variant="primary")
|
532 |
|
533 |
def create_recording_dataset(*recordings_and_state):
|
534 |
"""Create dataset from visible recordings and phrases"""
|
@@ -569,7 +574,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
569 |
return f"❌ Error creating dataset: {str(e)}"
|
570 |
|
571 |
# Status display for dataset creation
|
572 |
-
dataset_status = gr.Textbox(label="Dataset Creation Status", interactive=False, visible=
|
573 |
|
574 |
record_dataset_btn.click(
|
575 |
create_recording_dataset,
|
@@ -578,7 +583,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
578 |
)
|
579 |
|
580 |
# Advanced options accordion
|
581 |
-
with gr.Accordion("Advanced options", open=False):
|
582 |
base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
|
583 |
use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
|
584 |
with gr.Row():
|
@@ -608,10 +613,11 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
608 |
lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
|
609 |
return _save_uploaded_dataset(files or [], lines)
|
610 |
|
611 |
-
|
|
|
612 |
|
613 |
# Save recordings button
|
614 |
-
save_rec_btn = gr.Button("Save recordings as dataset")
|
615 |
|
616 |
def _collect_preloaded_recs(*recs_and_texts):
|
617 |
import soundfile as sf
|
@@ -638,128 +644,17 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
638 |
_write_jsonl(rows, jsonl_path)
|
639 |
return str(jsonl_path)
|
640 |
|
641 |
-
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [
|
642 |
-
|
643 |
-
# Quick sample from multilingual datasets (Common Voice, etc.)
|
644 |
-
with gr.Row():
|
645 |
-
vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "pt", "nl", "ru", "ar", "zh", "ja", "ko", "da", "sv", "fi", "et", "cs", "hr", "bg", "uk", "ro", "hu", "el"], value="en", label="Sample Language")
|
646 |
-
vp_samples = gr.Number(value=20, precision=0, label="Num samples")
|
647 |
-
vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
|
648 |
-
vp_btn = gr.Button("Use Multilingual Dataset Sample")
|
649 |
-
|
650 |
-
def _collect_multilingual_sample(lang_code: str, num_samples: int, split: str):
|
651 |
-
"""Collect sample audio and text from NVIDIA Granary dataset"""
|
652 |
-
from datasets import load_dataset, Audio
|
653 |
-
import random
|
654 |
-
|
655 |
-
# Map language code to Granary format
|
656 |
-
granary_lang_map = {
|
657 |
-
"en": "en", "de": "de", "fr": "fr", "es": "es", "it": "it",
|
658 |
-
"pl": "pl", "pt": "pt", "nl": "nl", "ru": "ru", "ar": "ar",
|
659 |
-
"zh": "zh", "ja": "ja", "ko": "ko", "da": "da", "sv": "sv",
|
660 |
-
"no": "no", "fi": "fi", "et": "et", "lv": "lv", "lt": "lt",
|
661 |
-
"sl": "sl", "sk": "sk", "cs": "cs", "hr": "hr", "bg": "bg",
|
662 |
-
"uk": "uk", "ro": "ro", "hu": "hu", "el": "el", "mt": "mt"
|
663 |
-
}
|
664 |
-
|
665 |
-
granary_lang = granary_lang_map.get(lang_code, "en")
|
666 |
-
|
667 |
-
try:
|
668 |
-
print(f"Collecting {num_samples} samples from NVIDIA Granary dataset for language: {lang_code}")
|
669 |
-
|
670 |
-
# Load Granary dataset with ASR split
|
671 |
-
ds = load_dataset("nvidia/Granary", granary_lang, split="asr", streaming=True)
|
672 |
-
|
673 |
-
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
674 |
-
rows = []
|
675 |
-
texts = []
|
676 |
-
count = 0
|
677 |
-
|
678 |
-
# Sample from the dataset
|
679 |
-
for example in ds:
|
680 |
-
if count >= num_samples:
|
681 |
-
break
|
682 |
-
|
683 |
-
text = example.get("text", "").strip()
|
684 |
-
audio_path = example.get("audio_filepath", "")
|
685 |
-
|
686 |
-
# Filter for quality samples
|
687 |
-
if (text and
|
688 |
-
len(text) > 10 and
|
689 |
-
len(text) < 200 and
|
690 |
-
audio_path): # Must have audio file
|
691 |
-
|
692 |
-
rows.append({
|
693 |
-
"audio_path": audio_path,
|
694 |
-
"text": text
|
695 |
-
})
|
696 |
-
texts.append(text)
|
697 |
-
count += 1
|
698 |
-
|
699 |
-
if rows:
|
700 |
-
jsonl_path = dataset_dir / "data.jsonl"
|
701 |
-
_write_jsonl(rows, jsonl_path)
|
702 |
-
|
703 |
-
print(f"Successfully collected {len(rows)} samples from Granary dataset")
|
704 |
-
|
705 |
-
# Build markdown and audio content updates for on-screen prompts
|
706 |
-
markdown_updates = []
|
707 |
-
audio_updates = []
|
708 |
-
for i in range(MAX_COMPONENTS):
|
709 |
-
t = texts[i] if i < len(texts) else ""
|
710 |
-
if i < len(texts):
|
711 |
-
markdown_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
|
712 |
-
audio_updates.append(gr.update(visible=True))
|
713 |
-
else:
|
714 |
-
markdown_updates.append(gr.update(visible=False))
|
715 |
-
audio_updates.append(gr.update(visible=False))
|
716 |
-
|
717 |
-
combined_updates = markdown_updates + audio_updates
|
718 |
-
|
719 |
-
return (str(jsonl_path), texts, *combined_updates)
|
720 |
-
|
721 |
-
except Exception as e:
|
722 |
-
print(f"Granary sample collection failed for {lang_code}: {e}")
|
723 |
-
|
724 |
-
# Fallback: generate text-only samples if Granary fails
|
725 |
-
print(f"Using fallback: generating text-only samples for {lang_code}")
|
726 |
-
phrases = load_multilingual_phrases(lang_code, max_phrases=num_samples)
|
727 |
-
texts = phrases[:num_samples]
|
728 |
-
|
729 |
-
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
730 |
-
rows = [{"audio_path": "", "text": text} for text in texts]
|
731 |
-
jsonl_path = dataset_dir / "data.jsonl"
|
732 |
-
_write_jsonl(rows, jsonl_path)
|
733 |
|
734 |
-
|
735 |
-
markdown_updates = []
|
736 |
-
audio_updates = []
|
737 |
-
for i in range(MAX_COMPONENTS):
|
738 |
-
t = texts[i] if i < len(texts) else ""
|
739 |
-
if i < len(texts):
|
740 |
-
markdown_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
|
741 |
-
audio_updates.append(gr.update(visible=True))
|
742 |
-
else:
|
743 |
-
markdown_updates.append(gr.update(visible=False))
|
744 |
-
audio_updates.append(gr.update(visible=False))
|
745 |
-
|
746 |
-
combined_updates = markdown_updates + audio_updates
|
747 |
-
|
748 |
-
return (str(jsonl_path), texts, *combined_updates)
|
749 |
-
|
750 |
-
vp_btn.click(
|
751 |
-
_collect_multilingual_sample,
|
752 |
-
[vp_lang, vp_samples, vp_split],
|
753 |
-
[jsonl_out, phrase_texts_state] + phrase_markdowns,
|
754 |
-
)
|
755 |
|
756 |
-
start_btn = gr.Button("Start Fine-tuning")
|
757 |
-
logs_box = gr.Textbox(label="Logs", lines=20)
|
758 |
|
759 |
start_btn.click(
|
760 |
start_voxtral_training,
|
761 |
inputs=[
|
762 |
-
use_lora, base_model, repo_short,
|
763 |
batch_size, grad_accum, learning_rate, epochs,
|
764 |
lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
|
765 |
push_to_hub, deploy_demo,
|
@@ -767,6 +662,16 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
767 |
outputs=[logs_box],
|
768 |
)
|
769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
|
771 |
if __name__ == "__main__":
|
772 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
|
|
401 |
Read the phrases below and record them. Then start fine-tuning.
|
402 |
""")
|
403 |
|
404 |
+
# Hidden state to track dataset JSONL path
|
405 |
+
jsonl_path_state = gr.State("")
|
406 |
|
407 |
# Language selection for NVIDIA Granary phrases
|
408 |
language_selector = gr.Dropdown(
|
|
|
457 |
markdowns = []
|
458 |
recordings = []
|
459 |
for idx in range(max_components):
|
460 |
+
visible = False # Initially hidden - will be revealed when language is selected
|
461 |
phrase_text = ALL_PHRASES[idx] if idx < len(ALL_PHRASES) else ""
|
462 |
md = gr.Markdown(f"**{idx+1}. {phrase_text}**", visible=visible)
|
463 |
markdowns.append(md)
|
|
|
470 |
phrase_markdowns, rec_components = create_recording_grid(MAX_COMPONENTS)
|
471 |
|
472 |
# Add more rows button
|
473 |
+
add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary", visible=False)
|
474 |
|
475 |
def add_more_rows(current_visible, current_phrases):
|
476 |
"""Add 10 more rows by making them visible"""
|
|
|
492 |
return [new_visible] + markdown_updates + audio_updates
|
493 |
|
494 |
def change_language(language):
|
495 |
+
"""Change the language and reload phrases from multilingual datasets, reveal interface"""
|
496 |
new_phrases = load_multilingual_phrases(language, max_phrases=None)
|
497 |
# Reset visible rows to 10
|
498 |
visible_count = min(10, len(new_phrases), MAX_COMPONENTS)
|
|
|
512 |
markdown_updates.append(gr.update(value=f"**{i+1}. **", visible=False))
|
513 |
audio_updates.append(gr.update(visible=False))
|
514 |
|
515 |
+
# Reveal all interface elements when language is selected
|
516 |
+
reveal_updates = [
|
517 |
+
gr.update(visible=True), # add_rows_btn
|
518 |
+
gr.update(visible=True), # record_dataset_btn
|
519 |
+
gr.update(visible=True), # dataset_status
|
520 |
+
gr.update(visible=True), # advanced_accordion
|
521 |
+
gr.update(visible=True), # save_rec_btn
|
522 |
+
gr.update(visible=True), # start_btn
|
523 |
+
gr.update(visible=True), # logs_box
|
524 |
+
]
|
525 |
|
526 |
+
# Return: [phrases_state, visible_state] + markdown_updates + audio_updates + reveal_updates
|
527 |
+
return [new_phrases, visible_count] + markdown_updates + audio_updates + reveal_updates
|
|
|
|
|
|
|
|
|
528 |
|
529 |
add_rows_btn.click(
|
530 |
add_more_rows,
|
|
|
533 |
)
|
534 |
|
535 |
# Recording dataset creation button
|
536 |
+
record_dataset_btn = gr.Button("🎙️ Create Dataset from Recordings", variant="primary", visible=False)
|
537 |
|
538 |
def create_recording_dataset(*recordings_and_state):
|
539 |
"""Create dataset from visible recordings and phrases"""
|
|
|
574 |
return f"❌ Error creating dataset: {str(e)}"
|
575 |
|
576 |
# Status display for dataset creation
|
577 |
+
dataset_status = gr.Textbox(label="Dataset Creation Status", interactive=False, visible=False)
|
578 |
|
579 |
record_dataset_btn.click(
|
580 |
create_recording_dataset,
|
|
|
583 |
)
|
584 |
|
585 |
# Advanced options accordion
|
586 |
+
with gr.Accordion("Advanced options", open=False, visible=False) as advanced_accordion:
|
587 |
base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
|
588 |
use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
|
589 |
with gr.Row():
|
|
|
613 |
lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
|
614 |
return _save_uploaded_dataset(files or [], lines)
|
615 |
|
616 |
+
# Removed - no longer needed since jsonl_out was removed
|
617 |
+
# save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [])
|
618 |
|
619 |
# Save recordings button
|
620 |
+
save_rec_btn = gr.Button("Save recordings as dataset", visible=False)
|
621 |
|
622 |
def _collect_preloaded_recs(*recs_and_texts):
|
623 |
import soundfile as sf
|
|
|
644 |
_write_jsonl(rows, jsonl_path)
|
645 |
return str(jsonl_path)
|
646 |
|
647 |
+
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_path_state])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
|
649 |
+
# Removed multilingual dataset sample section - phrases are now loaded automatically when language is selected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
|
651 |
+
start_btn = gr.Button("Start Fine-tuning", visible=False)
|
652 |
+
logs_box = gr.Textbox(label="Logs", lines=20, visible=False)
|
653 |
|
654 |
start_btn.click(
|
655 |
start_voxtral_training,
|
656 |
inputs=[
|
657 |
+
use_lora, base_model, repo_short, jsonl_path_state, train_count, eval_count,
|
658 |
batch_size, grad_accum, learning_rate, epochs,
|
659 |
lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
|
660 |
push_to_hub, deploy_demo,
|
|
|
662 |
outputs=[logs_box],
|
663 |
)
|
664 |
|
665 |
+
# Connect language change to phrase reloading and interface reveal (placed after all components are defined)
|
666 |
+
language_selector.change(
|
667 |
+
change_language,
|
668 |
+
inputs=[language_selector],
|
669 |
+
outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components + [
|
670 |
+
add_rows_btn, record_dataset_btn, dataset_status, advanced_accordion,
|
671 |
+
save_rec_btn, start_btn, logs_box
|
672 |
+
]
|
673 |
+
)
|
674 |
+
|
675 |
|
676 |
if __name__ == "__main__":
|
677 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|