snowclipsed commited on
Commit
9308dfe
·
1 Parent(s): 595a8a6

remove any text model initialization in weights.py

Browse files
Files changed (2) hide show
  1. moondream.py +1 -1
  2. weights.py +34 -112
moondream.py CHANGED
@@ -75,7 +75,7 @@ class MoondreamModel(nn.Module):
75
  "vikhyatk/moondream2", revision="2025-01-09"
76
  )
77
  self.vision = build_vision_model(config.vision, dtype)
78
- self.text = build_text_model(config.text, dtype)
79
 
80
  # Region Model
81
  linear_cls = (
 
75
  "vikhyatk/moondream2", revision="2025-01-09"
76
  )
77
  self.vision = build_vision_model(config.vision, dtype)
78
+ self.text = build_text_model(config.text)
79
 
80
  # Region Model
81
  linear_cls = (
weights.py CHANGED
@@ -1,26 +1,10 @@
1
  import safetensors
2
  import torch
3
  import torch.nn as nn
4
- import re
5
 
6
  from contextlib import contextmanager
7
  from typing import Callable, List
8
 
9
- from .text import build_text_model
10
- from .config import TextConfig
11
-
12
-
13
- # Our custom linear has an module named linear, so we add linear to the name
14
- def add_linear_to_key(k: str) -> str:
15
- k = k.replace("model.", "")
16
- if k.startswith("text.") and ".linear." not in k:
17
- k = re.sub(
18
- r"(attn\.(?:qkv|proj)|mlp\.fc[12])\.(weight|bias)$",
19
- r"\1.linear.\2",
20
- k,
21
- )
22
- return k
23
-
24
 
25
  @contextmanager
26
  def safetensors_open(safetensors_file: str):
@@ -43,17 +27,12 @@ def safetensors_open(safetensors_file: str):
43
  yield get_tensor
44
 
45
 
46
- def _load_weights(
47
- get_tensor: Callable[[str], torch.Tensor],
48
- model: nn.Module,
49
- is_quantized: bool = False,
50
- ) -> None:
51
  """Internal function to load weights using a tensor getter function."""
52
- model = model.to(dtype=torch.float16)
53
 
54
  vision = model.vision
55
  region = model.region
56
-
57
  weight_map = {
58
  "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
59
  "patch_emb"
@@ -111,42 +90,23 @@ def _load_weights(
111
  }
112
  )
113
 
114
- if not is_quantized:
115
- for i in range(len(model.text["blocks"])):
116
- prefix = f"text_model.transformer.h.{i}"
117
- blk = model.text["blocks"][i]
118
- weight_map.update(
119
- {
120
- f"{prefix}.ln.weight": blk["ln"].weight,
121
- f"{prefix}.ln.bias": blk["ln"].bias,
122
- f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
123
- f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
124
- f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
125
- f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
126
- f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
127
- f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
128
- f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
129
- f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
130
- }
131
- )
132
- else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
133
- for i in range(len(model.text["blocks"])):
134
- prefix = f"text_model.transformer.h.{i}"
135
- blk = model.text["blocks"][i]
136
- weight_map.update(
137
- {
138
- f"{prefix}.ln.qweight": blk["ln"].weight,
139
- f"{prefix}.ln.bias": blk["ln"].bias,
140
- f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
141
- f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
142
- f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
143
- f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
144
- f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
145
- f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
146
- f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
147
- f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
148
- }
149
- )
150
 
151
  for key, tensor in weight_map.items():
152
  tensor.data.copy_(get_tensor(key))
@@ -160,75 +120,37 @@ def _load_weights(
160
  def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
161
  """Load weights from a safetensors file into a MoondreamModel instance."""
162
  with safetensors_open(weights_file) as get_tensor:
163
- all_keys = get_tensor.keys()
164
-
165
- is_quantized = any(
166
- ".qweight" in key or "_quantized" in key or "quant." in key
167
- for key in all_keys
168
- )
169
-
170
- if "text_model.transformer.h.0.ln.weight" in all_keys:
171
- layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
172
- else:
173
- layernorm_dtype = torch.float16
174
-
175
- linear_dtype = torch.int8 if is_quantized else torch.float16
176
-
177
- model.text = build_text_model(
178
- TextConfig
179
- )
180
- if model.setup_caches_flag:
181
- model._setup_caches()
182
-
183
  if (
184
- "vision.blocks.0.attn.proj.bias" in all_keys
185
- or "model.vision.blocks.0.attn.proj.bias" in all_keys
186
  ):
187
  with safetensors_open(weights_file) as get_tensor:
188
- tensors = {add_linear_to_key(k): get_tensor(k) for k in all_keys}
 
 
189
  model.load_state_dict(tensors, strict=False)
190
  else:
191
  # Wrap the get_tensor function to handle key normalization
192
- name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
193
  _load_weights(
194
- lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
195
- model,
196
- is_quantized,
197
  )
198
 
199
 
200
  def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
201
  """Load weights from a PyTorch file into a MoondreamModel instance."""
202
- tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
203
- all_keys = tensors.keys()
204
- is_quantized = any(
205
- ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys
206
- )
207
-
208
- if "text.blocks.0.ln.weight" in all_keys:
209
- layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
210
- else:
211
- layernorm_dtype = torch.float16
212
-
213
- linear_dtype = torch.int8 if is_quantized else torch.float16
214
- model.text = build_text_model(
215
- TextConfig
216
- )
217
- if model.setup_caches_flag:
218
- model._setup_caches()
219
-
220
- if (
221
- "vision.blocks.0.attn.proj.bias" in all_keys
222
- or "model.vision.blocks.0.attn.proj.bias" in all_keys
223
- ):
224
- tensors = {add_linear_to_key(k): v for k, v in tensors.items()}
225
- model.load_state_dict(tensors, strict=False)
226
  else:
227
  tensors = {
228
- k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
229
  for k, v in tensors.items()
230
  }
231
- _load_weights(lambda x: tensors[x], model, is_quantized)
232
 
233
 
234
  def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
 
1
  import safetensors
2
  import torch
3
  import torch.nn as nn
 
4
 
5
  from contextlib import contextmanager
6
  from typing import Callable, List
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  @contextmanager
10
  def safetensors_open(safetensors_file: str):
 
27
  yield get_tensor
28
 
29
 
30
+ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
 
 
 
 
31
  """Internal function to load weights using a tensor getter function."""
32
+ model = model.to(dtype=torch.bfloat16)
33
 
34
  vision = model.vision
35
  region = model.region
 
36
  weight_map = {
37
  "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
38
  "patch_emb"
 
90
  }
91
  )
92
 
93
+ for i in range(len(model.text["blocks"])):
94
+ prefix = f"text_model.transformer.h.{i}"
95
+ blk = model.text["blocks"][i]
96
+ weight_map.update(
97
+ {
98
+ f"{prefix}.ln.weight": blk["ln"].weight,
99
+ f"{prefix}.ln.bias": blk["ln"].bias,
100
+ f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
101
+ f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
102
+ f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
103
+ f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
104
+ f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
105
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
106
+ f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
107
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
108
+ }
109
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  for key, tensor in weight_map.items():
112
  tensor.data.copy_(get_tensor(key))
 
120
  def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
121
  """Load weights from a safetensors file into a MoondreamModel instance."""
122
  with safetensors_open(weights_file) as get_tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  if (
124
+ "vision.blocks.0.attn.proj.bias" in get_tensor.keys()
125
+ or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys()
126
  ):
127
  with safetensors_open(weights_file) as get_tensor:
128
+ tensors = {
129
+ k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys()
130
+ }
131
  model.load_state_dict(tensors, strict=False)
132
  else:
133
  # Wrap the get_tensor function to handle key normalization
134
+ name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
135
  _load_weights(
136
+ lambda x: get_tensor(name_map[x]).to(dtype=torch.bfloat16), model
 
 
137
  )
138
 
139
 
140
  def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
141
  """Load weights from a PyTorch file into a MoondreamModel instance."""
142
+ device = str(torch.empty(0).device)
143
+ tensors = torch.load(weights_file, map_location=device, weights_only=True)
144
+ if "vision.blocks.0.attn.proj.bias" in tensors.keys():
145
+ missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False)
146
+ print("Missing keys:", missing_keys)
147
+ print("Unexpected keys:", unexpected_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  else:
149
  tensors = {
150
+ k.replace("._orig_mod", ""): v.to(dtype=torch.bfloat16)
151
  for k, v in tensors.items()
152
  }
153
+ _load_weights(lambda x: tensors[x], model)
154
 
155
 
156
  def load_weights_into_model(weights_file: str, model: nn.Module) -> None: