snowclipsed
commited on
Commit
·
9fd23e4
1
Parent(s):
9308dfe
fix weights.py bfloat16 error
Browse files- weights.py +1 -1
weights.py
CHANGED
@@ -29,7 +29,7 @@ def safetensors_open(safetensors_file: str):
|
|
29 |
|
30 |
def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
|
31 |
"""Internal function to load weights using a tensor getter function."""
|
32 |
-
model = model.to(dtype=torch.
|
33 |
|
34 |
vision = model.vision
|
35 |
region = model.region
|
|
|
29 |
|
30 |
def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
|
31 |
"""Internal function to load weights using a tensor getter function."""
|
32 |
+
model = model.to(dtype=torch.float16)
|
33 |
|
34 |
vision = model.vision
|
35 |
region = model.region
|