Spaces:
Running
Running
Joseph Pollack
commited on
adds interface and dataset and auto push and demo
Browse files- interface.py +444 -0
- scripts/deploy_demo_space.py +952 -0
- scripts/generate_model_card.py +221 -0
- scripts/push_to_huggingface.py +700 -0
- train_lora.py → scripts/train.py +92 -62
- train.py → scripts/train_lora.py +107 -47
- templates/datasets/readme.md +171 -0
- templates/model_card.md +345 -0
- templates/spaces/demo_voxtral/README.md +23 -0
- templates/spaces/demo_voxtral/app.py +35 -0
- templates/spaces/demo_voxtral/requirements.txt +7 -0
interface.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Voxtral ASR Fine-tuning Interface
|
4 |
+
|
5 |
+
Features:
|
6 |
+
- Collect a personal voice dataset (upload WAV/FLAC + transcripts or record mic audio)
|
7 |
+
- Build a JSONL dataset ({audio_path, text}) at 16kHz
|
8 |
+
- Fine-tune Voxtral (LoRA or full) with streamed logs
|
9 |
+
- Push model to Hugging Face Hub
|
10 |
+
- Deploy a Voxtral ASR demo Space
|
11 |
+
|
12 |
+
Env tokens (optional):
|
13 |
+
- HF_WRITE_TOKEN or HF_TOKEN: write access token
|
14 |
+
- HF_READ_TOKEN: optional read token
|
15 |
+
- HF_USERNAME: fallback username if not derivable from token
|
16 |
+
"""
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
|
20 |
+
import os
|
21 |
+
import json
|
22 |
+
from pathlib import Path
|
23 |
+
from datetime import datetime
|
24 |
+
from typing import Any, Dict, Generator, Optional, Tuple
|
25 |
+
|
26 |
+
import gradio as gr
|
27 |
+
|
28 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
29 |
+
|
30 |
+
|
31 |
+
def get_python() -> str:
|
32 |
+
import sys
|
33 |
+
return sys.executable or "python"
|
34 |
+
|
35 |
+
|
36 |
+
def get_username_from_token(token: str) -> Optional[str]:
|
37 |
+
try:
|
38 |
+
from huggingface_hub import HfApi # type: ignore
|
39 |
+
api = HfApi(token=token)
|
40 |
+
info = api.whoami()
|
41 |
+
if isinstance(info, dict):
|
42 |
+
return info.get("name") or info.get("username")
|
43 |
+
if isinstance(info, str):
|
44 |
+
return info
|
45 |
+
except Exception:
|
46 |
+
return None
|
47 |
+
return None
|
48 |
+
|
49 |
+
|
50 |
+
def run_command_stream(args: list[str], env: Dict[str, str], cwd: Optional[Path] = None) -> Generator[str, None, int]:
|
51 |
+
import subprocess
|
52 |
+
import shlex
|
53 |
+
yield f"$ {' '.join(shlex.quote(a) for a in ([get_python()] + args))}"
|
54 |
+
process = subprocess.Popen(
|
55 |
+
[get_python()] + args,
|
56 |
+
stdout=subprocess.PIPE,
|
57 |
+
stderr=subprocess.STDOUT,
|
58 |
+
text=True,
|
59 |
+
env=env,
|
60 |
+
cwd=str(cwd or PROJECT_ROOT),
|
61 |
+
bufsize=1,
|
62 |
+
universal_newlines=True,
|
63 |
+
)
|
64 |
+
assert process.stdout is not None
|
65 |
+
for line in iter(process.stdout.readline, ""):
|
66 |
+
yield line.rstrip()
|
67 |
+
process.stdout.close()
|
68 |
+
code = process.wait()
|
69 |
+
yield f"[exit_code={code}]"
|
70 |
+
return code
|
71 |
+
|
72 |
+
|
73 |
+
def detect_nvidia_driver() -> Tuple[bool, str]:
|
74 |
+
"""Detect NVIDIA driver/GPU presence with multiple strategies.
|
75 |
+
|
76 |
+
Returns (available, human_message).
|
77 |
+
"""
|
78 |
+
# 1) Try torch CUDA
|
79 |
+
try:
|
80 |
+
import torch # type: ignore
|
81 |
+
if torch.cuda.is_available():
|
82 |
+
try:
|
83 |
+
num = torch.cuda.device_count()
|
84 |
+
names = [torch.cuda.get_device_name(i) for i in range(num)]
|
85 |
+
return True, f"NVIDIA GPU detected: {', '.join(names)}"
|
86 |
+
except Exception:
|
87 |
+
return True, "NVIDIA GPU detected (torch.cuda available)"
|
88 |
+
except Exception:
|
89 |
+
pass
|
90 |
+
|
91 |
+
# 2) Try NVML via pynvml
|
92 |
+
try:
|
93 |
+
import pynvml # type: ignore
|
94 |
+
try:
|
95 |
+
pynvml.nvmlInit()
|
96 |
+
cnt = pynvml.nvmlDeviceGetCount()
|
97 |
+
names = []
|
98 |
+
for i in range(cnt):
|
99 |
+
h = pynvml.nvmlDeviceGetHandleByIndex(i)
|
100 |
+
names.append(pynvml.nvmlDeviceGetName(h).decode("utf-8", errors="ignore"))
|
101 |
+
drv = pynvml.nvmlSystemGetDriverVersion().decode("utf-8", errors="ignore")
|
102 |
+
pynvml.nvmlShutdown()
|
103 |
+
if cnt > 0:
|
104 |
+
return True, f"NVIDIA driver {drv}; GPUs: {', '.join(names)}"
|
105 |
+
except Exception:
|
106 |
+
pass
|
107 |
+
except Exception:
|
108 |
+
pass
|
109 |
+
|
110 |
+
# 3) Try nvidia-smi
|
111 |
+
try:
|
112 |
+
import subprocess
|
113 |
+
res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=3)
|
114 |
+
if res.returncode == 0 and res.stdout.strip():
|
115 |
+
return True, res.stdout.strip().splitlines()[0]
|
116 |
+
except Exception:
|
117 |
+
pass
|
118 |
+
|
119 |
+
return False, "No NVIDIA driver/GPU detected"
|
120 |
+
|
121 |
+
|
122 |
+
def duplicate_space_hint() -> str:
|
123 |
+
space_id = os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID")
|
124 |
+
if space_id:
|
125 |
+
space_url = f"https://huggingface.co/spaces/{space_id}"
|
126 |
+
dup_url = f"{space_url}?duplicate=true"
|
127 |
+
return (
|
128 |
+
f"ℹ️ No NVIDIA driver detected. If you're on Hugging Face Spaces, "
|
129 |
+
f"please duplicate this Space to GPU hardware: [Duplicate this Space]({dup_url})."
|
130 |
+
)
|
131 |
+
return (
|
132 |
+
"ℹ️ No NVIDIA driver detected. To enable training, run on a machine with an NVIDIA GPU/driver "
|
133 |
+
"or duplicate this Space on Hugging Face with GPU hardware."
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
def _write_jsonl(rows: list[dict], path: Path) -> Path:
|
138 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
139 |
+
with open(path, "w", encoding="utf-8") as f:
|
140 |
+
for r in rows:
|
141 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
142 |
+
return path
|
143 |
+
|
144 |
+
|
145 |
+
def _save_uploaded_dataset(files: list, transcripts: list[str]) -> str:
|
146 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
147 |
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
148 |
+
rows: list[dict] = []
|
149 |
+
for i, fpath in enumerate(files or []):
|
150 |
+
if i >= len(transcripts):
|
151 |
+
break
|
152 |
+
rows.append({"audio_path": fpath, "text": transcripts[i] or ""})
|
153 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
154 |
+
_write_jsonl(rows, jsonl_path)
|
155 |
+
return str(jsonl_path)
|
156 |
+
|
157 |
+
|
158 |
+
def _save_recordings(recordings: list[tuple[int, list]], transcripts: list[str]) -> str:
|
159 |
+
import soundfile as sf
|
160 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
161 |
+
wav_dir = dataset_dir / "wavs"
|
162 |
+
wav_dir.mkdir(parents=True, exist_ok=True)
|
163 |
+
rows: list[dict] = []
|
164 |
+
for i, rec in enumerate(recordings or []):
|
165 |
+
if rec is None:
|
166 |
+
continue
|
167 |
+
if i >= len(transcripts):
|
168 |
+
break
|
169 |
+
sr, data = rec
|
170 |
+
out_path = wav_dir / f"rec_{i:04d}.wav"
|
171 |
+
sf.write(str(out_path), data, sr)
|
172 |
+
rows.append({"audio_path": str(out_path), "text": transcripts[i] or ""})
|
173 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
174 |
+
_write_jsonl(rows, jsonl_path)
|
175 |
+
return str(jsonl_path)
|
176 |
+
|
177 |
+
|
178 |
+
def start_voxtral_training(
|
179 |
+
use_lora: bool,
|
180 |
+
base_model: str,
|
181 |
+
repo_short: str,
|
182 |
+
jsonl_path: str,
|
183 |
+
train_count: int,
|
184 |
+
eval_count: int,
|
185 |
+
batch_size: int,
|
186 |
+
grad_accum: int,
|
187 |
+
learning_rate: float,
|
188 |
+
epochs: float,
|
189 |
+
lora_r: int,
|
190 |
+
lora_alpha: int,
|
191 |
+
lora_dropout: float,
|
192 |
+
freeze_audio_tower: bool,
|
193 |
+
push_to_hub: bool,
|
194 |
+
deploy_demo: bool,
|
195 |
+
) -> Generator[str, None, None]:
|
196 |
+
env = os.environ.copy()
|
197 |
+
write_token = env.get("HF_WRITE_TOKEN") or env.get("HF_TOKEN")
|
198 |
+
read_token = env.get("HF_READ_TOKEN")
|
199 |
+
username = get_username_from_token(write_token or "") or env.get("HF_USERNAME") or ""
|
200 |
+
output_dir = PROJECT_ROOT / "outputs" / repo_short
|
201 |
+
|
202 |
+
# 1) Train
|
203 |
+
script = PROJECT_ROOT / ("scripts/train_lora.py" if use_lora else "scripts/train.py")
|
204 |
+
args = [str(script)]
|
205 |
+
if jsonl_path:
|
206 |
+
args += ["--dataset-jsonl", jsonl_path]
|
207 |
+
args += [
|
208 |
+
"--model-checkpoint", base_model,
|
209 |
+
"--train-count", str(train_count),
|
210 |
+
"--eval-count", str(eval_count),
|
211 |
+
"--batch-size", str(batch_size),
|
212 |
+
"--grad-accum", str(grad_accum),
|
213 |
+
"--learning-rate", str(learning_rate),
|
214 |
+
"--epochs", str(epochs),
|
215 |
+
"--output-dir", str(output_dir),
|
216 |
+
"--save-steps", "50",
|
217 |
+
]
|
218 |
+
if use_lora:
|
219 |
+
args += [
|
220 |
+
"--lora-r", str(lora_r),
|
221 |
+
"--lora-alpha", str(lora_alpha),
|
222 |
+
"--lora-dropout", str(lora_dropout),
|
223 |
+
]
|
224 |
+
if freeze_audio_tower:
|
225 |
+
args += ["--freeze-audio-tower"]
|
226 |
+
for line in run_command_stream(args, env):
|
227 |
+
yield line
|
228 |
+
|
229 |
+
# 2) Push to Hub
|
230 |
+
if push_to_hub:
|
231 |
+
repo_name = f"{username}/{repo_short}" if username else repo_short
|
232 |
+
push_args = [
|
233 |
+
str(PROJECT_ROOT / "scripts/push_to_huggingface.py"),
|
234 |
+
str(output_dir),
|
235 |
+
repo_name,
|
236 |
+
]
|
237 |
+
for line in run_command_stream(push_args, env):
|
238 |
+
yield line
|
239 |
+
|
240 |
+
# 3) Deploy demo Space
|
241 |
+
if deploy_demo and username:
|
242 |
+
deploy_args = [
|
243 |
+
str(PROJECT_ROOT / "scripts/deploy_demo_space.py"),
|
244 |
+
"--hf-token", write_token or "",
|
245 |
+
"--hf-username", username,
|
246 |
+
"--model-id", f"{username}/{repo_short}",
|
247 |
+
"--demo-type", "voxtral",
|
248 |
+
"--space-name", f"{repo_short}-demo",
|
249 |
+
]
|
250 |
+
for line in run_command_stream(deploy_args, env):
|
251 |
+
yield line
|
252 |
+
|
253 |
+
|
254 |
+
PHRASES = [
|
255 |
+
"The quick brown fox jumps over the lazy dog.",
|
256 |
+
"Please say your full name.",
|
257 |
+
"Today is a good day to learn something new.",
|
258 |
+
"Artificial intelligence helps with many tasks.",
|
259 |
+
"I enjoy reading books and listening to music.",
|
260 |
+
"This is a sample sentence for testing speech.",
|
261 |
+
"Speak clearly and at a normal pace.",
|
262 |
+
"Numbers like one, two, three are easy to say.",
|
263 |
+
"The weather is sunny with a chance of rain.",
|
264 |
+
"Thank you for taking the time to help.",
|
265 |
+
]
|
266 |
+
|
267 |
+
with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
268 |
+
has_gpu, gpu_msg = detect_nvidia_driver()
|
269 |
+
if has_gpu:
|
270 |
+
gr.HTML(
|
271 |
+
f"""
|
272 |
+
<div style="background-color: rgba(59, 130, 246, 0.1); border: 1px solid rgba(59, 130, 246, 0.3); border-radius: 8px; padding: 12px; margin-bottom: 16px; text-align: center;">
|
273 |
+
<p style="color: rgb(59, 130, 246); margin: 0; font-size: 14px; font-weight: 600;">
|
274 |
+
✅ NVIDIA GPU ready — {gpu_msg}
|
275 |
+
</p>
|
276 |
+
<p style="color: rgb(59, 130, 246); margin: 6px 0 0; font-size: 12px;">
|
277 |
+
Set HF_WRITE_TOKEN/HF_TOKEN in environment to enable Hub push.
|
278 |
+
</p>
|
279 |
+
</div>
|
280 |
+
"""
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
hint_md = duplicate_space_hint()
|
284 |
+
gr.HTML(
|
285 |
+
f"""
|
286 |
+
<div style="background-color: rgba(245, 158, 11, 0.1); border: 1px solid rgba(245, 158, 11, 0.3); border-radius: 8px; padding: 12px; margin-bottom: 16px; text-align: center;">
|
287 |
+
<p style="color: rgb(234, 88, 12); margin: 0; font-size: 14px; font-weight: 600;">
|
288 |
+
⚠️ No NVIDIA GPU/driver detected — training requires a GPU runtime
|
289 |
+
</p>
|
290 |
+
<p style="color: rgb(234, 88, 12); margin: 6px 0 0; font-size: 12px;">
|
291 |
+
{hint_md}
|
292 |
+
</p>
|
293 |
+
</div>
|
294 |
+
"""
|
295 |
+
)
|
296 |
+
|
297 |
+
gr.Markdown("""
|
298 |
+
# 🎙️ Voxtral ASR Fine-tuning
|
299 |
+
Read the phrases below and record them. Then start fine-tuning.
|
300 |
+
""")
|
301 |
+
|
302 |
+
jsonl_out = gr.Textbox(label="Dataset JSONL path", interactive=False, visible=True)
|
303 |
+
|
304 |
+
# Recording grid with dynamic text readouts
|
305 |
+
phrase_texts_state = gr.State(PHRASES)
|
306 |
+
phrase_markdowns: list[gr.Markdown] = []
|
307 |
+
rec_components = []
|
308 |
+
with gr.Column():
|
309 |
+
for idx, phrase in enumerate(PHRASES):
|
310 |
+
md = gr.Markdown(f"**{idx+1}. {phrase}**")
|
311 |
+
phrase_markdowns.append(md)
|
312 |
+
comp = gr.Audio(sources="microphone", type="numpy", label=f"Recording {idx+1}")
|
313 |
+
rec_components.append(comp)
|
314 |
+
|
315 |
+
# Advanced options accordion
|
316 |
+
with gr.Accordion("Advanced options", open=False):
|
317 |
+
base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
|
318 |
+
use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
|
319 |
+
with gr.Row():
|
320 |
+
batch_size = gr.Number(value=2, precision=0, label="Batch size")
|
321 |
+
grad_accum = gr.Number(value=4, precision=0, label="Grad accum")
|
322 |
+
with gr.Row():
|
323 |
+
learning_rate = gr.Number(value=5e-5, precision=6, label="Learning rate")
|
324 |
+
epochs = gr.Number(value=3.0, precision=2, label="Epochs")
|
325 |
+
with gr.Accordion("LoRA settings", open=False):
|
326 |
+
lora_r = gr.Number(value=8, precision=0, label="LoRA r")
|
327 |
+
lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
|
328 |
+
lora_dropout = gr.Number(value=0.0, precision=3, label="LoRA dropout")
|
329 |
+
freeze_audio_tower = gr.Checkbox(value=True, label="Freeze audio tower")
|
330 |
+
with gr.Row():
|
331 |
+
train_count = gr.Number(value=100, precision=0, label="Train samples")
|
332 |
+
eval_count = gr.Number(value=50, precision=0, label="Eval samples")
|
333 |
+
repo_short = gr.Textbox(value=f"voxtral-finetune-{datetime.now().strftime('%Y%m%d_%H%M%S')}", label="Model repo (short)")
|
334 |
+
push_to_hub = gr.Checkbox(value=True, label="Push to HF Hub after training")
|
335 |
+
deploy_demo = gr.Checkbox(value=True, label="Deploy demo Space after push")
|
336 |
+
|
337 |
+
gr.Markdown("### Upload audio + transcripts (optional)")
|
338 |
+
upload_audio = gr.File(file_count="multiple", type="filepath", label="Upload WAV/FLAC files (optional)")
|
339 |
+
transcripts_box = gr.Textbox(lines=6, label="Transcripts (one per line, aligned with files)")
|
340 |
+
save_upload_btn = gr.Button("Save uploaded dataset")
|
341 |
+
|
342 |
+
def _collect_upload(files, txt):
|
343 |
+
lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
|
344 |
+
return _save_uploaded_dataset(files or [], lines)
|
345 |
+
|
346 |
+
save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [jsonl_out])
|
347 |
+
|
348 |
+
# Save recordings button
|
349 |
+
save_rec_btn = gr.Button("Save recordings as dataset")
|
350 |
+
|
351 |
+
def _collect_preloaded_recs(*recs_and_texts):
|
352 |
+
import soundfile as sf
|
353 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
354 |
+
wav_dir = dataset_dir / "wavs"
|
355 |
+
wav_dir.mkdir(parents=True, exist_ok=True)
|
356 |
+
rows: list[dict] = []
|
357 |
+
if not recs_and_texts:
|
358 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
359 |
+
_write_jsonl(rows, jsonl_path)
|
360 |
+
return str(jsonl_path)
|
361 |
+
texts = recs_and_texts[-1]
|
362 |
+
recs = recs_and_texts[:-1]
|
363 |
+
for i, rec in enumerate(recs):
|
364 |
+
if rec is None:
|
365 |
+
continue
|
366 |
+
sr, data = rec
|
367 |
+
out_path = wav_dir / f"rec_{i:04d}.wav"
|
368 |
+
sf.write(str(out_path), data, sr)
|
369 |
+
label_text = (texts[i] if isinstance(texts, list) and i < len(texts) else (PHRASES[i] if i < len(PHRASES) else ""))
|
370 |
+
rows.append({"audio_path": str(out_path), "text": label_text})
|
371 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
372 |
+
_write_jsonl(rows, jsonl_path)
|
373 |
+
return str(jsonl_path)
|
374 |
+
|
375 |
+
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_out])
|
376 |
+
|
377 |
+
# Quick sample from VoxPopuli (few random rows)
|
378 |
+
with gr.Row():
|
379 |
+
vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "ro", "hu", "cs", "nl", "fi", "hr", "sk", "sl", "et", "lt"], value="en", label="VoxPopuli language")
|
380 |
+
vp_samples = gr.Number(value=20, precision=0, label="Num samples")
|
381 |
+
vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
|
382 |
+
vp_btn = gr.Button("Use VoxPopuli sample")
|
383 |
+
|
384 |
+
def _collect_voxpopuli(lang_code: str, num_samples: int, split: str):
|
385 |
+
import sys
|
386 |
+
# Workaround for dill on Python 3.13 expecting __main__ during import
|
387 |
+
if "__main__" not in sys.modules:
|
388 |
+
sys.modules["__main__"] = sys.modules[__name__]
|
389 |
+
from datasets import load_dataset, Audio # type: ignore
|
390 |
+
import random
|
391 |
+
ds = load_dataset("facebook/voxpopuli", lang_code, split=split)
|
392 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
393 |
+
# shuffle and select
|
394 |
+
total = len(ds)
|
395 |
+
k = max(1, min(int(num_samples or 1), total))
|
396 |
+
ds = ds.shuffle(seed=random.randint(1, 10_000))
|
397 |
+
ds_sel = ds.select(range(k))
|
398 |
+
|
399 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
400 |
+
rows: list[dict] = []
|
401 |
+
texts: list[str] = []
|
402 |
+
for ex in ds_sel:
|
403 |
+
audio = ex.get("audio") or {}
|
404 |
+
path = audio.get("path")
|
405 |
+
text = ex.get("normalized_text") or ex.get("raw_text") or ""
|
406 |
+
if path and text is not None:
|
407 |
+
rows.append({"audio_path": path, "text": text})
|
408 |
+
texts.append(str(text))
|
409 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
410 |
+
_write_jsonl(rows, jsonl_path)
|
411 |
+
# Build markdown content updates for on-screen prompts
|
412 |
+
md_updates = []
|
413 |
+
for i in range(len(phrase_markdowns)):
|
414 |
+
t = texts[i] if i < len(texts) else ""
|
415 |
+
md_updates.append(f"**{i+1}. {t}**")
|
416 |
+
return (str(jsonl_path), texts, *md_updates)
|
417 |
+
|
418 |
+
vp_btn.click(
|
419 |
+
_collect_voxpopuli,
|
420 |
+
[vp_lang, vp_samples, vp_split],
|
421 |
+
[jsonl_out, phrase_texts_state] + phrase_markdowns,
|
422 |
+
)
|
423 |
+
|
424 |
+
start_btn = gr.Button("Start Fine-tuning")
|
425 |
+
logs_box = gr.Textbox(label="Logs", lines=20)
|
426 |
+
|
427 |
+
start_btn.click(
|
428 |
+
start_voxtral_training,
|
429 |
+
inputs=[
|
430 |
+
use_lora, base_model, repo_short, jsonl_out, train_count, eval_count,
|
431 |
+
batch_size, grad_accum, learning_rate, epochs,
|
432 |
+
lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
|
433 |
+
push_to_hub, deploy_demo,
|
434 |
+
],
|
435 |
+
outputs=[logs_box],
|
436 |
+
)
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
441 |
+
server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
|
442 |
+
demo.queue().launch(server_name=server_name, server_port=server_port, mcp_server=True)
|
443 |
+
|
444 |
+
|
scripts/deploy_demo_space.py
ADDED
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Demo Space Deployment Script
|
4 |
+
Deploys a Gradio demo space to Hugging Face Spaces for testing the fine-tuned model.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import argparse
|
12 |
+
import subprocess
|
13 |
+
import requests
|
14 |
+
import tempfile
|
15 |
+
import shutil
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Optional, Dict, Any
|
18 |
+
import time
|
19 |
+
|
20 |
+
# Import Hugging Face Hub API
|
21 |
+
try:
|
22 |
+
from huggingface_hub import HfApi, create_repo, upload_file
|
23 |
+
HF_HUB_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
HF_HUB_AVAILABLE = False
|
26 |
+
print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
|
27 |
+
|
28 |
+
# Add src to path for imports
|
29 |
+
sys.path.append(str(Path(__file__).parent.parent / "src"))
|
30 |
+
|
31 |
+
from config import SmolLM3Config
|
32 |
+
|
33 |
+
# Setup logging
|
34 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
class DemoSpaceDeployer:
|
38 |
+
"""Deploy demo space to Hugging Face Spaces"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
hf_token: str,
|
43 |
+
# Token used for API actions that create/update the Space (write perms)
|
44 |
+
hf_username: str,
|
45 |
+
model_id: str,
|
46 |
+
subfolder: str = "int4",
|
47 |
+
space_name: Optional[str] = None,
|
48 |
+
demo_type: Optional[str] = None,
|
49 |
+
config_file: Optional[str] = None,
|
50 |
+
# Optional token used as the Space's HF_TOKEN secret (read-only recommended)
|
51 |
+
space_secret_token: Optional[str] = None,
|
52 |
+
# Examples configuration
|
53 |
+
examples_type: Optional[str] = None,
|
54 |
+
disable_examples: Optional[bool] = None,
|
55 |
+
examples_json: Optional[str] = None,
|
56 |
+
# Branding overrides
|
57 |
+
brand_owner_name: Optional[str] = None,
|
58 |
+
brand_team_name: Optional[str] = None,
|
59 |
+
brand_discord_url: Optional[str] = None,
|
60 |
+
brand_hf_org: Optional[str] = None,
|
61 |
+
brand_hf_label: Optional[str] = None,
|
62 |
+
brand_hf_url: Optional[str] = None,
|
63 |
+
brand_gh_org: Optional[str] = None,
|
64 |
+
brand_gh_label: Optional[str] = None,
|
65 |
+
brand_gh_url: Optional[str] = None,
|
66 |
+
brand_project_name: Optional[str] = None,
|
67 |
+
brand_project_url: Optional[str] = None,
|
68 |
+
):
|
69 |
+
self.hf_token = hf_token
|
70 |
+
# The token we will store in the Space secrets. Defaults to hf_token if not provided
|
71 |
+
self.space_secret_token = space_secret_token or hf_token
|
72 |
+
self.hf_username = hf_username
|
73 |
+
# Allow passing just a repo name without username and auto-prefix
|
74 |
+
self.model_id = model_id if "/" in model_id else f"{hf_username}/{model_id}"
|
75 |
+
self.subfolder = subfolder
|
76 |
+
self.space_name = space_name or f"{self.model_id.split('/')[-1]}-demo"
|
77 |
+
self.space_id = f"{hf_username}/{self.space_name}"
|
78 |
+
self.space_url = f"https://huggingface.co/spaces/{self.space_id}"
|
79 |
+
self.config_file = config_file
|
80 |
+
|
81 |
+
# Config-derived context
|
82 |
+
self.system_message: Optional[str] = None
|
83 |
+
self.developer_message: Optional[str] = None
|
84 |
+
self.model_identity: Optional[str] = None
|
85 |
+
self.reasoning_effort: Optional[str] = None
|
86 |
+
# Examples context
|
87 |
+
self.examples_type: Optional[str] = (examples_type or None)
|
88 |
+
self.disable_examples: Optional[bool] = (disable_examples if disable_examples is not None else None)
|
89 |
+
self.examples_json: Optional[str] = (examples_json or None)
|
90 |
+
|
91 |
+
# Determine demo type from model_id if not provided
|
92 |
+
if demo_type is None:
|
93 |
+
demo_type = self._detect_demo_type(model_id)
|
94 |
+
|
95 |
+
# Template paths based on model type
|
96 |
+
self.demo_type = demo_type
|
97 |
+
self.template_dir = Path(__file__).parent.parent / "templates" / "spaces" / f"demo_{demo_type}"
|
98 |
+
self.workspace_dir = Path.cwd()
|
99 |
+
|
100 |
+
# Initialize HF API
|
101 |
+
if HF_HUB_AVAILABLE:
|
102 |
+
self.api = HfApi(token=self.hf_token)
|
103 |
+
else:
|
104 |
+
self.api = None
|
105 |
+
logger.warning("huggingface_hub not available, using CLI fallback")
|
106 |
+
|
107 |
+
# Load optional config-specified messages
|
108 |
+
try:
|
109 |
+
self._load_config_messages()
|
110 |
+
except Exception as e:
|
111 |
+
logger.warning(f"Could not load config messages: {e}")
|
112 |
+
|
113 |
+
# Branding defaults (can be overridden via CLI)
|
114 |
+
self.brand_owner_name = brand_owner_name or self.hf_username or "Tonic"
|
115 |
+
self.brand_team_name = brand_team_name or f"Team{self.brand_owner_name}"
|
116 |
+
self.brand_discord_url = brand_discord_url or "https://discord.gg/qdfnvSPcqP"
|
117 |
+
# HF org/link
|
118 |
+
_default_hf_org = brand_hf_org or self.hf_username or "MultiTransformer"
|
119 |
+
self.brand_hf_org = _default_hf_org
|
120 |
+
self.brand_hf_label = brand_hf_label or self.brand_hf_org
|
121 |
+
self.brand_hf_url = brand_hf_url or f"https://huggingface.co/{self.brand_hf_org}"
|
122 |
+
# GitHub org/link
|
123 |
+
_default_gh_org = brand_gh_org or self.hf_username or "tonic-ai"
|
124 |
+
self.brand_gh_org = _default_gh_org
|
125 |
+
self.brand_gh_label = brand_gh_label or self.brand_gh_org
|
126 |
+
self.brand_gh_url = brand_gh_url or f"https://github.com/{self.brand_gh_org}"
|
127 |
+
# Project link
|
128 |
+
self.brand_project_name = brand_project_name or "MultiTonic"
|
129 |
+
self.brand_project_url = brand_project_url or "https://github.com/MultiTonic"
|
130 |
+
|
131 |
+
def _load_config_messages(self) -> None:
|
132 |
+
"""Load system/developer/model_identity from a training config file if provided."""
|
133 |
+
if not self.config_file:
|
134 |
+
return
|
135 |
+
cfg_path = Path(self.config_file)
|
136 |
+
if not cfg_path.exists():
|
137 |
+
logger.warning(f"Config file not found: {cfg_path}")
|
138 |
+
return
|
139 |
+
|
140 |
+
# Ensure project root and config dir are importable for relative imports inside config
|
141 |
+
project_root = Path(__file__).parent.parent
|
142 |
+
if str(project_root) not in sys.path:
|
143 |
+
sys.path.insert(0, str(project_root))
|
144 |
+
cfg_dir = project_root / "config"
|
145 |
+
if str(cfg_dir) not in sys.path:
|
146 |
+
sys.path.insert(0, str(cfg_dir))
|
147 |
+
|
148 |
+
import importlib.util
|
149 |
+
spec = importlib.util.spec_from_file_location("config_module", str(cfg_path))
|
150 |
+
if not spec or not spec.loader:
|
151 |
+
return
|
152 |
+
module = importlib.util.module_from_spec(spec)
|
153 |
+
spec.loader.exec_module(module) # type: ignore
|
154 |
+
cfg = getattr(module, "config", None)
|
155 |
+
if cfg is None:
|
156 |
+
return
|
157 |
+
self.system_message = getattr(cfg, "system_message", None)
|
158 |
+
self.developer_message = getattr(cfg, "developer_message", None)
|
159 |
+
chat_kwargs = getattr(cfg, "chat_template_kwargs", None)
|
160 |
+
if isinstance(chat_kwargs, dict):
|
161 |
+
self.model_identity = chat_kwargs.get("model_identity")
|
162 |
+
self.reasoning_effort = chat_kwargs.get("reasoning_effort")
|
163 |
+
|
164 |
+
def _detect_demo_type(self, model_id: str) -> str:
|
165 |
+
"""Detect the appropriate demo type based on model ID"""
|
166 |
+
model_id_lower = model_id.lower()
|
167 |
+
|
168 |
+
# Voxtral ASR models
|
169 |
+
if "voxtral" in model_id_lower:
|
170 |
+
logger.info(f"Detected Voxtral model, using demo_voxtral template")
|
171 |
+
return "voxtral"
|
172 |
+
|
173 |
+
# Check for GPT-OSS models
|
174 |
+
if "gpt-oss" in model_id_lower or "gpt_oss" in model_id_lower:
|
175 |
+
logger.info(f"Detected GPT-OSS model, using demo_gpt template")
|
176 |
+
return "gpt"
|
177 |
+
|
178 |
+
# Check for SmolLM models (default)
|
179 |
+
elif "smollm" in model_id_lower or "smol" in model_id_lower:
|
180 |
+
logger.info(f"Detected SmolLM model, using demo_smol template")
|
181 |
+
return "smol"
|
182 |
+
|
183 |
+
# Default to SmolLM for unknown models
|
184 |
+
else:
|
185 |
+
logger.info(f"Unknown model type, defaulting to demo_smol template")
|
186 |
+
return "smol"
|
187 |
+
|
188 |
+
def _generate_env_setup(self) -> str:
|
189 |
+
"""Generate environment variable setup based on demo type and model"""
|
190 |
+
if self.demo_type == "gpt":
|
191 |
+
# For GPT-OSS models, we need more sophisticated environment setup
|
192 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
193 |
+
import json as _json
|
194 |
+
env_setup = f"""
|
195 |
+
# Environment variables for GPT-OSS model configuration
|
196 |
+
import os
|
197 |
+
os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
|
198 |
+
os.environ['LORA_MODEL_ID'] = {_json.dumps(self.model_id)}
|
199 |
+
os.environ['BASE_MODEL_ID'] = 'openai/gpt-oss-20b'
|
200 |
+
os.environ['MODEL_SUBFOLDER'] = {_json.dumps(self.subfolder if self.subfolder else "")}
|
201 |
+
os.environ['MODEL_NAME'] = {_json.dumps(model_name)}
|
202 |
+
os.environ['MODEL_IDENTITY'] = {_json.dumps(self.model_identity or "")}
|
203 |
+
os.environ['SYSTEM_MESSAGE'] = {_json.dumps(self.system_message or (self.model_identity or ""))}
|
204 |
+
os.environ['DEVELOPER_MESSAGE'] = {_json.dumps(self.developer_message or "")}
|
205 |
+
os.environ['REASONING_EFFORT'] = {_json.dumps((self.reasoning_effort or "medium"))}
|
206 |
+
{"os.environ['EXAMPLES_TYPE'] = " + _json.dumps(self.examples_type) + "\n" if self.examples_type else ''}
|
207 |
+
{"os.environ['DISABLE_EXAMPLES'] = 'true'\n" if self.disable_examples else ("os.environ['DISABLE_EXAMPLES'] = 'false'\n" if self.disable_examples is not None else '')}
|
208 |
+
{"os.environ['EXAMPLES_JSON'] = " + _json.dumps(self.examples_json) + "\n" if self.examples_json else ''}
|
209 |
+
|
210 |
+
# Branding/owner variables
|
211 |
+
os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
|
212 |
+
os.environ['BRAND_OWNER_NAME'] = {_json.dumps(self.brand_owner_name)}
|
213 |
+
os.environ['BRAND_TEAM_NAME'] = {_json.dumps(self.brand_team_name)}
|
214 |
+
os.environ['BRAND_DISCORD_URL'] = {_json.dumps(self.brand_discord_url)}
|
215 |
+
os.environ['BRAND_HF_ORG'] = {_json.dumps(self.brand_hf_org)}
|
216 |
+
os.environ['BRAND_HF_LABEL'] = {_json.dumps(self.brand_hf_label)}
|
217 |
+
os.environ['BRAND_HF_URL'] = {_json.dumps(self.brand_hf_url)}
|
218 |
+
os.environ['BRAND_GH_ORG'] = {_json.dumps(self.brand_gh_org)}
|
219 |
+
os.environ['BRAND_GH_LABEL'] = {_json.dumps(self.brand_gh_label)}
|
220 |
+
os.environ['BRAND_GH_URL'] = {_json.dumps(self.brand_gh_url)}
|
221 |
+
os.environ['BRAND_PROJECT_NAME'] = {_json.dumps(self.brand_project_name)}
|
222 |
+
os.environ['BRAND_PROJECT_URL'] = {_json.dumps(self.brand_project_url)}
|
223 |
+
|
224 |
+
"""
|
225 |
+
elif self.demo_type == "voxtral":
|
226 |
+
import json as _json
|
227 |
+
env_setup = f"""
|
228 |
+
# Environment variables for Voxtral ASR demo
|
229 |
+
import os
|
230 |
+
os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
|
231 |
+
os.environ['MODEL_NAME'] = {_json.dumps(self.model_id.split('/')[-1])}
|
232 |
+
os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
|
233 |
+
"""
|
234 |
+
else:
|
235 |
+
# For SmolLM models, use simpler setup
|
236 |
+
import json as _json
|
237 |
+
env_setup = f"""
|
238 |
+
# Environment variables for model configuration
|
239 |
+
import os
|
240 |
+
os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
|
241 |
+
os.environ['MODEL_SUBFOLDER'] = {_json.dumps(self.subfolder if self.subfolder else "")}
|
242 |
+
os.environ['MODEL_NAME'] = {_json.dumps(self.model_id.split("/")[-1])}
|
243 |
+
os.environ['MODEL_IDENTITY'] = {_json.dumps(self.model_identity or "")}
|
244 |
+
os.environ['SYSTEM_MESSAGE'] = {_json.dumps(self.system_message or (self.model_identity or ""))}
|
245 |
+
os.environ['DEVELOPER_MESSAGE'] = {_json.dumps(self.developer_message or "")}
|
246 |
+
os.environ['REASONING_EFFORT'] = {_json.dumps((self.reasoning_effort or "medium"))}
|
247 |
+
{"os.environ['EXAMPLES_TYPE'] = " + _json.dumps(self.examples_type) + "\n" if self.examples_type else ''}
|
248 |
+
{"os.environ['DISABLE_EXAMPLES'] = 'true'\n" if self.disable_examples else ("os.environ['DISABLE_EXAMPLES'] = 'false'\n" if self.disable_examples is not None else '')}
|
249 |
+
{"os.environ['EXAMPLES_JSON'] = " + _json.dumps(self.examples_json) + "\n" if self.examples_json else ''}
|
250 |
+
|
251 |
+
# Branding/owner variables
|
252 |
+
os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
|
253 |
+
os.environ['BRAND_OWNER_NAME'] = {_json.dumps(self.brand_owner_name)}
|
254 |
+
os.environ['BRAND_TEAM_NAME'] = {_json.dumps(self.brand_team_name)}
|
255 |
+
os.environ['BRAND_DISCORD_URL'] = {_json.dumps(self.brand_discord_url)}
|
256 |
+
os.environ['BRAND_HF_ORG'] = {_json.dumps(self.brand_hf_org)}
|
257 |
+
os.environ['BRAND_HF_LABEL'] = {_json.dumps(self.brand_hf_label)}
|
258 |
+
os.environ['BRAND_HF_URL'] = {_json.dumps(self.brand_hf_url)}
|
259 |
+
os.environ['BRAND_GH_ORG'] = {_json.dumps(self.brand_gh_org)}
|
260 |
+
os.environ['BRAND_GH_LABEL'] = {_json.dumps(self.brand_gh_label)}
|
261 |
+
os.environ['BRAND_GH_URL'] = {_json.dumps(self.brand_gh_url)}
|
262 |
+
os.environ['BRAND_PROJECT_NAME'] = {_json.dumps(self.brand_project_name)}
|
263 |
+
os.environ['BRAND_PROJECT_URL'] = {_json.dumps(self.brand_project_url)}
|
264 |
+
|
265 |
+
"""
|
266 |
+
return env_setup
|
267 |
+
|
268 |
+
def _set_model_variables(self):
|
269 |
+
"""Set model-specific environment variables in the space"""
|
270 |
+
try:
|
271 |
+
# Common variables for all models
|
272 |
+
self.api.add_space_variable(
|
273 |
+
repo_id=self.space_id,
|
274 |
+
key="HF_MODEL_ID",
|
275 |
+
value=self.model_id,
|
276 |
+
description="Model ID for the demo"
|
277 |
+
)
|
278 |
+
logger.info(f"✅ Successfully set HF_MODEL_ID variable: {self.model_id}")
|
279 |
+
|
280 |
+
if self.subfolder and self.subfolder.strip():
|
281 |
+
self.api.add_space_variable(
|
282 |
+
repo_id=self.space_id,
|
283 |
+
key="MODEL_SUBFOLDER",
|
284 |
+
value=self.subfolder,
|
285 |
+
description="Model subfolder for the demo"
|
286 |
+
)
|
287 |
+
logger.info(f"✅ Successfully set MODEL_SUBFOLDER variable: {self.subfolder}")
|
288 |
+
else:
|
289 |
+
logger.info("ℹ️ No subfolder specified, using main model")
|
290 |
+
|
291 |
+
# GPT-OSS specific variables
|
292 |
+
if self.demo_type == "gpt":
|
293 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
294 |
+
self.api.add_space_variable(
|
295 |
+
repo_id=self.space_id,
|
296 |
+
key="LORA_MODEL_ID",
|
297 |
+
value=self.model_id,
|
298 |
+
description="LoRA/Fine-tuned model ID"
|
299 |
+
)
|
300 |
+
logger.info(f"✅ Successfully set LORA_MODEL_ID variable: {self.model_id}")
|
301 |
+
self.api.add_space_variable(
|
302 |
+
repo_id=self.space_id,
|
303 |
+
key="BASE_MODEL_ID",
|
304 |
+
value="openai/gpt-oss-20b",
|
305 |
+
description="Base model ID for GPT-OSS"
|
306 |
+
)
|
307 |
+
logger.info("✅ Successfully set BASE_MODEL_ID variable: openai/gpt-oss-20b")
|
308 |
+
self.api.add_space_variable(
|
309 |
+
repo_id=self.space_id,
|
310 |
+
key="MODEL_NAME",
|
311 |
+
value=model_name,
|
312 |
+
description="Display name for the model"
|
313 |
+
)
|
314 |
+
logger.info(f"✅ Successfully set MODEL_NAME variable: {model_name}")
|
315 |
+
|
316 |
+
# Voxtral-specific variables
|
317 |
+
elif self.demo_type == "voxtral":
|
318 |
+
# HF_MODEL_ID was already set above; set a readable MODEL_NAME
|
319 |
+
vox_model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
320 |
+
self.api.add_space_variable(
|
321 |
+
repo_id=self.space_id,
|
322 |
+
key="MODEL_NAME",
|
323 |
+
value=vox_model_name,
|
324 |
+
description="Display name for the Voxtral model"
|
325 |
+
)
|
326 |
+
logger.info(f"✅ Set Voxtral MODEL_NAME variable: {vox_model_name}")
|
327 |
+
|
328 |
+
# Optional context variables
|
329 |
+
if self.model_identity:
|
330 |
+
self.api.add_space_variable(
|
331 |
+
repo_id=self.space_id,
|
332 |
+
key="MODEL_IDENTITY",
|
333 |
+
value=self.model_identity,
|
334 |
+
description="Default model identity/system persona"
|
335 |
+
)
|
336 |
+
logger.info("✅ Set MODEL_IDENTITY variable")
|
337 |
+
if self.system_message or self.model_identity:
|
338 |
+
self.api.add_space_variable(
|
339 |
+
repo_id=self.space_id,
|
340 |
+
key="SYSTEM_MESSAGE",
|
341 |
+
value=self.system_message or self.model_identity or "",
|
342 |
+
description="Default system message"
|
343 |
+
)
|
344 |
+
logger.info("✅ Set SYSTEM_MESSAGE variable")
|
345 |
+
if self.developer_message:
|
346 |
+
self.api.add_space_variable(
|
347 |
+
repo_id=self.space_id,
|
348 |
+
key="DEVELOPER_MESSAGE",
|
349 |
+
value=self.developer_message,
|
350 |
+
description="Default developer message"
|
351 |
+
)
|
352 |
+
logger.info("✅ Set DEVELOPER_MESSAGE variable")
|
353 |
+
if self.reasoning_effort:
|
354 |
+
self.api.add_space_variable(
|
355 |
+
repo_id=self.space_id,
|
356 |
+
key="REASONING_EFFORT",
|
357 |
+
value=self.reasoning_effort,
|
358 |
+
description="Default reasoning effort (low|medium|high)"
|
359 |
+
)
|
360 |
+
logger.info("✅ Set REASONING_EFFORT variable")
|
361 |
+
|
362 |
+
# Branding variables
|
363 |
+
branding_vars = {
|
364 |
+
"HF_USERNAME": self.hf_username,
|
365 |
+
"BRAND_OWNER_NAME": self.brand_owner_name,
|
366 |
+
"BRAND_TEAM_NAME": self.brand_team_name,
|
367 |
+
"BRAND_DISCORD_URL": self.brand_discord_url,
|
368 |
+
"BRAND_HF_ORG": self.brand_hf_org,
|
369 |
+
"BRAND_HF_LABEL": self.brand_hf_label,
|
370 |
+
"BRAND_HF_URL": self.brand_hf_url,
|
371 |
+
"BRAND_GH_ORG": self.brand_gh_org,
|
372 |
+
"BRAND_GH_LABEL": self.brand_gh_label,
|
373 |
+
"BRAND_GH_URL": self.brand_gh_url,
|
374 |
+
"BRAND_PROJECT_NAME": self.brand_project_name,
|
375 |
+
"BRAND_PROJECT_URL": self.brand_project_url,
|
376 |
+
}
|
377 |
+
for key, value in branding_vars.items():
|
378 |
+
self.api.add_space_variable(
|
379 |
+
repo_id=self.space_id,
|
380 |
+
key=key,
|
381 |
+
value=value,
|
382 |
+
description=f"Branding: {key}"
|
383 |
+
)
|
384 |
+
logger.info("✅ Set branding variables")
|
385 |
+
|
386 |
+
# Examples variables
|
387 |
+
if self.examples_type:
|
388 |
+
self.api.add_space_variable(
|
389 |
+
repo_id=self.space_id,
|
390 |
+
key="EXAMPLES_TYPE",
|
391 |
+
value=self.examples_type,
|
392 |
+
description="Examples pack type (e.g., general|medical)"
|
393 |
+
)
|
394 |
+
logger.info(f"✅ Set EXAMPLES_TYPE={self.examples_type}")
|
395 |
+
if self.disable_examples is not None:
|
396 |
+
self.api.add_space_variable(
|
397 |
+
repo_id=self.space_id,
|
398 |
+
key="DISABLE_EXAMPLES",
|
399 |
+
value=("true" if self.disable_examples else "false"),
|
400 |
+
description="Disable built-in examples"
|
401 |
+
)
|
402 |
+
logger.info(f"✅ Set DISABLE_EXAMPLES={self.disable_examples}")
|
403 |
+
if self.examples_json:
|
404 |
+
self.api.add_space_variable(
|
405 |
+
repo_id=self.space_id,
|
406 |
+
key="EXAMPLES_JSON",
|
407 |
+
value=self.examples_json,
|
408 |
+
description="Custom examples JSON override"
|
409 |
+
)
|
410 |
+
logger.info("✅ Set EXAMPLES_JSON override")
|
411 |
+
|
412 |
+
except Exception as e:
|
413 |
+
logger.error(f"❌ Failed to set model variables: {e}")
|
414 |
+
|
415 |
+
def validate_model_exists(self) -> bool:
|
416 |
+
"""Validate that the model exists on Hugging Face Hub"""
|
417 |
+
try:
|
418 |
+
logger.info(f"Validating model: {self.model_id}")
|
419 |
+
|
420 |
+
if HF_HUB_AVAILABLE:
|
421 |
+
# Use HF Hub API
|
422 |
+
try:
|
423 |
+
model_info = self.api.model_info(self.model_id)
|
424 |
+
logger.info(f"✅ Model {self.model_id} exists and is accessible")
|
425 |
+
return True
|
426 |
+
except Exception as e:
|
427 |
+
logger.error(f"❌ Model {self.model_id} not found via API: {e}")
|
428 |
+
return False
|
429 |
+
else:
|
430 |
+
# Fallback to requests
|
431 |
+
url = f"https://huggingface.co/api/models/{self.model_id}"
|
432 |
+
headers = {"Authorization": f"Bearer {self.hf_token}"}
|
433 |
+
response = requests.get(url, headers=headers, timeout=30)
|
434 |
+
|
435 |
+
if response.status_code == 200:
|
436 |
+
logger.info(f"✅ Model {self.model_id} exists and is accessible")
|
437 |
+
return True
|
438 |
+
else:
|
439 |
+
logger.error(f"❌ Model {self.model_id} not found or not accessible")
|
440 |
+
return False
|
441 |
+
|
442 |
+
except Exception as e:
|
443 |
+
logger.error(f"❌ Error validating model: {e}")
|
444 |
+
return False
|
445 |
+
|
446 |
+
def create_space_repository(self) -> bool:
|
447 |
+
"""Create the space repository on Hugging Face Hub"""
|
448 |
+
try:
|
449 |
+
logger.info(f"Creating Space: {self.space_name}")
|
450 |
+
|
451 |
+
if not HF_HUB_AVAILABLE:
|
452 |
+
logger.warning("huggingface_hub not available, falling back to CLI")
|
453 |
+
return self._create_space_cli()
|
454 |
+
|
455 |
+
# Use the latest HF Hub API to create space
|
456 |
+
try:
|
457 |
+
# Create the space using the API
|
458 |
+
create_repo(
|
459 |
+
repo_id=self.space_id,
|
460 |
+
token=self.hf_token,
|
461 |
+
repo_type="space",
|
462 |
+
exist_ok=True,
|
463 |
+
private=False, # Spaces are typically public
|
464 |
+
space_sdk="gradio", # Specify Gradio SDK
|
465 |
+
space_hardware="cpu-basic" # Use basic CPU
|
466 |
+
)
|
467 |
+
|
468 |
+
logger.info(f"✅ Space created successfully: {self.space_url}")
|
469 |
+
return True
|
470 |
+
|
471 |
+
except Exception as api_error:
|
472 |
+
logger.error(f"API creation failed: {api_error}")
|
473 |
+
logger.info("Falling back to CLI method...")
|
474 |
+
return self._create_space_cli()
|
475 |
+
|
476 |
+
except Exception as e:
|
477 |
+
logger.error(f"❌ Error creating space: {e}")
|
478 |
+
return False
|
479 |
+
|
480 |
+
def _create_space_cli(self) -> bool:
|
481 |
+
"""Fallback method using CLI commands"""
|
482 |
+
try:
|
483 |
+
logger.info("Using CLI fallback method...")
|
484 |
+
|
485 |
+
# Set HF token for CLI
|
486 |
+
os.environ['HF_TOKEN'] = self.hf_token
|
487 |
+
|
488 |
+
# Create space using Hugging Face CLI
|
489 |
+
cmd = [
|
490 |
+
"hf", "repo", "create",
|
491 |
+
self.space_id,
|
492 |
+
"--type", "space"
|
493 |
+
]
|
494 |
+
|
495 |
+
logger.info(f"Running command: {' '.join(cmd)}")
|
496 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
497 |
+
|
498 |
+
if result.returncode != 0:
|
499 |
+
logger.warning(f"First attempt failed: {result.stderr}")
|
500 |
+
# Try alternative approach without space-specific flags
|
501 |
+
logger.info("Retrying with basic space creation...")
|
502 |
+
cmd = [
|
503 |
+
"hf", "repo", "create",
|
504 |
+
self.space_id
|
505 |
+
]
|
506 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
507 |
+
|
508 |
+
if result.returncode == 0:
|
509 |
+
logger.info(f"✅ Space created successfully: {self.space_url}")
|
510 |
+
return True
|
511 |
+
else:
|
512 |
+
logger.error(f"❌ Failed to create space: {result.stderr}")
|
513 |
+
return False
|
514 |
+
|
515 |
+
except Exception as e:
|
516 |
+
logger.error(f"❌ Error creating space with CLI: {e}")
|
517 |
+
return False
|
518 |
+
|
519 |
+
def prepare_space_files(self) -> str:
|
520 |
+
"""Prepare all necessary files for the Space in a temporary directory"""
|
521 |
+
try:
|
522 |
+
logger.info("Preparing Space files...")
|
523 |
+
|
524 |
+
# Create temporary directory
|
525 |
+
temp_dir = tempfile.mkdtemp()
|
526 |
+
logger.info(f"Created temporary directory: {temp_dir}")
|
527 |
+
|
528 |
+
# Copy template files
|
529 |
+
copied_files = []
|
530 |
+
for file_path in self.template_dir.iterdir():
|
531 |
+
if file_path.is_file():
|
532 |
+
dest_path = Path(temp_dir) / file_path.name
|
533 |
+
shutil.copy2(file_path, dest_path)
|
534 |
+
copied_files.append(file_path.name)
|
535 |
+
logger.info(f"✅ Copied {file_path.name} to temp directory")
|
536 |
+
|
537 |
+
# Update app.py with environment variables
|
538 |
+
app_file = Path(temp_dir) / "app.py"
|
539 |
+
if app_file.exists():
|
540 |
+
with open(app_file, 'r', encoding='utf-8') as f:
|
541 |
+
content = f.read()
|
542 |
+
|
543 |
+
# Add environment variable setup at the top
|
544 |
+
env_setup = self._generate_env_setup()
|
545 |
+
|
546 |
+
# Insert after imports
|
547 |
+
lines = content.split('\n')
|
548 |
+
import_end = 0
|
549 |
+
for i, line in enumerate(lines):
|
550 |
+
if line.startswith('import ') or line.startswith('from '):
|
551 |
+
import_end = i + 1
|
552 |
+
elif line.strip() == '' and import_end > 0:
|
553 |
+
break
|
554 |
+
|
555 |
+
lines.insert(import_end, env_setup)
|
556 |
+
content = '\n'.join(lines)
|
557 |
+
|
558 |
+
with open(app_file, 'w', encoding='utf-8') as f:
|
559 |
+
f.write(content)
|
560 |
+
|
561 |
+
logger.info("✅ Updated app.py with model configuration")
|
562 |
+
|
563 |
+
# YAML front matter required by Hugging Face Spaces
|
564 |
+
yaml_front_matter = (
|
565 |
+
f"---\n"
|
566 |
+
f"title: {'GPT-OSS Demo' if self.demo_type == 'gpt' else 'SmolLM3 Demo'}\n"
|
567 |
+
f"emoji: {'🌟' if self.demo_type == 'gpt' else '💃🏻'}\n"
|
568 |
+
f"colorFrom: {'blue' if self.demo_type == 'gpt' else 'green'}\n"
|
569 |
+
f"colorTo: {'pink' if self.demo_type == 'gpt' else 'purple'}\n"
|
570 |
+
f"sdk: gradio\n"
|
571 |
+
f"sdk_version: 5.40.0\n"
|
572 |
+
f"app_file: app.py\n"
|
573 |
+
f"pinned: false\n"
|
574 |
+
f"short_description: Interactive demo for {self.model_id}\n"
|
575 |
+
+ ("license: mit\n" if self.demo_type != 'gpt' else "") +
|
576 |
+
f"---\n\n"
|
577 |
+
)
|
578 |
+
|
579 |
+
# Create README.md for the space (include configuration details)
|
580 |
+
readme_content = (
|
581 |
+
yaml_front_matter
|
582 |
+
+ f"# Demo: {self.model_id}\n\n"
|
583 |
+
+ f"This is an interactive demo for the fine-tuned model {self.model_id}.\n\n"
|
584 |
+
+ "## Features\n"
|
585 |
+
"- Interactive chat interface\n"
|
586 |
+
"- Customizable system & developer prompts\n"
|
587 |
+
"- Advanced generation parameters\n"
|
588 |
+
"- Thinking mode support\n\n"
|
589 |
+
+ "## Model Information\n"
|
590 |
+
f"- **Model ID**: {self.model_id}\n"
|
591 |
+
f"- **Subfolder**: {self.subfolder if self.subfolder and self.subfolder.strip() else 'main'}\n"
|
592 |
+
f"- **Deployed by**: {self.hf_username}\n"
|
593 |
+
+ ("- **Base Model**: openai/gpt-oss-20b\n" if self.demo_type == 'gpt' else "")
|
594 |
+
+ "\n"
|
595 |
+
+ "## Configuration\n"
|
596 |
+
"- **Model Identity**:\n\n"
|
597 |
+
f"```\n{self.model_identity or 'Not set'}\n```\n\n"
|
598 |
+
"- **System Message** (default):\n\n"
|
599 |
+
f"```\n{(self.system_message or self.model_identity) or 'Not set'}\n```\n\n"
|
600 |
+
"- **Developer Message** (default):\n\n"
|
601 |
+
f"```\n{self.developer_message or 'Not set'}\n```\n\n"
|
602 |
+
"These defaults come from the selected training configuration and can be adjusted in the UI when you run the demo.\n\n"
|
603 |
+
+ "## Usage\n"
|
604 |
+
"Simply start chatting with the model using the interface below!\n\n"
|
605 |
+
+ "---\n"
|
606 |
+
"*This demo was automatically deployed by the SmolFactory Fine-tuning Pipeline*\n"
|
607 |
+
)
|
608 |
+
|
609 |
+
with open(Path(temp_dir) / "README.md", 'w', encoding='utf-8') as f:
|
610 |
+
f.write(readme_content)
|
611 |
+
|
612 |
+
logger.info(f"✅ Prepared {len(copied_files)} files in temporary directory")
|
613 |
+
return temp_dir
|
614 |
+
|
615 |
+
except Exception as e:
|
616 |
+
logger.error(f"❌ Error preparing files: {e}")
|
617 |
+
return None
|
618 |
+
|
619 |
+
def upload_files_to_space(self, temp_dir: str) -> bool:
|
620 |
+
"""Upload files to the Space using HF Hub API directly"""
|
621 |
+
try:
|
622 |
+
logger.info("Uploading files to Space using HF Hub API...")
|
623 |
+
|
624 |
+
if not HF_HUB_AVAILABLE:
|
625 |
+
logger.error("❌ huggingface_hub not available for file upload")
|
626 |
+
return self._upload_files_cli(temp_dir)
|
627 |
+
|
628 |
+
# Upload each file using the HF Hub API
|
629 |
+
temp_path = Path(temp_dir)
|
630 |
+
uploaded_files = []
|
631 |
+
|
632 |
+
for file_path in temp_path.iterdir():
|
633 |
+
if file_path.is_file():
|
634 |
+
try:
|
635 |
+
# Upload file to the space
|
636 |
+
upload_file(
|
637 |
+
path_or_fileobj=str(file_path),
|
638 |
+
path_in_repo=file_path.name,
|
639 |
+
repo_id=self.space_id,
|
640 |
+
repo_type="space",
|
641 |
+
token=self.hf_token
|
642 |
+
)
|
643 |
+
uploaded_files.append(file_path.name)
|
644 |
+
logger.info(f"✅ Uploaded {file_path.name}")
|
645 |
+
except Exception as e:
|
646 |
+
logger.error(f"❌ Failed to upload {file_path.name}: {e}")
|
647 |
+
return False
|
648 |
+
|
649 |
+
logger.info(f"✅ Successfully uploaded {len(uploaded_files)} files to Space")
|
650 |
+
return True
|
651 |
+
|
652 |
+
except Exception as e:
|
653 |
+
logger.error(f"❌ Error uploading files: {e}")
|
654 |
+
return self._upload_files_cli(temp_dir)
|
655 |
+
|
656 |
+
def _upload_files_cli(self, temp_dir: str) -> bool:
|
657 |
+
"""Fallback method using CLI for file upload"""
|
658 |
+
try:
|
659 |
+
logger.info("Using CLI fallback for file upload...")
|
660 |
+
|
661 |
+
# Set HF token for CLI
|
662 |
+
os.environ['HF_TOKEN'] = self.hf_token
|
663 |
+
|
664 |
+
# Initialize git repository
|
665 |
+
subprocess.run(["git", "init"], cwd=temp_dir, check=True)
|
666 |
+
subprocess.run(["git", "config", "user.name", "Demo Deployer"], cwd=temp_dir, check=True)
|
667 |
+
subprocess.run(["git", "config", "user.email", "[email protected]"], cwd=temp_dir, check=True)
|
668 |
+
|
669 |
+
# Add files
|
670 |
+
subprocess.run(["git", "add", "."], cwd=temp_dir, check=True)
|
671 |
+
subprocess.run(["git", "commit", "-m", f"Deploy demo for {self.model_id}"], cwd=temp_dir, check=True)
|
672 |
+
|
673 |
+
# Add remote and push
|
674 |
+
remote_url = f"https://{self.hf_token}@huggingface.co/spaces/{self.space_id}"
|
675 |
+
subprocess.run(["git", "remote", "add", "origin", remote_url], cwd=temp_dir, check=True)
|
676 |
+
subprocess.run(["git", "push", "-u", "origin", "main"], cwd=temp_dir, check=True)
|
677 |
+
|
678 |
+
logger.info(f"✅ Successfully pushed files to space: {self.space_id}")
|
679 |
+
return True
|
680 |
+
|
681 |
+
except subprocess.CalledProcessError as e:
|
682 |
+
logger.error(f"❌ Git operation failed: {e}")
|
683 |
+
return False
|
684 |
+
except Exception as e:
|
685 |
+
logger.error(f"❌ Error pushing to space: {e}")
|
686 |
+
return False
|
687 |
+
|
688 |
+
def set_space_secrets(self) -> bool:
|
689 |
+
"""Set environment variables/secrets for the Space using HF Hub API"""
|
690 |
+
try:
|
691 |
+
logger.info("Setting Space secrets using HF Hub API...")
|
692 |
+
|
693 |
+
if not HF_HUB_AVAILABLE:
|
694 |
+
logger.warning("❌ huggingface_hub not available for setting secrets")
|
695 |
+
return self._manual_secret_setup()
|
696 |
+
|
697 |
+
# Set the HF_TOKEN secret for the space using the API
|
698 |
+
try:
|
699 |
+
self.api.add_space_secret(
|
700 |
+
repo_id=self.space_id,
|
701 |
+
key="HF_TOKEN",
|
702 |
+
value=self.space_secret_token,
|
703 |
+
description="Hugging Face token for model access"
|
704 |
+
)
|
705 |
+
logger.info("✅ Successfully set HF_TOKEN secret via API")
|
706 |
+
|
707 |
+
# Set model-specific environment variables
|
708 |
+
self._set_model_variables()
|
709 |
+
|
710 |
+
return True
|
711 |
+
|
712 |
+
except Exception as api_error:
|
713 |
+
logger.error(f"❌ Failed to set secrets via API: {api_error}")
|
714 |
+
logger.info("Falling back to manual setup...")
|
715 |
+
return self._manual_secret_setup()
|
716 |
+
|
717 |
+
except Exception as e:
|
718 |
+
logger.error(f"❌ Error setting space secrets: {e}")
|
719 |
+
return self._manual_secret_setup()
|
720 |
+
|
721 |
+
def _manual_secret_setup(self) -> bool:
|
722 |
+
"""Fallback method for manual secret setup"""
|
723 |
+
logger.info("📝 Manual Space Secrets Configuration:")
|
724 |
+
logger.info(f" HF_TOKEN=<hidden>")
|
725 |
+
logger.info(f" HF_MODEL_ID={self.model_id}")
|
726 |
+
if self.subfolder and self.subfolder.strip():
|
727 |
+
logger.info(f" MODEL_SUBFOLDER={self.subfolder}")
|
728 |
+
else:
|
729 |
+
logger.info(" MODEL_SUBFOLDER=(empty - using main model)")
|
730 |
+
|
731 |
+
# GPT-OSS specific variables
|
732 |
+
if self.demo_type == "gpt":
|
733 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
734 |
+
logger.info(f" LORA_MODEL_ID={self.model_id}")
|
735 |
+
logger.info(f" BASE_MODEL_ID=openai/gpt-oss-20b")
|
736 |
+
logger.info(f" MODEL_NAME={model_name}")
|
737 |
+
if self.model_identity:
|
738 |
+
logger.info(f" MODEL_IDENTITY={self.model_identity}")
|
739 |
+
if self.system_message:
|
740 |
+
logger.info(f" SYSTEM_MESSAGE={self.system_message}")
|
741 |
+
if self.developer_message:
|
742 |
+
logger.info(f" DEVELOPER_MESSAGE={self.developer_message}")
|
743 |
+
# Branding variables
|
744 |
+
logger.info(f" HF_USERNAME={self.hf_username}")
|
745 |
+
logger.info(f" BRAND_OWNER_NAME={self.brand_owner_name}")
|
746 |
+
logger.info(f" BRAND_TEAM_NAME={self.brand_team_name}")
|
747 |
+
logger.info(f" BRAND_DISCORD_URL={self.brand_discord_url}")
|
748 |
+
logger.info(f" BRAND_HF_ORG={self.brand_hf_org}")
|
749 |
+
logger.info(f" BRAND_HF_LABEL={self.brand_hf_label}")
|
750 |
+
logger.info(f" BRAND_HF_URL={self.brand_hf_url}")
|
751 |
+
logger.info(f" BRAND_GH_ORG={self.brand_gh_org}")
|
752 |
+
logger.info(f" BRAND_GH_LABEL={self.brand_gh_label}")
|
753 |
+
logger.info(f" BRAND_GH_URL={self.brand_gh_url}")
|
754 |
+
logger.info(f" BRAND_PROJECT_NAME={self.brand_project_name}")
|
755 |
+
logger.info(f" BRAND_PROJECT_URL={self.brand_project_url}")
|
756 |
+
|
757 |
+
# Examples variables
|
758 |
+
if self.examples_type:
|
759 |
+
logger.info(f" EXAMPLES_TYPE={self.examples_type}")
|
760 |
+
if self.disable_examples is not None:
|
761 |
+
logger.info(f" DISABLE_EXAMPLES={'true' if self.disable_examples else 'false'}")
|
762 |
+
if self.examples_json:
|
763 |
+
logger.info(f" EXAMPLES_JSON={self.examples_json}")
|
764 |
+
|
765 |
+
logger.info(f"\n🔧 To set secrets in your Space:")
|
766 |
+
logger.info(f"1. Go to your Space settings: {self.space_url}/settings")
|
767 |
+
logger.info("2. Navigate to the 'Repository secrets' section")
|
768 |
+
logger.info("3. Add the following secrets:")
|
769 |
+
logger.info(f" Name: HF_TOKEN")
|
770 |
+
logger.info(f" Value: <your token>")
|
771 |
+
logger.info(f" Name: HF_MODEL_ID")
|
772 |
+
logger.info(f" Value: {self.model_id}")
|
773 |
+
if self.subfolder and self.subfolder.strip():
|
774 |
+
logger.info(f" Name: MODEL_SUBFOLDER")
|
775 |
+
logger.info(f" Value: {self.subfolder}")
|
776 |
+
else:
|
777 |
+
logger.info(" Name: MODEL_SUBFOLDER")
|
778 |
+
logger.info(" Value: (leave empty)")
|
779 |
+
|
780 |
+
# GPT-OSS specific variables
|
781 |
+
if self.demo_type == "gpt":
|
782 |
+
model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
|
783 |
+
logger.info(f" Name: LORA_MODEL_ID")
|
784 |
+
logger.info(f" Value: {self.model_id}")
|
785 |
+
logger.info(f" Name: BASE_MODEL_ID")
|
786 |
+
logger.info(f" Value: openai/gpt-oss-20b")
|
787 |
+
logger.info(f" Name: MODEL_NAME")
|
788 |
+
logger.info(f" Value: {model_name}")
|
789 |
+
|
790 |
+
logger.info("4. Save the secrets")
|
791 |
+
|
792 |
+
return True
|
793 |
+
|
794 |
+
def test_space(self) -> bool:
|
795 |
+
"""Test if the Space is working correctly"""
|
796 |
+
try:
|
797 |
+
logger.info("Testing Space...")
|
798 |
+
|
799 |
+
# Wait a bit for the space to build
|
800 |
+
logger.info("Waiting 180 seconds for Space to build...")
|
801 |
+
time.sleep(180)
|
802 |
+
|
803 |
+
# Try to access the space
|
804 |
+
response = requests.get(self.space_url, timeout=30)
|
805 |
+
|
806 |
+
if response.status_code == 200:
|
807 |
+
logger.info(f"✅ Space is accessible: {self.space_url}")
|
808 |
+
return True
|
809 |
+
else:
|
810 |
+
logger.warning(f"⚠️ Space returned status code: {response.status_code}")
|
811 |
+
logger.warning(f"Response: {response.text[:500]}...")
|
812 |
+
return False
|
813 |
+
|
814 |
+
except Exception as e:
|
815 |
+
logger.error(f"❌ Error testing space: {e}")
|
816 |
+
return False
|
817 |
+
|
818 |
+
def deploy(self) -> bool:
|
819 |
+
"""Main deployment method"""
|
820 |
+
logger.info(f"🚀 Starting demo space deployment for {self.model_id}")
|
821 |
+
|
822 |
+
# Step 1: Validate model exists
|
823 |
+
if not self.validate_model_exists():
|
824 |
+
return False
|
825 |
+
|
826 |
+
# Step 2: Create space repository
|
827 |
+
if not self.create_space_repository():
|
828 |
+
return False
|
829 |
+
|
830 |
+
# Step 3: Prepare files
|
831 |
+
temp_dir = self.prepare_space_files()
|
832 |
+
if not temp_dir:
|
833 |
+
return False
|
834 |
+
|
835 |
+
# Step 4: Upload files
|
836 |
+
if not self.upload_files_to_space(temp_dir):
|
837 |
+
return False
|
838 |
+
|
839 |
+
# Step 5: Set space secrets
|
840 |
+
if not self.set_space_secrets():
|
841 |
+
return False
|
842 |
+
|
843 |
+
# Step 6: Clean up temp directory
|
844 |
+
try:
|
845 |
+
shutil.rmtree(temp_dir)
|
846 |
+
logger.info("✅ Cleaned up temporary directory")
|
847 |
+
except Exception as e:
|
848 |
+
logger.warning(f"⚠️ Warning: Could not clean up temp directory: {e}")
|
849 |
+
|
850 |
+
# Step 7: Test space
|
851 |
+
if not self.test_space():
|
852 |
+
logger.warning("⚠️ Space created but may need more time to build")
|
853 |
+
logger.info("Please check the Space manually in a few minutes")
|
854 |
+
|
855 |
+
logger.info(f"🎉 Demo space deployment completed!")
|
856 |
+
logger.info(f"📊 Space URL: {self.space_url}")
|
857 |
+
logger.info(f"🔧 Space configuration: {self.space_url}/settings")
|
858 |
+
|
859 |
+
return True
|
860 |
+
|
861 |
+
def main():
|
862 |
+
"""Main function for command line usage"""
|
863 |
+
print("Demo Space Deployment Script")
|
864 |
+
print("=" * 40)
|
865 |
+
|
866 |
+
parser = argparse.ArgumentParser(description="Deploy demo space to Hugging Face Spaces")
|
867 |
+
parser.add_argument("--hf-token", required=True, help="Hugging Face token")
|
868 |
+
parser.add_argument(
|
869 |
+
"--space-secret-token",
|
870 |
+
required=False,
|
871 |
+
help="Token to store as Space secret HF_TOKEN (defaults to --hf-token). Use a READ token here for least privilege.",
|
872 |
+
)
|
873 |
+
parser.add_argument("--hf-username", required=True, help="Hugging Face username")
|
874 |
+
parser.add_argument("--model-id", required=True, help="Model ID to deploy demo for")
|
875 |
+
parser.add_argument("--subfolder", default="int4", help="Model subfolder (default: int4)")
|
876 |
+
parser.add_argument("--space-name", help="Custom space name (optional)")
|
877 |
+
parser.add_argument("--demo-type", choices=["smol", "gpt"], help="Demo type: 'smol' for SmolLM, 'gpt' for GPT-OSS (auto-detected if not specified)")
|
878 |
+
parser.add_argument("--config-file", help="Path to the training config file to import context (system/developer/model_identity)")
|
879 |
+
# Examples configuration
|
880 |
+
parser.add_argument("--examples-type", choices=["general", "medical"], help="Examples pack to enable in the demo UI")
|
881 |
+
parser.add_argument("--disable-examples", action="store_true", help="Disable rendering of example prompts in the UI")
|
882 |
+
parser.add_argument("--examples-json", help="Custom examples JSON (list[str]) to override built-in examples")
|
883 |
+
# Branding customization
|
884 |
+
parser.add_argument("--brand-owner-name", help="Owner name shown in the UI title (defaults to HF username)")
|
885 |
+
parser.add_argument("--brand-team-name", help="Team name shown in Join Us (defaults to Team<owner>)")
|
886 |
+
parser.add_argument("--brand-discord-url", help="Discord invite URL for Join Us section")
|
887 |
+
parser.add_argument("--brand-hf-org", help="Hugging Face org/username to link in Join Us")
|
888 |
+
parser.add_argument("--brand-hf-label", help="Label for the HF link (defaults to org)")
|
889 |
+
parser.add_argument("--brand-hf-url", help="Custom HF link URL (defaults to https://huggingface.co/<org>)")
|
890 |
+
parser.add_argument("--brand-gh-org", help="GitHub org/username to link in Join Us")
|
891 |
+
parser.add_argument("--brand-gh-label", help="Label for the GitHub link (defaults to org)")
|
892 |
+
parser.add_argument("--brand-gh-url", help="Custom GitHub link URL (defaults to https://github.com/<org>)")
|
893 |
+
parser.add_argument("--brand-project-name", help="Project name to link in Join Us")
|
894 |
+
parser.add_argument("--brand-project-url", help="Project URL to link in Join Us")
|
895 |
+
|
896 |
+
args = parser.parse_args()
|
897 |
+
|
898 |
+
deployer = DemoSpaceDeployer(
|
899 |
+
hf_token=args.hf_token,
|
900 |
+
space_secret_token=(args.space_secret_token or None),
|
901 |
+
hf_username=args.hf_username,
|
902 |
+
model_id=args.model_id,
|
903 |
+
subfolder=args.subfolder,
|
904 |
+
space_name=args.space_name,
|
905 |
+
demo_type=args.demo_type,
|
906 |
+
config_file=args.config_file,
|
907 |
+
examples_type=args.examples_type,
|
908 |
+
disable_examples=(True if getattr(args, 'disable_examples', False) else None),
|
909 |
+
examples_json=args.examples_json,
|
910 |
+
brand_owner_name=args.brand_owner_name,
|
911 |
+
brand_team_name=args.brand_team_name,
|
912 |
+
brand_discord_url=args.brand_discord_url,
|
913 |
+
brand_hf_org=args.brand_hf_org,
|
914 |
+
brand_hf_label=args.brand_hf_label,
|
915 |
+
brand_hf_url=args.brand_hf_url,
|
916 |
+
brand_gh_org=args.brand_gh_org,
|
917 |
+
brand_gh_label=args.brand_gh_label,
|
918 |
+
brand_gh_url=args.brand_gh_url,
|
919 |
+
brand_project_name=args.brand_project_name,
|
920 |
+
brand_project_url=args.brand_project_url,
|
921 |
+
)
|
922 |
+
|
923 |
+
success = deployer.deploy()
|
924 |
+
|
925 |
+
if success:
|
926 |
+
print("\n✅ Deployment successful!")
|
927 |
+
print(f"🌐 Your Demo Space: {deployer.space_url}")
|
928 |
+
print(f"👤 Username: {deployer.hf_username}")
|
929 |
+
print(f"🤖 Model: {deployer.model_id}")
|
930 |
+
print("\nNext steps:")
|
931 |
+
print("1. Wait for the Space to build (usually 2-5 minutes)")
|
932 |
+
print("2. Secrets have been automatically set via API")
|
933 |
+
print("3. Test the interface by visiting the Space URL")
|
934 |
+
print("4. Share your demo with others!")
|
935 |
+
print("\nIf the Space doesn't work immediately, check:")
|
936 |
+
print("- The Space logs at the Space URL")
|
937 |
+
print("- That all files were uploaded correctly")
|
938 |
+
print("- That the HF token has write permissions")
|
939 |
+
print("- That the secrets were set correctly in Space settings")
|
940 |
+
else:
|
941 |
+
print("\n❌ Deployment failed!")
|
942 |
+
print("Check the error messages above and try again.")
|
943 |
+
print("\nTroubleshooting:")
|
944 |
+
print("1. Verify your HF token has write permissions")
|
945 |
+
print("2. Check that the space name is available")
|
946 |
+
print("3. Verify the model exists and is accessible")
|
947 |
+
print("4. Try creating the space manually on HF first")
|
948 |
+
|
949 |
+
sys.exit(0 if success else 1)
|
950 |
+
|
951 |
+
if __name__ == "__main__":
|
952 |
+
main()
|
scripts/generate_model_card.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Generate unified model card from template
|
4 |
+
Handles template variables and conditional sections for quantized models
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import argparse
|
10 |
+
import logging
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Dict, Any, Optional
|
13 |
+
from datetime import datetime
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
class ModelCardGenerator:
|
18 |
+
"""Generate unified model cards from templates"""
|
19 |
+
|
20 |
+
def __init__(self, template_path: str = "templates/model_card.md"):
|
21 |
+
self.template_path = Path(template_path)
|
22 |
+
if not self.template_path.exists():
|
23 |
+
raise FileNotFoundError(f"Template not found: {self.template_path}")
|
24 |
+
|
25 |
+
def load_template(self) -> str:
|
26 |
+
"""Load the model card template"""
|
27 |
+
with open(self.template_path, 'r', encoding='utf-8') as f:
|
28 |
+
return f.read()
|
29 |
+
|
30 |
+
def process_conditionals(self, content: str, variables: Dict[str, Any]) -> str:
|
31 |
+
"""Process conditional sections in the template"""
|
32 |
+
# Handle {{#if variable}}...{{/if}} blocks
|
33 |
+
pattern = r'\{\{#if\s+(\w+)\}\}(.*?)\{\{/if\}\}'
|
34 |
+
|
35 |
+
def replace_conditional(match):
|
36 |
+
variable_name = match.group(1)
|
37 |
+
conditional_content = match.group(2)
|
38 |
+
|
39 |
+
# Check if variable exists and is truthy
|
40 |
+
if variable_name in variables and variables[variable_name]:
|
41 |
+
return conditional_content
|
42 |
+
else:
|
43 |
+
return ""
|
44 |
+
|
45 |
+
return re.sub(pattern, replace_conditional, content, flags=re.DOTALL)
|
46 |
+
|
47 |
+
def replace_variables(self, content: str, variables: Dict[str, Any]) -> str:
|
48 |
+
"""Replace template variables with actual values"""
|
49 |
+
for key, value in variables.items():
|
50 |
+
placeholder = f"{{{{{key}}}}}"
|
51 |
+
content = content.replace(placeholder, str(value))
|
52 |
+
|
53 |
+
return content
|
54 |
+
|
55 |
+
def generate_model_card(self, variables: Dict[str, Any]) -> str:
|
56 |
+
"""Generate the complete model card"""
|
57 |
+
# Load template
|
58 |
+
content = self.load_template()
|
59 |
+
|
60 |
+
# Process conditionals first
|
61 |
+
content = self.process_conditionals(content, variables)
|
62 |
+
|
63 |
+
# Replace variables
|
64 |
+
content = self.replace_variables(content, variables)
|
65 |
+
|
66 |
+
return content
|
67 |
+
|
68 |
+
def save_model_card(self, content: str, output_path: str) -> bool:
|
69 |
+
"""Save the generated model card"""
|
70 |
+
try:
|
71 |
+
output_file = Path(output_path)
|
72 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
73 |
+
|
74 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
75 |
+
f.write(content)
|
76 |
+
|
77 |
+
logger.info(f"✅ Model card saved to: {output_file}")
|
78 |
+
return True
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"❌ Failed to save model card: {e}")
|
82 |
+
return False
|
83 |
+
|
84 |
+
def create_default_variables() -> Dict[str, Any]:
|
85 |
+
"""Create default variables for the model card"""
|
86 |
+
return {
|
87 |
+
"model_name": "SmolLM3 Fine-tuned Model",
|
88 |
+
"model_description": "A fine-tuned version of SmolLM3-3B for improved text generation and conversation capabilities.",
|
89 |
+
"repo_name": "your-username/model-name",
|
90 |
+
"base_model": "HuggingFaceTB/SmolLM3-3B",
|
91 |
+
"dataset_name": "OpenHermes-FR",
|
92 |
+
"training_config_type": "Custom Configuration",
|
93 |
+
"trainer_type": "SFTTrainer",
|
94 |
+
"batch_size": "8",
|
95 |
+
"gradient_accumulation_steps": "16",
|
96 |
+
"learning_rate": "5e-6",
|
97 |
+
"max_epochs": "3",
|
98 |
+
"max_seq_length": "2048",
|
99 |
+
"hardware_info": "GPU (H100/A100)",
|
100 |
+
"experiment_name": "smollm3-experiment",
|
101 |
+
"trackio_url": "https://trackio.space/experiment",
|
102 |
+
"dataset_repo": "tonic/trackio-experiments",
|
103 |
+
"dataset_size": "~80K samples",
|
104 |
+
"dataset_format": "Chat format",
|
105 |
+
"author_name": "Your Name",
|
106 |
+
"model_name_slug": "smollm3-fine-tuned",
|
107 |
+
"quantized_models": False,
|
108 |
+
"dataset_sample_size": None,
|
109 |
+
"training_loss": "N/A",
|
110 |
+
"validation_loss": "N/A",
|
111 |
+
"perplexity": "N/A"
|
112 |
+
}
|
113 |
+
|
114 |
+
def parse_args():
|
115 |
+
"""Parse command line arguments"""
|
116 |
+
parser = argparse.ArgumentParser(description="Generate unified model card")
|
117 |
+
parser.add_argument("--template", default="templates/model_card.md",
|
118 |
+
help="Path to model card template")
|
119 |
+
parser.add_argument("--output", default="README.md",
|
120 |
+
help="Output path for generated model card")
|
121 |
+
parser.add_argument("--repo-name", required=True,
|
122 |
+
help="Hugging Face repository name")
|
123 |
+
parser.add_argument("--model-name", help="Model name")
|
124 |
+
parser.add_argument("--experiment-name", help="Experiment name")
|
125 |
+
parser.add_argument("--dataset-name", help="Dataset name")
|
126 |
+
parser.add_argument("--training-config", help="Training configuration type")
|
127 |
+
parser.add_argument("--trainer-type", help="Trainer type")
|
128 |
+
parser.add_argument("--batch-size", help="Batch size")
|
129 |
+
parser.add_argument("--learning-rate", help="Learning rate")
|
130 |
+
parser.add_argument("--max-epochs", help="Maximum epochs")
|
131 |
+
parser.add_argument("--max-seq-length", help="Maximum sequence length")
|
132 |
+
parser.add_argument("--hardware-info", help="Hardware information")
|
133 |
+
parser.add_argument("--trackio-url", help="Trackio URL")
|
134 |
+
parser.add_argument("--dataset-repo", help="Dataset repository")
|
135 |
+
parser.add_argument("--author-name", help="Author name")
|
136 |
+
parser.add_argument("--quantized-models", action="store_true",
|
137 |
+
help="Include quantized models")
|
138 |
+
parser.add_argument("--dataset-sample-size", help="Dataset sample size")
|
139 |
+
parser.add_argument("--training-loss", help="Training loss value")
|
140 |
+
parser.add_argument("--validation-loss", help="Validation loss value")
|
141 |
+
parser.add_argument("--perplexity", help="Perplexity value")
|
142 |
+
|
143 |
+
return parser.parse_args()
|
144 |
+
|
145 |
+
def main():
|
146 |
+
"""Main function"""
|
147 |
+
args = parse_args()
|
148 |
+
|
149 |
+
# Setup logging
|
150 |
+
logging.basicConfig(
|
151 |
+
level=logging.INFO,
|
152 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
153 |
+
)
|
154 |
+
|
155 |
+
try:
|
156 |
+
# Create generator
|
157 |
+
generator = ModelCardGenerator(args.template)
|
158 |
+
|
159 |
+
# Create variables dictionary
|
160 |
+
variables = create_default_variables()
|
161 |
+
|
162 |
+
# Override with command line arguments
|
163 |
+
if args.repo_name:
|
164 |
+
variables["repo_name"] = args.repo_name
|
165 |
+
if args.model_name:
|
166 |
+
variables["model_name"] = args.model_name
|
167 |
+
if args.experiment_name:
|
168 |
+
variables["experiment_name"] = args.experiment_name
|
169 |
+
if args.dataset_name:
|
170 |
+
variables["dataset_name"] = args.dataset_name
|
171 |
+
if args.training_config:
|
172 |
+
variables["training_config_type"] = args.training_config
|
173 |
+
if args.trainer_type:
|
174 |
+
variables["trainer_type"] = args.trainer_type
|
175 |
+
if args.batch_size:
|
176 |
+
variables["batch_size"] = args.batch_size
|
177 |
+
if args.learning_rate:
|
178 |
+
variables["learning_rate"] = args.learning_rate
|
179 |
+
if args.max_epochs:
|
180 |
+
variables["max_epochs"] = args.max_epochs
|
181 |
+
if args.max_seq_length:
|
182 |
+
variables["max_seq_length"] = args.max_seq_length
|
183 |
+
if args.hardware_info:
|
184 |
+
variables["hardware_info"] = args.hardware_info
|
185 |
+
if args.trackio_url:
|
186 |
+
variables["trackio_url"] = args.trackio_url
|
187 |
+
if args.dataset_repo:
|
188 |
+
variables["dataset_repo"] = args.dataset_repo
|
189 |
+
if args.author_name:
|
190 |
+
variables["author_name"] = args.author_name
|
191 |
+
if args.quantized_models:
|
192 |
+
variables["quantized_models"] = True
|
193 |
+
if args.dataset_sample_size:
|
194 |
+
variables["dataset_sample_size"] = args.dataset_sample_size
|
195 |
+
if args.training_loss:
|
196 |
+
variables["training_loss"] = args.training_loss
|
197 |
+
if args.validation_loss:
|
198 |
+
variables["validation_loss"] = args.validation_loss
|
199 |
+
if args.perplexity:
|
200 |
+
variables["perplexity"] = args.perplexity
|
201 |
+
|
202 |
+
# Generate model card
|
203 |
+
print("🔄 Generating model card...")
|
204 |
+
content = generator.generate_model_card(variables)
|
205 |
+
|
206 |
+
# Save model card
|
207 |
+
if generator.save_model_card(content, args.output):
|
208 |
+
print("✅ Model card generated successfully!")
|
209 |
+
print(f"📄 Output: {args.output}")
|
210 |
+
else:
|
211 |
+
print("❌ Failed to generate model card")
|
212 |
+
return 1
|
213 |
+
|
214 |
+
return 0
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
logger.error(f"❌ Error generating model card: {e}")
|
218 |
+
return 1
|
219 |
+
|
220 |
+
if __name__ == "__main__":
|
221 |
+
exit(main())
|
scripts/push_to_huggingface.py
ADDED
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Push Trained Model and Results to Hugging Face Hub
|
4 |
+
Integrates with Trackio monitoring and HF Datasets for complete model deployment
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import argparse
|
10 |
+
import logging
|
11 |
+
import time
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Dict, Any, Optional, List
|
14 |
+
from datetime import datetime
|
15 |
+
import subprocess
|
16 |
+
import shutil
|
17 |
+
import platform
|
18 |
+
|
19 |
+
# Set timeout for HF operations to prevent hanging
|
20 |
+
os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '300'
|
21 |
+
os.environ['HF_HUB_UPLOAD_TIMEOUT'] = '600'
|
22 |
+
|
23 |
+
try:
|
24 |
+
from huggingface_hub import HfApi, create_repo, upload_file
|
25 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
26 |
+
HF_AVAILABLE = True
|
27 |
+
except ImportError:
|
28 |
+
HF_AVAILABLE = False
|
29 |
+
print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
|
30 |
+
|
31 |
+
try:
|
32 |
+
import sys
|
33 |
+
import os
|
34 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
|
35 |
+
from monitoring import SmolLM3Monitor
|
36 |
+
MONITORING_AVAILABLE = True
|
37 |
+
except ImportError:
|
38 |
+
MONITORING_AVAILABLE = False
|
39 |
+
print("Warning: monitoring module not available")
|
40 |
+
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
class TimeoutError(Exception):
|
44 |
+
"""Custom timeout exception"""
|
45 |
+
pass
|
46 |
+
|
47 |
+
def timeout_handler(signum, frame):
|
48 |
+
"""Signal handler for timeout"""
|
49 |
+
raise TimeoutError("Operation timed out")
|
50 |
+
|
51 |
+
class HuggingFacePusher:
|
52 |
+
"""Push trained models and results to Hugging Face Hub with HF Datasets integration"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
model_path: str,
|
57 |
+
repo_name: str,
|
58 |
+
token: Optional[str] = None,
|
59 |
+
private: bool = False,
|
60 |
+
trackio_url: Optional[str] = None,
|
61 |
+
experiment_name: Optional[str] = None,
|
62 |
+
dataset_repo: Optional[str] = None,
|
63 |
+
hf_token: Optional[str] = None,
|
64 |
+
author_name: Optional[str] = None,
|
65 |
+
model_description: Optional[str] = None,
|
66 |
+
training_config_type: Optional[str] = None,
|
67 |
+
model_name: Optional[str] = None,
|
68 |
+
dataset_name: Optional[str] = None,
|
69 |
+
batch_size: Optional[str] = None,
|
70 |
+
learning_rate: Optional[str] = None,
|
71 |
+
max_epochs: Optional[str] = None,
|
72 |
+
max_seq_length: Optional[str] = None,
|
73 |
+
trainer_type: Optional[str] = None
|
74 |
+
):
|
75 |
+
self.model_path = Path(model_path)
|
76 |
+
# Original user input (may be just the repo name without username)
|
77 |
+
self.repo_name = repo_name
|
78 |
+
self.token = token or hf_token or os.getenv('HF_TOKEN')
|
79 |
+
self.private = private
|
80 |
+
self.trackio_url = trackio_url
|
81 |
+
self.experiment_name = experiment_name
|
82 |
+
self.author_name = author_name
|
83 |
+
self.model_description = model_description
|
84 |
+
|
85 |
+
# Training configuration details for model card generation
|
86 |
+
self.training_config_type = training_config_type
|
87 |
+
self.model_name = model_name
|
88 |
+
self.dataset_name = dataset_name
|
89 |
+
self.batch_size = batch_size
|
90 |
+
self.learning_rate = learning_rate
|
91 |
+
self.max_epochs = max_epochs
|
92 |
+
self.max_seq_length = max_seq_length
|
93 |
+
self.trainer_type = trainer_type
|
94 |
+
|
95 |
+
# HF Datasets configuration
|
96 |
+
self.dataset_repo = dataset_repo or os.getenv('TRACKIO_DATASET_REPO', 'tonic/trackio-experiments')
|
97 |
+
self.hf_token = hf_token or os.getenv('HF_TOKEN')
|
98 |
+
|
99 |
+
# Initialize HF API
|
100 |
+
if HF_AVAILABLE:
|
101 |
+
self.api = HfApi(token=self.token)
|
102 |
+
else:
|
103 |
+
raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub")
|
104 |
+
|
105 |
+
# Resolve the full repo id (username/repo) if user only provided repo name
|
106 |
+
self.repo_id = self._resolve_repo_id(self.repo_name)
|
107 |
+
|
108 |
+
# Initialize monitoring if available
|
109 |
+
self.monitor = None
|
110 |
+
if MONITORING_AVAILABLE:
|
111 |
+
self.monitor = SmolLM3Monitor(
|
112 |
+
experiment_name=experiment_name or "model_push",
|
113 |
+
trackio_url=trackio_url,
|
114 |
+
enable_tracking=bool(trackio_url),
|
115 |
+
hf_token=self.hf_token,
|
116 |
+
dataset_repo=self.dataset_repo
|
117 |
+
)
|
118 |
+
|
119 |
+
logger.info(f"Initialized HuggingFacePusher for {self.repo_id}")
|
120 |
+
logger.info(f"Dataset repository: {self.dataset_repo}")
|
121 |
+
|
122 |
+
def _resolve_repo_id(self, repo_name: str) -> str:
|
123 |
+
"""Return a fully-qualified repo id in the form username/repo.
|
124 |
+
|
125 |
+
If the provided name already contains a '/', it is returned unchanged.
|
126 |
+
Otherwise, we attempt to derive the username from the authenticated token
|
127 |
+
or from the HF_USERNAME environment variable.
|
128 |
+
"""
|
129 |
+
try:
|
130 |
+
if "/" in repo_name:
|
131 |
+
return repo_name
|
132 |
+
|
133 |
+
# Need a username. Prefer API whoami(), fallback to env HF_USERNAME
|
134 |
+
username: Optional[str] = None
|
135 |
+
if self.token:
|
136 |
+
try:
|
137 |
+
user_info = self.api.whoami()
|
138 |
+
username = user_info.get("name") or user_info.get("username")
|
139 |
+
except Exception:
|
140 |
+
username = None
|
141 |
+
|
142 |
+
if not username:
|
143 |
+
username = os.getenv("HF_USERNAME")
|
144 |
+
|
145 |
+
if not username:
|
146 |
+
raise ValueError(
|
147 |
+
"Username could not be determined. Provide a token or set HF_USERNAME, "
|
148 |
+
"or pass a fully-qualified repo id 'username/repo'."
|
149 |
+
)
|
150 |
+
|
151 |
+
return f"{username}/{repo_name}"
|
152 |
+
except Exception as resolve_error:
|
153 |
+
logger.error(f"Failed to resolve full repo id for '{repo_name}': {resolve_error}")
|
154 |
+
# Fall back to provided value (may fail later at create/upload)
|
155 |
+
return repo_name
|
156 |
+
|
157 |
+
def create_repository(self) -> bool:
|
158 |
+
"""Create the Hugging Face repository"""
|
159 |
+
try:
|
160 |
+
logger.info(f"Creating repository: {self.repo_id}")
|
161 |
+
|
162 |
+
# Create repository with timeout handling
|
163 |
+
try:
|
164 |
+
# Create repository
|
165 |
+
create_repo(
|
166 |
+
repo_id=self.repo_id,
|
167 |
+
token=self.token,
|
168 |
+
private=self.private,
|
169 |
+
exist_ok=True
|
170 |
+
)
|
171 |
+
|
172 |
+
logger.info(f"✅ Repository created: https://huggingface.co/{self.repo_id}")
|
173 |
+
return True
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f"❌ Repository creation failed: {e}")
|
177 |
+
return False
|
178 |
+
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(f"❌ Failed to create repository: {e}")
|
181 |
+
return False
|
182 |
+
|
183 |
+
def validate_model_path(self) -> bool:
|
184 |
+
"""Validate that the model path contains required files"""
|
185 |
+
# Support both safetensors and pytorch formats
|
186 |
+
required_files = [
|
187 |
+
"config.json",
|
188 |
+
"tokenizer.json",
|
189 |
+
"tokenizer_config.json"
|
190 |
+
]
|
191 |
+
|
192 |
+
# Check for model files (either safetensors or pytorch)
|
193 |
+
model_files = [
|
194 |
+
"model.safetensors.index.json", # Safetensors format
|
195 |
+
"pytorch_model.bin" # PyTorch format
|
196 |
+
]
|
197 |
+
|
198 |
+
missing_files = []
|
199 |
+
for file in required_files:
|
200 |
+
if not (self.model_path / file).exists():
|
201 |
+
missing_files.append(file)
|
202 |
+
|
203 |
+
# Check if at least one model file exists
|
204 |
+
model_file_exists = any((self.model_path / file).exists() for file in model_files)
|
205 |
+
if not model_file_exists:
|
206 |
+
missing_files.extend(model_files)
|
207 |
+
|
208 |
+
if missing_files:
|
209 |
+
logger.error(f"❌ Missing required files: {missing_files}")
|
210 |
+
return False
|
211 |
+
|
212 |
+
logger.info("✅ Model files validated")
|
213 |
+
return True
|
214 |
+
|
215 |
+
def create_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
|
216 |
+
"""Create a comprehensive model card using the generate_model_card.py script"""
|
217 |
+
try:
|
218 |
+
# Import the model card generator
|
219 |
+
import sys
|
220 |
+
sys.path.append(os.path.join(os.path.dirname(__file__)))
|
221 |
+
from generate_model_card import ModelCardGenerator, create_default_variables
|
222 |
+
|
223 |
+
# Create generator
|
224 |
+
generator = ModelCardGenerator()
|
225 |
+
|
226 |
+
# Create variables for the model card
|
227 |
+
variables = create_default_variables()
|
228 |
+
|
229 |
+
# Update with actual values
|
230 |
+
variables.update({
|
231 |
+
"repo_name": self.repo_id,
|
232 |
+
"model_name": self.repo_id.split('/')[-1],
|
233 |
+
"experiment_name": self.experiment_name or "model_push",
|
234 |
+
"dataset_repo": self.dataset_repo,
|
235 |
+
"author_name": self.author_name or "Model Author",
|
236 |
+
"model_description": self.model_description or "A fine-tuned version of SmolLM3-3B for improved text generation capabilities.",
|
237 |
+
"training_config_type": self.training_config_type or "Custom Configuration",
|
238 |
+
"base_model": self.model_name or "HuggingFaceTB/SmolLM3-3B",
|
239 |
+
"dataset_name": self.dataset_name or "Custom Dataset",
|
240 |
+
"trainer_type": self.trainer_type or "SFTTrainer",
|
241 |
+
"batch_size": str(self.batch_size) if self.batch_size else "8",
|
242 |
+
"learning_rate": str(self.learning_rate) if self.learning_rate else "5e-6",
|
243 |
+
"max_epochs": str(self.max_epochs) if self.max_epochs else "3",
|
244 |
+
"max_seq_length": str(self.max_seq_length) if self.max_seq_length else "2048",
|
245 |
+
"hardware_info": self._get_hardware_info(),
|
246 |
+
"trackio_url": self.trackio_url or "N/A",
|
247 |
+
"training_loss": str(results.get('train_loss', 'N/A')),
|
248 |
+
"validation_loss": str(results.get('eval_loss', 'N/A')),
|
249 |
+
"perplexity": str(results.get('perplexity', 'N/A')),
|
250 |
+
"quantized_models": False # Set to True if quantized models are available
|
251 |
+
})
|
252 |
+
|
253 |
+
# Generate the model card
|
254 |
+
model_card_content = generator.generate_model_card(variables)
|
255 |
+
|
256 |
+
logger.info("✅ Model card generated using generate_model_card.py")
|
257 |
+
return model_card_content
|
258 |
+
|
259 |
+
except Exception as e:
|
260 |
+
logger.error(f"❌ Failed to generate model card with generator: {e}")
|
261 |
+
logger.info("🔄 Falling back to simple model card")
|
262 |
+
return self._create_simple_model_card(training_config, results)
|
263 |
+
|
264 |
+
def _create_simple_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
|
265 |
+
"""Create a simple model card without complex YAML to avoid formatting issues"""
|
266 |
+
return f"""---
|
267 |
+
language:
|
268 |
+
- en
|
269 |
+
- fr
|
270 |
+
license: apache-2.0
|
271 |
+
tags:
|
272 |
+
- smollm3
|
273 |
+
- fine-tuned
|
274 |
+
- causal-lm
|
275 |
+
- text-generation
|
276 |
+
pipeline_tag: text-generation
|
277 |
+
base_model: HuggingFaceTB/SmolLM3-3B
|
278 |
+
---
|
279 |
+
|
280 |
+
# {self.repo_id.split('/')[-1]}
|
281 |
+
|
282 |
+
This is a fine-tuned SmolLM3 model based on the HuggingFaceTB/SmolLM3-3B architecture.
|
283 |
+
|
284 |
+
## Model Details
|
285 |
+
|
286 |
+
- **Base Model**: HuggingFaceTB/SmolLM3-3B
|
287 |
+
- **Fine-tuning Method**: Supervised Fine-tuning
|
288 |
+
- **Training Date**: {datetime.now().strftime('%Y-%m-%d')}
|
289 |
+
- **Model Size**: {self._get_model_size():.1f} GB
|
290 |
+
- **Dataset Repository**: {self.dataset_repo}
|
291 |
+
- **Hardware**: {self._get_hardware_info()}
|
292 |
+
|
293 |
+
## Training Configuration
|
294 |
+
|
295 |
+
```json
|
296 |
+
{json.dumps(training_config, indent=2)}
|
297 |
+
```
|
298 |
+
|
299 |
+
## Training Results
|
300 |
+
|
301 |
+
```json
|
302 |
+
{json.dumps(results, indent=2)}
|
303 |
+
```
|
304 |
+
|
305 |
+
## Usage
|
306 |
+
|
307 |
+
```python
|
308 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
309 |
+
|
310 |
+
# Load model and tokenizer
|
311 |
+
model = AutoModelForCausalLM.from_pretrained("{self.repo_id}")
|
312 |
+
tokenizer = AutoTokenizer.from_pretrained("{self.repo_id}")
|
313 |
+
|
314 |
+
# Generate text
|
315 |
+
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
|
316 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
317 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
318 |
+
```
|
319 |
+
|
320 |
+
## Training Information
|
321 |
+
|
322 |
+
- **Base Model**: HuggingFaceTB/SmolLM3-3B
|
323 |
+
- **Hardware**: {self._get_hardware_info()}
|
324 |
+
- **Training Time**: {results.get('training_time_hours', 'Unknown')} hours
|
325 |
+
- **Final Loss**: {results.get('final_loss', 'Unknown')}
|
326 |
+
- **Final Accuracy**: {results.get('final_accuracy', 'Unknown')}
|
327 |
+
- **Dataset Repository**: {self.dataset_repo}
|
328 |
+
|
329 |
+
## Model Performance
|
330 |
+
|
331 |
+
- **Training Loss**: {results.get('train_loss', 'Unknown')}
|
332 |
+
- **Validation Loss**: {results.get('eval_loss', 'Unknown')}
|
333 |
+
- **Training Steps**: {results.get('total_steps', 'Unknown')}
|
334 |
+
|
335 |
+
## Experiment Tracking
|
336 |
+
|
337 |
+
This model was trained with experiment tracking enabled. Training metrics and configuration are stored in the HF Dataset repository: `{self.dataset_repo}`
|
338 |
+
|
339 |
+
## Limitations and Biases
|
340 |
+
|
341 |
+
This model is fine-tuned for specific tasks and may not generalize well to all use cases. Please evaluate the model's performance on your specific task before deployment.
|
342 |
+
|
343 |
+
## License
|
344 |
+
|
345 |
+
This model is licensed under the Apache 2.0 License.
|
346 |
+
"""
|
347 |
+
|
348 |
+
def _get_model_size(self) -> float:
|
349 |
+
"""Get model size in GB"""
|
350 |
+
try:
|
351 |
+
total_size = 0
|
352 |
+
for file in self.model_path.rglob("*"):
|
353 |
+
if file.is_file():
|
354 |
+
total_size += file.stat().st_size
|
355 |
+
return total_size / (1024**3) # Convert to GB
|
356 |
+
except:
|
357 |
+
return 0.0
|
358 |
+
|
359 |
+
def _get_hardware_info(self) -> str:
|
360 |
+
"""Get hardware information"""
|
361 |
+
try:
|
362 |
+
import torch
|
363 |
+
if torch.cuda.is_available():
|
364 |
+
gpu_name = torch.cuda.get_device_name(0)
|
365 |
+
return f"GPU: {gpu_name}"
|
366 |
+
else:
|
367 |
+
return "CPU"
|
368 |
+
except:
|
369 |
+
return "Unknown"
|
370 |
+
|
371 |
+
def upload_model_files(self) -> bool:
|
372 |
+
"""Upload model files to Hugging Face Hub with timeout protection"""
|
373 |
+
try:
|
374 |
+
logger.info("Uploading model files...")
|
375 |
+
|
376 |
+
# Upload all files in the model directory
|
377 |
+
for file_path in self.model_path.rglob("*"):
|
378 |
+
if file_path.is_file():
|
379 |
+
relative_path = file_path.relative_to(self.model_path)
|
380 |
+
remote_path = str(relative_path)
|
381 |
+
|
382 |
+
logger.info(f"Uploading {relative_path}")
|
383 |
+
|
384 |
+
try:
|
385 |
+
upload_file(
|
386 |
+
path_or_fileobj=str(file_path),
|
387 |
+
path_in_repo=remote_path,
|
388 |
+
repo_id=self.repo_id,
|
389 |
+
token=self.token
|
390 |
+
)
|
391 |
+
logger.info(f"✅ Uploaded {relative_path}")
|
392 |
+
|
393 |
+
except Exception as e:
|
394 |
+
logger.error(f"❌ Failed to upload {relative_path}: {e}")
|
395 |
+
return False
|
396 |
+
|
397 |
+
logger.info("✅ Model files uploaded successfully")
|
398 |
+
return True
|
399 |
+
|
400 |
+
except Exception as e:
|
401 |
+
logger.error(f"❌ Failed to upload model files: {e}")
|
402 |
+
return False
|
403 |
+
|
404 |
+
def upload_training_results(self, results_path: str) -> bool:
|
405 |
+
"""Upload training results and logs"""
|
406 |
+
try:
|
407 |
+
logger.info("Uploading training results...")
|
408 |
+
|
409 |
+
results_files = [
|
410 |
+
"train_results.json",
|
411 |
+
"eval_results.json",
|
412 |
+
"training_config.json",
|
413 |
+
"training.log"
|
414 |
+
]
|
415 |
+
|
416 |
+
for file_name in results_files:
|
417 |
+
file_path = Path(results_path) / file_name
|
418 |
+
if file_path.exists():
|
419 |
+
logger.info(f"Uploading {file_name}")
|
420 |
+
upload_file(
|
421 |
+
path_or_fileobj=str(file_path),
|
422 |
+
path_in_repo=f"training_results/{file_name}",
|
423 |
+
repo_id=self.repo_id,
|
424 |
+
token=self.token
|
425 |
+
)
|
426 |
+
|
427 |
+
logger.info("✅ Training results uploaded successfully")
|
428 |
+
return True
|
429 |
+
|
430 |
+
except Exception as e:
|
431 |
+
logger.error(f"❌ Failed to upload training results: {e}")
|
432 |
+
return False
|
433 |
+
|
434 |
+
def create_readme(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> bool:
|
435 |
+
"""Create and upload README.md"""
|
436 |
+
try:
|
437 |
+
logger.info("Creating README.md...")
|
438 |
+
|
439 |
+
readme_content = f"""# {self.repo_id.split('/')[-1]}
|
440 |
+
|
441 |
+
A fine-tuned SmolLM3 model for text generation tasks.
|
442 |
+
|
443 |
+
## Quick Start
|
444 |
+
|
445 |
+
```python
|
446 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
447 |
+
|
448 |
+
model = AutoModelForCausalLM.from_pretrained("{self.repo_id}")
|
449 |
+
tokenizer = AutoTokenizer.from_pretrained("{self.repo_id}")
|
450 |
+
|
451 |
+
# Generate text
|
452 |
+
text = "Hello, how are you?"
|
453 |
+
inputs = tokenizer(text, return_tensors="pt")
|
454 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
455 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
456 |
+
```
|
457 |
+
|
458 |
+
## Model Information
|
459 |
+
|
460 |
+
- **Base Model**: HuggingFaceTB/SmolLM3-3B
|
461 |
+
- **Fine-tuning Date**: {datetime.now().strftime('%Y-%m-%d')}
|
462 |
+
- **Model Size**: {self._get_model_size():.1f} GB
|
463 |
+
- **Training Steps**: {results.get('total_steps', 'Unknown')}
|
464 |
+
- **Final Loss**: {results.get('final_loss', 'Unknown')}
|
465 |
+
- **Dataset Repository**: {self.dataset_repo}
|
466 |
+
|
467 |
+
## Training Configuration
|
468 |
+
|
469 |
+
```json
|
470 |
+
{json.dumps(training_config, indent=2)}
|
471 |
+
```
|
472 |
+
|
473 |
+
## Performance Metrics
|
474 |
+
|
475 |
+
```json
|
476 |
+
{json.dumps(results, indent=2)}
|
477 |
+
```
|
478 |
+
|
479 |
+
## Experiment Tracking
|
480 |
+
|
481 |
+
Training metrics and configuration are stored in the HF Dataset repository: `{self.dataset_repo}`
|
482 |
+
|
483 |
+
## Files
|
484 |
+
|
485 |
+
- `model.safetensors.index.json`: Model weights (safetensors format)
|
486 |
+
- `config.json`: Model configuration
|
487 |
+
- `tokenizer.json`: Tokenizer configuration
|
488 |
+
- `training_results/`: Training logs and results
|
489 |
+
|
490 |
+
## License
|
491 |
+
|
492 |
+
MIT License
|
493 |
+
"""
|
494 |
+
|
495 |
+
# Write README to temporary file
|
496 |
+
readme_path = Path("temp_readme.md")
|
497 |
+
with open(readme_path, "w") as f:
|
498 |
+
f.write(readme_content)
|
499 |
+
|
500 |
+
# Upload README
|
501 |
+
upload_file(
|
502 |
+
path_or_fileobj=str(readme_path),
|
503 |
+
path_in_repo="README.md",
|
504 |
+
token=self.token,
|
505 |
+
repo_id=self.repo_id
|
506 |
+
)
|
507 |
+
|
508 |
+
# Clean up
|
509 |
+
readme_path.unlink()
|
510 |
+
|
511 |
+
logger.info("✅ README.md uploaded successfully")
|
512 |
+
return True
|
513 |
+
|
514 |
+
except Exception as e:
|
515 |
+
logger.error(f"❌ Failed to create README: {e}")
|
516 |
+
return False
|
517 |
+
|
518 |
+
def log_to_trackio(self, action: str, details: Dict[str, Any]):
|
519 |
+
"""Log push action to Trackio and HF Datasets"""
|
520 |
+
if self.monitor:
|
521 |
+
try:
|
522 |
+
# Log to Trackio
|
523 |
+
self.monitor.log_metrics({
|
524 |
+
"push_action": action,
|
525 |
+
"repo_name": self.repo_id,
|
526 |
+
"model_size_gb": self._get_model_size(),
|
527 |
+
"dataset_repo": self.dataset_repo,
|
528 |
+
**details
|
529 |
+
})
|
530 |
+
|
531 |
+
# Log training summary
|
532 |
+
self.monitor.log_training_summary({
|
533 |
+
"model_push": True,
|
534 |
+
"model_repo": self.repo_id,
|
535 |
+
"dataset_repo": self.dataset_repo,
|
536 |
+
"push_date": datetime.now().isoformat(),
|
537 |
+
**details
|
538 |
+
})
|
539 |
+
|
540 |
+
logger.info(f"✅ Logged {action} to Trackio and HF Datasets")
|
541 |
+
except Exception as e:
|
542 |
+
logger.error(f"❌ Failed to log to Trackio: {e}")
|
543 |
+
|
544 |
+
def push_model(self, training_config: Optional[Dict[str, Any]] = None,
|
545 |
+
results: Optional[Dict[str, Any]] = None) -> bool:
|
546 |
+
"""Complete model push process with HF Datasets integration"""
|
547 |
+
logger.info(f"🚀 Starting model push to {self.repo_id}")
|
548 |
+
logger.info(f"📊 Dataset repository: {self.dataset_repo}")
|
549 |
+
|
550 |
+
# Validate model path
|
551 |
+
if not self.validate_model_path():
|
552 |
+
return False
|
553 |
+
|
554 |
+
# Create repository
|
555 |
+
if not self.create_repository():
|
556 |
+
return False
|
557 |
+
|
558 |
+
# Load training config and results if not provided
|
559 |
+
if training_config is None:
|
560 |
+
training_config = self._load_training_config()
|
561 |
+
|
562 |
+
if results is None:
|
563 |
+
results = self._load_training_results()
|
564 |
+
|
565 |
+
# Create and upload model card
|
566 |
+
model_card = self.create_model_card(training_config, results)
|
567 |
+
model_card_path = Path("temp_model_card.md")
|
568 |
+
with open(model_card_path, "w") as f:
|
569 |
+
f.write(model_card)
|
570 |
+
|
571 |
+
try:
|
572 |
+
upload_file(
|
573 |
+
path_or_fileobj=str(model_card_path),
|
574 |
+
path_in_repo="README.md",
|
575 |
+
repo_id=self.repo_id,
|
576 |
+
token=self.token
|
577 |
+
)
|
578 |
+
finally:
|
579 |
+
model_card_path.unlink()
|
580 |
+
|
581 |
+
# Upload model files
|
582 |
+
if not self.upload_model_files():
|
583 |
+
return False
|
584 |
+
|
585 |
+
# Upload training results
|
586 |
+
if results:
|
587 |
+
self.upload_training_results(str(self.model_path))
|
588 |
+
|
589 |
+
# Log to Trackio and HF Datasets
|
590 |
+
self.log_to_trackio("model_push", {
|
591 |
+
"model_path": str(self.model_path),
|
592 |
+
"repo_name": self.repo_name,
|
593 |
+
"private": self.private,
|
594 |
+
"training_config": training_config,
|
595 |
+
"results": results
|
596 |
+
})
|
597 |
+
|
598 |
+
logger.info(f"🎉 Model successfully pushed to: https://huggingface.co/{self.repo_id}")
|
599 |
+
logger.info(f"📊 Experiment data stored in: {self.dataset_repo}")
|
600 |
+
return True
|
601 |
+
|
602 |
+
def _load_training_config(self) -> Dict[str, Any]:
|
603 |
+
"""Load training configuration"""
|
604 |
+
config_path = self.model_path / "training_config.json"
|
605 |
+
if config_path.exists():
|
606 |
+
with open(config_path, "r") as f:
|
607 |
+
return json.load(f)
|
608 |
+
return {"model_name": "HuggingFaceTB/SmolLM3-3B"}
|
609 |
+
|
610 |
+
def _load_training_results(self) -> Dict[str, Any]:
|
611 |
+
"""Load training results"""
|
612 |
+
results_path = self.model_path / "train_results.json"
|
613 |
+
if results_path.exists():
|
614 |
+
with open(results_path, "r") as f:
|
615 |
+
return json.load(f)
|
616 |
+
return {"final_loss": "Unknown", "total_steps": "Unknown"}
|
617 |
+
|
618 |
+
def parse_args():
|
619 |
+
"""Parse command line arguments"""
|
620 |
+
parser = argparse.ArgumentParser(description='Push trained model to Hugging Face Hub')
|
621 |
+
|
622 |
+
# Required arguments
|
623 |
+
parser.add_argument('model_path', type=str, help='Path to trained model directory')
|
624 |
+
parser.add_argument('repo_name', type=str, help='Hugging Face repository name (repo-name). Username will be auto-detected from your token.')
|
625 |
+
|
626 |
+
# Optional arguments
|
627 |
+
parser.add_argument('--token', type=str, default=None, help='Hugging Face token')
|
628 |
+
parser.add_argument('--hf-token', type=str, default=None, help='Hugging Face token (alternative to --token)')
|
629 |
+
parser.add_argument('--private', action='store_true', help='Make repository private')
|
630 |
+
parser.add_argument('--trackio-url', type=str, default=None, help='Trackio Space URL for logging')
|
631 |
+
parser.add_argument('--experiment-name', type=str, default=None, help='Experiment name for Trackio')
|
632 |
+
parser.add_argument('--dataset-repo', type=str, default=None, help='HF Dataset repository for experiment storage')
|
633 |
+
parser.add_argument('--author-name', type=str, default=None, help='Author name for model card')
|
634 |
+
parser.add_argument('--model-description', type=str, default=None, help='Model description for model card')
|
635 |
+
parser.add_argument('--training-config-type', type=str, default=None, help='Training configuration type')
|
636 |
+
parser.add_argument('--model-name', type=str, default=None, help='Base model name')
|
637 |
+
parser.add_argument('--dataset-name', type=str, default=None, help='Dataset name')
|
638 |
+
parser.add_argument('--batch-size', type=str, default=None, help='Batch size')
|
639 |
+
parser.add_argument('--learning-rate', type=str, default=None, help='Learning rate')
|
640 |
+
parser.add_argument('--max-epochs', type=str, default=None, help='Maximum epochs')
|
641 |
+
parser.add_argument('--max-seq-length', type=str, default=None, help='Maximum sequence length')
|
642 |
+
parser.add_argument('--trainer-type', type=str, default=None, help='Trainer type')
|
643 |
+
|
644 |
+
return parser.parse_args()
|
645 |
+
|
646 |
+
def main():
|
647 |
+
"""Main function"""
|
648 |
+
args = parse_args()
|
649 |
+
|
650 |
+
# Setup logging
|
651 |
+
logging.basicConfig(
|
652 |
+
level=logging.INFO,
|
653 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
654 |
+
)
|
655 |
+
|
656 |
+
logger.info("Starting model push to Hugging Face Hub")
|
657 |
+
|
658 |
+
# Initialize pusher
|
659 |
+
try:
|
660 |
+
pusher = HuggingFacePusher(
|
661 |
+
model_path=args.model_path,
|
662 |
+
repo_name=args.repo_name,
|
663 |
+
token=args.token,
|
664 |
+
private=args.private,
|
665 |
+
trackio_url=args.trackio_url,
|
666 |
+
experiment_name=args.experiment_name,
|
667 |
+
dataset_repo=args.dataset_repo,
|
668 |
+
hf_token=args.hf_token,
|
669 |
+
author_name=args.author_name,
|
670 |
+
model_description=args.model_description,
|
671 |
+
training_config_type=args.training_config_type,
|
672 |
+
model_name=args.model_name,
|
673 |
+
dataset_name=args.dataset_name,
|
674 |
+
batch_size=args.batch_size,
|
675 |
+
learning_rate=args.learning_rate,
|
676 |
+
max_epochs=args.max_epochs,
|
677 |
+
max_seq_length=args.max_seq_length,
|
678 |
+
trainer_type=args.trainer_type
|
679 |
+
)
|
680 |
+
|
681 |
+
# Push model
|
682 |
+
success = pusher.push_model()
|
683 |
+
|
684 |
+
if success:
|
685 |
+
logger.info("✅ Model push completed successfully!")
|
686 |
+
logger.info(f"🌐 View your model at: https://huggingface.co/{args.repo_name}")
|
687 |
+
if args.dataset_repo:
|
688 |
+
logger.info(f"📊 View experiment data at: https://huggingface.co/datasets/{args.dataset_repo}")
|
689 |
+
else:
|
690 |
+
logger.error("❌ Model push failed!")
|
691 |
+
return 1
|
692 |
+
|
693 |
+
except Exception as e:
|
694 |
+
logger.error(f"❌ Error during model push: {e}")
|
695 |
+
return 1
|
696 |
+
|
697 |
+
return 0
|
698 |
+
|
699 |
+
if __name__ == "__main__":
|
700 |
+
exit(main())
|
train_lora.py → scripts/train.py
RENAMED
@@ -1,14 +1,16 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
|
|
|
|
|
|
3 |
import torch
|
4 |
-
from datasets import load_dataset, Audio
|
5 |
from transformers import (
|
6 |
VoxtralForConditionalGeneration,
|
7 |
VoxtralProcessor,
|
8 |
Trainer,
|
9 |
TrainingArguments,
|
10 |
)
|
11 |
-
from peft import LoraConfig, get_peft_model
|
12 |
|
13 |
|
14 |
class VoxtralDataCollator:
|
@@ -95,82 +97,114 @@ class VoxtralDataCollator:
|
|
95 |
|
96 |
return batch
|
97 |
|
98 |
-
def
|
99 |
-
"""Load
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return train_dataset, eval_dataset
|
113 |
|
114 |
|
115 |
def main():
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
122 |
print(f"Using device: {torch_device}")
|
123 |
-
|
124 |
-
# Load processor and model
|
125 |
print("Loading processor and model...")
|
126 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
127 |
-
# Load model with LoRA configuration
|
128 |
-
config = LoraConfig(
|
129 |
-
r=8, # Rank of LoRA
|
130 |
-
lora_alpha=32,
|
131 |
-
lora_dropout=0.0,
|
132 |
-
bias="none",
|
133 |
-
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
134 |
-
task_type="SEQ_2_SEQ_LM",
|
135 |
-
)
|
136 |
-
# print number of parameters in model
|
137 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
138 |
model_checkpoint,
|
139 |
torch_dtype=torch.bfloat16,
|
140 |
device_map="auto"
|
141 |
)
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
# Setup data collator
|
152 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
153 |
-
|
154 |
-
# Simple training arguments
|
155 |
training_args = TrainingArguments(
|
156 |
output_dir=output_dir,
|
157 |
-
per_device_train_batch_size=
|
158 |
-
per_device_eval_batch_size=
|
159 |
-
gradient_accumulation_steps=
|
160 |
-
learning_rate=
|
161 |
-
num_train_epochs=
|
162 |
bf16=True,
|
163 |
-
logging_steps=
|
164 |
-
eval_steps=
|
165 |
-
save_steps=
|
166 |
eval_strategy="steps" if eval_dataset else "no",
|
167 |
save_strategy="steps",
|
168 |
report_to="none",
|
169 |
remove_unused_columns=False,
|
170 |
dataloader_num_workers=1,
|
171 |
)
|
172 |
-
|
173 |
-
# Setup trainer
|
174 |
trainer = Trainer(
|
175 |
model=model,
|
176 |
args=training_args,
|
@@ -178,22 +212,18 @@ def main():
|
|
178 |
eval_dataset=eval_dataset,
|
179 |
data_collator=data_collator,
|
180 |
)
|
181 |
-
|
182 |
-
# Start training
|
183 |
print("Starting training...")
|
184 |
trainer.train()
|
185 |
|
186 |
-
|
187 |
-
# Save model and processor
|
188 |
print(f"Saving model to {output_dir}")
|
189 |
trainer.save_model()
|
190 |
processor.save_pretrained(output_dir)
|
191 |
-
|
192 |
-
# Final evaluation
|
193 |
if eval_dataset:
|
194 |
results = trainer.evaluate()
|
195 |
print(f"Final evaluation results: {results}")
|
196 |
-
|
197 |
print("Training completed successfully!")
|
198 |
|
199 |
if __name__ == "__main__":
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
import torch
|
7 |
+
from datasets import load_dataset, Audio, Dataset
|
8 |
from transformers import (
|
9 |
VoxtralForConditionalGeneration,
|
10 |
VoxtralProcessor,
|
11 |
Trainer,
|
12 |
TrainingArguments,
|
13 |
)
|
|
|
14 |
|
15 |
|
16 |
class VoxtralDataCollator:
|
|
|
97 |
|
98 |
return batch
|
99 |
|
100 |
+
def _load_jsonl_dataset(jsonl_path: str) -> Dataset:
|
101 |
+
"""Load local JSONL with fields {audio_path, text} into a Dataset with audio column."""
|
102 |
+
records = []
|
103 |
+
jsonl_file = Path(jsonl_path)
|
104 |
+
if not jsonl_file.exists():
|
105 |
+
raise FileNotFoundError(f"Dataset jsonl not found: {jsonl_path}")
|
106 |
+
with open(jsonl_file, "r", encoding="utf-8") as f:
|
107 |
+
for line in f:
|
108 |
+
if not line.strip():
|
109 |
+
continue
|
110 |
+
obj = json.loads(line)
|
111 |
+
audio_path = obj.get("audio_path") or obj.get("audio")
|
112 |
+
text = obj.get("text")
|
113 |
+
if not audio_path or text is None:
|
114 |
+
continue
|
115 |
+
records.append({"audio": audio_path, "text": text})
|
116 |
+
if not records:
|
117 |
+
raise ValueError("No valid records found in JSONL. Expect keys: audio_path, text")
|
118 |
+
ds = Dataset.from_list(records)
|
119 |
+
# Cast the audio column from file paths and resample to 16kHz
|
120 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
121 |
+
return ds
|
122 |
+
|
123 |
+
|
124 |
+
def load_and_prepare_dataset(dataset_jsonl: str | None, dataset_name: str | None, dataset_config: str | None,
|
125 |
+
train_count: int, eval_count: int):
|
126 |
+
"""Load and prepare dataset for training.
|
127 |
+
|
128 |
+
Priority: local JSONL > HF dataset name/config > fallback tiny sample.
|
129 |
+
"""
|
130 |
+
if dataset_jsonl:
|
131 |
+
print(f"Loading local JSONL dataset: {dataset_jsonl}")
|
132 |
+
ds = _load_jsonl_dataset(dataset_jsonl)
|
133 |
+
else:
|
134 |
+
ds_name = dataset_name or "hf-audio/esb-datasets-test-only-sorted"
|
135 |
+
ds_cfg = dataset_config or "voxpopuli"
|
136 |
+
print(f"Loading dataset: {ds_name}/{ds_cfg}")
|
137 |
+
ds = load_dataset(ds_name, ds_cfg, split="test")
|
138 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
139 |
+
|
140 |
+
total = len(ds)
|
141 |
+
train_end = min(train_count, total)
|
142 |
+
eval_end = min(train_end + eval_count, total)
|
143 |
+
train_dataset = ds.select(range(train_end))
|
144 |
+
eval_dataset = ds.select(range(train_end, eval_end)) if eval_end > train_end else None
|
145 |
return train_dataset, eval_dataset
|
146 |
|
147 |
|
148 |
def main():
|
149 |
+
parser = argparse.ArgumentParser(description="Full fine-tune Voxtral for ASR")
|
150 |
+
parser.add_argument("--dataset-jsonl", type=str, default=None, help="Path to local JSONL with {audio_path, text}")
|
151 |
+
parser.add_argument("--dataset-name", type=str, default=None, help="HF dataset repo (if not using JSONL)")
|
152 |
+
parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
|
153 |
+
parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
|
154 |
+
parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
|
155 |
+
parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
|
156 |
+
parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
|
157 |
+
parser.add_argument("--batch-size", type=int, default=2)
|
158 |
+
parser.add_argument("--eval-batch-size", type=int, default=4)
|
159 |
+
parser.add_argument("--grad-accum", type=int, default=4)
|
160 |
+
parser.add_argument("--learning-rate", type=float, default=5e-5)
|
161 |
+
parser.add_argument("--epochs", type=float, default=3)
|
162 |
+
parser.add_argument("--logging-steps", type=int, default=10)
|
163 |
+
parser.add_argument("--save-steps", type=int, default=50)
|
164 |
+
args = parser.parse_args()
|
165 |
+
|
166 |
+
model_checkpoint = args.model_checkpoint
|
167 |
+
output_dir = args.output_dir
|
168 |
+
|
169 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
170 |
print(f"Using device: {torch_device}")
|
171 |
+
|
|
|
172 |
print("Loading processor and model...")
|
173 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
175 |
model_checkpoint,
|
176 |
torch_dtype=torch.bfloat16,
|
177 |
device_map="auto"
|
178 |
)
|
179 |
+
|
180 |
+
train_dataset, eval_dataset = load_and_prepare_dataset(
|
181 |
+
dataset_jsonl=args.dataset_jsonl,
|
182 |
+
dataset_name=args.dataset_name,
|
183 |
+
dataset_config=args.dataset_config,
|
184 |
+
train_count=args.train_count,
|
185 |
+
eval_count=args.eval_count,
|
186 |
+
)
|
187 |
+
|
|
|
188 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
189 |
+
|
|
|
190 |
training_args = TrainingArguments(
|
191 |
output_dir=output_dir,
|
192 |
+
per_device_train_batch_size=args.batch_size,
|
193 |
+
per_device_eval_batch_size=args.eval_batch_size,
|
194 |
+
gradient_accumulation_steps=args.grad_accum,
|
195 |
+
learning_rate=args.learning_rate,
|
196 |
+
num_train_epochs=args.epochs,
|
197 |
bf16=True,
|
198 |
+
logging_steps=args.logging_steps,
|
199 |
+
eval_steps=args.save_steps if eval_dataset else None,
|
200 |
+
save_steps=args.save_steps,
|
201 |
eval_strategy="steps" if eval_dataset else "no",
|
202 |
save_strategy="steps",
|
203 |
report_to="none",
|
204 |
remove_unused_columns=False,
|
205 |
dataloader_num_workers=1,
|
206 |
)
|
207 |
+
|
|
|
208 |
trainer = Trainer(
|
209 |
model=model,
|
210 |
args=training_args,
|
|
|
212 |
eval_dataset=eval_dataset,
|
213 |
data_collator=data_collator,
|
214 |
)
|
215 |
+
|
|
|
216 |
print("Starting training...")
|
217 |
trainer.train()
|
218 |
|
|
|
|
|
219 |
print(f"Saving model to {output_dir}")
|
220 |
trainer.save_model()
|
221 |
processor.save_pretrained(output_dir)
|
222 |
+
|
|
|
223 |
if eval_dataset:
|
224 |
results = trainer.evaluate()
|
225 |
print(f"Final evaluation results: {results}")
|
226 |
+
|
227 |
print("Training completed successfully!")
|
228 |
|
229 |
if __name__ == "__main__":
|
train.py → scripts/train_lora.py
RENAMED
@@ -1,13 +1,17 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
|
|
|
|
|
|
3 |
import torch
|
4 |
-
from datasets import load_dataset, Audio
|
5 |
from transformers import (
|
6 |
VoxtralForConditionalGeneration,
|
7 |
VoxtralProcessor,
|
8 |
Trainer,
|
9 |
TrainingArguments,
|
10 |
)
|
|
|
11 |
|
12 |
|
13 |
class VoxtralDataCollator:
|
@@ -94,68 +98,128 @@ class VoxtralDataCollator:
|
|
94 |
|
95 |
return batch
|
96 |
|
97 |
-
def
|
98 |
-
"""Load
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
return train_dataset, eval_dataset
|
112 |
|
113 |
|
114 |
def main():
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
121 |
print(f"Using device: {torch_device}")
|
122 |
-
|
123 |
-
# Load processor and model
|
124 |
print("Loading processor and model...")
|
125 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
128 |
model_checkpoint,
|
129 |
torch_dtype=torch.bfloat16,
|
130 |
device_map="auto"
|
131 |
)
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
138 |
-
|
139 |
-
# Simple training arguments
|
140 |
training_args = TrainingArguments(
|
141 |
output_dir=output_dir,
|
142 |
-
per_device_train_batch_size=
|
143 |
-
per_device_eval_batch_size=
|
144 |
-
gradient_accumulation_steps=
|
145 |
-
learning_rate=
|
146 |
-
num_train_epochs=
|
147 |
bf16=True,
|
148 |
-
logging_steps=
|
149 |
-
eval_steps=
|
150 |
-
save_steps=
|
151 |
eval_strategy="steps" if eval_dataset else "no",
|
152 |
save_strategy="steps",
|
153 |
report_to="none",
|
154 |
remove_unused_columns=False,
|
155 |
dataloader_num_workers=1,
|
156 |
)
|
157 |
-
|
158 |
-
# Setup trainer
|
159 |
trainer = Trainer(
|
160 |
model=model,
|
161 |
args=training_args,
|
@@ -163,22 +227,18 @@ def main():
|
|
163 |
eval_dataset=eval_dataset,
|
164 |
data_collator=data_collator,
|
165 |
)
|
166 |
-
|
167 |
-
# Start training
|
168 |
print("Starting training...")
|
169 |
trainer.train()
|
170 |
|
171 |
-
|
172 |
-
# Save model and processor
|
173 |
print(f"Saving model to {output_dir}")
|
174 |
trainer.save_model()
|
175 |
processor.save_pretrained(output_dir)
|
176 |
-
|
177 |
-
# Final evaluation
|
178 |
if eval_dataset:
|
179 |
results = trainer.evaluate()
|
180 |
print(f"Final evaluation results: {results}")
|
181 |
-
|
182 |
print("Training completed successfully!")
|
183 |
|
184 |
if __name__ == "__main__":
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
import torch
|
7 |
+
from datasets import load_dataset, Audio, Dataset
|
8 |
from transformers import (
|
9 |
VoxtralForConditionalGeneration,
|
10 |
VoxtralProcessor,
|
11 |
Trainer,
|
12 |
TrainingArguments,
|
13 |
)
|
14 |
+
from peft import LoraConfig, get_peft_model
|
15 |
|
16 |
|
17 |
class VoxtralDataCollator:
|
|
|
98 |
|
99 |
return batch
|
100 |
|
101 |
+
def _load_jsonl_dataset(jsonl_path: str) -> Dataset:
|
102 |
+
"""Load local JSONL with fields {audio_path, text} into a Dataset with audio column."""
|
103 |
+
records = []
|
104 |
+
jsonl_file = Path(jsonl_path)
|
105 |
+
if not jsonl_file.exists():
|
106 |
+
raise FileNotFoundError(f"Dataset jsonl not found: {jsonl_path}")
|
107 |
+
with open(jsonl_file, "r", encoding="utf-8") as f:
|
108 |
+
for line in f:
|
109 |
+
if not line.strip():
|
110 |
+
continue
|
111 |
+
obj = json.loads(line)
|
112 |
+
audio_path = obj.get("audio_path") or obj.get("audio")
|
113 |
+
text = obj.get("text")
|
114 |
+
if not audio_path or text is None:
|
115 |
+
continue
|
116 |
+
records.append({"audio": audio_path, "text": text})
|
117 |
+
if not records:
|
118 |
+
raise ValueError("No valid records found in JSONL. Expect keys: audio_path, text")
|
119 |
+
ds = Dataset.from_list(records)
|
120 |
+
# Cast the audio column from file paths and resample to 16kHz
|
121 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
122 |
+
return ds
|
123 |
+
|
124 |
+
|
125 |
+
def load_and_prepare_dataset(dataset_jsonl: str | None, dataset_name: str | None, dataset_config: str | None,
|
126 |
+
train_count: int, eval_count: int):
|
127 |
+
"""Load and prepare dataset for training (JSONL or HF hub)."""
|
128 |
+
if dataset_jsonl:
|
129 |
+
print(f"Loading local JSONL dataset: {dataset_jsonl}")
|
130 |
+
ds = _load_jsonl_dataset(dataset_jsonl)
|
131 |
+
else:
|
132 |
+
ds_name = dataset_name or "hf-audio/esb-datasets-test-only-sorted"
|
133 |
+
ds_cfg = dataset_config or "voxpopuli"
|
134 |
+
print(f"Loading dataset: {ds_name}/{ds_cfg}")
|
135 |
+
ds = load_dataset(ds_name, ds_cfg, split="test")
|
136 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
137 |
+
|
138 |
+
total = len(ds)
|
139 |
+
train_end = min(train_count, total)
|
140 |
+
eval_end = min(train_end + eval_count, total)
|
141 |
+
train_dataset = ds.select(range(train_end))
|
142 |
+
eval_dataset = ds.select(range(train_end, eval_end)) if eval_end > train_end else None
|
143 |
return train_dataset, eval_dataset
|
144 |
|
145 |
|
146 |
def main():
|
147 |
+
parser = argparse.ArgumentParser(description="LoRA fine-tune Voxtral for ASR")
|
148 |
+
parser.add_argument("--dataset-jsonl", type=str, default=None, help="Path to local JSONL with {audio_path, text}")
|
149 |
+
parser.add_argument("--dataset-name", type=str, default=None, help="HF dataset repo (if not using JSONL)")
|
150 |
+
parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
|
151 |
+
parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
|
152 |
+
parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
|
153 |
+
parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
|
154 |
+
parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
|
155 |
+
parser.add_argument("--batch-size", type=int, default=2)
|
156 |
+
parser.add_argument("--eval-batch-size", type=int, default=4)
|
157 |
+
parser.add_argument("--grad-accum", type=int, default=4)
|
158 |
+
parser.add_argument("--learning-rate", type=float, default=5e-5)
|
159 |
+
parser.add_argument("--epochs", type=float, default=3)
|
160 |
+
parser.add_argument("--logging-steps", type=int, default=10)
|
161 |
+
parser.add_argument("--save-steps", type=int, default=50)
|
162 |
+
parser.add_argument("--lora-r", type=int, default=8)
|
163 |
+
parser.add_argument("--lora-alpha", type=int, default=32)
|
164 |
+
parser.add_argument("--lora-dropout", type=float, default=0.0)
|
165 |
+
parser.add_argument("--freeze-audio-tower", action="store_true", help="Freeze audio encoder parameters")
|
166 |
+
args = parser.parse_args()
|
167 |
+
|
168 |
+
model_checkpoint = args.model_checkpoint
|
169 |
+
output_dir = args.output_dir
|
170 |
+
|
171 |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
172 |
print(f"Using device: {torch_device}")
|
173 |
+
|
|
|
174 |
print("Loading processor and model...")
|
175 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
176 |
+
lora_cfg = LoraConfig(
|
177 |
+
r=args.lora_r,
|
178 |
+
lora_alpha=args.lora_alpha,
|
179 |
+
lora_dropout=args.lora_dropout,
|
180 |
+
bias="none",
|
181 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
182 |
+
task_type="SEQ_2_SEQ_LM",
|
183 |
+
)
|
184 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
185 |
model_checkpoint,
|
186 |
torch_dtype=torch.bfloat16,
|
187 |
device_map="auto"
|
188 |
)
|
189 |
+
if args.freeze_audio_tower:
|
190 |
+
for param in model.audio_tower.parameters():
|
191 |
+
param.requires_grad = False
|
192 |
+
model = get_peft_model(model, lora_cfg)
|
193 |
+
model.print_trainable_parameters()
|
194 |
+
|
195 |
+
train_dataset, eval_dataset = load_and_prepare_dataset(
|
196 |
+
dataset_jsonl=args.dataset_jsonl,
|
197 |
+
dataset_name=args.dataset_name,
|
198 |
+
dataset_config=args.dataset_config,
|
199 |
+
train_count=args.train_count,
|
200 |
+
eval_count=args.eval_count,
|
201 |
+
)
|
202 |
+
|
203 |
data_collator = VoxtralDataCollator(processor, model_checkpoint)
|
204 |
+
|
|
|
205 |
training_args = TrainingArguments(
|
206 |
output_dir=output_dir,
|
207 |
+
per_device_train_batch_size=args.batch_size,
|
208 |
+
per_device_eval_batch_size=args.eval_batch_size,
|
209 |
+
gradient_accumulation_steps=args.grad_accum,
|
210 |
+
learning_rate=args.learning_rate,
|
211 |
+
num_train_epochs=args.epochs,
|
212 |
bf16=True,
|
213 |
+
logging_steps=args.logging_issues if hasattr(args, 'logging_issues') else args.logging_steps,
|
214 |
+
eval_steps=args.save_steps if eval_dataset else None,
|
215 |
+
save_steps=args.save_steps,
|
216 |
eval_strategy="steps" if eval_dataset else "no",
|
217 |
save_strategy="steps",
|
218 |
report_to="none",
|
219 |
remove_unused_columns=False,
|
220 |
dataloader_num_workers=1,
|
221 |
)
|
222 |
+
|
|
|
223 |
trainer = Trainer(
|
224 |
model=model,
|
225 |
args=training_args,
|
|
|
227 |
eval_dataset=eval_dataset,
|
228 |
data_collator=data_collator,
|
229 |
)
|
230 |
+
|
|
|
231 |
print("Starting training...")
|
232 |
trainer.train()
|
233 |
|
|
|
|
|
234 |
print(f"Saving model to {output_dir}")
|
235 |
trainer.save_model()
|
236 |
processor.save_pretrained(output_dir)
|
237 |
+
|
|
|
238 |
if eval_dataset:
|
239 |
results = trainer.evaluate()
|
240 |
print(f"Final evaluation results: {results}")
|
241 |
+
|
242 |
print("Training completed successfully!")
|
243 |
|
244 |
if __name__ == "__main__":
|
templates/datasets/readme.md
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
dataset_info:
|
3 |
+
features:
|
4 |
+
- name: experiment_id
|
5 |
+
dtype: string
|
6 |
+
- name: name
|
7 |
+
dtype: string
|
8 |
+
- name: description
|
9 |
+
dtype: string
|
10 |
+
- name: created_at
|
11 |
+
dtype: string
|
12 |
+
- name: status
|
13 |
+
dtype: string
|
14 |
+
- name: metrics
|
15 |
+
dtype: string
|
16 |
+
- name: parameters
|
17 |
+
dtype: string
|
18 |
+
- name: artifacts
|
19 |
+
dtype: string
|
20 |
+
- name: logs
|
21 |
+
dtype: string
|
22 |
+
- name: last_updated
|
23 |
+
dtype: string
|
24 |
+
splits:
|
25 |
+
- name: train
|
26 |
+
num_bytes: 4945
|
27 |
+
num_examples: 2
|
28 |
+
download_size: 15529
|
29 |
+
dataset_size: 4945
|
30 |
+
configs:
|
31 |
+
- config_name: default
|
32 |
+
data_files:
|
33 |
+
- split: train
|
34 |
+
path: data/train-*
|
35 |
+
tags:
|
36 |
+
- track tonic
|
37 |
+
- tonic
|
38 |
+
- experiment tracking
|
39 |
+
- smollm3
|
40 |
+
- fine-tuning
|
41 |
+
- legml
|
42 |
+
- hermes
|
43 |
+
---
|
44 |
+
|
45 |
+
# Trackio Experiments Dataset
|
46 |
+
|
47 |
+
This dataset stores experiment tracking data for ML training runs, particularly focused on SmolLM3 fine-tuning experiments with comprehensive metrics tracking.
|
48 |
+
|
49 |
+
## Dataset Structure
|
50 |
+
|
51 |
+
The dataset contains the following columns:
|
52 |
+
|
53 |
+
- **experiment_id**: Unique identifier for each experiment
|
54 |
+
- **name**: Human-readable name for the experiment
|
55 |
+
- **description**: Detailed description of the experiment
|
56 |
+
- **created_at**: Timestamp when the experiment was created
|
57 |
+
- **status**: Current status (running, completed, failed, paused)
|
58 |
+
- **metrics**: JSON string containing training metrics over time
|
59 |
+
- **parameters**: JSON string containing experiment configuration
|
60 |
+
- **artifacts**: JSON string containing experiment artifacts
|
61 |
+
- **logs**: JSON string containing experiment logs
|
62 |
+
- **last_updated**: Timestamp of last update
|
63 |
+
|
64 |
+
## Metrics Structure
|
65 |
+
|
66 |
+
The metrics field contains JSON arrays with the following structure:
|
67 |
+
|
68 |
+
```json
|
69 |
+
[
|
70 |
+
{
|
71 |
+
"timestamp": "2025-07-20T11:20:01.780908",
|
72 |
+
"step": 25,
|
73 |
+
"metrics": {
|
74 |
+
"loss": 1.1659,
|
75 |
+
"accuracy": 0.759,
|
76 |
+
"learning_rate": 7e-08,
|
77 |
+
"grad_norm": 10.3125,
|
78 |
+
"epoch": 0.004851130919895701,
|
79 |
+
|
80 |
+
// Advanced Training Metrics
|
81 |
+
"total_tokens": 1642080.0,
|
82 |
+
"truncated_tokens": 128,
|
83 |
+
"padding_tokens": 256,
|
84 |
+
"throughput": 3284160.0,
|
85 |
+
"step_time": 0.5,
|
86 |
+
"batch_size": 8,
|
87 |
+
"seq_len": 2048,
|
88 |
+
"token_acc": 0.759,
|
89 |
+
|
90 |
+
// Custom Losses
|
91 |
+
"train/gate_ortho": 0.0234,
|
92 |
+
"train/center": 0.0156,
|
93 |
+
|
94 |
+
// System Metrics
|
95 |
+
"gpu_memory_allocated": 17.202261447906494,
|
96 |
+
"gpu_memory_reserved": 75.474609375,
|
97 |
+
"gpu_utilization": 85.2,
|
98 |
+
"cpu_percent": 2.7,
|
99 |
+
"memory_percent": 10.1
|
100 |
+
}
|
101 |
+
}
|
102 |
+
]
|
103 |
+
```
|
104 |
+
|
105 |
+
## Supported Metrics
|
106 |
+
|
107 |
+
### Core Training Metrics
|
108 |
+
- **loss**: Training loss value
|
109 |
+
- **accuracy**: Model accuracy
|
110 |
+
- **learning_rate**: Current learning rate
|
111 |
+
- **grad_norm**: Gradient norm
|
112 |
+
- **epoch**: Current epoch progress
|
113 |
+
|
114 |
+
### Advanced Token Metrics
|
115 |
+
- **total_tokens**: Total tokens processed in the batch
|
116 |
+
- **truncated_tokens**: Number of tokens truncated during processing
|
117 |
+
- **padding_tokens**: Number of padding tokens added
|
118 |
+
- **throughput**: Tokens processed per second
|
119 |
+
- **step_time**: Time taken for the current training step
|
120 |
+
- **batch_size**: Current batch size
|
121 |
+
- **seq_len**: Sequence length
|
122 |
+
- **token_acc**: Token-level accuracy
|
123 |
+
|
124 |
+
### Custom Losses (SmolLM3-specific)
|
125 |
+
- **train/gate_ortho**: Gate orthogonality loss
|
126 |
+
- **train/center**: Center loss component
|
127 |
+
|
128 |
+
### System Performance Metrics
|
129 |
+
- **gpu_memory_allocated**: GPU memory currently allocated (GB)
|
130 |
+
- **gpu_memory_reserved**: GPU memory reserved (GB)
|
131 |
+
- **gpu_utilization**: GPU utilization percentage
|
132 |
+
- **cpu_percent**: CPU usage percentage
|
133 |
+
- **memory_percent**: System memory usage percentage
|
134 |
+
|
135 |
+
## Usage
|
136 |
+
|
137 |
+
This dataset is automatically used by the Trackio monitoring system to store and retrieve experiment data. It provides persistent storage for experiment tracking across different training runs.
|
138 |
+
|
139 |
+
## Integration
|
140 |
+
|
141 |
+
The dataset is used by:
|
142 |
+
- Trackio Spaces for experiment visualization
|
143 |
+
- Training scripts for logging metrics and parameters
|
144 |
+
- Monitoring systems for experiment tracking
|
145 |
+
- SmolLM3 fine-tuning pipeline for comprehensive metrics capture
|
146 |
+
|
147 |
+
## Privacy
|
148 |
+
|
149 |
+
This dataset is private by default to ensure experiment data security. Only users with appropriate permissions can access the data.
|
150 |
+
|
151 |
+
## Examples
|
152 |
+
|
153 |
+
### Sample Experiment Entry
|
154 |
+
```json
|
155 |
+
{
|
156 |
+
"experiment_id": "exp_20250720_130853",
|
157 |
+
"name": "smollm3_finetune",
|
158 |
+
"description": "SmolLM3 fine-tuning experiment with comprehensive metrics",
|
159 |
+
"created_at": "2025-07-20T11:20:01.780908",
|
160 |
+
"status": "running",
|
161 |
+
"metrics": "[{\"timestamp\": \"2025-07-20T11:20:01.780908\", \"step\": 25, \"metrics\": {\"loss\": 1.1659, \"accuracy\": 0.759, \"total_tokens\": 1642080.0, \"throughput\": 3284160.0, \"train/gate_ortho\": 0.0234, \"train/center\": 0.0156}}]",
|
162 |
+
"parameters": "{\"model_name\": \"HuggingFaceTB/SmolLM3-3B\", \"batch_size\": 8, \"learning_rate\": 3.5e-06, \"max_seq_length\": 12288}",
|
163 |
+
"artifacts": "[]",
|
164 |
+
"logs": "[]",
|
165 |
+
"last_updated": "2025-07-20T11:20:01.780908"
|
166 |
+
}
|
167 |
+
```
|
168 |
+
|
169 |
+
## License
|
170 |
+
|
171 |
+
This dataset is part of the Trackio experiment tracking system and follows the same license as the main project.
|
templates/model_card.md
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
- fr
|
5 |
+
license: apache-2.0
|
6 |
+
library_name: transformers
|
7 |
+
tags:
|
8 |
+
- smollm3
|
9 |
+
- fine-tuned
|
10 |
+
- causal-lm
|
11 |
+
- text-generation
|
12 |
+
- tonic
|
13 |
+
- legml
|
14 |
+
{{#if quantized_models}}- quantized{{/if}}
|
15 |
+
pipeline_tag: text-generation
|
16 |
+
base_model: {{base_model}}
|
17 |
+
{{#if dataset_name}}
|
18 |
+
datasets:
|
19 |
+
- {{dataset_name}}
|
20 |
+
{{/if}}
|
21 |
+
{{#if quantized_models}}
|
22 |
+
model-index:
|
23 |
+
- name: {{model_name}}
|
24 |
+
results:
|
25 |
+
- task:
|
26 |
+
type: text-generation
|
27 |
+
dataset:
|
28 |
+
name: {{dataset_name}}
|
29 |
+
type: {{dataset_name}}
|
30 |
+
metrics:
|
31 |
+
- name: Training Loss
|
32 |
+
type: loss
|
33 |
+
value: "{{training_loss|default:'N/A'}}"
|
34 |
+
- name: Validation Loss
|
35 |
+
type: loss
|
36 |
+
value: "{{validation_loss|default:'N/A'}}"
|
37 |
+
- name: Perplexity
|
38 |
+
type: perplexity
|
39 |
+
value: "{{perplexity|default:'N/A'}}"
|
40 |
+
- name: {{model_name}} (int8 quantized)
|
41 |
+
results:
|
42 |
+
- task:
|
43 |
+
type: text-generation
|
44 |
+
dataset:
|
45 |
+
name: {{dataset_name}}
|
46 |
+
type: {{dataset_name}}
|
47 |
+
metrics:
|
48 |
+
- name: Memory Reduction
|
49 |
+
type: memory_efficiency
|
50 |
+
value: "~50%"
|
51 |
+
- name: Inference Speed
|
52 |
+
type: speed
|
53 |
+
value: "Faster"
|
54 |
+
- name: {{model_name}} (int4 quantized)
|
55 |
+
results:
|
56 |
+
- task:
|
57 |
+
type: text-generation
|
58 |
+
dataset:
|
59 |
+
name: {{dataset_name}}
|
60 |
+
type: {{dataset_name}}
|
61 |
+
metrics:
|
62 |
+
- name: Memory Reduction
|
63 |
+
type: memory_efficiency
|
64 |
+
value: "~75%"
|
65 |
+
- name: Inference Speed
|
66 |
+
type: speed
|
67 |
+
value: "Significantly Faster"
|
68 |
+
{{else}}
|
69 |
+
model-index:
|
70 |
+
- name: {{model_name}}
|
71 |
+
results:
|
72 |
+
- task:
|
73 |
+
type: text-generation
|
74 |
+
dataset:
|
75 |
+
name: {{dataset_name}}
|
76 |
+
type: {{dataset_name}}
|
77 |
+
metrics:
|
78 |
+
- name: Training Loss
|
79 |
+
type: loss
|
80 |
+
value: "{{training_loss|default:'N/A'}}"
|
81 |
+
- name: Validation Loss
|
82 |
+
type: loss
|
83 |
+
value: "{{validation_loss|default:'N/A'}}"
|
84 |
+
- name: Perplexity
|
85 |
+
type: perplexity
|
86 |
+
value: "{{perplexity|default:'N/A'}}"
|
87 |
+
{{/if}}
|
88 |
+
{{#if author_name}}
|
89 |
+
author: {{author_name}}
|
90 |
+
{{/if}}
|
91 |
+
{{#if experiment_name}}
|
92 |
+
experiment_name: {{experiment_name}}
|
93 |
+
{{/if}}
|
94 |
+
{{#if trackio_url}}
|
95 |
+
trackio_url: {{trackio_url}}
|
96 |
+
{{/if}}
|
97 |
+
{{#if dataset_repo}}
|
98 |
+
dataset_repo: {{dataset_repo}}
|
99 |
+
{{/if}}
|
100 |
+
{{#if hardware_info}}
|
101 |
+
hardware: "{{hardware_info}}"
|
102 |
+
{{/if}}
|
103 |
+
{{#if training_config_type}}
|
104 |
+
training_config: {{training_config_type}}
|
105 |
+
{{/if}}
|
106 |
+
{{#if trainer_type}}
|
107 |
+
trainer_type: {{trainer_type}}
|
108 |
+
{{/if}}
|
109 |
+
{{#if batch_size}}
|
110 |
+
batch_size: {{batch_size}}
|
111 |
+
{{/if}}
|
112 |
+
{{#if learning_rate}}
|
113 |
+
learning_rate: {{learning_rate}}
|
114 |
+
{{/if}}
|
115 |
+
{{#if max_epochs}}
|
116 |
+
max_epochs: {{max_epochs}}
|
117 |
+
{{/if}}
|
118 |
+
{{#if max_seq_length}}
|
119 |
+
max_seq_length: {{max_seq_length}}
|
120 |
+
{{/if}}
|
121 |
+
{{#if dataset_sample_size}}
|
122 |
+
dataset_sample_size: {{dataset_sample_size}}
|
123 |
+
{{/if}}
|
124 |
+
{{#if dataset_size}}
|
125 |
+
dataset_size: {{dataset_size}}
|
126 |
+
{{/if}}
|
127 |
+
{{#if dataset_format}}
|
128 |
+
dataset_format: {{dataset_format}}
|
129 |
+
{{/if}}
|
130 |
+
{{#if gradient_accumulation_steps}}
|
131 |
+
gradient_accumulation_steps: {{gradient_accumulation_steps}}
|
132 |
+
{{/if}}
|
133 |
+
---
|
134 |
+
|
135 |
+
# {{model_name}}
|
136 |
+
|
137 |
+
{{model_description}}
|
138 |
+
|
139 |
+
## Model Details
|
140 |
+
|
141 |
+
- **Base Model**: SmolLM3-3B
|
142 |
+
- **Model Type**: Causal Language Model
|
143 |
+
- **Languages**: English, French
|
144 |
+
- **License**: Apache 2.0
|
145 |
+
- **Fine-tuned**: Yes
|
146 |
+
{{#if quantized_models}}
|
147 |
+
- **Quantized Versions**: Available in subdirectories
|
148 |
+
{{/if}}
|
149 |
+
|
150 |
+
## Usage
|
151 |
+
|
152 |
+
### Main Model
|
153 |
+
|
154 |
+
```python
|
155 |
+
import torch
|
156 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
157 |
+
|
158 |
+
# Load the main model
|
159 |
+
model = AutoModelForCausalLM.from_pretrained(
|
160 |
+
"{{repo_name}}",
|
161 |
+
device_map="auto",
|
162 |
+
torch_dtype=torch.bfloat16
|
163 |
+
)
|
164 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
|
165 |
+
|
166 |
+
# Generate text
|
167 |
+
input_text = "What are we having for dinner?"
|
168 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
|
169 |
+
output = model.generate(**input_ids, max_new_tokens=50)
|
170 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
171 |
+
```
|
172 |
+
|
173 |
+
## Training Information
|
174 |
+
|
175 |
+
### Training Configuration
|
176 |
+
- **Base Model**: {{base_model}}
|
177 |
+
- **Dataset**: {{dataset_name}}
|
178 |
+
- **Training Config**: {{training_config_type}}
|
179 |
+
- **Trainer Type**: {{trainer_type}}
|
180 |
+
{{#if dataset_sample_size}}
|
181 |
+
- **Dataset Sample Size**: {{dataset_sample_size}}
|
182 |
+
{{/if}}
|
183 |
+
|
184 |
+
### Training Parameters
|
185 |
+
- **Batch Size**: {{batch_size}}
|
186 |
+
- **Gradient Accumulation**: {{gradient_accumulation_steps}}
|
187 |
+
- **Learning Rate**: {{learning_rate}}
|
188 |
+
- **Max Epochs**: {{max_epochs}}
|
189 |
+
- **Sequence Length**: {{max_seq_length}}
|
190 |
+
|
191 |
+
### Training Infrastructure
|
192 |
+
- **Hardware**: {{hardware_info}}
|
193 |
+
- **Monitoring**: Trackio integration
|
194 |
+
- **Experiment**: {{experiment_name}}
|
195 |
+
|
196 |
+
## Model Architecture
|
197 |
+
|
198 |
+
This is a fine-tuned version of the SmolLM3-3B model with the following specifications:
|
199 |
+
|
200 |
+
- **Base Model**: SmolLM3-3B
|
201 |
+
- **Parameters**: ~3B
|
202 |
+
- **Context Length**: {{max_seq_length}}
|
203 |
+
- **Languages**: English, French
|
204 |
+
- **Architecture**: Transformer-based causal language model
|
205 |
+
|
206 |
+
## Performance
|
207 |
+
|
208 |
+
The model provides:
|
209 |
+
- **Text Generation**: High-quality text generation capabilities
|
210 |
+
- **Conversation**: Natural conversation abilities
|
211 |
+
- **Multilingual**: Support for English and French
|
212 |
+
{{#if quantized_models}}
|
213 |
+
- **Quantized Versions**: Optimized for different deployment scenarios
|
214 |
+
{{/if}}
|
215 |
+
|
216 |
+
## Limitations
|
217 |
+
|
218 |
+
1. **Context Length**: Limited by the model's maximum sequence length
|
219 |
+
2. **Bias**: May inherit biases from the training data
|
220 |
+
3. **Factual Accuracy**: May generate incorrect or outdated information
|
221 |
+
4. **Safety**: Should be used responsibly with appropriate safeguards
|
222 |
+
{{#if quantized_models}}
|
223 |
+
5. **Quantization**: Quantized versions may have slightly reduced accuracy
|
224 |
+
{{/if}}
|
225 |
+
|
226 |
+
## Training Data
|
227 |
+
|
228 |
+
The model was fine-tuned on:
|
229 |
+
- **Dataset**: {{dataset_name}}
|
230 |
+
- **Size**: {{dataset_size}}
|
231 |
+
- **Format**: {{dataset_format}}
|
232 |
+
- **Languages**: English, French
|
233 |
+
|
234 |
+
## Evaluation
|
235 |
+
|
236 |
+
The model was evaluated using:
|
237 |
+
- **Metrics**: Loss, perplexity, and qualitative assessment
|
238 |
+
- **Monitoring**: Real-time tracking via Trackio
|
239 |
+
- **Validation**: Regular validation during training
|
240 |
+
|
241 |
+
## Citation
|
242 |
+
|
243 |
+
If you use this model in your research, please cite:
|
244 |
+
|
245 |
+
```bibtex
|
246 |
+
@misc{{{model_name_slug}},
|
247 |
+
title={{{{model_name}}}},
|
248 |
+
author={{{author_name}}},
|
249 |
+
year={2024},
|
250 |
+
url={https://huggingface.co/{{repo_name}}}
|
251 |
+
}
|
252 |
+
```
|
253 |
+
|
254 |
+
## License
|
255 |
+
|
256 |
+
This model is licensed under the Apache 2.0 License.
|
257 |
+
|
258 |
+
## Acknowledgments
|
259 |
+
|
260 |
+
- **Base Model**: SmolLM3-3B by HuggingFaceTB
|
261 |
+
- **Training Framework**: PyTorch, Transformers, PEFT
|
262 |
+
- **Monitoring**: Trackio integration
|
263 |
+
- **Quantization**: torchao library
|
264 |
+
|
265 |
+
## Support
|
266 |
+
|
267 |
+
For questions and support:
|
268 |
+
- Open an issue on the Hugging Face repository
|
269 |
+
- Check the model documentation
|
270 |
+
- Review the training logs and configuration
|
271 |
+
|
272 |
+
## Repository Structure
|
273 |
+
|
274 |
+
```
|
275 |
+
{{repo_name}}/
|
276 |
+
├── README.md (this file)
|
277 |
+
├── config.json
|
278 |
+
├── pytorch_model.bin
|
279 |
+
├── tokenizer.json
|
280 |
+
└── tokenizer_config.json
|
281 |
+
```
|
282 |
+
|
283 |
+
## Usage Examples
|
284 |
+
|
285 |
+
### Text Generation
|
286 |
+
```python
|
287 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
288 |
+
|
289 |
+
model = AutoModelForCausalLM.from_pretrained("{{repo_name}}")
|
290 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
|
291 |
+
|
292 |
+
text = "The future of artificial intelligence is"
|
293 |
+
inputs = tokenizer(text, return_tensors="pt")
|
294 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
295 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
296 |
+
```
|
297 |
+
|
298 |
+
### Conversation
|
299 |
+
```python
|
300 |
+
def chat_with_model(prompt, max_length=100):
|
301 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
302 |
+
outputs = model.generate(**inputs, max_new_tokens=max_length)
|
303 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
304 |
+
|
305 |
+
response = chat_with_model("Hello, how are you today?")
|
306 |
+
print(response)
|
307 |
+
```
|
308 |
+
|
309 |
+
### Advanced Usage
|
310 |
+
```python
|
311 |
+
# With generation parameters
|
312 |
+
outputs = model.generate(
|
313 |
+
**inputs,
|
314 |
+
max_new_tokens=100,
|
315 |
+
temperature=0.7,
|
316 |
+
top_p=0.9,
|
317 |
+
do_sample=True,
|
318 |
+
pad_token_id=tokenizer.eos_token_id
|
319 |
+
)
|
320 |
+
```
|
321 |
+
|
322 |
+
## Monitoring and Tracking
|
323 |
+
|
324 |
+
This model was trained with comprehensive monitoring:
|
325 |
+
- **Trackio Space**: {{trackio_url}}
|
326 |
+
- **Experiment**: {{experiment_name}}
|
327 |
+
- **Dataset Repository**: https://huggingface.co/datasets/{{dataset_repo}}
|
328 |
+
- **Training Logs**: Available in the experiment data
|
329 |
+
|
330 |
+
## Deployment
|
331 |
+
|
332 |
+
### Requirements
|
333 |
+
```bash
|
334 |
+
pip install torch transformers accelerate
|
335 |
+
{{#if quantized_models}}
|
336 |
+
pip install torchao # For quantized models
|
337 |
+
{{/if}}
|
338 |
+
```
|
339 |
+
|
340 |
+
### Hardware Requirements
|
341 |
+
- **Main Model**: GPU with 8GB+ VRAM recommended
|
342 |
+
|
343 |
+
## Changelog
|
344 |
+
|
345 |
+
- **v1.0.0**: Initial release with fine-tuned model
|
templates/spaces/demo_voxtral/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Voxtral ASR Demo
|
3 |
+
emoji: 🎙️
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: cyan
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.42.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
short_description: Interactive ASR demo for a fine-tuned Voxtral model
|
11 |
+
---
|
12 |
+
This Space serves a Voxtral ASR model for speech-to-text transcription.
|
13 |
+
Usage:
|
14 |
+
|
15 |
+
- Click Record and read the displayed phrase aloud.
|
16 |
+
- Stop recording to see the transcription.
|
17 |
+
- Works best with ~16 kHz audio; internal processing follows Voxtral's processor expectations.
|
18 |
+
|
19 |
+
Environment variables expected:
|
20 |
+
|
21 |
+
- `HF_MODEL_ID`: The model repo to load (e.g., `username/voxtral-finetune-YYYYMMDD_HHMMSS`)
|
22 |
+
- `MODEL_NAME`: Display name
|
23 |
+
- `HF_USERNAME`: For branding
|
templates/spaces/demo_voxtral/app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
|
5 |
+
|
6 |
+
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "mistralai/Voxtral-Mini-3B-2507")
|
7 |
+
MODEL_NAME = os.getenv("MODEL_NAME", HF_MODEL_ID.split("/")[-1])
|
8 |
+
HF_USERNAME = os.getenv("HF_USERNAME", "")
|
9 |
+
|
10 |
+
processor = AutoProcessor.from_pretrained(HF_MODEL_ID)
|
11 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(HF_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16)
|
12 |
+
|
13 |
+
def transcribe(audio_tuple):
|
14 |
+
if audio_tuple is None:
|
15 |
+
return "No audio provided"
|
16 |
+
sr, data = audio_tuple
|
17 |
+
inputs = processor.apply_transcription_request(language="en", model_id=HF_MODEL_ID, audio=[data], format=["WAV"], return_tensors="pt")
|
18 |
+
inputs = {k: (v.to(model.device) if hasattr(v, 'to') else v) for k, v in inputs.items()}
|
19 |
+
with torch.no_grad():
|
20 |
+
output_ids = model.generate(**inputs, max_new_tokens=256)
|
21 |
+
# Voxtral returns full sequence; decode and strip special tokens
|
22 |
+
text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
23 |
+
return text
|
24 |
+
|
25 |
+
with gr.Blocks() as demo:
|
26 |
+
gr.Markdown(f"# 🎙️ Voxtral ASR Demo — {MODEL_NAME}")
|
27 |
+
audio = gr.Audio(sources="microphone", type="numpy", label="Record or upload audio")
|
28 |
+
btn = gr.Button("Transcribe")
|
29 |
+
out = gr.Textbox(label="Transcription", lines=4)
|
30 |
+
btn.click(transcribe, inputs=[audio], outputs=[out])
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
demo.launch(mcp_server=True)
|
34 |
+
|
35 |
+
|
templates/spaces/demo_voxtral/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=5.38.2
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
datasets
|
5 |
+
soundfile
|
6 |
+
librosa
|
7 |
+
|