openfree's picture
Deploy from GitHub repository
b20c769 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score, JaccardIndex, Metric
from torchmetrics.segmentation import MeanIoU
class MetricsAccuracy(Metric):
"""
Computes the Ovearall Accuracy
Args:
modalities (list): list of modalities used
num_classes (int): number of classes
save_results (bool): if True saves prediction in a csv file
get_classes (bool): if True returns the classwise F1 Score
"""
def __init__(
self,
modalities: list = [],
num_classes: int = 15,
save_results: bool = False,
get_classes: bool = False,
multilabel: bool = False,
):
super().__init__()
self.get_classes = get_classes
task = "multilabel" if multilabel else "multiclass"
self.acc = Accuracy(task=task, num_classes=num_classes)
self.save_results = save_results
self.multiclass = not (multilabel)
if save_results:
self.results = {}
def update(self, pred, gt):
self.acc(pred, gt["label"])
if self.save_results:
for i, name in enumerate(gt["name"]):
self.results[name] = list(pred.cpu()[i].numpy())
def compute(self):
return {"OA": self.acc.compute()}
class MetricsMonoModal(Metric):
"""
Computes the micro, macro and weighted F1 Score for multi label classification
Args:
modalities (list): list of modalities used
num_classes (int): number of classes
save_results (bool): if True saves prediction in a csv file
get_classes (bool): if True returns the classwise F1 Score
"""
def __init__(
self,
modalities: list = [],
num_classes: int = 15,
save_results: bool = False,
get_classes: bool = False,
multilabel: bool = True,
):
super().__init__()
self.get_classes = get_classes
task = "multilabel" if multilabel else "multiclass"
self.f1 = F1Score(
task=task, average="none", num_labels=num_classes, num_classes=num_classes
)
self.f1_micro = F1Score(
task=task, average="micro", num_labels=num_classes, num_classes=num_classes
)
self.f1_weighted = F1Score(
task=task, average="weighted", num_labels=num_classes, num_classes=num_classes
)
self.save_results = save_results
self.multiclass = not (multilabel)
if save_results:
self.results = {}
def update(self, pred, gt):
if self.multiclass:
gt["label"] = gt["label"].argmax(dim=1)
self.f1(pred, gt["label"])
self.f1_micro(pred, gt["label"])
self.f1_weighted(pred, gt["label"])
if self.save_results:
for i, name in enumerate(gt["name"]):
self.results[name] = list(pred.cpu()[i].numpy())
def compute(self):
if self.get_classes:
f1 = self.f1.compute()
out = {
"F1_Score_macro": sum(f1) / len(f1),
"F1_Score_micro": self.f1_micro.compute(),
"F1_Score_weighted": self.f1_weighted.compute(),
}
for i in range(len(f1)):
out["_".join(["F1_classe", str(i)])] = f1[i]
return out
f1 = self.f1.compute()
out = {
"F1_Score_macro": sum(f1) / len(f1),
"F1_Score_micro": self.f1_micro.compute(),
"F1_Score_weighted": self.f1_weighted.compute(),
}
if self.save_results:
out["results"] = self.results
return out
return out
class NoMetrics(Metric):
"""
Computes no metrics or saves a batch of reconstruction to visualise them
Args:
save_reconstructs (bool): if True saves a batch of reconstructions
modalities (list): list of modalities used
save_dir (str): where to save reconstructions
"""
def __init__(
self,
save_reconstructs: bool = False,
modalities: list = [],
save_dir: str = "",
):
super().__init__()
self.save_dir = save_dir
self.save_recons = save_reconstructs
self.modalities = modalities
if self.save_recons:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
self.saves = {}
for modality in self.modalities:
self.saves[modality] = []
self.saves["_".join(["gt", modality])] = []
def update(self, pred, gt):
if self.save_recons:
recons, _ = pred
for modality in self.modalities:
if modality == "aerial":
preds = recons["_".join(["reconstruct", modality])]
target = gt[modality][:, :, :300, :300]
else:
preds, mask = recons["_".join(["reconstruct", modality])]
target = gt[modality][mask[:, 0], mask[:, 1]]
indice = torch.randint(0, len(preds), (1,)).item()
self.saves[modality].append(preds[indice])
self.saves["_".join(["gt", modality])].append(target[indice])
def compute(self):
if self.save_recons:
for key in self.saves.keys():
for i, tensor in enumerate(self.saves[key]):
torch.save(tensor.cpu(), self.save_dir + key + str(i) + ".pt")
return {}
class MetricsContrastif(Metric):
"""
Computes metrics for contrastive. Given embeddings for all tokens, we compute the cosine similarity matrix.
The metric computed is the accuracy of the M -1 minimum distances of each line (except diagonal of course)
being the same token across other modalities with M the number of modalities.
Args:
modalities (list): list of modalities used
"""
def __init__(
self,
modalities: list = [],
):
super().__init__()
self.modalities = modalities
self.n_k = len(self.modalities)
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
for i in range(len(modalities)):
self.add_state(modalities[i], default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, logits):
size = len(logits) // self.n_k
labels = torch.arange(size).unsqueeze(1)
labels = torch.cat([labels + i * len(labels) for i in range(self.n_k)], dim=1)
labels = torch.cat([labels for _ in range(self.n_k)]).to(logits.device)
for i in range(self.n_k):
_, top_indices = torch.topk(
logits[i * size : (i + 1) * size], k=self.n_k, dim=1, largest=True
)
self.__dict__[self.modalities[i]] += (
torch.sum(
torch.tensor(
[
top_indices[i, j] in labels[i]
for i in range(top_indices.size(0))
for j in range(self.n_k)
]
)
)
- len(top_indices)
) / (self.n_k - 1)
self.count += len(logits)
def compute(self):
dict = {}
for i in range(len(self.modalities)):
dict["_".join(["acc", self.modalities[i]])] = (
self.__dict__[self.modalities[i]] / self.count
)
return dict
class MetricsContrastifMulti(Metric):
"""
Computes metrics for contrastive. Given embeddings for all tokens, we compute the cosine similarity matrix.
The metric computed is the accuracy of the M -1 minimum distances of each line (except diagonal of course)
being the same token across other modalities with M the number of modalities.
Args:
modalities (list): list of modalities used
"""
def __init__(
self,
modalities: dict = {},
):
super().__init__()
self.modalities = modalities
for dataset in self.modalities.keys():
self.add_state(dataset + "_count", default=torch.tensor(0), dist_reduce_fx="sum")
for i in range(len(modalities[dataset])):
self.add_state(
dataset + "_" + modalities[dataset][i],
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
def update(self, logits, dataset):
modalities = self.modalities[dataset]
n_modalities = len(modalities)
size = len(logits) // n_modalities
labels = torch.arange(size).unsqueeze(1)
labels = torch.cat([labels + i * len(labels) for i in range(n_modalities)], dim=1)
labels = torch.cat([labels for _ in range(n_modalities)]).to(logits.device)
for i in range(n_modalities):
_, top_indices = torch.topk(
logits[i * size : (i + 1) * size], k=n_modalities, dim=1, largest=True
)
self.__dict__[dataset + "_" + modalities[i]] += (
torch.sum(
torch.tensor(
[
top_indices[i, j] in labels[i]
for i in range(top_indices.size(0))
for j in range(n_modalities)
]
)
)
- len(top_indices)
) / (n_modalities - 1)
self.__dict__[dataset + "_count"] += len(logits)
def compute(self):
dict = {}
for dataset in self.modalities.keys():
for i in range(len(self.modalities[dataset])):
dict["_".join(["acc", dataset, self.modalities[dataset][i]])] = (
self.__dict__[dataset + "_" + self.modalities[dataset][i]]
/ self.__dict__[dataset + "_count"]
)
return dict
class MetricsSemSeg(Metric):
"""
Computes mIoU for semantic segmentation
Args:
modalities (list): list of modalities used
num_classes (int): number of classes
save_results (bool): if True saves prediction in a csv file
get_classes (bool): if True returns the classwise F1 Score
"""
def __init__(
self,
modalities: list = [],
num_classes: int = 15,
save_results: bool = False,
get_classes: bool = False,
):
super().__init__()
self.modality = modalities[0]
self.num_classes = num_classes
self.get_classes = get_classes
self.miou = MeanIoU(num_classes=num_classes, per_class=True)
self.save_results = save_results
if save_results:
self.results = {}
def update(self, pred, gt):
label = gt["label"].flatten(0, 1).long()
self.miou(
torch.nn.functional.one_hot(
pred.flatten(2, 3).permute(0, 2, 1).flatten(0, 1).argmax(dim=1),
num_classes=self.num_classes,
),
torch.nn.functional.one_hot(label, num_classes=self.num_classes),
)
if self.save_results:
for i, name in enumerate(gt["name"]):
self.results[name] = list(pred.cpu()[i].numpy())
def compute(self):
if self.get_classes:
miou = self.miou.compute()
out = {"mIoU": sum(miou) / len(miou)}
for i in range(len(miou[:-1])):
out["_".join(["IoU", str(i)])] = miou[i]
return out
miou = self.miou.compute()
out = {
"IoU": miou[1].item(), # IoU of the foreground class
}
if self.save_results:
out["results"] = self.results
return out
return out
class MetricsSemSegJ(Metric):
"""
Computes the IoU for binary segmentation
Args:
modalities (list): list of modalities used
num_classes (int): number of classes
save_results (bool): if True saves prediction in a csv file
get_classes (bool): if True returns the classwise F1 Score
"""
def __init__(
self,
modalities: list = [],
num_classes: int = 15,
save_results: bool = False,
get_classes: bool = False,
save_dir: str = "",
):
super().__init__()
self.modality = modalities[0]
self.num_classes = num_classes
self.get_classes = get_classes
self.miou = JaccardIndex(task="multiclass", num_classes=2, ignore_index=-1)
self.save_results = save_results
if save_results:
self.save_dir = save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def update(self, pred, gt):
self.miou(
pred.flatten(2, 3).permute(0, 2, 1).flatten(0, 1).argmax(dim=1),
gt["label"].flatten(1, 2).flatten(0, 1).long(),
)
if self.save_results:
for i, name in enumerate(gt["name"]):
np.save(self.save_dir + str(name) + ".npy", pred.cpu()[i].numpy())
np.save(self.save_dir + str(name) + "_gt.npy", gt["label"].cpu()[i].numpy())
def compute(self):
if self.get_classes:
miou = self.miou.compute()
out = {"mIoU": sum(miou[:-1]) / len(miou[:-1])}
for i in range(len(miou[:-1])):
out["_".join(["IoU", str(i)])] = miou[i]
return out
miou = self.miou.compute()
out = {
"IoU": miou,
}
return out
import numpy as np
class Metric(object):
"""Base class for all metrics.
From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
"""
def reset(self):
pass
def add(self):
pass
def value(self):
pass
class ConfusionMatrix(Metric):
"""Constructs a confusion matrix for a multi-class classification problems.
Does not support multi-label, multi-class problems.
Keyword arguments:
- num_classes (int): number of classes in the classification problem.
- normalized (boolean, optional): Determines whether or not the confusion
matrix is normalized or not. Default: False.
Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
"""
def __init__(self, num_classes, normalized=False, device="cpu", lazy=True):
super().__init__()
if device == "cpu":
self.conf = np.ndarray((num_classes, num_classes), dtype=np.int64)
else:
self.conf = torch.zeros((num_classes, num_classes)).cuda()
self.normalized = normalized
self.num_classes = num_classes
self.device = device
self.reset()
self.lazy = lazy
def reset(self):
if self.device == "cpu":
self.conf.fill(0)
else:
self.conf = torch.zeros(self.conf.shape).cuda()
def add(self, predicted, target):
"""Computes the confusion matrix
The shape of the confusion matrix is K x K, where K is the number
of classes.
Keyword arguments:
- predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
predicted scores obtained from the model for N examples and K classes,
or an N-tensor/array of integer values between 0 and K-1.
- target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
ground-truth classes for N examples and K classes, or an N-tensor/array
of integer values between 0 and K-1.
"""
# If target and/or predicted are tensors, convert them to numpy arrays
if self.device == "cpu":
if torch.is_tensor(predicted):
predicted = predicted.cpu().numpy()
if torch.is_tensor(target):
target = target.cpu().numpy()
assert (
predicted.shape[0] == target.shape[0]
), "number of targets and predicted outputs do not match"
if len(predicted.shape) != 1:
assert (
predicted.shape[1] == self.num_classes
), "number of predictions does not match size of confusion matrix"
predicted = predicted.argmax(1)
else:
if not self.lazy:
assert (predicted.max() < self.num_classes) and (
predicted.min() >= 0
), "predicted values are not between 0 and k-1"
if len(target.shape) != 1:
if not self.lazy:
assert (
target.shape[1] == self.num_classes
), "Onehot target does not match size of confusion matrix"
assert (target >= 0).all() and (
target <= 1
).all(), "in one-hot encoding, target values should be 0 or 1"
assert (target.sum(1) == 1).all(), "multi-label setting is not supported"
target = target.argmax(1)
else:
if not self.lazy:
assert (target.max() < self.num_classes) and (
target.min() >= 0
), "target values are not between 0 and k-1"
# hack for bincounting 2 arrays together
x = predicted + self.num_classes * target
if self.device == "cpu":
bincount_2d = np.bincount(x.astype(np.int64), minlength=self.num_classes**2)
assert bincount_2d.size == self.num_classes**2
conf = bincount_2d.reshape((self.num_classes, self.num_classes))
else:
bincount_2d = torch.bincount(x, minlength=self.num_classes**2)
conf = bincount_2d.view((self.num_classes, self.num_classes))
self.conf += conf
def value(self):
"""
Returns:
Confustion matrix of K rows and K columns, where rows corresponds
to ground-truth targets and columns corresponds to predicted
targets.
"""
if self.normalized:
conf = self.conf.astype(np.float32)
return conf / conf.sum(1).clip(min=1e-12)[:, None]
else:
return self.conf
class IoU(Metric):
"""Computes the intersection over union (IoU) per class and corresponding
mean (mIoU).
Intersection over union (IoU) is a common evaluation metric for semantic
segmentation. The predictions are first accumulated in a confusion matrix
and the IoU is computed from it as follows:
IoU = true_positive / (true_positive + false_positive + false_negative).
Keyword arguments:
- num_classes (int): number of classes in the classification problem
- normalized (boolean, optional): Determines whether or not the confusion
matrix is normalized or not. Default: False.
- ignore_index (int or iterable, optional): Index of the classes to ignore
when computing the IoU. Can be an int, or any iterable of ints.
"""
def __init__(
self,
num_classes,
normalized=False,
ignore_index=None,
cm_device="cpu",
lazy=True,
visu=False,
save_dir="",
):
super().__init__()
self.conf_metric = ConfusionMatrix(num_classes, normalized, device=cm_device, lazy=lazy)
self.lazy = lazy
self.visu = visu
if self.visu:
self.save_dir = save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if ignore_index is None:
self.ignore_index = None
elif isinstance(ignore_index, int):
self.ignore_index = (ignore_index,)
else:
try:
self.ignore_index = tuple(ignore_index)
except TypeError:
raise ValueError("'ignore_index' must be an int or iterable")
def reset(self):
self.conf_metric.reset()
def update(self, predicted, tg):
"""Adds the predicted and target pair to the IoU metric.
Keyword arguments:
- predicted (Tensor): Can be a (N, K, H, W) tensor of
predicted scores obtained from the model for N examples and K classes,
or (N, H, W) tensor of integer values between 0 and K-1.
- target (Tensor): Can be a (N, K, H, W) tensor of
target scores for N examples and K classes, or (N, H, W) tensor of
integer values between 0 and K-1.
"""
target = tg["label"]
# Dimensions check
assert predicted.size(0) == target.size(
0
), "number of targets and predicted outputs do not match"
assert (
predicted.dim() == 3 or predicted.dim() == 4
), "predictions must be of dimension (N, H, W) or (N, K, H, W)"
assert (
target.dim() == 3 or target.dim() == 4
), "targets must be of dimension (N, H, W) or (N, K, H, W)"
# If the tensor is in categorical format convert it to integer format
if predicted.dim() == 4:
_, predicted = predicted.max(1)
if target.dim() == 4:
_, target = target.max(1)
self.conf_metric.add(predicted.view(-1), target.view(-1))
if self.visu:
for i, name in enumerate(tg["name"]):
np.save(self.save_dir + str(name) + ".npy", predicted.cpu()[i].numpy())
np.save(self.save_dir + str(name) + "_gt.npy", target.cpu()[i].numpy())
def value(self):
"""Computes the IoU and mean IoU.
The mean computation ignores NaN elements of the IoU array.
Returns:
Tuple: (IoU, mIoU). The first output is the per class IoU,
for K classes it's numpy.ndarray with K elements. The second output,
is the mean IoU.
"""
conf_matrix = self.conf_metric.value()
if self.ignore_index is not None:
conf_matrix[:, self.ignore_index] = 0
conf_matrix[self.ignore_index, :] = 0
true_positive = np.diag(conf_matrix)
false_positive = np.sum(conf_matrix, 0) - true_positive
false_negative = np.sum(conf_matrix, 1) - true_positive
# Just in case we get a division by 0, ignore/hide the error
with np.errstate(divide="ignore", invalid="ignore"):
iou = true_positive / (true_positive + false_positive + false_negative)
return iou, np.nanmean(iou)
def compute(self):
conf_matrix = self.conf_metric.value()
if torch.is_tensor(conf_matrix):
conf_matrix = conf_matrix.cpu().numpy()
if self.ignore_index is not None:
conf_matrix[:, self.ignore_index] = 0
conf_matrix[self.ignore_index, :] = 0
true_positive = np.diag(conf_matrix)
false_positive = np.sum(conf_matrix, 0) - true_positive
false_negative = np.sum(conf_matrix, 1) - true_positive
# Just in case we get a division by 0, ignore/hide the error
with np.errstate(divide="ignore", invalid="ignore"):
iou = true_positive / (true_positive + false_positive + false_negative)
miou = float(np.nanmean(iou))
acc = float(np.diag(conf_matrix).sum() / conf_matrix.sum() * 100)
out = {"mIoU": miou, "acc": acc}
return out
class MetricsBinarySemSeg(Metric):
"""
Computes IoU Score for binary segmentation tasks
Args:
modalities (list): list of modalities used
save_results (bool): if True saves prediction in a csv file
threshold (float): threshold for binary prediction (default: 0.5)
"""
def __init__(self, modalities: list = [], save_results: bool = False, threshold: float = 0.5):
super().__init__()
self.modality = modalities[0]
self.threshold = threshold
self.miou = MeanIoU(num_classes=2, per_class=True) # Binary: 2 classes (0 and 1)
self.save_results = save_results
if save_results:
self.results = {}
def update(self, pred, gt):
# Convert predictions to binary using threshold
pred_binary = (pred.sigmoid() > self.threshold).float()
# Convert to one-hot encoding
pred_one_hot = torch.nn.functional.one_hot(
pred_binary.flatten(2, 3).permute(0, 2, 1).flatten(0, 1).long(), num_classes=2
)
gt_one_hot = torch.nn.functional.one_hot(
gt["label"].flatten(1, 2).flatten(0, 1).long(), num_classes=2
)
self.miou(pred_one_hot, gt_one_hot)
if self.save_results:
for i, name in enumerate(gt["name"]):
self.results[name] = pred_binary.cpu()[i].numpy()
def compute(self):
miou = self.miou.compute()
# For binary segmentation, we typically care about IoU of class 1 (foreground)
out = {
"mIoU": miou[1].item(), # IoU of the foreground class
"IoU_background": miou[0].item(),
"IoU_foreground": miou[1].item(),
}
if self.save_results:
out["results"] = self.results
return out
class MetricsReg(Metric):
"""
Computes the Root Mean Square Error (RMSE) for regression tasks by applying a softplus activation
to the predictions before computing the MSE loss.
Args:
modalities (list): List of modalities used in the model. Currently not used in the implementation
but kept for consistency with other metric classes.
Attributes:
softplus (nn.Softplus): Softplus activation function applied to predictions
mse (torch.Tensor): Accumulated mean squared error
total_samples (torch.Tensor): Counter for total number of samples processed
"""
def __init__(
self,
modalities: list = [],
):
super().__init__()
self.softplus = nn.Softplus()
self.mse = torch.tensor(0.0)
self.total_samples = torch.tensor(0)
def update(self, pred, gt):
self.mse += F.mse_loss(self.softplus(pred), gt["label"]).cpu()
self.total_samples += 1
def compute(self):
rmse = torch.sqrt(self.mse / self.total_samples)
out = {"RMSE": rmse.item()}
return out
class SegPangaea(Metric):
"""
SegPangaea is a class for evaluating segmentation models using a confusion matrix approach.
Attributes:
num_classes (int): Number of classes in the segmentation task
ignore_index (int): Index value to ignore when computing metrics
confusion_matrix (torch.Tensor): Matrix of shape (num_classes, num_classes) to store predictions
Methods:
update(pred, gt):
Updates the confusion matrix with new predictions and ground truth.
Args:
pred (torch.Tensor): Model predictions
gt (dict): Dictionary containing ground truth labels under 'label' key
compute():
Computes various metrics from the accumulated confusion matrix.
Returns:
dict: Dictionary containing the following metrics:
- mIoU: Mean Intersection over Union across all classes
- mF1: Mean F1 score across all classes
- mAcc: Mean pixel accuracy
"""
def __init__(self, num_classes, ignore_index):
self.num_classes = num_classes
self.ignore_index = ignore_index
self.confusion_matrix = torch.zeros(num_classes, num_classes)
def update(self, pred, gt):
label = gt["label"].flatten(1, 2)
pred = torch.argmax(pred, dim=1).flatten(1, 2)
valid_mask = label != self.ignore_index
pred, target = pred[valid_mask], label[valid_mask]
count = torch.bincount((pred * self.num_classes + target), minlength=self.num_classes**2)
self.confusion_matrix = self.confusion_matrix.to(pred.device)
self.confusion_matrix += count.view(self.num_classes, self.num_classes)
def compute(self):
# Calculate IoU for each class
intersection = torch.diag(self.confusion_matrix)
union = self.confusion_matrix.sum(dim=1) + self.confusion_matrix.sum(dim=0) - intersection
iou = intersection / (union + 1e-6)
# Calculate precision and recall for each class
precision = intersection / (self.confusion_matrix.sum(dim=0) + 1e-6)
recall = intersection / (self.confusion_matrix.sum(dim=1) + 1e-6)
# Calculate F1-score for each class
f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
# Calculate mean IoU, mean F1-score, and mean Accuracy
miou = iou.mean().item()
mf1 = f1.mean().item()
macc = (intersection.sum() / (self.confusion_matrix.sum() + 1e-6)).item()
# Convert metrics to CPU and to Python scalars
iou = iou.cpu()
f1 = f1.cpu()
precision = precision.cpu()
recall = recall.cpu()
# Prepare the metrics dictionary
metrics = {
"mIoU": miou,
"mF1": mf1,
"mAcc": macc,
}
return metrics