# Import required libraries from datasets import load_dataset, Audio from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer ) import torch from dataclasses import dataclass from typing import Any, Dict, List, Union from functools import partial import evaluate # Load the dataset dataset = load_dataset("") # Specify Data Repo on HF dataset # Split the dataset into train and test sets (80-20 split) split_dataset = dataset['train'].train_test_split(test_size=0.2) split_dataset # Select only the relevant columns for training split_dataset['train'] = split_dataset['train'].select_columns(["audio", "sentence"]) split_dataset['train'] # Initialize the Whisper processor for Swahili transcription processor = WhisperProcessor.from_pretrained( "openai/whisper-small", language="swahili", task="transcribe" ) # Print audio features before and after resampling to match Whisper's expected sampling rate print('BEFORE>>> ', split_dataset['train'].features['audio']) sampling_rate = processor.feature_extractor.sampling_rate split_dataset['train'] = split_dataset['train'].cast_column( "audio", Audio(sampling_rate=sampling_rate) ) print('AFTER>>> ', split_dataset['train'].features['audio']) # Do the same for the test set print('BEFORE>>> ', split_dataset['test'].features['audio']) split_dataset['test'] = split_dataset['test'].cast_column( "audio", Audio(sampling_rate=sampling_rate) ) print('AFTER>>> ', split_dataset['test'].features['audio']) def prepare_dataset(example): """Preprocess audio and text data for Whisper model training""" audio = example["audio"] # Process audio and text using Whisper processor example = processor( audio=audio["array"], sampling_rate=audio["sampling_rate"], text=example["sentence"], ) # Compute input length of audio sample in seconds example["input_length"] = len(audio["array"]) / audio["sampling_rate"] return example # Apply preprocessing to train and test sets split_dataset['train'] = split_dataset['train'].map( prepare_dataset, remove_columns=split_dataset['train'].column_names, num_proc=4 # Use 4 processes for faster preprocessing ) split_dataset['test'] = split_dataset['test'].map( prepare_dataset, remove_columns=split_dataset['test'].column_names, num_proc=1 ) # Filter out audio samples longer than 30 seconds max_input_length = 30.0 def is_audio_in_length_range(length): return length < max_input_length split_dataset['train'] = split_dataset['train'].filter( is_audio_in_length_range, input_columns=["input_length"], ) @dataclass class DataCollatorSpeechSeq2SeqWithPadding: """Custom data collator for Whisper speech-to-sequence tasks with padding""" processor: Any def __call__( self, features: List[Dict[str, Union[List[int], torch.Tensor]]] ) -> Dict[str, torch.Tensor]: # Split inputs and labels since they need different padding methods # First process audio inputs input_features = [ {"input_features": feature["input_features"][0]} for feature in features ] batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") # Process label sequences label_features = [{"input_ids": feature["labels"]} for feature in features] labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") # Replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill( labels_batch.attention_mask.ne(1), -100 ) # Remove BOS token if it was appended previously if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels return batch # Initialize data collator data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) # Load evaluation metric (Word Error Rate) metric = evaluate.load("wer") # Initialize text normalizer for evaluation from transformers.models.whisper.english_normalizer import BasicTextNormalizer normalizer = BasicTextNormalizer() def compute_metrics(pred): """Compute WER (Word Error Rate) metrics for evaluation""" pred_ids = pred.predictions label_ids = pred.label_ids # Replace -100 with pad_token_id label_ids[label_ids == -100] = processor.tokenizer.pad_token_id # Decode predictions and labels pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) label_str = processor.batch_decode(label_ids, skip_special_tokens=True) # Compute orthographic WER wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str) # Compute normalized WER pred_str_norm = [normalizer(pred) for pred in pred_str] label_str_norm = [normalizer(label) for label in label_str] # Filter samples with non-zero references pred_str_norm = [ pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0 ] label_str_norm = [ label_str_norm[i] for i in range(len(label_str_norm)) if len(label_str_norm[i]) > 0 ] wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm) return {"wer_ortho": wer_ortho, "wer": wer} # Load pre-trained Whisper model model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") # Disable cache during training (incompatible with gradient checkpointing) model.config.use_cache = False # Configure generation settings (re-enable cache for generation) model.generate = partial( model.generate, language="swahili", task="transcribe", use_cache=True ) # Set up training arguments training_args = Seq2SeqTrainingArguments( output_dir="./model", # Output directory per_device_train_batch_size=16, # Batch size for training gradient_accumulation_steps=1, # Number of steps before gradient update learning_rate=1e-6, # Learning rate lr_scheduler_type="constant_with_warmup", # Learning rate scheduler warmup_steps=50, # Warmup steps max_steps=10000, # Total training steps gradient_checkpointing=True, # Use gradient checkpointing fp16=True, # Use mixed precision training fp16_full_eval=True, # Use mixed precision evaluation evaluation_strategy="steps", # Evaluation strategy per_device_eval_batch_size=16, # Batch size for evaluation predict_with_generate=True, # Use generation for evaluation generation_max_length=225, # Maximum generation length save_steps=500, # Save checkpoint every N steps eval_steps=500, # Evaluate every N steps logging_steps=100, # Log metrics every N steps report_to=["tensorboard", "wandb"], # Logging integrations load_best_model_at_end=True, # Load best model at end of training metric_for_best_model="wer", # Metric for selecting best model greater_is_better=False, # Lower WER is better push_to_hub=True, # Push to Hugging Face Hub save_total_limit=3, # Maximum number of checkpoints to keep ) # Initialize trainer trainer = Seq2SeqTrainer( args=training_args, model=model, train_dataset=split_dataset['train'], eval_dataset=split_dataset['test'], data_collator=data_collator, compute_metrics=compute_metrics, tokenizer=processor, # Changed from processing_class to tokenizer ) # Start training trainer.train()