TorchGeo

Model Weights extracted below:

import os
import hashlib

import torch
import segmentation_models_pytorch as smp


url = "https://github.com/microsoft/ai4g-flood/raw/refs/heads/main/models/ai4g_sar_model.ckpt"
state_dict = torch.hub.load_state_dict_from_url(url, weights_only=False, map_location="cpu")["state_dict"]
state_dict = {k.replace("model.model.", ""): v for k, v in state_dict.items() if "model.model." in k}
model = smp.Unet(
    encoder_name="mobilenet_v2",
    encoder_weights=None,
    in_channels=2,
    classes=2,
)
model.load_state_dict(state_dict, strict=True)

filename = "unet_mobilenetv2_sentinel1_ai4g_flood.pth"
torch.save(model.state_dict(), filename)
md5 = hashlib.md5(open(filename, "rb").read()).hexdigest()[:8]
os.rename(filename, filename.replace(".pth", f"-{md5}.pth"))
Downloads last month
17
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support