Mariam-Elz commited on
Commit
6e6abb2
·
verified ·
1 Parent(s): 2981c34

Upload model/archs/unet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model/archs/unet.py +53 -53
model/archs/unet.py CHANGED
@@ -1,53 +1,53 @@
1
- '''
2
- Codes are from:
3
- https://github.com/jaxony/unet-pytorch/blob/master/model.py
4
- '''
5
-
6
- import torch
7
- import torch.nn as nn
8
- from diffusers import UNet2DModel
9
- import einops
10
- class UNetPP(nn.Module):
11
- '''
12
- Wrapper for UNet in diffusers
13
- '''
14
- def __init__(self, in_channels):
15
- super(UNetPP, self).__init__()
16
- self.in_channels = in_channels
17
- self.unet = UNet2DModel(
18
- sample_size=[256, 256*3],
19
- in_channels=in_channels,
20
- out_channels=32,
21
- layers_per_block=2,
22
- block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
23
- down_block_types=(
24
- "DownBlock2D",
25
- "DownBlock2D",
26
- "DownBlock2D",
27
- "AttnDownBlock2D",
28
- "AttnDownBlock2D",
29
- "AttnDownBlock2D",
30
- "DownBlock2D",
31
- ),
32
- up_block_types=(
33
- "UpBlock2D",
34
- "AttnUpBlock2D",
35
- "AttnUpBlock2D",
36
- "AttnUpBlock2D",
37
- "UpBlock2D",
38
- "UpBlock2D",
39
- "UpBlock2D",
40
- ),
41
- )
42
-
43
- # self.unet.enable_xformers_memory_efficient_attention()
44
- if in_channels > 12:
45
- self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
46
-
47
- def forward(self, x, t=256):
48
- learned_plane = self.learned_plane
49
- if x.shape[1] < self.in_channels:
50
- learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
51
- x = torch.cat([x, learned_plane], dim = 1)
52
- return self.unet(x, t).sample
53
-
 
1
+ '''
2
+ Codes are from:
3
+ https://github.com/jaxony/unet-pytorch/blob/master/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from diffusers import UNet2DModel
9
+ import einops
10
+ class UNetPP(nn.Module):
11
+ '''
12
+ Wrapper for UNet in diffusers
13
+ '''
14
+ def __init__(self, in_channels):
15
+ super(UNetPP, self).__init__()
16
+ self.in_channels = in_channels
17
+ self.unet = UNet2DModel(
18
+ sample_size=[256, 256*3],
19
+ in_channels=in_channels,
20
+ out_channels=32,
21
+ layers_per_block=2,
22
+ block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
23
+ down_block_types=(
24
+ "DownBlock2D",
25
+ "DownBlock2D",
26
+ "DownBlock2D",
27
+ "AttnDownBlock2D",
28
+ "AttnDownBlock2D",
29
+ "AttnDownBlock2D",
30
+ "DownBlock2D",
31
+ ),
32
+ up_block_types=(
33
+ "UpBlock2D",
34
+ "AttnUpBlock2D",
35
+ "AttnUpBlock2D",
36
+ "AttnUpBlock2D",
37
+ "UpBlock2D",
38
+ "UpBlock2D",
39
+ "UpBlock2D",
40
+ ),
41
+ )
42
+
43
+ self.unet.enable_xformers_memory_efficient_attention()
44
+ if in_channels > 12:
45
+ self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
46
+
47
+ def forward(self, x, t=256):
48
+ learned_plane = self.learned_plane
49
+ if x.shape[1] < self.in_channels:
50
+ learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
51
+ x = torch.cat([x, learned_plane], dim = 1)
52
+ return self.unet(x, t).sample
53
+