Automatic Speech Recognition
Transformers
Safetensors
Swahili
English
whisper
Generated from Trainer
File size: 7,964 Bytes
00cacca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Import required libraries
from datasets import load_dataset, Audio
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from functools import partial
import evaluate

# Load the dataset
dataset = load_dataset("")          # Specify Data Repo on HF
dataset

# Split the dataset into train and test sets (80-20 split)
split_dataset = dataset['train'].train_test_split(test_size=0.2)
split_dataset

# Select only the relevant columns for training
split_dataset['train'] = split_dataset['train'].select_columns(["audio", "sentence"])
split_dataset['train']

# Initialize the Whisper processor for Swahili transcription
processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small", 
    language="swahili", 
    task="transcribe"
)

# Print audio features before and after resampling to match Whisper's expected sampling rate
print('BEFORE>>> ', split_dataset['train'].features['audio'])
sampling_rate = processor.feature_extractor.sampling_rate
split_dataset['train'] = split_dataset['train'].cast_column(
    "audio", 
    Audio(sampling_rate=sampling_rate)
)
print('AFTER>>> ', split_dataset['train'].features['audio'])

# Do the same for the test set
print('BEFORE>>> ', split_dataset['test'].features['audio'])
split_dataset['test'] = split_dataset['test'].cast_column(
    "audio", 
    Audio(sampling_rate=sampling_rate)
)
print('AFTER>>> ', split_dataset['test'].features['audio'])

def prepare_dataset(example):
    """Preprocess audio and text data for Whisper model training"""
    audio = example["audio"]
    
    # Process audio and text using Whisper processor
    example = processor(
        audio=audio["array"],
        sampling_rate=audio["sampling_rate"],
        text=example["sentence"],
    )
    
    # Compute input length of audio sample in seconds
    example["input_length"] = len(audio["array"]) / audio["sampling_rate"]
    
    return example

# Apply preprocessing to train and test sets
split_dataset['train'] = split_dataset['train'].map(
    prepare_dataset, 
    remove_columns=split_dataset['train'].column_names, 
    num_proc=4  # Use 4 processes for faster preprocessing
)

split_dataset['test'] = split_dataset['test'].map(
    prepare_dataset, 
    remove_columns=split_dataset['test'].column_names, 
    num_proc=1
)

# Filter out audio samples longer than 30 seconds
max_input_length = 30.0
def is_audio_in_length_range(length):
    return length < max_input_length

split_dataset['train'] = split_dataset['train'].filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """Custom data collator for Whisper speech-to-sequence tasks with padding"""
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they need different padding methods
        # First process audio inputs
        input_features = [
            {"input_features": feature["input_features"][0]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Process label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # Remove BOS token if it was appended previously
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

# Initialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# Load evaluation metric (Word Error Rate)
metric = evaluate.load("wer")

# Initialize text normalizer for evaluation
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
normalizer = BasicTextNormalizer()

def compute_metrics(pred):
    """Compute WER (Word Error Rate) metrics for evaluation"""
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # Compute orthographic WER
    wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)

    # Compute normalized WER
    pred_str_norm = [normalizer(pred) for pred in pred_str]
    label_str_norm = [normalizer(label) for label in label_str]
    
    # Filter samples with non-zero references
    pred_str_norm = [
        pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
    ]
    label_str_norm = [
        label_str_norm[i] for i in range(len(label_str_norm)) if len(label_str_norm[i]) > 0
    ]

    wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)

    return {"wer_ortho": wer_ortho, "wer": wer}

# Load pre-trained Whisper model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

# Disable cache during training (incompatible with gradient checkpointing)
model.config.use_cache = False

# Configure generation settings (re-enable cache for generation)
model.generate = partial(
    model.generate, 
    language="swahili", 
    task="transcribe", 
    use_cache=True
)

# Set up training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./model",                      # Output directory
    per_device_train_batch_size=16,            # Batch size for training
    gradient_accumulation_steps=1,             # Number of steps before gradient update
    learning_rate=1e-6,                        # Learning rate
    lr_scheduler_type="constant_with_warmup",  # Learning rate scheduler
    warmup_steps=50,                           # Warmup steps
    max_steps=10000,                           # Total training steps
    gradient_checkpointing=True,               # Use gradient checkpointing
    fp16=True,                                 # Use mixed precision training
    fp16_full_eval=True,                       # Use mixed precision evaluation
    evaluation_strategy="steps",               # Evaluation strategy
    per_device_eval_batch_size=16,             # Batch size for evaluation
    predict_with_generate=True,                # Use generation for evaluation
    generation_max_length=225,                 # Maximum generation length
    save_steps=500,                            # Save checkpoint every N steps
    eval_steps=500,                            # Evaluate every N steps
    logging_steps=100,                         # Log metrics every N steps
    report_to=["tensorboard", "wandb"],        # Logging integrations
    load_best_model_at_end=True,               # Load best model at end of training
    metric_for_best_model="wer",               # Metric for selecting best model
    greater_is_better=False,                   # Lower WER is better
    push_to_hub=True,                          # Push to Hugging Face Hub
    save_total_limit=3,                        # Maximum number of checkpoints to keep
)

# Initialize trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=split_dataset['train'],
    eval_dataset=split_dataset['test'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,  # Changed from processing_class to tokenizer
)

# Start training
trainer.train()