Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
# @Author : Lintao Peng | |
# @File : SGFMT.py | |
# coding=utf-8 | |
# Design based on the Vit | |
import torch.nn as nn | |
from net.IntmdSequential import IntermediateSequential | |
#实现了自注意力机制,相当于unet的bottleneck层 | |
class SelfAttention(nn.Module): | |
def __init__( | |
self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0 | |
): | |
super().__init__() | |
self.num_heads = heads | |
head_dim = dim // heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(dropout_rate) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(dropout_rate) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = ( | |
self.qkv(x) | |
.reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
.permute(2, 0, 3, 1, 4) | |
) | |
q, k, v = ( | |
qkv[0], | |
qkv[1], | |
qkv[2], | |
) # make torchscript happy (cannot use tensor as tuple) | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x): | |
return self.fn(x) + x | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.fn = fn | |
def forward(self, x): | |
return self.fn(self.norm(x)) | |
class PreNormDrop(nn.Module): | |
def __init__(self, dim, dropout_rate, fn): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.dropout = nn.Dropout(p=dropout_rate) | |
self.fn = fn | |
def forward(self, x): | |
return self.dropout(self.fn(self.norm(x))) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout_rate): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, hidden_dim), | |
nn.GELU(), | |
nn.Dropout(p=dropout_rate), | |
nn.Linear(hidden_dim, dim), | |
nn.Dropout(p=dropout_rate), | |
) | |
def forward(self, x): | |
return self.net(x) | |
class TransformerModel(nn.Module): | |
def __init__( | |
self, | |
dim, #512 | |
depth, #4 | |
heads, #8 | |
mlp_dim, #4096 | |
dropout_rate=0.1, | |
attn_dropout_rate=0.1, | |
): | |
super().__init__() | |
layers = [] | |
for _ in range(depth): | |
layers.extend( | |
[ | |
Residual( | |
PreNormDrop( | |
dim, | |
dropout_rate, | |
SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate), | |
) | |
), | |
Residual( | |
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) | |
), | |
] | |
) | |
# dim = dim / 2 | |
self.net = IntermediateSequential(*layers) | |
def forward(self, x): | |
return self.net(x) | |