Steveeeeeeen HF Staff commited on
Commit
97ae18a
·
0 Parent(s):

add training script for voxtral

Browse files
Files changed (1) hide show
  1. train.py +186 -0
train.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from datasets import load_dataset, Audio
5
+ from transformers import (
6
+ VoxtralForConditionalGeneration,
7
+ VoxtralProcessor,
8
+ Trainer,
9
+ TrainingArguments,
10
+ )
11
+ import jiwer
12
+
13
+
14
+ class VoxtralDataCollator:
15
+ """Data collator for Voxtral STT training - processes audio and text."""
16
+
17
+ def __init__(self, processor, model_id):
18
+ self.processor = processor
19
+ self.model_id = model_id
20
+ self.pad_id = processor.tokenizer.pad_token_id
21
+
22
+ def __call__(self, features):
23
+ """
24
+ Each feature should have:
25
+ - "audio": raw audio (whatever your processor expects)
26
+ - "text": transcription string
27
+ """
28
+ texts = [f["text"] for f in features]
29
+ audios = [f["audio"]["array"] for f in features]
30
+
31
+ # 1) Build the PROMPT part: [AUDIO]…[AUDIO] <transcribe>
32
+ prompt = self.processor.apply_transcription_request( # (same method you used)
33
+ language="en",
34
+ model_id=self.model_id if hasattr(self, "model_id") else None,
35
+ audio=audios,
36
+ format=["WAV"] * len(audios),
37
+ return_tensors="pt",
38
+ )
39
+ # prompt["input_ids"]: shape [B, L_prompt]
40
+ # keep any extra fields (e.g., audio features) to pass through to the model
41
+ passthrough = {k: v for k, v in prompt.items()
42
+ if k not in ("input_ids", "attention_mask")}
43
+
44
+ prompt_ids = prompt["input_ids"] # [B, Lp]
45
+ prompt_attn = prompt["attention_mask"] # [B, Lp]
46
+ B = prompt_ids.size(0)
47
+
48
+ tok = self.processor.tokenizer
49
+ # 2) Tokenize transcriptions WITHOUT padding; we'll pad after concatenation
50
+ text_tok = tok(
51
+ texts,
52
+ add_special_tokens=False,
53
+ padding=False,
54
+ truncation=True,
55
+ max_length=256,
56
+ return_tensors=None,
57
+ )
58
+ text_ids_list = text_tok["input_ids"]
59
+
60
+ # 3) Concatenate: input_ids = [PROMPT] + [TEXT]
61
+ input_ids, attention_mask, labels = [], [], []
62
+ for i in range(B):
63
+ p_ids = prompt_ids[i].tolist()
64
+ p_att = prompt_attn[i].tolist()
65
+ t_ids = text_ids_list[i]
66
+
67
+ ids = p_ids + t_ids
68
+ attn = p_att + [1] * len(t_ids)
69
+ # labels: mask prompt tokens, learn only on text tokens
70
+ lab = [-100] * len(p_ids) + t_ids
71
+
72
+ input_ids.append(ids)
73
+ attention_mask.append(attn)
74
+ labels.append(lab)
75
+
76
+ # 4) Pad to max length in batch
77
+ pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
78
+ max_len = max(len(x) for x in input_ids)
79
+
80
+ def pad_to(seq, fill, L):
81
+ return seq + [fill] * (L - len(seq))
82
+
83
+ input_ids = [pad_to(x, pad_id, max_len) for x in input_ids]
84
+ attention_mask = [pad_to(x, 0, max_len) for x in attention_mask]
85
+ labels = [pad_to(x, -100, max_len) for x in labels]
86
+
87
+ batch = {
88
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
89
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
90
+ "labels": torch.tensor(labels, dtype=torch.long),
91
+ }
92
+ # 5) Include processor outputs needed by the model (e.g., audio features)
93
+ for k, v in passthrough.items():
94
+ batch[k] = v
95
+
96
+ return batch
97
+
98
+ def load_and_prepare_dataset():
99
+ """Load and prepare dataset for training."""
100
+ dataset_name = "hf-audio/esb-datasets-test-only-sorted"
101
+ dataset_config = "voxpopuli"
102
+
103
+ print(f"Loading dataset: {dataset_name}/{dataset_config}")
104
+ dataset = load_dataset(dataset_name, dataset_config, split="test")
105
+
106
+ # Cast audio to 16kHz (required for Voxtral)
107
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
108
+
109
+ train_dataset = dataset.select(range(100))
110
+ eval_dataset = dataset.select(range(100, 150))
111
+
112
+ return train_dataset, eval_dataset
113
+
114
+
115
+ def main():
116
+ # Configuration
117
+ model_checkpoint = "mistralai/Voxtral-Mini-3B-2507"
118
+ output_dir = "./voxtral-finetuned"
119
+
120
+ # Set device
121
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+ print(f"Using device: {torch_device}")
123
+
124
+ # Load processor and model
125
+ print("Loading processor and model...")
126
+ processor = VoxtralProcessor.from_pretrained(model_checkpoint)
127
+
128
+ model = VoxtralForConditionalGeneration.from_pretrained(
129
+ model_checkpoint,
130
+ torch_dtype=torch.bfloat16,
131
+ device_map="auto"
132
+ )
133
+
134
+ # Load and prepare dataset
135
+ train_dataset, eval_dataset = load_and_prepare_dataset()
136
+
137
+ # Setup data collator
138
+ data_collator = VoxtralDataCollator(processor, model_checkpoint)
139
+
140
+ # Simple training arguments
141
+ training_args = TrainingArguments(
142
+ output_dir=output_dir,
143
+ per_device_train_batch_size=2,
144
+ per_device_eval_batch_size=4,
145
+ gradient_accumulation_steps=4,
146
+ learning_rate=5e-5,
147
+ num_train_epochs=3,
148
+ bf16=True,
149
+ logging_steps=10,
150
+ eval_steps=50 if eval_dataset else None,
151
+ save_steps=50,
152
+ eval_strategy="steps" if eval_dataset else "no",
153
+ save_strategy="steps",
154
+ report_to="none",
155
+ remove_unused_columns=False,
156
+ dataloader_num_workers=1,
157
+ )
158
+
159
+ # Setup trainer
160
+ trainer = Trainer(
161
+ model=model,
162
+ args=training_args,
163
+ train_dataset=train_dataset,
164
+ eval_dataset=eval_dataset,
165
+ data_collator=data_collator,
166
+ )
167
+
168
+ # Start training
169
+ print("Starting training...")
170
+ trainer.train()
171
+
172
+
173
+ # Save model and processor
174
+ print(f"Saving model to {output_dir}")
175
+ trainer.save_model()
176
+ processor.save_pretrained(output_dir)
177
+
178
+ # Final evaluation
179
+ if eval_dataset:
180
+ results = trainer.evaluate()
181
+ print(f"Final evaluation results: {results}")
182
+
183
+ print("Training completed successfully!")
184
+
185
+ if __name__ == "__main__":
186
+ main()