#!/usr/bin/env python3 """ Voxtral ASR Fine-tuning Interface Features: - Collect a personal voice dataset (upload WAV/FLAC + transcripts or record mic audio) - Build a JSONL dataset ({audio_path, text}) at 16kHz - Fine-tune Voxtral (LoRA or full) with streamed logs - Push model to Hugging Face Hub - Deploy a Voxtral ASR demo Space Env tokens (optional): - HF_WRITE_TOKEN or HF_TOKEN: write access token - HF_READ_TOKEN: optional read token - HF_USERNAME: fallback username if not derivable from token """ from __future__ import annotations import os import json from pathlib import Path from datetime import datetime from typing import Any, Dict, Generator, Optional, Tuple import gradio as gr PROJECT_ROOT = Path(__file__).resolve().parent def get_python() -> str: import sys return sys.executable or "python" def get_username_from_token(token: str) -> Optional[str]: try: from huggingface_hub import HfApi # type: ignore api = HfApi(token=token) info = api.whoami() if isinstance(info, dict): return info.get("name") or info.get("username") if isinstance(info, str): return info except Exception: return None return None def run_command_stream(args: list[str], env: Dict[str, str], cwd: Optional[Path] = None) -> Generator[str, None, int]: import subprocess import shlex yield f"$ {' '.join(shlex.quote(a) for a in ([get_python()] + args))}" process = subprocess.Popen( [get_python()] + args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env, cwd=str(cwd or PROJECT_ROOT), bufsize=1, universal_newlines=True, ) assert process.stdout is not None for line in iter(process.stdout.readline, ""): yield line.rstrip() process.stdout.close() code = process.wait() yield f"[exit_code={code}]" return code def detect_nvidia_driver() -> Tuple[bool, str]: """Detect NVIDIA driver/GPU presence with multiple strategies. Returns (available, human_message). """ # 1) Try torch CUDA try: import torch # type: ignore if torch.cuda.is_available(): try: num = torch.cuda.device_count() names = [torch.cuda.get_device_name(i) for i in range(num)] return True, f"NVIDIA GPU detected: {', '.join(names)}" except Exception: return True, "NVIDIA GPU detected (torch.cuda available)" except Exception: pass # 2) Try NVML via pynvml try: import pynvml # type: ignore try: pynvml.nvmlInit() cnt = pynvml.nvmlDeviceGetCount() names = [] for i in range(cnt): h = pynvml.nvmlDeviceGetHandleByIndex(i) names.append(pynvml.nvmlDeviceGetName(h).decode("utf-8", errors="ignore")) drv = pynvml.nvmlSystemGetDriverVersion().decode("utf-8", errors="ignore") pynvml.nvmlShutdown() if cnt > 0: return True, f"NVIDIA driver {drv}; GPUs: {', '.join(names)}" except Exception: pass except Exception: pass # 3) Try nvidia-smi try: import subprocess res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=3) if res.returncode == 0 and res.stdout.strip(): return True, res.stdout.strip().splitlines()[0] except Exception: pass return False, "No NVIDIA driver/GPU detected" def duplicate_space_hint() -> str: space_id = os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID") if space_id: space_url = f"https://huggingface.co/spaces/{space_id}" dup_url = f"{space_url}?duplicate=true" return ( f"ℹ️ No NVIDIA driver detected. If you're on Hugging Face Spaces, " f"please duplicate this Space to GPU hardware: [Duplicate this Space]({dup_url})." ) return ( "ℹ️ No NVIDIA driver detected. To enable training, run on a machine with an NVIDIA GPU/driver " "or duplicate this Space on Hugging Face with GPU hardware." ) def _write_jsonl(rows: list[dict], path: Path) -> Path: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: for r in rows: f.write(json.dumps(r, ensure_ascii=False) + "\n") return path def _save_uploaded_dataset(files: list, transcripts: list[str]) -> str: dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user" dataset_dir.mkdir(parents=True, exist_ok=True) rows: list[dict] = [] for i, fpath in enumerate(files or []): if i >= len(transcripts): break rows.append({"audio_path": fpath, "text": transcripts[i] or ""}) jsonl_path = dataset_dir / "data.jsonl" _write_jsonl(rows, jsonl_path) return str(jsonl_path) def _save_recordings(recordings: list[tuple[int, list]], transcripts: list[str]) -> str: import soundfile as sf dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user" wav_dir = dataset_dir / "wavs" wav_dir.mkdir(parents=True, exist_ok=True) rows: list[dict] = [] for i, rec in enumerate(recordings or []): if rec is None: continue if i >= len(transcripts): break sr, data = rec out_path = wav_dir / f"rec_{i:04d}.wav" sf.write(str(out_path), data, sr) rows.append({"audio_path": str(out_path), "text": transcripts[i] or ""}) jsonl_path = dataset_dir / "data.jsonl" _write_jsonl(rows, jsonl_path) return str(jsonl_path) def start_voxtral_training( use_lora: bool, base_model: str, repo_short: str, jsonl_path: str, train_count: int, eval_count: int, batch_size: int, grad_accum: int, learning_rate: float, epochs: float, lora_r: int, lora_alpha: int, lora_dropout: float, freeze_audio_tower: bool, push_to_hub: bool, deploy_demo: bool, ) -> Generator[str, None, None]: env = os.environ.copy() write_token = env.get("HF_WRITE_TOKEN") or env.get("HF_TOKEN") read_token = env.get("HF_READ_TOKEN") username = get_username_from_token(write_token or "") or env.get("HF_USERNAME") or "" output_dir = PROJECT_ROOT / "outputs" / repo_short # 1) Train script = PROJECT_ROOT / ("scripts/train_lora.py" if use_lora else "scripts/train.py") args = [str(script)] if jsonl_path: args += ["--dataset-jsonl", jsonl_path] args += [ "--model-checkpoint", base_model, "--train-count", str(train_count), "--eval-count", str(eval_count), "--batch-size", str(batch_size), "--grad-accum", str(grad_accum), "--learning-rate", str(learning_rate), "--epochs", str(epochs), "--output-dir", str(output_dir), "--save-steps", "50", ] if use_lora: args += [ "--lora-r", str(lora_r), "--lora-alpha", str(lora_alpha), "--lora-dropout", str(lora_dropout), ] if freeze_audio_tower: args += ["--freeze-audio-tower"] for line in run_command_stream(args, env): yield line # 2) Push to Hub if push_to_hub: repo_name = f"{username}/{repo_short}" if username else repo_short push_args = [ str(PROJECT_ROOT / "scripts/push_to_huggingface.py"), str(output_dir), repo_name, ] for line in run_command_stream(push_args, env): yield line # 3) Deploy demo Space if deploy_demo and username: deploy_args = [ str(PROJECT_ROOT / "scripts/deploy_demo_space.py"), "--hf-token", write_token or "", "--hf-username", username, "--model-id", f"{username}/{repo_short}", "--demo-type", "voxtral", "--space-name", f"{repo_short}-demo", ] for line in run_command_stream(deploy_args, env): yield line def load_voxpopuli_phrases(language="en", max_phrases=None, split="train"): """Load phrases from VoxPopuli dataset. Args: language: Language code (e.g., 'en', 'de', 'fr', etc.) max_phrases: Maximum number of phrases to load (None for all available) split: Dataset split to use ('train', 'validation', 'test') Returns: List of normalized text phrases """ try: from datasets import load_dataset import random # Load the specified language dataset ds = load_dataset("facebook/voxpopuli", language, split=split) # Extract normalized text phrases phrases = [] for example in ds: text = example.get("normalized_text", "").strip() if text and len(text) > 10: # Filter out very short phrases phrases.append(text) # Shuffle and limit if specified if max_phrases: phrases = random.sample(phrases, min(max_phrases, len(phrases))) else: # If no limit, shuffle the entire list random.shuffle(phrases) return phrases except Exception as e: print(f"Error loading VoxPopuli phrases: {e}") # Fallback to some basic phrases if loading fails return [ "The quick brown fox jumps over the lazy dog.", "Please say your full name.", "Today is a good day to learn something new.", "Artificial intelligence helps with many tasks.", "I enjoy reading books and listening to music.", ] # Initialize phrases dynamically VOXPOPULI_LANGUAGE = "en" # Default to English ALL_PHRASES = load_voxpopuli_phrases(VOXPOPULI_LANGUAGE, max_phrases=None) with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo: has_gpu, gpu_msg = detect_nvidia_driver() if has_gpu: gr.HTML( f"""

✅ NVIDIA GPU ready — {gpu_msg}

Set HF_WRITE_TOKEN/HF_TOKEN in environment to enable Hub push.

""" ) else: hint_md = duplicate_space_hint() gr.HTML( f"""

⚠️ No NVIDIA GPU/driver detected — training requires a GPU runtime

{hint_md}

""" ) gr.Markdown(""" # 🎙️ Voxtral ASR Fine-tuning Read the phrases below and record them. Then start fine-tuning. """) jsonl_out = gr.Textbox(label="Dataset JSONL path", interactive=False, visible=True) # Language selection for VoxPopuli phrases voxpopuli_lang = gr.Dropdown( choices=["en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr", "sk", "sl", "et", "lt"], value="en", label="VoxPopuli Language", info="Select language for phrases from VoxPopuli dataset" ) # Recording grid with dynamic text readouts phrase_texts_state = gr.State(ALL_PHRASES) visible_rows_state = gr.State(10) # Start with 10 visible rows max_rows = len(ALL_PHRASES) # No cap on total rows phrase_markdowns: list[gr.Markdown] = [] rec_components = [] def create_recording_grid(phrases, visible_count=10): """Create recording grid components dynamically""" markdowns = [] recordings = [] for idx, phrase in enumerate(phrases): visible = idx < visible_count md = gr.Markdown(f"**{idx+1}. {phrase}**", visible=visible) markdowns.append(md) comp = gr.Audio(sources="microphone", type="numpy", label=f"Recording {idx+1}", visible=visible) recordings.append(comp) return markdowns, recordings # Initial grid creation with gr.Column(): phrase_markdowns, rec_components = create_recording_grid(ALL_PHRASES, 10) # Add more rows button add_rows_btn = gr.Button("➕ Add 10 More Rows", variant="secondary") def add_more_rows(current_visible, current_phrases): """Add 10 more rows by making them visible""" new_visible = min(current_visible + 10, len(current_phrases)) visibility_updates = [] for i in range(len(current_phrases)): if i < new_visible: visibility_updates.append(gr.update(visible=True)) else: visibility_updates.append(gr.update(visible=False)) return [new_visible] + visibility_updates def change_language(language): """Change the language and reload phrases from VoxPopuli""" new_phrases = load_voxpopuli_phrases(language, max_phrases=None) # Reset visible rows to 10 visible_count = min(10, len(new_phrases)) # Create combined updates for existing components (up to current length) current_len = len(phrase_markdowns) combined_updates = [] # Update existing components for i in range(current_len): if i < len(new_phrases): if i < visible_count: combined_updates.append(gr.update(value=f"**{i+1}. {new_phrases[i]}**", visible=True)) else: combined_updates.append(gr.update(visible=False)) else: combined_updates.append(gr.update(visible=False)) # If we have more phrases than components, we can't update them via Gradio # The interface will need to be reloaded for significantly different phrase counts return [new_phrases, visible_count] + combined_updates # Connect language change to phrase reloading voxpopuli_lang.change( change_language, inputs=[voxpopuli_lang], outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components ) add_rows_btn.click( add_more_rows, inputs=[visible_rows_state, phrase_texts_state], outputs=[visible_rows_state] + phrase_markdowns + rec_components ) # Advanced options accordion with gr.Accordion("Advanced options", open=False): base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model") use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)") with gr.Row(): batch_size = gr.Number(value=2, precision=0, label="Batch size") grad_accum = gr.Number(value=4, precision=0, label="Grad accum") with gr.Row(): learning_rate = gr.Number(value=5e-5, precision=6, label="Learning rate") epochs = gr.Number(value=3.0, precision=2, label="Epochs") with gr.Accordion("LoRA settings", open=False): lora_r = gr.Number(value=8, precision=0, label="LoRA r") lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha") lora_dropout = gr.Number(value=0.0, precision=3, label="LoRA dropout") freeze_audio_tower = gr.Checkbox(value=True, label="Freeze audio tower") with gr.Row(): train_count = gr.Number(value=100, precision=0, label="Train samples") eval_count = gr.Number(value=50, precision=0, label="Eval samples") repo_short = gr.Textbox(value=f"voxtral-finetune-{datetime.now().strftime('%Y%m%d_%H%M%S')}", label="Model repo (short)") push_to_hub = gr.Checkbox(value=True, label="Push to HF Hub after training") deploy_demo = gr.Checkbox(value=True, label="Deploy demo Space after push") gr.Markdown("### Upload audio + transcripts (optional)") upload_audio = gr.File(file_count="multiple", type="filepath", label="Upload WAV/FLAC files (optional)") transcripts_box = gr.Textbox(lines=6, label="Transcripts (one per line, aligned with files)") save_upload_btn = gr.Button("Save uploaded dataset") def _collect_upload(files, txt): lines = [s.strip() for s in (txt or "").splitlines() if s.strip()] return _save_uploaded_dataset(files or [], lines) save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [jsonl_out]) # Save recordings button save_rec_btn = gr.Button("Save recordings as dataset") def _collect_preloaded_recs(*recs_and_texts): import soundfile as sf dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user" wav_dir = dataset_dir / "wavs" wav_dir.mkdir(parents=True, exist_ok=True) rows: list[dict] = [] if not recs_and_texts: jsonl_path = dataset_dir / "data.jsonl" _write_jsonl(rows, jsonl_path) return str(jsonl_path) texts = recs_and_texts[-1] recs = recs_and_texts[:-1] for i, rec in enumerate(recs): if rec is None: continue sr, data = rec out_path = wav_dir / f"rec_{i:04d}.wav" sf.write(str(out_path), data, sr) # Use the full phrase list (ALL_PHRASES) instead of just PHRASES label_text = (texts[i] if isinstance(texts, list) and i < len(texts) else (ALL_PHRASES[i] if i < len(ALL_PHRASES) else "")) rows.append({"audio_path": str(out_path), "text": label_text}) jsonl_path = dataset_dir / "data.jsonl" _write_jsonl(rows, jsonl_path) return str(jsonl_path) save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_out]) # Quick sample from VoxPopuli (few random rows) with gr.Row(): vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "ro", "hu", "cs", "nl", "fi", "hr", "sk", "sl", "et", "lt"], value="en", label="VoxPopuli language") vp_samples = gr.Number(value=20, precision=0, label="Num samples") vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split") vp_btn = gr.Button("Use VoxPopuli sample") def _collect_voxpopuli(lang_code: str, num_samples: int, split: str): import sys # Workaround for dill on Python 3.13 expecting __main__ during import if "__main__" not in sys.modules: sys.modules["__main__"] = sys.modules[__name__] from datasets import load_dataset, Audio # type: ignore import random ds = load_dataset("facebook/voxpopuli", lang_code, split=split) ds = ds.cast_column("audio", Audio(sampling_rate=16000)) # shuffle and select total = len(ds) k = max(1, min(int(num_samples or 1), total)) ds = ds.shuffle(seed=random.randint(1, 10_000)) ds_sel = ds.select(range(k)) dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user" rows: list[dict] = [] texts: list[str] = [] for ex in ds_sel: audio = ex.get("audio") or {} path = audio.get("path") text = ex.get("normalized_text") or ex.get("raw_text") or "" if path and text is not None: rows.append({"audio_path": path, "text": text}) texts.append(str(text)) jsonl_path = dataset_dir / "data.jsonl" _write_jsonl(rows, jsonl_path) # Build markdown content updates for on-screen prompts combined_updates = [] for i in range(len(phrase_markdowns)): t = texts[i] if i < len(texts) else "" if i < len(texts): combined_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True)) else: combined_updates.append(gr.update(visible=False)) return (str(jsonl_path), texts, *combined_updates) vp_btn.click( _collect_voxpopuli, [vp_lang, vp_samples, vp_split], [jsonl_out, phrase_texts_state] + phrase_markdowns, ) start_btn = gr.Button("Start Fine-tuning") logs_box = gr.Textbox(label="Logs", lines=20) start_btn.click( start_voxtral_training, inputs=[ use_lora, base_model, repo_short, jsonl_out, train_count, eval_count, batch_size, grad_accum, learning_rate, epochs, lora_r, lora_alpha, lora_dropout, freeze_audio_tower, push_to_hub, deploy_demo, ], outputs=[logs_box], ) if __name__ == "__main__": server_port = int(os.environ.get("INTERFACE_PORT", "7860")) server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0") demo.queue().launch(server_name=server_name, server_port=server_port, mcp_server=True, ssr_mode=False)