Automatic Speech Recognition
Transformers
Safetensors
Swahili
English
whisper
Generated from Trainer
ASR-STT / STT Training Script.py
Jacaranda's picture
Upload STT Training Script.py
00cacca verified
# Import required libraries
from datasets import load_dataset, Audio
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer
)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from functools import partial
import evaluate
# Load the dataset
dataset = load_dataset("") # Specify Data Repo on HF
dataset
# Split the dataset into train and test sets (80-20 split)
split_dataset = dataset['train'].train_test_split(test_size=0.2)
split_dataset
# Select only the relevant columns for training
split_dataset['train'] = split_dataset['train'].select_columns(["audio", "sentence"])
split_dataset['train']
# Initialize the Whisper processor for Swahili transcription
processor = WhisperProcessor.from_pretrained(
"openai/whisper-small",
language="swahili",
task="transcribe"
)
# Print audio features before and after resampling to match Whisper's expected sampling rate
print('BEFORE>>> ', split_dataset['train'].features['audio'])
sampling_rate = processor.feature_extractor.sampling_rate
split_dataset['train'] = split_dataset['train'].cast_column(
"audio",
Audio(sampling_rate=sampling_rate)
)
print('AFTER>>> ', split_dataset['train'].features['audio'])
# Do the same for the test set
print('BEFORE>>> ', split_dataset['test'].features['audio'])
split_dataset['test'] = split_dataset['test'].cast_column(
"audio",
Audio(sampling_rate=sampling_rate)
)
print('AFTER>>> ', split_dataset['test'].features['audio'])
def prepare_dataset(example):
"""Preprocess audio and text data for Whisper model training"""
audio = example["audio"]
# Process audio and text using Whisper processor
example = processor(
audio=audio["array"],
sampling_rate=audio["sampling_rate"],
text=example["sentence"],
)
# Compute input length of audio sample in seconds
example["input_length"] = len(audio["array"]) / audio["sampling_rate"]
return example
# Apply preprocessing to train and test sets
split_dataset['train'] = split_dataset['train'].map(
prepare_dataset,
remove_columns=split_dataset['train'].column_names,
num_proc=4 # Use 4 processes for faster preprocessing
)
split_dataset['test'] = split_dataset['test'].map(
prepare_dataset,
remove_columns=split_dataset['test'].column_names,
num_proc=1
)
# Filter out audio samples longer than 30 seconds
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
split_dataset['train'] = split_dataset['train'].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
"""Custom data collator for Whisper speech-to-sequence tasks with padding"""
processor: Any
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# Split inputs and labels since they need different padding methods
# First process audio inputs
input_features = [
{"input_features": feature["input_features"][0]} for feature in features
]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# Process label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# Replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)
# Remove BOS token if it was appended previously
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
# Initialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# Load evaluation metric (Word Error Rate)
metric = evaluate.load("wer")
# Initialize text normalizer for evaluation
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
normalizer = BasicTextNormalizer()
def compute_metrics(pred):
"""Compute WER (Word Error Rate) metrics for evaluation"""
pred_ids = pred.predictions
label_ids = pred.label_ids
# Replace -100 with pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# Decode predictions and labels
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
# Compute orthographic WER
wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
# Compute normalized WER
pred_str_norm = [normalizer(pred) for pred in pred_str]
label_str_norm = [normalizer(label) for label in label_str]
# Filter samples with non-zero references
pred_str_norm = [
pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
]
label_str_norm = [
label_str_norm[i] for i in range(len(label_str_norm)) if len(label_str_norm[i]) > 0
]
wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)
return {"wer_ortho": wer_ortho, "wer": wer}
# Load pre-trained Whisper model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
# Disable cache during training (incompatible with gradient checkpointing)
model.config.use_cache = False
# Configure generation settings (re-enable cache for generation)
model.generate = partial(
model.generate,
language="swahili",
task="transcribe",
use_cache=True
)
# Set up training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./model", # Output directory
per_device_train_batch_size=16, # Batch size for training
gradient_accumulation_steps=1, # Number of steps before gradient update
learning_rate=1e-6, # Learning rate
lr_scheduler_type="constant_with_warmup", # Learning rate scheduler
warmup_steps=50, # Warmup steps
max_steps=10000, # Total training steps
gradient_checkpointing=True, # Use gradient checkpointing
fp16=True, # Use mixed precision training
fp16_full_eval=True, # Use mixed precision evaluation
evaluation_strategy="steps", # Evaluation strategy
per_device_eval_batch_size=16, # Batch size for evaluation
predict_with_generate=True, # Use generation for evaluation
generation_max_length=225, # Maximum generation length
save_steps=500, # Save checkpoint every N steps
eval_steps=500, # Evaluate every N steps
logging_steps=100, # Log metrics every N steps
report_to=["tensorboard", "wandb"], # Logging integrations
load_best_model_at_end=True, # Load best model at end of training
metric_for_best_model="wer", # Metric for selecting best model
greater_is_better=False, # Lower WER is better
push_to_hub=True, # Push to Hugging Face Hub
save_total_limit=3, # Maximum number of checkpoints to keep
)
# Initialize trainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=split_dataset['train'],
eval_dataset=split_dataset['test'],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor, # Changed from processing_class to tokenizer
)
# Start training
trainer.train()