Feature Extraction
Transformers
Safetensors
custom_code
gheinrich commited on
Commit
a50b54b
1 Parent(s): 2f947ba

Upload model

Browse files
Files changed (4) hide show
  1. adaptor_generic.py +29 -0
  2. adaptor_mlp.py +150 -0
  3. adaptor_registry.py +37 -0
  4. hf_model.py +5 -1
adaptor_generic.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
15
+ from .adaptor_mlp import create_mlp_from_state
16
+
17
+
18
+ class GenericAdaptor(AdaptorBase):
19
+ def __init__(self, main_config: Namespace, adaptor_config, state):
20
+ super().__init__()
21
+
22
+ self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.')
23
+ self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.')
24
+
25
+ def forward(self, input: AdaptorInput) -> RadioOutput:
26
+ summary = self.head_mlp(input.summary)
27
+ feat = self.feat_mlp(input.features)
28
+
29
+ return RadioOutput(summary, feat)
adaptor_mlp.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ import math
9
+ from typing import Dict
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Block
16
+
17
+
18
+ class MLP(nn.Module):
19
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
20
+ num_inner: int = 0, device: torch.device = None, **kwargs):
21
+ super(MLP, self).__init__()
22
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
23
+ self.norm = nn.LayerNorm(hidden_size, device=device)
24
+ self.relu = nn.ReLU()
25
+
26
+ inner = []
27
+ for _ in range(num_inner):
28
+ inner.extend([
29
+ nn.Linear(hidden_size, hidden_size, device=device),
30
+ nn.LayerNorm(hidden_size, device=device),
31
+ nn.ReLU(),
32
+ ])
33
+ if inner:
34
+ self.inner = nn.Sequential(*inner)
35
+ else:
36
+ self.inner = nn.Identity()
37
+
38
+ self.fc2 = nn.Linear(hidden_size, output_size, device=device)
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ x = self.fc1(x)
42
+ x = self.norm(x)
43
+ x = self.relu(x)
44
+ x = self.inner(x)
45
+ x = self.fc2(x)
46
+ return x
47
+
48
+
49
+ class MLP2(nn.Module):
50
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
51
+ num_inner: int = 0,
52
+ pre_norm: bool = False, device: torch.device = None,
53
+ upsample_factor: int = 1,
54
+ **kwargs):
55
+ super().__init__()
56
+
57
+ self.pre_norm = nn.Sequential(
58
+ nn.LayerNorm(input_size),
59
+ nn.GELU(),
60
+ ) if pre_norm else nn.Identity()
61
+
62
+ self.upsample_factor = upsample_factor
63
+ self._real_output_dim = output_size
64
+
65
+ hidden_size *= upsample_factor
66
+ output_size *= (upsample_factor ** 2)
67
+
68
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
69
+
70
+ blocks = []
71
+ for _ in range(num_inner):
72
+ blocks.append(nn.Sequential(
73
+ nn.LayerNorm(hidden_size, device=device),
74
+ nn.GELU(),
75
+ nn.Linear(hidden_size, hidden_size, device=device),
76
+ ))
77
+ self.blocks = nn.ModuleList(blocks)
78
+
79
+ self.final = nn.Sequential(
80
+ nn.LayerNorm(hidden_size, device=device),
81
+ nn.GELU(),
82
+ nn.Linear(hidden_size, output_size, device=device),
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x = self.pre_norm(x)
87
+ x = self.fc1(x)
88
+ for block in self.blocks:
89
+ x = x + block(x)
90
+ x = self.final(x)
91
+
92
+ if self.upsample_factor > 1:
93
+ h = w = int(math.sqrt(x.shape[1]))
94
+ x = rearrange(x, 'b (h w) (u1 u2 c) -> b (u1 h u2 w) c',
95
+ h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
96
+ c=self._real_output_dim)
97
+
98
+ return x
99
+
100
+
101
+ MLP_FACTORY = {
102
+ 'v1': MLP,
103
+ 'v2': MLP2,
104
+ }
105
+
106
+
107
+ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
108
+ state = {
109
+ k[len(prefix):]: v
110
+ for k, v in state.items()
111
+ if k.startswith(prefix)
112
+ }
113
+ return state
114
+
115
+
116
+ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
117
+ state = strip_prefix(state, prefix)
118
+
119
+ if version == 'v1':
120
+ hidden_dim, input_dim = state['fc1.weight'].shape
121
+ output_dim = state['fc2.weight'].shape[0]
122
+
123
+ for num_inner in range(1000):
124
+ k = f'inner.{num_inner}.0.weight'
125
+ if k not in state:
126
+ break
127
+ elif version == 'v2':
128
+ hidden_dim, input_dim = state['fc1.weight'].shape
129
+ output_dim = state['final.2.weight'].shape[0]
130
+
131
+ for num_inner in range(1000):
132
+ k = f'blocks.{num_inner}.0.weight'
133
+ if k not in state:
134
+ break
135
+ else:
136
+ raise ValueError(f'Unsupported MLP version: {version}')
137
+
138
+ return input_dim, hidden_dim, output_dim, num_inner
139
+
140
+
141
+ def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
142
+ state = strip_prefix(state, prefix)
143
+
144
+ input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state)
145
+
146
+ ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner)
147
+
148
+ ret.load_state_dict(state)
149
+
150
+ return ret
adaptor_registry.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+ from typing import Dict, Any
10
+
11
+ import torch
12
+
13
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
14
+
15
+ dict_t = Dict[str, Any]
16
+ state_t = Dict[str, torch.Tensor]
17
+
18
+
19
+ class AdaptorRegistry:
20
+ def __init__(self):
21
+ self._registry = {}
22
+
23
+ def register_adaptor(self, name):
24
+ def decorator(factory_function):
25
+ if name in self._registry:
26
+ raise ValueError(f"Model '{name}' already registered")
27
+ self._registry[name] = factory_function
28
+ return factory_function
29
+ return decorator
30
+
31
+ def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
32
+ if name not in self._registry:
33
+ return GenericAdaptor(main_config, adaptor_config, state)
34
+ return self._registry[name](main_config, adaptor_config, state)
35
+
36
+ # Creating an instance of the registry
37
+ adaptor_registry = AdaptorRegistry()
hf_model.py CHANGED
@@ -21,7 +21,11 @@ from transformers import PretrainedConfig, PreTrainedModel
21
 
22
  from .common import RESOURCE_MAP, DEFAULT_VERSION
23
 
24
- # Force import of eradio_model in order to register it.
 
 
 
 
25
  from .eradio_model import eradio
26
  from .radio_model import create_model_from_args
27
  from .radio_model import RADIOModel as RADIOModelBase, Resolution
 
21
 
22
  from .common import RESOURCE_MAP, DEFAULT_VERSION
23
 
24
+ # Import all required modules.
25
+ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
26
+ from .adaptor_registry import adaptor_registry
27
+ from .enable_cpe_support import enable_cpe
28
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
29
  from .eradio_model import eradio
30
  from .radio_model import create_model_from_args
31
  from .radio_model import RADIOModel as RADIOModelBase, Resolution