{ "cells": [ { "cell_type": "markdown", "id": "b0ed0610-9643-4e48-a05d-281703e16998", "metadata": {}, "source": [ "### Import Required Packages" ] }, { "cell_type": "code", "execution_count": 1, "id": "0c14de93-d962-4225-86c5-c2a44ad716ab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/utils/_pytree.py:185: FutureWarning: optree is installed but the version is too old to support PyTorch Dynamo in C++ pytree. C++ pytree support is disabled. Please consider upgrading optree using `python3 -m pip install --upgrade 'optree>=0.13.0'`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2025-04-25 13:20:34,206] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -c /tmp/tmpr6vshjte/test.c -o /tmp/tmpr6vshjte/test.o\n", "INFO:root:x86_64-linux-gnu-gcc /tmp/tmpr6vshjte/test.o -laio -o /tmp/tmpr6vshjte/a.out\n", "/usr/bin/ld: cannot find -laio: No such file or directory\n", "collect2: error: ld returned 1 exit status\n", "INFO:root:x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -c /tmp/tmp6obl3v8t/test.c -o /tmp/tmp6obl3v8t/test.o\n", "INFO:root:x86_64-linux-gnu-gcc /tmp/tmp6obl3v8t/test.o -L/usr/local/cuda -L/usr/local/cuda/lib64 -lcufile -o /tmp/tmp6obl3v8t/a.out\n", "INFO:root:x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -c /tmp/tmpkrdiql1w/test.c -o /tmp/tmpkrdiql1w/test.o\n", "INFO:root:x86_64-linux-gnu-gcc /tmp/tmpkrdiql1w/test.o -laio -o /tmp/tmpkrdiql1w/a.out\n", "/usr/bin/ld: cannot find -laio: No such file or directory\n", "collect2: error: ld returned 1 exit status\n" ] } ], "source": [ "import sys\n", "sys.path.append('../.')\n", "\n", "import torch\n", "import numpy as np\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", "import cartopy.crs as ccrs\n", "import cartopy.feature as cfeature\n", "\n", "from downstream.gap_fill.gf_dataloader import create_dataloader\n", "from downstream.gap_fill.lightning_presetup import GAIAGapFill" ] }, { "cell_type": "markdown", "id": "9a442874-1f90-41c1-90a1-c7a364bce1a0", "metadata": {}, "source": [ "# Load model from checkpoint" ] }, { "cell_type": "code", "execution_count": 2, "id": "2c41511f-d7e5-48c5-9ff3-31c93f5a85b1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/utils/_pytree.py:185: FutureWarning: optree is installed but the version is too old to support PyTorch Dynamo in C++ pytree. C++ pytree support is disabled. Please consider upgrading optree using `python3 -m pip install --upgrade 'optree>=0.13.0'`.\n", " warnings.warn(\n", "INFO:downstream.gap_fill.lightning_wrapper:Mask Ratio: 0.1\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config_path = \"../configs/gap_fill_config.yaml\"\n", "\n", "# Create model using config file\n", "model = GAIAGapFill(\n", " config_path=config_path,\n", ")\n", " \n", "model.configure_model()\n", "\n", "checkpoint_dir = \"../checkpoints/gaia-gapfill-v1.pt\"\n", "state_dict = torch.load(checkpoint_dir)\n", "model.load_state_dict(state_dict)" ] }, { "cell_type": "markdown", "id": "2ebc874a-cad4-4baf-8c9b-b751dc8f7f23", "metadata": {}, "source": [ "### Run Inference" ] }, { "cell_type": "code", "execution_count": 4, "id": "3cb316ff-d7cb-43fc-8f6c-dbe2e5b95508", "metadata": {}, "outputs": [], "source": [ "def run_inference(model, num_samples, device='cuda', dtype=torch.float32, mask_ratio=0.5, \n", " data_path=\"../data/IR/*/*.nc4\", **dataloader_kwargs):\n", " \"\"\"\n", " Run inference and plot ground truth, masked input, and reconstruction.\n", "\n", " Args:\n", " model: PyTorch model with .model.encoder and .model.decoder.\n", " num_samples: Number of samples to visualize.\n", " device: Device to use ('cuda' or 'cpu').\n", " dtype: Data type to cast inputs.\n", " mask_ratio: Ratio of input to mask.\n", " data_path: Full path to the data files (including wildcards).\n", " **dataloader_kwargs: Additional keyword arguments for create_dataloader.\n", " \"\"\"\n", " model.eval()\n", " model.to(device)\n", " \n", " loader = create_dataloader(\n", " data_path=data_path, \n", " **dataloader_kwargs\n", " )\n", " samples_collected = 0\n", "\n", " with torch.no_grad():\n", " for batch in tqdm(loader, desc=\"Running Inference\", unit=\"sample\"):\n", " x = batch['x'][:, :1, :1].to(device=device, dtype=dtype)\n", " x_mask = batch['x_mask'][:, :1, :1].to(device=device, dtype=dtype)\n", " _temporal_pos = [t.to(device) for t in batch['temporal_pos']] if 'temporal_pos' in batch else None\n", " temporal_pos = None # do not include temporal encodings at inference time\n", " \n", " latent, mask, ids_restore, _ = model.model.encoder(\n", " x, x_mask, mask_ratio=mask_ratio, temporal_pos=temporal_pos\n", " )\n", " pred = model.model.decoder(latent, ids_restore, temporal_pos=temporal_pos)\n", " pred_unpatched = model.model.encoder.unpatchify(pred)\n", " \n", " s, p, q = model.model.encoder.patch_embed.patch_size\n", " mask = model.model.encoder.unpatchify(mask.unsqueeze(-1).repeat(1, 1, s * p * q))\n", " \n", " for i in range(min(x.size(0), num_samples - samples_collected)):\n", " plot_reconstructions(\n", " ground_truth=x[i],\n", " masked_input=x[i] * (1 - mask[i]),\n", " reconstruction=pred_unpatched[i]\n", " )\n", " samples_collected += 1\n", " if samples_collected >= num_samples:\n", " return\n", "\n", "\n", "def plot_reconstructions(ground_truth, masked_input, reconstruction):\n", " \"\"\"\n", " Display ground truth, masked input, and reconstruction\n", " as global maps with borders and coastlines, flipped vertically.\n", " \"\"\"\n", " projection = ccrs.PlateCarree()\n", " borders = cfeature.BORDERS\n", "\n", " fig = plt.figure(figsize=(15, 5))\n", " images = [ground_truth, masked_input, reconstruction]\n", " titles = [\"Ground Truth\", \"Masked Input\", \"Reconstruction\"]\n", "\n", " for idx, (img, title) in enumerate(zip(images, titles), start=1):\n", " ax = fig.add_subplot(1, 3, idx, projection=projection)\n", " ax.set_extent([-180, 180, -60, 60], crs=projection)\n", "\n", " # get the numpy array and flip it vertically\n", " arr = img.squeeze().cpu().numpy()\n", " arr = np.flipud(arr) \n", "\n", " im = ax.imshow(\n", " arr,\n", " transform=projection,\n", " extent=[-180, 180, -60, 60],\n", " cmap='RdBu_r',\n", " vmin=0,\n", " vmax=1\n", " )\n", " ax.add_feature(borders, linewidth=0.5, edgecolor='black')\n", " ax.coastlines(resolution='50m', linewidth=0.5)\n", " ax.set_title(title)\n", " ax.axis('off')\n", "\n", " # shared colorbar\n", " cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])\n", " plt.colorbar(im, cax=cbar_ax, label='Normalized value')\n", "\n", " plt.tight_layout(rect=[0, 0, 0.9, 1])\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "id": "56bdc450-135a-4797-b622-9312208fd683", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Running Inference: 0%| | 0/6 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Running Inference: 0%| | 0/6 [00:07