snowclipsed
commited on
Commit
·
9308dfe
1
Parent(s):
595a8a6
remove any text model initialization in weights.py
Browse files- moondream.py +1 -1
- weights.py +34 -112
moondream.py
CHANGED
@@ -75,7 +75,7 @@ class MoondreamModel(nn.Module):
|
|
75 |
"vikhyatk/moondream2", revision="2025-01-09"
|
76 |
)
|
77 |
self.vision = build_vision_model(config.vision, dtype)
|
78 |
-
self.text = build_text_model(config.text
|
79 |
|
80 |
# Region Model
|
81 |
linear_cls = (
|
|
|
75 |
"vikhyatk/moondream2", revision="2025-01-09"
|
76 |
)
|
77 |
self.vision = build_vision_model(config.vision, dtype)
|
78 |
+
self.text = build_text_model(config.text)
|
79 |
|
80 |
# Region Model
|
81 |
linear_cls = (
|
weights.py
CHANGED
@@ -1,26 +1,10 @@
|
|
1 |
import safetensors
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
-
import re
|
5 |
|
6 |
from contextlib import contextmanager
|
7 |
from typing import Callable, List
|
8 |
|
9 |
-
from .text import build_text_model
|
10 |
-
from .config import TextConfig
|
11 |
-
|
12 |
-
|
13 |
-
# Our custom linear has an module named linear, so we add linear to the name
|
14 |
-
def add_linear_to_key(k: str) -> str:
|
15 |
-
k = k.replace("model.", "")
|
16 |
-
if k.startswith("text.") and ".linear." not in k:
|
17 |
-
k = re.sub(
|
18 |
-
r"(attn\.(?:qkv|proj)|mlp\.fc[12])\.(weight|bias)$",
|
19 |
-
r"\1.linear.\2",
|
20 |
-
k,
|
21 |
-
)
|
22 |
-
return k
|
23 |
-
|
24 |
|
25 |
@contextmanager
|
26 |
def safetensors_open(safetensors_file: str):
|
@@ -43,17 +27,12 @@ def safetensors_open(safetensors_file: str):
|
|
43 |
yield get_tensor
|
44 |
|
45 |
|
46 |
-
def _load_weights(
|
47 |
-
get_tensor: Callable[[str], torch.Tensor],
|
48 |
-
model: nn.Module,
|
49 |
-
is_quantized: bool = False,
|
50 |
-
) -> None:
|
51 |
"""Internal function to load weights using a tensor getter function."""
|
52 |
-
model = model.to(dtype=torch.
|
53 |
|
54 |
vision = model.vision
|
55 |
region = model.region
|
56 |
-
|
57 |
weight_map = {
|
58 |
"vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
|
59 |
"patch_emb"
|
@@ -111,42 +90,23 @@ def _load_weights(
|
|
111 |
}
|
112 |
)
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
{
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
)
|
132 |
-
else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
|
133 |
-
for i in range(len(model.text["blocks"])):
|
134 |
-
prefix = f"text_model.transformer.h.{i}"
|
135 |
-
blk = model.text["blocks"][i]
|
136 |
-
weight_map.update(
|
137 |
-
{
|
138 |
-
f"{prefix}.ln.qweight": blk["ln"].weight,
|
139 |
-
f"{prefix}.ln.bias": blk["ln"].bias,
|
140 |
-
f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
|
141 |
-
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
142 |
-
f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
|
143 |
-
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
144 |
-
f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
|
145 |
-
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
146 |
-
f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
|
147 |
-
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
148 |
-
}
|
149 |
-
)
|
150 |
|
151 |
for key, tensor in weight_map.items():
|
152 |
tensor.data.copy_(get_tensor(key))
|
@@ -160,75 +120,37 @@ def _load_weights(
|
|
160 |
def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
161 |
"""Load weights from a safetensors file into a MoondreamModel instance."""
|
162 |
with safetensors_open(weights_file) as get_tensor:
|
163 |
-
all_keys = get_tensor.keys()
|
164 |
-
|
165 |
-
is_quantized = any(
|
166 |
-
".qweight" in key or "_quantized" in key or "quant." in key
|
167 |
-
for key in all_keys
|
168 |
-
)
|
169 |
-
|
170 |
-
if "text_model.transformer.h.0.ln.weight" in all_keys:
|
171 |
-
layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
|
172 |
-
else:
|
173 |
-
layernorm_dtype = torch.float16
|
174 |
-
|
175 |
-
linear_dtype = torch.int8 if is_quantized else torch.float16
|
176 |
-
|
177 |
-
model.text = build_text_model(
|
178 |
-
TextConfig
|
179 |
-
)
|
180 |
-
if model.setup_caches_flag:
|
181 |
-
model._setup_caches()
|
182 |
-
|
183 |
if (
|
184 |
-
"vision.blocks.0.attn.proj.bias" in
|
185 |
-
or "model.vision.blocks.0.attn.proj.bias" in
|
186 |
):
|
187 |
with safetensors_open(weights_file) as get_tensor:
|
188 |
-
tensors = {
|
|
|
|
|
189 |
model.load_state_dict(tensors, strict=False)
|
190 |
else:
|
191 |
# Wrap the get_tensor function to handle key normalization
|
192 |
-
name_map = {k.replace("._orig_mod", ""): k for k in
|
193 |
_load_weights(
|
194 |
-
lambda x: get_tensor(name_map[x]).to(dtype=torch.
|
195 |
-
model,
|
196 |
-
is_quantized,
|
197 |
)
|
198 |
|
199 |
|
200 |
def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
201 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
if "text.blocks.0.ln.weight" in all_keys:
|
209 |
-
layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
|
210 |
-
else:
|
211 |
-
layernorm_dtype = torch.float16
|
212 |
-
|
213 |
-
linear_dtype = torch.int8 if is_quantized else torch.float16
|
214 |
-
model.text = build_text_model(
|
215 |
-
TextConfig
|
216 |
-
)
|
217 |
-
if model.setup_caches_flag:
|
218 |
-
model._setup_caches()
|
219 |
-
|
220 |
-
if (
|
221 |
-
"vision.blocks.0.attn.proj.bias" in all_keys
|
222 |
-
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
223 |
-
):
|
224 |
-
tensors = {add_linear_to_key(k): v for k, v in tensors.items()}
|
225 |
-
model.load_state_dict(tensors, strict=False)
|
226 |
else:
|
227 |
tensors = {
|
228 |
-
k.replace("._orig_mod", ""): v.to(dtype=torch.
|
229 |
for k, v in tensors.items()
|
230 |
}
|
231 |
-
_load_weights(lambda x: tensors[x], model
|
232 |
|
233 |
|
234 |
def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
|
|
|
1 |
import safetensors
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
5 |
from contextlib import contextmanager
|
6 |
from typing import Callable, List
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
@contextmanager
|
10 |
def safetensors_open(safetensors_file: str):
|
|
|
27 |
yield get_tensor
|
28 |
|
29 |
|
30 |
+
def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
|
|
|
|
|
|
|
|
|
31 |
"""Internal function to load weights using a tensor getter function."""
|
32 |
+
model = model.to(dtype=torch.bfloat16)
|
33 |
|
34 |
vision = model.vision
|
35 |
region = model.region
|
|
|
36 |
weight_map = {
|
37 |
"vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
|
38 |
"patch_emb"
|
|
|
90 |
}
|
91 |
)
|
92 |
|
93 |
+
for i in range(len(model.text["blocks"])):
|
94 |
+
prefix = f"text_model.transformer.h.{i}"
|
95 |
+
blk = model.text["blocks"][i]
|
96 |
+
weight_map.update(
|
97 |
+
{
|
98 |
+
f"{prefix}.ln.weight": blk["ln"].weight,
|
99 |
+
f"{prefix}.ln.bias": blk["ln"].bias,
|
100 |
+
f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
|
101 |
+
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
102 |
+
f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
|
103 |
+
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
104 |
+
f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
|
105 |
+
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
106 |
+
f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
|
107 |
+
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
108 |
+
}
|
109 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
for key, tensor in weight_map.items():
|
112 |
tensor.data.copy_(get_tensor(key))
|
|
|
120 |
def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
121 |
"""Load weights from a safetensors file into a MoondreamModel instance."""
|
122 |
with safetensors_open(weights_file) as get_tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
if (
|
124 |
+
"vision.blocks.0.attn.proj.bias" in get_tensor.keys()
|
125 |
+
or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys()
|
126 |
):
|
127 |
with safetensors_open(weights_file) as get_tensor:
|
128 |
+
tensors = {
|
129 |
+
k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys()
|
130 |
+
}
|
131 |
model.load_state_dict(tensors, strict=False)
|
132 |
else:
|
133 |
# Wrap the get_tensor function to handle key normalization
|
134 |
+
name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
|
135 |
_load_weights(
|
136 |
+
lambda x: get_tensor(name_map[x]).to(dtype=torch.bfloat16), model
|
|
|
|
|
137 |
)
|
138 |
|
139 |
|
140 |
def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
141 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
142 |
+
device = str(torch.empty(0).device)
|
143 |
+
tensors = torch.load(weights_file, map_location=device, weights_only=True)
|
144 |
+
if "vision.blocks.0.attn.proj.bias" in tensors.keys():
|
145 |
+
missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False)
|
146 |
+
print("Missing keys:", missing_keys)
|
147 |
+
print("Unexpected keys:", unexpected_keys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
else:
|
149 |
tensors = {
|
150 |
+
k.replace("._orig_mod", ""): v.to(dtype=torch.bfloat16)
|
151 |
for k, v in tensors.items()
|
152 |
}
|
153 |
+
_load_weights(lambda x: tensors[x], model)
|
154 |
|
155 |
|
156 |
def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
|