Mariam-Elz commited on
Commit
c96dc51
·
verified ·
1 Parent(s): 7ba4e32

Upload util/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. util/utils.py +194 -194
util/utils.py CHANGED
@@ -1,194 +1,194 @@
1
- import numpy as np
2
- import torch
3
- import random
4
-
5
-
6
- # Reworked so this matches gluPerspective / glm::perspective, using fovy
7
- def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
8
- # y = np.tan(fovy / 2)
9
- x = np.tan(fovx / 2)
10
- return torch.tensor([[1/x, 0, 0, 0],
11
- [ 0, -aspect/x, 0, 0],
12
- [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
13
- [ 0, 0, -1, 0]], dtype=torch.float32, device=device)
14
-
15
-
16
- def translate(x, y, z, device=None):
17
- return torch.tensor([[1, 0, 0, x],
18
- [0, 1, 0, y],
19
- [0, 0, 1, z],
20
- [0, 0, 0, 1]], dtype=torch.float32, device=device)
21
-
22
-
23
- def rotate_x(a, device=None):
24
- s, c = np.sin(a), np.cos(a)
25
- return torch.tensor([[1, 0, 0, 0],
26
- [0, c, -s, 0],
27
- [0, s, c, 0],
28
- [0, 0, 0, 1]], dtype=torch.float32, device=device)
29
-
30
-
31
- def rotate_y(a, device=None):
32
- s, c = np.sin(a), np.cos(a)
33
- return torch.tensor([[ c, 0, s, 0],
34
- [ 0, 1, 0, 0],
35
- [-s, 0, c, 0],
36
- [ 0, 0, 0, 1]], dtype=torch.float32, device=device)
37
-
38
-
39
- def rotate_z(a, device=None):
40
- s, c = np.sin(a), np.cos(a)
41
- return torch.tensor([[c, -s, 0, 0],
42
- [s, c, 0, 0],
43
- [0, 0, 1, 0],
44
- [0, 0, 0, 1]], dtype=torch.float32, device=device)
45
-
46
- @torch.no_grad()
47
- def batch_random_rotation_translation(b, t, device=None):
48
- m = np.random.normal(size=[b, 3, 3])
49
- m[:, 1] = np.cross(m[:, 0], m[:, 2])
50
- m[:, 2] = np.cross(m[:, 0], m[:, 1])
51
- m = m / np.linalg.norm(m, axis=2, keepdims=True)
52
- m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant')
53
- m[:, 3, 3] = 1.0
54
- m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3])
55
- return torch.tensor(m, dtype=torch.float32, device=device)
56
-
57
- @torch.no_grad()
58
- def random_rotation_translation(t, device=None):
59
- m = np.random.normal(size=[3, 3])
60
- m[1] = np.cross(m[0], m[2])
61
- m[2] = np.cross(m[0], m[1])
62
- m = m / np.linalg.norm(m, axis=1, keepdims=True)
63
- m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
64
- m[3, 3] = 1.0
65
- m[:3, 3] = np.random.uniform(-t, t, size=[3])
66
- return torch.tensor(m, dtype=torch.float32, device=device)
67
-
68
-
69
- @torch.no_grad()
70
- def random_rotation(device=None):
71
- m = np.random.normal(size=[3, 3])
72
- m[1] = np.cross(m[0], m[2])
73
- m[2] = np.cross(m[0], m[1])
74
- m = m / np.linalg.norm(m, axis=1, keepdims=True)
75
- m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
76
- m[3, 3] = 1.0
77
- m[:3, 3] = np.array([0,0,0]).astype(np.float32)
78
- return torch.tensor(m, dtype=torch.float32, device=device)
79
-
80
-
81
- def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
82
- return torch.sum(x*y, -1, keepdim=True)
83
-
84
-
85
- def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
86
- return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
87
-
88
-
89
- def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
90
- return x / length(x, eps)
91
-
92
-
93
- def lr_schedule(iter, warmup_iter, scheduler_decay):
94
- if iter < warmup_iter:
95
- return iter / warmup_iter
96
- return max(0.0, 10 ** (
97
- -(iter - warmup_iter) * scheduler_decay))
98
-
99
-
100
- def trans_depth(depth):
101
- depth = depth[0].detach().cpu().numpy()
102
- valid = depth > 0
103
- depth[valid] -= depth[valid].min()
104
- depth[valid] = ((depth[valid] / depth[valid].max()) * 255)
105
- return depth.astype('uint8')
106
-
107
-
108
- def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None):
109
- assert isinstance(input, torch.Tensor)
110
- if posinf is None:
111
- posinf = torch.finfo(input.dtype).max
112
- if neginf is None:
113
- neginf = torch.finfo(input.dtype).min
114
- assert nan == 0
115
- return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
116
-
117
-
118
- def load_item(filepath):
119
- with open(filepath, 'r') as f:
120
- items = [name.strip() for name in f.readlines()]
121
- return set(items)
122
-
123
- def load_prompt(filepath):
124
- uuid2prompt = {}
125
- with open(filepath, 'r') as f:
126
- for line in f.readlines():
127
- list_line = line.split(',')
128
- uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip()
129
- return uuid2prompt
130
-
131
- def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0):
132
- if scale == 1:
133
- return image_tensor
134
- B, C, H, W = image_tensor.shape
135
- new_H, new_W = int(H * scale), int(W * scale)
136
- resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0)
137
- background = torch.zeros_like(image_tensor) + c
138
- start_y, start_x = (H - new_H) // 2, (W - new_W) // 2
139
- if shift == 0:
140
- background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image
141
- else:
142
- for i in range(B):
143
- randx = random.randint(-shift, shift)
144
- randy = random.randint(-shift, shift)
145
- if rgb == True:
146
- if i == 0 or i==2 or i==4:
147
- randx = 0
148
- randy = 0
149
- background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i]
150
- if aug_shift == 0:
151
- return background
152
- for i in range(B):
153
- for j in range(C):
154
- background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255
155
- return background
156
-
157
- def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0):
158
- # triview_color: [6,C,H,W]
159
- # rgb is useful when shift is not 0
160
- triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift)
161
- if blender is False:
162
- triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2])
163
- triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1)
164
- triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2)
165
- triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2)
166
- triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1)
167
- triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2)
168
- else:
169
- triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2])
170
- triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1)
171
- triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2])
172
- triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2])
173
- triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2)
174
- triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2)
175
- if fix == True:
176
- triview_color0[1] = triview_color0[1] * 0
177
- triview_color0[2] = triview_color0[2] * 0
178
- triview_color3[1] = triview_color3[1] * 0
179
- triview_color3[2] = triview_color3[2] * 0
180
-
181
- triview_color1[0] = triview_color1[0] * 0
182
- triview_color1[1] = triview_color1[1] * 0
183
- triview_color4[0] = triview_color4[0] * 0
184
- triview_color4[1] = triview_color4[1] * 0
185
-
186
- triview_color2[0] = triview_color2[0] * 0
187
- triview_color2[2] = triview_color2[2] * 0
188
- triview_color5[0] = triview_color5[0] * 0
189
- triview_color5[2] = triview_color5[2] * 0
190
- color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2)
191
- color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2)
192
- color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim)
193
- return color_tensor_gt
194
-
 
1
+ import numpy as np
2
+ import torch
3
+ import random
4
+
5
+
6
+ # Reworked so this matches gluPerspective / glm::perspective, using fovy
7
+ def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
8
+ # y = np.tan(fovy / 2)
9
+ x = np.tan(fovx / 2)
10
+ return torch.tensor([[1/x, 0, 0, 0],
11
+ [ 0, -aspect/x, 0, 0],
12
+ [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
13
+ [ 0, 0, -1, 0]], dtype=torch.float32, device=device)
14
+
15
+
16
+ def translate(x, y, z, device=None):
17
+ return torch.tensor([[1, 0, 0, x],
18
+ [0, 1, 0, y],
19
+ [0, 0, 1, z],
20
+ [0, 0, 0, 1]], dtype=torch.float32, device=device)
21
+
22
+
23
+ def rotate_x(a, device=None):
24
+ s, c = np.sin(a), np.cos(a)
25
+ return torch.tensor([[1, 0, 0, 0],
26
+ [0, c, -s, 0],
27
+ [0, s, c, 0],
28
+ [0, 0, 0, 1]], dtype=torch.float32, device=device)
29
+
30
+
31
+ def rotate_y(a, device=None):
32
+ s, c = np.sin(a), np.cos(a)
33
+ return torch.tensor([[ c, 0, s, 0],
34
+ [ 0, 1, 0, 0],
35
+ [-s, 0, c, 0],
36
+ [ 0, 0, 0, 1]], dtype=torch.float32, device=device)
37
+
38
+
39
+ def rotate_z(a, device=None):
40
+ s, c = np.sin(a), np.cos(a)
41
+ return torch.tensor([[c, -s, 0, 0],
42
+ [s, c, 0, 0],
43
+ [0, 0, 1, 0],
44
+ [0, 0, 0, 1]], dtype=torch.float32, device=device)
45
+
46
+ @torch.no_grad()
47
+ def batch_random_rotation_translation(b, t, device=None):
48
+ m = np.random.normal(size=[b, 3, 3])
49
+ m[:, 1] = np.cross(m[:, 0], m[:, 2])
50
+ m[:, 2] = np.cross(m[:, 0], m[:, 1])
51
+ m = m / np.linalg.norm(m, axis=2, keepdims=True)
52
+ m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant')
53
+ m[:, 3, 3] = 1.0
54
+ m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3])
55
+ return torch.tensor(m, dtype=torch.float32, device=device)
56
+
57
+ @torch.no_grad()
58
+ def random_rotation_translation(t, device=None):
59
+ m = np.random.normal(size=[3, 3])
60
+ m[1] = np.cross(m[0], m[2])
61
+ m[2] = np.cross(m[0], m[1])
62
+ m = m / np.linalg.norm(m, axis=1, keepdims=True)
63
+ m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
64
+ m[3, 3] = 1.0
65
+ m[:3, 3] = np.random.uniform(-t, t, size=[3])
66
+ return torch.tensor(m, dtype=torch.float32, device=device)
67
+
68
+
69
+ @torch.no_grad()
70
+ def random_rotation(device=None):
71
+ m = np.random.normal(size=[3, 3])
72
+ m[1] = np.cross(m[0], m[2])
73
+ m[2] = np.cross(m[0], m[1])
74
+ m = m / np.linalg.norm(m, axis=1, keepdims=True)
75
+ m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
76
+ m[3, 3] = 1.0
77
+ m[:3, 3] = np.array([0,0,0]).astype(np.float32)
78
+ return torch.tensor(m, dtype=torch.float32, device=device)
79
+
80
+
81
+ def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
82
+ return torch.sum(x*y, -1, keepdim=True)
83
+
84
+
85
+ def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
86
+ return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
87
+
88
+
89
+ def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
90
+ return x / length(x, eps)
91
+
92
+
93
+ def lr_schedule(iter, warmup_iter, scheduler_decay):
94
+ if iter < warmup_iter:
95
+ return iter / warmup_iter
96
+ return max(0.0, 10 ** (
97
+ -(iter - warmup_iter) * scheduler_decay))
98
+
99
+
100
+ def trans_depth(depth):
101
+ depth = depth[0].detach().cpu().numpy()
102
+ valid = depth > 0
103
+ depth[valid] -= depth[valid].min()
104
+ depth[valid] = ((depth[valid] / depth[valid].max()) * 255)
105
+ return depth.astype('uint8')
106
+
107
+
108
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None):
109
+ assert isinstance(input, torch.Tensor)
110
+ if posinf is None:
111
+ posinf = torch.finfo(input.dtype).max
112
+ if neginf is None:
113
+ neginf = torch.finfo(input.dtype).min
114
+ assert nan == 0
115
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
116
+
117
+
118
+ def load_item(filepath):
119
+ with open(filepath, 'r') as f:
120
+ items = [name.strip() for name in f.readlines()]
121
+ return set(items)
122
+
123
+ def load_prompt(filepath):
124
+ uuid2prompt = {}
125
+ with open(filepath, 'r') as f:
126
+ for line in f.readlines():
127
+ list_line = line.split(',')
128
+ uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip()
129
+ return uuid2prompt
130
+
131
+ def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0):
132
+ if scale == 1:
133
+ return image_tensor
134
+ B, C, H, W = image_tensor.shape
135
+ new_H, new_W = int(H * scale), int(W * scale)
136
+ resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0)
137
+ background = torch.zeros_like(image_tensor) + c
138
+ start_y, start_x = (H - new_H) // 2, (W - new_W) // 2
139
+ if shift == 0:
140
+ background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image
141
+ else:
142
+ for i in range(B):
143
+ randx = random.randint(-shift, shift)
144
+ randy = random.randint(-shift, shift)
145
+ if rgb == True:
146
+ if i == 0 or i==2 or i==4:
147
+ randx = 0
148
+ randy = 0
149
+ background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i]
150
+ if aug_shift == 0:
151
+ return background
152
+ for i in range(B):
153
+ for j in range(C):
154
+ background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255
155
+ return background
156
+
157
+ def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0):
158
+ # triview_color: [6,C,H,W]
159
+ # rgb is useful when shift is not 0
160
+ triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift)
161
+ if blender is False:
162
+ triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2])
163
+ triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1)
164
+ triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2)
165
+ triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2)
166
+ triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1)
167
+ triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2)
168
+ else:
169
+ triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2])
170
+ triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1)
171
+ triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2])
172
+ triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2])
173
+ triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2)
174
+ triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2)
175
+ if fix == True:
176
+ triview_color0[1] = triview_color0[1] * 0
177
+ triview_color0[2] = triview_color0[2] * 0
178
+ triview_color3[1] = triview_color3[1] * 0
179
+ triview_color3[2] = triview_color3[2] * 0
180
+
181
+ triview_color1[0] = triview_color1[0] * 0
182
+ triview_color1[1] = triview_color1[1] * 0
183
+ triview_color4[0] = triview_color4[0] * 0
184
+ triview_color4[1] = triview_color4[1] * 0
185
+
186
+ triview_color2[0] = triview_color2[0] * 0
187
+ triview_color2[2] = triview_color2[2] * 0
188
+ triview_color5[0] = triview_color5[0] * 0
189
+ triview_color5[2] = triview_color5[2] * 0
190
+ color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2)
191
+ color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2)
192
+ color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim)
193
+ return color_tensor_gt
194
+