Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -3,28 +3,23 @@ import torch.nn as nn
|
|
| 3 |
import math
|
| 4 |
|
| 5 |
|
| 6 |
-
class
|
| 7 |
-
def __init__(self,
|
| 8 |
super().__init__()
|
| 9 |
|
| 10 |
self.patch_size = patch_size
|
| 11 |
-
self.d_model = d_model
|
| 12 |
-
|
| 13 |
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
| 14 |
-
self.proj = nn.Linear(in_channels * patch_size * patch_size, d_model)
|
| 15 |
|
| 16 |
def forward(self, x):
|
| 17 |
batch_size, c, h, w = x.shape
|
| 18 |
|
| 19 |
-
# Unfold
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# Apply linear projection to each patch: (batch_size, num_patches, in_channels * patch_size * patch_size) -> (batch_size, num_patches, d_model)
|
| 27 |
-
return self.proj(patches)
|
| 28 |
|
| 29 |
|
| 30 |
# Positional Encoding
|
|
@@ -139,7 +134,7 @@ class PositionwiseFeedForward(nn.Module):
|
|
| 139 |
|
| 140 |
self.ffn = nn.Sequential(
|
| 141 |
nn.Linear(in_features=d_model, out_features=(d_model * 4)),
|
| 142 |
-
nn.
|
| 143 |
nn.Linear(in_features=(d_model * 4), out_features=d_model),
|
| 144 |
nn.Dropout(p=dropout),
|
| 145 |
)
|
|
@@ -218,9 +213,8 @@ class Encoder(nn.Module):
|
|
| 218 |
|
| 219 |
self.patch_size = patch_size
|
| 220 |
|
| 221 |
-
self.
|
| 222 |
-
|
| 223 |
-
)
|
| 224 |
|
| 225 |
seq_length = (image_size // patch_size) ** 2
|
| 226 |
|
|
@@ -245,7 +239,7 @@ class Encoder(nn.Module):
|
|
| 245 |
|
| 246 |
# Extract the patches and apply a linear layer
|
| 247 |
batch_size = src.shape[0]
|
| 248 |
-
src = self.
|
| 249 |
|
| 250 |
# Add the learned positional embedding
|
| 251 |
src = src + self.pos_embedding
|
|
|
|
| 3 |
import math
|
| 4 |
|
| 5 |
|
| 6 |
+
class ExtractPatches(nn.Module):
|
| 7 |
+
def __init__(self, patch_size: int = 16):
|
| 8 |
super().__init__()
|
| 9 |
|
| 10 |
self.patch_size = patch_size
|
|
|
|
|
|
|
| 11 |
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
|
|
|
| 12 |
|
| 13 |
def forward(self, x):
|
| 14 |
batch_size, c, h, w = x.shape
|
| 15 |
|
| 16 |
+
# Unfold applies a slding window to generate patches
|
| 17 |
+
# The transpose and reshape change the shape to (batch_size, num_patches, 3 * patch_size * patch_size), flattening the patches
|
| 18 |
+
return (
|
| 19 |
+
self.unfold(x)
|
| 20 |
+
.transpose(1, 2)
|
| 21 |
+
.reshape(batch_size, -1, c * self.patch_size * self.patch_size)
|
| 22 |
+
)
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
# Positional Encoding
|
|
|
|
| 134 |
|
| 135 |
self.ffn = nn.Sequential(
|
| 136 |
nn.Linear(in_features=d_model, out_features=(d_model * 4)),
|
| 137 |
+
nn.ReLU(),
|
| 138 |
nn.Linear(in_features=(d_model * 4), out_features=d_model),
|
| 139 |
nn.Dropout(p=dropout),
|
| 140 |
)
|
|
|
|
| 213 |
|
| 214 |
self.patch_size = patch_size
|
| 215 |
|
| 216 |
+
self.extract_patches = ExtractPatches(patch_size=patch_size)
|
| 217 |
+
self.fc_in = nn.Linear(in_channels * patch_size * patch_size, d_model)
|
|
|
|
| 218 |
|
| 219 |
seq_length = (image_size // patch_size) ** 2
|
| 220 |
|
|
|
|
| 239 |
|
| 240 |
# Extract the patches and apply a linear layer
|
| 241 |
batch_size = src.shape[0]
|
| 242 |
+
src = self.fc_in(self.extract_patches(src))
|
| 243 |
|
| 244 |
# Add the learned positional embedding
|
| 245 |
src = src + self.pos_embedding
|