Spaces:
Running
on
Zero
Running
on
Zero
RageshAntony
commited on
added onediffusion files
Browse files- onediffusion/dataset/multitask/multiview.py +277 -0
- onediffusion/dataset/raydiff_utils.py +739 -0
- onediffusion/dataset/transforms.py +133 -0
- onediffusion/dataset/utils.py +175 -0
- onediffusion/diffusion/pipelines/image_processor.py +674 -0
- onediffusion/diffusion/pipelines/onediffusion.py +1080 -0
- onediffusion/models/denoiser/__init__.py +3 -0
- onediffusion/models/denoiser/nextdit/__init__.py +1 -0
- onediffusion/models/denoiser/nextdit/layers.py +132 -0
- onediffusion/models/denoiser/nextdit/modeling_nextdit.py +571 -0
onediffusion/dataset/multitask/multiview.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from typing import List, Tuple, Union
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from onediffusion.dataset.utils import *
|
11 |
+
import glob
|
12 |
+
|
13 |
+
from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras
|
14 |
+
from onediffusion.dataset.transforms import CenterCropResizeImage
|
15 |
+
from pytorch3d.renderer import PerspectiveCameras
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
def _cameras_from_opencv_projection(
|
20 |
+
R: torch.Tensor,
|
21 |
+
tvec: torch.Tensor,
|
22 |
+
camera_matrix: torch.Tensor,
|
23 |
+
image_size: torch.Tensor,
|
24 |
+
do_normalize_cameras,
|
25 |
+
normalize_scale,
|
26 |
+
) -> PerspectiveCameras:
|
27 |
+
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
|
28 |
+
principal_point = camera_matrix[:, :2, 2]
|
29 |
+
|
30 |
+
# Retype the image_size correctly and flip to width, height.
|
31 |
+
image_size_wh = image_size.to(R).flip(dims=(1,))
|
32 |
+
|
33 |
+
# Screen to NDC conversion:
|
34 |
+
# For non square images, we scale the points such that smallest side
|
35 |
+
# has range [-1, 1] and the largest side has range [-u, u], with u > 1.
|
36 |
+
# This convention is consistent with the PyTorch3D renderer, as well as
|
37 |
+
# the transformation function `get_ndc_to_screen_transform`.
|
38 |
+
scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
|
39 |
+
scale = scale.expand(-1, 2)
|
40 |
+
c0 = image_size_wh / 2.0
|
41 |
+
|
42 |
+
# Get the PyTorch3D focal length and principal point.
|
43 |
+
focal_pytorch3d = focal_length / scale
|
44 |
+
p0_pytorch3d = -(principal_point - c0) / scale
|
45 |
+
|
46 |
+
# For R, T we flip x, y axes (opencv screen space has an opposite
|
47 |
+
# orientation of screen axes).
|
48 |
+
# We also transpose R (opencv multiplies points from the opposite=left side).
|
49 |
+
R_pytorch3d = R.clone().permute(0, 2, 1)
|
50 |
+
T_pytorch3d = tvec.clone()
|
51 |
+
R_pytorch3d[:, :, :2] *= -1
|
52 |
+
T_pytorch3d[:, :2] *= -1
|
53 |
+
|
54 |
+
cams = PerspectiveCameras(
|
55 |
+
R=R_pytorch3d,
|
56 |
+
T=T_pytorch3d,
|
57 |
+
focal_length=focal_pytorch3d,
|
58 |
+
principal_point=p0_pytorch3d,
|
59 |
+
image_size=image_size,
|
60 |
+
device=R.device,
|
61 |
+
)
|
62 |
+
|
63 |
+
if do_normalize_cameras:
|
64 |
+
cams, _ = normalize_cameras(cams, scale=normalize_scale)
|
65 |
+
|
66 |
+
cams = first_camera_transform(cams, rotation_only=False)
|
67 |
+
return cams
|
68 |
+
|
69 |
+
def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0):
|
70 |
+
cameras = _cameras_from_opencv_projection(
|
71 |
+
R=Rs,
|
72 |
+
tvec=Ts,
|
73 |
+
camera_matrix=Ks,
|
74 |
+
image_size=sizes,
|
75 |
+
do_normalize_cameras=do_normalize_cameras,
|
76 |
+
normalize_scale=normalize_scale
|
77 |
+
)
|
78 |
+
|
79 |
+
rays_embedding = cameras_to_rays(
|
80 |
+
cameras=cameras,
|
81 |
+
num_patches_x=target_size,
|
82 |
+
num_patches_y=target_size,
|
83 |
+
crop_parameters=None,
|
84 |
+
use_plucker=use_plucker
|
85 |
+
)
|
86 |
+
|
87 |
+
return rays_embedding.rays
|
88 |
+
|
89 |
+
def convert_rgba_to_rgb_white_bg(image):
|
90 |
+
"""Convert RGBA image to RGB with white background"""
|
91 |
+
if image.mode == 'RGBA':
|
92 |
+
# Create a white background
|
93 |
+
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
|
94 |
+
# Composite the image onto the white background
|
95 |
+
return Image.alpha_composite(background, image).convert('RGB')
|
96 |
+
return image.convert('RGB')
|
97 |
+
|
98 |
+
class MultiviewDataset(Dataset):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
scene_folders: str,
|
102 |
+
samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range
|
103 |
+
transform=None,
|
104 |
+
caption_keys: Union[str, List] = "caption",
|
105 |
+
multiscale=False,
|
106 |
+
aspect_ratio_type=ASPECT_RATIO_512,
|
107 |
+
c2w_scaling=1.7,
|
108 |
+
default_max_distance=1, # default max distance from all camera of a scene ,
|
109 |
+
do_normalize=True, # whether normalize translation of c2w with max_distance
|
110 |
+
swap_xz=False, # whether swap x and z axis of 3D scenes
|
111 |
+
valid_paths: str = "",
|
112 |
+
frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different
|
113 |
+
):
|
114 |
+
if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list):
|
115 |
+
samples_per_set = (samples_per_set, samples_per_set)
|
116 |
+
self.samples_range = samples_per_set # Tuple of (min_samples, max_samples)
|
117 |
+
self.transform = transform
|
118 |
+
self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys]
|
119 |
+
self.aspect_ratio = aspect_ratio_type
|
120 |
+
self.scene_folders = sorted(glob.glob(scene_folders))
|
121 |
+
# filter out scene folders that do not have transforms.json
|
122 |
+
self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders))
|
123 |
+
|
124 |
+
# if valid_paths.txt exists, only use paths in that file
|
125 |
+
if os.path.exists(valid_paths):
|
126 |
+
with open(valid_paths, 'r') as f:
|
127 |
+
valid_scene_folders = f.read().splitlines()
|
128 |
+
self.scene_folders = sorted(valid_scene_folders)
|
129 |
+
|
130 |
+
self.c2w_scaling = c2w_scaling
|
131 |
+
self.do_normalize = do_normalize
|
132 |
+
self.default_max_distance = default_max_distance
|
133 |
+
self.swap_xz = swap_xz
|
134 |
+
self.frame_sliding_windows = frame_sliding_windows
|
135 |
+
|
136 |
+
if multiscale:
|
137 |
+
assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880]
|
138 |
+
if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
|
139 |
+
self.interpolate_model = T.InterpolationMode.LANCZOS
|
140 |
+
self.ratio_index = {}
|
141 |
+
self.ratio_nums = {}
|
142 |
+
for k, v in self.aspect_ratio.items():
|
143 |
+
self.ratio_index[float(k)] = [] # used for self.getitem
|
144 |
+
self.ratio_nums[float(k)] = 0 # used for batch-sampler
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.scene_folders)
|
148 |
+
|
149 |
+
def __getitem__(self, idx):
|
150 |
+
try:
|
151 |
+
scene_path = self.scene_folders[idx]
|
152 |
+
|
153 |
+
if os.path.exists(os.path.join(scene_path, "images")):
|
154 |
+
image_folder = os.path.join(scene_path, "images")
|
155 |
+
downscale_factor = 1
|
156 |
+
elif os.path.exists(os.path.join(scene_path, "images_4")):
|
157 |
+
image_folder = os.path.join(scene_path, "images_4")
|
158 |
+
downscale_factor = 1 / 4
|
159 |
+
elif os.path.exists(os.path.join(scene_path, "images_8")):
|
160 |
+
image_folder = os.path.join(scene_path, "images_8")
|
161 |
+
downscale_factor = 1 / 8
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
|
165 |
+
json_path = os.path.join(scene_path, "transforms.json")
|
166 |
+
caption_path = os.path.join(scene_path, "caption.json")
|
167 |
+
image_files = os.listdir(image_folder)
|
168 |
+
|
169 |
+
with open(json_path, 'r') as f:
|
170 |
+
json_data = json.load(f)
|
171 |
+
height, width = json_data['h'], json_data['w']
|
172 |
+
|
173 |
+
dh, dw = int(height * downscale_factor), int(width * downscale_factor)
|
174 |
+
fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor
|
175 |
+
cx = dw // 2
|
176 |
+
cy = dh // 2
|
177 |
+
|
178 |
+
frame_list = json_data['frames']
|
179 |
+
|
180 |
+
# Randomly select number of samples
|
181 |
+
|
182 |
+
samples_per_set = random.randint(self.samples_range[0], self.samples_range[1])
|
183 |
+
|
184 |
+
# uniformly for all scenes
|
185 |
+
if self.frame_sliding_windows is None:
|
186 |
+
selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list)))
|
187 |
+
# limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles)
|
188 |
+
else:
|
189 |
+
# Determine the starting index of the sliding window
|
190 |
+
if len(frame_list) <= self.frame_sliding_windows:
|
191 |
+
# If the frame list is smaller than or equal to X, use the entire list
|
192 |
+
window_start = 0
|
193 |
+
window_end = len(frame_list)
|
194 |
+
else:
|
195 |
+
# Randomly select a starting point for the window
|
196 |
+
window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows)
|
197 |
+
window_end = window_start + self.frame_sliding_windows
|
198 |
+
|
199 |
+
# Get the indices within the sliding window
|
200 |
+
window_indices = list(range(window_start, window_end))
|
201 |
+
|
202 |
+
# Randomly sample indices from the window
|
203 |
+
selected_indices = random.sample(window_indices, samples_per_set)
|
204 |
+
|
205 |
+
image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices]
|
206 |
+
image_paths = [os.path.join(image_folder, file) for file in image_files]
|
207 |
+
|
208 |
+
# Load images and convert RGBA to RGB with white background
|
209 |
+
images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths]
|
210 |
+
|
211 |
+
if self.transform:
|
212 |
+
images = [self.transform(image) for image in images]
|
213 |
+
else:
|
214 |
+
closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0
|
215 |
+
closest_size = tuple(map(int, closest_size))
|
216 |
+
transform = T.Compose([
|
217 |
+
T.ToTensor(),
|
218 |
+
CenterCropResizeImage(closest_size),
|
219 |
+
T.Normalize([.5], [.5]),
|
220 |
+
])
|
221 |
+
images = [transform(image) for image in images]
|
222 |
+
images = torch.stack(images)
|
223 |
+
|
224 |
+
c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices]
|
225 |
+
c2ws = torch.tensor(c2ws).reshape(-1, 4, 4)
|
226 |
+
# max_distance = json_data.get('max_distance', self.default_max_distance)
|
227 |
+
# if 'max_distance' not in json_data.keys():
|
228 |
+
# print(f"not found `max_distance` in json path: {json_path}")
|
229 |
+
|
230 |
+
if self.swap_xz:
|
231 |
+
swap_xz = torch.tensor([[[0, 0, 1., 0],
|
232 |
+
[0, 1., 0, 0],
|
233 |
+
[-1., 0, 0, 0],
|
234 |
+
[0, 0, 0, 1.]]])
|
235 |
+
c2ws = swap_xz @ c2ws
|
236 |
+
|
237 |
+
# OPENGL to OPENCV
|
238 |
+
c2ws[:, 0:3, 1:3] *= -1
|
239 |
+
c2ws = c2ws[:, [1, 0, 2, 3], :]
|
240 |
+
c2ws[:, 2, :] *= -1
|
241 |
+
|
242 |
+
w2cs = torch.inverse(c2ws)
|
243 |
+
K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1)
|
244 |
+
Rs = w2cs[:, :3, :3]
|
245 |
+
Ts = w2cs[:, :3, 3]
|
246 |
+
sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1)
|
247 |
+
|
248 |
+
# get ray embedding and padding last dimension to 16 (num channels of VAE)
|
249 |
+
# rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
|
250 |
+
rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
|
251 |
+
rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6)
|
252 |
+
# padding = (0, 10) # pad the last dimension to 16
|
253 |
+
# rays = torch.nn.functional.pad(rays, padding, "constant", 0)
|
254 |
+
rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658
|
255 |
+
|
256 |
+
if os.path.exists(caption_path):
|
257 |
+
with open(caption_path, 'r') as f:
|
258 |
+
caption_key = random.choice(self.caption_keys)
|
259 |
+
caption = json.load(f).get(caption_key, "")
|
260 |
+
else:
|
261 |
+
caption = ""
|
262 |
+
|
263 |
+
caption = "[[multiview]] " + caption if caption else "[[multiview]]"
|
264 |
+
|
265 |
+
return {
|
266 |
+
'pixel_values': images,
|
267 |
+
'rays': rays,
|
268 |
+
'aspect_ratio': closest_ratio,
|
269 |
+
'caption': caption,
|
270 |
+
'height': dh,
|
271 |
+
'width': dw,
|
272 |
+
# 'origins': rays_od[..., :3],
|
273 |
+
# 'dirs': rays_od[..., 3:6]
|
274 |
+
}
|
275 |
+
except Exception as e:
|
276 |
+
return self.__getitem__(random.randint(0, len(self.scene_folders) - 1))
|
277 |
+
|
onediffusion/dataset/raydiff_utils.py
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
Adapted from code originally written by David Novotny.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from pytorch3d.transforms import Rotate, Translate
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from pytorch3d.renderer import PerspectiveCameras, RayBundle
|
13 |
+
|
14 |
+
def intersect_skew_line_groups(p, r, mask):
|
15 |
+
# p, r both of shape (B, N, n_intersected_lines, 3)
|
16 |
+
# mask of shape (B, N, n_intersected_lines)
|
17 |
+
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
18 |
+
if p_intersect is None:
|
19 |
+
return None, None, None, None
|
20 |
+
_, p_line_intersect = point_line_distance(
|
21 |
+
p, r, p_intersect[..., None, :].expand_as(p)
|
22 |
+
)
|
23 |
+
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
|
24 |
+
dim=-1
|
25 |
+
)
|
26 |
+
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
27 |
+
|
28 |
+
|
29 |
+
def intersect_skew_lines_high_dim(p, r, mask=None):
|
30 |
+
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
31 |
+
dim = p.shape[-1]
|
32 |
+
# make sure the heading vectors are l2-normed
|
33 |
+
if mask is None:
|
34 |
+
mask = torch.ones_like(p[..., 0])
|
35 |
+
r = torch.nn.functional.normalize(r, dim=-1)
|
36 |
+
|
37 |
+
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
38 |
+
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
39 |
+
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
40 |
+
|
41 |
+
# I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10
|
42 |
+
# p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0]
|
43 |
+
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
44 |
+
|
45 |
+
# I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3])
|
46 |
+
# sum_proj: torch.Size([1, 1, 3, 1])
|
47 |
+
|
48 |
+
# p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0]
|
49 |
+
|
50 |
+
if torch.any(torch.isnan(p_intersect)):
|
51 |
+
print(p_intersect)
|
52 |
+
return None, None
|
53 |
+
ipdb.set_trace()
|
54 |
+
assert False
|
55 |
+
return p_intersect, r
|
56 |
+
|
57 |
+
|
58 |
+
def point_line_distance(p1, r1, p2):
|
59 |
+
df = p2 - p1
|
60 |
+
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
61 |
+
line_pt_nearest = p2 - proj_vector
|
62 |
+
d = (proj_vector).norm(dim=-1)
|
63 |
+
return d, line_pt_nearest
|
64 |
+
|
65 |
+
|
66 |
+
def compute_optical_axis_intersection(cameras):
|
67 |
+
centers = cameras.get_camera_center()
|
68 |
+
principal_points = cameras.principal_point
|
69 |
+
|
70 |
+
one_vec = torch.ones((len(cameras), 1), device=centers.device)
|
71 |
+
optical_axis = torch.cat((principal_points, one_vec), -1)
|
72 |
+
|
73 |
+
# optical_axis = torch.cat(
|
74 |
+
# (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1
|
75 |
+
# )
|
76 |
+
|
77 |
+
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
|
78 |
+
pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
|
79 |
+
|
80 |
+
directions = pp2 - centers
|
81 |
+
centers = centers.unsqueeze(0).unsqueeze(0)
|
82 |
+
directions = directions.unsqueeze(0).unsqueeze(0)
|
83 |
+
|
84 |
+
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
|
85 |
+
p=centers, r=directions, mask=None
|
86 |
+
)
|
87 |
+
|
88 |
+
if p_intersect is None:
|
89 |
+
dist = None
|
90 |
+
else:
|
91 |
+
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
92 |
+
dist = (p_intersect - centers).norm(dim=-1)
|
93 |
+
|
94 |
+
return p_intersect, dist, p_line_intersect, pp2, r
|
95 |
+
|
96 |
+
|
97 |
+
def normalize_cameras(cameras, scale=1.0):
|
98 |
+
"""
|
99 |
+
Normalizes cameras such that the optical axes point to the origin, the rotation is
|
100 |
+
identity, and the norm of the translation of the first camera is 1.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
cameras (pytorch3d.renderer.cameras.CamerasBase).
|
104 |
+
scale (float): Norm of the translation of the first camera.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras.
|
108 |
+
undo_transform (function): Function that undoes the normalization.
|
109 |
+
"""
|
110 |
+
|
111 |
+
# Let distance from first camera to origin be unit
|
112 |
+
new_cameras = cameras.clone()
|
113 |
+
new_transform = (
|
114 |
+
new_cameras.get_world_to_view_transform()
|
115 |
+
) # potential R is not valid matrix
|
116 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
|
117 |
+
cameras
|
118 |
+
)
|
119 |
+
|
120 |
+
if p_intersect is None:
|
121 |
+
print("Warning: optical axes code has a nan. Returning identity cameras.")
|
122 |
+
new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype)
|
123 |
+
new_cameras.T[:] = torch.tensor(
|
124 |
+
[0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype
|
125 |
+
)
|
126 |
+
return new_cameras, lambda x: x
|
127 |
+
|
128 |
+
d = dist.squeeze(dim=1).squeeze(dim=0)[0]
|
129 |
+
# Degenerate case
|
130 |
+
if d == 0:
|
131 |
+
print(cameras.T)
|
132 |
+
print(new_transform.get_matrix()[:, 3, :3])
|
133 |
+
assert False
|
134 |
+
assert d != 0
|
135 |
+
|
136 |
+
# Can't figure out how to make scale part of the transform too without messing up R.
|
137 |
+
# Ideally, we would just wrap it all in a single Pytorch3D transform so that it
|
138 |
+
# would work with any structure (eg PointClouds, Meshes).
|
139 |
+
tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse()
|
140 |
+
tT = Translate(p_intersect)
|
141 |
+
t = tR.compose(tT)
|
142 |
+
|
143 |
+
new_transform = t.compose(new_transform)
|
144 |
+
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
|
145 |
+
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale
|
146 |
+
|
147 |
+
def undo_transform(cameras):
|
148 |
+
cameras_copy = cameras.clone()
|
149 |
+
cameras_copy.T *= d / scale
|
150 |
+
new_t = (
|
151 |
+
t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix()
|
152 |
+
)
|
153 |
+
cameras_copy.R = new_t[:, :3, :3]
|
154 |
+
cameras_copy.T = new_t[:, 3, :3]
|
155 |
+
return cameras_copy
|
156 |
+
|
157 |
+
return new_cameras, undo_transform
|
158 |
+
|
159 |
+
def first_camera_transform(cameras, rotation_only=True):
|
160 |
+
new_cameras = cameras.clone()
|
161 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
162 |
+
tR = Rotate(new_cameras.R[0].unsqueeze(0))
|
163 |
+
if rotation_only:
|
164 |
+
t = tR.inverse()
|
165 |
+
else:
|
166 |
+
tT = Translate(new_cameras.T[0].unsqueeze(0))
|
167 |
+
t = tR.compose(tT).inverse()
|
168 |
+
|
169 |
+
new_transform = t.compose(new_transform)
|
170 |
+
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
|
171 |
+
new_cameras.T = new_transform.get_matrix()[:, 3, :3]
|
172 |
+
|
173 |
+
return new_cameras
|
174 |
+
|
175 |
+
|
176 |
+
def get_identity_cameras_with_intrinsics(cameras):
|
177 |
+
D = len(cameras)
|
178 |
+
device = cameras.R.device
|
179 |
+
|
180 |
+
new_cameras = cameras.clone()
|
181 |
+
new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1))
|
182 |
+
new_cameras.T = torch.zeros((D, 3), device=device)
|
183 |
+
|
184 |
+
return new_cameras
|
185 |
+
|
186 |
+
|
187 |
+
def normalize_cameras_batch(cameras, scale=1.0, normalize_first_camera=False):
|
188 |
+
new_cameras = []
|
189 |
+
undo_transforms = []
|
190 |
+
for cam in cameras:
|
191 |
+
if normalize_first_camera:
|
192 |
+
# Normalize cameras such that first camera is identity and origin is at
|
193 |
+
# first camera center.
|
194 |
+
normalized_cameras = first_camera_transform(cam, rotation_only=False)
|
195 |
+
undo_transform = None
|
196 |
+
else:
|
197 |
+
normalized_cameras, undo_transform = normalize_cameras(cam, scale=scale)
|
198 |
+
new_cameras.append(normalized_cameras)
|
199 |
+
undo_transforms.append(undo_transform)
|
200 |
+
return new_cameras, undo_transforms
|
201 |
+
|
202 |
+
|
203 |
+
class Rays(object):
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
rays=None,
|
207 |
+
origins=None,
|
208 |
+
directions=None,
|
209 |
+
moments=None,
|
210 |
+
is_plucker=False,
|
211 |
+
moments_rescale=1.0,
|
212 |
+
ndc_coordinates=None,
|
213 |
+
crop_parameters=None,
|
214 |
+
num_patches_x=16,
|
215 |
+
num_patches_y=16,
|
216 |
+
):
|
217 |
+
"""
|
218 |
+
Ray class to keep track of current ray representation.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
rays: (..., 6).
|
222 |
+
origins: (..., 3).
|
223 |
+
directions: (..., 3).
|
224 |
+
moments: (..., 3).
|
225 |
+
is_plucker: If True, rays are in plucker coordinates (Default: False).
|
226 |
+
moments_rescale: Rescale the moment component of the rays by a scalar.
|
227 |
+
ndc_coordinates: (..., 2): NDC coordinates of each ray.
|
228 |
+
"""
|
229 |
+
if rays is not None:
|
230 |
+
self.rays = rays
|
231 |
+
self._is_plucker = is_plucker
|
232 |
+
elif origins is not None and directions is not None:
|
233 |
+
self.rays = torch.cat((origins, directions), dim=-1)
|
234 |
+
self._is_plucker = False
|
235 |
+
elif directions is not None and moments is not None:
|
236 |
+
self.rays = torch.cat((directions, moments), dim=-1)
|
237 |
+
self._is_plucker = True
|
238 |
+
else:
|
239 |
+
raise Exception("Invalid combination of arguments")
|
240 |
+
|
241 |
+
if moments_rescale != 1.0:
|
242 |
+
self.rescale_moments(moments_rescale)
|
243 |
+
|
244 |
+
if ndc_coordinates is not None:
|
245 |
+
self.ndc_coordinates = ndc_coordinates
|
246 |
+
elif crop_parameters is not None:
|
247 |
+
# (..., H, W, 2)
|
248 |
+
xy_grid = compute_ndc_coordinates(
|
249 |
+
crop_parameters,
|
250 |
+
num_patches_x=num_patches_x,
|
251 |
+
num_patches_y=num_patches_y,
|
252 |
+
)[..., :2]
|
253 |
+
xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2)
|
254 |
+
self.ndc_coordinates = xy_grid
|
255 |
+
else:
|
256 |
+
self.ndc_coordinates = None
|
257 |
+
|
258 |
+
def __getitem__(self, index):
|
259 |
+
return Rays(
|
260 |
+
rays=self.rays[index],
|
261 |
+
is_plucker=self._is_plucker,
|
262 |
+
ndc_coordinates=(
|
263 |
+
self.ndc_coordinates[index]
|
264 |
+
if self.ndc_coordinates is not None
|
265 |
+
else None
|
266 |
+
),
|
267 |
+
)
|
268 |
+
|
269 |
+
def to_spatial(self, include_ndc_coordinates=False):
|
270 |
+
"""
|
271 |
+
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
torch.Tensor: (..., 6, H, W)
|
275 |
+
"""
|
276 |
+
rays = self.to_plucker().rays
|
277 |
+
*batch_dims, P, D = rays.shape
|
278 |
+
H = W = int(np.sqrt(P))
|
279 |
+
assert H * W == P
|
280 |
+
rays = torch.transpose(rays, -1, -2) # (..., 6, H * W)
|
281 |
+
rays = rays.reshape(*batch_dims, D, H, W)
|
282 |
+
if include_ndc_coordinates:
|
283 |
+
ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W)
|
284 |
+
ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W)
|
285 |
+
rays = torch.cat((rays, ndc_coords), dim=-3)
|
286 |
+
return rays
|
287 |
+
|
288 |
+
def rescale_moments(self, scale):
|
289 |
+
"""
|
290 |
+
Rescale the moment component of the rays by a scalar. Might be desirable since
|
291 |
+
moments may come from a very narrow distribution.
|
292 |
+
|
293 |
+
Note that this modifies in place!
|
294 |
+
"""
|
295 |
+
if self.is_plucker:
|
296 |
+
self.rays[..., 3:] *= scale
|
297 |
+
return self
|
298 |
+
else:
|
299 |
+
return self.to_plucker().rescale_moments(scale)
|
300 |
+
|
301 |
+
@classmethod
|
302 |
+
def from_spatial(cls, rays, moments_rescale=1.0, ndc_coordinates=None):
|
303 |
+
"""
|
304 |
+
Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6)
|
305 |
+
|
306 |
+
Args:
|
307 |
+
rays: (..., 6, H, W)
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
Rays: (..., H * W, 6)
|
311 |
+
"""
|
312 |
+
*batch_dims, D, H, W = rays.shape
|
313 |
+
rays = rays.reshape(*batch_dims, D, H * W)
|
314 |
+
rays = torch.transpose(rays, -1, -2)
|
315 |
+
return cls(
|
316 |
+
rays=rays,
|
317 |
+
is_plucker=True,
|
318 |
+
moments_rescale=moments_rescale,
|
319 |
+
ndc_coordinates=ndc_coordinates,
|
320 |
+
)
|
321 |
+
|
322 |
+
def to_point_direction(self, normalize_moment=True):
|
323 |
+
"""
|
324 |
+
Convert to point direction representation <O, D>.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
rays: (..., 6).
|
328 |
+
"""
|
329 |
+
if self._is_plucker:
|
330 |
+
direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1)
|
331 |
+
moment = self.rays[..., 3:]
|
332 |
+
if normalize_moment:
|
333 |
+
c = torch.linalg.norm(direction, dim=-1, keepdim=True)
|
334 |
+
moment = moment / c
|
335 |
+
points = torch.cross(direction, moment, dim=-1)
|
336 |
+
return Rays(
|
337 |
+
rays=torch.cat((points, direction), dim=-1),
|
338 |
+
is_plucker=False,
|
339 |
+
ndc_coordinates=self.ndc_coordinates,
|
340 |
+
)
|
341 |
+
else:
|
342 |
+
return self
|
343 |
+
|
344 |
+
def to_plucker(self):
|
345 |
+
"""
|
346 |
+
Convert to plucker representation <D, OxD>.
|
347 |
+
"""
|
348 |
+
if self.is_plucker:
|
349 |
+
return self
|
350 |
+
else:
|
351 |
+
ray = self.rays.clone()
|
352 |
+
ray_origins = ray[..., :3]
|
353 |
+
ray_directions = ray[..., 3:]
|
354 |
+
# Normalize ray directions to unit vectors
|
355 |
+
ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
|
356 |
+
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
|
357 |
+
new_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
|
358 |
+
return Rays(
|
359 |
+
rays=new_ray, is_plucker=True, ndc_coordinates=self.ndc_coordinates
|
360 |
+
)
|
361 |
+
|
362 |
+
def get_directions(self, normalize=True):
|
363 |
+
if self.is_plucker:
|
364 |
+
directions = self.rays[..., :3]
|
365 |
+
else:
|
366 |
+
directions = self.rays[..., 3:]
|
367 |
+
if normalize:
|
368 |
+
directions = torch.nn.functional.normalize(directions, dim=-1)
|
369 |
+
return directions
|
370 |
+
|
371 |
+
def get_origins(self):
|
372 |
+
if self.is_plucker:
|
373 |
+
origins = self.to_point_direction().get_origins()
|
374 |
+
else:
|
375 |
+
origins = self.rays[..., :3]
|
376 |
+
return origins
|
377 |
+
|
378 |
+
def get_moments(self):
|
379 |
+
if self.is_plucker:
|
380 |
+
moments = self.rays[..., 3:]
|
381 |
+
else:
|
382 |
+
moments = self.to_plucker().get_moments()
|
383 |
+
return moments
|
384 |
+
|
385 |
+
def get_ndc_coordinates(self):
|
386 |
+
return self.ndc_coordinates
|
387 |
+
|
388 |
+
@property
|
389 |
+
def is_plucker(self):
|
390 |
+
return self._is_plucker
|
391 |
+
|
392 |
+
@property
|
393 |
+
def device(self):
|
394 |
+
return self.rays.device
|
395 |
+
|
396 |
+
def __repr__(self, *args, **kwargs):
|
397 |
+
ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor"
|
398 |
+
if self._is_plucker:
|
399 |
+
return "PluRay" + ray_str
|
400 |
+
else:
|
401 |
+
return "DirRay" + ray_str
|
402 |
+
|
403 |
+
def to(self, device):
|
404 |
+
self.rays = self.rays.to(device)
|
405 |
+
|
406 |
+
def clone(self):
|
407 |
+
return Rays(rays=self.rays.clone(), is_plucker=self._is_plucker)
|
408 |
+
|
409 |
+
@property
|
410 |
+
def shape(self):
|
411 |
+
return self.rays.shape
|
412 |
+
|
413 |
+
def visualize(self):
|
414 |
+
directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu()
|
415 |
+
moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu()
|
416 |
+
return (directions + 1) / 2, (moments + 1) / 2
|
417 |
+
|
418 |
+
def to_ray_bundle(self, length=0.3, recenter=True):
|
419 |
+
lengths = torch.ones_like(self.get_origins()[..., :2]) * length
|
420 |
+
lengths[..., 0] = 0
|
421 |
+
if recenter:
|
422 |
+
centers, _ = intersect_skew_lines_high_dim(
|
423 |
+
self.get_origins(), self.get_directions()
|
424 |
+
)
|
425 |
+
centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1)
|
426 |
+
else:
|
427 |
+
centers = self.get_origins()
|
428 |
+
return RayBundle(
|
429 |
+
origins=centers,
|
430 |
+
directions=self.get_directions(),
|
431 |
+
lengths=lengths,
|
432 |
+
xys=self.get_directions(),
|
433 |
+
)
|
434 |
+
|
435 |
+
|
436 |
+
def cameras_to_rays(
|
437 |
+
cameras,
|
438 |
+
crop_parameters,
|
439 |
+
use_half_pix=True,
|
440 |
+
use_plucker=True,
|
441 |
+
num_patches_x=16,
|
442 |
+
num_patches_y=16,
|
443 |
+
):
|
444 |
+
"""
|
445 |
+
Unprojects rays from camera center to grid on image plane.
|
446 |
+
|
447 |
+
Args:
|
448 |
+
cameras: Pytorch3D cameras to unproject. Can be batched.
|
449 |
+
crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale).
|
450 |
+
Shape is (B, 4).
|
451 |
+
use_half_pix: If True, use half pixel offset (Default: True).
|
452 |
+
use_plucker: If True, return rays in plucker coordinates (Default: False).
|
453 |
+
num_patches_x: Number of patches in x direction (Default: 16).
|
454 |
+
num_patches_y: Number of patches in y direction (Default: 16).
|
455 |
+
"""
|
456 |
+
unprojected = []
|
457 |
+
crop_parameters_list = (
|
458 |
+
crop_parameters if crop_parameters is not None else [None for _ in cameras]
|
459 |
+
)
|
460 |
+
for camera, crop_param in zip(cameras, crop_parameters_list):
|
461 |
+
xyd_grid = compute_ndc_coordinates(
|
462 |
+
crop_parameters=crop_param,
|
463 |
+
use_half_pix=use_half_pix,
|
464 |
+
num_patches_x=num_patches_x,
|
465 |
+
num_patches_y=num_patches_y,
|
466 |
+
)
|
467 |
+
|
468 |
+
unprojected.append(
|
469 |
+
camera.unproject_points(
|
470 |
+
xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
|
471 |
+
)
|
472 |
+
)
|
473 |
+
unprojected = torch.stack(unprojected, dim=0) # (N, P, 3)
|
474 |
+
origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3)
|
475 |
+
origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3)
|
476 |
+
directions = unprojected - origins
|
477 |
+
|
478 |
+
rays = Rays(
|
479 |
+
origins=origins,
|
480 |
+
directions=directions,
|
481 |
+
crop_parameters=crop_parameters,
|
482 |
+
num_patches_x=num_patches_x,
|
483 |
+
num_patches_y=num_patches_y,
|
484 |
+
)
|
485 |
+
if use_plucker:
|
486 |
+
return rays.to_plucker()
|
487 |
+
return rays
|
488 |
+
|
489 |
+
|
490 |
+
def rays_to_cameras(
|
491 |
+
rays,
|
492 |
+
crop_parameters,
|
493 |
+
num_patches_x=16,
|
494 |
+
num_patches_y=16,
|
495 |
+
use_half_pix=True,
|
496 |
+
sampled_ray_idx=None,
|
497 |
+
cameras=None,
|
498 |
+
focal_length=(3.453,),
|
499 |
+
):
|
500 |
+
"""
|
501 |
+
If cameras are provided, will use those intrinsics. Otherwise will use the provided
|
502 |
+
focal_length(s). Dataset default is 3.32.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
rays (Rays): (N, P, 6)
|
506 |
+
crop_parameters (torch.Tensor): (N, 4)
|
507 |
+
"""
|
508 |
+
device = rays.device
|
509 |
+
origins = rays.get_origins()
|
510 |
+
directions = rays.get_directions()
|
511 |
+
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
|
512 |
+
|
513 |
+
# Retrieve target rays
|
514 |
+
if cameras is None:
|
515 |
+
if len(focal_length) == 1:
|
516 |
+
focal_length = focal_length * rays.shape[0]
|
517 |
+
I_camera = PerspectiveCameras(focal_length=focal_length, device=device)
|
518 |
+
else:
|
519 |
+
# Use same intrinsics but reset to identity extrinsics.
|
520 |
+
I_camera = cameras.clone()
|
521 |
+
I_camera.R[:] = torch.eye(3, device=device)
|
522 |
+
I_camera.T[:] = torch.zeros(3, device=device)
|
523 |
+
I_patch_rays = cameras_to_rays(
|
524 |
+
cameras=I_camera,
|
525 |
+
num_patches_x=num_patches_x,
|
526 |
+
num_patches_y=num_patches_y,
|
527 |
+
use_half_pix=use_half_pix,
|
528 |
+
crop_parameters=crop_parameters,
|
529 |
+
).get_directions()
|
530 |
+
|
531 |
+
if sampled_ray_idx is not None:
|
532 |
+
I_patch_rays = I_patch_rays[:, sampled_ray_idx]
|
533 |
+
|
534 |
+
# Compute optimal rotation to align rays
|
535 |
+
R = torch.zeros_like(I_camera.R)
|
536 |
+
for i in range(len(I_camera)):
|
537 |
+
R[i] = compute_optimal_rotation_alignment(
|
538 |
+
I_patch_rays[i],
|
539 |
+
directions[i],
|
540 |
+
)
|
541 |
+
|
542 |
+
# Construct and return rotated camera
|
543 |
+
cam = I_camera.clone()
|
544 |
+
cam.R = R
|
545 |
+
cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
|
546 |
+
return cam
|
547 |
+
|
548 |
+
|
549 |
+
# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/
|
550 |
+
def ql_decomposition(A):
|
551 |
+
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
|
552 |
+
A_tilde = torch.matmul(A, P)
|
553 |
+
Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
|
554 |
+
Q = torch.matmul(Q_tilde, P)
|
555 |
+
L = torch.matmul(torch.matmul(P, R_tilde), P)
|
556 |
+
d = torch.diag(L)
|
557 |
+
Q[:, 0] *= torch.sign(d[0])
|
558 |
+
Q[:, 1] *= torch.sign(d[1])
|
559 |
+
Q[:, 2] *= torch.sign(d[2])
|
560 |
+
L[0] *= torch.sign(d[0])
|
561 |
+
L[1] *= torch.sign(d[1])
|
562 |
+
L[2] *= torch.sign(d[2])
|
563 |
+
return Q, L
|
564 |
+
|
565 |
+
|
566 |
+
def rays_to_cameras_homography(
|
567 |
+
rays,
|
568 |
+
crop_parameters,
|
569 |
+
num_patches_x=16,
|
570 |
+
num_patches_y=16,
|
571 |
+
use_half_pix=True,
|
572 |
+
sampled_ray_idx=None,
|
573 |
+
reproj_threshold=0.2,
|
574 |
+
):
|
575 |
+
"""
|
576 |
+
Args:
|
577 |
+
rays (Rays): (N, P, 6)
|
578 |
+
crop_parameters (torch.Tensor): (N, 4)
|
579 |
+
"""
|
580 |
+
device = rays.device
|
581 |
+
origins = rays.get_origins()
|
582 |
+
directions = rays.get_directions()
|
583 |
+
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
|
584 |
+
|
585 |
+
# Retrieve target rays
|
586 |
+
I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device)
|
587 |
+
I_patch_rays = cameras_to_rays(
|
588 |
+
cameras=I_camera,
|
589 |
+
num_patches_x=num_patches_x,
|
590 |
+
num_patches_y=num_patches_y,
|
591 |
+
use_half_pix=use_half_pix,
|
592 |
+
crop_parameters=crop_parameters,
|
593 |
+
).get_directions()
|
594 |
+
|
595 |
+
if sampled_ray_idx is not None:
|
596 |
+
I_patch_rays = I_patch_rays[:, sampled_ray_idx]
|
597 |
+
|
598 |
+
# Compute optimal rotation to align rays
|
599 |
+
Rs = []
|
600 |
+
focal_lengths = []
|
601 |
+
principal_points = []
|
602 |
+
for i in range(rays.shape[-3]):
|
603 |
+
R, f, pp = compute_optimal_rotation_intrinsics(
|
604 |
+
I_patch_rays[i],
|
605 |
+
directions[i],
|
606 |
+
reproj_threshold=reproj_threshold,
|
607 |
+
)
|
608 |
+
Rs.append(R)
|
609 |
+
focal_lengths.append(f)
|
610 |
+
principal_points.append(pp)
|
611 |
+
|
612 |
+
R = torch.stack(Rs)
|
613 |
+
focal_lengths = torch.stack(focal_lengths)
|
614 |
+
principal_points = torch.stack(principal_points)
|
615 |
+
T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
|
616 |
+
return PerspectiveCameras(
|
617 |
+
R=R,
|
618 |
+
T=T,
|
619 |
+
focal_length=focal_lengths,
|
620 |
+
principal_point=principal_points,
|
621 |
+
device=device,
|
622 |
+
)
|
623 |
+
|
624 |
+
|
625 |
+
def compute_optimal_rotation_alignment(A, B):
|
626 |
+
"""
|
627 |
+
Compute optimal R that minimizes: || A - B @ R ||_F
|
628 |
+
|
629 |
+
Args:
|
630 |
+
A (torch.Tensor): (N, 3)
|
631 |
+
B (torch.Tensor): (N, 3)
|
632 |
+
|
633 |
+
Returns:
|
634 |
+
R (torch.tensor): (3, 3)
|
635 |
+
"""
|
636 |
+
# normally with R @ B, this would be A @ B.T
|
637 |
+
H = B.T @ A
|
638 |
+
U, _, Vh = torch.linalg.svd(H, full_matrices=True)
|
639 |
+
s = torch.linalg.det(U @ Vh)
|
640 |
+
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
|
641 |
+
return U @ S_prime @ Vh
|
642 |
+
|
643 |
+
|
644 |
+
def compute_optimal_rotation_intrinsics(
|
645 |
+
rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2
|
646 |
+
):
|
647 |
+
"""
|
648 |
+
Note: for some reason, f seems to be 1/f.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
rays_origin (torch.Tensor): (N, 3)
|
652 |
+
rays_target (torch.Tensor): (N, 3)
|
653 |
+
z_threshold (float): Threshold for z value to be considered valid.
|
654 |
+
|
655 |
+
Returns:
|
656 |
+
R (torch.tensor): (3, 3)
|
657 |
+
focal_length (torch.tensor): (2,)
|
658 |
+
principal_point (torch.tensor): (2,)
|
659 |
+
"""
|
660 |
+
device = rays_origin.device
|
661 |
+
z_mask = torch.logical_and(
|
662 |
+
torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold
|
663 |
+
)[:, 2]
|
664 |
+
rays_target = rays_target[z_mask]
|
665 |
+
rays_origin = rays_origin[z_mask]
|
666 |
+
rays_origin = rays_origin[:, :2] / rays_origin[:, -1:]
|
667 |
+
rays_target = rays_target[:, :2] / rays_target[:, -1:]
|
668 |
+
|
669 |
+
A, _ = cv2.findHomography(
|
670 |
+
rays_origin.cpu().numpy(),
|
671 |
+
rays_target.cpu().numpy(),
|
672 |
+
cv2.RANSAC,
|
673 |
+
reproj_threshold,
|
674 |
+
)
|
675 |
+
A = torch.from_numpy(A).float().to(device)
|
676 |
+
|
677 |
+
if torch.linalg.det(A) < 0:
|
678 |
+
A = -A
|
679 |
+
|
680 |
+
R, L = ql_decomposition(A)
|
681 |
+
L = L / L[2][2]
|
682 |
+
|
683 |
+
f = torch.stack((L[0][0], L[1][1]))
|
684 |
+
pp = torch.stack((L[2][0], L[2][1]))
|
685 |
+
return R, f, pp
|
686 |
+
|
687 |
+
|
688 |
+
def compute_ndc_coordinates(
|
689 |
+
crop_parameters=None,
|
690 |
+
use_half_pix=True,
|
691 |
+
num_patches_x=16,
|
692 |
+
num_patches_y=16,
|
693 |
+
device=None,
|
694 |
+
):
|
695 |
+
"""
|
696 |
+
Computes NDC Grid using crop_parameters. If crop_parameters is not provided,
|
697 |
+
then it assumes that the crop is the entire image (corresponding to an NDC grid
|
698 |
+
where top left corner is (1, 1) and bottom right corner is (-1, -1)).
|
699 |
+
"""
|
700 |
+
if crop_parameters is None:
|
701 |
+
cc_x, cc_y, width = 0, 0, 2
|
702 |
+
else:
|
703 |
+
if len(crop_parameters.shape) > 1:
|
704 |
+
return torch.stack(
|
705 |
+
[
|
706 |
+
compute_ndc_coordinates(
|
707 |
+
crop_parameters=crop_param,
|
708 |
+
use_half_pix=use_half_pix,
|
709 |
+
num_patches_x=num_patches_x,
|
710 |
+
num_patches_y=num_patches_y,
|
711 |
+
)
|
712 |
+
for crop_param in crop_parameters
|
713 |
+
],
|
714 |
+
dim=0,
|
715 |
+
)
|
716 |
+
device = crop_parameters.device
|
717 |
+
cc_x, cc_y, width, _ = crop_parameters
|
718 |
+
|
719 |
+
dx = 1 / num_patches_x
|
720 |
+
dy = 1 / num_patches_y
|
721 |
+
if use_half_pix:
|
722 |
+
min_y = 1 - dy
|
723 |
+
max_y = -min_y
|
724 |
+
min_x = 1 - dx
|
725 |
+
max_x = -min_x
|
726 |
+
else:
|
727 |
+
min_y = min_x = 1
|
728 |
+
max_y = -1 + 2 * dy
|
729 |
+
max_x = -1 + 2 * dx
|
730 |
+
|
731 |
+
y, x = torch.meshgrid(
|
732 |
+
torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device),
|
733 |
+
torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device),
|
734 |
+
indexing="ij",
|
735 |
+
)
|
736 |
+
x_prime = x * width / 2 - cc_x
|
737 |
+
y_prime = y * width / 2 - cc_y
|
738 |
+
xyd_grid = torch.stack([x_prime, y_prime, torch.ones_like(x)], dim=-1)
|
739 |
+
return xyd_grid
|
onediffusion/dataset/transforms.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def crop(image, i, j, h, w):
|
5 |
+
"""
|
6 |
+
Args:
|
7 |
+
image (torch.tensor): Image to be cropped. Size is (C, H, W)
|
8 |
+
"""
|
9 |
+
if len(image.size()) != 3:
|
10 |
+
raise ValueError("image should be a 3D tensor")
|
11 |
+
return image[..., i : i + h, j : j + w]
|
12 |
+
|
13 |
+
def resize(image, target_size, interpolation_mode):
|
14 |
+
if len(target_size) != 2:
|
15 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
16 |
+
return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0)
|
17 |
+
|
18 |
+
def resize_scale(image, target_size, interpolation_mode):
|
19 |
+
if len(target_size) != 2:
|
20 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
21 |
+
H, W = image.size(-2), image.size(-1)
|
22 |
+
scale_ = target_size[0] / min(H, W)
|
23 |
+
return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0)
|
24 |
+
|
25 |
+
def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"):
|
26 |
+
"""
|
27 |
+
Do spatial cropping and resizing to the image
|
28 |
+
Args:
|
29 |
+
image (torch.tensor): Image to be cropped. Size is (C, H, W)
|
30 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
31 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
32 |
+
h (int): Height of the cropped region.
|
33 |
+
w (int): Width of the cropped region.
|
34 |
+
size (tuple(int, int)): height and width of resized image
|
35 |
+
Returns:
|
36 |
+
image (torch.tensor): Resized and cropped image. Size is (C, H, W)
|
37 |
+
"""
|
38 |
+
if len(image.size()) != 3:
|
39 |
+
raise ValueError("image should be a 3D torch.tensor")
|
40 |
+
image = crop(image, i, j, h, w)
|
41 |
+
image = resize(image, size, interpolation_mode)
|
42 |
+
return image
|
43 |
+
|
44 |
+
def center_crop(image, crop_size):
|
45 |
+
if len(image.size()) != 3:
|
46 |
+
raise ValueError("image should be a 3D torch.tensor")
|
47 |
+
h, w = image.size(-2), image.size(-1)
|
48 |
+
th, tw = crop_size
|
49 |
+
if h < th or w < tw:
|
50 |
+
raise ValueError("height and width must be no smaller than crop_size")
|
51 |
+
i = int(round((h - th) / 2.0))
|
52 |
+
j = int(round((w - tw) / 2.0))
|
53 |
+
return crop(image, i, j, th, tw)
|
54 |
+
|
55 |
+
def center_crop_using_short_edge(image):
|
56 |
+
if len(image.size()) != 3:
|
57 |
+
raise ValueError("image should be a 3D torch.tensor")
|
58 |
+
h, w = image.size(-2), image.size(-1)
|
59 |
+
if h < w:
|
60 |
+
th, tw = h, h
|
61 |
+
i = 0
|
62 |
+
j = int(round((w - tw) / 2.0))
|
63 |
+
else:
|
64 |
+
th, tw = w, w
|
65 |
+
i = int(round((h - th) / 2.0))
|
66 |
+
j = 0
|
67 |
+
return crop(image, i, j, th, tw)
|
68 |
+
|
69 |
+
class CenterCropResizeImage:
|
70 |
+
"""
|
71 |
+
Resize the image while maintaining aspect ratio, and then crop it to the desired size.
|
72 |
+
The resizing is done such that the area of padding/cropping is minimized.
|
73 |
+
"""
|
74 |
+
def __init__(self, size, interpolation_mode="bilinear"):
|
75 |
+
if isinstance(size, tuple):
|
76 |
+
if len(size) != 2:
|
77 |
+
raise ValueError(f"Size should be a tuple (height, width), instead got {size}")
|
78 |
+
self.size = size
|
79 |
+
else:
|
80 |
+
self.size = (size, size)
|
81 |
+
self.interpolation_mode = interpolation_mode
|
82 |
+
|
83 |
+
def __call__(self, image):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W)
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width)
|
90 |
+
"""
|
91 |
+
target_height, target_width = self.size
|
92 |
+
target_aspect = target_width / target_height
|
93 |
+
|
94 |
+
# Get current image shape and aspect ratio
|
95 |
+
_, height, width = image.shape
|
96 |
+
height, width = float(height), float(width)
|
97 |
+
current_aspect = width / height
|
98 |
+
|
99 |
+
# Calculate crop dimensions
|
100 |
+
if current_aspect > target_aspect:
|
101 |
+
# Image is wider than target, crop width
|
102 |
+
crop_height = height
|
103 |
+
crop_width = height * target_aspect
|
104 |
+
else:
|
105 |
+
# Image is taller than target, crop height
|
106 |
+
crop_height = width / target_aspect
|
107 |
+
crop_width = width
|
108 |
+
|
109 |
+
# Calculate crop coordinates (center crop)
|
110 |
+
y1 = (height - crop_height) / 2
|
111 |
+
x1 = (width - crop_width) / 2
|
112 |
+
|
113 |
+
# Perform the crop
|
114 |
+
cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width))
|
115 |
+
|
116 |
+
# Resize the cropped image to the target size
|
117 |
+
resized_image = resize(cropped_image, self.size, self.interpolation_mode)
|
118 |
+
|
119 |
+
return resized_image
|
120 |
+
|
121 |
+
# Example usage
|
122 |
+
if __name__ == "__main__":
|
123 |
+
# Create a sample image tensor
|
124 |
+
sample_image = torch.rand(3, 480, 640) # (C, H, W)
|
125 |
+
|
126 |
+
# Initialize the transform
|
127 |
+
transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear")
|
128 |
+
|
129 |
+
# Apply the transform
|
130 |
+
transformed_image = transform(sample_image)
|
131 |
+
|
132 |
+
print(f"Original image shape: {sample_image.shape}")
|
133 |
+
print(f"Transformed image shape: {transformed_image.shape}")
|
onediffusion/dataset/utils.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
ASPECT_RATIO_2880 = {
|
3 |
+
'0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0],
|
4 |
+
'0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0],
|
5 |
+
'0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0],
|
6 |
+
'0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0],
|
7 |
+
'0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0],
|
8 |
+
'1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0],
|
9 |
+
'1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0],
|
10 |
+
'1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0],
|
11 |
+
'2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0],
|
12 |
+
'3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0]
|
13 |
+
}
|
14 |
+
|
15 |
+
ASPECT_RATIO_2048 = {
|
16 |
+
'0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0],
|
17 |
+
'0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
|
18 |
+
'0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
|
19 |
+
'0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
|
20 |
+
'0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
|
21 |
+
'1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
|
22 |
+
'1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
|
23 |
+
'1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
|
24 |
+
'2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0],
|
25 |
+
'3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0]
|
26 |
+
}
|
27 |
+
|
28 |
+
ASPECT_RATIO_1024 = {
|
29 |
+
'0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
|
30 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
31 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
32 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
33 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
34 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
35 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
36 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
37 |
+
'2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
|
38 |
+
'3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
|
39 |
+
}
|
40 |
+
|
41 |
+
ASPECT_RATIO_512 = {
|
42 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
43 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
44 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
45 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
46 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
47 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
48 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
49 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
50 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
51 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
ASPECT_RATIO_384 = {
|
56 |
+
'0.25': [192.0, 768.0],
|
57 |
+
'0.26': [192.0, 736.0],
|
58 |
+
'0.27': [208.0, 768.0],
|
59 |
+
'0.28': [208.0, 736.0],
|
60 |
+
'0.33': [240.0, 720.0],
|
61 |
+
'0.4': [256.0, 640.0],
|
62 |
+
'0.42': [304.0, 720.0],
|
63 |
+
'0.48': [368.0, 768.0],
|
64 |
+
'0.5': [384.0, 768.0],
|
65 |
+
'0.52': [384.0, 736.0],
|
66 |
+
'0.57': [384.0, 672.0],
|
67 |
+
'0.6': [384.0, 640.0],
|
68 |
+
'0.73': [384.0, 528.0],
|
69 |
+
'0.77': [384.0, 496.0],
|
70 |
+
'0.83': [384.0, 464.0],
|
71 |
+
'0.89': [384.0, 432.0],
|
72 |
+
'0.92': [384.0, 416.0],
|
73 |
+
'1.0': [384.0, 384.0],
|
74 |
+
'1.09': [384.0, 352.0],
|
75 |
+
'1.14': [384.0, 336.0],
|
76 |
+
'1.2': [384.0, 320.0],
|
77 |
+
'1.26': [384.0, 304.0],
|
78 |
+
'1.33': [384.0, 288.0],
|
79 |
+
'1.41': [384.0, 272.0],
|
80 |
+
'1.6': [384.0, 240.0],
|
81 |
+
'1.71': [384.0, 224.0],
|
82 |
+
'2.0': [384.0, 192.0],
|
83 |
+
'2.4': [384.0, 160.0],
|
84 |
+
'2.88': [368.0, 128.0],
|
85 |
+
'3.0': [384.0, 128.0],
|
86 |
+
'3.43': [384.0, 112.0],
|
87 |
+
'4.0': [384.0, 96.0]
|
88 |
+
}
|
89 |
+
|
90 |
+
ASPECT_RATIO_256 = {
|
91 |
+
'0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
|
92 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
93 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
94 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
95 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
96 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
97 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
98 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
99 |
+
'2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
|
100 |
+
'3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
|
101 |
+
}
|
102 |
+
|
103 |
+
ASPECT_RATIO_256_TEST = {
|
104 |
+
'0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
|
105 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
106 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
107 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
108 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
109 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
110 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
111 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
112 |
+
'2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
|
113 |
+
'4.0': [512.0, 128.0]
|
114 |
+
}
|
115 |
+
|
116 |
+
ASPECT_RATIO_512_TEST = {
|
117 |
+
'0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
|
118 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
119 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
120 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
121 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
122 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
123 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
124 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
125 |
+
'2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
|
126 |
+
'4.0': [1024.0, 256.0]
|
127 |
+
}
|
128 |
+
|
129 |
+
ASPECT_RATIO_1024_TEST = {
|
130 |
+
'0.25': [512., 2048.], '0.28': [512., 1856.],
|
131 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
132 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
133 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
134 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
135 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
136 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
137 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
138 |
+
'2.5': [1600., 640.], '3.0': [1728., 576.],
|
139 |
+
'4.0': [2048., 512.],
|
140 |
+
}
|
141 |
+
|
142 |
+
ASPECT_RATIO_2048_TEST = {
|
143 |
+
'0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0],
|
144 |
+
'0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
|
145 |
+
'0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
|
146 |
+
'0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
|
147 |
+
'0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
|
148 |
+
'1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
|
149 |
+
'1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
|
150 |
+
'1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
|
151 |
+
'2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0],
|
152 |
+
'4.0': [4096.0, 1024.0]
|
153 |
+
}
|
154 |
+
|
155 |
+
ASPECT_RATIO_2880_TEST = {
|
156 |
+
'0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0],
|
157 |
+
'0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0],
|
158 |
+
'0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0],
|
159 |
+
'0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0],
|
160 |
+
'0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0],
|
161 |
+
'1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0],
|
162 |
+
'1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0],
|
163 |
+
'1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0],
|
164 |
+
'2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0],
|
165 |
+
'4.0': [8192.0, 2048.0],
|
166 |
+
}
|
167 |
+
|
168 |
+
def get_chunks(lst, n):
|
169 |
+
for i in range(0, len(lst), n):
|
170 |
+
yield lst[i:i + n]
|
171 |
+
|
172 |
+
def get_closest_ratio(height: float, width: float, ratios: dict):
|
173 |
+
aspect_ratio = height / width
|
174 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
175 |
+
return ratios[closest_ratio], float(closest_ratio)
|
onediffusion/diffusion/pipelines/image_processor.py
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import warnings
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import PIL.Image
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torchvision.transforms as T
|
24 |
+
from PIL import Image, ImageFilter, ImageOps
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
28 |
+
|
29 |
+
from onediffusion.dataset.transforms import CenterCropResizeImage
|
30 |
+
|
31 |
+
PipelineImageInput = Union[
|
32 |
+
PIL.Image.Image,
|
33 |
+
np.ndarray,
|
34 |
+
torch.Tensor,
|
35 |
+
List[PIL.Image.Image],
|
36 |
+
List[np.ndarray],
|
37 |
+
List[torch.Tensor],
|
38 |
+
]
|
39 |
+
|
40 |
+
PipelineDepthInput = PipelineImageInput
|
41 |
+
|
42 |
+
|
43 |
+
def is_valid_image(image):
|
44 |
+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
45 |
+
|
46 |
+
|
47 |
+
def is_valid_image_imagelist(images):
|
48 |
+
# check if the image input is one of the supported formats for image and image list:
|
49 |
+
# it can be either one of below 3
|
50 |
+
# (1) a 4d pytorch tensor or numpy array,
|
51 |
+
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
|
52 |
+
# (3) a list of valid image
|
53 |
+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
54 |
+
return True
|
55 |
+
elif is_valid_image(images):
|
56 |
+
return True
|
57 |
+
elif isinstance(images, list):
|
58 |
+
return all(is_valid_image(image) for image in images)
|
59 |
+
return False
|
60 |
+
|
61 |
+
|
62 |
+
class VaeImageProcessorOneDiffuser(ConfigMixin):
|
63 |
+
"""
|
64 |
+
Image processor for VAE.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
69 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
70 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
71 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
72 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
73 |
+
Resampling filter to use when resizing the image.
|
74 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
75 |
+
Whether to normalize the image to [-1,1].
|
76 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
77 |
+
Whether to binarize the image to 0/1.
|
78 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
79 |
+
Whether to convert the images to RGB format.
|
80 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
81 |
+
Whether to convert the images to grayscale format.
|
82 |
+
"""
|
83 |
+
|
84 |
+
config_name = CONFIG_NAME
|
85 |
+
|
86 |
+
@register_to_config
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
do_resize: bool = True,
|
90 |
+
vae_scale_factor: int = 8,
|
91 |
+
vae_latent_channels: int = 4,
|
92 |
+
resample: str = "lanczos",
|
93 |
+
do_normalize: bool = True,
|
94 |
+
do_binarize: bool = False,
|
95 |
+
do_convert_rgb: bool = False,
|
96 |
+
do_convert_grayscale: bool = False,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
if do_convert_rgb and do_convert_grayscale:
|
100 |
+
raise ValueError(
|
101 |
+
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
|
102 |
+
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
103 |
+
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
104 |
+
)
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
108 |
+
"""
|
109 |
+
Convert a numpy image or a batch of images to a PIL image.
|
110 |
+
"""
|
111 |
+
if images.ndim == 3:
|
112 |
+
images = images[None, ...]
|
113 |
+
images = (images * 255).round().astype("uint8")
|
114 |
+
if images.shape[-1] == 1:
|
115 |
+
# special case for grayscale (single channel) images
|
116 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
117 |
+
else:
|
118 |
+
pil_images = [Image.fromarray(image) for image in images]
|
119 |
+
|
120 |
+
return pil_images
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
124 |
+
"""
|
125 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
126 |
+
"""
|
127 |
+
if not isinstance(images, list):
|
128 |
+
images = [images]
|
129 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
130 |
+
images = np.stack(images, axis=0)
|
131 |
+
|
132 |
+
return images
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
136 |
+
"""
|
137 |
+
Convert a NumPy image to a PyTorch tensor.
|
138 |
+
"""
|
139 |
+
if images.ndim == 3:
|
140 |
+
images = images[..., None]
|
141 |
+
|
142 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
143 |
+
return images
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
147 |
+
"""
|
148 |
+
Convert a PyTorch tensor to a NumPy image.
|
149 |
+
"""
|
150 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
151 |
+
return images
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
155 |
+
"""
|
156 |
+
Normalize an image array to [-1,1].
|
157 |
+
"""
|
158 |
+
return 2.0 * images - 1.0
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
162 |
+
"""
|
163 |
+
Denormalize an image array to [0,1].
|
164 |
+
"""
|
165 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
169 |
+
"""
|
170 |
+
Converts a PIL image to RGB format.
|
171 |
+
"""
|
172 |
+
image = image.convert("RGB")
|
173 |
+
|
174 |
+
return image
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
178 |
+
"""
|
179 |
+
Converts a PIL image to grayscale format.
|
180 |
+
"""
|
181 |
+
image = image.convert("L")
|
182 |
+
|
183 |
+
return image
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
187 |
+
"""
|
188 |
+
Applies Gaussian blur to an image.
|
189 |
+
"""
|
190 |
+
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
191 |
+
|
192 |
+
return image
|
193 |
+
|
194 |
+
@staticmethod
|
195 |
+
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
196 |
+
"""
|
197 |
+
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
198 |
+
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
|
199 |
+
processing are 512x512, the region will be expanded to 128x128.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
mask_image (PIL.Image.Image): Mask image.
|
203 |
+
width (int): Width of the image to be processed.
|
204 |
+
height (int): Height of the image to be processed.
|
205 |
+
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
|
209 |
+
matches the original aspect ratio.
|
210 |
+
"""
|
211 |
+
|
212 |
+
mask_image = mask_image.convert("L")
|
213 |
+
mask = np.array(mask_image)
|
214 |
+
|
215 |
+
# 1. find a rectangular region that contains all masked ares in an image
|
216 |
+
h, w = mask.shape
|
217 |
+
crop_left = 0
|
218 |
+
for i in range(w):
|
219 |
+
if not (mask[:, i] == 0).all():
|
220 |
+
break
|
221 |
+
crop_left += 1
|
222 |
+
|
223 |
+
crop_right = 0
|
224 |
+
for i in reversed(range(w)):
|
225 |
+
if not (mask[:, i] == 0).all():
|
226 |
+
break
|
227 |
+
crop_right += 1
|
228 |
+
|
229 |
+
crop_top = 0
|
230 |
+
for i in range(h):
|
231 |
+
if not (mask[i] == 0).all():
|
232 |
+
break
|
233 |
+
crop_top += 1
|
234 |
+
|
235 |
+
crop_bottom = 0
|
236 |
+
for i in reversed(range(h)):
|
237 |
+
if not (mask[i] == 0).all():
|
238 |
+
break
|
239 |
+
crop_bottom += 1
|
240 |
+
|
241 |
+
# 2. add padding to the crop region
|
242 |
+
x1, y1, x2, y2 = (
|
243 |
+
int(max(crop_left - pad, 0)),
|
244 |
+
int(max(crop_top - pad, 0)),
|
245 |
+
int(min(w - crop_right + pad, w)),
|
246 |
+
int(min(h - crop_bottom + pad, h)),
|
247 |
+
)
|
248 |
+
|
249 |
+
# 3. expands crop region to match the aspect ratio of the image to be processed
|
250 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
251 |
+
ratio_processing = width / height
|
252 |
+
|
253 |
+
if ratio_crop_region > ratio_processing:
|
254 |
+
desired_height = (x2 - x1) / ratio_processing
|
255 |
+
desired_height_diff = int(desired_height - (y2 - y1))
|
256 |
+
y1 -= desired_height_diff // 2
|
257 |
+
y2 += desired_height_diff - desired_height_diff // 2
|
258 |
+
if y2 >= mask_image.height:
|
259 |
+
diff = y2 - mask_image.height
|
260 |
+
y2 -= diff
|
261 |
+
y1 -= diff
|
262 |
+
if y1 < 0:
|
263 |
+
y2 -= y1
|
264 |
+
y1 -= y1
|
265 |
+
if y2 >= mask_image.height:
|
266 |
+
y2 = mask_image.height
|
267 |
+
else:
|
268 |
+
desired_width = (y2 - y1) * ratio_processing
|
269 |
+
desired_width_diff = int(desired_width - (x2 - x1))
|
270 |
+
x1 -= desired_width_diff // 2
|
271 |
+
x2 += desired_width_diff - desired_width_diff // 2
|
272 |
+
if x2 >= mask_image.width:
|
273 |
+
diff = x2 - mask_image.width
|
274 |
+
x2 -= diff
|
275 |
+
x1 -= diff
|
276 |
+
if x1 < 0:
|
277 |
+
x2 -= x1
|
278 |
+
x1 -= x1
|
279 |
+
if x2 >= mask_image.width:
|
280 |
+
x2 = mask_image.width
|
281 |
+
|
282 |
+
return x1, y1, x2, y2
|
283 |
+
|
284 |
+
def _resize_and_fill(
|
285 |
+
self,
|
286 |
+
image: PIL.Image.Image,
|
287 |
+
width: int,
|
288 |
+
height: int,
|
289 |
+
) -> PIL.Image.Image:
|
290 |
+
"""
|
291 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
292 |
+
the image within the dimensions, filling empty with data from image.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
image: The image to resize.
|
296 |
+
width: The width to resize the image to.
|
297 |
+
height: The height to resize the image to.
|
298 |
+
"""
|
299 |
+
|
300 |
+
ratio = width / height
|
301 |
+
src_ratio = image.width / image.height
|
302 |
+
|
303 |
+
src_w = width if ratio < src_ratio else image.width * height // image.height
|
304 |
+
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
305 |
+
|
306 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
307 |
+
res = Image.new("RGB", (width, height))
|
308 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
309 |
+
|
310 |
+
if ratio < src_ratio:
|
311 |
+
fill_height = height // 2 - src_h // 2
|
312 |
+
if fill_height > 0:
|
313 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
314 |
+
res.paste(
|
315 |
+
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
316 |
+
box=(0, fill_height + src_h),
|
317 |
+
)
|
318 |
+
elif ratio > src_ratio:
|
319 |
+
fill_width = width // 2 - src_w // 2
|
320 |
+
if fill_width > 0:
|
321 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
322 |
+
res.paste(
|
323 |
+
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
324 |
+
box=(fill_width + src_w, 0),
|
325 |
+
)
|
326 |
+
|
327 |
+
return res
|
328 |
+
|
329 |
+
def _resize_and_crop(
|
330 |
+
self,
|
331 |
+
image: PIL.Image.Image,
|
332 |
+
width: int,
|
333 |
+
height: int,
|
334 |
+
) -> PIL.Image.Image:
|
335 |
+
"""
|
336 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
337 |
+
the image within the dimensions, cropping the excess.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
image: The image to resize.
|
341 |
+
width: The width to resize the image to.
|
342 |
+
height: The height to resize the image to.
|
343 |
+
"""
|
344 |
+
ratio = width / height
|
345 |
+
src_ratio = image.width / image.height
|
346 |
+
|
347 |
+
src_w = width if ratio > src_ratio else image.width * height // image.height
|
348 |
+
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
349 |
+
|
350 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
351 |
+
res = Image.new("RGB", (width, height))
|
352 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
353 |
+
return res
|
354 |
+
|
355 |
+
def resize(
|
356 |
+
self,
|
357 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
358 |
+
height: int,
|
359 |
+
width: int,
|
360 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
361 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
362 |
+
"""
|
363 |
+
Resize image.
|
364 |
+
|
365 |
+
Args:
|
366 |
+
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
367 |
+
The image input, can be a PIL image, numpy array or pytorch tensor.
|
368 |
+
height (`int`):
|
369 |
+
The height to resize to.
|
370 |
+
width (`int`):
|
371 |
+
The width to resize to.
|
372 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
373 |
+
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
374 |
+
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
|
375 |
+
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
|
376 |
+
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
|
377 |
+
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
378 |
+
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
379 |
+
supported for PIL image input.
|
380 |
+
|
381 |
+
Returns:
|
382 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
383 |
+
The resized image.
|
384 |
+
"""
|
385 |
+
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
|
386 |
+
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
|
387 |
+
if isinstance(image, PIL.Image.Image):
|
388 |
+
if resize_mode == "default":
|
389 |
+
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
390 |
+
elif resize_mode == "fill":
|
391 |
+
image = self._resize_and_fill(image, width, height)
|
392 |
+
elif resize_mode == "crop":
|
393 |
+
image = self._resize_and_crop(image, width, height)
|
394 |
+
else:
|
395 |
+
raise ValueError(f"resize_mode {resize_mode} is not supported")
|
396 |
+
|
397 |
+
elif isinstance(image, torch.Tensor):
|
398 |
+
image = torch.nn.functional.interpolate(
|
399 |
+
image,
|
400 |
+
size=(height, width),
|
401 |
+
)
|
402 |
+
elif isinstance(image, np.ndarray):
|
403 |
+
image = self.numpy_to_pt(image)
|
404 |
+
image = torch.nn.functional.interpolate(
|
405 |
+
image,
|
406 |
+
size=(height, width),
|
407 |
+
)
|
408 |
+
image = self.pt_to_numpy(image)
|
409 |
+
return image
|
410 |
+
|
411 |
+
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
412 |
+
"""
|
413 |
+
Create a mask.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
image (`PIL.Image.Image`):
|
417 |
+
The image input, should be a PIL image.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
`PIL.Image.Image`:
|
421 |
+
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
422 |
+
"""
|
423 |
+
image[image < 0.5] = 0
|
424 |
+
image[image >= 0.5] = 1
|
425 |
+
|
426 |
+
return image
|
427 |
+
|
428 |
+
def get_default_height_width(
|
429 |
+
self,
|
430 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
431 |
+
height: Optional[int] = None,
|
432 |
+
width: Optional[int] = None,
|
433 |
+
) -> Tuple[int, int]:
|
434 |
+
"""
|
435 |
+
This function return the height and width that are downscaled to the next integer multiple of
|
436 |
+
`vae_scale_factor`.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
440 |
+
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
441 |
+
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
442 |
+
have shape `[batch, channel, height, width]`.
|
443 |
+
height (`int`, *optional*, defaults to `None`):
|
444 |
+
The height in preprocessed image. If `None`, will use the height of `image` input.
|
445 |
+
width (`int`, *optional*`, defaults to `None`):
|
446 |
+
The width in preprocessed. If `None`, will use the width of the `image` input.
|
447 |
+
"""
|
448 |
+
|
449 |
+
if height is None:
|
450 |
+
if isinstance(image, PIL.Image.Image):
|
451 |
+
height = image.height
|
452 |
+
elif isinstance(image, torch.Tensor):
|
453 |
+
height = image.shape[2]
|
454 |
+
else:
|
455 |
+
height = image.shape[1]
|
456 |
+
|
457 |
+
if width is None:
|
458 |
+
if isinstance(image, PIL.Image.Image):
|
459 |
+
width = image.width
|
460 |
+
elif isinstance(image, torch.Tensor):
|
461 |
+
width = image.shape[3]
|
462 |
+
else:
|
463 |
+
width = image.shape[2]
|
464 |
+
|
465 |
+
width, height = (
|
466 |
+
x - x % self.config.vae_scale_factor for x in (width, height)
|
467 |
+
) # resize to integer multiple of vae_scale_factor
|
468 |
+
|
469 |
+
return height, width
|
470 |
+
|
471 |
+
def preprocess(
|
472 |
+
self,
|
473 |
+
image: PipelineImageInput,
|
474 |
+
height: Optional[int] = None,
|
475 |
+
width: Optional[int] = None,
|
476 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
477 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
478 |
+
do_crop: bool = True,
|
479 |
+
) -> torch.Tensor:
|
480 |
+
"""
|
481 |
+
Preprocess the image input.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
image (`pipeline_image_input`):
|
485 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
486 |
+
supported formats.
|
487 |
+
height (`int`, *optional*, defaults to `None`):
|
488 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
489 |
+
height.
|
490 |
+
width (`int`, *optional*`, defaults to `None`):
|
491 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
492 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
493 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
494 |
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
495 |
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
496 |
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
497 |
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
498 |
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
499 |
+
supported for PIL image input.
|
500 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
501 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
502 |
+
"""
|
503 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
504 |
+
|
505 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
506 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
507 |
+
if isinstance(image, torch.Tensor):
|
508 |
+
# if image is a pytorch tensor could have 2 possible shapes:
|
509 |
+
# 1. batch x height x width: we should insert the channel dimension at position 1
|
510 |
+
# 2. channel x height x width: we should insert batch dimension at position 0,
|
511 |
+
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
512 |
+
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
513 |
+
image = image.unsqueeze(1)
|
514 |
+
else:
|
515 |
+
# if it is a numpy array, it could have 2 possible shapes:
|
516 |
+
# 1. batch x height x width: insert channel dimension on last position
|
517 |
+
# 2. height x width x channel: insert batch dimension on first position
|
518 |
+
if image.shape[-1] == 1:
|
519 |
+
image = np.expand_dims(image, axis=0)
|
520 |
+
else:
|
521 |
+
image = np.expand_dims(image, axis=-1)
|
522 |
+
|
523 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
524 |
+
warnings.warn(
|
525 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
526 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
527 |
+
FutureWarning,
|
528 |
+
)
|
529 |
+
image = np.concatenate(image, axis=0)
|
530 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
531 |
+
warnings.warn(
|
532 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
533 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
534 |
+
FutureWarning,
|
535 |
+
)
|
536 |
+
image = torch.cat(image, axis=0)
|
537 |
+
|
538 |
+
if not is_valid_image_imagelist(image):
|
539 |
+
raise ValueError(
|
540 |
+
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
541 |
+
)
|
542 |
+
if not isinstance(image, list):
|
543 |
+
image = [image]
|
544 |
+
|
545 |
+
if isinstance(image[0], PIL.Image.Image):
|
546 |
+
pass
|
547 |
+
elif isinstance(image[0], np.ndarray):
|
548 |
+
image = self.numpy_to_pil(image)
|
549 |
+
elif isinstance(image[0], torch.Tensor):
|
550 |
+
image = self.pt_to_numpy(image)
|
551 |
+
image = self.numpy_to_pil(image)
|
552 |
+
|
553 |
+
if do_crop:
|
554 |
+
transforms = T.Compose([
|
555 |
+
T.Lambda(lambda image: image.convert('RGB')),
|
556 |
+
T.ToTensor(),
|
557 |
+
CenterCropResizeImage((height, width)),
|
558 |
+
T.Normalize([.5], [.5]),
|
559 |
+
])
|
560 |
+
else:
|
561 |
+
transforms = T.Compose([
|
562 |
+
T.Lambda(lambda image: image.convert('RGB')),
|
563 |
+
T.ToTensor(),
|
564 |
+
T.Resize((height, width)),
|
565 |
+
T.Normalize([.5], [.5]),
|
566 |
+
])
|
567 |
+
image = torch.stack([transforms(i) for i in image])
|
568 |
+
|
569 |
+
# expected range [0,1], normalize to [-1,1]
|
570 |
+
do_normalize = self.config.do_normalize
|
571 |
+
if do_normalize and image.min() < 0:
|
572 |
+
warnings.warn(
|
573 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
574 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
575 |
+
FutureWarning,
|
576 |
+
)
|
577 |
+
do_normalize = False
|
578 |
+
if do_normalize:
|
579 |
+
image = self.normalize(image)
|
580 |
+
|
581 |
+
if self.config.do_binarize:
|
582 |
+
image = self.binarize(image)
|
583 |
+
|
584 |
+
return image
|
585 |
+
|
586 |
+
def postprocess(
|
587 |
+
self,
|
588 |
+
image: torch.Tensor,
|
589 |
+
output_type: str = "pil",
|
590 |
+
do_denormalize: Optional[List[bool]] = None,
|
591 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
592 |
+
"""
|
593 |
+
Postprocess the image output from tensor to `output_type`.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
image (`torch.Tensor`):
|
597 |
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
598 |
+
output_type (`str`, *optional*, defaults to `pil`):
|
599 |
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
600 |
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
601 |
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
602 |
+
`VaeImageProcessor` config.
|
603 |
+
|
604 |
+
Returns:
|
605 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
606 |
+
The postprocessed image.
|
607 |
+
"""
|
608 |
+
if not isinstance(image, torch.Tensor):
|
609 |
+
raise ValueError(
|
610 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
611 |
+
)
|
612 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
613 |
+
deprecation_message = (
|
614 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
615 |
+
"`pil`, `np`, `pt`, `latent`"
|
616 |
+
)
|
617 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
618 |
+
output_type = "np"
|
619 |
+
|
620 |
+
if output_type == "latent":
|
621 |
+
return image
|
622 |
+
|
623 |
+
if do_denormalize is None:
|
624 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
625 |
+
|
626 |
+
image = torch.stack(
|
627 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
628 |
+
)
|
629 |
+
|
630 |
+
if output_type == "pt":
|
631 |
+
return image
|
632 |
+
|
633 |
+
image = self.pt_to_numpy(image)
|
634 |
+
|
635 |
+
if output_type == "np":
|
636 |
+
return image
|
637 |
+
|
638 |
+
if output_type == "pil":
|
639 |
+
return self.numpy_to_pil(image)
|
640 |
+
|
641 |
+
def apply_overlay(
|
642 |
+
self,
|
643 |
+
mask: PIL.Image.Image,
|
644 |
+
init_image: PIL.Image.Image,
|
645 |
+
image: PIL.Image.Image,
|
646 |
+
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
647 |
+
) -> PIL.Image.Image:
|
648 |
+
"""
|
649 |
+
overlay the inpaint output to the original image
|
650 |
+
"""
|
651 |
+
|
652 |
+
width, height = image.width, image.height
|
653 |
+
|
654 |
+
init_image = self.resize(init_image, width=width, height=height)
|
655 |
+
mask = self.resize(mask, width=width, height=height)
|
656 |
+
|
657 |
+
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
658 |
+
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
|
659 |
+
init_image_masked = init_image_masked.convert("RGBA")
|
660 |
+
|
661 |
+
if crop_coords is not None:
|
662 |
+
x, y, x2, y2 = crop_coords
|
663 |
+
w = x2 - x
|
664 |
+
h = y2 - y
|
665 |
+
base_image = PIL.Image.new("RGBA", (width, height))
|
666 |
+
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
667 |
+
base_image.paste(image, (x, y))
|
668 |
+
image = base_image.convert("RGB")
|
669 |
+
|
670 |
+
image = image.convert("RGBA")
|
671 |
+
image.alpha_composite(init_image_masked)
|
672 |
+
image = image.convert("RGB")
|
673 |
+
|
674 |
+
return image
|
onediffusion/diffusion/pipelines/onediffusion.py
ADDED
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
import inspect
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import os
|
7 |
+
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
11 |
+
from diffusers.utils import (
|
12 |
+
CONFIG_NAME,
|
13 |
+
DEPRECATED_REVISION_ARGS,
|
14 |
+
BaseOutput,
|
15 |
+
PushToHubMixin,
|
16 |
+
deprecate,
|
17 |
+
is_accelerate_available,
|
18 |
+
is_accelerate_version,
|
19 |
+
is_torch_npu_available,
|
20 |
+
is_torch_version,
|
21 |
+
logging,
|
22 |
+
numpy_to_pil,
|
23 |
+
replace_example_docstring,
|
24 |
+
)
|
25 |
+
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
|
26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
# from diffusers.image_processor import VaeImageProcessor
|
29 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
30 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
31 |
+
from PIL import Image
|
32 |
+
|
33 |
+
from onediffusion.models.denoiser.nextdit import NextDiT
|
34 |
+
from onediffusion.dataset.utils import *
|
35 |
+
from onediffusion.dataset.multitask.multiview import calculate_rays
|
36 |
+
from onediffusion.diffusion.pipelines.image_processor import VaeImageProcessorOneDiffuser
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
SUPPORTED_DEVICE_MAP = ["balanced"]
|
41 |
+
|
42 |
+
EXAMPLE_DOC_STRING = """
|
43 |
+
Examples:
|
44 |
+
```py
|
45 |
+
>>> import torch
|
46 |
+
>>> from one_diffusion import OneDiffusionPipeline
|
47 |
+
|
48 |
+
>>> pipe = OneDiffusionPipeline.from_pretrained("path_to_one_diffuser_model")
|
49 |
+
>>> pipe = pipe.to("cuda")
|
50 |
+
|
51 |
+
>>> prompt = "A beautiful sunset over the ocean"
|
52 |
+
>>> image = pipe(prompt).images[0]
|
53 |
+
>>> image.save("beautiful_sunset.png")
|
54 |
+
```
|
55 |
+
"""
|
56 |
+
|
57 |
+
def create_c2w_matrix(azimuth_deg, elevation_deg, distance=1.0, target=np.array([0, 0, 0])):
|
58 |
+
"""
|
59 |
+
Create a Camera-to-World (C2W) matrix from azimuth and elevation angles.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
- azimuth_deg: Azimuth angle in degrees.
|
63 |
+
- elevation_deg: Elevation angle in degrees.
|
64 |
+
- distance: Distance from the target point.
|
65 |
+
- target: The point the camera is looking at in world coordinates.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
- C2W: A 4x4 NumPy array representing the Camera-to-World transformation matrix.
|
69 |
+
"""
|
70 |
+
# Convert angles from degrees to radians
|
71 |
+
azimuth = np.deg2rad(azimuth_deg)
|
72 |
+
elevation = np.deg2rad(elevation_deg)
|
73 |
+
|
74 |
+
# Spherical to Cartesian conversion for camera position
|
75 |
+
x = distance * np.cos(elevation) * np.cos(azimuth)
|
76 |
+
y = distance * np.cos(elevation) * np.sin(azimuth)
|
77 |
+
z = distance * np.sin(elevation)
|
78 |
+
camera_position = np.array([x, y, z])
|
79 |
+
|
80 |
+
# Define the forward vector (from camera to target)
|
81 |
+
target = 2*camera_position - target
|
82 |
+
forward = target - camera_position
|
83 |
+
forward /= np.linalg.norm(forward)
|
84 |
+
|
85 |
+
# Define the world up vector
|
86 |
+
world_up = np.array([0, 0, 1])
|
87 |
+
|
88 |
+
# Compute the right vector
|
89 |
+
right = np.cross(world_up, forward)
|
90 |
+
if np.linalg.norm(right) < 1e-6:
|
91 |
+
# Handle the singularity when forward is parallel to world_up
|
92 |
+
world_up = np.array([0, 1, 0])
|
93 |
+
right = np.cross(world_up, forward)
|
94 |
+
right /= np.linalg.norm(right)
|
95 |
+
|
96 |
+
# Recompute the orthogonal up vector
|
97 |
+
up = np.cross(forward, right)
|
98 |
+
|
99 |
+
# Construct the rotation matrix
|
100 |
+
rotation = np.vstack([right, up, forward]).T # 3x3
|
101 |
+
|
102 |
+
# Construct the full C2W matrix
|
103 |
+
C2W = np.eye(4)
|
104 |
+
C2W[:3, :3] = rotation
|
105 |
+
C2W[:3, 3] = camera_position
|
106 |
+
|
107 |
+
return C2W
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class OneDiffusionPipelineOutput(BaseOutput):
|
111 |
+
"""
|
112 |
+
Output class for Stable Diffusion pipelines.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
116 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
117 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
118 |
+
"""
|
119 |
+
|
120 |
+
images: Union[List[Image.Image], np.ndarray]
|
121 |
+
latents: Optional[torch.Tensor] = None
|
122 |
+
|
123 |
+
|
124 |
+
def retrieve_latents(
|
125 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
126 |
+
):
|
127 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
128 |
+
return encoder_output.latent_dist.sample(generator)
|
129 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
130 |
+
return encoder_output.latent_dist.mode()
|
131 |
+
elif hasattr(encoder_output, "latents"):
|
132 |
+
return encoder_output.latents
|
133 |
+
else:
|
134 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
135 |
+
|
136 |
+
|
137 |
+
def calculate_shift(
|
138 |
+
image_seq_len,
|
139 |
+
base_seq_len: int = 256,
|
140 |
+
max_seq_len: int = 4096,
|
141 |
+
base_shift: float = 0.5,
|
142 |
+
max_shift: float = 1.16,
|
143 |
+
# max_clip: float = 1.5,
|
144 |
+
):
|
145 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) # 0.000169270833
|
146 |
+
b = base_shift - m * base_seq_len # 0.5-0.0433333332
|
147 |
+
mu = image_seq_len * m + b
|
148 |
+
# mu = min(mu, max_clip)
|
149 |
+
return mu
|
150 |
+
|
151 |
+
|
152 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
153 |
+
def retrieve_timesteps(
|
154 |
+
scheduler,
|
155 |
+
num_inference_steps: Optional[int] = None,
|
156 |
+
device: Optional[Union[str, torch.device]] = None,
|
157 |
+
timesteps: Optional[List[int]] = None,
|
158 |
+
sigmas: Optional[List[float]] = None,
|
159 |
+
**kwargs,
|
160 |
+
):
|
161 |
+
"""
|
162 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
163 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
scheduler (`SchedulerMixin`):
|
167 |
+
The scheduler to get timesteps from.
|
168 |
+
num_inference_steps (`int`):
|
169 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
170 |
+
must be `None`.
|
171 |
+
device (`str` or `torch.device`, *optional*):
|
172 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
173 |
+
timesteps (`List[int]`, *optional*):
|
174 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
175 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
176 |
+
sigmas (`List[float]`, *optional*):
|
177 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
178 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
182 |
+
second element is the number of inference steps.
|
183 |
+
"""
|
184 |
+
if timesteps is not None and sigmas is not None:
|
185 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
186 |
+
if timesteps is not None:
|
187 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
188 |
+
if not accepts_timesteps:
|
189 |
+
raise ValueError(
|
190 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
191 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
192 |
+
)
|
193 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
194 |
+
timesteps = scheduler.timesteps
|
195 |
+
num_inference_steps = len(timesteps)
|
196 |
+
elif sigmas is not None:
|
197 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
198 |
+
if not accept_sigmas:
|
199 |
+
raise ValueError(
|
200 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
201 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
202 |
+
)
|
203 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
204 |
+
timesteps = scheduler.timesteps
|
205 |
+
num_inference_steps = len(timesteps)
|
206 |
+
else:
|
207 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
208 |
+
timesteps = scheduler.timesteps
|
209 |
+
return timesteps, num_inference_steps
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
class OneDiffusionPipeline(DiffusionPipeline):
|
214 |
+
r"""
|
215 |
+
Pipeline for text-to-image generation using OneDiffuser.
|
216 |
+
|
217 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
218 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
219 |
+
|
220 |
+
Args:
|
221 |
+
transformer ([`NextDiT`]):
|
222 |
+
Conditional transformer (NextDiT) architecture to denoise the encoded image latents.
|
223 |
+
vae ([`AutoencoderKL`]):
|
224 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
225 |
+
text_encoder ([`T5EncoderModel`]):
|
226 |
+
Frozen text-encoder. OneDiffuser uses the T5 model as text encoder.
|
227 |
+
tokenizer (`T5Tokenizer`):
|
228 |
+
Tokenizer of class T5Tokenizer.
|
229 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
230 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
transformer: NextDiT,
|
236 |
+
vae: AutoencoderKL,
|
237 |
+
text_encoder: T5EncoderModel,
|
238 |
+
tokenizer: T5Tokenizer,
|
239 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
self.register_modules(
|
243 |
+
transformer=transformer,
|
244 |
+
vae=vae,
|
245 |
+
text_encoder=text_encoder,
|
246 |
+
tokenizer=tokenizer,
|
247 |
+
scheduler=scheduler,
|
248 |
+
)
|
249 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
250 |
+
self.image_processor = VaeImageProcessorOneDiffuser(vae_scale_factor=self.vae_scale_factor)
|
251 |
+
|
252 |
+
def enable_vae_slicing(self):
|
253 |
+
self.vae.enable_slicing()
|
254 |
+
|
255 |
+
def disable_vae_slicing(self):
|
256 |
+
self.vae.disable_slicing()
|
257 |
+
|
258 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
259 |
+
if is_accelerate_available():
|
260 |
+
from accelerate import cpu_offload
|
261 |
+
else:
|
262 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
263 |
+
|
264 |
+
device = torch.device(f"cuda:{gpu_id}")
|
265 |
+
|
266 |
+
for cpu_offloaded_model in [self.transformer, self.text_encoder, self.vae]:
|
267 |
+
if cpu_offloaded_model is not None:
|
268 |
+
cpu_offload(cpu_offloaded_model, device)
|
269 |
+
|
270 |
+
@property
|
271 |
+
def _execution_device(self):
|
272 |
+
if self.device != torch.device("meta") or not hasattr(self.transformer, "_hf_hook"):
|
273 |
+
return self.device
|
274 |
+
for module in self.transformer.modules():
|
275 |
+
if (
|
276 |
+
hasattr(module, "_hf_hook")
|
277 |
+
and hasattr(module._hf_hook, "execution_device")
|
278 |
+
and module._hf_hook.execution_device is not None
|
279 |
+
):
|
280 |
+
return torch.device(module._hf_hook.execution_device)
|
281 |
+
return self.device
|
282 |
+
|
283 |
+
def encode_prompt(
|
284 |
+
self,
|
285 |
+
prompt,
|
286 |
+
device,
|
287 |
+
num_images_per_prompt,
|
288 |
+
do_classifier_free_guidance,
|
289 |
+
negative_prompt=None,
|
290 |
+
max_length=300,
|
291 |
+
):
|
292 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
293 |
+
|
294 |
+
text_inputs = self.tokenizer(
|
295 |
+
prompt,
|
296 |
+
padding="max_length",
|
297 |
+
max_length=max_length,
|
298 |
+
truncation=True,
|
299 |
+
add_special_tokens=True,
|
300 |
+
return_tensors="pt",
|
301 |
+
)
|
302 |
+
text_input_ids = text_inputs.input_ids
|
303 |
+
attention_mask = text_inputs.attention_mask
|
304 |
+
|
305 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
306 |
+
|
307 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
308 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
309 |
+
logger.warning(
|
310 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
311 |
+
f" {max_length} tokens: {removed_text}"
|
312 |
+
)
|
313 |
+
|
314 |
+
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
|
315 |
+
prompt_embeds = text_encoder_output[0].to(torch.float32)
|
316 |
+
|
317 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
318 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
319 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
320 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
321 |
+
|
322 |
+
# duplicate attention mask for each generation per prompt
|
323 |
+
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
324 |
+
attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, -1)
|
325 |
+
|
326 |
+
# get unconditional embeddings for classifier free guidance
|
327 |
+
if do_classifier_free_guidance:
|
328 |
+
uncond_tokens: List[str]
|
329 |
+
if negative_prompt is None:
|
330 |
+
uncond_tokens = [""] * batch_size
|
331 |
+
elif type(prompt) is not type(negative_prompt):
|
332 |
+
raise TypeError(
|
333 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
334 |
+
f" {type(prompt)}."
|
335 |
+
)
|
336 |
+
elif isinstance(negative_prompt, str):
|
337 |
+
uncond_tokens = [negative_prompt]
|
338 |
+
elif batch_size != len(negative_prompt):
|
339 |
+
raise ValueError(
|
340 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
341 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
342 |
+
" the batch size of `prompt`."
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
uncond_tokens = negative_prompt
|
346 |
+
|
347 |
+
max_length = text_input_ids.shape[-1]
|
348 |
+
uncond_input = self.tokenizer(
|
349 |
+
uncond_tokens,
|
350 |
+
padding="max_length",
|
351 |
+
max_length=max_length,
|
352 |
+
truncation=True,
|
353 |
+
return_tensors="pt",
|
354 |
+
)
|
355 |
+
|
356 |
+
uncond_encoder_output = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device))
|
357 |
+
negative_prompt_embeds = uncond_encoder_output[0].to(torch.float32)
|
358 |
+
|
359 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
360 |
+
seq_len = negative_prompt_embeds.shape[1]
|
361 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
362 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
363 |
+
|
364 |
+
# duplicate unconditional attention mask for each generation per prompt
|
365 |
+
uncond_attention_mask = uncond_input.attention_mask.repeat(1, num_images_per_prompt)
|
366 |
+
uncond_attention_mask = uncond_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
367 |
+
|
368 |
+
# For classifier free guidance, we need to do two forward passes.
|
369 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
370 |
+
# to avoid doing two forward passes
|
371 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
372 |
+
attention_mask = torch.cat([uncond_attention_mask, attention_mask])
|
373 |
+
|
374 |
+
return prompt_embeds.to(device), attention_mask.to(device)
|
375 |
+
|
376 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
377 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
378 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
379 |
+
raise ValueError(
|
380 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
381 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
382 |
+
)
|
383 |
+
|
384 |
+
if latents is None:
|
385 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
386 |
+
else:
|
387 |
+
latents = latents.to(device)
|
388 |
+
|
389 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
390 |
+
latents = latents * self.scheduler.init_noise_sigma
|
391 |
+
return latents
|
392 |
+
|
393 |
+
@torch.no_grad()
|
394 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
395 |
+
def __call__(
|
396 |
+
self,
|
397 |
+
prompt: Union[str, List[str]] = None,
|
398 |
+
height: Optional[int] = None,
|
399 |
+
width: Optional[int] = None,
|
400 |
+
num_inference_steps: int = 50,
|
401 |
+
guidance_scale: float = 5.0,
|
402 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
403 |
+
num_images_per_prompt: Optional[int] = 1,
|
404 |
+
eta: float = 0.0,
|
405 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
406 |
+
latents: Optional[torch.FloatTensor] = None,
|
407 |
+
output_type: Optional[str] = "pil",
|
408 |
+
return_dict: bool = True,
|
409 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
410 |
+
callback_steps: int = 1,
|
411 |
+
forward_kwargs: Optional[Dict[str, Any]] = {},
|
412 |
+
**kwargs,
|
413 |
+
):
|
414 |
+
r"""
|
415 |
+
Function invoked when calling the pipeline for generation.
|
416 |
+
|
417 |
+
Args:
|
418 |
+
prompt (`str` or `List[str]`, *optional*):
|
419 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
420 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_size):
|
421 |
+
The height in pixels of the generated image.
|
422 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_size):
|
423 |
+
The width in pixels of the generated image.
|
424 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
425 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
426 |
+
expense of slower inference.
|
427 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
428 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
429 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
430 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
431 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
432 |
+
usually at the expense of lower image quality.
|
433 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
434 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
435 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
436 |
+
less than `1`).
|
437 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
438 |
+
The number of images to generate per prompt.
|
439 |
+
eta (`float`, *optional*, defaults to 0.0):
|
440 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
441 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
442 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
443 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
444 |
+
to make generation deterministic.
|
445 |
+
latents (`torch.FloatTensor`, *optional*):
|
446 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
447 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
448 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
449 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
450 |
+
The output format of the generate image. Choose between
|
451 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
452 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
453 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
454 |
+
plain tuple.
|
455 |
+
callback (`Callable`, *optional*):
|
456 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
457 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
458 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
459 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
460 |
+
called at every step.
|
461 |
+
|
462 |
+
Examples:
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
466 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
467 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
468 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
469 |
+
(nsfw) content, according to the `safety_checker`.
|
470 |
+
"""
|
471 |
+
height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
|
472 |
+
width = width or self.transformer.config.input_size[-1] * 8
|
473 |
+
|
474 |
+
# check inputs. Raise error if not correct
|
475 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
476 |
+
|
477 |
+
# define call parameters
|
478 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
479 |
+
device = self._execution_device
|
480 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
481 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf
|
482 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
483 |
+
|
484 |
+
encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
|
485 |
+
prompt,
|
486 |
+
device,
|
487 |
+
num_images_per_prompt,
|
488 |
+
do_classifier_free_guidance,
|
489 |
+
negative_prompt,
|
490 |
+
)
|
491 |
+
|
492 |
+
# set timesteps
|
493 |
+
# # self.scheduler.set_timesteps(num_inference_steps, device=device)
|
494 |
+
# timesteps = self.scheduler.timesteps
|
495 |
+
timesteps = None
|
496 |
+
|
497 |
+
# prepare latent variables
|
498 |
+
num_channels_latents = self.transformer.config.in_channels
|
499 |
+
latents = self.prepare_latents(
|
500 |
+
batch_size * num_images_per_prompt,
|
501 |
+
num_channels_latents,
|
502 |
+
height,
|
503 |
+
width,
|
504 |
+
self.dtype,
|
505 |
+
device,
|
506 |
+
generator,
|
507 |
+
latents,
|
508 |
+
)
|
509 |
+
|
510 |
+
# prepare extra step kwargs
|
511 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
512 |
+
|
513 |
+
# 5. Prepare timesteps
|
514 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
515 |
+
image_seq_len = latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
|
516 |
+
mu = calculate_shift(
|
517 |
+
image_seq_len,
|
518 |
+
self.scheduler.config.base_image_seq_len,
|
519 |
+
self.scheduler.config.max_image_seq_len,
|
520 |
+
self.scheduler.config.base_shift,
|
521 |
+
self.scheduler.config.max_shift,
|
522 |
+
)
|
523 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
524 |
+
self.scheduler,
|
525 |
+
num_inference_steps,
|
526 |
+
device,
|
527 |
+
timesteps,
|
528 |
+
sigmas,
|
529 |
+
mu=mu,
|
530 |
+
)
|
531 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
532 |
+
self._num_timesteps = len(timesteps)
|
533 |
+
|
534 |
+
# denoising loop
|
535 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
536 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
537 |
+
for i, t in enumerate(timesteps):
|
538 |
+
# expand the latents if we are doing classifier free guidance
|
539 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
540 |
+
# latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
541 |
+
|
542 |
+
# predict the noise residual
|
543 |
+
noise_pred = self.transformer(
|
544 |
+
samples=latent_model_input.to(self.dtype),
|
545 |
+
timesteps=torch.tensor([t] * latent_model_input.shape[0], device=device),
|
546 |
+
encoder_hidden_states=encoder_hidden_states.to(self.dtype),
|
547 |
+
encoder_attention_mask=encoder_attention_mask,
|
548 |
+
**forward_kwargs
|
549 |
+
)
|
550 |
+
|
551 |
+
# perform guidance
|
552 |
+
if do_classifier_free_guidance:
|
553 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
554 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
555 |
+
|
556 |
+
# compute the previous noisy sample x_t -> x_t-1
|
557 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
558 |
+
|
559 |
+
# call the callback, if provided
|
560 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
561 |
+
progress_bar.update()
|
562 |
+
if callback is not None and i % callback_steps == 0:
|
563 |
+
callback(i, t, latents)
|
564 |
+
|
565 |
+
# scale and decode the image latents with vae
|
566 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
567 |
+
if latents.ndim == 5:
|
568 |
+
latents = latents.squeeze(1)
|
569 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
570 |
+
|
571 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
572 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
573 |
+
|
574 |
+
if output_type == "pil":
|
575 |
+
image = self.numpy_to_pil(image)
|
576 |
+
|
577 |
+
if not return_dict:
|
578 |
+
return (image, None)
|
579 |
+
|
580 |
+
return OneDiffusionPipelineOutput(images=image)
|
581 |
+
|
582 |
+
@torch.no_grad()
|
583 |
+
def img2img(
|
584 |
+
self,
|
585 |
+
prompt: Union[str, List[str]] = None,
|
586 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
|
587 |
+
height: Optional[int] = None,
|
588 |
+
width: Optional[int] = None,
|
589 |
+
num_inference_steps: int = 50,
|
590 |
+
guidance_scale: float = 5.0,
|
591 |
+
denoise_mask: Optional[List[int]] = [1, 0],
|
592 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
593 |
+
num_images_per_prompt: Optional[int] = 1,
|
594 |
+
eta: float = 0.0,
|
595 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
596 |
+
latents: Optional[torch.FloatTensor] = None,
|
597 |
+
output_type: Optional[str] = "pil",
|
598 |
+
return_dict: bool = True,
|
599 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
600 |
+
callback_steps: int = 1,
|
601 |
+
do_crop: bool = True,
|
602 |
+
is_multiview: bool = False,
|
603 |
+
multiview_azimuths: Optional[List[int]] = [0, 30, 60, 90],
|
604 |
+
multiview_elevations: Optional[List[int]] = [0, 0, 0, 0],
|
605 |
+
multiview_distances: float = 1.7,
|
606 |
+
multiview_c2ws: Optional[List[torch.Tensor]] = None,
|
607 |
+
multiview_intrinsics: Optional[torch.Tensor] = None,
|
608 |
+
multiview_focal_length: float = 1.3887,
|
609 |
+
forward_kwargs: Optional[Dict[str, Any]] = {},
|
610 |
+
noise_scale: float = 1.0,
|
611 |
+
**kwargs,
|
612 |
+
):
|
613 |
+
# Convert single image to list for consistent handling
|
614 |
+
if isinstance(image, PIL.Image.Image):
|
615 |
+
image = [image]
|
616 |
+
|
617 |
+
if height is None or width is None:
|
618 |
+
closest_ar = get_closest_ratio(height=image[0].size[1], width=image[0].size[0], ratios=ASPECT_RATIO_512)
|
619 |
+
height, width = int(closest_ar[0][0]), int(closest_ar[0][1])
|
620 |
+
|
621 |
+
if not isinstance(multiview_distances, list) and not isinstance(multiview_distances, tuple):
|
622 |
+
multiview_distances = [multiview_distances] * len(multiview_azimuths)
|
623 |
+
|
624 |
+
# height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
|
625 |
+
# width = width or self.transformer.config.input_size[-1] * 8
|
626 |
+
|
627 |
+
# 1. check inputs. Raise error if not correct
|
628 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
629 |
+
|
630 |
+
# Additional input validation for image list
|
631 |
+
if not all(isinstance(img, PIL.Image.Image) for img in image):
|
632 |
+
raise ValueError("All elements in image list must be PIL.Image objects")
|
633 |
+
|
634 |
+
# 2. define call parameters
|
635 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
636 |
+
device = self._execution_device
|
637 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
638 |
+
|
639 |
+
# 3. Encode input prompt
|
640 |
+
encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
|
641 |
+
prompt,
|
642 |
+
device,
|
643 |
+
num_images_per_prompt,
|
644 |
+
do_classifier_free_guidance,
|
645 |
+
negative_prompt,
|
646 |
+
)
|
647 |
+
|
648 |
+
# 4. Preprocess all images
|
649 |
+
if image is not None and len(image) > 0:
|
650 |
+
processed_image = self.image_processor.preprocess(image, height=height, width=width, do_crop=do_crop)
|
651 |
+
else:
|
652 |
+
processed_image = None
|
653 |
+
|
654 |
+
# # Stack processed images along the sequence dimension
|
655 |
+
# if len(processed_images) > 1:
|
656 |
+
# processed_image = torch.cat(processed_images, dim=0)
|
657 |
+
# else:
|
658 |
+
# processed_image = processed_images[0]
|
659 |
+
|
660 |
+
timesteps = None
|
661 |
+
|
662 |
+
# 6. prepare latent variables
|
663 |
+
num_channels_latents = self.transformer.config.in_channels
|
664 |
+
if processed_image is not None:
|
665 |
+
cond_latents = self.prepare_latents(
|
666 |
+
batch_size * num_images_per_prompt,
|
667 |
+
num_channels_latents,
|
668 |
+
height,
|
669 |
+
width,
|
670 |
+
self.dtype,
|
671 |
+
device,
|
672 |
+
generator,
|
673 |
+
latents,
|
674 |
+
image=processed_image,
|
675 |
+
)
|
676 |
+
else:
|
677 |
+
cond_latents = None
|
678 |
+
|
679 |
+
# 7. prepare extra step kwargs
|
680 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
681 |
+
denoise_mask = torch.tensor(denoise_mask, device=device)
|
682 |
+
denoise_indices = torch.where(denoise_mask == 1)[0]
|
683 |
+
cond_indices = torch.where(denoise_mask == 0)[0]
|
684 |
+
seq_length = denoise_mask.shape[0]
|
685 |
+
|
686 |
+
latents = self.prepare_init_latents(
|
687 |
+
batch_size * num_images_per_prompt,
|
688 |
+
seq_length,
|
689 |
+
num_channels_latents,
|
690 |
+
height,
|
691 |
+
width,
|
692 |
+
self.dtype,
|
693 |
+
device,
|
694 |
+
generator,
|
695 |
+
)
|
696 |
+
|
697 |
+
# 5. Prepare timesteps
|
698 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
699 |
+
# image_seq_len = latents.shape[1] * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
|
700 |
+
image_seq_len = noise_scale * sum(denoise_mask) * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
|
701 |
+
# image_seq_len = 256
|
702 |
+
mu = calculate_shift(
|
703 |
+
image_seq_len,
|
704 |
+
self.scheduler.config.base_image_seq_len,
|
705 |
+
self.scheduler.config.max_image_seq_len,
|
706 |
+
self.scheduler.config.base_shift,
|
707 |
+
self.scheduler.config.max_shift,
|
708 |
+
)
|
709 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
710 |
+
self.scheduler,
|
711 |
+
num_inference_steps,
|
712 |
+
device,
|
713 |
+
timesteps,
|
714 |
+
sigmas,
|
715 |
+
mu=mu,
|
716 |
+
)
|
717 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
718 |
+
self._num_timesteps = len(timesteps)
|
719 |
+
|
720 |
+
if is_multiview:
|
721 |
+
cond_indices_images = [index // 2 for index in cond_indices if index % 2 == 0]
|
722 |
+
cond_indices_rays = [index // 2 for index in cond_indices if index % 2 == 1]
|
723 |
+
|
724 |
+
multiview_elevations = [element for element in multiview_elevations if element is not None]
|
725 |
+
multiview_azimuths = [element for element in multiview_azimuths if element is not None]
|
726 |
+
multiview_distances = [element for element in multiview_distances if element is not None]
|
727 |
+
|
728 |
+
if multiview_c2ws is None:
|
729 |
+
multiview_c2ws = [
|
730 |
+
torch.tensor(create_c2w_matrix(azimuth, elevation, distance)) for azimuth, elevation, distance in zip(multiview_azimuths, multiview_elevations, multiview_distances)
|
731 |
+
]
|
732 |
+
c2ws = torch.stack(multiview_c2ws).float()
|
733 |
+
else:
|
734 |
+
c2ws = torch.Tensor(multiview_c2ws).float()
|
735 |
+
|
736 |
+
c2ws[:, 0:3, 1:3] *= -1
|
737 |
+
c2ws = c2ws[:, [1, 0, 2, 3], :]
|
738 |
+
c2ws[:, 2, :] *= -1
|
739 |
+
|
740 |
+
w2cs = torch.inverse(c2ws)
|
741 |
+
if multiview_intrinsics is None:
|
742 |
+
multiview_intrinsics = torch.Tensor([[[multiview_focal_length, 0, 0.5], [0, multiview_focal_length, 0.5], [0, 0, 1]]]).repeat(c2ws.shape[0], 1, 1)
|
743 |
+
K = multiview_intrinsics
|
744 |
+
Rs = w2cs[:, :3, :3]
|
745 |
+
Ts = w2cs[:, :3, 3]
|
746 |
+
sizes = torch.Tensor([[1, 1]]).repeat(c2ws.shape[0], 1)
|
747 |
+
|
748 |
+
assert height == width
|
749 |
+
cond_rays = calculate_rays(K, sizes, Rs, Ts, height // 8)
|
750 |
+
cond_rays = cond_rays.reshape(-1, height // 8, width // 8, 6)
|
751 |
+
# padding = (0, 10)
|
752 |
+
# cond_rays = torch.nn.functional.pad(cond_rays, padding, "constant", 0)
|
753 |
+
cond_rays = torch.cat([cond_rays, cond_rays, cond_rays[..., :4]], dim=-1) * 1.658
|
754 |
+
cond_rays = cond_rays[None].repeat(batch_size * num_images_per_prompt, 1, 1, 1, 1)
|
755 |
+
cond_rays = cond_rays.permute(0, 1, 4, 2, 3)
|
756 |
+
cond_rays = cond_rays.to(device, dtype=self.dtype)
|
757 |
+
|
758 |
+
latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
|
759 |
+
if cond_latents is not None:
|
760 |
+
latents[:, cond_indices_images, 0] = cond_latents
|
761 |
+
latents[:, cond_indices_rays, 1] = cond_rays
|
762 |
+
latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
|
763 |
+
else:
|
764 |
+
if cond_latents is not None:
|
765 |
+
latents[:, cond_indices] = cond_latents
|
766 |
+
|
767 |
+
# denoising loop
|
768 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
769 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
770 |
+
for i, t in enumerate(timesteps):
|
771 |
+
# expand the latents if we are doing classifier free guidance
|
772 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
773 |
+
input_t = torch.broadcast_to(einops.repeat(torch.Tensor([t]).to(device), "1 -> 1 f 1 1 1", f=latent_model_input.shape[1]), latent_model_input.shape).clone()
|
774 |
+
|
775 |
+
if is_multiview:
|
776 |
+
input_t = einops.rearrange(input_t, "b (f n) c h w -> b f n c h w", n=2)
|
777 |
+
input_t[:, cond_indices_images, 0] = self.scheduler.timesteps[-1]
|
778 |
+
input_t[:, cond_indices_rays, 1] = self.scheduler.timesteps[-1]
|
779 |
+
input_t = einops.rearrange(input_t, "b f n c h w -> b (f n) c h w")
|
780 |
+
else:
|
781 |
+
input_t[:, cond_indices] = self.scheduler.timesteps[-1]
|
782 |
+
|
783 |
+
# predict the noise residual
|
784 |
+
noise_pred = self.transformer(
|
785 |
+
samples=latent_model_input.to(self.dtype),
|
786 |
+
timesteps=input_t,
|
787 |
+
encoder_hidden_states=encoder_hidden_states.to(self.dtype),
|
788 |
+
encoder_attention_mask=encoder_attention_mask,
|
789 |
+
**forward_kwargs
|
790 |
+
)
|
791 |
+
|
792 |
+
# perform guidance
|
793 |
+
if do_classifier_free_guidance:
|
794 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
795 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
796 |
+
|
797 |
+
# compute the previous noisy sample x_t -> x_t-1
|
798 |
+
bs, n_frame = noise_pred.shape[:2]
|
799 |
+
noise_pred = einops.rearrange(noise_pred, "b f c h w -> (b f) c h w")
|
800 |
+
latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
|
801 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
802 |
+
latents = einops.rearrange(latents, "(b f) c h w -> b f c h w", b=bs, f=n_frame)
|
803 |
+
if is_multiview:
|
804 |
+
latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
|
805 |
+
if cond_latents is not None:
|
806 |
+
latents[:, cond_indices_images, 0] = cond_latents
|
807 |
+
latents[:, cond_indices_rays, 1] = cond_rays
|
808 |
+
latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
|
809 |
+
else:
|
810 |
+
if cond_latents is not None:
|
811 |
+
latents[:, cond_indices] = cond_latents
|
812 |
+
|
813 |
+
# call the callback, if provided
|
814 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
815 |
+
progress_bar.update()
|
816 |
+
if callback is not None and i % callback_steps == 0:
|
817 |
+
callback(i, t, latents)
|
818 |
+
|
819 |
+
decoded_latents = latents / 1.658
|
820 |
+
# scale and decode the image latents with vae
|
821 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
822 |
+
if latents.ndim == 5:
|
823 |
+
latents = latents[:, denoise_indices]
|
824 |
+
latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
|
825 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
826 |
+
|
827 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
828 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
829 |
+
|
830 |
+
if output_type == "pil":
|
831 |
+
image = self.numpy_to_pil(image)
|
832 |
+
|
833 |
+
if not return_dict:
|
834 |
+
return (image, None)
|
835 |
+
|
836 |
+
return OneDiffusionPipelineOutput(images=image, latents=decoded_latents)
|
837 |
+
|
838 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
839 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
840 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
841 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
842 |
+
# and should be between [0, 1]
|
843 |
+
|
844 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
845 |
+
extra_step_kwargs = {}
|
846 |
+
if accepts_eta:
|
847 |
+
extra_step_kwargs["eta"] = eta
|
848 |
+
|
849 |
+
# check if the scheduler accepts generator
|
850 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
851 |
+
if accepts_generator:
|
852 |
+
extra_step_kwargs["generator"] = generator
|
853 |
+
return extra_step_kwargs
|
854 |
+
|
855 |
+
def check_inputs(self, prompt, height, width, callback_steps):
|
856 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
857 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
858 |
+
|
859 |
+
if height % 16 != 0 or width % 16 != 0:
|
860 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
861 |
+
|
862 |
+
if (callback_steps is None) or (
|
863 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
864 |
+
):
|
865 |
+
raise ValueError(
|
866 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
867 |
+
f" {type(callback_steps)}."
|
868 |
+
)
|
869 |
+
|
870 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
871 |
+
# get the original timestep using init_timestep
|
872 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
873 |
+
|
874 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
875 |
+
timesteps = self.scheduler.timesteps[t_start:]
|
876 |
+
|
877 |
+
return timesteps, num_inference_steps - t_start
|
878 |
+
|
879 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None):
|
880 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
881 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
882 |
+
raise ValueError(
|
883 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
884 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
885 |
+
)
|
886 |
+
|
887 |
+
if latents is None:
|
888 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
889 |
+
else:
|
890 |
+
latents = latents.to(device)
|
891 |
+
|
892 |
+
if image is None:
|
893 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
894 |
+
# latents = latents * self.scheduler.init_noise_sigma
|
895 |
+
return latents
|
896 |
+
|
897 |
+
image = image.to(device=device, dtype=dtype)
|
898 |
+
|
899 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
900 |
+
raise ValueError(
|
901 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
902 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
903 |
+
)
|
904 |
+
elif isinstance(generator, list):
|
905 |
+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
|
906 |
+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
|
907 |
+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
|
908 |
+
raise ValueError(
|
909 |
+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
|
910 |
+
)
|
911 |
+
init_latents = [
|
912 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
913 |
+
for i in range(batch_size)
|
914 |
+
]
|
915 |
+
init_latents = torch.cat(init_latents, dim=0)
|
916 |
+
else:
|
917 |
+
init_latents = retrieve_latents(self.vae.encode(image.to(self.vae.dtype)), generator=generator)
|
918 |
+
|
919 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
920 |
+
init_latents = init_latents.to(device=device, dtype=dtype)
|
921 |
+
|
922 |
+
init_latents = einops.rearrange(init_latents, "(bs views) c h w -> bs views c h w", bs=batch_size, views=init_latents.shape[0]//batch_size)
|
923 |
+
# latents = einops.rearrange(latents, "b c h w -> b 1 c h w")
|
924 |
+
# latents = torch.concat([latents, init_latents], dim=1)
|
925 |
+
return init_latents
|
926 |
+
|
927 |
+
def prepare_init_latents(self, batch_size, seq_length, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
928 |
+
shape = (batch_size, seq_length, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
929 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
930 |
+
raise ValueError(
|
931 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
932 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
933 |
+
)
|
934 |
+
|
935 |
+
if latents is None:
|
936 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
937 |
+
else:
|
938 |
+
latents = latents.to(device)
|
939 |
+
|
940 |
+
return latents
|
941 |
+
|
942 |
+
@torch.no_grad()
|
943 |
+
def generate(
|
944 |
+
self,
|
945 |
+
prompt: Union[str, List[str]],
|
946 |
+
num_inference_steps: int = 50,
|
947 |
+
guidance_scale: float = 5.0,
|
948 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
949 |
+
num_images_per_prompt: Optional[int] = 1,
|
950 |
+
height: Optional[int] = None,
|
951 |
+
width: Optional[int] = None,
|
952 |
+
eta: float = 0.0,
|
953 |
+
generator: Optional[torch.Generator] = None,
|
954 |
+
latents: Optional[torch.FloatTensor] = None,
|
955 |
+
output_type: Optional[str] = "pil",
|
956 |
+
return_dict: bool = True,
|
957 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
958 |
+
callback_steps: Optional[int] = 1,
|
959 |
+
):
|
960 |
+
"""
|
961 |
+
Function for image generation using the OneDiffusionPipeline.
|
962 |
+
"""
|
963 |
+
return self(
|
964 |
+
prompt=prompt,
|
965 |
+
num_inference_steps=num_inference_steps,
|
966 |
+
guidance_scale=guidance_scale,
|
967 |
+
negative_prompt=negative_prompt,
|
968 |
+
num_images_per_prompt=num_images_per_prompt,
|
969 |
+
height=height,
|
970 |
+
width=width,
|
971 |
+
eta=eta,
|
972 |
+
generator=generator,
|
973 |
+
latents=latents,
|
974 |
+
output_type=output_type,
|
975 |
+
return_dict=return_dict,
|
976 |
+
callback=callback,
|
977 |
+
callback_steps=callback_steps,
|
978 |
+
)
|
979 |
+
|
980 |
+
@staticmethod
|
981 |
+
def numpy_to_pil(images):
|
982 |
+
"""
|
983 |
+
Convert a numpy image or a batch of images to a PIL image.
|
984 |
+
"""
|
985 |
+
if images.ndim == 3:
|
986 |
+
images = images[None, ...]
|
987 |
+
images = (images * 255).round().astype("uint8")
|
988 |
+
if images.shape[-1] == 1:
|
989 |
+
# special case for grayscale (single channel) images
|
990 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
991 |
+
else:
|
992 |
+
pil_images = [Image.fromarray(image) for image in images]
|
993 |
+
|
994 |
+
return pil_images
|
995 |
+
|
996 |
+
@classmethod
|
997 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
998 |
+
model_path = pretrained_model_name_or_path
|
999 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
1000 |
+
force_download = kwargs.pop("force_download", False)
|
1001 |
+
proxies = kwargs.pop("proxies", None)
|
1002 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
1003 |
+
token = kwargs.pop("token", None)
|
1004 |
+
revision = kwargs.pop("revision", None)
|
1005 |
+
from_flax = kwargs.pop("from_flax", False)
|
1006 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1007 |
+
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
1008 |
+
custom_revision = kwargs.pop("custom_revision", None)
|
1009 |
+
provider = kwargs.pop("provider", None)
|
1010 |
+
sess_options = kwargs.pop("sess_options", None)
|
1011 |
+
device_map = kwargs.pop("device_map", None)
|
1012 |
+
max_memory = kwargs.pop("max_memory", None)
|
1013 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
1014 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
1015 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
1016 |
+
variant = kwargs.pop("variant", None)
|
1017 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1018 |
+
use_onnx = kwargs.pop("use_onnx", None)
|
1019 |
+
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
1020 |
+
|
1021 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
1022 |
+
low_cpu_mem_usage = False
|
1023 |
+
logger.warning(
|
1024 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
1025 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
1026 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
1027 |
+
" install accelerate\n```\n."
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
1031 |
+
raise NotImplementedError(
|
1032 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1033 |
+
" `low_cpu_mem_usage=False`."
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
1037 |
+
raise NotImplementedError(
|
1038 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1039 |
+
" `device_map=None`."
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
if device_map is not None and not is_accelerate_available():
|
1043 |
+
raise NotImplementedError(
|
1044 |
+
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
if device_map is not None and not isinstance(device_map, str):
|
1048 |
+
raise ValueError("`device_map` must be a string.")
|
1049 |
+
|
1050 |
+
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
|
1051 |
+
raise NotImplementedError(
|
1052 |
+
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
|
1053 |
+
)
|
1054 |
+
|
1055 |
+
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
|
1056 |
+
if is_accelerate_version("<", "0.28.0"):
|
1057 |
+
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
|
1058 |
+
|
1059 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
1060 |
+
raise ValueError(
|
1061 |
+
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
1062 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
transformer = NextDiT.from_pretrained(f"{model_path}", subfolder="transformer", torch_dtype=torch.float32, cache_dir=cache_dir)
|
1066 |
+
vae = AutoencoderKL.from_pretrained(f"{model_path}", subfolder="vae", cache_dir=cache_dir)
|
1067 |
+
text_encoder = T5EncoderModel.from_pretrained(f"{model_path}", subfolder="text_encoder", torch_dtype=torch.float16, cache_dir=cache_dir)
|
1068 |
+
tokenizer = T5Tokenizer.from_pretrained(model_path, subfolder="tokenizer", cache_dir=cache_dir)
|
1069 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler", cache_dir=cache_dir)
|
1070 |
+
|
1071 |
+
pipeline = cls(
|
1072 |
+
transformer=transformer,
|
1073 |
+
vae=vae,
|
1074 |
+
text_encoder=text_encoder,
|
1075 |
+
tokenizer=tokenizer,
|
1076 |
+
scheduler=scheduler,
|
1077 |
+
**kwargs
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
return pipeline
|
onediffusion/models/denoiser/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import (
|
2 |
+
nextdit
|
3 |
+
)
|
onediffusion/models/denoiser/nextdit/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .modeling_nextdit import NextDiT
|
onediffusion/models/denoiser/nextdit/layers.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from typing import Callable, Optional
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
try:
|
13 |
+
from apex.normalization import FusedRMSNorm as RMSNorm
|
14 |
+
except ImportError:
|
15 |
+
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
16 |
+
|
17 |
+
|
18 |
+
class RMSNorm(torch.nn.Module):
|
19 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
20 |
+
"""
|
21 |
+
Initialize the RMSNorm normalization layer.
|
22 |
+
Args:
|
23 |
+
dim (int): The dimension of the input tensor.
|
24 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
25 |
+
Attributes:
|
26 |
+
eps (float): A small value added to the denominator for numerical stability.
|
27 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.eps = eps
|
31 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
32 |
+
|
33 |
+
def _norm(self, x):
|
34 |
+
"""
|
35 |
+
Apply the RMSNorm normalization to the input tensor.
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): The input tensor.
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: The normalized tensor.
|
40 |
+
"""
|
41 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
"""
|
45 |
+
Forward pass through the RMSNorm layer.
|
46 |
+
Args:
|
47 |
+
x (torch.Tensor): The input tensor.
|
48 |
+
Returns:
|
49 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
50 |
+
"""
|
51 |
+
output = self._norm(x.float()).type_as(x)
|
52 |
+
return output * self.weight
|
53 |
+
|
54 |
+
|
55 |
+
def modulate(x, scale):
|
56 |
+
return x * (1 + scale.unsqueeze(1))
|
57 |
+
|
58 |
+
class LLamaFeedForward(nn.Module):
|
59 |
+
"""
|
60 |
+
Corresponds to the FeedForward layer in Next DiT.
|
61 |
+
"""
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
dim: int,
|
65 |
+
hidden_dim: int,
|
66 |
+
multiple_of: int,
|
67 |
+
ffn_dim_multiplier: Optional[float] = None,
|
68 |
+
zeros_initialize: bool = True,
|
69 |
+
dtype: torch.dtype = torch.float32,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.dim = dim
|
73 |
+
self.hidden_dim = hidden_dim
|
74 |
+
self.multiple_of = multiple_of
|
75 |
+
self.ffn_dim_multiplier = ffn_dim_multiplier
|
76 |
+
self.zeros_initialize = zeros_initialize
|
77 |
+
self.dtype = dtype
|
78 |
+
|
79 |
+
# Compute hidden_dim based on the given formula
|
80 |
+
hidden_dim_calculated = int(2 * self.hidden_dim / 3)
|
81 |
+
if self.ffn_dim_multiplier is not None:
|
82 |
+
hidden_dim_calculated = int(self.ffn_dim_multiplier * hidden_dim_calculated)
|
83 |
+
hidden_dim_calculated = self.multiple_of * ((hidden_dim_calculated + self.multiple_of - 1) // self.multiple_of)
|
84 |
+
|
85 |
+
# Define linear layers
|
86 |
+
self.w1 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
|
87 |
+
self.w2 = nn.Linear(hidden_dim_calculated, self.dim, bias=False)
|
88 |
+
self.w3 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
|
89 |
+
|
90 |
+
# Initialize weights
|
91 |
+
if self.zeros_initialize:
|
92 |
+
nn.init.zeros_(self.w2.weight)
|
93 |
+
else:
|
94 |
+
nn.init.xavier_uniform_(self.w2.weight)
|
95 |
+
nn.init.xavier_uniform_(self.w1.weight)
|
96 |
+
nn.init.xavier_uniform_(self.w3.weight)
|
97 |
+
|
98 |
+
def _forward_silu_gating(self, x1, x3):
|
99 |
+
return F.silu(x1) * x3
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
103 |
+
|
104 |
+
class FinalLayer(nn.Module):
|
105 |
+
"""
|
106 |
+
The final layer of Next-DiT.
|
107 |
+
"""
|
108 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
109 |
+
super().__init__()
|
110 |
+
self.hidden_size = hidden_size
|
111 |
+
self.patch_size = patch_size
|
112 |
+
self.out_channels = out_channels
|
113 |
+
|
114 |
+
# LayerNorm without learnable parameters (elementwise_affine=False)
|
115 |
+
self.norm_final = nn.LayerNorm(self.hidden_size, eps=1e-6, elementwise_affine=False)
|
116 |
+
self.linear = nn.Linear(self.hidden_size, np.prod(self.patch_size) * self.out_channels, bias=True)
|
117 |
+
nn.init.zeros_(self.linear.weight)
|
118 |
+
nn.init.zeros_(self.linear.bias)
|
119 |
+
|
120 |
+
self.adaLN_modulation = nn.Sequential(
|
121 |
+
nn.SiLU(),
|
122 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
123 |
+
)
|
124 |
+
# Initialize the last layer with zeros
|
125 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
126 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
127 |
+
|
128 |
+
def forward(self, x, c):
|
129 |
+
scale = self.adaLN_modulation(c)
|
130 |
+
x = modulate(self.norm_final(x), scale)
|
131 |
+
x = self.linear(x)
|
132 |
+
return x
|
onediffusion/models/denoiser/nextdit/modeling_nextdit.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import einops
|
7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
9 |
+
from typing import Any, Tuple, Optional
|
10 |
+
from flash_attn import flash_attn_varlen_func
|
11 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
12 |
+
|
13 |
+
from .layers import LLamaFeedForward, RMSNorm
|
14 |
+
|
15 |
+
# import frasch
|
16 |
+
|
17 |
+
|
18 |
+
def modulate(x, scale):
|
19 |
+
return x * (1 + scale)
|
20 |
+
|
21 |
+
class TimestepEmbedder(nn.Module):
|
22 |
+
"""
|
23 |
+
Embeds scalar timesteps into vector representations.
|
24 |
+
"""
|
25 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
26 |
+
super().__init__()
|
27 |
+
self.hidden_size = hidden_size
|
28 |
+
self.frequency_embedding_size = frequency_embedding_size
|
29 |
+
self.mlp = nn.Sequential(
|
30 |
+
nn.Linear(self.frequency_embedding_size, self.hidden_size),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
33 |
+
)
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def timestep_embedding(t, dim, max_period=10000):
|
37 |
+
"""
|
38 |
+
Create sinusoidal timestep embeddings.
|
39 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
40 |
+
:param dim: the dimension of the output.
|
41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
43 |
+
"""
|
44 |
+
half = dim // 2
|
45 |
+
freqs = torch.exp(
|
46 |
+
-np.log(max_period) * torch.arange(0, half, dtype=t.dtype) / half
|
47 |
+
).to(t.device)
|
48 |
+
args = t[:, :, None] * freqs[None, :]
|
49 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
50 |
+
if dim % 2:
|
51 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :, :1])], dim=-1)
|
52 |
+
return embedding
|
53 |
+
|
54 |
+
def forward(self, t):
|
55 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
56 |
+
t_freq = t_freq.to(self.mlp[0].weight.dtype)
|
57 |
+
return self.mlp(t_freq)
|
58 |
+
|
59 |
+
class FinalLayer(nn.Module):
|
60 |
+
def __init__(self, hidden_size, num_patches, out_channels):
|
61 |
+
super().__init__()
|
62 |
+
self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
|
63 |
+
self.linear = nn.Linear(hidden_size, num_patches * out_channels)
|
64 |
+
self.adaLN_modulation = nn.Sequential(
|
65 |
+
nn.SiLU(),
|
66 |
+
nn.Linear(min(hidden_size, 1024), hidden_size),
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x, c):
|
70 |
+
scale = self.adaLN_modulation(c)
|
71 |
+
x = modulate(self.norm_final(x), scale)
|
72 |
+
x = self.linear(x)
|
73 |
+
return x
|
74 |
+
|
75 |
+
class Attention(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
dim,
|
79 |
+
n_heads,
|
80 |
+
n_kv_heads=None,
|
81 |
+
qk_norm=False,
|
82 |
+
y_dim=0,
|
83 |
+
base_seqlen=None,
|
84 |
+
proportional_attn=False,
|
85 |
+
attention_dropout=0.0,
|
86 |
+
max_position_embeddings=384,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.dim = dim
|
90 |
+
self.n_heads = n_heads
|
91 |
+
self.n_kv_heads = n_kv_heads or n_heads
|
92 |
+
self.qk_norm = qk_norm
|
93 |
+
self.y_dim = y_dim
|
94 |
+
self.base_seqlen = base_seqlen
|
95 |
+
self.proportional_attn = proportional_attn
|
96 |
+
self.attention_dropout = attention_dropout
|
97 |
+
self.max_position_embeddings = max_position_embeddings
|
98 |
+
|
99 |
+
self.head_dim = dim // n_heads
|
100 |
+
|
101 |
+
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
|
102 |
+
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
103 |
+
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
104 |
+
|
105 |
+
if y_dim > 0:
|
106 |
+
self.wk_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
|
107 |
+
self.wv_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
|
108 |
+
self.gate = nn.Parameter(torch.zeros(n_heads))
|
109 |
+
|
110 |
+
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
|
111 |
+
|
112 |
+
if qk_norm:
|
113 |
+
self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
114 |
+
self.k_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim)
|
115 |
+
if y_dim > 0:
|
116 |
+
self.ky_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim, eps=1e-6)
|
117 |
+
else:
|
118 |
+
self.ky_norm = nn.Identity()
|
119 |
+
else:
|
120 |
+
self.q_norm = nn.Identity()
|
121 |
+
self.k_norm = nn.Identity()
|
122 |
+
self.ky_norm = nn.Identity()
|
123 |
+
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def apply_rotary_emb(xq, xk, freqs_cis):
|
127 |
+
# xq, xk: [batch_size, seq_len, n_heads, head_dim]
|
128 |
+
# freqs_cis: [1, seq_len, 1, head_dim]
|
129 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
|
130 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
|
131 |
+
|
132 |
+
xq_complex = torch.view_as_complex(xq_)
|
133 |
+
xk_complex = torch.view_as_complex(xk_)
|
134 |
+
|
135 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
136 |
+
|
137 |
+
# Apply freqs_cis
|
138 |
+
xq_out = xq_complex * freqs_cis
|
139 |
+
xk_out = xk_complex * freqs_cis
|
140 |
+
|
141 |
+
# Convert back to real numbers
|
142 |
+
xq_out = torch.view_as_real(xq_out).flatten(-2)
|
143 |
+
xk_out = torch.view_as_real(xk_out).flatten(-2)
|
144 |
+
|
145 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
146 |
+
|
147 |
+
# copied from huggingface modeling_llama.py
|
148 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
149 |
+
def _get_unpad_data(attention_mask):
|
150 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
151 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
152 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
153 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
154 |
+
return (
|
155 |
+
indices,
|
156 |
+
cu_seqlens,
|
157 |
+
max_seqlen_in_batch,
|
158 |
+
)
|
159 |
+
|
160 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
161 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
162 |
+
|
163 |
+
key_layer = index_first_axis(
|
164 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
165 |
+
indices_k,
|
166 |
+
)
|
167 |
+
value_layer = index_first_axis(
|
168 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
169 |
+
indices_k,
|
170 |
+
)
|
171 |
+
if query_length == kv_seq_len:
|
172 |
+
query_layer = index_first_axis(
|
173 |
+
query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),
|
174 |
+
indices_k,
|
175 |
+
)
|
176 |
+
cu_seqlens_q = cu_seqlens_k
|
177 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
178 |
+
indices_q = indices_k
|
179 |
+
elif query_length == 1:
|
180 |
+
max_seqlen_in_batch_q = 1
|
181 |
+
cu_seqlens_q = torch.arange(
|
182 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
183 |
+
) # There is a memcpy here, that is very bad.
|
184 |
+
indices_q = cu_seqlens_q[:-1]
|
185 |
+
query_layer = query_layer.squeeze(1)
|
186 |
+
else:
|
187 |
+
# The -q_len: slice assumes left padding.
|
188 |
+
attention_mask = attention_mask[:, -query_length:]
|
189 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
190 |
+
|
191 |
+
return (
|
192 |
+
query_layer,
|
193 |
+
key_layer,
|
194 |
+
value_layer,
|
195 |
+
indices_q,
|
196 |
+
(cu_seqlens_q, cu_seqlens_k),
|
197 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
198 |
+
)
|
199 |
+
|
200 |
+
def forward(
|
201 |
+
self,
|
202 |
+
x,
|
203 |
+
x_mask,
|
204 |
+
freqs_cis,
|
205 |
+
y=None,
|
206 |
+
y_mask=None,
|
207 |
+
init_cache=False,
|
208 |
+
):
|
209 |
+
bsz, seqlen, _ = x.size()
|
210 |
+
xq = self.wq(x)
|
211 |
+
xk = self.wk(x)
|
212 |
+
xv = self.wv(x)
|
213 |
+
|
214 |
+
if x_mask is None:
|
215 |
+
x_mask = torch.ones(bsz, seqlen, dtype=torch.bool, device=x.device)
|
216 |
+
inp_dtype = xq.dtype
|
217 |
+
|
218 |
+
xq = self.q_norm(xq)
|
219 |
+
xk = self.k_norm(xk)
|
220 |
+
|
221 |
+
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
222 |
+
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
223 |
+
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
224 |
+
|
225 |
+
if self.n_kv_heads != self.n_heads:
|
226 |
+
n_rep = self.n_heads // self.n_kv_heads
|
227 |
+
xk = xk.repeat_interleave(n_rep, dim=2)
|
228 |
+
xv = xv.repeat_interleave(n_rep, dim=2)
|
229 |
+
|
230 |
+
freqs_cis = freqs_cis.to(xq.device)
|
231 |
+
xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis)
|
232 |
+
|
233 |
+
if inp_dtype in [torch.float16, torch.bfloat16]:
|
234 |
+
# begin var_len flash attn
|
235 |
+
(
|
236 |
+
query_states,
|
237 |
+
key_states,
|
238 |
+
value_states,
|
239 |
+
indices_q,
|
240 |
+
cu_seq_lens,
|
241 |
+
max_seq_lens,
|
242 |
+
) = self._upad_input(xq, xk, xv, x_mask, seqlen)
|
243 |
+
|
244 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
245 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
246 |
+
|
247 |
+
attn_output_unpad = flash_attn_varlen_func(
|
248 |
+
query_states.to(inp_dtype),
|
249 |
+
key_states.to(inp_dtype),
|
250 |
+
value_states.to(inp_dtype),
|
251 |
+
cu_seqlens_q=cu_seqlens_q,
|
252 |
+
cu_seqlens_k=cu_seqlens_k,
|
253 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
254 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
255 |
+
dropout_p=0.0,
|
256 |
+
causal=False,
|
257 |
+
softmax_scale=None,
|
258 |
+
softcap=30,
|
259 |
+
)
|
260 |
+
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
|
261 |
+
else:
|
262 |
+
output = (
|
263 |
+
F.scaled_dot_product_attention(
|
264 |
+
xq.permute(0, 2, 1, 3),
|
265 |
+
xk.permute(0, 2, 1, 3),
|
266 |
+
xv.permute(0, 2, 1, 3),
|
267 |
+
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_heads, seqlen, -1),
|
268 |
+
scale=None,
|
269 |
+
)
|
270 |
+
.permute(0, 2, 1, 3)
|
271 |
+
.to(inp_dtype)
|
272 |
+
) #ok
|
273 |
+
|
274 |
+
|
275 |
+
if hasattr(self, "wk_y"):
|
276 |
+
yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_kv_heads, self.head_dim)
|
277 |
+
yv = self.wv_y(y).view(bsz, -1, self.n_kv_heads, self.head_dim)
|
278 |
+
n_rep = self.n_heads // self.n_kv_heads
|
279 |
+
# if n_rep >= 1:
|
280 |
+
# yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
281 |
+
# yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
282 |
+
if n_rep >= 1:
|
283 |
+
yk = einops.repeat(yk, "b l h d -> b l (repeat h) d", repeat=n_rep)
|
284 |
+
yv = einops.repeat(yv, "b l h d -> b l (repeat h) d", repeat=n_rep)
|
285 |
+
output_y = F.scaled_dot_product_attention(
|
286 |
+
xq.permute(0, 2, 1, 3),
|
287 |
+
yk.permute(0, 2, 1, 3),
|
288 |
+
yv.permute(0, 2, 1, 3),
|
289 |
+
y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_heads, seqlen, -1).to(torch.bool),
|
290 |
+
).permute(0, 2, 1, 3)
|
291 |
+
output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
|
292 |
+
output = output + output_y
|
293 |
+
|
294 |
+
output = output.flatten(-2)
|
295 |
+
output = self.wo(output)
|
296 |
+
|
297 |
+
return output.to(inp_dtype)
|
298 |
+
|
299 |
+
class TransformerBlock(nn.Module):
|
300 |
+
"""
|
301 |
+
Corresponds to the Transformer block in the JAX code.
|
302 |
+
"""
|
303 |
+
def __init__(
|
304 |
+
self,
|
305 |
+
dim,
|
306 |
+
n_heads,
|
307 |
+
n_kv_heads,
|
308 |
+
multiple_of,
|
309 |
+
ffn_dim_multiplier,
|
310 |
+
norm_eps,
|
311 |
+
qk_norm,
|
312 |
+
y_dim,
|
313 |
+
max_position_embeddings,
|
314 |
+
):
|
315 |
+
super().__init__()
|
316 |
+
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim=y_dim, max_position_embeddings=max_position_embeddings)
|
317 |
+
self.feed_forward = LLamaFeedForward(
|
318 |
+
dim=dim,
|
319 |
+
hidden_dim=4 * dim,
|
320 |
+
multiple_of=multiple_of,
|
321 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
322 |
+
)
|
323 |
+
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
324 |
+
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
325 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
326 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
327 |
+
self.adaLN_modulation = nn.Sequential(
|
328 |
+
nn.SiLU(),
|
329 |
+
nn.Linear(min(dim, 1024), 4 * dim),
|
330 |
+
)
|
331 |
+
self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
x,
|
336 |
+
x_mask,
|
337 |
+
freqs_cis,
|
338 |
+
y,
|
339 |
+
y_mask,
|
340 |
+
adaln_input=None,
|
341 |
+
):
|
342 |
+
if adaln_input is not None:
|
343 |
+
scales_gates = self.adaLN_modulation(adaln_input)
|
344 |
+
# TODO: Duong - check the dimension of chunking
|
345 |
+
# scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
|
346 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
|
347 |
+
x = x + torch.tanh(gate_msa) * self.attention_norm2(
|
348 |
+
self.attention(
|
349 |
+
modulate(self.attention_norm1(x), scale_msa), # ok
|
350 |
+
x_mask,
|
351 |
+
freqs_cis,
|
352 |
+
self.attention_y_norm(y), # ok
|
353 |
+
y_mask,
|
354 |
+
)
|
355 |
+
)
|
356 |
+
x = x + torch.tanh(gate_mlp) * self.ffn_norm2(
|
357 |
+
self.feed_forward(
|
358 |
+
modulate(self.ffn_norm1(x), scale_mlp),
|
359 |
+
)
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
x = x + self.attention_norm2(
|
363 |
+
self.attention(
|
364 |
+
self.attention_norm1(x),
|
365 |
+
x_mask,
|
366 |
+
freqs_cis,
|
367 |
+
self.attention_y_norm(y),
|
368 |
+
y_mask,
|
369 |
+
)
|
370 |
+
)
|
371 |
+
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
372 |
+
return x
|
373 |
+
|
374 |
+
|
375 |
+
class NextDiT(ModelMixin, ConfigMixin):
|
376 |
+
"""
|
377 |
+
Diffusion model with a Transformer backbone for joint image-video training.
|
378 |
+
"""
|
379 |
+
@register_to_config
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
input_size=(1, 32, 32),
|
383 |
+
patch_size=(1, 2, 2),
|
384 |
+
in_channels=16,
|
385 |
+
hidden_size=4096,
|
386 |
+
depth=32,
|
387 |
+
num_heads=32,
|
388 |
+
num_kv_heads=None,
|
389 |
+
multiple_of=256,
|
390 |
+
ffn_dim_multiplier=None,
|
391 |
+
norm_eps=1e-5,
|
392 |
+
pred_sigma=False,
|
393 |
+
caption_channels=4096,
|
394 |
+
qk_norm=False,
|
395 |
+
norm_type="rms",
|
396 |
+
model_max_length=120,
|
397 |
+
rotary_max_length=384,
|
398 |
+
rotary_max_length_t=None
|
399 |
+
):
|
400 |
+
super().__init__()
|
401 |
+
self.input_size = input_size
|
402 |
+
self.patch_size = patch_size
|
403 |
+
self.in_channels = in_channels
|
404 |
+
self.hidden_size = hidden_size
|
405 |
+
self.depth = depth
|
406 |
+
self.num_heads = num_heads
|
407 |
+
self.num_kv_heads = num_kv_heads or num_heads
|
408 |
+
self.multiple_of = multiple_of
|
409 |
+
self.ffn_dim_multiplier = ffn_dim_multiplier
|
410 |
+
self.norm_eps = norm_eps
|
411 |
+
self.pred_sigma = pred_sigma
|
412 |
+
self.caption_channels = caption_channels
|
413 |
+
self.qk_norm = qk_norm
|
414 |
+
self.norm_type = norm_type
|
415 |
+
self.model_max_length = model_max_length
|
416 |
+
self.rotary_max_length = rotary_max_length
|
417 |
+
self.rotary_max_length_t = rotary_max_length_t
|
418 |
+
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
419 |
+
|
420 |
+
self.x_embedder = nn.Linear(np.prod(self.patch_size) * in_channels, hidden_size)
|
421 |
+
|
422 |
+
self.t_embedder = TimestepEmbedder(min(hidden_size, 1024))
|
423 |
+
self.y_embedder = nn.Sequential(
|
424 |
+
nn.LayerNorm(caption_channels, eps=1e-6),
|
425 |
+
nn.Linear(caption_channels, min(hidden_size, 1024)),
|
426 |
+
)
|
427 |
+
|
428 |
+
self.layers = nn.ModuleList([
|
429 |
+
TransformerBlock(
|
430 |
+
dim=hidden_size,
|
431 |
+
n_heads=num_heads,
|
432 |
+
n_kv_heads=self.num_kv_heads,
|
433 |
+
multiple_of=multiple_of,
|
434 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
435 |
+
norm_eps=norm_eps,
|
436 |
+
qk_norm=qk_norm,
|
437 |
+
y_dim=caption_channels,
|
438 |
+
max_position_embeddings=rotary_max_length,
|
439 |
+
)
|
440 |
+
for _ in range(depth)
|
441 |
+
])
|
442 |
+
|
443 |
+
self.final_layer = FinalLayer(
|
444 |
+
hidden_size=hidden_size,
|
445 |
+
num_patches=np.prod(patch_size),
|
446 |
+
out_channels=self.out_channels,
|
447 |
+
)
|
448 |
+
|
449 |
+
assert (hidden_size // num_heads) % 6 == 0, "3d rope needs head dim to be divisible by 6"
|
450 |
+
|
451 |
+
self.freqs_cis = self.precompute_freqs_cis(
|
452 |
+
hidden_size // num_heads,
|
453 |
+
self.rotary_max_length,
|
454 |
+
end_t=self.rotary_max_length_t
|
455 |
+
)
|
456 |
+
|
457 |
+
def to(self, *args, **kwargs):
|
458 |
+
self = super().to(*args, **kwargs)
|
459 |
+
# self.freqs_cis = self.freqs_cis.to(*args, **kwargs)
|
460 |
+
return self
|
461 |
+
|
462 |
+
@staticmethod
|
463 |
+
def precompute_freqs_cis(
|
464 |
+
dim: int,
|
465 |
+
end: int,
|
466 |
+
end_t: int = None,
|
467 |
+
theta: float = 10000.0,
|
468 |
+
scale_factor: float = 1.0,
|
469 |
+
scale_watershed: float = 1.0,
|
470 |
+
timestep: float = 1.0,
|
471 |
+
):
|
472 |
+
if timestep < scale_watershed:
|
473 |
+
linear_factor = scale_factor
|
474 |
+
ntk_factor = 1.0
|
475 |
+
else:
|
476 |
+
linear_factor = 1.0
|
477 |
+
ntk_factor = scale_factor
|
478 |
+
|
479 |
+
theta = theta * ntk_factor
|
480 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
|
481 |
+
|
482 |
+
timestep = torch.arange(end, dtype=torch.float32)
|
483 |
+
freqs = torch.outer(timestep, freqs).float()
|
484 |
+
freqs_cis = torch.exp(1j * freqs)
|
485 |
+
|
486 |
+
if end_t is not None:
|
487 |
+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
|
488 |
+
timestep_t = torch.arange(end_t, dtype=torch.float32)
|
489 |
+
freqs_t = torch.outer(timestep_t, freqs_t).float()
|
490 |
+
freqs_cis_t = torch.exp(1j * freqs_t)
|
491 |
+
freqs_cis_t = freqs_cis_t.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
|
492 |
+
else:
|
493 |
+
end_t = end
|
494 |
+
freqs_cis_t = freqs_cis.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
|
495 |
+
|
496 |
+
freqs_cis_h = freqs_cis.view(1, end, 1, dim // 6).repeat(end_t, 1, end, 1)
|
497 |
+
freqs_cis_w = freqs_cis.view(1, 1, end, dim // 6).repeat(end_t, end, 1, 1)
|
498 |
+
freqs_cis = torch.cat([freqs_cis_t, freqs_cis_h, freqs_cis_w], dim=-1).view(end_t, end, end, -1)
|
499 |
+
return freqs_cis
|
500 |
+
|
501 |
+
def forward(
|
502 |
+
self,
|
503 |
+
samples,
|
504 |
+
timesteps,
|
505 |
+
encoder_hidden_states,
|
506 |
+
encoder_attention_mask,
|
507 |
+
scale_factor: float = 1.0, # scale_factor for rotary embedding
|
508 |
+
scale_watershed: float = 1.0, # scale_watershed for rotary embedding
|
509 |
+
):
|
510 |
+
if samples.ndim == 4: # B C H W
|
511 |
+
samples = samples[:, None, ...] # B F C H W
|
512 |
+
|
513 |
+
precomputed_freqs_cis = None
|
514 |
+
if scale_factor != 1 or scale_watershed != 1:
|
515 |
+
precomputed_freqs_cis = self.precompute_freqs_cis(
|
516 |
+
self.hidden_size // self.num_heads,
|
517 |
+
self.rotary_max_length,
|
518 |
+
end_t=self.rotary_max_length_t,
|
519 |
+
scale_factor=scale_factor,
|
520 |
+
scale_watershed=scale_watershed,
|
521 |
+
timestep=torch.max(timesteps.cpu()).item()
|
522 |
+
)
|
523 |
+
|
524 |
+
if len(timesteps.shape) == 5:
|
525 |
+
t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
|
526 |
+
timesteps = t.mean(dim=-1)
|
527 |
+
elif len(timesteps.shape) == 1:
|
528 |
+
timesteps = timesteps[:, None, None, None, None].expand_as(samples)
|
529 |
+
t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
|
530 |
+
timesteps = t.mean(dim=-1)
|
531 |
+
samples, T, H, W, freqs_cis = self.patchify(samples, precomputed_freqs_cis)
|
532 |
+
samples = self.x_embedder(samples)
|
533 |
+
t = self.t_embedder(timesteps)
|
534 |
+
|
535 |
+
encoder_attention_mask_float = encoder_attention_mask[..., None].float()
|
536 |
+
encoder_hidden_states_pool = (encoder_hidden_states * encoder_attention_mask_float).sum(dim=1) / (encoder_attention_mask_float.sum(dim=1) + 1e-8)
|
537 |
+
encoder_hidden_states_pool = encoder_hidden_states_pool.to(samples.dtype)
|
538 |
+
y = self.y_embedder(encoder_hidden_states_pool)
|
539 |
+
y = y.unsqueeze(1).expand(-1, samples.size(1), -1)
|
540 |
+
|
541 |
+
adaln_input = t + y
|
542 |
+
|
543 |
+
for block in self.layers:
|
544 |
+
samples = block(samples, None, freqs_cis, encoder_hidden_states, encoder_attention_mask, adaln_input)
|
545 |
+
|
546 |
+
samples = self.final_layer(samples, adaln_input)
|
547 |
+
samples = self.unpatchify(samples, T, H, W)
|
548 |
+
|
549 |
+
return samples
|
550 |
+
|
551 |
+
def patchify(self, x, precompute_freqs_cis=None):
|
552 |
+
# pytorch is C, H, W
|
553 |
+
B, T, C, H, W = x.size()
|
554 |
+
pT, pH, pW = self.patch_size
|
555 |
+
x = x.view(B, T // pT, pT, C, H // pH, pH, W // pW, pW)
|
556 |
+
x = x.permute(0, 1, 4, 6, 2, 5, 7, 3)
|
557 |
+
x = x.reshape(B, -1, pT * pH * pW * C)
|
558 |
+
if precompute_freqs_cis is None:
|
559 |
+
freqs_cis = self.freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * self.freqs_cis.shape[3:])[None].to(x.device)
|
560 |
+
else:
|
561 |
+
freqs_cis = precompute_freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * precompute_freqs_cis.shape[3:])[None].to(x.device)
|
562 |
+
return x, T // pT, H // pH, W // pW, freqs_cis
|
563 |
+
|
564 |
+
def unpatchify(self, x, T, H, W):
|
565 |
+
B = x.size(0)
|
566 |
+
C = self.out_channels
|
567 |
+
pT, pH, pW = self.patch_size
|
568 |
+
x = x.view(B, T, H, W, pT, pH, pW, C)
|
569 |
+
x = x.permute(0, 1, 4, 7, 2, 5, 3, 6)
|
570 |
+
x = x.reshape(B, T * pT, C, H * pH, W * pW)
|
571 |
+
return x
|