Spaces:
Running
Running
Joseph Pollack
commited on
simplifies the interface
Browse files- __pycache__/interface.cpython-313.pyc +0 -0
- interface.py +38 -133
__pycache__/interface.cpython-313.pyc
ADDED
|
Binary file (30.3 kB). View file
|
|
|
interface.py
CHANGED
|
@@ -401,7 +401,8 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 401 |
Read the phrases below and record them. Then start fine-tuning.
|
| 402 |
""")
|
| 403 |
|
| 404 |
-
|
|
|
|
| 405 |
|
| 406 |
# Language selection for NVIDIA Granary phrases
|
| 407 |
language_selector = gr.Dropdown(
|
|
@@ -456,7 +457,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 456 |
markdowns = []
|
| 457 |
recordings = []
|
| 458 |
for idx in range(max_components):
|
| 459 |
-
visible =
|
| 460 |
phrase_text = ALL_PHRASES[idx] if idx < len(ALL_PHRASES) else ""
|
| 461 |
md = gr.Markdown(f"**{idx+1}. {phrase_text}**", visible=visible)
|
| 462 |
markdowns.append(md)
|
|
@@ -469,7 +470,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 469 |
phrase_markdowns, rec_components = create_recording_grid(MAX_COMPONENTS)
|
| 470 |
|
| 471 |
# Add more rows button
|
| 472 |
-
add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary")
|
| 473 |
|
| 474 |
def add_more_rows(current_visible, current_phrases):
|
| 475 |
"""Add 10 more rows by making them visible"""
|
|
@@ -491,7 +492,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 491 |
return [new_visible] + markdown_updates + audio_updates
|
| 492 |
|
| 493 |
def change_language(language):
|
| 494 |
-
"""Change the language and reload phrases from multilingual datasets"""
|
| 495 |
new_phrases = load_multilingual_phrases(language, max_phrases=None)
|
| 496 |
# Reset visible rows to 10
|
| 497 |
visible_count = min(10, len(new_phrases), MAX_COMPONENTS)
|
|
@@ -511,15 +512,19 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 511 |
markdown_updates.append(gr.update(value=f"**{i+1}. **", visible=False))
|
| 512 |
audio_updates.append(gr.update(visible=False))
|
| 513 |
|
| 514 |
-
#
|
| 515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
change_language,
|
| 520 |
-
inputs=[language_selector],
|
| 521 |
-
outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components
|
| 522 |
-
)
|
| 523 |
|
| 524 |
add_rows_btn.click(
|
| 525 |
add_more_rows,
|
|
@@ -528,7 +533,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 528 |
)
|
| 529 |
|
| 530 |
# Recording dataset creation button
|
| 531 |
-
record_dataset_btn = gr.Button("🎙️ Create Dataset from Recordings", variant="primary")
|
| 532 |
|
| 533 |
def create_recording_dataset(*recordings_and_state):
|
| 534 |
"""Create dataset from visible recordings and phrases"""
|
|
@@ -569,7 +574,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 569 |
return f"❌ Error creating dataset: {str(e)}"
|
| 570 |
|
| 571 |
# Status display for dataset creation
|
| 572 |
-
dataset_status = gr.Textbox(label="Dataset Creation Status", interactive=False, visible=
|
| 573 |
|
| 574 |
record_dataset_btn.click(
|
| 575 |
create_recording_dataset,
|
|
@@ -578,7 +583,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 578 |
)
|
| 579 |
|
| 580 |
# Advanced options accordion
|
| 581 |
-
with gr.Accordion("Advanced options", open=False):
|
| 582 |
base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
|
| 583 |
use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
|
| 584 |
with gr.Row():
|
|
@@ -608,10 +613,11 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 608 |
lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
|
| 609 |
return _save_uploaded_dataset(files or [], lines)
|
| 610 |
|
| 611 |
-
|
|
|
|
| 612 |
|
| 613 |
# Save recordings button
|
| 614 |
-
save_rec_btn = gr.Button("Save recordings as dataset")
|
| 615 |
|
| 616 |
def _collect_preloaded_recs(*recs_and_texts):
|
| 617 |
import soundfile as sf
|
|
@@ -638,128 +644,17 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 638 |
_write_jsonl(rows, jsonl_path)
|
| 639 |
return str(jsonl_path)
|
| 640 |
|
| 641 |
-
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [
|
| 642 |
-
|
| 643 |
-
# Quick sample from multilingual datasets (Common Voice, etc.)
|
| 644 |
-
with gr.Row():
|
| 645 |
-
vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "pt", "nl", "ru", "ar", "zh", "ja", "ko", "da", "sv", "fi", "et", "cs", "hr", "bg", "uk", "ro", "hu", "el"], value="en", label="Sample Language")
|
| 646 |
-
vp_samples = gr.Number(value=20, precision=0, label="Num samples")
|
| 647 |
-
vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
|
| 648 |
-
vp_btn = gr.Button("Use Multilingual Dataset Sample")
|
| 649 |
-
|
| 650 |
-
def _collect_multilingual_sample(lang_code: str, num_samples: int, split: str):
|
| 651 |
-
"""Collect sample audio and text from NVIDIA Granary dataset"""
|
| 652 |
-
from datasets import load_dataset, Audio
|
| 653 |
-
import random
|
| 654 |
-
|
| 655 |
-
# Map language code to Granary format
|
| 656 |
-
granary_lang_map = {
|
| 657 |
-
"en": "en", "de": "de", "fr": "fr", "es": "es", "it": "it",
|
| 658 |
-
"pl": "pl", "pt": "pt", "nl": "nl", "ru": "ru", "ar": "ar",
|
| 659 |
-
"zh": "zh", "ja": "ja", "ko": "ko", "da": "da", "sv": "sv",
|
| 660 |
-
"no": "no", "fi": "fi", "et": "et", "lv": "lv", "lt": "lt",
|
| 661 |
-
"sl": "sl", "sk": "sk", "cs": "cs", "hr": "hr", "bg": "bg",
|
| 662 |
-
"uk": "uk", "ro": "ro", "hu": "hu", "el": "el", "mt": "mt"
|
| 663 |
-
}
|
| 664 |
-
|
| 665 |
-
granary_lang = granary_lang_map.get(lang_code, "en")
|
| 666 |
-
|
| 667 |
-
try:
|
| 668 |
-
print(f"Collecting {num_samples} samples from NVIDIA Granary dataset for language: {lang_code}")
|
| 669 |
-
|
| 670 |
-
# Load Granary dataset with ASR split
|
| 671 |
-
ds = load_dataset("nvidia/Granary", granary_lang, split="asr", streaming=True)
|
| 672 |
-
|
| 673 |
-
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 674 |
-
rows = []
|
| 675 |
-
texts = []
|
| 676 |
-
count = 0
|
| 677 |
-
|
| 678 |
-
# Sample from the dataset
|
| 679 |
-
for example in ds:
|
| 680 |
-
if count >= num_samples:
|
| 681 |
-
break
|
| 682 |
-
|
| 683 |
-
text = example.get("text", "").strip()
|
| 684 |
-
audio_path = example.get("audio_filepath", "")
|
| 685 |
-
|
| 686 |
-
# Filter for quality samples
|
| 687 |
-
if (text and
|
| 688 |
-
len(text) > 10 and
|
| 689 |
-
len(text) < 200 and
|
| 690 |
-
audio_path): # Must have audio file
|
| 691 |
-
|
| 692 |
-
rows.append({
|
| 693 |
-
"audio_path": audio_path,
|
| 694 |
-
"text": text
|
| 695 |
-
})
|
| 696 |
-
texts.append(text)
|
| 697 |
-
count += 1
|
| 698 |
-
|
| 699 |
-
if rows:
|
| 700 |
-
jsonl_path = dataset_dir / "data.jsonl"
|
| 701 |
-
_write_jsonl(rows, jsonl_path)
|
| 702 |
-
|
| 703 |
-
print(f"Successfully collected {len(rows)} samples from Granary dataset")
|
| 704 |
-
|
| 705 |
-
# Build markdown and audio content updates for on-screen prompts
|
| 706 |
-
markdown_updates = []
|
| 707 |
-
audio_updates = []
|
| 708 |
-
for i in range(MAX_COMPONENTS):
|
| 709 |
-
t = texts[i] if i < len(texts) else ""
|
| 710 |
-
if i < len(texts):
|
| 711 |
-
markdown_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
|
| 712 |
-
audio_updates.append(gr.update(visible=True))
|
| 713 |
-
else:
|
| 714 |
-
markdown_updates.append(gr.update(visible=False))
|
| 715 |
-
audio_updates.append(gr.update(visible=False))
|
| 716 |
-
|
| 717 |
-
combined_updates = markdown_updates + audio_updates
|
| 718 |
-
|
| 719 |
-
return (str(jsonl_path), texts, *combined_updates)
|
| 720 |
-
|
| 721 |
-
except Exception as e:
|
| 722 |
-
print(f"Granary sample collection failed for {lang_code}: {e}")
|
| 723 |
-
|
| 724 |
-
# Fallback: generate text-only samples if Granary fails
|
| 725 |
-
print(f"Using fallback: generating text-only samples for {lang_code}")
|
| 726 |
-
phrases = load_multilingual_phrases(lang_code, max_phrases=num_samples)
|
| 727 |
-
texts = phrases[:num_samples]
|
| 728 |
-
|
| 729 |
-
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 730 |
-
rows = [{"audio_path": "", "text": text} for text in texts]
|
| 731 |
-
jsonl_path = dataset_dir / "data.jsonl"
|
| 732 |
-
_write_jsonl(rows, jsonl_path)
|
| 733 |
|
| 734 |
-
|
| 735 |
-
markdown_updates = []
|
| 736 |
-
audio_updates = []
|
| 737 |
-
for i in range(MAX_COMPONENTS):
|
| 738 |
-
t = texts[i] if i < len(texts) else ""
|
| 739 |
-
if i < len(texts):
|
| 740 |
-
markdown_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
|
| 741 |
-
audio_updates.append(gr.update(visible=True))
|
| 742 |
-
else:
|
| 743 |
-
markdown_updates.append(gr.update(visible=False))
|
| 744 |
-
audio_updates.append(gr.update(visible=False))
|
| 745 |
-
|
| 746 |
-
combined_updates = markdown_updates + audio_updates
|
| 747 |
-
|
| 748 |
-
return (str(jsonl_path), texts, *combined_updates)
|
| 749 |
-
|
| 750 |
-
vp_btn.click(
|
| 751 |
-
_collect_multilingual_sample,
|
| 752 |
-
[vp_lang, vp_samples, vp_split],
|
| 753 |
-
[jsonl_out, phrase_texts_state] + phrase_markdowns,
|
| 754 |
-
)
|
| 755 |
|
| 756 |
-
start_btn = gr.Button("Start Fine-tuning")
|
| 757 |
-
logs_box = gr.Textbox(label="Logs", lines=20)
|
| 758 |
|
| 759 |
start_btn.click(
|
| 760 |
start_voxtral_training,
|
| 761 |
inputs=[
|
| 762 |
-
use_lora, base_model, repo_short,
|
| 763 |
batch_size, grad_accum, learning_rate, epochs,
|
| 764 |
lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
|
| 765 |
push_to_hub, deploy_demo,
|
|
@@ -767,6 +662,16 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 767 |
outputs=[logs_box],
|
| 768 |
)
|
| 769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
|
| 771 |
if __name__ == "__main__":
|
| 772 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
|
|
|
| 401 |
Read the phrases below and record them. Then start fine-tuning.
|
| 402 |
""")
|
| 403 |
|
| 404 |
+
# Hidden state to track dataset JSONL path
|
| 405 |
+
jsonl_path_state = gr.State("")
|
| 406 |
|
| 407 |
# Language selection for NVIDIA Granary phrases
|
| 408 |
language_selector = gr.Dropdown(
|
|
|
|
| 457 |
markdowns = []
|
| 458 |
recordings = []
|
| 459 |
for idx in range(max_components):
|
| 460 |
+
visible = False # Initially hidden - will be revealed when language is selected
|
| 461 |
phrase_text = ALL_PHRASES[idx] if idx < len(ALL_PHRASES) else ""
|
| 462 |
md = gr.Markdown(f"**{idx+1}. {phrase_text}**", visible=visible)
|
| 463 |
markdowns.append(md)
|
|
|
|
| 470 |
phrase_markdowns, rec_components = create_recording_grid(MAX_COMPONENTS)
|
| 471 |
|
| 472 |
# Add more rows button
|
| 473 |
+
add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary", visible=False)
|
| 474 |
|
| 475 |
def add_more_rows(current_visible, current_phrases):
|
| 476 |
"""Add 10 more rows by making them visible"""
|
|
|
|
| 492 |
return [new_visible] + markdown_updates + audio_updates
|
| 493 |
|
| 494 |
def change_language(language):
|
| 495 |
+
"""Change the language and reload phrases from multilingual datasets, reveal interface"""
|
| 496 |
new_phrases = load_multilingual_phrases(language, max_phrases=None)
|
| 497 |
# Reset visible rows to 10
|
| 498 |
visible_count = min(10, len(new_phrases), MAX_COMPONENTS)
|
|
|
|
| 512 |
markdown_updates.append(gr.update(value=f"**{i+1}. **", visible=False))
|
| 513 |
audio_updates.append(gr.update(visible=False))
|
| 514 |
|
| 515 |
+
# Reveal all interface elements when language is selected
|
| 516 |
+
reveal_updates = [
|
| 517 |
+
gr.update(visible=True), # add_rows_btn
|
| 518 |
+
gr.update(visible=True), # record_dataset_btn
|
| 519 |
+
gr.update(visible=True), # dataset_status
|
| 520 |
+
gr.update(visible=True), # advanced_accordion
|
| 521 |
+
gr.update(visible=True), # save_rec_btn
|
| 522 |
+
gr.update(visible=True), # start_btn
|
| 523 |
+
gr.update(visible=True), # logs_box
|
| 524 |
+
]
|
| 525 |
|
| 526 |
+
# Return: [phrases_state, visible_state] + markdown_updates + audio_updates + reveal_updates
|
| 527 |
+
return [new_phrases, visible_count] + markdown_updates + audio_updates + reveal_updates
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
add_rows_btn.click(
|
| 530 |
add_more_rows,
|
|
|
|
| 533 |
)
|
| 534 |
|
| 535 |
# Recording dataset creation button
|
| 536 |
+
record_dataset_btn = gr.Button("🎙️ Create Dataset from Recordings", variant="primary", visible=False)
|
| 537 |
|
| 538 |
def create_recording_dataset(*recordings_and_state):
|
| 539 |
"""Create dataset from visible recordings and phrases"""
|
|
|
|
| 574 |
return f"❌ Error creating dataset: {str(e)}"
|
| 575 |
|
| 576 |
# Status display for dataset creation
|
| 577 |
+
dataset_status = gr.Textbox(label="Dataset Creation Status", interactive=False, visible=False)
|
| 578 |
|
| 579 |
record_dataset_btn.click(
|
| 580 |
create_recording_dataset,
|
|
|
|
| 583 |
)
|
| 584 |
|
| 585 |
# Advanced options accordion
|
| 586 |
+
with gr.Accordion("Advanced options", open=False, visible=False) as advanced_accordion:
|
| 587 |
base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
|
| 588 |
use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
|
| 589 |
with gr.Row():
|
|
|
|
| 613 |
lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
|
| 614 |
return _save_uploaded_dataset(files or [], lines)
|
| 615 |
|
| 616 |
+
# Removed - no longer needed since jsonl_out was removed
|
| 617 |
+
# save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [])
|
| 618 |
|
| 619 |
# Save recordings button
|
| 620 |
+
save_rec_btn = gr.Button("Save recordings as dataset", visible=False)
|
| 621 |
|
| 622 |
def _collect_preloaded_recs(*recs_and_texts):
|
| 623 |
import soundfile as sf
|
|
|
|
| 644 |
_write_jsonl(rows, jsonl_path)
|
| 645 |
return str(jsonl_path)
|
| 646 |
|
| 647 |
+
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_path_state])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
+
# Removed multilingual dataset sample section - phrases are now loaded automatically when language is selected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
+
start_btn = gr.Button("Start Fine-tuning", visible=False)
|
| 652 |
+
logs_box = gr.Textbox(label="Logs", lines=20, visible=False)
|
| 653 |
|
| 654 |
start_btn.click(
|
| 655 |
start_voxtral_training,
|
| 656 |
inputs=[
|
| 657 |
+
use_lora, base_model, repo_short, jsonl_path_state, train_count, eval_count,
|
| 658 |
batch_size, grad_accum, learning_rate, epochs,
|
| 659 |
lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
|
| 660 |
push_to_hub, deploy_demo,
|
|
|
|
| 662 |
outputs=[logs_box],
|
| 663 |
)
|
| 664 |
|
| 665 |
+
# Connect language change to phrase reloading and interface reveal (placed after all components are defined)
|
| 666 |
+
language_selector.change(
|
| 667 |
+
change_language,
|
| 668 |
+
inputs=[language_selector],
|
| 669 |
+
outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components + [
|
| 670 |
+
add_rows_btn, record_dataset_btn, dataset_status, advanced_accordion,
|
| 671 |
+
save_rec_btn, start_btn, logs_box
|
| 672 |
+
]
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
|
| 676 |
if __name__ == "__main__":
|
| 677 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|