snowclipsed commited on
Commit
7090fbf
·
1 Parent(s): fcbc9bf

only remove is_quantized

Browse files
Files changed (1) hide show
  1. weights.py +36 -1
weights.py CHANGED
@@ -161,6 +161,25 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
161
  """Load weights from a safetensors file into a MoondreamModel instance."""
162
  with safetensors_open(weights_file) as get_tensor:
163
  all_keys = get_tensor.keys()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  if (
165
  "vision.blocks.0.attn.proj.bias" in all_keys
166
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
@@ -173,7 +192,8 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
173
  name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
174
  _load_weights(
175
  lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
176
- model
 
177
  )
178
 
179
 
@@ -181,6 +201,21 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
181
  """Load weights from a PyTorch file into a MoondreamModel instance."""
182
  tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
183
  all_keys = tensors.keys()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  if (
186
  "vision.blocks.0.attn.proj.bias" in all_keys
 
161
  """Load weights from a safetensors file into a MoondreamModel instance."""
162
  with safetensors_open(weights_file) as get_tensor:
163
  all_keys = get_tensor.keys()
164
+
165
+ # is_quantized = any(
166
+ # ".qweight" in key or "_quantized" in key or "quant." in key
167
+ # for key in all_keys
168
+ # )
169
+
170
+ if "text_model.transformer.h.0.ln.weight" in all_keys:
171
+ layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
172
+ else:
173
+ layernorm_dtype = torch.float16
174
+
175
+ # linear_dtype = torch.int8 if is_quantized else torch.float16
176
+
177
+ model.text = build_text_model(
178
+ TextConfig
179
+ )
180
+ if model.setup_caches_flag:
181
+ model._setup_caches()
182
+
183
  if (
184
  "vision.blocks.0.attn.proj.bias" in all_keys
185
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
 
192
  name_map = {k.replace("._orig_mod", ""): k for k in all_keys}
193
  _load_weights(
194
  lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
195
+ model,
196
+ # is_quantized,
197
  )
198
 
199
 
 
201
  """Load weights from a PyTorch file into a MoondreamModel instance."""
202
  tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
203
  all_keys = tensors.keys()
204
+ # is_quantized = any(
205
+ # ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys
206
+ # )
207
+
208
+ if "text.blocks.0.ln.weight" in all_keys:
209
+ layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
210
+ else:
211
+ layernorm_dtype = torch.float16
212
+
213
+ # linear_dtype = torch.int8 if is_quantized else torch.float16
214
+ model.text = build_text_model(
215
+ TextConfig
216
+ )
217
+ if model.setup_caches_flag:
218
+ model._setup_caches()
219
 
220
  if (
221
  "vision.blocks.0.attn.proj.bias" in all_keys