import torch import torch.nn as nn from super_image import RcanModel, RcanConfig class CustomRcan(RcanModel): """ RCAN variant without sub_mean / add_mean normalization. Useful for physical variables like wind components (u, v), where image normalization is not applicable. """ def forward(self, x): # Skip sub_mean and add_mean x = self.head(x) res = self.body(x) res += x x = self.tail(res) return x def load_rcan(pretrained_repo="lschmidt/rcan-dsc", config_file="config.json", weight_file="pytorch_model_4x.pt"): from huggingface_hub import hf_hub_download config, _ = RcanConfig.from_pretrained(pretrained_repo, config_filename=config_file) model = CustomRcan(config) state_dict_path = hf_hub_download(repo_id=pretrained_repo, filename=weight_file) state_dict = torch.load(state_dict_path, map_location="cpu") model.load_state_dict(state_dict, strict=False) return model