lcipolina commited on
Commit
f9a5e67
1 Parent(s): 9c9ed77

Upload glide_text2im/fp16_util.py

Browse files
Files changed (1) hide show
  1. glide_text2im/fp16_util.py +25 -0
glide_text2im/fp16_util.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to inference with 16-bit precision.
3
+ """
4
+
5
+ import torch.nn as nn
6
+
7
+
8
+ def convert_module_to_f16(l):
9
+ """
10
+ Convert primitive modules to float16.
11
+ """
12
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
13
+ l.weight.data = l.weight.data.half()
14
+ if l.bias is not None:
15
+ l.bias.data = l.bias.data.half()
16
+
17
+
18
+ def convert_module_to_f32(l):
19
+ """
20
+ Convert primitive modules to float32, undoing convert_module_to_f16().
21
+ """
22
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23
+ l.weight.data = l.weight.data.float()
24
+ if l.bias is not None:
25
+ l.bias.data = l.bias.data.float()