File size: 6,044 Bytes
f2e5c2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from dataclasses import dataclass
from functools import partial
from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 6
width: int = 512
head_width: int = 64
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'none'
final_ln_after_pool: bool = True # apply final LayerNorm after pooling
pool_type: str = 'none'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
img_embed: bool = False
cls_embed: bool = False
projection = False
use_flex = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
dropout: float = 0.1,
num_registers: int = 0,
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
act_layer = QuickGELU if quick_gelu else nn.GELU
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
output_dim=embed_dim,
patch_dropout=vision_cfg.patch_dropout,
no_ln_pre=vision_cfg.no_ln_pre,
pool_type=vision_cfg.pool_type,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
act_layer=act_layer,
norm_layer=norm_layer,
output_tokens=vision_cfg.output_tokens,
img_embed = vision_cfg.img_embed,
use_flex = True,
dropout = dropout,
num_registers = num_registers,
use_rel_bias =True,
)
return visual
class MixedOmicsModel(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
drop_rate: float = 0.25,
num_registers: int = 0,
*args,
**kwargs,
):
super().__init__()
self.drop_prob = drop_rate
self.num_registers = num_registers
vision_cfg.cls_embed = False
self.visual = _build_vision_tower(embed_dim,
vision_cfg,
quick_gelu,
cast_dtype,
dropout=drop_rate,
num_registers=0,
)
self.image_proj = nn.Linear(embed_dim, embed_dim)
self.image_proj.apply(self.init_weights)
self.ln_post = LayerNorm(embed_dim)
def init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _check_tensor(self, tensor, name):
print(name, " : ", tensor.shape)
if torch.isnan(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains NaN values.")
if torch.isinf(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains Inf values.")
def forward(
self,
image,
coords=None,
im_mask=None,
*args,
**kwargs,
):
## image embedding
image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous())
image_embeds = self.ln_post(image_embeds)
if im_mask is not None:
mask = im_mask.unsqueeze(-1).contiguous()
masked_embeds = image_embeds * mask
sum_embeds = masked_embeds.sum(dim=1)
valid_counts = mask.sum(dim=1).clamp(min=1) # [N, 1]
mean_embeds = sum_embeds / valid_counts # [N, dim]
else:
mean_embeds = image_embeds.mean(-2)
image_embeds_final = self.image_proj(mean_embeds)
return image_embeds_final, image_embeds, mean_embeds
def make_model(
embed_dim=768,
droprate=0.1,
num_registers=0,
depth=4,
):
vCfg = CLIPVisionCfg
vCfg.width = embed_dim
vCfg.layers = depth
model = MixedOmicsModel(
embed_dim=embed_dim,
vision_cfg=vCfg,
drop_rate=droprate,
num_registers=num_registers,
)
return model |