Spaces:
Running
Running
feat: create a table
Browse files
dev/inference/wandb-backend.ipynb
CHANGED
|
@@ -46,7 +46,8 @@
|
|
| 46 |
"batch_size = 8\n",
|
| 47 |
"num_images = 128\n",
|
| 48 |
"top_k = 8\n",
|
| 49 |
-
"text_normalizer = TextNormalizer() if normalize_text else None"
|
|
|
|
| 50 |
]
|
| 51 |
},
|
| 52 |
{
|
|
@@ -95,8 +96,8 @@
|
|
| 95 |
" samples = []\n",
|
| 96 |
" for row in reader:\n",
|
| 97 |
" samples.append(row)\n",
|
| 98 |
-
" # make list multiple of batch_size by adding \
|
| 99 |
-
" samples_to_add = [{'Caption':
|
| 100 |
" samples.extend(samples_to_add)\n",
|
| 101 |
" # reshape\n",
|
| 102 |
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
|
@@ -388,7 +389,6 @@
|
|
| 388 |
" def p_clip(inputs):\n",
|
| 389 |
" logits = clip(**inputs).logits_per_image\n",
|
| 390 |
" return logits\n",
|
| 391 |
-
" scores = jax.nn.softmax(logits, axis=0).squeeze() \n",
|
| 392 |
" \n",
|
| 393 |
" functions_pmapped = False"
|
| 394 |
]
|
|
@@ -649,7 +649,8 @@
|
|
| 649 |
"outputs": [],
|
| 650 |
"source": [
|
| 651 |
"results = []\n",
|
| 652 |
-
"columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]"
|
|
|
|
| 653 |
]
|
| 654 |
},
|
| 655 |
{
|
|
@@ -660,12 +661,23 @@
|
|
| 660 |
"outputs": [],
|
| 661 |
"source": [
|
| 662 |
"for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
|
|
|
|
| 663 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 664 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
| 665 |
-
" top_scores = [
|
| 666 |
" results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
|
| 667 |
]
|
| 668 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
{
|
| 670 |
"cell_type": "code",
|
| 671 |
"execution_count": null,
|
|
|
|
| 46 |
"batch_size = 8\n",
|
| 47 |
"num_images = 128\n",
|
| 48 |
"top_k = 8\n",
|
| 49 |
+
"text_normalizer = TextNormalizer() if normalize_text else None\n",
|
| 50 |
+
"padding_item = 'NONE'"
|
| 51 |
]
|
| 52 |
},
|
| 53 |
{
|
|
|
|
| 96 |
" samples = []\n",
|
| 97 |
" for row in reader:\n",
|
| 98 |
" samples.append(row)\n",
|
| 99 |
+
" # make list multiple of batch_size by adding elements\n",
|
| 100 |
+
" samples_to_add = [{'Caption':padding_item, 'Theme':padding_item}] * (-len(samples) % batch_size)\n",
|
| 101 |
" samples.extend(samples_to_add)\n",
|
| 102 |
" # reshape\n",
|
| 103 |
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
|
|
|
| 389 |
" def p_clip(inputs):\n",
|
| 390 |
" logits = clip(**inputs).logits_per_image\n",
|
| 391 |
" return logits\n",
|
|
|
|
| 392 |
" \n",
|
| 393 |
" functions_pmapped = False"
|
| 394 |
]
|
|
|
|
| 649 |
"outputs": [],
|
| 650 |
"source": [
|
| 651 |
"results = []\n",
|
| 652 |
+
"columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
| 653 |
+
"logits = jax.device_get(logits)"
|
| 654 |
]
|
| 655 |
},
|
| 656 |
{
|
|
|
|
| 661 |
"outputs": [],
|
| 662 |
"source": [
|
| 663 |
"for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
|
| 664 |
+
" if sample['Caption'] == padding_item: continue\n",
|
| 665 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 666 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
| 667 |
+
" top_scores = [scores[x] for x in idx]\n",
|
| 668 |
" results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
|
| 669 |
]
|
| 670 |
},
|
| 671 |
+
{
|
| 672 |
+
"cell_type": "code",
|
| 673 |
+
"execution_count": null,
|
| 674 |
+
"id": "4bf40461-99d3-4d36-b7cc-e0129a3c9053",
|
| 675 |
+
"metadata": {},
|
| 676 |
+
"outputs": [],
|
| 677 |
+
"source": [
|
| 678 |
+
"table = wandb.Table(columns=columns, data=results)"
|
| 679 |
+
]
|
| 680 |
+
},
|
| 681 |
{
|
| 682 |
"cell_type": "code",
|
| 683 |
"execution_count": null,
|