Spaces:
Sleeping
Sleeping
import logging | |
import random | |
from copy import deepcopy | |
from typing import Dict, Optional | |
import torch | |
import torchvision.transforms.v2.functional as TVF | |
from einops import rearrange | |
from torch.utils.data import DataLoader | |
logger = logging.getLogger("__main__") | |
def get_embeddings(data_loader, model, device, subsample_tokens: Optional[float] = None): | |
embeddings = [] | |
labels = [] | |
if subsample_tokens: | |
print(f"Subsampling tokens with ratio {subsample_tokens}") | |
model = model.eval() | |
with torch.no_grad(): | |
for batch in data_loader: | |
batch_labels = batch.pop("target") | |
if "s1" in batch: | |
batch["s1"] = batch["s1"].to(device).to(torch.bfloat16) | |
if "s2" in batch: | |
batch["s2"] = batch["s2"].to(device).to(torch.bfloat16) | |
if "months" in batch: | |
batch["months"] = batch["months"].to(device).long() | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
batch_embeddings = model(**batch) # (bsz, dim) or (bsz, tokens, dim) | |
if subsample_tokens is not None: | |
if len(batch_embeddings.shape) < 3: | |
raise ValueError("subsample tokens only works for segmentation tasks") | |
num_tokens_per_instance = batch_embeddings.shape[1] | |
num_instances_to_keep = int(num_tokens_per_instance * subsample_tokens) | |
sampled_indices = torch.randperm(num_tokens_per_instance)[:num_instances_to_keep] | |
batch_embeddings = batch_embeddings[:, sampled_indices] | |
tokens_per_dim = int(num_tokens_per_instance**0.5) | |
pixels_per_token_dim = int(batch_labels.shape[1] / tokens_per_dim) | |
batch_labels_per_token = rearrange( | |
batch_labels, | |
"b (t_h p_h) (t_w p_w) -> b (t_h t_w) (p_h p_w)", | |
t_h=tokens_per_dim, | |
t_w=tokens_per_dim, | |
p_h=pixels_per_token_dim, | |
p_w=pixels_per_token_dim, | |
) | |
batch_labels = batch_labels_per_token[:, sampled_indices] | |
embeddings.append(batch_embeddings.to(torch.bfloat16).cpu()) | |
labels.append(batch_labels) | |
return torch.cat(embeddings, dim=0), torch.cat(labels, dim=0) | |
class DownstreamAugs(object): | |
""" | |
For now, lets have no parameters | |
Choose 1 of 8 transformations and apply it to space_x and the segmentation map (if needed) | |
""" | |
def __init__(self, enabled: bool): | |
self.enabled = enabled | |
self.transformations = [ | |
self.no_transform, # No transformation | |
self.rotate_90, # 90-degree rotation | |
self.rotate_180, # 180-degree rotation | |
self.rotate_270, # 270-degree rotation | |
self.hflip, # Horizontal flip | |
self.vflip, # Vertical flip | |
self.hflip_rotate_90, # Horizontal flip of 90-degree rotated image | |
self.vflip_rotate_90, # Vertical flip of 90-degree rotated image | |
] | |
def no_transform(self, x): | |
return x | |
def rotate_90(self, x): | |
return TVF.rotate(x, 90) | |
def rotate_180(self, x): | |
return TVF.rotate(x, 180) | |
def rotate_270(self, x): | |
return TVF.rotate(x, 270) | |
def hflip(self, x): | |
return TVF.hflip(x) | |
def vflip(self, x): | |
return TVF.vflip(x) | |
def hflip_rotate_90(self, x): | |
return TVF.hflip(TVF.rotate(x, 90)) | |
def vflip_rotate_90(self, x): | |
return TVF.vflip(TVF.rotate(x, 90)) | |
def apply(self, image, target, task_type): | |
assert task_type in ["cls", "seg"] | |
# image is (H, W, C) | |
# target is either (1,) for classification or (H, W) for segmentation | |
if not self.enabled: | |
return image, target | |
# choose 1 of 8 possible augmentations | |
transformation = random.choice(self.transformations) | |
# transform image and rearrange | |
image = rearrange(image, "h w c -> c h w") | |
image = transformation(image) | |
image = rearrange(image, "c h w -> h w c") | |
if task_type == "cls": | |
return image, target | |
else: | |
# transform segmentation map and rearrange | |
assert target.shape[-1] == image.shape[-1] | |
assert target.shape[-2] == image.shape[-2] | |
target = rearrange(target, "h w -> 1 h w") | |
target = transformation(target) | |
target = rearrange(target, "1 h w -> h w") | |
return image, target | |
def get_loaders( | |
benchmark, | |
config, | |
model_name, | |
batch_size, | |
num_workers, | |
eval_type, | |
train_partition: Optional[str] = None, | |
valtest_partition: Optional[str] = None, | |
norm_ops: Optional[Dict] = None, | |
): | |
use_train_augs = True if eval_type == "FT" else False | |
dataclass_kwargs = deepcopy(benchmark["kwargs"]) | |
if norm_ops is None: | |
dataclass_kwargs["norm_operation"] = config["models"][model_name] | |
else: | |
dataclass_kwargs["norm_operation"] = norm_ops | |
train_kwargs = deepcopy(dataclass_kwargs) | |
valtest_kwargs = deepcopy(dataclass_kwargs) | |
if train_partition is not None: | |
train_kwargs["partition"] = train_partition | |
if valtest_partition is None: | |
valtest_partition = "default" | |
valtest_kwargs["partition"] = valtest_partition | |
elif valtest_partition: | |
raise ValueError("Shouldn't have not None val_partition but None train_partiton") | |
return { | |
"train": DataLoader( | |
benchmark["class"]( | |
**train_kwargs, | |
split="train", | |
augmentation=DownstreamAugs(use_train_augs), | |
), | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
), | |
"valid": DataLoader( | |
benchmark["class"]( | |
**valtest_kwargs, | |
split="valid", | |
augmentation=DownstreamAugs(False), | |
), | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
), | |
"test": DataLoader( | |
benchmark["class"]( | |
**valtest_kwargs, | |
split="test", | |
augmentation=DownstreamAugs(False), | |
), | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
), | |
} | |