import torch from model import UNetTransformerModel, UNetConfig # Load config and model config = UNetConfig(in_channels=1, out_channels=3, image_size=256) model = UNetTransformerModel(config) # Load your existing model weights model.model.load_state_dict(torch.load("unet_epoch20.pth", map_location="cpu")) # Save the model and config in HF-compatible format model.save_pretrained("brain_unet_hf") config.save_pretrained("brain_unet_hf") print("✅ Saved to brain_unet_hf/")