Spaces:
Running
Running
File size: 6,596 Bytes
643a0c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
#!/usr/bin/env python3
import torch
from datasets import load_dataset, Audio
from transformers import (
VoxtralForConditionalGeneration,
VoxtralProcessor,
Trainer,
TrainingArguments,
)
import jiwer
from peft import LoraConfig, get_peft_model
class VoxtralDataCollator:
"""Data collator for Voxtral STT training - processes audio and text."""
def __init__(self, processor, model_id):
self.processor = processor
self.model_id = model_id
self.pad_id = processor.tokenizer.pad_token_id
def __call__(self, features):
"""
Each feature should have:
- "audio": raw audio (whatever your processor expects)
- "text": transcription string
"""
texts = [f["text"] for f in features]
audios = [f["audio"]["array"] for f in features]
# 1) Build the PROMPT part: [AUDIO]…[AUDIO] <transcribe>
prompt = self.processor.apply_transcription_request( # (same method you used)
language="en",
model_id=self.model_id if hasattr(self, "model_id") else None,
audio=audios,
format=["WAV"] * len(audios),
return_tensors="pt",
)
# prompt["input_ids"]: shape [B, L_prompt]
# keep any extra fields (e.g., audio features) to pass through to the model
passthrough = {k: v for k, v in prompt.items()
if k not in ("input_ids", "attention_mask")}
prompt_ids = prompt["input_ids"] # [B, Lp]
prompt_attn = prompt["attention_mask"] # [B, Lp]
B = prompt_ids.size(0)
tok = self.processor.tokenizer
# 2) Tokenize transcriptions WITHOUT padding; we'll pad after concatenation
text_tok = tok(
texts,
add_special_tokens=False,
padding=False,
truncation=True,
max_length=256,
return_tensors=None,
)
text_ids_list = text_tok["input_ids"]
# 3) Concatenate: input_ids = [PROMPT] + [TEXT]
input_ids, attention_mask, labels = [], [], []
for i in range(B):
p_ids = prompt_ids[i].tolist()
p_att = prompt_attn[i].tolist()
t_ids = text_ids_list[i]
ids = p_ids + t_ids
attn = p_att + [1] * len(t_ids)
# labels: mask prompt tokens, learn only on text tokens
lab = [-100] * len(p_ids) + t_ids
input_ids.append(ids)
attention_mask.append(attn)
labels.append(lab)
# 4) Pad to max length in batch
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
max_len = max(len(x) for x in input_ids)
def pad_to(seq, fill, L):
return seq + [fill] * (L - len(seq))
input_ids = [pad_to(x, pad_id, max_len) for x in input_ids]
attention_mask = [pad_to(x, 0, max_len) for x in attention_mask]
labels = [pad_to(x, -100, max_len) for x in labels]
batch = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# 5) Include processor outputs needed by the model (e.g., audio features)
for k, v in passthrough.items():
batch[k] = v
return batch
def load_and_prepare_dataset():
"""Load and prepare dataset for training."""
dataset_name = "hf-audio/esb-datasets-test-only-sorted"
dataset_config = "voxpopuli"
print(f"Loading dataset: {dataset_name}/{dataset_config}")
dataset = load_dataset(dataset_name, dataset_config, split="test")
# Cast audio to 16kHz (required for Voxtral)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
train_dataset = dataset.select(range(100))
eval_dataset = dataset.select(range(100, 150))
return train_dataset, eval_dataset
def main():
# Configuration
model_checkpoint = "mistralai/Voxtral-Mini-3B-2507"
output_dir = "./voxtral-finetuned"
# Set device
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
# Load processor and model
print("Loading processor and model...")
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
# Load model with LoRA configuration
config = LoraConfig(
r=8, # Rank of LoRA
lora_alpha=32,
lora_dropout=0.0,
bias="none",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="SEQ_2_SEQ_LM",
)
# print number of parameters in model
model = VoxtralForConditionalGeneration.from_pretrained(
model_checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Freeze the audio encoder model.audio_tower
for param in model.audio_tower.parameters():
param.requires_grad = False
model = get_peft_model(model, config)
model.print_trainable_parameters()
# Load and prepare dataset
train_dataset, eval_dataset = load_and_prepare_dataset()
# Setup data collator
data_collator = VoxtralDataCollator(processor, model_checkpoint)
# Simple training arguments
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=2,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-5,
num_train_epochs=3,
bf16=True,
logging_steps=10,
eval_steps=50 if eval_dataset else None,
save_steps=50,
eval_strategy="steps" if eval_dataset else "no",
save_strategy="steps",
report_to="none",
remove_unused_columns=False,
dataloader_num_workers=1,
)
# Setup trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# Start training
print("Starting training...")
trainer.train()
# Save model and processor
print(f"Saving model to {output_dir}")
trainer.save_model()
processor.save_pretrained(output_dir)
# Final evaluation
if eval_dataset:
results = trainer.evaluate()
print(f"Final evaluation results: {results}")
print("Training completed successfully!")
if __name__ == "__main__":
main() |