sienna223 commited on
Commit
119e1fd
·
1 Parent(s): e22a98e
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --upgrade pip wheel setuptools --no-cache-dir && \
11
+ pip install -r /code/requirements.txt --no-cache-dir && \
12
+ pip install flash-attn --no-build-isolation --no-cache-dir
13
+
14
+ # Set up a new user named "user" with user ID 1000
15
+ RUN useradd -m -u 1000 user
16
+
17
+ # Switch to the "user" user
18
+ USER user
19
+
20
+ # Set home to the user's home directory
21
+ ENV HOME=/home/user \
22
+ PATH=/home/user/.local/bin:$PATH
23
+
24
+ # Set the working directory to the user's home directory
25
+ WORKDIR $HOME/app
26
+
27
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
28
+ COPY --chown=user . $HOME/app
29
+
30
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -7,6 +7,4 @@ sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
  short_description: 'OmniGen2: Unified Image Understanding and Generation.'
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  pinned: false
8
  license: apache-2.0
9
  short_description: 'OmniGen2: Unified Image Understanding and Generation.'
10
+ ---
 
 
app.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dotenv
2
+
3
+ dotenv.load_dotenv(override=True)
4
+
5
+ import gradio as gr
6
+
7
+ import os
8
+ import argparse
9
+ import random
10
+ from datetime import datetime
11
+
12
+ import torch
13
+ from torchvision.transforms.functional import to_pil_image, to_tensor
14
+
15
+ from accelerate import Accelerator
16
+
17
+ from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
18
+ from omnigen2.utils.img_util import create_collage
19
+ from omnigen2.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
20
+ from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
21
+
22
+ NEGATIVE_PROMPT = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+
25
+ pipeline = None
26
+ accelerator = None
27
+ save_images = False
28
+
29
+ def load_pipeline(accelerator, weight_dtype, args):
30
+ pipeline = OmniGen2Pipeline.from_pretrained(
31
+ args.model_path,
32
+ torch_dtype=weight_dtype,
33
+ trust_remote_code=True,
34
+ )
35
+ if args.enable_sequential_cpu_offload:
36
+ pipeline.enable_sequential_cpu_offload()
37
+ elif args.enable_model_cpu_offload:
38
+ pipeline.enable_model_cpu_offload()
39
+ else:
40
+ pipeline = pipeline.to(accelerator.device)
41
+ return pipeline
42
+
43
+
44
+ def run(
45
+ instruction,
46
+ width_input,
47
+ height_input,
48
+ scheduler,
49
+ num_inference_steps,
50
+ image_input_1,
51
+ image_input_2,
52
+ image_input_3,
53
+ negative_prompt,
54
+ guidance_scale_input,
55
+ img_guidance_scale_input,
56
+ cfg_range_start,
57
+ cfg_range_end,
58
+ num_images_per_prompt,
59
+ max_input_image_side_length,
60
+ max_pixels,
61
+ seed_input,
62
+ progress=gr.Progress(),
63
+ ):
64
+ input_images = [image_input_1, image_input_2, image_input_3]
65
+ input_images = [img for img in input_images if img is not None]
66
+
67
+ if len(input_images) == 0:
68
+ input_images = None
69
+
70
+ if seed_input == -1:
71
+ seed_input = random.randint(0, 2**16 - 1)
72
+
73
+ generator = torch.Generator(device=accelerator.device).manual_seed(seed_input)
74
+
75
+ def progress_callback(cur_step, timesteps):
76
+ frac = (cur_step + 1) / float(timesteps)
77
+ progress(frac)
78
+
79
+ if scheduler == 'euler':
80
+ pipeline.scheduler = FlowMatchEulerDiscreteScheduler()
81
+ elif scheduler == 'dpmsolver':
82
+ pipeline.scheduler = DPMSolverMultistepScheduler(
83
+ algorithm_type="dpmsolver++",
84
+ solver_type="midpoint",
85
+ solver_order=2,
86
+ prediction_type="flow_prediction",
87
+ )
88
+
89
+ results = pipeline(
90
+ prompt=instruction,
91
+ input_images=input_images,
92
+ width=width_input,
93
+ height=height_input,
94
+ max_input_image_side_length=max_input_image_side_length,
95
+ max_pixels=max_pixels,
96
+ num_inference_steps=num_inference_steps,
97
+ max_sequence_length=1024,
98
+ text_guidance_scale=guidance_scale_input,
99
+ image_guidance_scale=img_guidance_scale_input,
100
+ cfg_range=(cfg_range_start, cfg_range_end),
101
+ negative_prompt=negative_prompt,
102
+ num_images_per_prompt=num_images_per_prompt,
103
+ generator=generator,
104
+ output_type="pil",
105
+ step_func=progress_callback,
106
+ )
107
+
108
+ progress(1.0)
109
+
110
+ vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
111
+ output_image = create_collage(vis_images)
112
+
113
+ if save_images:
114
+ # Create outputs directory if it doesn't exist
115
+ output_dir = os.path.join(ROOT_DIR, "outputs_gradio")
116
+ os.makedirs(output_dir, exist_ok=True)
117
+
118
+ # Generate unique filename with timestamp
119
+ timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
120
+
121
+ # Generate unique filename with timestamp
122
+ output_path = os.path.join(output_dir, f"{timestamp}.png")
123
+ # Save the image
124
+ output_image.save(output_path)
125
+
126
+ # Save All Generated Images
127
+ if len(results.images) > 1:
128
+ for i, image in enumerate(results.images):
129
+ image_name, ext = os.path.splitext(output_path)
130
+ image.save(f"{image_name}_{i}{ext}")
131
+ return output_image
132
+
133
+
134
+ def get_example():
135
+ case = [
136
+ [
137
+ "The sun rises slightly, the dew on the rose petals in the garden is clear, a crystal ladybug is crawling to the dew, the background is the early morning garden, macro lens.",
138
+ 1024,
139
+ 1024,
140
+ 'euler',
141
+ 50,
142
+ None,
143
+ None,
144
+ None,
145
+ NEGATIVE_PROMPT,
146
+ 3.5,
147
+ 1.0,
148
+ 0.0,
149
+ 1.0,
150
+ 1,
151
+ 2048,
152
+ 1024 * 1024,
153
+ 0,
154
+ ],
155
+ [
156
+ "A snow maiden with pale translucent skin, frosty white lashes, and a soft expression of longing",
157
+ 1024,
158
+ 1024,
159
+ 'euler',
160
+ 50,
161
+ None,
162
+ None,
163
+ None,
164
+ NEGATIVE_PROMPT,
165
+ 3.5,
166
+ 1.0,
167
+ 0.0,
168
+ 1.0,
169
+ 1,
170
+ 2048,
171
+ 1024 * 1024,
172
+ 0,
173
+ ],
174
+ [
175
+ "Add a fisherman hat to the woman's head",
176
+ 1024,
177
+ 1024,
178
+ 'euler',
179
+ 50,
180
+ os.path.join(ROOT_DIR, "example_images/flux5.png"),
181
+ None,
182
+ None,
183
+ NEGATIVE_PROMPT,
184
+ 5.0,
185
+ 2.0,
186
+ 0.0,
187
+ 1.0,
188
+ 1,
189
+ 2048,
190
+ 1024 * 1024,
191
+ 0,
192
+ ],
193
+ [
194
+ " replace the sword with a hammer.",
195
+ 1024,
196
+ 1024,
197
+ 'euler',
198
+ 50,
199
+ os.path.join(
200
+ ROOT_DIR,
201
+ "example_images/d8f8f44c64106e7715c61b5dfa9d9ca0974314c5d4a4a50418acf7ff373432bb.png",
202
+ ),
203
+ None,
204
+ None,
205
+ NEGATIVE_PROMPT,
206
+ 5.0,
207
+ 2.0,
208
+ 0.0,
209
+ 1.0,
210
+ 1,
211
+ 2048,
212
+ 1024 * 1024,
213
+ 0,
214
+ ],
215
+ [
216
+ "Extract the character from the picture and fill the rest of the background with white.",
217
+ # "Transform the sculpture into jade",
218
+ 1024,
219
+ 1024,
220
+ 'euler',
221
+ 50,
222
+ os.path.join(
223
+ ROOT_DIR, "example_images/46e79704-c88e-4e68-97b4-b4c40cd29826.png"
224
+ ),
225
+ None,
226
+ None,
227
+ NEGATIVE_PROMPT,
228
+ 5.0,
229
+ 2.0,
230
+ 0.0,
231
+ 1.0,
232
+ 1,
233
+ 2048,
234
+ 1024 * 1024,
235
+ 0,
236
+ ],
237
+ [
238
+ "Make he smile",
239
+ 1024,
240
+ 1024,
241
+ 'euler',
242
+ 50,
243
+ os.path.join(
244
+ ROOT_DIR, "example_images/vicky-hladynets-C8Ta0gwPbQg-unsplash.jpg"
245
+ ),
246
+ None,
247
+ None,
248
+ NEGATIVE_PROMPT,
249
+ 5.0,
250
+ 2.0,
251
+ 0.0,
252
+ 1.0,
253
+ 1,
254
+ 2048,
255
+ 1024 * 1024,
256
+ 0,
257
+ ],
258
+ [
259
+ "Change the background to classroom",
260
+ 1024,
261
+ 1024,
262
+ 'euler',
263
+ 50,
264
+ os.path.join(ROOT_DIR, "example_images/ComfyUI_temp_mllvz_00071_.png"),
265
+ None,
266
+ None,
267
+ NEGATIVE_PROMPT,
268
+ 5.0,
269
+ 2.0,
270
+ 0.0,
271
+ 1.0,
272
+ 1,
273
+ 2048,
274
+ 1024 * 1024,
275
+ 0,
276
+ ],
277
+ [
278
+ "Raise his hand",
279
+ 1024,
280
+ 1024,
281
+ 'euler',
282
+ 50,
283
+ os.path.join(
284
+ ROOT_DIR,
285
+ "example_images/289089159-a6d7abc142419e63cab0a566eb38e0fb6acb217b340f054c6172139b316f6596.png",
286
+ ),
287
+ None,
288
+ None,
289
+ NEGATIVE_PROMPT,
290
+ 5.0,
291
+ 2.0,
292
+ 0.0,
293
+ 1.0,
294
+ 1,
295
+ 2048,
296
+ 1024 * 1024,
297
+ 0,
298
+ ],
299
+ [
300
+ "Generate a photo of an anime-style figurine placed on a desk. The figurine model should be based on the character photo provided in the attachment, accurately replicating the full-body pose, facial expression, and clothing style of the character in the photo, ensuring the entire figurine is fully presented. The overall design should be exquisite and detailed, soft gradient colors and a delicate texture, leaning towards a Japanese anime style, rich in details, with a realistic quality and beautiful visual appeal.",
301
+ 1024,
302
+ 1024,
303
+ 'euler',
304
+ 50,
305
+ os.path.join(ROOT_DIR, "example_images/RAL_0315.JPG"),
306
+ None,
307
+ None,
308
+ NEGATIVE_PROMPT,
309
+ 5.0,
310
+ 2.0,
311
+ 0.0,
312
+ 1.0,
313
+ 1,
314
+ 2048,
315
+ 1024 * 1024,
316
+ 0,
317
+ ],
318
+ [
319
+ "Change the dress to blue.",
320
+ 1024,
321
+ 1024,
322
+ 'euler',
323
+ 50,
324
+ os.path.join(ROOT_DIR, "example_images/1.png"),
325
+ None,
326
+ None,
327
+ NEGATIVE_PROMPT,
328
+ 5.0,
329
+ 2.0,
330
+ 0.0,
331
+ 1.0,
332
+ 1,
333
+ 2048,
334
+ 1024 * 1024,
335
+ 0,
336
+ ],
337
+ [
338
+ "Remove the cat",
339
+ 1024,
340
+ 1024,
341
+ 'euler',
342
+ 50,
343
+ os.path.join(
344
+ ROOT_DIR,
345
+ "example_images/386724677-589d19050d4ea0603aee6831459aede29a24f4d8668c62c049f413db31508a54.png",
346
+ ),
347
+ None,
348
+ None,
349
+ NEGATIVE_PROMPT,
350
+ 5.0,
351
+ 2.0,
352
+ 0.0,
353
+ 1.0,
354
+ 1,
355
+ 2048,
356
+ 1024 * 1024,
357
+ 0,
358
+ ],
359
+ [
360
+ "In a cozy café, the anime figure is sitting in front of a laptop, smiling confidently.",
361
+ 1024,
362
+ 1024,
363
+ 'euler',
364
+ 50,
365
+ os.path.join(ROOT_DIR, "example_images/ComfyUI_00254_.png"),
366
+ None,
367
+ None,
368
+ NEGATIVE_PROMPT,
369
+ 5.0,
370
+ 2.0,
371
+ 0.0,
372
+ 1.0,
373
+ 1,
374
+ 2048,
375
+ 1024 * 1024,
376
+ 0,
377
+ ],
378
+ [
379
+ "Create a wedding figure based on the girl in the first image and the man in the second image. Set the background as a wedding hall, with the man dressed in a suit and the girl in a white wedding dress. Ensure that the original faces remain unchanged and are accurately preserved. The man should adopt a realistic style, whereas the girl should maintain their classic anime style.",
380
+ 1024,
381
+ 1024,
382
+ 'euler',
383
+ 50,
384
+ os.path.join(ROOT_DIR, "example_images/1_20241127203215.png"),
385
+ os.path.join(ROOT_DIR, "example_images/000050281.jpg"),
386
+ None,
387
+ NEGATIVE_PROMPT,
388
+ 5.0,
389
+ 3.0,
390
+ 0.0,
391
+ 1.0,
392
+ 1,
393
+ 2048,
394
+ 1024 * 1024,
395
+ 0,
396
+ ],
397
+ [
398
+ "Let the girl and the boy get married in the church. ",
399
+ 1024,
400
+ 1024,
401
+ 'euler',
402
+ 50,
403
+ os.path.join(ROOT_DIR, "example_images/8FtFUxRzXqaguVRGzkHvN.png"),
404
+ os.path.join(ROOT_DIR, "example_images/01194-20240127001056_1024x1536.png"),
405
+ None,
406
+ NEGATIVE_PROMPT,
407
+ 5.0,
408
+ 3.0,
409
+ 0.0,
410
+ 1.0,
411
+ 1,
412
+ 2048,
413
+ 1024 * 1024,
414
+ 0,
415
+ ],
416
+ [
417
+ "Let the man from image1 and the woman from image2 kiss and hug",
418
+ 1024,
419
+ 1024,
420
+ 'euler',
421
+ 50,
422
+ os.path.join(ROOT_DIR, "example_images/1280X1280.png"),
423
+ os.path.join(ROOT_DIR, "example_images/000077066.jpg"),
424
+ None,
425
+ NEGATIVE_PROMPT,
426
+ 5.0,
427
+ 2.0,
428
+ 0.0,
429
+ 1.0,
430
+ 1,
431
+ 2048,
432
+ 1024 * 1024,
433
+ 0,
434
+ ],
435
+ [
436
+ "Please let the person in image 2 hold the toy from the first image in a parking lot.",
437
+ 1024,
438
+ 1024,
439
+ 'euler',
440
+ 50,
441
+ os.path.join(ROOT_DIR, "example_images/04.jpg"),
442
+ os.path.join(ROOT_DIR, "example_images/000365954.jpg"),
443
+ None,
444
+ NEGATIVE_PROMPT,
445
+ 5.0,
446
+ 2.0,
447
+ 0.0,
448
+ 1.0,
449
+ 1,
450
+ 2048,
451
+ 1024 * 1024,
452
+ 0,
453
+ ],
454
+ [
455
+ "Make the girl pray in the second image.",
456
+ 1024,
457
+ 682,
458
+ 'euler',
459
+ 50,
460
+ os.path.join(ROOT_DIR, "example_images/000440817.jpg"),
461
+ os.path.join(ROOT_DIR, "example_images/000119733.jpg"),
462
+ None,
463
+ NEGATIVE_PROMPT,
464
+ 5.0,
465
+ 2.0,
466
+ 0.0,
467
+ 1.0,
468
+ 1,
469
+ 2048,
470
+ 1024 * 1024,
471
+ 0,
472
+ ],
473
+ [
474
+ "Add the bird from image 1 to the desk in image 2",
475
+ 1024,
476
+ 682,
477
+ 'euler',
478
+ 50,
479
+ os.path.join(
480
+ ROOT_DIR,
481
+ "example_images/996e2cf6-daa5-48c4-9ad7-0719af640c17_1748848108409.png",
482
+ ),
483
+ os.path.join(ROOT_DIR, "example_images/00066-10350085.png"),
484
+ None,
485
+ NEGATIVE_PROMPT,
486
+ 5.0,
487
+ 2.0,
488
+ 0.0,
489
+ 1.0,
490
+ 1,
491
+ 2048,
492
+ 1024 * 1024,
493
+ 0,
494
+ ],
495
+ [
496
+ "Replace the apple in the first image with the cat from the second image",
497
+ 1024,
498
+ 780,
499
+ 'euler',
500
+ 50,
501
+ os.path.join(ROOT_DIR, "example_images/apple.png"),
502
+ os.path.join(
503
+ ROOT_DIR,
504
+ "example_images/468404374-d52ec1a44aa7e0dc9c2807ce09d303a111c78f34da3da2401b83ce10815ff872.png",
505
+ ),
506
+ None,
507
+ NEGATIVE_PROMPT,
508
+ 5.0,
509
+ 2.0,
510
+ 0.0,
511
+ 1.0,
512
+ 1,
513
+ 2048,
514
+ 1024 * 1024,
515
+ 0,
516
+ ],
517
+ [
518
+ "Replace the woman in the second image with the woman from the first image",
519
+ 1024,
520
+ 747,
521
+ 'euler',
522
+ 50,
523
+ os.path.join(
524
+ ROOT_DIR, "example_images/byward-outfitters-B97YFrsITyo-unsplash.jpg"
525
+ ),
526
+ os.path.join(
527
+ ROOT_DIR, "example_images/6652baf6-4096-40ef-a475-425e4c072daf.png"
528
+ ),
529
+ None,
530
+ NEGATIVE_PROMPT,
531
+ 5.0,
532
+ 2.0,
533
+ 0.0,
534
+ 1.0,
535
+ 1,
536
+ 2048,
537
+ 1024 * 1024,
538
+ 0,
539
+ ],
540
+ ]
541
+ return case
542
+
543
+
544
+ def run_for_examples(
545
+ instruction,
546
+ width_input,
547
+ height_input,
548
+ scheduler,
549
+ num_inference_steps,
550
+ image_input_1,
551
+ image_input_2,
552
+ image_input_3,
553
+ negative_prompt,
554
+ text_guidance_scale_input,
555
+ image_guidance_scale_input,
556
+ cfg_range_start,
557
+ cfg_range_end,
558
+ num_images_per_prompt,
559
+ max_input_image_side_length,
560
+ max_pixels,
561
+ seed_input,
562
+ ):
563
+ return run(
564
+ instruction,
565
+ width_input,
566
+ height_input,
567
+ scheduler,
568
+ num_inference_steps,
569
+ image_input_1,
570
+ image_input_2,
571
+ image_input_3,
572
+ negative_prompt,
573
+ text_guidance_scale_input,
574
+ image_guidance_scale_input,
575
+ cfg_range_start,
576
+ cfg_range_end,
577
+ num_images_per_prompt,
578
+ max_input_image_side_length,
579
+ max_pixels,
580
+ seed_input,
581
+ )
582
+
583
+ description = """
584
+ ### 💡 Quick Tips for Best Results (see our [github](https://github.com/VectorSpaceLab/OmniGen2?tab=readme-ov-file#-usage-tips) for more details)
585
+ - Image Quality: Use high-resolution images (at least 512x512 recommended).
586
+ - Be Specific: Instead of "Add bird to desk", try "Add the bird from image 1 to the desk in image 2".
587
+ - Use English: English prompts currently yield better results.
588
+ - Adjust image_guidance_scale for better consistency with the reference image:
589
+ - Image Editing: 1.3 - 2.0
590
+ - In-context Generation: 2.0 - 3.0
591
+ """
592
+
593
+ article = """
594
+ citation to be added
595
+ """
596
+
597
+ def main(args):
598
+ # Gradio
599
+ with gr.Blocks() as demo:
600
+ gr.Markdown(
601
+ "# OmniGen2: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen2)"
602
+ )
603
+ gr.Markdown(description)
604
+ with gr.Row():
605
+ with gr.Column():
606
+ # text prompt
607
+ instruction = gr.Textbox(
608
+ label='Enter your prompt. Use "first/second image" or “第一张图/第二张图” as reference.',
609
+ placeholder="Type your prompt here...",
610
+ )
611
+
612
+ with gr.Row(equal_height=True):
613
+ # input images
614
+ image_input_1 = gr.Image(label="First Image", type="pil")
615
+ image_input_2 = gr.Image(label="Second Image", type="pil")
616
+ image_input_3 = gr.Image(label="Third Image", type="pil")
617
+
618
+ generate_button = gr.Button("Generate Image")
619
+
620
+ negative_prompt = gr.Textbox(
621
+ label="Enter your negative prompt",
622
+ placeholder="Type your negative prompt here...",
623
+ value=NEGATIVE_PROMPT,
624
+ )
625
+
626
+ # slider
627
+ with gr.Row(equal_height=True):
628
+ height_input = gr.Slider(
629
+ label="Height", minimum=256, maximum=1024, value=1024, step=128
630
+ )
631
+ width_input = gr.Slider(
632
+ label="Width", minimum=256, maximum=1024, value=1024, step=128
633
+ )
634
+ with gr.Row(equal_height=True):
635
+ text_guidance_scale_input = gr.Slider(
636
+ label="Text Guidance Scale",
637
+ minimum=1.0,
638
+ maximum=8.0,
639
+ value=5.0,
640
+ step=0.1,
641
+ )
642
+
643
+ image_guidance_scale_input = gr.Slider(
644
+ label="Image Guidance Scale",
645
+ minimum=1.0,
646
+ maximum=3.0,
647
+ value=2.0,
648
+ step=0.1,
649
+ )
650
+ with gr.Row(equal_height=True):
651
+ cfg_range_start = gr.Slider(
652
+ label="CFG Range Start",
653
+ minimum=0.0,
654
+ maximum=1.0,
655
+ value=0.0,
656
+ step=0.1,
657
+ )
658
+
659
+ cfg_range_end = gr.Slider(
660
+ label="CFG Range End",
661
+ minimum=0.0,
662
+ maximum=1.0,
663
+ value=1.0,
664
+ step=0.1,
665
+ )
666
+
667
+ def adjust_end_slider(start_val, end_val):
668
+ return max(start_val, end_val)
669
+
670
+ def adjust_start_slider(end_val, start_val):
671
+ return min(end_val, start_val)
672
+
673
+ cfg_range_start.input(
674
+ fn=adjust_end_slider,
675
+ inputs=[cfg_range_start, cfg_range_end],
676
+ outputs=[cfg_range_end]
677
+ )
678
+
679
+ cfg_range_end.input(
680
+ fn=adjust_start_slider,
681
+ inputs=[cfg_range_end, cfg_range_start],
682
+ outputs=[cfg_range_start]
683
+ )
684
+
685
+ with gr.Row(equal_height=True):
686
+ scheduler_input = gr.Dropdown(
687
+ label="Scheduler",
688
+ choices=["euler", "dpmsolver"],
689
+ value="euler",
690
+ info="The scheduler to use for the model.",
691
+ )
692
+
693
+ num_inference_steps = gr.Slider(
694
+ label="Inference Steps", minimum=20, maximum=100, value=50, step=1
695
+ )
696
+ with gr.Row(equal_height=True):
697
+ num_images_per_prompt = gr.Slider(
698
+ label="Number of images per prompt",
699
+ minimum=1,
700
+ maximum=4,
701
+ value=1,
702
+ step=1,
703
+ )
704
+
705
+ seed_input = gr.Slider(
706
+ label="Seed", minimum=-1, maximum=2147483647, value=0, step=1
707
+ )
708
+ with gr.Row(equal_height=True):
709
+ max_input_image_side_length = gr.Slider(
710
+ label="max_input_image_side_length",
711
+ minimum=256,
712
+ maximum=2048,
713
+ value=2048,
714
+ step=256,
715
+ )
716
+ max_pixels = gr.Slider(
717
+ label="max_pixels",
718
+ minimum=256 * 256,
719
+ maximum=1536 * 1536,
720
+ value=1024 * 1024,
721
+ step=256 * 256,
722
+ )
723
+
724
+ with gr.Column():
725
+ with gr.Column():
726
+ # output image
727
+ output_image = gr.Image(label="Output Image")
728
+ global save_images
729
+ save_images = gr.Checkbox(label="Save generated images", value=False)
730
+
731
+ global accelerator
732
+ global pipeline
733
+
734
+ bf16 = True
735
+ accelerator = Accelerator(mixed_precision="bf16" if bf16 else "no")
736
+ weight_dtype = torch.bfloat16 if bf16 else torch.float32
737
+
738
+ pipeline = load_pipeline(accelerator, weight_dtype, args)
739
+
740
+ # click
741
+ generate_button.click(
742
+ run,
743
+ inputs=[
744
+ instruction,
745
+ width_input,
746
+ height_input,
747
+ scheduler_input,
748
+ num_inference_steps,
749
+ image_input_1,
750
+ image_input_2,
751
+ image_input_3,
752
+ negative_prompt,
753
+ text_guidance_scale_input,
754
+ image_guidance_scale_input,
755
+ cfg_range_start,
756
+ cfg_range_end,
757
+ num_images_per_prompt,
758
+ max_input_image_side_length,
759
+ max_pixels,
760
+ seed_input,
761
+ ],
762
+ outputs=output_image,
763
+ )
764
+
765
+ gr.Examples(
766
+ examples=get_example(),
767
+ fn=run_for_examples,
768
+ inputs=[
769
+ instruction,
770
+ width_input,
771
+ height_input,
772
+ scheduler_input,
773
+ num_inference_steps,
774
+ image_input_1,
775
+ image_input_2,
776
+ image_input_3,
777
+ negative_prompt,
778
+ text_guidance_scale_input,
779
+ image_guidance_scale_input,
780
+ cfg_range_start,
781
+ cfg_range_end,
782
+ num_images_per_prompt,
783
+ max_input_image_side_length,
784
+ max_pixels,
785
+ seed_input,
786
+ ],
787
+ outputs=output_image,
788
+ )
789
+
790
+ gr.Markdown(article)
791
+ # launch
792
+ demo.launch(share=args.share, server_port=args.port, allowed_paths=[ROOT_DIR])
793
+
794
+ def parse_args():
795
+ parser = argparse.ArgumentParser(description="Run the OmniGen2")
796
+ parser.add_argument("--share", action="store_true", help="Share the Gradio app")
797
+ parser.add_argument(
798
+ "--port", type=int, default=7860, help="Port to use for the Gradio app"
799
+ )
800
+ parser.add_argument(
801
+ "--model_path",
802
+ type=str,
803
+ default="OmniGen2/OmniGen2",
804
+ help="Path or HuggingFace name of the model to load."
805
+ )
806
+ parser.add_argument(
807
+ "--enable_model_cpu_offload",
808
+ action="store_true",
809
+ help="Enable model CPU offload."
810
+ )
811
+ parser.add_argument(
812
+ "--enable_sequential_cpu_offload",
813
+ action="store_true",
814
+ help="Enable sequential CPU offload."
815
+ )
816
+ args = parser.parse_args()
817
+ return args
818
+
819
+ if __name__ == "__main__":
820
+ args = parse_args()
821
+ main(args)
omnigen2/.DS_Store ADDED
Binary file (6.15 kB). View file
 
omnigen2/__init__.py ADDED
File without changes
omnigen2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
omnigen2/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (139 Bytes). View file
 
omnigen2/models/__init__.py ADDED
File without changes
omnigen2/models/attention_processor.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniGen2 Attention Processor Module
3
+
4
+ Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ import math
20
+ from typing import Optional, Tuple, Dict, Any
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from einops import repeat
25
+ from flash_attn import flash_attn_varlen_func
26
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
27
+
28
+ from diffusers.models.attention_processor import Attention
29
+ from .embeddings import apply_rotary_emb
30
+
31
+
32
+ class OmniGen2AttnProcessorFlash2Varlen:
33
+ """
34
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
35
+
36
+ This processor is optimized for PyTorch 2.0 and implements:
37
+ - Flash attention with variable length sequences
38
+ - Rotary position embeddings (RoPE)
39
+ - Query-Key normalization
40
+ - Proportional attention scaling
41
+
42
+ Args:
43
+ None
44
+
45
+ Raises:
46
+ ImportError: If PyTorch version is less than 2.0
47
+ """
48
+
49
+ def __init__(self) -> None:
50
+ """Initialize the attention processor."""
51
+ if not hasattr(F, "scaled_dot_product_attention"):
52
+ raise ImportError(
53
+ "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
54
+ "Please upgrade PyTorch to version 2.0 or later."
55
+ )
56
+
57
+ def _upad_input(
58
+ self,
59
+ query_layer: torch.Tensor,
60
+ key_layer: torch.Tensor,
61
+ value_layer: torch.Tensor,
62
+ attention_mask: torch.Tensor,
63
+ query_length: int,
64
+ num_heads: int,
65
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
66
+ """
67
+ Unpad the input tensors for flash attention.
68
+
69
+ Args:
70
+ query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
71
+ key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
72
+ value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
73
+ attention_mask: Attention mask tensor of shape (batch_size, seq_len)
74
+ query_length: Length of the query sequence
75
+ num_heads: Number of attention heads
76
+
77
+ Returns:
78
+ Tuple containing:
79
+ - Unpadded query tensor
80
+ - Unpadded key tensor
81
+ - Unpadded value tensor
82
+ - Query indices
83
+ - Tuple of cumulative sequence lengths for query and key
84
+ - Tuple of maximum sequence lengths for query and key
85
+ """
86
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
87
+ """Helper function to get unpadding data from attention mask."""
88
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
89
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
90
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
91
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
92
+ return indices, cu_seqlens, max_seqlen_in_batch
93
+
94
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
95
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
96
+
97
+ # Unpad key and value layers
98
+ key_layer = index_first_axis(
99
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
100
+ indices_k,
101
+ )
102
+ value_layer = index_first_axis(
103
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
104
+ indices_k,
105
+ )
106
+
107
+ # Handle different query length cases
108
+ if query_length == kv_seq_len:
109
+ query_layer = index_first_axis(
110
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
111
+ indices_k,
112
+ )
113
+ cu_seqlens_q = cu_seqlens_k
114
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
115
+ indices_q = indices_k
116
+ elif query_length == 1:
117
+ max_seqlen_in_batch_q = 1
118
+ cu_seqlens_q = torch.arange(
119
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
120
+ )
121
+ indices_q = cu_seqlens_q[:-1]
122
+ query_layer = query_layer.squeeze(1)
123
+ else:
124
+ attention_mask = attention_mask[:, -query_length:]
125
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
126
+
127
+ return (
128
+ query_layer,
129
+ key_layer,
130
+ value_layer,
131
+ indices_q,
132
+ (cu_seqlens_q, cu_seqlens_k),
133
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
134
+ )
135
+
136
+ def __call__(
137
+ self,
138
+ attn: Attention,
139
+ hidden_states: torch.Tensor,
140
+ encoder_hidden_states: torch.Tensor,
141
+ attention_mask: Optional[torch.Tensor] = None,
142
+ image_rotary_emb: Optional[torch.Tensor] = None,
143
+ base_sequence_length: Optional[int] = None,
144
+ ) -> torch.Tensor:
145
+ """
146
+ Process attention computation with flash attention.
147
+
148
+ Args:
149
+ attn: Attention module
150
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
151
+ encoder_hidden_states: Encoder hidden states tensor
152
+ attention_mask: Optional attention mask tensor
153
+ image_rotary_emb: Optional rotary embeddings for image tokens
154
+ base_sequence_length: Optional base sequence length for proportional attention
155
+
156
+ Returns:
157
+ torch.Tensor: Processed hidden states after attention computation
158
+ """
159
+ batch_size, sequence_length, _ = hidden_states.shape
160
+
161
+ # Get Query-Key-Value Pair
162
+ query = attn.to_q(hidden_states)
163
+ key = attn.to_k(encoder_hidden_states)
164
+ value = attn.to_v(encoder_hidden_states)
165
+
166
+ query_dim = query.shape[-1]
167
+ inner_dim = key.shape[-1]
168
+ head_dim = query_dim // attn.heads
169
+ dtype = query.dtype
170
+
171
+ # Get key-value heads
172
+ kv_heads = inner_dim // head_dim
173
+
174
+ # Reshape tensors for attention computation
175
+ query = query.view(batch_size, -1, attn.heads, head_dim)
176
+ key = key.view(batch_size, -1, kv_heads, head_dim)
177
+ value = value.view(batch_size, -1, kv_heads, head_dim)
178
+
179
+ # Apply Query-Key normalization
180
+ if attn.norm_q is not None:
181
+ query = attn.norm_q(query)
182
+ if attn.norm_k is not None:
183
+ key = attn.norm_k(key)
184
+
185
+ # Apply Rotary Position Embeddings
186
+ if image_rotary_emb is not None:
187
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
188
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
189
+
190
+ query, key = query.to(dtype), key.to(dtype)
191
+
192
+ # Calculate attention scale
193
+ if base_sequence_length is not None:
194
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
195
+ else:
196
+ softmax_scale = attn.scale
197
+
198
+ # Unpad input for flash attention
199
+ (
200
+ query_states,
201
+ key_states,
202
+ value_states,
203
+ indices_q,
204
+ cu_seq_lens,
205
+ max_seq_lens,
206
+ ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
207
+
208
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
209
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
210
+
211
+ # Handle different number of heads
212
+ if kv_heads < attn.heads:
213
+ key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
214
+ value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
215
+
216
+ # Apply flash attention
217
+ attn_output_unpad = flash_attn_varlen_func(
218
+ query_states,
219
+ key_states,
220
+ value_states,
221
+ cu_seqlens_q=cu_seqlens_q,
222
+ cu_seqlens_k=cu_seqlens_k,
223
+ max_seqlen_q=max_seqlen_in_batch_q,
224
+ max_seqlen_k=max_seqlen_in_batch_k,
225
+ dropout_p=0.0,
226
+ causal=False,
227
+ softmax_scale=softmax_scale,
228
+ )
229
+
230
+ # Pad output and apply final transformations
231
+ hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
232
+ hidden_states = hidden_states.flatten(-2)
233
+ hidden_states = hidden_states.type_as(query)
234
+
235
+ # Apply output projection
236
+ hidden_states = attn.to_out[0](hidden_states)
237
+ hidden_states = attn.to_out[1](hidden_states)
238
+
239
+ return hidden_states
omnigen2/models/embeddings.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+
20
+ from diffusers.models.activations import get_activation
21
+
22
+
23
+ class TimestepEmbedding(nn.Module):
24
+ def __init__(
25
+ self,
26
+ in_channels: int,
27
+ time_embed_dim: int,
28
+ act_fn: str = "silu",
29
+ out_dim: int = None,
30
+ post_act_fn: Optional[str] = None,
31
+ cond_proj_dim=None,
32
+ sample_proj_bias=True,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
37
+
38
+ if cond_proj_dim is not None:
39
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
40
+ else:
41
+ self.cond_proj = None
42
+
43
+ self.act = get_activation(act_fn)
44
+
45
+ if out_dim is not None:
46
+ time_embed_dim_out = out_dim
47
+ else:
48
+ time_embed_dim_out = time_embed_dim
49
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
50
+
51
+ if post_act_fn is None:
52
+ self.post_act = None
53
+ else:
54
+ self.post_act = get_activation(post_act_fn)
55
+
56
+ self.initialize_weights()
57
+
58
+ def initialize_weights(self):
59
+ nn.init.normal_(self.linear_1.weight, std=0.02)
60
+ nn.init.zeros_(self.linear_1.bias)
61
+ nn.init.normal_(self.linear_2.weight, std=0.02)
62
+ nn.init.zeros_(self.linear_2.bias)
63
+
64
+ def forward(self, sample, condition=None):
65
+ if condition is not None:
66
+ sample = sample + self.cond_proj(condition)
67
+ sample = self.linear_1(sample)
68
+
69
+ if self.act is not None:
70
+ sample = self.act(sample)
71
+
72
+ sample = self.linear_2(sample)
73
+
74
+ if self.post_act is not None:
75
+ sample = self.post_act(sample)
76
+ return sample
77
+
78
+
79
+ def apply_rotary_emb(
80
+ x: torch.Tensor,
81
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
82
+ use_real: bool = True,
83
+ use_real_unbind_dim: int = -1,
84
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ """
86
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
87
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
88
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
89
+ tensors contain rotary embeddings and are returned as real tensors.
90
+
91
+ Args:
92
+ x (`torch.Tensor`):
93
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
94
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
95
+
96
+ Returns:
97
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
98
+ """
99
+ if use_real:
100
+ cos, sin = freqs_cis # [S, D]
101
+ cos = cos[None, None]
102
+ sin = sin[None, None]
103
+ cos, sin = cos.to(x.device), sin.to(x.device)
104
+
105
+ if use_real_unbind_dim == -1:
106
+ # Used for flux, cogvideox, hunyuan-dit
107
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
108
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
109
+ elif use_real_unbind_dim == -2:
110
+ # Used for Stable Audio, OmniGen and CogView4
111
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
112
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
113
+ else:
114
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
115
+
116
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
117
+
118
+ return out
119
+ else:
120
+ # used for lumina
121
+ # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
122
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
123
+ freqs_cis = freqs_cis.unsqueeze(2)
124
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
125
+
126
+ return x_out.type_as(x)
omnigen2/models/transformers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .transformer_omnigen2 import OmniGen2Transformer2DModel
2
+
3
+ __all__ = ["OmniGen2Transformer2DModel"]
omnigen2/models/transformers/block_lumina2.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warnings
17
+ import itertools
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from diffusers.models.embeddings import Timesteps
25
+ from ..embeddings import TimestepEmbedding
26
+ from .components import swiglu
27
+
28
+ try:
29
+ # from apex.normalization import FusedRMSNorm
30
+ # from flash_attn.ops.rms_norm import RMSNorm as FusedRMSNorm
31
+ # from flash_attn.ops.triton.layer_norm import RMSNorm as FusedRMSNorm
32
+ from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm
33
+ FUSEDRMSNORM_AVALIBLE = True
34
+ except ImportError:
35
+ FUSEDRMSNORM_AVALIBLE = False
36
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
37
+
38
+ try:
39
+ from flash_attn.ops.activations import swiglu as fused_swiglu
40
+ FUSEDSWIGLU_AVALIBLE = True
41
+ except ImportError:
42
+
43
+ FUSEDSWIGLU_AVALIBLE = False
44
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
45
+
46
+ class LuminaRMSNormZero(nn.Module):
47
+ """
48
+ Norm layer adaptive RMS normalization zero.
49
+
50
+ Parameters:
51
+ embedding_dim (`int`): The size of each embedding vector.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ embedding_dim: int,
57
+ norm_eps: float,
58
+ norm_elementwise_affine: bool,
59
+ use_fused_rms_norm: bool = False,
60
+ ):
61
+ super().__init__()
62
+ self.silu = nn.SiLU()
63
+ self.linear = nn.Linear(
64
+ min(embedding_dim, 1024),
65
+ 4 * embedding_dim,
66
+ bias=True,
67
+ )
68
+ if use_fused_rms_norm:
69
+ assert FUSEDRMSNORM_AVALIBLE
70
+ self.norm = FusedRMSNorm(embedding_dim, eps=norm_eps)
71
+ else:
72
+ self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps)
73
+
74
+ def forward(
75
+ self,
76
+ x: torch.Tensor,
77
+ emb: Optional[torch.Tensor] = None,
78
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
79
+ emb = self.linear(self.silu(emb))
80
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
81
+ x = self.norm(x) * (1 + scale_msa[:, None])
82
+ # x_norm = self.norm(x)
83
+ # print(f"{x.shape=} {x.dtype=} {x_norm.shape=} {x_norm.dtype=}")
84
+ # print(f"{scale_msa.shape=} {scale_msa.dtype=}")
85
+ # print(f"{scale_msa[:, None].shape=} {scale_msa[:, None].dtype=}")
86
+ # x = x_norm * (1 + scale_msa[:, None])
87
+
88
+ return x, gate_msa, scale_mlp, gate_mlp
89
+
90
+
91
+ class LuminaLayerNormContinuous(nn.Module):
92
+ def __init__(
93
+ self,
94
+ embedding_dim: int,
95
+ conditioning_embedding_dim: int,
96
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
97
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
98
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
99
+ # However, this is how it was implemented in the original code, and it's rather likely you should
100
+ # set `elementwise_affine` to False.
101
+ elementwise_affine=True,
102
+ eps=1e-5,
103
+ bias=True,
104
+ norm_type="layer_norm",
105
+ out_dim: Optional[int] = None,
106
+ use_fused_rms_norm: bool = False
107
+ ):
108
+ super().__init__()
109
+
110
+ # AdaLN
111
+ self.silu = nn.SiLU()
112
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
113
+
114
+ if norm_type == "layer_norm":
115
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
116
+ elif norm_type == "rms_norm":
117
+ if use_fused_rms_norm:
118
+ assert FUSEDRMSNORM_AVALIBLE
119
+ self.norm = FusedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
120
+ else:
121
+ self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
122
+ else:
123
+ raise ValueError(f"unknown norm_type {norm_type}")
124
+
125
+ self.linear_2 = None
126
+ if out_dim is not None:
127
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
128
+
129
+ def forward(
130
+ self,
131
+ x: torch.Tensor,
132
+ conditioning_embedding: torch.Tensor,
133
+ ) -> torch.Tensor:
134
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
135
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
136
+ scale = emb
137
+ x = self.norm(x) * (1 + scale)[:, None, :]
138
+
139
+ if self.linear_2 is not None:
140
+ x = self.linear_2(x)
141
+
142
+ return x
143
+
144
+
145
+ class LuminaFeedForward(nn.Module):
146
+ r"""
147
+ A feed-forward layer.
148
+
149
+ Parameters:
150
+ hidden_size (`int`):
151
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
152
+ hidden representations.
153
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
154
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
155
+ of this value.
156
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
157
+ dimension. Defaults to None.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ dim: int,
163
+ inner_dim: int,
164
+ multiple_of: Optional[int] = 256,
165
+ ffn_dim_multiplier: Optional[float] = None,
166
+ use_fused_swiglu: bool = False
167
+ ):
168
+ super().__init__()
169
+ self.use_fused_swiglu = use_fused_swiglu
170
+
171
+ if use_fused_swiglu:
172
+ assert FUSEDSWIGLU_AVALIBLE
173
+ self.swiglu = fused_swiglu
174
+ else:
175
+ self.swiglu = swiglu
176
+
177
+ # custom hidden_size factor multiplier
178
+ if ffn_dim_multiplier is not None:
179
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
180
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
181
+
182
+ self.linear_1 = nn.Linear(
183
+ dim,
184
+ inner_dim,
185
+ bias=False,
186
+ )
187
+ self.linear_2 = nn.Linear(
188
+ inner_dim,
189
+ dim,
190
+ bias=False,
191
+ )
192
+ self.linear_3 = nn.Linear(
193
+ dim,
194
+ inner_dim,
195
+ bias=False,
196
+ )
197
+
198
+ def forward(self, x):
199
+ h1, h2 = self.linear_1(x), self.linear_3(x)
200
+ return self.linear_2(self.swiglu(h1, h2))
201
+
202
+
203
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
204
+ def __init__(
205
+ self,
206
+ hidden_size: int = 4096,
207
+ text_feat_dim: int = 2048,
208
+ frequency_embedding_size: int = 256,
209
+ norm_eps: float = 1e-5,
210
+ timestep_scale: float = 1.0,
211
+ use_fused_rms_norm: bool = False
212
+ ) -> None:
213
+ super().__init__()
214
+
215
+ self.time_proj = Timesteps(
216
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
217
+ )
218
+
219
+ self.timestep_embedder = TimestepEmbedding(
220
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
221
+ )
222
+
223
+ if use_fused_rms_norm:
224
+ assert FUSEDRMSNORM_AVALIBLE
225
+ RMSNorm = FusedRMSNorm
226
+ else:
227
+ RMSNorm = nn.RMSNorm
228
+
229
+ self.caption_embedder = nn.Sequential(
230
+ RMSNorm(text_feat_dim, eps=norm_eps),
231
+ nn.Linear(text_feat_dim, hidden_size, bias=True),
232
+ )
233
+
234
+ self._initialize_weights()
235
+
236
+ def _initialize_weights(self):
237
+ nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
238
+ nn.init.zeros_(self.caption_embedder[1].bias)
239
+
240
+ def forward(
241
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
242
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
243
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
244
+ time_embed = self.timestep_embedder(timestep_proj)
245
+ caption_embed = self.caption_embedder(text_hidden_states)
246
+ return time_embed, caption_embed
omnigen2/models/transformers/components.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+ def swiglu(x, y):
4
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
omnigen2/models/transformers/repo.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from einops import repeat
7
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
8
+
9
+ class OmniGen2RotaryPosEmbed(nn.Module):
10
+ def __init__(self, theta: int,
11
+ axes_dim: Tuple[int, int, int],
12
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
13
+ patch_size: int = 2):
14
+ super().__init__()
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+ self.axes_lens = axes_lens
18
+ self.patch_size = patch_size
19
+
20
+ @staticmethod
21
+ def get_freqs_cis(axes_dim: Tuple[int, int, int],
22
+ axes_lens: Tuple[int, int, int],
23
+ theta: int) -> List[torch.Tensor]:
24
+ freqs_cis = []
25
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
26
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
27
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
28
+ freqs_cis.append(emb)
29
+ return freqs_cis
30
+
31
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
32
+ device = ids.device
33
+ if ids.device.type == "mps":
34
+ ids = ids.to("cpu")
35
+
36
+ result = []
37
+ for i in range(len(self.axes_dim)):
38
+ freqs = freqs_cis[i].to(ids.device)
39
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
40
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
41
+ return torch.cat(result, dim=-1).to(device)
42
+
43
+ def forward(
44
+ self,
45
+ freqs_cis,
46
+ attention_mask,
47
+ l_effective_ref_img_len,
48
+ l_effective_img_len,
49
+ ref_img_sizes,
50
+ img_sizes,
51
+ device
52
+ ):
53
+ batch_size = len(attention_mask)
54
+ p = self.patch_size
55
+
56
+ encoder_seq_len = attention_mask.shape[1]
57
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
58
+
59
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
60
+
61
+ max_seq_len = max(seq_lengths)
62
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
63
+ max_img_len = max(l_effective_img_len)
64
+
65
+ # Create position IDs
66
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
67
+
68
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
69
+ # add text position ids
70
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
71
+
72
+ pe_shift = cap_seq_len
73
+ pe_shift_len = cap_seq_len
74
+
75
+ if ref_img_sizes[i] is not None:
76
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
77
+ H, W = ref_img_size
78
+ ref_H_tokens, ref_W_tokens = H // p, W // p
79
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
80
+ # add image position ids
81
+
82
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
83
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
84
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
85
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
86
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
87
+
88
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
89
+ pe_shift_len += ref_img_len
90
+
91
+ H, W = img_sizes[i]
92
+ H_tokens, W_tokens = H // p, W // p
93
+ assert H_tokens * W_tokens == l_effective_img_len[i]
94
+
95
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
96
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
97
+
98
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
99
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
100
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
101
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
102
+
103
+ # Get combined rotary embeddings
104
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
105
+
106
+ # create separate rotary embeddings for captions and images
107
+ cap_freqs_cis = torch.zeros(
108
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
109
+ )
110
+ ref_img_freqs_cis = torch.zeros(
111
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
112
+ )
113
+ img_freqs_cis = torch.zeros(
114
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
115
+ )
116
+
117
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
118
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
119
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
120
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
121
+
122
+ return (
123
+ cap_freqs_cis,
124
+ ref_img_freqs_cis,
125
+ img_freqs_cis,
126
+ freqs_cis,
127
+ l_effective_cap_len,
128
+ seq_lengths,
129
+ )
omnigen2/models/transformers/transformer_omnigen2.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import itertools
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from einops import rearrange
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.loaders import PeftAdapterMixin
12
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
13
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
14
+ from diffusers.models.attention_processor import Attention
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+
18
+ from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen
19
+ from .repo import OmniGen2RotaryPosEmbed
20
+ from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
21
+
22
+ try:
23
+ from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm
24
+ FUSEDRMSNORM_AVALIBLE = True
25
+ except ImportError:
26
+ FUSEDRMSNORM_AVALIBLE = False
27
+ warnings.warn("Cannot import FusedRMSNorm, falling back to vanilla implementation")
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class OmniGen2TransformerBlock(nn.Module):
33
+ """
34
+ Transformer block for OmniGen2 model.
35
+
36
+ This block implements a transformer layer with:
37
+ - Multi-head attention with flash attention
38
+ - Feed-forward network with SwiGLU activation
39
+ - RMS normalization
40
+ - Optional modulation for conditional generation
41
+
42
+ Args:
43
+ dim: Dimension of the input and output tensors
44
+ num_attention_heads: Number of attention heads
45
+ num_kv_heads: Number of key-value heads
46
+ multiple_of: Multiple of which the hidden dimension should be
47
+ ffn_dim_multiplier: Multiplier for the feed-forward network dimension
48
+ norm_eps: Epsilon value for normalization layers
49
+ modulation: Whether to use modulation for conditional generation
50
+ use_fused_rms_norm: Whether to use fused RMS normalization
51
+ use_fused_swiglu: Whether to use fused SwiGLU activation
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ dim: int,
57
+ num_attention_heads: int,
58
+ num_kv_heads: int,
59
+ multiple_of: int,
60
+ ffn_dim_multiplier: float,
61
+ norm_eps: float,
62
+ modulation: bool = True,
63
+ use_fused_rms_norm: bool = True,
64
+ use_fused_swiglu: bool = True,
65
+ ) -> None:
66
+ """Initialize the transformer block."""
67
+ super().__init__()
68
+ self.head_dim = dim // num_attention_heads
69
+ self.modulation = modulation
70
+
71
+ # Initialize attention layer
72
+ self.attn = Attention(
73
+ query_dim=dim,
74
+ cross_attention_dim=None,
75
+ dim_head=dim // num_attention_heads,
76
+ qk_norm="rms_norm",
77
+ heads=num_attention_heads,
78
+ kv_heads=num_kv_heads,
79
+ eps=1e-5,
80
+ bias=False,
81
+ out_bias=False,
82
+ processor=OmniGen2AttnProcessorFlash2Varlen(),
83
+ )
84
+
85
+ # Initialize feed-forward network
86
+ self.feed_forward = LuminaFeedForward(
87
+ dim=dim,
88
+ inner_dim=4 * dim,
89
+ multiple_of=multiple_of,
90
+ ffn_dim_multiplier=ffn_dim_multiplier,
91
+ use_fused_swiglu=use_fused_swiglu,
92
+ )
93
+
94
+ # Initialize normalization layers
95
+ if modulation:
96
+ self.norm1 = LuminaRMSNormZero(
97
+ embedding_dim=dim,
98
+ norm_eps=norm_eps,
99
+ norm_elementwise_affine=True,
100
+ use_fused_rms_norm=use_fused_rms_norm,
101
+ )
102
+ else:
103
+ if use_fused_rms_norm:
104
+ if not FUSEDRMSNORM_AVALIBLE:
105
+ raise ImportError("FusedRMSNorm is not available")
106
+ self.norm1 = FusedRMSNorm(dim, eps=norm_eps)
107
+ else:
108
+ self.norm1 = nn.RMSNorm(dim, eps=norm_eps)
109
+
110
+ if use_fused_rms_norm:
111
+ if not FUSEDRMSNORM_AVALIBLE:
112
+ raise ImportError("FusedRMSNorm is not available")
113
+ self.ffn_norm1 = FusedRMSNorm(dim, eps=norm_eps)
114
+ self.norm2 = FusedRMSNorm(dim, eps=norm_eps)
115
+ self.ffn_norm2 = FusedRMSNorm(dim, eps=norm_eps)
116
+ else:
117
+ self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps)
118
+ self.norm2 = nn.RMSNorm(dim, eps=norm_eps)
119
+ self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps)
120
+
121
+ self.initialize_weights()
122
+
123
+ def initialize_weights(self) -> None:
124
+ """
125
+ Initialize the weights of the transformer block.
126
+
127
+ Uses Xavier uniform initialization for linear layers and zero initialization for biases.
128
+ """
129
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
130
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
131
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
132
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
133
+
134
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
135
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
136
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
137
+
138
+ if self.modulation:
139
+ nn.init.zeros_(self.norm1.linear.weight)
140
+ nn.init.zeros_(self.norm1.linear.bias)
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: torch.Tensor,
146
+ image_rotary_emb: torch.Tensor,
147
+ temb: Optional[torch.Tensor] = None,
148
+ ) -> torch.Tensor:
149
+ """
150
+ Forward pass of the transformer block.
151
+
152
+ Args:
153
+ hidden_states: Input hidden states tensor
154
+ attention_mask: Attention mask tensor
155
+ image_rotary_emb: Rotary embeddings for image tokens
156
+ temb: Optional timestep embedding tensor
157
+
158
+ Returns:
159
+ torch.Tensor: Output hidden states after transformer block processing
160
+ """
161
+ if self.modulation:
162
+ if temb is None:
163
+ raise ValueError("temb must be provided when modulation is enabled")
164
+
165
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
166
+ attn_output = self.attn(
167
+ hidden_states=norm_hidden_states,
168
+ encoder_hidden_states=norm_hidden_states,
169
+ attention_mask=attention_mask,
170
+ image_rotary_emb=image_rotary_emb,
171
+ )
172
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
173
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
174
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
175
+ else:
176
+ norm_hidden_states = self.norm1(hidden_states)
177
+ attn_output = self.attn(
178
+ hidden_states=norm_hidden_states,
179
+ encoder_hidden_states=norm_hidden_states,
180
+ attention_mask=attention_mask,
181
+ image_rotary_emb=image_rotary_emb,
182
+ )
183
+ hidden_states = hidden_states + self.norm2(attn_output)
184
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
185
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
186
+
187
+ return hidden_states
188
+
189
+
190
+ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
191
+ """
192
+ OmniGen2 Transformer 2D Model.
193
+
194
+ A transformer-based diffusion model for image generation with:
195
+ - Patch-based image processing
196
+ - Rotary position embeddings
197
+ - Multi-head attention
198
+ - Conditional generation support
199
+
200
+ Args:
201
+ patch_size: Size of image patches
202
+ in_channels: Number of input channels
203
+ out_channels: Number of output channels (defaults to in_channels)
204
+ hidden_size: Size of hidden layers
205
+ num_layers: Number of transformer layers
206
+ num_refiner_layers: Number of refiner layers
207
+ num_attention_heads: Number of attention heads
208
+ num_kv_heads: Number of key-value heads
209
+ multiple_of: Multiple of which the hidden dimension should be
210
+ ffn_dim_multiplier: Multiplier for feed-forward network dimension
211
+ norm_eps: Epsilon value for normalization layers
212
+ axes_dim_rope: Dimensions for rotary position embeddings
213
+ axes_lens: Lengths for rotary position embeddings
214
+ text_feat_dim: Dimension of text features
215
+ timestep_scale: Scale factor for timestep embeddings
216
+ use_fused_rms_norm: Whether to use fused RMS normalization
217
+ use_fused_swiglu: Whether to use fused SwiGLU activation
218
+ """
219
+
220
+ _supports_gradient_checkpointing = True
221
+ _no_split_modules = ["Omnigen2TransformerBlock"]
222
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
223
+
224
+ @register_to_config
225
+ def __init__(
226
+ self,
227
+ patch_size: int = 2,
228
+ in_channels: int = 16,
229
+ out_channels: Optional[int] = None,
230
+ hidden_size: int = 2304,
231
+ num_layers: int = 26,
232
+ num_refiner_layers: int = 2,
233
+ num_attention_heads: int = 24,
234
+ num_kv_heads: int = 8,
235
+ multiple_of: int = 256,
236
+ ffn_dim_multiplier: Optional[float] = None,
237
+ norm_eps: float = 1e-5,
238
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
239
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
240
+ text_feat_dim: int = 1024,
241
+ timestep_scale: float = 1.0,
242
+ use_fused_rms_norm: bool = True,
243
+ use_fused_swiglu: bool = True,
244
+ ) -> None:
245
+ """Initialize the OmniGen2 transformer model."""
246
+ super().__init__()
247
+
248
+ # Validate configuration
249
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
250
+ raise ValueError(
251
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
252
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
253
+ )
254
+
255
+ self.out_channels = out_channels or in_channels
256
+
257
+ # Initialize embeddings
258
+ self.rope_embedder = OmniGen2RotaryPosEmbed(
259
+ theta=10000,
260
+ axes_dim=axes_dim_rope,
261
+ axes_lens=axes_lens,
262
+ patch_size=patch_size,
263
+ )
264
+
265
+ self.x_embedder = nn.Linear(
266
+ in_features=patch_size * patch_size * in_channels,
267
+ out_features=hidden_size,
268
+ )
269
+
270
+ self.ref_image_patch_embedder = nn.Linear(
271
+ in_features=patch_size * patch_size * in_channels,
272
+ out_features=hidden_size,
273
+ )
274
+
275
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
276
+ hidden_size=hidden_size,
277
+ text_feat_dim=text_feat_dim,
278
+ norm_eps=norm_eps,
279
+ timestep_scale=timestep_scale,
280
+ use_fused_rms_norm=use_fused_rms_norm,
281
+ )
282
+
283
+ # Initialize transformer blocks
284
+ self.noise_refiner = nn.ModuleList([
285
+ OmniGen2TransformerBlock(
286
+ hidden_size,
287
+ num_attention_heads,
288
+ num_kv_heads,
289
+ multiple_of,
290
+ ffn_dim_multiplier,
291
+ norm_eps,
292
+ modulation=True,
293
+ use_fused_rms_norm=use_fused_rms_norm,
294
+ use_fused_swiglu=use_fused_swiglu,
295
+ )
296
+ for _ in range(num_refiner_layers)
297
+ ])
298
+
299
+ self.ref_image_refiner = nn.ModuleList([
300
+ OmniGen2TransformerBlock(
301
+ hidden_size,
302
+ num_attention_heads,
303
+ num_kv_heads,
304
+ multiple_of,
305
+ ffn_dim_multiplier,
306
+ norm_eps,
307
+ modulation=True,
308
+ use_fused_rms_norm=use_fused_rms_norm,
309
+ use_fused_swiglu=use_fused_swiglu,
310
+ )
311
+ for _ in range(num_refiner_layers)
312
+ ])
313
+
314
+ self.context_refiner = nn.ModuleList(
315
+ [
316
+ OmniGen2TransformerBlock(
317
+ hidden_size,
318
+ num_attention_heads,
319
+ num_kv_heads,
320
+ multiple_of,
321
+ ffn_dim_multiplier,
322
+ norm_eps,
323
+ modulation=False,
324
+ use_fused_rms_norm=use_fused_rms_norm,
325
+ use_fused_swiglu=use_fused_swiglu
326
+ )
327
+ for _ in range(num_refiner_layers)
328
+ ]
329
+ )
330
+
331
+ # 3. Transformer blocks
332
+ self.layers = nn.ModuleList(
333
+ [
334
+ OmniGen2TransformerBlock(
335
+ hidden_size,
336
+ num_attention_heads,
337
+ num_kv_heads,
338
+ multiple_of,
339
+ ffn_dim_multiplier,
340
+ norm_eps,
341
+ modulation=True,
342
+ use_fused_rms_norm=use_fused_rms_norm,
343
+ use_fused_swiglu=use_fused_swiglu
344
+ )
345
+ for _ in range(num_layers)
346
+ ]
347
+ )
348
+
349
+ # 4. Output norm & projection
350
+ self.norm_out = LuminaLayerNormContinuous(
351
+ embedding_dim=hidden_size,
352
+ conditioning_embedding_dim=min(hidden_size, 1024),
353
+ elementwise_affine=False,
354
+ eps=1e-6,
355
+ bias=True,
356
+ out_dim=patch_size * patch_size * self.out_channels,
357
+ use_fused_rms_norm=use_fused_rms_norm,
358
+ )
359
+
360
+ # Add learnable embeddings to distinguish different images
361
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
362
+
363
+ self.gradient_checkpointing = False
364
+
365
+ self.initialize_weights()
366
+
367
+ def initialize_weights(self) -> None:
368
+ """
369
+ Initialize the weights of the model.
370
+
371
+ Uses Xavier uniform initialization for linear layers.
372
+ """
373
+ nn.init.xavier_uniform_(self.x_embedder.weight)
374
+ nn.init.constant_(self.x_embedder.bias, 0.0)
375
+
376
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
377
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
378
+
379
+ nn.init.zeros_(self.norm_out.linear_1.weight)
380
+ nn.init.zeros_(self.norm_out.linear_1.bias)
381
+ nn.init.zeros_(self.norm_out.linear_2.weight)
382
+ nn.init.zeros_(self.norm_out.linear_2.bias)
383
+
384
+ nn.init.normal_(self.image_index_embedding, std=0.02)
385
+
386
+ def img_patch_embed_and_refine(
387
+ self,
388
+ hidden_states,
389
+ ref_image_hidden_states,
390
+ padded_img_mask,
391
+ padded_ref_img_mask,
392
+ noise_rotary_emb,
393
+ ref_img_rotary_emb,
394
+ l_effective_ref_img_len,
395
+ l_effective_img_len,
396
+ temb
397
+ ):
398
+ batch_size = len(hidden_states)
399
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
400
+
401
+ hidden_states = self.x_embedder(hidden_states)
402
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
403
+
404
+ for i in range(batch_size):
405
+ shift = 0
406
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
407
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
408
+ shift += ref_img_len
409
+
410
+ for layer in self.noise_refiner:
411
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
412
+
413
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
414
+ num_ref_images = len(flat_l_effective_ref_img_len)
415
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
416
+
417
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
418
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
419
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
420
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
421
+
422
+ # sequence of ref imgs to batch
423
+ idx = 0
424
+ for i in range(batch_size):
425
+ shift = 0
426
+ for ref_img_len in l_effective_ref_img_len[i]:
427
+ batch_ref_img_mask[idx, :ref_img_len] = True
428
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
429
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
430
+ batch_temb[idx] = temb[i]
431
+ shift += ref_img_len
432
+ idx += 1
433
+
434
+ # refine ref imgs separately
435
+ for layer in self.ref_image_refiner:
436
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
437
+
438
+ # batch of ref imgs to sequence
439
+ idx = 0
440
+ for i in range(batch_size):
441
+ shift = 0
442
+ for ref_img_len in l_effective_ref_img_len[i]:
443
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
444
+ shift += ref_img_len
445
+ idx += 1
446
+
447
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
448
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
449
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
450
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
451
+
452
+ return combined_img_hidden_states
453
+
454
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
455
+ batch_size = len(hidden_states)
456
+ p = self.config.patch_size
457
+ device = hidden_states[0].device
458
+
459
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
460
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
461
+
462
+ if ref_image_hidden_states is not None:
463
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
464
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
465
+ else:
466
+ ref_img_sizes = [None for _ in range(batch_size)]
467
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
468
+
469
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
470
+ max_img_len = max(l_effective_img_len)
471
+
472
+ # ref image patch embeddings
473
+ flat_ref_img_hidden_states = []
474
+ for i in range(batch_size):
475
+ if ref_img_sizes[i] is not None:
476
+ imgs = []
477
+ for ref_img in ref_image_hidden_states[i]:
478
+ C, H, W = ref_img.size()
479
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
480
+ imgs.append(ref_img)
481
+
482
+ img = torch.cat(imgs, dim=0)
483
+ flat_ref_img_hidden_states.append(img)
484
+ else:
485
+ flat_ref_img_hidden_states.append(None)
486
+
487
+ # image patch embeddings
488
+ flat_hidden_states = []
489
+ for i in range(batch_size):
490
+ img = hidden_states[i]
491
+ C, H, W = img.size()
492
+
493
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
494
+ flat_hidden_states.append(img)
495
+
496
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
497
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
498
+ for i in range(batch_size):
499
+ if ref_img_sizes[i] is not None:
500
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
501
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
502
+
503
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
504
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
505
+ for i in range(batch_size):
506
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
507
+ padded_img_mask[i, :l_effective_img_len[i]] = True
508
+
509
+ return (
510
+ padded_hidden_states,
511
+ padded_ref_img_hidden_states,
512
+ padded_img_mask,
513
+ padded_ref_img_mask,
514
+ l_effective_ref_img_len,
515
+ l_effective_img_len,
516
+ ref_img_sizes,
517
+ img_sizes,
518
+ )
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
523
+ timestep: torch.Tensor,
524
+ text_hidden_states: torch.Tensor,
525
+ freqs_cis: torch.Tensor,
526
+ text_attention_mask: torch.Tensor,
527
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
528
+ attention_kwargs: Optional[Dict[str, Any]] = None,
529
+ return_dict: bool = False,
530
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
531
+ if attention_kwargs is not None:
532
+ attention_kwargs = attention_kwargs.copy()
533
+ lora_scale = attention_kwargs.pop("scale", 1.0)
534
+ else:
535
+ lora_scale = 1.0
536
+
537
+ if USE_PEFT_BACKEND:
538
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
539
+ scale_lora_layers(self, lora_scale)
540
+ else:
541
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
542
+ logger.warning(
543
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
544
+ )
545
+
546
+ # 1. Condition, positional & patch embedding
547
+ batch_size = len(hidden_states)
548
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
549
+
550
+ if is_hidden_states_tensor:
551
+ assert hidden_states.ndim == 4
552
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
553
+
554
+ device = hidden_states[0].device
555
+
556
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
557
+
558
+ (
559
+ hidden_states,
560
+ ref_image_hidden_states,
561
+ img_mask,
562
+ ref_img_mask,
563
+ l_effective_ref_img_len,
564
+ l_effective_img_len,
565
+ ref_img_sizes,
566
+ img_sizes,
567
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
568
+
569
+ (
570
+ context_rotary_emb,
571
+ ref_img_rotary_emb,
572
+ noise_rotary_emb,
573
+ rotary_emb,
574
+ encoder_seq_lengths,
575
+ seq_lengths,
576
+ ) = self.rope_embedder(
577
+ freqs_cis,
578
+ text_attention_mask,
579
+ l_effective_ref_img_len,
580
+ l_effective_img_len,
581
+ ref_img_sizes,
582
+ img_sizes,
583
+ device,
584
+ )
585
+
586
+ # 2. Context refinement
587
+ for layer in self.context_refiner:
588
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
589
+
590
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
591
+ hidden_states,
592
+ ref_image_hidden_states,
593
+ img_mask,
594
+ ref_img_mask,
595
+ noise_rotary_emb,
596
+ ref_img_rotary_emb,
597
+ l_effective_ref_img_len,
598
+ l_effective_img_len,
599
+ temb,
600
+ )
601
+
602
+ # 3. Joint Transformer blocks
603
+ max_seq_len = max(seq_lengths)
604
+
605
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
606
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
607
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
608
+ attention_mask[i, :seq_len] = True
609
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
610
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
611
+
612
+ hidden_states = joint_hidden_states
613
+
614
+ for layer_idx, layer in enumerate(self.layers):
615
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
616
+ hidden_states = self._gradient_checkpointing_func(
617
+ layer, hidden_states, attention_mask, rotary_emb, temb
618
+ )
619
+ else:
620
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
621
+
622
+ # 4. Output norm & projection
623
+ hidden_states = self.norm_out(hidden_states, temb)
624
+
625
+ p = self.config.patch_size
626
+ output = []
627
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
628
+ height, width = img_size
629
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
630
+ if is_hidden_states_tensor:
631
+ output = torch.stack(output, dim=0)
632
+
633
+ if USE_PEFT_BACKEND:
634
+ # remove `lora_scale` from each PEFT layer
635
+ unscale_lora_layers(self, lora_scale)
636
+
637
+ if not return_dict:
638
+ return output
639
+ return Transformer2DModelOutput(sample=output)
omnigen2/ops/.DS_Store ADDED
Binary file (6.15 kB). View file
 
omnigen2/ops/triton/__init__.py ADDED
File without changes
omnigen2/ops/triton/layer_norm.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ import triton
15
+ import triton.language as tl
16
+
17
+
18
+ from typing import Callable
19
+
20
+
21
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
22
+ def decorator(*args, **kwargs):
23
+ if cuda_amp_deprecated:
24
+ kwargs["device_type"] = "cuda"
25
+ return dec(*args, **kwargs)
26
+ return decorator
27
+
28
+
29
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
30
+ deprecated = True
31
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
32
+ else:
33
+ deprecated = False
34
+ from torch.cuda.amp import custom_fwd, custom_bwd
35
+
36
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
37
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
38
+
39
+
40
+ def triton_autotune_configs():
41
+ # Return configs with a valid warp count for the current device
42
+ configs=[]
43
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
44
+ max_threads_per_block=1024
45
+ # Default to warp size 32 if not defined by device
46
+ warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
47
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
48
+ warp_count=1
49
+ while warp_count*warp_size <= max_threads_per_block:
50
+ configs.append(triton.Config({}, num_warps=warp_count))
51
+ warp_count*=2
52
+ return configs
53
+
54
+ def layer_norm_ref(
55
+ x,
56
+ weight,
57
+ bias,
58
+ residual=None,
59
+ x1=None,
60
+ weight1=None,
61
+ bias1=None,
62
+ eps=1e-6,
63
+ dropout_p=0.0,
64
+ rowscale=None,
65
+ prenorm=False,
66
+ zero_centered_weight=False,
67
+ dropout_mask=None,
68
+ dropout_mask1=None,
69
+ upcast=False,
70
+ ):
71
+ dtype = x.dtype
72
+ if upcast:
73
+ x = x.float()
74
+ weight = weight.float()
75
+ bias = bias.float() if bias is not None else None
76
+ residual = residual.float() if residual is not None else residual
77
+ x1 = x1.float() if x1 is not None else None
78
+ weight1 = weight1.float() if weight1 is not None else None
79
+ bias1 = bias1.float() if bias1 is not None else None
80
+ if zero_centered_weight:
81
+ weight = weight + 1.0
82
+ if weight1 is not None:
83
+ weight1 = weight1 + 1.0
84
+ if x1 is not None:
85
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
86
+ if rowscale is not None:
87
+ x = x * rowscale[..., None]
88
+ if dropout_p > 0.0:
89
+ if dropout_mask is not None:
90
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
91
+ else:
92
+ x = F.dropout(x, p=dropout_p)
93
+ if x1 is not None:
94
+ if dropout_mask1 is not None:
95
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
96
+ else:
97
+ x1 = F.dropout(x1, p=dropout_p)
98
+ if x1 is not None:
99
+ x = x + x1
100
+ if residual is not None:
101
+ x = (x + residual).to(x.dtype)
102
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
103
+ dtype
104
+ )
105
+ if weight1 is None:
106
+ return out if not prenorm else (out, x)
107
+ else:
108
+ out1 = F.layer_norm(
109
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
110
+ ).to(dtype)
111
+ return (out, out1) if not prenorm else (out, out1, x)
112
+
113
+
114
+ def rms_norm_ref(
115
+ x,
116
+ weight,
117
+ bias,
118
+ residual=None,
119
+ x1=None,
120
+ weight1=None,
121
+ bias1=None,
122
+ eps=1e-6,
123
+ dropout_p=0.0,
124
+ rowscale=None,
125
+ prenorm=False,
126
+ zero_centered_weight=False,
127
+ dropout_mask=None,
128
+ dropout_mask1=None,
129
+ upcast=False,
130
+ ):
131
+ dtype = x.dtype
132
+ if upcast:
133
+ x = x.float()
134
+ weight = weight.float()
135
+ bias = bias.float() if bias is not None else None
136
+ residual = residual.float() if residual is not None else residual
137
+ x1 = x1.float() if x1 is not None else None
138
+ weight1 = weight1.float() if weight1 is not None else None
139
+ bias1 = bias1.float() if bias1 is not None else None
140
+ if zero_centered_weight:
141
+ weight = weight + 1.0
142
+ if weight1 is not None:
143
+ weight1 = weight1 + 1.0
144
+ if x1 is not None:
145
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
146
+ if rowscale is not None:
147
+ x = x * rowscale[..., None]
148
+ if dropout_p > 0.0:
149
+ if dropout_mask is not None:
150
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
151
+ else:
152
+ x = F.dropout(x, p=dropout_p)
153
+ if x1 is not None:
154
+ if dropout_mask1 is not None:
155
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
156
+ else:
157
+ x1 = F.dropout(x1, p=dropout_p)
158
+ if x1 is not None:
159
+ x = x + x1
160
+ if residual is not None:
161
+ x = (x + residual).to(x.dtype)
162
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
163
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
164
+ if weight1 is None:
165
+ return out if not prenorm else (out, x)
166
+ else:
167
+ out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
168
+ dtype
169
+ )
170
+ return (out, out1) if not prenorm else (out, out1, x)
171
+
172
+
173
+ @triton.autotune(
174
+ configs=triton_autotune_configs(),
175
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
176
+ )
177
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
178
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
179
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
180
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
181
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
182
+ @triton.jit
183
+ def _layer_norm_fwd_1pass_kernel(
184
+ X, # pointer to the input
185
+ Y, # pointer to the output
186
+ W, # pointer to the weights
187
+ B, # pointer to the biases
188
+ RESIDUAL, # pointer to the residual
189
+ X1,
190
+ W1,
191
+ B1,
192
+ Y1,
193
+ RESIDUAL_OUT, # pointer to the residual
194
+ ROWSCALE,
195
+ SEEDS, # Dropout seeds for each row
196
+ DROPOUT_MASK,
197
+ Mean, # pointer to the mean
198
+ Rstd, # pointer to the 1/std
199
+ stride_x_row, # how much to increase the pointer when moving by 1 row
200
+ stride_y_row,
201
+ stride_res_row,
202
+ stride_res_out_row,
203
+ stride_x1_row,
204
+ stride_y1_row,
205
+ M, # number of rows in X
206
+ N, # number of columns in X
207
+ eps, # epsilon to avoid division by zero
208
+ dropout_p, # Dropout probability
209
+ zero_centered_weight, # If true, add 1.0 to the weight
210
+ IS_RMS_NORM: tl.constexpr,
211
+ BLOCK_N: tl.constexpr,
212
+ HAS_RESIDUAL: tl.constexpr,
213
+ STORE_RESIDUAL_OUT: tl.constexpr,
214
+ HAS_BIAS: tl.constexpr,
215
+ HAS_DROPOUT: tl.constexpr,
216
+ STORE_DROPOUT_MASK: tl.constexpr,
217
+ HAS_ROWSCALE: tl.constexpr,
218
+ HAS_X1: tl.constexpr,
219
+ HAS_W1: tl.constexpr,
220
+ HAS_B1: tl.constexpr,
221
+ ):
222
+ # Map the program id to the row of X and Y it should compute.
223
+ row = tl.program_id(0)
224
+ X += row * stride_x_row
225
+ Y += row * stride_y_row
226
+ if HAS_RESIDUAL:
227
+ RESIDUAL += row * stride_res_row
228
+ if STORE_RESIDUAL_OUT:
229
+ RESIDUAL_OUT += row * stride_res_out_row
230
+ if HAS_X1:
231
+ X1 += row * stride_x1_row
232
+ if HAS_W1:
233
+ Y1 += row * stride_y1_row
234
+ # Compute mean and variance
235
+ cols = tl.arange(0, BLOCK_N)
236
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
237
+ if HAS_ROWSCALE:
238
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
239
+ x *= rowscale
240
+ if HAS_DROPOUT:
241
+ # Compute dropout mask
242
+ # 7 rounds is good enough, and reduces register pressure
243
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
244
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
245
+ if STORE_DROPOUT_MASK:
246
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
247
+ if HAS_X1:
248
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
249
+ if HAS_ROWSCALE:
250
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
251
+ x1 *= rowscale
252
+ if HAS_DROPOUT:
253
+ # Compute dropout mask
254
+ # 7 rounds is good enough, and reduces register pressure
255
+ keep_mask = (
256
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
257
+ )
258
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
259
+ if STORE_DROPOUT_MASK:
260
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
261
+ x += x1
262
+ if HAS_RESIDUAL:
263
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
264
+ x += residual
265
+ if STORE_RESIDUAL_OUT:
266
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
267
+ if not IS_RMS_NORM:
268
+ mean = tl.sum(x, axis=0) / N
269
+ tl.store(Mean + row, mean)
270
+ xbar = tl.where(cols < N, x - mean, 0.0)
271
+ var = tl.sum(xbar * xbar, axis=0) / N
272
+ else:
273
+ xbar = tl.where(cols < N, x, 0.0)
274
+ var = tl.sum(xbar * xbar, axis=0) / N
275
+ rstd = 1 / tl.sqrt(var + eps)
276
+ tl.store(Rstd + row, rstd)
277
+ # Normalize and apply linear transformation
278
+ mask = cols < N
279
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
280
+ if zero_centered_weight:
281
+ w += 1.0
282
+ if HAS_BIAS:
283
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
284
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
285
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
286
+ # Write output
287
+ tl.store(Y + cols, y, mask=mask)
288
+ if HAS_W1:
289
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
290
+ if zero_centered_weight:
291
+ w1 += 1.0
292
+ if HAS_B1:
293
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
294
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
295
+ tl.store(Y1 + cols, y1, mask=mask)
296
+
297
+
298
+ def _layer_norm_fwd(
299
+ x,
300
+ weight,
301
+ bias,
302
+ eps,
303
+ residual=None,
304
+ x1=None,
305
+ weight1=None,
306
+ bias1=None,
307
+ dropout_p=0.0,
308
+ rowscale=None,
309
+ out_dtype=None,
310
+ residual_dtype=None,
311
+ zero_centered_weight=False,
312
+ is_rms_norm=False,
313
+ return_dropout_mask=False,
314
+ out=None,
315
+ residual_out=None
316
+ ):
317
+ if residual is not None:
318
+ residual_dtype = residual.dtype
319
+ M, N = x.shape
320
+ assert x.stride(-1) == 1
321
+ if residual is not None:
322
+ assert residual.stride(-1) == 1
323
+ assert residual.shape == (M, N)
324
+ assert weight.shape == (N,)
325
+ assert weight.stride(-1) == 1
326
+ if bias is not None:
327
+ assert bias.stride(-1) == 1
328
+ assert bias.shape == (N,)
329
+ if x1 is not None:
330
+ assert x1.shape == x.shape
331
+ assert rowscale is None
332
+ assert x1.stride(-1) == 1
333
+ if weight1 is not None:
334
+ assert weight1.shape == (N,)
335
+ assert weight1.stride(-1) == 1
336
+ if bias1 is not None:
337
+ assert bias1.shape == (N,)
338
+ assert bias1.stride(-1) == 1
339
+ if rowscale is not None:
340
+ assert rowscale.is_contiguous()
341
+ assert rowscale.shape == (M,)
342
+ # allocate output
343
+ if out is None:
344
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
345
+ else:
346
+ assert out.shape == x.shape
347
+ assert out.stride(-1) == 1
348
+ if weight1 is not None:
349
+ y1 = torch.empty_like(out)
350
+ assert y1.stride(-1) == 1
351
+ else:
352
+ y1 = None
353
+ if (
354
+ residual is not None
355
+ or (residual_dtype is not None and residual_dtype != x.dtype)
356
+ or dropout_p > 0.0
357
+ or rowscale is not None
358
+ or x1 is not None
359
+ ):
360
+ if residual_out is None:
361
+ residual_out = torch.empty(
362
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
363
+ )
364
+ else:
365
+ assert residual_out.shape == x.shape
366
+ assert residual_out.stride(-1) == 1
367
+ else:
368
+ residual_out = None
369
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
370
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
+ if dropout_p > 0.0:
372
+ seeds = torch.randint(
373
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
+ )
375
+ else:
376
+ seeds = None
377
+ if return_dropout_mask and dropout_p > 0.0:
378
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
379
+ else:
380
+ dropout_mask = None
381
+ # Less than 64KB per feature: enqueue fused kernel
382
+ MAX_FUSED_SIZE = 65536 // x.element_size()
383
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
384
+ if N > BLOCK_N:
385
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
386
+ with torch.cuda.device(x.device.index):
387
+ _layer_norm_fwd_1pass_kernel[(M,)](
388
+ x,
389
+ out,
390
+ weight,
391
+ bias,
392
+ residual,
393
+ x1,
394
+ weight1,
395
+ bias1,
396
+ y1,
397
+ residual_out,
398
+ rowscale,
399
+ seeds,
400
+ dropout_mask,
401
+ mean,
402
+ rstd,
403
+ x.stride(0),
404
+ out.stride(0),
405
+ residual.stride(0) if residual is not None else 0,
406
+ residual_out.stride(0) if residual_out is not None else 0,
407
+ x1.stride(0) if x1 is not None else 0,
408
+ y1.stride(0) if y1 is not None else 0,
409
+ M,
410
+ N,
411
+ eps,
412
+ dropout_p,
413
+ zero_centered_weight,
414
+ is_rms_norm,
415
+ BLOCK_N,
416
+ residual is not None,
417
+ residual_out is not None,
418
+ bias is not None,
419
+ dropout_p > 0.0,
420
+ dropout_mask is not None,
421
+ rowscale is not None,
422
+ )
423
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
424
+ if dropout_mask is not None and x1 is not None:
425
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
426
+ else:
427
+ dropout_mask1 = None
428
+ return (
429
+ out,
430
+ y1,
431
+ mean,
432
+ rstd,
433
+ residual_out if residual_out is not None else x,
434
+ seeds,
435
+ dropout_mask,
436
+ dropout_mask1,
437
+ )
438
+
439
+
440
+ @triton.autotune(
441
+ configs=triton_autotune_configs(),
442
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
443
+ )
444
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
445
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
446
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
447
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
448
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
449
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
450
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
451
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
452
+ @triton.jit
453
+ def _layer_norm_bwd_kernel(
454
+ X, # pointer to the input
455
+ W, # pointer to the weights
456
+ B, # pointer to the biases
457
+ Y, # pointer to the output to be recomputed
458
+ DY, # pointer to the output gradient
459
+ DX, # pointer to the input gradient
460
+ DW, # pointer to the partial sum of weights gradient
461
+ DB, # pointer to the partial sum of biases gradient
462
+ DRESIDUAL,
463
+ W1,
464
+ DY1,
465
+ DX1,
466
+ DW1,
467
+ DB1,
468
+ DRESIDUAL_IN,
469
+ ROWSCALE,
470
+ SEEDS,
471
+ Mean, # pointer to the mean
472
+ Rstd, # pointer to the 1/std
473
+ stride_x_row, # how much to increase the pointer when moving by 1 row
474
+ stride_y_row,
475
+ stride_dy_row,
476
+ stride_dx_row,
477
+ stride_dres_row,
478
+ stride_dy1_row,
479
+ stride_dx1_row,
480
+ stride_dres_in_row,
481
+ M, # number of rows in X
482
+ N, # number of columns in X
483
+ eps, # epsilon to avoid division by zero
484
+ dropout_p,
485
+ zero_centered_weight,
486
+ rows_per_program,
487
+ IS_RMS_NORM: tl.constexpr,
488
+ BLOCK_N: tl.constexpr,
489
+ HAS_DRESIDUAL: tl.constexpr,
490
+ STORE_DRESIDUAL: tl.constexpr,
491
+ HAS_BIAS: tl.constexpr,
492
+ HAS_DROPOUT: tl.constexpr,
493
+ HAS_ROWSCALE: tl.constexpr,
494
+ HAS_DY1: tl.constexpr,
495
+ HAS_DX1: tl.constexpr,
496
+ HAS_B1: tl.constexpr,
497
+ RECOMPUTE_OUTPUT: tl.constexpr,
498
+ ):
499
+ # Map the program id to the elements of X, DX, and DY it should compute.
500
+ row_block_id = tl.program_id(0)
501
+ row_start = row_block_id * rows_per_program
502
+ # Do not early exit if row_start >= M, because we need to write DW and DB
503
+ cols = tl.arange(0, BLOCK_N)
504
+ mask = cols < N
505
+ X += row_start * stride_x_row
506
+ if HAS_DRESIDUAL:
507
+ DRESIDUAL += row_start * stride_dres_row
508
+ if STORE_DRESIDUAL:
509
+ DRESIDUAL_IN += row_start * stride_dres_in_row
510
+ DY += row_start * stride_dy_row
511
+ DX += row_start * stride_dx_row
512
+ if HAS_DY1:
513
+ DY1 += row_start * stride_dy1_row
514
+ if HAS_DX1:
515
+ DX1 += row_start * stride_dx1_row
516
+ if RECOMPUTE_OUTPUT:
517
+ Y += row_start * stride_y_row
518
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
519
+ if zero_centered_weight:
520
+ w += 1.0
521
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
522
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
523
+ if HAS_DY1:
524
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
525
+ if zero_centered_weight:
526
+ w1 += 1.0
527
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
528
+ if HAS_BIAS:
529
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
530
+ if HAS_DY1:
531
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
532
+ if HAS_B1:
533
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
534
+ row_end = min((row_block_id + 1) * rows_per_program, M)
535
+ for row in range(row_start, row_end):
536
+ # Load data to SRAM
537
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
538
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
539
+ if HAS_DY1:
540
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
541
+ if not IS_RMS_NORM:
542
+ mean = tl.load(Mean + row)
543
+ rstd = tl.load(Rstd + row)
544
+ # Compute dx
545
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
546
+ xhat = tl.where(mask, xhat, 0.0)
547
+ if RECOMPUTE_OUTPUT:
548
+ y = xhat * w + b if HAS_BIAS else xhat * w
549
+ tl.store(Y + cols, y, mask=mask)
550
+ wdy = w * dy
551
+ dw += dy * xhat
552
+ if HAS_BIAS:
553
+ db += dy
554
+ if HAS_DY1:
555
+ wdy += w1 * dy1
556
+ dw1 += dy1 * xhat
557
+ if HAS_B1:
558
+ db1 += dy1
559
+ if not IS_RMS_NORM:
560
+ c1 = tl.sum(xhat * wdy, axis=0) / N
561
+ c2 = tl.sum(wdy, axis=0) / N
562
+ dx = (wdy - (xhat * c1 + c2)) * rstd
563
+ else:
564
+ c1 = tl.sum(xhat * wdy, axis=0) / N
565
+ dx = (wdy - xhat * c1) * rstd
566
+ if HAS_DRESIDUAL:
567
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
568
+ dx += dres
569
+ # Write dx
570
+ if STORE_DRESIDUAL:
571
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
572
+ if HAS_DX1:
573
+ if HAS_DROPOUT:
574
+ keep_mask = (
575
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
576
+ )
577
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
578
+ else:
579
+ dx1 = dx
580
+ tl.store(DX1 + cols, dx1, mask=mask)
581
+ if HAS_DROPOUT:
582
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
583
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
584
+ if HAS_ROWSCALE:
585
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
586
+ dx *= rowscale
587
+ tl.store(DX + cols, dx, mask=mask)
588
+
589
+ X += stride_x_row
590
+ if HAS_DRESIDUAL:
591
+ DRESIDUAL += stride_dres_row
592
+ if STORE_DRESIDUAL:
593
+ DRESIDUAL_IN += stride_dres_in_row
594
+ if RECOMPUTE_OUTPUT:
595
+ Y += stride_y_row
596
+ DY += stride_dy_row
597
+ DX += stride_dx_row
598
+ if HAS_DY1:
599
+ DY1 += stride_dy1_row
600
+ if HAS_DX1:
601
+ DX1 += stride_dx1_row
602
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
603
+ if HAS_BIAS:
604
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
605
+ if HAS_DY1:
606
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
607
+ if HAS_B1:
608
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
609
+
610
+
611
+ def _layer_norm_bwd(
612
+ dy,
613
+ x,
614
+ weight,
615
+ bias,
616
+ eps,
617
+ mean,
618
+ rstd,
619
+ dresidual=None,
620
+ dy1=None,
621
+ weight1=None,
622
+ bias1=None,
623
+ seeds=None,
624
+ dropout_p=0.0,
625
+ rowscale=None,
626
+ has_residual=False,
627
+ has_x1=False,
628
+ zero_centered_weight=False,
629
+ is_rms_norm=False,
630
+ x_dtype=None,
631
+ recompute_output=False,
632
+ ):
633
+ M, N = x.shape
634
+ assert x.stride(-1) == 1
635
+ assert dy.stride(-1) == 1
636
+ assert dy.shape == (M, N)
637
+ if dresidual is not None:
638
+ assert dresidual.stride(-1) == 1
639
+ assert dresidual.shape == (M, N)
640
+ assert weight.shape == (N,)
641
+ assert weight.stride(-1) == 1
642
+ if bias is not None:
643
+ assert bias.stride(-1) == 1
644
+ assert bias.shape == (N,)
645
+ if dy1 is not None:
646
+ assert weight1 is not None
647
+ assert dy1.shape == dy.shape
648
+ assert dy1.stride(-1) == 1
649
+ if weight1 is not None:
650
+ assert weight1.shape == (N,)
651
+ assert weight1.stride(-1) == 1
652
+ if bias1 is not None:
653
+ assert bias1.shape == (N,)
654
+ assert bias1.stride(-1) == 1
655
+ if seeds is not None:
656
+ assert seeds.is_contiguous()
657
+ assert seeds.shape == (M if not has_x1 else M * 2,)
658
+ if rowscale is not None:
659
+ assert rowscale.is_contiguous()
660
+ assert rowscale.shape == (M,)
661
+ # allocate output
662
+ dx = (
663
+ torch.empty_like(x)
664
+ if x_dtype is None
665
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
666
+ )
667
+ dresidual_in = (
668
+ torch.empty_like(x)
669
+ if has_residual
670
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
671
+ else None
672
+ )
673
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
674
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
675
+ if recompute_output:
676
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
677
+
678
+ # Less than 64KB per feature: enqueue fused kernel
679
+ MAX_FUSED_SIZE = 65536 // x.element_size()
680
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
681
+ if N > BLOCK_N:
682
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
683
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
684
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
685
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
686
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
687
+ _db = (
688
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
689
+ if bias is not None
690
+ else None
691
+ )
692
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
693
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
694
+ rows_per_program = math.ceil(M / sm_count)
695
+ grid = (sm_count,)
696
+ with torch.cuda.device(x.device.index):
697
+ _layer_norm_bwd_kernel[grid](
698
+ x,
699
+ weight,
700
+ bias,
701
+ y,
702
+ dy,
703
+ dx,
704
+ _dw,
705
+ _db,
706
+ dresidual,
707
+ weight1,
708
+ dy1,
709
+ dx1,
710
+ _dw1,
711
+ _db1,
712
+ dresidual_in,
713
+ rowscale,
714
+ seeds,
715
+ mean,
716
+ rstd,
717
+ x.stride(0),
718
+ 0 if not recompute_output else y.stride(0),
719
+ dy.stride(0),
720
+ dx.stride(0),
721
+ dresidual.stride(0) if dresidual is not None else 0,
722
+ dy1.stride(0) if dy1 is not None else 0,
723
+ dx1.stride(0) if dx1 is not None else 0,
724
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
725
+ M,
726
+ N,
727
+ eps,
728
+ dropout_p,
729
+ zero_centered_weight,
730
+ rows_per_program,
731
+ is_rms_norm,
732
+ BLOCK_N,
733
+ dresidual is not None,
734
+ dresidual_in is not None,
735
+ bias is not None,
736
+ dropout_p > 0.0,
737
+ )
738
+ dw = _dw.sum(0).to(weight.dtype)
739
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
740
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
741
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
742
+ # Don't need to compute dresidual_in separately in this case
743
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
744
+ dresidual_in = dx
745
+ if has_x1 and dropout_p == 0.0:
746
+ dx1 = dx
747
+ return (
748
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
749
+ if not recompute_output
750
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
751
+ )
752
+
753
+
754
+ class LayerNormFn(torch.autograd.Function):
755
+ @staticmethod
756
+ def forward(
757
+ ctx,
758
+ x,
759
+ weight,
760
+ bias,
761
+ residual=None,
762
+ x1=None,
763
+ weight1=None,
764
+ bias1=None,
765
+ eps=1e-6,
766
+ dropout_p=0.0,
767
+ rowscale=None,
768
+ prenorm=False,
769
+ residual_in_fp32=False,
770
+ zero_centered_weight=False,
771
+ is_rms_norm=False,
772
+ return_dropout_mask=False,
773
+ out=None,
774
+ residual_out=None
775
+ ):
776
+ x_shape_og = x.shape
777
+ # Check for zero sequence length
778
+ if x.numel() == 0:
779
+ ctx.zero_seq_length = True
780
+ # Only save minimal required tensors for backward
781
+ # ctx.save_for_backward(weight, bias, weight1, bias1)
782
+ ctx.x_shape_og = x_shape_og
783
+ ctx.weight_shape = weight.shape
784
+ ctx.weight_dtype = weight.dtype
785
+ ctx.weight_device = weight.device
786
+
787
+ ctx.has_bias = bias is not None
788
+ ctx.bias_shape = bias.shape if bias is not None else None
789
+ ctx.bias_dtype = bias.dtype if bias is not None else None
790
+ ctx.bias_device = bias.device if bias is not None else None
791
+
792
+ ctx.has_weight1 = weight1 is not None
793
+ ctx.weight1_shape = weight1.shape if weight1 is not None else None
794
+ ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
795
+ ctx.weight1_device = weight1.device if weight1 is not None else None
796
+
797
+ ctx.has_bias1 = bias1 is not None
798
+ ctx.bias1_shape = bias1.shape if bias1 is not None else None
799
+ ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
800
+ ctx.bias1_device = bias1.device if bias1 is not None else None
801
+
802
+ ctx.has_residual = residual is not None
803
+ ctx.has_x1 = x1 is not None
804
+ ctx.dropout_p = dropout_p
805
+
806
+ # Handle output tensors with correct dtype
807
+ y = x # Preserve input tensor properties
808
+ y1 = torch.empty_like(x) if x1 is not None else None
809
+
810
+ # Only create residual_out if prenorm is True
811
+ residual_out = torch.empty(x.shape,
812
+ dtype=torch.float32 if residual_in_fp32 else x.dtype,
813
+ device=x.device) if prenorm else None
814
+
815
+ # Handle dropout masks
816
+ dropout_mask = None
817
+ dropout_mask1 = None
818
+ if return_dropout_mask:
819
+ dropout_mask = torch.empty_like(x, dtype=torch.uint8)
820
+ if x1 is not None:
821
+ dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
822
+
823
+ # Return based on configuration
824
+ if not return_dropout_mask:
825
+ if weight1 is None:
826
+ return y if not prenorm else (y, residual_out)
827
+ else:
828
+ return (y, y1) if not prenorm else (y, y1, residual_out)
829
+ else:
830
+ if weight1 is None:
831
+ return ((y, dropout_mask, dropout_mask1) if not prenorm
832
+ else (y, residual_out, dropout_mask, dropout_mask1))
833
+ else:
834
+ return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
835
+ else (y, y1, residual_out, dropout_mask, dropout_mask1))
836
+
837
+ ctx.zero_seq_length = False
838
+ # reshape input data into 2D tensor
839
+ x = x.reshape(-1, x.shape[-1])
840
+ if x.stride(-1) != 1:
841
+ x = x.contiguous()
842
+ if residual is not None:
843
+ assert residual.shape == x_shape_og
844
+ residual = residual.reshape(-1, residual.shape[-1])
845
+ if residual.stride(-1) != 1:
846
+ residual = residual.contiguous()
847
+ if x1 is not None:
848
+ assert x1.shape == x_shape_og
849
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
850
+ x1 = x1.reshape(-1, x1.shape[-1])
851
+ if x1.stride(-1) != 1:
852
+ x1 = x1.contiguous()
853
+ weight = weight.contiguous()
854
+ if bias is not None:
855
+ bias = bias.contiguous()
856
+ if weight1 is not None:
857
+ weight1 = weight1.contiguous()
858
+ if bias1 is not None:
859
+ bias1 = bias1.contiguous()
860
+ if rowscale is not None:
861
+ rowscale = rowscale.reshape(-1).contiguous()
862
+ residual_dtype = (
863
+ residual.dtype
864
+ if residual is not None
865
+ else (torch.float32 if residual_in_fp32 else None)
866
+ )
867
+ if out is not None:
868
+ out = out.reshape(-1, out.shape[-1])
869
+ if residual_out is not None:
870
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
871
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
872
+ x,
873
+ weight,
874
+ bias,
875
+ eps,
876
+ residual,
877
+ x1,
878
+ weight1,
879
+ bias1,
880
+ dropout_p=dropout_p,
881
+ rowscale=rowscale,
882
+ residual_dtype=residual_dtype,
883
+ zero_centered_weight=zero_centered_weight,
884
+ is_rms_norm=is_rms_norm,
885
+ return_dropout_mask=return_dropout_mask,
886
+ out=out,
887
+ residual_out=residual_out
888
+ )
889
+ ctx.save_for_backward(
890
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
891
+ )
892
+ ctx.x_shape_og = x_shape_og
893
+ ctx.eps = eps
894
+ ctx.dropout_p = dropout_p
895
+ ctx.is_rms_norm = is_rms_norm
896
+ ctx.has_residual = residual is not None
897
+ ctx.has_x1 = x1 is not None
898
+ ctx.prenorm = prenorm
899
+ ctx.x_dtype = x.dtype
900
+ ctx.zero_centered_weight = zero_centered_weight
901
+ y = y.reshape(x_shape_og)
902
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
903
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
904
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
905
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
906
+ if not return_dropout_mask:
907
+ if weight1 is None:
908
+ return y if not prenorm else (y, residual_out)
909
+ else:
910
+ return (y, y1) if not prenorm else (y, y1, residual_out)
911
+ else:
912
+ if weight1 is None:
913
+ return (
914
+ (y, dropout_mask, dropout_mask1)
915
+ if not prenorm
916
+ else (y, residual_out, dropout_mask, dropout_mask1)
917
+ )
918
+ else:
919
+ return (
920
+ (y, y1, dropout_mask, dropout_mask1)
921
+ if not prenorm
922
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
923
+ )
924
+
925
+ @staticmethod
926
+ def backward(ctx, dy, *args):
927
+ if ctx.zero_seq_length:
928
+ return (
929
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
930
+ torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
931
+ torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
932
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
933
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
934
+ torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
935
+ torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
936
+ None,
937
+ None,
938
+ None,
939
+ None,
940
+ None,
941
+ None,
942
+ None,
943
+ None,
944
+ None,
945
+ None,
946
+ )
947
+
948
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
949
+ dy = dy.reshape(-1, dy.shape[-1])
950
+ if dy.stride(-1) != 1:
951
+ dy = dy.contiguous()
952
+ assert dy.shape == x.shape
953
+ if weight1 is not None:
954
+ dy1, args = args[0], args[1:]
955
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
956
+ if dy1.stride(-1) != 1:
957
+ dy1 = dy1.contiguous()
958
+ assert dy1.shape == x.shape
959
+ else:
960
+ dy1 = None
961
+ if ctx.prenorm:
962
+ dresidual = args[0]
963
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
964
+ if dresidual.stride(-1) != 1:
965
+ dresidual = dresidual.contiguous()
966
+ assert dresidual.shape == x.shape
967
+ else:
968
+ dresidual = None
969
+
970
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
971
+ dy,
972
+ x,
973
+ weight,
974
+ bias,
975
+ ctx.eps,
976
+ mean,
977
+ rstd,
978
+ dresidual,
979
+ dy1,
980
+ weight1,
981
+ bias1,
982
+ seeds,
983
+ ctx.dropout_p,
984
+ rowscale,
985
+ ctx.has_residual,
986
+ ctx.has_x1,
987
+ ctx.zero_centered_weight,
988
+ ctx.is_rms_norm,
989
+ x_dtype=ctx.x_dtype,
990
+ )
991
+ return (
992
+ dx.reshape(ctx.x_shape_og),
993
+ dw,
994
+ db,
995
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
996
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
997
+ dw1,
998
+ db1,
999
+ None,
1000
+ None,
1001
+ None,
1002
+ None,
1003
+ None,
1004
+ None,
1005
+ None,
1006
+ None,
1007
+ None,
1008
+ None,
1009
+ )
1010
+
1011
+
1012
+ def layer_norm_fn(
1013
+ x,
1014
+ weight,
1015
+ bias,
1016
+ residual=None,
1017
+ x1=None,
1018
+ weight1=None,
1019
+ bias1=None,
1020
+ eps=1e-6,
1021
+ dropout_p=0.0,
1022
+ rowscale=None,
1023
+ prenorm=False,
1024
+ residual_in_fp32=False,
1025
+ zero_centered_weight=False,
1026
+ is_rms_norm=False,
1027
+ return_dropout_mask=False,
1028
+ out=None,
1029
+ residual_out=None
1030
+ ):
1031
+ return LayerNormFn.apply(
1032
+ x,
1033
+ weight,
1034
+ bias,
1035
+ residual,
1036
+ x1,
1037
+ weight1,
1038
+ bias1,
1039
+ eps,
1040
+ dropout_p,
1041
+ rowscale,
1042
+ prenorm,
1043
+ residual_in_fp32,
1044
+ zero_centered_weight,
1045
+ is_rms_norm,
1046
+ return_dropout_mask,
1047
+ out,
1048
+ residual_out
1049
+ )
1050
+
1051
+
1052
+ def rms_norm_fn(
1053
+ x,
1054
+ weight,
1055
+ bias,
1056
+ residual=None,
1057
+ x1=None,
1058
+ weight1=None,
1059
+ bias1=None,
1060
+ eps=1e-6,
1061
+ dropout_p=0.0,
1062
+ rowscale=None,
1063
+ prenorm=False,
1064
+ residual_in_fp32=False,
1065
+ zero_centered_weight=False,
1066
+ return_dropout_mask=False,
1067
+ out=None,
1068
+ residual_out=None
1069
+ ):
1070
+ return LayerNormFn.apply(
1071
+ x,
1072
+ weight,
1073
+ bias,
1074
+ residual,
1075
+ x1,
1076
+ weight1,
1077
+ bias1,
1078
+ eps,
1079
+ dropout_p,
1080
+ rowscale,
1081
+ prenorm,
1082
+ residual_in_fp32,
1083
+ zero_centered_weight,
1084
+ True,
1085
+ return_dropout_mask,
1086
+ out,
1087
+ residual_out
1088
+ )
1089
+
1090
+
1091
+ class RMSNorm(torch.nn.Module):
1092
+
1093
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
1094
+ device=None, dtype=None):
1095
+ factory_kwargs = {"device": device, "dtype": dtype}
1096
+ super().__init__()
1097
+ self.eps = eps
1098
+ if dropout_p > 0.0:
1099
+ self.drop = torch.nn.Dropout(dropout_p)
1100
+ else:
1101
+ self.drop = None
1102
+ self.zero_centered_weight = zero_centered_weight
1103
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1104
+ self.register_parameter("bias", None)
1105
+ self.reset_parameters()
1106
+
1107
+ def reset_parameters(self):
1108
+ if not self.zero_centered_weight:
1109
+ torch.nn.init.ones_(self.weight)
1110
+ else:
1111
+ torch.nn.init.zeros_(self.weight)
1112
+
1113
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1114
+ return rms_norm_fn(
1115
+ x,
1116
+ self.weight,
1117
+ self.bias,
1118
+ residual=residual,
1119
+ eps=self.eps,
1120
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1121
+ prenorm=prenorm,
1122
+ residual_in_fp32=residual_in_fp32,
1123
+ zero_centered_weight=self.zero_centered_weight,
1124
+ )
1125
+
1126
+
1127
+ class LayerNormLinearFn(torch.autograd.Function):
1128
+ @staticmethod
1129
+ @custom_fwd
1130
+ def forward(
1131
+ ctx,
1132
+ x,
1133
+ norm_weight,
1134
+ norm_bias,
1135
+ linear_weight,
1136
+ linear_bias,
1137
+ residual=None,
1138
+ eps=1e-6,
1139
+ prenorm=False,
1140
+ residual_in_fp32=False,
1141
+ is_rms_norm=False,
1142
+ ):
1143
+ x_shape_og = x.shape
1144
+ # reshape input data into 2D tensor
1145
+ x = x.reshape(-1, x.shape[-1])
1146
+ if x.stride(-1) != 1:
1147
+ x = x.contiguous()
1148
+ if residual is not None:
1149
+ assert residual.shape == x_shape_og
1150
+ residual = residual.reshape(-1, residual.shape[-1])
1151
+ if residual.stride(-1) != 1:
1152
+ residual = residual.contiguous()
1153
+ norm_weight = norm_weight.contiguous()
1154
+ if norm_bias is not None:
1155
+ norm_bias = norm_bias.contiguous()
1156
+ residual_dtype = (
1157
+ residual.dtype
1158
+ if residual is not None
1159
+ else (torch.float32 if residual_in_fp32 else None)
1160
+ )
1161
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1162
+ x,
1163
+ norm_weight,
1164
+ norm_bias,
1165
+ eps,
1166
+ residual,
1167
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
1168
+ residual_dtype=residual_dtype,
1169
+ is_rms_norm=is_rms_norm,
1170
+ )
1171
+ y = y.reshape(x_shape_og)
1172
+ dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
1173
+ linear_weight = linear_weight.to(dtype)
1174
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1175
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1176
+ # We don't store y, will be recomputed in the backward pass to save memory
1177
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
1178
+ ctx.x_shape_og = x_shape_og
1179
+ ctx.eps = eps
1180
+ ctx.is_rms_norm = is_rms_norm
1181
+ ctx.has_residual = residual is not None
1182
+ ctx.prenorm = prenorm
1183
+ ctx.x_dtype = x.dtype
1184
+ ctx.linear_bias_is_none = linear_bias is None
1185
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1186
+
1187
+ @staticmethod
1188
+ @custom_bwd
1189
+ def backward(ctx, dout, *args):
1190
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1191
+ dout = dout.reshape(-1, dout.shape[-1])
1192
+ dy = F.linear(dout, linear_weight.t())
1193
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1194
+ if dy.stride(-1) != 1:
1195
+ dy = dy.contiguous()
1196
+ assert dy.shape == x.shape
1197
+ if ctx.prenorm:
1198
+ dresidual = args[0]
1199
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1200
+ if dresidual.stride(-1) != 1:
1201
+ dresidual = dresidual.contiguous()
1202
+ assert dresidual.shape == x.shape
1203
+ else:
1204
+ dresidual = None
1205
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1206
+ dy,
1207
+ x,
1208
+ norm_weight,
1209
+ norm_bias,
1210
+ ctx.eps,
1211
+ mean,
1212
+ rstd,
1213
+ dresidual=dresidual,
1214
+ has_residual=ctx.has_residual,
1215
+ is_rms_norm=ctx.is_rms_norm,
1216
+ x_dtype=ctx.x_dtype,
1217
+ recompute_output=True,
1218
+ )
1219
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1220
+ return (
1221
+ dx.reshape(ctx.x_shape_og),
1222
+ dnorm_weight,
1223
+ dnorm_bias,
1224
+ dlinear_weight,
1225
+ dlinear_bias,
1226
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1227
+ None,
1228
+ None,
1229
+ None,
1230
+ None,
1231
+ )
1232
+
1233
+
1234
+ def layer_norm_linear_fn(
1235
+ x,
1236
+ norm_weight,
1237
+ norm_bias,
1238
+ linear_weight,
1239
+ linear_bias,
1240
+ residual=None,
1241
+ eps=1e-6,
1242
+ prenorm=False,
1243
+ residual_in_fp32=False,
1244
+ is_rms_norm=False,
1245
+ ):
1246
+ return LayerNormLinearFn.apply(
1247
+ x,
1248
+ norm_weight,
1249
+ norm_bias,
1250
+ linear_weight,
1251
+ linear_bias,
1252
+ residual,
1253
+ eps,
1254
+ prenorm,
1255
+ residual_in_fp32,
1256
+ is_rms_norm,
1257
+ )
omnigen2/pipelines/__init__.py ADDED
File without changes
omnigen2/pipelines/image_processor.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+
23
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
24
+ from diffusers.configuration_utils import register_to_config
25
+
26
+ class OmniGen2ImageProcessor(VaeImageProcessor):
27
+ """
28
+ Image processor for PixArt image resize and crop.
29
+
30
+ Args:
31
+ do_resize (`bool`, *optional*, defaults to `True`):
32
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
33
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
34
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
35
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
36
+ resample (`str`, *optional*, defaults to `lanczos`):
37
+ Resampling filter to use when resizing the image.
38
+ do_normalize (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the image to [-1,1].
40
+ do_binarize (`bool`, *optional*, defaults to `False`):
41
+ Whether to binarize the image to 0/1.
42
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
43
+ Whether to convert the images to RGB format.
44
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
45
+ Whether to convert the images to grayscale format.
46
+ """
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ do_resize: bool = True,
52
+ vae_scale_factor: int = 16,
53
+ resample: str = "lanczos",
54
+ max_pixels: Optional[int] = None,
55
+ max_side_length: Optional[int] = None,
56
+ do_normalize: bool = True,
57
+ do_binarize: bool = False,
58
+ do_convert_grayscale: bool = False,
59
+ ):
60
+ super().__init__(
61
+ do_resize=do_resize,
62
+ vae_scale_factor=vae_scale_factor,
63
+ resample=resample,
64
+ do_normalize=do_normalize,
65
+ do_binarize=do_binarize,
66
+ do_convert_grayscale=do_convert_grayscale,
67
+ )
68
+
69
+ self.max_pixels = max_pixels
70
+ self.max_side_length = max_side_length
71
+
72
+ def get_new_height_width(
73
+ self,
74
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
75
+ height: Optional[int] = None,
76
+ width: Optional[int] = None,
77
+ max_pixels: Optional[int] = None,
78
+ max_side_length: Optional[int] = None,
79
+ ) -> Tuple[int, int]:
80
+ r"""
81
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
82
+
83
+ Args:
84
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
85
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
86
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
87
+ tensor, it should have shape `[batch, channels, height, width]`.
88
+ height (`Optional[int]`, *optional*, defaults to `None`):
89
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
90
+ width (`Optional[int]`, *optional*, defaults to `None`):
91
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
92
+
93
+ Returns:
94
+ `Tuple[int, int]`:
95
+ A tuple containing the height and width, both resized to the nearest integer multiple of
96
+ `vae_scale_factor`.
97
+ """
98
+
99
+ if height is None:
100
+ if isinstance(image, PIL.Image.Image):
101
+ height = image.height
102
+ elif isinstance(image, torch.Tensor):
103
+ height = image.shape[2]
104
+ else:
105
+ height = image.shape[1]
106
+
107
+ if width is None:
108
+ if isinstance(image, PIL.Image.Image):
109
+ width = image.width
110
+ elif isinstance(image, torch.Tensor):
111
+ width = image.shape[3]
112
+ else:
113
+ width = image.shape[2]
114
+
115
+ if max_side_length is None:
116
+ max_side_length = self.max_side_length
117
+
118
+ if max_pixels is None:
119
+ max_pixels = self.max_pixels
120
+
121
+ ratio = 1.0
122
+ if max_side_length is not None:
123
+ if height > width:
124
+ max_side_length_ratio = max_side_length / height
125
+ else:
126
+ max_side_length_ratio = max_side_length / width
127
+
128
+ cur_pixels = height * width
129
+ max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
130
+ ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
131
+
132
+ new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
133
+ return new_height, new_width
134
+
135
+ def preprocess(
136
+ self,
137
+ image: PipelineImageInput,
138
+ height: Optional[int] = None,
139
+ width: Optional[int] = None,
140
+ max_pixels: Optional[int] = None,
141
+ max_side_length: Optional[int] = None,
142
+ resize_mode: str = "default", # "default", "fill", "crop"
143
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
144
+ ) -> torch.Tensor:
145
+ """
146
+ Preprocess the image input.
147
+
148
+ Args:
149
+ image (`PipelineImageInput`):
150
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
151
+ supported formats.
152
+ height (`int`, *optional*):
153
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
154
+ height.
155
+ width (`int`, *optional*):
156
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
157
+ resize_mode (`str`, *optional*, defaults to `default`):
158
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
159
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
160
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
161
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
162
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
163
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
164
+ supported for PIL image input.
165
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
166
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
167
+
168
+ Returns:
169
+ `torch.Tensor`:
170
+ The preprocessed image.
171
+ """
172
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
173
+
174
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
175
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
176
+ if isinstance(image, torch.Tensor):
177
+ # if image is a pytorch tensor could have 2 possible shapes:
178
+ # 1. batch x height x width: we should insert the channel dimension at position 1
179
+ # 2. channel x height x width: we should insert batch dimension at position 0,
180
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
181
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
182
+ image = image.unsqueeze(1)
183
+ else:
184
+ # if it is a numpy array, it could have 2 possible shapes:
185
+ # 1. batch x height x width: insert channel dimension on last position
186
+ # 2. height x width x channel: insert batch dimension on first position
187
+ if image.shape[-1] == 1:
188
+ image = np.expand_dims(image, axis=0)
189
+ else:
190
+ image = np.expand_dims(image, axis=-1)
191
+
192
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
193
+ warnings.warn(
194
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
195
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
196
+ FutureWarning,
197
+ )
198
+ image = np.concatenate(image, axis=0)
199
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
200
+ warnings.warn(
201
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
202
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
203
+ FutureWarning,
204
+ )
205
+ image = torch.cat(image, axis=0)
206
+
207
+ if not is_valid_image_imagelist(image):
208
+ raise ValueError(
209
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
210
+ )
211
+ if not isinstance(image, list):
212
+ image = [image]
213
+
214
+ if isinstance(image[0], PIL.Image.Image):
215
+ if crops_coords is not None:
216
+ image = [i.crop(crops_coords) for i in image]
217
+ if self.config.do_resize:
218
+ height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
219
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
220
+ if self.config.do_convert_rgb:
221
+ image = [self.convert_to_rgb(i) for i in image]
222
+ elif self.config.do_convert_grayscale:
223
+ image = [self.convert_to_grayscale(i) for i in image]
224
+ image = self.pil_to_numpy(image) # to np
225
+ image = self.numpy_to_pt(image) # to pt
226
+
227
+ elif isinstance(image[0], np.ndarray):
228
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
229
+
230
+ image = self.numpy_to_pt(image)
231
+
232
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
233
+ if self.config.do_resize:
234
+ image = self.resize(image, height, width)
235
+
236
+ elif isinstance(image[0], torch.Tensor):
237
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
238
+
239
+ if self.config.do_convert_grayscale and image.ndim == 3:
240
+ image = image.unsqueeze(1)
241
+
242
+ channel = image.shape[1]
243
+ # don't need any preprocess if the image is latents
244
+ if channel == self.config.vae_latent_channels:
245
+ return image
246
+
247
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
248
+ if self.config.do_resize:
249
+ image = self.resize(image, height, width)
250
+
251
+ # expected range [0,1], normalize to [-1,1]
252
+ do_normalize = self.config.do_normalize
253
+ if do_normalize and image.min() < 0:
254
+ warnings.warn(
255
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
256
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
257
+ FutureWarning,
258
+ )
259
+ do_normalize = False
260
+ if do_normalize:
261
+ image = self.normalize(image)
262
+
263
+ if self.config.do_binarize:
264
+ image = self.binarize(image)
265
+
266
+ return image
omnigen2/pipelines/omnigen2/pipeline_omnigen2.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniGen2 Diffusion Pipeline
3
+
4
+ Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import math
23
+
24
+ from PIL import Image
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ from transformers import Qwen2_5_VLForConditionalGeneration
30
+
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import OmniGen2Transformer2DModel
33
+ from ...models.transformers.repo import OmniGen2RotaryPosEmbed
34
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
+ from diffusers.utils import (
36
+ is_torch_xla_available,
37
+ logging,
38
+ )
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
41
+
42
+ from dataclasses import dataclass
43
+
44
+ import PIL.Image
45
+
46
+ from diffusers.utils import BaseOutput
47
+
48
+ from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ @dataclass
61
+ class FMPipelineOutput(BaseOutput):
62
+ """
63
+ Output class for OmniGen2 pipeline.
64
+
65
+ Args:
66
+ images (Union[List[PIL.Image.Image], np.ndarray]):
67
+ List of denoised PIL images of length `batch_size` or numpy array of shape
68
+ `(batch_size, height, width, num_channels)`. Contains the generated images.
69
+ """
70
+ images: Union[List[PIL.Image.Image], np.ndarray]
71
+
72
+
73
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
74
+ def retrieve_timesteps(
75
+ scheduler,
76
+ num_inference_steps: Optional[int] = None,
77
+ device: Optional[Union[str, torch.device]] = None,
78
+ timesteps: Optional[List[int]] = None,
79
+ **kwargs,
80
+ ):
81
+ """
82
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
83
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
84
+
85
+ Args:
86
+ scheduler (`SchedulerMixin`):
87
+ The scheduler to get timesteps from.
88
+ num_inference_steps (`int`):
89
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
90
+ must be `None`.
91
+ device (`str` or `torch.device`, *optional*):
92
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
93
+ timesteps (`List[int]`, *optional*):
94
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
95
+ `num_inference_steps` and `sigmas` must be `None`.
96
+ sigmas (`List[float]`, *optional*):
97
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
98
+ `num_inference_steps` and `timesteps` must be `None`.
99
+
100
+ Returns:
101
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
102
+ second element is the number of inference steps.
103
+ """
104
+ if timesteps is not None:
105
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
106
+ if not accepts_timesteps:
107
+ raise ValueError(
108
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
109
+ f" timestep schedules. Please check whether you are using the correct scheduler."
110
+ )
111
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
112
+ timesteps = scheduler.timesteps
113
+ num_inference_steps = len(timesteps)
114
+ else:
115
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
116
+ timesteps = scheduler.timesteps
117
+ return timesteps, num_inference_steps
118
+
119
+
120
+ class OmniGen2Pipeline(DiffusionPipeline):
121
+ """
122
+ Pipeline for text-to-image generation using OmniGen2.
123
+
124
+ This pipeline implements a text-to-image generation model that uses:
125
+ - Qwen2.5-VL for text encoding
126
+ - A custom transformer architecture for image generation
127
+ - VAE for image encoding/decoding
128
+ - FlowMatchEulerDiscreteScheduler for noise scheduling
129
+
130
+ Args:
131
+ transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
132
+ vae (AutoencoderKL): The VAE model for image encoding/decoding.
133
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
134
+ text_encoder (Qwen2_5_VLModel): The text encoder model.
135
+ tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
136
+ """
137
+
138
+ model_cpu_offload_seq = "mllm->transformer->vae"
139
+
140
+ def __init__(
141
+ self,
142
+ transformer: OmniGen2Transformer2DModel,
143
+ vae: AutoencoderKL,
144
+ scheduler: FlowMatchEulerDiscreteScheduler,
145
+ mllm: Qwen2_5_VLForConditionalGeneration,
146
+ processor,
147
+ ) -> None:
148
+ """
149
+ Initialize the OmniGen2 pipeline.
150
+
151
+ Args:
152
+ transformer: The transformer model for image generation.
153
+ vae: The VAE model for image encoding/decoding.
154
+ scheduler: The scheduler for noise scheduling.
155
+ text_encoder: The text encoder model.
156
+ tokenizer: The tokenizer for text processing.
157
+ """
158
+ super().__init__()
159
+
160
+ self.register_modules(
161
+ transformer=transformer,
162
+ vae=vae,
163
+ scheduler=scheduler,
164
+ mllm=mllm,
165
+ processor=processor
166
+ )
167
+ self.vae_scale_factor = (
168
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
169
+ )
170
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
171
+ self.default_sample_size = 128
172
+
173
+ def prepare_latents(
174
+ self,
175
+ batch_size: int,
176
+ num_channels_latents: int,
177
+ height: int,
178
+ width: int,
179
+ dtype: torch.dtype,
180
+ device: torch.device,
181
+ generator: Optional[torch.Generator],
182
+ latents: Optional[torch.FloatTensor] = None,
183
+ ) -> torch.FloatTensor:
184
+ """
185
+ Prepare the initial latents for the diffusion process.
186
+
187
+ Args:
188
+ batch_size: The number of images to generate.
189
+ num_channels_latents: The number of channels in the latent space.
190
+ height: The height of the generated image.
191
+ width: The width of the generated image.
192
+ dtype: The data type of the latents.
193
+ device: The device to place the latents on.
194
+ generator: The random number generator to use.
195
+ latents: Optional pre-computed latents to use instead of random initialization.
196
+
197
+ Returns:
198
+ torch.FloatTensor: The prepared latents tensor.
199
+ """
200
+ height = int(height) // self.vae_scale_factor
201
+ width = int(width) // self.vae_scale_factor
202
+
203
+ shape = (batch_size, num_channels_latents, height, width)
204
+
205
+ if latents is None:
206
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
207
+ else:
208
+ latents = latents.to(device)
209
+ return latents
210
+
211
+ def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
212
+ """
213
+ Encode an image into the VAE latent space.
214
+
215
+ Args:
216
+ img: The input image tensor to encode.
217
+
218
+ Returns:
219
+ torch.FloatTensor: The encoded latent representation.
220
+ """
221
+ z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
222
+ if self.vae.config.shift_factor is not None:
223
+ z0 = z0 - self.vae.config.shift_factor
224
+ if self.vae.config.scaling_factor is not None:
225
+ z0 = z0 * self.vae.config.scaling_factor
226
+ z0 = z0.to(dtype=self.vae.dtype)
227
+ return z0
228
+
229
+ def prepare_image(
230
+ self,
231
+ images: Union[List[PIL.Image.Image], PIL.Image.Image],
232
+ batch_size: int,
233
+ num_images_per_prompt: int,
234
+ max_pixels: int,
235
+ max_side_length: int,
236
+ device: torch.device,
237
+ dtype: torch.dtype,
238
+ ) -> List[Optional[torch.FloatTensor]]:
239
+ """
240
+ Prepare input images for processing by encoding them into the VAE latent space.
241
+
242
+ Args:
243
+ images: Single image or list of images to process.
244
+ batch_size: The number of images to generate per prompt.
245
+ num_images_per_prompt: The number of images to generate for each prompt.
246
+ device: The device to place the encoded latents on.
247
+ dtype: The data type of the encoded latents.
248
+
249
+ Returns:
250
+ List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
251
+ """
252
+ if batch_size == 1:
253
+ images = [images]
254
+ latents = []
255
+ for i, img in enumerate(images):
256
+ if img is not None and len(img) > 0:
257
+ ref_latents = []
258
+ for j, img_j in enumerate(img):
259
+ img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
260
+ ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
261
+ else:
262
+ ref_latents = None
263
+ for _ in range(num_images_per_prompt):
264
+ latents.append(ref_latents)
265
+
266
+ return latents
267
+
268
+ def _get_qwen2_prompt_embeds(
269
+ self,
270
+ prompt: Union[str, List[str]],
271
+ device: Optional[torch.device] = None,
272
+ max_sequence_length: int = 256,
273
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
274
+ """
275
+ Get prompt embeddings from the Qwen2 text encoder.
276
+
277
+ Args:
278
+ prompt: The prompt or list of prompts to encode.
279
+ device: The device to place the embeddings on. If None, uses the pipeline's device.
280
+ max_sequence_length: Maximum sequence length for tokenization.
281
+
282
+ Returns:
283
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
284
+ - The prompt embeddings tensor
285
+ - The attention mask tensor
286
+
287
+ Raises:
288
+ Warning: If the input text is truncated due to sequence length limitations.
289
+ """
290
+ device = device or self._execution_device
291
+ prompt = [prompt] if isinstance(prompt, str) else prompt
292
+ text_inputs = self.processor.tokenizer(
293
+ prompt,
294
+ padding="max_length",
295
+ max_length=max_sequence_length,
296
+ truncation=True,
297
+ return_tensors="pt",
298
+ )
299
+
300
+ text_input_ids = text_inputs.input_ids.to(device)
301
+ untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
302
+
303
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
304
+ removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
305
+ logger.warning(
306
+ "The following part of your input was truncated because Gemma can only handle sequences up to"
307
+ f" {max_sequence_length} tokens: {removed_text}"
308
+ )
309
+
310
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
311
+ prompt_embeds = self.mllm(
312
+ text_input_ids,
313
+ attention_mask=prompt_attention_mask,
314
+ output_hidden_states=True,
315
+ ).hidden_states[-1]
316
+
317
+ if self.mllm is not None:
318
+ dtype = self.mllm.dtype
319
+ elif self.transformer is not None:
320
+ dtype = self.transformer.dtype
321
+ else:
322
+ dtype = None
323
+
324
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
325
+
326
+ return prompt_embeds, prompt_attention_mask
327
+
328
+ def _apply_chat_template(self, prompt: str):
329
+ prompt = [
330
+ {
331
+ "role": "system",
332
+ "content": "You are a helpful assistant that generates high-quality images based on user instructions.",
333
+ },
334
+ {"role": "user", "content": prompt},
335
+ ]
336
+ prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
337
+ return prompt
338
+
339
+ def encode_prompt(
340
+ self,
341
+ prompt: Union[str, List[str]],
342
+ do_classifier_free_guidance: bool = True,
343
+ negative_prompt: Optional[Union[str, List[str]]] = None,
344
+ num_images_per_prompt: int = 1,
345
+ device: Optional[torch.device] = None,
346
+ prompt_embeds: Optional[torch.Tensor] = None,
347
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
348
+ prompt_attention_mask: Optional[torch.Tensor] = None,
349
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
350
+ max_sequence_length: int = 256,
351
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
352
+ r"""
353
+ Encodes the prompt into text encoder hidden states.
354
+
355
+ Args:
356
+ prompt (`str` or `List[str]`, *optional*):
357
+ prompt to be encoded
358
+ negative_prompt (`str` or `List[str]`, *optional*):
359
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
360
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
361
+ Lumina-T2I, this should be "".
362
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
363
+ whether to use classifier free guidance or not
364
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
365
+ number of images that should be generated per prompt
366
+ device: (`torch.device`, *optional*):
367
+ torch device to place the resulting embeddings on
368
+ prompt_embeds (`torch.Tensor`, *optional*):
369
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
370
+ provided, text embeddings will be generated from `prompt` input argument.
371
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
372
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
373
+ max_sequence_length (`int`, defaults to `256`):
374
+ Maximum sequence length to use for the prompt.
375
+ """
376
+ device = device or self._execution_device
377
+
378
+ prompt = [prompt] if isinstance(prompt, str) else prompt
379
+ prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
380
+
381
+ if prompt is not None:
382
+ batch_size = len(prompt)
383
+ else:
384
+ batch_size = prompt_embeds.shape[0]
385
+ if prompt_embeds is None:
386
+ prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
387
+ prompt=prompt,
388
+ device=device,
389
+ max_sequence_length=max_sequence_length
390
+ )
391
+
392
+ batch_size, seq_len, _ = prompt_embeds.shape
393
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
394
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
395
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
396
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
397
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
398
+
399
+ # Get negative embeddings for classifier free guidance
400
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
401
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
402
+
403
+ # Normalize str to list
404
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
405
+ negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
406
+
407
+ if prompt is not None and type(prompt) is not type(negative_prompt):
408
+ raise TypeError(
409
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
410
+ f" {type(prompt)}."
411
+ )
412
+ elif isinstance(negative_prompt, str):
413
+ negative_prompt = [negative_prompt]
414
+ elif batch_size != len(negative_prompt):
415
+ raise ValueError(
416
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
417
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
418
+ " the batch size of `prompt`."
419
+ )
420
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
421
+ prompt=negative_prompt,
422
+ device=device,
423
+ max_sequence_length=max_sequence_length,
424
+ )
425
+
426
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
427
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
428
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
429
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
430
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
431
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
432
+ batch_size * num_images_per_prompt, -1
433
+ )
434
+
435
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
436
+
437
+ @property
438
+ def num_timesteps(self):
439
+ return self._num_timesteps
440
+
441
+ @property
442
+ def text_guidance_scale(self):
443
+ return self._text_guidance_scale
444
+
445
+ @property
446
+ def image_guidance_scale(self):
447
+ return self._image_guidance_scale
448
+
449
+ @property
450
+ def cfg_range(self):
451
+ return self._cfg_range
452
+
453
+ @torch.no_grad()
454
+ def __call__(
455
+ self,
456
+ prompt: Optional[Union[str, List[str]]] = None,
457
+ negative_prompt: Optional[Union[str, List[str]]] = None,
458
+ prompt_embeds: Optional[torch.FloatTensor] = None,
459
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
460
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
461
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
462
+ max_sequence_length: Optional[int] = None,
463
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
464
+ input_images: Optional[List[PIL.Image.Image]] = None,
465
+ num_images_per_prompt: int = 1,
466
+ height: Optional[int] = None,
467
+ width: Optional[int] = None,
468
+ max_pixels: int = 1024 * 1024,
469
+ max_input_image_side_length: int = 1024,
470
+ align_res: bool = True,
471
+ num_inference_steps: int = 28,
472
+ text_guidance_scale: float = 4.0,
473
+ image_guidance_scale: float = 1.0,
474
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
475
+ attention_kwargs: Optional[Dict[str, Any]] = None,
476
+ timesteps: List[int] = None,
477
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
478
+ latents: Optional[torch.FloatTensor] = None,
479
+ output_type: Optional[str] = "pil",
480
+ return_dict: bool = True,
481
+ verbose: bool = False,
482
+ step_func=None,
483
+ ):
484
+
485
+ height = height or self.default_sample_size * self.vae_scale_factor
486
+ width = width or self.default_sample_size * self.vae_scale_factor
487
+
488
+ self._text_guidance_scale = text_guidance_scale
489
+ self._image_guidance_scale = image_guidance_scale
490
+ self._cfg_range = cfg_range
491
+ self._attention_kwargs = attention_kwargs
492
+
493
+ # 2. Define call parameters
494
+ if prompt is not None and isinstance(prompt, str):
495
+ batch_size = 1
496
+ elif prompt is not None and isinstance(prompt, list):
497
+ batch_size = len(prompt)
498
+ else:
499
+ batch_size = prompt_embeds.shape[0]
500
+
501
+ device = self._execution_device
502
+
503
+ # 3. Encode input prompt
504
+ (
505
+ prompt_embeds,
506
+ prompt_attention_mask,
507
+ negative_prompt_embeds,
508
+ negative_prompt_attention_mask,
509
+ ) = self.encode_prompt(
510
+ prompt,
511
+ self.text_guidance_scale > 1.0,
512
+ negative_prompt=negative_prompt,
513
+ num_images_per_prompt=num_images_per_prompt,
514
+ device=device,
515
+ prompt_embeds=prompt_embeds,
516
+ negative_prompt_embeds=negative_prompt_embeds,
517
+ prompt_attention_mask=prompt_attention_mask,
518
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
519
+ max_sequence_length=max_sequence_length,
520
+ )
521
+
522
+ dtype = self.vae.dtype
523
+ # 3. Prepare control image
524
+ ref_latents = self.prepare_image(
525
+ images=input_images,
526
+ batch_size=batch_size,
527
+ num_images_per_prompt=num_images_per_prompt,
528
+ max_pixels=max_pixels,
529
+ max_side_length=max_input_image_side_length,
530
+ device=device,
531
+ dtype=dtype,
532
+ )
533
+
534
+ if input_images is None:
535
+ input_images = []
536
+
537
+ if len(input_images) == 1 and align_res:
538
+ width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
539
+ ori_width, ori_height = width, height
540
+ else:
541
+ ori_width, ori_height = width, height
542
+
543
+ cur_pixels = height * width
544
+ ratio = (max_pixels / cur_pixels) ** 0.5
545
+ ratio = min(ratio, 1.0)
546
+
547
+ height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
548
+
549
+ if len(input_images) == 0:
550
+ self._image_guidance_scale = 1
551
+
552
+ # 4. Prepare latents.
553
+ latent_channels = self.transformer.config.in_channels
554
+ latents = self.prepare_latents(
555
+ batch_size * num_images_per_prompt,
556
+ latent_channels,
557
+ height,
558
+ width,
559
+ prompt_embeds.dtype,
560
+ device,
561
+ generator,
562
+ latents,
563
+ )
564
+
565
+ freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
566
+ self.transformer.config.axes_dim_rope,
567
+ self.transformer.config.axes_lens,
568
+ theta=10000,
569
+ )
570
+
571
+ image = self.processing(
572
+ latents=latents,
573
+ ref_latents=ref_latents,
574
+ prompt_embeds=prompt_embeds,
575
+ freqs_cis=freqs_cis,
576
+ negative_prompt_embeds=negative_prompt_embeds,
577
+ prompt_attention_mask=prompt_attention_mask,
578
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
579
+ num_inference_steps=num_inference_steps,
580
+ timesteps=timesteps,
581
+ device=device,
582
+ dtype=dtype,
583
+ verbose=verbose,
584
+ step_func=step_func,
585
+ )
586
+
587
+ image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
588
+
589
+ image = self.image_processor.postprocess(image, output_type=output_type)
590
+
591
+ # Offload all models
592
+ self.maybe_free_model_hooks()
593
+
594
+ if not return_dict:
595
+ return image
596
+ else:
597
+ return FMPipelineOutput(images=image)
598
+
599
+ def processing(
600
+ self,
601
+ latents,
602
+ ref_latents,
603
+ prompt_embeds,
604
+ freqs_cis,
605
+ negative_prompt_embeds,
606
+ prompt_attention_mask,
607
+ negative_prompt_attention_mask,
608
+ num_inference_steps,
609
+ timesteps,
610
+ device,
611
+ dtype,
612
+ verbose,
613
+ step_func=None
614
+ ):
615
+ batch_size = latents.shape[0]
616
+
617
+ timesteps, num_inference_steps = retrieve_timesteps(
618
+ self.scheduler,
619
+ num_inference_steps,
620
+ device,
621
+ timesteps,
622
+ num_tokens=latents.shape[-2] * latents.shape[-1]
623
+ )
624
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
625
+ self._num_timesteps = len(timesteps)
626
+
627
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
628
+ for i, t in enumerate(timesteps):
629
+ model_pred = self.predict(
630
+ t=t,
631
+ latents=latents,
632
+ prompt_embeds=prompt_embeds,
633
+ freqs_cis=freqs_cis,
634
+ prompt_attention_mask=prompt_attention_mask,
635
+ ref_image_hidden_states=ref_latents,
636
+ )
637
+ text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
638
+ image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
639
+
640
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
641
+ model_pred_ref = self.predict(
642
+ t=t,
643
+ latents=latents,
644
+ prompt_embeds=negative_prompt_embeds,
645
+ freqs_cis=freqs_cis,
646
+ prompt_attention_mask=negative_prompt_attention_mask,
647
+ ref_image_hidden_states=ref_latents,
648
+ )
649
+
650
+ if image_guidance_scale != 1:
651
+ model_pred_uncond = self.predict(
652
+ t=t,
653
+ latents=latents,
654
+ prompt_embeds=negative_prompt_embeds,
655
+ freqs_cis=freqs_cis,
656
+ prompt_attention_mask=negative_prompt_attention_mask,
657
+ ref_image_hidden_states=None,
658
+ )
659
+ else:
660
+ model_pred_uncond = torch.zeros_like(model_pred)
661
+
662
+ model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
663
+ text_guidance_scale * (model_pred - model_pred_ref)
664
+ elif text_guidance_scale > 1.0:
665
+ model_pred_uncond = self.predict(
666
+ t=t,
667
+ latents=latents,
668
+ prompt_embeds=negative_prompt_embeds,
669
+ freqs_cis=freqs_cis,
670
+ prompt_attention_mask=negative_prompt_attention_mask,
671
+ ref_image_hidden_states=None,
672
+ )
673
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
674
+
675
+ latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
676
+
677
+ latents = latents.to(dtype=dtype)
678
+
679
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
680
+ progress_bar.update()
681
+
682
+ if step_func is not None:
683
+ step_func(i, self._num_timesteps)
684
+
685
+ latents = latents.to(dtype=dtype)
686
+ if self.vae.config.scaling_factor is not None:
687
+ latents = latents / self.vae.config.scaling_factor
688
+ if self.vae.config.shift_factor is not None:
689
+ latents = latents + self.vae.config.shift_factor
690
+ image = self.vae.decode(latents, return_dict=False)[0]
691
+
692
+ return image
693
+
694
+ def predict(
695
+ self,
696
+ t,
697
+ latents,
698
+ prompt_embeds,
699
+ freqs_cis,
700
+ prompt_attention_mask,
701
+ ref_image_hidden_states,
702
+ ):
703
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
704
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
705
+
706
+ batch_size, num_channels_latents, height, width = latents.shape
707
+
708
+ optional_kwargs = {}
709
+ if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
710
+ optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
711
+
712
+ model_pred = self.transformer(
713
+ latents,
714
+ timestep,
715
+ prompt_embeds,
716
+ freqs_cis,
717
+ prompt_attention_mask,
718
+ **optional_kwargs
719
+ )
720
+ return model_pred
omnigen2/pipelines/omnigen2/pipeline_omnigen2_chat.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniGen2 Diffusion Pipeline
3
+
4
+ Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import math
23
+
24
+ from PIL import Image
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ from transformers import Qwen2_5_VLForConditionalGeneration
30
+
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import OmniGen2Transformer2DModel
33
+ from ...models.transformers.repo import OmniGen2RotaryPosEmbed
34
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
+ from diffusers.utils import (
36
+ is_torch_xla_available,
37
+ logging,
38
+ )
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
41
+
42
+ from dataclasses import dataclass
43
+
44
+ import PIL.Image
45
+
46
+ from diffusers.utils import BaseOutput
47
+
48
+ from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ @dataclass
61
+ class OmniGen2PipelineOutput(BaseOutput):
62
+ """
63
+ Output class for OmniGen2 pipeline.
64
+
65
+ Args:
66
+ images (Union[List[PIL.Image.Image], np.ndarray]):
67
+ List of denoised PIL images of length `batch_size` or numpy array of shape
68
+ `(batch_size, height, width, num_channels)`. Contains the generated images.
69
+ """
70
+ text: str
71
+ images: Union[List[PIL.Image.Image], np.ndarray]
72
+
73
+
74
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
75
+ def retrieve_timesteps(
76
+ scheduler,
77
+ num_inference_steps: Optional[int] = None,
78
+ device: Optional[Union[str, torch.device]] = None,
79
+ timesteps: Optional[List[int]] = None,
80
+ **kwargs,
81
+ ):
82
+ """
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None:
106
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
107
+ if not accepts_timesteps:
108
+ raise ValueError(
109
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
110
+ f" timestep schedules. Please check whether you are using the correct scheduler."
111
+ )
112
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
113
+ timesteps = scheduler.timesteps
114
+ num_inference_steps = len(timesteps)
115
+ else:
116
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
117
+ timesteps = scheduler.timesteps
118
+ return timesteps, num_inference_steps
119
+
120
+
121
+ class OmniGen2ChatPipeline(DiffusionPipeline):
122
+ """
123
+ Pipeline for text-to-image generation using OmniGen2.
124
+
125
+ This pipeline implements a text-to-image generation model that uses:
126
+ - Qwen2.5-VL for text encoding
127
+ - A custom transformer architecture for image generation
128
+ - VAE for image encoding/decoding
129
+ - FlowMatchEulerDiscreteScheduler for noise scheduling
130
+
131
+ Args:
132
+ transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
133
+ vae (AutoencoderKL): The VAE model for image encoding/decoding.
134
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
135
+ text_encoder (Qwen2_5_VLModel): The text encoder model.
136
+ tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
137
+ """
138
+
139
+ model_cpu_offload_seq = "mllm->transformer->vae"
140
+ def __init__(
141
+ self,
142
+ transformer: OmniGen2Transformer2DModel,
143
+ vae: AutoencoderKL,
144
+ scheduler: FlowMatchEulerDiscreteScheduler,
145
+ mllm: Qwen2_5_VLForConditionalGeneration,
146
+ processor,
147
+ ) -> None:
148
+ """
149
+ Initialize the OmniGen2 pipeline.
150
+
151
+ Args:
152
+ transformer: The transformer model for image generation.
153
+ vae: The VAE model for image encoding/decoding.
154
+ scheduler: The scheduler for noise scheduling.
155
+ text_encoder: The text encoder model.
156
+ tokenizer: The tokenizer for text processing.
157
+ """
158
+ super().__init__()
159
+
160
+ self.register_modules(
161
+ transformer=transformer,
162
+ vae=vae,
163
+ scheduler=scheduler,
164
+ mllm=mllm,
165
+ processor=processor
166
+ )
167
+ self.vae_scale_factor = (
168
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
169
+ )
170
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
171
+ self.default_sample_size = 128
172
+
173
+ def prepare_latents(
174
+ self,
175
+ batch_size: int,
176
+ num_channels_latents: int,
177
+ height: int,
178
+ width: int,
179
+ dtype: torch.dtype,
180
+ device: torch.device,
181
+ generator: Optional[torch.Generator],
182
+ latents: Optional[torch.FloatTensor] = None,
183
+ ) -> torch.FloatTensor:
184
+ """
185
+ Prepare the initial latents for the diffusion process.
186
+
187
+ Args:
188
+ batch_size: The number of images to generate.
189
+ num_channels_latents: The number of channels in the latent space.
190
+ height: The height of the generated image.
191
+ width: The width of the generated image.
192
+ dtype: The data type of the latents.
193
+ device: The device to place the latents on.
194
+ generator: The random number generator to use.
195
+ latents: Optional pre-computed latents to use instead of random initialization.
196
+
197
+ Returns:
198
+ torch.FloatTensor: The prepared latents tensor.
199
+ """
200
+ height = int(height) // self.vae_scale_factor
201
+ width = int(width) // self.vae_scale_factor
202
+
203
+ shape = (batch_size, num_channels_latents, height, width)
204
+
205
+ if latents is None:
206
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
207
+ else:
208
+ latents = latents.to(device)
209
+ return latents
210
+
211
+ def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
212
+ """
213
+ Encode an image into the VAE latent space.
214
+
215
+ Args:
216
+ img: The input image tensor to encode.
217
+
218
+ Returns:
219
+ torch.FloatTensor: The encoded latent representation.
220
+ """
221
+ z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
222
+ if self.vae.config.shift_factor is not None:
223
+ z0 = z0 - self.vae.config.shift_factor
224
+ if self.vae.config.scaling_factor is not None:
225
+ z0 = z0 * self.vae.config.scaling_factor
226
+ z0 = z0.to(dtype=self.vae.dtype)
227
+ return z0
228
+
229
+ def prepare_image(
230
+ self,
231
+ images: Union[List[PIL.Image.Image], PIL.Image.Image],
232
+ batch_size: int,
233
+ num_images_per_prompt: int,
234
+ max_pixels: int,
235
+ max_side_length: int,
236
+ device: torch.device,
237
+ dtype: torch.dtype,
238
+ ) -> List[Optional[torch.FloatTensor]]:
239
+ """
240
+ Prepare input images for processing by encoding them into the VAE latent space.
241
+
242
+ Args:
243
+ images: Single image or list of images to process.
244
+ batch_size: The number of images to generate per prompt.
245
+ num_images_per_prompt: The number of images to generate for each prompt.
246
+ device: The device to place the encoded latents on.
247
+ dtype: The data type of the encoded latents.
248
+
249
+ Returns:
250
+ List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
251
+ """
252
+ if batch_size == 1:
253
+ images = [images]
254
+ latents = []
255
+ for i, img in enumerate(images):
256
+ if img is not None and len(img) > 0:
257
+ ref_latents = []
258
+ for j, img_j in enumerate(img):
259
+ img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
260
+ ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
261
+ else:
262
+ ref_latents = None
263
+ for _ in range(num_images_per_prompt):
264
+ latents.append(ref_latents)
265
+
266
+ return latents
267
+
268
+ def _apply_chat_template(self, prompt: str, images: List = None):
269
+ if images is not None:
270
+ prompt += "".join(
271
+ [
272
+ f"<img{i}>: <|vision_start|><|image_pad|><|vision_end|>"
273
+ for i in range(1, len(images) + 1)
274
+ ]
275
+ )
276
+ prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
277
+ return prompt
278
+
279
+ def _get_qwen2_prompt_embeds(
280
+ self,
281
+ prompt: Union[str, List[str]],
282
+ input_images = None,
283
+ device: Optional[torch.device] = None,
284
+ use_only_text_hidden_states: bool = True,
285
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
286
+ """
287
+ Get prompt embeddings from the Qwen2 text encoder.
288
+
289
+ Args:
290
+ prompt: The prompt or list of prompts to encode.
291
+ device: The device to place the embeddings on. If None, uses the pipeline's device.
292
+
293
+ Returns:
294
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
295
+ - The prompt embeddings tensor
296
+ - The attention mask tensor
297
+
298
+ Raises:
299
+ Warning: If the input text is truncated due to sequence length limitations.
300
+ """
301
+ device = device or self._execution_device
302
+ prompt = [prompt] if isinstance(prompt, str) else prompt
303
+
304
+ inputs = self.processor(
305
+ text=prompt,
306
+ images=input_images,
307
+ videos=None,
308
+ padding=True,
309
+ return_tensors="pt",
310
+ )
311
+ inputs = inputs.to(device)
312
+
313
+ prompt_embeds = self.mllm(
314
+ **inputs,
315
+ output_hidden_states=True,
316
+ ).hidden_states[-1]
317
+
318
+ text_input_ids = inputs.input_ids
319
+ text_mask = inputs.attention_mask
320
+ if use_only_text_hidden_states:
321
+ mask = text_input_ids != self.mllm.config.image_token_id
322
+ mask = mask & text_mask
323
+ mask = mask.bool()
324
+
325
+ text_l = mask.sum(dim=-1)
326
+ max_l = text_l.max()
327
+ text_batch_size = prompt_embeds.size(0)
328
+ new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype)
329
+ new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device)
330
+ for i in range(text_batch_size):
331
+ new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]]
332
+ new_text_mask[i, :text_l[i]] = 1
333
+
334
+ prompt_embeds = new_prompt_embeds
335
+ text_mask = new_text_mask
336
+
337
+ prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device)
338
+ return prompt_embeds, text_mask
339
+
340
+
341
+ def encode_prompt(
342
+ self,
343
+ prompt: Union[str, List[str]],
344
+ input_images: Optional[Union[str, List[str]]] = None,
345
+ do_classifier_free_guidance: bool = True,
346
+ negative_prompt: Optional[Union[str, List[str]]] = None,
347
+ num_images_per_prompt: int = 1,
348
+ device: Optional[torch.device] = None,
349
+ prompt_embeds: Optional[torch.Tensor] = None,
350
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
351
+ prompt_attention_mask: Optional[torch.Tensor] = None,
352
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
353
+ max_sequence_length: int = 256,
354
+ use_text_encoder_penultimate_layer_feats: bool = False
355
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
356
+ r"""
357
+ Encodes the prompt into text encoder hidden states.
358
+
359
+ Args:
360
+ prompt (`str` or `List[str]`, *optional*):
361
+ prompt to be encoded
362
+ negative_prompt (`str` or `List[str]`, *optional*):
363
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
364
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
365
+ Lumina-T2I, this should be "".
366
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
367
+ whether to use classifier free guidance or not
368
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
369
+ number of images that should be generated per prompt
370
+ device: (`torch.device`, *optional*):
371
+ torch device to place the resulting embeddings on
372
+ prompt_embeds (`torch.Tensor`, *optional*):
373
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
374
+ provided, text embeddings will be generated from `prompt` input argument.
375
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
376
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
377
+ max_sequence_length (`int`, defaults to `256`):
378
+ Maximum sequence length to use for the prompt.
379
+ """
380
+ device = device or self._execution_device
381
+
382
+ prompt = [prompt] if isinstance(prompt, str) else prompt
383
+
384
+ if prompt is not None:
385
+ batch_size = len(prompt)
386
+ else:
387
+ batch_size = prompt_embeds.shape[0]
388
+ if prompt_embeds is None:
389
+ prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
390
+ prompt=prompt,
391
+ input_images=input_images,
392
+ device=device,
393
+ )
394
+
395
+ batch_size, seq_len, _ = prompt_embeds.shape
396
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
397
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
398
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
399
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
400
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
401
+
402
+ # Get negative embeddings for classifier free guidance
403
+ negative_prompt_embeds, negative_prompt_attention_mask = None, None
404
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
405
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
406
+
407
+ # Normalize str to list
408
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
409
+ negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
410
+
411
+ if prompt is not None and type(prompt) is not type(negative_prompt):
412
+ raise TypeError(
413
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
414
+ f" {type(prompt)}."
415
+ )
416
+ elif isinstance(negative_prompt, str):
417
+ negative_prompt = [negative_prompt]
418
+ elif batch_size != len(negative_prompt):
419
+ raise ValueError(
420
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
421
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
422
+ " the batch size of `prompt`."
423
+ )
424
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
425
+ prompt=negative_prompt,
426
+ device=device,
427
+ )
428
+
429
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
430
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
431
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
432
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
433
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
434
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
435
+ batch_size * num_images_per_prompt, -1
436
+ )
437
+
438
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
439
+
440
+ @property
441
+ def num_timesteps(self):
442
+ return self._num_timesteps
443
+
444
+ @property
445
+ def text_guidance_scale(self):
446
+ return self._text_guidance_scale
447
+
448
+ @property
449
+ def image_guidance_scale(self):
450
+ return self._image_guidance_scale
451
+
452
+ @property
453
+ def cfg_range(self):
454
+ return self._cfg_range
455
+
456
+ def prepare_inputs_for_text_generation(self, prompts, input_images, device):
457
+ if isinstance(prompts, str):
458
+ prompts = [prompts]
459
+
460
+ ori_padding_side = self.processor.tokenizer.padding_side
461
+ self.processor.tokenizer.padding_side = "left"
462
+ inputs = self.processor(
463
+ text=prompts,
464
+ images=input_images,
465
+ videos=None,
466
+ padding=True,
467
+ return_tensors="pt",
468
+ ).to(device)
469
+ self.processor.tokenizer.padding_side = ori_padding_side
470
+ return inputs
471
+
472
+ def generate_text(self, prompt, input_images):
473
+ inputs = self.prepare_inputs_for_text_generation(
474
+ prompt, input_images, self.mllm.device
475
+ )
476
+ generated_ids = self.mllm.generate(
477
+ **inputs,
478
+ tokenizer=self.processor.tokenizer,
479
+ max_new_tokens=256,
480
+ stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"],
481
+ ) # stop_words=[151643, 151645, 151665]
482
+ generated_ids_trimmed = [
483
+ out_ids[len(in_ids) :]
484
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
485
+ ]
486
+ output_texts = self.processor.batch_decode(
487
+ generated_ids_trimmed,
488
+ # skip_special_tokens=True,
489
+ skip_special_tokens=False,
490
+ clean_up_tokenization_spaces=False,
491
+ )
492
+ return output_texts
493
+
494
+ def generate_image(
495
+ self,
496
+ prompt: Optional[Union[str, List[str]]] = None,
497
+ negative_prompt: Optional[Union[str, List[str]]] = None,
498
+ prompt_embeds: Optional[torch.FloatTensor] = None,
499
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
500
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
501
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
502
+ use_text_encoder_penultimate_layer_feats: bool = False,
503
+ max_sequence_length: Optional[int] = None,
504
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
505
+ input_images: Optional[List[PIL.Image.Image]] = None,
506
+ num_images_per_prompt: int = 1,
507
+ height: Optional[int] = None,
508
+ width: Optional[int] = None,
509
+ max_pixels: int = 1024 * 1024,
510
+ max_input_image_side_length: int = 1024,
511
+ align_res: bool = True,
512
+ num_inference_steps: int = 28,
513
+ text_guidance_scale: float = 4.0,
514
+ image_guidance_scale: float = 1.0,
515
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
516
+ attention_kwargs: Optional[Dict[str, Any]] = None,
517
+ timesteps: List[int] = None,
518
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
519
+ latents: Optional[torch.FloatTensor] = None,
520
+ output_type: Optional[str] = "pil",
521
+ return_dict: bool = True,
522
+ verbose: bool = False,
523
+ step_func=None,
524
+ ):
525
+ height = height or self.default_sample_size * self.vae_scale_factor
526
+ width = width or self.default_sample_size * self.vae_scale_factor
527
+
528
+ self._text_guidance_scale = text_guidance_scale
529
+ self._image_guidance_scale = image_guidance_scale
530
+ self._cfg_range = cfg_range
531
+ self._attention_kwargs = attention_kwargs
532
+
533
+ # 2. Define call parameters
534
+ if prompt is not None and isinstance(prompt, str):
535
+ batch_size = 1
536
+ elif prompt is not None and isinstance(prompt, list):
537
+ batch_size = len(prompt)
538
+ else:
539
+ batch_size = prompt_embeds.shape[0]
540
+
541
+ device = self._execution_device
542
+
543
+ # 3. Encode input promptb
544
+ (
545
+ prompt_embeds,
546
+ prompt_attention_mask,
547
+ negative_prompt_embeds,
548
+ negative_prompt_attention_mask,
549
+ ) = self.encode_prompt(
550
+ prompt,
551
+ input_images,
552
+ self.text_guidance_scale > 1.0,
553
+ negative_prompt=negative_prompt,
554
+ num_images_per_prompt=num_images_per_prompt,
555
+ device=device,
556
+ prompt_embeds=prompt_embeds,
557
+ negative_prompt_embeds=negative_prompt_embeds,
558
+ prompt_attention_mask=prompt_attention_mask,
559
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
560
+ max_sequence_length=max_sequence_length,
561
+ use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats
562
+ )
563
+
564
+ dtype = self.vae.dtype
565
+ # 3. Prepare control image
566
+ ref_latents = self.prepare_image(
567
+ images=input_images,
568
+ batch_size=batch_size,
569
+ num_images_per_prompt=num_images_per_prompt,
570
+ max_pixels=max_pixels,
571
+ max_side_length=max_input_image_side_length,
572
+ device=device,
573
+ dtype=dtype,
574
+ )
575
+
576
+ if input_images is None:
577
+ input_images = []
578
+
579
+ if len(input_images) == 1 and align_res:
580
+ width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
581
+ ori_width, ori_height = width, height
582
+ else:
583
+ ori_width, ori_height = width, height
584
+
585
+ cur_pixels = height * width
586
+ ratio = (max_pixels / cur_pixels) ** 0.5
587
+ ratio = min(ratio, 1.0)
588
+
589
+ height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
590
+
591
+ if len(input_images) == 0:
592
+ self._image_guidance_scale = 1
593
+
594
+ # 4. Prepare latents.
595
+ latent_channels = self.transformer.config.in_channels
596
+ latents = self.prepare_latents(
597
+ batch_size * num_images_per_prompt,
598
+ latent_channels,
599
+ height,
600
+ width,
601
+ prompt_embeds.dtype,
602
+ device,
603
+ generator,
604
+ latents,
605
+ )
606
+
607
+ freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
608
+ self.transformer.config.axes_dim_rope,
609
+ self.transformer.config.axes_lens,
610
+ theta=10000,
611
+ )
612
+
613
+ image = self.processing(
614
+ latents=latents,
615
+ ref_latents=ref_latents,
616
+ prompt_embeds=prompt_embeds,
617
+ freqs_cis=freqs_cis,
618
+ negative_prompt_embeds=negative_prompt_embeds,
619
+ prompt_attention_mask=prompt_attention_mask,
620
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
621
+ num_inference_steps=num_inference_steps,
622
+ timesteps=timesteps,
623
+ device=device,
624
+ dtype=dtype,
625
+ verbose=verbose,
626
+ step_func=step_func,
627
+ )
628
+
629
+ image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
630
+
631
+ image = self.image_processor.postprocess(image, output_type=output_type)
632
+
633
+ # Offload all models
634
+ self.maybe_free_model_hooks()
635
+ return image
636
+
637
+ @torch.no_grad()
638
+ def __call__(
639
+ self,
640
+ prompt: Optional[Union[str, List[str]]] = None,
641
+ negative_prompt: Optional[Union[str, List[str]]] = None,
642
+ prompt_embeds: Optional[torch.FloatTensor] = None,
643
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
644
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
645
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
646
+ use_text_encoder_penultimate_layer_feats: bool = False,
647
+ max_sequence_length: Optional[int] = None,
648
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
649
+ input_images: Optional[List[PIL.Image.Image]] = None,
650
+ num_images_per_prompt: int = 1,
651
+ height: Optional[int] = 1024,
652
+ width: Optional[int] = 1024,
653
+ max_pixels: Optional[int] = 1024 * 1024,
654
+ max_input_image_side_length: int = 1024,
655
+ align_res: bool = True,
656
+ num_inference_steps: int = 28,
657
+ text_guidance_scale: float = 4.0,
658
+ image_guidance_scale: float = 1.0,
659
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
660
+ attention_kwargs: Optional[Dict[str, Any]] = None,
661
+ timesteps: List[int] = None,
662
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
663
+ latents: Optional[torch.FloatTensor] = None,
664
+ output_type: Optional[str] = "pil",
665
+ return_dict: bool = True,
666
+ verbose: bool = False,
667
+ step_func=None,
668
+ ):
669
+ assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn"
670
+
671
+ # input_images = self.preprocess_images(input_images, max_input_image_size)
672
+ prompt = self._apply_chat_template(prompt, input_images)
673
+ generated_text = self.generate_text(prompt, input_images)[0]
674
+
675
+ images = None
676
+ if generated_text.startswith("<|img|>"):
677
+ #TODO: reuse the hidden state when generate text instead of re-generating
678
+ prompt = prompt + generated_text.split("<|img|>")[0]
679
+ images = self.generate_image(
680
+ prompt=prompt,
681
+ negative_prompt=negative_prompt,
682
+ use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats,
683
+ max_sequence_length=max_sequence_length,
684
+ input_images=input_images,
685
+ num_images_per_prompt=num_images_per_prompt,
686
+ height=height,
687
+ width=width,
688
+ max_pixels=max_pixels,
689
+ max_input_image_side_length=max_input_image_side_length,
690
+ align_res=align_res,
691
+ num_inference_steps=num_inference_steps,
692
+ text_guidance_scale=text_guidance_scale,
693
+ image_guidance_scale=image_guidance_scale,
694
+ cfg_range=cfg_range,
695
+ timesteps=timesteps,
696
+ generator=generator,
697
+ latents=latents,
698
+ return_dict=False,
699
+ verbose=verbose,
700
+ step_func=step_func,
701
+ )
702
+
703
+ generated_text = generated_text.replace("<|im_end|>", "")
704
+ if not return_dict:
705
+ return generated_text, images
706
+ else:
707
+ return OmniGen2PipelineOutput(text=generated_text, images=images)
708
+
709
+ def processing(
710
+ self,
711
+ latents,
712
+ ref_latents,
713
+ prompt_embeds,
714
+ freqs_cis,
715
+ negative_prompt_embeds,
716
+ prompt_attention_mask,
717
+ negative_prompt_attention_mask,
718
+ num_inference_steps,
719
+ timesteps,
720
+ device,
721
+ dtype,
722
+ verbose,
723
+ step_func=None
724
+ ):
725
+ batch_size = latents.shape[0]
726
+
727
+ timesteps, num_inference_steps = retrieve_timesteps(
728
+ self.scheduler,
729
+ num_inference_steps,
730
+ device,
731
+ timesteps,
732
+ num_tokens=latents.shape[-2] * latents.shape[-1]
733
+ )
734
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
735
+ self._num_timesteps = len(timesteps)
736
+
737
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
738
+ for i, t in enumerate(timesteps):
739
+ model_pred = self.predict(
740
+ t=t,
741
+ latents=latents,
742
+ prompt_embeds=prompt_embeds,
743
+ freqs_cis=freqs_cis,
744
+ prompt_attention_mask=prompt_attention_mask,
745
+ ref_image_hidden_states=ref_latents,
746
+ )
747
+
748
+ text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
749
+ image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
750
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
751
+ model_pred_ref = self.predict(
752
+ t=t,
753
+ latents=latents,
754
+ prompt_embeds=negative_prompt_embeds,
755
+ freqs_cis=freqs_cis,
756
+ prompt_attention_mask=negative_prompt_attention_mask,
757
+ ref_image_hidden_states=ref_latents,
758
+ )
759
+
760
+ if image_guidance_scale != 1:
761
+ model_pred_uncond = self.predict(
762
+ t=t,
763
+ latents=latents,
764
+ prompt_embeds=negative_prompt_embeds,
765
+ freqs_cis=freqs_cis,
766
+ prompt_attention_mask=negative_prompt_attention_mask,
767
+ ref_image_hidden_states=None,
768
+ )
769
+ else:
770
+ model_pred_uncond = torch.zeros_like(model_pred)
771
+
772
+ model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
773
+ text_guidance_scale * (model_pred - model_pred_ref)
774
+ elif text_guidance_scale > 1.0:
775
+ model_pred_uncond = self.predict(
776
+ t=t,
777
+ latents=latents,
778
+ prompt_embeds=negative_prompt_embeds,
779
+ freqs_cis=freqs_cis,
780
+ prompt_attention_mask=negative_prompt_attention_mask,
781
+ ref_image_hidden_states=None,
782
+ )
783
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
784
+
785
+ latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
786
+
787
+ latents = latents.to(dtype=dtype)
788
+
789
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
790
+ progress_bar.update()
791
+
792
+ if step_func is not None:
793
+ step_func(i, self._num_timesteps)
794
+
795
+ latents = latents.to(dtype=dtype)
796
+ if self.vae.config.scaling_factor is not None:
797
+ latents = latents / self.vae.config.scaling_factor
798
+ if self.vae.config.shift_factor is not None:
799
+ latents = latents + self.vae.config.shift_factor
800
+ image = self.vae.decode(latents, return_dict=False)[0]
801
+
802
+ return image
803
+
804
+ def predict(
805
+ self,
806
+ t,
807
+ latents,
808
+ prompt_embeds,
809
+ freqs_cis,
810
+ prompt_attention_mask,
811
+ ref_image_hidden_states,
812
+ ):
813
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
814
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
815
+
816
+ batch_size, num_channels_latents, height, width = latents.shape
817
+
818
+ optional_kwargs = {}
819
+ if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
820
+ optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
821
+
822
+ model_pred = self.transformer(
823
+ latents,
824
+ timestep,
825
+ prompt_embeds,
826
+ freqs_cis,
827
+ prompt_attention_mask,
828
+ **optional_kwargs
829
+ )
830
+ return model_pred
omnigen2/pipelines/pipeline_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
5
+ """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
6
+ :param pipeline:
7
+ :param prompt:
8
+ :param negative_prompt:
9
+ :param device:
10
+ :return:
11
+ """
12
+ max_length = pipeline.tokenizer.model_max_length
13
+
14
+ # simple way to determine length of tokens
15
+ # count_prompt = len(prompt.split(" "))
16
+ # count_negative_prompt = len(negative_prompt.split(" "))
17
+
18
+ # create the tensor based on which prompt is longer
19
+ # if count_prompt >= count_negative_prompt:
20
+ input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
21
+ # input_ids = pipeline.tokenizer(prompt, padding="max_length",
22
+ # max_length=pipeline.tokenizer.model_max_length,
23
+ # truncation=True,
24
+ # return_tensors="pt",).input_ids.to(device)
25
+ shape_max_length = input_ids.shape[-1]
26
+
27
+ if negative_prompt is not None:
28
+ negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
29
+ max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
30
+
31
+ # else:
32
+ # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
33
+ # shape_max_length = negative_ids.shape[-1]
34
+ # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
35
+ # max_length=shape_max_length).input_ids.to(device)
36
+
37
+ concat_embeds = []
38
+ neg_embeds = []
39
+ for i in range(0, shape_max_length, max_length):
40
+ if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
41
+ attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
42
+ else:
43
+ attention_mask = None
44
+ concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
45
+ attention_mask=attention_mask)[0])
46
+
47
+ if negative_prompt is not None:
48
+ if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
49
+ attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
50
+ else:
51
+ attention_mask = None
52
+ neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
53
+ attention_mask=attention_mask)[0])
54
+
55
+ concat_embeds = torch.cat(concat_embeds, dim=1)
56
+
57
+ if negative_prompt is not None:
58
+ neg_embeds = torch.cat(neg_embeds, dim=1)
59
+ else:
60
+ neg_embeds = None
61
+
62
+ return concat_embeds, neg_embeds
omnigen2/schedulers/__init__.py ADDED
File without changes
omnigen2/schedulers/scheduling_dpmsolver_multistep.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import deprecate, is_scipy_available
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
+
28
+
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
34
+ def betas_for_alpha_bar(
35
+ num_diffusion_timesteps,
36
+ max_beta=0.999,
37
+ alpha_transform_type="cosine",
38
+ ):
39
+ """
40
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
41
+ (1-beta) over time from t = [0,1].
42
+
43
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
44
+ to that part of the diffusion process.
45
+
46
+
47
+ Args:
48
+ num_diffusion_timesteps (`int`): the number of betas to produce.
49
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
50
+ prevent singularities.
51
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
52
+ Choose from `cosine` or `exp`
53
+
54
+ Returns:
55
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
56
+ """
57
+ if alpha_transform_type == "cosine":
58
+
59
+ def alpha_bar_fn(t):
60
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
61
+
62
+ elif alpha_transform_type == "exp":
63
+
64
+ def alpha_bar_fn(t):
65
+ return math.exp(t * -12.0)
66
+
67
+ else:
68
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
69
+
70
+ betas = []
71
+ for i in range(num_diffusion_timesteps):
72
+ t1 = i / num_diffusion_timesteps
73
+ t2 = (i + 1) / num_diffusion_timesteps
74
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
+ return torch.tensor(betas, dtype=torch.float32)
76
+
77
+
78
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
79
+ def rescale_zero_terminal_snr(betas):
80
+ """
81
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
82
+
83
+
84
+ Args:
85
+ betas (`torch.Tensor`):
86
+ the betas that the scheduler is being initialized with.
87
+
88
+ Returns:
89
+ `torch.Tensor`: rescaled betas with zero terminal SNR
90
+ """
91
+ # Convert betas to alphas_bar_sqrt
92
+ alphas = 1.0 - betas
93
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
94
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
95
+
96
+ # Store old values.
97
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
98
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
99
+
100
+ # Shift so the last timestep is zero.
101
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
102
+
103
+ # Scale so the first timestep is back to the old value.
104
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
105
+
106
+ # Convert alphas_bar_sqrt to betas
107
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
108
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
109
+ alphas = torch.cat([alphas_bar[0:1], alphas])
110
+ betas = 1 - alphas
111
+
112
+ return betas
113
+
114
+
115
+ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
116
+ """
117
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
118
+
119
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
120
+ methods the library implements for all schedulers such as loading and saving.
121
+
122
+ Args:
123
+ num_train_timesteps (`int`, defaults to 1000):
124
+ The number of diffusion steps to train the model.
125
+ beta_start (`float`, defaults to 0.0001):
126
+ The starting `beta` value of inference.
127
+ beta_end (`float`, defaults to 0.02):
128
+ The final `beta` value.
129
+ beta_schedule (`str`, defaults to `"linear"`):
130
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
131
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
132
+ trained_betas (`np.ndarray`, *optional*):
133
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
134
+ solver_order (`int`, defaults to 2):
135
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
136
+ sampling, and `solver_order=3` for unconditional sampling.
137
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
138
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
139
+ `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
140
+ Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
141
+ thresholding (`bool`, defaults to `False`):
142
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
143
+ as Stable Diffusion.
144
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
145
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
146
+ sample_max_value (`float`, defaults to 1.0):
147
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
148
+ `algorithm_type="dpmsolver++"`.
149
+ algorithm_type (`str`, defaults to `dpmsolver++`):
150
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
151
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
152
+ paper, and the `dpmsolver++` type implements the algorithms in the
153
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
154
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
155
+ solver_type (`str`, defaults to `midpoint`):
156
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
157
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
158
+ lower_order_final (`bool`, defaults to `True`):
159
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
160
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
161
+ euler_at_final (`bool`, defaults to `False`):
162
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
163
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
164
+ steps, but sometimes may result in blurring.
165
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
166
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
167
+ the sigmas are determined according to a sequence of noise levels {σi}.
168
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
169
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
173
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
174
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
175
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
176
+ `lambda(t)`.
177
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
178
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
179
+ flow_shift (`float`, *optional*, defaults to 1.0):
180
+ The shift value for the timestep schedule for flow matching.
181
+ final_sigmas_type (`str`, defaults to `"zero"`):
182
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
183
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
184
+ lambda_min_clipped (`float`, defaults to `-inf`):
185
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
186
+ cosine (`squaredcos_cap_v2`) noise schedule.
187
+ variance_type (`str`, *optional*):
188
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
189
+ contains the predicted Gaussian variance.
190
+ timestep_spacing (`str`, defaults to `"linspace"`):
191
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
192
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
193
+ steps_offset (`int`, defaults to 0):
194
+ An offset added to the inference steps, as required by some model families.
195
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
196
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
197
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
198
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
199
+ """
200
+
201
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
202
+ order = 1
203
+
204
+ @register_to_config
205
+ def __init__(
206
+ self,
207
+ num_train_timesteps: int = 1000,
208
+ beta_start: float = 0.0001,
209
+ beta_end: float = 0.02,
210
+ beta_schedule: str = "linear",
211
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
212
+ solver_order: int = 2,
213
+ prediction_type: str = "epsilon",
214
+ thresholding: bool = False,
215
+ dynamic_thresholding_ratio: float = 0.995,
216
+ sample_max_value: float = 1.0,
217
+ algorithm_type: str = "dpmsolver++",
218
+ solver_type: str = "midpoint",
219
+ lower_order_final: bool = True,
220
+ euler_at_final: bool = False,
221
+ final_sigmas_type: str = 'zero',
222
+ dynamic_time_shift: bool = True
223
+ ):
224
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
225
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
226
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
227
+
228
+ if trained_betas is not None:
229
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
230
+ elif beta_schedule == "linear":
231
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
232
+ elif beta_schedule == "scaled_linear":
233
+ # this schedule is very specific to the latent diffusion model.
234
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
235
+ elif beta_schedule == "squaredcos_cap_v2":
236
+ # Glide cosine schedule
237
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
238
+ else:
239
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
240
+ self.alphas = 1.0 - self.betas
241
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
242
+
243
+ # Currently we only support VP-type noise schedule
244
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
245
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
246
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
247
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
248
+
249
+ # standard deviation of the initial noise distribution
250
+ self.init_noise_sigma = 1.0
251
+
252
+ # settings for DPM-Solver
253
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
254
+ if algorithm_type == "deis":
255
+ self.register_to_config(algorithm_type="dpmsolver++")
256
+ else:
257
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
258
+
259
+ if solver_type not in ["midpoint", "heun"]:
260
+ if solver_type in ["logrho", "bh1", "bh2"]:
261
+ self.register_to_config(solver_type="midpoint")
262
+ else:
263
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
264
+
265
+ # if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
266
+ # raise ValueError(
267
+ # f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
268
+ # )
269
+
270
+ # setable values
271
+ self.num_inference_steps = None
272
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
273
+ self.timesteps = torch.from_numpy(timesteps)
274
+ self.model_outputs = [None] * solver_order
275
+ self.lower_order_nums = 0
276
+ self._step_index = None
277
+ self._begin_index = None
278
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
279
+
280
+ @property
281
+ def step_index(self):
282
+ """
283
+ The index counter for current timestep. It will increase 1 after each scheduler step.
284
+ """
285
+ return self._step_index
286
+
287
+ @property
288
+ def begin_index(self):
289
+ """
290
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
291
+ """
292
+ return self._begin_index
293
+
294
+ def set_begin_index(self, begin_index: int = 0):
295
+ """
296
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
297
+
298
+ Args:
299
+ begin_index (`int`):
300
+ The begin index for the scheduler.
301
+ """
302
+ self._begin_index = begin_index
303
+
304
+ def set_timesteps(
305
+ self,
306
+ num_inference_steps: int = None,
307
+ device: Union[str, torch.device] = None,
308
+ timesteps: Optional[List[int]] = None,
309
+ num_tokens: Optional[int] = None
310
+ ):
311
+ if timesteps is None:
312
+ self.num_inference_steps = num_inference_steps
313
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
314
+ if self.config.dynamic_time_shift and num_tokens is not None:
315
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
316
+ timesteps = timesteps / (m - m * timesteps + timesteps)
317
+
318
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
319
+ sigmas = torch.cat([1 - timesteps, torch.zeros(1, device=timesteps.device)])
320
+
321
+ self.sigmas = sigmas
322
+ self.timesteps = timesteps
323
+
324
+ self.num_inference_steps = len(timesteps)
325
+
326
+ self.model_outputs = [
327
+ None,
328
+ ] * self.config.solver_order
329
+ self.lower_order_nums = 0
330
+
331
+ # add an index counter for schedulers that allow duplicated timesteps
332
+ self._step_index = None
333
+ self._begin_index = None
334
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
335
+
336
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
337
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
338
+ """
339
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
340
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
341
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
342
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
343
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
344
+
345
+ https://arxiv.org/abs/2205.11487
346
+ """
347
+ dtype = sample.dtype
348
+ batch_size, channels, *remaining_dims = sample.shape
349
+
350
+ if dtype not in (torch.float32, torch.float64):
351
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
352
+
353
+ # Flatten sample for doing quantile calculation along each image
354
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
355
+
356
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
357
+
358
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
359
+ s = torch.clamp(
360
+ s, min=1, max=self.config.sample_max_value
361
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
362
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
363
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
364
+
365
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
366
+ sample = sample.to(dtype)
367
+
368
+ return sample
369
+
370
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
371
+ def _sigma_to_t(self, sigma, log_sigmas):
372
+ # get log sigma
373
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
374
+
375
+ # get distribution
376
+ dists = log_sigma - log_sigmas[:, np.newaxis]
377
+
378
+ # get sigmas range
379
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
380
+ high_idx = low_idx + 1
381
+
382
+ low = log_sigmas[low_idx]
383
+ high = log_sigmas[high_idx]
384
+
385
+ # interpolate sigmas
386
+ w = (low - log_sigma) / (low - high)
387
+ w = np.clip(w, 0, 1)
388
+
389
+ # transform interpolation to time range
390
+ t = (1 - w) * low_idx + w * high_idx
391
+ t = t.reshape(sigma.shape)
392
+ return t
393
+
394
+ def _sigma_to_alpha_sigma_t(self, sigma):
395
+ alpha_t = 1 - sigma
396
+ sigma_t = sigma
397
+
398
+ return alpha_t, sigma_t
399
+
400
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
401
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
402
+ """Constructs the noise schedule of Karras et al. (2022)."""
403
+
404
+ # Hack to make sure that other schedulers which copy this function don't break
405
+ # TODO: Add this logic to the other schedulers
406
+ if hasattr(self.config, "sigma_min"):
407
+ sigma_min = self.config.sigma_min
408
+ else:
409
+ sigma_min = None
410
+
411
+ if hasattr(self.config, "sigma_max"):
412
+ sigma_max = self.config.sigma_max
413
+ else:
414
+ sigma_max = None
415
+
416
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
417
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
418
+
419
+ rho = 7.0 # 7.0 is the value used in the paper
420
+ ramp = np.linspace(0, 1, num_inference_steps)
421
+ min_inv_rho = sigma_min ** (1 / rho)
422
+ max_inv_rho = sigma_max ** (1 / rho)
423
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
424
+ return sigmas
425
+
426
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
427
+ """Constructs the noise schedule of Lu et al. (2022)."""
428
+
429
+ lambda_min: float = in_lambdas[-1].item()
430
+ lambda_max: float = in_lambdas[0].item()
431
+
432
+ rho = 1.0 # 1.0 is the value used in the paper
433
+ ramp = np.linspace(0, 1, num_inference_steps)
434
+ min_inv_rho = lambda_min ** (1 / rho)
435
+ max_inv_rho = lambda_max ** (1 / rho)
436
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
437
+ return lambdas
438
+
439
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
440
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
441
+ """Constructs an exponential noise schedule."""
442
+
443
+ # Hack to make sure that other schedulers which copy this function don't break
444
+ # TODO: Add this logic to the other schedulers
445
+ if hasattr(self.config, "sigma_min"):
446
+ sigma_min = self.config.sigma_min
447
+ else:
448
+ sigma_min = None
449
+
450
+ if hasattr(self.config, "sigma_max"):
451
+ sigma_max = self.config.sigma_max
452
+ else:
453
+ sigma_max = None
454
+
455
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
456
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
457
+
458
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
459
+ return sigmas
460
+
461
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
462
+ def _convert_to_beta(
463
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
464
+ ) -> torch.Tensor:
465
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
466
+
467
+ # Hack to make sure that other schedulers which copy this function don't break
468
+ # TODO: Add this logic to the other schedulers
469
+ if hasattr(self.config, "sigma_min"):
470
+ sigma_min = self.config.sigma_min
471
+ else:
472
+ sigma_min = None
473
+
474
+ if hasattr(self.config, "sigma_max"):
475
+ sigma_max = self.config.sigma_max
476
+ else:
477
+ sigma_max = None
478
+
479
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
480
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
481
+
482
+ sigmas = np.array(
483
+ [
484
+ sigma_min + (ppf * (sigma_max - sigma_min))
485
+ for ppf in [
486
+ scipy.stats.beta.ppf(timestep, alpha, beta)
487
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
488
+ ]
489
+ ]
490
+ )
491
+ return sigmas
492
+
493
+ def convert_model_output(
494
+ self,
495
+ model_output: torch.Tensor,
496
+ *args,
497
+ sample: torch.Tensor = None,
498
+ **kwargs,
499
+ ) -> torch.Tensor:
500
+ """
501
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
502
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
503
+ integral of the data prediction model.
504
+
505
+ <Tip>
506
+
507
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
508
+ prediction and data prediction models.
509
+
510
+ </Tip>
511
+
512
+ Args:
513
+ model_output (`torch.Tensor`):
514
+ The direct output from the learned diffusion model.
515
+ sample (`torch.Tensor`):
516
+ A current instance of a sample created by the diffusion process.
517
+
518
+ Returns:
519
+ `torch.Tensor`:
520
+ The converted model output.
521
+ """
522
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
523
+ if sample is None:
524
+ if len(args) > 1:
525
+ sample = args[1]
526
+ else:
527
+ raise ValueError("missing `sample` as a required keyward argument")
528
+ if timestep is not None:
529
+ deprecate(
530
+ "timesteps",
531
+ "1.0.0",
532
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
533
+ )
534
+
535
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
536
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
537
+ if self.config.prediction_type == "epsilon":
538
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
539
+ if self.config.variance_type in ["learned", "learned_range"]:
540
+ model_output = model_output[:, :3]
541
+ sigma = self.sigmas[self.step_index]
542
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
543
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
544
+ elif self.config.prediction_type == "sample":
545
+ x0_pred = model_output
546
+ elif self.config.prediction_type == "v_prediction":
547
+ sigma = self.sigmas[self.step_index]
548
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
549
+ x0_pred = alpha_t * sample - sigma_t * model_output
550
+ elif self.config.prediction_type == "flow_prediction":
551
+ sigma_t = self.sigmas[self.step_index]
552
+ x0_pred = sample + sigma_t * model_output
553
+ else:
554
+ raise ValueError(
555
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
556
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
557
+ )
558
+
559
+ if self.config.thresholding:
560
+ x0_pred = self._threshold_sample(x0_pred)
561
+
562
+ return x0_pred
563
+
564
+ # DPM-Solver needs to solve an integral of the noise prediction model.
565
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
566
+ if self.config.prediction_type == "epsilon":
567
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
568
+ if self.config.variance_type in ["learned", "learned_range"]:
569
+ epsilon = model_output[:, :3]
570
+ else:
571
+ epsilon = model_output
572
+ elif self.config.prediction_type == "sample":
573
+ sigma = self.sigmas[self.step_index]
574
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
575
+ epsilon = (sample - alpha_t * model_output) / sigma_t
576
+ elif self.config.prediction_type == "v_prediction":
577
+ sigma = self.sigmas[self.step_index]
578
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
579
+ epsilon = alpha_t * model_output + sigma_t * sample
580
+ else:
581
+ raise ValueError(
582
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
583
+ " `v_prediction` for the DPMSolverMultistepScheduler."
584
+ )
585
+
586
+ if self.config.thresholding:
587
+ sigma = self.sigmas[self.step_index]
588
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
589
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
590
+ x0_pred = self._threshold_sample(x0_pred)
591
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
592
+
593
+ return epsilon
594
+
595
+ def dpm_solver_first_order_update(
596
+ self,
597
+ model_output: torch.Tensor,
598
+ *args,
599
+ sample: torch.Tensor = None,
600
+ noise: Optional[torch.Tensor] = None,
601
+ **kwargs,
602
+ ) -> torch.Tensor:
603
+ """
604
+ One step for the first-order DPMSolver (equivalent to DDIM).
605
+
606
+ Args:
607
+ model_output (`torch.Tensor`):
608
+ The direct output from the learned diffusion model.
609
+ sample (`torch.Tensor`):
610
+ A current instance of a sample created by the diffusion process.
611
+
612
+ Returns:
613
+ `torch.Tensor`:
614
+ The sample tensor at the previous timestep.
615
+ """
616
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
617
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
618
+ if sample is None:
619
+ if len(args) > 2:
620
+ sample = args[2]
621
+ else:
622
+ raise ValueError(" missing `sample` as a required keyward argument")
623
+ if timestep is not None:
624
+ deprecate(
625
+ "timesteps",
626
+ "1.0.0",
627
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
628
+ )
629
+
630
+ if prev_timestep is not None:
631
+ deprecate(
632
+ "prev_timestep",
633
+ "1.0.0",
634
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
635
+ )
636
+
637
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
638
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
639
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
640
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
641
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
642
+
643
+ h = lambda_t - lambda_s
644
+ if self.config.algorithm_type == "dpmsolver++":
645
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
646
+ elif self.config.algorithm_type == "dpmsolver":
647
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
648
+ elif self.config.algorithm_type == "sde-dpmsolver++":
649
+ assert noise is not None
650
+ x_t = (
651
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
652
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
653
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
654
+ )
655
+ elif self.config.algorithm_type == "sde-dpmsolver":
656
+ assert noise is not None
657
+ x_t = (
658
+ (alpha_t / alpha_s) * sample
659
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
660
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
661
+ )
662
+ return x_t
663
+
664
+ def multistep_dpm_solver_second_order_update(
665
+ self,
666
+ model_output_list: List[torch.Tensor],
667
+ *args,
668
+ sample: torch.Tensor = None,
669
+ noise: Optional[torch.Tensor] = None,
670
+ **kwargs,
671
+ ) -> torch.Tensor:
672
+ """
673
+ One step for the second-order multistep DPMSolver.
674
+
675
+ Args:
676
+ model_output_list (`List[torch.Tensor]`):
677
+ The direct outputs from learned diffusion model at current and latter timesteps.
678
+ sample (`torch.Tensor`):
679
+ A current instance of a sample created by the diffusion process.
680
+
681
+ Returns:
682
+ `torch.Tensor`:
683
+ The sample tensor at the previous timestep.
684
+ """
685
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
686
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
687
+ if sample is None:
688
+ if len(args) > 2:
689
+ sample = args[2]
690
+ else:
691
+ raise ValueError(" missing `sample` as a required keyward argument")
692
+ if timestep_list is not None:
693
+ deprecate(
694
+ "timestep_list",
695
+ "1.0.0",
696
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
697
+ )
698
+
699
+ if prev_timestep is not None:
700
+ deprecate(
701
+ "prev_timestep",
702
+ "1.0.0",
703
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
704
+ )
705
+
706
+ sigma_t, sigma_s0, sigma_s1 = (
707
+ self.sigmas[self.step_index + 1],
708
+ self.sigmas[self.step_index],
709
+ self.sigmas[self.step_index - 1],
710
+ )
711
+
712
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
713
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
714
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
715
+
716
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
717
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
718
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
719
+
720
+ m0, m1 = model_output_list[-1], model_output_list[-2]
721
+
722
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
723
+ r0 = h_0 / h
724
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
725
+ if self.config.algorithm_type == "dpmsolver++":
726
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
727
+ if self.config.solver_type == "midpoint":
728
+ x_t = (
729
+ (sigma_t / sigma_s0) * sample
730
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
731
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
732
+ )
733
+ elif self.config.solver_type == "heun":
734
+ x_t = (
735
+ (sigma_t / sigma_s0) * sample
736
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
737
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
738
+ )
739
+ elif self.config.algorithm_type == "dpmsolver":
740
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
741
+ if self.config.solver_type == "midpoint":
742
+ x_t = (
743
+ (alpha_t / alpha_s0) * sample
744
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
745
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
746
+ )
747
+ elif self.config.solver_type == "heun":
748
+ x_t = (
749
+ (alpha_t / alpha_s0) * sample
750
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
751
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
752
+ )
753
+ elif self.config.algorithm_type == "sde-dpmsolver++":
754
+ assert noise is not None
755
+ if self.config.solver_type == "midpoint":
756
+ x_t = (
757
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
758
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
759
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
760
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
761
+ )
762
+ elif self.config.solver_type == "heun":
763
+ x_t = (
764
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
765
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
766
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
767
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
768
+ )
769
+ elif self.config.algorithm_type == "sde-dpmsolver":
770
+ assert noise is not None
771
+ if self.config.solver_type == "midpoint":
772
+ x_t = (
773
+ (alpha_t / alpha_s0) * sample
774
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
775
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
776
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
777
+ )
778
+ elif self.config.solver_type == "heun":
779
+ x_t = (
780
+ (alpha_t / alpha_s0) * sample
781
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
782
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
783
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
784
+ )
785
+ return x_t
786
+
787
+ def multistep_dpm_solver_third_order_update(
788
+ self,
789
+ model_output_list: List[torch.Tensor],
790
+ *args,
791
+ sample: torch.Tensor = None,
792
+ noise: Optional[torch.Tensor] = None,
793
+ **kwargs,
794
+ ) -> torch.Tensor:
795
+ """
796
+ One step for the third-order multistep DPMSolver.
797
+
798
+ Args:
799
+ model_output_list (`List[torch.Tensor]`):
800
+ The direct outputs from learned diffusion model at current and latter timesteps.
801
+ sample (`torch.Tensor`):
802
+ A current instance of a sample created by diffusion process.
803
+
804
+ Returns:
805
+ `torch.Tensor`:
806
+ The sample tensor at the previous timestep.
807
+ """
808
+
809
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
810
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
811
+ if sample is None:
812
+ if len(args) > 2:
813
+ sample = args[2]
814
+ else:
815
+ raise ValueError(" missing`sample` as a required keyward argument")
816
+ if timestep_list is not None:
817
+ deprecate(
818
+ "timestep_list",
819
+ "1.0.0",
820
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
821
+ )
822
+
823
+ if prev_timestep is not None:
824
+ deprecate(
825
+ "prev_timestep",
826
+ "1.0.0",
827
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
828
+ )
829
+
830
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
831
+ self.sigmas[self.step_index + 1],
832
+ self.sigmas[self.step_index],
833
+ self.sigmas[self.step_index - 1],
834
+ self.sigmas[self.step_index - 2],
835
+ )
836
+
837
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
838
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
839
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
840
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
841
+
842
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
843
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
844
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
845
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
846
+
847
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
848
+
849
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
850
+ r0, r1 = h_0 / h, h_1 / h
851
+ D0 = m0
852
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
853
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
854
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
855
+ if self.config.algorithm_type == "dpmsolver++":
856
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
857
+ x_t = (
858
+ (sigma_t / sigma_s0) * sample
859
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
860
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
861
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
862
+ )
863
+ elif self.config.algorithm_type == "dpmsolver":
864
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
865
+ x_t = (
866
+ (alpha_t / alpha_s0) * sample
867
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
868
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
869
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
870
+ )
871
+ elif self.config.algorithm_type == "sde-dpmsolver++":
872
+ assert noise is not None
873
+ x_t = (
874
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
875
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
876
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
877
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
878
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
879
+ )
880
+ return x_t
881
+
882
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
883
+ if schedule_timesteps is None:
884
+ schedule_timesteps = self.timesteps
885
+
886
+ index_candidates = (schedule_timesteps == timestep).nonzero()
887
+
888
+ if len(index_candidates) == 0:
889
+ step_index = len(self.timesteps) - 1
890
+ # The sigma index that is taken for the **very** first `step`
891
+ # is always the second index (or the last index if there is only 1)
892
+ # This way we can ensure we don't accidentally skip a sigma in
893
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
894
+ elif len(index_candidates) > 1:
895
+ step_index = index_candidates[1].item()
896
+ else:
897
+ step_index = index_candidates[0].item()
898
+
899
+ return step_index
900
+
901
+ def _init_step_index(self, timestep):
902
+ """
903
+ Initialize the step_index counter for the scheduler.
904
+ """
905
+
906
+ if self.begin_index is None:
907
+ if isinstance(timestep, torch.Tensor):
908
+ timestep = timestep.to(self.timesteps.device)
909
+ self._step_index = self.index_for_timestep(timestep)
910
+ else:
911
+ self._step_index = self._begin_index
912
+
913
+ def step(
914
+ self,
915
+ model_output: torch.Tensor,
916
+ timestep: Union[int, torch.Tensor],
917
+ sample: torch.Tensor,
918
+ generator=None,
919
+ variance_noise: Optional[torch.Tensor] = None,
920
+ return_dict: bool = True,
921
+ ) -> Union[SchedulerOutput, Tuple]:
922
+ """
923
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
924
+ the multistep DPMSolver.
925
+
926
+ Args:
927
+ model_output (`torch.Tensor`):
928
+ The direct output from learned diffusion model.
929
+ timestep (`int`):
930
+ The current discrete timestep in the diffusion chain.
931
+ sample (`torch.Tensor`):
932
+ A current instance of a sample created by the diffusion process.
933
+ generator (`torch.Generator`, *optional*):
934
+ A random number generator.
935
+ variance_noise (`torch.Tensor`):
936
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
937
+ itself. Useful for methods such as [`LEdits++`].
938
+ return_dict (`bool`):
939
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
940
+
941
+ Returns:
942
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
943
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
944
+ tuple is returned where the first element is the sample tensor.
945
+
946
+ """
947
+ if self.num_inference_steps is None:
948
+ raise ValueError(
949
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
950
+ )
951
+
952
+ if self.step_index is None:
953
+ self._init_step_index(timestep)
954
+
955
+ # Improve numerical stability for small number of steps
956
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
957
+ self.config.euler_at_final
958
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
959
+ or self.config.final_sigmas_type == "zero"
960
+ )
961
+ lower_order_second = (
962
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
963
+ )
964
+
965
+ model_output = self.convert_model_output(model_output, sample=sample)
966
+ for i in range(self.config.solver_order - 1):
967
+ self.model_outputs[i] = self.model_outputs[i + 1]
968
+ self.model_outputs[-1] = model_output
969
+
970
+ # Upcast to avoid precision issues when computing prev_sample
971
+ sample = sample.to(torch.float32)
972
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
973
+ noise = randn_tensor(
974
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
975
+ )
976
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
977
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
978
+ else:
979
+ noise = None
980
+
981
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
982
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
983
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
984
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
985
+ else:
986
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
987
+
988
+ if self.lower_order_nums < self.config.solver_order:
989
+ self.lower_order_nums += 1
990
+
991
+ # Cast sample back to expected dtype
992
+ prev_sample = prev_sample.to(model_output.dtype)
993
+
994
+ # upon completion increase step index by one
995
+ self._step_index += 1
996
+
997
+ if not return_dict:
998
+ return (prev_sample,)
999
+
1000
+ return SchedulerOutput(prev_sample=prev_sample)
1001
+
1002
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1003
+ """
1004
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
1005
+ current timestep.
1006
+
1007
+ Args:
1008
+ sample (`torch.Tensor`):
1009
+ The input sample.
1010
+
1011
+ Returns:
1012
+ `torch.Tensor`:
1013
+ A scaled input sample.
1014
+ """
1015
+ return sample
1016
+
1017
+ def add_noise(
1018
+ self,
1019
+ original_samples: torch.Tensor,
1020
+ noise: torch.Tensor,
1021
+ timesteps: torch.IntTensor,
1022
+ ) -> torch.Tensor:
1023
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
1024
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
1025
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
1026
+ # mps does not support float64
1027
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
1028
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
1029
+ else:
1030
+ schedule_timesteps = self.timesteps.to(original_samples.device)
1031
+ timesteps = timesteps.to(original_samples.device)
1032
+
1033
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1034
+ if self.begin_index is None:
1035
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1036
+ elif self.step_index is not None:
1037
+ # add_noise is called after first denoising step (for inpainting)
1038
+ step_indices = [self.step_index] * timesteps.shape[0]
1039
+ else:
1040
+ # add noise is called before first denoising step to create initial latent(img2img)
1041
+ step_indices = [self.begin_index] * timesteps.shape[0]
1042
+
1043
+ sigma = sigmas[step_indices].flatten()
1044
+ while len(sigma.shape) < len(original_samples.shape):
1045
+ sigma = sigma.unsqueeze(-1)
1046
+
1047
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
1048
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
1049
+ return noisy_samples
1050
+
1051
+ def __len__(self):
1052
+ return self.config.num_train_timesteps
omnigen2/schedulers/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ dynamic_time_shift: bool = True
69
+ ):
70
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
71
+
72
+ self.timesteps = timesteps
73
+
74
+ self._step_index = None
75
+ self._begin_index = None
76
+
77
+ @property
78
+ def step_index(self):
79
+ """
80
+ The index counter for current timestep. It will increase 1 after each scheduler step.
81
+ """
82
+ return self._step_index
83
+
84
+ @property
85
+ def begin_index(self):
86
+ """
87
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
88
+ """
89
+ return self._begin_index
90
+
91
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
92
+ def set_begin_index(self, begin_index: int = 0):
93
+ """
94
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
95
+
96
+ Args:
97
+ begin_index (`int`):
98
+ The begin index for the scheduler.
99
+ """
100
+ self._begin_index = begin_index
101
+
102
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
103
+ if schedule_timesteps is None:
104
+ schedule_timesteps = self._timesteps
105
+
106
+ indices = (schedule_timesteps == timestep).nonzero()
107
+
108
+ # The sigma index that is taken for the **very** first `step`
109
+ # is always the second index (or the last index if there is only 1)
110
+ # This way we can ensure we don't accidentally skip a sigma in
111
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
112
+ pos = 1 if len(indices) > 1 else 0
113
+
114
+ return indices[pos].item()
115
+
116
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
117
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
118
+
119
+ def set_timesteps(
120
+ self,
121
+ num_inference_steps: int = None,
122
+ device: Union[str, torch.device] = None,
123
+ timesteps: Optional[List[float]] = None,
124
+ num_tokens: Optional[int] = None
125
+ ):
126
+ """
127
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
128
+
129
+ Args:
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model.
132
+ device (`str` or `torch.device`, *optional*):
133
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
134
+ """
135
+
136
+ if timesteps is None:
137
+ self.num_inference_steps = num_inference_steps
138
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
139
+ if self.config.dynamic_time_shift and num_tokens is not None:
140
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
141
+ timesteps = timesteps / (m - m * timesteps + timesteps)
142
+
143
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
144
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
145
+
146
+ self.timesteps = timesteps
147
+ self._timesteps = _timesteps
148
+ self._step_index = None
149
+ self._begin_index = None
150
+
151
+ def _init_step_index(self, timestep):
152
+ if self.begin_index is None:
153
+ if isinstance(timestep, torch.Tensor):
154
+ timestep = timestep.to(self.timesteps.device)
155
+ self._step_index = self.index_for_timestep(timestep)
156
+ else:
157
+ self._step_index = self._begin_index
158
+
159
+ def step(
160
+ self,
161
+ model_output: torch.FloatTensor,
162
+ timestep: Union[float, torch.FloatTensor],
163
+ sample: torch.FloatTensor,
164
+ generator: Optional[torch.Generator] = None,
165
+ return_dict: bool = True,
166
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
167
+ """
168
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
169
+ process from the learned model outputs (most often the predicted noise).
170
+
171
+ Args:
172
+ model_output (`torch.FloatTensor`):
173
+ The direct output from learned diffusion model.
174
+ timestep (`float`):
175
+ The current discrete timestep in the diffusion chain.
176
+ sample (`torch.FloatTensor`):
177
+ A current instance of a sample created by the diffusion process.
178
+ s_churn (`float`):
179
+ s_tmin (`float`):
180
+ s_tmax (`float`):
181
+ s_noise (`float`, defaults to 1.0):
182
+ Scaling factor for noise added to the sample.
183
+ generator (`torch.Generator`, *optional*):
184
+ A random number generator.
185
+ return_dict (`bool`):
186
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
187
+ tuple.
188
+
189
+ Returns:
190
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
191
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
192
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
193
+ """
194
+
195
+ if (
196
+ isinstance(timestep, int)
197
+ or isinstance(timestep, torch.IntTensor)
198
+ or isinstance(timestep, torch.LongTensor)
199
+ ):
200
+ raise ValueError(
201
+ (
202
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204
+ " one of the `scheduler.timesteps` as a timestep."
205
+ ),
206
+ )
207
+
208
+ if self.step_index is None:
209
+ self._init_step_index(timestep)
210
+ # Upcast to avoid precision issues when computing prev_sample
211
+ sample = sample.to(torch.float32)
212
+ t = self._timesteps[self.step_index]
213
+ t_next = self._timesteps[self.step_index + 1]
214
+
215
+ prev_sample = sample + (t_next - t) * model_output
216
+
217
+ # Cast sample back to model compatible dtype
218
+ prev_sample = prev_sample.to(model_output.dtype)
219
+
220
+ # upon completion increase step index by one
221
+ self._step_index += 1
222
+
223
+ if not return_dict:
224
+ return (prev_sample,)
225
+
226
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
227
+
228
+ def __len__(self):
229
+ return self.config.num_train_timesteps
omnigen2/utils/__init__.py ADDED
File without changes
omnigen2/utils/img_util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from torchvision.transforms.functional import to_pil_image
7
+
8
+ def resize_image(image, max_pixels, img_scale_num):
9
+ width, height = image.size
10
+ cur_pixels = height * width
11
+ ratio = (max_pixels / cur_pixels) ** 0.5
12
+ ratio = min(ratio, 1.0) # do not upscale input image
13
+
14
+ new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num
15
+
16
+ image = image.resize((new_width, new_height), resample=Image.BICUBIC)
17
+ return image
18
+
19
+ def create_collage(images: List[torch.Tensor]) -> Image.Image:
20
+ """Create a horizontal collage from a list of images."""
21
+ max_height = max(img.shape[-2] for img in images)
22
+ total_width = sum(img.shape[-1] for img in images)
23
+ canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
24
+
25
+ current_x = 0
26
+ for img in images:
27
+ h, w = img.shape[-2:]
28
+ canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
29
+ current_x += w
30
+
31
+ return to_pil_image(canvas)
omnigen2/utils/vpn_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import json
4
+ import yaml
5
+
6
+ import requests
7
+
8
+ class VPNManager:
9
+ def __init__(self, config_path: str = '/etc/mihomo/config.yaml'):
10
+ with open(config_path, 'r') as f:
11
+ config = yaml.safe_load(f)
12
+ self.external_controller = config['external-controller']
13
+ self.external_controller = self.external_controller.replace('0.0.0.0', '127.0.0.1')
14
+ self.secret = config['secret']
15
+
16
+ self.headers = {"Authorization": f"Bearer {self.secret}"}
17
+
18
+ self.unavailable_nodes = set()
19
+
20
+ @property
21
+ def current_node(self):
22
+ url = f"http://{self.external_controller}/group/Proxy"
23
+ r = requests.request("GET", url, headers=self.headers)
24
+ return r.json()['now']
25
+
26
+ @property
27
+ def available_nodes(self):
28
+ return list(self.get_available_vpn_nodes() - self.unavailable_nodes)
29
+
30
+ def switch_vpn_node(self, node_name):
31
+ url = f"http://{self.external_controller}/proxies/Proxy"
32
+
33
+ payload = json.dumps({
34
+ "name": node_name
35
+ })
36
+ headers = self.headers.copy()
37
+ headers.update({'Content-Type': 'application/json'})
38
+ r = requests.request("PUT", url, headers=headers, data=payload)
39
+ if r.status_code != 204:
40
+ raise Warning(f"Failed to switch to {node_name}")
41
+ return r.status_code == 204
42
+
43
+ def get_random_available_vpn_node(self):
44
+ return random.choice(self.available_nodes)
45
+
46
+ def random_switch_vpn_node(self):
47
+ node_name = self.get_random_available_vpn_node()
48
+ print(f"Switching to {node_name}")
49
+ self.switch_vpn_node(node_name)
50
+ # self.current_node = node_name
51
+ return node_name
52
+
53
+ def get_vpn_nodes(self):
54
+ url = f"http://{self.external_controller}/group/Proxy"
55
+ delay_res = requests.get(url, headers=self.headers)
56
+ return delay_res.json()['all']
57
+
58
+ def get_available_vpn_nodes(self):
59
+ url = f"http://{self.external_controller}/group/Proxy/delay?timeout=5000&url=http://www.gstatic.com/generate_204"
60
+ delay_res = requests.get(url, headers=self.headers)
61
+ return set(delay_res.json().keys())
62
+
63
+ def get_current_vpn_node_ip(self):
64
+ url = "http://ifconfig.me"
65
+ r = requests.request("GET", url)
66
+ return r.text
67
+
68
+ def add_unavailable_node(self, node_name):
69
+ self.unavailable_nodes.add(node_name)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ timm
4
+ einops
5
+ accelerate
6
+ transformers==4.51.3
7
+ diffusers
8
+ opencv-python-headless
9
+ scipy
10
+ wandb
11
+ matplotlib
12
+ Pillow
13
+ tqdm
14
+ omegaconf
15
+ python-dotenv
16
+ ninja