snowclipsed commited on
Commit
8f87bef
·
1 Parent(s): fa2446a

safely remove text_model loading in weights.py

Browse files
Files changed (1) hide show
  1. weights.py +0 -34
weights.py CHANGED
@@ -162,24 +162,6 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
162
  with safetensors_open(weights_file) as get_tensor:
163
  all_keys = get_tensor.keys()
164
 
165
- # is_quantized = any(
166
- # ".qweight" in key or "_quantized" in key or "quant." in key
167
- # for key in all_keys
168
- # )
169
-
170
- if "text_model.transformer.h.0.ln.weight" in all_keys:
171
- layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
172
- else:
173
- layernorm_dtype = torch.float16
174
-
175
- # linear_dtype = torch.int8 if is_quantized else torch.float16
176
-
177
- model.text = build_text_model(
178
- TextConfig
179
- )
180
- if model.setup_caches_flag:
181
- model._setup_caches()
182
-
183
  if (
184
  "vision.blocks.0.attn.proj.bias" in all_keys
185
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
@@ -201,22 +183,6 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
201
  """Load weights from a PyTorch file into a MoondreamModel instance."""
202
  tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
203
  all_keys = tensors.keys()
204
- # is_quantized = any(
205
- # ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys
206
- # )
207
-
208
- if "text.blocks.0.ln.weight" in all_keys:
209
- layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
210
- else:
211
- layernorm_dtype = torch.float16
212
-
213
- # linear_dtype = torch.int8 if is_quantized else torch.float16
214
- model.text = build_text_model(
215
- TextConfig
216
- )
217
- if model.setup_caches_flag:
218
- model._setup_caches()
219
-
220
  if (
221
  "vision.blocks.0.attn.proj.bias" in all_keys
222
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
 
162
  with safetensors_open(weights_file) as get_tensor:
163
  all_keys = get_tensor.keys()
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  if (
166
  "vision.blocks.0.attn.proj.bias" in all_keys
167
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
 
183
  """Load weights from a PyTorch file into a MoondreamModel instance."""
184
  tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
185
  all_keys = tensors.keys()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  if (
187
  "vision.blocks.0.attn.proj.bias" in all_keys
188
  or "model.vision.blocks.0.attn.proj.bias" in all_keys