#!/usr/bin/env python3 """ Voxtral ASR Fine-tuning Interface Features: - Collect a personal voice dataset (upload WAV/FLAC + transcripts or record mic audio) - Build a JSONL dataset ({audio_path, text}) at 16kHz - Fine-tune Voxtral (LoRA or full) with streamed logs - Push model to Hugging Face Hub - Deploy a Voxtral ASR demo Space Env tokens (optional): - HF_WRITE_TOKEN or HF_TOKEN: write access token - HF_READ_TOKEN: optional read token - HF_USERNAME: fallback username if not derivable from token """ from __future__ import annotations import os import json from pathlib import Path from datetime import datetime from typing import Any, Dict, Generator, Optional, Tuple import gradio as gr PROJECT_ROOT = Path(__file__).resolve().parent def get_python() -> str: import sys return sys.executable or "python" def get_username_from_token(token: str) -> Optional[str]: try: from huggingface_hub import HfApi # type: ignore api = HfApi(token=token) info = api.whoami() if isinstance(info, dict): return info.get("name") or info.get("username") if isinstance(info, str): return info except Exception: return None return None def run_command_stream(args: list[str], env: Dict[str, str], cwd: Optional[Path] = None) -> Generator[str, None, int]: import subprocess import shlex yield f"$ {' '.join(shlex.quote(a) for a in ([get_python()] + args))}" process = subprocess.Popen( [get_python()] + args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env, cwd=str(cwd or PROJECT_ROOT), bufsize=1, universal_newlines=True, ) assert process.stdout is not None for line in iter(process.stdout.readline, ""): yield line.rstrip() process.stdout.close() code = process.wait() yield f"[exit_code={code}]" return code def detect_nvidia_driver() -> Tuple[bool, str]: """Detect NVIDIA driver/GPU presence with multiple strategies. Returns (available, human_message). """ # 1) Try torch CUDA try: import torch # type: ignore if torch.cuda.is_available(): try: num = torch.cuda.device_count() names = [torch.cuda.get_device_name(i) for i in range(num)] return True, f"NVIDIA GPU detected: {', '.join(names)}" except Exception: return True, "NVIDIA GPU detected (torch.cuda available)" except Exception: pass # 2) Try NVML via pynvml try: import pynvml # type: ignore try: pynvml.nvmlInit() cnt = pynvml.nvmlDeviceGetCount() names = [] for i in range(cnt): h = pynvml.nvmlDeviceGetHandleByIndex(i) names.append(pynvml.nvmlDeviceGetName(h).decode("utf-8", errors="ignore")) drv = pynvml.nvmlSystemGetDriverVersion().decode("utf-8", errors="ignore") pynvml.nvmlShutdown() if cnt > 0: return True, f"NVIDIA driver {drv}; GPUs: {', '.join(names)}" except Exception: pass except Exception: pass # 3) Try nvidia-smi try: import subprocess res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=3) if res.returncode == 0 and res.stdout.strip(): return True, res.stdout.strip().splitlines()[0] except Exception: pass return False, "No NVIDIA driver/GPU detected" def duplicate_space_hint() -> str: space_id = os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID") if space_id: space_url = f"https://huggingface.co/spaces/{space_id}" dup_url = f"{space_url}?duplicate=true" return ( f"ℹ️ No NVIDIA driver detected. If you're on Hugging Face Spaces, " f"please duplicate this Space to GPU hardware: [Duplicate this Space]({dup_url})." ) return ( "ℹ️ No NVIDIA driver detected. To enable training, run on a machine with an NVIDIA GPU/driver " "or duplicate this Space on Hugging Face with GPU hardware." ) def _write_jsonl(rows: list[dict], path: Path) -> Path: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: for r in rows: f.write(json.dumps(r, ensure_ascii=False) + "\n") return path def _save_uploaded_dataset(files: list, transcripts: list[str]) -> str: dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user" dataset_dir.mkdir(parents=True, exist_ok=True) rows: list[dict] = [] for i, fpath in enumerate(files or []): if i >= len(transcripts): break rows.append({"audio_path": fpath, "text": transcripts[i] or ""}) jsonl_path = dataset_dir / "data.jsonl" _write_jsonl(rows, jsonl_path) return str(jsonl_path) def _save_recordings(recordings: list[tuple[int, list]], transcripts: list[str]) -> str: import soundfile as sf dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user" wav_dir = dataset_dir / "wavs" wav_dir.mkdir(parents=True, exist_ok=True) rows: list[dict] = [] for i, rec in enumerate(recordings or []): if rec is None: continue if i >= len(transcripts): break sr, data = rec out_path = wav_dir / f"rec_{i:04d}.wav" sf.write(str(out_path), data, sr) rows.append({"audio_path": str(out_path), "text": transcripts[i] or ""}) jsonl_path = dataset_dir / "data.jsonl" _write_jsonl(rows, jsonl_path) return str(jsonl_path) def start_voxtral_training( use_lora: bool, base_model: str, repo_short: str, jsonl_path: str, train_count: int, eval_count: int, batch_size: int, grad_accum: int, learning_rate: float, epochs: float, lora_r: int, lora_alpha: int, lora_dropout: float, freeze_audio_tower: bool, push_to_hub: bool, deploy_demo: bool, ) -> Generator[str, None, None]: env = os.environ.copy() write_token = env.get("HF_WRITE_TOKEN") or env.get("HF_TOKEN") read_token = env.get("HF_READ_TOKEN") username = get_username_from_token(write_token or "") or env.get("HF_USERNAME") or "" output_dir = PROJECT_ROOT / "outputs" / repo_short # 1) Train script = PROJECT_ROOT / ("scripts/train_lora.py" if use_lora else "scripts/train.py") args = [str(script)] if jsonl_path: args += ["--dataset-jsonl", jsonl_path] args += [ "--model-checkpoint", base_model, "--train-count", str(train_count), "--eval-count", str(eval_count), "--batch-size", str(batch_size), "--grad-accum", str(grad_accum), "--learning-rate", str(learning_rate), "--epochs", str(epochs), "--output-dir", str(output_dir), "--save-steps", "50", ] if use_lora: args += [ "--lora-r", str(lora_r), "--lora-alpha", str(lora_alpha), "--lora-dropout", str(lora_dropout), ] if freeze_audio_tower: args += ["--freeze-audio-tower"] for line in run_command_stream(args, env): yield line # 2) Push to Hub if push_to_hub: repo_name = f"{username}/{repo_short}" if username else repo_short push_args = [ str(PROJECT_ROOT / "scripts/push_to_huggingface.py"), str(output_dir), repo_name, ] for line in run_command_stream(push_args, env): yield line # 3) Deploy demo Space if deploy_demo and username: deploy_args = [ str(PROJECT_ROOT / "scripts/deploy_demo_space.py"), "--hf-token", write_token or "", "--hf-username", username, "--model-id", f"{username}/{repo_short}", "--demo-type", "voxtral", "--space-name", f"{repo_short}-demo", ] for line in run_command_stream(deploy_args, env): yield line def load_multilingual_phrases(language="en", max_phrases=None, split="train"): """Load phrases from NVIDIA Granary dataset. Uses the high-quality Granary dataset which contains speech recognition and translation data for 25 European languages. Args: language: Language code (e.g., 'en', 'de', 'fr', etc.) max_phrases: Maximum number of phrases to load (None for default 1000) split: Dataset split to use ('train', 'validation', 'test') Returns: List of transcription phrases from Granary dataset """ from datasets import load_dataset import random # Default to 1000 phrases if not specified if max_phrases is None: max_phrases = 1000 # Language code mapping for Granary dataset # Granary supports these language codes directly granary_supported_langs = { "en": "en", "de": "de", "fr": "fr", "es": "es", "it": "it", "pl": "pl", "pt": "pt", "nl": "nl", "ru": "ru", "ar": "ar", "zh": "zh", "ja": "ja", "ko": "ko", "da": "da", "sv": "sv", "no": "no", "fi": "fi", "et": "et", "lv": "lv", "lt": "lt", "sl": "sl", "sk": "sk", "cs": "cs", "hr": "hr", "bg": "bg", "uk": "uk", "ro": "ro", "hu": "hu", "el": "el", "mt": "mt" } # Map input language to Granary configuration granary_lang = granary_supported_langs.get(language, "en") # Default to English try: print(f"Loading phrases from NVIDIA Granary dataset for language: {language}") # Load Granary dataset with ASR (speech recognition) split # Use streaming to handle large datasets efficiently ds = load_dataset("nvidia/Granary", granary_lang, split="asr", streaming=True) phrases = [] count = 0 seen_phrases = set() # Sample phrases from the dataset for example in ds: if count >= max_phrases: break # Extract the text transcription text = example.get("text", "").strip() # Filter for quality phrases if (text and len(text) > 10 and # Minimum length len(text) < 200 and # Maximum length to avoid very long utterances text not in seen_phrases and # Avoid duplicates not text.isdigit() and # Avoid pure numbers not all(c in "0123456789., " for c in text)): # Avoid mostly numeric phrases.append(text) seen_phrases.add(text) count += 1 if phrases: # Shuffle the phrases for variety random.shuffle(phrases) print(f"Successfully loaded {len(phrases)} phrases from Granary dataset for {language}") return phrases else: print(f"No suitable phrases found in Granary dataset for {language}") raise Exception("No phrases found") except Exception as e: print(f"Granary dataset loading failed for {language}: {e}") # Fallback to basic phrases if Granary fails print("Using fallback phrases") fallback_phrases = [ "The quick brown fox jumps over the lazy dog.", "Please say your full name.", "Today is a good day to learn something new.", "Artificial intelligence helps with many tasks.", "I enjoy reading books and listening to music.", "This is a sample sentence for testing speech.", "Speak clearly and at a normal pace.", "Numbers like one, two, three are easy to say.", "The weather is sunny with a chance of rain.", "Thank you for taking the time to help.", "Hello, how are you today?", "I would like to order a pizza.", "The meeting is scheduled for tomorrow.", "Please call me back as soon as possible.", "Thank you for your assistance.", "Can you help me with this problem?", "I need to make a reservation.", "The weather looks beautiful outside.", "Let's go for a walk in the park.", "I enjoy listening to classical music.", ] if max_phrases: fallback_phrases = random.sample(fallback_phrases, min(max_phrases, len(fallback_phrases))) else: random.shuffle(fallback_phrases) return fallback_phrases # Initialize phrases dynamically DEFAULT_LANGUAGE = "en" # Default to English ALL_PHRASES = load_multilingual_phrases(DEFAULT_LANGUAGE, max_phrases=None) with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo: has_gpu, gpu_msg = detect_nvidia_driver() if has_gpu: gr.HTML( f"""
✅ NVIDIA GPU ready — {gpu_msg}
Set HF_WRITE_TOKEN/HF_TOKEN in environment to enable Hub push.
⚠️ No NVIDIA GPU/driver detected — training requires a GPU runtime
{hint_md}