Automatic Speech Recognition
Transformers
Safetensors
Swahili
English
whisper
Generated from Trainer
Jacaranda commited on
Commit
00cacca
·
verified ·
1 Parent(s): fb01186

Upload STT Training Script.py

Browse files
Files changed (1) hide show
  1. STT Training Script.py +216 -0
STT Training Script.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import required libraries
2
+ from datasets import load_dataset, Audio
3
+ from transformers import (
4
+ WhisperProcessor,
5
+ WhisperForConditionalGeneration,
6
+ Seq2SeqTrainingArguments,
7
+ Seq2SeqTrainer
8
+ )
9
+ import torch
10
+ from dataclasses import dataclass
11
+ from typing import Any, Dict, List, Union
12
+ from functools import partial
13
+ import evaluate
14
+
15
+ # Load the dataset
16
+ dataset = load_dataset("") # Specify Data Repo on HF
17
+ dataset
18
+
19
+ # Split the dataset into train and test sets (80-20 split)
20
+ split_dataset = dataset['train'].train_test_split(test_size=0.2)
21
+ split_dataset
22
+
23
+ # Select only the relevant columns for training
24
+ split_dataset['train'] = split_dataset['train'].select_columns(["audio", "sentence"])
25
+ split_dataset['train']
26
+
27
+ # Initialize the Whisper processor for Swahili transcription
28
+ processor = WhisperProcessor.from_pretrained(
29
+ "openai/whisper-small",
30
+ language="swahili",
31
+ task="transcribe"
32
+ )
33
+
34
+ # Print audio features before and after resampling to match Whisper's expected sampling rate
35
+ print('BEFORE>>> ', split_dataset['train'].features['audio'])
36
+ sampling_rate = processor.feature_extractor.sampling_rate
37
+ split_dataset['train'] = split_dataset['train'].cast_column(
38
+ "audio",
39
+ Audio(sampling_rate=sampling_rate)
40
+ )
41
+ print('AFTER>>> ', split_dataset['train'].features['audio'])
42
+
43
+ # Do the same for the test set
44
+ print('BEFORE>>> ', split_dataset['test'].features['audio'])
45
+ split_dataset['test'] = split_dataset['test'].cast_column(
46
+ "audio",
47
+ Audio(sampling_rate=sampling_rate)
48
+ )
49
+ print('AFTER>>> ', split_dataset['test'].features['audio'])
50
+
51
+ def prepare_dataset(example):
52
+ """Preprocess audio and text data for Whisper model training"""
53
+ audio = example["audio"]
54
+
55
+ # Process audio and text using Whisper processor
56
+ example = processor(
57
+ audio=audio["array"],
58
+ sampling_rate=audio["sampling_rate"],
59
+ text=example["sentence"],
60
+ )
61
+
62
+ # Compute input length of audio sample in seconds
63
+ example["input_length"] = len(audio["array"]) / audio["sampling_rate"]
64
+
65
+ return example
66
+
67
+ # Apply preprocessing to train and test sets
68
+ split_dataset['train'] = split_dataset['train'].map(
69
+ prepare_dataset,
70
+ remove_columns=split_dataset['train'].column_names,
71
+ num_proc=4 # Use 4 processes for faster preprocessing
72
+ )
73
+
74
+ split_dataset['test'] = split_dataset['test'].map(
75
+ prepare_dataset,
76
+ remove_columns=split_dataset['test'].column_names,
77
+ num_proc=1
78
+ )
79
+
80
+ # Filter out audio samples longer than 30 seconds
81
+ max_input_length = 30.0
82
+ def is_audio_in_length_range(length):
83
+ return length < max_input_length
84
+
85
+ split_dataset['train'] = split_dataset['train'].filter(
86
+ is_audio_in_length_range,
87
+ input_columns=["input_length"],
88
+ )
89
+
90
+ @dataclass
91
+ class DataCollatorSpeechSeq2SeqWithPadding:
92
+ """Custom data collator for Whisper speech-to-sequence tasks with padding"""
93
+ processor: Any
94
+
95
+ def __call__(
96
+ self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
97
+ ) -> Dict[str, torch.Tensor]:
98
+ # Split inputs and labels since they need different padding methods
99
+ # First process audio inputs
100
+ input_features = [
101
+ {"input_features": feature["input_features"][0]} for feature in features
102
+ ]
103
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
104
+
105
+ # Process label sequences
106
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
107
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
108
+
109
+ # Replace padding with -100 to ignore loss correctly
110
+ labels = labels_batch["input_ids"].masked_fill(
111
+ labels_batch.attention_mask.ne(1), -100
112
+ )
113
+
114
+ # Remove BOS token if it was appended previously
115
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
116
+ labels = labels[:, 1:]
117
+
118
+ batch["labels"] = labels
119
+
120
+ return batch
121
+
122
+ # Initialize data collator
123
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
124
+
125
+ # Load evaluation metric (Word Error Rate)
126
+ metric = evaluate.load("wer")
127
+
128
+ # Initialize text normalizer for evaluation
129
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
130
+ normalizer = BasicTextNormalizer()
131
+
132
+ def compute_metrics(pred):
133
+ """Compute WER (Word Error Rate) metrics for evaluation"""
134
+ pred_ids = pred.predictions
135
+ label_ids = pred.label_ids
136
+
137
+ # Replace -100 with pad_token_id
138
+ label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
139
+
140
+ # Decode predictions and labels
141
+ pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
142
+ label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
143
+
144
+ # Compute orthographic WER
145
+ wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
146
+
147
+ # Compute normalized WER
148
+ pred_str_norm = [normalizer(pred) for pred in pred_str]
149
+ label_str_norm = [normalizer(label) for label in label_str]
150
+
151
+ # Filter samples with non-zero references
152
+ pred_str_norm = [
153
+ pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
154
+ ]
155
+ label_str_norm = [
156
+ label_str_norm[i] for i in range(len(label_str_norm)) if len(label_str_norm[i]) > 0
157
+ ]
158
+
159
+ wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)
160
+
161
+ return {"wer_ortho": wer_ortho, "wer": wer}
162
+
163
+ # Load pre-trained Whisper model
164
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
165
+
166
+ # Disable cache during training (incompatible with gradient checkpointing)
167
+ model.config.use_cache = False
168
+
169
+ # Configure generation settings (re-enable cache for generation)
170
+ model.generate = partial(
171
+ model.generate,
172
+ language="swahili",
173
+ task="transcribe",
174
+ use_cache=True
175
+ )
176
+
177
+ # Set up training arguments
178
+ training_args = Seq2SeqTrainingArguments(
179
+ output_dir="./model", # Output directory
180
+ per_device_train_batch_size=16, # Batch size for training
181
+ gradient_accumulation_steps=1, # Number of steps before gradient update
182
+ learning_rate=1e-6, # Learning rate
183
+ lr_scheduler_type="constant_with_warmup", # Learning rate scheduler
184
+ warmup_steps=50, # Warmup steps
185
+ max_steps=10000, # Total training steps
186
+ gradient_checkpointing=True, # Use gradient checkpointing
187
+ fp16=True, # Use mixed precision training
188
+ fp16_full_eval=True, # Use mixed precision evaluation
189
+ evaluation_strategy="steps", # Evaluation strategy
190
+ per_device_eval_batch_size=16, # Batch size for evaluation
191
+ predict_with_generate=True, # Use generation for evaluation
192
+ generation_max_length=225, # Maximum generation length
193
+ save_steps=500, # Save checkpoint every N steps
194
+ eval_steps=500, # Evaluate every N steps
195
+ logging_steps=100, # Log metrics every N steps
196
+ report_to=["tensorboard", "wandb"], # Logging integrations
197
+ load_best_model_at_end=True, # Load best model at end of training
198
+ metric_for_best_model="wer", # Metric for selecting best model
199
+ greater_is_better=False, # Lower WER is better
200
+ push_to_hub=True, # Push to Hugging Face Hub
201
+ save_total_limit=3, # Maximum number of checkpoints to keep
202
+ )
203
+
204
+ # Initialize trainer
205
+ trainer = Seq2SeqTrainer(
206
+ args=training_args,
207
+ model=model,
208
+ train_dataset=split_dataset['train'],
209
+ eval_dataset=split_dataset['test'],
210
+ data_collator=data_collator,
211
+ compute_metrics=compute_metrics,
212
+ tokenizer=processor, # Changed from processing_class to tokenizer
213
+ )
214
+
215
+ # Start training
216
+ trainer.train()