Spaces:
Running
Running
fix: style
Browse files- tools/inference/inference_pipeline.ipynb +46 -19
- tools/train/train.py +5 -5
tools/inference/inference_pipeline.ipynb
CHANGED
|
@@ -70,15 +70,15 @@
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
-
"DALLE_MODEL =
|
| 74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
| 77 |
-
"VQGAN_REPO =
|
| 78 |
-
"VQGAN_COMMIT_ID =
|
| 79 |
"\n",
|
| 80 |
"# CLIP model\n",
|
| 81 |
-
"CLIP_REPO =
|
| 82 |
"CLIP_COMMIT_ID = None"
|
| 83 |
]
|
| 84 |
},
|
|
@@ -121,18 +121,28 @@
|
|
| 121 |
"import wandb\n",
|
| 122 |
"\n",
|
| 123 |
"# Load dalle-mini\n",
|
| 124 |
-
"if
|
| 125 |
" # wandb artifact\n",
|
| 126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
| 127 |
" # we only download required files (no need for opt_state which is large)\n",
|
| 128 |
-
" model_files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
" for f in model_files:\n",
|
| 130 |
-
" artifact.get_path(f).download(
|
| 131 |
-
" model = DalleBart.from_pretrained(
|
| 132 |
-
" tokenizer = AutoTokenizer.from_pretrained(
|
| 133 |
"else:\n",
|
| 134 |
" # local folder or 🤗 Hub\n",
|
| 135 |
-
" model = DalleBart.from_pretrained(
|
|
|
|
|
|
|
| 136 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
| 137 |
"\n",
|
| 138 |
"# Load VQGAN\n",
|
|
@@ -191,7 +201,7 @@
|
|
| 191 |
"from functools import partial\n",
|
| 192 |
"\n",
|
| 193 |
"# model inference\n",
|
| 194 |
-
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3,4))\n",
|
| 195 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
| 196 |
" return model.generate(\n",
|
| 197 |
" **tokenized_prompt,\n",
|
|
@@ -203,11 +213,13 @@
|
|
| 203 |
" top_p=top_p\n",
|
| 204 |
" )\n",
|
| 205 |
"\n",
|
|
|
|
| 206 |
"# decode images\n",
|
| 207 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 208 |
"def p_decode(indices, params):\n",
|
| 209 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 210 |
"\n",
|
|
|
|
| 211 |
"# score images\n",
|
| 212 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 213 |
"def p_clip(inputs, params):\n",
|
|
@@ -235,7 +247,7 @@
|
|
| 235 |
"import random\n",
|
| 236 |
"\n",
|
| 237 |
"# create a random key\n",
|
| 238 |
-
"seed = random.randint(0, 2**32-1)\n",
|
| 239 |
"key = jax.random.PRNGKey(seed)"
|
| 240 |
]
|
| 241 |
},
|
|
@@ -287,7 +299,7 @@
|
|
| 287 |
},
|
| 288 |
"outputs": [],
|
| 289 |
"source": [
|
| 290 |
-
"prompt =
|
| 291 |
]
|
| 292 |
},
|
| 293 |
{
|
|
@@ -323,7 +335,13 @@
|
|
| 323 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
| 324 |
"\n",
|
| 325 |
"# tokenize\n",
|
| 326 |
-
"tokenized_prompt = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
"tokenized_prompt"
|
| 328 |
]
|
| 329 |
},
|
|
@@ -408,12 +426,14 @@
|
|
| 408 |
" # get a new key\n",
|
| 409 |
" key, subkey = jax.random.split(key)\n",
|
| 410 |
" # generate images\n",
|
| 411 |
-
" encoded_images = p_generate(
|
|
|
|
|
|
|
| 412 |
" # remove BOS\n",
|
| 413 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 414 |
" # decode images\n",
|
| 415 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
| 416 |
-
" decoded_images = decoded_images.clip(0
|
| 417 |
" for img in decoded_images:\n",
|
| 418 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
| 419 |
]
|
|
@@ -436,7 +456,14 @@
|
|
| 436 |
"outputs": [],
|
| 437 |
"source": [
|
| 438 |
"# get clip scores\n",
|
| 439 |
-
"clip_inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
| 441 |
"logits = logits.squeeze().flatten()"
|
| 442 |
]
|
|
@@ -458,10 +485,10 @@
|
|
| 458 |
},
|
| 459 |
"outputs": [],
|
| 460 |
"source": [
|
| 461 |
-
"print(f
|
| 462 |
"for idx in logits.argsort()[::-1]:\n",
|
| 463 |
" display(images[idx])\n",
|
| 464 |
-
" print(f
|
| 465 |
]
|
| 466 |
}
|
| 467 |
],
|
|
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-3bqwu04f:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
|
| 74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
| 77 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
| 78 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
| 79 |
"\n",
|
| 80 |
"# CLIP model\n",
|
| 81 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
|
| 82 |
"CLIP_COMMIT_ID = None"
|
| 83 |
]
|
| 84 |
},
|
|
|
|
| 121 |
"import wandb\n",
|
| 122 |
"\n",
|
| 123 |
"# Load dalle-mini\n",
|
| 124 |
+
"if \":\" in DALLE_MODEL:\n",
|
| 125 |
" # wandb artifact\n",
|
| 126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
| 127 |
" # we only download required files (no need for opt_state which is large)\n",
|
| 128 |
+
" model_files = [\n",
|
| 129 |
+
" \"config.json\",\n",
|
| 130 |
+
" \"flax_model.msgpack\",\n",
|
| 131 |
+
" \"merges.txt\",\n",
|
| 132 |
+
" \"special_tokens_map.json\",\n",
|
| 133 |
+
" \"tokenizer.json\",\n",
|
| 134 |
+
" \"tokenizer_config.json\",\n",
|
| 135 |
+
" \"vocab.json\",\n",
|
| 136 |
+
" ]\n",
|
| 137 |
" for f in model_files:\n",
|
| 138 |
+
" artifact.get_path(f).download(\"model\")\n",
|
| 139 |
+
" model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
|
| 140 |
+
" tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
|
| 141 |
"else:\n",
|
| 142 |
" # local folder or 🤗 Hub\n",
|
| 143 |
+
" model = DalleBart.from_pretrained(\n",
|
| 144 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
| 145 |
+
" )\n",
|
| 146 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
| 147 |
"\n",
|
| 148 |
"# Load VQGAN\n",
|
|
|
|
| 201 |
"from functools import partial\n",
|
| 202 |
"\n",
|
| 203 |
"# model inference\n",
|
| 204 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
|
| 205 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
| 206 |
" return model.generate(\n",
|
| 207 |
" **tokenized_prompt,\n",
|
|
|
|
| 213 |
" top_p=top_p\n",
|
| 214 |
" )\n",
|
| 215 |
"\n",
|
| 216 |
+
"\n",
|
| 217 |
"# decode images\n",
|
| 218 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 219 |
"def p_decode(indices, params):\n",
|
| 220 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 221 |
"\n",
|
| 222 |
+
"\n",
|
| 223 |
"# score images\n",
|
| 224 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 225 |
"def p_clip(inputs, params):\n",
|
|
|
|
| 247 |
"import random\n",
|
| 248 |
"\n",
|
| 249 |
"# create a random key\n",
|
| 250 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
| 251 |
"key = jax.random.PRNGKey(seed)"
|
| 252 |
]
|
| 253 |
},
|
|
|
|
| 299 |
},
|
| 300 |
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
+
"prompt = \"a red T-shirt\""
|
| 303 |
]
|
| 304 |
},
|
| 305 |
{
|
|
|
|
| 335 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
| 336 |
"\n",
|
| 337 |
"# tokenize\n",
|
| 338 |
+
"tokenized_prompt = tokenizer(\n",
|
| 339 |
+
" repeated_prompts,\n",
|
| 340 |
+
" return_tensors=\"jax\",\n",
|
| 341 |
+
" padding=\"max_length\",\n",
|
| 342 |
+
" truncation=True,\n",
|
| 343 |
+
" max_length=128,\n",
|
| 344 |
+
").data\n",
|
| 345 |
"tokenized_prompt"
|
| 346 |
]
|
| 347 |
},
|
|
|
|
| 426 |
" # get a new key\n",
|
| 427 |
" key, subkey = jax.random.split(key)\n",
|
| 428 |
" # generate images\n",
|
| 429 |
+
" encoded_images = p_generate(\n",
|
| 430 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
|
| 431 |
+
" )\n",
|
| 432 |
" # remove BOS\n",
|
| 433 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 434 |
" # decode images\n",
|
| 435 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
| 436 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
| 437 |
" for img in decoded_images:\n",
|
| 438 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
| 439 |
]
|
|
|
|
| 456 |
"outputs": [],
|
| 457 |
"source": [
|
| 458 |
"# get clip scores\n",
|
| 459 |
+
"clip_inputs = processor(\n",
|
| 460 |
+
" text=[prompt] * jax.device_count(),\n",
|
| 461 |
+
" images=images,\n",
|
| 462 |
+
" return_tensors=\"np\",\n",
|
| 463 |
+
" padding=\"max_length\",\n",
|
| 464 |
+
" max_length=77,\n",
|
| 465 |
+
" truncation=True,\n",
|
| 466 |
+
").data\n",
|
| 467 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
| 468 |
"logits = logits.squeeze().flatten()"
|
| 469 |
]
|
|
|
|
| 485 |
},
|
| 486 |
"outputs": [],
|
| 487 |
"source": [
|
| 488 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
| 489 |
"for idx in logits.argsort()[::-1]:\n",
|
| 490 |
" display(images[idx])\n",
|
| 491 |
+
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
| 492 |
]
|
| 493 |
}
|
| 494 |
],
|
tools/train/train.py
CHANGED
|
@@ -219,9 +219,7 @@ class TrainingArguments:
|
|
| 219 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
| 220 |
},
|
| 221 |
)
|
| 222 |
-
weight_decay: float = field(
|
| 223 |
-
default=None, metadata={"help": "Weight decay."}
|
| 224 |
-
)
|
| 225 |
beta1: float = field(
|
| 226 |
default=0.9,
|
| 227 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
|
@@ -237,13 +235,15 @@ class TrainingArguments:
|
|
| 237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 238 |
)
|
| 239 |
block_size: int = field(
|
| 240 |
-
default=1024,
|
|
|
|
| 241 |
)
|
| 242 |
preconditioning_compute_steps: int = field(
|
| 243 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 244 |
)
|
| 245 |
skip_preconditioning_dim_size_gt: int = field(
|
| 246 |
-
default=4096,
|
|
|
|
| 247 |
)
|
| 248 |
optim_quantized: bool = field(
|
| 249 |
default=False,
|
|
|
|
| 219 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
| 220 |
},
|
| 221 |
)
|
| 222 |
+
weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
|
|
|
|
|
|
|
| 223 |
beta1: float = field(
|
| 224 |
default=0.9,
|
| 225 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
|
|
|
| 235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 236 |
)
|
| 237 |
block_size: int = field(
|
| 238 |
+
default=1024,
|
| 239 |
+
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
| 240 |
)
|
| 241 |
preconditioning_compute_steps: int = field(
|
| 242 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 243 |
)
|
| 244 |
skip_preconditioning_dim_size_gt: int = field(
|
| 245 |
+
default=4096,
|
| 246 |
+
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
|
| 247 |
)
|
| 248 |
optim_quantized: bool = field(
|
| 249 |
default=False,
|