|
import torch |
|
from torch.utils.data import DataLoader |
|
import lightning.pytorch as pl |
|
import xarray as xr |
|
import numpy as np |
|
import glob |
|
import random |
|
|
|
|
|
class GapFillDataset(torch.utils.data.Dataset): |
|
def __init__(self, ds): |
|
self.ds = ds |
|
self.resize_shape = (480, 1440) |
|
self.input_horizon = 1 |
|
self.norm_range = [200, 350] |
|
|
|
if "Tb" not in self.ds.data_vars: |
|
raise ValueError("Dataset must contain 'Tb' variable.") |
|
|
|
if len(self.ds.time) <= self.input_horizon: |
|
raise ValueError("Dataset too small for input horizon.") |
|
|
|
def __getitem__(self, idx): |
|
input_data = self.ds.isel(time=slice(idx, idx + self.input_horizon)).astype( |
|
np.float16 |
|
) |
|
|
|
chunks = {"time": 1, "lat": 3298, "lon": 9896} |
|
input_data["Tb"] = ( |
|
input_data["Tb"] |
|
.chunk(chunks) |
|
.bfill(dim="lat", limit=5) |
|
.bfill(dim="lon", limit=5) |
|
) |
|
|
|
input_data = self.resize_data(input_data, *self.resize_shape) |
|
|
|
tb_mask = input_data["Tb"].isnull() |
|
input_data = xr.where(tb_mask, 0, input_data["Tb"]) |
|
|
|
if tb_mask.values.all(): |
|
return self.__getitem__(random.randint(0, len(self.ds) - 1)) |
|
|
|
input_tensor = torch.from_numpy(input_data.values).half().unsqueeze(0) |
|
input_mask = torch.from_numpy( |
|
np.repeat( |
|
tb_mask.values[np.newaxis, :, :, :], input_tensor.shape[0], axis=0 |
|
) |
|
) |
|
|
|
input_tensor = torch.clamp( |
|
input_tensor, min=self.norm_range[0], max=self.norm_range[1] |
|
) |
|
input_tensor = (input_tensor - self.norm_range[0]) / ( |
|
self.norm_range[1] - self.norm_range[0] |
|
) |
|
|
|
timestamps = input_data.time |
|
years = torch.tensor(timestamps.dt.year.values, dtype=torch.int32) |
|
doy = torch.tensor(timestamps.dt.dayofyear.values, dtype=torch.int32) |
|
hours = torch.tensor(timestamps.dt.hour.values, dtype=torch.int32) |
|
minutes = torch.tensor(timestamps.dt.minute.values, dtype=torch.int32) |
|
|
|
return { |
|
"x": input_tensor, |
|
"x_mask": input_mask, |
|
"temporal_pos": [years, doy, hours, minutes], |
|
} |
|
|
|
def __len__(self): |
|
return len(self.ds.time) - 1 - self.input_horizon |
|
|
|
@staticmethod |
|
def resize_data(multi_data, new_lat_size, new_lon_size): |
|
new_lat = np.linspace(multi_data.lat.min(), multi_data.lat.max(), new_lat_size) |
|
new_lon = np.linspace(multi_data.lon.min(), multi_data.lon.max(), new_lon_size) |
|
if "time" in multi_data.dims: |
|
return xr.concat( |
|
[ |
|
multi_data.isel(time=t) |
|
.interp(lat=new_lat, lon=new_lon, method="nearest") |
|
.expand_dims(dim="time") |
|
for t in range(len(multi_data["time"])) |
|
], |
|
dim="time", |
|
) |
|
else: |
|
return multi_data.interp(lat=new_lat, lon=new_lon, method="nearest") |
|
|
|
|
|
class TrainingDataModule(pl.LightningDataModule): |
|
def __init__(self, data_path, max_files=None): |
|
super().__init__() |
|
self.data_path = data_path |
|
|
|
chunks = {"time": 1, "lat": 3298, "lon": 9896} |
|
file_batch_size = 10 |
|
|
|
goes_files = sorted(glob.glob(f"{self.data_path}"))[:max_files] |
|
datasets = [] |
|
for i in range(0, len(goes_files), file_batch_size): |
|
batch_files = goes_files[i : i + file_batch_size] |
|
batch_ds = xr.open_mfdataset( |
|
batch_files, |
|
combine="by_coords", |
|
parallel=False, |
|
chunks=chunks, |
|
compat="override", |
|
coords="minimal", |
|
combine_attrs="override", |
|
engine="netcdf4", |
|
) |
|
datasets.append(batch_ds) |
|
|
|
ds = xr.combine_by_coords( |
|
datasets, compat="override", coords="minimal", combine_attrs="override" |
|
) |
|
ds["time"] = ds["time"].dt.floor("s") |
|
|
|
n = len(ds.time) |
|
train_size = int(0.8 * n) |
|
|
|
self.train_dataset = GapFillDataset(ds.isel(time=slice(0, train_size))) |
|
self.val_dataset = GapFillDataset(ds.isel(time=slice(train_size, n))) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_dataset, shuffle=True) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_dataset, shuffle=False) |
|
|
|
def on_epoch_end(self): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def create_dataloader( |
|
data_path="../data/IR/*/*.nc4", |
|
max_files=None, |
|
): |
|
""" |
|
Create a dataloader for gapfill inference. |
|
|
|
Args: |
|
data_path: Full path (including wildcards) to the data files. |
|
max_files: Maximum number of files to load. If None, all files are loaded. |
|
|
|
Returns: |
|
DataLoader: PyTorch DataLoader object. |
|
""" |
|
|
|
if torch.cuda.device_count() > 1 and not torch.distributed.is_initialized(): |
|
torch.distributed.init_process_group(backend="nccl") |
|
|
|
data_module = TrainingDataModule(data_path, max_files=max_files) |
|
data_module.setup(stage="test") |
|
|
|
|
|
if hasattr(data_module, "val_dataset") and data_module.val_dataset is not None: |
|
dataset = data_module.val_dataset |
|
else: |
|
dataset = data_module.train_dataset |
|
|
|
loader = DataLoader( |
|
dataset, |
|
) |
|
|
|
return loader |
|
|