Julien Blanchon
commited on
Commit
·
423995d
1
Parent(s):
28e3661
torch_dtype
Browse files- app.py +2 -1
- tim/models/nvidia_radio/radio/extra_models.py +45 -23
- tim/models/utils/text_encoders.py +2 -2
app.py
CHANGED
@@ -327,7 +327,8 @@ with gr.Blocks(css=css) as demo:
|
|
327 |
fn=generate_image,
|
328 |
inputs=[prompt],
|
329 |
outputs=[result, seed],
|
330 |
-
cache_examples=
|
|
|
331 |
)
|
332 |
|
333 |
gr.on(
|
|
|
327 |
fn=generate_image,
|
328 |
inputs=[prompt],
|
329 |
outputs=[result, seed],
|
330 |
+
cache_examples=True,
|
331 |
+
cache_mode="lazy",
|
332 |
)
|
333 |
|
334 |
gr.on(
|
tim/models/nvidia_radio/radio/extra_models.py
CHANGED
@@ -13,7 +13,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
13 |
from .forward_intermediates import forward_intermediates
|
14 |
from .input_conditioner import InputConditioner
|
15 |
|
16 |
-
_has_torch_sdpa = hasattr(F,
|
17 |
|
18 |
|
19 |
class PaliGemmaWrapper(nn.Module):
|
@@ -52,18 +52,25 @@ class PaliGemmaWrapper(nn.Module):
|
|
52 |
return self(x)
|
53 |
|
54 |
|
55 |
-
def _get_paligemma_model(
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
if LooseVersion(tx_version) > LooseVersion(
|
59 |
-
warnings.warn(
|
|
|
|
|
60 |
|
61 |
extra_args = dict()
|
62 |
|
63 |
if dtype is not None:
|
64 |
-
extra_args[
|
65 |
-
rev = str(dtype).split(
|
66 |
-
extra_args[
|
67 |
|
68 |
model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
|
69 |
|
@@ -73,22 +80,31 @@ def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype =
|
|
73 |
|
74 |
return vis_model
|
75 |
|
|
|
76 |
@register_model
|
77 |
def paligemma_896_student(**kwargs):
|
78 |
-
model = _get_paligemma_model(
|
|
|
|
|
79 |
|
80 |
return model
|
81 |
|
82 |
|
83 |
def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
|
84 |
B, N, C = x.shape
|
85 |
-
qkv =
|
|
|
|
|
|
|
|
|
86 |
|
87 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
88 |
x = F.scaled_dot_product_attention(
|
89 |
-
q,
|
|
|
|
|
90 |
is_causal=False,
|
91 |
-
dropout_p=self.attn_drop.p if self.training else 0
|
92 |
scale=self.scale,
|
93 |
)
|
94 |
x = x.transpose(1, 2).reshape(B, N, C)
|
@@ -96,11 +112,14 @@ def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
|
|
96 |
x = self.proj_drop(x)
|
97 |
return x
|
98 |
|
99 |
-
|
|
|
|
|
|
|
100 |
if cache_dir:
|
101 |
torch.hub.set_dir(cache_dir)
|
102 |
model: nn.Module = torch.hub.load(
|
103 |
-
|
104 |
dino_v2_model,
|
105 |
pretrained=pretrained,
|
106 |
# **kwargs,
|
@@ -108,11 +127,12 @@ def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=Tru
|
|
108 |
|
109 |
if _has_torch_sdpa:
|
110 |
for n, m in model.named_modules():
|
111 |
-
if n.endswith(
|
112 |
m.forward = MethodType(dv2_sdpa, m)
|
113 |
|
114 |
return model
|
115 |
|
|
|
116 |
class DinoWrapper(nn.Module):
|
117 |
def __init__(self, dino_model: nn.Module):
|
118 |
super().__init__()
|
@@ -130,11 +150,11 @@ class DinoWrapper(nn.Module):
|
|
130 |
|
131 |
@property
|
132 |
def num_cls_tokens(self):
|
133 |
-
return getattr(self.inner,
|
134 |
|
135 |
@property
|
136 |
def num_registers(self):
|
137 |
-
return getattr(self.inner,
|
138 |
|
139 |
@property
|
140 |
def num_summary_tokens(self):
|
@@ -147,8 +167,8 @@ class DinoWrapper(nn.Module):
|
|
147 |
def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
148 |
parts = self.inner.forward_features(*args, **kwargs)
|
149 |
|
150 |
-
cls_token = parts[
|
151 |
-
features = parts[
|
152 |
|
153 |
return cls_token, features
|
154 |
|
@@ -157,12 +177,13 @@ class DinoWrapper(nn.Module):
|
|
157 |
x = self.inner.blocks(x)
|
158 |
x_norm = self.inner.norm(x)
|
159 |
|
160 |
-
return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
|
161 |
|
162 |
def patchify(self, x: torch.Tensor) -> torch.Tensor:
|
163 |
return self.inner.prepare_tokens_with_masks(x)
|
164 |
|
165 |
-
def forward_intermediates(
|
|
|
166 |
x: torch.Tensor,
|
167 |
norm: bool = False,
|
168 |
**kwargs,
|
@@ -199,8 +220,9 @@ def _dino_student(arch: str, **kwargs):
|
|
199 |
|
200 |
@register_model
|
201 |
def dino_v2_l_student(**kwargs):
|
202 |
-
return _dino_student(
|
|
|
203 |
|
204 |
@register_model
|
205 |
def dino_v2_g_student(**kwargs):
|
206 |
-
return _dino_student(
|
|
|
13 |
from .forward_intermediates import forward_intermediates
|
14 |
from .input_conditioner import InputConditioner
|
15 |
|
16 |
+
_has_torch_sdpa = hasattr(F, "scaled_dot_product_attention")
|
17 |
|
18 |
|
19 |
class PaliGemmaWrapper(nn.Module):
|
|
|
52 |
return self(x)
|
53 |
|
54 |
|
55 |
+
def _get_paligemma_model(
|
56 |
+
repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16
|
57 |
+
):
|
58 |
+
from transformers import (
|
59 |
+
PaliGemmaForConditionalGeneration,
|
60 |
+
__version__ as tx_version,
|
61 |
+
)
|
62 |
|
63 |
+
if LooseVersion(tx_version) > LooseVersion("4.44.2"):
|
64 |
+
warnings.warn(
|
65 |
+
f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.'
|
66 |
+
)
|
67 |
|
68 |
extra_args = dict()
|
69 |
|
70 |
if dtype is not None:
|
71 |
+
extra_args["dtype"] = dtype
|
72 |
+
rev = str(dtype).split(".")[-1]
|
73 |
+
extra_args["revision"] = rev
|
74 |
|
75 |
model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
|
76 |
|
|
|
80 |
|
81 |
return vis_model
|
82 |
|
83 |
+
|
84 |
@register_model
|
85 |
def paligemma_896_student(**kwargs):
|
86 |
+
model = _get_paligemma_model(
|
87 |
+
"google/paligemma-3b-pt-896", embed_dim=1152, dtype=None
|
88 |
+
)
|
89 |
|
90 |
return model
|
91 |
|
92 |
|
93 |
def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
|
94 |
B, N, C = x.shape
|
95 |
+
qkv = (
|
96 |
+
self.qkv(x)
|
97 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
98 |
+
.permute(2, 0, 3, 1, 4)
|
99 |
+
)
|
100 |
|
101 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
102 |
x = F.scaled_dot_product_attention(
|
103 |
+
q,
|
104 |
+
k,
|
105 |
+
v,
|
106 |
is_causal=False,
|
107 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
108 |
scale=self.scale,
|
109 |
)
|
110 |
x = x.transpose(1, 2).reshape(B, N, C)
|
|
|
112 |
x = self.proj_drop(x)
|
113 |
return x
|
114 |
|
115 |
+
|
116 |
+
def _load_dino_v2(
|
117 |
+
dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs
|
118 |
+
):
|
119 |
if cache_dir:
|
120 |
torch.hub.set_dir(cache_dir)
|
121 |
model: nn.Module = torch.hub.load(
|
122 |
+
"facebookresearch/dinov2",
|
123 |
dino_v2_model,
|
124 |
pretrained=pretrained,
|
125 |
# **kwargs,
|
|
|
127 |
|
128 |
if _has_torch_sdpa:
|
129 |
for n, m in model.named_modules():
|
130 |
+
if n.endswith(".attn"):
|
131 |
m.forward = MethodType(dv2_sdpa, m)
|
132 |
|
133 |
return model
|
134 |
|
135 |
+
|
136 |
class DinoWrapper(nn.Module):
|
137 |
def __init__(self, dino_model: nn.Module):
|
138 |
super().__init__()
|
|
|
150 |
|
151 |
@property
|
152 |
def num_cls_tokens(self):
|
153 |
+
return getattr(self.inner, "num_tokens", 1)
|
154 |
|
155 |
@property
|
156 |
def num_registers(self):
|
157 |
+
return getattr(self.inner, "num_register_tokens", 0)
|
158 |
|
159 |
@property
|
160 |
def num_summary_tokens(self):
|
|
|
167 |
def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
168 |
parts = self.inner.forward_features(*args, **kwargs)
|
169 |
|
170 |
+
cls_token = parts["x_norm_clstoken"]
|
171 |
+
features = parts["x_norm_patchtokens"]
|
172 |
|
173 |
return cls_token, features
|
174 |
|
|
|
177 |
x = self.inner.blocks(x)
|
178 |
x_norm = self.inner.norm(x)
|
179 |
|
180 |
+
return x_norm[:, 0], x_norm[:, self.num_summary_tokens :]
|
181 |
|
182 |
def patchify(self, x: torch.Tensor) -> torch.Tensor:
|
183 |
return self.inner.prepare_tokens_with_masks(x)
|
184 |
|
185 |
+
def forward_intermediates(
|
186 |
+
self,
|
187 |
x: torch.Tensor,
|
188 |
norm: bool = False,
|
189 |
**kwargs,
|
|
|
220 |
|
221 |
@register_model
|
222 |
def dino_v2_l_student(**kwargs):
|
223 |
+
return _dino_student("dinov2_vitl14_reg", **kwargs)
|
224 |
+
|
225 |
|
226 |
@register_model
|
227 |
def dino_v2_g_student(**kwargs):
|
228 |
+
return _dino_student("dinov2_vitg14_reg", **kwargs)
|
tim/models/utils/text_encoders.py
CHANGED
@@ -13,14 +13,14 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
|
|
13 |
text_encoder_dir,
|
14 |
attn_implementation="flash_attention_2",
|
15 |
device_map="cpu",
|
16 |
-
|
17 |
).model
|
18 |
elif "t5" in text_encoder_dir:
|
19 |
text_encoder = T5EncoderModel.from_pretrained(
|
20 |
text_encoder_dir,
|
21 |
attn_implementation="sdpa",
|
22 |
device_map="cpu",
|
23 |
-
|
24 |
)
|
25 |
else:
|
26 |
raise NotImplementedError
|
|
|
13 |
text_encoder_dir,
|
14 |
attn_implementation="flash_attention_2",
|
15 |
device_map="cpu",
|
16 |
+
dtype=weight_dtype,
|
17 |
).model
|
18 |
elif "t5" in text_encoder_dir:
|
19 |
text_encoder = T5EncoderModel.from_pretrained(
|
20 |
text_encoder_dir,
|
21 |
attn_implementation="sdpa",
|
22 |
device_map="cpu",
|
23 |
+
dtype=weight_dtype,
|
24 |
)
|
25 |
else:
|
26 |
raise NotImplementedError
|