Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	style: reformat
Browse files
    	
        tools/inference/inference_pipeline.ipynb
    CHANGED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tools/inference/log_inference_samples.ipynb
    CHANGED
    
    | @@ -31,11 +31,14 @@ | |
| 31 | 
             
               "metadata": {},
         | 
| 32 | 
             
               "outputs": [],
         | 
| 33 | 
             
               "source": [
         | 
| 34 | 
            -
                "run_ids = [ | 
| 35 | 
            -
                "ENTITY, PROJECT =  | 
| 36 | 
            -
                "VQGAN_REPO, VQGAN_COMMIT_ID =  | 
| 37 | 
            -
                " | 
| 38 | 
            -
                " | 
|  | |
|  | |
|  | |
| 39 | 
             
                "add_clip_32 = False"
         | 
| 40 | 
             
               ]
         | 
| 41 | 
             
              },
         | 
| @@ -63,8 +66,8 @@ | |
| 63 | 
             
                "num_images = 128\n",
         | 
| 64 | 
             
                "top_k = 8\n",
         | 
| 65 | 
             
                "text_normalizer = TextNormalizer()\n",
         | 
| 66 | 
            -
                "padding_item =  | 
| 67 | 
            -
                "seed = random.randint(0, 2**32-1)\n",
         | 
| 68 | 
             
                "key = jax.random.PRNGKey(seed)\n",
         | 
| 69 | 
             
                "api = wandb.Api()"
         | 
| 70 | 
             
               ]
         | 
| @@ -100,12 +103,15 @@ | |
| 100 | 
             
                "def p_decode(indices, params):\n",
         | 
| 101 | 
             
                "    return vqgan.decode_code(indices, params=params)\n",
         | 
| 102 | 
             
                "\n",
         | 
|  | |
| 103 | 
             
                "@partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 104 | 
             
                "def p_clip16(inputs, params):\n",
         | 
| 105 | 
             
                "    logits = clip16(params=params, **inputs).logits_per_image\n",
         | 
| 106 | 
             
                "    return logits\n",
         | 
| 107 | 
             
                "\n",
         | 
|  | |
| 108 | 
             
                "if add_clip_32:\n",
         | 
|  | |
| 109 | 
             
                "    @partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 110 | 
             
                "    def p_clip32(inputs, params):\n",
         | 
| 111 | 
             
                "        logits = clip32(params=params, **inputs).logits_per_image\n",
         | 
| @@ -119,13 +125,13 @@ | |
| 119 | 
             
               "metadata": {},
         | 
| 120 | 
             
               "outputs": [],
         | 
| 121 | 
             
               "source": [
         | 
| 122 | 
            -
                "with open( | 
| 123 | 
             
                "    samples = [l.strip() for l in f.readlines()]\n",
         | 
| 124 | 
             
                "    # make list multiple of batch_size by adding elements\n",
         | 
| 125 | 
             
                "    samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
         | 
| 126 | 
             
                "    samples.extend(samples_to_add)\n",
         | 
| 127 | 
             
                "    # reshape\n",
         | 
| 128 | 
            -
                "    samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
         | 
| 129 | 
             
               ]
         | 
| 130 | 
             
              },
         | 
| 131 | 
             
              {
         | 
| @@ -138,9 +144,17 @@ | |
| 138 | 
             
                "def get_artifact_versions(run_id, latest_only=False):\n",
         | 
| 139 | 
             
                "    try:\n",
         | 
| 140 | 
             
                "        if latest_only:\n",
         | 
| 141 | 
            -
                "            return [ | 
|  | |
|  | |
|  | |
|  | |
| 142 | 
             
                "        else:\n",
         | 
| 143 | 
            -
                "            return api.artifact_versions( | 
|  | |
|  | |
|  | |
|  | |
| 144 | 
             
                "    except:\n",
         | 
| 145 | 
             
                "        return []"
         | 
| 146 | 
             
               ]
         | 
| @@ -153,7 +167,7 @@ | |
| 153 | 
             
               "outputs": [],
         | 
| 154 | 
             
               "source": [
         | 
| 155 | 
             
                "def get_training_config(run_id):\n",
         | 
| 156 | 
            -
                "    training_run = api.run(f | 
| 157 | 
             
                "    config = training_run.config\n",
         | 
| 158 | 
             
                "    return config"
         | 
| 159 | 
             
               ]
         | 
| @@ -168,8 +182,8 @@ | |
| 168 | 
             
                "# retrieve inference run details\n",
         | 
| 169 | 
             
                "def get_last_inference_version(run_id):\n",
         | 
| 170 | 
             
                "    try:\n",
         | 
| 171 | 
            -
                "        inference_run = api.run(f | 
| 172 | 
            -
                "        return inference_run.summary.get( | 
| 173 | 
             
                "    except:\n",
         | 
| 174 | 
             
                "        return None"
         | 
| 175 | 
             
               ]
         | 
| @@ -183,7 +197,6 @@ | |
| 183 | 
             
               "source": [
         | 
| 184 | 
             
                "# compile functions - needed only once per run\n",
         | 
| 185 | 
             
                "def pmap_model_function(model):\n",
         | 
| 186 | 
            -
                "    \n",
         | 
| 187 | 
             
                "    @partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 188 | 
             
                "    def _generate(tokenized_prompt, key, params):\n",
         | 
| 189 | 
             
                "        return model.generate(\n",
         | 
| @@ -195,7 +208,7 @@ | |
| 195 | 
             
                "            top_k=gen_top_k,\n",
         | 
| 196 | 
             
                "            top_p=gen_top_p\n",
         | 
| 197 | 
             
                "        )\n",
         | 
| 198 | 
            -
                " | 
| 199 | 
             
                "    return _generate"
         | 
| 200 | 
             
               ]
         | 
| 201 | 
             
              },
         | 
| @@ -222,13 +235,21 @@ | |
| 222 | 
             
                "training_config = get_training_config(run_id)\n",
         | 
| 223 | 
             
                "run = None\n",
         | 
| 224 | 
             
                "p_generate = None\n",
         | 
| 225 | 
            -
                "model_files = [ | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 226 | 
             
                "for artifact in artifact_versions:\n",
         | 
| 227 | 
            -
                "    print(f | 
| 228 | 
             
                "    version = int(artifact.version[1:])\n",
         | 
| 229 | 
             
                "    results16, results32 = [], []\n",
         | 
| 230 | 
            -
                "    columns = [ | 
| 231 | 
            -
                " | 
| 232 | 
             
                "    if latest_only:\n",
         | 
| 233 | 
             
                "        assert last_inference_version is None or version > last_inference_version\n",
         | 
| 234 | 
             
                "    else:\n",
         | 
| @@ -236,14 +257,23 @@ | |
| 236 | 
             
                "            # we should start from v0\n",
         | 
| 237 | 
             
                "            assert version == 0\n",
         | 
| 238 | 
             
                "        elif version <= last_inference_version:\n",
         | 
| 239 | 
            -
                "            print( | 
|  | |
|  | |
| 240 | 
             
                "        else:\n",
         | 
| 241 | 
             
                "            # check we are logging the correct version\n",
         | 
| 242 | 
             
                "            assert version == last_inference_version + 1\n",
         | 
| 243 | 
             
                "\n",
         | 
| 244 | 
             
                "    # start/resume corresponding run\n",
         | 
| 245 | 
             
                "    if run is None:\n",
         | 
| 246 | 
            -
                "        run = wandb.init( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 247 | 
             
                "\n",
         | 
| 248 | 
             
                "    # work in temporary directory\n",
         | 
| 249 | 
             
                "    with tempfile.TemporaryDirectory() as tmp:\n",
         | 
| @@ -264,64 +294,109 @@ | |
| 264 | 
             
                "\n",
         | 
| 265 | 
             
                "        # process one batch of captions\n",
         | 
| 266 | 
             
                "        for batch in tqdm(samples):\n",
         | 
| 267 | 
            -
                "            processed_prompts =  | 
|  | |
|  | |
|  | |
|  | |
| 268 | 
             
                "\n",
         | 
| 269 | 
             
                "            # repeat the prompts to distribute over each device and tokenize\n",
         | 
| 270 | 
             
                "            processed_prompts = processed_prompts * jax.device_count()\n",
         | 
| 271 | 
            -
                "            tokenized_prompt = tokenizer( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 272 | 
             
                "            tokenized_prompt = shard(tokenized_prompt)\n",
         | 
| 273 | 
             
                "\n",
         | 
| 274 | 
             
                "            # generate images\n",
         | 
| 275 | 
             
                "            images = []\n",
         | 
| 276 | 
            -
                "            pbar = tqdm( | 
|  | |
|  | |
|  | |
|  | |
| 277 | 
             
                "            for i in pbar:\n",
         | 
| 278 | 
             
                "                key, subkey = jax.random.split(key)\n",
         | 
| 279 | 
            -
                "                encoded_images = p_generate( | 
|  | |
|  | |
| 280 | 
             
                "                encoded_images = encoded_images.sequences[..., 1:]\n",
         | 
| 281 | 
             
                "                decoded_images = p_decode(encoded_images, vqgan_params)\n",
         | 
| 282 | 
            -
                "                decoded_images = decoded_images.clip(0 | 
|  | |
|  | |
| 283 | 
             
                "                for img in decoded_images:\n",
         | 
| 284 | 
            -
                "                    images.append( | 
|  | |
|  | |
| 285 | 
             
                "\n",
         | 
| 286 | 
            -
                "            def add_clip_results(results, processor, p_clip, clip_params) | 
| 287 | 
            -
                "                clip_inputs = processor( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 288 | 
             
                "                # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
         | 
| 289 | 
            -
                "                images_per_prompt_indices = np.asarray( | 
| 290 | 
            -
                " | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 291 | 
             
                "                clip_inputs = shard(clip_inputs)\n",
         | 
| 292 | 
             
                "                logits = p_clip(clip_inputs, clip_params)\n",
         | 
| 293 | 
             
                "                logits = logits.reshape(-1, num_images)\n",
         | 
| 294 | 
             
                "                top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
         | 
| 295 | 
             
                "                logits = jax.device_get(logits)\n",
         | 
| 296 | 
             
                "                # add to results table\n",
         | 
| 297 | 
            -
                "                for i, (idx, scores, sample) in enumerate( | 
| 298 | 
            -
                "                     | 
|  | |
|  | |
|  | |
| 299 | 
             
                "                    cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
         | 
| 300 | 
            -
                "                    top_images = [ | 
|  | |
|  | |
|  | |
| 301 | 
             
                "                    results.append([sample] + top_images)\n",
         | 
| 302 | 
            -
                " | 
| 303 | 
             
                "            # get clip scores\n",
         | 
| 304 | 
            -
                "            pbar.set_description( | 
| 305 | 
             
                "            add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
         | 
| 306 | 
            -
                " | 
| 307 | 
             
                "            # get clip 32 scores\n",
         | 
| 308 | 
             
                "            if add_clip_32:\n",
         | 
| 309 | 
            -
                "                pbar.set_description( | 
| 310 | 
             
                "                add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
         | 
| 311 | 
             
                "\n",
         | 
| 312 | 
             
                "            pbar.close()\n",
         | 
| 313 | 
             
                "\n",
         | 
| 314 | 
            -
                "                \n",
         | 
| 315 | 
            -
                "\n",
         | 
| 316 | 
             
                "    # log results\n",
         | 
| 317 | 
             
                "    table = wandb.Table(columns=columns, data=results16)\n",
         | 
| 318 | 
            -
                "    run.log({ | 
| 319 | 
             
                "    wandb.finish()\n",
         | 
| 320 | 
            -
                " | 
| 321 | 
            -
                "    if add_clip_32 | 
| 322 | 
            -
                "        run = wandb.init( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 323 | 
             
                "        table = wandb.Table(columns=columns, data=results32)\n",
         | 
| 324 | 
            -
                "        run.log({ | 
| 325 | 
             
                "        wandb.finish()\n",
         | 
| 326 | 
             
                "        run = None  # ensure we don't log on this run"
         | 
| 327 | 
             
               ]
         | 
|  | |
| 31 | 
             
               "metadata": {},
         | 
| 32 | 
             
               "outputs": [],
         | 
| 33 | 
             
               "source": [
         | 
| 34 | 
            +
                "run_ids = [\"63otg87g\"]\n",
         | 
| 35 | 
            +
                "ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\"  # used only for training run\n",
         | 
| 36 | 
            +
                "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
         | 
| 37 | 
            +
                "    \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
         | 
| 38 | 
            +
                "    \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
         | 
| 39 | 
            +
                ")\n",
         | 
| 40 | 
            +
                "latest_only = True  # log only latest or all versions\n",
         | 
| 41 | 
            +
                "suffix = \"\"  # mainly for duplicate inference runs with a deleted version\n",
         | 
| 42 | 
             
                "add_clip_32 = False"
         | 
| 43 | 
             
               ]
         | 
| 44 | 
             
              },
         | 
|  | |
| 66 | 
             
                "num_images = 128\n",
         | 
| 67 | 
             
                "top_k = 8\n",
         | 
| 68 | 
             
                "text_normalizer = TextNormalizer()\n",
         | 
| 69 | 
            +
                "padding_item = \"NONE\"\n",
         | 
| 70 | 
            +
                "seed = random.randint(0, 2 ** 32 - 1)\n",
         | 
| 71 | 
             
                "key = jax.random.PRNGKey(seed)\n",
         | 
| 72 | 
             
                "api = wandb.Api()"
         | 
| 73 | 
             
               ]
         | 
|  | |
| 103 | 
             
                "def p_decode(indices, params):\n",
         | 
| 104 | 
             
                "    return vqgan.decode_code(indices, params=params)\n",
         | 
| 105 | 
             
                "\n",
         | 
| 106 | 
            +
                "\n",
         | 
| 107 | 
             
                "@partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 108 | 
             
                "def p_clip16(inputs, params):\n",
         | 
| 109 | 
             
                "    logits = clip16(params=params, **inputs).logits_per_image\n",
         | 
| 110 | 
             
                "    return logits\n",
         | 
| 111 | 
             
                "\n",
         | 
| 112 | 
            +
                "\n",
         | 
| 113 | 
             
                "if add_clip_32:\n",
         | 
| 114 | 
            +
                "\n",
         | 
| 115 | 
             
                "    @partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 116 | 
             
                "    def p_clip32(inputs, params):\n",
         | 
| 117 | 
             
                "        logits = clip32(params=params, **inputs).logits_per_image\n",
         | 
|  | |
| 125 | 
             
               "metadata": {},
         | 
| 126 | 
             
               "outputs": [],
         | 
| 127 | 
             
               "source": [
         | 
| 128 | 
            +
                "with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
         | 
| 129 | 
             
                "    samples = [l.strip() for l in f.readlines()]\n",
         | 
| 130 | 
             
                "    # make list multiple of batch_size by adding elements\n",
         | 
| 131 | 
             
                "    samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
         | 
| 132 | 
             
                "    samples.extend(samples_to_add)\n",
         | 
| 133 | 
             
                "    # reshape\n",
         | 
| 134 | 
            +
                "    samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
         | 
| 135 | 
             
               ]
         | 
| 136 | 
             
              },
         | 
| 137 | 
             
              {
         | 
|  | |
| 144 | 
             
                "def get_artifact_versions(run_id, latest_only=False):\n",
         | 
| 145 | 
             
                "    try:\n",
         | 
| 146 | 
             
                "        if latest_only:\n",
         | 
| 147 | 
            +
                "            return [\n",
         | 
| 148 | 
            +
                "                api.artifact(\n",
         | 
| 149 | 
            +
                "                    type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
         | 
| 150 | 
            +
                "                )\n",
         | 
| 151 | 
            +
                "            ]\n",
         | 
| 152 | 
             
                "        else:\n",
         | 
| 153 | 
            +
                "            return api.artifact_versions(\n",
         | 
| 154 | 
            +
                "                type_name=\"bart_model\",\n",
         | 
| 155 | 
            +
                "                name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
         | 
| 156 | 
            +
                "                per_page=10000,\n",
         | 
| 157 | 
            +
                "            )\n",
         | 
| 158 | 
             
                "    except:\n",
         | 
| 159 | 
             
                "        return []"
         | 
| 160 | 
             
               ]
         | 
|  | |
| 167 | 
             
               "outputs": [],
         | 
| 168 | 
             
               "source": [
         | 
| 169 | 
             
                "def get_training_config(run_id):\n",
         | 
| 170 | 
            +
                "    training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
         | 
| 171 | 
             
                "    config = training_run.config\n",
         | 
| 172 | 
             
                "    return config"
         | 
| 173 | 
             
               ]
         | 
|  | |
| 182 | 
             
                "# retrieve inference run details\n",
         | 
| 183 | 
             
                "def get_last_inference_version(run_id):\n",
         | 
| 184 | 
             
                "    try:\n",
         | 
| 185 | 
            +
                "        inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
         | 
| 186 | 
            +
                "        return inference_run.summary.get(\"version\", None)\n",
         | 
| 187 | 
             
                "    except:\n",
         | 
| 188 | 
             
                "        return None"
         | 
| 189 | 
             
               ]
         | 
|  | |
| 197 | 
             
               "source": [
         | 
| 198 | 
             
                "# compile functions - needed only once per run\n",
         | 
| 199 | 
             
                "def pmap_model_function(model):\n",
         | 
|  | |
| 200 | 
             
                "    @partial(jax.pmap, axis_name=\"batch\")\n",
         | 
| 201 | 
             
                "    def _generate(tokenized_prompt, key, params):\n",
         | 
| 202 | 
             
                "        return model.generate(\n",
         | 
|  | |
| 208 | 
             
                "            top_k=gen_top_k,\n",
         | 
| 209 | 
             
                "            top_p=gen_top_p\n",
         | 
| 210 | 
             
                "        )\n",
         | 
| 211 | 
            +
                "\n",
         | 
| 212 | 
             
                "    return _generate"
         | 
| 213 | 
             
               ]
         | 
| 214 | 
             
              },
         | 
|  | |
| 235 | 
             
                "training_config = get_training_config(run_id)\n",
         | 
| 236 | 
             
                "run = None\n",
         | 
| 237 | 
             
                "p_generate = None\n",
         | 
| 238 | 
            +
                "model_files = [\n",
         | 
| 239 | 
            +
                "    \"config.json\",\n",
         | 
| 240 | 
            +
                "    \"flax_model.msgpack\",\n",
         | 
| 241 | 
            +
                "    \"merges.txt\",\n",
         | 
| 242 | 
            +
                "    \"special_tokens_map.json\",\n",
         | 
| 243 | 
            +
                "    \"tokenizer.json\",\n",
         | 
| 244 | 
            +
                "    \"tokenizer_config.json\",\n",
         | 
| 245 | 
            +
                "    \"vocab.json\",\n",
         | 
| 246 | 
            +
                "]\n",
         | 
| 247 | 
             
                "for artifact in artifact_versions:\n",
         | 
| 248 | 
            +
                "    print(f\"Processing artifact: {artifact.name}\")\n",
         | 
| 249 | 
             
                "    version = int(artifact.version[1:])\n",
         | 
| 250 | 
             
                "    results16, results32 = [], []\n",
         | 
| 251 | 
            +
                "    columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
         | 
| 252 | 
            +
                "\n",
         | 
| 253 | 
             
                "    if latest_only:\n",
         | 
| 254 | 
             
                "        assert last_inference_version is None or version > last_inference_version\n",
         | 
| 255 | 
             
                "    else:\n",
         | 
|  | |
| 257 | 
             
                "            # we should start from v0\n",
         | 
| 258 | 
             
                "            assert version == 0\n",
         | 
| 259 | 
             
                "        elif version <= last_inference_version:\n",
         | 
| 260 | 
            +
                "            print(\n",
         | 
| 261 | 
            +
                "                f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
         | 
| 262 | 
            +
                "            )\n",
         | 
| 263 | 
             
                "        else:\n",
         | 
| 264 | 
             
                "            # check we are logging the correct version\n",
         | 
| 265 | 
             
                "            assert version == last_inference_version + 1\n",
         | 
| 266 | 
             
                "\n",
         | 
| 267 | 
             
                "    # start/resume corresponding run\n",
         | 
| 268 | 
             
                "    if run is None:\n",
         | 
| 269 | 
            +
                "        run = wandb.init(\n",
         | 
| 270 | 
            +
                "            job_type=\"inference\",\n",
         | 
| 271 | 
            +
                "            entity=\"dalle-mini\",\n",
         | 
| 272 | 
            +
                "            project=\"dalle-mini\",\n",
         | 
| 273 | 
            +
                "            config=training_config,\n",
         | 
| 274 | 
            +
                "            id=f\"{run_id}-clip16{suffix}\",\n",
         | 
| 275 | 
            +
                "            resume=\"allow\",\n",
         | 
| 276 | 
            +
                "        )\n",
         | 
| 277 | 
             
                "\n",
         | 
| 278 | 
             
                "    # work in temporary directory\n",
         | 
| 279 | 
             
                "    with tempfile.TemporaryDirectory() as tmp:\n",
         | 
|  | |
| 294 | 
             
                "\n",
         | 
| 295 | 
             
                "        # process one batch of captions\n",
         | 
| 296 | 
             
                "        for batch in tqdm(samples):\n",
         | 
| 297 | 
            +
                "            processed_prompts = (\n",
         | 
| 298 | 
            +
                "                [text_normalizer(x) for x in batch]\n",
         | 
| 299 | 
            +
                "                if model.config.normalize_text\n",
         | 
| 300 | 
            +
                "                else list(batch)\n",
         | 
| 301 | 
            +
                "            )\n",
         | 
| 302 | 
             
                "\n",
         | 
| 303 | 
             
                "            # repeat the prompts to distribute over each device and tokenize\n",
         | 
| 304 | 
             
                "            processed_prompts = processed_prompts * jax.device_count()\n",
         | 
| 305 | 
            +
                "            tokenized_prompt = tokenizer(\n",
         | 
| 306 | 
            +
                "                processed_prompts,\n",
         | 
| 307 | 
            +
                "                return_tensors=\"jax\",\n",
         | 
| 308 | 
            +
                "                padding=\"max_length\",\n",
         | 
| 309 | 
            +
                "                truncation=True,\n",
         | 
| 310 | 
            +
                "                max_length=128,\n",
         | 
| 311 | 
            +
                "            ).data\n",
         | 
| 312 | 
             
                "            tokenized_prompt = shard(tokenized_prompt)\n",
         | 
| 313 | 
             
                "\n",
         | 
| 314 | 
             
                "            # generate images\n",
         | 
| 315 | 
             
                "            images = []\n",
         | 
| 316 | 
            +
                "            pbar = tqdm(\n",
         | 
| 317 | 
            +
                "                range(num_images // jax.device_count()),\n",
         | 
| 318 | 
            +
                "                desc=\"Generating Images\",\n",
         | 
| 319 | 
            +
                "                leave=True,\n",
         | 
| 320 | 
            +
                "            )\n",
         | 
| 321 | 
             
                "            for i in pbar:\n",
         | 
| 322 | 
             
                "                key, subkey = jax.random.split(key)\n",
         | 
| 323 | 
            +
                "                encoded_images = p_generate(\n",
         | 
| 324 | 
            +
                "                    tokenized_prompt, shard_prng_key(subkey), model_params\n",
         | 
| 325 | 
            +
                "                )\n",
         | 
| 326 | 
             
                "                encoded_images = encoded_images.sequences[..., 1:]\n",
         | 
| 327 | 
             
                "                decoded_images = p_decode(encoded_images, vqgan_params)\n",
         | 
| 328 | 
            +
                "                decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
         | 
| 329 | 
            +
                "                    (-1, 256, 256, 3)\n",
         | 
| 330 | 
            +
                "                )\n",
         | 
| 331 | 
             
                "                for img in decoded_images:\n",
         | 
| 332 | 
            +
                "                    images.append(\n",
         | 
| 333 | 
            +
                "                        Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
         | 
| 334 | 
            +
                "                    )\n",
         | 
| 335 | 
             
                "\n",
         | 
| 336 | 
            +
                "            def add_clip_results(results, processor, p_clip, clip_params):\n",
         | 
| 337 | 
            +
                "                clip_inputs = processor(\n",
         | 
| 338 | 
            +
                "                    text=batch,\n",
         | 
| 339 | 
            +
                "                    images=images,\n",
         | 
| 340 | 
            +
                "                    return_tensors=\"np\",\n",
         | 
| 341 | 
            +
                "                    padding=\"max_length\",\n",
         | 
| 342 | 
            +
                "                    max_length=77,\n",
         | 
| 343 | 
            +
                "                    truncation=True,\n",
         | 
| 344 | 
            +
                "                ).data\n",
         | 
| 345 | 
             
                "                # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
         | 
| 346 | 
            +
                "                images_per_prompt_indices = np.asarray(\n",
         | 
| 347 | 
            +
                "                    range(0, len(images), batch_size)\n",
         | 
| 348 | 
            +
                "                )\n",
         | 
| 349 | 
            +
                "                clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
         | 
| 350 | 
            +
                "                    list(\n",
         | 
| 351 | 
            +
                "                        clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
         | 
| 352 | 
            +
                "                        for i in range(batch_size)\n",
         | 
| 353 | 
            +
                "                    )\n",
         | 
| 354 | 
            +
                "                )\n",
         | 
| 355 | 
             
                "                clip_inputs = shard(clip_inputs)\n",
         | 
| 356 | 
             
                "                logits = p_clip(clip_inputs, clip_params)\n",
         | 
| 357 | 
             
                "                logits = logits.reshape(-1, num_images)\n",
         | 
| 358 | 
             
                "                top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
         | 
| 359 | 
             
                "                logits = jax.device_get(logits)\n",
         | 
| 360 | 
             
                "                # add to results table\n",
         | 
| 361 | 
            +
                "                for i, (idx, scores, sample) in enumerate(\n",
         | 
| 362 | 
            +
                "                    zip(top_scores, logits, batch)\n",
         | 
| 363 | 
            +
                "                ):\n",
         | 
| 364 | 
            +
                "                    if sample == padding_item:\n",
         | 
| 365 | 
            +
                "                        continue\n",
         | 
| 366 | 
             
                "                    cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
         | 
| 367 | 
            +
                "                    top_images = [\n",
         | 
| 368 | 
            +
                "                        wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
         | 
| 369 | 
            +
                "                        for x in idx\n",
         | 
| 370 | 
            +
                "                    ]\n",
         | 
| 371 | 
             
                "                    results.append([sample] + top_images)\n",
         | 
| 372 | 
            +
                "\n",
         | 
| 373 | 
             
                "            # get clip scores\n",
         | 
| 374 | 
            +
                "            pbar.set_description(\"Calculating CLIP 16 scores\")\n",
         | 
| 375 | 
             
                "            add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
         | 
| 376 | 
            +
                "\n",
         | 
| 377 | 
             
                "            # get clip 32 scores\n",
         | 
| 378 | 
             
                "            if add_clip_32:\n",
         | 
| 379 | 
            +
                "                pbar.set_description(\"Calculating CLIP 32 scores\")\n",
         | 
| 380 | 
             
                "                add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
         | 
| 381 | 
             
                "\n",
         | 
| 382 | 
             
                "            pbar.close()\n",
         | 
| 383 | 
             
                "\n",
         | 
|  | |
|  | |
| 384 | 
             
                "    # log results\n",
         | 
| 385 | 
             
                "    table = wandb.Table(columns=columns, data=results16)\n",
         | 
| 386 | 
            +
                "    run.log({\"Samples\": table, \"version\": version})\n",
         | 
| 387 | 
             
                "    wandb.finish()\n",
         | 
| 388 | 
            +
                "\n",
         | 
| 389 | 
            +
                "    if add_clip_32:\n",
         | 
| 390 | 
            +
                "        run = wandb.init(\n",
         | 
| 391 | 
            +
                "            job_type=\"inference\",\n",
         | 
| 392 | 
            +
                "            entity=\"dalle-mini\",\n",
         | 
| 393 | 
            +
                "            project=\"dalle-mini\",\n",
         | 
| 394 | 
            +
                "            config=training_config,\n",
         | 
| 395 | 
            +
                "            id=f\"{run_id}-clip32{suffix}\",\n",
         | 
| 396 | 
            +
                "            resume=\"allow\",\n",
         | 
| 397 | 
            +
                "        )\n",
         | 
| 398 | 
             
                "        table = wandb.Table(columns=columns, data=results32)\n",
         | 
| 399 | 
            +
                "        run.log({\"Samples\": table, \"version\": version})\n",
         | 
| 400 | 
             
                "        wandb.finish()\n",
         | 
| 401 | 
             
                "        run = None  # ensure we don't log on this run"
         | 
| 402 | 
             
               ]
         | 

