elsayedissa commited on
Commit
dab7ac0
1 Parent(s): 57479a4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +87 -0
README.md CHANGED
@@ -76,6 +76,93 @@ The following hyperparameters were used during training:
76
  | 0.1337 | 0.83 | 24000 | 0.1472 | 0.0854 |
77
  | 0.1289 | 0.87 | 25000 | 0.1466 | 0.0855 |
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  ### Framework versions
81
 
 
76
  | 0.1337 | 0.83 | 24000 | 0.1472 | 0.0854 |
77
  | 0.1289 | 0.87 | 25000 | 0.1466 | 0.0855 |
78
 
79
+ ### Transcription:
80
+
81
+ ```python
82
+ from datasets import load_dataset, Audio
83
+ import torch
84
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
85
+
86
+ # device
87
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+
89
+ # load the model
90
+ processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-spanish")
91
+ model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-spanish").to(device)
92
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="es", task="transcribe")
93
+
94
+ # load the dataset
95
+ commonvoice_eval = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="validation", streaming=True)
96
+ commonvoice_eval = commonvoice_eval.cast_column("audio", Audio(sampling_rate=16000))
97
+ sample = next(iter(commonvoice_eval))["audio"]
98
+
99
+ # features and generate token ids
100
+ input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
101
+ predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
102
+
103
+ # decode
104
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
105
+
106
+ print(transcription)
107
+
108
+ ```
109
+
110
+ ### Evaluation:
111
+
112
+ Evaluates this model on `mozilla-foundation/common_voice_11_0` test split.
113
+
114
+ ```python
115
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
116
+ from datasets import load_dataset, Audio
117
+ import evaluate
118
+ import torch
119
+ import re
120
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
121
+
122
+ # device
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+
125
+ # metric
126
+ wer_metric = evaluate.load("wer")
127
+
128
+ # model
129
+ processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-spanish")
130
+ model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-spanish")
131
+
132
+ # dataset
133
+ dataset = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", )#cache_dir=args.cache_dir
134
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
135
+
136
+ #for debuggings: it gets some examples
137
+ #dataset = dataset.shard(num_shards=10000, index=0)
138
+ #print(dataset)
139
+
140
+ def normalize(batch):
141
+ batch["gold_text"] = whisper_norm(batch['sentence'])
142
+ return batch
143
+
144
+ def map_wer(batch):
145
+ model.to(device)
146
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language = "es", task = "transcribe")
147
+ inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
148
+ with torch.no_grad():
149
+ generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
150
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
151
+ batch["predicted_text"] = whisper_norm(transcription)
152
+ return batch
153
+
154
+ # process GOLD text
155
+ processed_dataset = dataset.map(normalize)
156
+ # get predictions
157
+ predicted = processed_dataset.map(map_wer)
158
+
159
+ # word error rate
160
+ wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
161
+ wer = round(100 * wer, 2)
162
+ print("WER:", wer)
163
+
164
+
165
+ ```
166
 
167
  ### Framework versions
168