Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import torch | |
| import torch.nn as nn | |
| class UpsampleBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.conv_in = nn.Sequential( | |
| nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), | |
| nn.GELU() | |
| ) | |
| self.conv_up = nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), | |
| nn.GELU() | |
| ) | |
| self.conv_out = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) | |
| ) | |
| def forward(self, x): | |
| x = self.conv_in(x) | |
| x = self.conv_up(x) | |
| x = self.conv_out(x) | |
| return x | |
| class CLIPImagePreProcessor(nn.Module): | |
| def __init__( | |
| self, | |
| input_size=896, | |
| clip_input_size=224, | |
| downscale_factor: int = 16, | |
| ): | |
| super().__init__() | |
| # make sure they are evenly divisible | |
| assert input_size % clip_input_size == 0 | |
| in_channels = 3 | |
| self.input_size = input_size | |
| self.clip_input_size = clip_input_size | |
| self.downscale_factor = downscale_factor | |
| subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768 | |
| channels = subpixel_channels | |
| upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4 | |
| num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2 | |
| # make the residual down up blocks | |
| self.upsample_blocks = nn.ModuleList() | |
| self.subpixel_blocks = nn.ModuleList() | |
| current_channels = channels | |
| current_downscale = downscale_factor | |
| for _ in range(num_upsample_blocks): | |
| # determine the reshuffled channel count for this dimension | |
| output_downscale = current_downscale // 2 | |
| out_channels = in_channels * output_downscale ** 2 | |
| # out_channels = current_channels // 2 | |
| self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels)) | |
| current_channels = out_channels | |
| current_downscale = output_downscale | |
| self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale)) | |
| # (bs, 768, 56, 56) -> (bs, 192, 112, 112) | |
| # (bs, 192, 112, 112) -> (bs, 48, 224, 224) | |
| self.conv_out = nn.Conv2d( | |
| current_channels, | |
| out_channels=3, | |
| kernel_size=3, | |
| padding=1 | |
| ) # (bs, 48, 224, 224) -> (bs, 3, 224, 224) | |
| # do a pooling layer to downscale the input to 1/3 of the size | |
| # (bs, 3, 896, 896) -> (bs, 3, 224, 224) | |
| kernel_size = input_size // clip_input_size | |
| self.res_down = nn.AvgPool2d( | |
| kernel_size=kernel_size, | |
| stride=kernel_size | |
| ) # (bs, 3, 896, 896) -> (bs, 3, 224, 224) | |
| # make a blending for output residual with near 0 weight | |
| self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224) | |
| self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56) | |
| self.conv_in = nn.Sequential( | |
| nn.Conv2d( | |
| subpixel_channels, | |
| channels, | |
| kernel_size=3, | |
| padding=1 | |
| ), | |
| nn.GELU() | |
| ) # (bs, 768, 56, 56) -> (bs, 768, 56, 56) | |
| # make 2 deep blocks | |
| def forward(self, x): | |
| inputs = x | |
| # resize to input_size x input_size | |
| x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic') | |
| res = self.res_down(inputs) | |
| x = self.unshuffle(x) | |
| x = self.conv_in(x) | |
| for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks): | |
| x = up(x) | |
| block_res = subpixel(inputs) | |
| x = x + block_res | |
| x = self.conv_out(x) | |
| # blend residual | |
| x = x * self.res_blend + res | |
| return x | |
