EXAONE-Path-MSI / models /aggregator.py
2ms's picture
remove unused option
f1e29ff
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from dataclasses import dataclass
from functools import partial
from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 6
width: int = 512
head_width: int = 64
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'none'
final_ln_after_pool: bool = True # apply final LayerNorm after pooling
pool_type: str = 'none'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
img_embed: bool = False
cls_embed: bool = False
projection = False
use_flex = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
dropout: float = 0.1,
num_registers: int = 0,
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
act_layer = QuickGELU if quick_gelu else nn.GELU
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
output_dim=embed_dim,
patch_dropout=vision_cfg.patch_dropout,
no_ln_pre=vision_cfg.no_ln_pre,
pool_type=vision_cfg.pool_type,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
act_layer=act_layer,
norm_layer=norm_layer,
output_tokens=vision_cfg.output_tokens,
img_embed = vision_cfg.img_embed,
use_flex = True,
dropout = dropout,
num_registers = num_registers,
use_rel_bias =True,
)
return visual
class MixedOmicsModel(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
drop_rate: float = 0.25,
num_registers: int = 0,
*args,
**kwargs,
):
super().__init__()
self.drop_prob = drop_rate
self.num_registers = num_registers
vision_cfg.cls_embed = False
self.visual = _build_vision_tower(embed_dim,
vision_cfg,
quick_gelu,
cast_dtype,
dropout=drop_rate,
num_registers=0,
)
self.image_proj = nn.Linear(embed_dim, embed_dim)
self.image_proj.apply(self.init_weights)
self.ln_post = LayerNorm(embed_dim)
def init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _check_tensor(self, tensor, name):
print(name, " : ", tensor.shape)
if torch.isnan(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains NaN values.")
if torch.isinf(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains Inf values.")
def forward(
self,
image,
coords=None,
im_mask=None,
*args,
**kwargs,
):
## image embedding
image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous())
image_embeds = self.ln_post(image_embeds)
if im_mask is not None:
mask = im_mask.unsqueeze(-1).contiguous()
masked_embeds = image_embeds * mask
sum_embeds = masked_embeds.sum(dim=1)
valid_counts = mask.sum(dim=1).clamp(min=1) # [N, 1]
mean_embeds = sum_embeds / valid_counts # [N, dim]
else:
mean_embeds = image_embeds.mean(-2)
image_embeds_final = self.image_proj(mean_embeds)
return image_embeds_final, image_embeds, mean_embeds
def make_model(
embed_dim=768,
droprate=0.1,
num_registers=0,
depth=4,
):
vCfg = CLIPVisionCfg
vCfg.width = embed_dim
vCfg.layers = depth
model = MixedOmicsModel(
embed_dim=embed_dim,
vision_cfg=vCfg,
drop_rate=droprate,
num_registers=num_registers,
)
return model