Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
86ba774
1
Parent(s):
95d2faf
* Prepend [bos] to image encodings, rename to "labels".
Browse files- model/data-pipeline.ipynb +32 -13
model/data-pipeline.ipynb
CHANGED
|
@@ -161,7 +161,8 @@
|
|
| 161 |
"source": [
|
| 162 |
"# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
|
| 163 |
"max_length = 256 # Read from data_args.max_source_length\n",
|
| 164 |
-
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')"
|
|
|
|
| 165 |
]
|
| 166 |
},
|
| 167 |
{
|
|
@@ -178,7 +179,7 @@
|
|
| 178 |
" inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
|
| 179 |
" )\n",
|
| 180 |
"\n",
|
| 181 |
-
" model_inputs[\"
|
| 182 |
"\n",
|
| 183 |
" return model_inputs"
|
| 184 |
]
|
|
@@ -192,10 +193,10 @@
|
|
| 192 |
"source": [
|
| 193 |
"num_workers = 48 # We have 96 processors in the TPU\n",
|
| 194 |
"column_names = dataset.column_names\n",
|
| 195 |
-
"
|
| 196 |
-
"
|
| 197 |
-
"
|
| 198 |
-
"
|
| 199 |
")"
|
| 200 |
]
|
| 201 |
},
|
|
@@ -240,7 +241,7 @@
|
|
| 240 |
"text": [
|
| 241 |
"INFO:absl:Starting the local TPU driver.\n",
|
| 242 |
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
| 243 |
-
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are:
|
| 244 |
]
|
| 245 |
}
|
| 246 |
],
|
|
@@ -257,7 +258,7 @@
|
|
| 257 |
"metadata": {},
|
| 258 |
"outputs": [],
|
| 259 |
"source": [
|
| 260 |
-
"loader = data_loader(rng,
|
| 261 |
]
|
| 262 |
},
|
| 263 |
{
|
|
@@ -279,7 +280,7 @@
|
|
| 279 |
{
|
| 280 |
"data": {
|
| 281 |
"text/plain": [
|
| 282 |
-
"dict_keys(['attention_mask', '
|
| 283 |
]
|
| 284 |
},
|
| 285 |
"execution_count": 13,
|
|
@@ -309,7 +310,7 @@
|
|
| 309 |
}
|
| 310 |
],
|
| 311 |
"source": [
|
| 312 |
-
"len(superbatch[\"
|
| 313 |
]
|
| 314 |
},
|
| 315 |
{
|
|
@@ -321,7 +322,7 @@
|
|
| 321 |
{
|
| 322 |
"data": {
|
| 323 |
"text/plain": [
|
| 324 |
-
"(8, 64,
|
| 325 |
]
|
| 326 |
},
|
| 327 |
"execution_count": 15,
|
|
@@ -330,15 +331,33 @@
|
|
| 330 |
}
|
| 331 |
],
|
| 332 |
"source": [
|
| 333 |
-
"superbatch[\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
]
|
| 335 |
},
|
| 336 |
{
|
| 337 |
"cell_type": "code",
|
| 338 |
-
"execution_count":
|
| 339 |
"id": "cfe23a71",
|
| 340 |
"metadata": {},
|
| 341 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
"source": []
|
| 343 |
}
|
| 344 |
],
|
|
|
|
| 161 |
"source": [
|
| 162 |
"# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
|
| 163 |
"max_length = 256 # Read from data_args.max_source_length\n",
|
| 164 |
+
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
|
| 165 |
+
"image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
|
| 166 |
]
|
| 167 |
},
|
| 168 |
{
|
|
|
|
| 179 |
" inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
|
| 180 |
" )\n",
|
| 181 |
"\n",
|
| 182 |
+
" model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
|
| 183 |
"\n",
|
| 184 |
" return model_inputs"
|
| 185 |
]
|
|
|
|
| 193 |
"source": [
|
| 194 |
"num_workers = 48 # We have 96 processors in the TPU\n",
|
| 195 |
"column_names = dataset.column_names\n",
|
| 196 |
+
"input_dataset = dataset.map(preprocess_function,\n",
|
| 197 |
+
" remove_columns=column_names,\n",
|
| 198 |
+
" batched=True,\n",
|
| 199 |
+
" num_proc=48\n",
|
| 200 |
")"
|
| 201 |
]
|
| 202 |
},
|
|
|
|
| 241 |
"text": [
|
| 242 |
"INFO:absl:Starting the local TPU driver.\n",
|
| 243 |
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
| 244 |
+
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
|
| 245 |
]
|
| 246 |
}
|
| 247 |
],
|
|
|
|
| 258 |
"metadata": {},
|
| 259 |
"outputs": [],
|
| 260 |
"source": [
|
| 261 |
+
"loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
|
| 262 |
]
|
| 263 |
},
|
| 264 |
{
|
|
|
|
| 280 |
{
|
| 281 |
"data": {
|
| 282 |
"text/plain": [
|
| 283 |
+
"dict_keys(['attention_mask', 'input_ids', 'labels'])"
|
| 284 |
]
|
| 285 |
},
|
| 286 |
"execution_count": 13,
|
|
|
|
| 310 |
}
|
| 311 |
],
|
| 312 |
"source": [
|
| 313 |
+
"len(superbatch[\"labels\"])"
|
| 314 |
]
|
| 315 |
},
|
| 316 |
{
|
|
|
|
| 322 |
{
|
| 323 |
"data": {
|
| 324 |
"text/plain": [
|
| 325 |
+
"(8, 64, 257)"
|
| 326 |
]
|
| 327 |
},
|
| 328 |
"execution_count": 15,
|
|
|
|
| 331 |
}
|
| 332 |
],
|
| 333 |
"source": [
|
| 334 |
+
"superbatch[\"labels\"].shape"
|
| 335 |
+
]
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"cell_type": "markdown",
|
| 339 |
+
"id": "6800153b",
|
| 340 |
+
"metadata": {},
|
| 341 |
+
"source": [
|
| 342 |
+
"Any image sequence should begin with `image_bos`:"
|
| 343 |
]
|
| 344 |
},
|
| 345 |
{
|
| 346 |
"cell_type": "code",
|
| 347 |
+
"execution_count": 16,
|
| 348 |
"id": "cfe23a71",
|
| 349 |
"metadata": {},
|
| 350 |
"outputs": [],
|
| 351 |
+
"source": [
|
| 352 |
+
"assert superbatch[\"labels\"][1][5][0].item() == image_bos"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"cell_type": "code",
|
| 357 |
+
"execution_count": null,
|
| 358 |
+
"id": "0fb899b4",
|
| 359 |
+
"metadata": {},
|
| 360 |
+
"outputs": [],
|
| 361 |
"source": []
|
| 362 |
}
|
| 363 |
],
|