Steveeeeeeen HF Staff commited on
Commit
643a0c1
·
1 Parent(s): 97ae18a

add lora training script'

Browse files
Files changed (1) hide show
  1. train_lora.py +201 -0
train_lora.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from datasets import load_dataset, Audio
5
+ from transformers import (
6
+ VoxtralForConditionalGeneration,
7
+ VoxtralProcessor,
8
+ Trainer,
9
+ TrainingArguments,
10
+ )
11
+ import jiwer
12
+ from peft import LoraConfig, get_peft_model
13
+
14
+
15
+ class VoxtralDataCollator:
16
+ """Data collator for Voxtral STT training - processes audio and text."""
17
+
18
+ def __init__(self, processor, model_id):
19
+ self.processor = processor
20
+ self.model_id = model_id
21
+ self.pad_id = processor.tokenizer.pad_token_id
22
+
23
+ def __call__(self, features):
24
+ """
25
+ Each feature should have:
26
+ - "audio": raw audio (whatever your processor expects)
27
+ - "text": transcription string
28
+ """
29
+ texts = [f["text"] for f in features]
30
+ audios = [f["audio"]["array"] for f in features]
31
+
32
+ # 1) Build the PROMPT part: [AUDIO]…[AUDIO] <transcribe>
33
+ prompt = self.processor.apply_transcription_request( # (same method you used)
34
+ language="en",
35
+ model_id=self.model_id if hasattr(self, "model_id") else None,
36
+ audio=audios,
37
+ format=["WAV"] * len(audios),
38
+ return_tensors="pt",
39
+ )
40
+ # prompt["input_ids"]: shape [B, L_prompt]
41
+ # keep any extra fields (e.g., audio features) to pass through to the model
42
+ passthrough = {k: v for k, v in prompt.items()
43
+ if k not in ("input_ids", "attention_mask")}
44
+
45
+ prompt_ids = prompt["input_ids"] # [B, Lp]
46
+ prompt_attn = prompt["attention_mask"] # [B, Lp]
47
+ B = prompt_ids.size(0)
48
+
49
+ tok = self.processor.tokenizer
50
+ # 2) Tokenize transcriptions WITHOUT padding; we'll pad after concatenation
51
+ text_tok = tok(
52
+ texts,
53
+ add_special_tokens=False,
54
+ padding=False,
55
+ truncation=True,
56
+ max_length=256,
57
+ return_tensors=None,
58
+ )
59
+ text_ids_list = text_tok["input_ids"]
60
+
61
+ # 3) Concatenate: input_ids = [PROMPT] + [TEXT]
62
+ input_ids, attention_mask, labels = [], [], []
63
+ for i in range(B):
64
+ p_ids = prompt_ids[i].tolist()
65
+ p_att = prompt_attn[i].tolist()
66
+ t_ids = text_ids_list[i]
67
+
68
+ ids = p_ids + t_ids
69
+ attn = p_att + [1] * len(t_ids)
70
+ # labels: mask prompt tokens, learn only on text tokens
71
+ lab = [-100] * len(p_ids) + t_ids
72
+
73
+ input_ids.append(ids)
74
+ attention_mask.append(attn)
75
+ labels.append(lab)
76
+
77
+ # 4) Pad to max length in batch
78
+ pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
79
+ max_len = max(len(x) for x in input_ids)
80
+
81
+ def pad_to(seq, fill, L):
82
+ return seq + [fill] * (L - len(seq))
83
+
84
+ input_ids = [pad_to(x, pad_id, max_len) for x in input_ids]
85
+ attention_mask = [pad_to(x, 0, max_len) for x in attention_mask]
86
+ labels = [pad_to(x, -100, max_len) for x in labels]
87
+
88
+ batch = {
89
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
90
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
91
+ "labels": torch.tensor(labels, dtype=torch.long),
92
+ }
93
+ # 5) Include processor outputs needed by the model (e.g., audio features)
94
+ for k, v in passthrough.items():
95
+ batch[k] = v
96
+
97
+ return batch
98
+
99
+ def load_and_prepare_dataset():
100
+ """Load and prepare dataset for training."""
101
+ dataset_name = "hf-audio/esb-datasets-test-only-sorted"
102
+ dataset_config = "voxpopuli"
103
+
104
+ print(f"Loading dataset: {dataset_name}/{dataset_config}")
105
+ dataset = load_dataset(dataset_name, dataset_config, split="test")
106
+
107
+ # Cast audio to 16kHz (required for Voxtral)
108
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
109
+
110
+ train_dataset = dataset.select(range(100))
111
+ eval_dataset = dataset.select(range(100, 150))
112
+
113
+ return train_dataset, eval_dataset
114
+
115
+
116
+ def main():
117
+ # Configuration
118
+ model_checkpoint = "mistralai/Voxtral-Mini-3B-2507"
119
+ output_dir = "./voxtral-finetuned"
120
+
121
+ # Set device
122
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ print(f"Using device: {torch_device}")
124
+
125
+ # Load processor and model
126
+ print("Loading processor and model...")
127
+ processor = VoxtralProcessor.from_pretrained(model_checkpoint)
128
+ # Load model with LoRA configuration
129
+ config = LoraConfig(
130
+ r=8, # Rank of LoRA
131
+ lora_alpha=32,
132
+ lora_dropout=0.0,
133
+ bias="none",
134
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
135
+ task_type="SEQ_2_SEQ_LM",
136
+ )
137
+ # print number of parameters in model
138
+ model = VoxtralForConditionalGeneration.from_pretrained(
139
+ model_checkpoint,
140
+ torch_dtype=torch.bfloat16,
141
+ device_map="auto"
142
+ )
143
+ # Freeze the audio encoder model.audio_tower
144
+ for param in model.audio_tower.parameters():
145
+ param.requires_grad = False
146
+
147
+ model = get_peft_model(model, config)
148
+ model.print_trainable_parameters()
149
+ # Load and prepare dataset
150
+ train_dataset, eval_dataset = load_and_prepare_dataset()
151
+
152
+ # Setup data collator
153
+ data_collator = VoxtralDataCollator(processor, model_checkpoint)
154
+
155
+ # Simple training arguments
156
+ training_args = TrainingArguments(
157
+ output_dir=output_dir,
158
+ per_device_train_batch_size=2,
159
+ per_device_eval_batch_size=4,
160
+ gradient_accumulation_steps=4,
161
+ learning_rate=5e-5,
162
+ num_train_epochs=3,
163
+ bf16=True,
164
+ logging_steps=10,
165
+ eval_steps=50 if eval_dataset else None,
166
+ save_steps=50,
167
+ eval_strategy="steps" if eval_dataset else "no",
168
+ save_strategy="steps",
169
+ report_to="none",
170
+ remove_unused_columns=False,
171
+ dataloader_num_workers=1,
172
+ )
173
+
174
+ # Setup trainer
175
+ trainer = Trainer(
176
+ model=model,
177
+ args=training_args,
178
+ train_dataset=train_dataset,
179
+ eval_dataset=eval_dataset,
180
+ data_collator=data_collator,
181
+ )
182
+
183
+ # Start training
184
+ print("Starting training...")
185
+ trainer.train()
186
+
187
+
188
+ # Save model and processor
189
+ print(f"Saving model to {output_dir}")
190
+ trainer.save_model()
191
+ processor.save_pretrained(output_dir)
192
+
193
+ # Final evaluation
194
+ if eval_dataset:
195
+ results = trainer.evaluate()
196
+ print(f"Final evaluation results: {results}")
197
+
198
+ print("Training completed successfully!")
199
+
200
+ if __name__ == "__main__":
201
+ main()