Spaces:
Running
Running
Joseph Pollack
commited on
adds interface and dataset and auto push and demo
Browse files- interface.py +444 -0
- scripts/deploy_demo_space.py +952 -0
- scripts/generate_model_card.py +221 -0
- scripts/push_to_huggingface.py +700 -0
- train_lora.py → scripts/train.py +92 -62
- train.py → scripts/train_lora.py +107 -47
- templates/datasets/readme.md +171 -0
- templates/model_card.md +345 -0
- templates/spaces/demo_voxtral/README.md +23 -0
- templates/spaces/demo_voxtral/app.py +35 -0
- templates/spaces/demo_voxtral/requirements.txt +7 -0
interface.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Voxtral ASR Fine-tuning Interface
|
| 4 |
+
|
| 5 |
+
Features:
|
| 6 |
+
- Collect a personal voice dataset (upload WAV/FLAC + transcripts or record mic audio)
|
| 7 |
+
- Build a JSONL dataset ({audio_path, text}) at 16kHz
|
| 8 |
+
- Fine-tune Voxtral (LoRA or full) with streamed logs
|
| 9 |
+
- Push model to Hugging Face Hub
|
| 10 |
+
- Deploy a Voxtral ASR demo Space
|
| 11 |
+
|
| 12 |
+
Env tokens (optional):
|
| 13 |
+
- HF_WRITE_TOKEN or HF_TOKEN: write access token
|
| 14 |
+
- HF_READ_TOKEN: optional read token
|
| 15 |
+
- HF_USERNAME: fallback username if not derivable from token
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import json
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
from typing import Any, Dict, Generator, Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import gradio as gr
|
| 27 |
+
|
| 28 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_python() -> str:
|
| 32 |
+
import sys
|
| 33 |
+
return sys.executable or "python"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_username_from_token(token: str) -> Optional[str]:
|
| 37 |
+
try:
|
| 38 |
+
from huggingface_hub import HfApi # type: ignore
|
| 39 |
+
api = HfApi(token=token)
|
| 40 |
+
info = api.whoami()
|
| 41 |
+
if isinstance(info, dict):
|
| 42 |
+
return info.get("name") or info.get("username")
|
| 43 |
+
if isinstance(info, str):
|
| 44 |
+
return info
|
| 45 |
+
except Exception:
|
| 46 |
+
return None
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def run_command_stream(args: list[str], env: Dict[str, str], cwd: Optional[Path] = None) -> Generator[str, None, int]:
|
| 51 |
+
import subprocess
|
| 52 |
+
import shlex
|
| 53 |
+
yield f"$ {' '.join(shlex.quote(a) for a in ([get_python()] + args))}"
|
| 54 |
+
process = subprocess.Popen(
|
| 55 |
+
[get_python()] + args,
|
| 56 |
+
stdout=subprocess.PIPE,
|
| 57 |
+
stderr=subprocess.STDOUT,
|
| 58 |
+
text=True,
|
| 59 |
+
env=env,
|
| 60 |
+
cwd=str(cwd or PROJECT_ROOT),
|
| 61 |
+
bufsize=1,
|
| 62 |
+
universal_newlines=True,
|
| 63 |
+
)
|
| 64 |
+
assert process.stdout is not None
|
| 65 |
+
for line in iter(process.stdout.readline, ""):
|
| 66 |
+
yield line.rstrip()
|
| 67 |
+
process.stdout.close()
|
| 68 |
+
code = process.wait()
|
| 69 |
+
yield f"[exit_code={code}]"
|
| 70 |
+
return code
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def detect_nvidia_driver() -> Tuple[bool, str]:
|
| 74 |
+
"""Detect NVIDIA driver/GPU presence with multiple strategies.
|
| 75 |
+
|
| 76 |
+
Returns (available, human_message).
|
| 77 |
+
"""
|
| 78 |
+
# 1) Try torch CUDA
|
| 79 |
+
try:
|
| 80 |
+
import torch # type: ignore
|
| 81 |
+
if torch.cuda.is_available():
|
| 82 |
+
try:
|
| 83 |
+
num = torch.cuda.device_count()
|
| 84 |
+
names = [torch.cuda.get_device_name(i) for i in range(num)]
|
| 85 |
+
return True, f"NVIDIA GPU detected: {', '.join(names)}"
|
| 86 |
+
except Exception:
|
| 87 |
+
return True, "NVIDIA GPU detected (torch.cuda available)"
|
| 88 |
+
except Exception:
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
# 2) Try NVML via pynvml
|
| 92 |
+
try:
|
| 93 |
+
import pynvml # type: ignore
|
| 94 |
+
try:
|
| 95 |
+
pynvml.nvmlInit()
|
| 96 |
+
cnt = pynvml.nvmlDeviceGetCount()
|
| 97 |
+
names = []
|
| 98 |
+
for i in range(cnt):
|
| 99 |
+
h = pynvml.nvmlDeviceGetHandleByIndex(i)
|
| 100 |
+
names.append(pynvml.nvmlDeviceGetName(h).decode("utf-8", errors="ignore"))
|
| 101 |
+
drv = pynvml.nvmlSystemGetDriverVersion().decode("utf-8", errors="ignore")
|
| 102 |
+
pynvml.nvmlShutdown()
|
| 103 |
+
if cnt > 0:
|
| 104 |
+
return True, f"NVIDIA driver {drv}; GPUs: {', '.join(names)}"
|
| 105 |
+
except Exception:
|
| 106 |
+
pass
|
| 107 |
+
except Exception:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
# 3) Try nvidia-smi
|
| 111 |
+
try:
|
| 112 |
+
import subprocess
|
| 113 |
+
res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=3)
|
| 114 |
+
if res.returncode == 0 and res.stdout.strip():
|
| 115 |
+
return True, res.stdout.strip().splitlines()[0]
|
| 116 |
+
except Exception:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
return False, "No NVIDIA driver/GPU detected"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def duplicate_space_hint() -> str:
|
| 123 |
+
space_id = os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID")
|
| 124 |
+
if space_id:
|
| 125 |
+
space_url = f"https://huggingface.co/spaces/{space_id}"
|
| 126 |
+
dup_url = f"{space_url}?duplicate=true"
|
| 127 |
+
return (
|
| 128 |
+
f"ℹ️ No NVIDIA driver detected. If you're on Hugging Face Spaces, "
|
| 129 |
+
f"please duplicate this Space to GPU hardware: [Duplicate this Space]({dup_url})."
|
| 130 |
+
)
|
| 131 |
+
return (
|
| 132 |
+
"ℹ️ No NVIDIA driver detected. To enable training, run on a machine with an NVIDIA GPU/driver "
|
| 133 |
+
"or duplicate this Space on Hugging Face with GPU hardware."
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _write_jsonl(rows: list[dict], path: Path) -> Path:
|
| 138 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 139 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 140 |
+
for r in rows:
|
| 141 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 142 |
+
return path
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _save_uploaded_dataset(files: list, transcripts: list[str]) -> str:
|
| 146 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 147 |
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
rows: list[dict] = []
|
| 149 |
+
for i, fpath in enumerate(files or []):
|
| 150 |
+
if i >= len(transcripts):
|
| 151 |
+
break
|
| 152 |
+
rows.append({"audio_path": fpath, "text": transcripts[i] or ""})
|
| 153 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
| 154 |
+
_write_jsonl(rows, jsonl_path)
|
| 155 |
+
return str(jsonl_path)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _save_recordings(recordings: list[tuple[int, list]], transcripts: list[str]) -> str:
|
| 159 |
+
import soundfile as sf
|
| 160 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 161 |
+
wav_dir = dataset_dir / "wavs"
|
| 162 |
+
wav_dir.mkdir(parents=True, exist_ok=True)
|
| 163 |
+
rows: list[dict] = []
|
| 164 |
+
for i, rec in enumerate(recordings or []):
|
| 165 |
+
if rec is None:
|
| 166 |
+
continue
|
| 167 |
+
if i >= len(transcripts):
|
| 168 |
+
break
|
| 169 |
+
sr, data = rec
|
| 170 |
+
out_path = wav_dir / f"rec_{i:04d}.wav"
|
| 171 |
+
sf.write(str(out_path), data, sr)
|
| 172 |
+
rows.append({"audio_path": str(out_path), "text": transcripts[i] or ""})
|
| 173 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
| 174 |
+
_write_jsonl(rows, jsonl_path)
|
| 175 |
+
return str(jsonl_path)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def start_voxtral_training(
|
| 179 |
+
use_lora: bool,
|
| 180 |
+
base_model: str,
|
| 181 |
+
repo_short: str,
|
| 182 |
+
jsonl_path: str,
|
| 183 |
+
train_count: int,
|
| 184 |
+
eval_count: int,
|
| 185 |
+
batch_size: int,
|
| 186 |
+
grad_accum: int,
|
| 187 |
+
learning_rate: float,
|
| 188 |
+
epochs: float,
|
| 189 |
+
lora_r: int,
|
| 190 |
+
lora_alpha: int,
|
| 191 |
+
lora_dropout: float,
|
| 192 |
+
freeze_audio_tower: bool,
|
| 193 |
+
push_to_hub: bool,
|
| 194 |
+
deploy_demo: bool,
|
| 195 |
+
) -> Generator[str, None, None]:
|
| 196 |
+
env = os.environ.copy()
|
| 197 |
+
write_token = env.get("HF_WRITE_TOKEN") or env.get("HF_TOKEN")
|
| 198 |
+
read_token = env.get("HF_READ_TOKEN")
|
| 199 |
+
username = get_username_from_token(write_token or "") or env.get("HF_USERNAME") or ""
|
| 200 |
+
output_dir = PROJECT_ROOT / "outputs" / repo_short
|
| 201 |
+
|
| 202 |
+
# 1) Train
|
| 203 |
+
script = PROJECT_ROOT / ("scripts/train_lora.py" if use_lora else "scripts/train.py")
|
| 204 |
+
args = [str(script)]
|
| 205 |
+
if jsonl_path:
|
| 206 |
+
args += ["--dataset-jsonl", jsonl_path]
|
| 207 |
+
args += [
|
| 208 |
+
"--model-checkpoint", base_model,
|
| 209 |
+
"--train-count", str(train_count),
|
| 210 |
+
"--eval-count", str(eval_count),
|
| 211 |
+
"--batch-size", str(batch_size),
|
| 212 |
+
"--grad-accum", str(grad_accum),
|
| 213 |
+
"--learning-rate", str(learning_rate),
|
| 214 |
+
"--epochs", str(epochs),
|
| 215 |
+
"--output-dir", str(output_dir),
|
| 216 |
+
"--save-steps", "50",
|
| 217 |
+
]
|
| 218 |
+
if use_lora:
|
| 219 |
+
args += [
|
| 220 |
+
"--lora-r", str(lora_r),
|
| 221 |
+
"--lora-alpha", str(lora_alpha),
|
| 222 |
+
"--lora-dropout", str(lora_dropout),
|
| 223 |
+
]
|
| 224 |
+
if freeze_audio_tower:
|
| 225 |
+
args += ["--freeze-audio-tower"]
|
| 226 |
+
for line in run_command_stream(args, env):
|
| 227 |
+
yield line
|
| 228 |
+
|
| 229 |
+
# 2) Push to Hub
|
| 230 |
+
if push_to_hub:
|
| 231 |
+
repo_name = f"{username}/{repo_short}" if username else repo_short
|
| 232 |
+
push_args = [
|
| 233 |
+
str(PROJECT_ROOT / "scripts/push_to_huggingface.py"),
|
| 234 |
+
str(output_dir),
|
| 235 |
+
repo_name,
|
| 236 |
+
]
|
| 237 |
+
for line in run_command_stream(push_args, env):
|
| 238 |
+
yield line
|
| 239 |
+
|
| 240 |
+
# 3) Deploy demo Space
|
| 241 |
+
if deploy_demo and username:
|
| 242 |
+
deploy_args = [
|
| 243 |
+
str(PROJECT_ROOT / "scripts/deploy_demo_space.py"),
|
| 244 |
+
"--hf-token", write_token or "",
|
| 245 |
+
"--hf-username", username,
|
| 246 |
+
"--model-id", f"{username}/{repo_short}",
|
| 247 |
+
"--demo-type", "voxtral",
|
| 248 |
+
"--space-name", f"{repo_short}-demo",
|
| 249 |
+
]
|
| 250 |
+
for line in run_command_stream(deploy_args, env):
|
| 251 |
+
yield line
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
PHRASES = [
|
| 255 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 256 |
+
"Please say your full name.",
|
| 257 |
+
"Today is a good day to learn something new.",
|
| 258 |
+
"Artificial intelligence helps with many tasks.",
|
| 259 |
+
"I enjoy reading books and listening to music.",
|
| 260 |
+
"This is a sample sentence for testing speech.",
|
| 261 |
+
"Speak clearly and at a normal pace.",
|
| 262 |
+
"Numbers like one, two, three are easy to say.",
|
| 263 |
+
"The weather is sunny with a chance of rain.",
|
| 264 |
+
"Thank you for taking the time to help.",
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
| 268 |
+
has_gpu, gpu_msg = detect_nvidia_driver()
|
| 269 |
+
if has_gpu:
|
| 270 |
+
gr.HTML(
|
| 271 |
+
f"""
|
| 272 |
+
<div style="background-color: rgba(59, 130, 246, 0.1); border: 1px solid rgba(59, 130, 246, 0.3); border-radius: 8px; padding: 12px; margin-bottom: 16px; text-align: center;">
|
| 273 |
+
<p style="color: rgb(59, 130, 246); margin: 0; font-size: 14px; font-weight: 600;">
|
| 274 |
+
✅ NVIDIA GPU ready — {gpu_msg}
|
| 275 |
+
</p>
|
| 276 |
+
<p style="color: rgb(59, 130, 246); margin: 6px 0 0; font-size: 12px;">
|
| 277 |
+
Set HF_WRITE_TOKEN/HF_TOKEN in environment to enable Hub push.
|
| 278 |
+
</p>
|
| 279 |
+
</div>
|
| 280 |
+
"""
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
hint_md = duplicate_space_hint()
|
| 284 |
+
gr.HTML(
|
| 285 |
+
f"""
|
| 286 |
+
<div style="background-color: rgba(245, 158, 11, 0.1); border: 1px solid rgba(245, 158, 11, 0.3); border-radius: 8px; padding: 12px; margin-bottom: 16px; text-align: center;">
|
| 287 |
+
<p style="color: rgb(234, 88, 12); margin: 0; font-size: 14px; font-weight: 600;">
|
| 288 |
+
⚠️ No NVIDIA GPU/driver detected — training requires a GPU runtime
|
| 289 |
+
</p>
|
| 290 |
+
<p style="color: rgb(234, 88, 12); margin: 6px 0 0; font-size: 12px;">
|
| 291 |
+
{hint_md}
|
| 292 |
+
</p>
|
| 293 |
+
</div>
|
| 294 |
+
"""
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
gr.Markdown("""
|
| 298 |
+
# 🎙️ Voxtral ASR Fine-tuning
|
| 299 |
+
Read the phrases below and record them. Then start fine-tuning.
|
| 300 |
+
""")
|
| 301 |
+
|
| 302 |
+
jsonl_out = gr.Textbox(label="Dataset JSONL path", interactive=False, visible=True)
|
| 303 |
+
|
| 304 |
+
# Recording grid with dynamic text readouts
|
| 305 |
+
phrase_texts_state = gr.State(PHRASES)
|
| 306 |
+
phrase_markdowns: list[gr.Markdown] = []
|
| 307 |
+
rec_components = []
|
| 308 |
+
with gr.Column():
|
| 309 |
+
for idx, phrase in enumerate(PHRASES):
|
| 310 |
+
md = gr.Markdown(f"**{idx+1}. {phrase}**")
|
| 311 |
+
phrase_markdowns.append(md)
|
| 312 |
+
comp = gr.Audio(sources="microphone", type="numpy", label=f"Recording {idx+1}")
|
| 313 |
+
rec_components.append(comp)
|
| 314 |
+
|
| 315 |
+
# Advanced options accordion
|
| 316 |
+
with gr.Accordion("Advanced options", open=False):
|
| 317 |
+
base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
|
| 318 |
+
use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
|
| 319 |
+
with gr.Row():
|
| 320 |
+
batch_size = gr.Number(value=2, precision=0, label="Batch size")
|
| 321 |
+
grad_accum = gr.Number(value=4, precision=0, label="Grad accum")
|
| 322 |
+
with gr.Row():
|
| 323 |
+
learning_rate = gr.Number(value=5e-5, precision=6, label="Learning rate")
|
| 324 |
+
epochs = gr.Number(value=3.0, precision=2, label="Epochs")
|
| 325 |
+
with gr.Accordion("LoRA settings", open=False):
|
| 326 |
+
lora_r = gr.Number(value=8, precision=0, label="LoRA r")
|
| 327 |
+
lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
|
| 328 |
+
lora_dropout = gr.Number(value=0.0, precision=3, label="LoRA dropout")
|
| 329 |
+
freeze_audio_tower = gr.Checkbox(value=True, label="Freeze audio tower")
|
| 330 |
+
with gr.Row():
|
| 331 |
+
train_count = gr.Number(value=100, precision=0, label="Train samples")
|
| 332 |
+
eval_count = gr.Number(value=50, precision=0, label="Eval samples")
|
| 333 |
+
repo_short = gr.Textbox(value=f"voxtral-finetune-{datetime.now().strftime('%Y%m%d_%H%M%S')}", label="Model repo (short)")
|
| 334 |
+
push_to_hub = gr.Checkbox(value=True, label="Push to HF Hub after training")
|
| 335 |
+
deploy_demo = gr.Checkbox(value=True, label="Deploy demo Space after push")
|
| 336 |
+
|
| 337 |
+
gr.Markdown("### Upload audio + transcripts (optional)")
|
| 338 |
+
upload_audio = gr.File(file_count="multiple", type="filepath", label="Upload WAV/FLAC files (optional)")
|
| 339 |
+
transcripts_box = gr.Textbox(lines=6, label="Transcripts (one per line, aligned with files)")
|
| 340 |
+
save_upload_btn = gr.Button("Save uploaded dataset")
|
| 341 |
+
|
| 342 |
+
def _collect_upload(files, txt):
|
| 343 |
+
lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
|
| 344 |
+
return _save_uploaded_dataset(files or [], lines)
|
| 345 |
+
|
| 346 |
+
save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [jsonl_out])
|
| 347 |
+
|
| 348 |
+
# Save recordings button
|
| 349 |
+
save_rec_btn = gr.Button("Save recordings as dataset")
|
| 350 |
+
|
| 351 |
+
def _collect_preloaded_recs(*recs_and_texts):
|
| 352 |
+
import soundfile as sf
|
| 353 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 354 |
+
wav_dir = dataset_dir / "wavs"
|
| 355 |
+
wav_dir.mkdir(parents=True, exist_ok=True)
|
| 356 |
+
rows: list[dict] = []
|
| 357 |
+
if not recs_and_texts:
|
| 358 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
| 359 |
+
_write_jsonl(rows, jsonl_path)
|
| 360 |
+
return str(jsonl_path)
|
| 361 |
+
texts = recs_and_texts[-1]
|
| 362 |
+
recs = recs_and_texts[:-1]
|
| 363 |
+
for i, rec in enumerate(recs):
|
| 364 |
+
if rec is None:
|
| 365 |
+
continue
|
| 366 |
+
sr, data = rec
|
| 367 |
+
out_path = wav_dir / f"rec_{i:04d}.wav"
|
| 368 |
+
sf.write(str(out_path), data, sr)
|
| 369 |
+
label_text = (texts[i] if isinstance(texts, list) and i < len(texts) else (PHRASES[i] if i < len(PHRASES) else ""))
|
| 370 |
+
rows.append({"audio_path": str(out_path), "text": label_text})
|
| 371 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
| 372 |
+
_write_jsonl(rows, jsonl_path)
|
| 373 |
+
return str(jsonl_path)
|
| 374 |
+
|
| 375 |
+
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_out])
|
| 376 |
+
|
| 377 |
+
# Quick sample from VoxPopuli (few random rows)
|
| 378 |
+
with gr.Row():
|
| 379 |
+
vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "ro", "hu", "cs", "nl", "fi", "hr", "sk", "sl", "et", "lt"], value="en", label="VoxPopuli language")
|
| 380 |
+
vp_samples = gr.Number(value=20, precision=0, label="Num samples")
|
| 381 |
+
vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
|
| 382 |
+
vp_btn = gr.Button("Use VoxPopuli sample")
|
| 383 |
+
|
| 384 |
+
def _collect_voxpopuli(lang_code: str, num_samples: int, split: str):
|
| 385 |
+
import sys
|
| 386 |
+
# Workaround for dill on Python 3.13 expecting __main__ during import
|
| 387 |
+
if "__main__" not in sys.modules:
|
| 388 |
+
sys.modules["__main__"] = sys.modules[__name__]
|
| 389 |
+
from datasets import load_dataset, Audio # type: ignore
|
| 390 |
+
import random
|
| 391 |
+
ds = load_dataset("facebook/voxpopuli", lang_code, split=split)
|
| 392 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 393 |
+
# shuffle and select
|
| 394 |
+
total = len(ds)
|
| 395 |
+
k = max(1, min(int(num_samples or 1), total))
|
| 396 |
+
ds = ds.shuffle(seed=random.randint(1, 10_000))
|
| 397 |
+
ds_sel = ds.select(range(k))
|
| 398 |
+
|
| 399 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 400 |
+
rows: list[dict] = []
|
| 401 |
+
texts: list[str] = []
|
| 402 |
+
for ex in ds_sel:
|
| 403 |
+
audio = ex.get("audio") or {}
|
| 404 |
+
path = audio.get("path")
|
| 405 |
+
text = ex.get("normalized_text") or ex.get("raw_text") or ""
|
| 406 |
+
if path and text is not None:
|
| 407 |
+
rows.append({"audio_path": path, "text": text})
|
| 408 |
+
texts.append(str(text))
|
| 409 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
| 410 |
+
_write_jsonl(rows, jsonl_path)
|
| 411 |
+
# Build markdown content updates for on-screen prompts
|
| 412 |
+
md_updates = []
|
| 413 |
+
for i in range(len(phrase_markdowns)):
|
| 414 |
+
t = texts[i] if i < len(texts) else ""
|
| 415 |
+
md_updates.append(f"**{i+1}. {t}**")
|
| 416 |
+
return (str(jsonl_path), texts, *md_updates)
|
| 417 |
+
|
| 418 |
+
vp_btn.click(
|
| 419 |
+
_collect_voxpopuli,
|
| 420 |
+
[vp_lang, vp_samples, vp_split],
|
| 421 |
+
[jsonl_out, phrase_texts_state] + phrase_markdowns,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
start_btn = gr.Button("Start Fine-tuning")
|
| 425 |
+
logs_box = gr.Textbox(label="Logs", lines=20)
|
| 426 |
+
|
| 427 |
+
start_btn.click(
|
| 428 |
+
start_voxtral_training,
|
| 429 |
+
inputs=[
|
| 430 |
+
use_lora, base_model, repo_short, jsonl_out, train_count, eval_count,
|
| 431 |
+
batch_size, grad_accum, learning_rate, epochs,
|
| 432 |
+
lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
|
| 433 |
+
push_to_hub, deploy_demo,
|
| 434 |
+
],
|
| 435 |
+
outputs=[logs_box],
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
| 441 |
+
server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
|
| 442 |
+
demo.queue().launch(server_name=server_name, server_port=server_port, mcp_server=True)
|
| 443 |
+
|
| 444 |
+
|
scripts/deploy_demo_space.py
ADDED
|
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Demo Space Deployment Script
|
| 4 |
+
Deploys a Gradio demo space to Hugging Face Spaces for testing the fine-tuned model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import argparse
|
| 12 |
+
import subprocess
|
| 13 |
+
import requests
|
| 14 |
+
import tempfile
|
| 15 |
+
import shutil
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Optional, Dict, Any
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
# Import Hugging Face Hub API
|
| 21 |
+
try:
|
| 22 |
+
from huggingface_hub import HfApi, create_repo, upload_file
|
| 23 |
+
HF_HUB_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
HF_HUB_AVAILABLE = False
|
| 26 |
+
print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
|
| 27 |
+
|
| 28 |
+
# Add src to path for imports
|
| 29 |
+
sys.path.append(str(Path(__file__).parent.parent / "src"))
|
| 30 |
+
|
| 31 |
+
from config import SmolLM3Config
|
| 32 |
+
|
| 33 |
+
# Setup logging
|
| 34 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
class DemoSpaceDeployer:
|
| 38 |
+
"""Deploy demo space to Hugging Face Spaces"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
hf_token: str,
|
| 43 |
+
# Token used for API actions that create/update the Space (write perms)
|
| 44 |
+
hf_username: str,
|
| 45 |
+
model_id: str,
|
| 46 |
+
subfolder: str = "int4",
|
| 47 |
+
space_name: Optional[str] = None,
|
| 48 |
+
demo_type: Optional[str] = None,
|
| 49 |
+
config_file: Optional[str] = None,
|
| 50 |
+
# Optional token used as the Space's HF_TOKEN secret (read-only recommended)
|
| 51 |
+
space_secret_token: Optional[str] = None,
|
| 52 |
+
# Examples configuration
|
| 53 |
+
examples_type: Optional[str] = None,
|
| 54 |
+
disable_examples: Optional[bool] = None,
|
| 55 |
+
examples_json: Optional[str] = None,
|
| 56 |
+
# Branding overrides
|
| 57 |
+
brand_owner_name: Optional[str] = None,
|
| 58 |
+
brand_team_name: Optional[str] = None,
|
| 59 |
+
brand_discord_url: Optional[str] = None,
|
| 60 |
+
brand_hf_org: Optional[str] = None,
|
| 61 |
+
brand_hf_label: Optional[str] = None,
|
| 62 |
+
brand_hf_url: Optional[str] = None,
|
| 63 |
+
brand_gh_org: Optional[str] = None,
|
| 64 |
+
brand_gh_label: Optional[str] = None,
|
| 65 |
+
brand_gh_url: Optional[str] = None,
|
| 66 |
+
brand_project_name: Optional[str] = None,
|
| 67 |
+
brand_project_url: Optional[str] = None,
|
| 68 |
+
):
|
| 69 |
+
self.hf_token = hf_token
|
| 70 |
+
# The token we will store in the Space secrets. Defaults to hf_token if not provided
|
| 71 |
+
self.space_secret_token = space_secret_token or hf_token
|
| 72 |
+
self.hf_username = hf_username
|
| 73 |
+
# Allow passing just a repo name without username and auto-prefix
|
| 74 |
+
self.model_id = model_id if "/" in model_id else f"{hf_username}/{model_id}"
|
| 75 |
+
self.subfolder = subfolder
|
| 76 |
+
self.space_name = space_name or f"{self.model_id.split('/')[-1]}-demo"
|
| 77 |
+
self.space_id = f"{hf_username}/{self.space_name}"
|
| 78 |
+
self.space_url = f"https://huggingface.co/spaces/{self.space_id}"
|
| 79 |
+
self.config_file = config_file
|
| 80 |
+
|
| 81 |
+
# Config-derived context
|
| 82 |
+
self.system_message: Optional[str] = None
|
| 83 |
+
self.developer_message: Optional[str] = None
|
| 84 |
+
self.model_identity: Optional[str] = None
|
| 85 |
+
self.reasoning_effort: Optional[str] = None
|
| 86 |
+
# Examples context
|
| 87 |
+
self.examples_type: Optional[str] = (examples_type or None)
|
| 88 |
+
self.disable_examples: Optional[bool] = (disable_examples if disable_examples is not None else None)
|
| 89 |
+
self.examples_json: Optional[str] = (examples_json or None)
|
| 90 |
+
|
| 91 |
+
# Determine demo type from model_id if not provided
|
| 92 |
+
if demo_type is None:
|
| 93 |
+
demo_type = self._detect_demo_type(model_id)
|
| 94 |
+
|
| 95 |
+
# Template paths based on model type
|
| 96 |
+
self.demo_type = demo_type
|
| 97 |
+
self.template_dir = Path(__file__).parent.parent / "templates" / "spaces" / f"demo_{demo_type}"
|
| 98 |
+
self.workspace_dir = Path.cwd()
|
| 99 |
+
|
| 100 |
+
# Initialize HF API
|
| 101 |
+
if HF_HUB_AVAILABLE:
|
| 102 |
+
self.api = HfApi(token=self.hf_token)
|
| 103 |
+
else:
|
| 104 |
+
self.api = None
|
| 105 |
+
logger.warning("huggingface_hub not available, using CLI fallback")
|
| 106 |
+
|
| 107 |
+
# Load optional config-specified messages
|
| 108 |
+
try:
|
| 109 |
+
self._load_config_messages()
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.warning(f"Could not load config messages: {e}")
|
| 112 |
+
|
| 113 |
+
# Branding defaults (can be overridden via CLI)
|
| 114 |
+
self.brand_owner_name = brand_owner_name or self.hf_username or "Tonic"
|
| 115 |
+
self.brand_team_name = brand_team_name or f"Team{self.brand_owner_name}"
|
| 116 |
+
self.brand_discord_url = brand_discord_url or "https://discord.gg/qdfnvSPcqP"
|
| 117 |
+
# HF org/link
|
| 118 |
+
_default_hf_org = brand_hf_org or self.hf_username or "MultiTransformer"
|
| 119 |
+
self.brand_hf_org = _default_hf_org
|
| 120 |
+
self.brand_hf_label = brand_hf_label or self.brand_hf_org
|
| 121 |
+
self.brand_hf_url = brand_hf_url or f"https://huggingface.co/{self.brand_hf_org}"
|
| 122 |
+
# GitHub org/link
|
| 123 |
+
_default_gh_org = brand_gh_org or self.hf_username or "tonic-ai"
|
| 124 |
+
self.brand_gh_org = _default_gh_org
|
| 125 |
+
self.brand_gh_label = brand_gh_label or self.brand_gh_org
|
| 126 |
+
self.brand_gh_url = brand_gh_url or f"https://github.com/{self.brand_gh_org}"
|
| 127 |
+
# Project link
|
| 128 |
+
self.brand_project_name = brand_project_name or "MultiTonic"
|
| 129 |
+
self.brand_project_url = brand_project_url or "https://github.com/MultiTonic"
|
| 130 |
+
|
| 131 |
+
def _load_config_messages(self) -> None:
|
| 132 |
+
"""Load system/developer/model_identity from a training config file if provided."""
|
| 133 |
+
if not self.config_file:
|
| 134 |
+
return
|
| 135 |
+
cfg_path = Path(self.config_file)
|
| 136 |
+
if not cfg_path.exists():
|
| 137 |
+
logger.warning(f"Config file not found: {cfg_path}")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# Ensure project root and config dir are importable for relative imports inside config
|
| 141 |
+
project_root = Path(__file__).parent.parent
|
| 142 |
+
if str(project_root) not in sys.path:
|
| 143 |
+
sys.path.insert(0, str(project_root))
|
| 144 |
+
cfg_dir = project_root / "config"
|
| 145 |
+
if str(cfg_dir) not in sys.path:
|
| 146 |
+
sys.path.insert(0, str(cfg_dir))
|
| 147 |
+
|
| 148 |
+
import importlib.util
|
| 149 |
+
spec = importlib.util.spec_from_file_location("config_module", str(cfg_path))
|
| 150 |
+
if not spec or not spec.loader:
|
| 151 |
+
return
|
| 152 |
+
module = importlib.util.module_from_spec(spec)
|
| 153 |
+
spec.loader.exec_module(module) # type: ignore
|
| 154 |
+
cfg = getattr(module, "config", None)
|
| 155 |
+
if cfg is None:
|
| 156 |
+
return
|
| 157 |
+
self.system_message = getattr(cfg, "system_message", None)
|
| 158 |
+
self.developer_message = getattr(cfg, "developer_message", None)
|
| 159 |
+
chat_kwargs = getattr(cfg, "chat_template_kwargs", None)
|
| 160 |
+
if isinstance(chat_kwargs, dict):
|
| 161 |
+
self.model_identity = chat_kwargs.get("model_identity")
|
| 162 |
+
self.reasoning_effort = chat_kwargs.get("reasoning_effort")
|
| 163 |
+
|
| 164 |
+
def _detect_demo_type(self, model_id: str) -> str:
|
| 165 |
+
"""Detect the appropriate demo type based on model ID"""
|
| 166 |
+
model_id_lower = model_id.lower()
|
| 167 |
+
|
| 168 |
+
# Voxtral ASR models
|
| 169 |
+
if "voxtral" in model_id_lower:
|
| 170 |
+
logger.info(f"Detected Voxtral model, using demo_voxtral template")
|
| 171 |
+
return "voxtral"
|
| 172 |
+
|
| 173 |
+
# Check for GPT-OSS models
|
| 174 |
+
if "gpt-oss" in model_id_lower or "gpt_oss" in model_id_lower:
|
| 175 |
+
logger.info(f"Detected GPT-OSS model, using demo_gpt template")
|
| 176 |
+
return "gpt"
|
| 177 |
+
|
| 178 |
+
# Check for SmolLM models (default)
|
| 179 |
+
elif "smollm" in model_id_lower or "smol" in model_id_lower:
|
| 180 |
+
logger.info(f"Detected SmolLM model, using demo_smol template")
|
| 181 |
+
return "smol"
|
| 182 |
+
|
| 183 |
+
# Default to SmolLM for unknown models
|
| 184 |
+
else:
|
| 185 |
+
logger.info(f"Unknown model type, defaulting to demo_smol template")
|
| 186 |
+
return "smol"
|
| 187 |
+
|
| 188 |
+
def _generate_env_setup(self) -> str:
|
| 189 |
+
"""Generate environment variable setup based on demo type and model"""
|
| 190 |
+
if self.demo_type == "gpt":
|
| 191 |
+
# For GPT-OSS models, we need more sophisticated environment setup
|
| 192 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
| 193 |
+
import json as _json
|
| 194 |
+
env_setup = f"""
|
| 195 |
+
# Environment variables for GPT-OSS model configuration
|
| 196 |
+
import os
|
| 197 |
+
os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
|
| 198 |
+
os.environ['LORA_MODEL_ID'] = {_json.dumps(self.model_id)}
|
| 199 |
+
os.environ['BASE_MODEL_ID'] = 'openai/gpt-oss-20b'
|
| 200 |
+
os.environ['MODEL_SUBFOLDER'] = {_json.dumps(self.subfolder if self.subfolder else "")}
|
| 201 |
+
os.environ['MODEL_NAME'] = {_json.dumps(model_name)}
|
| 202 |
+
os.environ['MODEL_IDENTITY'] = {_json.dumps(self.model_identity or "")}
|
| 203 |
+
os.environ['SYSTEM_MESSAGE'] = {_json.dumps(self.system_message or (self.model_identity or ""))}
|
| 204 |
+
os.environ['DEVELOPER_MESSAGE'] = {_json.dumps(self.developer_message or "")}
|
| 205 |
+
os.environ['REASONING_EFFORT'] = {_json.dumps((self.reasoning_effort or "medium"))}
|
| 206 |
+
{"os.environ['EXAMPLES_TYPE'] = " + _json.dumps(self.examples_type) + "\n" if self.examples_type else ''}
|
| 207 |
+
{"os.environ['DISABLE_EXAMPLES'] = 'true'\n" if self.disable_examples else ("os.environ['DISABLE_EXAMPLES'] = 'false'\n" if self.disable_examples is not None else '')}
|
| 208 |
+
{"os.environ['EXAMPLES_JSON'] = " + _json.dumps(self.examples_json) + "\n" if self.examples_json else ''}
|
| 209 |
+
|
| 210 |
+
# Branding/owner variables
|
| 211 |
+
os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
|
| 212 |
+
os.environ['BRAND_OWNER_NAME'] = {_json.dumps(self.brand_owner_name)}
|
| 213 |
+
os.environ['BRAND_TEAM_NAME'] = {_json.dumps(self.brand_team_name)}
|
| 214 |
+
os.environ['BRAND_DISCORD_URL'] = {_json.dumps(self.brand_discord_url)}
|
| 215 |
+
os.environ['BRAND_HF_ORG'] = {_json.dumps(self.brand_hf_org)}
|
| 216 |
+
os.environ['BRAND_HF_LABEL'] = {_json.dumps(self.brand_hf_label)}
|
| 217 |
+
os.environ['BRAND_HF_URL'] = {_json.dumps(self.brand_hf_url)}
|
| 218 |
+
os.environ['BRAND_GH_ORG'] = {_json.dumps(self.brand_gh_org)}
|
| 219 |
+
os.environ['BRAND_GH_LABEL'] = {_json.dumps(self.brand_gh_label)}
|
| 220 |
+
os.environ['BRAND_GH_URL'] = {_json.dumps(self.brand_gh_url)}
|
| 221 |
+
os.environ['BRAND_PROJECT_NAME'] = {_json.dumps(self.brand_project_name)}
|
| 222 |
+
os.environ['BRAND_PROJECT_URL'] = {_json.dumps(self.brand_project_url)}
|
| 223 |
+
|
| 224 |
+
"""
|
| 225 |
+
elif self.demo_type == "voxtral":
|
| 226 |
+
import json as _json
|
| 227 |
+
env_setup = f"""
|
| 228 |
+
# Environment variables for Voxtral ASR demo
|
| 229 |
+
import os
|
| 230 |
+
os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
|
| 231 |
+
os.environ['MODEL_NAME'] = {_json.dumps(self.model_id.split('/')[-1])}
|
| 232 |
+
os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
|
| 233 |
+
"""
|
| 234 |
+
else:
|
| 235 |
+
# For SmolLM models, use simpler setup
|
| 236 |
+
import json as _json
|
| 237 |
+
env_setup = f"""
|
| 238 |
+
# Environment variables for model configuration
|
| 239 |
+
import os
|
| 240 |
+
os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
|
| 241 |
+
os.environ['MODEL_SUBFOLDER'] = {_json.dumps(self.subfolder if self.subfolder else "")}
|
| 242 |
+
os.environ['MODEL_NAME'] = {_json.dumps(self.model_id.split("/")[-1])}
|
| 243 |
+
os.environ['MODEL_IDENTITY'] = {_json.dumps(self.model_identity or "")}
|
| 244 |
+
os.environ['SYSTEM_MESSAGE'] = {_json.dumps(self.system_message or (self.model_identity or ""))}
|
| 245 |
+
os.environ['DEVELOPER_MESSAGE'] = {_json.dumps(self.developer_message or "")}
|
| 246 |
+
os.environ['REASONING_EFFORT'] = {_json.dumps((self.reasoning_effort or "medium"))}
|
| 247 |
+
{"os.environ['EXAMPLES_TYPE'] = " + _json.dumps(self.examples_type) + "\n" if self.examples_type else ''}
|
| 248 |
+
{"os.environ['DISABLE_EXAMPLES'] = 'true'\n" if self.disable_examples else ("os.environ['DISABLE_EXAMPLES'] = 'false'\n" if self.disable_examples is not None else '')}
|
| 249 |
+
{"os.environ['EXAMPLES_JSON'] = " + _json.dumps(self.examples_json) + "\n" if self.examples_json else ''}
|
| 250 |
+
|
| 251 |
+
# Branding/owner variables
|
| 252 |
+
os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
|
| 253 |
+
os.environ['BRAND_OWNER_NAME'] = {_json.dumps(self.brand_owner_name)}
|
| 254 |
+
os.environ['BRAND_TEAM_NAME'] = {_json.dumps(self.brand_team_name)}
|
| 255 |
+
os.environ['BRAND_DISCORD_URL'] = {_json.dumps(self.brand_discord_url)}
|
| 256 |
+
os.environ['BRAND_HF_ORG'] = {_json.dumps(self.brand_hf_org)}
|
| 257 |
+
os.environ['BRAND_HF_LABEL'] = {_json.dumps(self.brand_hf_label)}
|
| 258 |
+
os.environ['BRAND_HF_URL'] = {_json.dumps(self.brand_hf_url)}
|
| 259 |
+
os.environ['BRAND_GH_ORG'] = {_json.dumps(self.brand_gh_org)}
|
| 260 |
+
os.environ['BRAND_GH_LABEL'] = {_json.dumps(self.brand_gh_label)}
|
| 261 |
+
os.environ['BRAND_GH_URL'] = {_json.dumps(self.brand_gh_url)}
|
| 262 |
+
os.environ['BRAND_PROJECT_NAME'] = {_json.dumps(self.brand_project_name)}
|
| 263 |
+
os.environ['BRAND_PROJECT_URL'] = {_json.dumps(self.brand_project_url)}
|
| 264 |
+
|
| 265 |
+
"""
|
| 266 |
+
return env_setup
|
| 267 |
+
|
| 268 |
+
def _set_model_variables(self):
|
| 269 |
+
"""Set model-specific environment variables in the space"""
|
| 270 |
+
try:
|
| 271 |
+
# Common variables for all models
|
| 272 |
+
self.api.add_space_variable(
|
| 273 |
+
repo_id=self.space_id,
|
| 274 |
+
key="HF_MODEL_ID",
|
| 275 |
+
value=self.model_id,
|
| 276 |
+
description="Model ID for the demo"
|
| 277 |
+
)
|
| 278 |
+
logger.info(f"✅ Successfully set HF_MODEL_ID variable: {self.model_id}")
|
| 279 |
+
|
| 280 |
+
if self.subfolder and self.subfolder.strip():
|
| 281 |
+
self.api.add_space_variable(
|
| 282 |
+
repo_id=self.space_id,
|
| 283 |
+
key="MODEL_SUBFOLDER",
|
| 284 |
+
value=self.subfolder,
|
| 285 |
+
description="Model subfolder for the demo"
|
| 286 |
+
)
|
| 287 |
+
logger.info(f"✅ Successfully set MODEL_SUBFOLDER variable: {self.subfolder}")
|
| 288 |
+
else:
|
| 289 |
+
logger.info("ℹ️ No subfolder specified, using main model")
|
| 290 |
+
|
| 291 |
+
# GPT-OSS specific variables
|
| 292 |
+
if self.demo_type == "gpt":
|
| 293 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
| 294 |
+
self.api.add_space_variable(
|
| 295 |
+
repo_id=self.space_id,
|
| 296 |
+
key="LORA_MODEL_ID",
|
| 297 |
+
value=self.model_id,
|
| 298 |
+
description="LoRA/Fine-tuned model ID"
|
| 299 |
+
)
|
| 300 |
+
logger.info(f"✅ Successfully set LORA_MODEL_ID variable: {self.model_id}")
|
| 301 |
+
self.api.add_space_variable(
|
| 302 |
+
repo_id=self.space_id,
|
| 303 |
+
key="BASE_MODEL_ID",
|
| 304 |
+
value="openai/gpt-oss-20b",
|
| 305 |
+
description="Base model ID for GPT-OSS"
|
| 306 |
+
)
|
| 307 |
+
logger.info("✅ Successfully set BASE_MODEL_ID variable: openai/gpt-oss-20b")
|
| 308 |
+
self.api.add_space_variable(
|
| 309 |
+
repo_id=self.space_id,
|
| 310 |
+
key="MODEL_NAME",
|
| 311 |
+
value=model_name,
|
| 312 |
+
description="Display name for the model"
|
| 313 |
+
)
|
| 314 |
+
logger.info(f"✅ Successfully set MODEL_NAME variable: {model_name}")
|
| 315 |
+
|
| 316 |
+
# Voxtral-specific variables
|
| 317 |
+
elif self.demo_type == "voxtral":
|
| 318 |
+
# HF_MODEL_ID was already set above; set a readable MODEL_NAME
|
| 319 |
+
vox_model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
| 320 |
+
self.api.add_space_variable(
|
| 321 |
+
repo_id=self.space_id,
|
| 322 |
+
key="MODEL_NAME",
|
| 323 |
+
value=vox_model_name,
|
| 324 |
+
description="Display name for the Voxtral model"
|
| 325 |
+
)
|
| 326 |
+
logger.info(f"✅ Set Voxtral MODEL_NAME variable: {vox_model_name}")
|
| 327 |
+
|
| 328 |
+
# Optional context variables
|
| 329 |
+
if self.model_identity:
|
| 330 |
+
self.api.add_space_variable(
|
| 331 |
+
repo_id=self.space_id,
|
| 332 |
+
key="MODEL_IDENTITY",
|
| 333 |
+
value=self.model_identity,
|
| 334 |
+
description="Default model identity/system persona"
|
| 335 |
+
)
|
| 336 |
+
logger.info("✅ Set MODEL_IDENTITY variable")
|
| 337 |
+
if self.system_message or self.model_identity:
|
| 338 |
+
self.api.add_space_variable(
|
| 339 |
+
repo_id=self.space_id,
|
| 340 |
+
key="SYSTEM_MESSAGE",
|
| 341 |
+
value=self.system_message or self.model_identity or "",
|
| 342 |
+
description="Default system message"
|
| 343 |
+
)
|
| 344 |
+
logger.info("✅ Set SYSTEM_MESSAGE variable")
|
| 345 |
+
if self.developer_message:
|
| 346 |
+
self.api.add_space_variable(
|
| 347 |
+
repo_id=self.space_id,
|
| 348 |
+
key="DEVELOPER_MESSAGE",
|
| 349 |
+
value=self.developer_message,
|
| 350 |
+
description="Default developer message"
|
| 351 |
+
)
|
| 352 |
+
logger.info("✅ Set DEVELOPER_MESSAGE variable")
|
| 353 |
+
if self.reasoning_effort:
|
| 354 |
+
self.api.add_space_variable(
|
| 355 |
+
repo_id=self.space_id,
|
| 356 |
+
key="REASONING_EFFORT",
|
| 357 |
+
value=self.reasoning_effort,
|
| 358 |
+
description="Default reasoning effort (low|medium|high)"
|
| 359 |
+
)
|
| 360 |
+
logger.info("✅ Set REASONING_EFFORT variable")
|
| 361 |
+
|
| 362 |
+
# Branding variables
|
| 363 |
+
branding_vars = {
|
| 364 |
+
"HF_USERNAME": self.hf_username,
|
| 365 |
+
"BRAND_OWNER_NAME": self.brand_owner_name,
|
| 366 |
+
"BRAND_TEAM_NAME": self.brand_team_name,
|
| 367 |
+
"BRAND_DISCORD_URL": self.brand_discord_url,
|
| 368 |
+
"BRAND_HF_ORG": self.brand_hf_org,
|
| 369 |
+
"BRAND_HF_LABEL": self.brand_hf_label,
|
| 370 |
+
"BRAND_HF_URL": self.brand_hf_url,
|
| 371 |
+
"BRAND_GH_ORG": self.brand_gh_org,
|
| 372 |
+
"BRAND_GH_LABEL": self.brand_gh_label,
|
| 373 |
+
"BRAND_GH_URL": self.brand_gh_url,
|
| 374 |
+
"BRAND_PROJECT_NAME": self.brand_project_name,
|
| 375 |
+
"BRAND_PROJECT_URL": self.brand_project_url,
|
| 376 |
+
}
|
| 377 |
+
for key, value in branding_vars.items():
|
| 378 |
+
self.api.add_space_variable(
|
| 379 |
+
repo_id=self.space_id,
|
| 380 |
+
key=key,
|
| 381 |
+
value=value,
|
| 382 |
+
description=f"Branding: {key}"
|
| 383 |
+
)
|
| 384 |
+
logger.info("✅ Set branding variables")
|
| 385 |
+
|
| 386 |
+
# Examples variables
|
| 387 |
+
if self.examples_type:
|
| 388 |
+
self.api.add_space_variable(
|
| 389 |
+
repo_id=self.space_id,
|
| 390 |
+
key="EXAMPLES_TYPE",
|
| 391 |
+
value=self.examples_type,
|
| 392 |
+
description="Examples pack type (e.g., general|medical)"
|
| 393 |
+
)
|
| 394 |
+
logger.info(f"✅ Set EXAMPLES_TYPE={self.examples_type}")
|
| 395 |
+
if self.disable_examples is not None:
|
| 396 |
+
self.api.add_space_variable(
|
| 397 |
+
repo_id=self.space_id,
|
| 398 |
+
key="DISABLE_EXAMPLES",
|
| 399 |
+
value=("true" if self.disable_examples else "false"),
|
| 400 |
+
description="Disable built-in examples"
|
| 401 |
+
)
|
| 402 |
+
logger.info(f"✅ Set DISABLE_EXAMPLES={self.disable_examples}")
|
| 403 |
+
if self.examples_json:
|
| 404 |
+
self.api.add_space_variable(
|
| 405 |
+
repo_id=self.space_id,
|
| 406 |
+
key="EXAMPLES_JSON",
|
| 407 |
+
value=self.examples_json,
|
| 408 |
+
description="Custom examples JSON override"
|
| 409 |
+
)
|
| 410 |
+
logger.info("✅ Set EXAMPLES_JSON override")
|
| 411 |
+
|
| 412 |
+
except Exception as e:
|
| 413 |
+
logger.error(f"❌ Failed to set model variables: {e}")
|
| 414 |
+
|
| 415 |
+
def validate_model_exists(self) -> bool:
|
| 416 |
+
"""Validate that the model exists on Hugging Face Hub"""
|
| 417 |
+
try:
|
| 418 |
+
logger.info(f"Validating model: {self.model_id}")
|
| 419 |
+
|
| 420 |
+
if HF_HUB_AVAILABLE:
|
| 421 |
+
# Use HF Hub API
|
| 422 |
+
try:
|
| 423 |
+
model_info = self.api.model_info(self.model_id)
|
| 424 |
+
logger.info(f"✅ Model {self.model_id} exists and is accessible")
|
| 425 |
+
return True
|
| 426 |
+
except Exception as e:
|
| 427 |
+
logger.error(f"❌ Model {self.model_id} not found via API: {e}")
|
| 428 |
+
return False
|
| 429 |
+
else:
|
| 430 |
+
# Fallback to requests
|
| 431 |
+
url = f"https://huggingface.co/api/models/{self.model_id}"
|
| 432 |
+
headers = {"Authorization": f"Bearer {self.hf_token}"}
|
| 433 |
+
response = requests.get(url, headers=headers, timeout=30)
|
| 434 |
+
|
| 435 |
+
if response.status_code == 200:
|
| 436 |
+
logger.info(f"✅ Model {self.model_id} exists and is accessible")
|
| 437 |
+
return True
|
| 438 |
+
else:
|
| 439 |
+
logger.error(f"❌ Model {self.model_id} not found or not accessible")
|
| 440 |
+
return False
|
| 441 |
+
|
| 442 |
+
except Exception as e:
|
| 443 |
+
logger.error(f"❌ Error validating model: {e}")
|
| 444 |
+
return False
|
| 445 |
+
|
| 446 |
+
def create_space_repository(self) -> bool:
|
| 447 |
+
"""Create the space repository on Hugging Face Hub"""
|
| 448 |
+
try:
|
| 449 |
+
logger.info(f"Creating Space: {self.space_name}")
|
| 450 |
+
|
| 451 |
+
if not HF_HUB_AVAILABLE:
|
| 452 |
+
logger.warning("huggingface_hub not available, falling back to CLI")
|
| 453 |
+
return self._create_space_cli()
|
| 454 |
+
|
| 455 |
+
# Use the latest HF Hub API to create space
|
| 456 |
+
try:
|
| 457 |
+
# Create the space using the API
|
| 458 |
+
create_repo(
|
| 459 |
+
repo_id=self.space_id,
|
| 460 |
+
token=self.hf_token,
|
| 461 |
+
repo_type="space",
|
| 462 |
+
exist_ok=True,
|
| 463 |
+
private=False, # Spaces are typically public
|
| 464 |
+
space_sdk="gradio", # Specify Gradio SDK
|
| 465 |
+
space_hardware="cpu-basic" # Use basic CPU
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
logger.info(f"✅ Space created successfully: {self.space_url}")
|
| 469 |
+
return True
|
| 470 |
+
|
| 471 |
+
except Exception as api_error:
|
| 472 |
+
logger.error(f"API creation failed: {api_error}")
|
| 473 |
+
logger.info("Falling back to CLI method...")
|
| 474 |
+
return self._create_space_cli()
|
| 475 |
+
|
| 476 |
+
except Exception as e:
|
| 477 |
+
logger.error(f"❌ Error creating space: {e}")
|
| 478 |
+
return False
|
| 479 |
+
|
| 480 |
+
def _create_space_cli(self) -> bool:
|
| 481 |
+
"""Fallback method using CLI commands"""
|
| 482 |
+
try:
|
| 483 |
+
logger.info("Using CLI fallback method...")
|
| 484 |
+
|
| 485 |
+
# Set HF token for CLI
|
| 486 |
+
os.environ['HF_TOKEN'] = self.hf_token
|
| 487 |
+
|
| 488 |
+
# Create space using Hugging Face CLI
|
| 489 |
+
cmd = [
|
| 490 |
+
"hf", "repo", "create",
|
| 491 |
+
self.space_id,
|
| 492 |
+
"--type", "space"
|
| 493 |
+
]
|
| 494 |
+
|
| 495 |
+
logger.info(f"Running command: {' '.join(cmd)}")
|
| 496 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 497 |
+
|
| 498 |
+
if result.returncode != 0:
|
| 499 |
+
logger.warning(f"First attempt failed: {result.stderr}")
|
| 500 |
+
# Try alternative approach without space-specific flags
|
| 501 |
+
logger.info("Retrying with basic space creation...")
|
| 502 |
+
cmd = [
|
| 503 |
+
"hf", "repo", "create",
|
| 504 |
+
self.space_id
|
| 505 |
+
]
|
| 506 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 507 |
+
|
| 508 |
+
if result.returncode == 0:
|
| 509 |
+
logger.info(f"✅ Space created successfully: {self.space_url}")
|
| 510 |
+
return True
|
| 511 |
+
else:
|
| 512 |
+
logger.error(f"❌ Failed to create space: {result.stderr}")
|
| 513 |
+
return False
|
| 514 |
+
|
| 515 |
+
except Exception as e:
|
| 516 |
+
logger.error(f"❌ Error creating space with CLI: {e}")
|
| 517 |
+
return False
|
| 518 |
+
|
| 519 |
+
def prepare_space_files(self) -> str:
|
| 520 |
+
"""Prepare all necessary files for the Space in a temporary directory"""
|
| 521 |
+
try:
|
| 522 |
+
logger.info("Preparing Space files...")
|
| 523 |
+
|
| 524 |
+
# Create temporary directory
|
| 525 |
+
temp_dir = tempfile.mkdtemp()
|
| 526 |
+
logger.info(f"Created temporary directory: {temp_dir}")
|
| 527 |
+
|
| 528 |
+
# Copy template files
|
| 529 |
+
copied_files = []
|
| 530 |
+
for file_path in self.template_dir.iterdir():
|
| 531 |
+
if file_path.is_file():
|
| 532 |
+
dest_path = Path(temp_dir) / file_path.name
|
| 533 |
+
shutil.copy2(file_path, dest_path)
|
| 534 |
+
copied_files.append(file_path.name)
|
| 535 |
+
logger.info(f"✅ Copied {file_path.name} to temp directory")
|
| 536 |
+
|
| 537 |
+
# Update app.py with environment variables
|
| 538 |
+
app_file = Path(temp_dir) / "app.py"
|
| 539 |
+
if app_file.exists():
|
| 540 |
+
with open(app_file, 'r', encoding='utf-8') as f:
|
| 541 |
+
content = f.read()
|
| 542 |
+
|
| 543 |
+
# Add environment variable setup at the top
|
| 544 |
+
env_setup = self._generate_env_setup()
|
| 545 |
+
|
| 546 |
+
# Insert after imports
|
| 547 |
+
lines = content.split('\n')
|
| 548 |
+
import_end = 0
|
| 549 |
+
for i, line in enumerate(lines):
|
| 550 |
+
if line.startswith('import ') or line.startswith('from '):
|
| 551 |
+
import_end = i + 1
|
| 552 |
+
elif line.strip() == '' and import_end > 0:
|
| 553 |
+
break
|
| 554 |
+
|
| 555 |
+
lines.insert(import_end, env_setup)
|
| 556 |
+
content = '\n'.join(lines)
|
| 557 |
+
|
| 558 |
+
with open(app_file, 'w', encoding='utf-8') as f:
|
| 559 |
+
f.write(content)
|
| 560 |
+
|
| 561 |
+
logger.info("✅ Updated app.py with model configuration")
|
| 562 |
+
|
| 563 |
+
# YAML front matter required by Hugging Face Spaces
|
| 564 |
+
yaml_front_matter = (
|
| 565 |
+
f"---\n"
|
| 566 |
+
f"title: {'GPT-OSS Demo' if self.demo_type == 'gpt' else 'SmolLM3 Demo'}\n"
|
| 567 |
+
f"emoji: {'🌟' if self.demo_type == 'gpt' else '💃🏻'}\n"
|
| 568 |
+
f"colorFrom: {'blue' if self.demo_type == 'gpt' else 'green'}\n"
|
| 569 |
+
f"colorTo: {'pink' if self.demo_type == 'gpt' else 'purple'}\n"
|
| 570 |
+
f"sdk: gradio\n"
|
| 571 |
+
f"sdk_version: 5.40.0\n"
|
| 572 |
+
f"app_file: app.py\n"
|
| 573 |
+
f"pinned: false\n"
|
| 574 |
+
f"short_description: Interactive demo for {self.model_id}\n"
|
| 575 |
+
+ ("license: mit\n" if self.demo_type != 'gpt' else "") +
|
| 576 |
+
f"---\n\n"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Create README.md for the space (include configuration details)
|
| 580 |
+
readme_content = (
|
| 581 |
+
yaml_front_matter
|
| 582 |
+
+ f"# Demo: {self.model_id}\n\n"
|
| 583 |
+
+ f"This is an interactive demo for the fine-tuned model {self.model_id}.\n\n"
|
| 584 |
+
+ "## Features\n"
|
| 585 |
+
"- Interactive chat interface\n"
|
| 586 |
+
"- Customizable system & developer prompts\n"
|
| 587 |
+
"- Advanced generation parameters\n"
|
| 588 |
+
"- Thinking mode support\n\n"
|
| 589 |
+
+ "## Model Information\n"
|
| 590 |
+
f"- **Model ID**: {self.model_id}\n"
|
| 591 |
+
f"- **Subfolder**: {self.subfolder if self.subfolder and self.subfolder.strip() else 'main'}\n"
|
| 592 |
+
f"- **Deployed by**: {self.hf_username}\n"
|
| 593 |
+
+ ("- **Base Model**: openai/gpt-oss-20b\n" if self.demo_type == 'gpt' else "")
|
| 594 |
+
+ "\n"
|
| 595 |
+
+ "## Configuration\n"
|
| 596 |
+
"- **Model Identity**:\n\n"
|
| 597 |
+
f"```\n{self.model_identity or 'Not set'}\n```\n\n"
|
| 598 |
+
"- **System Message** (default):\n\n"
|
| 599 |
+
f"```\n{(self.system_message or self.model_identity) or 'Not set'}\n```\n\n"
|
| 600 |
+
"- **Developer Message** (default):\n\n"
|
| 601 |
+
f"```\n{self.developer_message or 'Not set'}\n```\n\n"
|
| 602 |
+
"These defaults come from the selected training configuration and can be adjusted in the UI when you run the demo.\n\n"
|
| 603 |
+
+ "## Usage\n"
|
| 604 |
+
"Simply start chatting with the model using the interface below!\n\n"
|
| 605 |
+
+ "---\n"
|
| 606 |
+
"*This demo was automatically deployed by the SmolFactory Fine-tuning Pipeline*\n"
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
with open(Path(temp_dir) / "README.md", 'w', encoding='utf-8') as f:
|
| 610 |
+
f.write(readme_content)
|
| 611 |
+
|
| 612 |
+
logger.info(f"✅ Prepared {len(copied_files)} files in temporary directory")
|
| 613 |
+
return temp_dir
|
| 614 |
+
|
| 615 |
+
except Exception as e:
|
| 616 |
+
logger.error(f"❌ Error preparing files: {e}")
|
| 617 |
+
return None
|
| 618 |
+
|
| 619 |
+
def upload_files_to_space(self, temp_dir: str) -> bool:
|
| 620 |
+
"""Upload files to the Space using HF Hub API directly"""
|
| 621 |
+
try:
|
| 622 |
+
logger.info("Uploading files to Space using HF Hub API...")
|
| 623 |
+
|
| 624 |
+
if not HF_HUB_AVAILABLE:
|
| 625 |
+
logger.error("❌ huggingface_hub not available for file upload")
|
| 626 |
+
return self._upload_files_cli(temp_dir)
|
| 627 |
+
|
| 628 |
+
# Upload each file using the HF Hub API
|
| 629 |
+
temp_path = Path(temp_dir)
|
| 630 |
+
uploaded_files = []
|
| 631 |
+
|
| 632 |
+
for file_path in temp_path.iterdir():
|
| 633 |
+
if file_path.is_file():
|
| 634 |
+
try:
|
| 635 |
+
# Upload file to the space
|
| 636 |
+
upload_file(
|
| 637 |
+
path_or_fileobj=str(file_path),
|
| 638 |
+
path_in_repo=file_path.name,
|
| 639 |
+
repo_id=self.space_id,
|
| 640 |
+
repo_type="space",
|
| 641 |
+
token=self.hf_token
|
| 642 |
+
)
|
| 643 |
+
uploaded_files.append(file_path.name)
|
| 644 |
+
logger.info(f"✅ Uploaded {file_path.name}")
|
| 645 |
+
except Exception as e:
|
| 646 |
+
logger.error(f"❌ Failed to upload {file_path.name}: {e}")
|
| 647 |
+
return False
|
| 648 |
+
|
| 649 |
+
logger.info(f"✅ Successfully uploaded {len(uploaded_files)} files to Space")
|
| 650 |
+
return True
|
| 651 |
+
|
| 652 |
+
except Exception as e:
|
| 653 |
+
logger.error(f"❌ Error uploading files: {e}")
|
| 654 |
+
return self._upload_files_cli(temp_dir)
|
| 655 |
+
|
| 656 |
+
def _upload_files_cli(self, temp_dir: str) -> bool:
|
| 657 |
+
"""Fallback method using CLI for file upload"""
|
| 658 |
+
try:
|
| 659 |
+
logger.info("Using CLI fallback for file upload...")
|
| 660 |
+
|
| 661 |
+
# Set HF token for CLI
|
| 662 |
+
os.environ['HF_TOKEN'] = self.hf_token
|
| 663 |
+
|
| 664 |
+
# Initialize git repository
|
| 665 |
+
subprocess.run(["git", "init"], cwd=temp_dir, check=True)
|
| 666 |
+
subprocess.run(["git", "config", "user.name", "Demo Deployer"], cwd=temp_dir, check=True)
|
| 667 |
+
subprocess.run(["git", "config", "user.email", "[email protected]"], cwd=temp_dir, check=True)
|
| 668 |
+
|
| 669 |
+
# Add files
|
| 670 |
+
subprocess.run(["git", "add", "."], cwd=temp_dir, check=True)
|
| 671 |
+
subprocess.run(["git", "commit", "-m", f"Deploy demo for {self.model_id}"], cwd=temp_dir, check=True)
|
| 672 |
+
|
| 673 |
+
# Add remote and push
|
| 674 |
+
remote_url = f"https://{self.hf_token}@huggingface.co/spaces/{self.space_id}"
|
| 675 |
+
subprocess.run(["git", "remote", "add", "origin", remote_url], cwd=temp_dir, check=True)
|
| 676 |
+
subprocess.run(["git", "push", "-u", "origin", "main"], cwd=temp_dir, check=True)
|
| 677 |
+
|
| 678 |
+
logger.info(f"✅ Successfully pushed files to space: {self.space_id}")
|
| 679 |
+
return True
|
| 680 |
+
|
| 681 |
+
except subprocess.CalledProcessError as e:
|
| 682 |
+
logger.error(f"❌ Git operation failed: {e}")
|
| 683 |
+
return False
|
| 684 |
+
except Exception as e:
|
| 685 |
+
logger.error(f"❌ Error pushing to space: {e}")
|
| 686 |
+
return False
|
| 687 |
+
|
| 688 |
+
def set_space_secrets(self) -> bool:
|
| 689 |
+
"""Set environment variables/secrets for the Space using HF Hub API"""
|
| 690 |
+
try:
|
| 691 |
+
logger.info("Setting Space secrets using HF Hub API...")
|
| 692 |
+
|
| 693 |
+
if not HF_HUB_AVAILABLE:
|
| 694 |
+
logger.warning("❌ huggingface_hub not available for setting secrets")
|
| 695 |
+
return self._manual_secret_setup()
|
| 696 |
+
|
| 697 |
+
# Set the HF_TOKEN secret for the space using the API
|
| 698 |
+
try:
|
| 699 |
+
self.api.add_space_secret(
|
| 700 |
+
repo_id=self.space_id,
|
| 701 |
+
key="HF_TOKEN",
|
| 702 |
+
value=self.space_secret_token,
|
| 703 |
+
description="Hugging Face token for model access"
|
| 704 |
+
)
|
| 705 |
+
logger.info("✅ Successfully set HF_TOKEN secret via API")
|
| 706 |
+
|
| 707 |
+
# Set model-specific environment variables
|
| 708 |
+
self._set_model_variables()
|
| 709 |
+
|
| 710 |
+
return True
|
| 711 |
+
|
| 712 |
+
except Exception as api_error:
|
| 713 |
+
logger.error(f"❌ Failed to set secrets via API: {api_error}")
|
| 714 |
+
logger.info("Falling back to manual setup...")
|
| 715 |
+
return self._manual_secret_setup()
|
| 716 |
+
|
| 717 |
+
except Exception as e:
|
| 718 |
+
logger.error(f"❌ Error setting space secrets: {e}")
|
| 719 |
+
return self._manual_secret_setup()
|
| 720 |
+
|
| 721 |
+
def _manual_secret_setup(self) -> bool:
|
| 722 |
+
"""Fallback method for manual secret setup"""
|
| 723 |
+
logger.info("📝 Manual Space Secrets Configuration:")
|
| 724 |
+
logger.info(f" HF_TOKEN=<hidden>")
|
| 725 |
+
logger.info(f" HF_MODEL_ID={self.model_id}")
|
| 726 |
+
if self.subfolder and self.subfolder.strip():
|
| 727 |
+
logger.info(f" MODEL_SUBFOLDER={self.subfolder}")
|
| 728 |
+
else:
|
| 729 |
+
logger.info(" MODEL_SUBFOLDER=(empty - using main model)")
|
| 730 |
+
|
| 731 |
+
# GPT-OSS specific variables
|
| 732 |
+
if self.demo_type == "gpt":
|
| 733 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
| 734 |
+
logger.info(f" LORA_MODEL_ID={self.model_id}")
|
| 735 |
+
logger.info(f" BASE_MODEL_ID=openai/gpt-oss-20b")
|
| 736 |
+
logger.info(f" MODEL_NAME={model_name}")
|
| 737 |
+
if self.model_identity:
|
| 738 |
+
logger.info(f" MODEL_IDENTITY={self.model_identity}")
|
| 739 |
+
if self.system_message:
|
| 740 |
+
logger.info(f" SYSTEM_MESSAGE={self.system_message}")
|
| 741 |
+
if self.developer_message:
|
| 742 |
+
logger.info(f" DEVELOPER_MESSAGE={self.developer_message}")
|
| 743 |
+
# Branding variables
|
| 744 |
+
logger.info(f" HF_USERNAME={self.hf_username}")
|
| 745 |
+
logger.info(f" BRAND_OWNER_NAME={self.brand_owner_name}")
|
| 746 |
+
logger.info(f" BRAND_TEAM_NAME={self.brand_team_name}")
|
| 747 |
+
logger.info(f" BRAND_DISCORD_URL={self.brand_discord_url}")
|
| 748 |
+
logger.info(f" BRAND_HF_ORG={self.brand_hf_org}")
|
| 749 |
+
logger.info(f" BRAND_HF_LABEL={self.brand_hf_label}")
|
| 750 |
+
logger.info(f" BRAND_HF_URL={self.brand_hf_url}")
|
| 751 |
+
logger.info(f" BRAND_GH_ORG={self.brand_gh_org}")
|
| 752 |
+
logger.info(f" BRAND_GH_LABEL={self.brand_gh_label}")
|
| 753 |
+
logger.info(f" BRAND_GH_URL={self.brand_gh_url}")
|
| 754 |
+
logger.info(f" BRAND_PROJECT_NAME={self.brand_project_name}")
|
| 755 |
+
logger.info(f" BRAND_PROJECT_URL={self.brand_project_url}")
|
| 756 |
+
|
| 757 |
+
# Examples variables
|
| 758 |
+
if self.examples_type:
|
| 759 |
+
logger.info(f" EXAMPLES_TYPE={self.examples_type}")
|
| 760 |
+
if self.disable_examples is not None:
|
| 761 |
+
logger.info(f" DISABLE_EXAMPLES={'true' if self.disable_examples else 'false'}")
|
| 762 |
+
if self.examples_json:
|
| 763 |
+
logger.info(f" EXAMPLES_JSON={self.examples_json}")
|
| 764 |
+
|
| 765 |
+
logger.info(f"\n🔧 To set secrets in your Space:")
|
| 766 |
+
logger.info(f"1. Go to your Space settings: {self.space_url}/settings")
|
| 767 |
+
logger.info("2. Navigate to the 'Repository secrets' section")
|
| 768 |
+
logger.info("3. Add the following secrets:")
|
| 769 |
+
logger.info(f" Name: HF_TOKEN")
|
| 770 |
+
logger.info(f" Value: <your token>")
|
| 771 |
+
logger.info(f" Name: HF_MODEL_ID")
|
| 772 |
+
logger.info(f" Value: {self.model_id}")
|
| 773 |
+
if self.subfolder and self.subfolder.strip():
|
| 774 |
+
logger.info(f" Name: MODEL_SUBFOLDER")
|
| 775 |
+
logger.info(f" Value: {self.subfolder}")
|
| 776 |
+
else:
|
| 777 |
+
logger.info(" Name: MODEL_SUBFOLDER")
|
| 778 |
+
logger.info(" Value: (leave empty)")
|
| 779 |
+
|
| 780 |
+
# GPT-OSS specific variables
|
| 781 |
+
if self.demo_type == "gpt":
|
| 782 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
| 783 |
+
logger.info(f" Name: LORA_MODEL_ID")
|
| 784 |
+
logger.info(f" Value: {self.model_id}")
|
| 785 |
+
logger.info(f" Name: BASE_MODEL_ID")
|
| 786 |
+
logger.info(f" Value: openai/gpt-oss-20b")
|
| 787 |
+
logger.info(f" Name: MODEL_NAME")
|
| 788 |
+
logger.info(f" Value: {model_name}")
|
| 789 |
+
|
| 790 |
+
logger.info("4. Save the secrets")
|
| 791 |
+
|
| 792 |
+
return True
|
| 793 |
+
|
| 794 |
+
def test_space(self) -> bool:
|
| 795 |
+
"""Test if the Space is working correctly"""
|
| 796 |
+
try:
|
| 797 |
+
logger.info("Testing Space...")
|
| 798 |
+
|
| 799 |
+
# Wait a bit for the space to build
|
| 800 |
+
logger.info("Waiting 180 seconds for Space to build...")
|
| 801 |
+
time.sleep(180)
|
| 802 |
+
|
| 803 |
+
# Try to access the space
|
| 804 |
+
response = requests.get(self.space_url, timeout=30)
|
| 805 |
+
|
| 806 |
+
if response.status_code == 200:
|
| 807 |
+
logger.info(f"✅ Space is accessible: {self.space_url}")
|
| 808 |
+
return True
|
| 809 |
+
else:
|
| 810 |
+
logger.warning(f"⚠️ Space returned status code: {response.status_code}")
|
| 811 |
+
logger.warning(f"Response: {response.text[:500]}...")
|
| 812 |
+
return False
|
| 813 |
+
|
| 814 |
+
except Exception as e:
|
| 815 |
+
logger.error(f"❌ Error testing space: {e}")
|
| 816 |
+
return False
|
| 817 |
+
|
| 818 |
+
def deploy(self) -> bool:
|
| 819 |
+
"""Main deployment method"""
|
| 820 |
+
logger.info(f"🚀 Starting demo space deployment for {self.model_id}")
|
| 821 |
+
|
| 822 |
+
# Step 1: Validate model exists
|
| 823 |
+
if not self.validate_model_exists():
|
| 824 |
+
return False
|
| 825 |
+
|
| 826 |
+
# Step 2: Create space repository
|
| 827 |
+
if not self.create_space_repository():
|
| 828 |
+
return False
|
| 829 |
+
|
| 830 |
+
# Step 3: Prepare files
|
| 831 |
+
temp_dir = self.prepare_space_files()
|
| 832 |
+
if not temp_dir:
|
| 833 |
+
return False
|
| 834 |
+
|
| 835 |
+
# Step 4: Upload files
|
| 836 |
+
if not self.upload_files_to_space(temp_dir):
|
| 837 |
+
return False
|
| 838 |
+
|
| 839 |
+
# Step 5: Set space secrets
|
| 840 |
+
if not self.set_space_secrets():
|
| 841 |
+
return False
|
| 842 |
+
|
| 843 |
+
# Step 6: Clean up temp directory
|
| 844 |
+
try:
|
| 845 |
+
shutil.rmtree(temp_dir)
|
| 846 |
+
logger.info("✅ Cleaned up temporary directory")
|
| 847 |
+
except Exception as e:
|
| 848 |
+
logger.warning(f"⚠️ Warning: Could not clean up temp directory: {e}")
|
| 849 |
+
|
| 850 |
+
# Step 7: Test space
|
| 851 |
+
if not self.test_space():
|
| 852 |
+
logger.warning("⚠️ Space created but may need more time to build")
|
| 853 |
+
logger.info("Please check the Space manually in a few minutes")
|
| 854 |
+
|
| 855 |
+
logger.info(f"🎉 Demo space deployment completed!")
|
| 856 |
+
logger.info(f"📊 Space URL: {self.space_url}")
|
| 857 |
+
logger.info(f"🔧 Space configuration: {self.space_url}/settings")
|
| 858 |
+
|
| 859 |
+
return True
|
| 860 |
+
|
| 861 |
+
def main():
|
| 862 |
+
"""Main function for command line usage"""
|
| 863 |
+
print("Demo Space Deployment Script")
|
| 864 |
+
print("=" * 40)
|
| 865 |
+
|
| 866 |
+
parser = argparse.ArgumentParser(description="Deploy demo space to Hugging Face Spaces")
|
| 867 |
+
parser.add_argument("--hf-token", required=True, help="Hugging Face token")
|
| 868 |
+
parser.add_argument(
|
| 869 |
+
"--space-secret-token",
|
| 870 |
+
required=False,
|
| 871 |
+
help="Token to store as Space secret HF_TOKEN (defaults to --hf-token). Use a READ token here for least privilege.",
|
| 872 |
+
)
|
| 873 |
+
parser.add_argument("--hf-username", required=True, help="Hugging Face username")
|
| 874 |
+
parser.add_argument("--model-id", required=True, help="Model ID to deploy demo for")
|
| 875 |
+
parser.add_argument("--subfolder", default="int4", help="Model subfolder (default: int4)")
|
| 876 |
+
parser.add_argument("--space-name", help="Custom space name (optional)")
|
| 877 |
+
parser.add_argument("--demo-type", choices=["smol", "gpt"], help="Demo type: 'smol' for SmolLM, 'gpt' for GPT-OSS (auto-detected if not specified)")
|
| 878 |
+
parser.add_argument("--config-file", help="Path to the training config file to import context (system/developer/model_identity)")
|
| 879 |
+
# Examples configuration
|
| 880 |
+
parser.add_argument("--examples-type", choices=["general", "medical"], help="Examples pack to enable in the demo UI")
|
| 881 |
+
parser.add_argument("--disable-examples", action="store_true", help="Disable rendering of example prompts in the UI")
|
| 882 |
+
parser.add_argument("--examples-json", help="Custom examples JSON (list[str]) to override built-in examples")
|
| 883 |
+
# Branding customization
|
| 884 |
+
parser.add_argument("--brand-owner-name", help="Owner name shown in the UI title (defaults to HF username)")
|
| 885 |
+
parser.add_argument("--brand-team-name", help="Team name shown in Join Us (defaults to Team<owner>)")
|
| 886 |
+
parser.add_argument("--brand-discord-url", help="Discord invite URL for Join Us section")
|
| 887 |
+
parser.add_argument("--brand-hf-org", help="Hugging Face org/username to link in Join Us")
|
| 888 |
+
parser.add_argument("--brand-hf-label", help="Label for the HF link (defaults to org)")
|
| 889 |
+
parser.add_argument("--brand-hf-url", help="Custom HF link URL (defaults to https://huggingface.co/<org>)")
|
| 890 |
+
parser.add_argument("--brand-gh-org", help="GitHub org/username to link in Join Us")
|
| 891 |
+
parser.add_argument("--brand-gh-label", help="Label for the GitHub link (defaults to org)")
|
| 892 |
+
parser.add_argument("--brand-gh-url", help="Custom GitHub link URL (defaults to https://github.com/<org>)")
|
| 893 |
+
parser.add_argument("--brand-project-name", help="Project name to link in Join Us")
|
| 894 |
+
parser.add_argument("--brand-project-url", help="Project URL to link in Join Us")
|
| 895 |
+
|
| 896 |
+
args = parser.parse_args()
|
| 897 |
+
|
| 898 |
+
deployer = DemoSpaceDeployer(
|
| 899 |
+
hf_token=args.hf_token,
|
| 900 |
+
space_secret_token=(args.space_secret_token or None),
|
| 901 |
+
hf_username=args.hf_username,
|
| 902 |
+
model_id=args.model_id,
|
| 903 |
+
subfolder=args.subfolder,
|
| 904 |
+
space_name=args.space_name,
|
| 905 |
+
demo_type=args.demo_type,
|
| 906 |
+
config_file=args.config_file,
|
| 907 |
+
examples_type=args.examples_type,
|
| 908 |
+
disable_examples=(True if getattr(args, 'disable_examples', False) else None),
|
| 909 |
+
examples_json=args.examples_json,
|
| 910 |
+
brand_owner_name=args.brand_owner_name,
|
| 911 |
+
brand_team_name=args.brand_team_name,
|
| 912 |
+
brand_discord_url=args.brand_discord_url,
|
| 913 |
+
brand_hf_org=args.brand_hf_org,
|
| 914 |
+
brand_hf_label=args.brand_hf_label,
|
| 915 |
+
brand_hf_url=args.brand_hf_url,
|
| 916 |
+
brand_gh_org=args.brand_gh_org,
|
| 917 |
+
brand_gh_label=args.brand_gh_label,
|
| 918 |
+
brand_gh_url=args.brand_gh_url,
|
| 919 |
+
brand_project_name=args.brand_project_name,
|
| 920 |
+
brand_project_url=args.brand_project_url,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
success = deployer.deploy()
|
| 924 |
+
|
| 925 |
+
if success:
|
| 926 |
+
print("\n✅ Deployment successful!")
|
| 927 |
+
print(f"🌐 Your Demo Space: {deployer.space_url}")
|
| 928 |
+
print(f"👤 Username: {deployer.hf_username}")
|
| 929 |
+
print(f"🤖 Model: {deployer.model_id}")
|
| 930 |
+
print("\nNext steps:")
|
| 931 |
+
print("1. Wait for the Space to build (usually 2-5 minutes)")
|
| 932 |
+
print("2. Secrets have been automatically set via API")
|
| 933 |
+
print("3. Test the interface by visiting the Space URL")
|
| 934 |
+
print("4. Share your demo with others!")
|
| 935 |
+
print("\nIf the Space doesn't work immediately, check:")
|
| 936 |
+
print("- The Space logs at the Space URL")
|
| 937 |
+
print("- That all files were uploaded correctly")
|
| 938 |
+
print("- That the HF token has write permissions")
|
| 939 |
+
print("- That the secrets were set correctly in Space settings")
|
| 940 |
+
else:
|
| 941 |
+
print("\n❌ Deployment failed!")
|
| 942 |
+
print("Check the error messages above and try again.")
|
| 943 |
+
print("\nTroubleshooting:")
|
| 944 |
+
print("1. Verify your HF token has write permissions")
|
| 945 |
+
print("2. Check that the space name is available")
|
| 946 |
+
print("3. Verify the model exists and is accessible")
|
| 947 |
+
print("4. Try creating the space manually on HF first")
|
| 948 |
+
|
| 949 |
+
sys.exit(0 if success else 1)
|
| 950 |
+
|
| 951 |
+
if __name__ == "__main__":
|
| 952 |
+
main()
|
scripts/generate_model_card.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate unified model card from template
|
| 4 |
+
Handles template variables and conditional sections for quantized models
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import argparse
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Any, Optional
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class ModelCardGenerator:
|
| 18 |
+
"""Generate unified model cards from templates"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, template_path: str = "templates/model_card.md"):
|
| 21 |
+
self.template_path = Path(template_path)
|
| 22 |
+
if not self.template_path.exists():
|
| 23 |
+
raise FileNotFoundError(f"Template not found: {self.template_path}")
|
| 24 |
+
|
| 25 |
+
def load_template(self) -> str:
|
| 26 |
+
"""Load the model card template"""
|
| 27 |
+
with open(self.template_path, 'r', encoding='utf-8') as f:
|
| 28 |
+
return f.read()
|
| 29 |
+
|
| 30 |
+
def process_conditionals(self, content: str, variables: Dict[str, Any]) -> str:
|
| 31 |
+
"""Process conditional sections in the template"""
|
| 32 |
+
# Handle {{#if variable}}...{{/if}} blocks
|
| 33 |
+
pattern = r'\{\{#if\s+(\w+)\}\}(.*?)\{\{/if\}\}'
|
| 34 |
+
|
| 35 |
+
def replace_conditional(match):
|
| 36 |
+
variable_name = match.group(1)
|
| 37 |
+
conditional_content = match.group(2)
|
| 38 |
+
|
| 39 |
+
# Check if variable exists and is truthy
|
| 40 |
+
if variable_name in variables and variables[variable_name]:
|
| 41 |
+
return conditional_content
|
| 42 |
+
else:
|
| 43 |
+
return ""
|
| 44 |
+
|
| 45 |
+
return re.sub(pattern, replace_conditional, content, flags=re.DOTALL)
|
| 46 |
+
|
| 47 |
+
def replace_variables(self, content: str, variables: Dict[str, Any]) -> str:
|
| 48 |
+
"""Replace template variables with actual values"""
|
| 49 |
+
for key, value in variables.items():
|
| 50 |
+
placeholder = f"{{{{{key}}}}}"
|
| 51 |
+
content = content.replace(placeholder, str(value))
|
| 52 |
+
|
| 53 |
+
return content
|
| 54 |
+
|
| 55 |
+
def generate_model_card(self, variables: Dict[str, Any]) -> str:
|
| 56 |
+
"""Generate the complete model card"""
|
| 57 |
+
# Load template
|
| 58 |
+
content = self.load_template()
|
| 59 |
+
|
| 60 |
+
# Process conditionals first
|
| 61 |
+
content = self.process_conditionals(content, variables)
|
| 62 |
+
|
| 63 |
+
# Replace variables
|
| 64 |
+
content = self.replace_variables(content, variables)
|
| 65 |
+
|
| 66 |
+
return content
|
| 67 |
+
|
| 68 |
+
def save_model_card(self, content: str, output_path: str) -> bool:
|
| 69 |
+
"""Save the generated model card"""
|
| 70 |
+
try:
|
| 71 |
+
output_file = Path(output_path)
|
| 72 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 75 |
+
f.write(content)
|
| 76 |
+
|
| 77 |
+
logger.info(f"✅ Model card saved to: {output_file}")
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"❌ Failed to save model card: {e}")
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
def create_default_variables() -> Dict[str, Any]:
|
| 85 |
+
"""Create default variables for the model card"""
|
| 86 |
+
return {
|
| 87 |
+
"model_name": "SmolLM3 Fine-tuned Model",
|
| 88 |
+
"model_description": "A fine-tuned version of SmolLM3-3B for improved text generation and conversation capabilities.",
|
| 89 |
+
"repo_name": "your-username/model-name",
|
| 90 |
+
"base_model": "HuggingFaceTB/SmolLM3-3B",
|
| 91 |
+
"dataset_name": "OpenHermes-FR",
|
| 92 |
+
"training_config_type": "Custom Configuration",
|
| 93 |
+
"trainer_type": "SFTTrainer",
|
| 94 |
+
"batch_size": "8",
|
| 95 |
+
"gradient_accumulation_steps": "16",
|
| 96 |
+
"learning_rate": "5e-6",
|
| 97 |
+
"max_epochs": "3",
|
| 98 |
+
"max_seq_length": "2048",
|
| 99 |
+
"hardware_info": "GPU (H100/A100)",
|
| 100 |
+
"experiment_name": "smollm3-experiment",
|
| 101 |
+
"trackio_url": "https://trackio.space/experiment",
|
| 102 |
+
"dataset_repo": "tonic/trackio-experiments",
|
| 103 |
+
"dataset_size": "~80K samples",
|
| 104 |
+
"dataset_format": "Chat format",
|
| 105 |
+
"author_name": "Your Name",
|
| 106 |
+
"model_name_slug": "smollm3-fine-tuned",
|
| 107 |
+
"quantized_models": False,
|
| 108 |
+
"dataset_sample_size": None,
|
| 109 |
+
"training_loss": "N/A",
|
| 110 |
+
"validation_loss": "N/A",
|
| 111 |
+
"perplexity": "N/A"
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
def parse_args():
|
| 115 |
+
"""Parse command line arguments"""
|
| 116 |
+
parser = argparse.ArgumentParser(description="Generate unified model card")
|
| 117 |
+
parser.add_argument("--template", default="templates/model_card.md",
|
| 118 |
+
help="Path to model card template")
|
| 119 |
+
parser.add_argument("--output", default="README.md",
|
| 120 |
+
help="Output path for generated model card")
|
| 121 |
+
parser.add_argument("--repo-name", required=True,
|
| 122 |
+
help="Hugging Face repository name")
|
| 123 |
+
parser.add_argument("--model-name", help="Model name")
|
| 124 |
+
parser.add_argument("--experiment-name", help="Experiment name")
|
| 125 |
+
parser.add_argument("--dataset-name", help="Dataset name")
|
| 126 |
+
parser.add_argument("--training-config", help="Training configuration type")
|
| 127 |
+
parser.add_argument("--trainer-type", help="Trainer type")
|
| 128 |
+
parser.add_argument("--batch-size", help="Batch size")
|
| 129 |
+
parser.add_argument("--learning-rate", help="Learning rate")
|
| 130 |
+
parser.add_argument("--max-epochs", help="Maximum epochs")
|
| 131 |
+
parser.add_argument("--max-seq-length", help="Maximum sequence length")
|
| 132 |
+
parser.add_argument("--hardware-info", help="Hardware information")
|
| 133 |
+
parser.add_argument("--trackio-url", help="Trackio URL")
|
| 134 |
+
parser.add_argument("--dataset-repo", help="Dataset repository")
|
| 135 |
+
parser.add_argument("--author-name", help="Author name")
|
| 136 |
+
parser.add_argument("--quantized-models", action="store_true",
|
| 137 |
+
help="Include quantized models")
|
| 138 |
+
parser.add_argument("--dataset-sample-size", help="Dataset sample size")
|
| 139 |
+
parser.add_argument("--training-loss", help="Training loss value")
|
| 140 |
+
parser.add_argument("--validation-loss", help="Validation loss value")
|
| 141 |
+
parser.add_argument("--perplexity", help="Perplexity value")
|
| 142 |
+
|
| 143 |
+
return parser.parse_args()
|
| 144 |
+
|
| 145 |
+
def main():
|
| 146 |
+
"""Main function"""
|
| 147 |
+
args = parse_args()
|
| 148 |
+
|
| 149 |
+
# Setup logging
|
| 150 |
+
logging.basicConfig(
|
| 151 |
+
level=logging.INFO,
|
| 152 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
# Create generator
|
| 157 |
+
generator = ModelCardGenerator(args.template)
|
| 158 |
+
|
| 159 |
+
# Create variables dictionary
|
| 160 |
+
variables = create_default_variables()
|
| 161 |
+
|
| 162 |
+
# Override with command line arguments
|
| 163 |
+
if args.repo_name:
|
| 164 |
+
variables["repo_name"] = args.repo_name
|
| 165 |
+
if args.model_name:
|
| 166 |
+
variables["model_name"] = args.model_name
|
| 167 |
+
if args.experiment_name:
|
| 168 |
+
variables["experiment_name"] = args.experiment_name
|
| 169 |
+
if args.dataset_name:
|
| 170 |
+
variables["dataset_name"] = args.dataset_name
|
| 171 |
+
if args.training_config:
|
| 172 |
+
variables["training_config_type"] = args.training_config
|
| 173 |
+
if args.trainer_type:
|
| 174 |
+
variables["trainer_type"] = args.trainer_type
|
| 175 |
+
if args.batch_size:
|
| 176 |
+
variables["batch_size"] = args.batch_size
|
| 177 |
+
if args.learning_rate:
|
| 178 |
+
variables["learning_rate"] = args.learning_rate
|
| 179 |
+
if args.max_epochs:
|
| 180 |
+
variables["max_epochs"] = args.max_epochs
|
| 181 |
+
if args.max_seq_length:
|
| 182 |
+
variables["max_seq_length"] = args.max_seq_length
|
| 183 |
+
if args.hardware_info:
|
| 184 |
+
variables["hardware_info"] = args.hardware_info
|
| 185 |
+
if args.trackio_url:
|
| 186 |
+
variables["trackio_url"] = args.trackio_url
|
| 187 |
+
if args.dataset_repo:
|
| 188 |
+
variables["dataset_repo"] = args.dataset_repo
|
| 189 |
+
if args.author_name:
|
| 190 |
+
variables["author_name"] = args.author_name
|
| 191 |
+
if args.quantized_models:
|
| 192 |
+
variables["quantized_models"] = True
|
| 193 |
+
if args.dataset_sample_size:
|
| 194 |
+
variables["dataset_sample_size"] = args.dataset_sample_size
|
| 195 |
+
if args.training_loss:
|
| 196 |
+
variables["training_loss"] = args.training_loss
|
| 197 |
+
if args.validation_loss:
|
| 198 |
+
variables["validation_loss"] = args.validation_loss
|
| 199 |
+
if args.perplexity:
|
| 200 |
+
variables["perplexity"] = args.perplexity
|
| 201 |
+
|
| 202 |
+
# Generate model card
|
| 203 |
+
print("🔄 Generating model card...")
|
| 204 |
+
content = generator.generate_model_card(variables)
|
| 205 |
+
|
| 206 |
+
# Save model card
|
| 207 |
+
if generator.save_model_card(content, args.output):
|
| 208 |
+
print("✅ Model card generated successfully!")
|
| 209 |
+
print(f"📄 Output: {args.output}")
|
| 210 |
+
else:
|
| 211 |
+
print("❌ Failed to generate model card")
|
| 212 |
+
return 1
|
| 213 |
+
|
| 214 |
+
return 0
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"❌ Error generating model card: {e}")
|
| 218 |
+
return 1
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
exit(main())
|
scripts/push_to_huggingface.py
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Push Trained Model and Results to Hugging Face Hub
|
| 4 |
+
Integrates with Trackio monitoring and HF Datasets for complete model deployment
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import argparse
|
| 10 |
+
import logging
|
| 11 |
+
import time
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, Any, Optional, List
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
import subprocess
|
| 16 |
+
import shutil
|
| 17 |
+
import platform
|
| 18 |
+
|
| 19 |
+
# Set timeout for HF operations to prevent hanging
|
| 20 |
+
os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '300'
|
| 21 |
+
os.environ['HF_HUB_UPLOAD_TIMEOUT'] = '600'
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from huggingface_hub import HfApi, create_repo, upload_file
|
| 25 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 26 |
+
HF_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
HF_AVAILABLE = False
|
| 29 |
+
print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
import sys
|
| 33 |
+
import os
|
| 34 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
|
| 35 |
+
from monitoring import SmolLM3Monitor
|
| 36 |
+
MONITORING_AVAILABLE = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
MONITORING_AVAILABLE = False
|
| 39 |
+
print("Warning: monitoring module not available")
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
class TimeoutError(Exception):
|
| 44 |
+
"""Custom timeout exception"""
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
def timeout_handler(signum, frame):
|
| 48 |
+
"""Signal handler for timeout"""
|
| 49 |
+
raise TimeoutError("Operation timed out")
|
| 50 |
+
|
| 51 |
+
class HuggingFacePusher:
|
| 52 |
+
"""Push trained models and results to Hugging Face Hub with HF Datasets integration"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
model_path: str,
|
| 57 |
+
repo_name: str,
|
| 58 |
+
token: Optional[str] = None,
|
| 59 |
+
private: bool = False,
|
| 60 |
+
trackio_url: Optional[str] = None,
|
| 61 |
+
experiment_name: Optional[str] = None,
|
| 62 |
+
dataset_repo: Optional[str] = None,
|
| 63 |
+
hf_token: Optional[str] = None,
|
| 64 |
+
author_name: Optional[str] = None,
|
| 65 |
+
model_description: Optional[str] = None,
|
| 66 |
+
training_config_type: Optional[str] = None,
|
| 67 |
+
model_name: Optional[str] = None,
|
| 68 |
+
dataset_name: Optional[str] = None,
|
| 69 |
+
batch_size: Optional[str] = None,
|
| 70 |
+
learning_rate: Optional[str] = None,
|
| 71 |
+
max_epochs: Optional[str] = None,
|
| 72 |
+
max_seq_length: Optional[str] = None,
|
| 73 |
+
trainer_type: Optional[str] = None
|
| 74 |
+
):
|
| 75 |
+
self.model_path = Path(model_path)
|
| 76 |
+
# Original user input (may be just the repo name without username)
|
| 77 |
+
self.repo_name = repo_name
|
| 78 |
+
self.token = token or hf_token or os.getenv('HF_TOKEN')
|
| 79 |
+
self.private = private
|
| 80 |
+
self.trackio_url = trackio_url
|
| 81 |
+
self.experiment_name = experiment_name
|
| 82 |
+
self.author_name = author_name
|
| 83 |
+
self.model_description = model_description
|
| 84 |
+
|
| 85 |
+
# Training configuration details for model card generation
|
| 86 |
+
self.training_config_type = training_config_type
|
| 87 |
+
self.model_name = model_name
|
| 88 |
+
self.dataset_name = dataset_name
|
| 89 |
+
self.batch_size = batch_size
|
| 90 |
+
self.learning_rate = learning_rate
|
| 91 |
+
self.max_epochs = max_epochs
|
| 92 |
+
self.max_seq_length = max_seq_length
|
| 93 |
+
self.trainer_type = trainer_type
|
| 94 |
+
|
| 95 |
+
# HF Datasets configuration
|
| 96 |
+
self.dataset_repo = dataset_repo or os.getenv('TRACKIO_DATASET_REPO', 'tonic/trackio-experiments')
|
| 97 |
+
self.hf_token = hf_token or os.getenv('HF_TOKEN')
|
| 98 |
+
|
| 99 |
+
# Initialize HF API
|
| 100 |
+
if HF_AVAILABLE:
|
| 101 |
+
self.api = HfApi(token=self.token)
|
| 102 |
+
else:
|
| 103 |
+
raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub")
|
| 104 |
+
|
| 105 |
+
# Resolve the full repo id (username/repo) if user only provided repo name
|
| 106 |
+
self.repo_id = self._resolve_repo_id(self.repo_name)
|
| 107 |
+
|
| 108 |
+
# Initialize monitoring if available
|
| 109 |
+
self.monitor = None
|
| 110 |
+
if MONITORING_AVAILABLE:
|
| 111 |
+
self.monitor = SmolLM3Monitor(
|
| 112 |
+
experiment_name=experiment_name or "model_push",
|
| 113 |
+
trackio_url=trackio_url,
|
| 114 |
+
enable_tracking=bool(trackio_url),
|
| 115 |
+
hf_token=self.hf_token,
|
| 116 |
+
dataset_repo=self.dataset_repo
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
logger.info(f"Initialized HuggingFacePusher for {self.repo_id}")
|
| 120 |
+
logger.info(f"Dataset repository: {self.dataset_repo}")
|
| 121 |
+
|
| 122 |
+
def _resolve_repo_id(self, repo_name: str) -> str:
|
| 123 |
+
"""Return a fully-qualified repo id in the form username/repo.
|
| 124 |
+
|
| 125 |
+
If the provided name already contains a '/', it is returned unchanged.
|
| 126 |
+
Otherwise, we attempt to derive the username from the authenticated token
|
| 127 |
+
or from the HF_USERNAME environment variable.
|
| 128 |
+
"""
|
| 129 |
+
try:
|
| 130 |
+
if "/" in repo_name:
|
| 131 |
+
return repo_name
|
| 132 |
+
|
| 133 |
+
# Need a username. Prefer API whoami(), fallback to env HF_USERNAME
|
| 134 |
+
username: Optional[str] = None
|
| 135 |
+
if self.token:
|
| 136 |
+
try:
|
| 137 |
+
user_info = self.api.whoami()
|
| 138 |
+
username = user_info.get("name") or user_info.get("username")
|
| 139 |
+
except Exception:
|
| 140 |
+
username = None
|
| 141 |
+
|
| 142 |
+
if not username:
|
| 143 |
+
username = os.getenv("HF_USERNAME")
|
| 144 |
+
|
| 145 |
+
if not username:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
"Username could not be determined. Provide a token or set HF_USERNAME, "
|
| 148 |
+
"or pass a fully-qualified repo id 'username/repo'."
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return f"{username}/{repo_name}"
|
| 152 |
+
except Exception as resolve_error:
|
| 153 |
+
logger.error(f"Failed to resolve full repo id for '{repo_name}': {resolve_error}")
|
| 154 |
+
# Fall back to provided value (may fail later at create/upload)
|
| 155 |
+
return repo_name
|
| 156 |
+
|
| 157 |
+
def create_repository(self) -> bool:
|
| 158 |
+
"""Create the Hugging Face repository"""
|
| 159 |
+
try:
|
| 160 |
+
logger.info(f"Creating repository: {self.repo_id}")
|
| 161 |
+
|
| 162 |
+
# Create repository with timeout handling
|
| 163 |
+
try:
|
| 164 |
+
# Create repository
|
| 165 |
+
create_repo(
|
| 166 |
+
repo_id=self.repo_id,
|
| 167 |
+
token=self.token,
|
| 168 |
+
private=self.private,
|
| 169 |
+
exist_ok=True
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
logger.info(f"✅ Repository created: https://huggingface.co/{self.repo_id}")
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.error(f"❌ Repository creation failed: {e}")
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"❌ Failed to create repository: {e}")
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
def validate_model_path(self) -> bool:
|
| 184 |
+
"""Validate that the model path contains required files"""
|
| 185 |
+
# Support both safetensors and pytorch formats
|
| 186 |
+
required_files = [
|
| 187 |
+
"config.json",
|
| 188 |
+
"tokenizer.json",
|
| 189 |
+
"tokenizer_config.json"
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
# Check for model files (either safetensors or pytorch)
|
| 193 |
+
model_files = [
|
| 194 |
+
"model.safetensors.index.json", # Safetensors format
|
| 195 |
+
"pytorch_model.bin" # PyTorch format
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
missing_files = []
|
| 199 |
+
for file in required_files:
|
| 200 |
+
if not (self.model_path / file).exists():
|
| 201 |
+
missing_files.append(file)
|
| 202 |
+
|
| 203 |
+
# Check if at least one model file exists
|
| 204 |
+
model_file_exists = any((self.model_path / file).exists() for file in model_files)
|
| 205 |
+
if not model_file_exists:
|
| 206 |
+
missing_files.extend(model_files)
|
| 207 |
+
|
| 208 |
+
if missing_files:
|
| 209 |
+
logger.error(f"❌ Missing required files: {missing_files}")
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
+
logger.info("✅ Model files validated")
|
| 213 |
+
return True
|
| 214 |
+
|
| 215 |
+
def create_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
|
| 216 |
+
"""Create a comprehensive model card using the generate_model_card.py script"""
|
| 217 |
+
try:
|
| 218 |
+
# Import the model card generator
|
| 219 |
+
import sys
|
| 220 |
+
sys.path.append(os.path.join(os.path.dirname(__file__)))
|
| 221 |
+
from generate_model_card import ModelCardGenerator, create_default_variables
|
| 222 |
+
|
| 223 |
+
# Create generator
|
| 224 |
+
generator = ModelCardGenerator()
|
| 225 |
+
|
| 226 |
+
# Create variables for the model card
|
| 227 |
+
variables = create_default_variables()
|
| 228 |
+
|
| 229 |
+
# Update with actual values
|
| 230 |
+
variables.update({
|
| 231 |
+
"repo_name": self.repo_id,
|
| 232 |
+
"model_name": self.repo_id.split('/')[-1],
|
| 233 |
+
"experiment_name": self.experiment_name or "model_push",
|
| 234 |
+
"dataset_repo": self.dataset_repo,
|
| 235 |
+
"author_name": self.author_name or "Model Author",
|
| 236 |
+
"model_description": self.model_description or "A fine-tuned version of SmolLM3-3B for improved text generation capabilities.",
|
| 237 |
+
"training_config_type": self.training_config_type or "Custom Configuration",
|
| 238 |
+
"base_model": self.model_name or "HuggingFaceTB/SmolLM3-3B",
|
| 239 |
+
"dataset_name": self.dataset_name or "Custom Dataset",
|
| 240 |
+
"trainer_type": self.trainer_type or "SFTTrainer",
|
| 241 |
+
"batch_size": str(self.batch_size) if self.batch_size else "8",
|
| 242 |
+
"learning_rate": str(self.learning_rate) if self.learning_rate else "5e-6",
|
| 243 |
+
"max_epochs": str(self.max_epochs) if self.max_epochs else "3",
|
| 244 |
+
"max_seq_length": str(self.max_seq_length) if self.max_seq_length else "2048",
|
| 245 |
+
"hardware_info": self._get_hardware_info(),
|
| 246 |
+
"trackio_url": self.trackio_url or "N/A",
|
| 247 |
+
"training_loss": str(results.get('train_loss', 'N/A')),
|
| 248 |
+
"validation_loss": str(results.get('eval_loss', 'N/A')),
|
| 249 |
+
"perplexity": str(results.get('perplexity', 'N/A')),
|
| 250 |
+
"quantized_models": False # Set to True if quantized models are available
|
| 251 |
+
})
|
| 252 |
+
|
| 253 |
+
# Generate the model card
|
| 254 |
+
model_card_content = generator.generate_model_card(variables)
|
| 255 |
+
|
| 256 |
+
logger.info("✅ Model card generated using generate_model_card.py")
|
| 257 |
+
return model_card_content
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.error(f"❌ Failed to generate model card with generator: {e}")
|
| 261 |
+
logger.info("🔄 Falling back to simple model card")
|
| 262 |
+
return self._create_simple_model_card(training_config, results)
|
| 263 |
+
|
| 264 |
+
def _create_simple_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
|
| 265 |
+
"""Create a simple model card without complex YAML to avoid formatting issues"""
|
| 266 |
+
return f"""---
|
| 267 |
+
language:
|
| 268 |
+
- en
|
| 269 |
+
- fr
|
| 270 |
+
license: apache-2.0
|
| 271 |
+
tags:
|
| 272 |
+
- smollm3
|
| 273 |
+
- fine-tuned
|
| 274 |
+
- causal-lm
|
| 275 |
+
- text-generation
|
| 276 |
+
pipeline_tag: text-generation
|
| 277 |
+
base_model: HuggingFaceTB/SmolLM3-3B
|
| 278 |
+
---
|
| 279 |
+
|
| 280 |
+
# {self.repo_id.split('/')[-1]}
|
| 281 |
+
|
| 282 |
+
This is a fine-tuned SmolLM3 model based on the HuggingFaceTB/SmolLM3-3B architecture.
|
| 283 |
+
|
| 284 |
+
## Model Details
|
| 285 |
+
|
| 286 |
+
- **Base Model**: HuggingFaceTB/SmolLM3-3B
|
| 287 |
+
- **Fine-tuning Method**: Supervised Fine-tuning
|
| 288 |
+
- **Training Date**: {datetime.now().strftime('%Y-%m-%d')}
|
| 289 |
+
- **Model Size**: {self._get_model_size():.1f} GB
|
| 290 |
+
- **Dataset Repository**: {self.dataset_repo}
|
| 291 |
+
- **Hardware**: {self._get_hardware_info()}
|
| 292 |
+
|
| 293 |
+
## Training Configuration
|
| 294 |
+
|
| 295 |
+
```json
|
| 296 |
+
{json.dumps(training_config, indent=2)}
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
## Training Results
|
| 300 |
+
|
| 301 |
+
```json
|
| 302 |
+
{json.dumps(results, indent=2)}
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
## Usage
|
| 306 |
+
|
| 307 |
+
```python
|
| 308 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 309 |
+
|
| 310 |
+
# Load model and tokenizer
|
| 311 |
+
model = AutoModelForCausalLM.from_pretrained("{self.repo_id}")
|
| 312 |
+
tokenizer = AutoTokenizer.from_pretrained("{self.repo_id}")
|
| 313 |
+
|
| 314 |
+
# Generate text
|
| 315 |
+
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
|
| 316 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 317 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
## Training Information
|
| 321 |
+
|
| 322 |
+
- **Base Model**: HuggingFaceTB/SmolLM3-3B
|
| 323 |
+
- **Hardware**: {self._get_hardware_info()}
|
| 324 |
+
- **Training Time**: {results.get('training_time_hours', 'Unknown')} hours
|
| 325 |
+
- **Final Loss**: {results.get('final_loss', 'Unknown')}
|
| 326 |
+
- **Final Accuracy**: {results.get('final_accuracy', 'Unknown')}
|
| 327 |
+
- **Dataset Repository**: {self.dataset_repo}
|
| 328 |
+
|
| 329 |
+
## Model Performance
|
| 330 |
+
|
| 331 |
+
- **Training Loss**: {results.get('train_loss', 'Unknown')}
|
| 332 |
+
- **Validation Loss**: {results.get('eval_loss', 'Unknown')}
|
| 333 |
+
- **Training Steps**: {results.get('total_steps', 'Unknown')}
|
| 334 |
+
|
| 335 |
+
## Experiment Tracking
|
| 336 |
+
|
| 337 |
+
This model was trained with experiment tracking enabled. Training metrics and configuration are stored in the HF Dataset repository: `{self.dataset_repo}`
|
| 338 |
+
|
| 339 |
+
## Limitations and Biases
|
| 340 |
+
|
| 341 |
+
This model is fine-tuned for specific tasks and may not generalize well to all use cases. Please evaluate the model's performance on your specific task before deployment.
|
| 342 |
+
|
| 343 |
+
## License
|
| 344 |
+
|
| 345 |
+
This model is licensed under the Apache 2.0 License.
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def _get_model_size(self) -> float:
|
| 349 |
+
"""Get model size in GB"""
|
| 350 |
+
try:
|
| 351 |
+
total_size = 0
|
| 352 |
+
for file in self.model_path.rglob("*"):
|
| 353 |
+
if file.is_file():
|
| 354 |
+
total_size += file.stat().st_size
|
| 355 |
+
return total_size / (1024**3) # Convert to GB
|
| 356 |
+
except:
|
| 357 |
+
return 0.0
|
| 358 |
+
|
| 359 |
+
def _get_hardware_info(self) -> str:
|
| 360 |
+
"""Get hardware information"""
|
| 361 |
+
try:
|
| 362 |
+
import torch
|
| 363 |
+
if torch.cuda.is_available():
|
| 364 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 365 |
+
return f"GPU: {gpu_name}"
|
| 366 |
+
else:
|
| 367 |
+
return "CPU"
|
| 368 |
+
except:
|
| 369 |
+
return "Unknown"
|
| 370 |
+
|
| 371 |
+
def upload_model_files(self) -> bool:
|
| 372 |
+
"""Upload model files to Hugging Face Hub with timeout protection"""
|
| 373 |
+
try:
|
| 374 |
+
logger.info("Uploading model files...")
|
| 375 |
+
|
| 376 |
+
# Upload all files in the model directory
|
| 377 |
+
for file_path in self.model_path.rglob("*"):
|
| 378 |
+
if file_path.is_file():
|
| 379 |
+
relative_path = file_path.relative_to(self.model_path)
|
| 380 |
+
remote_path = str(relative_path)
|
| 381 |
+
|
| 382 |
+
logger.info(f"Uploading {relative_path}")
|
| 383 |
+
|
| 384 |
+
try:
|
| 385 |
+
upload_file(
|
| 386 |
+
path_or_fileobj=str(file_path),
|
| 387 |
+
path_in_repo=remote_path,
|
| 388 |
+
repo_id=self.repo_id,
|
| 389 |
+
token=self.token
|
| 390 |
+
)
|
| 391 |
+
logger.info(f"✅ Uploaded {relative_path}")
|
| 392 |
+
|
| 393 |
+
except Exception as e:
|
| 394 |
+
logger.error(f"❌ Failed to upload {relative_path}: {e}")
|
| 395 |
+
return False
|
| 396 |
+
|
| 397 |
+
logger.info("✅ Model files uploaded successfully")
|
| 398 |
+
return True
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.error(f"❌ Failed to upload model files: {e}")
|
| 402 |
+
return False
|
| 403 |
+
|
| 404 |
+
def upload_training_results(self, results_path: str) -> bool:
|
| 405 |
+
"""Upload training results and logs"""
|
| 406 |
+
try:
|
| 407 |
+
logger.info("Uploading training results...")
|
| 408 |
+
|
| 409 |
+
results_files = [
|
| 410 |
+
"train_results.json",
|
| 411 |
+
"eval_results.json",
|
| 412 |
+
"training_config.json",
|
| 413 |
+
"training.log"
|
| 414 |
+
]
|
| 415 |
+
|
| 416 |
+
for file_name in results_files:
|
| 417 |
+
file_path = Path(results_path) / file_name
|
| 418 |
+
if file_path.exists():
|
| 419 |
+
logger.info(f"Uploading {file_name}")
|
| 420 |
+
upload_file(
|
| 421 |
+
path_or_fileobj=str(file_path),
|
| 422 |
+
path_in_repo=f"training_results/{file_name}",
|
| 423 |
+
repo_id=self.repo_id,
|
| 424 |
+
token=self.token
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
logger.info("✅ Training results uploaded successfully")
|
| 428 |
+
return True
|
| 429 |
+
|
| 430 |
+
except Exception as e:
|
| 431 |
+
logger.error(f"❌ Failed to upload training results: {e}")
|
| 432 |
+
return False
|
| 433 |
+
|
| 434 |
+
def create_readme(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> bool:
|
| 435 |
+
"""Create and upload README.md"""
|
| 436 |
+
try:
|
| 437 |
+
logger.info("Creating README.md...")
|
| 438 |
+
|
| 439 |
+
readme_content = f"""# {self.repo_id.split('/')[-1]}
|
| 440 |
+
|
| 441 |
+
A fine-tuned SmolLM3 model for text generation tasks.
|
| 442 |
+
|
| 443 |
+
## Quick Start
|
| 444 |
+
|
| 445 |
+
```python
|
| 446 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 447 |
+
|
| 448 |
+
model = AutoModelForCausalLM.from_pretrained("{self.repo_id}")
|
| 449 |
+
tokenizer = AutoTokenizer.from_pretrained("{self.repo_id}")
|
| 450 |
+
|
| 451 |
+
# Generate text
|
| 452 |
+
text = "Hello, how are you?"
|
| 453 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 454 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 455 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 456 |
+
```
|
| 457 |
+
|
| 458 |
+
## Model Information
|
| 459 |
+
|
| 460 |
+
- **Base Model**: HuggingFaceTB/SmolLM3-3B
|
| 461 |
+
- **Fine-tuning Date**: {datetime.now().strftime('%Y-%m-%d')}
|
| 462 |
+
- **Model Size**: {self._get_model_size():.1f} GB
|
| 463 |
+
- **Training Steps**: {results.get('total_steps', 'Unknown')}
|
| 464 |
+
- **Final Loss**: {results.get('final_loss', 'Unknown')}
|
| 465 |
+
- **Dataset Repository**: {self.dataset_repo}
|
| 466 |
+
|
| 467 |
+
## Training Configuration
|
| 468 |
+
|
| 469 |
+
```json
|
| 470 |
+
{json.dumps(training_config, indent=2)}
|
| 471 |
+
```
|
| 472 |
+
|
| 473 |
+
## Performance Metrics
|
| 474 |
+
|
| 475 |
+
```json
|
| 476 |
+
{json.dumps(results, indent=2)}
|
| 477 |
+
```
|
| 478 |
+
|
| 479 |
+
## Experiment Tracking
|
| 480 |
+
|
| 481 |
+
Training metrics and configuration are stored in the HF Dataset repository: `{self.dataset_repo}`
|
| 482 |
+
|
| 483 |
+
## Files
|
| 484 |
+
|
| 485 |
+
- `model.safetensors.index.json`: Model weights (safetensors format)
|
| 486 |
+
- `config.json`: Model configuration
|
| 487 |
+
- `tokenizer.json`: Tokenizer configuration
|
| 488 |
+
- `training_results/`: Training logs and results
|
| 489 |
+
|
| 490 |
+
## License
|
| 491 |
+
|
| 492 |
+
MIT License
|
| 493 |
+
"""
|
| 494 |
+
|
| 495 |
+
# Write README to temporary file
|
| 496 |
+
readme_path = Path("temp_readme.md")
|
| 497 |
+
with open(readme_path, "w") as f:
|
| 498 |
+
f.write(readme_content)
|
| 499 |
+
|
| 500 |
+
# Upload README
|
| 501 |
+
upload_file(
|
| 502 |
+
path_or_fileobj=str(readme_path),
|
| 503 |
+
path_in_repo="README.md",
|
| 504 |
+
token=self.token,
|
| 505 |
+
repo_id=self.repo_id
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Clean up
|
| 509 |
+
readme_path.unlink()
|
| 510 |
+
|
| 511 |
+
logger.info("✅ README.md uploaded successfully")
|
| 512 |
+
return True
|
| 513 |
+
|
| 514 |
+
except Exception as e:
|
| 515 |
+
logger.error(f"❌ Failed to create README: {e}")
|
| 516 |
+
return False
|
| 517 |
+
|
| 518 |
+
def log_to_trackio(self, action: str, details: Dict[str, Any]):
|
| 519 |
+
"""Log push action to Trackio and HF Datasets"""
|
| 520 |
+
if self.monitor:
|
| 521 |
+
try:
|
| 522 |
+
# Log to Trackio
|
| 523 |
+
self.monitor.log_metrics({
|
| 524 |
+
"push_action": action,
|
| 525 |
+
"repo_name": self.repo_id,
|
| 526 |
+
"model_size_gb": self._get_model_size(),
|
| 527 |
+
"dataset_repo": self.dataset_repo,
|
| 528 |
+
**details
|
| 529 |
+
})
|
| 530 |
+
|
| 531 |
+
# Log training summary
|
| 532 |
+
self.monitor.log_training_summary({
|
| 533 |
+
"model_push": True,
|
| 534 |
+
"model_repo": self.repo_id,
|
| 535 |
+
"dataset_repo": self.dataset_repo,
|
| 536 |
+
"push_date": datetime.now().isoformat(),
|
| 537 |
+
**details
|
| 538 |
+
})
|
| 539 |
+
|
| 540 |
+
logger.info(f"✅ Logged {action} to Trackio and HF Datasets")
|
| 541 |
+
except Exception as e:
|
| 542 |
+
logger.error(f"❌ Failed to log to Trackio: {e}")
|
| 543 |
+
|
| 544 |
+
def push_model(self, training_config: Optional[Dict[str, Any]] = None,
|
| 545 |
+
results: Optional[Dict[str, Any]] = None) -> bool:
|
| 546 |
+
"""Complete model push process with HF Datasets integration"""
|
| 547 |
+
logger.info(f"🚀 Starting model push to {self.repo_id}")
|
| 548 |
+
logger.info(f"📊 Dataset repository: {self.dataset_repo}")
|
| 549 |
+
|
| 550 |
+
# Validate model path
|
| 551 |
+
if not self.validate_model_path():
|
| 552 |
+
return False
|
| 553 |
+
|
| 554 |
+
# Create repository
|
| 555 |
+
if not self.create_repository():
|
| 556 |
+
return False
|
| 557 |
+
|
| 558 |
+
# Load training config and results if not provided
|
| 559 |
+
if training_config is None:
|
| 560 |
+
training_config = self._load_training_config()
|
| 561 |
+
|
| 562 |
+
if results is None:
|
| 563 |
+
results = self._load_training_results()
|
| 564 |
+
|
| 565 |
+
# Create and upload model card
|
| 566 |
+
model_card = self.create_model_card(training_config, results)
|
| 567 |
+
model_card_path = Path("temp_model_card.md")
|
| 568 |
+
with open(model_card_path, "w") as f:
|
| 569 |
+
f.write(model_card)
|
| 570 |
+
|
| 571 |
+
try:
|
| 572 |
+
upload_file(
|
| 573 |
+
path_or_fileobj=str(model_card_path),
|
| 574 |
+
path_in_repo="README.md",
|
| 575 |
+
repo_id=self.repo_id,
|
| 576 |
+
token=self.token
|
| 577 |
+
)
|
| 578 |
+
finally:
|
| 579 |
+
model_card_path.unlink()
|
| 580 |
+
|
| 581 |
+
# Upload model files
|
| 582 |
+
if not self.upload_model_files():
|
| 583 |
+
return False
|
| 584 |
+
|
| 585 |
+
# Upload training results
|
| 586 |
+
if results:
|
| 587 |
+
self.upload_training_results(str(self.model_path))
|
| 588 |
+
|
| 589 |
+
# Log to Trackio and HF Datasets
|
| 590 |
+
self.log_to_trackio("model_push", {
|
| 591 |
+
"model_path": str(self.model_path),
|
| 592 |
+
"repo_name": self.repo_name,
|
| 593 |
+
"private": self.private,
|
| 594 |
+
"training_config": training_config,
|
| 595 |
+
"results": results
|
| 596 |
+
})
|
| 597 |
+
|
| 598 |
+
logger.info(f"🎉 Model successfully pushed to: https://huggingface.co/{self.repo_id}")
|
| 599 |
+
logger.info(f"📊 Experiment data stored in: {self.dataset_repo}")
|
| 600 |
+
return True
|
| 601 |
+
|
| 602 |
+
def _load_training_config(self) -> Dict[str, Any]:
|
| 603 |
+
"""Load training configuration"""
|
| 604 |
+
config_path = self.model_path / "training_config.json"
|
| 605 |
+
if config_path.exists():
|
| 606 |
+
with open(config_path, "r") as f:
|
| 607 |
+
return json.load(f)
|
| 608 |
+
return {"model_name": "HuggingFaceTB/SmolLM3-3B"}
|
| 609 |
+
|
| 610 |
+
def _load_training_results(self) -> Dict[str, Any]:
|
| 611 |
+
"""Load training results"""
|
| 612 |
+
results_path = self.model_path / "train_results.json"
|
| 613 |
+
if results_path.exists():
|
| 614 |
+
with open(results_path, "r") as f:
|
| 615 |
+
return json.load(f)
|
| 616 |
+
return {"final_loss": "Unknown", "total_steps": "Unknown"}
|
| 617 |
+
|
| 618 |
+
def parse_args():
|
| 619 |
+
"""Parse command line arguments"""
|
| 620 |
+
parser = argparse.ArgumentParser(description='Push trained model to Hugging Face Hub')
|
| 621 |
+
|
| 622 |
+
# Required arguments
|
| 623 |
+
parser.add_argument('model_path', type=str, help='Path to trained model directory')
|
| 624 |
+
parser.add_argument('repo_name', type=str, help='Hugging Face repository name (repo-name). Username will be auto-detected from your token.')
|
| 625 |
+
|
| 626 |
+
# Optional arguments
|
| 627 |
+
parser.add_argument('--token', type=str, default=None, help='Hugging Face token')
|
| 628 |
+
parser.add_argument('--hf-token', type=str, default=None, help='Hugging Face token (alternative to --token)')
|
| 629 |
+
parser.add_argument('--private', action='store_true', help='Make repository private')
|
| 630 |
+
parser.add_argument('--trackio-url', type=str, default=None, help='Trackio Space URL for logging')
|
| 631 |
+
parser.add_argument('--experiment-name', type=str, default=None, help='Experiment name for Trackio')
|
| 632 |
+
parser.add_argument('--dataset-repo', type=str, default=None, help='HF Dataset repository for experiment storage')
|
| 633 |
+
parser.add_argument('--author-name', type=str, default=None, help='Author name for model card')
|
| 634 |
+
parser.add_argument('--model-description', type=str, default=None, help='Model description for model card')
|
| 635 |
+
parser.add_argument('--training-config-type', type=str, default=None, help='Training configuration type')
|
| 636 |
+
parser.add_argument('--model-name', type=str, default=None, help='Base model name')
|
| 637 |
+
parser.add_argument('--dataset-name', type=str, default=None, help='Dataset name')
|
| 638 |
+
parser.add_argument('--batch-size', type=str, default=None, help='Batch size')
|
| 639 |
+
parser.add_argument('--learning-rate', type=str, default=None, help='Learning rate')
|
| 640 |
+
parser.add_argument('--max-epochs', type=str, default=None, help='Maximum epochs')
|
| 641 |
+
parser.add_argument('--max-seq-length', type=str, default=None, help='Maximum sequence length')
|
| 642 |
+
parser.add_argument('--trainer-type', type=str, default=None, help='Trainer type')
|
| 643 |
+
|
| 644 |
+
return parser.parse_args()
|
| 645 |
+
|
| 646 |
+
def main():
|
| 647 |
+
"""Main function"""
|
| 648 |
+
args = parse_args()
|
| 649 |
+
|
| 650 |
+
# Setup logging
|
| 651 |
+
logging.basicConfig(
|
| 652 |
+
level=logging.INFO,
|
| 653 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
logger.info("Starting model push to Hugging Face Hub")
|
| 657 |
+
|
| 658 |
+
# Initialize pusher
|
| 659 |
+
try:
|
| 660 |
+
pusher = HuggingFacePusher(
|
| 661 |
+
model_path=args.model_path,
|
| 662 |
+
repo_name=args.repo_name,
|
| 663 |
+
token=args.token,
|
| 664 |
+
private=args.private,
|
| 665 |
+
trackio_url=args.trackio_url,
|
| 666 |
+
experiment_name=args.experiment_name,
|
| 667 |
+
dataset_repo=args.dataset_repo,
|
| 668 |
+
hf_token=args.hf_token,
|
| 669 |
+
author_name=args.author_name,
|
| 670 |
+
model_description=args.model_description,
|
| 671 |
+
training_config_type=args.training_config_type,
|
| 672 |
+
model_name=args.model_name,
|
| 673 |
+
dataset_name=args.dataset_name,
|
| 674 |
+
batch_size=args.batch_size,
|
| 675 |
+
learning_rate=args.learning_rate,
|
| 676 |
+
max_epochs=args.max_epochs,
|
| 677 |
+
max_seq_length=args.max_seq_length,
|
| 678 |
+
trainer_type=args.trainer_type
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# Push model
|
| 682 |
+
success = pusher.push_model()
|
| 683 |
+
|
| 684 |
+
if success:
|
| 685 |
+
logger.info("✅ Model push completed successfully!")
|
| 686 |
+
logger.info(f"🌐 View your model at: https://huggingface.co/{args.repo_name}")
|
| 687 |
+
if args.dataset_repo:
|
| 688 |
+
logger.info(f"📊 View experiment data at: https://huggingface.co/datasets/{args.dataset_repo}")
|
| 689 |
+
else:
|
| 690 |
+
logger.error("❌ Model push failed!")
|
| 691 |
+
return 1
|
| 692 |
+
|
| 693 |
+
except Exception as e:
|
| 694 |
+
logger.error(f"❌ Error during model push: {e}")
|
| 695 |
+
return 1
|
| 696 |
+
|
| 697 |
+
return 0
|
| 698 |
+
|
| 699 |
+
if __name__ == "__main__":
|
| 700 |
+
exit(main())
|
train_lora.py → scripts/train.py
RENAMED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
-
from datasets import load_dataset, Audio
|
| 5 |
from transformers import (
|
| 6 |
VoxtralForConditionalGeneration,
|
| 7 |
VoxtralProcessor,
|
| 8 |
Trainer,
|
| 9 |
TrainingArguments,
|
| 10 |
)
|
| 11 |
-
from peft import LoraConfig, get_peft_model
|
| 12 |
|
| 13 |
|
| 14 |
class VoxtralDataCollator:
|
|
@@ -95,82 +97,114 @@ class VoxtralDataCollator:
|
|
| 95 |
|
| 96 |
return batch
|
| 97 |
|
| 98 |
-
def
|
| 99 |
-
"""Load
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
return train_dataset, eval_dataset
|
| 113 |
|
| 114 |
|
| 115 |
def main():
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 122 |
print(f"Using device: {torch_device}")
|
| 123 |
-
|
| 124 |
-
# Load processor and model
|
| 125 |
print("Loading processor and model...")
|
| 126 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
| 127 |
-
# Load model with LoRA configuration
|
| 128 |
-
config = LoraConfig(
|
| 129 |
-
r=8, # Rank of LoRA
|
| 130 |
-
lora_alpha=32,
|
| 131 |
-
lora_dropout=0.0,
|
| 132 |
-
bias="none",
|
| 133 |
-
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 134 |
-
task_type="SEQ_2_SEQ_LM",
|
| 135 |
-
)
|
| 136 |
-
# print number of parameters in model
|
| 137 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
| 138 |
model_checkpoint,
|
| 139 |
torch_dtype=torch.bfloat16,
|
| 140 |
device_map="auto"
|
| 141 |
)
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
# Setup data collator
|
| 152 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
| 153 |
-
|
| 154 |
-
# Simple training arguments
|
| 155 |
training_args = TrainingArguments(
|
| 156 |
output_dir=output_dir,
|
| 157 |
-
per_device_train_batch_size=
|
| 158 |
-
per_device_eval_batch_size=
|
| 159 |
-
gradient_accumulation_steps=
|
| 160 |
-
learning_rate=
|
| 161 |
-
num_train_epochs=
|
| 162 |
bf16=True,
|
| 163 |
-
logging_steps=
|
| 164 |
-
eval_steps=
|
| 165 |
-
save_steps=
|
| 166 |
eval_strategy="steps" if eval_dataset else "no",
|
| 167 |
save_strategy="steps",
|
| 168 |
report_to="none",
|
| 169 |
remove_unused_columns=False,
|
| 170 |
dataloader_num_workers=1,
|
| 171 |
)
|
| 172 |
-
|
| 173 |
-
# Setup trainer
|
| 174 |
trainer = Trainer(
|
| 175 |
model=model,
|
| 176 |
args=training_args,
|
|
@@ -178,22 +212,18 @@ def main():
|
|
| 178 |
eval_dataset=eval_dataset,
|
| 179 |
data_collator=data_collator,
|
| 180 |
)
|
| 181 |
-
|
| 182 |
-
# Start training
|
| 183 |
print("Starting training...")
|
| 184 |
trainer.train()
|
| 185 |
|
| 186 |
-
|
| 187 |
-
# Save model and processor
|
| 188 |
print(f"Saving model to {output_dir}")
|
| 189 |
trainer.save_model()
|
| 190 |
processor.save_pretrained(output_dir)
|
| 191 |
-
|
| 192 |
-
# Final evaluation
|
| 193 |
if eval_dataset:
|
| 194 |
results = trainer.evaluate()
|
| 195 |
print(f"Final evaluation results: {results}")
|
| 196 |
-
|
| 197 |
print("Training completed successfully!")
|
| 198 |
|
| 199 |
if __name__ == "__main__":
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
import torch
|
| 7 |
+
from datasets import load_dataset, Audio, Dataset
|
| 8 |
from transformers import (
|
| 9 |
VoxtralForConditionalGeneration,
|
| 10 |
VoxtralProcessor,
|
| 11 |
Trainer,
|
| 12 |
TrainingArguments,
|
| 13 |
)
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class VoxtralDataCollator:
|
|
|
|
| 97 |
|
| 98 |
return batch
|
| 99 |
|
| 100 |
+
def _load_jsonl_dataset(jsonl_path: str) -> Dataset:
|
| 101 |
+
"""Load local JSONL with fields {audio_path, text} into a Dataset with audio column."""
|
| 102 |
+
records = []
|
| 103 |
+
jsonl_file = Path(jsonl_path)
|
| 104 |
+
if not jsonl_file.exists():
|
| 105 |
+
raise FileNotFoundError(f"Dataset jsonl not found: {jsonl_path}")
|
| 106 |
+
with open(jsonl_file, "r", encoding="utf-8") as f:
|
| 107 |
+
for line in f:
|
| 108 |
+
if not line.strip():
|
| 109 |
+
continue
|
| 110 |
+
obj = json.loads(line)
|
| 111 |
+
audio_path = obj.get("audio_path") or obj.get("audio")
|
| 112 |
+
text = obj.get("text")
|
| 113 |
+
if not audio_path or text is None:
|
| 114 |
+
continue
|
| 115 |
+
records.append({"audio": audio_path, "text": text})
|
| 116 |
+
if not records:
|
| 117 |
+
raise ValueError("No valid records found in JSONL. Expect keys: audio_path, text")
|
| 118 |
+
ds = Dataset.from_list(records)
|
| 119 |
+
# Cast the audio column from file paths and resample to 16kHz
|
| 120 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 121 |
+
return ds
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def load_and_prepare_dataset(dataset_jsonl: str | None, dataset_name: str | None, dataset_config: str | None,
|
| 125 |
+
train_count: int, eval_count: int):
|
| 126 |
+
"""Load and prepare dataset for training.
|
| 127 |
+
|
| 128 |
+
Priority: local JSONL > HF dataset name/config > fallback tiny sample.
|
| 129 |
+
"""
|
| 130 |
+
if dataset_jsonl:
|
| 131 |
+
print(f"Loading local JSONL dataset: {dataset_jsonl}")
|
| 132 |
+
ds = _load_jsonl_dataset(dataset_jsonl)
|
| 133 |
+
else:
|
| 134 |
+
ds_name = dataset_name or "hf-audio/esb-datasets-test-only-sorted"
|
| 135 |
+
ds_cfg = dataset_config or "voxpopuli"
|
| 136 |
+
print(f"Loading dataset: {ds_name}/{ds_cfg}")
|
| 137 |
+
ds = load_dataset(ds_name, ds_cfg, split="test")
|
| 138 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 139 |
+
|
| 140 |
+
total = len(ds)
|
| 141 |
+
train_end = min(train_count, total)
|
| 142 |
+
eval_end = min(train_end + eval_count, total)
|
| 143 |
+
train_dataset = ds.select(range(train_end))
|
| 144 |
+
eval_dataset = ds.select(range(train_end, eval_end)) if eval_end > train_end else None
|
| 145 |
return train_dataset, eval_dataset
|
| 146 |
|
| 147 |
|
| 148 |
def main():
|
| 149 |
+
parser = argparse.ArgumentParser(description="Full fine-tune Voxtral for ASR")
|
| 150 |
+
parser.add_argument("--dataset-jsonl", type=str, default=None, help="Path to local JSONL with {audio_path, text}")
|
| 151 |
+
parser.add_argument("--dataset-name", type=str, default=None, help="HF dataset repo (if not using JSONL)")
|
| 152 |
+
parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
|
| 153 |
+
parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
|
| 154 |
+
parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
|
| 155 |
+
parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
|
| 156 |
+
parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
|
| 157 |
+
parser.add_argument("--batch-size", type=int, default=2)
|
| 158 |
+
parser.add_argument("--eval-batch-size", type=int, default=4)
|
| 159 |
+
parser.add_argument("--grad-accum", type=int, default=4)
|
| 160 |
+
parser.add_argument("--learning-rate", type=float, default=5e-5)
|
| 161 |
+
parser.add_argument("--epochs", type=float, default=3)
|
| 162 |
+
parser.add_argument("--logging-steps", type=int, default=10)
|
| 163 |
+
parser.add_argument("--save-steps", type=int, default=50)
|
| 164 |
+
args = parser.parse_args()
|
| 165 |
+
|
| 166 |
+
model_checkpoint = args.model_checkpoint
|
| 167 |
+
output_dir = args.output_dir
|
| 168 |
+
|
| 169 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 170 |
print(f"Using device: {torch_device}")
|
| 171 |
+
|
|
|
|
| 172 |
print("Loading processor and model...")
|
| 173 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
| 175 |
model_checkpoint,
|
| 176 |
torch_dtype=torch.bfloat16,
|
| 177 |
device_map="auto"
|
| 178 |
)
|
| 179 |
+
|
| 180 |
+
train_dataset, eval_dataset = load_and_prepare_dataset(
|
| 181 |
+
dataset_jsonl=args.dataset_jsonl,
|
| 182 |
+
dataset_name=args.dataset_name,
|
| 183 |
+
dataset_config=args.dataset_config,
|
| 184 |
+
train_count=args.train_count,
|
| 185 |
+
eval_count=args.eval_count,
|
| 186 |
+
)
|
| 187 |
+
|
|
|
|
| 188 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
| 189 |
+
|
|
|
|
| 190 |
training_args = TrainingArguments(
|
| 191 |
output_dir=output_dir,
|
| 192 |
+
per_device_train_batch_size=args.batch_size,
|
| 193 |
+
per_device_eval_batch_size=args.eval_batch_size,
|
| 194 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 195 |
+
learning_rate=args.learning_rate,
|
| 196 |
+
num_train_epochs=args.epochs,
|
| 197 |
bf16=True,
|
| 198 |
+
logging_steps=args.logging_steps,
|
| 199 |
+
eval_steps=args.save_steps if eval_dataset else None,
|
| 200 |
+
save_steps=args.save_steps,
|
| 201 |
eval_strategy="steps" if eval_dataset else "no",
|
| 202 |
save_strategy="steps",
|
| 203 |
report_to="none",
|
| 204 |
remove_unused_columns=False,
|
| 205 |
dataloader_num_workers=1,
|
| 206 |
)
|
| 207 |
+
|
|
|
|
| 208 |
trainer = Trainer(
|
| 209 |
model=model,
|
| 210 |
args=training_args,
|
|
|
|
| 212 |
eval_dataset=eval_dataset,
|
| 213 |
data_collator=data_collator,
|
| 214 |
)
|
| 215 |
+
|
|
|
|
| 216 |
print("Starting training...")
|
| 217 |
trainer.train()
|
| 218 |
|
|
|
|
|
|
|
| 219 |
print(f"Saving model to {output_dir}")
|
| 220 |
trainer.save_model()
|
| 221 |
processor.save_pretrained(output_dir)
|
| 222 |
+
|
|
|
|
| 223 |
if eval_dataset:
|
| 224 |
results = trainer.evaluate()
|
| 225 |
print(f"Final evaluation results: {results}")
|
| 226 |
+
|
| 227 |
print("Training completed successfully!")
|
| 228 |
|
| 229 |
if __name__ == "__main__":
|
train.py → scripts/train_lora.py
RENAMED
|
@@ -1,13 +1,17 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
-
from datasets import load_dataset, Audio
|
| 5 |
from transformers import (
|
| 6 |
VoxtralForConditionalGeneration,
|
| 7 |
VoxtralProcessor,
|
| 8 |
Trainer,
|
| 9 |
TrainingArguments,
|
| 10 |
)
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class VoxtralDataCollator:
|
|
@@ -94,68 +98,128 @@ class VoxtralDataCollator:
|
|
| 94 |
|
| 95 |
return batch
|
| 96 |
|
| 97 |
-
def
|
| 98 |
-
"""Load
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
return train_dataset, eval_dataset
|
| 112 |
|
| 113 |
|
| 114 |
def main():
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 121 |
print(f"Using device: {torch_device}")
|
| 122 |
-
|
| 123 |
-
# Load processor and model
|
| 124 |
print("Loading processor and model...")
|
| 125 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
| 128 |
model_checkpoint,
|
| 129 |
torch_dtype=torch.bfloat16,
|
| 130 |
device_map="auto"
|
| 131 |
)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
| 138 |
-
|
| 139 |
-
# Simple training arguments
|
| 140 |
training_args = TrainingArguments(
|
| 141 |
output_dir=output_dir,
|
| 142 |
-
per_device_train_batch_size=
|
| 143 |
-
per_device_eval_batch_size=
|
| 144 |
-
gradient_accumulation_steps=
|
| 145 |
-
learning_rate=
|
| 146 |
-
num_train_epochs=
|
| 147 |
bf16=True,
|
| 148 |
-
logging_steps=
|
| 149 |
-
eval_steps=
|
| 150 |
-
save_steps=
|
| 151 |
eval_strategy="steps" if eval_dataset else "no",
|
| 152 |
save_strategy="steps",
|
| 153 |
report_to="none",
|
| 154 |
remove_unused_columns=False,
|
| 155 |
dataloader_num_workers=1,
|
| 156 |
)
|
| 157 |
-
|
| 158 |
-
# Setup trainer
|
| 159 |
trainer = Trainer(
|
| 160 |
model=model,
|
| 161 |
args=training_args,
|
|
@@ -163,22 +227,18 @@ def main():
|
|
| 163 |
eval_dataset=eval_dataset,
|
| 164 |
data_collator=data_collator,
|
| 165 |
)
|
| 166 |
-
|
| 167 |
-
# Start training
|
| 168 |
print("Starting training...")
|
| 169 |
trainer.train()
|
| 170 |
|
| 171 |
-
|
| 172 |
-
# Save model and processor
|
| 173 |
print(f"Saving model to {output_dir}")
|
| 174 |
trainer.save_model()
|
| 175 |
processor.save_pretrained(output_dir)
|
| 176 |
-
|
| 177 |
-
# Final evaluation
|
| 178 |
if eval_dataset:
|
| 179 |
results = trainer.evaluate()
|
| 180 |
print(f"Final evaluation results: {results}")
|
| 181 |
-
|
| 182 |
print("Training completed successfully!")
|
| 183 |
|
| 184 |
if __name__ == "__main__":
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
import torch
|
| 7 |
+
from datasets import load_dataset, Audio, Dataset
|
| 8 |
from transformers import (
|
| 9 |
VoxtralForConditionalGeneration,
|
| 10 |
VoxtralProcessor,
|
| 11 |
Trainer,
|
| 12 |
TrainingArguments,
|
| 13 |
)
|
| 14 |
+
from peft import LoraConfig, get_peft_model
|
| 15 |
|
| 16 |
|
| 17 |
class VoxtralDataCollator:
|
|
|
|
| 98 |
|
| 99 |
return batch
|
| 100 |
|
| 101 |
+
def _load_jsonl_dataset(jsonl_path: str) -> Dataset:
|
| 102 |
+
"""Load local JSONL with fields {audio_path, text} into a Dataset with audio column."""
|
| 103 |
+
records = []
|
| 104 |
+
jsonl_file = Path(jsonl_path)
|
| 105 |
+
if not jsonl_file.exists():
|
| 106 |
+
raise FileNotFoundError(f"Dataset jsonl not found: {jsonl_path}")
|
| 107 |
+
with open(jsonl_file, "r", encoding="utf-8") as f:
|
| 108 |
+
for line in f:
|
| 109 |
+
if not line.strip():
|
| 110 |
+
continue
|
| 111 |
+
obj = json.loads(line)
|
| 112 |
+
audio_path = obj.get("audio_path") or obj.get("audio")
|
| 113 |
+
text = obj.get("text")
|
| 114 |
+
if not audio_path or text is None:
|
| 115 |
+
continue
|
| 116 |
+
records.append({"audio": audio_path, "text": text})
|
| 117 |
+
if not records:
|
| 118 |
+
raise ValueError("No valid records found in JSONL. Expect keys: audio_path, text")
|
| 119 |
+
ds = Dataset.from_list(records)
|
| 120 |
+
# Cast the audio column from file paths and resample to 16kHz
|
| 121 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 122 |
+
return ds
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_and_prepare_dataset(dataset_jsonl: str | None, dataset_name: str | None, dataset_config: str | None,
|
| 126 |
+
train_count: int, eval_count: int):
|
| 127 |
+
"""Load and prepare dataset for training (JSONL or HF hub)."""
|
| 128 |
+
if dataset_jsonl:
|
| 129 |
+
print(f"Loading local JSONL dataset: {dataset_jsonl}")
|
| 130 |
+
ds = _load_jsonl_dataset(dataset_jsonl)
|
| 131 |
+
else:
|
| 132 |
+
ds_name = dataset_name or "hf-audio/esb-datasets-test-only-sorted"
|
| 133 |
+
ds_cfg = dataset_config or "voxpopuli"
|
| 134 |
+
print(f"Loading dataset: {ds_name}/{ds_cfg}")
|
| 135 |
+
ds = load_dataset(ds_name, ds_cfg, split="test")
|
| 136 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 137 |
+
|
| 138 |
+
total = len(ds)
|
| 139 |
+
train_end = min(train_count, total)
|
| 140 |
+
eval_end = min(train_end + eval_count, total)
|
| 141 |
+
train_dataset = ds.select(range(train_end))
|
| 142 |
+
eval_dataset = ds.select(range(train_end, eval_end)) if eval_end > train_end else None
|
| 143 |
return train_dataset, eval_dataset
|
| 144 |
|
| 145 |
|
| 146 |
def main():
|
| 147 |
+
parser = argparse.ArgumentParser(description="LoRA fine-tune Voxtral for ASR")
|
| 148 |
+
parser.add_argument("--dataset-jsonl", type=str, default=None, help="Path to local JSONL with {audio_path, text}")
|
| 149 |
+
parser.add_argument("--dataset-name", type=str, default=None, help="HF dataset repo (if not using JSONL)")
|
| 150 |
+
parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
|
| 151 |
+
parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
|
| 152 |
+
parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
|
| 153 |
+
parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
|
| 154 |
+
parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
|
| 155 |
+
parser.add_argument("--batch-size", type=int, default=2)
|
| 156 |
+
parser.add_argument("--eval-batch-size", type=int, default=4)
|
| 157 |
+
parser.add_argument("--grad-accum", type=int, default=4)
|
| 158 |
+
parser.add_argument("--learning-rate", type=float, default=5e-5)
|
| 159 |
+
parser.add_argument("--epochs", type=float, default=3)
|
| 160 |
+
parser.add_argument("--logging-steps", type=int, default=10)
|
| 161 |
+
parser.add_argument("--save-steps", type=int, default=50)
|
| 162 |
+
parser.add_argument("--lora-r", type=int, default=8)
|
| 163 |
+
parser.add_argument("--lora-alpha", type=int, default=32)
|
| 164 |
+
parser.add_argument("--lora-dropout", type=float, default=0.0)
|
| 165 |
+
parser.add_argument("--freeze-audio-tower", action="store_true", help="Freeze audio encoder parameters")
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
|
| 168 |
+
model_checkpoint = args.model_checkpoint
|
| 169 |
+
output_dir = args.output_dir
|
| 170 |
+
|
| 171 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 172 |
print(f"Using device: {torch_device}")
|
| 173 |
+
|
|
|
|
| 174 |
print("Loading processor and model...")
|
| 175 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
| 176 |
+
lora_cfg = LoraConfig(
|
| 177 |
+
r=args.lora_r,
|
| 178 |
+
lora_alpha=args.lora_alpha,
|
| 179 |
+
lora_dropout=args.lora_dropout,
|
| 180 |
+
bias="none",
|
| 181 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 182 |
+
task_type="SEQ_2_SEQ_LM",
|
| 183 |
+
)
|
| 184 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
| 185 |
model_checkpoint,
|
| 186 |
torch_dtype=torch.bfloat16,
|
| 187 |
device_map="auto"
|
| 188 |
)
|
| 189 |
+
if args.freeze_audio_tower:
|
| 190 |
+
for param in model.audio_tower.parameters():
|
| 191 |
+
param.requires_grad = False
|
| 192 |
+
model = get_peft_model(model, lora_cfg)
|
| 193 |
+
model.print_trainable_parameters()
|
| 194 |
+
|
| 195 |
+
train_dataset, eval_dataset = load_and_prepare_dataset(
|
| 196 |
+
dataset_jsonl=args.dataset_jsonl,
|
| 197 |
+
dataset_name=args.dataset_name,
|
| 198 |
+
dataset_config=args.dataset_config,
|
| 199 |
+
train_count=args.train_count,
|
| 200 |
+
eval_count=args.eval_count,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
| 204 |
+
|
|
|
|
| 205 |
training_args = TrainingArguments(
|
| 206 |
output_dir=output_dir,
|
| 207 |
+
per_device_train_batch_size=args.batch_size,
|
| 208 |
+
per_device_eval_batch_size=args.eval_batch_size,
|
| 209 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 210 |
+
learning_rate=args.learning_rate,
|
| 211 |
+
num_train_epochs=args.epochs,
|
| 212 |
bf16=True,
|
| 213 |
+
logging_steps=args.logging_issues if hasattr(args, 'logging_issues') else args.logging_steps,
|
| 214 |
+
eval_steps=args.save_steps if eval_dataset else None,
|
| 215 |
+
save_steps=args.save_steps,
|
| 216 |
eval_strategy="steps" if eval_dataset else "no",
|
| 217 |
save_strategy="steps",
|
| 218 |
report_to="none",
|
| 219 |
remove_unused_columns=False,
|
| 220 |
dataloader_num_workers=1,
|
| 221 |
)
|
| 222 |
+
|
|
|
|
| 223 |
trainer = Trainer(
|
| 224 |
model=model,
|
| 225 |
args=training_args,
|
|
|
|
| 227 |
eval_dataset=eval_dataset,
|
| 228 |
data_collator=data_collator,
|
| 229 |
)
|
| 230 |
+
|
|
|
|
| 231 |
print("Starting training...")
|
| 232 |
trainer.train()
|
| 233 |
|
|
|
|
|
|
|
| 234 |
print(f"Saving model to {output_dir}")
|
| 235 |
trainer.save_model()
|
| 236 |
processor.save_pretrained(output_dir)
|
| 237 |
+
|
|
|
|
| 238 |
if eval_dataset:
|
| 239 |
results = trainer.evaluate()
|
| 240 |
print(f"Final evaluation results: {results}")
|
| 241 |
+
|
| 242 |
print("Training completed successfully!")
|
| 243 |
|
| 244 |
if __name__ == "__main__":
|
templates/datasets/readme.md
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
dataset_info:
|
| 3 |
+
features:
|
| 4 |
+
- name: experiment_id
|
| 5 |
+
dtype: string
|
| 6 |
+
- name: name
|
| 7 |
+
dtype: string
|
| 8 |
+
- name: description
|
| 9 |
+
dtype: string
|
| 10 |
+
- name: created_at
|
| 11 |
+
dtype: string
|
| 12 |
+
- name: status
|
| 13 |
+
dtype: string
|
| 14 |
+
- name: metrics
|
| 15 |
+
dtype: string
|
| 16 |
+
- name: parameters
|
| 17 |
+
dtype: string
|
| 18 |
+
- name: artifacts
|
| 19 |
+
dtype: string
|
| 20 |
+
- name: logs
|
| 21 |
+
dtype: string
|
| 22 |
+
- name: last_updated
|
| 23 |
+
dtype: string
|
| 24 |
+
splits:
|
| 25 |
+
- name: train
|
| 26 |
+
num_bytes: 4945
|
| 27 |
+
num_examples: 2
|
| 28 |
+
download_size: 15529
|
| 29 |
+
dataset_size: 4945
|
| 30 |
+
configs:
|
| 31 |
+
- config_name: default
|
| 32 |
+
data_files:
|
| 33 |
+
- split: train
|
| 34 |
+
path: data/train-*
|
| 35 |
+
tags:
|
| 36 |
+
- track tonic
|
| 37 |
+
- tonic
|
| 38 |
+
- experiment tracking
|
| 39 |
+
- smollm3
|
| 40 |
+
- fine-tuning
|
| 41 |
+
- legml
|
| 42 |
+
- hermes
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
# Trackio Experiments Dataset
|
| 46 |
+
|
| 47 |
+
This dataset stores experiment tracking data for ML training runs, particularly focused on SmolLM3 fine-tuning experiments with comprehensive metrics tracking.
|
| 48 |
+
|
| 49 |
+
## Dataset Structure
|
| 50 |
+
|
| 51 |
+
The dataset contains the following columns:
|
| 52 |
+
|
| 53 |
+
- **experiment_id**: Unique identifier for each experiment
|
| 54 |
+
- **name**: Human-readable name for the experiment
|
| 55 |
+
- **description**: Detailed description of the experiment
|
| 56 |
+
- **created_at**: Timestamp when the experiment was created
|
| 57 |
+
- **status**: Current status (running, completed, failed, paused)
|
| 58 |
+
- **metrics**: JSON string containing training metrics over time
|
| 59 |
+
- **parameters**: JSON string containing experiment configuration
|
| 60 |
+
- **artifacts**: JSON string containing experiment artifacts
|
| 61 |
+
- **logs**: JSON string containing experiment logs
|
| 62 |
+
- **last_updated**: Timestamp of last update
|
| 63 |
+
|
| 64 |
+
## Metrics Structure
|
| 65 |
+
|
| 66 |
+
The metrics field contains JSON arrays with the following structure:
|
| 67 |
+
|
| 68 |
+
```json
|
| 69 |
+
[
|
| 70 |
+
{
|
| 71 |
+
"timestamp": "2025-07-20T11:20:01.780908",
|
| 72 |
+
"step": 25,
|
| 73 |
+
"metrics": {
|
| 74 |
+
"loss": 1.1659,
|
| 75 |
+
"accuracy": 0.759,
|
| 76 |
+
"learning_rate": 7e-08,
|
| 77 |
+
"grad_norm": 10.3125,
|
| 78 |
+
"epoch": 0.004851130919895701,
|
| 79 |
+
|
| 80 |
+
// Advanced Training Metrics
|
| 81 |
+
"total_tokens": 1642080.0,
|
| 82 |
+
"truncated_tokens": 128,
|
| 83 |
+
"padding_tokens": 256,
|
| 84 |
+
"throughput": 3284160.0,
|
| 85 |
+
"step_time": 0.5,
|
| 86 |
+
"batch_size": 8,
|
| 87 |
+
"seq_len": 2048,
|
| 88 |
+
"token_acc": 0.759,
|
| 89 |
+
|
| 90 |
+
// Custom Losses
|
| 91 |
+
"train/gate_ortho": 0.0234,
|
| 92 |
+
"train/center": 0.0156,
|
| 93 |
+
|
| 94 |
+
// System Metrics
|
| 95 |
+
"gpu_memory_allocated": 17.202261447906494,
|
| 96 |
+
"gpu_memory_reserved": 75.474609375,
|
| 97 |
+
"gpu_utilization": 85.2,
|
| 98 |
+
"cpu_percent": 2.7,
|
| 99 |
+
"memory_percent": 10.1
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
]
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Supported Metrics
|
| 106 |
+
|
| 107 |
+
### Core Training Metrics
|
| 108 |
+
- **loss**: Training loss value
|
| 109 |
+
- **accuracy**: Model accuracy
|
| 110 |
+
- **learning_rate**: Current learning rate
|
| 111 |
+
- **grad_norm**: Gradient norm
|
| 112 |
+
- **epoch**: Current epoch progress
|
| 113 |
+
|
| 114 |
+
### Advanced Token Metrics
|
| 115 |
+
- **total_tokens**: Total tokens processed in the batch
|
| 116 |
+
- **truncated_tokens**: Number of tokens truncated during processing
|
| 117 |
+
- **padding_tokens**: Number of padding tokens added
|
| 118 |
+
- **throughput**: Tokens processed per second
|
| 119 |
+
- **step_time**: Time taken for the current training step
|
| 120 |
+
- **batch_size**: Current batch size
|
| 121 |
+
- **seq_len**: Sequence length
|
| 122 |
+
- **token_acc**: Token-level accuracy
|
| 123 |
+
|
| 124 |
+
### Custom Losses (SmolLM3-specific)
|
| 125 |
+
- **train/gate_ortho**: Gate orthogonality loss
|
| 126 |
+
- **train/center**: Center loss component
|
| 127 |
+
|
| 128 |
+
### System Performance Metrics
|
| 129 |
+
- **gpu_memory_allocated**: GPU memory currently allocated (GB)
|
| 130 |
+
- **gpu_memory_reserved**: GPU memory reserved (GB)
|
| 131 |
+
- **gpu_utilization**: GPU utilization percentage
|
| 132 |
+
- **cpu_percent**: CPU usage percentage
|
| 133 |
+
- **memory_percent**: System memory usage percentage
|
| 134 |
+
|
| 135 |
+
## Usage
|
| 136 |
+
|
| 137 |
+
This dataset is automatically used by the Trackio monitoring system to store and retrieve experiment data. It provides persistent storage for experiment tracking across different training runs.
|
| 138 |
+
|
| 139 |
+
## Integration
|
| 140 |
+
|
| 141 |
+
The dataset is used by:
|
| 142 |
+
- Trackio Spaces for experiment visualization
|
| 143 |
+
- Training scripts for logging metrics and parameters
|
| 144 |
+
- Monitoring systems for experiment tracking
|
| 145 |
+
- SmolLM3 fine-tuning pipeline for comprehensive metrics capture
|
| 146 |
+
|
| 147 |
+
## Privacy
|
| 148 |
+
|
| 149 |
+
This dataset is private by default to ensure experiment data security. Only users with appropriate permissions can access the data.
|
| 150 |
+
|
| 151 |
+
## Examples
|
| 152 |
+
|
| 153 |
+
### Sample Experiment Entry
|
| 154 |
+
```json
|
| 155 |
+
{
|
| 156 |
+
"experiment_id": "exp_20250720_130853",
|
| 157 |
+
"name": "smollm3_finetune",
|
| 158 |
+
"description": "SmolLM3 fine-tuning experiment with comprehensive metrics",
|
| 159 |
+
"created_at": "2025-07-20T11:20:01.780908",
|
| 160 |
+
"status": "running",
|
| 161 |
+
"metrics": "[{\"timestamp\": \"2025-07-20T11:20:01.780908\", \"step\": 25, \"metrics\": {\"loss\": 1.1659, \"accuracy\": 0.759, \"total_tokens\": 1642080.0, \"throughput\": 3284160.0, \"train/gate_ortho\": 0.0234, \"train/center\": 0.0156}}]",
|
| 162 |
+
"parameters": "{\"model_name\": \"HuggingFaceTB/SmolLM3-3B\", \"batch_size\": 8, \"learning_rate\": 3.5e-06, \"max_seq_length\": 12288}",
|
| 163 |
+
"artifacts": "[]",
|
| 164 |
+
"logs": "[]",
|
| 165 |
+
"last_updated": "2025-07-20T11:20:01.780908"
|
| 166 |
+
}
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
## License
|
| 170 |
+
|
| 171 |
+
This dataset is part of the Trackio experiment tracking system and follows the same license as the main project.
|
templates/model_card.md
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
- fr
|
| 5 |
+
license: apache-2.0
|
| 6 |
+
library_name: transformers
|
| 7 |
+
tags:
|
| 8 |
+
- smollm3
|
| 9 |
+
- fine-tuned
|
| 10 |
+
- causal-lm
|
| 11 |
+
- text-generation
|
| 12 |
+
- tonic
|
| 13 |
+
- legml
|
| 14 |
+
{{#if quantized_models}}- quantized{{/if}}
|
| 15 |
+
pipeline_tag: text-generation
|
| 16 |
+
base_model: {{base_model}}
|
| 17 |
+
{{#if dataset_name}}
|
| 18 |
+
datasets:
|
| 19 |
+
- {{dataset_name}}
|
| 20 |
+
{{/if}}
|
| 21 |
+
{{#if quantized_models}}
|
| 22 |
+
model-index:
|
| 23 |
+
- name: {{model_name}}
|
| 24 |
+
results:
|
| 25 |
+
- task:
|
| 26 |
+
type: text-generation
|
| 27 |
+
dataset:
|
| 28 |
+
name: {{dataset_name}}
|
| 29 |
+
type: {{dataset_name}}
|
| 30 |
+
metrics:
|
| 31 |
+
- name: Training Loss
|
| 32 |
+
type: loss
|
| 33 |
+
value: "{{training_loss|default:'N/A'}}"
|
| 34 |
+
- name: Validation Loss
|
| 35 |
+
type: loss
|
| 36 |
+
value: "{{validation_loss|default:'N/A'}}"
|
| 37 |
+
- name: Perplexity
|
| 38 |
+
type: perplexity
|
| 39 |
+
value: "{{perplexity|default:'N/A'}}"
|
| 40 |
+
- name: {{model_name}} (int8 quantized)
|
| 41 |
+
results:
|
| 42 |
+
- task:
|
| 43 |
+
type: text-generation
|
| 44 |
+
dataset:
|
| 45 |
+
name: {{dataset_name}}
|
| 46 |
+
type: {{dataset_name}}
|
| 47 |
+
metrics:
|
| 48 |
+
- name: Memory Reduction
|
| 49 |
+
type: memory_efficiency
|
| 50 |
+
value: "~50%"
|
| 51 |
+
- name: Inference Speed
|
| 52 |
+
type: speed
|
| 53 |
+
value: "Faster"
|
| 54 |
+
- name: {{model_name}} (int4 quantized)
|
| 55 |
+
results:
|
| 56 |
+
- task:
|
| 57 |
+
type: text-generation
|
| 58 |
+
dataset:
|
| 59 |
+
name: {{dataset_name}}
|
| 60 |
+
type: {{dataset_name}}
|
| 61 |
+
metrics:
|
| 62 |
+
- name: Memory Reduction
|
| 63 |
+
type: memory_efficiency
|
| 64 |
+
value: "~75%"
|
| 65 |
+
- name: Inference Speed
|
| 66 |
+
type: speed
|
| 67 |
+
value: "Significantly Faster"
|
| 68 |
+
{{else}}
|
| 69 |
+
model-index:
|
| 70 |
+
- name: {{model_name}}
|
| 71 |
+
results:
|
| 72 |
+
- task:
|
| 73 |
+
type: text-generation
|
| 74 |
+
dataset:
|
| 75 |
+
name: {{dataset_name}}
|
| 76 |
+
type: {{dataset_name}}
|
| 77 |
+
metrics:
|
| 78 |
+
- name: Training Loss
|
| 79 |
+
type: loss
|
| 80 |
+
value: "{{training_loss|default:'N/A'}}"
|
| 81 |
+
- name: Validation Loss
|
| 82 |
+
type: loss
|
| 83 |
+
value: "{{validation_loss|default:'N/A'}}"
|
| 84 |
+
- name: Perplexity
|
| 85 |
+
type: perplexity
|
| 86 |
+
value: "{{perplexity|default:'N/A'}}"
|
| 87 |
+
{{/if}}
|
| 88 |
+
{{#if author_name}}
|
| 89 |
+
author: {{author_name}}
|
| 90 |
+
{{/if}}
|
| 91 |
+
{{#if experiment_name}}
|
| 92 |
+
experiment_name: {{experiment_name}}
|
| 93 |
+
{{/if}}
|
| 94 |
+
{{#if trackio_url}}
|
| 95 |
+
trackio_url: {{trackio_url}}
|
| 96 |
+
{{/if}}
|
| 97 |
+
{{#if dataset_repo}}
|
| 98 |
+
dataset_repo: {{dataset_repo}}
|
| 99 |
+
{{/if}}
|
| 100 |
+
{{#if hardware_info}}
|
| 101 |
+
hardware: "{{hardware_info}}"
|
| 102 |
+
{{/if}}
|
| 103 |
+
{{#if training_config_type}}
|
| 104 |
+
training_config: {{training_config_type}}
|
| 105 |
+
{{/if}}
|
| 106 |
+
{{#if trainer_type}}
|
| 107 |
+
trainer_type: {{trainer_type}}
|
| 108 |
+
{{/if}}
|
| 109 |
+
{{#if batch_size}}
|
| 110 |
+
batch_size: {{batch_size}}
|
| 111 |
+
{{/if}}
|
| 112 |
+
{{#if learning_rate}}
|
| 113 |
+
learning_rate: {{learning_rate}}
|
| 114 |
+
{{/if}}
|
| 115 |
+
{{#if max_epochs}}
|
| 116 |
+
max_epochs: {{max_epochs}}
|
| 117 |
+
{{/if}}
|
| 118 |
+
{{#if max_seq_length}}
|
| 119 |
+
max_seq_length: {{max_seq_length}}
|
| 120 |
+
{{/if}}
|
| 121 |
+
{{#if dataset_sample_size}}
|
| 122 |
+
dataset_sample_size: {{dataset_sample_size}}
|
| 123 |
+
{{/if}}
|
| 124 |
+
{{#if dataset_size}}
|
| 125 |
+
dataset_size: {{dataset_size}}
|
| 126 |
+
{{/if}}
|
| 127 |
+
{{#if dataset_format}}
|
| 128 |
+
dataset_format: {{dataset_format}}
|
| 129 |
+
{{/if}}
|
| 130 |
+
{{#if gradient_accumulation_steps}}
|
| 131 |
+
gradient_accumulation_steps: {{gradient_accumulation_steps}}
|
| 132 |
+
{{/if}}
|
| 133 |
+
---
|
| 134 |
+
|
| 135 |
+
# {{model_name}}
|
| 136 |
+
|
| 137 |
+
{{model_description}}
|
| 138 |
+
|
| 139 |
+
## Model Details
|
| 140 |
+
|
| 141 |
+
- **Base Model**: SmolLM3-3B
|
| 142 |
+
- **Model Type**: Causal Language Model
|
| 143 |
+
- **Languages**: English, French
|
| 144 |
+
- **License**: Apache 2.0
|
| 145 |
+
- **Fine-tuned**: Yes
|
| 146 |
+
{{#if quantized_models}}
|
| 147 |
+
- **Quantized Versions**: Available in subdirectories
|
| 148 |
+
{{/if}}
|
| 149 |
+
|
| 150 |
+
## Usage
|
| 151 |
+
|
| 152 |
+
### Main Model
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
import torch
|
| 156 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 157 |
+
|
| 158 |
+
# Load the main model
|
| 159 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 160 |
+
"{{repo_name}}",
|
| 161 |
+
device_map="auto",
|
| 162 |
+
torch_dtype=torch.bfloat16
|
| 163 |
+
)
|
| 164 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
|
| 165 |
+
|
| 166 |
+
# Generate text
|
| 167 |
+
input_text = "What are we having for dinner?"
|
| 168 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
|
| 169 |
+
output = model.generate(**input_ids, max_new_tokens=50)
|
| 170 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
## Training Information
|
| 174 |
+
|
| 175 |
+
### Training Configuration
|
| 176 |
+
- **Base Model**: {{base_model}}
|
| 177 |
+
- **Dataset**: {{dataset_name}}
|
| 178 |
+
- **Training Config**: {{training_config_type}}
|
| 179 |
+
- **Trainer Type**: {{trainer_type}}
|
| 180 |
+
{{#if dataset_sample_size}}
|
| 181 |
+
- **Dataset Sample Size**: {{dataset_sample_size}}
|
| 182 |
+
{{/if}}
|
| 183 |
+
|
| 184 |
+
### Training Parameters
|
| 185 |
+
- **Batch Size**: {{batch_size}}
|
| 186 |
+
- **Gradient Accumulation**: {{gradient_accumulation_steps}}
|
| 187 |
+
- **Learning Rate**: {{learning_rate}}
|
| 188 |
+
- **Max Epochs**: {{max_epochs}}
|
| 189 |
+
- **Sequence Length**: {{max_seq_length}}
|
| 190 |
+
|
| 191 |
+
### Training Infrastructure
|
| 192 |
+
- **Hardware**: {{hardware_info}}
|
| 193 |
+
- **Monitoring**: Trackio integration
|
| 194 |
+
- **Experiment**: {{experiment_name}}
|
| 195 |
+
|
| 196 |
+
## Model Architecture
|
| 197 |
+
|
| 198 |
+
This is a fine-tuned version of the SmolLM3-3B model with the following specifications:
|
| 199 |
+
|
| 200 |
+
- **Base Model**: SmolLM3-3B
|
| 201 |
+
- **Parameters**: ~3B
|
| 202 |
+
- **Context Length**: {{max_seq_length}}
|
| 203 |
+
- **Languages**: English, French
|
| 204 |
+
- **Architecture**: Transformer-based causal language model
|
| 205 |
+
|
| 206 |
+
## Performance
|
| 207 |
+
|
| 208 |
+
The model provides:
|
| 209 |
+
- **Text Generation**: High-quality text generation capabilities
|
| 210 |
+
- **Conversation**: Natural conversation abilities
|
| 211 |
+
- **Multilingual**: Support for English and French
|
| 212 |
+
{{#if quantized_models}}
|
| 213 |
+
- **Quantized Versions**: Optimized for different deployment scenarios
|
| 214 |
+
{{/if}}
|
| 215 |
+
|
| 216 |
+
## Limitations
|
| 217 |
+
|
| 218 |
+
1. **Context Length**: Limited by the model's maximum sequence length
|
| 219 |
+
2. **Bias**: May inherit biases from the training data
|
| 220 |
+
3. **Factual Accuracy**: May generate incorrect or outdated information
|
| 221 |
+
4. **Safety**: Should be used responsibly with appropriate safeguards
|
| 222 |
+
{{#if quantized_models}}
|
| 223 |
+
5. **Quantization**: Quantized versions may have slightly reduced accuracy
|
| 224 |
+
{{/if}}
|
| 225 |
+
|
| 226 |
+
## Training Data
|
| 227 |
+
|
| 228 |
+
The model was fine-tuned on:
|
| 229 |
+
- **Dataset**: {{dataset_name}}
|
| 230 |
+
- **Size**: {{dataset_size}}
|
| 231 |
+
- **Format**: {{dataset_format}}
|
| 232 |
+
- **Languages**: English, French
|
| 233 |
+
|
| 234 |
+
## Evaluation
|
| 235 |
+
|
| 236 |
+
The model was evaluated using:
|
| 237 |
+
- **Metrics**: Loss, perplexity, and qualitative assessment
|
| 238 |
+
- **Monitoring**: Real-time tracking via Trackio
|
| 239 |
+
- **Validation**: Regular validation during training
|
| 240 |
+
|
| 241 |
+
## Citation
|
| 242 |
+
|
| 243 |
+
If you use this model in your research, please cite:
|
| 244 |
+
|
| 245 |
+
```bibtex
|
| 246 |
+
@misc{{{model_name_slug}},
|
| 247 |
+
title={{{{model_name}}}},
|
| 248 |
+
author={{{author_name}}},
|
| 249 |
+
year={2024},
|
| 250 |
+
url={https://huggingface.co/{{repo_name}}}
|
| 251 |
+
}
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
## License
|
| 255 |
+
|
| 256 |
+
This model is licensed under the Apache 2.0 License.
|
| 257 |
+
|
| 258 |
+
## Acknowledgments
|
| 259 |
+
|
| 260 |
+
- **Base Model**: SmolLM3-3B by HuggingFaceTB
|
| 261 |
+
- **Training Framework**: PyTorch, Transformers, PEFT
|
| 262 |
+
- **Monitoring**: Trackio integration
|
| 263 |
+
- **Quantization**: torchao library
|
| 264 |
+
|
| 265 |
+
## Support
|
| 266 |
+
|
| 267 |
+
For questions and support:
|
| 268 |
+
- Open an issue on the Hugging Face repository
|
| 269 |
+
- Check the model documentation
|
| 270 |
+
- Review the training logs and configuration
|
| 271 |
+
|
| 272 |
+
## Repository Structure
|
| 273 |
+
|
| 274 |
+
```
|
| 275 |
+
{{repo_name}}/
|
| 276 |
+
├── README.md (this file)
|
| 277 |
+
├── config.json
|
| 278 |
+
├── pytorch_model.bin
|
| 279 |
+
├── tokenizer.json
|
| 280 |
+
└── tokenizer_config.json
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
## Usage Examples
|
| 284 |
+
|
| 285 |
+
### Text Generation
|
| 286 |
+
```python
|
| 287 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 288 |
+
|
| 289 |
+
model = AutoModelForCausalLM.from_pretrained("{{repo_name}}")
|
| 290 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
|
| 291 |
+
|
| 292 |
+
text = "The future of artificial intelligence is"
|
| 293 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 294 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 295 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 296 |
+
```
|
| 297 |
+
|
| 298 |
+
### Conversation
|
| 299 |
+
```python
|
| 300 |
+
def chat_with_model(prompt, max_length=100):
|
| 301 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 302 |
+
outputs = model.generate(**inputs, max_new_tokens=max_length)
|
| 303 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 304 |
+
|
| 305 |
+
response = chat_with_model("Hello, how are you today?")
|
| 306 |
+
print(response)
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
### Advanced Usage
|
| 310 |
+
```python
|
| 311 |
+
# With generation parameters
|
| 312 |
+
outputs = model.generate(
|
| 313 |
+
**inputs,
|
| 314 |
+
max_new_tokens=100,
|
| 315 |
+
temperature=0.7,
|
| 316 |
+
top_p=0.9,
|
| 317 |
+
do_sample=True,
|
| 318 |
+
pad_token_id=tokenizer.eos_token_id
|
| 319 |
+
)
|
| 320 |
+
```
|
| 321 |
+
|
| 322 |
+
## Monitoring and Tracking
|
| 323 |
+
|
| 324 |
+
This model was trained with comprehensive monitoring:
|
| 325 |
+
- **Trackio Space**: {{trackio_url}}
|
| 326 |
+
- **Experiment**: {{experiment_name}}
|
| 327 |
+
- **Dataset Repository**: https://huggingface.co/datasets/{{dataset_repo}}
|
| 328 |
+
- **Training Logs**: Available in the experiment data
|
| 329 |
+
|
| 330 |
+
## Deployment
|
| 331 |
+
|
| 332 |
+
### Requirements
|
| 333 |
+
```bash
|
| 334 |
+
pip install torch transformers accelerate
|
| 335 |
+
{{#if quantized_models}}
|
| 336 |
+
pip install torchao # For quantized models
|
| 337 |
+
{{/if}}
|
| 338 |
+
```
|
| 339 |
+
|
| 340 |
+
### Hardware Requirements
|
| 341 |
+
- **Main Model**: GPU with 8GB+ VRAM recommended
|
| 342 |
+
|
| 343 |
+
## Changelog
|
| 344 |
+
|
| 345 |
+
- **v1.0.0**: Initial release with fine-tuned model
|
templates/spaces/demo_voxtral/README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Voxtral ASR Demo
|
| 3 |
+
emoji: 🎙️
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: cyan
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.42.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: Interactive ASR demo for a fine-tuned Voxtral model
|
| 11 |
+
---
|
| 12 |
+
This Space serves a Voxtral ASR model for speech-to-text transcription.
|
| 13 |
+
Usage:
|
| 14 |
+
|
| 15 |
+
- Click Record and read the displayed phrase aloud.
|
| 16 |
+
- Stop recording to see the transcription.
|
| 17 |
+
- Works best with ~16 kHz audio; internal processing follows Voxtral's processor expectations.
|
| 18 |
+
|
| 19 |
+
Environment variables expected:
|
| 20 |
+
|
| 21 |
+
- `HF_MODEL_ID`: The model repo to load (e.g., `username/voxtral-finetune-YYYYMMDD_HHMMSS`)
|
| 22 |
+
- `MODEL_NAME`: Display name
|
| 23 |
+
- `HF_USERNAME`: For branding
|
templates/spaces/demo_voxtral/app.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
|
| 5 |
+
|
| 6 |
+
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "mistralai/Voxtral-Mini-3B-2507")
|
| 7 |
+
MODEL_NAME = os.getenv("MODEL_NAME", HF_MODEL_ID.split("/")[-1])
|
| 8 |
+
HF_USERNAME = os.getenv("HF_USERNAME", "")
|
| 9 |
+
|
| 10 |
+
processor = AutoProcessor.from_pretrained(HF_MODEL_ID)
|
| 11 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(HF_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16)
|
| 12 |
+
|
| 13 |
+
def transcribe(audio_tuple):
|
| 14 |
+
if audio_tuple is None:
|
| 15 |
+
return "No audio provided"
|
| 16 |
+
sr, data = audio_tuple
|
| 17 |
+
inputs = processor.apply_transcription_request(language="en", model_id=HF_MODEL_ID, audio=[data], format=["WAV"], return_tensors="pt")
|
| 18 |
+
inputs = {k: (v.to(model.device) if hasattr(v, 'to') else v) for k, v in inputs.items()}
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
output_ids = model.generate(**inputs, max_new_tokens=256)
|
| 21 |
+
# Voxtral returns full sequence; decode and strip special tokens
|
| 22 |
+
text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
with gr.Blocks() as demo:
|
| 26 |
+
gr.Markdown(f"# 🎙️ Voxtral ASR Demo — {MODEL_NAME}")
|
| 27 |
+
audio = gr.Audio(sources="microphone", type="numpy", label="Record or upload audio")
|
| 28 |
+
btn = gr.Button("Transcribe")
|
| 29 |
+
out = gr.Textbox(label="Transcription", lines=4)
|
| 30 |
+
btn.click(transcribe, inputs=[audio], outputs=[out])
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
demo.launch(mcp_server=True)
|
| 34 |
+
|
| 35 |
+
|
templates/spaces/demo_voxtral/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.38.2
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
datasets
|
| 5 |
+
soundfile
|
| 6 |
+
librosa
|
| 7 |
+
|