Julien Blanchon commited on
Commit
423995d
·
1 Parent(s): 28e3661

torch_dtype

Browse files
app.py CHANGED
@@ -327,7 +327,8 @@ with gr.Blocks(css=css) as demo:
327
  fn=generate_image,
328
  inputs=[prompt],
329
  outputs=[result, seed],
330
- cache_examples="lazy",
 
331
  )
332
 
333
  gr.on(
 
327
  fn=generate_image,
328
  inputs=[prompt],
329
  outputs=[result, seed],
330
+ cache_examples=True,
331
+ cache_mode="lazy",
332
  )
333
 
334
  gr.on(
tim/models/nvidia_radio/radio/extra_models.py CHANGED
@@ -13,7 +13,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13
  from .forward_intermediates import forward_intermediates
14
  from .input_conditioner import InputConditioner
15
 
16
- _has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')
17
 
18
 
19
  class PaliGemmaWrapper(nn.Module):
@@ -52,18 +52,25 @@ class PaliGemmaWrapper(nn.Module):
52
  return self(x)
53
 
54
 
55
- def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
56
- from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version
 
 
 
 
 
57
 
58
- if LooseVersion(tx_version) > LooseVersion('4.44.2'):
59
- warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')
 
 
60
 
61
  extra_args = dict()
62
 
63
  if dtype is not None:
64
- extra_args['torch_dtype'] = dtype
65
- rev = str(dtype).split('.')[-1]
66
- extra_args['revision'] = rev
67
 
68
  model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
69
 
@@ -73,22 +80,31 @@ def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype =
73
 
74
  return vis_model
75
 
 
76
  @register_model
77
  def paligemma_896_student(**kwargs):
78
- model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)
 
 
79
 
80
  return model
81
 
82
 
83
  def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
84
  B, N, C = x.shape
85
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
86
 
87
  q, k, v = qkv[0], qkv[1], qkv[2]
88
  x = F.scaled_dot_product_attention(
89
- q, k, v,
 
 
90
  is_causal=False,
91
- dropout_p=self.attn_drop.p if self.training else 0.,
92
  scale=self.scale,
93
  )
94
  x = x.transpose(1, 2).reshape(B, N, C)
@@ -96,11 +112,14 @@ def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
96
  x = self.proj_drop(x)
97
  return x
98
 
99
- def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
 
 
 
100
  if cache_dir:
101
  torch.hub.set_dir(cache_dir)
102
  model: nn.Module = torch.hub.load(
103
- 'facebookresearch/dinov2',
104
  dino_v2_model,
105
  pretrained=pretrained,
106
  # **kwargs,
@@ -108,11 +127,12 @@ def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=Tru
108
 
109
  if _has_torch_sdpa:
110
  for n, m in model.named_modules():
111
- if n.endswith('.attn'):
112
  m.forward = MethodType(dv2_sdpa, m)
113
 
114
  return model
115
 
 
116
  class DinoWrapper(nn.Module):
117
  def __init__(self, dino_model: nn.Module):
118
  super().__init__()
@@ -130,11 +150,11 @@ class DinoWrapper(nn.Module):
130
 
131
  @property
132
  def num_cls_tokens(self):
133
- return getattr(self.inner, 'num_tokens', 1)
134
 
135
  @property
136
  def num_registers(self):
137
- return getattr(self.inner, 'num_register_tokens', 0)
138
 
139
  @property
140
  def num_summary_tokens(self):
@@ -147,8 +167,8 @@ class DinoWrapper(nn.Module):
147
  def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
148
  parts = self.inner.forward_features(*args, **kwargs)
149
 
150
- cls_token = parts['x_norm_clstoken']
151
- features = parts['x_norm_patchtokens']
152
 
153
  return cls_token, features
154
 
@@ -157,12 +177,13 @@ class DinoWrapper(nn.Module):
157
  x = self.inner.blocks(x)
158
  x_norm = self.inner.norm(x)
159
 
160
- return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
161
 
162
  def patchify(self, x: torch.Tensor) -> torch.Tensor:
163
  return self.inner.prepare_tokens_with_masks(x)
164
 
165
- def forward_intermediates(self,
 
166
  x: torch.Tensor,
167
  norm: bool = False,
168
  **kwargs,
@@ -199,8 +220,9 @@ def _dino_student(arch: str, **kwargs):
199
 
200
  @register_model
201
  def dino_v2_l_student(**kwargs):
202
- return _dino_student('dinov2_vitl14_reg', **kwargs)
 
203
 
204
  @register_model
205
  def dino_v2_g_student(**kwargs):
206
- return _dino_student('dinov2_vitg14_reg', **kwargs)
 
13
  from .forward_intermediates import forward_intermediates
14
  from .input_conditioner import InputConditioner
15
 
16
+ _has_torch_sdpa = hasattr(F, "scaled_dot_product_attention")
17
 
18
 
19
  class PaliGemmaWrapper(nn.Module):
 
52
  return self(x)
53
 
54
 
55
+ def _get_paligemma_model(
56
+ repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16
57
+ ):
58
+ from transformers import (
59
+ PaliGemmaForConditionalGeneration,
60
+ __version__ as tx_version,
61
+ )
62
 
63
+ if LooseVersion(tx_version) > LooseVersion("4.44.2"):
64
+ warnings.warn(
65
+ f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.'
66
+ )
67
 
68
  extra_args = dict()
69
 
70
  if dtype is not None:
71
+ extra_args["dtype"] = dtype
72
+ rev = str(dtype).split(".")[-1]
73
+ extra_args["revision"] = rev
74
 
75
  model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
76
 
 
80
 
81
  return vis_model
82
 
83
+
84
  @register_model
85
  def paligemma_896_student(**kwargs):
86
+ model = _get_paligemma_model(
87
+ "google/paligemma-3b-pt-896", embed_dim=1152, dtype=None
88
+ )
89
 
90
  return model
91
 
92
 
93
  def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
94
  B, N, C = x.shape
95
+ qkv = (
96
+ self.qkv(x)
97
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
98
+ .permute(2, 0, 3, 1, 4)
99
+ )
100
 
101
  q, k, v = qkv[0], qkv[1], qkv[2]
102
  x = F.scaled_dot_product_attention(
103
+ q,
104
+ k,
105
+ v,
106
  is_causal=False,
107
+ dropout_p=self.attn_drop.p if self.training else 0.0,
108
  scale=self.scale,
109
  )
110
  x = x.transpose(1, 2).reshape(B, N, C)
 
112
  x = self.proj_drop(x)
113
  return x
114
 
115
+
116
+ def _load_dino_v2(
117
+ dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs
118
+ ):
119
  if cache_dir:
120
  torch.hub.set_dir(cache_dir)
121
  model: nn.Module = torch.hub.load(
122
+ "facebookresearch/dinov2",
123
  dino_v2_model,
124
  pretrained=pretrained,
125
  # **kwargs,
 
127
 
128
  if _has_torch_sdpa:
129
  for n, m in model.named_modules():
130
+ if n.endswith(".attn"):
131
  m.forward = MethodType(dv2_sdpa, m)
132
 
133
  return model
134
 
135
+
136
  class DinoWrapper(nn.Module):
137
  def __init__(self, dino_model: nn.Module):
138
  super().__init__()
 
150
 
151
  @property
152
  def num_cls_tokens(self):
153
+ return getattr(self.inner, "num_tokens", 1)
154
 
155
  @property
156
  def num_registers(self):
157
+ return getattr(self.inner, "num_register_tokens", 0)
158
 
159
  @property
160
  def num_summary_tokens(self):
 
167
  def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
168
  parts = self.inner.forward_features(*args, **kwargs)
169
 
170
+ cls_token = parts["x_norm_clstoken"]
171
+ features = parts["x_norm_patchtokens"]
172
 
173
  return cls_token, features
174
 
 
177
  x = self.inner.blocks(x)
178
  x_norm = self.inner.norm(x)
179
 
180
+ return x_norm[:, 0], x_norm[:, self.num_summary_tokens :]
181
 
182
  def patchify(self, x: torch.Tensor) -> torch.Tensor:
183
  return self.inner.prepare_tokens_with_masks(x)
184
 
185
+ def forward_intermediates(
186
+ self,
187
  x: torch.Tensor,
188
  norm: bool = False,
189
  **kwargs,
 
220
 
221
  @register_model
222
  def dino_v2_l_student(**kwargs):
223
+ return _dino_student("dinov2_vitl14_reg", **kwargs)
224
+
225
 
226
  @register_model
227
  def dino_v2_g_student(**kwargs):
228
+ return _dino_student("dinov2_vitg14_reg", **kwargs)
tim/models/utils/text_encoders.py CHANGED
@@ -13,14 +13,14 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
13
  text_encoder_dir,
14
  attn_implementation="flash_attention_2",
15
  device_map="cpu",
16
- torch_dtype=weight_dtype,
17
  ).model
18
  elif "t5" in text_encoder_dir:
19
  text_encoder = T5EncoderModel.from_pretrained(
20
  text_encoder_dir,
21
  attn_implementation="sdpa",
22
  device_map="cpu",
23
- torch_dtype=weight_dtype,
24
  )
25
  else:
26
  raise NotImplementedError
 
13
  text_encoder_dir,
14
  attn_implementation="flash_attention_2",
15
  device_map="cpu",
16
+ dtype=weight_dtype,
17
  ).model
18
  elif "t5" in text_encoder_dir:
19
  text_encoder = T5EncoderModel.from_pretrained(
20
  text_encoder_dir,
21
  attn_implementation="sdpa",
22
  device_map="cpu",
23
+ dtype=weight_dtype,
24
  )
25
  else:
26
  raise NotImplementedError