Spaces:
Runtime error
Runtime error
Upload glide_text2im/fp16_util.py
Browse files- glide_text2im/fp16_util.py +25 -0
glide_text2im/fp16_util.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers to inference with 16-bit precision.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
def convert_module_to_f16(l):
|
9 |
+
"""
|
10 |
+
Convert primitive modules to float16.
|
11 |
+
"""
|
12 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
13 |
+
l.weight.data = l.weight.data.half()
|
14 |
+
if l.bias is not None:
|
15 |
+
l.bias.data = l.bias.data.half()
|
16 |
+
|
17 |
+
|
18 |
+
def convert_module_to_f32(l):
|
19 |
+
"""
|
20 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
21 |
+
"""
|
22 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
23 |
+
l.weight.data = l.weight.data.float()
|
24 |
+
if l.bias is not None:
|
25 |
+
l.bias.data = l.bias.data.float()
|