snowclipsed commited on
Commit
fcbc9bf
·
1 Parent(s): 9fd23e4

add back add_linear_to_key

Browse files
Files changed (1) hide show
  1. weights.py +76 -33
weights.py CHANGED
@@ -1,10 +1,26 @@
1
  import safetensors
2
  import torch
3
  import torch.nn as nn
 
4
 
5
  from contextlib import contextmanager
6
  from typing import Callable, List
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  @contextmanager
10
  def safetensors_open(safetensors_file: str):
@@ -27,12 +43,17 @@ def safetensors_open(safetensors_file: str):
27
  yield get_tensor
28
 
29
 
30
- def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
 
 
 
 
31
  """Internal function to load weights using a tensor getter function."""
32
  model = model.to(dtype=torch.float16)
33
 
34
  vision = model.vision
35
  region = model.region
 
36
  weight_map = {
37
  "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
38
  "patch_emb"
@@ -90,23 +111,42 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -
90
  }
91
  )
92
 
93
- for i in range(len(model.text["blocks"])):
94
- prefix = f"text_model.transformer.h.{i}"
95
- blk = model.text["blocks"][i]
96
- weight_map.update(
97
- {
98
- f"{prefix}.ln.weight": blk["ln"].weight,
99
- f"{prefix}.ln.bias": blk["ln"].bias,
100
- f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
101
- f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
102
- f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
103
- f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
104
- f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
105
- f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
106
- f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
107
- f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
108
- }
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  for key, tensor in weight_map.items():
112
  tensor.data.copy_(get_tensor(key))
@@ -120,34 +160,37 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -
120
  def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
121
  """Load weights from a safetensors file into a MoondreamModel instance."""
122
  with safetensors_open(weights_file) as get_tensor:
 
123
  if (
124
- "vision.blocks.0.attn.proj.bias" in get_tensor.keys()
125
- or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys()
126
  ):
127
  with safetensors_open(weights_file) as get_tensor:
128
- tensors = {
129
- k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys()
130
- }
131
  model.load_state_dict(tensors, strict=False)
132
  else:
133
  # Wrap the get_tensor function to handle key normalization
134
- name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
135
  _load_weights(
136
- lambda x: get_tensor(name_map[x]).to(dtype=torch.bfloat16), model
 
137
  )
138
 
139
 
140
  def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
141
  """Load weights from a PyTorch file into a MoondreamModel instance."""
142
- device = str(torch.empty(0).device)
143
- tensors = torch.load(weights_file, map_location=device, weights_only=True)
144
- if "vision.blocks.0.attn.proj.bias" in tensors.keys():
145
- missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False)
146
- print("Missing keys:", missing_keys)
147
- print("Unexpected keys:", unexpected_keys)
 
 
 
148
  else:
149
  tensors = {
150
- k.replace("._orig_mod", ""): v.to(dtype=torch.bfloat16)
151
  for k, v in tensors.items()
152
  }
153
  _load_weights(lambda x: tensors[x], model)
@@ -168,4 +211,4 @@ def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
168
 
169
  # Make all parameters contiguous
170
  for param in model.parameters():
171
- param.data = param.data.contiguous()
 
1
  import safetensors
2
  import torch
3
  import torch.nn as nn
4
+ import re
5
 
6
  from contextlib import contextmanager
7
  from typing import Callable, List
8
 
9
+ from .text import build_text_model
10
+ from .config import TextConfig
11
+
12
+
13
+ # Our custom linear has an module named linear, so we add linear to the name
14
+ def add_linear_to_key(k: str) -> str:
15
+ k = k.replace("model.", "")
16
+ if k.startswith("text.") and ".linear." not in k:
17
+ k = re.sub(
18
+ r"(attn\.(?:qkv|proj)|mlp\.fc[12])\.(weight|bias)$",
19
+ r"\1.linear.\2",
20
+ k,
21
+ )
22
+ return k
23
+
24
 
25
  @contextmanager
26
  def safetensors_open(safetensors_file: str):
 
43
  yield get_tensor
44
 
45
 
46
+ def _load_weights(
47
+ get_tensor: Callable[[str], torch.Tensor],
48
+ model: nn.Module,
49
+ is_quantized: bool = False,
50
+ ) -> None:
51
  """Internal function to load weights using a tensor getter function."""
52
  model = model.to(dtype=torch.float16)
53
 
54
  vision = model.vision
55
  region = model.region
56
+
57
  weight_map = {
58
  "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
59
  "patch_emb"
 
111
  }
112
  )
113
 
114
+ if not is_quantized:
115
+ for i in range(len(model.text["blocks"])):
116
+ prefix = f"text_model.transformer.h.{i}"
117
+ blk = model.text["blocks"][i]
118
+ weight_map.update(
119
+ {
120
+ f"{prefix}.ln.weight": blk["ln"].weight,
121
+ f"{prefix}.ln.bias": blk["ln"].bias,
122
+ f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
123
+ f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
124
+ f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
125
+ f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
126
+ f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
127
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
128
+ f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
129
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
130
+ }
131
+ )
132
+ else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
133
+ for i in range(len(model.text["blocks"])):
134
+ prefix = f"text_model.transformer.h.{i}"
135
+ blk = model.text["blocks"][i]
136
+ weight_map.update(
137
+ {
138
+ f"{prefix}.ln.qweight": blk["ln"].weight,
139
+ f"{prefix}.ln.bias": blk["ln"].bias,
140
+ f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
141
+ f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
142
+ f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
143
+ f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
144
+ f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
145
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
146
+ f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
147
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
148
+ }
149
+ )
150
 
151
  for key, tensor in weight_map.items():
152
  tensor.data.copy_(get_tensor(key))
 
160
  def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
161
  """Load weights from a safetensors file into a MoondreamModel instance."""
162
  with safetensors_open(weights_file) as get_tensor:
163
+ all_keys = get_tensor.keys()
164
  if (
165
+ "vision.blocks.0.attn.proj.bias" in all_keys
166
+ or "model.vision.blocks.0.attn.proj.bias" in all_keys
167
  ):
168
  with safetensors_open(weights_file) as get_tensor:
169
+ tensors = {add_linear_to_key(k): get_tensor(k) for k in all_keys}
 
 
170
  model.load_state_dict(tensors, strict=False)
171
  else:
172
  # Wrap the get_tensor function to handle key normalization
173
+ name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
174
  _load_weights(
175
+ lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
176
+ model
177
  )
178
 
179
 
180
  def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
181
  """Load weights from a PyTorch file into a MoondreamModel instance."""
182
+ tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
183
+ all_keys = tensors.keys()
184
+
185
+ if (
186
+ "vision.blocks.0.attn.proj.bias" in all_keys
187
+ or "model.vision.blocks.0.attn.proj.bias" in all_keys
188
+ ):
189
+ tensors = {add_linear_to_key(k): v for k, v in tensors.items()}
190
+ model.load_state_dict(tensors, strict=False)
191
  else:
192
  tensors = {
193
+ k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
194
  for k, v in tensors.items()
195
  }
196
  _load_weights(lambda x: tensors[x], model)
 
211
 
212
  # Make all parameters contiguous
213
  for param in model.parameters():
214
+ param.data = param.data.contiguous()