DocWolle commited on
Commit
461155f
·
verified ·
1 Parent(s): d9d43ce

Add new model for top world languages supported by Whisper

Browse files
Generate_tflite_for_whisper_base_TOP_WORLD_version.ipynb ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "c5g9NTF_Ixad"
7
+ },
8
+ "source": [
9
+ "##Install Tranformers and datasets"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "w4VPaSlnHUvT",
17
+ "collapsed": true
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "!pip install transformers==4.33.0\n",
22
+ "!pip install tensorflow==2.14.0\n",
23
+ "!pip install numpy==1.26.4"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {
30
+ "id": "ClniiYCWHK4b",
31
+ "collapsed": true
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "! pip install datasets"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {
41
+ "id": "pljpioLsJOtb"
42
+ },
43
+ "source": [
44
+ "##Load pre trained TF Whisper Base model"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {
51
+ "id": "BJNOxn5vHaGi"
52
+ },
53
+ "outputs": [],
54
+ "source": [
55
+ "import tensorflow as tf\n",
56
+ "from transformers import TFWhisperModel, WhisperFeatureExtractor\n",
57
+ "from datasets import load_dataset\n",
58
+ "\n",
59
+ "model = TFWhisperModel.from_pretrained(\"openai/whisper-base\")\n",
60
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-base\")\n",
61
+ "\n",
62
+ "ds = load_dataset(\"google/fleurs\", \"fr_fr\", split=\"test\")\n",
63
+ "inputs = feature_extractor(\n",
64
+ " ds[0][\"audio\"][\"array\"], sampling_rate=ds[0][\"audio\"][\"sampling_rate\"], return_tensors=\"tf\"\n",
65
+ ")\n",
66
+ "input_features = inputs.input_features\n",
67
+ "print(input_features)\n",
68
+ "decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id\n",
69
+ "last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state\n",
70
+ "list(last_hidden_state.shape)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {
76
+ "id": "W9XP25uhJl44"
77
+ },
78
+ "source": [
79
+ "##Generate Saved model"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "id": "vpYwMmgyHf0B"
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "model.save('/content/tf_whisper_saved')"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {
96
+ "id": "TY_79jFEJYyJ"
97
+ },
98
+ "source": [
99
+ "##Convert saved model to TFLite model"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {
106
+ "id": "owez2zvzHl-p"
107
+ },
108
+ "outputs": [],
109
+ "source": [
110
+ "import tensorflow as tf\n",
111
+ "\n",
112
+ "saved_model_dir = '/content/tf_whisper_saved'\n",
113
+ "tflite_model_path = 'whisper.tflite'\n",
114
+ "\n",
115
+ "# Convert the model\n",
116
+ "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
117
+ "converter.target_spec.supported_ops = [\n",
118
+ " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n",
119
+ " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n",
120
+ "]\n",
121
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
122
+ "tflite_model = converter.convert()\n",
123
+ "\n",
124
+ "# Save the model\n",
125
+ "with open(tflite_model_path, 'wb') as f:\n",
126
+ " f.write(tflite_model)"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "metadata": {
133
+ "id": "tFkzUrjIbNcH"
134
+ },
135
+ "outputs": [],
136
+ "source": [
137
+ "%ls -la"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "metadata": {
143
+ "id": "fpEnWZt7iQJK"
144
+ },
145
+ "source": [
146
+ "##Evaluate TF model"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {
153
+ "id": "-RuFFohHg2ho"
154
+ },
155
+ "outputs": [],
156
+ "source": [
157
+ "import tensorflow as tf\n",
158
+ "from transformers import WhisperProcessor, TFWhisperForConditionalGeneration\n",
159
+ "from datasets import load_dataset\n",
160
+ "\n",
161
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-base\")\n",
162
+ "model = TFWhisperForConditionalGeneration.from_pretrained(\"openai/whisper-base\")\n",
163
+ "\n",
164
+ "ds = load_dataset(\"google/fleurs\", \"fr_fr\", split=\"test\")\n",
165
+ "\n",
166
+ "inputs = processor(ds[0][\"audio\"][\"array\"], return_tensors=\"tf\")\n",
167
+ "input_features = inputs.input_features\n",
168
+ "\n",
169
+ "generated_ids = model.generate(input_features)\n",
170
+ "\n",
171
+ "transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]\n",
172
+ "transcription"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {
178
+ "id": "U-eKuy_cG4u0"
179
+ },
180
+ "source": [
181
+ "## Evaluate TF Lite model (naive)\n",
182
+ "\n",
183
+ "We can load the model as defined above... but the model is useless on its own. Generation is much more complex that a model forward pass."
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {
190
+ "id": "wnfHirgyG0W4"
191
+ },
192
+ "outputs": [],
193
+ "source": [
194
+ "tflite_model_path = 'whisper.tflite'\n",
195
+ "interpreter = tf.lite.Interpreter(tflite_model_path)"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "metadata": {
201
+ "id": "a8VJQuHJKzl4"
202
+ },
203
+ "source": [
204
+ "## Create generation-enabled TF Lite model\n",
205
+ "\n",
206
+ "The solution consists in defining a model whose serving function is the generation call. Here's an example of how to do it:"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "markdown",
211
+ "metadata": {
212
+ "id": "JmIgqWVgVBZN"
213
+ },
214
+ "source": [
215
+ "Now with monkey-patch for fixing NaN errors with -inf values"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 3,
221
+ "metadata": {
222
+ "id": "e5P8s66yU7Kv"
223
+ },
224
+ "outputs": [],
225
+ "source": [
226
+ "import tensorflow as tf\n",
227
+ "import numpy as np\n",
228
+ "from transformers import TFForceTokensLogitsProcessor, TFLogitsProcessor\n",
229
+ "from typing import List, Optional, Union, Any\n",
230
+ "\n",
231
+ "# Patching methods of class TFForceTokensLogitsProcessor(TFLogitsProcessor):\n",
232
+ "\n",
233
+ "def my__init__(self, force_token_map: List[List[int]]):\n",
234
+ " force_token_map = dict(force_token_map)\n",
235
+ " # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the\n",
236
+ " # index of the array corresponds to the index of the token to be forced, for XLA compatibility.\n",
237
+ " # Indexes without forced tokens will have an negative value.\n",
238
+ " force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1\n",
239
+ " for index, token in force_token_map.items():\n",
240
+ " if token is not None:\n",
241
+ " force_token_array[index] = token\n",
242
+ " self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)\n",
243
+ "\n",
244
+ "def my__call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:\n",
245
+ " def _force_token(generation_idx):\n",
246
+ " batch_size = scores.shape[0]\n",
247
+ " current_token = self.force_token_array[generation_idx]\n",
248
+ "\n",
249
+ " # Original code below generates NaN values when the model is exported to tflite\n",
250
+ " # it just needs to be a negative number so that the forced token's value of 0 is the largest\n",
251
+ " # so it will get chosen\n",
252
+ " #new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float(\"inf\")\n",
253
+ " new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float(1)\n",
254
+ " indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)\n",
255
+ " updates = tf.zeros((batch_size,), dtype=scores.dtype)\n",
256
+ " new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)\n",
257
+ " return new_scores\n",
258
+ "\n",
259
+ " scores = tf.cond(\n",
260
+ " tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),\n",
261
+ " # If the current length is geq than the length of force_token_array, the processor does nothing.\n",
262
+ " lambda: tf.identity(scores),\n",
263
+ " # Otherwise, it may force a certain token.\n",
264
+ " lambda: tf.cond(\n",
265
+ " tf.greater_equal(self.force_token_array[cur_len], 0),\n",
266
+ " # Only valid (positive) tokens are forced\n",
267
+ " lambda: _force_token(cur_len),\n",
268
+ " # Otherwise, the processor does nothing.\n",
269
+ " lambda: scores,\n",
270
+ " ),\n",
271
+ " )\n",
272
+ " return scores\n",
273
+ "\n",
274
+ "TFForceTokensLogitsProcessor.__init__ = my__init__\n",
275
+ "TFForceTokensLogitsProcessor.__call__ = my__call__"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {
282
+ "id": "rIkUCdiyU7ZT"
283
+ },
284
+ "outputs": [],
285
+ "source": [
286
+ "import tensorflow as tf\n",
287
+ "\n",
288
+ "class GenerateModel(tf.Module):\n",
289
+ " def __init__(self, model):\n",
290
+ " super(GenerateModel, self).__init__()\n",
291
+ " self.model = model\n",
292
+ "\n",
293
+ " @tf.function(\n",
294
+ " input_signature=[\n",
295
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
296
+ " tf.TensorSpec((1,), tf.int32, name=\"lang_token\"),\n",
297
+ " ],\n",
298
+ " )\n",
299
+ " def transcribe_lang(self, input_features, lang_token):\n",
300
+ " if lang_token == 50259:\n",
301
+ " outputs = self.model.generate(\n",
302
+ " input_features,\n",
303
+ " max_new_tokens=450,\n",
304
+ " return_dict_in_generate=True,\n",
305
+ " forced_decoder_ids=[[1, 50259], [2, 50359], [3, 50363]],\n",
306
+ " )\n",
307
+ "\n",
308
+ " elif lang_token == 50260:\n",
309
+ " outputs = self.model.generate(\n",
310
+ " input_features,\n",
311
+ " max_new_tokens=450,\n",
312
+ " return_dict_in_generate=True,\n",
313
+ " forced_decoder_ids=[[1, 50260], [2, 50359], [3, 50363]],\n",
314
+ " )\n",
315
+ "\n",
316
+ " elif lang_token == 50261:\n",
317
+ " outputs = self.model.generate(\n",
318
+ " input_features,\n",
319
+ " max_new_tokens=450,\n",
320
+ " return_dict_in_generate=True,\n",
321
+ " forced_decoder_ids=[[1, 50261], [2, 50359], [3, 50363]],\n",
322
+ " )\n",
323
+ "\n",
324
+ " elif lang_token == 50262:\n",
325
+ " outputs = self.model.generate(\n",
326
+ " input_features,\n",
327
+ " max_new_tokens=450,\n",
328
+ " return_dict_in_generate=True,\n",
329
+ " forced_decoder_ids=[[1, 50262], [2, 50359], [3, 50363]],\n",
330
+ " )\n",
331
+ "\n",
332
+ " elif lang_token == 50263:\n",
333
+ " outputs = self.model.generate(\n",
334
+ " input_features,\n",
335
+ " max_new_tokens=450,\n",
336
+ " return_dict_in_generate=True,\n",
337
+ " forced_decoder_ids=[[1, 50263], [2, 50359], [3, 50363]],\n",
338
+ " )\n",
339
+ "\n",
340
+ " elif lang_token == 50264:\n",
341
+ " outputs = self.model.generate(\n",
342
+ " input_features,\n",
343
+ " max_new_tokens=450,\n",
344
+ " return_dict_in_generate=True,\n",
345
+ " forced_decoder_ids=[[1, 50264], [2, 50359], [3, 50363]],\n",
346
+ " )\n",
347
+ "\n",
348
+ " elif lang_token == 50265:\n",
349
+ " outputs = self.model.generate(\n",
350
+ " input_features,\n",
351
+ " max_new_tokens=450,\n",
352
+ " return_dict_in_generate=True,\n",
353
+ " forced_decoder_ids=[[1, 50265], [2, 50359], [3, 50363]],\n",
354
+ " )\n",
355
+ "\n",
356
+ " elif lang_token == 50266:\n",
357
+ " outputs = self.model.generate(\n",
358
+ " input_features,\n",
359
+ " max_new_tokens=450,\n",
360
+ " return_dict_in_generate=True,\n",
361
+ " forced_decoder_ids=[[1, 50266], [2, 50359], [3, 50363]],\n",
362
+ " )\n",
363
+ "\n",
364
+ " elif lang_token == 50267:\n",
365
+ " outputs = self.model.generate(\n",
366
+ " input_features,\n",
367
+ " max_new_tokens=450,\n",
368
+ " return_dict_in_generate=True,\n",
369
+ " forced_decoder_ids=[[1, 50267], [2, 50359], [3, 50363]],\n",
370
+ " )\n",
371
+ "\n",
372
+ " elif lang_token == 50268:\n",
373
+ " outputs = self.model.generate(\n",
374
+ " input_features,\n",
375
+ " max_new_tokens=450,\n",
376
+ " return_dict_in_generate=True,\n",
377
+ " forced_decoder_ids=[[1, 50268], [2, 50359], [3, 50363]],\n",
378
+ " )\n",
379
+ "\n",
380
+ " elif lang_token == 50269:\n",
381
+ " outputs = self.model.generate(\n",
382
+ " input_features,\n",
383
+ " max_new_tokens=450,\n",
384
+ " return_dict_in_generate=True,\n",
385
+ " forced_decoder_ids=[[1, 50269], [2, 50359], [3, 50363]],\n",
386
+ " )\n",
387
+ "\n",
388
+ " elif lang_token == 50270:\n",
389
+ " outputs = self.model.generate(\n",
390
+ " input_features,\n",
391
+ " max_new_tokens=450,\n",
392
+ " return_dict_in_generate=True,\n",
393
+ " forced_decoder_ids=[[1, 50270], [2, 50359], [3, 50363]],\n",
394
+ " )\n",
395
+ "\n",
396
+ " elif lang_token == 50271:\n",
397
+ " outputs = self.model.generate(\n",
398
+ " input_features,\n",
399
+ " max_new_tokens=450,\n",
400
+ " return_dict_in_generate=True,\n",
401
+ " forced_decoder_ids=[[1, 50271], [2, 50359], [3, 50363]],\n",
402
+ " )\n",
403
+ "\n",
404
+ " elif lang_token == 50272:\n",
405
+ " outputs = self.model.generate(\n",
406
+ " input_features,\n",
407
+ " max_new_tokens=450,\n",
408
+ " return_dict_in_generate=True,\n",
409
+ " forced_decoder_ids=[[1, 50272], [2, 50359], [3, 50363]],\n",
410
+ " )\n",
411
+ "\n",
412
+ " elif lang_token == 50273:\n",
413
+ " outputs = self.model.generate(\n",
414
+ " input_features,\n",
415
+ " max_new_tokens=450,\n",
416
+ " return_dict_in_generate=True,\n",
417
+ " forced_decoder_ids=[[1, 50273], [2, 50359], [3, 50363]],\n",
418
+ " )\n",
419
+ "\n",
420
+ " elif lang_token == 50274:\n",
421
+ " outputs = self.model.generate(\n",
422
+ " input_features,\n",
423
+ " max_new_tokens=450,\n",
424
+ " return_dict_in_generate=True,\n",
425
+ " forced_decoder_ids=[[1, 50274], [2, 50359], [3, 50363]],\n",
426
+ " )\n",
427
+ "\n",
428
+ " elif lang_token == 50275:\n",
429
+ " outputs = self.model.generate(\n",
430
+ " input_features,\n",
431
+ " max_new_tokens=450,\n",
432
+ " return_dict_in_generate=True,\n",
433
+ " forced_decoder_ids=[[1, 50275], [2, 50359], [3, 50363]],\n",
434
+ " )\n",
435
+ "\n",
436
+ " elif lang_token == 50277:\n",
437
+ " outputs = self.model.generate(\n",
438
+ " input_features,\n",
439
+ " max_new_tokens=450,\n",
440
+ " return_dict_in_generate=True,\n",
441
+ " forced_decoder_ids=[[1, 50277], [2, 50359], [3, 50363]],\n",
442
+ " )\n",
443
+ "\n",
444
+ " elif lang_token == 50278:\n",
445
+ " outputs = self.model.generate(\n",
446
+ " input_features,\n",
447
+ " max_new_tokens=450,\n",
448
+ " return_dict_in_generate=True,\n",
449
+ " forced_decoder_ids=[[1, 50278], [2, 50359], [3, 50363]],\n",
450
+ " )\n",
451
+ "\n",
452
+ " elif lang_token == 50279:\n",
453
+ " outputs = self.model.generate(\n",
454
+ " input_features,\n",
455
+ " max_new_tokens=450,\n",
456
+ " return_dict_in_generate=True,\n",
457
+ " forced_decoder_ids=[[1, 50279], [2, 50359], [3, 50363]],\n",
458
+ " )\n",
459
+ "\n",
460
+ " elif lang_token == 50280:\n",
461
+ " outputs = self.model.generate(\n",
462
+ " input_features,\n",
463
+ " max_new_tokens=450,\n",
464
+ " return_dict_in_generate=True,\n",
465
+ " forced_decoder_ids=[[1, 50280], [2, 50359], [3, 50363]],\n",
466
+ " )\n",
467
+ "\n",
468
+ " elif lang_token == 50281:\n",
469
+ " outputs = self.model.generate(\n",
470
+ " input_features,\n",
471
+ " max_new_tokens=450,\n",
472
+ " return_dict_in_generate=True,\n",
473
+ " forced_decoder_ids=[[1, 50281], [2, 50359], [3, 50363]],\n",
474
+ " )\n",
475
+ "\n",
476
+ " elif lang_token == 50282:\n",
477
+ " outputs = self.model.generate(\n",
478
+ " input_features,\n",
479
+ " max_new_tokens=450,\n",
480
+ " return_dict_in_generate=True,\n",
481
+ " forced_decoder_ids=[[1, 50282], [2, 50359], [3, 50363]],\n",
482
+ " )\n",
483
+ "\n",
484
+ " elif lang_token == 50283:\n",
485
+ " outputs = self.model.generate(\n",
486
+ " input_features,\n",
487
+ " max_new_tokens=450,\n",
488
+ " return_dict_in_generate=True,\n",
489
+ " forced_decoder_ids=[[1, 50283], [2, 50359], [3, 50363]],\n",
490
+ " )\n",
491
+ "\n",
492
+ " elif lang_token == 50284:\n",
493
+ " outputs = self.model.generate(\n",
494
+ " input_features,\n",
495
+ " max_new_tokens=450,\n",
496
+ " return_dict_in_generate=True,\n",
497
+ " forced_decoder_ids=[[1, 50284], [2, 50359], [3, 50363]],\n",
498
+ " )\n",
499
+ "\n",
500
+ " elif lang_token == 50285:\n",
501
+ " outputs = self.model.generate(\n",
502
+ " input_features,\n",
503
+ " max_new_tokens=450,\n",
504
+ " return_dict_in_generate=True,\n",
505
+ " forced_decoder_ids=[[1, 50285], [2, 50359], [3, 50363]],\n",
506
+ " )\n",
507
+ "\n",
508
+ " elif lang_token == 50286:\n",
509
+ " outputs = self.model.generate(\n",
510
+ " input_features,\n",
511
+ " max_new_tokens=450,\n",
512
+ " return_dict_in_generate=True,\n",
513
+ " forced_decoder_ids=[[1, 50286], [2, 50359], [3, 50363]],\n",
514
+ " )\n",
515
+ "\n",
516
+ " elif lang_token == 50287:\n",
517
+ " outputs = self.model.generate(\n",
518
+ " input_features,\n",
519
+ " max_new_tokens=450,\n",
520
+ " return_dict_in_generate=True,\n",
521
+ " forced_decoder_ids=[[1, 50287], [2, 50359], [3, 50363]],\n",
522
+ " )\n",
523
+ "\n",
524
+ " elif lang_token == 50288:\n",
525
+ " outputs = self.model.generate(\n",
526
+ " input_features,\n",
527
+ " max_new_tokens=450,\n",
528
+ " return_dict_in_generate=True,\n",
529
+ " forced_decoder_ids=[[1, 50288], [2, 50359], [3, 50363]],\n",
530
+ " )\n",
531
+ "\n",
532
+ " elif lang_token == 50289:\n",
533
+ " outputs = self.model.generate(\n",
534
+ " input_features,\n",
535
+ " max_new_tokens=450,\n",
536
+ " return_dict_in_generate=True,\n",
537
+ " forced_decoder_ids=[[1, 50289], [2, 50359], [3, 50363]],\n",
538
+ " )\n",
539
+ "\n",
540
+ " elif lang_token == 50290:\n",
541
+ " outputs = self.model.generate(\n",
542
+ " input_features,\n",
543
+ " max_new_tokens=450,\n",
544
+ " return_dict_in_generate=True,\n",
545
+ " forced_decoder_ids=[[1, 50290], [2, 50359], [3, 50363]],\n",
546
+ " )\n",
547
+ "\n",
548
+ " elif lang_token == 50291:\n",
549
+ " outputs = self.model.generate(\n",
550
+ " input_features,\n",
551
+ " max_new_tokens=450,\n",
552
+ " return_dict_in_generate=True,\n",
553
+ " forced_decoder_ids=[[1, 50291], [2, 50359], [3, 50363]],\n",
554
+ " )\n",
555
+ "\n",
556
+ " elif lang_token == 50292:\n",
557
+ " outputs = self.model.generate(\n",
558
+ " input_features,\n",
559
+ " max_new_tokens=450,\n",
560
+ " return_dict_in_generate=True,\n",
561
+ " forced_decoder_ids=[[1, 50292], [2, 50359], [3, 50363]],\n",
562
+ " )\n",
563
+ "\n",
564
+ " elif lang_token == 50293:\n",
565
+ " outputs = self.model.generate(\n",
566
+ " input_features,\n",
567
+ " max_new_tokens=450,\n",
568
+ " return_dict_in_generate=True,\n",
569
+ " forced_decoder_ids=[[1, 50293], [2, 50359], [3, 50363]],\n",
570
+ " )\n",
571
+ "\n",
572
+ " elif lang_token == 50298:\n",
573
+ " outputs = self.model.generate(\n",
574
+ " input_features,\n",
575
+ " max_new_tokens=450,\n",
576
+ " return_dict_in_generate=True,\n",
577
+ " forced_decoder_ids=[[1, 50298], [2, 50359], [3, 50363]],\n",
578
+ " )\n",
579
+ "\n",
580
+ " elif lang_token == 50301:\n",
581
+ " outputs = self.model.generate(\n",
582
+ " input_features,\n",
583
+ " max_new_tokens=450,\n",
584
+ " return_dict_in_generate=True,\n",
585
+ " forced_decoder_ids=[[1, 50301], [2, 50359], [3, 50363]],\n",
586
+ " )\n",
587
+ "\n",
588
+ " elif lang_token == 50305:\n",
589
+ " outputs = self.model.generate(\n",
590
+ " input_features,\n",
591
+ " max_new_tokens=450,\n",
592
+ " return_dict_in_generate=True,\n",
593
+ " forced_decoder_ids=[[1, 50305], [2, 50359], [3, 50363]],\n",
594
+ " )\n",
595
+ "\n",
596
+ " elif lang_token == 50307:\n",
597
+ " outputs = self.model.generate(\n",
598
+ " input_features,\n",
599
+ " max_new_tokens=450,\n",
600
+ " return_dict_in_generate=True,\n",
601
+ " forced_decoder_ids=[[1, 50307], [2, 50359], [3, 50363]],\n",
602
+ " )\n",
603
+ "\n",
604
+ " elif lang_token == 50343:\n",
605
+ " outputs = self.model.generate(\n",
606
+ " input_features,\n",
607
+ " max_new_tokens=450,\n",
608
+ " return_dict_in_generate=True,\n",
609
+ " forced_decoder_ids=[[1, 50343], [2, 50359], [3, 50363]],\n",
610
+ " )\n",
611
+ "\n",
612
+ " elif lang_token == 50345:\n",
613
+ " outputs = self.model.generate(\n",
614
+ " input_features,\n",
615
+ " max_new_tokens=450,\n",
616
+ " return_dict_in_generate=True,\n",
617
+ " forced_decoder_ids=[[1, 50345], [2, 50359], [3, 50363]],\n",
618
+ " )\n",
619
+ "\n",
620
+ " else:\n",
621
+ " outputs = self.model.generate(\n",
622
+ " input_features,\n",
623
+ " max_new_tokens=450, # change as needed\n",
624
+ " return_dict_in_generate=True,\n",
625
+ " forced_decoder_ids=[[2, 50359], [3, 50363]],\n",
626
+ " )\n",
627
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
628
+ "\n",
629
+ "\n",
630
+ " @tf.function(\n",
631
+ " input_signature=[\n",
632
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
633
+ " ],\n",
634
+ " )\n",
635
+ " def transcribe(self, input_features):\n",
636
+ " outputs = self.model.generate(\n",
637
+ " input_features,\n",
638
+ " max_new_tokens=450, # change as needed\n",
639
+ " return_dict_in_generate=True,\n",
640
+ " forced_decoder_ids=[[2, 50359], [3, 50363]],\n",
641
+ " )\n",
642
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
643
+ "\n",
644
+ " @tf.function(\n",
645
+ " input_signature=[\n",
646
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
647
+ " ],\n",
648
+ " )\n",
649
+ " def translate(self, input_features):\n",
650
+ " outputs = self.model.generate(\n",
651
+ " input_features,\n",
652
+ " max_new_tokens=450, # change as needed\n",
653
+ " return_dict_in_generate=True,\n",
654
+ " forced_decoder_ids=[[2, 50358], [3, 50363]],\n",
655
+ " )\n",
656
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
657
+ "\n",
658
+ "# Assuming `model` is already defined and loaded\n",
659
+ "saved_model_dir = '/content/tf_whisper_saved'\n",
660
+ "tflite_model_path = 'whisper.tflite'\n",
661
+ "\n",
662
+ "generate_model = GenerateModel(model=model)\n",
663
+ "tf.saved_model.save(generate_model, saved_model_dir, signatures={\n",
664
+ " \"serving_default\": generate_model.transcribe,\n",
665
+ " \"serving_transcribe\": generate_model.transcribe,\n",
666
+ " \"serving_translate\": generate_model.translate,\n",
667
+ " \"serving_transcribe_lang\": generate_model.transcribe_lang,\n",
668
+ "\n",
669
+ "})\n",
670
+ "\n",
671
+ "# Convert the model\n",
672
+ "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
673
+ "converter.target_spec.supported_ops = [\n",
674
+ " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n",
675
+ " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n",
676
+ "]\n",
677
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
678
+ "tflite_model = converter.convert()\n",
679
+ "\n",
680
+ "# Save the model\n",
681
+ "with open(tflite_model_path, 'wb') as f:\n",
682
+ " f.write(tflite_model)"
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "source": [
688
+ "pwd"
689
+ ],
690
+ "metadata": {
691
+ "id": "llf-5421rZ-G"
692
+ },
693
+ "execution_count": null,
694
+ "outputs": []
695
+ },
696
+ {
697
+ "cell_type": "code",
698
+ "source": [
699
+ "!zip -r /content/tf_whisper_saved.zip /content/tf_whisper_saved/"
700
+ ],
701
+ "metadata": {
702
+ "colab": {
703
+ "base_uri": "https://localhost:8080/"
704
+ },
705
+ "collapsed": true,
706
+ "id": "7pnAWtGZp6MJ",
707
+ "outputId": "42d6c775-1af9-4482-837a-eb3537d5e2c0"
708
+ },
709
+ "execution_count": null,
710
+ "outputs": [
711
+ {
712
+ "output_type": "stream",
713
+ "name": "stdout",
714
+ "text": [
715
+ " adding: content/tf_whisper_saved/ (stored 0%)\n",
716
+ " adding: content/tf_whisper_saved/assets/ (stored 0%)\n",
717
+ " adding: content/tf_whisper_saved/variables/ (stored 0%)\n",
718
+ " adding: content/tf_whisper_saved/variables/variables.data-00000-of-00001 (deflated 41%)\n",
719
+ " adding: content/tf_whisper_saved/variables/variables.index (deflated 79%)\n",
720
+ " adding: content/tf_whisper_saved/fingerprint.pb (stored 0%)\n",
721
+ " adding: content/tf_whisper_saved/keras_metadata.pb (deflated 96%)\n",
722
+ " adding: content/tf_whisper_saved/saved_model.pb (deflated 93%)\n"
723
+ ]
724
+ }
725
+ ]
726
+ },
727
+ {
728
+ "cell_type": "code",
729
+ "execution_count": null,
730
+ "metadata": {
731
+ "id": "u9MustgMU7oI"
732
+ },
733
+ "outputs": [],
734
+ "source": [
735
+ "# loaded model... now with generate!\n",
736
+ "tflite_model_path = 'whisper.tflite'\n",
737
+ "interpreter = tf.lite.Interpreter(tflite_model_path)\n",
738
+ "\n",
739
+ "tflite_generate = interpreter.get_signature_runner('serving_translate')\n",
740
+ "lang_token = tf.constant([50261], dtype=tf.int32)\n",
741
+ "generated_ids = tflite_generate(input_features=input_features)[\"sequences\"]\n",
742
+ "transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]\n",
743
+ "transcription\n",
744
+ "\n",
745
+ "\n"
746
+ ]
747
+ }
748
+ ],
749
+ "metadata": {
750
+ "colab": {
751
+ "machine_shape": "hm",
752
+ "provenance": []
753
+ },
754
+ "kernelspec": {
755
+ "display_name": "Python 3",
756
+ "name": "python3"
757
+ },
758
+ "language_info": {
759
+ "name": "python"
760
+ }
761
+ },
762
+ "nbformat": 4,
763
+ "nbformat_minor": 0
764
+ }
whisper-base.TOP_WORLD.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d48fcde4fda3d2c4b68e6f156c6cf4bc66fc4e3d817c9c7b87849ebd5c8ee72
3
+ size 107564368
whisper-base.TOP_WORLD.tokens ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [50259, 50260, 50261, 50262,50263, 50264, 50265, 50266, 50267, 50268, 50269,50270, 50271,50272, 50273, 50274,50275, 50277,50278,50279,50280, 50281,50282, 50283, 50284, 50285, 50286,50287, 50288, 50289,50290, 50291, 50292, 50293, 50298, 50301, 50305, 50307, 50343, 50345]
2
+
3
+