Spaces:
Runtime error
Runtime error
Upload glide_text2im/clip/utils.py
Browse files- glide_text2im/clip/utils.py +97 -0
glide_text2im/clip/utils.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import attr
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
FilterFn = Callable[[torch.Tensor], torch.Tensor]
|
10 |
+
|
11 |
+
|
12 |
+
class ZeroKeyBiasGrad(torch.autograd.Function):
|
13 |
+
@staticmethod
|
14 |
+
def forward(ctx, x):
|
15 |
+
return x
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def backward(ctx, output_grad):
|
19 |
+
output_grad = output_grad.clone()
|
20 |
+
output_grad.chunk(3)[1].zero_()
|
21 |
+
return output_grad
|
22 |
+
|
23 |
+
|
24 |
+
def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
|
25 |
+
return ZeroKeyBiasGrad.apply(x)
|
26 |
+
|
27 |
+
|
28 |
+
@attr.s(eq=False, repr=False)
|
29 |
+
class LayerNorm(nn.Module):
|
30 |
+
n_state: int = attr.ib()
|
31 |
+
eps: float = attr.ib(default=1e-6)
|
32 |
+
device: torch.device = attr.ib(default=torch.device("cuda"))
|
33 |
+
|
34 |
+
def __attrs_post_init__(self) -> None:
|
35 |
+
super().__init__()
|
36 |
+
self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device))
|
37 |
+
self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device))
|
38 |
+
self.g.weight_decay_level = "disable" # type: ignore
|
39 |
+
self.b.weight_decay_level = "disable" # type: ignore
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
42 |
+
return F.layer_norm(
|
43 |
+
x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
@attr.s(eq=False, repr=False)
|
48 |
+
class Affine(nn.Module):
|
49 |
+
n_in: int = attr.ib()
|
50 |
+
n_out: int = attr.ib()
|
51 |
+
use_bias: bool = attr.ib(default=True)
|
52 |
+
use_admnet_init: bool = attr.ib(default=False)
|
53 |
+
std: Optional[float] = attr.ib(default=None)
|
54 |
+
extra_init_scale: Optional[float] = attr.ib(default=None)
|
55 |
+
bias_filter_fn: FilterFn = attr.ib(default=lambda x: x)
|
56 |
+
device: torch.device = attr.ib(default=torch.device("cuda"))
|
57 |
+
|
58 |
+
def __attrs_post_init__(self) -> None:
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
if not self.use_admnet_init:
|
62 |
+
self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
|
63 |
+
self.std = (
|
64 |
+
self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
|
65 |
+
)
|
66 |
+
|
67 |
+
w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
|
68 |
+
self.w = nn.Parameter(w)
|
69 |
+
|
70 |
+
if self.use_bias:
|
71 |
+
self.b = nn.Parameter(
|
72 |
+
torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
|
73 |
+
)
|
74 |
+
self.b.weight_decay_level = "disable" # type: ignore
|
75 |
+
else:
|
76 |
+
if self.extra_init_scale is not None:
|
77 |
+
raise ValueError("extra_init_scale incompatible with admnet init")
|
78 |
+
|
79 |
+
w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
|
80 |
+
|
81 |
+
if self.use_bias:
|
82 |
+
b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)
|
83 |
+
|
84 |
+
self.w = nn.Parameter(w)
|
85 |
+
|
86 |
+
if self.use_bias:
|
87 |
+
self.b = nn.Parameter(b)
|
88 |
+
self.b.weight_decay_level = "disable" # type: ignore
|
89 |
+
|
90 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
91 |
+
w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype)
|
92 |
+
b = (
|
93 |
+
self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype))
|
94 |
+
if self.use_bias
|
95 |
+
else None
|
96 |
+
)
|
97 |
+
return F.linear(x, w, b)
|