Joseph Pollack commited on
Commit
be9aa9f
·
unverified ·
1 Parent(s): ec1abe7

adds interface and dataset and auto push and demo

Browse files
interface.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Voxtral ASR Fine-tuning Interface
4
+
5
+ Features:
6
+ - Collect a personal voice dataset (upload WAV/FLAC + transcripts or record mic audio)
7
+ - Build a JSONL dataset ({audio_path, text}) at 16kHz
8
+ - Fine-tune Voxtral (LoRA or full) with streamed logs
9
+ - Push model to Hugging Face Hub
10
+ - Deploy a Voxtral ASR demo Space
11
+
12
+ Env tokens (optional):
13
+ - HF_WRITE_TOKEN or HF_TOKEN: write access token
14
+ - HF_READ_TOKEN: optional read token
15
+ - HF_USERNAME: fallback username if not derivable from token
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ import json
22
+ from pathlib import Path
23
+ from datetime import datetime
24
+ from typing import Any, Dict, Generator, Optional, Tuple
25
+
26
+ import gradio as gr
27
+
28
+ PROJECT_ROOT = Path(__file__).resolve().parent
29
+
30
+
31
+ def get_python() -> str:
32
+ import sys
33
+ return sys.executable or "python"
34
+
35
+
36
+ def get_username_from_token(token: str) -> Optional[str]:
37
+ try:
38
+ from huggingface_hub import HfApi # type: ignore
39
+ api = HfApi(token=token)
40
+ info = api.whoami()
41
+ if isinstance(info, dict):
42
+ return info.get("name") or info.get("username")
43
+ if isinstance(info, str):
44
+ return info
45
+ except Exception:
46
+ return None
47
+ return None
48
+
49
+
50
+ def run_command_stream(args: list[str], env: Dict[str, str], cwd: Optional[Path] = None) -> Generator[str, None, int]:
51
+ import subprocess
52
+ import shlex
53
+ yield f"$ {' '.join(shlex.quote(a) for a in ([get_python()] + args))}"
54
+ process = subprocess.Popen(
55
+ [get_python()] + args,
56
+ stdout=subprocess.PIPE,
57
+ stderr=subprocess.STDOUT,
58
+ text=True,
59
+ env=env,
60
+ cwd=str(cwd or PROJECT_ROOT),
61
+ bufsize=1,
62
+ universal_newlines=True,
63
+ )
64
+ assert process.stdout is not None
65
+ for line in iter(process.stdout.readline, ""):
66
+ yield line.rstrip()
67
+ process.stdout.close()
68
+ code = process.wait()
69
+ yield f"[exit_code={code}]"
70
+ return code
71
+
72
+
73
+ def detect_nvidia_driver() -> Tuple[bool, str]:
74
+ """Detect NVIDIA driver/GPU presence with multiple strategies.
75
+
76
+ Returns (available, human_message).
77
+ """
78
+ # 1) Try torch CUDA
79
+ try:
80
+ import torch # type: ignore
81
+ if torch.cuda.is_available():
82
+ try:
83
+ num = torch.cuda.device_count()
84
+ names = [torch.cuda.get_device_name(i) for i in range(num)]
85
+ return True, f"NVIDIA GPU detected: {', '.join(names)}"
86
+ except Exception:
87
+ return True, "NVIDIA GPU detected (torch.cuda available)"
88
+ except Exception:
89
+ pass
90
+
91
+ # 2) Try NVML via pynvml
92
+ try:
93
+ import pynvml # type: ignore
94
+ try:
95
+ pynvml.nvmlInit()
96
+ cnt = pynvml.nvmlDeviceGetCount()
97
+ names = []
98
+ for i in range(cnt):
99
+ h = pynvml.nvmlDeviceGetHandleByIndex(i)
100
+ names.append(pynvml.nvmlDeviceGetName(h).decode("utf-8", errors="ignore"))
101
+ drv = pynvml.nvmlSystemGetDriverVersion().decode("utf-8", errors="ignore")
102
+ pynvml.nvmlShutdown()
103
+ if cnt > 0:
104
+ return True, f"NVIDIA driver {drv}; GPUs: {', '.join(names)}"
105
+ except Exception:
106
+ pass
107
+ except Exception:
108
+ pass
109
+
110
+ # 3) Try nvidia-smi
111
+ try:
112
+ import subprocess
113
+ res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=3)
114
+ if res.returncode == 0 and res.stdout.strip():
115
+ return True, res.stdout.strip().splitlines()[0]
116
+ except Exception:
117
+ pass
118
+
119
+ return False, "No NVIDIA driver/GPU detected"
120
+
121
+
122
+ def duplicate_space_hint() -> str:
123
+ space_id = os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID")
124
+ if space_id:
125
+ space_url = f"https://huggingface.co/spaces/{space_id}"
126
+ dup_url = f"{space_url}?duplicate=true"
127
+ return (
128
+ f"ℹ️ No NVIDIA driver detected. If you're on Hugging Face Spaces, "
129
+ f"please duplicate this Space to GPU hardware: [Duplicate this Space]({dup_url})."
130
+ )
131
+ return (
132
+ "ℹ️ No NVIDIA driver detected. To enable training, run on a machine with an NVIDIA GPU/driver "
133
+ "or duplicate this Space on Hugging Face with GPU hardware."
134
+ )
135
+
136
+
137
+ def _write_jsonl(rows: list[dict], path: Path) -> Path:
138
+ path.parent.mkdir(parents=True, exist_ok=True)
139
+ with open(path, "w", encoding="utf-8") as f:
140
+ for r in rows:
141
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
142
+ return path
143
+
144
+
145
+ def _save_uploaded_dataset(files: list, transcripts: list[str]) -> str:
146
+ dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
147
+ dataset_dir.mkdir(parents=True, exist_ok=True)
148
+ rows: list[dict] = []
149
+ for i, fpath in enumerate(files or []):
150
+ if i >= len(transcripts):
151
+ break
152
+ rows.append({"audio_path": fpath, "text": transcripts[i] or ""})
153
+ jsonl_path = dataset_dir / "data.jsonl"
154
+ _write_jsonl(rows, jsonl_path)
155
+ return str(jsonl_path)
156
+
157
+
158
+ def _save_recordings(recordings: list[tuple[int, list]], transcripts: list[str]) -> str:
159
+ import soundfile as sf
160
+ dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
161
+ wav_dir = dataset_dir / "wavs"
162
+ wav_dir.mkdir(parents=True, exist_ok=True)
163
+ rows: list[dict] = []
164
+ for i, rec in enumerate(recordings or []):
165
+ if rec is None:
166
+ continue
167
+ if i >= len(transcripts):
168
+ break
169
+ sr, data = rec
170
+ out_path = wav_dir / f"rec_{i:04d}.wav"
171
+ sf.write(str(out_path), data, sr)
172
+ rows.append({"audio_path": str(out_path), "text": transcripts[i] or ""})
173
+ jsonl_path = dataset_dir / "data.jsonl"
174
+ _write_jsonl(rows, jsonl_path)
175
+ return str(jsonl_path)
176
+
177
+
178
+ def start_voxtral_training(
179
+ use_lora: bool,
180
+ base_model: str,
181
+ repo_short: str,
182
+ jsonl_path: str,
183
+ train_count: int,
184
+ eval_count: int,
185
+ batch_size: int,
186
+ grad_accum: int,
187
+ learning_rate: float,
188
+ epochs: float,
189
+ lora_r: int,
190
+ lora_alpha: int,
191
+ lora_dropout: float,
192
+ freeze_audio_tower: bool,
193
+ push_to_hub: bool,
194
+ deploy_demo: bool,
195
+ ) -> Generator[str, None, None]:
196
+ env = os.environ.copy()
197
+ write_token = env.get("HF_WRITE_TOKEN") or env.get("HF_TOKEN")
198
+ read_token = env.get("HF_READ_TOKEN")
199
+ username = get_username_from_token(write_token or "") or env.get("HF_USERNAME") or ""
200
+ output_dir = PROJECT_ROOT / "outputs" / repo_short
201
+
202
+ # 1) Train
203
+ script = PROJECT_ROOT / ("scripts/train_lora.py" if use_lora else "scripts/train.py")
204
+ args = [str(script)]
205
+ if jsonl_path:
206
+ args += ["--dataset-jsonl", jsonl_path]
207
+ args += [
208
+ "--model-checkpoint", base_model,
209
+ "--train-count", str(train_count),
210
+ "--eval-count", str(eval_count),
211
+ "--batch-size", str(batch_size),
212
+ "--grad-accum", str(grad_accum),
213
+ "--learning-rate", str(learning_rate),
214
+ "--epochs", str(epochs),
215
+ "--output-dir", str(output_dir),
216
+ "--save-steps", "50",
217
+ ]
218
+ if use_lora:
219
+ args += [
220
+ "--lora-r", str(lora_r),
221
+ "--lora-alpha", str(lora_alpha),
222
+ "--lora-dropout", str(lora_dropout),
223
+ ]
224
+ if freeze_audio_tower:
225
+ args += ["--freeze-audio-tower"]
226
+ for line in run_command_stream(args, env):
227
+ yield line
228
+
229
+ # 2) Push to Hub
230
+ if push_to_hub:
231
+ repo_name = f"{username}/{repo_short}" if username else repo_short
232
+ push_args = [
233
+ str(PROJECT_ROOT / "scripts/push_to_huggingface.py"),
234
+ str(output_dir),
235
+ repo_name,
236
+ ]
237
+ for line in run_command_stream(push_args, env):
238
+ yield line
239
+
240
+ # 3) Deploy demo Space
241
+ if deploy_demo and username:
242
+ deploy_args = [
243
+ str(PROJECT_ROOT / "scripts/deploy_demo_space.py"),
244
+ "--hf-token", write_token or "",
245
+ "--hf-username", username,
246
+ "--model-id", f"{username}/{repo_short}",
247
+ "--demo-type", "voxtral",
248
+ "--space-name", f"{repo_short}-demo",
249
+ ]
250
+ for line in run_command_stream(deploy_args, env):
251
+ yield line
252
+
253
+
254
+ PHRASES = [
255
+ "The quick brown fox jumps over the lazy dog.",
256
+ "Please say your full name.",
257
+ "Today is a good day to learn something new.",
258
+ "Artificial intelligence helps with many tasks.",
259
+ "I enjoy reading books and listening to music.",
260
+ "This is a sample sentence for testing speech.",
261
+ "Speak clearly and at a normal pace.",
262
+ "Numbers like one, two, three are easy to say.",
263
+ "The weather is sunny with a chance of rain.",
264
+ "Thank you for taking the time to help.",
265
+ ]
266
+
267
+ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
268
+ has_gpu, gpu_msg = detect_nvidia_driver()
269
+ if has_gpu:
270
+ gr.HTML(
271
+ f"""
272
+ <div style="background-color: rgba(59, 130, 246, 0.1); border: 1px solid rgba(59, 130, 246, 0.3); border-radius: 8px; padding: 12px; margin-bottom: 16px; text-align: center;">
273
+ <p style="color: rgb(59, 130, 246); margin: 0; font-size: 14px; font-weight: 600;">
274
+ ✅ NVIDIA GPU ready — {gpu_msg}
275
+ </p>
276
+ <p style="color: rgb(59, 130, 246); margin: 6px 0 0; font-size: 12px;">
277
+ Set HF_WRITE_TOKEN/HF_TOKEN in environment to enable Hub push.
278
+ </p>
279
+ </div>
280
+ """
281
+ )
282
+ else:
283
+ hint_md = duplicate_space_hint()
284
+ gr.HTML(
285
+ f"""
286
+ <div style="background-color: rgba(245, 158, 11, 0.1); border: 1px solid rgba(245, 158, 11, 0.3); border-radius: 8px; padding: 12px; margin-bottom: 16px; text-align: center;">
287
+ <p style="color: rgb(234, 88, 12); margin: 0; font-size: 14px; font-weight: 600;">
288
+ ⚠️ No NVIDIA GPU/driver detected — training requires a GPU runtime
289
+ </p>
290
+ <p style="color: rgb(234, 88, 12); margin: 6px 0 0; font-size: 12px;">
291
+ {hint_md}
292
+ </p>
293
+ </div>
294
+ """
295
+ )
296
+
297
+ gr.Markdown("""
298
+ # 🎙️ Voxtral ASR Fine-tuning
299
+ Read the phrases below and record them. Then start fine-tuning.
300
+ """)
301
+
302
+ jsonl_out = gr.Textbox(label="Dataset JSONL path", interactive=False, visible=True)
303
+
304
+ # Recording grid with dynamic text readouts
305
+ phrase_texts_state = gr.State(PHRASES)
306
+ phrase_markdowns: list[gr.Markdown] = []
307
+ rec_components = []
308
+ with gr.Column():
309
+ for idx, phrase in enumerate(PHRASES):
310
+ md = gr.Markdown(f"**{idx+1}. {phrase}**")
311
+ phrase_markdowns.append(md)
312
+ comp = gr.Audio(sources="microphone", type="numpy", label=f"Recording {idx+1}")
313
+ rec_components.append(comp)
314
+
315
+ # Advanced options accordion
316
+ with gr.Accordion("Advanced options", open=False):
317
+ base_model = gr.Textbox(value="mistralai/Voxtral-Mini-3B-2507", label="Base Voxtral model")
318
+ use_lora = gr.Checkbox(value=True, label="Use LoRA (parameter-efficient)")
319
+ with gr.Row():
320
+ batch_size = gr.Number(value=2, precision=0, label="Batch size")
321
+ grad_accum = gr.Number(value=4, precision=0, label="Grad accum")
322
+ with gr.Row():
323
+ learning_rate = gr.Number(value=5e-5, precision=6, label="Learning rate")
324
+ epochs = gr.Number(value=3.0, precision=2, label="Epochs")
325
+ with gr.Accordion("LoRA settings", open=False):
326
+ lora_r = gr.Number(value=8, precision=0, label="LoRA r")
327
+ lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
328
+ lora_dropout = gr.Number(value=0.0, precision=3, label="LoRA dropout")
329
+ freeze_audio_tower = gr.Checkbox(value=True, label="Freeze audio tower")
330
+ with gr.Row():
331
+ train_count = gr.Number(value=100, precision=0, label="Train samples")
332
+ eval_count = gr.Number(value=50, precision=0, label="Eval samples")
333
+ repo_short = gr.Textbox(value=f"voxtral-finetune-{datetime.now().strftime('%Y%m%d_%H%M%S')}", label="Model repo (short)")
334
+ push_to_hub = gr.Checkbox(value=True, label="Push to HF Hub after training")
335
+ deploy_demo = gr.Checkbox(value=True, label="Deploy demo Space after push")
336
+
337
+ gr.Markdown("### Upload audio + transcripts (optional)")
338
+ upload_audio = gr.File(file_count="multiple", type="filepath", label="Upload WAV/FLAC files (optional)")
339
+ transcripts_box = gr.Textbox(lines=6, label="Transcripts (one per line, aligned with files)")
340
+ save_upload_btn = gr.Button("Save uploaded dataset")
341
+
342
+ def _collect_upload(files, txt):
343
+ lines = [s.strip() for s in (txt or "").splitlines() if s.strip()]
344
+ return _save_uploaded_dataset(files or [], lines)
345
+
346
+ save_upload_btn.click(_collect_upload, [upload_audio, transcripts_box], [jsonl_out])
347
+
348
+ # Save recordings button
349
+ save_rec_btn = gr.Button("Save recordings as dataset")
350
+
351
+ def _collect_preloaded_recs(*recs_and_texts):
352
+ import soundfile as sf
353
+ dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
354
+ wav_dir = dataset_dir / "wavs"
355
+ wav_dir.mkdir(parents=True, exist_ok=True)
356
+ rows: list[dict] = []
357
+ if not recs_and_texts:
358
+ jsonl_path = dataset_dir / "data.jsonl"
359
+ _write_jsonl(rows, jsonl_path)
360
+ return str(jsonl_path)
361
+ texts = recs_and_texts[-1]
362
+ recs = recs_and_texts[:-1]
363
+ for i, rec in enumerate(recs):
364
+ if rec is None:
365
+ continue
366
+ sr, data = rec
367
+ out_path = wav_dir / f"rec_{i:04d}.wav"
368
+ sf.write(str(out_path), data, sr)
369
+ label_text = (texts[i] if isinstance(texts, list) and i < len(texts) else (PHRASES[i] if i < len(PHRASES) else ""))
370
+ rows.append({"audio_path": str(out_path), "text": label_text})
371
+ jsonl_path = dataset_dir / "data.jsonl"
372
+ _write_jsonl(rows, jsonl_path)
373
+ return str(jsonl_path)
374
+
375
+ save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_out])
376
+
377
+ # Quick sample from VoxPopuli (few random rows)
378
+ with gr.Row():
379
+ vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "ro", "hu", "cs", "nl", "fi", "hr", "sk", "sl", "et", "lt"], value="en", label="VoxPopuli language")
380
+ vp_samples = gr.Number(value=20, precision=0, label="Num samples")
381
+ vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
382
+ vp_btn = gr.Button("Use VoxPopuli sample")
383
+
384
+ def _collect_voxpopuli(lang_code: str, num_samples: int, split: str):
385
+ import sys
386
+ # Workaround for dill on Python 3.13 expecting __main__ during import
387
+ if "__main__" not in sys.modules:
388
+ sys.modules["__main__"] = sys.modules[__name__]
389
+ from datasets import load_dataset, Audio # type: ignore
390
+ import random
391
+ ds = load_dataset("facebook/voxpopuli", lang_code, split=split)
392
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
393
+ # shuffle and select
394
+ total = len(ds)
395
+ k = max(1, min(int(num_samples or 1), total))
396
+ ds = ds.shuffle(seed=random.randint(1, 10_000))
397
+ ds_sel = ds.select(range(k))
398
+
399
+ dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
400
+ rows: list[dict] = []
401
+ texts: list[str] = []
402
+ for ex in ds_sel:
403
+ audio = ex.get("audio") or {}
404
+ path = audio.get("path")
405
+ text = ex.get("normalized_text") or ex.get("raw_text") or ""
406
+ if path and text is not None:
407
+ rows.append({"audio_path": path, "text": text})
408
+ texts.append(str(text))
409
+ jsonl_path = dataset_dir / "data.jsonl"
410
+ _write_jsonl(rows, jsonl_path)
411
+ # Build markdown content updates for on-screen prompts
412
+ md_updates = []
413
+ for i in range(len(phrase_markdowns)):
414
+ t = texts[i] if i < len(texts) else ""
415
+ md_updates.append(f"**{i+1}. {t}**")
416
+ return (str(jsonl_path), texts, *md_updates)
417
+
418
+ vp_btn.click(
419
+ _collect_voxpopuli,
420
+ [vp_lang, vp_samples, vp_split],
421
+ [jsonl_out, phrase_texts_state] + phrase_markdowns,
422
+ )
423
+
424
+ start_btn = gr.Button("Start Fine-tuning")
425
+ logs_box = gr.Textbox(label="Logs", lines=20)
426
+
427
+ start_btn.click(
428
+ start_voxtral_training,
429
+ inputs=[
430
+ use_lora, base_model, repo_short, jsonl_out, train_count, eval_count,
431
+ batch_size, grad_accum, learning_rate, epochs,
432
+ lora_r, lora_alpha, lora_dropout, freeze_audio_tower,
433
+ push_to_hub, deploy_demo,
434
+ ],
435
+ outputs=[logs_box],
436
+ )
437
+
438
+
439
+ if __name__ == "__main__":
440
+ server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
441
+ server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
442
+ demo.queue().launch(server_name=server_name, server_port=server_port, mcp_server=True)
443
+
444
+
scripts/deploy_demo_space.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Demo Space Deployment Script
4
+ Deploys a Gradio demo space to Hugging Face Spaces for testing the fine-tuned model.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import json
10
+ import logging
11
+ import argparse
12
+ import subprocess
13
+ import requests
14
+ import tempfile
15
+ import shutil
16
+ from pathlib import Path
17
+ from typing import Optional, Dict, Any
18
+ import time
19
+
20
+ # Import Hugging Face Hub API
21
+ try:
22
+ from huggingface_hub import HfApi, create_repo, upload_file
23
+ HF_HUB_AVAILABLE = True
24
+ except ImportError:
25
+ HF_HUB_AVAILABLE = False
26
+ print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
27
+
28
+ # Add src to path for imports
29
+ sys.path.append(str(Path(__file__).parent.parent / "src"))
30
+
31
+ from config import SmolLM3Config
32
+
33
+ # Setup logging
34
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
35
+ logger = logging.getLogger(__name__)
36
+
37
+ class DemoSpaceDeployer:
38
+ """Deploy demo space to Hugging Face Spaces"""
39
+
40
+ def __init__(
41
+ self,
42
+ hf_token: str,
43
+ # Token used for API actions that create/update the Space (write perms)
44
+ hf_username: str,
45
+ model_id: str,
46
+ subfolder: str = "int4",
47
+ space_name: Optional[str] = None,
48
+ demo_type: Optional[str] = None,
49
+ config_file: Optional[str] = None,
50
+ # Optional token used as the Space's HF_TOKEN secret (read-only recommended)
51
+ space_secret_token: Optional[str] = None,
52
+ # Examples configuration
53
+ examples_type: Optional[str] = None,
54
+ disable_examples: Optional[bool] = None,
55
+ examples_json: Optional[str] = None,
56
+ # Branding overrides
57
+ brand_owner_name: Optional[str] = None,
58
+ brand_team_name: Optional[str] = None,
59
+ brand_discord_url: Optional[str] = None,
60
+ brand_hf_org: Optional[str] = None,
61
+ brand_hf_label: Optional[str] = None,
62
+ brand_hf_url: Optional[str] = None,
63
+ brand_gh_org: Optional[str] = None,
64
+ brand_gh_label: Optional[str] = None,
65
+ brand_gh_url: Optional[str] = None,
66
+ brand_project_name: Optional[str] = None,
67
+ brand_project_url: Optional[str] = None,
68
+ ):
69
+ self.hf_token = hf_token
70
+ # The token we will store in the Space secrets. Defaults to hf_token if not provided
71
+ self.space_secret_token = space_secret_token or hf_token
72
+ self.hf_username = hf_username
73
+ # Allow passing just a repo name without username and auto-prefix
74
+ self.model_id = model_id if "/" in model_id else f"{hf_username}/{model_id}"
75
+ self.subfolder = subfolder
76
+ self.space_name = space_name or f"{self.model_id.split('/')[-1]}-demo"
77
+ self.space_id = f"{hf_username}/{self.space_name}"
78
+ self.space_url = f"https://huggingface.co/spaces/{self.space_id}"
79
+ self.config_file = config_file
80
+
81
+ # Config-derived context
82
+ self.system_message: Optional[str] = None
83
+ self.developer_message: Optional[str] = None
84
+ self.model_identity: Optional[str] = None
85
+ self.reasoning_effort: Optional[str] = None
86
+ # Examples context
87
+ self.examples_type: Optional[str] = (examples_type or None)
88
+ self.disable_examples: Optional[bool] = (disable_examples if disable_examples is not None else None)
89
+ self.examples_json: Optional[str] = (examples_json or None)
90
+
91
+ # Determine demo type from model_id if not provided
92
+ if demo_type is None:
93
+ demo_type = self._detect_demo_type(model_id)
94
+
95
+ # Template paths based on model type
96
+ self.demo_type = demo_type
97
+ self.template_dir = Path(__file__).parent.parent / "templates" / "spaces" / f"demo_{demo_type}"
98
+ self.workspace_dir = Path.cwd()
99
+
100
+ # Initialize HF API
101
+ if HF_HUB_AVAILABLE:
102
+ self.api = HfApi(token=self.hf_token)
103
+ else:
104
+ self.api = None
105
+ logger.warning("huggingface_hub not available, using CLI fallback")
106
+
107
+ # Load optional config-specified messages
108
+ try:
109
+ self._load_config_messages()
110
+ except Exception as e:
111
+ logger.warning(f"Could not load config messages: {e}")
112
+
113
+ # Branding defaults (can be overridden via CLI)
114
+ self.brand_owner_name = brand_owner_name or self.hf_username or "Tonic"
115
+ self.brand_team_name = brand_team_name or f"Team{self.brand_owner_name}"
116
+ self.brand_discord_url = brand_discord_url or "https://discord.gg/qdfnvSPcqP"
117
+ # HF org/link
118
+ _default_hf_org = brand_hf_org or self.hf_username or "MultiTransformer"
119
+ self.brand_hf_org = _default_hf_org
120
+ self.brand_hf_label = brand_hf_label or self.brand_hf_org
121
+ self.brand_hf_url = brand_hf_url or f"https://huggingface.co/{self.brand_hf_org}"
122
+ # GitHub org/link
123
+ _default_gh_org = brand_gh_org or self.hf_username or "tonic-ai"
124
+ self.brand_gh_org = _default_gh_org
125
+ self.brand_gh_label = brand_gh_label or self.brand_gh_org
126
+ self.brand_gh_url = brand_gh_url or f"https://github.com/{self.brand_gh_org}"
127
+ # Project link
128
+ self.brand_project_name = brand_project_name or "MultiTonic"
129
+ self.brand_project_url = brand_project_url or "https://github.com/MultiTonic"
130
+
131
+ def _load_config_messages(self) -> None:
132
+ """Load system/developer/model_identity from a training config file if provided."""
133
+ if not self.config_file:
134
+ return
135
+ cfg_path = Path(self.config_file)
136
+ if not cfg_path.exists():
137
+ logger.warning(f"Config file not found: {cfg_path}")
138
+ return
139
+
140
+ # Ensure project root and config dir are importable for relative imports inside config
141
+ project_root = Path(__file__).parent.parent
142
+ if str(project_root) not in sys.path:
143
+ sys.path.insert(0, str(project_root))
144
+ cfg_dir = project_root / "config"
145
+ if str(cfg_dir) not in sys.path:
146
+ sys.path.insert(0, str(cfg_dir))
147
+
148
+ import importlib.util
149
+ spec = importlib.util.spec_from_file_location("config_module", str(cfg_path))
150
+ if not spec or not spec.loader:
151
+ return
152
+ module = importlib.util.module_from_spec(spec)
153
+ spec.loader.exec_module(module) # type: ignore
154
+ cfg = getattr(module, "config", None)
155
+ if cfg is None:
156
+ return
157
+ self.system_message = getattr(cfg, "system_message", None)
158
+ self.developer_message = getattr(cfg, "developer_message", None)
159
+ chat_kwargs = getattr(cfg, "chat_template_kwargs", None)
160
+ if isinstance(chat_kwargs, dict):
161
+ self.model_identity = chat_kwargs.get("model_identity")
162
+ self.reasoning_effort = chat_kwargs.get("reasoning_effort")
163
+
164
+ def _detect_demo_type(self, model_id: str) -> str:
165
+ """Detect the appropriate demo type based on model ID"""
166
+ model_id_lower = model_id.lower()
167
+
168
+ # Voxtral ASR models
169
+ if "voxtral" in model_id_lower:
170
+ logger.info(f"Detected Voxtral model, using demo_voxtral template")
171
+ return "voxtral"
172
+
173
+ # Check for GPT-OSS models
174
+ if "gpt-oss" in model_id_lower or "gpt_oss" in model_id_lower:
175
+ logger.info(f"Detected GPT-OSS model, using demo_gpt template")
176
+ return "gpt"
177
+
178
+ # Check for SmolLM models (default)
179
+ elif "smollm" in model_id_lower or "smol" in model_id_lower:
180
+ logger.info(f"Detected SmolLM model, using demo_smol template")
181
+ return "smol"
182
+
183
+ # Default to SmolLM for unknown models
184
+ else:
185
+ logger.info(f"Unknown model type, defaulting to demo_smol template")
186
+ return "smol"
187
+
188
+ def _generate_env_setup(self) -> str:
189
+ """Generate environment variable setup based on demo type and model"""
190
+ if self.demo_type == "gpt":
191
+ # For GPT-OSS models, we need more sophisticated environment setup
192
+ model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
193
+ import json as _json
194
+ env_setup = f"""
195
+ # Environment variables for GPT-OSS model configuration
196
+ import os
197
+ os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
198
+ os.environ['LORA_MODEL_ID'] = {_json.dumps(self.model_id)}
199
+ os.environ['BASE_MODEL_ID'] = 'openai/gpt-oss-20b'
200
+ os.environ['MODEL_SUBFOLDER'] = {_json.dumps(self.subfolder if self.subfolder else "")}
201
+ os.environ['MODEL_NAME'] = {_json.dumps(model_name)}
202
+ os.environ['MODEL_IDENTITY'] = {_json.dumps(self.model_identity or "")}
203
+ os.environ['SYSTEM_MESSAGE'] = {_json.dumps(self.system_message or (self.model_identity or ""))}
204
+ os.environ['DEVELOPER_MESSAGE'] = {_json.dumps(self.developer_message or "")}
205
+ os.environ['REASONING_EFFORT'] = {_json.dumps((self.reasoning_effort or "medium"))}
206
+ {"os.environ['EXAMPLES_TYPE'] = " + _json.dumps(self.examples_type) + "\n" if self.examples_type else ''}
207
+ {"os.environ['DISABLE_EXAMPLES'] = 'true'\n" if self.disable_examples else ("os.environ['DISABLE_EXAMPLES'] = 'false'\n" if self.disable_examples is not None else '')}
208
+ {"os.environ['EXAMPLES_JSON'] = " + _json.dumps(self.examples_json) + "\n" if self.examples_json else ''}
209
+
210
+ # Branding/owner variables
211
+ os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
212
+ os.environ['BRAND_OWNER_NAME'] = {_json.dumps(self.brand_owner_name)}
213
+ os.environ['BRAND_TEAM_NAME'] = {_json.dumps(self.brand_team_name)}
214
+ os.environ['BRAND_DISCORD_URL'] = {_json.dumps(self.brand_discord_url)}
215
+ os.environ['BRAND_HF_ORG'] = {_json.dumps(self.brand_hf_org)}
216
+ os.environ['BRAND_HF_LABEL'] = {_json.dumps(self.brand_hf_label)}
217
+ os.environ['BRAND_HF_URL'] = {_json.dumps(self.brand_hf_url)}
218
+ os.environ['BRAND_GH_ORG'] = {_json.dumps(self.brand_gh_org)}
219
+ os.environ['BRAND_GH_LABEL'] = {_json.dumps(self.brand_gh_label)}
220
+ os.environ['BRAND_GH_URL'] = {_json.dumps(self.brand_gh_url)}
221
+ os.environ['BRAND_PROJECT_NAME'] = {_json.dumps(self.brand_project_name)}
222
+ os.environ['BRAND_PROJECT_URL'] = {_json.dumps(self.brand_project_url)}
223
+
224
+ """
225
+ elif self.demo_type == "voxtral":
226
+ import json as _json
227
+ env_setup = f"""
228
+ # Environment variables for Voxtral ASR demo
229
+ import os
230
+ os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
231
+ os.environ['MODEL_NAME'] = {_json.dumps(self.model_id.split('/')[-1])}
232
+ os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
233
+ """
234
+ else:
235
+ # For SmolLM models, use simpler setup
236
+ import json as _json
237
+ env_setup = f"""
238
+ # Environment variables for model configuration
239
+ import os
240
+ os.environ['HF_MODEL_ID'] = {_json.dumps(self.model_id)}
241
+ os.environ['MODEL_SUBFOLDER'] = {_json.dumps(self.subfolder if self.subfolder else "")}
242
+ os.environ['MODEL_NAME'] = {_json.dumps(self.model_id.split("/")[-1])}
243
+ os.environ['MODEL_IDENTITY'] = {_json.dumps(self.model_identity or "")}
244
+ os.environ['SYSTEM_MESSAGE'] = {_json.dumps(self.system_message or (self.model_identity or ""))}
245
+ os.environ['DEVELOPER_MESSAGE'] = {_json.dumps(self.developer_message or "")}
246
+ os.environ['REASONING_EFFORT'] = {_json.dumps((self.reasoning_effort or "medium"))}
247
+ {"os.environ['EXAMPLES_TYPE'] = " + _json.dumps(self.examples_type) + "\n" if self.examples_type else ''}
248
+ {"os.environ['DISABLE_EXAMPLES'] = 'true'\n" if self.disable_examples else ("os.environ['DISABLE_EXAMPLES'] = 'false'\n" if self.disable_examples is not None else '')}
249
+ {"os.environ['EXAMPLES_JSON'] = " + _json.dumps(self.examples_json) + "\n" if self.examples_json else ''}
250
+
251
+ # Branding/owner variables
252
+ os.environ['HF_USERNAME'] = {_json.dumps(self.hf_username)}
253
+ os.environ['BRAND_OWNER_NAME'] = {_json.dumps(self.brand_owner_name)}
254
+ os.environ['BRAND_TEAM_NAME'] = {_json.dumps(self.brand_team_name)}
255
+ os.environ['BRAND_DISCORD_URL'] = {_json.dumps(self.brand_discord_url)}
256
+ os.environ['BRAND_HF_ORG'] = {_json.dumps(self.brand_hf_org)}
257
+ os.environ['BRAND_HF_LABEL'] = {_json.dumps(self.brand_hf_label)}
258
+ os.environ['BRAND_HF_URL'] = {_json.dumps(self.brand_hf_url)}
259
+ os.environ['BRAND_GH_ORG'] = {_json.dumps(self.brand_gh_org)}
260
+ os.environ['BRAND_GH_LABEL'] = {_json.dumps(self.brand_gh_label)}
261
+ os.environ['BRAND_GH_URL'] = {_json.dumps(self.brand_gh_url)}
262
+ os.environ['BRAND_PROJECT_NAME'] = {_json.dumps(self.brand_project_name)}
263
+ os.environ['BRAND_PROJECT_URL'] = {_json.dumps(self.brand_project_url)}
264
+
265
+ """
266
+ return env_setup
267
+
268
+ def _set_model_variables(self):
269
+ """Set model-specific environment variables in the space"""
270
+ try:
271
+ # Common variables for all models
272
+ self.api.add_space_variable(
273
+ repo_id=self.space_id,
274
+ key="HF_MODEL_ID",
275
+ value=self.model_id,
276
+ description="Model ID for the demo"
277
+ )
278
+ logger.info(f"✅ Successfully set HF_MODEL_ID variable: {self.model_id}")
279
+
280
+ if self.subfolder and self.subfolder.strip():
281
+ self.api.add_space_variable(
282
+ repo_id=self.space_id,
283
+ key="MODEL_SUBFOLDER",
284
+ value=self.subfolder,
285
+ description="Model subfolder for the demo"
286
+ )
287
+ logger.info(f"✅ Successfully set MODEL_SUBFOLDER variable: {self.subfolder}")
288
+ else:
289
+ logger.info("ℹ️ No subfolder specified, using main model")
290
+
291
+ # GPT-OSS specific variables
292
+ if self.demo_type == "gpt":
293
+ model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
294
+ self.api.add_space_variable(
295
+ repo_id=self.space_id,
296
+ key="LORA_MODEL_ID",
297
+ value=self.model_id,
298
+ description="LoRA/Fine-tuned model ID"
299
+ )
300
+ logger.info(f"✅ Successfully set LORA_MODEL_ID variable: {self.model_id}")
301
+ self.api.add_space_variable(
302
+ repo_id=self.space_id,
303
+ key="BASE_MODEL_ID",
304
+ value="openai/gpt-oss-20b",
305
+ description="Base model ID for GPT-OSS"
306
+ )
307
+ logger.info("✅ Successfully set BASE_MODEL_ID variable: openai/gpt-oss-20b")
308
+ self.api.add_space_variable(
309
+ repo_id=self.space_id,
310
+ key="MODEL_NAME",
311
+ value=model_name,
312
+ description="Display name for the model"
313
+ )
314
+ logger.info(f"✅ Successfully set MODEL_NAME variable: {model_name}")
315
+
316
+ # Voxtral-specific variables
317
+ elif self.demo_type == "voxtral":
318
+ # HF_MODEL_ID was already set above; set a readable MODEL_NAME
319
+ vox_model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
320
+ self.api.add_space_variable(
321
+ repo_id=self.space_id,
322
+ key="MODEL_NAME",
323
+ value=vox_model_name,
324
+ description="Display name for the Voxtral model"
325
+ )
326
+ logger.info(f"✅ Set Voxtral MODEL_NAME variable: {vox_model_name}")
327
+
328
+ # Optional context variables
329
+ if self.model_identity:
330
+ self.api.add_space_variable(
331
+ repo_id=self.space_id,
332
+ key="MODEL_IDENTITY",
333
+ value=self.model_identity,
334
+ description="Default model identity/system persona"
335
+ )
336
+ logger.info("✅ Set MODEL_IDENTITY variable")
337
+ if self.system_message or self.model_identity:
338
+ self.api.add_space_variable(
339
+ repo_id=self.space_id,
340
+ key="SYSTEM_MESSAGE",
341
+ value=self.system_message or self.model_identity or "",
342
+ description="Default system message"
343
+ )
344
+ logger.info("✅ Set SYSTEM_MESSAGE variable")
345
+ if self.developer_message:
346
+ self.api.add_space_variable(
347
+ repo_id=self.space_id,
348
+ key="DEVELOPER_MESSAGE",
349
+ value=self.developer_message,
350
+ description="Default developer message"
351
+ )
352
+ logger.info("✅ Set DEVELOPER_MESSAGE variable")
353
+ if self.reasoning_effort:
354
+ self.api.add_space_variable(
355
+ repo_id=self.space_id,
356
+ key="REASONING_EFFORT",
357
+ value=self.reasoning_effort,
358
+ description="Default reasoning effort (low|medium|high)"
359
+ )
360
+ logger.info("✅ Set REASONING_EFFORT variable")
361
+
362
+ # Branding variables
363
+ branding_vars = {
364
+ "HF_USERNAME": self.hf_username,
365
+ "BRAND_OWNER_NAME": self.brand_owner_name,
366
+ "BRAND_TEAM_NAME": self.brand_team_name,
367
+ "BRAND_DISCORD_URL": self.brand_discord_url,
368
+ "BRAND_HF_ORG": self.brand_hf_org,
369
+ "BRAND_HF_LABEL": self.brand_hf_label,
370
+ "BRAND_HF_URL": self.brand_hf_url,
371
+ "BRAND_GH_ORG": self.brand_gh_org,
372
+ "BRAND_GH_LABEL": self.brand_gh_label,
373
+ "BRAND_GH_URL": self.brand_gh_url,
374
+ "BRAND_PROJECT_NAME": self.brand_project_name,
375
+ "BRAND_PROJECT_URL": self.brand_project_url,
376
+ }
377
+ for key, value in branding_vars.items():
378
+ self.api.add_space_variable(
379
+ repo_id=self.space_id,
380
+ key=key,
381
+ value=value,
382
+ description=f"Branding: {key}"
383
+ )
384
+ logger.info("✅ Set branding variables")
385
+
386
+ # Examples variables
387
+ if self.examples_type:
388
+ self.api.add_space_variable(
389
+ repo_id=self.space_id,
390
+ key="EXAMPLES_TYPE",
391
+ value=self.examples_type,
392
+ description="Examples pack type (e.g., general|medical)"
393
+ )
394
+ logger.info(f"✅ Set EXAMPLES_TYPE={self.examples_type}")
395
+ if self.disable_examples is not None:
396
+ self.api.add_space_variable(
397
+ repo_id=self.space_id,
398
+ key="DISABLE_EXAMPLES",
399
+ value=("true" if self.disable_examples else "false"),
400
+ description="Disable built-in examples"
401
+ )
402
+ logger.info(f"✅ Set DISABLE_EXAMPLES={self.disable_examples}")
403
+ if self.examples_json:
404
+ self.api.add_space_variable(
405
+ repo_id=self.space_id,
406
+ key="EXAMPLES_JSON",
407
+ value=self.examples_json,
408
+ description="Custom examples JSON override"
409
+ )
410
+ logger.info("✅ Set EXAMPLES_JSON override")
411
+
412
+ except Exception as e:
413
+ logger.error(f"❌ Failed to set model variables: {e}")
414
+
415
+ def validate_model_exists(self) -> bool:
416
+ """Validate that the model exists on Hugging Face Hub"""
417
+ try:
418
+ logger.info(f"Validating model: {self.model_id}")
419
+
420
+ if HF_HUB_AVAILABLE:
421
+ # Use HF Hub API
422
+ try:
423
+ model_info = self.api.model_info(self.model_id)
424
+ logger.info(f"✅ Model {self.model_id} exists and is accessible")
425
+ return True
426
+ except Exception as e:
427
+ logger.error(f"❌ Model {self.model_id} not found via API: {e}")
428
+ return False
429
+ else:
430
+ # Fallback to requests
431
+ url = f"https://huggingface.co/api/models/{self.model_id}"
432
+ headers = {"Authorization": f"Bearer {self.hf_token}"}
433
+ response = requests.get(url, headers=headers, timeout=30)
434
+
435
+ if response.status_code == 200:
436
+ logger.info(f"✅ Model {self.model_id} exists and is accessible")
437
+ return True
438
+ else:
439
+ logger.error(f"❌ Model {self.model_id} not found or not accessible")
440
+ return False
441
+
442
+ except Exception as e:
443
+ logger.error(f"❌ Error validating model: {e}")
444
+ return False
445
+
446
+ def create_space_repository(self) -> bool:
447
+ """Create the space repository on Hugging Face Hub"""
448
+ try:
449
+ logger.info(f"Creating Space: {self.space_name}")
450
+
451
+ if not HF_HUB_AVAILABLE:
452
+ logger.warning("huggingface_hub not available, falling back to CLI")
453
+ return self._create_space_cli()
454
+
455
+ # Use the latest HF Hub API to create space
456
+ try:
457
+ # Create the space using the API
458
+ create_repo(
459
+ repo_id=self.space_id,
460
+ token=self.hf_token,
461
+ repo_type="space",
462
+ exist_ok=True,
463
+ private=False, # Spaces are typically public
464
+ space_sdk="gradio", # Specify Gradio SDK
465
+ space_hardware="cpu-basic" # Use basic CPU
466
+ )
467
+
468
+ logger.info(f"✅ Space created successfully: {self.space_url}")
469
+ return True
470
+
471
+ except Exception as api_error:
472
+ logger.error(f"API creation failed: {api_error}")
473
+ logger.info("Falling back to CLI method...")
474
+ return self._create_space_cli()
475
+
476
+ except Exception as e:
477
+ logger.error(f"❌ Error creating space: {e}")
478
+ return False
479
+
480
+ def _create_space_cli(self) -> bool:
481
+ """Fallback method using CLI commands"""
482
+ try:
483
+ logger.info("Using CLI fallback method...")
484
+
485
+ # Set HF token for CLI
486
+ os.environ['HF_TOKEN'] = self.hf_token
487
+
488
+ # Create space using Hugging Face CLI
489
+ cmd = [
490
+ "hf", "repo", "create",
491
+ self.space_id,
492
+ "--type", "space"
493
+ ]
494
+
495
+ logger.info(f"Running command: {' '.join(cmd)}")
496
+ result = subprocess.run(cmd, capture_output=True, text=True)
497
+
498
+ if result.returncode != 0:
499
+ logger.warning(f"First attempt failed: {result.stderr}")
500
+ # Try alternative approach without space-specific flags
501
+ logger.info("Retrying with basic space creation...")
502
+ cmd = [
503
+ "hf", "repo", "create",
504
+ self.space_id
505
+ ]
506
+ result = subprocess.run(cmd, capture_output=True, text=True)
507
+
508
+ if result.returncode == 0:
509
+ logger.info(f"✅ Space created successfully: {self.space_url}")
510
+ return True
511
+ else:
512
+ logger.error(f"❌ Failed to create space: {result.stderr}")
513
+ return False
514
+
515
+ except Exception as e:
516
+ logger.error(f"❌ Error creating space with CLI: {e}")
517
+ return False
518
+
519
+ def prepare_space_files(self) -> str:
520
+ """Prepare all necessary files for the Space in a temporary directory"""
521
+ try:
522
+ logger.info("Preparing Space files...")
523
+
524
+ # Create temporary directory
525
+ temp_dir = tempfile.mkdtemp()
526
+ logger.info(f"Created temporary directory: {temp_dir}")
527
+
528
+ # Copy template files
529
+ copied_files = []
530
+ for file_path in self.template_dir.iterdir():
531
+ if file_path.is_file():
532
+ dest_path = Path(temp_dir) / file_path.name
533
+ shutil.copy2(file_path, dest_path)
534
+ copied_files.append(file_path.name)
535
+ logger.info(f"✅ Copied {file_path.name} to temp directory")
536
+
537
+ # Update app.py with environment variables
538
+ app_file = Path(temp_dir) / "app.py"
539
+ if app_file.exists():
540
+ with open(app_file, 'r', encoding='utf-8') as f:
541
+ content = f.read()
542
+
543
+ # Add environment variable setup at the top
544
+ env_setup = self._generate_env_setup()
545
+
546
+ # Insert after imports
547
+ lines = content.split('\n')
548
+ import_end = 0
549
+ for i, line in enumerate(lines):
550
+ if line.startswith('import ') or line.startswith('from '):
551
+ import_end = i + 1
552
+ elif line.strip() == '' and import_end > 0:
553
+ break
554
+
555
+ lines.insert(import_end, env_setup)
556
+ content = '\n'.join(lines)
557
+
558
+ with open(app_file, 'w', encoding='utf-8') as f:
559
+ f.write(content)
560
+
561
+ logger.info("✅ Updated app.py with model configuration")
562
+
563
+ # YAML front matter required by Hugging Face Spaces
564
+ yaml_front_matter = (
565
+ f"---\n"
566
+ f"title: {'GPT-OSS Demo' if self.demo_type == 'gpt' else 'SmolLM3 Demo'}\n"
567
+ f"emoji: {'🌟' if self.demo_type == 'gpt' else '💃🏻'}\n"
568
+ f"colorFrom: {'blue' if self.demo_type == 'gpt' else 'green'}\n"
569
+ f"colorTo: {'pink' if self.demo_type == 'gpt' else 'purple'}\n"
570
+ f"sdk: gradio\n"
571
+ f"sdk_version: 5.40.0\n"
572
+ f"app_file: app.py\n"
573
+ f"pinned: false\n"
574
+ f"short_description: Interactive demo for {self.model_id}\n"
575
+ + ("license: mit\n" if self.demo_type != 'gpt' else "") +
576
+ f"---\n\n"
577
+ )
578
+
579
+ # Create README.md for the space (include configuration details)
580
+ readme_content = (
581
+ yaml_front_matter
582
+ + f"# Demo: {self.model_id}\n\n"
583
+ + f"This is an interactive demo for the fine-tuned model {self.model_id}.\n\n"
584
+ + "## Features\n"
585
+ "- Interactive chat interface\n"
586
+ "- Customizable system & developer prompts\n"
587
+ "- Advanced generation parameters\n"
588
+ "- Thinking mode support\n\n"
589
+ + "## Model Information\n"
590
+ f"- **Model ID**: {self.model_id}\n"
591
+ f"- **Subfolder**: {self.subfolder if self.subfolder and self.subfolder.strip() else 'main'}\n"
592
+ f"- **Deployed by**: {self.hf_username}\n"
593
+ + ("- **Base Model**: openai/gpt-oss-20b\n" if self.demo_type == 'gpt' else "")
594
+ + "\n"
595
+ + "## Configuration\n"
596
+ "- **Model Identity**:\n\n"
597
+ f"```\n{self.model_identity or 'Not set'}\n```\n\n"
598
+ "- **System Message** (default):\n\n"
599
+ f"```\n{(self.system_message or self.model_identity) or 'Not set'}\n```\n\n"
600
+ "- **Developer Message** (default):\n\n"
601
+ f"```\n{self.developer_message or 'Not set'}\n```\n\n"
602
+ "These defaults come from the selected training configuration and can be adjusted in the UI when you run the demo.\n\n"
603
+ + "## Usage\n"
604
+ "Simply start chatting with the model using the interface below!\n\n"
605
+ + "---\n"
606
+ "*This demo was automatically deployed by the SmolFactory Fine-tuning Pipeline*\n"
607
+ )
608
+
609
+ with open(Path(temp_dir) / "README.md", 'w', encoding='utf-8') as f:
610
+ f.write(readme_content)
611
+
612
+ logger.info(f"✅ Prepared {len(copied_files)} files in temporary directory")
613
+ return temp_dir
614
+
615
+ except Exception as e:
616
+ logger.error(f"❌ Error preparing files: {e}")
617
+ return None
618
+
619
+ def upload_files_to_space(self, temp_dir: str) -> bool:
620
+ """Upload files to the Space using HF Hub API directly"""
621
+ try:
622
+ logger.info("Uploading files to Space using HF Hub API...")
623
+
624
+ if not HF_HUB_AVAILABLE:
625
+ logger.error("❌ huggingface_hub not available for file upload")
626
+ return self._upload_files_cli(temp_dir)
627
+
628
+ # Upload each file using the HF Hub API
629
+ temp_path = Path(temp_dir)
630
+ uploaded_files = []
631
+
632
+ for file_path in temp_path.iterdir():
633
+ if file_path.is_file():
634
+ try:
635
+ # Upload file to the space
636
+ upload_file(
637
+ path_or_fileobj=str(file_path),
638
+ path_in_repo=file_path.name,
639
+ repo_id=self.space_id,
640
+ repo_type="space",
641
+ token=self.hf_token
642
+ )
643
+ uploaded_files.append(file_path.name)
644
+ logger.info(f"✅ Uploaded {file_path.name}")
645
+ except Exception as e:
646
+ logger.error(f"❌ Failed to upload {file_path.name}: {e}")
647
+ return False
648
+
649
+ logger.info(f"✅ Successfully uploaded {len(uploaded_files)} files to Space")
650
+ return True
651
+
652
+ except Exception as e:
653
+ logger.error(f"❌ Error uploading files: {e}")
654
+ return self._upload_files_cli(temp_dir)
655
+
656
+ def _upload_files_cli(self, temp_dir: str) -> bool:
657
+ """Fallback method using CLI for file upload"""
658
+ try:
659
+ logger.info("Using CLI fallback for file upload...")
660
+
661
+ # Set HF token for CLI
662
+ os.environ['HF_TOKEN'] = self.hf_token
663
+
664
+ # Initialize git repository
665
+ subprocess.run(["git", "init"], cwd=temp_dir, check=True)
666
+ subprocess.run(["git", "config", "user.name", "Demo Deployer"], cwd=temp_dir, check=True)
667
+ subprocess.run(["git", "config", "user.email", "[email protected]"], cwd=temp_dir, check=True)
668
+
669
+ # Add files
670
+ subprocess.run(["git", "add", "."], cwd=temp_dir, check=True)
671
+ subprocess.run(["git", "commit", "-m", f"Deploy demo for {self.model_id}"], cwd=temp_dir, check=True)
672
+
673
+ # Add remote and push
674
+ remote_url = f"https://{self.hf_token}@huggingface.co/spaces/{self.space_id}"
675
+ subprocess.run(["git", "remote", "add", "origin", remote_url], cwd=temp_dir, check=True)
676
+ subprocess.run(["git", "push", "-u", "origin", "main"], cwd=temp_dir, check=True)
677
+
678
+ logger.info(f"✅ Successfully pushed files to space: {self.space_id}")
679
+ return True
680
+
681
+ except subprocess.CalledProcessError as e:
682
+ logger.error(f"❌ Git operation failed: {e}")
683
+ return False
684
+ except Exception as e:
685
+ logger.error(f"❌ Error pushing to space: {e}")
686
+ return False
687
+
688
+ def set_space_secrets(self) -> bool:
689
+ """Set environment variables/secrets for the Space using HF Hub API"""
690
+ try:
691
+ logger.info("Setting Space secrets using HF Hub API...")
692
+
693
+ if not HF_HUB_AVAILABLE:
694
+ logger.warning("❌ huggingface_hub not available for setting secrets")
695
+ return self._manual_secret_setup()
696
+
697
+ # Set the HF_TOKEN secret for the space using the API
698
+ try:
699
+ self.api.add_space_secret(
700
+ repo_id=self.space_id,
701
+ key="HF_TOKEN",
702
+ value=self.space_secret_token,
703
+ description="Hugging Face token for model access"
704
+ )
705
+ logger.info("✅ Successfully set HF_TOKEN secret via API")
706
+
707
+ # Set model-specific environment variables
708
+ self._set_model_variables()
709
+
710
+ return True
711
+
712
+ except Exception as api_error:
713
+ logger.error(f"❌ Failed to set secrets via API: {api_error}")
714
+ logger.info("Falling back to manual setup...")
715
+ return self._manual_secret_setup()
716
+
717
+ except Exception as e:
718
+ logger.error(f"❌ Error setting space secrets: {e}")
719
+ return self._manual_secret_setup()
720
+
721
+ def _manual_secret_setup(self) -> bool:
722
+ """Fallback method for manual secret setup"""
723
+ logger.info("📝 Manual Space Secrets Configuration:")
724
+ logger.info(f" HF_TOKEN=<hidden>")
725
+ logger.info(f" HF_MODEL_ID={self.model_id}")
726
+ if self.subfolder and self.subfolder.strip():
727
+ logger.info(f" MODEL_SUBFOLDER={self.subfolder}")
728
+ else:
729
+ logger.info(" MODEL_SUBFOLDER=(empty - using main model)")
730
+
731
+ # GPT-OSS specific variables
732
+ if self.demo_type == "gpt":
733
+ model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
734
+ logger.info(f" LORA_MODEL_ID={self.model_id}")
735
+ logger.info(f" BASE_MODEL_ID=openai/gpt-oss-20b")
736
+ logger.info(f" MODEL_NAME={model_name}")
737
+ if self.model_identity:
738
+ logger.info(f" MODEL_IDENTITY={self.model_identity}")
739
+ if self.system_message:
740
+ logger.info(f" SYSTEM_MESSAGE={self.system_message}")
741
+ if self.developer_message:
742
+ logger.info(f" DEVELOPER_MESSAGE={self.developer_message}")
743
+ # Branding variables
744
+ logger.info(f" HF_USERNAME={self.hf_username}")
745
+ logger.info(f" BRAND_OWNER_NAME={self.brand_owner_name}")
746
+ logger.info(f" BRAND_TEAM_NAME={self.brand_team_name}")
747
+ logger.info(f" BRAND_DISCORD_URL={self.brand_discord_url}")
748
+ logger.info(f" BRAND_HF_ORG={self.brand_hf_org}")
749
+ logger.info(f" BRAND_HF_LABEL={self.brand_hf_label}")
750
+ logger.info(f" BRAND_HF_URL={self.brand_hf_url}")
751
+ logger.info(f" BRAND_GH_ORG={self.brand_gh_org}")
752
+ logger.info(f" BRAND_GH_LABEL={self.brand_gh_label}")
753
+ logger.info(f" BRAND_GH_URL={self.brand_gh_url}")
754
+ logger.info(f" BRAND_PROJECT_NAME={self.brand_project_name}")
755
+ logger.info(f" BRAND_PROJECT_URL={self.brand_project_url}")
756
+
757
+ # Examples variables
758
+ if self.examples_type:
759
+ logger.info(f" EXAMPLES_TYPE={self.examples_type}")
760
+ if self.disable_examples is not None:
761
+ logger.info(f" DISABLE_EXAMPLES={'true' if self.disable_examples else 'false'}")
762
+ if self.examples_json:
763
+ logger.info(f" EXAMPLES_JSON={self.examples_json}")
764
+
765
+ logger.info(f"\n🔧 To set secrets in your Space:")
766
+ logger.info(f"1. Go to your Space settings: {self.space_url}/settings")
767
+ logger.info("2. Navigate to the 'Repository secrets' section")
768
+ logger.info("3. Add the following secrets:")
769
+ logger.info(f" Name: HF_TOKEN")
770
+ logger.info(f" Value: <your token>")
771
+ logger.info(f" Name: HF_MODEL_ID")
772
+ logger.info(f" Value: {self.model_id}")
773
+ if self.subfolder and self.subfolder.strip():
774
+ logger.info(f" Name: MODEL_SUBFOLDER")
775
+ logger.info(f" Value: {self.subfolder}")
776
+ else:
777
+ logger.info(" Name: MODEL_SUBFOLDER")
778
+ logger.info(" Value: (leave empty)")
779
+
780
+ # GPT-OSS specific variables
781
+ if self.demo_type == "gpt":
782
+ model_name = self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
783
+ logger.info(f" Name: LORA_MODEL_ID")
784
+ logger.info(f" Value: {self.model_id}")
785
+ logger.info(f" Name: BASE_MODEL_ID")
786
+ logger.info(f" Value: openai/gpt-oss-20b")
787
+ logger.info(f" Name: MODEL_NAME")
788
+ logger.info(f" Value: {model_name}")
789
+
790
+ logger.info("4. Save the secrets")
791
+
792
+ return True
793
+
794
+ def test_space(self) -> bool:
795
+ """Test if the Space is working correctly"""
796
+ try:
797
+ logger.info("Testing Space...")
798
+
799
+ # Wait a bit for the space to build
800
+ logger.info("Waiting 180 seconds for Space to build...")
801
+ time.sleep(180)
802
+
803
+ # Try to access the space
804
+ response = requests.get(self.space_url, timeout=30)
805
+
806
+ if response.status_code == 200:
807
+ logger.info(f"✅ Space is accessible: {self.space_url}")
808
+ return True
809
+ else:
810
+ logger.warning(f"⚠️ Space returned status code: {response.status_code}")
811
+ logger.warning(f"Response: {response.text[:500]}...")
812
+ return False
813
+
814
+ except Exception as e:
815
+ logger.error(f"❌ Error testing space: {e}")
816
+ return False
817
+
818
+ def deploy(self) -> bool:
819
+ """Main deployment method"""
820
+ logger.info(f"🚀 Starting demo space deployment for {self.model_id}")
821
+
822
+ # Step 1: Validate model exists
823
+ if not self.validate_model_exists():
824
+ return False
825
+
826
+ # Step 2: Create space repository
827
+ if not self.create_space_repository():
828
+ return False
829
+
830
+ # Step 3: Prepare files
831
+ temp_dir = self.prepare_space_files()
832
+ if not temp_dir:
833
+ return False
834
+
835
+ # Step 4: Upload files
836
+ if not self.upload_files_to_space(temp_dir):
837
+ return False
838
+
839
+ # Step 5: Set space secrets
840
+ if not self.set_space_secrets():
841
+ return False
842
+
843
+ # Step 6: Clean up temp directory
844
+ try:
845
+ shutil.rmtree(temp_dir)
846
+ logger.info("✅ Cleaned up temporary directory")
847
+ except Exception as e:
848
+ logger.warning(f"⚠️ Warning: Could not clean up temp directory: {e}")
849
+
850
+ # Step 7: Test space
851
+ if not self.test_space():
852
+ logger.warning("⚠️ Space created but may need more time to build")
853
+ logger.info("Please check the Space manually in a few minutes")
854
+
855
+ logger.info(f"🎉 Demo space deployment completed!")
856
+ logger.info(f"📊 Space URL: {self.space_url}")
857
+ logger.info(f"🔧 Space configuration: {self.space_url}/settings")
858
+
859
+ return True
860
+
861
+ def main():
862
+ """Main function for command line usage"""
863
+ print("Demo Space Deployment Script")
864
+ print("=" * 40)
865
+
866
+ parser = argparse.ArgumentParser(description="Deploy demo space to Hugging Face Spaces")
867
+ parser.add_argument("--hf-token", required=True, help="Hugging Face token")
868
+ parser.add_argument(
869
+ "--space-secret-token",
870
+ required=False,
871
+ help="Token to store as Space secret HF_TOKEN (defaults to --hf-token). Use a READ token here for least privilege.",
872
+ )
873
+ parser.add_argument("--hf-username", required=True, help="Hugging Face username")
874
+ parser.add_argument("--model-id", required=True, help="Model ID to deploy demo for")
875
+ parser.add_argument("--subfolder", default="int4", help="Model subfolder (default: int4)")
876
+ parser.add_argument("--space-name", help="Custom space name (optional)")
877
+ parser.add_argument("--demo-type", choices=["smol", "gpt"], help="Demo type: 'smol' for SmolLM, 'gpt' for GPT-OSS (auto-detected if not specified)")
878
+ parser.add_argument("--config-file", help="Path to the training config file to import context (system/developer/model_identity)")
879
+ # Examples configuration
880
+ parser.add_argument("--examples-type", choices=["general", "medical"], help="Examples pack to enable in the demo UI")
881
+ parser.add_argument("--disable-examples", action="store_true", help="Disable rendering of example prompts in the UI")
882
+ parser.add_argument("--examples-json", help="Custom examples JSON (list[str]) to override built-in examples")
883
+ # Branding customization
884
+ parser.add_argument("--brand-owner-name", help="Owner name shown in the UI title (defaults to HF username)")
885
+ parser.add_argument("--brand-team-name", help="Team name shown in Join Us (defaults to Team<owner>)")
886
+ parser.add_argument("--brand-discord-url", help="Discord invite URL for Join Us section")
887
+ parser.add_argument("--brand-hf-org", help="Hugging Face org/username to link in Join Us")
888
+ parser.add_argument("--brand-hf-label", help="Label for the HF link (defaults to org)")
889
+ parser.add_argument("--brand-hf-url", help="Custom HF link URL (defaults to https://huggingface.co/<org>)")
890
+ parser.add_argument("--brand-gh-org", help="GitHub org/username to link in Join Us")
891
+ parser.add_argument("--brand-gh-label", help="Label for the GitHub link (defaults to org)")
892
+ parser.add_argument("--brand-gh-url", help="Custom GitHub link URL (defaults to https://github.com/<org>)")
893
+ parser.add_argument("--brand-project-name", help="Project name to link in Join Us")
894
+ parser.add_argument("--brand-project-url", help="Project URL to link in Join Us")
895
+
896
+ args = parser.parse_args()
897
+
898
+ deployer = DemoSpaceDeployer(
899
+ hf_token=args.hf_token,
900
+ space_secret_token=(args.space_secret_token or None),
901
+ hf_username=args.hf_username,
902
+ model_id=args.model_id,
903
+ subfolder=args.subfolder,
904
+ space_name=args.space_name,
905
+ demo_type=args.demo_type,
906
+ config_file=args.config_file,
907
+ examples_type=args.examples_type,
908
+ disable_examples=(True if getattr(args, 'disable_examples', False) else None),
909
+ examples_json=args.examples_json,
910
+ brand_owner_name=args.brand_owner_name,
911
+ brand_team_name=args.brand_team_name,
912
+ brand_discord_url=args.brand_discord_url,
913
+ brand_hf_org=args.brand_hf_org,
914
+ brand_hf_label=args.brand_hf_label,
915
+ brand_hf_url=args.brand_hf_url,
916
+ brand_gh_org=args.brand_gh_org,
917
+ brand_gh_label=args.brand_gh_label,
918
+ brand_gh_url=args.brand_gh_url,
919
+ brand_project_name=args.brand_project_name,
920
+ brand_project_url=args.brand_project_url,
921
+ )
922
+
923
+ success = deployer.deploy()
924
+
925
+ if success:
926
+ print("\n✅ Deployment successful!")
927
+ print(f"🌐 Your Demo Space: {deployer.space_url}")
928
+ print(f"👤 Username: {deployer.hf_username}")
929
+ print(f"🤖 Model: {deployer.model_id}")
930
+ print("\nNext steps:")
931
+ print("1. Wait for the Space to build (usually 2-5 minutes)")
932
+ print("2. Secrets have been automatically set via API")
933
+ print("3. Test the interface by visiting the Space URL")
934
+ print("4. Share your demo with others!")
935
+ print("\nIf the Space doesn't work immediately, check:")
936
+ print("- The Space logs at the Space URL")
937
+ print("- That all files were uploaded correctly")
938
+ print("- That the HF token has write permissions")
939
+ print("- That the secrets were set correctly in Space settings")
940
+ else:
941
+ print("\n❌ Deployment failed!")
942
+ print("Check the error messages above and try again.")
943
+ print("\nTroubleshooting:")
944
+ print("1. Verify your HF token has write permissions")
945
+ print("2. Check that the space name is available")
946
+ print("3. Verify the model exists and is accessible")
947
+ print("4. Try creating the space manually on HF first")
948
+
949
+ sys.exit(0 if success else 1)
950
+
951
+ if __name__ == "__main__":
952
+ main()
scripts/generate_model_card.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate unified model card from template
4
+ Handles template variables and conditional sections for quantized models
5
+ """
6
+
7
+ import os
8
+ import re
9
+ import argparse
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional
13
+ from datetime import datetime
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class ModelCardGenerator:
18
+ """Generate unified model cards from templates"""
19
+
20
+ def __init__(self, template_path: str = "templates/model_card.md"):
21
+ self.template_path = Path(template_path)
22
+ if not self.template_path.exists():
23
+ raise FileNotFoundError(f"Template not found: {self.template_path}")
24
+
25
+ def load_template(self) -> str:
26
+ """Load the model card template"""
27
+ with open(self.template_path, 'r', encoding='utf-8') as f:
28
+ return f.read()
29
+
30
+ def process_conditionals(self, content: str, variables: Dict[str, Any]) -> str:
31
+ """Process conditional sections in the template"""
32
+ # Handle {{#if variable}}...{{/if}} blocks
33
+ pattern = r'\{\{#if\s+(\w+)\}\}(.*?)\{\{/if\}\}'
34
+
35
+ def replace_conditional(match):
36
+ variable_name = match.group(1)
37
+ conditional_content = match.group(2)
38
+
39
+ # Check if variable exists and is truthy
40
+ if variable_name in variables and variables[variable_name]:
41
+ return conditional_content
42
+ else:
43
+ return ""
44
+
45
+ return re.sub(pattern, replace_conditional, content, flags=re.DOTALL)
46
+
47
+ def replace_variables(self, content: str, variables: Dict[str, Any]) -> str:
48
+ """Replace template variables with actual values"""
49
+ for key, value in variables.items():
50
+ placeholder = f"{{{{{key}}}}}"
51
+ content = content.replace(placeholder, str(value))
52
+
53
+ return content
54
+
55
+ def generate_model_card(self, variables: Dict[str, Any]) -> str:
56
+ """Generate the complete model card"""
57
+ # Load template
58
+ content = self.load_template()
59
+
60
+ # Process conditionals first
61
+ content = self.process_conditionals(content, variables)
62
+
63
+ # Replace variables
64
+ content = self.replace_variables(content, variables)
65
+
66
+ return content
67
+
68
+ def save_model_card(self, content: str, output_path: str) -> bool:
69
+ """Save the generated model card"""
70
+ try:
71
+ output_file = Path(output_path)
72
+ output_file.parent.mkdir(parents=True, exist_ok=True)
73
+
74
+ with open(output_file, 'w', encoding='utf-8') as f:
75
+ f.write(content)
76
+
77
+ logger.info(f"✅ Model card saved to: {output_file}")
78
+ return True
79
+
80
+ except Exception as e:
81
+ logger.error(f"❌ Failed to save model card: {e}")
82
+ return False
83
+
84
+ def create_default_variables() -> Dict[str, Any]:
85
+ """Create default variables for the model card"""
86
+ return {
87
+ "model_name": "SmolLM3 Fine-tuned Model",
88
+ "model_description": "A fine-tuned version of SmolLM3-3B for improved text generation and conversation capabilities.",
89
+ "repo_name": "your-username/model-name",
90
+ "base_model": "HuggingFaceTB/SmolLM3-3B",
91
+ "dataset_name": "OpenHermes-FR",
92
+ "training_config_type": "Custom Configuration",
93
+ "trainer_type": "SFTTrainer",
94
+ "batch_size": "8",
95
+ "gradient_accumulation_steps": "16",
96
+ "learning_rate": "5e-6",
97
+ "max_epochs": "3",
98
+ "max_seq_length": "2048",
99
+ "hardware_info": "GPU (H100/A100)",
100
+ "experiment_name": "smollm3-experiment",
101
+ "trackio_url": "https://trackio.space/experiment",
102
+ "dataset_repo": "tonic/trackio-experiments",
103
+ "dataset_size": "~80K samples",
104
+ "dataset_format": "Chat format",
105
+ "author_name": "Your Name",
106
+ "model_name_slug": "smollm3-fine-tuned",
107
+ "quantized_models": False,
108
+ "dataset_sample_size": None,
109
+ "training_loss": "N/A",
110
+ "validation_loss": "N/A",
111
+ "perplexity": "N/A"
112
+ }
113
+
114
+ def parse_args():
115
+ """Parse command line arguments"""
116
+ parser = argparse.ArgumentParser(description="Generate unified model card")
117
+ parser.add_argument("--template", default="templates/model_card.md",
118
+ help="Path to model card template")
119
+ parser.add_argument("--output", default="README.md",
120
+ help="Output path for generated model card")
121
+ parser.add_argument("--repo-name", required=True,
122
+ help="Hugging Face repository name")
123
+ parser.add_argument("--model-name", help="Model name")
124
+ parser.add_argument("--experiment-name", help="Experiment name")
125
+ parser.add_argument("--dataset-name", help="Dataset name")
126
+ parser.add_argument("--training-config", help="Training configuration type")
127
+ parser.add_argument("--trainer-type", help="Trainer type")
128
+ parser.add_argument("--batch-size", help="Batch size")
129
+ parser.add_argument("--learning-rate", help="Learning rate")
130
+ parser.add_argument("--max-epochs", help="Maximum epochs")
131
+ parser.add_argument("--max-seq-length", help="Maximum sequence length")
132
+ parser.add_argument("--hardware-info", help="Hardware information")
133
+ parser.add_argument("--trackio-url", help="Trackio URL")
134
+ parser.add_argument("--dataset-repo", help="Dataset repository")
135
+ parser.add_argument("--author-name", help="Author name")
136
+ parser.add_argument("--quantized-models", action="store_true",
137
+ help="Include quantized models")
138
+ parser.add_argument("--dataset-sample-size", help="Dataset sample size")
139
+ parser.add_argument("--training-loss", help="Training loss value")
140
+ parser.add_argument("--validation-loss", help="Validation loss value")
141
+ parser.add_argument("--perplexity", help="Perplexity value")
142
+
143
+ return parser.parse_args()
144
+
145
+ def main():
146
+ """Main function"""
147
+ args = parse_args()
148
+
149
+ # Setup logging
150
+ logging.basicConfig(
151
+ level=logging.INFO,
152
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
153
+ )
154
+
155
+ try:
156
+ # Create generator
157
+ generator = ModelCardGenerator(args.template)
158
+
159
+ # Create variables dictionary
160
+ variables = create_default_variables()
161
+
162
+ # Override with command line arguments
163
+ if args.repo_name:
164
+ variables["repo_name"] = args.repo_name
165
+ if args.model_name:
166
+ variables["model_name"] = args.model_name
167
+ if args.experiment_name:
168
+ variables["experiment_name"] = args.experiment_name
169
+ if args.dataset_name:
170
+ variables["dataset_name"] = args.dataset_name
171
+ if args.training_config:
172
+ variables["training_config_type"] = args.training_config
173
+ if args.trainer_type:
174
+ variables["trainer_type"] = args.trainer_type
175
+ if args.batch_size:
176
+ variables["batch_size"] = args.batch_size
177
+ if args.learning_rate:
178
+ variables["learning_rate"] = args.learning_rate
179
+ if args.max_epochs:
180
+ variables["max_epochs"] = args.max_epochs
181
+ if args.max_seq_length:
182
+ variables["max_seq_length"] = args.max_seq_length
183
+ if args.hardware_info:
184
+ variables["hardware_info"] = args.hardware_info
185
+ if args.trackio_url:
186
+ variables["trackio_url"] = args.trackio_url
187
+ if args.dataset_repo:
188
+ variables["dataset_repo"] = args.dataset_repo
189
+ if args.author_name:
190
+ variables["author_name"] = args.author_name
191
+ if args.quantized_models:
192
+ variables["quantized_models"] = True
193
+ if args.dataset_sample_size:
194
+ variables["dataset_sample_size"] = args.dataset_sample_size
195
+ if args.training_loss:
196
+ variables["training_loss"] = args.training_loss
197
+ if args.validation_loss:
198
+ variables["validation_loss"] = args.validation_loss
199
+ if args.perplexity:
200
+ variables["perplexity"] = args.perplexity
201
+
202
+ # Generate model card
203
+ print("🔄 Generating model card...")
204
+ content = generator.generate_model_card(variables)
205
+
206
+ # Save model card
207
+ if generator.save_model_card(content, args.output):
208
+ print("✅ Model card generated successfully!")
209
+ print(f"📄 Output: {args.output}")
210
+ else:
211
+ print("❌ Failed to generate model card")
212
+ return 1
213
+
214
+ return 0
215
+
216
+ except Exception as e:
217
+ logger.error(f"❌ Error generating model card: {e}")
218
+ return 1
219
+
220
+ if __name__ == "__main__":
221
+ exit(main())
scripts/push_to_huggingface.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Push Trained Model and Results to Hugging Face Hub
4
+ Integrates with Trackio monitoring and HF Datasets for complete model deployment
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import argparse
10
+ import logging
11
+ import time
12
+ from pathlib import Path
13
+ from typing import Dict, Any, Optional, List
14
+ from datetime import datetime
15
+ import subprocess
16
+ import shutil
17
+ import platform
18
+
19
+ # Set timeout for HF operations to prevent hanging
20
+ os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '300'
21
+ os.environ['HF_HUB_UPLOAD_TIMEOUT'] = '600'
22
+
23
+ try:
24
+ from huggingface_hub import HfApi, create_repo, upload_file
25
+ from huggingface_hub import snapshot_download, hf_hub_download
26
+ HF_AVAILABLE = True
27
+ except ImportError:
28
+ HF_AVAILABLE = False
29
+ print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
30
+
31
+ try:
32
+ import sys
33
+ import os
34
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
35
+ from monitoring import SmolLM3Monitor
36
+ MONITORING_AVAILABLE = True
37
+ except ImportError:
38
+ MONITORING_AVAILABLE = False
39
+ print("Warning: monitoring module not available")
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ class TimeoutError(Exception):
44
+ """Custom timeout exception"""
45
+ pass
46
+
47
+ def timeout_handler(signum, frame):
48
+ """Signal handler for timeout"""
49
+ raise TimeoutError("Operation timed out")
50
+
51
+ class HuggingFacePusher:
52
+ """Push trained models and results to Hugging Face Hub with HF Datasets integration"""
53
+
54
+ def __init__(
55
+ self,
56
+ model_path: str,
57
+ repo_name: str,
58
+ token: Optional[str] = None,
59
+ private: bool = False,
60
+ trackio_url: Optional[str] = None,
61
+ experiment_name: Optional[str] = None,
62
+ dataset_repo: Optional[str] = None,
63
+ hf_token: Optional[str] = None,
64
+ author_name: Optional[str] = None,
65
+ model_description: Optional[str] = None,
66
+ training_config_type: Optional[str] = None,
67
+ model_name: Optional[str] = None,
68
+ dataset_name: Optional[str] = None,
69
+ batch_size: Optional[str] = None,
70
+ learning_rate: Optional[str] = None,
71
+ max_epochs: Optional[str] = None,
72
+ max_seq_length: Optional[str] = None,
73
+ trainer_type: Optional[str] = None
74
+ ):
75
+ self.model_path = Path(model_path)
76
+ # Original user input (may be just the repo name without username)
77
+ self.repo_name = repo_name
78
+ self.token = token or hf_token or os.getenv('HF_TOKEN')
79
+ self.private = private
80
+ self.trackio_url = trackio_url
81
+ self.experiment_name = experiment_name
82
+ self.author_name = author_name
83
+ self.model_description = model_description
84
+
85
+ # Training configuration details for model card generation
86
+ self.training_config_type = training_config_type
87
+ self.model_name = model_name
88
+ self.dataset_name = dataset_name
89
+ self.batch_size = batch_size
90
+ self.learning_rate = learning_rate
91
+ self.max_epochs = max_epochs
92
+ self.max_seq_length = max_seq_length
93
+ self.trainer_type = trainer_type
94
+
95
+ # HF Datasets configuration
96
+ self.dataset_repo = dataset_repo or os.getenv('TRACKIO_DATASET_REPO', 'tonic/trackio-experiments')
97
+ self.hf_token = hf_token or os.getenv('HF_TOKEN')
98
+
99
+ # Initialize HF API
100
+ if HF_AVAILABLE:
101
+ self.api = HfApi(token=self.token)
102
+ else:
103
+ raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub")
104
+
105
+ # Resolve the full repo id (username/repo) if user only provided repo name
106
+ self.repo_id = self._resolve_repo_id(self.repo_name)
107
+
108
+ # Initialize monitoring if available
109
+ self.monitor = None
110
+ if MONITORING_AVAILABLE:
111
+ self.monitor = SmolLM3Monitor(
112
+ experiment_name=experiment_name or "model_push",
113
+ trackio_url=trackio_url,
114
+ enable_tracking=bool(trackio_url),
115
+ hf_token=self.hf_token,
116
+ dataset_repo=self.dataset_repo
117
+ )
118
+
119
+ logger.info(f"Initialized HuggingFacePusher for {self.repo_id}")
120
+ logger.info(f"Dataset repository: {self.dataset_repo}")
121
+
122
+ def _resolve_repo_id(self, repo_name: str) -> str:
123
+ """Return a fully-qualified repo id in the form username/repo.
124
+
125
+ If the provided name already contains a '/', it is returned unchanged.
126
+ Otherwise, we attempt to derive the username from the authenticated token
127
+ or from the HF_USERNAME environment variable.
128
+ """
129
+ try:
130
+ if "/" in repo_name:
131
+ return repo_name
132
+
133
+ # Need a username. Prefer API whoami(), fallback to env HF_USERNAME
134
+ username: Optional[str] = None
135
+ if self.token:
136
+ try:
137
+ user_info = self.api.whoami()
138
+ username = user_info.get("name") or user_info.get("username")
139
+ except Exception:
140
+ username = None
141
+
142
+ if not username:
143
+ username = os.getenv("HF_USERNAME")
144
+
145
+ if not username:
146
+ raise ValueError(
147
+ "Username could not be determined. Provide a token or set HF_USERNAME, "
148
+ "or pass a fully-qualified repo id 'username/repo'."
149
+ )
150
+
151
+ return f"{username}/{repo_name}"
152
+ except Exception as resolve_error:
153
+ logger.error(f"Failed to resolve full repo id for '{repo_name}': {resolve_error}")
154
+ # Fall back to provided value (may fail later at create/upload)
155
+ return repo_name
156
+
157
+ def create_repository(self) -> bool:
158
+ """Create the Hugging Face repository"""
159
+ try:
160
+ logger.info(f"Creating repository: {self.repo_id}")
161
+
162
+ # Create repository with timeout handling
163
+ try:
164
+ # Create repository
165
+ create_repo(
166
+ repo_id=self.repo_id,
167
+ token=self.token,
168
+ private=self.private,
169
+ exist_ok=True
170
+ )
171
+
172
+ logger.info(f"✅ Repository created: https://huggingface.co/{self.repo_id}")
173
+ return True
174
+
175
+ except Exception as e:
176
+ logger.error(f"❌ Repository creation failed: {e}")
177
+ return False
178
+
179
+ except Exception as e:
180
+ logger.error(f"❌ Failed to create repository: {e}")
181
+ return False
182
+
183
+ def validate_model_path(self) -> bool:
184
+ """Validate that the model path contains required files"""
185
+ # Support both safetensors and pytorch formats
186
+ required_files = [
187
+ "config.json",
188
+ "tokenizer.json",
189
+ "tokenizer_config.json"
190
+ ]
191
+
192
+ # Check for model files (either safetensors or pytorch)
193
+ model_files = [
194
+ "model.safetensors.index.json", # Safetensors format
195
+ "pytorch_model.bin" # PyTorch format
196
+ ]
197
+
198
+ missing_files = []
199
+ for file in required_files:
200
+ if not (self.model_path / file).exists():
201
+ missing_files.append(file)
202
+
203
+ # Check if at least one model file exists
204
+ model_file_exists = any((self.model_path / file).exists() for file in model_files)
205
+ if not model_file_exists:
206
+ missing_files.extend(model_files)
207
+
208
+ if missing_files:
209
+ logger.error(f"❌ Missing required files: {missing_files}")
210
+ return False
211
+
212
+ logger.info("✅ Model files validated")
213
+ return True
214
+
215
+ def create_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
216
+ """Create a comprehensive model card using the generate_model_card.py script"""
217
+ try:
218
+ # Import the model card generator
219
+ import sys
220
+ sys.path.append(os.path.join(os.path.dirname(__file__)))
221
+ from generate_model_card import ModelCardGenerator, create_default_variables
222
+
223
+ # Create generator
224
+ generator = ModelCardGenerator()
225
+
226
+ # Create variables for the model card
227
+ variables = create_default_variables()
228
+
229
+ # Update with actual values
230
+ variables.update({
231
+ "repo_name": self.repo_id,
232
+ "model_name": self.repo_id.split('/')[-1],
233
+ "experiment_name": self.experiment_name or "model_push",
234
+ "dataset_repo": self.dataset_repo,
235
+ "author_name": self.author_name or "Model Author",
236
+ "model_description": self.model_description or "A fine-tuned version of SmolLM3-3B for improved text generation capabilities.",
237
+ "training_config_type": self.training_config_type or "Custom Configuration",
238
+ "base_model": self.model_name or "HuggingFaceTB/SmolLM3-3B",
239
+ "dataset_name": self.dataset_name or "Custom Dataset",
240
+ "trainer_type": self.trainer_type or "SFTTrainer",
241
+ "batch_size": str(self.batch_size) if self.batch_size else "8",
242
+ "learning_rate": str(self.learning_rate) if self.learning_rate else "5e-6",
243
+ "max_epochs": str(self.max_epochs) if self.max_epochs else "3",
244
+ "max_seq_length": str(self.max_seq_length) if self.max_seq_length else "2048",
245
+ "hardware_info": self._get_hardware_info(),
246
+ "trackio_url": self.trackio_url or "N/A",
247
+ "training_loss": str(results.get('train_loss', 'N/A')),
248
+ "validation_loss": str(results.get('eval_loss', 'N/A')),
249
+ "perplexity": str(results.get('perplexity', 'N/A')),
250
+ "quantized_models": False # Set to True if quantized models are available
251
+ })
252
+
253
+ # Generate the model card
254
+ model_card_content = generator.generate_model_card(variables)
255
+
256
+ logger.info("✅ Model card generated using generate_model_card.py")
257
+ return model_card_content
258
+
259
+ except Exception as e:
260
+ logger.error(f"❌ Failed to generate model card with generator: {e}")
261
+ logger.info("🔄 Falling back to simple model card")
262
+ return self._create_simple_model_card(training_config, results)
263
+
264
+ def _create_simple_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
265
+ """Create a simple model card without complex YAML to avoid formatting issues"""
266
+ return f"""---
267
+ language:
268
+ - en
269
+ - fr
270
+ license: apache-2.0
271
+ tags:
272
+ - smollm3
273
+ - fine-tuned
274
+ - causal-lm
275
+ - text-generation
276
+ pipeline_tag: text-generation
277
+ base_model: HuggingFaceTB/SmolLM3-3B
278
+ ---
279
+
280
+ # {self.repo_id.split('/')[-1]}
281
+
282
+ This is a fine-tuned SmolLM3 model based on the HuggingFaceTB/SmolLM3-3B architecture.
283
+
284
+ ## Model Details
285
+
286
+ - **Base Model**: HuggingFaceTB/SmolLM3-3B
287
+ - **Fine-tuning Method**: Supervised Fine-tuning
288
+ - **Training Date**: {datetime.now().strftime('%Y-%m-%d')}
289
+ - **Model Size**: {self._get_model_size():.1f} GB
290
+ - **Dataset Repository**: {self.dataset_repo}
291
+ - **Hardware**: {self._get_hardware_info()}
292
+
293
+ ## Training Configuration
294
+
295
+ ```json
296
+ {json.dumps(training_config, indent=2)}
297
+ ```
298
+
299
+ ## Training Results
300
+
301
+ ```json
302
+ {json.dumps(results, indent=2)}
303
+ ```
304
+
305
+ ## Usage
306
+
307
+ ```python
308
+ from transformers import AutoModelForCausalLM, AutoTokenizer
309
+
310
+ # Load model and tokenizer
311
+ model = AutoModelForCausalLM.from_pretrained("{self.repo_id}")
312
+ tokenizer = AutoTokenizer.from_pretrained("{self.repo_id}")
313
+
314
+ # Generate text
315
+ inputs = tokenizer("Hello, how are you?", return_tensors="pt")
316
+ outputs = model.generate(**inputs, max_new_tokens=100)
317
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
318
+ ```
319
+
320
+ ## Training Information
321
+
322
+ - **Base Model**: HuggingFaceTB/SmolLM3-3B
323
+ - **Hardware**: {self._get_hardware_info()}
324
+ - **Training Time**: {results.get('training_time_hours', 'Unknown')} hours
325
+ - **Final Loss**: {results.get('final_loss', 'Unknown')}
326
+ - **Final Accuracy**: {results.get('final_accuracy', 'Unknown')}
327
+ - **Dataset Repository**: {self.dataset_repo}
328
+
329
+ ## Model Performance
330
+
331
+ - **Training Loss**: {results.get('train_loss', 'Unknown')}
332
+ - **Validation Loss**: {results.get('eval_loss', 'Unknown')}
333
+ - **Training Steps**: {results.get('total_steps', 'Unknown')}
334
+
335
+ ## Experiment Tracking
336
+
337
+ This model was trained with experiment tracking enabled. Training metrics and configuration are stored in the HF Dataset repository: `{self.dataset_repo}`
338
+
339
+ ## Limitations and Biases
340
+
341
+ This model is fine-tuned for specific tasks and may not generalize well to all use cases. Please evaluate the model's performance on your specific task before deployment.
342
+
343
+ ## License
344
+
345
+ This model is licensed under the Apache 2.0 License.
346
+ """
347
+
348
+ def _get_model_size(self) -> float:
349
+ """Get model size in GB"""
350
+ try:
351
+ total_size = 0
352
+ for file in self.model_path.rglob("*"):
353
+ if file.is_file():
354
+ total_size += file.stat().st_size
355
+ return total_size / (1024**3) # Convert to GB
356
+ except:
357
+ return 0.0
358
+
359
+ def _get_hardware_info(self) -> str:
360
+ """Get hardware information"""
361
+ try:
362
+ import torch
363
+ if torch.cuda.is_available():
364
+ gpu_name = torch.cuda.get_device_name(0)
365
+ return f"GPU: {gpu_name}"
366
+ else:
367
+ return "CPU"
368
+ except:
369
+ return "Unknown"
370
+
371
+ def upload_model_files(self) -> bool:
372
+ """Upload model files to Hugging Face Hub with timeout protection"""
373
+ try:
374
+ logger.info("Uploading model files...")
375
+
376
+ # Upload all files in the model directory
377
+ for file_path in self.model_path.rglob("*"):
378
+ if file_path.is_file():
379
+ relative_path = file_path.relative_to(self.model_path)
380
+ remote_path = str(relative_path)
381
+
382
+ logger.info(f"Uploading {relative_path}")
383
+
384
+ try:
385
+ upload_file(
386
+ path_or_fileobj=str(file_path),
387
+ path_in_repo=remote_path,
388
+ repo_id=self.repo_id,
389
+ token=self.token
390
+ )
391
+ logger.info(f"✅ Uploaded {relative_path}")
392
+
393
+ except Exception as e:
394
+ logger.error(f"❌ Failed to upload {relative_path}: {e}")
395
+ return False
396
+
397
+ logger.info("✅ Model files uploaded successfully")
398
+ return True
399
+
400
+ except Exception as e:
401
+ logger.error(f"❌ Failed to upload model files: {e}")
402
+ return False
403
+
404
+ def upload_training_results(self, results_path: str) -> bool:
405
+ """Upload training results and logs"""
406
+ try:
407
+ logger.info("Uploading training results...")
408
+
409
+ results_files = [
410
+ "train_results.json",
411
+ "eval_results.json",
412
+ "training_config.json",
413
+ "training.log"
414
+ ]
415
+
416
+ for file_name in results_files:
417
+ file_path = Path(results_path) / file_name
418
+ if file_path.exists():
419
+ logger.info(f"Uploading {file_name}")
420
+ upload_file(
421
+ path_or_fileobj=str(file_path),
422
+ path_in_repo=f"training_results/{file_name}",
423
+ repo_id=self.repo_id,
424
+ token=self.token
425
+ )
426
+
427
+ logger.info("✅ Training results uploaded successfully")
428
+ return True
429
+
430
+ except Exception as e:
431
+ logger.error(f"❌ Failed to upload training results: {e}")
432
+ return False
433
+
434
+ def create_readme(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> bool:
435
+ """Create and upload README.md"""
436
+ try:
437
+ logger.info("Creating README.md...")
438
+
439
+ readme_content = f"""# {self.repo_id.split('/')[-1]}
440
+
441
+ A fine-tuned SmolLM3 model for text generation tasks.
442
+
443
+ ## Quick Start
444
+
445
+ ```python
446
+ from transformers import AutoModelForCausalLM, AutoTokenizer
447
+
448
+ model = AutoModelForCausalLM.from_pretrained("{self.repo_id}")
449
+ tokenizer = AutoTokenizer.from_pretrained("{self.repo_id}")
450
+
451
+ # Generate text
452
+ text = "Hello, how are you?"
453
+ inputs = tokenizer(text, return_tensors="pt")
454
+ outputs = model.generate(**inputs, max_new_tokens=100)
455
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
456
+ ```
457
+
458
+ ## Model Information
459
+
460
+ - **Base Model**: HuggingFaceTB/SmolLM3-3B
461
+ - **Fine-tuning Date**: {datetime.now().strftime('%Y-%m-%d')}
462
+ - **Model Size**: {self._get_model_size():.1f} GB
463
+ - **Training Steps**: {results.get('total_steps', 'Unknown')}
464
+ - **Final Loss**: {results.get('final_loss', 'Unknown')}
465
+ - **Dataset Repository**: {self.dataset_repo}
466
+
467
+ ## Training Configuration
468
+
469
+ ```json
470
+ {json.dumps(training_config, indent=2)}
471
+ ```
472
+
473
+ ## Performance Metrics
474
+
475
+ ```json
476
+ {json.dumps(results, indent=2)}
477
+ ```
478
+
479
+ ## Experiment Tracking
480
+
481
+ Training metrics and configuration are stored in the HF Dataset repository: `{self.dataset_repo}`
482
+
483
+ ## Files
484
+
485
+ - `model.safetensors.index.json`: Model weights (safetensors format)
486
+ - `config.json`: Model configuration
487
+ - `tokenizer.json`: Tokenizer configuration
488
+ - `training_results/`: Training logs and results
489
+
490
+ ## License
491
+
492
+ MIT License
493
+ """
494
+
495
+ # Write README to temporary file
496
+ readme_path = Path("temp_readme.md")
497
+ with open(readme_path, "w") as f:
498
+ f.write(readme_content)
499
+
500
+ # Upload README
501
+ upload_file(
502
+ path_or_fileobj=str(readme_path),
503
+ path_in_repo="README.md",
504
+ token=self.token,
505
+ repo_id=self.repo_id
506
+ )
507
+
508
+ # Clean up
509
+ readme_path.unlink()
510
+
511
+ logger.info("✅ README.md uploaded successfully")
512
+ return True
513
+
514
+ except Exception as e:
515
+ logger.error(f"❌ Failed to create README: {e}")
516
+ return False
517
+
518
+ def log_to_trackio(self, action: str, details: Dict[str, Any]):
519
+ """Log push action to Trackio and HF Datasets"""
520
+ if self.monitor:
521
+ try:
522
+ # Log to Trackio
523
+ self.monitor.log_metrics({
524
+ "push_action": action,
525
+ "repo_name": self.repo_id,
526
+ "model_size_gb": self._get_model_size(),
527
+ "dataset_repo": self.dataset_repo,
528
+ **details
529
+ })
530
+
531
+ # Log training summary
532
+ self.monitor.log_training_summary({
533
+ "model_push": True,
534
+ "model_repo": self.repo_id,
535
+ "dataset_repo": self.dataset_repo,
536
+ "push_date": datetime.now().isoformat(),
537
+ **details
538
+ })
539
+
540
+ logger.info(f"✅ Logged {action} to Trackio and HF Datasets")
541
+ except Exception as e:
542
+ logger.error(f"❌ Failed to log to Trackio: {e}")
543
+
544
+ def push_model(self, training_config: Optional[Dict[str, Any]] = None,
545
+ results: Optional[Dict[str, Any]] = None) -> bool:
546
+ """Complete model push process with HF Datasets integration"""
547
+ logger.info(f"🚀 Starting model push to {self.repo_id}")
548
+ logger.info(f"📊 Dataset repository: {self.dataset_repo}")
549
+
550
+ # Validate model path
551
+ if not self.validate_model_path():
552
+ return False
553
+
554
+ # Create repository
555
+ if not self.create_repository():
556
+ return False
557
+
558
+ # Load training config and results if not provided
559
+ if training_config is None:
560
+ training_config = self._load_training_config()
561
+
562
+ if results is None:
563
+ results = self._load_training_results()
564
+
565
+ # Create and upload model card
566
+ model_card = self.create_model_card(training_config, results)
567
+ model_card_path = Path("temp_model_card.md")
568
+ with open(model_card_path, "w") as f:
569
+ f.write(model_card)
570
+
571
+ try:
572
+ upload_file(
573
+ path_or_fileobj=str(model_card_path),
574
+ path_in_repo="README.md",
575
+ repo_id=self.repo_id,
576
+ token=self.token
577
+ )
578
+ finally:
579
+ model_card_path.unlink()
580
+
581
+ # Upload model files
582
+ if not self.upload_model_files():
583
+ return False
584
+
585
+ # Upload training results
586
+ if results:
587
+ self.upload_training_results(str(self.model_path))
588
+
589
+ # Log to Trackio and HF Datasets
590
+ self.log_to_trackio("model_push", {
591
+ "model_path": str(self.model_path),
592
+ "repo_name": self.repo_name,
593
+ "private": self.private,
594
+ "training_config": training_config,
595
+ "results": results
596
+ })
597
+
598
+ logger.info(f"🎉 Model successfully pushed to: https://huggingface.co/{self.repo_id}")
599
+ logger.info(f"📊 Experiment data stored in: {self.dataset_repo}")
600
+ return True
601
+
602
+ def _load_training_config(self) -> Dict[str, Any]:
603
+ """Load training configuration"""
604
+ config_path = self.model_path / "training_config.json"
605
+ if config_path.exists():
606
+ with open(config_path, "r") as f:
607
+ return json.load(f)
608
+ return {"model_name": "HuggingFaceTB/SmolLM3-3B"}
609
+
610
+ def _load_training_results(self) -> Dict[str, Any]:
611
+ """Load training results"""
612
+ results_path = self.model_path / "train_results.json"
613
+ if results_path.exists():
614
+ with open(results_path, "r") as f:
615
+ return json.load(f)
616
+ return {"final_loss": "Unknown", "total_steps": "Unknown"}
617
+
618
+ def parse_args():
619
+ """Parse command line arguments"""
620
+ parser = argparse.ArgumentParser(description='Push trained model to Hugging Face Hub')
621
+
622
+ # Required arguments
623
+ parser.add_argument('model_path', type=str, help='Path to trained model directory')
624
+ parser.add_argument('repo_name', type=str, help='Hugging Face repository name (repo-name). Username will be auto-detected from your token.')
625
+
626
+ # Optional arguments
627
+ parser.add_argument('--token', type=str, default=None, help='Hugging Face token')
628
+ parser.add_argument('--hf-token', type=str, default=None, help='Hugging Face token (alternative to --token)')
629
+ parser.add_argument('--private', action='store_true', help='Make repository private')
630
+ parser.add_argument('--trackio-url', type=str, default=None, help='Trackio Space URL for logging')
631
+ parser.add_argument('--experiment-name', type=str, default=None, help='Experiment name for Trackio')
632
+ parser.add_argument('--dataset-repo', type=str, default=None, help='HF Dataset repository for experiment storage')
633
+ parser.add_argument('--author-name', type=str, default=None, help='Author name for model card')
634
+ parser.add_argument('--model-description', type=str, default=None, help='Model description for model card')
635
+ parser.add_argument('--training-config-type', type=str, default=None, help='Training configuration type')
636
+ parser.add_argument('--model-name', type=str, default=None, help='Base model name')
637
+ parser.add_argument('--dataset-name', type=str, default=None, help='Dataset name')
638
+ parser.add_argument('--batch-size', type=str, default=None, help='Batch size')
639
+ parser.add_argument('--learning-rate', type=str, default=None, help='Learning rate')
640
+ parser.add_argument('--max-epochs', type=str, default=None, help='Maximum epochs')
641
+ parser.add_argument('--max-seq-length', type=str, default=None, help='Maximum sequence length')
642
+ parser.add_argument('--trainer-type', type=str, default=None, help='Trainer type')
643
+
644
+ return parser.parse_args()
645
+
646
+ def main():
647
+ """Main function"""
648
+ args = parse_args()
649
+
650
+ # Setup logging
651
+ logging.basicConfig(
652
+ level=logging.INFO,
653
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
654
+ )
655
+
656
+ logger.info("Starting model push to Hugging Face Hub")
657
+
658
+ # Initialize pusher
659
+ try:
660
+ pusher = HuggingFacePusher(
661
+ model_path=args.model_path,
662
+ repo_name=args.repo_name,
663
+ token=args.token,
664
+ private=args.private,
665
+ trackio_url=args.trackio_url,
666
+ experiment_name=args.experiment_name,
667
+ dataset_repo=args.dataset_repo,
668
+ hf_token=args.hf_token,
669
+ author_name=args.author_name,
670
+ model_description=args.model_description,
671
+ training_config_type=args.training_config_type,
672
+ model_name=args.model_name,
673
+ dataset_name=args.dataset_name,
674
+ batch_size=args.batch_size,
675
+ learning_rate=args.learning_rate,
676
+ max_epochs=args.max_epochs,
677
+ max_seq_length=args.max_seq_length,
678
+ trainer_type=args.trainer_type
679
+ )
680
+
681
+ # Push model
682
+ success = pusher.push_model()
683
+
684
+ if success:
685
+ logger.info("✅ Model push completed successfully!")
686
+ logger.info(f"🌐 View your model at: https://huggingface.co/{args.repo_name}")
687
+ if args.dataset_repo:
688
+ logger.info(f"📊 View experiment data at: https://huggingface.co/datasets/{args.dataset_repo}")
689
+ else:
690
+ logger.error("❌ Model push failed!")
691
+ return 1
692
+
693
+ except Exception as e:
694
+ logger.error(f"❌ Error during model push: {e}")
695
+ return 1
696
+
697
+ return 0
698
+
699
+ if __name__ == "__main__":
700
+ exit(main())
train_lora.py → scripts/train.py RENAMED
@@ -1,14 +1,16 @@
1
  #!/usr/bin/env python3
2
 
 
 
 
3
  import torch
4
- from datasets import load_dataset, Audio
5
  from transformers import (
6
  VoxtralForConditionalGeneration,
7
  VoxtralProcessor,
8
  Trainer,
9
  TrainingArguments,
10
  )
11
- from peft import LoraConfig, get_peft_model
12
 
13
 
14
  class VoxtralDataCollator:
@@ -95,82 +97,114 @@ class VoxtralDataCollator:
95
 
96
  return batch
97
 
98
- def load_and_prepare_dataset():
99
- """Load and prepare dataset for training."""
100
- dataset_name = "hf-audio/esb-datasets-test-only-sorted"
101
- dataset_config = "voxpopuli"
102
-
103
- print(f"Loading dataset: {dataset_name}/{dataset_config}")
104
- dataset = load_dataset(dataset_name, dataset_config, split="test")
105
-
106
- # Cast audio to 16kHz (required for Voxtral)
107
- dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
108
-
109
- train_dataset = dataset.select(range(100))
110
- eval_dataset = dataset.select(range(100, 150))
111
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return train_dataset, eval_dataset
113
 
114
 
115
  def main():
116
- # Configuration
117
- model_checkpoint = "mistralai/Voxtral-Mini-3B-2507"
118
- output_dir = "./voxtral-finetuned"
119
-
120
- # Set device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
  print(f"Using device: {torch_device}")
123
-
124
- # Load processor and model
125
  print("Loading processor and model...")
126
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
127
- # Load model with LoRA configuration
128
- config = LoraConfig(
129
- r=8, # Rank of LoRA
130
- lora_alpha=32,
131
- lora_dropout=0.0,
132
- bias="none",
133
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
134
- task_type="SEQ_2_SEQ_LM",
135
- )
136
- # print number of parameters in model
137
  model = VoxtralForConditionalGeneration.from_pretrained(
138
  model_checkpoint,
139
  torch_dtype=torch.bfloat16,
140
  device_map="auto"
141
  )
142
- # Freeze the audio encoder model.audio_tower
143
- for param in model.audio_tower.parameters():
144
- param.requires_grad = False
145
-
146
- model = get_peft_model(model, config)
147
- model.print_trainable_parameters()
148
- # Load and prepare dataset
149
- train_dataset, eval_dataset = load_and_prepare_dataset()
150
-
151
- # Setup data collator
152
  data_collator = VoxtralDataCollator(processor, model_checkpoint)
153
-
154
- # Simple training arguments
155
  training_args = TrainingArguments(
156
  output_dir=output_dir,
157
- per_device_train_batch_size=2,
158
- per_device_eval_batch_size=4,
159
- gradient_accumulation_steps=4,
160
- learning_rate=5e-5,
161
- num_train_epochs=3,
162
  bf16=True,
163
- logging_steps=10,
164
- eval_steps=50 if eval_dataset else None,
165
- save_steps=50,
166
  eval_strategy="steps" if eval_dataset else "no",
167
  save_strategy="steps",
168
  report_to="none",
169
  remove_unused_columns=False,
170
  dataloader_num_workers=1,
171
  )
172
-
173
- # Setup trainer
174
  trainer = Trainer(
175
  model=model,
176
  args=training_args,
@@ -178,22 +212,18 @@ def main():
178
  eval_dataset=eval_dataset,
179
  data_collator=data_collator,
180
  )
181
-
182
- # Start training
183
  print("Starting training...")
184
  trainer.train()
185
 
186
-
187
- # Save model and processor
188
  print(f"Saving model to {output_dir}")
189
  trainer.save_model()
190
  processor.save_pretrained(output_dir)
191
-
192
- # Final evaluation
193
  if eval_dataset:
194
  results = trainer.evaluate()
195
  print(f"Final evaluation results: {results}")
196
-
197
  print("Training completed successfully!")
198
 
199
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
 
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
  import torch
7
+ from datasets import load_dataset, Audio, Dataset
8
  from transformers import (
9
  VoxtralForConditionalGeneration,
10
  VoxtralProcessor,
11
  Trainer,
12
  TrainingArguments,
13
  )
 
14
 
15
 
16
  class VoxtralDataCollator:
 
97
 
98
  return batch
99
 
100
+ def _load_jsonl_dataset(jsonl_path: str) -> Dataset:
101
+ """Load local JSONL with fields {audio_path, text} into a Dataset with audio column."""
102
+ records = []
103
+ jsonl_file = Path(jsonl_path)
104
+ if not jsonl_file.exists():
105
+ raise FileNotFoundError(f"Dataset jsonl not found: {jsonl_path}")
106
+ with open(jsonl_file, "r", encoding="utf-8") as f:
107
+ for line in f:
108
+ if not line.strip():
109
+ continue
110
+ obj = json.loads(line)
111
+ audio_path = obj.get("audio_path") or obj.get("audio")
112
+ text = obj.get("text")
113
+ if not audio_path or text is None:
114
+ continue
115
+ records.append({"audio": audio_path, "text": text})
116
+ if not records:
117
+ raise ValueError("No valid records found in JSONL. Expect keys: audio_path, text")
118
+ ds = Dataset.from_list(records)
119
+ # Cast the audio column from file paths and resample to 16kHz
120
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
121
+ return ds
122
+
123
+
124
+ def load_and_prepare_dataset(dataset_jsonl: str | None, dataset_name: str | None, dataset_config: str | None,
125
+ train_count: int, eval_count: int):
126
+ """Load and prepare dataset for training.
127
+
128
+ Priority: local JSONL > HF dataset name/config > fallback tiny sample.
129
+ """
130
+ if dataset_jsonl:
131
+ print(f"Loading local JSONL dataset: {dataset_jsonl}")
132
+ ds = _load_jsonl_dataset(dataset_jsonl)
133
+ else:
134
+ ds_name = dataset_name or "hf-audio/esb-datasets-test-only-sorted"
135
+ ds_cfg = dataset_config or "voxpopuli"
136
+ print(f"Loading dataset: {ds_name}/{ds_cfg}")
137
+ ds = load_dataset(ds_name, ds_cfg, split="test")
138
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
139
+
140
+ total = len(ds)
141
+ train_end = min(train_count, total)
142
+ eval_end = min(train_end + eval_count, total)
143
+ train_dataset = ds.select(range(train_end))
144
+ eval_dataset = ds.select(range(train_end, eval_end)) if eval_end > train_end else None
145
  return train_dataset, eval_dataset
146
 
147
 
148
  def main():
149
+ parser = argparse.ArgumentParser(description="Full fine-tune Voxtral for ASR")
150
+ parser.add_argument("--dataset-jsonl", type=str, default=None, help="Path to local JSONL with {audio_path, text}")
151
+ parser.add_argument("--dataset-name", type=str, default=None, help="HF dataset repo (if not using JSONL)")
152
+ parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
153
+ parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
154
+ parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
155
+ parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
156
+ parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
157
+ parser.add_argument("--batch-size", type=int, default=2)
158
+ parser.add_argument("--eval-batch-size", type=int, default=4)
159
+ parser.add_argument("--grad-accum", type=int, default=4)
160
+ parser.add_argument("--learning-rate", type=float, default=5e-5)
161
+ parser.add_argument("--epochs", type=float, default=3)
162
+ parser.add_argument("--logging-steps", type=int, default=10)
163
+ parser.add_argument("--save-steps", type=int, default=50)
164
+ args = parser.parse_args()
165
+
166
+ model_checkpoint = args.model_checkpoint
167
+ output_dir = args.output_dir
168
+
169
  torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
  print(f"Using device: {torch_device}")
171
+
 
172
  print("Loading processor and model...")
173
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
 
 
 
 
 
 
 
 
 
 
174
  model = VoxtralForConditionalGeneration.from_pretrained(
175
  model_checkpoint,
176
  torch_dtype=torch.bfloat16,
177
  device_map="auto"
178
  )
179
+
180
+ train_dataset, eval_dataset = load_and_prepare_dataset(
181
+ dataset_jsonl=args.dataset_jsonl,
182
+ dataset_name=args.dataset_name,
183
+ dataset_config=args.dataset_config,
184
+ train_count=args.train_count,
185
+ eval_count=args.eval_count,
186
+ )
187
+
 
188
  data_collator = VoxtralDataCollator(processor, model_checkpoint)
189
+
 
190
  training_args = TrainingArguments(
191
  output_dir=output_dir,
192
+ per_device_train_batch_size=args.batch_size,
193
+ per_device_eval_batch_size=args.eval_batch_size,
194
+ gradient_accumulation_steps=args.grad_accum,
195
+ learning_rate=args.learning_rate,
196
+ num_train_epochs=args.epochs,
197
  bf16=True,
198
+ logging_steps=args.logging_steps,
199
+ eval_steps=args.save_steps if eval_dataset else None,
200
+ save_steps=args.save_steps,
201
  eval_strategy="steps" if eval_dataset else "no",
202
  save_strategy="steps",
203
  report_to="none",
204
  remove_unused_columns=False,
205
  dataloader_num_workers=1,
206
  )
207
+
 
208
  trainer = Trainer(
209
  model=model,
210
  args=training_args,
 
212
  eval_dataset=eval_dataset,
213
  data_collator=data_collator,
214
  )
215
+
 
216
  print("Starting training...")
217
  trainer.train()
218
 
 
 
219
  print(f"Saving model to {output_dir}")
220
  trainer.save_model()
221
  processor.save_pretrained(output_dir)
222
+
 
223
  if eval_dataset:
224
  results = trainer.evaluate()
225
  print(f"Final evaluation results: {results}")
226
+
227
  print("Training completed successfully!")
228
 
229
  if __name__ == "__main__":
train.py → scripts/train_lora.py RENAMED
@@ -1,13 +1,17 @@
1
  #!/usr/bin/env python3
2
 
 
 
 
3
  import torch
4
- from datasets import load_dataset, Audio
5
  from transformers import (
6
  VoxtralForConditionalGeneration,
7
  VoxtralProcessor,
8
  Trainer,
9
  TrainingArguments,
10
  )
 
11
 
12
 
13
  class VoxtralDataCollator:
@@ -94,68 +98,128 @@ class VoxtralDataCollator:
94
 
95
  return batch
96
 
97
- def load_and_prepare_dataset():
98
- """Load and prepare dataset for training."""
99
- dataset_name = "hf-audio/esb-datasets-test-only-sorted"
100
- dataset_config = "voxpopuli"
101
-
102
- print(f"Loading dataset: {dataset_name}/{dataset_config}")
103
- dataset = load_dataset(dataset_name, dataset_config, split="test")
104
-
105
- # Cast audio to 16kHz (required for Voxtral)
106
- dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
107
-
108
- train_dataset = dataset.select(range(100))
109
- eval_dataset = dataset.select(range(100, 150))
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  return train_dataset, eval_dataset
112
 
113
 
114
  def main():
115
- # Configuration
116
- model_checkpoint = "mistralai/Voxtral-Mini-3B-2507"
117
- output_dir = "./voxtral-finetuned"
118
-
119
- # Set device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
  print(f"Using device: {torch_device}")
122
-
123
- # Load processor and model
124
  print("Loading processor and model...")
125
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
126
-
 
 
 
 
 
 
 
127
  model = VoxtralForConditionalGeneration.from_pretrained(
128
  model_checkpoint,
129
  torch_dtype=torch.bfloat16,
130
  device_map="auto"
131
  )
132
-
133
- # Load and prepare dataset
134
- train_dataset, eval_dataset = load_and_prepare_dataset()
135
-
136
- # Setup data collator
 
 
 
 
 
 
 
 
 
137
  data_collator = VoxtralDataCollator(processor, model_checkpoint)
138
-
139
- # Simple training arguments
140
  training_args = TrainingArguments(
141
  output_dir=output_dir,
142
- per_device_train_batch_size=2,
143
- per_device_eval_batch_size=4,
144
- gradient_accumulation_steps=4,
145
- learning_rate=5e-5,
146
- num_train_epochs=3,
147
  bf16=True,
148
- logging_steps=10,
149
- eval_steps=50 if eval_dataset else None,
150
- save_steps=50,
151
  eval_strategy="steps" if eval_dataset else "no",
152
  save_strategy="steps",
153
  report_to="none",
154
  remove_unused_columns=False,
155
  dataloader_num_workers=1,
156
  )
157
-
158
- # Setup trainer
159
  trainer = Trainer(
160
  model=model,
161
  args=training_args,
@@ -163,22 +227,18 @@ def main():
163
  eval_dataset=eval_dataset,
164
  data_collator=data_collator,
165
  )
166
-
167
- # Start training
168
  print("Starting training...")
169
  trainer.train()
170
 
171
-
172
- # Save model and processor
173
  print(f"Saving model to {output_dir}")
174
  trainer.save_model()
175
  processor.save_pretrained(output_dir)
176
-
177
- # Final evaluation
178
  if eval_dataset:
179
  results = trainer.evaluate()
180
  print(f"Final evaluation results: {results}")
181
-
182
  print("Training completed successfully!")
183
 
184
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
 
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
  import torch
7
+ from datasets import load_dataset, Audio, Dataset
8
  from transformers import (
9
  VoxtralForConditionalGeneration,
10
  VoxtralProcessor,
11
  Trainer,
12
  TrainingArguments,
13
  )
14
+ from peft import LoraConfig, get_peft_model
15
 
16
 
17
  class VoxtralDataCollator:
 
98
 
99
  return batch
100
 
101
+ def _load_jsonl_dataset(jsonl_path: str) -> Dataset:
102
+ """Load local JSONL with fields {audio_path, text} into a Dataset with audio column."""
103
+ records = []
104
+ jsonl_file = Path(jsonl_path)
105
+ if not jsonl_file.exists():
106
+ raise FileNotFoundError(f"Dataset jsonl not found: {jsonl_path}")
107
+ with open(jsonl_file, "r", encoding="utf-8") as f:
108
+ for line in f:
109
+ if not line.strip():
110
+ continue
111
+ obj = json.loads(line)
112
+ audio_path = obj.get("audio_path") or obj.get("audio")
113
+ text = obj.get("text")
114
+ if not audio_path or text is None:
115
+ continue
116
+ records.append({"audio": audio_path, "text": text})
117
+ if not records:
118
+ raise ValueError("No valid records found in JSONL. Expect keys: audio_path, text")
119
+ ds = Dataset.from_list(records)
120
+ # Cast the audio column from file paths and resample to 16kHz
121
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
122
+ return ds
123
+
124
+
125
+ def load_and_prepare_dataset(dataset_jsonl: str | None, dataset_name: str | None, dataset_config: str | None,
126
+ train_count: int, eval_count: int):
127
+ """Load and prepare dataset for training (JSONL or HF hub)."""
128
+ if dataset_jsonl:
129
+ print(f"Loading local JSONL dataset: {dataset_jsonl}")
130
+ ds = _load_jsonl_dataset(dataset_jsonl)
131
+ else:
132
+ ds_name = dataset_name or "hf-audio/esb-datasets-test-only-sorted"
133
+ ds_cfg = dataset_config or "voxpopuli"
134
+ print(f"Loading dataset: {ds_name}/{ds_cfg}")
135
+ ds = load_dataset(ds_name, ds_cfg, split="test")
136
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
137
+
138
+ total = len(ds)
139
+ train_end = min(train_count, total)
140
+ eval_end = min(train_end + eval_count, total)
141
+ train_dataset = ds.select(range(train_end))
142
+ eval_dataset = ds.select(range(train_end, eval_end)) if eval_end > train_end else None
143
  return train_dataset, eval_dataset
144
 
145
 
146
  def main():
147
+ parser = argparse.ArgumentParser(description="LoRA fine-tune Voxtral for ASR")
148
+ parser.add_argument("--dataset-jsonl", type=str, default=None, help="Path to local JSONL with {audio_path, text}")
149
+ parser.add_argument("--dataset-name", type=str, default=None, help="HF dataset repo (if not using JSONL)")
150
+ parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
151
+ parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
152
+ parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
153
+ parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
154
+ parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
155
+ parser.add_argument("--batch-size", type=int, default=2)
156
+ parser.add_argument("--eval-batch-size", type=int, default=4)
157
+ parser.add_argument("--grad-accum", type=int, default=4)
158
+ parser.add_argument("--learning-rate", type=float, default=5e-5)
159
+ parser.add_argument("--epochs", type=float, default=3)
160
+ parser.add_argument("--logging-steps", type=int, default=10)
161
+ parser.add_argument("--save-steps", type=int, default=50)
162
+ parser.add_argument("--lora-r", type=int, default=8)
163
+ parser.add_argument("--lora-alpha", type=int, default=32)
164
+ parser.add_argument("--lora-dropout", type=float, default=0.0)
165
+ parser.add_argument("--freeze-audio-tower", action="store_true", help="Freeze audio encoder parameters")
166
+ args = parser.parse_args()
167
+
168
+ model_checkpoint = args.model_checkpoint
169
+ output_dir = args.output_dir
170
+
171
  torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
  print(f"Using device: {torch_device}")
173
+
 
174
  print("Loading processor and model...")
175
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
176
+ lora_cfg = LoraConfig(
177
+ r=args.lora_r,
178
+ lora_alpha=args.lora_alpha,
179
+ lora_dropout=args.lora_dropout,
180
+ bias="none",
181
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
182
+ task_type="SEQ_2_SEQ_LM",
183
+ )
184
  model = VoxtralForConditionalGeneration.from_pretrained(
185
  model_checkpoint,
186
  torch_dtype=torch.bfloat16,
187
  device_map="auto"
188
  )
189
+ if args.freeze_audio_tower:
190
+ for param in model.audio_tower.parameters():
191
+ param.requires_grad = False
192
+ model = get_peft_model(model, lora_cfg)
193
+ model.print_trainable_parameters()
194
+
195
+ train_dataset, eval_dataset = load_and_prepare_dataset(
196
+ dataset_jsonl=args.dataset_jsonl,
197
+ dataset_name=args.dataset_name,
198
+ dataset_config=args.dataset_config,
199
+ train_count=args.train_count,
200
+ eval_count=args.eval_count,
201
+ )
202
+
203
  data_collator = VoxtralDataCollator(processor, model_checkpoint)
204
+
 
205
  training_args = TrainingArguments(
206
  output_dir=output_dir,
207
+ per_device_train_batch_size=args.batch_size,
208
+ per_device_eval_batch_size=args.eval_batch_size,
209
+ gradient_accumulation_steps=args.grad_accum,
210
+ learning_rate=args.learning_rate,
211
+ num_train_epochs=args.epochs,
212
  bf16=True,
213
+ logging_steps=args.logging_issues if hasattr(args, 'logging_issues') else args.logging_steps,
214
+ eval_steps=args.save_steps if eval_dataset else None,
215
+ save_steps=args.save_steps,
216
  eval_strategy="steps" if eval_dataset else "no",
217
  save_strategy="steps",
218
  report_to="none",
219
  remove_unused_columns=False,
220
  dataloader_num_workers=1,
221
  )
222
+
 
223
  trainer = Trainer(
224
  model=model,
225
  args=training_args,
 
227
  eval_dataset=eval_dataset,
228
  data_collator=data_collator,
229
  )
230
+
 
231
  print("Starting training...")
232
  trainer.train()
233
 
 
 
234
  print(f"Saving model to {output_dir}")
235
  trainer.save_model()
236
  processor.save_pretrained(output_dir)
237
+
 
238
  if eval_dataset:
239
  results = trainer.evaluate()
240
  print(f"Final evaluation results: {results}")
241
+
242
  print("Training completed successfully!")
243
 
244
  if __name__ == "__main__":
templates/datasets/readme.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ dataset_info:
3
+ features:
4
+ - name: experiment_id
5
+ dtype: string
6
+ - name: name
7
+ dtype: string
8
+ - name: description
9
+ dtype: string
10
+ - name: created_at
11
+ dtype: string
12
+ - name: status
13
+ dtype: string
14
+ - name: metrics
15
+ dtype: string
16
+ - name: parameters
17
+ dtype: string
18
+ - name: artifacts
19
+ dtype: string
20
+ - name: logs
21
+ dtype: string
22
+ - name: last_updated
23
+ dtype: string
24
+ splits:
25
+ - name: train
26
+ num_bytes: 4945
27
+ num_examples: 2
28
+ download_size: 15529
29
+ dataset_size: 4945
30
+ configs:
31
+ - config_name: default
32
+ data_files:
33
+ - split: train
34
+ path: data/train-*
35
+ tags:
36
+ - track tonic
37
+ - tonic
38
+ - experiment tracking
39
+ - smollm3
40
+ - fine-tuning
41
+ - legml
42
+ - hermes
43
+ ---
44
+
45
+ # Trackio Experiments Dataset
46
+
47
+ This dataset stores experiment tracking data for ML training runs, particularly focused on SmolLM3 fine-tuning experiments with comprehensive metrics tracking.
48
+
49
+ ## Dataset Structure
50
+
51
+ The dataset contains the following columns:
52
+
53
+ - **experiment_id**: Unique identifier for each experiment
54
+ - **name**: Human-readable name for the experiment
55
+ - **description**: Detailed description of the experiment
56
+ - **created_at**: Timestamp when the experiment was created
57
+ - **status**: Current status (running, completed, failed, paused)
58
+ - **metrics**: JSON string containing training metrics over time
59
+ - **parameters**: JSON string containing experiment configuration
60
+ - **artifacts**: JSON string containing experiment artifacts
61
+ - **logs**: JSON string containing experiment logs
62
+ - **last_updated**: Timestamp of last update
63
+
64
+ ## Metrics Structure
65
+
66
+ The metrics field contains JSON arrays with the following structure:
67
+
68
+ ```json
69
+ [
70
+ {
71
+ "timestamp": "2025-07-20T11:20:01.780908",
72
+ "step": 25,
73
+ "metrics": {
74
+ "loss": 1.1659,
75
+ "accuracy": 0.759,
76
+ "learning_rate": 7e-08,
77
+ "grad_norm": 10.3125,
78
+ "epoch": 0.004851130919895701,
79
+
80
+ // Advanced Training Metrics
81
+ "total_tokens": 1642080.0,
82
+ "truncated_tokens": 128,
83
+ "padding_tokens": 256,
84
+ "throughput": 3284160.0,
85
+ "step_time": 0.5,
86
+ "batch_size": 8,
87
+ "seq_len": 2048,
88
+ "token_acc": 0.759,
89
+
90
+ // Custom Losses
91
+ "train/gate_ortho": 0.0234,
92
+ "train/center": 0.0156,
93
+
94
+ // System Metrics
95
+ "gpu_memory_allocated": 17.202261447906494,
96
+ "gpu_memory_reserved": 75.474609375,
97
+ "gpu_utilization": 85.2,
98
+ "cpu_percent": 2.7,
99
+ "memory_percent": 10.1
100
+ }
101
+ }
102
+ ]
103
+ ```
104
+
105
+ ## Supported Metrics
106
+
107
+ ### Core Training Metrics
108
+ - **loss**: Training loss value
109
+ - **accuracy**: Model accuracy
110
+ - **learning_rate**: Current learning rate
111
+ - **grad_norm**: Gradient norm
112
+ - **epoch**: Current epoch progress
113
+
114
+ ### Advanced Token Metrics
115
+ - **total_tokens**: Total tokens processed in the batch
116
+ - **truncated_tokens**: Number of tokens truncated during processing
117
+ - **padding_tokens**: Number of padding tokens added
118
+ - **throughput**: Tokens processed per second
119
+ - **step_time**: Time taken for the current training step
120
+ - **batch_size**: Current batch size
121
+ - **seq_len**: Sequence length
122
+ - **token_acc**: Token-level accuracy
123
+
124
+ ### Custom Losses (SmolLM3-specific)
125
+ - **train/gate_ortho**: Gate orthogonality loss
126
+ - **train/center**: Center loss component
127
+
128
+ ### System Performance Metrics
129
+ - **gpu_memory_allocated**: GPU memory currently allocated (GB)
130
+ - **gpu_memory_reserved**: GPU memory reserved (GB)
131
+ - **gpu_utilization**: GPU utilization percentage
132
+ - **cpu_percent**: CPU usage percentage
133
+ - **memory_percent**: System memory usage percentage
134
+
135
+ ## Usage
136
+
137
+ This dataset is automatically used by the Trackio monitoring system to store and retrieve experiment data. It provides persistent storage for experiment tracking across different training runs.
138
+
139
+ ## Integration
140
+
141
+ The dataset is used by:
142
+ - Trackio Spaces for experiment visualization
143
+ - Training scripts for logging metrics and parameters
144
+ - Monitoring systems for experiment tracking
145
+ - SmolLM3 fine-tuning pipeline for comprehensive metrics capture
146
+
147
+ ## Privacy
148
+
149
+ This dataset is private by default to ensure experiment data security. Only users with appropriate permissions can access the data.
150
+
151
+ ## Examples
152
+
153
+ ### Sample Experiment Entry
154
+ ```json
155
+ {
156
+ "experiment_id": "exp_20250720_130853",
157
+ "name": "smollm3_finetune",
158
+ "description": "SmolLM3 fine-tuning experiment with comprehensive metrics",
159
+ "created_at": "2025-07-20T11:20:01.780908",
160
+ "status": "running",
161
+ "metrics": "[{\"timestamp\": \"2025-07-20T11:20:01.780908\", \"step\": 25, \"metrics\": {\"loss\": 1.1659, \"accuracy\": 0.759, \"total_tokens\": 1642080.0, \"throughput\": 3284160.0, \"train/gate_ortho\": 0.0234, \"train/center\": 0.0156}}]",
162
+ "parameters": "{\"model_name\": \"HuggingFaceTB/SmolLM3-3B\", \"batch_size\": 8, \"learning_rate\": 3.5e-06, \"max_seq_length\": 12288}",
163
+ "artifacts": "[]",
164
+ "logs": "[]",
165
+ "last_updated": "2025-07-20T11:20:01.780908"
166
+ }
167
+ ```
168
+
169
+ ## License
170
+
171
+ This dataset is part of the Trackio experiment tracking system and follows the same license as the main project.
templates/model_card.md ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - fr
5
+ license: apache-2.0
6
+ library_name: transformers
7
+ tags:
8
+ - smollm3
9
+ - fine-tuned
10
+ - causal-lm
11
+ - text-generation
12
+ - tonic
13
+ - legml
14
+ {{#if quantized_models}}- quantized{{/if}}
15
+ pipeline_tag: text-generation
16
+ base_model: {{base_model}}
17
+ {{#if dataset_name}}
18
+ datasets:
19
+ - {{dataset_name}}
20
+ {{/if}}
21
+ {{#if quantized_models}}
22
+ model-index:
23
+ - name: {{model_name}}
24
+ results:
25
+ - task:
26
+ type: text-generation
27
+ dataset:
28
+ name: {{dataset_name}}
29
+ type: {{dataset_name}}
30
+ metrics:
31
+ - name: Training Loss
32
+ type: loss
33
+ value: "{{training_loss|default:'N/A'}}"
34
+ - name: Validation Loss
35
+ type: loss
36
+ value: "{{validation_loss|default:'N/A'}}"
37
+ - name: Perplexity
38
+ type: perplexity
39
+ value: "{{perplexity|default:'N/A'}}"
40
+ - name: {{model_name}} (int8 quantized)
41
+ results:
42
+ - task:
43
+ type: text-generation
44
+ dataset:
45
+ name: {{dataset_name}}
46
+ type: {{dataset_name}}
47
+ metrics:
48
+ - name: Memory Reduction
49
+ type: memory_efficiency
50
+ value: "~50%"
51
+ - name: Inference Speed
52
+ type: speed
53
+ value: "Faster"
54
+ - name: {{model_name}} (int4 quantized)
55
+ results:
56
+ - task:
57
+ type: text-generation
58
+ dataset:
59
+ name: {{dataset_name}}
60
+ type: {{dataset_name}}
61
+ metrics:
62
+ - name: Memory Reduction
63
+ type: memory_efficiency
64
+ value: "~75%"
65
+ - name: Inference Speed
66
+ type: speed
67
+ value: "Significantly Faster"
68
+ {{else}}
69
+ model-index:
70
+ - name: {{model_name}}
71
+ results:
72
+ - task:
73
+ type: text-generation
74
+ dataset:
75
+ name: {{dataset_name}}
76
+ type: {{dataset_name}}
77
+ metrics:
78
+ - name: Training Loss
79
+ type: loss
80
+ value: "{{training_loss|default:'N/A'}}"
81
+ - name: Validation Loss
82
+ type: loss
83
+ value: "{{validation_loss|default:'N/A'}}"
84
+ - name: Perplexity
85
+ type: perplexity
86
+ value: "{{perplexity|default:'N/A'}}"
87
+ {{/if}}
88
+ {{#if author_name}}
89
+ author: {{author_name}}
90
+ {{/if}}
91
+ {{#if experiment_name}}
92
+ experiment_name: {{experiment_name}}
93
+ {{/if}}
94
+ {{#if trackio_url}}
95
+ trackio_url: {{trackio_url}}
96
+ {{/if}}
97
+ {{#if dataset_repo}}
98
+ dataset_repo: {{dataset_repo}}
99
+ {{/if}}
100
+ {{#if hardware_info}}
101
+ hardware: "{{hardware_info}}"
102
+ {{/if}}
103
+ {{#if training_config_type}}
104
+ training_config: {{training_config_type}}
105
+ {{/if}}
106
+ {{#if trainer_type}}
107
+ trainer_type: {{trainer_type}}
108
+ {{/if}}
109
+ {{#if batch_size}}
110
+ batch_size: {{batch_size}}
111
+ {{/if}}
112
+ {{#if learning_rate}}
113
+ learning_rate: {{learning_rate}}
114
+ {{/if}}
115
+ {{#if max_epochs}}
116
+ max_epochs: {{max_epochs}}
117
+ {{/if}}
118
+ {{#if max_seq_length}}
119
+ max_seq_length: {{max_seq_length}}
120
+ {{/if}}
121
+ {{#if dataset_sample_size}}
122
+ dataset_sample_size: {{dataset_sample_size}}
123
+ {{/if}}
124
+ {{#if dataset_size}}
125
+ dataset_size: {{dataset_size}}
126
+ {{/if}}
127
+ {{#if dataset_format}}
128
+ dataset_format: {{dataset_format}}
129
+ {{/if}}
130
+ {{#if gradient_accumulation_steps}}
131
+ gradient_accumulation_steps: {{gradient_accumulation_steps}}
132
+ {{/if}}
133
+ ---
134
+
135
+ # {{model_name}}
136
+
137
+ {{model_description}}
138
+
139
+ ## Model Details
140
+
141
+ - **Base Model**: SmolLM3-3B
142
+ - **Model Type**: Causal Language Model
143
+ - **Languages**: English, French
144
+ - **License**: Apache 2.0
145
+ - **Fine-tuned**: Yes
146
+ {{#if quantized_models}}
147
+ - **Quantized Versions**: Available in subdirectories
148
+ {{/if}}
149
+
150
+ ## Usage
151
+
152
+ ### Main Model
153
+
154
+ ```python
155
+ import torch
156
+ from transformers import AutoModelForCausalLM, AutoTokenizer
157
+
158
+ # Load the main model
159
+ model = AutoModelForCausalLM.from_pretrained(
160
+ "{{repo_name}}",
161
+ device_map="auto",
162
+ torch_dtype=torch.bfloat16
163
+ )
164
+ tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
165
+
166
+ # Generate text
167
+ input_text = "What are we having for dinner?"
168
+ input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
169
+ output = model.generate(**input_ids, max_new_tokens=50)
170
+ print(tokenizer.decode(output[0], skip_special_tokens=True))
171
+ ```
172
+
173
+ ## Training Information
174
+
175
+ ### Training Configuration
176
+ - **Base Model**: {{base_model}}
177
+ - **Dataset**: {{dataset_name}}
178
+ - **Training Config**: {{training_config_type}}
179
+ - **Trainer Type**: {{trainer_type}}
180
+ {{#if dataset_sample_size}}
181
+ - **Dataset Sample Size**: {{dataset_sample_size}}
182
+ {{/if}}
183
+
184
+ ### Training Parameters
185
+ - **Batch Size**: {{batch_size}}
186
+ - **Gradient Accumulation**: {{gradient_accumulation_steps}}
187
+ - **Learning Rate**: {{learning_rate}}
188
+ - **Max Epochs**: {{max_epochs}}
189
+ - **Sequence Length**: {{max_seq_length}}
190
+
191
+ ### Training Infrastructure
192
+ - **Hardware**: {{hardware_info}}
193
+ - **Monitoring**: Trackio integration
194
+ - **Experiment**: {{experiment_name}}
195
+
196
+ ## Model Architecture
197
+
198
+ This is a fine-tuned version of the SmolLM3-3B model with the following specifications:
199
+
200
+ - **Base Model**: SmolLM3-3B
201
+ - **Parameters**: ~3B
202
+ - **Context Length**: {{max_seq_length}}
203
+ - **Languages**: English, French
204
+ - **Architecture**: Transformer-based causal language model
205
+
206
+ ## Performance
207
+
208
+ The model provides:
209
+ - **Text Generation**: High-quality text generation capabilities
210
+ - **Conversation**: Natural conversation abilities
211
+ - **Multilingual**: Support for English and French
212
+ {{#if quantized_models}}
213
+ - **Quantized Versions**: Optimized for different deployment scenarios
214
+ {{/if}}
215
+
216
+ ## Limitations
217
+
218
+ 1. **Context Length**: Limited by the model's maximum sequence length
219
+ 2. **Bias**: May inherit biases from the training data
220
+ 3. **Factual Accuracy**: May generate incorrect or outdated information
221
+ 4. **Safety**: Should be used responsibly with appropriate safeguards
222
+ {{#if quantized_models}}
223
+ 5. **Quantization**: Quantized versions may have slightly reduced accuracy
224
+ {{/if}}
225
+
226
+ ## Training Data
227
+
228
+ The model was fine-tuned on:
229
+ - **Dataset**: {{dataset_name}}
230
+ - **Size**: {{dataset_size}}
231
+ - **Format**: {{dataset_format}}
232
+ - **Languages**: English, French
233
+
234
+ ## Evaluation
235
+
236
+ The model was evaluated using:
237
+ - **Metrics**: Loss, perplexity, and qualitative assessment
238
+ - **Monitoring**: Real-time tracking via Trackio
239
+ - **Validation**: Regular validation during training
240
+
241
+ ## Citation
242
+
243
+ If you use this model in your research, please cite:
244
+
245
+ ```bibtex
246
+ @misc{{{model_name_slug}},
247
+ title={{{{model_name}}}},
248
+ author={{{author_name}}},
249
+ year={2024},
250
+ url={https://huggingface.co/{{repo_name}}}
251
+ }
252
+ ```
253
+
254
+ ## License
255
+
256
+ This model is licensed under the Apache 2.0 License.
257
+
258
+ ## Acknowledgments
259
+
260
+ - **Base Model**: SmolLM3-3B by HuggingFaceTB
261
+ - **Training Framework**: PyTorch, Transformers, PEFT
262
+ - **Monitoring**: Trackio integration
263
+ - **Quantization**: torchao library
264
+
265
+ ## Support
266
+
267
+ For questions and support:
268
+ - Open an issue on the Hugging Face repository
269
+ - Check the model documentation
270
+ - Review the training logs and configuration
271
+
272
+ ## Repository Structure
273
+
274
+ ```
275
+ {{repo_name}}/
276
+ ├── README.md (this file)
277
+ ├── config.json
278
+ ├── pytorch_model.bin
279
+ ├── tokenizer.json
280
+ └── tokenizer_config.json
281
+ ```
282
+
283
+ ## Usage Examples
284
+
285
+ ### Text Generation
286
+ ```python
287
+ from transformers import AutoModelForCausalLM, AutoTokenizer
288
+
289
+ model = AutoModelForCausalLM.from_pretrained("{{repo_name}}")
290
+ tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
291
+
292
+ text = "The future of artificial intelligence is"
293
+ inputs = tokenizer(text, return_tensors="pt")
294
+ outputs = model.generate(**inputs, max_new_tokens=100)
295
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
296
+ ```
297
+
298
+ ### Conversation
299
+ ```python
300
+ def chat_with_model(prompt, max_length=100):
301
+ inputs = tokenizer(prompt, return_tensors="pt")
302
+ outputs = model.generate(**inputs, max_new_tokens=max_length)
303
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
304
+
305
+ response = chat_with_model("Hello, how are you today?")
306
+ print(response)
307
+ ```
308
+
309
+ ### Advanced Usage
310
+ ```python
311
+ # With generation parameters
312
+ outputs = model.generate(
313
+ **inputs,
314
+ max_new_tokens=100,
315
+ temperature=0.7,
316
+ top_p=0.9,
317
+ do_sample=True,
318
+ pad_token_id=tokenizer.eos_token_id
319
+ )
320
+ ```
321
+
322
+ ## Monitoring and Tracking
323
+
324
+ This model was trained with comprehensive monitoring:
325
+ - **Trackio Space**: {{trackio_url}}
326
+ - **Experiment**: {{experiment_name}}
327
+ - **Dataset Repository**: https://huggingface.co/datasets/{{dataset_repo}}
328
+ - **Training Logs**: Available in the experiment data
329
+
330
+ ## Deployment
331
+
332
+ ### Requirements
333
+ ```bash
334
+ pip install torch transformers accelerate
335
+ {{#if quantized_models}}
336
+ pip install torchao # For quantized models
337
+ {{/if}}
338
+ ```
339
+
340
+ ### Hardware Requirements
341
+ - **Main Model**: GPU with 8GB+ VRAM recommended
342
+
343
+ ## Changelog
344
+
345
+ - **v1.0.0**: Initial release with fine-tuned model
templates/spaces/demo_voxtral/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Voxtral ASR Demo
3
+ emoji: 🎙️
4
+ colorFrom: indigo
5
+ colorTo: cyan
6
+ sdk: gradio
7
+ sdk_version: 5.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Interactive ASR demo for a fine-tuned Voxtral model
11
+ ---
12
+ This Space serves a Voxtral ASR model for speech-to-text transcription.
13
+ Usage:
14
+
15
+ - Click Record and read the displayed phrase aloud.
16
+ - Stop recording to see the transcription.
17
+ - Works best with ~16 kHz audio; internal processing follows Voxtral's processor expectations.
18
+
19
+ Environment variables expected:
20
+
21
+ - `HF_MODEL_ID`: The model repo to load (e.g., `username/voxtral-finetune-YYYYMMDD_HHMMSS`)
22
+ - `MODEL_NAME`: Display name
23
+ - `HF_USERNAME`: For branding
templates/spaces/demo_voxtral/app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModelForSeq2SeqLM
5
+
6
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "mistralai/Voxtral-Mini-3B-2507")
7
+ MODEL_NAME = os.getenv("MODEL_NAME", HF_MODEL_ID.split("/")[-1])
8
+ HF_USERNAME = os.getenv("HF_USERNAME", "")
9
+
10
+ processor = AutoProcessor.from_pretrained(HF_MODEL_ID)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(HF_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16)
12
+
13
+ def transcribe(audio_tuple):
14
+ if audio_tuple is None:
15
+ return "No audio provided"
16
+ sr, data = audio_tuple
17
+ inputs = processor.apply_transcription_request(language="en", model_id=HF_MODEL_ID, audio=[data], format=["WAV"], return_tensors="pt")
18
+ inputs = {k: (v.to(model.device) if hasattr(v, 'to') else v) for k, v in inputs.items()}
19
+ with torch.no_grad():
20
+ output_ids = model.generate(**inputs, max_new_tokens=256)
21
+ # Voxtral returns full sequence; decode and strip special tokens
22
+ text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
23
+ return text
24
+
25
+ with gr.Blocks() as demo:
26
+ gr.Markdown(f"# 🎙️ Voxtral ASR Demo — {MODEL_NAME}")
27
+ audio = gr.Audio(sources="microphone", type="numpy", label="Record or upload audio")
28
+ btn = gr.Button("Transcribe")
29
+ out = gr.Textbox(label="Transcription", lines=4)
30
+ btn.click(transcribe, inputs=[audio], outputs=[out])
31
+
32
+ if __name__ == "__main__":
33
+ demo.launch(mcp_server=True)
34
+
35
+
templates/spaces/demo_voxtral/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=5.38.2
2
+ torch
3
+ transformers
4
+ datasets
5
+ soundfile
6
+ librosa
7
+