Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from sklearn.metrics import accuracy_score, f1_score | |
from src.galileo import adjust_learning_rate | |
from .metrics import mean_iou | |
PROBING_LRs = { | |
"LP": [ | |
1e-4, | |
3e-4, | |
5e-4, | |
8e-4, | |
1e-3, | |
3e-3, | |
5e-3, | |
8e-3, | |
1e-2, | |
3e-2, | |
5e-2, | |
8e-2, | |
1e-1, | |
3e-1, | |
5e-1, | |
8e-1, | |
], | |
} | |
def train_and_eval_probe_cls(lr, config, loaders, in_features, device): | |
probe = train_probe_cls( | |
data_loader=loaders["train"], | |
lr=lr, | |
epochs=50, | |
in_features=in_features, | |
num_classes=config["num_classes"], | |
is_multilabel=config["is_multilabel"], | |
device=device, | |
) | |
val_acc = evaluate_probe_cls( | |
data_loader=loaders["valid"], | |
probe=probe, | |
is_multilabel=config["is_multilabel"], | |
device=device, | |
) | |
test_acc = evaluate_probe_cls( | |
data_loader=loaders["test"], | |
probe=probe, | |
is_multilabel=config["is_multilabel"], | |
device=device, | |
) | |
return val_acc, test_acc | |
def train_and_eval_probe_seg(lr, config, loaders, in_features, grid_size, device): | |
output_patch_size = math.ceil(config["segmentation_map_height_width"] / grid_size) | |
probe = train_probe_seg( | |
data_loader=loaders["train"], | |
lr=lr, | |
epochs=50, | |
in_features=in_features, | |
num_classes=config["num_classes"], | |
patch_size=output_patch_size, | |
device=device, | |
) | |
val_miou = evaluate_probe_seg( | |
data_loader=loaders["valid"], | |
probe=probe, | |
num_classes=config["num_classes"], | |
patch_size=output_patch_size, | |
device=device, | |
) | |
test_miou = evaluate_probe_seg( | |
data_loader=loaders["test"], | |
probe=probe, | |
num_classes=config["num_classes"], | |
patch_size=output_patch_size, | |
device=device, | |
) | |
return val_miou, test_miou | |
def train_probe_cls( | |
data_loader, | |
lr, | |
epochs, | |
in_features, | |
num_classes, | |
is_multilabel, | |
device, | |
): | |
probe = nn.Sequential(nn.BatchNorm1d(in_features), nn.Linear(in_features, num_classes)).to( | |
device | |
) | |
opt = torch.optim.AdamW(probe.parameters(), lr=lr) | |
sched_config = { | |
"lr": lr, | |
"warmup_epochs": int(epochs * 0.1), | |
"min_lr": 1.0e-5, | |
"epochs": epochs, | |
} | |
probe = probe.train() | |
if is_multilabel: | |
loss_function = nn.MultiLabelSoftMarginLoss() | |
else: | |
loss_function = nn.CrossEntropyLoss() | |
for epoch in range(epochs): | |
for i, batch in enumerate(data_loader): | |
batch_emb, batch_labels = batch # (bsz, dim), (bsz) | |
batch_emb = batch_emb.to(device) | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
logits = probe(batch_emb) # (bsz, num_classes) | |
loss = loss_function(logits, batch_labels.to(device)) | |
loss.backward() | |
adjust_learning_rate( | |
optimizer=opt, | |
epoch=epoch + (i / len(data_loader)), | |
total_epochs=sched_config["epochs"], | |
warmup_epochs=sched_config["warmup_epochs"], | |
max_lr=sched_config["lr"], | |
min_lr=sched_config["min_lr"], | |
) | |
opt.step() | |
opt.zero_grad() | |
return probe | |
def evaluate_probe_cls(data_loader, probe, is_multilabel, device): | |
probe = probe.eval() | |
all_logits = [] | |
all_labels = [] | |
with torch.no_grad(): | |
for batch in data_loader: | |
batch_emb, batch_labels = batch # (bsz, dim), (bsz) | |
batch_emb = batch_emb.to(device) | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
batch_logits = probe(batch_emb) # (bsz, num_classes) | |
all_logits.append(batch_logits.float().cpu()) | |
all_labels.append(batch_labels) | |
all_logits = torch.cat(all_logits, dim=0) | |
all_labels = torch.cat(all_labels, dim=0) | |
if is_multilabel: | |
all_preds = torch.sigmoid(all_logits) > 0.5 | |
return f1_score(all_labels, all_preds, average="micro") | |
else: | |
all_preds = torch.argmax(all_logits, dim=-1) | |
return accuracy_score(all_labels, all_preds) | |
def train_probe_seg( | |
data_loader, | |
lr, | |
epochs, | |
in_features, | |
num_classes, | |
patch_size, | |
probe_type, | |
device, | |
): | |
logits_per_patch = int(num_classes * patch_size * patch_size) | |
assert probe_type in ["LP", "MLP"] | |
if probe_type == "LP": | |
probe = nn.Sequential(nn.Linear(in_features, logits_per_patch)).to(device) | |
else: | |
probe = nn.Sequential( | |
nn.Linear(in_features, 2048), nn.GELU(), nn.Linear(2048, logits_per_patch) | |
).to(device) | |
opt = torch.optim.AdamW(probe.parameters(), lr=lr) | |
sched_config = { | |
"lr": lr, | |
"warmup_epochs": int(epochs * 0.1), | |
"min_lr": 1.0e-5, | |
"epochs": epochs, | |
} | |
probe = probe.train() | |
loss_function = nn.CrossEntropyLoss(ignore_index=-1) # for MADOS, but ok for others | |
for epoch in range(epochs): | |
for i, batch in enumerate(data_loader): | |
batch_emb, batch_labels = batch # (bsz, num_patches, dim), (bsz, H, W) | |
spatial_patches_per_dim = int(batch_emb.shape[1] ** 0.5) | |
batch_emb = batch_emb.to(device) | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
logits = probe(batch_emb) # (bsz, num_patches, logits_per_patch) | |
# this is a bit hackey | |
if batch_labels.shape[1] == batch_labels.shape[2]: | |
logits = rearrange( | |
logits, | |
"b (h w) (c i j) -> b c (h i) (w j)", | |
h=spatial_patches_per_dim, | |
w=spatial_patches_per_dim, | |
c=num_classes, | |
i=patch_size, | |
j=patch_size, | |
) | |
if logits.shape[-2] != batch_labels.shape[-2]: | |
logits = F.interpolate( | |
logits, | |
size=(batch_labels.shape[-2], batch_labels.shape[-1]), | |
mode="bilinear", | |
align_corners=True, | |
) # (bsz, num_classes, H, W) | |
else: | |
# otherwise, we subsampled in the get_embeddings step | |
logits = rearrange( | |
logits, | |
"b t (c i j) -> b c t (i j)", | |
c=num_classes, | |
i=patch_size, | |
j=patch_size, | |
) | |
loss = loss_function(logits, batch_labels.to(device)) | |
loss.backward() | |
adjust_learning_rate( | |
optimizer=opt, | |
epoch=epoch + (i / len(data_loader)), | |
total_epochs=sched_config["epochs"], | |
warmup_epochs=sched_config["warmup_epochs"], | |
max_lr=sched_config["lr"], | |
min_lr=sched_config["min_lr"], | |
) | |
opt.step() | |
opt.zero_grad() | |
return probe | |
def evaluate_probe_seg( | |
data_loader, | |
probe, | |
num_classes, | |
patch_size, | |
device, | |
): | |
probe = probe.eval() | |
all_preds = [] | |
all_labels = [] | |
with torch.no_grad(): | |
for batch in data_loader: | |
batch_emb, batch_labels = batch # (bsz, num_patches, dim), (bsz, H, W) | |
spatial_patches_per_dim = int(batch_emb.shape[1] ** 0.5) | |
batch_emb = batch_emb.to(device) | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
logits = probe(batch_emb) # (bsz, num_patches, logits_per_patch) | |
logits = rearrange( | |
logits, | |
"b (h w) (c i j) -> b c (h i) (w j)", | |
h=spatial_patches_per_dim, | |
w=spatial_patches_per_dim, | |
c=num_classes, | |
i=patch_size, | |
j=patch_size, | |
) | |
if logits.shape[-2] != batch_labels.shape[-2]: | |
logits = F.interpolate( | |
logits, | |
size=(batch_labels.shape[-2], batch_labels.shape[-1]), | |
mode="bilinear", | |
align_corners=True, | |
) # (bsz, num_classes, H, W) | |
preds = torch.argmax(logits, dim=1).cpu() | |
all_preds.append(preds) | |
all_labels.append(batch_labels) | |
all_preds = torch.cat(all_preds) | |
all_labels = torch.cat(all_labels) | |
miou = mean_iou(all_preds, all_labels, num_classes=num_classes, ignore_label=-1) | |
return miou | |