snowclipsed commited on
Commit
9fd23e4
·
1 Parent(s): 9308dfe

fix weights.py bfloat16 error

Browse files
Files changed (1) hide show
  1. weights.py +1 -1
weights.py CHANGED
@@ -29,7 +29,7 @@ def safetensors_open(safetensors_file: str):
29
 
30
  def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
31
  """Internal function to load weights using a tensor getter function."""
32
- model = model.to(dtype=torch.bfloat16)
33
 
34
  vision = model.vision
35
  region = model.region
 
29
 
30
  def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
31
  """Internal function to load weights using a tensor getter function."""
32
+ model = model.to(dtype=torch.float16)
33
 
34
  vision = model.vision
35
  region = model.region