Feature Extraction
Transformers
Safetensors
custom_code
gheinrich commited on
Commit
bed3eea
1 Parent(s): b836cd6

Upload model

Browse files
Files changed (2) hide show
  1. hf_model.py +2 -1
  2. vitdet.py +173 -0
hf_model.py CHANGED
@@ -30,7 +30,8 @@ from .eradio_model import eradio
30
  from .radio_model import create_model_from_args
31
  from .radio_model import RADIOModel as RADIOModelBase, Resolution
32
  from .input_conditioner import get_default_conditioner, InputConditioner
33
-
 
34
 
35
  # Register extra models
36
  from .extra_timm_models import *
 
30
  from .radio_model import create_model_from_args
31
  from .radio_model import RADIOModel as RADIOModelBase, Resolution
32
  from .input_conditioner import get_default_conditioner, InputConditioner
33
+ from .vit_patch_generator import ViTPatchGenerator
34
+ from .vitdet import apply_vitdet_arch, VitDetArgs
35
 
36
  # Register extra models
37
  from .extra_timm_models import *
vitdet.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from contextlib import contextmanager
3
+ from logging import getLogger
4
+ import math
5
+ import sys
6
+ from typing import List, Union, Iterable
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+
12
+ from timm.models import VisionTransformer
13
+ from einops import rearrange
14
+
15
+ DEFAULT_NUM_WINDOWED = 5
16
+
17
+
18
+ class VitDetArgs:
19
+ def __init__(self,
20
+ window_size: int,
21
+ num_summary_tokens: int,
22
+ num_windowed: int = DEFAULT_NUM_WINDOWED,
23
+ ):
24
+ self.window_size = window_size
25
+ self.num_summary_tokens = num_summary_tokens
26
+ self.num_windowed = num_windowed
27
+
28
+
29
+ def apply_vitdet_arch(model: VisionTransformer, args: VitDetArgs):
30
+ if isinstance(model, VisionTransformer):
31
+ patch_embed = getattr(model, 'patch_generator', model.patch_embed)
32
+
33
+ return ViTDetHook(patch_embed, model.blocks, args)
34
+ else:
35
+ print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)
36
+
37
+
38
+ class ViTDetHook:
39
+ def __init__(self,
40
+ embedder: nn.Module,
41
+ blocks: nn.Sequential,
42
+ args: VitDetArgs,
43
+ ):
44
+ self.blocks = blocks
45
+ self.num_summary_tokens = args.num_summary_tokens
46
+ self.window_size = args.window_size
47
+
48
+ self._input_resolution = None
49
+ self._num_windows = None
50
+ self._cls_patch = None
51
+ self._order_cache = dict()
52
+
53
+ embedder.register_forward_pre_hook(self._enter_model)
54
+
55
+ # This will decide if we window-fy the patches
56
+ # and enable vit-det for this iteration, and if so,
57
+ # rearrange the patches for efficient mode switching
58
+ blocks.register_forward_pre_hook(self._enter_blocks)
59
+
60
+ is_global = True
61
+ period = args.num_windowed + 1
62
+ for i, layer in enumerate(blocks[:-1]):
63
+ ctr = i % period
64
+ if ctr == 0:
65
+ layer.register_forward_pre_hook(self._to_windows)
66
+ is_global = False
67
+ elif ctr == args.num_windowed:
68
+ layer.register_forward_pre_hook(self._to_global)
69
+ is_global = True
70
+
71
+ # Always ensure the final layer is a global layer
72
+ if not is_global:
73
+ blocks[-1].register_forward_pre_hook(self._to_global)
74
+
75
+ blocks.register_forward_hook(self._exit_model)
76
+
77
+ def _enter_model(self, _, input: List[torch.Tensor]):
78
+ self._input_resolution = input[0].shape[-2:]
79
+
80
+ def _enter_blocks(self, _, input: List[torch.Tensor]):
81
+ # print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)
82
+
83
+ patches = input[0]
84
+ patches = self._rearrange_patches(patches)
85
+
86
+ return (patches,) + input[1:]
87
+
88
+ def _to_windows(self, _, input: List[torch.Tensor]):
89
+ patches = input[0]
90
+
91
+ if self.num_summary_tokens:
92
+ self._cls_patch = patches[:, :self.num_summary_tokens]
93
+ patches = patches[:, self.num_summary_tokens:]
94
+
95
+ patches = rearrange(
96
+ patches, 'b (p t) c -> (b p) t c',
97
+ p=self._num_windows, t=self.window_size ** 2,
98
+ )
99
+
100
+ return (patches,) + input[1:]
101
+
102
+ def _to_global(self, _, input: List[torch.Tensor]):
103
+ patches = input[0]
104
+
105
+ patches = rearrange(
106
+ patches, '(b p) t c -> b (p t) c',
107
+ p=self._num_windows, t=self.window_size ** 2,
108
+ b=patches.shape[0] // self._num_windows,
109
+ )
110
+
111
+ if self.num_summary_tokens:
112
+ patches = torch.cat([
113
+ self._cls_patch,
114
+ patches,
115
+ ], dim=1)
116
+
117
+ return (patches,) + input[1:]
118
+
119
+ def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):
120
+ # Return patches to their original order
121
+ patch_order = self._order_cache[self._input_resolution][0]
122
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
123
+
124
+ ret_patches = torch.empty_like(patches)
125
+ ret_patches = torch.scatter(
126
+ ret_patches,
127
+ dim=1,
128
+ index=patch_order,
129
+ src=patches,
130
+ )
131
+
132
+ return ret_patches
133
+
134
+ def _rearrange_patches(self, patches: torch.Tensor):
135
+ # We rearrange the patches so that we can efficiently
136
+ # switch between windowed and global mode by just
137
+ # reshaping the tensor
138
+
139
+ patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))
140
+ if patch_order is None:
141
+ num_feat_patches = patches.shape[1] - self.num_summary_tokens
142
+ num_pixels = self._input_resolution[0] * self._input_resolution[1]
143
+
144
+ patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))
145
+ rows = self._input_resolution[-2] // patch_size
146
+ cols = self._input_resolution[-1] // patch_size
147
+
148
+ w_rows = rows // self.window_size
149
+ w_cols = cols // self.window_size
150
+
151
+ patch_order = torch.arange(0, num_feat_patches, device=patches.device)
152
+
153
+ patch_order = rearrange(
154
+ patch_order, '(wy py wx px) -> (wy wx py px)',
155
+ wy=w_rows, wx=w_cols,
156
+ py=self.window_size, px=self.window_size,
157
+ )
158
+
159
+ if self.num_summary_tokens:
160
+ patch_order = torch.cat([
161
+ torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),
162
+ patch_order + self.num_summary_tokens,
163
+ ])
164
+
165
+ self._num_windows = w_rows * w_cols
166
+ self._order_cache[self._input_resolution] = (
167
+ patch_order,
168
+ self._num_windows,
169
+ )
170
+
171
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
172
+ patches = torch.gather(patches, dim=1, index=patch_order)
173
+ return patches