Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload model/archs/unet.py with huggingface_hub
Browse files- model/archs/unet.py +53 -0
    	
        model/archs/unet.py
    ADDED
    
    | @@ -0,0 +1,53 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            '''
         | 
| 2 | 
            +
            Codes are from:
         | 
| 3 | 
            +
            https://github.com/jaxony/unet-pytorch/blob/master/model.py
         | 
| 4 | 
            +
            '''
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            from diffusers import UNet2DModel
         | 
| 9 | 
            +
            import einops
         | 
| 10 | 
            +
            class UNetPP(nn.Module):
         | 
| 11 | 
            +
                '''
         | 
| 12 | 
            +
                    Wrapper for UNet in diffusers
         | 
| 13 | 
            +
                '''
         | 
| 14 | 
            +
                def __init__(self, in_channels):
         | 
| 15 | 
            +
                    super(UNetPP, self).__init__()
         | 
| 16 | 
            +
                    self.in_channels = in_channels
         | 
| 17 | 
            +
                    self.unet = UNet2DModel(
         | 
| 18 | 
            +
                            sample_size=[256, 256*3],
         | 
| 19 | 
            +
                            in_channels=in_channels,
         | 
| 20 | 
            +
                            out_channels=32,
         | 
| 21 | 
            +
                            layers_per_block=2,
         | 
| 22 | 
            +
                            block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
         | 
| 23 | 
            +
                            down_block_types=(
         | 
| 24 | 
            +
                                "DownBlock2D",
         | 
| 25 | 
            +
                                "DownBlock2D",
         | 
| 26 | 
            +
                                "DownBlock2D",
         | 
| 27 | 
            +
                                "AttnDownBlock2D",
         | 
| 28 | 
            +
                                "AttnDownBlock2D",
         | 
| 29 | 
            +
                                "AttnDownBlock2D",
         | 
| 30 | 
            +
                                "DownBlock2D",
         | 
| 31 | 
            +
                            ),
         | 
| 32 | 
            +
                            up_block_types=(
         | 
| 33 | 
            +
                                "UpBlock2D",
         | 
| 34 | 
            +
                                "AttnUpBlock2D",
         | 
| 35 | 
            +
                                "AttnUpBlock2D",
         | 
| 36 | 
            +
                                "AttnUpBlock2D",
         | 
| 37 | 
            +
                                "UpBlock2D",
         | 
| 38 | 
            +
                                "UpBlock2D",
         | 
| 39 | 
            +
                                "UpBlock2D",
         | 
| 40 | 
            +
                            ),
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
                         
         | 
| 43 | 
            +
                    # self.unet.enable_xformers_memory_efficient_attention()    
         | 
| 44 | 
            +
                    if in_channels > 12:
         | 
| 45 | 
            +
                        self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def forward(self, x, t=256):
         | 
| 48 | 
            +
                    learned_plane = self.learned_plane
         | 
| 49 | 
            +
                    if x.shape[1] < self.in_channels:
         | 
| 50 | 
            +
                        learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
         | 
| 51 | 
            +
                        x = torch.cat([x, learned_plane], dim = 1)
         | 
| 52 | 
            +
                    return self.unet(x, t).sample
         | 
| 53 | 
            +
             | 
