Spaces:
Build error
Build error
init
Browse files- Dockerfile +30 -0
- README.md +1 -3
- app.py +821 -0
- omnigen2/.DS_Store +0 -0
- omnigen2/__init__.py +0 -0
- omnigen2/__pycache__/__init__.cpython-310.pyc +0 -0
- omnigen2/__pycache__/__init__.cpython-38.pyc +0 -0
- omnigen2/models/__init__.py +0 -0
- omnigen2/models/attention_processor.py +239 -0
- omnigen2/models/embeddings.py +126 -0
- omnigen2/models/transformers/__init__.py +3 -0
- omnigen2/models/transformers/block_lumina2.py +246 -0
- omnigen2/models/transformers/components.py +4 -0
- omnigen2/models/transformers/repo.py +129 -0
- omnigen2/models/transformers/transformer_omnigen2.py +639 -0
- omnigen2/ops/.DS_Store +0 -0
- omnigen2/ops/triton/__init__.py +0 -0
- omnigen2/ops/triton/layer_norm.py +1257 -0
- omnigen2/pipelines/__init__.py +0 -0
- omnigen2/pipelines/image_processor.py +266 -0
- omnigen2/pipelines/omnigen2/pipeline_omnigen2.py +720 -0
- omnigen2/pipelines/omnigen2/pipeline_omnigen2_chat.py +830 -0
- omnigen2/pipelines/pipeline_utils.py +62 -0
- omnigen2/schedulers/__init__.py +0 -0
- omnigen2/schedulers/scheduling_dpmsolver_multistep.py +1052 -0
- omnigen2/schedulers/scheduling_flow_match_euler_discrete.py +229 -0
- omnigen2/utils/__init__.py +0 -0
- omnigen2/utils/img_util.py +31 -0
- omnigen2/utils/vpn_utils.py +69 -0
- requirements.txt +16 -0
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
5 |
+
|
6 |
+
WORKDIR /code
|
7 |
+
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
RUN pip install --upgrade pip wheel setuptools --no-cache-dir && \
|
11 |
+
pip install -r /code/requirements.txt --no-cache-dir && \
|
12 |
+
pip install flash-attn --no-build-isolation --no-cache-dir
|
13 |
+
|
14 |
+
# Set up a new user named "user" with user ID 1000
|
15 |
+
RUN useradd -m -u 1000 user
|
16 |
+
|
17 |
+
# Switch to the "user" user
|
18 |
+
USER user
|
19 |
+
|
20 |
+
# Set home to the user's home directory
|
21 |
+
ENV HOME=/home/user \
|
22 |
+
PATH=/home/user/.local/bin:$PATH
|
23 |
+
|
24 |
+
# Set the working directory to the user's home directory
|
25 |
+
WORKDIR $HOME/app
|
26 |
+
|
27 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
28 |
+
COPY --chown=user . $HOME/app
|
29 |
+
|
30 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -7,6 +7,4 @@ sdk: docker
|
|
7 |
pinned: false
|
8 |
license: apache-2.0
|
9 |
short_description: 'OmniGen2: Unified Image Understanding and Generation.'
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
7 |
pinned: false
|
8 |
license: apache-2.0
|
9 |
short_description: 'OmniGen2: Unified Image Understanding and Generation.'
|
10 |
+
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dotenv
|
2 |
+
|
3 |
+
dotenv.load_dotenv(override=True)
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import random
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torchvision.transforms.functional import to_pil_image, to_tensor
|
14 |
+
|
15 |
+
from accelerate import Accelerator
|
16 |
+
|
17 |
+
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
|
18 |
+
from omnigen2.utils.img_util import create_collage
|
19 |
+
from omnigen2.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
20 |
+
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
21 |
+
|
22 |
+
NEGATIVE_PROMPT = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"
|
23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
|
25 |
+
pipeline = None
|
26 |
+
accelerator = None
|
27 |
+
save_images = False
|
28 |
+
|
29 |
+
def load_pipeline(accelerator, weight_dtype, args):
|
30 |
+
pipeline = OmniGen2Pipeline.from_pretrained(
|
31 |
+
args.model_path,
|
32 |
+
torch_dtype=weight_dtype,
|
33 |
+
trust_remote_code=True,
|
34 |
+
)
|
35 |
+
if args.enable_sequential_cpu_offload:
|
36 |
+
pipeline.enable_sequential_cpu_offload()
|
37 |
+
elif args.enable_model_cpu_offload:
|
38 |
+
pipeline.enable_model_cpu_offload()
|
39 |
+
else:
|
40 |
+
pipeline = pipeline.to(accelerator.device)
|
41 |
+
return pipeline
|
42 |
+
|
43 |
+
|
44 |
+
def run(
|
45 |
+
instruction,
|
46 |
+
width_input,
|
47 |
+
height_input,
|
48 |
+
scheduler,
|
49 |
+
num_inference_steps,
|
50 |
+
image_input_1,
|
51 |
+
image_input_2,
|
52 |
+
image_input_3,
|
53 |
+
negative_prompt,
|
54 |
+
guidance_scale_input,
|
55 |
+
img_guidance_scale_input,
|
56 |
+
cfg_range_start,
|
57 |
+
cfg_range_end,
|
58 |
+
num_images_per_prompt,
|
59 |
+
max_input_image_side_length,
|
60 |
+
max_pixels,
|
61 |
+
seed_input,
|
62 |
+
progress=gr.Progress(),
|
63 |
+
):
|
64 |
+
input_images = [image_input_1, image_input_2, image_input_3]
|
65 |
+
input_images = [img for img in input_images if img is not None]
|
66 |
+
|
67 |
+
if len(input_images) == 0:
|
68 |
+
input_images = None
|
69 |
+
|
70 |
+
if seed_input == -1:
|
71 |
+
seed_input = random.randint(0, 2**16 - 1)
|
72 |
+
|
73 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(seed_input)
|
74 |
+
|
75 |
+
def progress_callback(cur_step, timesteps):
|
76 |
+
frac = (cur_step + 1) / float(timesteps)
|
77 |
+
progress(frac)
|
78 |
+
|
79 |
+
if scheduler == 'euler':
|
80 |
+
pipeline.scheduler = FlowMatchEulerDiscreteScheduler()
|
81 |
+
elif scheduler == 'dpmsolver':
|
82 |
+
pipeline.scheduler = DPMSolverMultistepScheduler(
|
83 |
+
algorithm_type="dpmsolver++",
|
84 |
+
solver_type="midpoint",
|
85 |
+
solver_order=2,
|
86 |
+
prediction_type="flow_prediction",
|
87 |
+
)
|
88 |
+
|
89 |
+
results = pipeline(
|
90 |
+
prompt=instruction,
|
91 |
+
input_images=input_images,
|
92 |
+
width=width_input,
|
93 |
+
height=height_input,
|
94 |
+
max_input_image_side_length=max_input_image_side_length,
|
95 |
+
max_pixels=max_pixels,
|
96 |
+
num_inference_steps=num_inference_steps,
|
97 |
+
max_sequence_length=1024,
|
98 |
+
text_guidance_scale=guidance_scale_input,
|
99 |
+
image_guidance_scale=img_guidance_scale_input,
|
100 |
+
cfg_range=(cfg_range_start, cfg_range_end),
|
101 |
+
negative_prompt=negative_prompt,
|
102 |
+
num_images_per_prompt=num_images_per_prompt,
|
103 |
+
generator=generator,
|
104 |
+
output_type="pil",
|
105 |
+
step_func=progress_callback,
|
106 |
+
)
|
107 |
+
|
108 |
+
progress(1.0)
|
109 |
+
|
110 |
+
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
|
111 |
+
output_image = create_collage(vis_images)
|
112 |
+
|
113 |
+
if save_images:
|
114 |
+
# Create outputs directory if it doesn't exist
|
115 |
+
output_dir = os.path.join(ROOT_DIR, "outputs_gradio")
|
116 |
+
os.makedirs(output_dir, exist_ok=True)
|
117 |
+
|
118 |
+
# Generate unique filename with timestamp
|
119 |
+
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
|
120 |
+
|
121 |
+
# Generate unique filename with timestamp
|
122 |
+
output_path = os.path.join(output_dir, f"{timestamp}.png")
|
123 |
+
# Save the image
|
124 |
+
output_image.save(output_path)
|
125 |
+
|
126 |
+
# Save All Generated Images
|
127 |
+
if len(results.images) > 1:
|
128 |
+
for i, image in enumerate(results.images):
|
129 |
+
image_name, ext = os.path.splitext(output_path)
|
130 |
+
image.save(f"{image_name}_{i}{ext}")
|
131 |
+
return output_image
|
132 |
+
|
133 |
+
|
134 |
+
def get_example():
|
135 |
+
case = [
|
136 |
+
[
|
137 |
+
"The sun rises slightly, the dew on the rose petals in the garden is clear, a crystal ladybug is crawling to the dew, the background is the early morning garden, macro lens.",
|
138 |
+
1024,
|
139 |
+
1024,
|
140 |
+
'euler',
|
141 |
+
50,
|
142 |
+
None,
|
143 |
+
None,
|
144 |
+
None,
|
145 |
+
NEGATIVE_PROMPT,
|
146 |
+
3.5,
|
147 |
+
1.0,
|
148 |
+
0.0,
|
149 |
+
1.0,
|
150 |
+
1,
|
151 |
+
2048,
|
152 |
+
1024 * 1024,
|
153 |
+
0,
|
154 |
+
],
|
155 |
+
[
|
156 |
+
"A snow maiden with pale translucent skin, frosty white lashes, and a soft expression of longing",
|
157 |
+
1024,
|
158 |
+
1024,
|
159 |
+
'euler',
|
160 |
+
50,
|
161 |
+
None,
|
162 |
+
None,
|
163 |
+
None,
|
164 |
+
NEGATIVE_PROMPT,
|
165 |
+
3.5,
|
166 |
+
1.0,
|
167 |
+
0.0,
|
168 |
+
1.0,
|
169 |
+
1,
|
170 |
+
2048,
|
171 |
+
1024 * 1024,
|
172 |
+
0,
|
173 |
+
],
|
174 |
+
[
|
175 |
+
"Add a fisherman hat to the woman's head",
|
176 |
+
1024,
|
177 |
+
1024,
|
178 |
+
'euler',
|
179 |
+
50,
|
180 |
+
os.path.join(ROOT_DIR, "example_images/flux5.png"),
|
181 |
+
None,
|
182 |
+
None,
|
183 |
+
NEGATIVE_PROMPT,
|
184 |
+
5.0,
|
185 |
+
2.0,
|
186 |
+
0.0,
|
187 |
+
1.0,
|
188 |
+
1,
|
189 |
+
2048,
|
190 |
+
1024 * 1024,
|
191 |
+
0,
|
192 |
+
],
|
193 |
+
[
|
194 |
+
" replace the sword with a hammer.",
|
195 |
+
1024,
|
196 |
+
1024,
|
197 |
+
'euler',
|
198 |
+
50,
|
199 |
+
os.path.join(
|
200 |
+
ROOT_DIR,
|
201 |
+
"example_images/d8f8f44c64106e7715c61b5dfa9d9ca0974314c5d4a4a50418acf7ff373432bb.png",
|
202 |
+
),
|
203 |
+
None,
|
204 |
+
None,
|
205 |
+
NEGATIVE_PROMPT,
|
206 |
+
5.0,
|
207 |
+
2.0,
|
208 |
+
0.0,
|
209 |
+
1.0,
|
210 |
+
1,
|
211 |
+
2048,
|
212 |
+
1024 * 1024,
|
213 |
+
0,
|
214 |
+
],
|
215 |
+
[
|
216 |
+
"Extract the character from the picture and fill the rest of the background with white.",
|
217 |
+
# "Transform the sculpture into jade",
|
218 |
+
1024,
|
219 |
+
1024,
|
220 |
+
'euler',
|
221 |
+
50,
|
222 |
+
os.path.join(
|
223 |
+
ROOT_DIR, "example_images/46e79704-c88e-4e68-97b4-b4c40cd29826.png"
|
224 |
+
),
|
225 |
+
None,
|
226 |
+
None,
|
227 |
+
NEGATIVE_PROMPT,
|
228 |
+
5.0,
|
229 |
+
2.0,
|
230 |
+
0.0,
|
231 |
+
1.0,
|
232 |
+
1,
|
233 |
+
2048,
|
234 |
+
1024 * 1024,
|
235 |
+
0,
|
236 |
+
],
|
237 |
+
[
|
238 |
+
"Make he smile",
|
239 |
+
1024,
|
240 |
+
1024,
|
241 |
+
'euler',
|
242 |
+
50,
|
243 |
+
os.path.join(
|
244 |
+
ROOT_DIR, "example_images/vicky-hladynets-C8Ta0gwPbQg-unsplash.jpg"
|
245 |
+
),
|
246 |
+
None,
|
247 |
+
None,
|
248 |
+
NEGATIVE_PROMPT,
|
249 |
+
5.0,
|
250 |
+
2.0,
|
251 |
+
0.0,
|
252 |
+
1.0,
|
253 |
+
1,
|
254 |
+
2048,
|
255 |
+
1024 * 1024,
|
256 |
+
0,
|
257 |
+
],
|
258 |
+
[
|
259 |
+
"Change the background to classroom",
|
260 |
+
1024,
|
261 |
+
1024,
|
262 |
+
'euler',
|
263 |
+
50,
|
264 |
+
os.path.join(ROOT_DIR, "example_images/ComfyUI_temp_mllvz_00071_.png"),
|
265 |
+
None,
|
266 |
+
None,
|
267 |
+
NEGATIVE_PROMPT,
|
268 |
+
5.0,
|
269 |
+
2.0,
|
270 |
+
0.0,
|
271 |
+
1.0,
|
272 |
+
1,
|
273 |
+
2048,
|
274 |
+
1024 * 1024,
|
275 |
+
0,
|
276 |
+
],
|
277 |
+
[
|
278 |
+
"Raise his hand",
|
279 |
+
1024,
|
280 |
+
1024,
|
281 |
+
'euler',
|
282 |
+
50,
|
283 |
+
os.path.join(
|
284 |
+
ROOT_DIR,
|
285 |
+
"example_images/289089159-a6d7abc142419e63cab0a566eb38e0fb6acb217b340f054c6172139b316f6596.png",
|
286 |
+
),
|
287 |
+
None,
|
288 |
+
None,
|
289 |
+
NEGATIVE_PROMPT,
|
290 |
+
5.0,
|
291 |
+
2.0,
|
292 |
+
0.0,
|
293 |
+
1.0,
|
294 |
+
1,
|
295 |
+
2048,
|
296 |
+
1024 * 1024,
|
297 |
+
0,
|
298 |
+
],
|
299 |
+
[
|
300 |
+
"Generate a photo of an anime-style figurine placed on a desk. The figurine model should be based on the character photo provided in the attachment, accurately replicating the full-body pose, facial expression, and clothing style of the character in the photo, ensuring the entire figurine is fully presented. The overall design should be exquisite and detailed, soft gradient colors and a delicate texture, leaning towards a Japanese anime style, rich in details, with a realistic quality and beautiful visual appeal.",
|
301 |
+
1024,
|
302 |
+
1024,
|
303 |
+
'euler',
|
304 |
+
50,
|
305 |
+
os.path.join(ROOT_DIR, "example_images/RAL_0315.JPG"),
|
306 |
+
None,
|
307 |
+
None,
|
308 |
+
NEGATIVE_PROMPT,
|
309 |
+
5.0,
|
310 |
+
2.0,
|
311 |
+
0.0,
|
312 |
+
1.0,
|
313 |
+
1,
|
314 |
+
2048,
|
315 |
+
1024 * 1024,
|
316 |
+
0,
|
317 |
+
],
|
318 |
+
[
|
319 |
+
"Change the dress to blue.",
|
320 |
+
1024,
|
321 |
+
1024,
|
322 |
+
'euler',
|
323 |
+
50,
|
324 |
+
os.path.join(ROOT_DIR, "example_images/1.png"),
|
325 |
+
None,
|
326 |
+
None,
|
327 |
+
NEGATIVE_PROMPT,
|
328 |
+
5.0,
|
329 |
+
2.0,
|
330 |
+
0.0,
|
331 |
+
1.0,
|
332 |
+
1,
|
333 |
+
2048,
|
334 |
+
1024 * 1024,
|
335 |
+
0,
|
336 |
+
],
|
337 |
+
[
|
338 |
+
"Remove the cat",
|
339 |
+
1024,
|
340 |
+
1024,
|
341 |
+
'euler',
|
342 |
+
50,
|
343 |
+
os.path.join(
|
344 |
+
ROOT_DIR,
|
345 |
+
"example_images/386724677-589d19050d4ea0603aee6831459aede29a24f4d8668c62c049f413db31508a54.png",
|
346 |
+
),
|
347 |
+
None,
|
348 |
+
None,
|
349 |
+
NEGATIVE_PROMPT,
|
350 |
+
5.0,
|
351 |
+
2.0,
|
352 |
+
0.0,
|
353 |
+
1.0,
|
354 |
+
1,
|
355 |
+
2048,
|
356 |
+
1024 * 1024,
|
357 |
+
0,
|
358 |
+
],
|
359 |
+
[
|
360 |
+
"In a cozy café, the anime figure is sitting in front of a laptop, smiling confidently.",
|
361 |
+
1024,
|
362 |
+
1024,
|
363 |
+
'euler',
|
364 |
+
50,
|
365 |
+
os.path.join(ROOT_DIR, "example_images/ComfyUI_00254_.png"),
|
366 |
+
None,
|
367 |
+
None,
|
368 |
+
NEGATIVE_PROMPT,
|
369 |
+
5.0,
|
370 |
+
2.0,
|
371 |
+
0.0,
|
372 |
+
1.0,
|
373 |
+
1,
|
374 |
+
2048,
|
375 |
+
1024 * 1024,
|
376 |
+
0,
|
377 |
+
],
|
378 |
+
[
|
379 |
+
"Create a wedding figure based on the girl in the first image and the man in the second image. Set the background as a wedding hall, with the man dressed in a suit and the girl in a white wedding dress. Ensure that the original faces remain unchanged and are accurately preserved. The man should adopt a realistic style, whereas the girl should maintain their classic anime style.",
|
380 |
+
1024,
|
381 |
+
1024,
|
382 |
+
'euler',
|
383 |
+
50,
|
384 |
+
os.path.join(ROOT_DIR, "example_images/1_20241127203215.png"),
|
385 |
+
os.path.join(ROOT_DIR, "example_images/000050281.jpg"),
|
386 |
+
None,
|
387 |
+
NEGATIVE_PROMPT,
|
388 |
+
5.0,
|
389 |
+
3.0,
|
390 |
+
0.0,
|
391 |
+
1.0,
|
392 |
+
1,
|
393 |
+
2048,
|
394 |
+
1024 * 1024,
|
395 |
+
0,
|
396 |
+
],
|
397 |
+
[
|
398 |
+
"Let the girl and the boy get married in the church. ",
|
399 |
+
1024,
|
400 |
+
1024,
|
401 |
+
'euler',
|
402 |
+
50,
|
403 |
+
os.path.join(ROOT_DIR, "example_images/8FtFUxRzXqaguVRGzkHvN.png"),
|
404 |
+
os.path.join(ROOT_DIR, "example_images/01194-20240127001056_1024x1536.png"),
|
405 |
+
None,
|
406 |
+
NEGATIVE_PROMPT,
|
407 |
+
5.0,
|
408 |
+
3.0,
|
409 |
+
0.0,
|
410 |
+
1.0,
|
411 |
+
1,
|
412 |
+
2048,
|
413 |
+
1024 * 1024,
|
414 |
+
0,
|
415 |
+
],
|
416 |
+
[
|
417 |
+
"Let the man from image1 and the woman from image2 kiss and hug",
|
418 |
+
1024,
|
419 |
+
1024,
|
420 |
+
'euler',
|
421 |
+
50,
|
422 |
+
os.path.join(ROOT_DIR, "example_images/1280X1280.png"),
|
423 |
+
os.path.join(ROOT_DIR, "example_images/000077066.jpg"),
|
424 |
+
None,
|
425 |
+
NEGATIVE_PROMPT,
|
426 |
+
5.0,
|
427 |
+
2.0,
|
428 |
+
0.0,
|
429 |
+
1.0,
|
430 |
+
1,
|
431 |
+
2048,
|
432 |
+
1024 * 1024,
|
433 |
+
0,
|
434 |
+
],
|
435 |
+
[
|
436 |
+
"Please let the person in image 2 hold the toy from the first image in a parking lot.",
|
437 |
+
1024,
|
438 |
+
1024,
|
439 |
+
'euler',
|
440 |
+
50,
|
441 |
+
os.path.join(ROOT_DIR, "example_images/04.jpg"),
|
442 |
+
os.path.join(ROOT_DIR, "example_images/000365954.jpg"),
|
443 |
+
None,
|
444 |
+
NEGATIVE_PROMPT,
|
445 |
+
5.0,
|
446 |
+
2.0,
|
447 |
+
0.0,
|
448 |
+
1.0,
|
449 |
+
1,
|
450 |
+
2048,
|
451 |
+
1024 * 1024,
|
452 |
+
0,
|
453 |
+
],
|
454 |
+
[
|
455 |
+
"Make the girl pray in the second image.",
|
456 |
+
1024,
|
457 |
+
682,
|
458 |
+
'euler',
|
459 |
+
50,
|
460 |
+
os.path.join(ROOT_DIR, "example_images/000440817.jpg"),
|
461 |
+
os.path.join(ROOT_DIR, "example_images/000119733.jpg"),
|
462 |
+
None,
|
463 |
+
NEGATIVE_PROMPT,
|
464 |
+
5.0,
|
465 |
+
2.0,
|
466 |
+
0.0,
|
467 |
+
1.0,
|
468 |
+
1,
|
469 |
+
2048,
|
470 |
+
1024 * 1024,
|
471 |
+
0,
|
472 |
+
],
|
473 |
+
[
|
474 |
+
"Add the bird from image 1 to the desk in image 2",
|
475 |
+
1024,
|
476 |
+
682,
|
477 |
+
'euler',
|
478 |
+
50,
|
479 |
+
os.path.join(
|
480 |
+
ROOT_DIR,
|
481 |
+
"example_images/996e2cf6-daa5-48c4-9ad7-0719af640c17_1748848108409.png",
|
482 |
+
),
|
483 |
+
os.path.join(ROOT_DIR, "example_images/00066-10350085.png"),
|
484 |
+
None,
|
485 |
+
NEGATIVE_PROMPT,
|
486 |
+
5.0,
|
487 |
+
2.0,
|
488 |
+
0.0,
|
489 |
+
1.0,
|
490 |
+
1,
|
491 |
+
2048,
|
492 |
+
1024 * 1024,
|
493 |
+
0,
|
494 |
+
],
|
495 |
+
[
|
496 |
+
"Replace the apple in the first image with the cat from the second image",
|
497 |
+
1024,
|
498 |
+
780,
|
499 |
+
'euler',
|
500 |
+
50,
|
501 |
+
os.path.join(ROOT_DIR, "example_images/apple.png"),
|
502 |
+
os.path.join(
|
503 |
+
ROOT_DIR,
|
504 |
+
"example_images/468404374-d52ec1a44aa7e0dc9c2807ce09d303a111c78f34da3da2401b83ce10815ff872.png",
|
505 |
+
),
|
506 |
+
None,
|
507 |
+
NEGATIVE_PROMPT,
|
508 |
+
5.0,
|
509 |
+
2.0,
|
510 |
+
0.0,
|
511 |
+
1.0,
|
512 |
+
1,
|
513 |
+
2048,
|
514 |
+
1024 * 1024,
|
515 |
+
0,
|
516 |
+
],
|
517 |
+
[
|
518 |
+
"Replace the woman in the second image with the woman from the first image",
|
519 |
+
1024,
|
520 |
+
747,
|
521 |
+
'euler',
|
522 |
+
50,
|
523 |
+
os.path.join(
|
524 |
+
ROOT_DIR, "example_images/byward-outfitters-B97YFrsITyo-unsplash.jpg"
|
525 |
+
),
|
526 |
+
os.path.join(
|
527 |
+
ROOT_DIR, "example_images/6652baf6-4096-40ef-a475-425e4c072daf.png"
|
528 |
+
),
|
529 |
+
None,
|
530 |
+
NEGATIVE_PROMPT,
|
531 |
+
5.0,
|
532 |
+
2.0,
|
533 |
+
0.0,
|
534 |
+
1.0,
|
535 |
+
1,
|
536 |
+
2048,
|
537 |
+
1024 * 1024,
|
538 |
+
0,
|
539 |
+
],
|
540 |
+
]
|
541 |
+
return case
|
542 |
+
|
543 |
+
|
544 |
+
def run_for_examples(
|
545 |
+
instruction,
|
546 |
+
width_input,
|
547 |
+
height_input,
|
548 |
+
scheduler,
|
549 |
+
num_inference_steps,
|
550 |
+
image_input_1,
|
551 |
+
image_input_2,
|
552 |
+
image_input_3,
|
553 |
+
negative_prompt,
|
554 |
+
text_guidance_scale_input,
|
555 |
+
image_guidance_scale_input,
|
556 |
+
cfg_range_start,
|
557 |
+
cfg_range_end,
|
558 |
+
num_images_per_prompt,
|
559 |
+
max_input_image_side_length,
|
560 |
+
max_pixels,
|
561 |
+
seed_input,
|
562 |
+
):
|
563 |
+
return run(
|
564 |
+
instruction,
|
565 |
+
width_input,
|
566 |
+
height_input,
|
567 |
+
scheduler,
|
568 |
+
num_inference_steps,
|
569 |
+
image_input_1,
|
570 |
+
image_input_2,
|
571 |
+
image_input_3,
|
572 |
+
negative_prompt,
|
573 |
+
text_guidance_scale_input,
|
574 |
+
image_guidance_scale_input,
|
575 |
+
cfg_range_start,
|
576 |
+
cfg_range_end,
|
577 |
+
num_images_per_prompt,
|
578 |
+
max_input_image_side_length,
|
579 |
+
max_pixels,
|
580 |
+
seed_input,
|
581 |
+
)
|
582 |
+
|
583 |
+
description = """
|
584 |
+
### 💡 Quick Tips for Best Results (see our [github](https://github.com/VectorSpaceLab/OmniGen2?tab=readme-ov-file#-usage-tips) for more details)
|
585 |
+
- Image Quality: Use high-resolution images (at least 512x512 recommended).
|
586 |
+
- Be Specific: Instead of "Add bird to desk", try "Add the bird from image 1 to the desk in image 2".
|
587 |
+
- Use English: English prompts currently yield better results.
|
588 |
+
- Adjust image_guidance_scale for better consistency with the reference image:
|
589 |
+
- Image Editing: 1.3 - 2.0
|
590 |
+
- In-context Generation: 2.0 - 3.0
|
591 |
+
"""
|
592 |
+
|
593 |
+
article = """
|
594 |
+
citation to be added
|
595 |
+
"""
|
596 |
+
|
597 |
+
def main(args):
|
598 |
+
# Gradio
|
599 |
+
with gr.Blocks() as demo:
|
600 |
+
gr.Markdown(
|
601 |
+
"# OmniGen2: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen2)"
|
602 |
+
)
|
603 |
+
gr.Markdown(description)
|
604 |
+
with gr.Row():
|
605 |
+
with gr.Column():
|
606 |
+
# text prompt
|
607 |
+
instruction = gr.Textbox(
|
608 |
+
label='Enter your prompt. Use "first/second image" or “第一张图/第二张图” as reference.',
|
609 |
+
placeholder="Type your prompt here...",
|
610 |
+
)
|
611 |
+
|
612 |
+
with gr.Row(equal_height=True):
|
613 |
+
# input images
|
614 |
+
image_input_1 = gr.Image(label="First Image", type="pil")
|
615 |
+
image_input_2 = gr.Image(label="Second Image", type="pil")
|
616 |
+
image_input_3 = gr.Image(label="Third Image", type="pil")
|
617 |
+
|
618 |
+
generate_button = gr.Button("Generate Image")
|
619 |
+
|
620 |
+
negative_prompt = gr.Textbox(
|
621 |
+
label="Enter your negative prompt",
|
622 |
+
placeholder="Type your negative prompt here...",
|
623 |
+
value=NEGATIVE_PROMPT,
|
624 |
+
)
|
625 |
+
|
626 |
+
# slider
|
627 |
+
with gr.Row(equal_height=True):
|
628 |
+
height_input = gr.Slider(
|
629 |
+
label="Height", minimum=256, maximum=1024, value=1024, step=128
|
630 |
+
)
|
631 |
+
width_input = gr.Slider(
|
632 |
+
label="Width", minimum=256, maximum=1024, value=1024, step=128
|
633 |
+
)
|
634 |
+
with gr.Row(equal_height=True):
|
635 |
+
text_guidance_scale_input = gr.Slider(
|
636 |
+
label="Text Guidance Scale",
|
637 |
+
minimum=1.0,
|
638 |
+
maximum=8.0,
|
639 |
+
value=5.0,
|
640 |
+
step=0.1,
|
641 |
+
)
|
642 |
+
|
643 |
+
image_guidance_scale_input = gr.Slider(
|
644 |
+
label="Image Guidance Scale",
|
645 |
+
minimum=1.0,
|
646 |
+
maximum=3.0,
|
647 |
+
value=2.0,
|
648 |
+
step=0.1,
|
649 |
+
)
|
650 |
+
with gr.Row(equal_height=True):
|
651 |
+
cfg_range_start = gr.Slider(
|
652 |
+
label="CFG Range Start",
|
653 |
+
minimum=0.0,
|
654 |
+
maximum=1.0,
|
655 |
+
value=0.0,
|
656 |
+
step=0.1,
|
657 |
+
)
|
658 |
+
|
659 |
+
cfg_range_end = gr.Slider(
|
660 |
+
label="CFG Range End",
|
661 |
+
minimum=0.0,
|
662 |
+
maximum=1.0,
|
663 |
+
value=1.0,
|
664 |
+
step=0.1,
|
665 |
+
)
|
666 |
+
|
667 |
+
def adjust_end_slider(start_val, end_val):
|
668 |
+
return max(start_val, end_val)
|
669 |
+
|
670 |
+
def adjust_start_slider(end_val, start_val):
|
671 |
+
return min(end_val, start_val)
|
672 |
+
|
673 |
+
cfg_range_start.input(
|
674 |
+
fn=adjust_end_slider,
|
675 |
+
inputs=[cfg_range_start, cfg_range_end],
|
676 |
+
outputs=[cfg_range_end]
|
677 |
+
)
|
678 |
+
|
679 |
+
cfg_range_end.input(
|
680 |
+
fn=adjust_start_slider,
|
681 |
+
inputs=[cfg_range_end, cfg_range_start],
|
682 |
+
outputs=[cfg_range_start]
|
683 |
+
)
|
684 |
+
|
685 |
+
with gr.Row(equal_height=True):
|
686 |
+
scheduler_input = gr.Dropdown(
|
687 |
+
label="Scheduler",
|
688 |
+
choices=["euler", "dpmsolver"],
|
689 |
+
value="euler",
|
690 |
+
info="The scheduler to use for the model.",
|
691 |
+
)
|
692 |
+
|
693 |
+
num_inference_steps = gr.Slider(
|
694 |
+
label="Inference Steps", minimum=20, maximum=100, value=50, step=1
|
695 |
+
)
|
696 |
+
with gr.Row(equal_height=True):
|
697 |
+
num_images_per_prompt = gr.Slider(
|
698 |
+
label="Number of images per prompt",
|
699 |
+
minimum=1,
|
700 |
+
maximum=4,
|
701 |
+
value=1,
|
702 |
+
step=1,
|
703 |
+
)
|
704 |
+
|
705 |
+
seed_input = gr.Slider(
|
706 |
+
label="Seed", minimum=-1, maximum=2147483647, value=0, step=1
|
707 |
+
)
|
708 |
+
with gr.Row(equal_height=True):
|
709 |
+
max_input_image_side_length = gr.Slider(
|
710 |
+
label="max_input_image_side_length",
|
711 |
+
minimum=256,
|
712 |
+
maximum=2048,
|
713 |
+
value=2048,
|
714 |
+
step=256,
|
715 |
+
)
|
716 |
+
max_pixels = gr.Slider(
|
717 |
+
label="max_pixels",
|
718 |
+
minimum=256 * 256,
|
719 |
+
maximum=1536 * 1536,
|
720 |
+
value=1024 * 1024,
|
721 |
+
step=256 * 256,
|
722 |
+
)
|
723 |
+
|
724 |
+
with gr.Column():
|
725 |
+
with gr.Column():
|
726 |
+
# output image
|
727 |
+
output_image = gr.Image(label="Output Image")
|
728 |
+
global save_images
|
729 |
+
save_images = gr.Checkbox(label="Save generated images", value=False)
|
730 |
+
|
731 |
+
global accelerator
|
732 |
+
global pipeline
|
733 |
+
|
734 |
+
bf16 = True
|
735 |
+
accelerator = Accelerator(mixed_precision="bf16" if bf16 else "no")
|
736 |
+
weight_dtype = torch.bfloat16 if bf16 else torch.float32
|
737 |
+
|
738 |
+
pipeline = load_pipeline(accelerator, weight_dtype, args)
|
739 |
+
|
740 |
+
# click
|
741 |
+
generate_button.click(
|
742 |
+
run,
|
743 |
+
inputs=[
|
744 |
+
instruction,
|
745 |
+
width_input,
|
746 |
+
height_input,
|
747 |
+
scheduler_input,
|
748 |
+
num_inference_steps,
|
749 |
+
image_input_1,
|
750 |
+
image_input_2,
|
751 |
+
image_input_3,
|
752 |
+
negative_prompt,
|
753 |
+
text_guidance_scale_input,
|
754 |
+
image_guidance_scale_input,
|
755 |
+
cfg_range_start,
|
756 |
+
cfg_range_end,
|
757 |
+
num_images_per_prompt,
|
758 |
+
max_input_image_side_length,
|
759 |
+
max_pixels,
|
760 |
+
seed_input,
|
761 |
+
],
|
762 |
+
outputs=output_image,
|
763 |
+
)
|
764 |
+
|
765 |
+
gr.Examples(
|
766 |
+
examples=get_example(),
|
767 |
+
fn=run_for_examples,
|
768 |
+
inputs=[
|
769 |
+
instruction,
|
770 |
+
width_input,
|
771 |
+
height_input,
|
772 |
+
scheduler_input,
|
773 |
+
num_inference_steps,
|
774 |
+
image_input_1,
|
775 |
+
image_input_2,
|
776 |
+
image_input_3,
|
777 |
+
negative_prompt,
|
778 |
+
text_guidance_scale_input,
|
779 |
+
image_guidance_scale_input,
|
780 |
+
cfg_range_start,
|
781 |
+
cfg_range_end,
|
782 |
+
num_images_per_prompt,
|
783 |
+
max_input_image_side_length,
|
784 |
+
max_pixels,
|
785 |
+
seed_input,
|
786 |
+
],
|
787 |
+
outputs=output_image,
|
788 |
+
)
|
789 |
+
|
790 |
+
gr.Markdown(article)
|
791 |
+
# launch
|
792 |
+
demo.launch(share=args.share, server_port=args.port, allowed_paths=[ROOT_DIR])
|
793 |
+
|
794 |
+
def parse_args():
|
795 |
+
parser = argparse.ArgumentParser(description="Run the OmniGen2")
|
796 |
+
parser.add_argument("--share", action="store_true", help="Share the Gradio app")
|
797 |
+
parser.add_argument(
|
798 |
+
"--port", type=int, default=7860, help="Port to use for the Gradio app"
|
799 |
+
)
|
800 |
+
parser.add_argument(
|
801 |
+
"--model_path",
|
802 |
+
type=str,
|
803 |
+
default="OmniGen2/OmniGen2",
|
804 |
+
help="Path or HuggingFace name of the model to load."
|
805 |
+
)
|
806 |
+
parser.add_argument(
|
807 |
+
"--enable_model_cpu_offload",
|
808 |
+
action="store_true",
|
809 |
+
help="Enable model CPU offload."
|
810 |
+
)
|
811 |
+
parser.add_argument(
|
812 |
+
"--enable_sequential_cpu_offload",
|
813 |
+
action="store_true",
|
814 |
+
help="Enable sequential CPU offload."
|
815 |
+
)
|
816 |
+
args = parser.parse_args()
|
817 |
+
return args
|
818 |
+
|
819 |
+
if __name__ == "__main__":
|
820 |
+
args = parse_args()
|
821 |
+
main(args)
|
omnigen2/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
omnigen2/__init__.py
ADDED
File without changes
|
omnigen2/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (160 Bytes). View file
|
|
omnigen2/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (139 Bytes). View file
|
|
omnigen2/models/__init__.py
ADDED
File without changes
|
omnigen2/models/attention_processor.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
OmniGen2 Attention Processor Module
|
3 |
+
|
4 |
+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
5 |
+
|
6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
you may not use this file except in compliance with the License.
|
8 |
+
You may obtain a copy of the License at
|
9 |
+
|
10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
|
12 |
+
Unless required by applicable law or agreed to in writing, software
|
13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
See the License for the specific language governing permissions and
|
16 |
+
limitations under the License.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import Optional, Tuple, Dict, Any
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from einops import repeat
|
25 |
+
from flash_attn import flash_attn_varlen_func
|
26 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
27 |
+
|
28 |
+
from diffusers.models.attention_processor import Attention
|
29 |
+
from .embeddings import apply_rotary_emb
|
30 |
+
|
31 |
+
|
32 |
+
class OmniGen2AttnProcessorFlash2Varlen:
|
33 |
+
"""
|
34 |
+
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
|
35 |
+
|
36 |
+
This processor is optimized for PyTorch 2.0 and implements:
|
37 |
+
- Flash attention with variable length sequences
|
38 |
+
- Rotary position embeddings (RoPE)
|
39 |
+
- Query-Key normalization
|
40 |
+
- Proportional attention scaling
|
41 |
+
|
42 |
+
Args:
|
43 |
+
None
|
44 |
+
|
45 |
+
Raises:
|
46 |
+
ImportError: If PyTorch version is less than 2.0
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self) -> None:
|
50 |
+
"""Initialize the attention processor."""
|
51 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
52 |
+
raise ImportError(
|
53 |
+
"OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
|
54 |
+
"Please upgrade PyTorch to version 2.0 or later."
|
55 |
+
)
|
56 |
+
|
57 |
+
def _upad_input(
|
58 |
+
self,
|
59 |
+
query_layer: torch.Tensor,
|
60 |
+
key_layer: torch.Tensor,
|
61 |
+
value_layer: torch.Tensor,
|
62 |
+
attention_mask: torch.Tensor,
|
63 |
+
query_length: int,
|
64 |
+
num_heads: int,
|
65 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
66 |
+
"""
|
67 |
+
Unpad the input tensors for flash attention.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
|
71 |
+
key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
72 |
+
value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
73 |
+
attention_mask: Attention mask tensor of shape (batch_size, seq_len)
|
74 |
+
query_length: Length of the query sequence
|
75 |
+
num_heads: Number of attention heads
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Tuple containing:
|
79 |
+
- Unpadded query tensor
|
80 |
+
- Unpadded key tensor
|
81 |
+
- Unpadded value tensor
|
82 |
+
- Query indices
|
83 |
+
- Tuple of cumulative sequence lengths for query and key
|
84 |
+
- Tuple of maximum sequence lengths for query and key
|
85 |
+
"""
|
86 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
87 |
+
"""Helper function to get unpadding data from attention mask."""
|
88 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
89 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
90 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
91 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
92 |
+
return indices, cu_seqlens, max_seqlen_in_batch
|
93 |
+
|
94 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
95 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
96 |
+
|
97 |
+
# Unpad key and value layers
|
98 |
+
key_layer = index_first_axis(
|
99 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
100 |
+
indices_k,
|
101 |
+
)
|
102 |
+
value_layer = index_first_axis(
|
103 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
104 |
+
indices_k,
|
105 |
+
)
|
106 |
+
|
107 |
+
# Handle different query length cases
|
108 |
+
if query_length == kv_seq_len:
|
109 |
+
query_layer = index_first_axis(
|
110 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
|
111 |
+
indices_k,
|
112 |
+
)
|
113 |
+
cu_seqlens_q = cu_seqlens_k
|
114 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
115 |
+
indices_q = indices_k
|
116 |
+
elif query_length == 1:
|
117 |
+
max_seqlen_in_batch_q = 1
|
118 |
+
cu_seqlens_q = torch.arange(
|
119 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
120 |
+
)
|
121 |
+
indices_q = cu_seqlens_q[:-1]
|
122 |
+
query_layer = query_layer.squeeze(1)
|
123 |
+
else:
|
124 |
+
attention_mask = attention_mask[:, -query_length:]
|
125 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
126 |
+
|
127 |
+
return (
|
128 |
+
query_layer,
|
129 |
+
key_layer,
|
130 |
+
value_layer,
|
131 |
+
indices_q,
|
132 |
+
(cu_seqlens_q, cu_seqlens_k),
|
133 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
134 |
+
)
|
135 |
+
|
136 |
+
def __call__(
|
137 |
+
self,
|
138 |
+
attn: Attention,
|
139 |
+
hidden_states: torch.Tensor,
|
140 |
+
encoder_hidden_states: torch.Tensor,
|
141 |
+
attention_mask: Optional[torch.Tensor] = None,
|
142 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
143 |
+
base_sequence_length: Optional[int] = None,
|
144 |
+
) -> torch.Tensor:
|
145 |
+
"""
|
146 |
+
Process attention computation with flash attention.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
attn: Attention module
|
150 |
+
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
|
151 |
+
encoder_hidden_states: Encoder hidden states tensor
|
152 |
+
attention_mask: Optional attention mask tensor
|
153 |
+
image_rotary_emb: Optional rotary embeddings for image tokens
|
154 |
+
base_sequence_length: Optional base sequence length for proportional attention
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
torch.Tensor: Processed hidden states after attention computation
|
158 |
+
"""
|
159 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
160 |
+
|
161 |
+
# Get Query-Key-Value Pair
|
162 |
+
query = attn.to_q(hidden_states)
|
163 |
+
key = attn.to_k(encoder_hidden_states)
|
164 |
+
value = attn.to_v(encoder_hidden_states)
|
165 |
+
|
166 |
+
query_dim = query.shape[-1]
|
167 |
+
inner_dim = key.shape[-1]
|
168 |
+
head_dim = query_dim // attn.heads
|
169 |
+
dtype = query.dtype
|
170 |
+
|
171 |
+
# Get key-value heads
|
172 |
+
kv_heads = inner_dim // head_dim
|
173 |
+
|
174 |
+
# Reshape tensors for attention computation
|
175 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
176 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
177 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
178 |
+
|
179 |
+
# Apply Query-Key normalization
|
180 |
+
if attn.norm_q is not None:
|
181 |
+
query = attn.norm_q(query)
|
182 |
+
if attn.norm_k is not None:
|
183 |
+
key = attn.norm_k(key)
|
184 |
+
|
185 |
+
# Apply Rotary Position Embeddings
|
186 |
+
if image_rotary_emb is not None:
|
187 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
188 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
189 |
+
|
190 |
+
query, key = query.to(dtype), key.to(dtype)
|
191 |
+
|
192 |
+
# Calculate attention scale
|
193 |
+
if base_sequence_length is not None:
|
194 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
195 |
+
else:
|
196 |
+
softmax_scale = attn.scale
|
197 |
+
|
198 |
+
# Unpad input for flash attention
|
199 |
+
(
|
200 |
+
query_states,
|
201 |
+
key_states,
|
202 |
+
value_states,
|
203 |
+
indices_q,
|
204 |
+
cu_seq_lens,
|
205 |
+
max_seq_lens,
|
206 |
+
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
|
207 |
+
|
208 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
209 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
210 |
+
|
211 |
+
# Handle different number of heads
|
212 |
+
if kv_heads < attn.heads:
|
213 |
+
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
214 |
+
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
215 |
+
|
216 |
+
# Apply flash attention
|
217 |
+
attn_output_unpad = flash_attn_varlen_func(
|
218 |
+
query_states,
|
219 |
+
key_states,
|
220 |
+
value_states,
|
221 |
+
cu_seqlens_q=cu_seqlens_q,
|
222 |
+
cu_seqlens_k=cu_seqlens_k,
|
223 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
224 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
225 |
+
dropout_p=0.0,
|
226 |
+
causal=False,
|
227 |
+
softmax_scale=softmax_scale,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Pad output and apply final transformations
|
231 |
+
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
|
232 |
+
hidden_states = hidden_states.flatten(-2)
|
233 |
+
hidden_states = hidden_states.type_as(query)
|
234 |
+
|
235 |
+
# Apply output projection
|
236 |
+
hidden_states = attn.to_out[0](hidden_states)
|
237 |
+
hidden_states = attn.to_out[1](hidden_states)
|
238 |
+
|
239 |
+
return hidden_states
|
omnigen2/models/embeddings.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import List, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
|
20 |
+
from diffusers.models.activations import get_activation
|
21 |
+
|
22 |
+
|
23 |
+
class TimestepEmbedding(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
in_channels: int,
|
27 |
+
time_embed_dim: int,
|
28 |
+
act_fn: str = "silu",
|
29 |
+
out_dim: int = None,
|
30 |
+
post_act_fn: Optional[str] = None,
|
31 |
+
cond_proj_dim=None,
|
32 |
+
sample_proj_bias=True,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
37 |
+
|
38 |
+
if cond_proj_dim is not None:
|
39 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
40 |
+
else:
|
41 |
+
self.cond_proj = None
|
42 |
+
|
43 |
+
self.act = get_activation(act_fn)
|
44 |
+
|
45 |
+
if out_dim is not None:
|
46 |
+
time_embed_dim_out = out_dim
|
47 |
+
else:
|
48 |
+
time_embed_dim_out = time_embed_dim
|
49 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
50 |
+
|
51 |
+
if post_act_fn is None:
|
52 |
+
self.post_act = None
|
53 |
+
else:
|
54 |
+
self.post_act = get_activation(post_act_fn)
|
55 |
+
|
56 |
+
self.initialize_weights()
|
57 |
+
|
58 |
+
def initialize_weights(self):
|
59 |
+
nn.init.normal_(self.linear_1.weight, std=0.02)
|
60 |
+
nn.init.zeros_(self.linear_1.bias)
|
61 |
+
nn.init.normal_(self.linear_2.weight, std=0.02)
|
62 |
+
nn.init.zeros_(self.linear_2.bias)
|
63 |
+
|
64 |
+
def forward(self, sample, condition=None):
|
65 |
+
if condition is not None:
|
66 |
+
sample = sample + self.cond_proj(condition)
|
67 |
+
sample = self.linear_1(sample)
|
68 |
+
|
69 |
+
if self.act is not None:
|
70 |
+
sample = self.act(sample)
|
71 |
+
|
72 |
+
sample = self.linear_2(sample)
|
73 |
+
|
74 |
+
if self.post_act is not None:
|
75 |
+
sample = self.post_act(sample)
|
76 |
+
return sample
|
77 |
+
|
78 |
+
|
79 |
+
def apply_rotary_emb(
|
80 |
+
x: torch.Tensor,
|
81 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
82 |
+
use_real: bool = True,
|
83 |
+
use_real_unbind_dim: int = -1,
|
84 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
85 |
+
"""
|
86 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
87 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
88 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
89 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
x (`torch.Tensor`):
|
93 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
94 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
98 |
+
"""
|
99 |
+
if use_real:
|
100 |
+
cos, sin = freqs_cis # [S, D]
|
101 |
+
cos = cos[None, None]
|
102 |
+
sin = sin[None, None]
|
103 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
104 |
+
|
105 |
+
if use_real_unbind_dim == -1:
|
106 |
+
# Used for flux, cogvideox, hunyuan-dit
|
107 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
108 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
109 |
+
elif use_real_unbind_dim == -2:
|
110 |
+
# Used for Stable Audio, OmniGen and CogView4
|
111 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
112 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
113 |
+
else:
|
114 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
115 |
+
|
116 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
117 |
+
|
118 |
+
return out
|
119 |
+
else:
|
120 |
+
# used for lumina
|
121 |
+
# x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
122 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
|
123 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
124 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
125 |
+
|
126 |
+
return x_out.type_as(x)
|
omnigen2/models/transformers/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .transformer_omnigen2 import OmniGen2Transformer2DModel
|
2 |
+
|
3 |
+
__all__ = ["OmniGen2Transformer2DModel"]
|
omnigen2/models/transformers/block_lumina2.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import warnings
|
17 |
+
import itertools
|
18 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from diffusers.models.embeddings import Timesteps
|
25 |
+
from ..embeddings import TimestepEmbedding
|
26 |
+
from .components import swiglu
|
27 |
+
|
28 |
+
try:
|
29 |
+
# from apex.normalization import FusedRMSNorm
|
30 |
+
# from flash_attn.ops.rms_norm import RMSNorm as FusedRMSNorm
|
31 |
+
# from flash_attn.ops.triton.layer_norm import RMSNorm as FusedRMSNorm
|
32 |
+
from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm
|
33 |
+
FUSEDRMSNORM_AVALIBLE = True
|
34 |
+
except ImportError:
|
35 |
+
FUSEDRMSNORM_AVALIBLE = False
|
36 |
+
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
37 |
+
|
38 |
+
try:
|
39 |
+
from flash_attn.ops.activations import swiglu as fused_swiglu
|
40 |
+
FUSEDSWIGLU_AVALIBLE = True
|
41 |
+
except ImportError:
|
42 |
+
|
43 |
+
FUSEDSWIGLU_AVALIBLE = False
|
44 |
+
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
45 |
+
|
46 |
+
class LuminaRMSNormZero(nn.Module):
|
47 |
+
"""
|
48 |
+
Norm layer adaptive RMS normalization zero.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
embedding_dim (`int`): The size of each embedding vector.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
embedding_dim: int,
|
57 |
+
norm_eps: float,
|
58 |
+
norm_elementwise_affine: bool,
|
59 |
+
use_fused_rms_norm: bool = False,
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.silu = nn.SiLU()
|
63 |
+
self.linear = nn.Linear(
|
64 |
+
min(embedding_dim, 1024),
|
65 |
+
4 * embedding_dim,
|
66 |
+
bias=True,
|
67 |
+
)
|
68 |
+
if use_fused_rms_norm:
|
69 |
+
assert FUSEDRMSNORM_AVALIBLE
|
70 |
+
self.norm = FusedRMSNorm(embedding_dim, eps=norm_eps)
|
71 |
+
else:
|
72 |
+
self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps)
|
73 |
+
|
74 |
+
def forward(
|
75 |
+
self,
|
76 |
+
x: torch.Tensor,
|
77 |
+
emb: Optional[torch.Tensor] = None,
|
78 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
79 |
+
emb = self.linear(self.silu(emb))
|
80 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
81 |
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
82 |
+
# x_norm = self.norm(x)
|
83 |
+
# print(f"{x.shape=} {x.dtype=} {x_norm.shape=} {x_norm.dtype=}")
|
84 |
+
# print(f"{scale_msa.shape=} {scale_msa.dtype=}")
|
85 |
+
# print(f"{scale_msa[:, None].shape=} {scale_msa[:, None].dtype=}")
|
86 |
+
# x = x_norm * (1 + scale_msa[:, None])
|
87 |
+
|
88 |
+
return x, gate_msa, scale_mlp, gate_mlp
|
89 |
+
|
90 |
+
|
91 |
+
class LuminaLayerNormContinuous(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
embedding_dim: int,
|
95 |
+
conditioning_embedding_dim: int,
|
96 |
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
97 |
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
98 |
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
99 |
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
100 |
+
# set `elementwise_affine` to False.
|
101 |
+
elementwise_affine=True,
|
102 |
+
eps=1e-5,
|
103 |
+
bias=True,
|
104 |
+
norm_type="layer_norm",
|
105 |
+
out_dim: Optional[int] = None,
|
106 |
+
use_fused_rms_norm: bool = False
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
# AdaLN
|
111 |
+
self.silu = nn.SiLU()
|
112 |
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
113 |
+
|
114 |
+
if norm_type == "layer_norm":
|
115 |
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
116 |
+
elif norm_type == "rms_norm":
|
117 |
+
if use_fused_rms_norm:
|
118 |
+
assert FUSEDRMSNORM_AVALIBLE
|
119 |
+
self.norm = FusedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
120 |
+
else:
|
121 |
+
self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
122 |
+
else:
|
123 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
124 |
+
|
125 |
+
self.linear_2 = None
|
126 |
+
if out_dim is not None:
|
127 |
+
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
x: torch.Tensor,
|
132 |
+
conditioning_embedding: torch.Tensor,
|
133 |
+
) -> torch.Tensor:
|
134 |
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
135 |
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
136 |
+
scale = emb
|
137 |
+
x = self.norm(x) * (1 + scale)[:, None, :]
|
138 |
+
|
139 |
+
if self.linear_2 is not None:
|
140 |
+
x = self.linear_2(x)
|
141 |
+
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
class LuminaFeedForward(nn.Module):
|
146 |
+
r"""
|
147 |
+
A feed-forward layer.
|
148 |
+
|
149 |
+
Parameters:
|
150 |
+
hidden_size (`int`):
|
151 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
152 |
+
hidden representations.
|
153 |
+
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
154 |
+
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
155 |
+
of this value.
|
156 |
+
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
157 |
+
dimension. Defaults to None.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
dim: int,
|
163 |
+
inner_dim: int,
|
164 |
+
multiple_of: Optional[int] = 256,
|
165 |
+
ffn_dim_multiplier: Optional[float] = None,
|
166 |
+
use_fused_swiglu: bool = False
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
self.use_fused_swiglu = use_fused_swiglu
|
170 |
+
|
171 |
+
if use_fused_swiglu:
|
172 |
+
assert FUSEDSWIGLU_AVALIBLE
|
173 |
+
self.swiglu = fused_swiglu
|
174 |
+
else:
|
175 |
+
self.swiglu = swiglu
|
176 |
+
|
177 |
+
# custom hidden_size factor multiplier
|
178 |
+
if ffn_dim_multiplier is not None:
|
179 |
+
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
180 |
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
181 |
+
|
182 |
+
self.linear_1 = nn.Linear(
|
183 |
+
dim,
|
184 |
+
inner_dim,
|
185 |
+
bias=False,
|
186 |
+
)
|
187 |
+
self.linear_2 = nn.Linear(
|
188 |
+
inner_dim,
|
189 |
+
dim,
|
190 |
+
bias=False,
|
191 |
+
)
|
192 |
+
self.linear_3 = nn.Linear(
|
193 |
+
dim,
|
194 |
+
inner_dim,
|
195 |
+
bias=False,
|
196 |
+
)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
h1, h2 = self.linear_1(x), self.linear_3(x)
|
200 |
+
return self.linear_2(self.swiglu(h1, h2))
|
201 |
+
|
202 |
+
|
203 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
hidden_size: int = 4096,
|
207 |
+
text_feat_dim: int = 2048,
|
208 |
+
frequency_embedding_size: int = 256,
|
209 |
+
norm_eps: float = 1e-5,
|
210 |
+
timestep_scale: float = 1.0,
|
211 |
+
use_fused_rms_norm: bool = False
|
212 |
+
) -> None:
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
self.time_proj = Timesteps(
|
216 |
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
|
217 |
+
)
|
218 |
+
|
219 |
+
self.timestep_embedder = TimestepEmbedding(
|
220 |
+
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
221 |
+
)
|
222 |
+
|
223 |
+
if use_fused_rms_norm:
|
224 |
+
assert FUSEDRMSNORM_AVALIBLE
|
225 |
+
RMSNorm = FusedRMSNorm
|
226 |
+
else:
|
227 |
+
RMSNorm = nn.RMSNorm
|
228 |
+
|
229 |
+
self.caption_embedder = nn.Sequential(
|
230 |
+
RMSNorm(text_feat_dim, eps=norm_eps),
|
231 |
+
nn.Linear(text_feat_dim, hidden_size, bias=True),
|
232 |
+
)
|
233 |
+
|
234 |
+
self._initialize_weights()
|
235 |
+
|
236 |
+
def _initialize_weights(self):
|
237 |
+
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
|
238 |
+
nn.init.zeros_(self.caption_embedder[1].bias)
|
239 |
+
|
240 |
+
def forward(
|
241 |
+
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
242 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
243 |
+
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
244 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
245 |
+
caption_embed = self.caption_embedder(text_hidden_states)
|
246 |
+
return time_embed, caption_embed
|
omnigen2/models/transformers/components.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
def swiglu(x, y):
|
4 |
+
return F.silu(x.float(), inplace=False).to(x.dtype) * y
|
omnigen2/models/transformers/repo.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from einops import repeat
|
7 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
8 |
+
|
9 |
+
class OmniGen2RotaryPosEmbed(nn.Module):
|
10 |
+
def __init__(self, theta: int,
|
11 |
+
axes_dim: Tuple[int, int, int],
|
12 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
13 |
+
patch_size: int = 2):
|
14 |
+
super().__init__()
|
15 |
+
self.theta = theta
|
16 |
+
self.axes_dim = axes_dim
|
17 |
+
self.axes_lens = axes_lens
|
18 |
+
self.patch_size = patch_size
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def get_freqs_cis(axes_dim: Tuple[int, int, int],
|
22 |
+
axes_lens: Tuple[int, int, int],
|
23 |
+
theta: int) -> List[torch.Tensor]:
|
24 |
+
freqs_cis = []
|
25 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
26 |
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
27 |
+
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
|
28 |
+
freqs_cis.append(emb)
|
29 |
+
return freqs_cis
|
30 |
+
|
31 |
+
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
|
32 |
+
device = ids.device
|
33 |
+
if ids.device.type == "mps":
|
34 |
+
ids = ids.to("cpu")
|
35 |
+
|
36 |
+
result = []
|
37 |
+
for i in range(len(self.axes_dim)):
|
38 |
+
freqs = freqs_cis[i].to(ids.device)
|
39 |
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
40 |
+
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
41 |
+
return torch.cat(result, dim=-1).to(device)
|
42 |
+
|
43 |
+
def forward(
|
44 |
+
self,
|
45 |
+
freqs_cis,
|
46 |
+
attention_mask,
|
47 |
+
l_effective_ref_img_len,
|
48 |
+
l_effective_img_len,
|
49 |
+
ref_img_sizes,
|
50 |
+
img_sizes,
|
51 |
+
device
|
52 |
+
):
|
53 |
+
batch_size = len(attention_mask)
|
54 |
+
p = self.patch_size
|
55 |
+
|
56 |
+
encoder_seq_len = attention_mask.shape[1]
|
57 |
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
58 |
+
|
59 |
+
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
60 |
+
|
61 |
+
max_seq_len = max(seq_lengths)
|
62 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
63 |
+
max_img_len = max(l_effective_img_len)
|
64 |
+
|
65 |
+
# Create position IDs
|
66 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
67 |
+
|
68 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
69 |
+
# add text position ids
|
70 |
+
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
71 |
+
|
72 |
+
pe_shift = cap_seq_len
|
73 |
+
pe_shift_len = cap_seq_len
|
74 |
+
|
75 |
+
if ref_img_sizes[i] is not None:
|
76 |
+
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
77 |
+
H, W = ref_img_size
|
78 |
+
ref_H_tokens, ref_W_tokens = H // p, W // p
|
79 |
+
assert ref_H_tokens * ref_W_tokens == ref_img_len
|
80 |
+
# add image position ids
|
81 |
+
|
82 |
+
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
83 |
+
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
84 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
85 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
86 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
87 |
+
|
88 |
+
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
89 |
+
pe_shift_len += ref_img_len
|
90 |
+
|
91 |
+
H, W = img_sizes[i]
|
92 |
+
H_tokens, W_tokens = H // p, W // p
|
93 |
+
assert H_tokens * W_tokens == l_effective_img_len[i]
|
94 |
+
|
95 |
+
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
96 |
+
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
97 |
+
|
98 |
+
assert pe_shift_len + l_effective_img_len[i] == seq_len
|
99 |
+
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
100 |
+
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
101 |
+
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
102 |
+
|
103 |
+
# Get combined rotary embeddings
|
104 |
+
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
|
105 |
+
|
106 |
+
# create separate rotary embeddings for captions and images
|
107 |
+
cap_freqs_cis = torch.zeros(
|
108 |
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
109 |
+
)
|
110 |
+
ref_img_freqs_cis = torch.zeros(
|
111 |
+
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
112 |
+
)
|
113 |
+
img_freqs_cis = torch.zeros(
|
114 |
+
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
115 |
+
)
|
116 |
+
|
117 |
+
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
118 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
119 |
+
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
120 |
+
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
121 |
+
|
122 |
+
return (
|
123 |
+
cap_freqs_cis,
|
124 |
+
ref_img_freqs_cis,
|
125 |
+
img_freqs_cis,
|
126 |
+
freqs_cis,
|
127 |
+
l_effective_cap_len,
|
128 |
+
seq_lengths,
|
129 |
+
)
|
omnigen2/models/transformers/transformer_omnigen2.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import itertools
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers.loaders import PeftAdapterMixin
|
12 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
13 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
14 |
+
from diffusers.models.attention_processor import Attention
|
15 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
16 |
+
from diffusers.models.modeling_utils import ModelMixin
|
17 |
+
|
18 |
+
from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen
|
19 |
+
from .repo import OmniGen2RotaryPosEmbed
|
20 |
+
from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
|
21 |
+
|
22 |
+
try:
|
23 |
+
from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm
|
24 |
+
FUSEDRMSNORM_AVALIBLE = True
|
25 |
+
except ImportError:
|
26 |
+
FUSEDRMSNORM_AVALIBLE = False
|
27 |
+
warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation")
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class OmniGen2TransformerBlock(nn.Module):
|
33 |
+
"""
|
34 |
+
Transformer block for OmniGen2 model.
|
35 |
+
|
36 |
+
This block implements a transformer layer with:
|
37 |
+
- Multi-head attention with flash attention
|
38 |
+
- Feed-forward network with SwiGLU activation
|
39 |
+
- RMS normalization
|
40 |
+
- Optional modulation for conditional generation
|
41 |
+
|
42 |
+
Args:
|
43 |
+
dim: Dimension of the input and output tensors
|
44 |
+
num_attention_heads: Number of attention heads
|
45 |
+
num_kv_heads: Number of key-value heads
|
46 |
+
multiple_of: Multiple of which the hidden dimension should be
|
47 |
+
ffn_dim_multiplier: Multiplier for the feed-forward network dimension
|
48 |
+
norm_eps: Epsilon value for normalization layers
|
49 |
+
modulation: Whether to use modulation for conditional generation
|
50 |
+
use_fused_rms_norm: Whether to use fused RMS normalization
|
51 |
+
use_fused_swiglu: Whether to use fused SwiGLU activation
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
dim: int,
|
57 |
+
num_attention_heads: int,
|
58 |
+
num_kv_heads: int,
|
59 |
+
multiple_of: int,
|
60 |
+
ffn_dim_multiplier: float,
|
61 |
+
norm_eps: float,
|
62 |
+
modulation: bool = True,
|
63 |
+
use_fused_rms_norm: bool = True,
|
64 |
+
use_fused_swiglu: bool = True,
|
65 |
+
) -> None:
|
66 |
+
"""Initialize the transformer block."""
|
67 |
+
super().__init__()
|
68 |
+
self.head_dim = dim // num_attention_heads
|
69 |
+
self.modulation = modulation
|
70 |
+
|
71 |
+
# Initialize attention layer
|
72 |
+
self.attn = Attention(
|
73 |
+
query_dim=dim,
|
74 |
+
cross_attention_dim=None,
|
75 |
+
dim_head=dim // num_attention_heads,
|
76 |
+
qk_norm="rms_norm",
|
77 |
+
heads=num_attention_heads,
|
78 |
+
kv_heads=num_kv_heads,
|
79 |
+
eps=1e-5,
|
80 |
+
bias=False,
|
81 |
+
out_bias=False,
|
82 |
+
processor=OmniGen2AttnProcessorFlash2Varlen(),
|
83 |
+
)
|
84 |
+
|
85 |
+
# Initialize feed-forward network
|
86 |
+
self.feed_forward = LuminaFeedForward(
|
87 |
+
dim=dim,
|
88 |
+
inner_dim=4 * dim,
|
89 |
+
multiple_of=multiple_of,
|
90 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
91 |
+
use_fused_swiglu=use_fused_swiglu,
|
92 |
+
)
|
93 |
+
|
94 |
+
# Initialize normalization layers
|
95 |
+
if modulation:
|
96 |
+
self.norm1 = LuminaRMSNormZero(
|
97 |
+
embedding_dim=dim,
|
98 |
+
norm_eps=norm_eps,
|
99 |
+
norm_elementwise_affine=True,
|
100 |
+
use_fused_rms_norm=use_fused_rms_norm,
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
if use_fused_rms_norm:
|
104 |
+
if not FUSEDRMSNORM_AVALIBLE:
|
105 |
+
raise ImportError("FusedRMSNorm is not available")
|
106 |
+
self.norm1 = FusedRMSNorm(dim, eps=norm_eps)
|
107 |
+
else:
|
108 |
+
self.norm1 = nn.RMSNorm(dim, eps=norm_eps)
|
109 |
+
|
110 |
+
if use_fused_rms_norm:
|
111 |
+
if not FUSEDRMSNORM_AVALIBLE:
|
112 |
+
raise ImportError("FusedRMSNorm is not available")
|
113 |
+
self.ffn_norm1 = FusedRMSNorm(dim, eps=norm_eps)
|
114 |
+
self.norm2 = FusedRMSNorm(dim, eps=norm_eps)
|
115 |
+
self.ffn_norm2 = FusedRMSNorm(dim, eps=norm_eps)
|
116 |
+
else:
|
117 |
+
self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps)
|
118 |
+
self.norm2 = nn.RMSNorm(dim, eps=norm_eps)
|
119 |
+
self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps)
|
120 |
+
|
121 |
+
self.initialize_weights()
|
122 |
+
|
123 |
+
def initialize_weights(self) -> None:
|
124 |
+
"""
|
125 |
+
Initialize the weights of the transformer block.
|
126 |
+
|
127 |
+
Uses Xavier uniform initialization for linear layers and zero initialization for biases.
|
128 |
+
"""
|
129 |
+
nn.init.xavier_uniform_(self.attn.to_q.weight)
|
130 |
+
nn.init.xavier_uniform_(self.attn.to_k.weight)
|
131 |
+
nn.init.xavier_uniform_(self.attn.to_v.weight)
|
132 |
+
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
|
133 |
+
|
134 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
|
135 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
|
136 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
|
137 |
+
|
138 |
+
if self.modulation:
|
139 |
+
nn.init.zeros_(self.norm1.linear.weight)
|
140 |
+
nn.init.zeros_(self.norm1.linear.bias)
|
141 |
+
|
142 |
+
def forward(
|
143 |
+
self,
|
144 |
+
hidden_states: torch.Tensor,
|
145 |
+
attention_mask: torch.Tensor,
|
146 |
+
image_rotary_emb: torch.Tensor,
|
147 |
+
temb: Optional[torch.Tensor] = None,
|
148 |
+
) -> torch.Tensor:
|
149 |
+
"""
|
150 |
+
Forward pass of the transformer block.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
hidden_states: Input hidden states tensor
|
154 |
+
attention_mask: Attention mask tensor
|
155 |
+
image_rotary_emb: Rotary embeddings for image tokens
|
156 |
+
temb: Optional timestep embedding tensor
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
torch.Tensor: Output hidden states after transformer block processing
|
160 |
+
"""
|
161 |
+
if self.modulation:
|
162 |
+
if temb is None:
|
163 |
+
raise ValueError("temb must be provided when modulation is enabled")
|
164 |
+
|
165 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
166 |
+
attn_output = self.attn(
|
167 |
+
hidden_states=norm_hidden_states,
|
168 |
+
encoder_hidden_states=norm_hidden_states,
|
169 |
+
attention_mask=attention_mask,
|
170 |
+
image_rotary_emb=image_rotary_emb,
|
171 |
+
)
|
172 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
173 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
174 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
175 |
+
else:
|
176 |
+
norm_hidden_states = self.norm1(hidden_states)
|
177 |
+
attn_output = self.attn(
|
178 |
+
hidden_states=norm_hidden_states,
|
179 |
+
encoder_hidden_states=norm_hidden_states,
|
180 |
+
attention_mask=attention_mask,
|
181 |
+
image_rotary_emb=image_rotary_emb,
|
182 |
+
)
|
183 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
184 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
185 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
186 |
+
|
187 |
+
return hidden_states
|
188 |
+
|
189 |
+
|
190 |
+
class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
191 |
+
"""
|
192 |
+
OmniGen2 Transformer 2D Model.
|
193 |
+
|
194 |
+
A transformer-based diffusion model for image generation with:
|
195 |
+
- Patch-based image processing
|
196 |
+
- Rotary position embeddings
|
197 |
+
- Multi-head attention
|
198 |
+
- Conditional generation support
|
199 |
+
|
200 |
+
Args:
|
201 |
+
patch_size: Size of image patches
|
202 |
+
in_channels: Number of input channels
|
203 |
+
out_channels: Number of output channels (defaults to in_channels)
|
204 |
+
hidden_size: Size of hidden layers
|
205 |
+
num_layers: Number of transformer layers
|
206 |
+
num_refiner_layers: Number of refiner layers
|
207 |
+
num_attention_heads: Number of attention heads
|
208 |
+
num_kv_heads: Number of key-value heads
|
209 |
+
multiple_of: Multiple of which the hidden dimension should be
|
210 |
+
ffn_dim_multiplier: Multiplier for feed-forward network dimension
|
211 |
+
norm_eps: Epsilon value for normalization layers
|
212 |
+
axes_dim_rope: Dimensions for rotary position embeddings
|
213 |
+
axes_lens: Lengths for rotary position embeddings
|
214 |
+
text_feat_dim: Dimension of text features
|
215 |
+
timestep_scale: Scale factor for timestep embeddings
|
216 |
+
use_fused_rms_norm: Whether to use fused RMS normalization
|
217 |
+
use_fused_swiglu: Whether to use fused SwiGLU activation
|
218 |
+
"""
|
219 |
+
|
220 |
+
_supports_gradient_checkpointing = True
|
221 |
+
_no_split_modules = ["Omnigen2TransformerBlock"]
|
222 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
223 |
+
|
224 |
+
@register_to_config
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
patch_size: int = 2,
|
228 |
+
in_channels: int = 16,
|
229 |
+
out_channels: Optional[int] = None,
|
230 |
+
hidden_size: int = 2304,
|
231 |
+
num_layers: int = 26,
|
232 |
+
num_refiner_layers: int = 2,
|
233 |
+
num_attention_heads: int = 24,
|
234 |
+
num_kv_heads: int = 8,
|
235 |
+
multiple_of: int = 256,
|
236 |
+
ffn_dim_multiplier: Optional[float] = None,
|
237 |
+
norm_eps: float = 1e-5,
|
238 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
239 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
240 |
+
text_feat_dim: int = 1024,
|
241 |
+
timestep_scale: float = 1.0,
|
242 |
+
use_fused_rms_norm: bool = True,
|
243 |
+
use_fused_swiglu: bool = True,
|
244 |
+
) -> None:
|
245 |
+
"""Initialize the OmniGen2 transformer model."""
|
246 |
+
super().__init__()
|
247 |
+
|
248 |
+
# Validate configuration
|
249 |
+
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
|
250 |
+
raise ValueError(
|
251 |
+
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
|
252 |
+
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
|
253 |
+
)
|
254 |
+
|
255 |
+
self.out_channels = out_channels or in_channels
|
256 |
+
|
257 |
+
# Initialize embeddings
|
258 |
+
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
259 |
+
theta=10000,
|
260 |
+
axes_dim=axes_dim_rope,
|
261 |
+
axes_lens=axes_lens,
|
262 |
+
patch_size=patch_size,
|
263 |
+
)
|
264 |
+
|
265 |
+
self.x_embedder = nn.Linear(
|
266 |
+
in_features=patch_size * patch_size * in_channels,
|
267 |
+
out_features=hidden_size,
|
268 |
+
)
|
269 |
+
|
270 |
+
self.ref_image_patch_embedder = nn.Linear(
|
271 |
+
in_features=patch_size * patch_size * in_channels,
|
272 |
+
out_features=hidden_size,
|
273 |
+
)
|
274 |
+
|
275 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
276 |
+
hidden_size=hidden_size,
|
277 |
+
text_feat_dim=text_feat_dim,
|
278 |
+
norm_eps=norm_eps,
|
279 |
+
timestep_scale=timestep_scale,
|
280 |
+
use_fused_rms_norm=use_fused_rms_norm,
|
281 |
+
)
|
282 |
+
|
283 |
+
# Initialize transformer blocks
|
284 |
+
self.noise_refiner = nn.ModuleList([
|
285 |
+
OmniGen2TransformerBlock(
|
286 |
+
hidden_size,
|
287 |
+
num_attention_heads,
|
288 |
+
num_kv_heads,
|
289 |
+
multiple_of,
|
290 |
+
ffn_dim_multiplier,
|
291 |
+
norm_eps,
|
292 |
+
modulation=True,
|
293 |
+
use_fused_rms_norm=use_fused_rms_norm,
|
294 |
+
use_fused_swiglu=use_fused_swiglu,
|
295 |
+
)
|
296 |
+
for _ in range(num_refiner_layers)
|
297 |
+
])
|
298 |
+
|
299 |
+
self.ref_image_refiner = nn.ModuleList([
|
300 |
+
OmniGen2TransformerBlock(
|
301 |
+
hidden_size,
|
302 |
+
num_attention_heads,
|
303 |
+
num_kv_heads,
|
304 |
+
multiple_of,
|
305 |
+
ffn_dim_multiplier,
|
306 |
+
norm_eps,
|
307 |
+
modulation=True,
|
308 |
+
use_fused_rms_norm=use_fused_rms_norm,
|
309 |
+
use_fused_swiglu=use_fused_swiglu,
|
310 |
+
)
|
311 |
+
for _ in range(num_refiner_layers)
|
312 |
+
])
|
313 |
+
|
314 |
+
self.context_refiner = nn.ModuleList(
|
315 |
+
[
|
316 |
+
OmniGen2TransformerBlock(
|
317 |
+
hidden_size,
|
318 |
+
num_attention_heads,
|
319 |
+
num_kv_heads,
|
320 |
+
multiple_of,
|
321 |
+
ffn_dim_multiplier,
|
322 |
+
norm_eps,
|
323 |
+
modulation=False,
|
324 |
+
use_fused_rms_norm=use_fused_rms_norm,
|
325 |
+
use_fused_swiglu=use_fused_swiglu
|
326 |
+
)
|
327 |
+
for _ in range(num_refiner_layers)
|
328 |
+
]
|
329 |
+
)
|
330 |
+
|
331 |
+
# 3. Transformer blocks
|
332 |
+
self.layers = nn.ModuleList(
|
333 |
+
[
|
334 |
+
OmniGen2TransformerBlock(
|
335 |
+
hidden_size,
|
336 |
+
num_attention_heads,
|
337 |
+
num_kv_heads,
|
338 |
+
multiple_of,
|
339 |
+
ffn_dim_multiplier,
|
340 |
+
norm_eps,
|
341 |
+
modulation=True,
|
342 |
+
use_fused_rms_norm=use_fused_rms_norm,
|
343 |
+
use_fused_swiglu=use_fused_swiglu
|
344 |
+
)
|
345 |
+
for _ in range(num_layers)
|
346 |
+
]
|
347 |
+
)
|
348 |
+
|
349 |
+
# 4. Output norm & projection
|
350 |
+
self.norm_out = LuminaLayerNormContinuous(
|
351 |
+
embedding_dim=hidden_size,
|
352 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
353 |
+
elementwise_affine=False,
|
354 |
+
eps=1e-6,
|
355 |
+
bias=True,
|
356 |
+
out_dim=patch_size * patch_size * self.out_channels,
|
357 |
+
use_fused_rms_norm=use_fused_rms_norm,
|
358 |
+
)
|
359 |
+
|
360 |
+
# Add learnable embeddings to distinguish different images
|
361 |
+
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
|
362 |
+
|
363 |
+
self.gradient_checkpointing = False
|
364 |
+
|
365 |
+
self.initialize_weights()
|
366 |
+
|
367 |
+
def initialize_weights(self) -> None:
|
368 |
+
"""
|
369 |
+
Initialize the weights of the model.
|
370 |
+
|
371 |
+
Uses Xavier uniform initialization for linear layers.
|
372 |
+
"""
|
373 |
+
nn.init.xavier_uniform_(self.x_embedder.weight)
|
374 |
+
nn.init.constant_(self.x_embedder.bias, 0.0)
|
375 |
+
|
376 |
+
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
|
377 |
+
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
|
378 |
+
|
379 |
+
nn.init.zeros_(self.norm_out.linear_1.weight)
|
380 |
+
nn.init.zeros_(self.norm_out.linear_1.bias)
|
381 |
+
nn.init.zeros_(self.norm_out.linear_2.weight)
|
382 |
+
nn.init.zeros_(self.norm_out.linear_2.bias)
|
383 |
+
|
384 |
+
nn.init.normal_(self.image_index_embedding, std=0.02)
|
385 |
+
|
386 |
+
def img_patch_embed_and_refine(
|
387 |
+
self,
|
388 |
+
hidden_states,
|
389 |
+
ref_image_hidden_states,
|
390 |
+
padded_img_mask,
|
391 |
+
padded_ref_img_mask,
|
392 |
+
noise_rotary_emb,
|
393 |
+
ref_img_rotary_emb,
|
394 |
+
l_effective_ref_img_len,
|
395 |
+
l_effective_img_len,
|
396 |
+
temb
|
397 |
+
):
|
398 |
+
batch_size = len(hidden_states)
|
399 |
+
max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
|
400 |
+
|
401 |
+
hidden_states = self.x_embedder(hidden_states)
|
402 |
+
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
403 |
+
|
404 |
+
for i in range(batch_size):
|
405 |
+
shift = 0
|
406 |
+
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
407 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
|
408 |
+
shift += ref_img_len
|
409 |
+
|
410 |
+
for layer in self.noise_refiner:
|
411 |
+
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
412 |
+
|
413 |
+
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
|
414 |
+
num_ref_images = len(flat_l_effective_ref_img_len)
|
415 |
+
max_ref_img_len = max(flat_l_effective_ref_img_len)
|
416 |
+
|
417 |
+
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
|
418 |
+
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
|
419 |
+
batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
|
420 |
+
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
|
421 |
+
|
422 |
+
# sequence of ref imgs to batch
|
423 |
+
idx = 0
|
424 |
+
for i in range(batch_size):
|
425 |
+
shift = 0
|
426 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
427 |
+
batch_ref_img_mask[idx, :ref_img_len] = True
|
428 |
+
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
|
429 |
+
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
|
430 |
+
batch_temb[idx] = temb[i]
|
431 |
+
shift += ref_img_len
|
432 |
+
idx += 1
|
433 |
+
|
434 |
+
# refine ref imgs separately
|
435 |
+
for layer in self.ref_image_refiner:
|
436 |
+
batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
|
437 |
+
|
438 |
+
# batch of ref imgs to sequence
|
439 |
+
idx = 0
|
440 |
+
for i in range(batch_size):
|
441 |
+
shift = 0
|
442 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
443 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
|
444 |
+
shift += ref_img_len
|
445 |
+
idx += 1
|
446 |
+
|
447 |
+
combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
|
448 |
+
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
|
449 |
+
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
|
450 |
+
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
|
451 |
+
|
452 |
+
return combined_img_hidden_states
|
453 |
+
|
454 |
+
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
455 |
+
batch_size = len(hidden_states)
|
456 |
+
p = self.config.patch_size
|
457 |
+
device = hidden_states[0].device
|
458 |
+
|
459 |
+
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
460 |
+
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
461 |
+
|
462 |
+
if ref_image_hidden_states is not None:
|
463 |
+
ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
|
464 |
+
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
465 |
+
else:
|
466 |
+
ref_img_sizes = [None for _ in range(batch_size)]
|
467 |
+
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
468 |
+
|
469 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
470 |
+
max_img_len = max(l_effective_img_len)
|
471 |
+
|
472 |
+
# ref image patch embeddings
|
473 |
+
flat_ref_img_hidden_states = []
|
474 |
+
for i in range(batch_size):
|
475 |
+
if ref_img_sizes[i] is not None:
|
476 |
+
imgs = []
|
477 |
+
for ref_img in ref_image_hidden_states[i]:
|
478 |
+
C, H, W = ref_img.size()
|
479 |
+
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
480 |
+
imgs.append(ref_img)
|
481 |
+
|
482 |
+
img = torch.cat(imgs, dim=0)
|
483 |
+
flat_ref_img_hidden_states.append(img)
|
484 |
+
else:
|
485 |
+
flat_ref_img_hidden_states.append(None)
|
486 |
+
|
487 |
+
# image patch embeddings
|
488 |
+
flat_hidden_states = []
|
489 |
+
for i in range(batch_size):
|
490 |
+
img = hidden_states[i]
|
491 |
+
C, H, W = img.size()
|
492 |
+
|
493 |
+
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
494 |
+
flat_hidden_states.append(img)
|
495 |
+
|
496 |
+
padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
497 |
+
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
|
498 |
+
for i in range(batch_size):
|
499 |
+
if ref_img_sizes[i] is not None:
|
500 |
+
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
|
501 |
+
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
|
502 |
+
|
503 |
+
padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
504 |
+
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
505 |
+
for i in range(batch_size):
|
506 |
+
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
|
507 |
+
padded_img_mask[i, :l_effective_img_len[i]] = True
|
508 |
+
|
509 |
+
return (
|
510 |
+
padded_hidden_states,
|
511 |
+
padded_ref_img_hidden_states,
|
512 |
+
padded_img_mask,
|
513 |
+
padded_ref_img_mask,
|
514 |
+
l_effective_ref_img_len,
|
515 |
+
l_effective_img_len,
|
516 |
+
ref_img_sizes,
|
517 |
+
img_sizes,
|
518 |
+
)
|
519 |
+
|
520 |
+
def forward(
|
521 |
+
self,
|
522 |
+
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
|
523 |
+
timestep: torch.Tensor,
|
524 |
+
text_hidden_states: torch.Tensor,
|
525 |
+
freqs_cis: torch.Tensor,
|
526 |
+
text_attention_mask: torch.Tensor,
|
527 |
+
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
|
528 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
529 |
+
return_dict: bool = False,
|
530 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
531 |
+
if attention_kwargs is not None:
|
532 |
+
attention_kwargs = attention_kwargs.copy()
|
533 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
534 |
+
else:
|
535 |
+
lora_scale = 1.0
|
536 |
+
|
537 |
+
if USE_PEFT_BACKEND:
|
538 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
539 |
+
scale_lora_layers(self, lora_scale)
|
540 |
+
else:
|
541 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
542 |
+
logger.warning(
|
543 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
544 |
+
)
|
545 |
+
|
546 |
+
# 1. Condition, positional & patch embedding
|
547 |
+
batch_size = len(hidden_states)
|
548 |
+
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
|
549 |
+
|
550 |
+
if is_hidden_states_tensor:
|
551 |
+
assert hidden_states.ndim == 4
|
552 |
+
hidden_states = [_hidden_states for _hidden_states in hidden_states]
|
553 |
+
|
554 |
+
device = hidden_states[0].device
|
555 |
+
|
556 |
+
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
557 |
+
|
558 |
+
(
|
559 |
+
hidden_states,
|
560 |
+
ref_image_hidden_states,
|
561 |
+
img_mask,
|
562 |
+
ref_img_mask,
|
563 |
+
l_effective_ref_img_len,
|
564 |
+
l_effective_img_len,
|
565 |
+
ref_img_sizes,
|
566 |
+
img_sizes,
|
567 |
+
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
568 |
+
|
569 |
+
(
|
570 |
+
context_rotary_emb,
|
571 |
+
ref_img_rotary_emb,
|
572 |
+
noise_rotary_emb,
|
573 |
+
rotary_emb,
|
574 |
+
encoder_seq_lengths,
|
575 |
+
seq_lengths,
|
576 |
+
) = self.rope_embedder(
|
577 |
+
freqs_cis,
|
578 |
+
text_attention_mask,
|
579 |
+
l_effective_ref_img_len,
|
580 |
+
l_effective_img_len,
|
581 |
+
ref_img_sizes,
|
582 |
+
img_sizes,
|
583 |
+
device,
|
584 |
+
)
|
585 |
+
|
586 |
+
# 2. Context refinement
|
587 |
+
for layer in self.context_refiner:
|
588 |
+
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
589 |
+
|
590 |
+
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
591 |
+
hidden_states,
|
592 |
+
ref_image_hidden_states,
|
593 |
+
img_mask,
|
594 |
+
ref_img_mask,
|
595 |
+
noise_rotary_emb,
|
596 |
+
ref_img_rotary_emb,
|
597 |
+
l_effective_ref_img_len,
|
598 |
+
l_effective_img_len,
|
599 |
+
temb,
|
600 |
+
)
|
601 |
+
|
602 |
+
# 3. Joint Transformer blocks
|
603 |
+
max_seq_len = max(seq_lengths)
|
604 |
+
|
605 |
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
606 |
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
607 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
608 |
+
attention_mask[i, :seq_len] = True
|
609 |
+
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
|
610 |
+
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
|
611 |
+
|
612 |
+
hidden_states = joint_hidden_states
|
613 |
+
|
614 |
+
for layer_idx, layer in enumerate(self.layers):
|
615 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
616 |
+
hidden_states = self._gradient_checkpointing_func(
|
617 |
+
layer, hidden_states, attention_mask, rotary_emb, temb
|
618 |
+
)
|
619 |
+
else:
|
620 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
621 |
+
|
622 |
+
# 4. Output norm & projection
|
623 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
624 |
+
|
625 |
+
p = self.config.patch_size
|
626 |
+
output = []
|
627 |
+
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
|
628 |
+
height, width = img_size
|
629 |
+
output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
|
630 |
+
if is_hidden_states_tensor:
|
631 |
+
output = torch.stack(output, dim=0)
|
632 |
+
|
633 |
+
if USE_PEFT_BACKEND:
|
634 |
+
# remove `lora_scale` from each PEFT layer
|
635 |
+
unscale_lora_layers(self, lora_scale)
|
636 |
+
|
637 |
+
if not return_dict:
|
638 |
+
return output
|
639 |
+
return Transformer2DModelOutput(sample=output)
|
omnigen2/ops/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
omnigen2/ops/triton/__init__.py
ADDED
File without changes
|
omnigen2/ops/triton/layer_norm.py
ADDED
@@ -0,0 +1,1257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Implement dropout + residual + layer_norm / rms_norm.
|
3 |
+
|
4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
import triton
|
15 |
+
import triton.language as tl
|
16 |
+
|
17 |
+
|
18 |
+
from typing import Callable
|
19 |
+
|
20 |
+
|
21 |
+
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
|
22 |
+
def decorator(*args, **kwargs):
|
23 |
+
if cuda_amp_deprecated:
|
24 |
+
kwargs["device_type"] = "cuda"
|
25 |
+
return dec(*args, **kwargs)
|
26 |
+
return decorator
|
27 |
+
|
28 |
+
|
29 |
+
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
|
30 |
+
deprecated = True
|
31 |
+
from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
|
32 |
+
else:
|
33 |
+
deprecated = False
|
34 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
35 |
+
|
36 |
+
custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
|
37 |
+
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
|
38 |
+
|
39 |
+
|
40 |
+
def triton_autotune_configs():
|
41 |
+
# Return configs with a valid warp count for the current device
|
42 |
+
configs=[]
|
43 |
+
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
|
44 |
+
max_threads_per_block=1024
|
45 |
+
# Default to warp size 32 if not defined by device
|
46 |
+
warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
|
47 |
+
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
|
48 |
+
warp_count=1
|
49 |
+
while warp_count*warp_size <= max_threads_per_block:
|
50 |
+
configs.append(triton.Config({}, num_warps=warp_count))
|
51 |
+
warp_count*=2
|
52 |
+
return configs
|
53 |
+
|
54 |
+
def layer_norm_ref(
|
55 |
+
x,
|
56 |
+
weight,
|
57 |
+
bias,
|
58 |
+
residual=None,
|
59 |
+
x1=None,
|
60 |
+
weight1=None,
|
61 |
+
bias1=None,
|
62 |
+
eps=1e-6,
|
63 |
+
dropout_p=0.0,
|
64 |
+
rowscale=None,
|
65 |
+
prenorm=False,
|
66 |
+
zero_centered_weight=False,
|
67 |
+
dropout_mask=None,
|
68 |
+
dropout_mask1=None,
|
69 |
+
upcast=False,
|
70 |
+
):
|
71 |
+
dtype = x.dtype
|
72 |
+
if upcast:
|
73 |
+
x = x.float()
|
74 |
+
weight = weight.float()
|
75 |
+
bias = bias.float() if bias is not None else None
|
76 |
+
residual = residual.float() if residual is not None else residual
|
77 |
+
x1 = x1.float() if x1 is not None else None
|
78 |
+
weight1 = weight1.float() if weight1 is not None else None
|
79 |
+
bias1 = bias1.float() if bias1 is not None else None
|
80 |
+
if zero_centered_weight:
|
81 |
+
weight = weight + 1.0
|
82 |
+
if weight1 is not None:
|
83 |
+
weight1 = weight1 + 1.0
|
84 |
+
if x1 is not None:
|
85 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
86 |
+
if rowscale is not None:
|
87 |
+
x = x * rowscale[..., None]
|
88 |
+
if dropout_p > 0.0:
|
89 |
+
if dropout_mask is not None:
|
90 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
91 |
+
else:
|
92 |
+
x = F.dropout(x, p=dropout_p)
|
93 |
+
if x1 is not None:
|
94 |
+
if dropout_mask1 is not None:
|
95 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
96 |
+
else:
|
97 |
+
x1 = F.dropout(x1, p=dropout_p)
|
98 |
+
if x1 is not None:
|
99 |
+
x = x + x1
|
100 |
+
if residual is not None:
|
101 |
+
x = (x + residual).to(x.dtype)
|
102 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
103 |
+
dtype
|
104 |
+
)
|
105 |
+
if weight1 is None:
|
106 |
+
return out if not prenorm else (out, x)
|
107 |
+
else:
|
108 |
+
out1 = F.layer_norm(
|
109 |
+
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
110 |
+
).to(dtype)
|
111 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
112 |
+
|
113 |
+
|
114 |
+
def rms_norm_ref(
|
115 |
+
x,
|
116 |
+
weight,
|
117 |
+
bias,
|
118 |
+
residual=None,
|
119 |
+
x1=None,
|
120 |
+
weight1=None,
|
121 |
+
bias1=None,
|
122 |
+
eps=1e-6,
|
123 |
+
dropout_p=0.0,
|
124 |
+
rowscale=None,
|
125 |
+
prenorm=False,
|
126 |
+
zero_centered_weight=False,
|
127 |
+
dropout_mask=None,
|
128 |
+
dropout_mask1=None,
|
129 |
+
upcast=False,
|
130 |
+
):
|
131 |
+
dtype = x.dtype
|
132 |
+
if upcast:
|
133 |
+
x = x.float()
|
134 |
+
weight = weight.float()
|
135 |
+
bias = bias.float() if bias is not None else None
|
136 |
+
residual = residual.float() if residual is not None else residual
|
137 |
+
x1 = x1.float() if x1 is not None else None
|
138 |
+
weight1 = weight1.float() if weight1 is not None else None
|
139 |
+
bias1 = bias1.float() if bias1 is not None else None
|
140 |
+
if zero_centered_weight:
|
141 |
+
weight = weight + 1.0
|
142 |
+
if weight1 is not None:
|
143 |
+
weight1 = weight1 + 1.0
|
144 |
+
if x1 is not None:
|
145 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
146 |
+
if rowscale is not None:
|
147 |
+
x = x * rowscale[..., None]
|
148 |
+
if dropout_p > 0.0:
|
149 |
+
if dropout_mask is not None:
|
150 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
151 |
+
else:
|
152 |
+
x = F.dropout(x, p=dropout_p)
|
153 |
+
if x1 is not None:
|
154 |
+
if dropout_mask1 is not None:
|
155 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
156 |
+
else:
|
157 |
+
x1 = F.dropout(x1, p=dropout_p)
|
158 |
+
if x1 is not None:
|
159 |
+
x = x + x1
|
160 |
+
if residual is not None:
|
161 |
+
x = (x + residual).to(x.dtype)
|
162 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
163 |
+
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
|
164 |
+
if weight1 is None:
|
165 |
+
return out if not prenorm else (out, x)
|
166 |
+
else:
|
167 |
+
out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
|
168 |
+
dtype
|
169 |
+
)
|
170 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
171 |
+
|
172 |
+
|
173 |
+
@triton.autotune(
|
174 |
+
configs=triton_autotune_configs(),
|
175 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
176 |
+
)
|
177 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
178 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
179 |
+
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
180 |
+
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
181 |
+
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
182 |
+
@triton.jit
|
183 |
+
def _layer_norm_fwd_1pass_kernel(
|
184 |
+
X, # pointer to the input
|
185 |
+
Y, # pointer to the output
|
186 |
+
W, # pointer to the weights
|
187 |
+
B, # pointer to the biases
|
188 |
+
RESIDUAL, # pointer to the residual
|
189 |
+
X1,
|
190 |
+
W1,
|
191 |
+
B1,
|
192 |
+
Y1,
|
193 |
+
RESIDUAL_OUT, # pointer to the residual
|
194 |
+
ROWSCALE,
|
195 |
+
SEEDS, # Dropout seeds for each row
|
196 |
+
DROPOUT_MASK,
|
197 |
+
Mean, # pointer to the mean
|
198 |
+
Rstd, # pointer to the 1/std
|
199 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
200 |
+
stride_y_row,
|
201 |
+
stride_res_row,
|
202 |
+
stride_res_out_row,
|
203 |
+
stride_x1_row,
|
204 |
+
stride_y1_row,
|
205 |
+
M, # number of rows in X
|
206 |
+
N, # number of columns in X
|
207 |
+
eps, # epsilon to avoid division by zero
|
208 |
+
dropout_p, # Dropout probability
|
209 |
+
zero_centered_weight, # If true, add 1.0 to the weight
|
210 |
+
IS_RMS_NORM: tl.constexpr,
|
211 |
+
BLOCK_N: tl.constexpr,
|
212 |
+
HAS_RESIDUAL: tl.constexpr,
|
213 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
214 |
+
HAS_BIAS: tl.constexpr,
|
215 |
+
HAS_DROPOUT: tl.constexpr,
|
216 |
+
STORE_DROPOUT_MASK: tl.constexpr,
|
217 |
+
HAS_ROWSCALE: tl.constexpr,
|
218 |
+
HAS_X1: tl.constexpr,
|
219 |
+
HAS_W1: tl.constexpr,
|
220 |
+
HAS_B1: tl.constexpr,
|
221 |
+
):
|
222 |
+
# Map the program id to the row of X and Y it should compute.
|
223 |
+
row = tl.program_id(0)
|
224 |
+
X += row * stride_x_row
|
225 |
+
Y += row * stride_y_row
|
226 |
+
if HAS_RESIDUAL:
|
227 |
+
RESIDUAL += row * stride_res_row
|
228 |
+
if STORE_RESIDUAL_OUT:
|
229 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
230 |
+
if HAS_X1:
|
231 |
+
X1 += row * stride_x1_row
|
232 |
+
if HAS_W1:
|
233 |
+
Y1 += row * stride_y1_row
|
234 |
+
# Compute mean and variance
|
235 |
+
cols = tl.arange(0, BLOCK_N)
|
236 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
237 |
+
if HAS_ROWSCALE:
|
238 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
239 |
+
x *= rowscale
|
240 |
+
if HAS_DROPOUT:
|
241 |
+
# Compute dropout mask
|
242 |
+
# 7 rounds is good enough, and reduces register pressure
|
243 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
244 |
+
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
245 |
+
if STORE_DROPOUT_MASK:
|
246 |
+
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
247 |
+
if HAS_X1:
|
248 |
+
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
249 |
+
if HAS_ROWSCALE:
|
250 |
+
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
251 |
+
x1 *= rowscale
|
252 |
+
if HAS_DROPOUT:
|
253 |
+
# Compute dropout mask
|
254 |
+
# 7 rounds is good enough, and reduces register pressure
|
255 |
+
keep_mask = (
|
256 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
257 |
+
)
|
258 |
+
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
259 |
+
if STORE_DROPOUT_MASK:
|
260 |
+
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
261 |
+
x += x1
|
262 |
+
if HAS_RESIDUAL:
|
263 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
264 |
+
x += residual
|
265 |
+
if STORE_RESIDUAL_OUT:
|
266 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
267 |
+
if not IS_RMS_NORM:
|
268 |
+
mean = tl.sum(x, axis=0) / N
|
269 |
+
tl.store(Mean + row, mean)
|
270 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
271 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
272 |
+
else:
|
273 |
+
xbar = tl.where(cols < N, x, 0.0)
|
274 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
275 |
+
rstd = 1 / tl.sqrt(var + eps)
|
276 |
+
tl.store(Rstd + row, rstd)
|
277 |
+
# Normalize and apply linear transformation
|
278 |
+
mask = cols < N
|
279 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
280 |
+
if zero_centered_weight:
|
281 |
+
w += 1.0
|
282 |
+
if HAS_BIAS:
|
283 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
284 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
285 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
286 |
+
# Write output
|
287 |
+
tl.store(Y + cols, y, mask=mask)
|
288 |
+
if HAS_W1:
|
289 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
290 |
+
if zero_centered_weight:
|
291 |
+
w1 += 1.0
|
292 |
+
if HAS_B1:
|
293 |
+
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
294 |
+
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
295 |
+
tl.store(Y1 + cols, y1, mask=mask)
|
296 |
+
|
297 |
+
|
298 |
+
def _layer_norm_fwd(
|
299 |
+
x,
|
300 |
+
weight,
|
301 |
+
bias,
|
302 |
+
eps,
|
303 |
+
residual=None,
|
304 |
+
x1=None,
|
305 |
+
weight1=None,
|
306 |
+
bias1=None,
|
307 |
+
dropout_p=0.0,
|
308 |
+
rowscale=None,
|
309 |
+
out_dtype=None,
|
310 |
+
residual_dtype=None,
|
311 |
+
zero_centered_weight=False,
|
312 |
+
is_rms_norm=False,
|
313 |
+
return_dropout_mask=False,
|
314 |
+
out=None,
|
315 |
+
residual_out=None
|
316 |
+
):
|
317 |
+
if residual is not None:
|
318 |
+
residual_dtype = residual.dtype
|
319 |
+
M, N = x.shape
|
320 |
+
assert x.stride(-1) == 1
|
321 |
+
if residual is not None:
|
322 |
+
assert residual.stride(-1) == 1
|
323 |
+
assert residual.shape == (M, N)
|
324 |
+
assert weight.shape == (N,)
|
325 |
+
assert weight.stride(-1) == 1
|
326 |
+
if bias is not None:
|
327 |
+
assert bias.stride(-1) == 1
|
328 |
+
assert bias.shape == (N,)
|
329 |
+
if x1 is not None:
|
330 |
+
assert x1.shape == x.shape
|
331 |
+
assert rowscale is None
|
332 |
+
assert x1.stride(-1) == 1
|
333 |
+
if weight1 is not None:
|
334 |
+
assert weight1.shape == (N,)
|
335 |
+
assert weight1.stride(-1) == 1
|
336 |
+
if bias1 is not None:
|
337 |
+
assert bias1.shape == (N,)
|
338 |
+
assert bias1.stride(-1) == 1
|
339 |
+
if rowscale is not None:
|
340 |
+
assert rowscale.is_contiguous()
|
341 |
+
assert rowscale.shape == (M,)
|
342 |
+
# allocate output
|
343 |
+
if out is None:
|
344 |
+
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
345 |
+
else:
|
346 |
+
assert out.shape == x.shape
|
347 |
+
assert out.stride(-1) == 1
|
348 |
+
if weight1 is not None:
|
349 |
+
y1 = torch.empty_like(out)
|
350 |
+
assert y1.stride(-1) == 1
|
351 |
+
else:
|
352 |
+
y1 = None
|
353 |
+
if (
|
354 |
+
residual is not None
|
355 |
+
or (residual_dtype is not None and residual_dtype != x.dtype)
|
356 |
+
or dropout_p > 0.0
|
357 |
+
or rowscale is not None
|
358 |
+
or x1 is not None
|
359 |
+
):
|
360 |
+
if residual_out is None:
|
361 |
+
residual_out = torch.empty(
|
362 |
+
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
assert residual_out.shape == x.shape
|
366 |
+
assert residual_out.stride(-1) == 1
|
367 |
+
else:
|
368 |
+
residual_out = None
|
369 |
+
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
370 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
371 |
+
if dropout_p > 0.0:
|
372 |
+
seeds = torch.randint(
|
373 |
+
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
seeds = None
|
377 |
+
if return_dropout_mask and dropout_p > 0.0:
|
378 |
+
dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
|
379 |
+
else:
|
380 |
+
dropout_mask = None
|
381 |
+
# Less than 64KB per feature: enqueue fused kernel
|
382 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
383 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
384 |
+
if N > BLOCK_N:
|
385 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
386 |
+
with torch.cuda.device(x.device.index):
|
387 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
388 |
+
x,
|
389 |
+
out,
|
390 |
+
weight,
|
391 |
+
bias,
|
392 |
+
residual,
|
393 |
+
x1,
|
394 |
+
weight1,
|
395 |
+
bias1,
|
396 |
+
y1,
|
397 |
+
residual_out,
|
398 |
+
rowscale,
|
399 |
+
seeds,
|
400 |
+
dropout_mask,
|
401 |
+
mean,
|
402 |
+
rstd,
|
403 |
+
x.stride(0),
|
404 |
+
out.stride(0),
|
405 |
+
residual.stride(0) if residual is not None else 0,
|
406 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
407 |
+
x1.stride(0) if x1 is not None else 0,
|
408 |
+
y1.stride(0) if y1 is not None else 0,
|
409 |
+
M,
|
410 |
+
N,
|
411 |
+
eps,
|
412 |
+
dropout_p,
|
413 |
+
zero_centered_weight,
|
414 |
+
is_rms_norm,
|
415 |
+
BLOCK_N,
|
416 |
+
residual is not None,
|
417 |
+
residual_out is not None,
|
418 |
+
bias is not None,
|
419 |
+
dropout_p > 0.0,
|
420 |
+
dropout_mask is not None,
|
421 |
+
rowscale is not None,
|
422 |
+
)
|
423 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
424 |
+
if dropout_mask is not None and x1 is not None:
|
425 |
+
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
426 |
+
else:
|
427 |
+
dropout_mask1 = None
|
428 |
+
return (
|
429 |
+
out,
|
430 |
+
y1,
|
431 |
+
mean,
|
432 |
+
rstd,
|
433 |
+
residual_out if residual_out is not None else x,
|
434 |
+
seeds,
|
435 |
+
dropout_mask,
|
436 |
+
dropout_mask1,
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
@triton.autotune(
|
441 |
+
configs=triton_autotune_configs(),
|
442 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
|
443 |
+
)
|
444 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
445 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
446 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
447 |
+
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
448 |
+
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
449 |
+
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
450 |
+
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
451 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
452 |
+
@triton.jit
|
453 |
+
def _layer_norm_bwd_kernel(
|
454 |
+
X, # pointer to the input
|
455 |
+
W, # pointer to the weights
|
456 |
+
B, # pointer to the biases
|
457 |
+
Y, # pointer to the output to be recomputed
|
458 |
+
DY, # pointer to the output gradient
|
459 |
+
DX, # pointer to the input gradient
|
460 |
+
DW, # pointer to the partial sum of weights gradient
|
461 |
+
DB, # pointer to the partial sum of biases gradient
|
462 |
+
DRESIDUAL,
|
463 |
+
W1,
|
464 |
+
DY1,
|
465 |
+
DX1,
|
466 |
+
DW1,
|
467 |
+
DB1,
|
468 |
+
DRESIDUAL_IN,
|
469 |
+
ROWSCALE,
|
470 |
+
SEEDS,
|
471 |
+
Mean, # pointer to the mean
|
472 |
+
Rstd, # pointer to the 1/std
|
473 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
474 |
+
stride_y_row,
|
475 |
+
stride_dy_row,
|
476 |
+
stride_dx_row,
|
477 |
+
stride_dres_row,
|
478 |
+
stride_dy1_row,
|
479 |
+
stride_dx1_row,
|
480 |
+
stride_dres_in_row,
|
481 |
+
M, # number of rows in X
|
482 |
+
N, # number of columns in X
|
483 |
+
eps, # epsilon to avoid division by zero
|
484 |
+
dropout_p,
|
485 |
+
zero_centered_weight,
|
486 |
+
rows_per_program,
|
487 |
+
IS_RMS_NORM: tl.constexpr,
|
488 |
+
BLOCK_N: tl.constexpr,
|
489 |
+
HAS_DRESIDUAL: tl.constexpr,
|
490 |
+
STORE_DRESIDUAL: tl.constexpr,
|
491 |
+
HAS_BIAS: tl.constexpr,
|
492 |
+
HAS_DROPOUT: tl.constexpr,
|
493 |
+
HAS_ROWSCALE: tl.constexpr,
|
494 |
+
HAS_DY1: tl.constexpr,
|
495 |
+
HAS_DX1: tl.constexpr,
|
496 |
+
HAS_B1: tl.constexpr,
|
497 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
498 |
+
):
|
499 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
500 |
+
row_block_id = tl.program_id(0)
|
501 |
+
row_start = row_block_id * rows_per_program
|
502 |
+
# Do not early exit if row_start >= M, because we need to write DW and DB
|
503 |
+
cols = tl.arange(0, BLOCK_N)
|
504 |
+
mask = cols < N
|
505 |
+
X += row_start * stride_x_row
|
506 |
+
if HAS_DRESIDUAL:
|
507 |
+
DRESIDUAL += row_start * stride_dres_row
|
508 |
+
if STORE_DRESIDUAL:
|
509 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
510 |
+
DY += row_start * stride_dy_row
|
511 |
+
DX += row_start * stride_dx_row
|
512 |
+
if HAS_DY1:
|
513 |
+
DY1 += row_start * stride_dy1_row
|
514 |
+
if HAS_DX1:
|
515 |
+
DX1 += row_start * stride_dx1_row
|
516 |
+
if RECOMPUTE_OUTPUT:
|
517 |
+
Y += row_start * stride_y_row
|
518 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
519 |
+
if zero_centered_weight:
|
520 |
+
w += 1.0
|
521 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
522 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
523 |
+
if HAS_DY1:
|
524 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
525 |
+
if zero_centered_weight:
|
526 |
+
w1 += 1.0
|
527 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
528 |
+
if HAS_BIAS:
|
529 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
530 |
+
if HAS_DY1:
|
531 |
+
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
532 |
+
if HAS_B1:
|
533 |
+
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
534 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
535 |
+
for row in range(row_start, row_end):
|
536 |
+
# Load data to SRAM
|
537 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
538 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
539 |
+
if HAS_DY1:
|
540 |
+
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
541 |
+
if not IS_RMS_NORM:
|
542 |
+
mean = tl.load(Mean + row)
|
543 |
+
rstd = tl.load(Rstd + row)
|
544 |
+
# Compute dx
|
545 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
546 |
+
xhat = tl.where(mask, xhat, 0.0)
|
547 |
+
if RECOMPUTE_OUTPUT:
|
548 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
549 |
+
tl.store(Y + cols, y, mask=mask)
|
550 |
+
wdy = w * dy
|
551 |
+
dw += dy * xhat
|
552 |
+
if HAS_BIAS:
|
553 |
+
db += dy
|
554 |
+
if HAS_DY1:
|
555 |
+
wdy += w1 * dy1
|
556 |
+
dw1 += dy1 * xhat
|
557 |
+
if HAS_B1:
|
558 |
+
db1 += dy1
|
559 |
+
if not IS_RMS_NORM:
|
560 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
561 |
+
c2 = tl.sum(wdy, axis=0) / N
|
562 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
563 |
+
else:
|
564 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
565 |
+
dx = (wdy - xhat * c1) * rstd
|
566 |
+
if HAS_DRESIDUAL:
|
567 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
568 |
+
dx += dres
|
569 |
+
# Write dx
|
570 |
+
if STORE_DRESIDUAL:
|
571 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
572 |
+
if HAS_DX1:
|
573 |
+
if HAS_DROPOUT:
|
574 |
+
keep_mask = (
|
575 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
576 |
+
)
|
577 |
+
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
578 |
+
else:
|
579 |
+
dx1 = dx
|
580 |
+
tl.store(DX1 + cols, dx1, mask=mask)
|
581 |
+
if HAS_DROPOUT:
|
582 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
583 |
+
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
584 |
+
if HAS_ROWSCALE:
|
585 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
586 |
+
dx *= rowscale
|
587 |
+
tl.store(DX + cols, dx, mask=mask)
|
588 |
+
|
589 |
+
X += stride_x_row
|
590 |
+
if HAS_DRESIDUAL:
|
591 |
+
DRESIDUAL += stride_dres_row
|
592 |
+
if STORE_DRESIDUAL:
|
593 |
+
DRESIDUAL_IN += stride_dres_in_row
|
594 |
+
if RECOMPUTE_OUTPUT:
|
595 |
+
Y += stride_y_row
|
596 |
+
DY += stride_dy_row
|
597 |
+
DX += stride_dx_row
|
598 |
+
if HAS_DY1:
|
599 |
+
DY1 += stride_dy1_row
|
600 |
+
if HAS_DX1:
|
601 |
+
DX1 += stride_dx1_row
|
602 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
603 |
+
if HAS_BIAS:
|
604 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
605 |
+
if HAS_DY1:
|
606 |
+
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
607 |
+
if HAS_B1:
|
608 |
+
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
609 |
+
|
610 |
+
|
611 |
+
def _layer_norm_bwd(
|
612 |
+
dy,
|
613 |
+
x,
|
614 |
+
weight,
|
615 |
+
bias,
|
616 |
+
eps,
|
617 |
+
mean,
|
618 |
+
rstd,
|
619 |
+
dresidual=None,
|
620 |
+
dy1=None,
|
621 |
+
weight1=None,
|
622 |
+
bias1=None,
|
623 |
+
seeds=None,
|
624 |
+
dropout_p=0.0,
|
625 |
+
rowscale=None,
|
626 |
+
has_residual=False,
|
627 |
+
has_x1=False,
|
628 |
+
zero_centered_weight=False,
|
629 |
+
is_rms_norm=False,
|
630 |
+
x_dtype=None,
|
631 |
+
recompute_output=False,
|
632 |
+
):
|
633 |
+
M, N = x.shape
|
634 |
+
assert x.stride(-1) == 1
|
635 |
+
assert dy.stride(-1) == 1
|
636 |
+
assert dy.shape == (M, N)
|
637 |
+
if dresidual is not None:
|
638 |
+
assert dresidual.stride(-1) == 1
|
639 |
+
assert dresidual.shape == (M, N)
|
640 |
+
assert weight.shape == (N,)
|
641 |
+
assert weight.stride(-1) == 1
|
642 |
+
if bias is not None:
|
643 |
+
assert bias.stride(-1) == 1
|
644 |
+
assert bias.shape == (N,)
|
645 |
+
if dy1 is not None:
|
646 |
+
assert weight1 is not None
|
647 |
+
assert dy1.shape == dy.shape
|
648 |
+
assert dy1.stride(-1) == 1
|
649 |
+
if weight1 is not None:
|
650 |
+
assert weight1.shape == (N,)
|
651 |
+
assert weight1.stride(-1) == 1
|
652 |
+
if bias1 is not None:
|
653 |
+
assert bias1.shape == (N,)
|
654 |
+
assert bias1.stride(-1) == 1
|
655 |
+
if seeds is not None:
|
656 |
+
assert seeds.is_contiguous()
|
657 |
+
assert seeds.shape == (M if not has_x1 else M * 2,)
|
658 |
+
if rowscale is not None:
|
659 |
+
assert rowscale.is_contiguous()
|
660 |
+
assert rowscale.shape == (M,)
|
661 |
+
# allocate output
|
662 |
+
dx = (
|
663 |
+
torch.empty_like(x)
|
664 |
+
if x_dtype is None
|
665 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
666 |
+
)
|
667 |
+
dresidual_in = (
|
668 |
+
torch.empty_like(x)
|
669 |
+
if has_residual
|
670 |
+
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
671 |
+
else None
|
672 |
+
)
|
673 |
+
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
674 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
675 |
+
if recompute_output:
|
676 |
+
assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
|
677 |
+
|
678 |
+
# Less than 64KB per feature: enqueue fused kernel
|
679 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
680 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
681 |
+
if N > BLOCK_N:
|
682 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
683 |
+
# Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
|
684 |
+
# latency of the gmem reads/writes, but will increase the time of summing up dw / db.
|
685 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
|
686 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
687 |
+
_db = (
|
688 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
689 |
+
if bias is not None
|
690 |
+
else None
|
691 |
+
)
|
692 |
+
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
693 |
+
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
694 |
+
rows_per_program = math.ceil(M / sm_count)
|
695 |
+
grid = (sm_count,)
|
696 |
+
with torch.cuda.device(x.device.index):
|
697 |
+
_layer_norm_bwd_kernel[grid](
|
698 |
+
x,
|
699 |
+
weight,
|
700 |
+
bias,
|
701 |
+
y,
|
702 |
+
dy,
|
703 |
+
dx,
|
704 |
+
_dw,
|
705 |
+
_db,
|
706 |
+
dresidual,
|
707 |
+
weight1,
|
708 |
+
dy1,
|
709 |
+
dx1,
|
710 |
+
_dw1,
|
711 |
+
_db1,
|
712 |
+
dresidual_in,
|
713 |
+
rowscale,
|
714 |
+
seeds,
|
715 |
+
mean,
|
716 |
+
rstd,
|
717 |
+
x.stride(0),
|
718 |
+
0 if not recompute_output else y.stride(0),
|
719 |
+
dy.stride(0),
|
720 |
+
dx.stride(0),
|
721 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
722 |
+
dy1.stride(0) if dy1 is not None else 0,
|
723 |
+
dx1.stride(0) if dx1 is not None else 0,
|
724 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
725 |
+
M,
|
726 |
+
N,
|
727 |
+
eps,
|
728 |
+
dropout_p,
|
729 |
+
zero_centered_weight,
|
730 |
+
rows_per_program,
|
731 |
+
is_rms_norm,
|
732 |
+
BLOCK_N,
|
733 |
+
dresidual is not None,
|
734 |
+
dresidual_in is not None,
|
735 |
+
bias is not None,
|
736 |
+
dropout_p > 0.0,
|
737 |
+
)
|
738 |
+
dw = _dw.sum(0).to(weight.dtype)
|
739 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
740 |
+
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
741 |
+
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
742 |
+
# Don't need to compute dresidual_in separately in this case
|
743 |
+
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
744 |
+
dresidual_in = dx
|
745 |
+
if has_x1 and dropout_p == 0.0:
|
746 |
+
dx1 = dx
|
747 |
+
return (
|
748 |
+
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
749 |
+
if not recompute_output
|
750 |
+
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
751 |
+
)
|
752 |
+
|
753 |
+
|
754 |
+
class LayerNormFn(torch.autograd.Function):
|
755 |
+
@staticmethod
|
756 |
+
def forward(
|
757 |
+
ctx,
|
758 |
+
x,
|
759 |
+
weight,
|
760 |
+
bias,
|
761 |
+
residual=None,
|
762 |
+
x1=None,
|
763 |
+
weight1=None,
|
764 |
+
bias1=None,
|
765 |
+
eps=1e-6,
|
766 |
+
dropout_p=0.0,
|
767 |
+
rowscale=None,
|
768 |
+
prenorm=False,
|
769 |
+
residual_in_fp32=False,
|
770 |
+
zero_centered_weight=False,
|
771 |
+
is_rms_norm=False,
|
772 |
+
return_dropout_mask=False,
|
773 |
+
out=None,
|
774 |
+
residual_out=None
|
775 |
+
):
|
776 |
+
x_shape_og = x.shape
|
777 |
+
# Check for zero sequence length
|
778 |
+
if x.numel() == 0:
|
779 |
+
ctx.zero_seq_length = True
|
780 |
+
# Only save minimal required tensors for backward
|
781 |
+
# ctx.save_for_backward(weight, bias, weight1, bias1)
|
782 |
+
ctx.x_shape_og = x_shape_og
|
783 |
+
ctx.weight_shape = weight.shape
|
784 |
+
ctx.weight_dtype = weight.dtype
|
785 |
+
ctx.weight_device = weight.device
|
786 |
+
|
787 |
+
ctx.has_bias = bias is not None
|
788 |
+
ctx.bias_shape = bias.shape if bias is not None else None
|
789 |
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
790 |
+
ctx.bias_device = bias.device if bias is not None else None
|
791 |
+
|
792 |
+
ctx.has_weight1 = weight1 is not None
|
793 |
+
ctx.weight1_shape = weight1.shape if weight1 is not None else None
|
794 |
+
ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
|
795 |
+
ctx.weight1_device = weight1.device if weight1 is not None else None
|
796 |
+
|
797 |
+
ctx.has_bias1 = bias1 is not None
|
798 |
+
ctx.bias1_shape = bias1.shape if bias1 is not None else None
|
799 |
+
ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
|
800 |
+
ctx.bias1_device = bias1.device if bias1 is not None else None
|
801 |
+
|
802 |
+
ctx.has_residual = residual is not None
|
803 |
+
ctx.has_x1 = x1 is not None
|
804 |
+
ctx.dropout_p = dropout_p
|
805 |
+
|
806 |
+
# Handle output tensors with correct dtype
|
807 |
+
y = x # Preserve input tensor properties
|
808 |
+
y1 = torch.empty_like(x) if x1 is not None else None
|
809 |
+
|
810 |
+
# Only create residual_out if prenorm is True
|
811 |
+
residual_out = torch.empty(x.shape,
|
812 |
+
dtype=torch.float32 if residual_in_fp32 else x.dtype,
|
813 |
+
device=x.device) if prenorm else None
|
814 |
+
|
815 |
+
# Handle dropout masks
|
816 |
+
dropout_mask = None
|
817 |
+
dropout_mask1 = None
|
818 |
+
if return_dropout_mask:
|
819 |
+
dropout_mask = torch.empty_like(x, dtype=torch.uint8)
|
820 |
+
if x1 is not None:
|
821 |
+
dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
|
822 |
+
|
823 |
+
# Return based on configuration
|
824 |
+
if not return_dropout_mask:
|
825 |
+
if weight1 is None:
|
826 |
+
return y if not prenorm else (y, residual_out)
|
827 |
+
else:
|
828 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
829 |
+
else:
|
830 |
+
if weight1 is None:
|
831 |
+
return ((y, dropout_mask, dropout_mask1) if not prenorm
|
832 |
+
else (y, residual_out, dropout_mask, dropout_mask1))
|
833 |
+
else:
|
834 |
+
return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
|
835 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1))
|
836 |
+
|
837 |
+
ctx.zero_seq_length = False
|
838 |
+
# reshape input data into 2D tensor
|
839 |
+
x = x.reshape(-1, x.shape[-1])
|
840 |
+
if x.stride(-1) != 1:
|
841 |
+
x = x.contiguous()
|
842 |
+
if residual is not None:
|
843 |
+
assert residual.shape == x_shape_og
|
844 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
845 |
+
if residual.stride(-1) != 1:
|
846 |
+
residual = residual.contiguous()
|
847 |
+
if x1 is not None:
|
848 |
+
assert x1.shape == x_shape_og
|
849 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
850 |
+
x1 = x1.reshape(-1, x1.shape[-1])
|
851 |
+
if x1.stride(-1) != 1:
|
852 |
+
x1 = x1.contiguous()
|
853 |
+
weight = weight.contiguous()
|
854 |
+
if bias is not None:
|
855 |
+
bias = bias.contiguous()
|
856 |
+
if weight1 is not None:
|
857 |
+
weight1 = weight1.contiguous()
|
858 |
+
if bias1 is not None:
|
859 |
+
bias1 = bias1.contiguous()
|
860 |
+
if rowscale is not None:
|
861 |
+
rowscale = rowscale.reshape(-1).contiguous()
|
862 |
+
residual_dtype = (
|
863 |
+
residual.dtype
|
864 |
+
if residual is not None
|
865 |
+
else (torch.float32 if residual_in_fp32 else None)
|
866 |
+
)
|
867 |
+
if out is not None:
|
868 |
+
out = out.reshape(-1, out.shape[-1])
|
869 |
+
if residual_out is not None:
|
870 |
+
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
871 |
+
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
872 |
+
x,
|
873 |
+
weight,
|
874 |
+
bias,
|
875 |
+
eps,
|
876 |
+
residual,
|
877 |
+
x1,
|
878 |
+
weight1,
|
879 |
+
bias1,
|
880 |
+
dropout_p=dropout_p,
|
881 |
+
rowscale=rowscale,
|
882 |
+
residual_dtype=residual_dtype,
|
883 |
+
zero_centered_weight=zero_centered_weight,
|
884 |
+
is_rms_norm=is_rms_norm,
|
885 |
+
return_dropout_mask=return_dropout_mask,
|
886 |
+
out=out,
|
887 |
+
residual_out=residual_out
|
888 |
+
)
|
889 |
+
ctx.save_for_backward(
|
890 |
+
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
891 |
+
)
|
892 |
+
ctx.x_shape_og = x_shape_og
|
893 |
+
ctx.eps = eps
|
894 |
+
ctx.dropout_p = dropout_p
|
895 |
+
ctx.is_rms_norm = is_rms_norm
|
896 |
+
ctx.has_residual = residual is not None
|
897 |
+
ctx.has_x1 = x1 is not None
|
898 |
+
ctx.prenorm = prenorm
|
899 |
+
ctx.x_dtype = x.dtype
|
900 |
+
ctx.zero_centered_weight = zero_centered_weight
|
901 |
+
y = y.reshape(x_shape_og)
|
902 |
+
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
903 |
+
residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
|
904 |
+
dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
905 |
+
dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
906 |
+
if not return_dropout_mask:
|
907 |
+
if weight1 is None:
|
908 |
+
return y if not prenorm else (y, residual_out)
|
909 |
+
else:
|
910 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
911 |
+
else:
|
912 |
+
if weight1 is None:
|
913 |
+
return (
|
914 |
+
(y, dropout_mask, dropout_mask1)
|
915 |
+
if not prenorm
|
916 |
+
else (y, residual_out, dropout_mask, dropout_mask1)
|
917 |
+
)
|
918 |
+
else:
|
919 |
+
return (
|
920 |
+
(y, y1, dropout_mask, dropout_mask1)
|
921 |
+
if not prenorm
|
922 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
923 |
+
)
|
924 |
+
|
925 |
+
@staticmethod
|
926 |
+
def backward(ctx, dy, *args):
|
927 |
+
if ctx.zero_seq_length:
|
928 |
+
return (
|
929 |
+
torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
|
930 |
+
torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
|
931 |
+
torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
|
932 |
+
torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
|
933 |
+
torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
|
934 |
+
torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
|
935 |
+
torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
|
936 |
+
None,
|
937 |
+
None,
|
938 |
+
None,
|
939 |
+
None,
|
940 |
+
None,
|
941 |
+
None,
|
942 |
+
None,
|
943 |
+
None,
|
944 |
+
None,
|
945 |
+
None,
|
946 |
+
)
|
947 |
+
|
948 |
+
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
949 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
950 |
+
if dy.stride(-1) != 1:
|
951 |
+
dy = dy.contiguous()
|
952 |
+
assert dy.shape == x.shape
|
953 |
+
if weight1 is not None:
|
954 |
+
dy1, args = args[0], args[1:]
|
955 |
+
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
956 |
+
if dy1.stride(-1) != 1:
|
957 |
+
dy1 = dy1.contiguous()
|
958 |
+
assert dy1.shape == x.shape
|
959 |
+
else:
|
960 |
+
dy1 = None
|
961 |
+
if ctx.prenorm:
|
962 |
+
dresidual = args[0]
|
963 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
964 |
+
if dresidual.stride(-1) != 1:
|
965 |
+
dresidual = dresidual.contiguous()
|
966 |
+
assert dresidual.shape == x.shape
|
967 |
+
else:
|
968 |
+
dresidual = None
|
969 |
+
|
970 |
+
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
971 |
+
dy,
|
972 |
+
x,
|
973 |
+
weight,
|
974 |
+
bias,
|
975 |
+
ctx.eps,
|
976 |
+
mean,
|
977 |
+
rstd,
|
978 |
+
dresidual,
|
979 |
+
dy1,
|
980 |
+
weight1,
|
981 |
+
bias1,
|
982 |
+
seeds,
|
983 |
+
ctx.dropout_p,
|
984 |
+
rowscale,
|
985 |
+
ctx.has_residual,
|
986 |
+
ctx.has_x1,
|
987 |
+
ctx.zero_centered_weight,
|
988 |
+
ctx.is_rms_norm,
|
989 |
+
x_dtype=ctx.x_dtype,
|
990 |
+
)
|
991 |
+
return (
|
992 |
+
dx.reshape(ctx.x_shape_og),
|
993 |
+
dw,
|
994 |
+
db,
|
995 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
996 |
+
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
997 |
+
dw1,
|
998 |
+
db1,
|
999 |
+
None,
|
1000 |
+
None,
|
1001 |
+
None,
|
1002 |
+
None,
|
1003 |
+
None,
|
1004 |
+
None,
|
1005 |
+
None,
|
1006 |
+
None,
|
1007 |
+
None,
|
1008 |
+
None,
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
|
1012 |
+
def layer_norm_fn(
|
1013 |
+
x,
|
1014 |
+
weight,
|
1015 |
+
bias,
|
1016 |
+
residual=None,
|
1017 |
+
x1=None,
|
1018 |
+
weight1=None,
|
1019 |
+
bias1=None,
|
1020 |
+
eps=1e-6,
|
1021 |
+
dropout_p=0.0,
|
1022 |
+
rowscale=None,
|
1023 |
+
prenorm=False,
|
1024 |
+
residual_in_fp32=False,
|
1025 |
+
zero_centered_weight=False,
|
1026 |
+
is_rms_norm=False,
|
1027 |
+
return_dropout_mask=False,
|
1028 |
+
out=None,
|
1029 |
+
residual_out=None
|
1030 |
+
):
|
1031 |
+
return LayerNormFn.apply(
|
1032 |
+
x,
|
1033 |
+
weight,
|
1034 |
+
bias,
|
1035 |
+
residual,
|
1036 |
+
x1,
|
1037 |
+
weight1,
|
1038 |
+
bias1,
|
1039 |
+
eps,
|
1040 |
+
dropout_p,
|
1041 |
+
rowscale,
|
1042 |
+
prenorm,
|
1043 |
+
residual_in_fp32,
|
1044 |
+
zero_centered_weight,
|
1045 |
+
is_rms_norm,
|
1046 |
+
return_dropout_mask,
|
1047 |
+
out,
|
1048 |
+
residual_out
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
|
1052 |
+
def rms_norm_fn(
|
1053 |
+
x,
|
1054 |
+
weight,
|
1055 |
+
bias,
|
1056 |
+
residual=None,
|
1057 |
+
x1=None,
|
1058 |
+
weight1=None,
|
1059 |
+
bias1=None,
|
1060 |
+
eps=1e-6,
|
1061 |
+
dropout_p=0.0,
|
1062 |
+
rowscale=None,
|
1063 |
+
prenorm=False,
|
1064 |
+
residual_in_fp32=False,
|
1065 |
+
zero_centered_weight=False,
|
1066 |
+
return_dropout_mask=False,
|
1067 |
+
out=None,
|
1068 |
+
residual_out=None
|
1069 |
+
):
|
1070 |
+
return LayerNormFn.apply(
|
1071 |
+
x,
|
1072 |
+
weight,
|
1073 |
+
bias,
|
1074 |
+
residual,
|
1075 |
+
x1,
|
1076 |
+
weight1,
|
1077 |
+
bias1,
|
1078 |
+
eps,
|
1079 |
+
dropout_p,
|
1080 |
+
rowscale,
|
1081 |
+
prenorm,
|
1082 |
+
residual_in_fp32,
|
1083 |
+
zero_centered_weight,
|
1084 |
+
True,
|
1085 |
+
return_dropout_mask,
|
1086 |
+
out,
|
1087 |
+
residual_out
|
1088 |
+
)
|
1089 |
+
|
1090 |
+
|
1091 |
+
class RMSNorm(torch.nn.Module):
|
1092 |
+
|
1093 |
+
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
|
1094 |
+
device=None, dtype=None):
|
1095 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
1096 |
+
super().__init__()
|
1097 |
+
self.eps = eps
|
1098 |
+
if dropout_p > 0.0:
|
1099 |
+
self.drop = torch.nn.Dropout(dropout_p)
|
1100 |
+
else:
|
1101 |
+
self.drop = None
|
1102 |
+
self.zero_centered_weight = zero_centered_weight
|
1103 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
1104 |
+
self.register_parameter("bias", None)
|
1105 |
+
self.reset_parameters()
|
1106 |
+
|
1107 |
+
def reset_parameters(self):
|
1108 |
+
if not self.zero_centered_weight:
|
1109 |
+
torch.nn.init.ones_(self.weight)
|
1110 |
+
else:
|
1111 |
+
torch.nn.init.zeros_(self.weight)
|
1112 |
+
|
1113 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
1114 |
+
return rms_norm_fn(
|
1115 |
+
x,
|
1116 |
+
self.weight,
|
1117 |
+
self.bias,
|
1118 |
+
residual=residual,
|
1119 |
+
eps=self.eps,
|
1120 |
+
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
1121 |
+
prenorm=prenorm,
|
1122 |
+
residual_in_fp32=residual_in_fp32,
|
1123 |
+
zero_centered_weight=self.zero_centered_weight,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
|
1127 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
1128 |
+
@staticmethod
|
1129 |
+
@custom_fwd
|
1130 |
+
def forward(
|
1131 |
+
ctx,
|
1132 |
+
x,
|
1133 |
+
norm_weight,
|
1134 |
+
norm_bias,
|
1135 |
+
linear_weight,
|
1136 |
+
linear_bias,
|
1137 |
+
residual=None,
|
1138 |
+
eps=1e-6,
|
1139 |
+
prenorm=False,
|
1140 |
+
residual_in_fp32=False,
|
1141 |
+
is_rms_norm=False,
|
1142 |
+
):
|
1143 |
+
x_shape_og = x.shape
|
1144 |
+
# reshape input data into 2D tensor
|
1145 |
+
x = x.reshape(-1, x.shape[-1])
|
1146 |
+
if x.stride(-1) != 1:
|
1147 |
+
x = x.contiguous()
|
1148 |
+
if residual is not None:
|
1149 |
+
assert residual.shape == x_shape_og
|
1150 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
1151 |
+
if residual.stride(-1) != 1:
|
1152 |
+
residual = residual.contiguous()
|
1153 |
+
norm_weight = norm_weight.contiguous()
|
1154 |
+
if norm_bias is not None:
|
1155 |
+
norm_bias = norm_bias.contiguous()
|
1156 |
+
residual_dtype = (
|
1157 |
+
residual.dtype
|
1158 |
+
if residual is not None
|
1159 |
+
else (torch.float32 if residual_in_fp32 else None)
|
1160 |
+
)
|
1161 |
+
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
1162 |
+
x,
|
1163 |
+
norm_weight,
|
1164 |
+
norm_bias,
|
1165 |
+
eps,
|
1166 |
+
residual,
|
1167 |
+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
|
1168 |
+
residual_dtype=residual_dtype,
|
1169 |
+
is_rms_norm=is_rms_norm,
|
1170 |
+
)
|
1171 |
+
y = y.reshape(x_shape_og)
|
1172 |
+
dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
|
1173 |
+
linear_weight = linear_weight.to(dtype)
|
1174 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
1175 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
1176 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
1177 |
+
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
1178 |
+
ctx.x_shape_og = x_shape_og
|
1179 |
+
ctx.eps = eps
|
1180 |
+
ctx.is_rms_norm = is_rms_norm
|
1181 |
+
ctx.has_residual = residual is not None
|
1182 |
+
ctx.prenorm = prenorm
|
1183 |
+
ctx.x_dtype = x.dtype
|
1184 |
+
ctx.linear_bias_is_none = linear_bias is None
|
1185 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1186 |
+
|
1187 |
+
@staticmethod
|
1188 |
+
@custom_bwd
|
1189 |
+
def backward(ctx, dout, *args):
|
1190 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1191 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
1192 |
+
dy = F.linear(dout, linear_weight.t())
|
1193 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
1194 |
+
if dy.stride(-1) != 1:
|
1195 |
+
dy = dy.contiguous()
|
1196 |
+
assert dy.shape == x.shape
|
1197 |
+
if ctx.prenorm:
|
1198 |
+
dresidual = args[0]
|
1199 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
1200 |
+
if dresidual.stride(-1) != 1:
|
1201 |
+
dresidual = dresidual.contiguous()
|
1202 |
+
assert dresidual.shape == x.shape
|
1203 |
+
else:
|
1204 |
+
dresidual = None
|
1205 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
1206 |
+
dy,
|
1207 |
+
x,
|
1208 |
+
norm_weight,
|
1209 |
+
norm_bias,
|
1210 |
+
ctx.eps,
|
1211 |
+
mean,
|
1212 |
+
rstd,
|
1213 |
+
dresidual=dresidual,
|
1214 |
+
has_residual=ctx.has_residual,
|
1215 |
+
is_rms_norm=ctx.is_rms_norm,
|
1216 |
+
x_dtype=ctx.x_dtype,
|
1217 |
+
recompute_output=True,
|
1218 |
+
)
|
1219 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
1220 |
+
return (
|
1221 |
+
dx.reshape(ctx.x_shape_og),
|
1222 |
+
dnorm_weight,
|
1223 |
+
dnorm_bias,
|
1224 |
+
dlinear_weight,
|
1225 |
+
dlinear_bias,
|
1226 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
1227 |
+
None,
|
1228 |
+
None,
|
1229 |
+
None,
|
1230 |
+
None,
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
|
1234 |
+
def layer_norm_linear_fn(
|
1235 |
+
x,
|
1236 |
+
norm_weight,
|
1237 |
+
norm_bias,
|
1238 |
+
linear_weight,
|
1239 |
+
linear_bias,
|
1240 |
+
residual=None,
|
1241 |
+
eps=1e-6,
|
1242 |
+
prenorm=False,
|
1243 |
+
residual_in_fp32=False,
|
1244 |
+
is_rms_norm=False,
|
1245 |
+
):
|
1246 |
+
return LayerNormLinearFn.apply(
|
1247 |
+
x,
|
1248 |
+
norm_weight,
|
1249 |
+
norm_bias,
|
1250 |
+
linear_weight,
|
1251 |
+
linear_bias,
|
1252 |
+
residual,
|
1253 |
+
eps,
|
1254 |
+
prenorm,
|
1255 |
+
residual_in_fp32,
|
1256 |
+
is_rms_norm,
|
1257 |
+
)
|
omnigen2/pipelines/__init__.py
ADDED
File without changes
|
omnigen2/pipelines/image_processor.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import warnings
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import PIL.Image
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
|
24 |
+
from diffusers.configuration_utils import register_to_config
|
25 |
+
|
26 |
+
class OmniGen2ImageProcessor(VaeImageProcessor):
|
27 |
+
"""
|
28 |
+
Image processor for PixArt image resize and crop.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
32 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
33 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
34 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
35 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
36 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
37 |
+
Resampling filter to use when resizing the image.
|
38 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
39 |
+
Whether to normalize the image to [-1,1].
|
40 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
41 |
+
Whether to binarize the image to 0/1.
|
42 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
43 |
+
Whether to convert the images to RGB format.
|
44 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
45 |
+
Whether to convert the images to grayscale format.
|
46 |
+
"""
|
47 |
+
|
48 |
+
@register_to_config
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
do_resize: bool = True,
|
52 |
+
vae_scale_factor: int = 16,
|
53 |
+
resample: str = "lanczos",
|
54 |
+
max_pixels: Optional[int] = None,
|
55 |
+
max_side_length: Optional[int] = None,
|
56 |
+
do_normalize: bool = True,
|
57 |
+
do_binarize: bool = False,
|
58 |
+
do_convert_grayscale: bool = False,
|
59 |
+
):
|
60 |
+
super().__init__(
|
61 |
+
do_resize=do_resize,
|
62 |
+
vae_scale_factor=vae_scale_factor,
|
63 |
+
resample=resample,
|
64 |
+
do_normalize=do_normalize,
|
65 |
+
do_binarize=do_binarize,
|
66 |
+
do_convert_grayscale=do_convert_grayscale,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.max_pixels = max_pixels
|
70 |
+
self.max_side_length = max_side_length
|
71 |
+
|
72 |
+
def get_new_height_width(
|
73 |
+
self,
|
74 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
75 |
+
height: Optional[int] = None,
|
76 |
+
width: Optional[int] = None,
|
77 |
+
max_pixels: Optional[int] = None,
|
78 |
+
max_side_length: Optional[int] = None,
|
79 |
+
) -> Tuple[int, int]:
|
80 |
+
r"""
|
81 |
+
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
85 |
+
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
|
86 |
+
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
|
87 |
+
tensor, it should have shape `[batch, channels, height, width]`.
|
88 |
+
height (`Optional[int]`, *optional*, defaults to `None`):
|
89 |
+
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
|
90 |
+
width (`Optional[int]`, *optional*, defaults to `None`):
|
91 |
+
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
`Tuple[int, int]`:
|
95 |
+
A tuple containing the height and width, both resized to the nearest integer multiple of
|
96 |
+
`vae_scale_factor`.
|
97 |
+
"""
|
98 |
+
|
99 |
+
if height is None:
|
100 |
+
if isinstance(image, PIL.Image.Image):
|
101 |
+
height = image.height
|
102 |
+
elif isinstance(image, torch.Tensor):
|
103 |
+
height = image.shape[2]
|
104 |
+
else:
|
105 |
+
height = image.shape[1]
|
106 |
+
|
107 |
+
if width is None:
|
108 |
+
if isinstance(image, PIL.Image.Image):
|
109 |
+
width = image.width
|
110 |
+
elif isinstance(image, torch.Tensor):
|
111 |
+
width = image.shape[3]
|
112 |
+
else:
|
113 |
+
width = image.shape[2]
|
114 |
+
|
115 |
+
if max_side_length is None:
|
116 |
+
max_side_length = self.max_side_length
|
117 |
+
|
118 |
+
if max_pixels is None:
|
119 |
+
max_pixels = self.max_pixels
|
120 |
+
|
121 |
+
ratio = 1.0
|
122 |
+
if max_side_length is not None:
|
123 |
+
if height > width:
|
124 |
+
max_side_length_ratio = max_side_length / height
|
125 |
+
else:
|
126 |
+
max_side_length_ratio = max_side_length / width
|
127 |
+
|
128 |
+
cur_pixels = height * width
|
129 |
+
max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
|
130 |
+
ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
|
131 |
+
|
132 |
+
new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
|
133 |
+
return new_height, new_width
|
134 |
+
|
135 |
+
def preprocess(
|
136 |
+
self,
|
137 |
+
image: PipelineImageInput,
|
138 |
+
height: Optional[int] = None,
|
139 |
+
width: Optional[int] = None,
|
140 |
+
max_pixels: Optional[int] = None,
|
141 |
+
max_side_length: Optional[int] = None,
|
142 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
143 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
144 |
+
) -> torch.Tensor:
|
145 |
+
"""
|
146 |
+
Preprocess the image input.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
image (`PipelineImageInput`):
|
150 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
151 |
+
supported formats.
|
152 |
+
height (`int`, *optional*):
|
153 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
154 |
+
height.
|
155 |
+
width (`int`, *optional*):
|
156 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
157 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
158 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
159 |
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
160 |
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
161 |
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
162 |
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
163 |
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
164 |
+
supported for PIL image input.
|
165 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
166 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
`torch.Tensor`:
|
170 |
+
The preprocessed image.
|
171 |
+
"""
|
172 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
173 |
+
|
174 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
175 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
176 |
+
if isinstance(image, torch.Tensor):
|
177 |
+
# if image is a pytorch tensor could have 2 possible shapes:
|
178 |
+
# 1. batch x height x width: we should insert the channel dimension at position 1
|
179 |
+
# 2. channel x height x width: we should insert batch dimension at position 0,
|
180 |
+
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
181 |
+
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
182 |
+
image = image.unsqueeze(1)
|
183 |
+
else:
|
184 |
+
# if it is a numpy array, it could have 2 possible shapes:
|
185 |
+
# 1. batch x height x width: insert channel dimension on last position
|
186 |
+
# 2. height x width x channel: insert batch dimension on first position
|
187 |
+
if image.shape[-1] == 1:
|
188 |
+
image = np.expand_dims(image, axis=0)
|
189 |
+
else:
|
190 |
+
image = np.expand_dims(image, axis=-1)
|
191 |
+
|
192 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
193 |
+
warnings.warn(
|
194 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
195 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
196 |
+
FutureWarning,
|
197 |
+
)
|
198 |
+
image = np.concatenate(image, axis=0)
|
199 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
200 |
+
warnings.warn(
|
201 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
202 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
203 |
+
FutureWarning,
|
204 |
+
)
|
205 |
+
image = torch.cat(image, axis=0)
|
206 |
+
|
207 |
+
if not is_valid_image_imagelist(image):
|
208 |
+
raise ValueError(
|
209 |
+
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
210 |
+
)
|
211 |
+
if not isinstance(image, list):
|
212 |
+
image = [image]
|
213 |
+
|
214 |
+
if isinstance(image[0], PIL.Image.Image):
|
215 |
+
if crops_coords is not None:
|
216 |
+
image = [i.crop(crops_coords) for i in image]
|
217 |
+
if self.config.do_resize:
|
218 |
+
height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
|
219 |
+
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
220 |
+
if self.config.do_convert_rgb:
|
221 |
+
image = [self.convert_to_rgb(i) for i in image]
|
222 |
+
elif self.config.do_convert_grayscale:
|
223 |
+
image = [self.convert_to_grayscale(i) for i in image]
|
224 |
+
image = self.pil_to_numpy(image) # to np
|
225 |
+
image = self.numpy_to_pt(image) # to pt
|
226 |
+
|
227 |
+
elif isinstance(image[0], np.ndarray):
|
228 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
229 |
+
|
230 |
+
image = self.numpy_to_pt(image)
|
231 |
+
|
232 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
233 |
+
if self.config.do_resize:
|
234 |
+
image = self.resize(image, height, width)
|
235 |
+
|
236 |
+
elif isinstance(image[0], torch.Tensor):
|
237 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
238 |
+
|
239 |
+
if self.config.do_convert_grayscale and image.ndim == 3:
|
240 |
+
image = image.unsqueeze(1)
|
241 |
+
|
242 |
+
channel = image.shape[1]
|
243 |
+
# don't need any preprocess if the image is latents
|
244 |
+
if channel == self.config.vae_latent_channels:
|
245 |
+
return image
|
246 |
+
|
247 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
248 |
+
if self.config.do_resize:
|
249 |
+
image = self.resize(image, height, width)
|
250 |
+
|
251 |
+
# expected range [0,1], normalize to [-1,1]
|
252 |
+
do_normalize = self.config.do_normalize
|
253 |
+
if do_normalize and image.min() < 0:
|
254 |
+
warnings.warn(
|
255 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
256 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
257 |
+
FutureWarning,
|
258 |
+
)
|
259 |
+
do_normalize = False
|
260 |
+
if do_normalize:
|
261 |
+
image = self.normalize(image)
|
262 |
+
|
263 |
+
if self.config.do_binarize:
|
264 |
+
image = self.binarize(image)
|
265 |
+
|
266 |
+
return image
|
omnigen2/pipelines/omnigen2/pipeline_omnigen2.py
ADDED
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
OmniGen2 Diffusion Pipeline
|
3 |
+
|
4 |
+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
5 |
+
|
6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
you may not use this file except in compliance with the License.
|
8 |
+
You may obtain a copy of the License at
|
9 |
+
|
10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
|
12 |
+
Unless required by applicable law or agreed to in writing, software
|
13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
See the License for the specific language governing permissions and
|
16 |
+
limitations under the License.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import math
|
23 |
+
|
24 |
+
from PIL import Image
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
30 |
+
|
31 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
32 |
+
from ...models.transformers import OmniGen2Transformer2DModel
|
33 |
+
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
|
34 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
35 |
+
from diffusers.utils import (
|
36 |
+
is_torch_xla_available,
|
37 |
+
logging,
|
38 |
+
)
|
39 |
+
from diffusers.utils.torch_utils import randn_tensor
|
40 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
41 |
+
|
42 |
+
from dataclasses import dataclass
|
43 |
+
|
44 |
+
import PIL.Image
|
45 |
+
|
46 |
+
from diffusers.utils import BaseOutput
|
47 |
+
|
48 |
+
from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
|
49 |
+
|
50 |
+
if is_torch_xla_available():
|
51 |
+
import torch_xla.core.xla_model as xm
|
52 |
+
|
53 |
+
XLA_AVAILABLE = True
|
54 |
+
else:
|
55 |
+
XLA_AVAILABLE = False
|
56 |
+
|
57 |
+
|
58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class FMPipelineOutput(BaseOutput):
|
62 |
+
"""
|
63 |
+
Output class for OmniGen2 pipeline.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
images (Union[List[PIL.Image.Image], np.ndarray]):
|
67 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape
|
68 |
+
`(batch_size, height, width, num_channels)`. Contains the generated images.
|
69 |
+
"""
|
70 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
71 |
+
|
72 |
+
|
73 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
74 |
+
def retrieve_timesteps(
|
75 |
+
scheduler,
|
76 |
+
num_inference_steps: Optional[int] = None,
|
77 |
+
device: Optional[Union[str, torch.device]] = None,
|
78 |
+
timesteps: Optional[List[int]] = None,
|
79 |
+
**kwargs,
|
80 |
+
):
|
81 |
+
"""
|
82 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
83 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
scheduler (`SchedulerMixin`):
|
87 |
+
The scheduler to get timesteps from.
|
88 |
+
num_inference_steps (`int`):
|
89 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
90 |
+
must be `None`.
|
91 |
+
device (`str` or `torch.device`, *optional*):
|
92 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
93 |
+
timesteps (`List[int]`, *optional*):
|
94 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
95 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
96 |
+
sigmas (`List[float]`, *optional*):
|
97 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
98 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
102 |
+
second element is the number of inference steps.
|
103 |
+
"""
|
104 |
+
if timesteps is not None:
|
105 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
106 |
+
if not accepts_timesteps:
|
107 |
+
raise ValueError(
|
108 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
109 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
110 |
+
)
|
111 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
112 |
+
timesteps = scheduler.timesteps
|
113 |
+
num_inference_steps = len(timesteps)
|
114 |
+
else:
|
115 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
116 |
+
timesteps = scheduler.timesteps
|
117 |
+
return timesteps, num_inference_steps
|
118 |
+
|
119 |
+
|
120 |
+
class OmniGen2Pipeline(DiffusionPipeline):
|
121 |
+
"""
|
122 |
+
Pipeline for text-to-image generation using OmniGen2.
|
123 |
+
|
124 |
+
This pipeline implements a text-to-image generation model that uses:
|
125 |
+
- Qwen2.5-VL for text encoding
|
126 |
+
- A custom transformer architecture for image generation
|
127 |
+
- VAE for image encoding/decoding
|
128 |
+
- FlowMatchEulerDiscreteScheduler for noise scheduling
|
129 |
+
|
130 |
+
Args:
|
131 |
+
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
|
132 |
+
vae (AutoencoderKL): The VAE model for image encoding/decoding.
|
133 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
|
134 |
+
text_encoder (Qwen2_5_VLModel): The text encoder model.
|
135 |
+
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
|
136 |
+
"""
|
137 |
+
|
138 |
+
model_cpu_offload_seq = "mllm->transformer->vae"
|
139 |
+
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
transformer: OmniGen2Transformer2DModel,
|
143 |
+
vae: AutoencoderKL,
|
144 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
145 |
+
mllm: Qwen2_5_VLForConditionalGeneration,
|
146 |
+
processor,
|
147 |
+
) -> None:
|
148 |
+
"""
|
149 |
+
Initialize the OmniGen2 pipeline.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
transformer: The transformer model for image generation.
|
153 |
+
vae: The VAE model for image encoding/decoding.
|
154 |
+
scheduler: The scheduler for noise scheduling.
|
155 |
+
text_encoder: The text encoder model.
|
156 |
+
tokenizer: The tokenizer for text processing.
|
157 |
+
"""
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
self.register_modules(
|
161 |
+
transformer=transformer,
|
162 |
+
vae=vae,
|
163 |
+
scheduler=scheduler,
|
164 |
+
mllm=mllm,
|
165 |
+
processor=processor
|
166 |
+
)
|
167 |
+
self.vae_scale_factor = (
|
168 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
169 |
+
)
|
170 |
+
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
|
171 |
+
self.default_sample_size = 128
|
172 |
+
|
173 |
+
def prepare_latents(
|
174 |
+
self,
|
175 |
+
batch_size: int,
|
176 |
+
num_channels_latents: int,
|
177 |
+
height: int,
|
178 |
+
width: int,
|
179 |
+
dtype: torch.dtype,
|
180 |
+
device: torch.device,
|
181 |
+
generator: Optional[torch.Generator],
|
182 |
+
latents: Optional[torch.FloatTensor] = None,
|
183 |
+
) -> torch.FloatTensor:
|
184 |
+
"""
|
185 |
+
Prepare the initial latents for the diffusion process.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
batch_size: The number of images to generate.
|
189 |
+
num_channels_latents: The number of channels in the latent space.
|
190 |
+
height: The height of the generated image.
|
191 |
+
width: The width of the generated image.
|
192 |
+
dtype: The data type of the latents.
|
193 |
+
device: The device to place the latents on.
|
194 |
+
generator: The random number generator to use.
|
195 |
+
latents: Optional pre-computed latents to use instead of random initialization.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
torch.FloatTensor: The prepared latents tensor.
|
199 |
+
"""
|
200 |
+
height = int(height) // self.vae_scale_factor
|
201 |
+
width = int(width) // self.vae_scale_factor
|
202 |
+
|
203 |
+
shape = (batch_size, num_channels_latents, height, width)
|
204 |
+
|
205 |
+
if latents is None:
|
206 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
207 |
+
else:
|
208 |
+
latents = latents.to(device)
|
209 |
+
return latents
|
210 |
+
|
211 |
+
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
|
212 |
+
"""
|
213 |
+
Encode an image into the VAE latent space.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
img: The input image tensor to encode.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
torch.FloatTensor: The encoded latent representation.
|
220 |
+
"""
|
221 |
+
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
|
222 |
+
if self.vae.config.shift_factor is not None:
|
223 |
+
z0 = z0 - self.vae.config.shift_factor
|
224 |
+
if self.vae.config.scaling_factor is not None:
|
225 |
+
z0 = z0 * self.vae.config.scaling_factor
|
226 |
+
z0 = z0.to(dtype=self.vae.dtype)
|
227 |
+
return z0
|
228 |
+
|
229 |
+
def prepare_image(
|
230 |
+
self,
|
231 |
+
images: Union[List[PIL.Image.Image], PIL.Image.Image],
|
232 |
+
batch_size: int,
|
233 |
+
num_images_per_prompt: int,
|
234 |
+
max_pixels: int,
|
235 |
+
max_side_length: int,
|
236 |
+
device: torch.device,
|
237 |
+
dtype: torch.dtype,
|
238 |
+
) -> List[Optional[torch.FloatTensor]]:
|
239 |
+
"""
|
240 |
+
Prepare input images for processing by encoding them into the VAE latent space.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
images: Single image or list of images to process.
|
244 |
+
batch_size: The number of images to generate per prompt.
|
245 |
+
num_images_per_prompt: The number of images to generate for each prompt.
|
246 |
+
device: The device to place the encoded latents on.
|
247 |
+
dtype: The data type of the encoded latents.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
|
251 |
+
"""
|
252 |
+
if batch_size == 1:
|
253 |
+
images = [images]
|
254 |
+
latents = []
|
255 |
+
for i, img in enumerate(images):
|
256 |
+
if img is not None and len(img) > 0:
|
257 |
+
ref_latents = []
|
258 |
+
for j, img_j in enumerate(img):
|
259 |
+
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
|
260 |
+
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
|
261 |
+
else:
|
262 |
+
ref_latents = None
|
263 |
+
for _ in range(num_images_per_prompt):
|
264 |
+
latents.append(ref_latents)
|
265 |
+
|
266 |
+
return latents
|
267 |
+
|
268 |
+
def _get_qwen2_prompt_embeds(
|
269 |
+
self,
|
270 |
+
prompt: Union[str, List[str]],
|
271 |
+
device: Optional[torch.device] = None,
|
272 |
+
max_sequence_length: int = 256,
|
273 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
274 |
+
"""
|
275 |
+
Get prompt embeddings from the Qwen2 text encoder.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
prompt: The prompt or list of prompts to encode.
|
279 |
+
device: The device to place the embeddings on. If None, uses the pipeline's device.
|
280 |
+
max_sequence_length: Maximum sequence length for tokenization.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
284 |
+
- The prompt embeddings tensor
|
285 |
+
- The attention mask tensor
|
286 |
+
|
287 |
+
Raises:
|
288 |
+
Warning: If the input text is truncated due to sequence length limitations.
|
289 |
+
"""
|
290 |
+
device = device or self._execution_device
|
291 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
292 |
+
text_inputs = self.processor.tokenizer(
|
293 |
+
prompt,
|
294 |
+
padding="max_length",
|
295 |
+
max_length=max_sequence_length,
|
296 |
+
truncation=True,
|
297 |
+
return_tensors="pt",
|
298 |
+
)
|
299 |
+
|
300 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
301 |
+
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
302 |
+
|
303 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
304 |
+
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
305 |
+
logger.warning(
|
306 |
+
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
307 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
308 |
+
)
|
309 |
+
|
310 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
311 |
+
prompt_embeds = self.mllm(
|
312 |
+
text_input_ids,
|
313 |
+
attention_mask=prompt_attention_mask,
|
314 |
+
output_hidden_states=True,
|
315 |
+
).hidden_states[-1]
|
316 |
+
|
317 |
+
if self.mllm is not None:
|
318 |
+
dtype = self.mllm.dtype
|
319 |
+
elif self.transformer is not None:
|
320 |
+
dtype = self.transformer.dtype
|
321 |
+
else:
|
322 |
+
dtype = None
|
323 |
+
|
324 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
325 |
+
|
326 |
+
return prompt_embeds, prompt_attention_mask
|
327 |
+
|
328 |
+
def _apply_chat_template(self, prompt: str):
|
329 |
+
prompt = [
|
330 |
+
{
|
331 |
+
"role": "system",
|
332 |
+
"content": "You are a helpful assistant that generates high-quality images based on user instructions.",
|
333 |
+
},
|
334 |
+
{"role": "user", "content": prompt},
|
335 |
+
]
|
336 |
+
prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
337 |
+
return prompt
|
338 |
+
|
339 |
+
def encode_prompt(
|
340 |
+
self,
|
341 |
+
prompt: Union[str, List[str]],
|
342 |
+
do_classifier_free_guidance: bool = True,
|
343 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
344 |
+
num_images_per_prompt: int = 1,
|
345 |
+
device: Optional[torch.device] = None,
|
346 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
347 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
348 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
349 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
350 |
+
max_sequence_length: int = 256,
|
351 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
352 |
+
r"""
|
353 |
+
Encodes the prompt into text encoder hidden states.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
prompt (`str` or `List[str]`, *optional*):
|
357 |
+
prompt to be encoded
|
358 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
359 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
360 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
361 |
+
Lumina-T2I, this should be "".
|
362 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
363 |
+
whether to use classifier free guidance or not
|
364 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
365 |
+
number of images that should be generated per prompt
|
366 |
+
device: (`torch.device`, *optional*):
|
367 |
+
torch device to place the resulting embeddings on
|
368 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
369 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
370 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
371 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
372 |
+
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
373 |
+
max_sequence_length (`int`, defaults to `256`):
|
374 |
+
Maximum sequence length to use for the prompt.
|
375 |
+
"""
|
376 |
+
device = device or self._execution_device
|
377 |
+
|
378 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
379 |
+
prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
|
380 |
+
|
381 |
+
if prompt is not None:
|
382 |
+
batch_size = len(prompt)
|
383 |
+
else:
|
384 |
+
batch_size = prompt_embeds.shape[0]
|
385 |
+
if prompt_embeds is None:
|
386 |
+
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
387 |
+
prompt=prompt,
|
388 |
+
device=device,
|
389 |
+
max_sequence_length=max_sequence_length
|
390 |
+
)
|
391 |
+
|
392 |
+
batch_size, seq_len, _ = prompt_embeds.shape
|
393 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
394 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
395 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
396 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
397 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
398 |
+
|
399 |
+
# Get negative embeddings for classifier free guidance
|
400 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
401 |
+
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
402 |
+
|
403 |
+
# Normalize str to list
|
404 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
405 |
+
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
|
406 |
+
|
407 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
408 |
+
raise TypeError(
|
409 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
410 |
+
f" {type(prompt)}."
|
411 |
+
)
|
412 |
+
elif isinstance(negative_prompt, str):
|
413 |
+
negative_prompt = [negative_prompt]
|
414 |
+
elif batch_size != len(negative_prompt):
|
415 |
+
raise ValueError(
|
416 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
417 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
418 |
+
" the batch size of `prompt`."
|
419 |
+
)
|
420 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
421 |
+
prompt=negative_prompt,
|
422 |
+
device=device,
|
423 |
+
max_sequence_length=max_sequence_length,
|
424 |
+
)
|
425 |
+
|
426 |
+
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
427 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
428 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
429 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
430 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
431 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
432 |
+
batch_size * num_images_per_prompt, -1
|
433 |
+
)
|
434 |
+
|
435 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
436 |
+
|
437 |
+
@property
|
438 |
+
def num_timesteps(self):
|
439 |
+
return self._num_timesteps
|
440 |
+
|
441 |
+
@property
|
442 |
+
def text_guidance_scale(self):
|
443 |
+
return self._text_guidance_scale
|
444 |
+
|
445 |
+
@property
|
446 |
+
def image_guidance_scale(self):
|
447 |
+
return self._image_guidance_scale
|
448 |
+
|
449 |
+
@property
|
450 |
+
def cfg_range(self):
|
451 |
+
return self._cfg_range
|
452 |
+
|
453 |
+
@torch.no_grad()
|
454 |
+
def __call__(
|
455 |
+
self,
|
456 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
457 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
458 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
459 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
460 |
+
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
461 |
+
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
462 |
+
max_sequence_length: Optional[int] = None,
|
463 |
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
464 |
+
input_images: Optional[List[PIL.Image.Image]] = None,
|
465 |
+
num_images_per_prompt: int = 1,
|
466 |
+
height: Optional[int] = None,
|
467 |
+
width: Optional[int] = None,
|
468 |
+
max_pixels: int = 1024 * 1024,
|
469 |
+
max_input_image_side_length: int = 1024,
|
470 |
+
align_res: bool = True,
|
471 |
+
num_inference_steps: int = 28,
|
472 |
+
text_guidance_scale: float = 4.0,
|
473 |
+
image_guidance_scale: float = 1.0,
|
474 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
475 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
476 |
+
timesteps: List[int] = None,
|
477 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
478 |
+
latents: Optional[torch.FloatTensor] = None,
|
479 |
+
output_type: Optional[str] = "pil",
|
480 |
+
return_dict: bool = True,
|
481 |
+
verbose: bool = False,
|
482 |
+
step_func=None,
|
483 |
+
):
|
484 |
+
|
485 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
486 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
487 |
+
|
488 |
+
self._text_guidance_scale = text_guidance_scale
|
489 |
+
self._image_guidance_scale = image_guidance_scale
|
490 |
+
self._cfg_range = cfg_range
|
491 |
+
self._attention_kwargs = attention_kwargs
|
492 |
+
|
493 |
+
# 2. Define call parameters
|
494 |
+
if prompt is not None and isinstance(prompt, str):
|
495 |
+
batch_size = 1
|
496 |
+
elif prompt is not None and isinstance(prompt, list):
|
497 |
+
batch_size = len(prompt)
|
498 |
+
else:
|
499 |
+
batch_size = prompt_embeds.shape[0]
|
500 |
+
|
501 |
+
device = self._execution_device
|
502 |
+
|
503 |
+
# 3. Encode input prompt
|
504 |
+
(
|
505 |
+
prompt_embeds,
|
506 |
+
prompt_attention_mask,
|
507 |
+
negative_prompt_embeds,
|
508 |
+
negative_prompt_attention_mask,
|
509 |
+
) = self.encode_prompt(
|
510 |
+
prompt,
|
511 |
+
self.text_guidance_scale > 1.0,
|
512 |
+
negative_prompt=negative_prompt,
|
513 |
+
num_images_per_prompt=num_images_per_prompt,
|
514 |
+
device=device,
|
515 |
+
prompt_embeds=prompt_embeds,
|
516 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
517 |
+
prompt_attention_mask=prompt_attention_mask,
|
518 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
519 |
+
max_sequence_length=max_sequence_length,
|
520 |
+
)
|
521 |
+
|
522 |
+
dtype = self.vae.dtype
|
523 |
+
# 3. Prepare control image
|
524 |
+
ref_latents = self.prepare_image(
|
525 |
+
images=input_images,
|
526 |
+
batch_size=batch_size,
|
527 |
+
num_images_per_prompt=num_images_per_prompt,
|
528 |
+
max_pixels=max_pixels,
|
529 |
+
max_side_length=max_input_image_side_length,
|
530 |
+
device=device,
|
531 |
+
dtype=dtype,
|
532 |
+
)
|
533 |
+
|
534 |
+
if input_images is None:
|
535 |
+
input_images = []
|
536 |
+
|
537 |
+
if len(input_images) == 1 and align_res:
|
538 |
+
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
|
539 |
+
ori_width, ori_height = width, height
|
540 |
+
else:
|
541 |
+
ori_width, ori_height = width, height
|
542 |
+
|
543 |
+
cur_pixels = height * width
|
544 |
+
ratio = (max_pixels / cur_pixels) ** 0.5
|
545 |
+
ratio = min(ratio, 1.0)
|
546 |
+
|
547 |
+
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
|
548 |
+
|
549 |
+
if len(input_images) == 0:
|
550 |
+
self._image_guidance_scale = 1
|
551 |
+
|
552 |
+
# 4. Prepare latents.
|
553 |
+
latent_channels = self.transformer.config.in_channels
|
554 |
+
latents = self.prepare_latents(
|
555 |
+
batch_size * num_images_per_prompt,
|
556 |
+
latent_channels,
|
557 |
+
height,
|
558 |
+
width,
|
559 |
+
prompt_embeds.dtype,
|
560 |
+
device,
|
561 |
+
generator,
|
562 |
+
latents,
|
563 |
+
)
|
564 |
+
|
565 |
+
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
566 |
+
self.transformer.config.axes_dim_rope,
|
567 |
+
self.transformer.config.axes_lens,
|
568 |
+
theta=10000,
|
569 |
+
)
|
570 |
+
|
571 |
+
image = self.processing(
|
572 |
+
latents=latents,
|
573 |
+
ref_latents=ref_latents,
|
574 |
+
prompt_embeds=prompt_embeds,
|
575 |
+
freqs_cis=freqs_cis,
|
576 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
577 |
+
prompt_attention_mask=prompt_attention_mask,
|
578 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
579 |
+
num_inference_steps=num_inference_steps,
|
580 |
+
timesteps=timesteps,
|
581 |
+
device=device,
|
582 |
+
dtype=dtype,
|
583 |
+
verbose=verbose,
|
584 |
+
step_func=step_func,
|
585 |
+
)
|
586 |
+
|
587 |
+
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
|
588 |
+
|
589 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
590 |
+
|
591 |
+
# Offload all models
|
592 |
+
self.maybe_free_model_hooks()
|
593 |
+
|
594 |
+
if not return_dict:
|
595 |
+
return image
|
596 |
+
else:
|
597 |
+
return FMPipelineOutput(images=image)
|
598 |
+
|
599 |
+
def processing(
|
600 |
+
self,
|
601 |
+
latents,
|
602 |
+
ref_latents,
|
603 |
+
prompt_embeds,
|
604 |
+
freqs_cis,
|
605 |
+
negative_prompt_embeds,
|
606 |
+
prompt_attention_mask,
|
607 |
+
negative_prompt_attention_mask,
|
608 |
+
num_inference_steps,
|
609 |
+
timesteps,
|
610 |
+
device,
|
611 |
+
dtype,
|
612 |
+
verbose,
|
613 |
+
step_func=None
|
614 |
+
):
|
615 |
+
batch_size = latents.shape[0]
|
616 |
+
|
617 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
618 |
+
self.scheduler,
|
619 |
+
num_inference_steps,
|
620 |
+
device,
|
621 |
+
timesteps,
|
622 |
+
num_tokens=latents.shape[-2] * latents.shape[-1]
|
623 |
+
)
|
624 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
625 |
+
self._num_timesteps = len(timesteps)
|
626 |
+
|
627 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
628 |
+
for i, t in enumerate(timesteps):
|
629 |
+
model_pred = self.predict(
|
630 |
+
t=t,
|
631 |
+
latents=latents,
|
632 |
+
prompt_embeds=prompt_embeds,
|
633 |
+
freqs_cis=freqs_cis,
|
634 |
+
prompt_attention_mask=prompt_attention_mask,
|
635 |
+
ref_image_hidden_states=ref_latents,
|
636 |
+
)
|
637 |
+
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
638 |
+
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
639 |
+
|
640 |
+
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
641 |
+
model_pred_ref = self.predict(
|
642 |
+
t=t,
|
643 |
+
latents=latents,
|
644 |
+
prompt_embeds=negative_prompt_embeds,
|
645 |
+
freqs_cis=freqs_cis,
|
646 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
647 |
+
ref_image_hidden_states=ref_latents,
|
648 |
+
)
|
649 |
+
|
650 |
+
if image_guidance_scale != 1:
|
651 |
+
model_pred_uncond = self.predict(
|
652 |
+
t=t,
|
653 |
+
latents=latents,
|
654 |
+
prompt_embeds=negative_prompt_embeds,
|
655 |
+
freqs_cis=freqs_cis,
|
656 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
657 |
+
ref_image_hidden_states=None,
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
model_pred_uncond = torch.zeros_like(model_pred)
|
661 |
+
|
662 |
+
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
|
663 |
+
text_guidance_scale * (model_pred - model_pred_ref)
|
664 |
+
elif text_guidance_scale > 1.0:
|
665 |
+
model_pred_uncond = self.predict(
|
666 |
+
t=t,
|
667 |
+
latents=latents,
|
668 |
+
prompt_embeds=negative_prompt_embeds,
|
669 |
+
freqs_cis=freqs_cis,
|
670 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
671 |
+
ref_image_hidden_states=None,
|
672 |
+
)
|
673 |
+
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
674 |
+
|
675 |
+
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
676 |
+
|
677 |
+
latents = latents.to(dtype=dtype)
|
678 |
+
|
679 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
680 |
+
progress_bar.update()
|
681 |
+
|
682 |
+
if step_func is not None:
|
683 |
+
step_func(i, self._num_timesteps)
|
684 |
+
|
685 |
+
latents = latents.to(dtype=dtype)
|
686 |
+
if self.vae.config.scaling_factor is not None:
|
687 |
+
latents = latents / self.vae.config.scaling_factor
|
688 |
+
if self.vae.config.shift_factor is not None:
|
689 |
+
latents = latents + self.vae.config.shift_factor
|
690 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
691 |
+
|
692 |
+
return image
|
693 |
+
|
694 |
+
def predict(
|
695 |
+
self,
|
696 |
+
t,
|
697 |
+
latents,
|
698 |
+
prompt_embeds,
|
699 |
+
freqs_cis,
|
700 |
+
prompt_attention_mask,
|
701 |
+
ref_image_hidden_states,
|
702 |
+
):
|
703 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
704 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
705 |
+
|
706 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
707 |
+
|
708 |
+
optional_kwargs = {}
|
709 |
+
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
|
710 |
+
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
711 |
+
|
712 |
+
model_pred = self.transformer(
|
713 |
+
latents,
|
714 |
+
timestep,
|
715 |
+
prompt_embeds,
|
716 |
+
freqs_cis,
|
717 |
+
prompt_attention_mask,
|
718 |
+
**optional_kwargs
|
719 |
+
)
|
720 |
+
return model_pred
|
omnigen2/pipelines/omnigen2/pipeline_omnigen2_chat.py
ADDED
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
OmniGen2 Diffusion Pipeline
|
3 |
+
|
4 |
+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
5 |
+
|
6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
you may not use this file except in compliance with the License.
|
8 |
+
You may obtain a copy of the License at
|
9 |
+
|
10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
|
12 |
+
Unless required by applicable law or agreed to in writing, software
|
13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
See the License for the specific language governing permissions and
|
16 |
+
limitations under the License.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import math
|
23 |
+
|
24 |
+
from PIL import Image
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
30 |
+
|
31 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
32 |
+
from ...models.transformers import OmniGen2Transformer2DModel
|
33 |
+
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
|
34 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
35 |
+
from diffusers.utils import (
|
36 |
+
is_torch_xla_available,
|
37 |
+
logging,
|
38 |
+
)
|
39 |
+
from diffusers.utils.torch_utils import randn_tensor
|
40 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
41 |
+
|
42 |
+
from dataclasses import dataclass
|
43 |
+
|
44 |
+
import PIL.Image
|
45 |
+
|
46 |
+
from diffusers.utils import BaseOutput
|
47 |
+
|
48 |
+
from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
|
49 |
+
|
50 |
+
if is_torch_xla_available():
|
51 |
+
import torch_xla.core.xla_model as xm
|
52 |
+
|
53 |
+
XLA_AVAILABLE = True
|
54 |
+
else:
|
55 |
+
XLA_AVAILABLE = False
|
56 |
+
|
57 |
+
|
58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class OmniGen2PipelineOutput(BaseOutput):
|
62 |
+
"""
|
63 |
+
Output class for OmniGen2 pipeline.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
images (Union[List[PIL.Image.Image], np.ndarray]):
|
67 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape
|
68 |
+
`(batch_size, height, width, num_channels)`. Contains the generated images.
|
69 |
+
"""
|
70 |
+
text: str
|
71 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
72 |
+
|
73 |
+
|
74 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
75 |
+
def retrieve_timesteps(
|
76 |
+
scheduler,
|
77 |
+
num_inference_steps: Optional[int] = None,
|
78 |
+
device: Optional[Union[str, torch.device]] = None,
|
79 |
+
timesteps: Optional[List[int]] = None,
|
80 |
+
**kwargs,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
84 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
scheduler (`SchedulerMixin`):
|
88 |
+
The scheduler to get timesteps from.
|
89 |
+
num_inference_steps (`int`):
|
90 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
91 |
+
must be `None`.
|
92 |
+
device (`str` or `torch.device`, *optional*):
|
93 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
94 |
+
timesteps (`List[int]`, *optional*):
|
95 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
96 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
97 |
+
sigmas (`List[float]`, *optional*):
|
98 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
99 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
103 |
+
second element is the number of inference steps.
|
104 |
+
"""
|
105 |
+
if timesteps is not None:
|
106 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
107 |
+
if not accepts_timesteps:
|
108 |
+
raise ValueError(
|
109 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
110 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
111 |
+
)
|
112 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
113 |
+
timesteps = scheduler.timesteps
|
114 |
+
num_inference_steps = len(timesteps)
|
115 |
+
else:
|
116 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
117 |
+
timesteps = scheduler.timesteps
|
118 |
+
return timesteps, num_inference_steps
|
119 |
+
|
120 |
+
|
121 |
+
class OmniGen2ChatPipeline(DiffusionPipeline):
|
122 |
+
"""
|
123 |
+
Pipeline for text-to-image generation using OmniGen2.
|
124 |
+
|
125 |
+
This pipeline implements a text-to-image generation model that uses:
|
126 |
+
- Qwen2.5-VL for text encoding
|
127 |
+
- A custom transformer architecture for image generation
|
128 |
+
- VAE for image encoding/decoding
|
129 |
+
- FlowMatchEulerDiscreteScheduler for noise scheduling
|
130 |
+
|
131 |
+
Args:
|
132 |
+
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
|
133 |
+
vae (AutoencoderKL): The VAE model for image encoding/decoding.
|
134 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
|
135 |
+
text_encoder (Qwen2_5_VLModel): The text encoder model.
|
136 |
+
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
|
137 |
+
"""
|
138 |
+
|
139 |
+
model_cpu_offload_seq = "mllm->transformer->vae"
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
transformer: OmniGen2Transformer2DModel,
|
143 |
+
vae: AutoencoderKL,
|
144 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
145 |
+
mllm: Qwen2_5_VLForConditionalGeneration,
|
146 |
+
processor,
|
147 |
+
) -> None:
|
148 |
+
"""
|
149 |
+
Initialize the OmniGen2 pipeline.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
transformer: The transformer model for image generation.
|
153 |
+
vae: The VAE model for image encoding/decoding.
|
154 |
+
scheduler: The scheduler for noise scheduling.
|
155 |
+
text_encoder: The text encoder model.
|
156 |
+
tokenizer: The tokenizer for text processing.
|
157 |
+
"""
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
self.register_modules(
|
161 |
+
transformer=transformer,
|
162 |
+
vae=vae,
|
163 |
+
scheduler=scheduler,
|
164 |
+
mllm=mllm,
|
165 |
+
processor=processor
|
166 |
+
)
|
167 |
+
self.vae_scale_factor = (
|
168 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
169 |
+
)
|
170 |
+
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
|
171 |
+
self.default_sample_size = 128
|
172 |
+
|
173 |
+
def prepare_latents(
|
174 |
+
self,
|
175 |
+
batch_size: int,
|
176 |
+
num_channels_latents: int,
|
177 |
+
height: int,
|
178 |
+
width: int,
|
179 |
+
dtype: torch.dtype,
|
180 |
+
device: torch.device,
|
181 |
+
generator: Optional[torch.Generator],
|
182 |
+
latents: Optional[torch.FloatTensor] = None,
|
183 |
+
) -> torch.FloatTensor:
|
184 |
+
"""
|
185 |
+
Prepare the initial latents for the diffusion process.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
batch_size: The number of images to generate.
|
189 |
+
num_channels_latents: The number of channels in the latent space.
|
190 |
+
height: The height of the generated image.
|
191 |
+
width: The width of the generated image.
|
192 |
+
dtype: The data type of the latents.
|
193 |
+
device: The device to place the latents on.
|
194 |
+
generator: The random number generator to use.
|
195 |
+
latents: Optional pre-computed latents to use instead of random initialization.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
torch.FloatTensor: The prepared latents tensor.
|
199 |
+
"""
|
200 |
+
height = int(height) // self.vae_scale_factor
|
201 |
+
width = int(width) // self.vae_scale_factor
|
202 |
+
|
203 |
+
shape = (batch_size, num_channels_latents, height, width)
|
204 |
+
|
205 |
+
if latents is None:
|
206 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
207 |
+
else:
|
208 |
+
latents = latents.to(device)
|
209 |
+
return latents
|
210 |
+
|
211 |
+
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
|
212 |
+
"""
|
213 |
+
Encode an image into the VAE latent space.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
img: The input image tensor to encode.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
torch.FloatTensor: The encoded latent representation.
|
220 |
+
"""
|
221 |
+
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
|
222 |
+
if self.vae.config.shift_factor is not None:
|
223 |
+
z0 = z0 - self.vae.config.shift_factor
|
224 |
+
if self.vae.config.scaling_factor is not None:
|
225 |
+
z0 = z0 * self.vae.config.scaling_factor
|
226 |
+
z0 = z0.to(dtype=self.vae.dtype)
|
227 |
+
return z0
|
228 |
+
|
229 |
+
def prepare_image(
|
230 |
+
self,
|
231 |
+
images: Union[List[PIL.Image.Image], PIL.Image.Image],
|
232 |
+
batch_size: int,
|
233 |
+
num_images_per_prompt: int,
|
234 |
+
max_pixels: int,
|
235 |
+
max_side_length: int,
|
236 |
+
device: torch.device,
|
237 |
+
dtype: torch.dtype,
|
238 |
+
) -> List[Optional[torch.FloatTensor]]:
|
239 |
+
"""
|
240 |
+
Prepare input images for processing by encoding them into the VAE latent space.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
images: Single image or list of images to process.
|
244 |
+
batch_size: The number of images to generate per prompt.
|
245 |
+
num_images_per_prompt: The number of images to generate for each prompt.
|
246 |
+
device: The device to place the encoded latents on.
|
247 |
+
dtype: The data type of the encoded latents.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
|
251 |
+
"""
|
252 |
+
if batch_size == 1:
|
253 |
+
images = [images]
|
254 |
+
latents = []
|
255 |
+
for i, img in enumerate(images):
|
256 |
+
if img is not None and len(img) > 0:
|
257 |
+
ref_latents = []
|
258 |
+
for j, img_j in enumerate(img):
|
259 |
+
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
|
260 |
+
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
|
261 |
+
else:
|
262 |
+
ref_latents = None
|
263 |
+
for _ in range(num_images_per_prompt):
|
264 |
+
latents.append(ref_latents)
|
265 |
+
|
266 |
+
return latents
|
267 |
+
|
268 |
+
def _apply_chat_template(self, prompt: str, images: List = None):
|
269 |
+
if images is not None:
|
270 |
+
prompt += "".join(
|
271 |
+
[
|
272 |
+
f"<img{i}>: <|vision_start|><|image_pad|><|vision_end|>"
|
273 |
+
for i in range(1, len(images) + 1)
|
274 |
+
]
|
275 |
+
)
|
276 |
+
prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
277 |
+
return prompt
|
278 |
+
|
279 |
+
def _get_qwen2_prompt_embeds(
|
280 |
+
self,
|
281 |
+
prompt: Union[str, List[str]],
|
282 |
+
input_images = None,
|
283 |
+
device: Optional[torch.device] = None,
|
284 |
+
use_only_text_hidden_states: bool = True,
|
285 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
286 |
+
"""
|
287 |
+
Get prompt embeddings from the Qwen2 text encoder.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
prompt: The prompt or list of prompts to encode.
|
291 |
+
device: The device to place the embeddings on. If None, uses the pipeline's device.
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
295 |
+
- The prompt embeddings tensor
|
296 |
+
- The attention mask tensor
|
297 |
+
|
298 |
+
Raises:
|
299 |
+
Warning: If the input text is truncated due to sequence length limitations.
|
300 |
+
"""
|
301 |
+
device = device or self._execution_device
|
302 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
303 |
+
|
304 |
+
inputs = self.processor(
|
305 |
+
text=prompt,
|
306 |
+
images=input_images,
|
307 |
+
videos=None,
|
308 |
+
padding=True,
|
309 |
+
return_tensors="pt",
|
310 |
+
)
|
311 |
+
inputs = inputs.to(device)
|
312 |
+
|
313 |
+
prompt_embeds = self.mllm(
|
314 |
+
**inputs,
|
315 |
+
output_hidden_states=True,
|
316 |
+
).hidden_states[-1]
|
317 |
+
|
318 |
+
text_input_ids = inputs.input_ids
|
319 |
+
text_mask = inputs.attention_mask
|
320 |
+
if use_only_text_hidden_states:
|
321 |
+
mask = text_input_ids != self.mllm.config.image_token_id
|
322 |
+
mask = mask & text_mask
|
323 |
+
mask = mask.bool()
|
324 |
+
|
325 |
+
text_l = mask.sum(dim=-1)
|
326 |
+
max_l = text_l.max()
|
327 |
+
text_batch_size = prompt_embeds.size(0)
|
328 |
+
new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
329 |
+
new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device)
|
330 |
+
for i in range(text_batch_size):
|
331 |
+
new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]]
|
332 |
+
new_text_mask[i, :text_l[i]] = 1
|
333 |
+
|
334 |
+
prompt_embeds = new_prompt_embeds
|
335 |
+
text_mask = new_text_mask
|
336 |
+
|
337 |
+
prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device)
|
338 |
+
return prompt_embeds, text_mask
|
339 |
+
|
340 |
+
|
341 |
+
def encode_prompt(
|
342 |
+
self,
|
343 |
+
prompt: Union[str, List[str]],
|
344 |
+
input_images: Optional[Union[str, List[str]]] = None,
|
345 |
+
do_classifier_free_guidance: bool = True,
|
346 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
347 |
+
num_images_per_prompt: int = 1,
|
348 |
+
device: Optional[torch.device] = None,
|
349 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
350 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
351 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
352 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
353 |
+
max_sequence_length: int = 256,
|
354 |
+
use_text_encoder_penultimate_layer_feats: bool = False
|
355 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
356 |
+
r"""
|
357 |
+
Encodes the prompt into text encoder hidden states.
|
358 |
+
|
359 |
+
Args:
|
360 |
+
prompt (`str` or `List[str]`, *optional*):
|
361 |
+
prompt to be encoded
|
362 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
363 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
364 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
365 |
+
Lumina-T2I, this should be "".
|
366 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
367 |
+
whether to use classifier free guidance or not
|
368 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
369 |
+
number of images that should be generated per prompt
|
370 |
+
device: (`torch.device`, *optional*):
|
371 |
+
torch device to place the resulting embeddings on
|
372 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
373 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
374 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
375 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
376 |
+
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
377 |
+
max_sequence_length (`int`, defaults to `256`):
|
378 |
+
Maximum sequence length to use for the prompt.
|
379 |
+
"""
|
380 |
+
device = device or self._execution_device
|
381 |
+
|
382 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
383 |
+
|
384 |
+
if prompt is not None:
|
385 |
+
batch_size = len(prompt)
|
386 |
+
else:
|
387 |
+
batch_size = prompt_embeds.shape[0]
|
388 |
+
if prompt_embeds is None:
|
389 |
+
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
390 |
+
prompt=prompt,
|
391 |
+
input_images=input_images,
|
392 |
+
device=device,
|
393 |
+
)
|
394 |
+
|
395 |
+
batch_size, seq_len, _ = prompt_embeds.shape
|
396 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
397 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
398 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
399 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
400 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
401 |
+
|
402 |
+
# Get negative embeddings for classifier free guidance
|
403 |
+
negative_prompt_embeds, negative_prompt_attention_mask = None, None
|
404 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
405 |
+
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
406 |
+
|
407 |
+
# Normalize str to list
|
408 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
409 |
+
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
|
410 |
+
|
411 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
412 |
+
raise TypeError(
|
413 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
414 |
+
f" {type(prompt)}."
|
415 |
+
)
|
416 |
+
elif isinstance(negative_prompt, str):
|
417 |
+
negative_prompt = [negative_prompt]
|
418 |
+
elif batch_size != len(negative_prompt):
|
419 |
+
raise ValueError(
|
420 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
421 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
422 |
+
" the batch size of `prompt`."
|
423 |
+
)
|
424 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
425 |
+
prompt=negative_prompt,
|
426 |
+
device=device,
|
427 |
+
)
|
428 |
+
|
429 |
+
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
430 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
431 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
432 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
433 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
434 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
435 |
+
batch_size * num_images_per_prompt, -1
|
436 |
+
)
|
437 |
+
|
438 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
439 |
+
|
440 |
+
@property
|
441 |
+
def num_timesteps(self):
|
442 |
+
return self._num_timesteps
|
443 |
+
|
444 |
+
@property
|
445 |
+
def text_guidance_scale(self):
|
446 |
+
return self._text_guidance_scale
|
447 |
+
|
448 |
+
@property
|
449 |
+
def image_guidance_scale(self):
|
450 |
+
return self._image_guidance_scale
|
451 |
+
|
452 |
+
@property
|
453 |
+
def cfg_range(self):
|
454 |
+
return self._cfg_range
|
455 |
+
|
456 |
+
def prepare_inputs_for_text_generation(self, prompts, input_images, device):
|
457 |
+
if isinstance(prompts, str):
|
458 |
+
prompts = [prompts]
|
459 |
+
|
460 |
+
ori_padding_side = self.processor.tokenizer.padding_side
|
461 |
+
self.processor.tokenizer.padding_side = "left"
|
462 |
+
inputs = self.processor(
|
463 |
+
text=prompts,
|
464 |
+
images=input_images,
|
465 |
+
videos=None,
|
466 |
+
padding=True,
|
467 |
+
return_tensors="pt",
|
468 |
+
).to(device)
|
469 |
+
self.processor.tokenizer.padding_side = ori_padding_side
|
470 |
+
return inputs
|
471 |
+
|
472 |
+
def generate_text(self, prompt, input_images):
|
473 |
+
inputs = self.prepare_inputs_for_text_generation(
|
474 |
+
prompt, input_images, self.mllm.device
|
475 |
+
)
|
476 |
+
generated_ids = self.mllm.generate(
|
477 |
+
**inputs,
|
478 |
+
tokenizer=self.processor.tokenizer,
|
479 |
+
max_new_tokens=256,
|
480 |
+
stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"],
|
481 |
+
) # stop_words=[151643, 151645, 151665]
|
482 |
+
generated_ids_trimmed = [
|
483 |
+
out_ids[len(in_ids) :]
|
484 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
485 |
+
]
|
486 |
+
output_texts = self.processor.batch_decode(
|
487 |
+
generated_ids_trimmed,
|
488 |
+
# skip_special_tokens=True,
|
489 |
+
skip_special_tokens=False,
|
490 |
+
clean_up_tokenization_spaces=False,
|
491 |
+
)
|
492 |
+
return output_texts
|
493 |
+
|
494 |
+
def generate_image(
|
495 |
+
self,
|
496 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
497 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
498 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
499 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
500 |
+
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
501 |
+
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
502 |
+
use_text_encoder_penultimate_layer_feats: bool = False,
|
503 |
+
max_sequence_length: Optional[int] = None,
|
504 |
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
505 |
+
input_images: Optional[List[PIL.Image.Image]] = None,
|
506 |
+
num_images_per_prompt: int = 1,
|
507 |
+
height: Optional[int] = None,
|
508 |
+
width: Optional[int] = None,
|
509 |
+
max_pixels: int = 1024 * 1024,
|
510 |
+
max_input_image_side_length: int = 1024,
|
511 |
+
align_res: bool = True,
|
512 |
+
num_inference_steps: int = 28,
|
513 |
+
text_guidance_scale: float = 4.0,
|
514 |
+
image_guidance_scale: float = 1.0,
|
515 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
516 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
517 |
+
timesteps: List[int] = None,
|
518 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
519 |
+
latents: Optional[torch.FloatTensor] = None,
|
520 |
+
output_type: Optional[str] = "pil",
|
521 |
+
return_dict: bool = True,
|
522 |
+
verbose: bool = False,
|
523 |
+
step_func=None,
|
524 |
+
):
|
525 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
526 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
527 |
+
|
528 |
+
self._text_guidance_scale = text_guidance_scale
|
529 |
+
self._image_guidance_scale = image_guidance_scale
|
530 |
+
self._cfg_range = cfg_range
|
531 |
+
self._attention_kwargs = attention_kwargs
|
532 |
+
|
533 |
+
# 2. Define call parameters
|
534 |
+
if prompt is not None and isinstance(prompt, str):
|
535 |
+
batch_size = 1
|
536 |
+
elif prompt is not None and isinstance(prompt, list):
|
537 |
+
batch_size = len(prompt)
|
538 |
+
else:
|
539 |
+
batch_size = prompt_embeds.shape[0]
|
540 |
+
|
541 |
+
device = self._execution_device
|
542 |
+
|
543 |
+
# 3. Encode input promptb
|
544 |
+
(
|
545 |
+
prompt_embeds,
|
546 |
+
prompt_attention_mask,
|
547 |
+
negative_prompt_embeds,
|
548 |
+
negative_prompt_attention_mask,
|
549 |
+
) = self.encode_prompt(
|
550 |
+
prompt,
|
551 |
+
input_images,
|
552 |
+
self.text_guidance_scale > 1.0,
|
553 |
+
negative_prompt=negative_prompt,
|
554 |
+
num_images_per_prompt=num_images_per_prompt,
|
555 |
+
device=device,
|
556 |
+
prompt_embeds=prompt_embeds,
|
557 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
558 |
+
prompt_attention_mask=prompt_attention_mask,
|
559 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
560 |
+
max_sequence_length=max_sequence_length,
|
561 |
+
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats
|
562 |
+
)
|
563 |
+
|
564 |
+
dtype = self.vae.dtype
|
565 |
+
# 3. Prepare control image
|
566 |
+
ref_latents = self.prepare_image(
|
567 |
+
images=input_images,
|
568 |
+
batch_size=batch_size,
|
569 |
+
num_images_per_prompt=num_images_per_prompt,
|
570 |
+
max_pixels=max_pixels,
|
571 |
+
max_side_length=max_input_image_side_length,
|
572 |
+
device=device,
|
573 |
+
dtype=dtype,
|
574 |
+
)
|
575 |
+
|
576 |
+
if input_images is None:
|
577 |
+
input_images = []
|
578 |
+
|
579 |
+
if len(input_images) == 1 and align_res:
|
580 |
+
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
|
581 |
+
ori_width, ori_height = width, height
|
582 |
+
else:
|
583 |
+
ori_width, ori_height = width, height
|
584 |
+
|
585 |
+
cur_pixels = height * width
|
586 |
+
ratio = (max_pixels / cur_pixels) ** 0.5
|
587 |
+
ratio = min(ratio, 1.0)
|
588 |
+
|
589 |
+
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
|
590 |
+
|
591 |
+
if len(input_images) == 0:
|
592 |
+
self._image_guidance_scale = 1
|
593 |
+
|
594 |
+
# 4. Prepare latents.
|
595 |
+
latent_channels = self.transformer.config.in_channels
|
596 |
+
latents = self.prepare_latents(
|
597 |
+
batch_size * num_images_per_prompt,
|
598 |
+
latent_channels,
|
599 |
+
height,
|
600 |
+
width,
|
601 |
+
prompt_embeds.dtype,
|
602 |
+
device,
|
603 |
+
generator,
|
604 |
+
latents,
|
605 |
+
)
|
606 |
+
|
607 |
+
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
608 |
+
self.transformer.config.axes_dim_rope,
|
609 |
+
self.transformer.config.axes_lens,
|
610 |
+
theta=10000,
|
611 |
+
)
|
612 |
+
|
613 |
+
image = self.processing(
|
614 |
+
latents=latents,
|
615 |
+
ref_latents=ref_latents,
|
616 |
+
prompt_embeds=prompt_embeds,
|
617 |
+
freqs_cis=freqs_cis,
|
618 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
619 |
+
prompt_attention_mask=prompt_attention_mask,
|
620 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
621 |
+
num_inference_steps=num_inference_steps,
|
622 |
+
timesteps=timesteps,
|
623 |
+
device=device,
|
624 |
+
dtype=dtype,
|
625 |
+
verbose=verbose,
|
626 |
+
step_func=step_func,
|
627 |
+
)
|
628 |
+
|
629 |
+
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
|
630 |
+
|
631 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
632 |
+
|
633 |
+
# Offload all models
|
634 |
+
self.maybe_free_model_hooks()
|
635 |
+
return image
|
636 |
+
|
637 |
+
@torch.no_grad()
|
638 |
+
def __call__(
|
639 |
+
self,
|
640 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
641 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
642 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
643 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
644 |
+
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
645 |
+
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
646 |
+
use_text_encoder_penultimate_layer_feats: bool = False,
|
647 |
+
max_sequence_length: Optional[int] = None,
|
648 |
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
649 |
+
input_images: Optional[List[PIL.Image.Image]] = None,
|
650 |
+
num_images_per_prompt: int = 1,
|
651 |
+
height: Optional[int] = 1024,
|
652 |
+
width: Optional[int] = 1024,
|
653 |
+
max_pixels: Optional[int] = 1024 * 1024,
|
654 |
+
max_input_image_side_length: int = 1024,
|
655 |
+
align_res: bool = True,
|
656 |
+
num_inference_steps: int = 28,
|
657 |
+
text_guidance_scale: float = 4.0,
|
658 |
+
image_guidance_scale: float = 1.0,
|
659 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
660 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
661 |
+
timesteps: List[int] = None,
|
662 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
663 |
+
latents: Optional[torch.FloatTensor] = None,
|
664 |
+
output_type: Optional[str] = "pil",
|
665 |
+
return_dict: bool = True,
|
666 |
+
verbose: bool = False,
|
667 |
+
step_func=None,
|
668 |
+
):
|
669 |
+
assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn"
|
670 |
+
|
671 |
+
# input_images = self.preprocess_images(input_images, max_input_image_size)
|
672 |
+
prompt = self._apply_chat_template(prompt, input_images)
|
673 |
+
generated_text = self.generate_text(prompt, input_images)[0]
|
674 |
+
|
675 |
+
images = None
|
676 |
+
if generated_text.startswith("<|img|>"):
|
677 |
+
#TODO: reuse the hidden state when generate text instead of re-generating
|
678 |
+
prompt = prompt + generated_text.split("<|img|>")[0]
|
679 |
+
images = self.generate_image(
|
680 |
+
prompt=prompt,
|
681 |
+
negative_prompt=negative_prompt,
|
682 |
+
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats,
|
683 |
+
max_sequence_length=max_sequence_length,
|
684 |
+
input_images=input_images,
|
685 |
+
num_images_per_prompt=num_images_per_prompt,
|
686 |
+
height=height,
|
687 |
+
width=width,
|
688 |
+
max_pixels=max_pixels,
|
689 |
+
max_input_image_side_length=max_input_image_side_length,
|
690 |
+
align_res=align_res,
|
691 |
+
num_inference_steps=num_inference_steps,
|
692 |
+
text_guidance_scale=text_guidance_scale,
|
693 |
+
image_guidance_scale=image_guidance_scale,
|
694 |
+
cfg_range=cfg_range,
|
695 |
+
timesteps=timesteps,
|
696 |
+
generator=generator,
|
697 |
+
latents=latents,
|
698 |
+
return_dict=False,
|
699 |
+
verbose=verbose,
|
700 |
+
step_func=step_func,
|
701 |
+
)
|
702 |
+
|
703 |
+
generated_text = generated_text.replace("<|im_end|>", "")
|
704 |
+
if not return_dict:
|
705 |
+
return generated_text, images
|
706 |
+
else:
|
707 |
+
return OmniGen2PipelineOutput(text=generated_text, images=images)
|
708 |
+
|
709 |
+
def processing(
|
710 |
+
self,
|
711 |
+
latents,
|
712 |
+
ref_latents,
|
713 |
+
prompt_embeds,
|
714 |
+
freqs_cis,
|
715 |
+
negative_prompt_embeds,
|
716 |
+
prompt_attention_mask,
|
717 |
+
negative_prompt_attention_mask,
|
718 |
+
num_inference_steps,
|
719 |
+
timesteps,
|
720 |
+
device,
|
721 |
+
dtype,
|
722 |
+
verbose,
|
723 |
+
step_func=None
|
724 |
+
):
|
725 |
+
batch_size = latents.shape[0]
|
726 |
+
|
727 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
728 |
+
self.scheduler,
|
729 |
+
num_inference_steps,
|
730 |
+
device,
|
731 |
+
timesteps,
|
732 |
+
num_tokens=latents.shape[-2] * latents.shape[-1]
|
733 |
+
)
|
734 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
735 |
+
self._num_timesteps = len(timesteps)
|
736 |
+
|
737 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
738 |
+
for i, t in enumerate(timesteps):
|
739 |
+
model_pred = self.predict(
|
740 |
+
t=t,
|
741 |
+
latents=latents,
|
742 |
+
prompt_embeds=prompt_embeds,
|
743 |
+
freqs_cis=freqs_cis,
|
744 |
+
prompt_attention_mask=prompt_attention_mask,
|
745 |
+
ref_image_hidden_states=ref_latents,
|
746 |
+
)
|
747 |
+
|
748 |
+
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
749 |
+
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
750 |
+
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
751 |
+
model_pred_ref = self.predict(
|
752 |
+
t=t,
|
753 |
+
latents=latents,
|
754 |
+
prompt_embeds=negative_prompt_embeds,
|
755 |
+
freqs_cis=freqs_cis,
|
756 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
757 |
+
ref_image_hidden_states=ref_latents,
|
758 |
+
)
|
759 |
+
|
760 |
+
if image_guidance_scale != 1:
|
761 |
+
model_pred_uncond = self.predict(
|
762 |
+
t=t,
|
763 |
+
latents=latents,
|
764 |
+
prompt_embeds=negative_prompt_embeds,
|
765 |
+
freqs_cis=freqs_cis,
|
766 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
767 |
+
ref_image_hidden_states=None,
|
768 |
+
)
|
769 |
+
else:
|
770 |
+
model_pred_uncond = torch.zeros_like(model_pred)
|
771 |
+
|
772 |
+
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
|
773 |
+
text_guidance_scale * (model_pred - model_pred_ref)
|
774 |
+
elif text_guidance_scale > 1.0:
|
775 |
+
model_pred_uncond = self.predict(
|
776 |
+
t=t,
|
777 |
+
latents=latents,
|
778 |
+
prompt_embeds=negative_prompt_embeds,
|
779 |
+
freqs_cis=freqs_cis,
|
780 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
781 |
+
ref_image_hidden_states=None,
|
782 |
+
)
|
783 |
+
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
784 |
+
|
785 |
+
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
786 |
+
|
787 |
+
latents = latents.to(dtype=dtype)
|
788 |
+
|
789 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
790 |
+
progress_bar.update()
|
791 |
+
|
792 |
+
if step_func is not None:
|
793 |
+
step_func(i, self._num_timesteps)
|
794 |
+
|
795 |
+
latents = latents.to(dtype=dtype)
|
796 |
+
if self.vae.config.scaling_factor is not None:
|
797 |
+
latents = latents / self.vae.config.scaling_factor
|
798 |
+
if self.vae.config.shift_factor is not None:
|
799 |
+
latents = latents + self.vae.config.shift_factor
|
800 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
801 |
+
|
802 |
+
return image
|
803 |
+
|
804 |
+
def predict(
|
805 |
+
self,
|
806 |
+
t,
|
807 |
+
latents,
|
808 |
+
prompt_embeds,
|
809 |
+
freqs_cis,
|
810 |
+
prompt_attention_mask,
|
811 |
+
ref_image_hidden_states,
|
812 |
+
):
|
813 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
814 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
815 |
+
|
816 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
817 |
+
|
818 |
+
optional_kwargs = {}
|
819 |
+
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
|
820 |
+
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
821 |
+
|
822 |
+
model_pred = self.transformer(
|
823 |
+
latents,
|
824 |
+
timestep,
|
825 |
+
prompt_embeds,
|
826 |
+
freqs_cis,
|
827 |
+
prompt_attention_mask,
|
828 |
+
**optional_kwargs
|
829 |
+
)
|
830 |
+
return model_pred
|
omnigen2/pipelines/pipeline_utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
|
5 |
+
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
|
6 |
+
:param pipeline:
|
7 |
+
:param prompt:
|
8 |
+
:param negative_prompt:
|
9 |
+
:param device:
|
10 |
+
:return:
|
11 |
+
"""
|
12 |
+
max_length = pipeline.tokenizer.model_max_length
|
13 |
+
|
14 |
+
# simple way to determine length of tokens
|
15 |
+
# count_prompt = len(prompt.split(" "))
|
16 |
+
# count_negative_prompt = len(negative_prompt.split(" "))
|
17 |
+
|
18 |
+
# create the tensor based on which prompt is longer
|
19 |
+
# if count_prompt >= count_negative_prompt:
|
20 |
+
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
|
21 |
+
# input_ids = pipeline.tokenizer(prompt, padding="max_length",
|
22 |
+
# max_length=pipeline.tokenizer.model_max_length,
|
23 |
+
# truncation=True,
|
24 |
+
# return_tensors="pt",).input_ids.to(device)
|
25 |
+
shape_max_length = input_ids.shape[-1]
|
26 |
+
|
27 |
+
if negative_prompt is not None:
|
28 |
+
negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
|
29 |
+
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
|
30 |
+
|
31 |
+
# else:
|
32 |
+
# negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
|
33 |
+
# shape_max_length = negative_ids.shape[-1]
|
34 |
+
# input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
|
35 |
+
# max_length=shape_max_length).input_ids.to(device)
|
36 |
+
|
37 |
+
concat_embeds = []
|
38 |
+
neg_embeds = []
|
39 |
+
for i in range(0, shape_max_length, max_length):
|
40 |
+
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
41 |
+
attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
|
42 |
+
else:
|
43 |
+
attention_mask = None
|
44 |
+
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
|
45 |
+
attention_mask=attention_mask)[0])
|
46 |
+
|
47 |
+
if negative_prompt is not None:
|
48 |
+
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
49 |
+
attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
|
50 |
+
else:
|
51 |
+
attention_mask = None
|
52 |
+
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
|
53 |
+
attention_mask=attention_mask)[0])
|
54 |
+
|
55 |
+
concat_embeds = torch.cat(concat_embeds, dim=1)
|
56 |
+
|
57 |
+
if negative_prompt is not None:
|
58 |
+
neg_embeds = torch.cat(neg_embeds, dim=1)
|
59 |
+
else:
|
60 |
+
neg_embeds = None
|
61 |
+
|
62 |
+
return concat_embeds, neg_embeds
|
omnigen2/schedulers/__init__.py
ADDED
File without changes
|
omnigen2/schedulers/scheduling_dpmsolver_multistep.py
ADDED
@@ -0,0 +1,1052 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
16 |
+
|
17 |
+
import math
|
18 |
+
from typing import List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.utils import deprecate, is_scipy_available
|
25 |
+
from diffusers.utils.torch_utils import randn_tensor
|
26 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
27 |
+
|
28 |
+
|
29 |
+
if is_scipy_available():
|
30 |
+
import scipy.stats
|
31 |
+
|
32 |
+
|
33 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
34 |
+
def betas_for_alpha_bar(
|
35 |
+
num_diffusion_timesteps,
|
36 |
+
max_beta=0.999,
|
37 |
+
alpha_transform_type="cosine",
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
41 |
+
(1-beta) over time from t = [0,1].
|
42 |
+
|
43 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
44 |
+
to that part of the diffusion process.
|
45 |
+
|
46 |
+
|
47 |
+
Args:
|
48 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
49 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
50 |
+
prevent singularities.
|
51 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
52 |
+
Choose from `cosine` or `exp`
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
56 |
+
"""
|
57 |
+
if alpha_transform_type == "cosine":
|
58 |
+
|
59 |
+
def alpha_bar_fn(t):
|
60 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
61 |
+
|
62 |
+
elif alpha_transform_type == "exp":
|
63 |
+
|
64 |
+
def alpha_bar_fn(t):
|
65 |
+
return math.exp(t * -12.0)
|
66 |
+
|
67 |
+
else:
|
68 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
69 |
+
|
70 |
+
betas = []
|
71 |
+
for i in range(num_diffusion_timesteps):
|
72 |
+
t1 = i / num_diffusion_timesteps
|
73 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
74 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
75 |
+
return torch.tensor(betas, dtype=torch.float32)
|
76 |
+
|
77 |
+
|
78 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
79 |
+
def rescale_zero_terminal_snr(betas):
|
80 |
+
"""
|
81 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
82 |
+
|
83 |
+
|
84 |
+
Args:
|
85 |
+
betas (`torch.Tensor`):
|
86 |
+
the betas that the scheduler is being initialized with.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
90 |
+
"""
|
91 |
+
# Convert betas to alphas_bar_sqrt
|
92 |
+
alphas = 1.0 - betas
|
93 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
94 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
95 |
+
|
96 |
+
# Store old values.
|
97 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
98 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
99 |
+
|
100 |
+
# Shift so the last timestep is zero.
|
101 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
102 |
+
|
103 |
+
# Scale so the first timestep is back to the old value.
|
104 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
105 |
+
|
106 |
+
# Convert alphas_bar_sqrt to betas
|
107 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
108 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
109 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
110 |
+
betas = 1 - alphas
|
111 |
+
|
112 |
+
return betas
|
113 |
+
|
114 |
+
|
115 |
+
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
116 |
+
"""
|
117 |
+
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
118 |
+
|
119 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
120 |
+
methods the library implements for all schedulers such as loading and saving.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
num_train_timesteps (`int`, defaults to 1000):
|
124 |
+
The number of diffusion steps to train the model.
|
125 |
+
beta_start (`float`, defaults to 0.0001):
|
126 |
+
The starting `beta` value of inference.
|
127 |
+
beta_end (`float`, defaults to 0.02):
|
128 |
+
The final `beta` value.
|
129 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
130 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
131 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
132 |
+
trained_betas (`np.ndarray`, *optional*):
|
133 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
134 |
+
solver_order (`int`, defaults to 2):
|
135 |
+
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
136 |
+
sampling, and `solver_order=3` for unconditional sampling.
|
137 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
138 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
139 |
+
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
140 |
+
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
141 |
+
thresholding (`bool`, defaults to `False`):
|
142 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
143 |
+
as Stable Diffusion.
|
144 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
145 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
146 |
+
sample_max_value (`float`, defaults to 1.0):
|
147 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
148 |
+
`algorithm_type="dpmsolver++"`.
|
149 |
+
algorithm_type (`str`, defaults to `dpmsolver++`):
|
150 |
+
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
151 |
+
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
152 |
+
paper, and the `dpmsolver++` type implements the algorithms in the
|
153 |
+
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
154 |
+
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
155 |
+
solver_type (`str`, defaults to `midpoint`):
|
156 |
+
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
157 |
+
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
158 |
+
lower_order_final (`bool`, defaults to `True`):
|
159 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
160 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
161 |
+
euler_at_final (`bool`, defaults to `False`):
|
162 |
+
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
163 |
+
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
164 |
+
steps, but sometimes may result in blurring.
|
165 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
166 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
167 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
168 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
169 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
170 |
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
171 |
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
172 |
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
173 |
+
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
|
174 |
+
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
175 |
+
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
176 |
+
`lambda(t)`.
|
177 |
+
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
178 |
+
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
179 |
+
flow_shift (`float`, *optional*, defaults to 1.0):
|
180 |
+
The shift value for the timestep schedule for flow matching.
|
181 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
182 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
183 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
184 |
+
lambda_min_clipped (`float`, defaults to `-inf`):
|
185 |
+
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
186 |
+
cosine (`squaredcos_cap_v2`) noise schedule.
|
187 |
+
variance_type (`str`, *optional*):
|
188 |
+
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
189 |
+
contains the predicted Gaussian variance.
|
190 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
191 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
192 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
193 |
+
steps_offset (`int`, defaults to 0):
|
194 |
+
An offset added to the inference steps, as required by some model families.
|
195 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
196 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
197 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
198 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
199 |
+
"""
|
200 |
+
|
201 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
202 |
+
order = 1
|
203 |
+
|
204 |
+
@register_to_config
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
num_train_timesteps: int = 1000,
|
208 |
+
beta_start: float = 0.0001,
|
209 |
+
beta_end: float = 0.02,
|
210 |
+
beta_schedule: str = "linear",
|
211 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
212 |
+
solver_order: int = 2,
|
213 |
+
prediction_type: str = "epsilon",
|
214 |
+
thresholding: bool = False,
|
215 |
+
dynamic_thresholding_ratio: float = 0.995,
|
216 |
+
sample_max_value: float = 1.0,
|
217 |
+
algorithm_type: str = "dpmsolver++",
|
218 |
+
solver_type: str = "midpoint",
|
219 |
+
lower_order_final: bool = True,
|
220 |
+
euler_at_final: bool = False,
|
221 |
+
final_sigmas_type: str = 'zero',
|
222 |
+
dynamic_time_shift: bool = True
|
223 |
+
):
|
224 |
+
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
225 |
+
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
226 |
+
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
227 |
+
|
228 |
+
if trained_betas is not None:
|
229 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
230 |
+
elif beta_schedule == "linear":
|
231 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
232 |
+
elif beta_schedule == "scaled_linear":
|
233 |
+
# this schedule is very specific to the latent diffusion model.
|
234 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
235 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
236 |
+
# Glide cosine schedule
|
237 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
238 |
+
else:
|
239 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
240 |
+
self.alphas = 1.0 - self.betas
|
241 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
242 |
+
|
243 |
+
# Currently we only support VP-type noise schedule
|
244 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
245 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
246 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
247 |
+
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
248 |
+
|
249 |
+
# standard deviation of the initial noise distribution
|
250 |
+
self.init_noise_sigma = 1.0
|
251 |
+
|
252 |
+
# settings for DPM-Solver
|
253 |
+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
254 |
+
if algorithm_type == "deis":
|
255 |
+
self.register_to_config(algorithm_type="dpmsolver++")
|
256 |
+
else:
|
257 |
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
258 |
+
|
259 |
+
if solver_type not in ["midpoint", "heun"]:
|
260 |
+
if solver_type in ["logrho", "bh1", "bh2"]:
|
261 |
+
self.register_to_config(solver_type="midpoint")
|
262 |
+
else:
|
263 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
264 |
+
|
265 |
+
# if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
266 |
+
# raise ValueError(
|
267 |
+
# f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
268 |
+
# )
|
269 |
+
|
270 |
+
# setable values
|
271 |
+
self.num_inference_steps = None
|
272 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
273 |
+
self.timesteps = torch.from_numpy(timesteps)
|
274 |
+
self.model_outputs = [None] * solver_order
|
275 |
+
self.lower_order_nums = 0
|
276 |
+
self._step_index = None
|
277 |
+
self._begin_index = None
|
278 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
279 |
+
|
280 |
+
@property
|
281 |
+
def step_index(self):
|
282 |
+
"""
|
283 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
284 |
+
"""
|
285 |
+
return self._step_index
|
286 |
+
|
287 |
+
@property
|
288 |
+
def begin_index(self):
|
289 |
+
"""
|
290 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
291 |
+
"""
|
292 |
+
return self._begin_index
|
293 |
+
|
294 |
+
def set_begin_index(self, begin_index: int = 0):
|
295 |
+
"""
|
296 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
begin_index (`int`):
|
300 |
+
The begin index for the scheduler.
|
301 |
+
"""
|
302 |
+
self._begin_index = begin_index
|
303 |
+
|
304 |
+
def set_timesteps(
|
305 |
+
self,
|
306 |
+
num_inference_steps: int = None,
|
307 |
+
device: Union[str, torch.device] = None,
|
308 |
+
timesteps: Optional[List[int]] = None,
|
309 |
+
num_tokens: Optional[int] = None
|
310 |
+
):
|
311 |
+
if timesteps is None:
|
312 |
+
self.num_inference_steps = num_inference_steps
|
313 |
+
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
|
314 |
+
if self.config.dynamic_time_shift and num_tokens is not None:
|
315 |
+
m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
|
316 |
+
timesteps = timesteps / (m - m * timesteps + timesteps)
|
317 |
+
|
318 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
319 |
+
sigmas = torch.cat([1 - timesteps, torch.zeros(1, device=timesteps.device)])
|
320 |
+
|
321 |
+
self.sigmas = sigmas
|
322 |
+
self.timesteps = timesteps
|
323 |
+
|
324 |
+
self.num_inference_steps = len(timesteps)
|
325 |
+
|
326 |
+
self.model_outputs = [
|
327 |
+
None,
|
328 |
+
] * self.config.solver_order
|
329 |
+
self.lower_order_nums = 0
|
330 |
+
|
331 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
332 |
+
self._step_index = None
|
333 |
+
self._begin_index = None
|
334 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
335 |
+
|
336 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
337 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
338 |
+
"""
|
339 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
340 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
341 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
342 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
343 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
344 |
+
|
345 |
+
https://arxiv.org/abs/2205.11487
|
346 |
+
"""
|
347 |
+
dtype = sample.dtype
|
348 |
+
batch_size, channels, *remaining_dims = sample.shape
|
349 |
+
|
350 |
+
if dtype not in (torch.float32, torch.float64):
|
351 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
352 |
+
|
353 |
+
# Flatten sample for doing quantile calculation along each image
|
354 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
355 |
+
|
356 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
357 |
+
|
358 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
359 |
+
s = torch.clamp(
|
360 |
+
s, min=1, max=self.config.sample_max_value
|
361 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
362 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
363 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
364 |
+
|
365 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
366 |
+
sample = sample.to(dtype)
|
367 |
+
|
368 |
+
return sample
|
369 |
+
|
370 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
371 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
372 |
+
# get log sigma
|
373 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
374 |
+
|
375 |
+
# get distribution
|
376 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
377 |
+
|
378 |
+
# get sigmas range
|
379 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
380 |
+
high_idx = low_idx + 1
|
381 |
+
|
382 |
+
low = log_sigmas[low_idx]
|
383 |
+
high = log_sigmas[high_idx]
|
384 |
+
|
385 |
+
# interpolate sigmas
|
386 |
+
w = (low - log_sigma) / (low - high)
|
387 |
+
w = np.clip(w, 0, 1)
|
388 |
+
|
389 |
+
# transform interpolation to time range
|
390 |
+
t = (1 - w) * low_idx + w * high_idx
|
391 |
+
t = t.reshape(sigma.shape)
|
392 |
+
return t
|
393 |
+
|
394 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
395 |
+
alpha_t = 1 - sigma
|
396 |
+
sigma_t = sigma
|
397 |
+
|
398 |
+
return alpha_t, sigma_t
|
399 |
+
|
400 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
401 |
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
402 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
403 |
+
|
404 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
405 |
+
# TODO: Add this logic to the other schedulers
|
406 |
+
if hasattr(self.config, "sigma_min"):
|
407 |
+
sigma_min = self.config.sigma_min
|
408 |
+
else:
|
409 |
+
sigma_min = None
|
410 |
+
|
411 |
+
if hasattr(self.config, "sigma_max"):
|
412 |
+
sigma_max = self.config.sigma_max
|
413 |
+
else:
|
414 |
+
sigma_max = None
|
415 |
+
|
416 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
417 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
418 |
+
|
419 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
420 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
421 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
422 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
423 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
424 |
+
return sigmas
|
425 |
+
|
426 |
+
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
427 |
+
"""Constructs the noise schedule of Lu et al. (2022)."""
|
428 |
+
|
429 |
+
lambda_min: float = in_lambdas[-1].item()
|
430 |
+
lambda_max: float = in_lambdas[0].item()
|
431 |
+
|
432 |
+
rho = 1.0 # 1.0 is the value used in the paper
|
433 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
434 |
+
min_inv_rho = lambda_min ** (1 / rho)
|
435 |
+
max_inv_rho = lambda_max ** (1 / rho)
|
436 |
+
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
437 |
+
return lambdas
|
438 |
+
|
439 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
440 |
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
441 |
+
"""Constructs an exponential noise schedule."""
|
442 |
+
|
443 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
444 |
+
# TODO: Add this logic to the other schedulers
|
445 |
+
if hasattr(self.config, "sigma_min"):
|
446 |
+
sigma_min = self.config.sigma_min
|
447 |
+
else:
|
448 |
+
sigma_min = None
|
449 |
+
|
450 |
+
if hasattr(self.config, "sigma_max"):
|
451 |
+
sigma_max = self.config.sigma_max
|
452 |
+
else:
|
453 |
+
sigma_max = None
|
454 |
+
|
455 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
456 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
457 |
+
|
458 |
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
459 |
+
return sigmas
|
460 |
+
|
461 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
462 |
+
def _convert_to_beta(
|
463 |
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
464 |
+
) -> torch.Tensor:
|
465 |
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
466 |
+
|
467 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
468 |
+
# TODO: Add this logic to the other schedulers
|
469 |
+
if hasattr(self.config, "sigma_min"):
|
470 |
+
sigma_min = self.config.sigma_min
|
471 |
+
else:
|
472 |
+
sigma_min = None
|
473 |
+
|
474 |
+
if hasattr(self.config, "sigma_max"):
|
475 |
+
sigma_max = self.config.sigma_max
|
476 |
+
else:
|
477 |
+
sigma_max = None
|
478 |
+
|
479 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
480 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
481 |
+
|
482 |
+
sigmas = np.array(
|
483 |
+
[
|
484 |
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
485 |
+
for ppf in [
|
486 |
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
487 |
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
488 |
+
]
|
489 |
+
]
|
490 |
+
)
|
491 |
+
return sigmas
|
492 |
+
|
493 |
+
def convert_model_output(
|
494 |
+
self,
|
495 |
+
model_output: torch.Tensor,
|
496 |
+
*args,
|
497 |
+
sample: torch.Tensor = None,
|
498 |
+
**kwargs,
|
499 |
+
) -> torch.Tensor:
|
500 |
+
"""
|
501 |
+
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
502 |
+
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
503 |
+
integral of the data prediction model.
|
504 |
+
|
505 |
+
<Tip>
|
506 |
+
|
507 |
+
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
508 |
+
prediction and data prediction models.
|
509 |
+
|
510 |
+
</Tip>
|
511 |
+
|
512 |
+
Args:
|
513 |
+
model_output (`torch.Tensor`):
|
514 |
+
The direct output from the learned diffusion model.
|
515 |
+
sample (`torch.Tensor`):
|
516 |
+
A current instance of a sample created by the diffusion process.
|
517 |
+
|
518 |
+
Returns:
|
519 |
+
`torch.Tensor`:
|
520 |
+
The converted model output.
|
521 |
+
"""
|
522 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
523 |
+
if sample is None:
|
524 |
+
if len(args) > 1:
|
525 |
+
sample = args[1]
|
526 |
+
else:
|
527 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
528 |
+
if timestep is not None:
|
529 |
+
deprecate(
|
530 |
+
"timesteps",
|
531 |
+
"1.0.0",
|
532 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
533 |
+
)
|
534 |
+
|
535 |
+
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
536 |
+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
537 |
+
if self.config.prediction_type == "epsilon":
|
538 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
539 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
540 |
+
model_output = model_output[:, :3]
|
541 |
+
sigma = self.sigmas[self.step_index]
|
542 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
543 |
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
544 |
+
elif self.config.prediction_type == "sample":
|
545 |
+
x0_pred = model_output
|
546 |
+
elif self.config.prediction_type == "v_prediction":
|
547 |
+
sigma = self.sigmas[self.step_index]
|
548 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
549 |
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
550 |
+
elif self.config.prediction_type == "flow_prediction":
|
551 |
+
sigma_t = self.sigmas[self.step_index]
|
552 |
+
x0_pred = sample + sigma_t * model_output
|
553 |
+
else:
|
554 |
+
raise ValueError(
|
555 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
556 |
+
"`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
|
557 |
+
)
|
558 |
+
|
559 |
+
if self.config.thresholding:
|
560 |
+
x0_pred = self._threshold_sample(x0_pred)
|
561 |
+
|
562 |
+
return x0_pred
|
563 |
+
|
564 |
+
# DPM-Solver needs to solve an integral of the noise prediction model.
|
565 |
+
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
566 |
+
if self.config.prediction_type == "epsilon":
|
567 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
568 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
569 |
+
epsilon = model_output[:, :3]
|
570 |
+
else:
|
571 |
+
epsilon = model_output
|
572 |
+
elif self.config.prediction_type == "sample":
|
573 |
+
sigma = self.sigmas[self.step_index]
|
574 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
575 |
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
576 |
+
elif self.config.prediction_type == "v_prediction":
|
577 |
+
sigma = self.sigmas[self.step_index]
|
578 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
579 |
+
epsilon = alpha_t * model_output + sigma_t * sample
|
580 |
+
else:
|
581 |
+
raise ValueError(
|
582 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
583 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
584 |
+
)
|
585 |
+
|
586 |
+
if self.config.thresholding:
|
587 |
+
sigma = self.sigmas[self.step_index]
|
588 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
589 |
+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
590 |
+
x0_pred = self._threshold_sample(x0_pred)
|
591 |
+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
592 |
+
|
593 |
+
return epsilon
|
594 |
+
|
595 |
+
def dpm_solver_first_order_update(
|
596 |
+
self,
|
597 |
+
model_output: torch.Tensor,
|
598 |
+
*args,
|
599 |
+
sample: torch.Tensor = None,
|
600 |
+
noise: Optional[torch.Tensor] = None,
|
601 |
+
**kwargs,
|
602 |
+
) -> torch.Tensor:
|
603 |
+
"""
|
604 |
+
One step for the first-order DPMSolver (equivalent to DDIM).
|
605 |
+
|
606 |
+
Args:
|
607 |
+
model_output (`torch.Tensor`):
|
608 |
+
The direct output from the learned diffusion model.
|
609 |
+
sample (`torch.Tensor`):
|
610 |
+
A current instance of a sample created by the diffusion process.
|
611 |
+
|
612 |
+
Returns:
|
613 |
+
`torch.Tensor`:
|
614 |
+
The sample tensor at the previous timestep.
|
615 |
+
"""
|
616 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
617 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
618 |
+
if sample is None:
|
619 |
+
if len(args) > 2:
|
620 |
+
sample = args[2]
|
621 |
+
else:
|
622 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
623 |
+
if timestep is not None:
|
624 |
+
deprecate(
|
625 |
+
"timesteps",
|
626 |
+
"1.0.0",
|
627 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
628 |
+
)
|
629 |
+
|
630 |
+
if prev_timestep is not None:
|
631 |
+
deprecate(
|
632 |
+
"prev_timestep",
|
633 |
+
"1.0.0",
|
634 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
635 |
+
)
|
636 |
+
|
637 |
+
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
638 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
639 |
+
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
640 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
641 |
+
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
642 |
+
|
643 |
+
h = lambda_t - lambda_s
|
644 |
+
if self.config.algorithm_type == "dpmsolver++":
|
645 |
+
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
646 |
+
elif self.config.algorithm_type == "dpmsolver":
|
647 |
+
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
648 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
649 |
+
assert noise is not None
|
650 |
+
x_t = (
|
651 |
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
652 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
653 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
654 |
+
)
|
655 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
656 |
+
assert noise is not None
|
657 |
+
x_t = (
|
658 |
+
(alpha_t / alpha_s) * sample
|
659 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
660 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
661 |
+
)
|
662 |
+
return x_t
|
663 |
+
|
664 |
+
def multistep_dpm_solver_second_order_update(
|
665 |
+
self,
|
666 |
+
model_output_list: List[torch.Tensor],
|
667 |
+
*args,
|
668 |
+
sample: torch.Tensor = None,
|
669 |
+
noise: Optional[torch.Tensor] = None,
|
670 |
+
**kwargs,
|
671 |
+
) -> torch.Tensor:
|
672 |
+
"""
|
673 |
+
One step for the second-order multistep DPMSolver.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
model_output_list (`List[torch.Tensor]`):
|
677 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
678 |
+
sample (`torch.Tensor`):
|
679 |
+
A current instance of a sample created by the diffusion process.
|
680 |
+
|
681 |
+
Returns:
|
682 |
+
`torch.Tensor`:
|
683 |
+
The sample tensor at the previous timestep.
|
684 |
+
"""
|
685 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
686 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
687 |
+
if sample is None:
|
688 |
+
if len(args) > 2:
|
689 |
+
sample = args[2]
|
690 |
+
else:
|
691 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
692 |
+
if timestep_list is not None:
|
693 |
+
deprecate(
|
694 |
+
"timestep_list",
|
695 |
+
"1.0.0",
|
696 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
697 |
+
)
|
698 |
+
|
699 |
+
if prev_timestep is not None:
|
700 |
+
deprecate(
|
701 |
+
"prev_timestep",
|
702 |
+
"1.0.0",
|
703 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
704 |
+
)
|
705 |
+
|
706 |
+
sigma_t, sigma_s0, sigma_s1 = (
|
707 |
+
self.sigmas[self.step_index + 1],
|
708 |
+
self.sigmas[self.step_index],
|
709 |
+
self.sigmas[self.step_index - 1],
|
710 |
+
)
|
711 |
+
|
712 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
713 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
714 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
715 |
+
|
716 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
717 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
718 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
719 |
+
|
720 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
721 |
+
|
722 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
723 |
+
r0 = h_0 / h
|
724 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
725 |
+
if self.config.algorithm_type == "dpmsolver++":
|
726 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
727 |
+
if self.config.solver_type == "midpoint":
|
728 |
+
x_t = (
|
729 |
+
(sigma_t / sigma_s0) * sample
|
730 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
731 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
732 |
+
)
|
733 |
+
elif self.config.solver_type == "heun":
|
734 |
+
x_t = (
|
735 |
+
(sigma_t / sigma_s0) * sample
|
736 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
737 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
738 |
+
)
|
739 |
+
elif self.config.algorithm_type == "dpmsolver":
|
740 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
741 |
+
if self.config.solver_type == "midpoint":
|
742 |
+
x_t = (
|
743 |
+
(alpha_t / alpha_s0) * sample
|
744 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
745 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
746 |
+
)
|
747 |
+
elif self.config.solver_type == "heun":
|
748 |
+
x_t = (
|
749 |
+
(alpha_t / alpha_s0) * sample
|
750 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
751 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
752 |
+
)
|
753 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
754 |
+
assert noise is not None
|
755 |
+
if self.config.solver_type == "midpoint":
|
756 |
+
x_t = (
|
757 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
758 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
759 |
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
760 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
761 |
+
)
|
762 |
+
elif self.config.solver_type == "heun":
|
763 |
+
x_t = (
|
764 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
765 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
766 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
767 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
768 |
+
)
|
769 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
770 |
+
assert noise is not None
|
771 |
+
if self.config.solver_type == "midpoint":
|
772 |
+
x_t = (
|
773 |
+
(alpha_t / alpha_s0) * sample
|
774 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
775 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
776 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
777 |
+
)
|
778 |
+
elif self.config.solver_type == "heun":
|
779 |
+
x_t = (
|
780 |
+
(alpha_t / alpha_s0) * sample
|
781 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
782 |
+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
783 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
784 |
+
)
|
785 |
+
return x_t
|
786 |
+
|
787 |
+
def multistep_dpm_solver_third_order_update(
|
788 |
+
self,
|
789 |
+
model_output_list: List[torch.Tensor],
|
790 |
+
*args,
|
791 |
+
sample: torch.Tensor = None,
|
792 |
+
noise: Optional[torch.Tensor] = None,
|
793 |
+
**kwargs,
|
794 |
+
) -> torch.Tensor:
|
795 |
+
"""
|
796 |
+
One step for the third-order multistep DPMSolver.
|
797 |
+
|
798 |
+
Args:
|
799 |
+
model_output_list (`List[torch.Tensor]`):
|
800 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
801 |
+
sample (`torch.Tensor`):
|
802 |
+
A current instance of a sample created by diffusion process.
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
`torch.Tensor`:
|
806 |
+
The sample tensor at the previous timestep.
|
807 |
+
"""
|
808 |
+
|
809 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
810 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
811 |
+
if sample is None:
|
812 |
+
if len(args) > 2:
|
813 |
+
sample = args[2]
|
814 |
+
else:
|
815 |
+
raise ValueError(" missing`sample` as a required keyward argument")
|
816 |
+
if timestep_list is not None:
|
817 |
+
deprecate(
|
818 |
+
"timestep_list",
|
819 |
+
"1.0.0",
|
820 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
821 |
+
)
|
822 |
+
|
823 |
+
if prev_timestep is not None:
|
824 |
+
deprecate(
|
825 |
+
"prev_timestep",
|
826 |
+
"1.0.0",
|
827 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
828 |
+
)
|
829 |
+
|
830 |
+
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
831 |
+
self.sigmas[self.step_index + 1],
|
832 |
+
self.sigmas[self.step_index],
|
833 |
+
self.sigmas[self.step_index - 1],
|
834 |
+
self.sigmas[self.step_index - 2],
|
835 |
+
)
|
836 |
+
|
837 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
838 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
839 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
840 |
+
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
841 |
+
|
842 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
843 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
844 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
845 |
+
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
846 |
+
|
847 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
848 |
+
|
849 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
850 |
+
r0, r1 = h_0 / h, h_1 / h
|
851 |
+
D0 = m0
|
852 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
853 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
854 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
855 |
+
if self.config.algorithm_type == "dpmsolver++":
|
856 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
857 |
+
x_t = (
|
858 |
+
(sigma_t / sigma_s0) * sample
|
859 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
860 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
861 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
862 |
+
)
|
863 |
+
elif self.config.algorithm_type == "dpmsolver":
|
864 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
865 |
+
x_t = (
|
866 |
+
(alpha_t / alpha_s0) * sample
|
867 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
868 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
869 |
+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
870 |
+
)
|
871 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
872 |
+
assert noise is not None
|
873 |
+
x_t = (
|
874 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
875 |
+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
|
876 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
877 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
|
878 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
879 |
+
)
|
880 |
+
return x_t
|
881 |
+
|
882 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
883 |
+
if schedule_timesteps is None:
|
884 |
+
schedule_timesteps = self.timesteps
|
885 |
+
|
886 |
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
887 |
+
|
888 |
+
if len(index_candidates) == 0:
|
889 |
+
step_index = len(self.timesteps) - 1
|
890 |
+
# The sigma index that is taken for the **very** first `step`
|
891 |
+
# is always the second index (or the last index if there is only 1)
|
892 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
893 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
894 |
+
elif len(index_candidates) > 1:
|
895 |
+
step_index = index_candidates[1].item()
|
896 |
+
else:
|
897 |
+
step_index = index_candidates[0].item()
|
898 |
+
|
899 |
+
return step_index
|
900 |
+
|
901 |
+
def _init_step_index(self, timestep):
|
902 |
+
"""
|
903 |
+
Initialize the step_index counter for the scheduler.
|
904 |
+
"""
|
905 |
+
|
906 |
+
if self.begin_index is None:
|
907 |
+
if isinstance(timestep, torch.Tensor):
|
908 |
+
timestep = timestep.to(self.timesteps.device)
|
909 |
+
self._step_index = self.index_for_timestep(timestep)
|
910 |
+
else:
|
911 |
+
self._step_index = self._begin_index
|
912 |
+
|
913 |
+
def step(
|
914 |
+
self,
|
915 |
+
model_output: torch.Tensor,
|
916 |
+
timestep: Union[int, torch.Tensor],
|
917 |
+
sample: torch.Tensor,
|
918 |
+
generator=None,
|
919 |
+
variance_noise: Optional[torch.Tensor] = None,
|
920 |
+
return_dict: bool = True,
|
921 |
+
) -> Union[SchedulerOutput, Tuple]:
|
922 |
+
"""
|
923 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
924 |
+
the multistep DPMSolver.
|
925 |
+
|
926 |
+
Args:
|
927 |
+
model_output (`torch.Tensor`):
|
928 |
+
The direct output from learned diffusion model.
|
929 |
+
timestep (`int`):
|
930 |
+
The current discrete timestep in the diffusion chain.
|
931 |
+
sample (`torch.Tensor`):
|
932 |
+
A current instance of a sample created by the diffusion process.
|
933 |
+
generator (`torch.Generator`, *optional*):
|
934 |
+
A random number generator.
|
935 |
+
variance_noise (`torch.Tensor`):
|
936 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
937 |
+
itself. Useful for methods such as [`LEdits++`].
|
938 |
+
return_dict (`bool`):
|
939 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
940 |
+
|
941 |
+
Returns:
|
942 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
943 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
944 |
+
tuple is returned where the first element is the sample tensor.
|
945 |
+
|
946 |
+
"""
|
947 |
+
if self.num_inference_steps is None:
|
948 |
+
raise ValueError(
|
949 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
950 |
+
)
|
951 |
+
|
952 |
+
if self.step_index is None:
|
953 |
+
self._init_step_index(timestep)
|
954 |
+
|
955 |
+
# Improve numerical stability for small number of steps
|
956 |
+
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
957 |
+
self.config.euler_at_final
|
958 |
+
or (self.config.lower_order_final and len(self.timesteps) < 15)
|
959 |
+
or self.config.final_sigmas_type == "zero"
|
960 |
+
)
|
961 |
+
lower_order_second = (
|
962 |
+
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
963 |
+
)
|
964 |
+
|
965 |
+
model_output = self.convert_model_output(model_output, sample=sample)
|
966 |
+
for i in range(self.config.solver_order - 1):
|
967 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
968 |
+
self.model_outputs[-1] = model_output
|
969 |
+
|
970 |
+
# Upcast to avoid precision issues when computing prev_sample
|
971 |
+
sample = sample.to(torch.float32)
|
972 |
+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
973 |
+
noise = randn_tensor(
|
974 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
975 |
+
)
|
976 |
+
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
977 |
+
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
978 |
+
else:
|
979 |
+
noise = None
|
980 |
+
|
981 |
+
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
982 |
+
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
983 |
+
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
984 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
985 |
+
else:
|
986 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
|
987 |
+
|
988 |
+
if self.lower_order_nums < self.config.solver_order:
|
989 |
+
self.lower_order_nums += 1
|
990 |
+
|
991 |
+
# Cast sample back to expected dtype
|
992 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
993 |
+
|
994 |
+
# upon completion increase step index by one
|
995 |
+
self._step_index += 1
|
996 |
+
|
997 |
+
if not return_dict:
|
998 |
+
return (prev_sample,)
|
999 |
+
|
1000 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
1001 |
+
|
1002 |
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
1003 |
+
"""
|
1004 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
1005 |
+
current timestep.
|
1006 |
+
|
1007 |
+
Args:
|
1008 |
+
sample (`torch.Tensor`):
|
1009 |
+
The input sample.
|
1010 |
+
|
1011 |
+
Returns:
|
1012 |
+
`torch.Tensor`:
|
1013 |
+
A scaled input sample.
|
1014 |
+
"""
|
1015 |
+
return sample
|
1016 |
+
|
1017 |
+
def add_noise(
|
1018 |
+
self,
|
1019 |
+
original_samples: torch.Tensor,
|
1020 |
+
noise: torch.Tensor,
|
1021 |
+
timesteps: torch.IntTensor,
|
1022 |
+
) -> torch.Tensor:
|
1023 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
1024 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
1025 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
1026 |
+
# mps does not support float64
|
1027 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
1028 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
1029 |
+
else:
|
1030 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
1031 |
+
timesteps = timesteps.to(original_samples.device)
|
1032 |
+
|
1033 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
1034 |
+
if self.begin_index is None:
|
1035 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
1036 |
+
elif self.step_index is not None:
|
1037 |
+
# add_noise is called after first denoising step (for inpainting)
|
1038 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
1039 |
+
else:
|
1040 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
1041 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
1042 |
+
|
1043 |
+
sigma = sigmas[step_indices].flatten()
|
1044 |
+
while len(sigma.shape) < len(original_samples.shape):
|
1045 |
+
sigma = sigma.unsqueeze(-1)
|
1046 |
+
|
1047 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
1048 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
1049 |
+
return noisy_samples
|
1050 |
+
|
1051 |
+
def __len__(self):
|
1052 |
+
return self.config.num_train_timesteps
|
omnigen2/schedulers/scheduling_flow_match_euler_discrete.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
Output class for the scheduler's `step` function output.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38 |
+
denoising loop.
|
39 |
+
"""
|
40 |
+
|
41 |
+
prev_sample: torch.FloatTensor
|
42 |
+
|
43 |
+
|
44 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
45 |
+
"""
|
46 |
+
Euler scheduler.
|
47 |
+
|
48 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
49 |
+
methods the library implements for all schedulers such as loading and saving.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
num_train_timesteps (`int`, defaults to 1000):
|
53 |
+
The number of diffusion steps to train the model.
|
54 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
55 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
56 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
57 |
+
shift (`float`, defaults to 1.0):
|
58 |
+
The shift value for the timestep schedule.
|
59 |
+
"""
|
60 |
+
|
61 |
+
_compatibles = []
|
62 |
+
order = 1
|
63 |
+
|
64 |
+
@register_to_config
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
num_train_timesteps: int = 1000,
|
68 |
+
dynamic_time_shift: bool = True
|
69 |
+
):
|
70 |
+
timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
|
71 |
+
|
72 |
+
self.timesteps = timesteps
|
73 |
+
|
74 |
+
self._step_index = None
|
75 |
+
self._begin_index = None
|
76 |
+
|
77 |
+
@property
|
78 |
+
def step_index(self):
|
79 |
+
"""
|
80 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
81 |
+
"""
|
82 |
+
return self._step_index
|
83 |
+
|
84 |
+
@property
|
85 |
+
def begin_index(self):
|
86 |
+
"""
|
87 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
88 |
+
"""
|
89 |
+
return self._begin_index
|
90 |
+
|
91 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
92 |
+
def set_begin_index(self, begin_index: int = 0):
|
93 |
+
"""
|
94 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
begin_index (`int`):
|
98 |
+
The begin index for the scheduler.
|
99 |
+
"""
|
100 |
+
self._begin_index = begin_index
|
101 |
+
|
102 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
103 |
+
if schedule_timesteps is None:
|
104 |
+
schedule_timesteps = self._timesteps
|
105 |
+
|
106 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
107 |
+
|
108 |
+
# The sigma index that is taken for the **very** first `step`
|
109 |
+
# is always the second index (or the last index if there is only 1)
|
110 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
111 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
112 |
+
pos = 1 if len(indices) > 1 else 0
|
113 |
+
|
114 |
+
return indices[pos].item()
|
115 |
+
|
116 |
+
# def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
117 |
+
# return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
118 |
+
|
119 |
+
def set_timesteps(
|
120 |
+
self,
|
121 |
+
num_inference_steps: int = None,
|
122 |
+
device: Union[str, torch.device] = None,
|
123 |
+
timesteps: Optional[List[float]] = None,
|
124 |
+
num_tokens: Optional[int] = None
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
128 |
+
|
129 |
+
Args:
|
130 |
+
num_inference_steps (`int`):
|
131 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
132 |
+
device (`str` or `torch.device`, *optional*):
|
133 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
134 |
+
"""
|
135 |
+
|
136 |
+
if timesteps is None:
|
137 |
+
self.num_inference_steps = num_inference_steps
|
138 |
+
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
|
139 |
+
if self.config.dynamic_time_shift and num_tokens is not None:
|
140 |
+
m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
|
141 |
+
timesteps = timesteps / (m - m * timesteps + timesteps)
|
142 |
+
|
143 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
144 |
+
_timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
|
145 |
+
|
146 |
+
self.timesteps = timesteps
|
147 |
+
self._timesteps = _timesteps
|
148 |
+
self._step_index = None
|
149 |
+
self._begin_index = None
|
150 |
+
|
151 |
+
def _init_step_index(self, timestep):
|
152 |
+
if self.begin_index is None:
|
153 |
+
if isinstance(timestep, torch.Tensor):
|
154 |
+
timestep = timestep.to(self.timesteps.device)
|
155 |
+
self._step_index = self.index_for_timestep(timestep)
|
156 |
+
else:
|
157 |
+
self._step_index = self._begin_index
|
158 |
+
|
159 |
+
def step(
|
160 |
+
self,
|
161 |
+
model_output: torch.FloatTensor,
|
162 |
+
timestep: Union[float, torch.FloatTensor],
|
163 |
+
sample: torch.FloatTensor,
|
164 |
+
generator: Optional[torch.Generator] = None,
|
165 |
+
return_dict: bool = True,
|
166 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
167 |
+
"""
|
168 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
169 |
+
process from the learned model outputs (most often the predicted noise).
|
170 |
+
|
171 |
+
Args:
|
172 |
+
model_output (`torch.FloatTensor`):
|
173 |
+
The direct output from learned diffusion model.
|
174 |
+
timestep (`float`):
|
175 |
+
The current discrete timestep in the diffusion chain.
|
176 |
+
sample (`torch.FloatTensor`):
|
177 |
+
A current instance of a sample created by the diffusion process.
|
178 |
+
s_churn (`float`):
|
179 |
+
s_tmin (`float`):
|
180 |
+
s_tmax (`float`):
|
181 |
+
s_noise (`float`, defaults to 1.0):
|
182 |
+
Scaling factor for noise added to the sample.
|
183 |
+
generator (`torch.Generator`, *optional*):
|
184 |
+
A random number generator.
|
185 |
+
return_dict (`bool`):
|
186 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
187 |
+
tuple.
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
191 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
192 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
193 |
+
"""
|
194 |
+
|
195 |
+
if (
|
196 |
+
isinstance(timestep, int)
|
197 |
+
or isinstance(timestep, torch.IntTensor)
|
198 |
+
or isinstance(timestep, torch.LongTensor)
|
199 |
+
):
|
200 |
+
raise ValueError(
|
201 |
+
(
|
202 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
203 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
204 |
+
" one of the `scheduler.timesteps` as a timestep."
|
205 |
+
),
|
206 |
+
)
|
207 |
+
|
208 |
+
if self.step_index is None:
|
209 |
+
self._init_step_index(timestep)
|
210 |
+
# Upcast to avoid precision issues when computing prev_sample
|
211 |
+
sample = sample.to(torch.float32)
|
212 |
+
t = self._timesteps[self.step_index]
|
213 |
+
t_next = self._timesteps[self.step_index + 1]
|
214 |
+
|
215 |
+
prev_sample = sample + (t_next - t) * model_output
|
216 |
+
|
217 |
+
# Cast sample back to model compatible dtype
|
218 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
219 |
+
|
220 |
+
# upon completion increase step index by one
|
221 |
+
self._step_index += 1
|
222 |
+
|
223 |
+
if not return_dict:
|
224 |
+
return (prev_sample,)
|
225 |
+
|
226 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
227 |
+
|
228 |
+
def __len__(self):
|
229 |
+
return self.config.num_train_timesteps
|
omnigen2/utils/__init__.py
ADDED
File without changes
|
omnigen2/utils/img_util.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision.transforms.functional import to_pil_image
|
7 |
+
|
8 |
+
def resize_image(image, max_pixels, img_scale_num):
|
9 |
+
width, height = image.size
|
10 |
+
cur_pixels = height * width
|
11 |
+
ratio = (max_pixels / cur_pixels) ** 0.5
|
12 |
+
ratio = min(ratio, 1.0) # do not upscale input image
|
13 |
+
|
14 |
+
new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num
|
15 |
+
|
16 |
+
image = image.resize((new_width, new_height), resample=Image.BICUBIC)
|
17 |
+
return image
|
18 |
+
|
19 |
+
def create_collage(images: List[torch.Tensor]) -> Image.Image:
|
20 |
+
"""Create a horizontal collage from a list of images."""
|
21 |
+
max_height = max(img.shape[-2] for img in images)
|
22 |
+
total_width = sum(img.shape[-1] for img in images)
|
23 |
+
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
|
24 |
+
|
25 |
+
current_x = 0
|
26 |
+
for img in images:
|
27 |
+
h, w = img.shape[-2:]
|
28 |
+
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
|
29 |
+
current_x += w
|
30 |
+
|
31 |
+
return to_pil_image(canvas)
|
omnigen2/utils/vpn_utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import json
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
import requests
|
7 |
+
|
8 |
+
class VPNManager:
|
9 |
+
def __init__(self, config_path: str = '/etc/mihomo/config.yaml'):
|
10 |
+
with open(config_path, 'r') as f:
|
11 |
+
config = yaml.safe_load(f)
|
12 |
+
self.external_controller = config['external-controller']
|
13 |
+
self.external_controller = self.external_controller.replace('0.0.0.0', '127.0.0.1')
|
14 |
+
self.secret = config['secret']
|
15 |
+
|
16 |
+
self.headers = {"Authorization": f"Bearer {self.secret}"}
|
17 |
+
|
18 |
+
self.unavailable_nodes = set()
|
19 |
+
|
20 |
+
@property
|
21 |
+
def current_node(self):
|
22 |
+
url = f"http://{self.external_controller}/group/Proxy"
|
23 |
+
r = requests.request("GET", url, headers=self.headers)
|
24 |
+
return r.json()['now']
|
25 |
+
|
26 |
+
@property
|
27 |
+
def available_nodes(self):
|
28 |
+
return list(self.get_available_vpn_nodes() - self.unavailable_nodes)
|
29 |
+
|
30 |
+
def switch_vpn_node(self, node_name):
|
31 |
+
url = f"http://{self.external_controller}/proxies/Proxy"
|
32 |
+
|
33 |
+
payload = json.dumps({
|
34 |
+
"name": node_name
|
35 |
+
})
|
36 |
+
headers = self.headers.copy()
|
37 |
+
headers.update({'Content-Type': 'application/json'})
|
38 |
+
r = requests.request("PUT", url, headers=headers, data=payload)
|
39 |
+
if r.status_code != 204:
|
40 |
+
raise Warning(f"Failed to switch to {node_name}")
|
41 |
+
return r.status_code == 204
|
42 |
+
|
43 |
+
def get_random_available_vpn_node(self):
|
44 |
+
return random.choice(self.available_nodes)
|
45 |
+
|
46 |
+
def random_switch_vpn_node(self):
|
47 |
+
node_name = self.get_random_available_vpn_node()
|
48 |
+
print(f"Switching to {node_name}")
|
49 |
+
self.switch_vpn_node(node_name)
|
50 |
+
# self.current_node = node_name
|
51 |
+
return node_name
|
52 |
+
|
53 |
+
def get_vpn_nodes(self):
|
54 |
+
url = f"http://{self.external_controller}/group/Proxy"
|
55 |
+
delay_res = requests.get(url, headers=self.headers)
|
56 |
+
return delay_res.json()['all']
|
57 |
+
|
58 |
+
def get_available_vpn_nodes(self):
|
59 |
+
url = f"http://{self.external_controller}/group/Proxy/delay?timeout=5000&url=http://www.gstatic.com/generate_204"
|
60 |
+
delay_res = requests.get(url, headers=self.headers)
|
61 |
+
return set(delay_res.json().keys())
|
62 |
+
|
63 |
+
def get_current_vpn_node_ip(self):
|
64 |
+
url = "http://ifconfig.me"
|
65 |
+
r = requests.request("GET", url)
|
66 |
+
return r.text
|
67 |
+
|
68 |
+
def add_unavailable_node(self, node_name):
|
69 |
+
self.unavailable_nodes.add(node_name)
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.6.0
|
2 |
+
torchvision==0.21.0
|
3 |
+
timm
|
4 |
+
einops
|
5 |
+
accelerate
|
6 |
+
transformers==4.51.3
|
7 |
+
diffusers
|
8 |
+
opencv-python-headless
|
9 |
+
scipy
|
10 |
+
wandb
|
11 |
+
matplotlib
|
12 |
+
Pillow
|
13 |
+
tqdm
|
14 |
+
omegaconf
|
15 |
+
python-dotenv
|
16 |
+
ninja
|