jadechoghari's picture
add model
9b9e0ee verified
raw
history blame
4.38 kB
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple
class PatchEmbed_org(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
y = x.flatten(2).transpose(1, 2)
return y
class PatchEmbed_new(nn.Module):
"""Flexible Image to Patch Embedding"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride
) # with overlapped patches
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
# self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
self.patch_hw = (h, w)
self.num_patches = h * w
def get_output_shape(self, img_size):
# todo: don't be lazy..
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# x = self.proj(x).flatten(2).transpose(1, 2)
x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
return x
class PatchEmbed3D_new(nn.Module):
"""Flexible Image to Patch Embedding"""
def __init__(
self,
video_size=(16, 224, 224),
patch_size=(2, 16, 16),
in_chans=3,
embed_dim=768,
stride=(2, 16, 16),
):
super().__init__()
self.video_size = video_size
self.patch_size = patch_size
self.in_chans = in_chans
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride
)
_, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w
self.patch_thw = (t, h, w)
self.num_patches = t * h * w
def get_output_shape(self, video_size):
# todo: don't be lazy..
return self.proj(
torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])
).shape
def forward(self, x):
B, C, T, H, W = x.shape
x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14
x = x.flatten(2) # 32, 768, 1568
x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768
return x
if __name__ == "__main__":
# patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))
# input = torch.rand(8,1,1024,128)
# output = patch_emb(input)
# print(output.shape) # (8,512,64)
patch_emb = PatchEmbed3D_new(
video_size=(6, 224, 224),
patch_size=(2, 16, 16),
in_chans=3,
embed_dim=768,
stride=(2, 16, 16),
)
input = torch.rand(8, 3, 6, 224, 224)
output = patch_emb(input)
print(output.shape) # (8,64)