Spaces:
Running
Running
feat: add functions
Browse files
dev/inference/wandb-backend.ipynb
CHANGED
|
@@ -2,13 +2,15 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
| 9 |
"source": [
|
| 10 |
"import csv\n",
|
| 11 |
"import tempfile\n",
|
|
|
|
|
|
|
| 12 |
"import wandb\n",
|
| 13 |
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
| 14 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
|
@@ -42,26 +44,82 @@
|
|
| 42 |
},
|
| 43 |
{
|
| 44 |
"cell_type": "code",
|
| 45 |
-
"execution_count":
|
| 46 |
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
| 47 |
"metadata": {},
|
| 48 |
"outputs": [],
|
| 49 |
"source": [
|
| 50 |
"with open('samples.csv', newline='', encoding='utf8') as f:\n",
|
| 51 |
-
" reader = csv.
|
|
|
|
| 52 |
" for row in reader:\n",
|
| 53 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
]
|
| 55 |
},
|
| 56 |
{
|
| 57 |
"cell_type": "code",
|
| 58 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
"id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
|
| 60 |
"metadata": {},
|
| 61 |
"outputs": [],
|
| 62 |
"source": [
|
|
|
|
| 63 |
"wandb_run = wandb_runs[0]\n",
|
| 64 |
-
"
|
| 65 |
]
|
| 66 |
},
|
| 67 |
{
|
|
@@ -280,27 +338,30 @@
|
|
| 280 |
},
|
| 281 |
{
|
| 282 |
"cell_type": "code",
|
| 283 |
-
"execution_count":
|
| 284 |
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
| 285 |
"metadata": {},
|
| 286 |
"outputs": [],
|
| 287 |
-
"source": [
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
| 304 |
},
|
| 305 |
{
|
| 306 |
"cell_type": "code",
|
|
@@ -323,7 +384,7 @@
|
|
| 323 |
{
|
| 324 |
"cell_type": "code",
|
| 325 |
"execution_count": null,
|
| 326 |
-
"id": "
|
| 327 |
"metadata": {},
|
| 328 |
"outputs": [],
|
| 329 |
"source": []
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 197,
|
| 6 |
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
| 9 |
"source": [
|
| 10 |
"import csv\n",
|
| 11 |
"import tempfile\n",
|
| 12 |
+
"from functools import partial\n",
|
| 13 |
+
"import jax\n",
|
| 14 |
"import wandb\n",
|
| 15 |
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
| 16 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
|
|
|
| 44 |
},
|
| 45 |
{
|
| 46 |
"cell_type": "code",
|
| 47 |
+
"execution_count": 245,
|
| 48 |
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
| 49 |
"metadata": {},
|
| 50 |
"outputs": [],
|
| 51 |
"source": [
|
| 52 |
"with open('samples.csv', newline='', encoding='utf8') as f:\n",
|
| 53 |
+
" reader = csv.DictReader(f)\n",
|
| 54 |
+
" samples = []\n",
|
| 55 |
" for row in reader:\n",
|
| 56 |
+
" samples.append(row)"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": 246,
|
| 62 |
+
"id": "f75b2869-fc25-4f56-b937-e97bbb712ede",
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [
|
| 65 |
+
{
|
| 66 |
+
"data": {
|
| 67 |
+
"text/plain": [
|
| 68 |
+
"101"
|
| 69 |
+
]
|
| 70 |
+
},
|
| 71 |
+
"execution_count": 246,
|
| 72 |
+
"metadata": {},
|
| 73 |
+
"output_type": "execute_result"
|
| 74 |
+
}
|
| 75 |
+
],
|
| 76 |
+
"source": [
|
| 77 |
+
"len(samples)"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 248,
|
| 83 |
+
"id": "2ea0b166-a20c-4d78-bffb-b792ca512d17",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [
|
| 86 |
+
{
|
| 87 |
+
"data": {
|
| 88 |
+
"text/plain": [
|
| 89 |
+
"104"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
"execution_count": 248,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"output_type": "execute_result"
|
| 95 |
+
}
|
| 96 |
+
],
|
| 97 |
+
"source": [
|
| 98 |
+
"samples_to_add = ['empty'] * (-len(samples) % 8)\n",
|
| 99 |
+
"samples.extend(samples_to_add)\n",
|
| 100 |
+
"len(samples)"
|
| 101 |
]
|
| 102 |
},
|
| 103 |
{
|
| 104 |
"cell_type": "code",
|
| 105 |
"execution_count": null,
|
| 106 |
+
"id": "a2c629e9-1a82-40c6-a260-ca1780c19a2e",
|
| 107 |
+
"metadata": {},
|
| 108 |
+
"outputs": [],
|
| 109 |
+
"source": [
|
| 110 |
+
"api = wandb.Api()"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"execution_count": 204,
|
| 116 |
"id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
|
| 117 |
"metadata": {},
|
| 118 |
"outputs": [],
|
| 119 |
"source": [
|
| 120 |
+
"# TODO: iterate on runs\n",
|
| 121 |
"wandb_run = wandb_runs[0]\n",
|
| 122 |
+
"functions_pmapped = False"
|
| 123 |
]
|
| 124 |
},
|
| 125 |
{
|
|
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "code",
|
| 341 |
+
"execution_count": 207,
|
| 342 |
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
| 343 |
"metadata": {},
|
| 344 |
"outputs": [],
|
| 345 |
+
"source": [
|
| 346 |
+
"# function to generate encoded images\n",
|
| 347 |
+
"# we should generate this function only once per run\n",
|
| 348 |
+
"if not functions_pmapped:\n",
|
| 349 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 350 |
+
" def p_generate(tokenized_prompt, key, params):\n",
|
| 351 |
+
" return model.generate(\n",
|
| 352 |
+
" **tokenized_prompt,\n",
|
| 353 |
+
" do_sample=True,\n",
|
| 354 |
+
" num_beams=1,\n",
|
| 355 |
+
" prng_key=key,\n",
|
| 356 |
+
" params=params\n",
|
| 357 |
+
" )\n",
|
| 358 |
+
" \n",
|
| 359 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 360 |
+
" def p_decode(indices, params):\n",
|
| 361 |
+
" return vqgan.decode_code(indices, params=params)\n",
|
| 362 |
+
" \n",
|
| 363 |
+
" functions_pmapped = False"
|
| 364 |
+
]
|
| 365 |
},
|
| 366 |
{
|
| 367 |
"cell_type": "code",
|
|
|
|
| 384 |
{
|
| 385 |
"cell_type": "code",
|
| 386 |
"execution_count": null,
|
| 387 |
+
"id": "e79ac8f2-adc2-4a16-970c-dadcceadd566",
|
| 388 |
"metadata": {},
|
| 389 |
"outputs": [],
|
| 390 |
"source": []
|