Commit
·
1f50acd
1
Parent(s):
2ca24ef
fix ops
Browse files- modeling_siglip.py +5 -4
modeling_siglip.py
CHANGED
|
@@ -95,11 +95,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
| 95 |
|
| 96 |
# Use inverse cdf transform for normal distribution to get truncated
|
| 97 |
# standard normal
|
| 98 |
-
if tensor.dtype
|
| 99 |
-
# The `erfinv_` op is not (yet?) defined in float16
|
|
|
|
| 100 |
tensor = tensor.to(torch.float32)
|
| 101 |
tensor.erfinv_()
|
| 102 |
-
tensor = tensor.to(
|
| 103 |
else:
|
| 104 |
tensor.erfinv_()
|
| 105 |
|
|
@@ -109,7 +110,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
| 109 |
|
| 110 |
# Clamp to ensure it's in the proper range
|
| 111 |
if tensor.dtype == torch.float16:
|
| 112 |
-
# The `clamp_` op is not (yet?) defined in float16
|
| 113 |
tensor = tensor.to(torch.float32)
|
| 114 |
tensor.clamp_(min=a, max=b)
|
| 115 |
tensor = tensor.to(torch.float16)
|
|
|
|
| 95 |
|
| 96 |
# Use inverse cdf transform for normal distribution to get truncated
|
| 97 |
# standard normal
|
| 98 |
+
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
| 99 |
+
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
| 100 |
+
og_dtype = tensor.dtype
|
| 101 |
tensor = tensor.to(torch.float32)
|
| 102 |
tensor.erfinv_()
|
| 103 |
+
tensor = tensor.to(og_dtype)
|
| 104 |
else:
|
| 105 |
tensor.erfinv_()
|
| 106 |
|
|
|
|
| 110 |
|
| 111 |
# Clamp to ensure it's in the proper range
|
| 112 |
if tensor.dtype == torch.float16:
|
| 113 |
+
# The `clamp_` op is not (yet?) defined in float16+cpu
|
| 114 |
tensor = tensor.to(torch.float32)
|
| 115 |
tensor.clamp_(min=a, max=b)
|
| 116 |
tensor = tensor.to(torch.float16)
|