GAIA-v1 / downstream /gap_fill /gf_dataloader.py
willbender's picture
GAIA: A Foundation Model for Operational Atmospheric Dynamics
fd943c3
import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl
import xarray as xr
import numpy as np
import glob
import random
class GapFillDataset(torch.utils.data.Dataset):
def __init__(self, ds):
self.ds = ds
self.resize_shape = (480, 1440)
self.input_horizon = 1
self.norm_range = [200, 350]
if "Tb" not in self.ds.data_vars:
raise ValueError("Dataset must contain 'Tb' variable.")
if len(self.ds.time) <= self.input_horizon:
raise ValueError("Dataset too small for input horizon.")
def __getitem__(self, idx):
input_data = self.ds.isel(time=slice(idx, idx + self.input_horizon)).astype(
np.float16
)
# Fill small missing pockets
chunks = {"time": 1, "lat": 3298, "lon": 9896}
input_data["Tb"] = (
input_data["Tb"]
.chunk(chunks)
.bfill(dim="lat", limit=5)
.bfill(dim="lon", limit=5)
)
input_data = self.resize_data(input_data, *self.resize_shape)
tb_mask = input_data["Tb"].isnull()
input_data = xr.where(tb_mask, 0, input_data["Tb"])
if tb_mask.values.all():
return self.__getitem__(random.randint(0, len(self.ds) - 1))
input_tensor = torch.from_numpy(input_data.values).half().unsqueeze(0)
input_mask = torch.from_numpy(
np.repeat(
tb_mask.values[np.newaxis, :, :, :], input_tensor.shape[0], axis=0
)
)
input_tensor = torch.clamp(
input_tensor, min=self.norm_range[0], max=self.norm_range[1]
)
input_tensor = (input_tensor - self.norm_range[0]) / (
self.norm_range[1] - self.norm_range[0]
)
timestamps = input_data.time # an xarray DataArray with datetime64[ns]
years = torch.tensor(timestamps.dt.year.values, dtype=torch.int32)
doy = torch.tensor(timestamps.dt.dayofyear.values, dtype=torch.int32)
hours = torch.tensor(timestamps.dt.hour.values, dtype=torch.int32)
minutes = torch.tensor(timestamps.dt.minute.values, dtype=torch.int32)
return {
"x": input_tensor,
"x_mask": input_mask,
"temporal_pos": [years, doy, hours, minutes],
}
def __len__(self):
return len(self.ds.time) - 1 - self.input_horizon
@staticmethod
def resize_data(multi_data, new_lat_size, new_lon_size):
new_lat = np.linspace(multi_data.lat.min(), multi_data.lat.max(), new_lat_size)
new_lon = np.linspace(multi_data.lon.min(), multi_data.lon.max(), new_lon_size)
if "time" in multi_data.dims:
return xr.concat(
[
multi_data.isel(time=t)
.interp(lat=new_lat, lon=new_lon, method="nearest")
.expand_dims(dim="time")
for t in range(len(multi_data["time"]))
],
dim="time",
)
else:
return multi_data.interp(lat=new_lat, lon=new_lon, method="nearest")
class TrainingDataModule(pl.LightningDataModule):
def __init__(self, data_path, max_files=None):
super().__init__()
self.data_path = data_path
chunks = {"time": 1, "lat": 3298, "lon": 9896}
file_batch_size = 10
goes_files = sorted(glob.glob(f"{self.data_path}"))[:max_files]
datasets = []
for i in range(0, len(goes_files), file_batch_size):
batch_files = goes_files[i : i + file_batch_size]
batch_ds = xr.open_mfdataset(
batch_files,
combine="by_coords",
parallel=False,
chunks=chunks,
compat="override",
coords="minimal",
combine_attrs="override",
engine="netcdf4",
)
datasets.append(batch_ds)
ds = xr.combine_by_coords(
datasets, compat="override", coords="minimal", combine_attrs="override"
)
ds["time"] = ds["time"].dt.floor("s")
n = len(ds.time)
train_size = int(0.8 * n)
self.train_dataset = GapFillDataset(ds.isel(time=slice(0, train_size)))
self.val_dataset = GapFillDataset(ds.isel(time=slice(train_size, n)))
def train_dataloader(self):
return DataLoader(self.train_dataset, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, shuffle=False)
def on_epoch_end(self):
torch.cuda.empty_cache()
def create_dataloader(
data_path="../data/IR/*/*.nc4",
max_files=None,
):
"""
Create a dataloader for gapfill inference.
Args:
data_path: Full path (including wildcards) to the data files.
max_files: Maximum number of files to load. If None, all files are loaded.
Returns:
DataLoader: PyTorch DataLoader object.
"""
if torch.cuda.device_count() > 1 and not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
data_module = TrainingDataModule(data_path, max_files=max_files)
data_module.setup(stage="test")
# Use validation dataset if available, else use training dataset.
if hasattr(data_module, "val_dataset") and data_module.val_dataset is not None:
dataset = data_module.val_dataset
else:
dataset = data_module.train_dataset
loader = DataLoader(
dataset,
)
return loader