snowclipsed
commited on
Commit
·
7090fbf
1
Parent(s):
fcbc9bf
only remove is_quantized
Browse files- weights.py +36 -1
weights.py
CHANGED
@@ -161,6 +161,25 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
|
161 |
"""Load weights from a safetensors file into a MoondreamModel instance."""
|
162 |
with safetensors_open(weights_file) as get_tensor:
|
163 |
all_keys = get_tensor.keys()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
if (
|
165 |
"vision.blocks.0.attn.proj.bias" in all_keys
|
166 |
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
@@ -173,7 +192,8 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
|
173 |
name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
|
174 |
_load_weights(
|
175 |
lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
|
176 |
-
model
|
|
|
177 |
)
|
178 |
|
179 |
|
@@ -181,6 +201,21 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
|
181 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
182 |
tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
|
183 |
all_keys = tensors.keys()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
if (
|
186 |
"vision.blocks.0.attn.proj.bias" in all_keys
|
|
|
161 |
"""Load weights from a safetensors file into a MoondreamModel instance."""
|
162 |
with safetensors_open(weights_file) as get_tensor:
|
163 |
all_keys = get_tensor.keys()
|
164 |
+
|
165 |
+
# is_quantized = any(
|
166 |
+
# ".qweight" in key or "_quantized" in key or "quant." in key
|
167 |
+
# for key in all_keys
|
168 |
+
# )
|
169 |
+
|
170 |
+
if "text_model.transformer.h.0.ln.weight" in all_keys:
|
171 |
+
layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
|
172 |
+
else:
|
173 |
+
layernorm_dtype = torch.float16
|
174 |
+
|
175 |
+
# linear_dtype = torch.int8 if is_quantized else torch.float16
|
176 |
+
|
177 |
+
model.text = build_text_model(
|
178 |
+
TextConfig
|
179 |
+
)
|
180 |
+
if model.setup_caches_flag:
|
181 |
+
model._setup_caches()
|
182 |
+
|
183 |
if (
|
184 |
"vision.blocks.0.attn.proj.bias" in all_keys
|
185 |
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
|
|
192 |
name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
|
193 |
_load_weights(
|
194 |
lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
|
195 |
+
model,
|
196 |
+
# is_quantized,
|
197 |
)
|
198 |
|
199 |
|
|
|
201 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
202 |
tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
|
203 |
all_keys = tensors.keys()
|
204 |
+
# is_quantized = any(
|
205 |
+
# ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys
|
206 |
+
# )
|
207 |
+
|
208 |
+
if "text.blocks.0.ln.weight" in all_keys:
|
209 |
+
layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
|
210 |
+
else:
|
211 |
+
layernorm_dtype = torch.float16
|
212 |
+
|
213 |
+
# linear_dtype = torch.int8 if is_quantized else torch.float16
|
214 |
+
model.text = build_text_model(
|
215 |
+
TextConfig
|
216 |
+
)
|
217 |
+
if model.setup_caches_flag:
|
218 |
+
model._setup_caches()
|
219 |
|
220 |
if (
|
221 |
"vision.blocks.0.attn.proj.bias" in all_keys
|