File size: 7,964 Bytes
00cacca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
# Import required libraries
from datasets import load_dataset, Audio
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer
)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from functools import partial
import evaluate
# Load the dataset
dataset = load_dataset("") # Specify Data Repo on HF
dataset
# Split the dataset into train and test sets (80-20 split)
split_dataset = dataset['train'].train_test_split(test_size=0.2)
split_dataset
# Select only the relevant columns for training
split_dataset['train'] = split_dataset['train'].select_columns(["audio", "sentence"])
split_dataset['train']
# Initialize the Whisper processor for Swahili transcription
processor = WhisperProcessor.from_pretrained(
"openai/whisper-small",
language="swahili",
task="transcribe"
)
# Print audio features before and after resampling to match Whisper's expected sampling rate
print('BEFORE>>> ', split_dataset['train'].features['audio'])
sampling_rate = processor.feature_extractor.sampling_rate
split_dataset['train'] = split_dataset['train'].cast_column(
"audio",
Audio(sampling_rate=sampling_rate)
)
print('AFTER>>> ', split_dataset['train'].features['audio'])
# Do the same for the test set
print('BEFORE>>> ', split_dataset['test'].features['audio'])
split_dataset['test'] = split_dataset['test'].cast_column(
"audio",
Audio(sampling_rate=sampling_rate)
)
print('AFTER>>> ', split_dataset['test'].features['audio'])
def prepare_dataset(example):
"""Preprocess audio and text data for Whisper model training"""
audio = example["audio"]
# Process audio and text using Whisper processor
example = processor(
audio=audio["array"],
sampling_rate=audio["sampling_rate"],
text=example["sentence"],
)
# Compute input length of audio sample in seconds
example["input_length"] = len(audio["array"]) / audio["sampling_rate"]
return example
# Apply preprocessing to train and test sets
split_dataset['train'] = split_dataset['train'].map(
prepare_dataset,
remove_columns=split_dataset['train'].column_names,
num_proc=4 # Use 4 processes for faster preprocessing
)
split_dataset['test'] = split_dataset['test'].map(
prepare_dataset,
remove_columns=split_dataset['test'].column_names,
num_proc=1
)
# Filter out audio samples longer than 30 seconds
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
split_dataset['train'] = split_dataset['train'].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
"""Custom data collator for Whisper speech-to-sequence tasks with padding"""
processor: Any
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# Split inputs and labels since they need different padding methods
# First process audio inputs
input_features = [
{"input_features": feature["input_features"][0]} for feature in features
]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# Process label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# Replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)
# Remove BOS token if it was appended previously
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
# Initialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# Load evaluation metric (Word Error Rate)
metric = evaluate.load("wer")
# Initialize text normalizer for evaluation
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
normalizer = BasicTextNormalizer()
def compute_metrics(pred):
"""Compute WER (Word Error Rate) metrics for evaluation"""
pred_ids = pred.predictions
label_ids = pred.label_ids
# Replace -100 with pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# Decode predictions and labels
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
# Compute orthographic WER
wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
# Compute normalized WER
pred_str_norm = [normalizer(pred) for pred in pred_str]
label_str_norm = [normalizer(label) for label in label_str]
# Filter samples with non-zero references
pred_str_norm = [
pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
]
label_str_norm = [
label_str_norm[i] for i in range(len(label_str_norm)) if len(label_str_norm[i]) > 0
]
wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)
return {"wer_ortho": wer_ortho, "wer": wer}
# Load pre-trained Whisper model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
# Disable cache during training (incompatible with gradient checkpointing)
model.config.use_cache = False
# Configure generation settings (re-enable cache for generation)
model.generate = partial(
model.generate,
language="swahili",
task="transcribe",
use_cache=True
)
# Set up training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./model", # Output directory
per_device_train_batch_size=16, # Batch size for training
gradient_accumulation_steps=1, # Number of steps before gradient update
learning_rate=1e-6, # Learning rate
lr_scheduler_type="constant_with_warmup", # Learning rate scheduler
warmup_steps=50, # Warmup steps
max_steps=10000, # Total training steps
gradient_checkpointing=True, # Use gradient checkpointing
fp16=True, # Use mixed precision training
fp16_full_eval=True, # Use mixed precision evaluation
evaluation_strategy="steps", # Evaluation strategy
per_device_eval_batch_size=16, # Batch size for evaluation
predict_with_generate=True, # Use generation for evaluation
generation_max_length=225, # Maximum generation length
save_steps=500, # Save checkpoint every N steps
eval_steps=500, # Evaluate every N steps
logging_steps=100, # Log metrics every N steps
report_to=["tensorboard", "wandb"], # Logging integrations
load_best_model_at_end=True, # Load best model at end of training
metric_for_best_model="wer", # Metric for selecting best model
greater_is_better=False, # Lower WER is better
push_to_hub=True, # Push to Hugging Face Hub
save_total_limit=3, # Maximum number of checkpoints to keep
)
# Initialize trainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=split_dataset['train'],
eval_dataset=split_dataset['test'],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor, # Changed from processing_class to tokenizer
)
# Start training
trainer.train() |