Spaces:
Running
Running
feat: add scoring
Browse files- dev/inference/wandb-backend.ipynb +338 -51
dev/inference/wandb-backend.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
@@ -10,7 +10,13 @@
|
|
10 |
"import csv\n",
|
11 |
"import tempfile\n",
|
12 |
"from functools import partial\n",
|
|
|
|
|
|
|
13 |
"import jax\n",
|
|
|
|
|
|
|
14 |
"import wandb\n",
|
15 |
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
16 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
@@ -30,6 +36,30 @@
|
|
30 |
"normalize_text = True"
|
31 |
]
|
32 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
"execution_count": null,
|
@@ -44,7 +74,18 @@
|
|
44 |
},
|
45 |
{
|
46 |
"cell_type": "code",
|
47 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
49 |
"metadata": {},
|
50 |
"outputs": [],
|
@@ -53,51 +94,32 @@
|
|
53 |
" reader = csv.DictReader(f)\n",
|
54 |
" samples = []\n",
|
55 |
" for row in reader:\n",
|
56 |
-
" samples.append(row)"
|
|
|
|
|
|
|
|
|
|
|
57 |
]
|
58 |
},
|
59 |
{
|
60 |
"cell_type": "code",
|
61 |
-
"execution_count":
|
62 |
"id": "f75b2869-fc25-4f56-b937-e97bbb712ede",
|
63 |
"metadata": {},
|
64 |
-
"outputs": [
|
65 |
-
{
|
66 |
-
"data": {
|
67 |
-
"text/plain": [
|
68 |
-
"101"
|
69 |
-
]
|
70 |
-
},
|
71 |
-
"execution_count": 246,
|
72 |
-
"metadata": {},
|
73 |
-
"output_type": "execute_result"
|
74 |
-
}
|
75 |
-
],
|
76 |
"source": [
|
77 |
"len(samples)"
|
78 |
]
|
79 |
},
|
80 |
{
|
81 |
"cell_type": "code",
|
82 |
-
"execution_count":
|
83 |
-
"id": "
|
84 |
"metadata": {},
|
85 |
-
"outputs": [
|
86 |
-
{
|
87 |
-
"data": {
|
88 |
-
"text/plain": [
|
89 |
-
"104"
|
90 |
-
]
|
91 |
-
},
|
92 |
-
"execution_count": 248,
|
93 |
-
"metadata": {},
|
94 |
-
"output_type": "execute_result"
|
95 |
-
}
|
96 |
-
],
|
97 |
"source": [
|
98 |
-
"
|
99 |
-
"samples.extend(samples_to_add)\n",
|
100 |
-
"len(samples)"
|
101 |
]
|
102 |
},
|
103 |
{
|
@@ -112,7 +134,7 @@
|
|
112 |
},
|
113 |
{
|
114 |
"cell_type": "code",
|
115 |
-
"execution_count":
|
116 |
"id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
|
117 |
"metadata": {},
|
118 |
"outputs": [],
|
@@ -148,21 +170,11 @@
|
|
148 |
{
|
149 |
"cell_type": "code",
|
150 |
"execution_count": null,
|
151 |
-
"id": "
|
152 |
-
"metadata": {},
|
153 |
-
"outputs": [],
|
154 |
-
"source": [
|
155 |
-
"versions = sorted(versions, key=lambda x: int(x.version[1:]))"
|
156 |
-
]
|
157 |
-
},
|
158 |
-
{
|
159 |
-
"cell_type": "code",
|
160 |
-
"execution_count": null,
|
161 |
-
"id": "d77159df-1a16-4996-aafd-1df82c5a3509",
|
162 |
"metadata": {},
|
163 |
"outputs": [],
|
164 |
"source": [
|
165 |
-
"versions"
|
166 |
]
|
167 |
},
|
168 |
{
|
@@ -253,6 +265,8 @@
|
|
253 |
"source": [
|
254 |
"if last_version_inference is None:\n",
|
255 |
" assert version == 0\n",
|
|
|
|
|
256 |
"else:\n",
|
257 |
" assert version == last_version_inference + 1"
|
258 |
]
|
@@ -338,7 +352,17 @@
|
|
338 |
},
|
339 |
{
|
340 |
"cell_type": "code",
|
341 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
343 |
"metadata": {},
|
344 |
"outputs": [],
|
@@ -360,6 +384,12 @@
|
|
360 |
" def p_decode(indices, params):\n",
|
361 |
" return vqgan.decode_code(indices, params=params)\n",
|
362 |
" \n",
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
" functions_pmapped = False"
|
364 |
]
|
365 |
},
|
@@ -369,25 +399,282 @@
|
|
369 |
"id": "7a24b903-777b-4e3d-817c-00ed613a7021",
|
370 |
"metadata": {},
|
371 |
"outputs": [],
|
372 |
-
"source": [
|
|
|
|
|
|
|
|
|
|
|
373 |
},
|
374 |
{
|
375 |
"cell_type": "code",
|
376 |
"execution_count": null,
|
377 |
-
"id": "
|
378 |
"metadata": {},
|
379 |
"outputs": [],
|
380 |
"source": [
|
381 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
]
|
383 |
},
|
384 |
{
|
385 |
"cell_type": "code",
|
386 |
"execution_count": null,
|
387 |
-
"id": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
"metadata": {},
|
389 |
"outputs": [],
|
390 |
"source": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
}
|
392 |
],
|
393 |
"metadata": {
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
|
|
10 |
"import csv\n",
|
11 |
"import tempfile\n",
|
12 |
"from functools import partial\n",
|
13 |
+
"import random\n",
|
14 |
+
"import numpy as np\n",
|
15 |
+
"from PIL import Image\n",
|
16 |
"import jax\n",
|
17 |
+
"import jax.numpy as jnp\n",
|
18 |
+
"from flax.training.common_utils import shard, shard_prng_key\n",
|
19 |
+
"from flax.jax_utils import replicate\n",
|
20 |
"import wandb\n",
|
21 |
"from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
|
22 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
|
|
36 |
"normalize_text = True"
|
37 |
]
|
38 |
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": null,
|
42 |
+
"id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"batch_size = 8\n",
|
47 |
+
"num_images = 128\n",
|
48 |
+
"top_k = 8\n",
|
49 |
+
"text_normalizer = TextNormalizer() if normalize_text else None"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": null,
|
55 |
+
"id": "6a045827-3461-4499-8959-38d173bc4e5e",
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"seed = random.randint(0, 2**32-1)\n",
|
60 |
+
"key = jax.random.PRNGKey(seed)"
|
61 |
+
]
|
62 |
+
},
|
63 |
{
|
64 |
"cell_type": "code",
|
65 |
"execution_count": null,
|
|
|
74 |
},
|
75 |
{
|
76 |
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"id": "4927529a-8828-4150-bc76-e1b60d8dee62",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"clip_params = replicate(clip.params)\n",
|
83 |
+
"vqgan_params = replicate(vqgan.params)"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": null,
|
89 |
"id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
|
90 |
"metadata": {},
|
91 |
"outputs": [],
|
|
|
94 |
" reader = csv.DictReader(f)\n",
|
95 |
" samples = []\n",
|
96 |
" for row in reader:\n",
|
97 |
+
" samples.append(row)\n",
|
98 |
+
" # make list multiple of batch_size by adding \"empty\"\n",
|
99 |
+
" samples_to_add = [{'Caption':'empty', 'Theme':'empty'}] * (-len(samples) % batch_size)\n",
|
100 |
+
" samples.extend(samples_to_add)\n",
|
101 |
+
" # reshape\n",
|
102 |
+
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
103 |
]
|
104 |
},
|
105 |
{
|
106 |
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
"id": "f75b2869-fc25-4f56-b937-e97bbb712ede",
|
109 |
"metadata": {},
|
110 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
"source": [
|
112 |
"len(samples)"
|
113 |
]
|
114 |
},
|
115 |
{
|
116 |
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"id": "c48525c9-447a-4430-81d7-4b699f545638",
|
119 |
"metadata": {},
|
120 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
"source": [
|
122 |
+
"samples[-1]"
|
|
|
|
|
123 |
]
|
124 |
},
|
125 |
{
|
|
|
134 |
},
|
135 |
{
|
136 |
"cell_type": "code",
|
137 |
+
"execution_count": null,
|
138 |
"id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
|
139 |
"metadata": {},
|
140 |
"outputs": [],
|
|
|
170 |
{
|
171 |
"cell_type": "code",
|
172 |
"execution_count": null,
|
173 |
+
"id": "ead44aee-52d5-4ca2-8984-c4d267d9e72a",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
"metadata": {},
|
175 |
"outputs": [],
|
176 |
"source": [
|
177 |
+
"versions[0].version"
|
178 |
]
|
179 |
},
|
180 |
{
|
|
|
265 |
"source": [
|
266 |
"if last_version_inference is None:\n",
|
267 |
" assert version == 0\n",
|
268 |
+
"elif last_version_inference >= version:\n",
|
269 |
+
" print(f'Version {version} has already been logged')\n",
|
270 |
"else:\n",
|
271 |
" assert version == last_version_inference + 1"
|
272 |
]
|
|
|
352 |
},
|
353 |
{
|
354 |
"cell_type": "code",
|
355 |
+
"execution_count": null,
|
356 |
+
"id": "320823c9-124a-4fc3-a12c-8c015a128285",
|
357 |
+
"metadata": {},
|
358 |
+
"outputs": [],
|
359 |
+
"source": [
|
360 |
+
"model_params = replicate(model.params)"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "code",
|
365 |
+
"execution_count": null,
|
366 |
"id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
|
367 |
"metadata": {},
|
368 |
"outputs": [],
|
|
|
384 |
" def p_decode(indices, params):\n",
|
385 |
" return vqgan.decode_code(indices, params=params)\n",
|
386 |
" \n",
|
387 |
+
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
388 |
+
" def p_clip(inputs):\n",
|
389 |
+
" logits = clip(**inputs).logits_per_image\n",
|
390 |
+
" return logits\n",
|
391 |
+
" scores = jax.nn.softmax(logits, axis=0).squeeze() \n",
|
392 |
+
" \n",
|
393 |
" functions_pmapped = False"
|
394 |
]
|
395 |
},
|
|
|
399 |
"id": "7a24b903-777b-4e3d-817c-00ed613a7021",
|
400 |
"metadata": {},
|
401 |
"outputs": [],
|
402 |
+
"source": [
|
403 |
+
"# TODO: loop over samples\n",
|
404 |
+
"batch = samples[0]\n",
|
405 |
+
"prompts = [x['Caption'] for x in batch]\n",
|
406 |
+
"processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts"
|
407 |
+
]
|
408 |
},
|
409 |
{
|
410 |
"cell_type": "code",
|
411 |
"execution_count": null,
|
412 |
+
"id": "d77aa785-dc05-4070-aba2-aa007524d20b",
|
413 |
"metadata": {},
|
414 |
"outputs": [],
|
415 |
"source": [
|
416 |
+
"processed_prompts"
|
417 |
+
]
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"cell_type": "code",
|
421 |
+
"execution_count": null,
|
422 |
+
"id": "95db38fb-8948-4814-98ae-c172ca7c6d0a",
|
423 |
+
"metadata": {},
|
424 |
+
"outputs": [],
|
425 |
+
"source": [
|
426 |
+
"repeated_prompts = processed_prompts * jax.device_count()"
|
427 |
+
]
|
428 |
+
},
|
429 |
+
{
|
430 |
+
"cell_type": "code",
|
431 |
+
"execution_count": null,
|
432 |
+
"id": "e948ba9e-3700-4e87-926f-580a10f3e5cd",
|
433 |
+
"metadata": {},
|
434 |
+
"outputs": [],
|
435 |
+
"source": [
|
436 |
+
"tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
|
437 |
+
"tokenized_prompt = shard(tokenized_prompt)"
|
438 |
+
]
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"cell_type": "code",
|
442 |
+
"execution_count": null,
|
443 |
+
"id": "30d96812-fc17-4acf-bb64-5fdb8d0cd313",
|
444 |
+
"metadata": {},
|
445 |
+
"outputs": [],
|
446 |
+
"source": [
|
447 |
+
"tokenized_prompt['input_ids'].shape"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": null,
|
453 |
+
"id": "92ea034b-2649-4d18-ab6d-877ed04ae5c4",
|
454 |
+
"metadata": {},
|
455 |
+
"outputs": [],
|
456 |
+
"source": [
|
457 |
+
"images = []\n",
|
458 |
+
"for i in range(num_images // jax.device_count()):\n",
|
459 |
+
" key, subkey = jax.random.split(key, 2)\n",
|
460 |
+
" \n",
|
461 |
+
" encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
|
462 |
+
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
463 |
+
" \n",
|
464 |
+
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
465 |
+
" decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
|
466 |
+
" \n",
|
467 |
+
" for img in decoded_images:\n",
|
468 |
+
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
|
469 |
+
" "
|
470 |
+
]
|
471 |
+
},
|
472 |
+
{
|
473 |
+
"cell_type": "code",
|
474 |
+
"execution_count": null,
|
475 |
+
"id": "84d52f30-44c9-4a74-9992-fb2578f19b90",
|
476 |
+
"metadata": {},
|
477 |
+
"outputs": [],
|
478 |
+
"source": [
|
479 |
+
"len(images)"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
{
|
483 |
+
"cell_type": "code",
|
484 |
+
"execution_count": null,
|
485 |
+
"id": "beb594f9-5b91-47fe-98bd-41e68c6b1d73",
|
486 |
+
"metadata": {},
|
487 |
+
"outputs": [],
|
488 |
+
"source": [
|
489 |
+
"images[0]"
|
490 |
+
]
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"cell_type": "code",
|
494 |
+
"execution_count": null,
|
495 |
+
"id": "bb135190-64e5-44af-b416-e688b034da44",
|
496 |
+
"metadata": {},
|
497 |
+
"outputs": [],
|
498 |
+
"source": [
|
499 |
+
"images[1]"
|
500 |
+
]
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"cell_type": "code",
|
504 |
+
"execution_count": null,
|
505 |
+
"id": "d78a0d92-72c2-4f82-a6ab-b3f5865dd863",
|
506 |
+
"metadata": {},
|
507 |
+
"outputs": [],
|
508 |
+
"source": [
|
509 |
+
"clip_inputs = processor(text=prompts, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data"
|
510 |
]
|
511 |
},
|
512 |
{
|
513 |
"cell_type": "code",
|
514 |
"execution_count": null,
|
515 |
+
"id": "89ff78a6-bfa4-44d9-ad66-07a4a68b4352",
|
516 |
+
"metadata": {},
|
517 |
+
"outputs": [],
|
518 |
+
"source": [
|
519 |
+
"# each shard will have one prompt\n",
|
520 |
+
"clip_inputs['input_ids'].shape"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"cell_type": "code",
|
525 |
+
"execution_count": null,
|
526 |
+
"id": "2cda8984-049c-4c87-96ad-7b0412750656",
|
527 |
+
"metadata": {},
|
528 |
+
"outputs": [],
|
529 |
+
"source": [
|
530 |
+
"# each shard needs to have the images corresponding to a specific prompt\n",
|
531 |
+
"clip_inputs['pixel_values'].shape"
|
532 |
+
]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"cell_type": "code",
|
536 |
+
"execution_count": null,
|
537 |
+
"id": "0a044e8f-be29-404b-b6c7-8f2395c5efc6",
|
538 |
+
"metadata": {},
|
539 |
+
"outputs": [],
|
540 |
+
"source": [
|
541 |
+
"images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
542 |
+
"images_per_prompt_indices"
|
543 |
+
]
|
544 |
+
},
|
545 |
+
{
|
546 |
+
"cell_type": "code",
|
547 |
+
"execution_count": null,
|
548 |
+
"id": "7a6c61b3-12e0-45d8-b39a-830288324d3d",
|
549 |
"metadata": {},
|
550 |
"outputs": [],
|
551 |
"source": []
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"cell_type": "code",
|
555 |
+
"execution_count": null,
|
556 |
+
"id": "7318e67e-4214-46f9-bf60-6d139d4bd00f",
|
557 |
+
"metadata": {},
|
558 |
+
"outputs": [],
|
559 |
+
"source": [
|
560 |
+
"# reorder so each shard will have correct images\n",
|
561 |
+
"clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))"
|
562 |
+
]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"cell_type": "code",
|
566 |
+
"execution_count": null,
|
567 |
+
"id": "90c949a2-8e2a-4905-b6d4-92038f1704b8",
|
568 |
+
"metadata": {},
|
569 |
+
"outputs": [],
|
570 |
+
"source": [
|
571 |
+
"clip_inputs = shard(clip_inputs)"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"cell_type": "code",
|
576 |
+
"execution_count": null,
|
577 |
+
"id": "58fa836e-5ebb-45e7-af77-ab10646dfbfb",
|
578 |
+
"metadata": {},
|
579 |
+
"outputs": [],
|
580 |
+
"source": [
|
581 |
+
"logits = p_clip(clip_inputs)"
|
582 |
+
]
|
583 |
+
},
|
584 |
+
{
|
585 |
+
"cell_type": "code",
|
586 |
+
"execution_count": null,
|
587 |
+
"id": "fd7a3f91-3a1f-4a0a-8b3e-3c926cd367fb",
|
588 |
+
"metadata": {},
|
589 |
+
"outputs": [],
|
590 |
+
"source": [
|
591 |
+
"logits.shape"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
{
|
595 |
+
"cell_type": "code",
|
596 |
+
"execution_count": null,
|
597 |
+
"id": "fa406db7-0a21-4e4b-9890-4c7aece4280c",
|
598 |
+
"metadata": {},
|
599 |
+
"outputs": [],
|
600 |
+
"source": [
|
601 |
+
"logits = logits.reshape(-1, num_images)"
|
602 |
+
]
|
603 |
+
},
|
604 |
+
{
|
605 |
+
"cell_type": "code",
|
606 |
+
"execution_count": null,
|
607 |
+
"id": "9c359a8c-2c27-4e68-8775-371857397723",
|
608 |
+
"metadata": {},
|
609 |
+
"outputs": [],
|
610 |
+
"source": [
|
611 |
+
"logits.shape"
|
612 |
+
]
|
613 |
+
},
|
614 |
+
{
|
615 |
+
"cell_type": "code",
|
616 |
+
"execution_count": null,
|
617 |
+
"id": "a56b9f28-dd91-4382-bc47-11e89fda1254",
|
618 |
+
"metadata": {},
|
619 |
+
"outputs": [],
|
620 |
+
"source": [
|
621 |
+
"logits"
|
622 |
+
]
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"cell_type": "code",
|
626 |
+
"execution_count": null,
|
627 |
+
"id": "0bed8167-0a6d-46c1-badf-8bdc20b93c31",
|
628 |
+
"metadata": {},
|
629 |
+
"outputs": [],
|
630 |
+
"source": [
|
631 |
+
"top_idx = logits.argsort()[:, -top_k:][..., ::-1]"
|
632 |
+
]
|
633 |
+
},
|
634 |
+
{
|
635 |
+
"cell_type": "code",
|
636 |
+
"execution_count": null,
|
637 |
+
"id": "188c5333-6b8c-4a17-8cc8-15651c77ef99",
|
638 |
+
"metadata": {},
|
639 |
+
"outputs": [],
|
640 |
+
"source": [
|
641 |
+
"len(images)"
|
642 |
+
]
|
643 |
+
},
|
644 |
+
{
|
645 |
+
"cell_type": "code",
|
646 |
+
"execution_count": null,
|
647 |
+
"id": "babd22b3-e773-467d-8bbb-f0323f57a44b",
|
648 |
+
"metadata": {},
|
649 |
+
"outputs": [],
|
650 |
+
"source": [
|
651 |
+
"results = []\n",
|
652 |
+
"columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]"
|
653 |
+
]
|
654 |
+
},
|
655 |
+
{
|
656 |
+
"cell_type": "code",
|
657 |
+
"execution_count": null,
|
658 |
+
"id": "75976c9f-dea5-48e3-8920-55a1bbfd91c2",
|
659 |
+
"metadata": {},
|
660 |
+
"outputs": [],
|
661 |
+
"source": [
|
662 |
+
"for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
|
663 |
+
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
664 |
+
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
665 |
+
" top_scores = [logits[x] for x in idx]\n",
|
666 |
+
" results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
|
667 |
+
]
|
668 |
+
},
|
669 |
+
{
|
670 |
+
"cell_type": "code",
|
671 |
+
"execution_count": null,
|
672 |
+
"id": "e1c04761-1016-47e9-925c-3a9ec6fec95a",
|
673 |
+
"metadata": {},
|
674 |
+
"outputs": [],
|
675 |
+
"source": [
|
676 |
+
"wandb.finish()"
|
677 |
+
]
|
678 |
}
|
679 |
],
|
680 |
"metadata": {
|