snowclipsed
commited on
Commit
·
fcbc9bf
1
Parent(s):
9fd23e4
add back add_linear_to_key
Browse files- weights.py +76 -33
weights.py
CHANGED
@@ -1,10 +1,26 @@
|
|
1 |
import safetensors
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
5 |
from contextlib import contextmanager
|
6 |
from typing import Callable, List
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
@contextmanager
|
10 |
def safetensors_open(safetensors_file: str):
|
@@ -27,12 +43,17 @@ def safetensors_open(safetensors_file: str):
|
|
27 |
yield get_tensor
|
28 |
|
29 |
|
30 |
-
def _load_weights(
|
|
|
|
|
|
|
|
|
31 |
"""Internal function to load weights using a tensor getter function."""
|
32 |
model = model.to(dtype=torch.float16)
|
33 |
|
34 |
vision = model.vision
|
35 |
region = model.region
|
|
|
36 |
weight_map = {
|
37 |
"vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
|
38 |
"patch_emb"
|
@@ -90,23 +111,42 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -
|
|
90 |
}
|
91 |
)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
for key, tensor in weight_map.items():
|
112 |
tensor.data.copy_(get_tensor(key))
|
@@ -120,34 +160,37 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -
|
|
120 |
def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
121 |
"""Load weights from a safetensors file into a MoondreamModel instance."""
|
122 |
with safetensors_open(weights_file) as get_tensor:
|
|
|
123 |
if (
|
124 |
-
"vision.blocks.0.attn.proj.bias" in
|
125 |
-
or "model.vision.blocks.0.attn.proj.bias" in
|
126 |
):
|
127 |
with safetensors_open(weights_file) as get_tensor:
|
128 |
-
tensors = {
|
129 |
-
k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys()
|
130 |
-
}
|
131 |
model.load_state_dict(tensors, strict=False)
|
132 |
else:
|
133 |
# Wrap the get_tensor function to handle key normalization
|
134 |
-
name_map = {k.replace("._orig_mod", ""): k for k in
|
135 |
_load_weights(
|
136 |
-
lambda x: get_tensor(name_map[x]).to(dtype=torch.
|
|
|
137 |
)
|
138 |
|
139 |
|
140 |
def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
141 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
148 |
else:
|
149 |
tensors = {
|
150 |
-
k.replace("._orig_mod", ""): v.to(dtype=torch.
|
151 |
for k, v in tensors.items()
|
152 |
}
|
153 |
_load_weights(lambda x: tensors[x], model)
|
@@ -168,4 +211,4 @@ def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
|
|
168 |
|
169 |
# Make all parameters contiguous
|
170 |
for param in model.parameters():
|
171 |
-
param.data = param.data.contiguous()
|
|
|
1 |
import safetensors
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
+
import re
|
5 |
|
6 |
from contextlib import contextmanager
|
7 |
from typing import Callable, List
|
8 |
|
9 |
+
from .text import build_text_model
|
10 |
+
from .config import TextConfig
|
11 |
+
|
12 |
+
|
13 |
+
# Our custom linear has an module named linear, so we add linear to the name
|
14 |
+
def add_linear_to_key(k: str) -> str:
|
15 |
+
k = k.replace("model.", "")
|
16 |
+
if k.startswith("text.") and ".linear." not in k:
|
17 |
+
k = re.sub(
|
18 |
+
r"(attn\.(?:qkv|proj)|mlp\.fc[12])\.(weight|bias)$",
|
19 |
+
r"\1.linear.\2",
|
20 |
+
k,
|
21 |
+
)
|
22 |
+
return k
|
23 |
+
|
24 |
|
25 |
@contextmanager
|
26 |
def safetensors_open(safetensors_file: str):
|
|
|
43 |
yield get_tensor
|
44 |
|
45 |
|
46 |
+
def _load_weights(
|
47 |
+
get_tensor: Callable[[str], torch.Tensor],
|
48 |
+
model: nn.Module,
|
49 |
+
is_quantized: bool = False,
|
50 |
+
) -> None:
|
51 |
"""Internal function to load weights using a tensor getter function."""
|
52 |
model = model.to(dtype=torch.float16)
|
53 |
|
54 |
vision = model.vision
|
55 |
region = model.region
|
56 |
+
|
57 |
weight_map = {
|
58 |
"vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
|
59 |
"patch_emb"
|
|
|
111 |
}
|
112 |
)
|
113 |
|
114 |
+
if not is_quantized:
|
115 |
+
for i in range(len(model.text["blocks"])):
|
116 |
+
prefix = f"text_model.transformer.h.{i}"
|
117 |
+
blk = model.text["blocks"][i]
|
118 |
+
weight_map.update(
|
119 |
+
{
|
120 |
+
f"{prefix}.ln.weight": blk["ln"].weight,
|
121 |
+
f"{prefix}.ln.bias": blk["ln"].bias,
|
122 |
+
f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
|
123 |
+
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
124 |
+
f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
|
125 |
+
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
126 |
+
f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
|
127 |
+
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
128 |
+
f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
|
129 |
+
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
130 |
+
}
|
131 |
+
)
|
132 |
+
else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
|
133 |
+
for i in range(len(model.text["blocks"])):
|
134 |
+
prefix = f"text_model.transformer.h.{i}"
|
135 |
+
blk = model.text["blocks"][i]
|
136 |
+
weight_map.update(
|
137 |
+
{
|
138 |
+
f"{prefix}.ln.qweight": blk["ln"].weight,
|
139 |
+
f"{prefix}.ln.bias": blk["ln"].bias,
|
140 |
+
f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
|
141 |
+
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
142 |
+
f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
|
143 |
+
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
144 |
+
f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
|
145 |
+
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
146 |
+
f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
|
147 |
+
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
148 |
+
}
|
149 |
+
)
|
150 |
|
151 |
for key, tensor in weight_map.items():
|
152 |
tensor.data.copy_(get_tensor(key))
|
|
|
160 |
def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
161 |
"""Load weights from a safetensors file into a MoondreamModel instance."""
|
162 |
with safetensors_open(weights_file) as get_tensor:
|
163 |
+
all_keys = get_tensor.keys()
|
164 |
if (
|
165 |
+
"vision.blocks.0.attn.proj.bias" in all_keys
|
166 |
+
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
167 |
):
|
168 |
with safetensors_open(weights_file) as get_tensor:
|
169 |
+
tensors = {add_linear_to_key(k): get_tensor(k) for k in all_keys}
|
|
|
|
|
170 |
model.load_state_dict(tensors, strict=False)
|
171 |
else:
|
172 |
# Wrap the get_tensor function to handle key normalization
|
173 |
+
name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
|
174 |
_load_weights(
|
175 |
+
lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
|
176 |
+
model
|
177 |
)
|
178 |
|
179 |
|
180 |
def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
181 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
182 |
+
tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
|
183 |
+
all_keys = tensors.keys()
|
184 |
+
|
185 |
+
if (
|
186 |
+
"vision.blocks.0.attn.proj.bias" in all_keys
|
187 |
+
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
188 |
+
):
|
189 |
+
tensors = {add_linear_to_key(k): v for k, v in tensors.items()}
|
190 |
+
model.load_state_dict(tensors, strict=False)
|
191 |
else:
|
192 |
tensors = {
|
193 |
+
k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
|
194 |
for k, v in tensors.items()
|
195 |
}
|
196 |
_load_weights(lambda x: tensors[x], model)
|
|
|
211 |
|
212 |
# Make all parameters contiguous
|
213 |
for param in model.parameters():
|
214 |
+
param.data = param.data.contiguous()
|