Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from functools import partial | |
| from torch import nn | |
| def activation_func(activation: str): | |
| return nn.ModuleDict([ | |
| ['relu', nn.ReLU(inplace=True)], | |
| ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)], | |
| ['selu', nn.SELU(inplace=True)], | |
| ['none', nn.Identity()] | |
| ])[activation] | |
| def norm_module(norm: str): | |
| return { | |
| 'batch': nn.BatchNorm2d, | |
| 'instance': nn.InstanceNorm2d, | |
| }[norm] | |
| class Conv2dAuto(nn.Conv2d): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # dynamic add padding based on the kernel_size | |
| self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) | |
| conv3x3 = partial(Conv2dAuto, kernel_size=3) | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, activation: str = 'relu'): | |
| super().__init__() | |
| self.in_channels, self.out_channels = in_channels, out_channels | |
| self.blocks = nn.Identity() | |
| self.activate = activation_func(activation) | |
| self.shortcut = nn.Identity() | |
| def forward(self, x): | |
| residual = x | |
| if self.should_apply_shortcut: | |
| residual = self.shortcut(x) | |
| x = self.blocks(x) | |
| x += residual | |
| x = self.activate(x) | |
| return x | |
| def should_apply_shortcut(self): | |
| return self.in_channels != self.out_channels | |
| class ResNetResidualBlock(ResidualBlock): | |
| def __init__( | |
| self, in_channels: int, out_channels: int, | |
| expansion: int = 1, downsampling: int = 1, | |
| conv=conv3x3, norm: str = 'batch', *args, **kwargs | |
| ): | |
| super().__init__(in_channels, out_channels, *args, **kwargs) | |
| self.expansion, self.downsampling = expansion, downsampling | |
| self.conv, self.norm = conv, norm_module(norm) | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1, | |
| stride=self.downsampling, bias=False), | |
| self.norm(self.expanded_channels)) if self.should_apply_shortcut else None | |
| def expanded_channels(self): | |
| return self.out_channels * self.expansion | |
| def should_apply_shortcut(self): | |
| return self.in_channels != self.expanded_channels | |
| def conv_norm(in_channels: int, out_channels: int, conv, norm, *args, **kwargs): | |
| return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs), norm(out_channels)) | |
| class ResNetBasicBlock(ResNetResidualBlock): | |
| """ | |
| Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation | |
| """ | |
| expansion = 1 | |
| def __init__( | |
| self, in_channels: int, out_channels: int, bias: bool = False, *args, **kwargs | |
| ): | |
| super().__init__(in_channels, out_channels, *args, **kwargs) | |
| self.blocks = nn.Sequential( | |
| conv_norm( | |
| self.in_channels, self.out_channels, conv=self.conv, norm=self.norm, | |
| bias=bias, stride=self.downsampling | |
| ), | |
| self.activate, | |
| conv_norm(self.out_channels, self.expanded_channels, conv=self.conv, norm=self.norm, bias=bias), | |
| ) | |