lcipolina commited on
Commit
41e7160
1 Parent(s): 9741d33

Upload glide_text2im/clip/utils.py

Browse files
Files changed (1) hide show
  1. glide_text2im/clip/utils.py +97 -0
glide_text2im/clip/utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Optional
3
+
4
+ import attr
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ FilterFn = Callable[[torch.Tensor], torch.Tensor]
10
+
11
+
12
+ class ZeroKeyBiasGrad(torch.autograd.Function):
13
+ @staticmethod
14
+ def forward(ctx, x):
15
+ return x
16
+
17
+ @staticmethod
18
+ def backward(ctx, output_grad):
19
+ output_grad = output_grad.clone()
20
+ output_grad.chunk(3)[1].zero_()
21
+ return output_grad
22
+
23
+
24
+ def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
25
+ return ZeroKeyBiasGrad.apply(x)
26
+
27
+
28
+ @attr.s(eq=False, repr=False)
29
+ class LayerNorm(nn.Module):
30
+ n_state: int = attr.ib()
31
+ eps: float = attr.ib(default=1e-6)
32
+ device: torch.device = attr.ib(default=torch.device("cuda"))
33
+
34
+ def __attrs_post_init__(self) -> None:
35
+ super().__init__()
36
+ self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device))
37
+ self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device))
38
+ self.g.weight_decay_level = "disable" # type: ignore
39
+ self.b.weight_decay_level = "disable" # type: ignore
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return F.layer_norm(
43
+ x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps
44
+ )
45
+
46
+
47
+ @attr.s(eq=False, repr=False)
48
+ class Affine(nn.Module):
49
+ n_in: int = attr.ib()
50
+ n_out: int = attr.ib()
51
+ use_bias: bool = attr.ib(default=True)
52
+ use_admnet_init: bool = attr.ib(default=False)
53
+ std: Optional[float] = attr.ib(default=None)
54
+ extra_init_scale: Optional[float] = attr.ib(default=None)
55
+ bias_filter_fn: FilterFn = attr.ib(default=lambda x: x)
56
+ device: torch.device = attr.ib(default=torch.device("cuda"))
57
+
58
+ def __attrs_post_init__(self) -> None:
59
+ super().__init__()
60
+
61
+ if not self.use_admnet_init:
62
+ self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
63
+ self.std = (
64
+ self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
65
+ )
66
+
67
+ w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
68
+ self.w = nn.Parameter(w)
69
+
70
+ if self.use_bias:
71
+ self.b = nn.Parameter(
72
+ torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
73
+ )
74
+ self.b.weight_decay_level = "disable" # type: ignore
75
+ else:
76
+ if self.extra_init_scale is not None:
77
+ raise ValueError("extra_init_scale incompatible with admnet init")
78
+
79
+ w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
80
+
81
+ if self.use_bias:
82
+ b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)
83
+
84
+ self.w = nn.Parameter(w)
85
+
86
+ if self.use_bias:
87
+ self.b = nn.Parameter(b)
88
+ self.b.weight_decay_level = "disable" # type: ignore
89
+
90
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
91
+ w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype)
92
+ b = (
93
+ self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype))
94
+ if self.use_bias
95
+ else None
96
+ )
97
+ return F.linear(x, w, b)