cb1cyf commited on
Commit
cf4796c
·
1 Parent(s): 1869f10

fix: omnigen2

Browse files
Files changed (46) hide show
  1. omnigen2/__init__.py +0 -0
  2. omnigen2/cache_functions/__init__.py +3 -0
  3. omnigen2/cache_functions/cache_init.py +38 -0
  4. omnigen2/cache_functions/cal_type.py +41 -0
  5. omnigen2/cache_functions/force_scheduler.py +19 -0
  6. omnigen2/dataset/__init__.py +0 -0
  7. omnigen2/dataset/omnigen2_test_dataset.py +153 -0
  8. omnigen2/dataset/omnigen2_train_dataset.py +203 -0
  9. omnigen2/models/__init__.py +0 -0
  10. omnigen2/models/attention_processor.py +357 -0
  11. omnigen2/models/embeddings.py +126 -0
  12. omnigen2/models/transformers/__init__.py +3 -0
  13. omnigen2/models/transformers/block_lumina2.py +218 -0
  14. omnigen2/models/transformers/components.py +4 -0
  15. omnigen2/models/transformers/repo.py +129 -0
  16. omnigen2/models/transformers/transformer_omnigen2.py +716 -0
  17. omnigen2/ops/triton/__init__.py +0 -0
  18. omnigen2/ops/triton/layer_norm.py +1257 -0
  19. omnigen2/optim/__init__.py +0 -0
  20. omnigen2/optim/scheduler/__init__.py +0 -0
  21. omnigen2/optim/scheduler/cosine_lr.py +118 -0
  22. omnigen2/optim/scheduler/scheduler.py +131 -0
  23. omnigen2/optim/scheduler/step_lr.py +63 -0
  24. omnigen2/pipelines/__init__.py +0 -0
  25. omnigen2/pipelines/image_processor.py +266 -0
  26. omnigen2/pipelines/lora_pipeline.py +388 -0
  27. omnigen2/pipelines/omnigen2/pipeline_omnigen2.py +774 -0
  28. omnigen2/pipelines/omnigen2/pipeline_omnigen2_chat.py +830 -0
  29. omnigen2/pipelines/pipeline_utils.py +62 -0
  30. omnigen2/schedulers/__init__.py +0 -0
  31. omnigen2/schedulers/scheduling_dpmsolver_multistep.py +1052 -0
  32. omnigen2/schedulers/scheduling_flow_match_euler_discrete.py +229 -0
  33. omnigen2/taylorseer_utils/__init__.py +51 -0
  34. omnigen2/training_utils.py +645 -0
  35. omnigen2/transport/__init__.py +74 -0
  36. omnigen2/transport/dpm_solver.py +1386 -0
  37. omnigen2/transport/integrators.py +122 -0
  38. omnigen2/transport/path.py +201 -0
  39. omnigen2/transport/transport.py +545 -0
  40. omnigen2/transport/utils.py +56 -0
  41. omnigen2/utils/__init__.py +0 -0
  42. omnigen2/utils/img_util.py +31 -0
  43. omnigen2/utils/import_utils.py +46 -0
  44. omnigen2/utils/logging_utils.py +15 -0
  45. omnigen2/utils/reproducibility.py +22 -0
  46. omnigen2/utils/teacache_util.py +43 -0
omnigen2/__init__.py ADDED
File without changes
omnigen2/cache_functions/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cache_init import cache_init
2
+ from .cal_type import cal_type
3
+ from .force_scheduler import force_scheduler
omnigen2/cache_functions/cache_init.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cache_init.py
2
+
3
+ # Type hinting would cause circular import, self should be `OmniGen2Pipeline`
4
+ def cache_init(self, num_steps: int):
5
+ '''
6
+ Initialization for cache.
7
+ '''
8
+ cache_dic = {}
9
+ cache = {}
10
+ cache_index = {}
11
+ cache[-1]={}
12
+ cache_index[-1]={}
13
+ cache_index['layer_index']={}
14
+ cache[-1]['layers_stream']={}
15
+ cache_dic['cache_counter'] = 0
16
+
17
+ for j in range(len(self.transformer.layers)):
18
+ cache[-1]['layers_stream'][j] = {}
19
+ cache_index[-1][j] = {}
20
+
21
+ cache_dic['Delta-DiT'] = False
22
+ cache_dic['cache_type'] = 'random'
23
+ cache_dic['cache_index'] = cache_index
24
+ cache_dic['cache'] = cache
25
+ cache_dic['fresh_ratio_schedule'] = 'ToCa'
26
+ cache_dic['fresh_ratio'] = 0.0
27
+ cache_dic['fresh_threshold'] = 3
28
+ cache_dic['soft_fresh_weight'] = 0.0
29
+ cache_dic['taylor_cache'] = True
30
+ cache_dic['max_order'] = 4
31
+ cache_dic['first_enhance'] = 5
32
+
33
+ current = {}
34
+ current['activated_steps'] = [0]
35
+ current['step'] = 0
36
+ current['num_steps'] = num_steps
37
+
38
+ return cache_dic, current
omnigen2/cache_functions/cal_type.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cal_type.py
2
+
3
+ from .force_scheduler import force_scheduler
4
+
5
+ def cal_type(cache_dic, current):
6
+ '''
7
+ Determine calculation type for this step
8
+ '''
9
+ if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']):
10
+ # FORA:Uniform
11
+ first_step = (current['step'] == 0)
12
+ else:
13
+ # ToCa: First enhanced
14
+ first_step = (current['step'] < cache_dic['first_enhance'])
15
+
16
+ if not first_step:
17
+ fresh_interval = cache_dic['cal_threshold']
18
+ else:
19
+ fresh_interval = cache_dic['fresh_threshold']
20
+
21
+ if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ):
22
+ current['type'] = 'full'
23
+ cache_dic['cache_counter'] = 0
24
+ current['activated_steps'].append(current['step'])
25
+ force_scheduler(cache_dic, current)
26
+
27
+ elif (cache_dic['taylor_cache']):
28
+ cache_dic['cache_counter'] += 1
29
+ current['type'] = 'Taylor'
30
+
31
+
32
+ elif (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
33
+ cache_dic['cache_counter'] += 1
34
+ current['type'] = 'ToCa'
35
+ # 'cache_noise' 'ToCa' 'FORA'
36
+ elif cache_dic['Delta-DiT']:
37
+ cache_dic['cache_counter'] += 1
38
+ current['type'] = 'Delta-Cache'
39
+ else:
40
+ cache_dic['cache_counter'] += 1
41
+ current['type'] = 'ToCa'
omnigen2/cache_functions/force_scheduler.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/force_scheduler.py
2
+
3
+ import torch
4
+
5
+ def force_scheduler(cache_dic, current):
6
+ if cache_dic['fresh_ratio'] == 0:
7
+ # FORA
8
+ linear_step_weight = 0.0
9
+ else:
10
+ # TokenCache
11
+ linear_step_weight = 0.0
12
+ step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps'])
13
+ threshold = torch.round(cache_dic['fresh_threshold'] / step_factor)
14
+
15
+ # no force constrain for sensitive steps, cause the performance is good enough.
16
+ # you may have a try.
17
+
18
+ cache_dic['cal_threshold'] = threshold
19
+ #return threshold
omnigen2/dataset/__init__.py ADDED
File without changes
omnigen2/dataset/omnigen2_test_dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import os
4
+ import random
5
+ import yaml
6
+ import glob
7
+ from PIL import Image
8
+
9
+ import torch
10
+
11
+ from datasets import load_dataset, concatenate_datasets
12
+
13
+ from ..pipelines.omnigen2.pipeline_omnigen2 import OmniGen2ImageProcessor
14
+
15
+ class OmniGen2TestDataset(torch.utils.data.Dataset):
16
+ SYSTEM_PROMPT = "You are a helpful assistant that generates high-quality images based on user instructions."
17
+
18
+ def __init__(
19
+ self,
20
+ config_path: str,
21
+ tokenizer,
22
+ use_chat_template: bool,
23
+ max_pixels: Optional[int] = None,
24
+ max_side_length: Optional[int] = None,
25
+ img_scale_num: int = 16,
26
+ align_res: bool = True
27
+ ):
28
+
29
+ self.max_pixels = max_pixels
30
+ self.max_side_length = max_side_length
31
+ self.img_scale_num = img_scale_num
32
+ self.align_res = align_res
33
+
34
+ with open(config_path, "r") as f:
35
+ self.config = yaml.load(f, Loader=yaml.FullLoader)
36
+
37
+ self.use_chat_template = use_chat_template
38
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=img_scale_num, do_resize=True)
39
+
40
+ data = self._collect_annotations(self.config)
41
+
42
+ self.data = data
43
+ self.tokenizer = tokenizer
44
+
45
+ def _collect_annotations(self, config):
46
+ json_datasets = []
47
+ for data in config['data']:
48
+ data_path, data_type = data['path'], data.get("type", "default")
49
+ if os.path.isdir(data_path):
50
+ jsonl_files = list(glob.glob(os.path.join(data_path, "**/*.jsonl"), recursive=True)) + list(glob.glob(os.path.join(data_path, "**/*.json"), recursive=True))
51
+ json_dataset = load_dataset('json', data_files=jsonl_files, cache_dir=None)['train']
52
+ else:
53
+ data_ext = os.path.splitext(data_path)[-1]
54
+ if data_ext in [".json", ".jsonl"]:
55
+ json_dataset = load_dataset('json', data_files=data_path, cache_dir=None)['train']
56
+ elif data_ext in [".yml", ".yaml"]:
57
+ with open(data_path, "r") as f:
58
+ sub_config = yaml.load(f, Loader=yaml.FullLoader)
59
+ json_dataset = self._collect_annotations(sub_config)
60
+ else:
61
+ raise NotImplementedError(
62
+ f'Unknown data file extension: "{data_ext}". '
63
+ f"Currently, .json, .jsonl .yml .yaml are supported. "
64
+ "If you are using a supported format, please set the file extension so that the proper parsing "
65
+ "routine can be called."
66
+ )
67
+ json_datasets.append(json_dataset)
68
+
69
+ json_dataset = concatenate_datasets(json_datasets)
70
+ return json_dataset
71
+
72
+ def apply_chat_template(self, instruction, system_prompt):
73
+ if self.use_chat_template:
74
+ prompt = [
75
+ {
76
+ "role": "system",
77
+ "content": system_prompt,
78
+ },
79
+ {"role": "user", "content": instruction},
80
+ ]
81
+ instruction = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
82
+ return instruction
83
+
84
+ def process_item(self, data_item):
85
+ assert data_item['instruction'] is not None
86
+ if 'input_images' in data_item and data_item['input_images'] is not None:
87
+ input_images_path = data_item['input_images']
88
+ input_images = []
89
+
90
+ for input_image_path in input_images_path:
91
+ input_image = Image.open(input_image_path).convert("RGB")
92
+ input_images.append(input_image)
93
+ else:
94
+ input_images_path, input_images = None, None
95
+
96
+ if input_images is not None and len(input_images) == 1 and self.align_res:
97
+ target_img_size = (input_images[0].width, input_images[0].height)
98
+ else:
99
+ target_img_size = data_item["target_img_size"]
100
+
101
+ w, h = target_img_size
102
+ cur_pixels = w * h
103
+ ratio = min(1, (self.max_pixels / cur_pixels) ** 0.5)
104
+
105
+ target_img_size = (int(w * ratio) // self.img_scale_num * self.img_scale_num, int(h * ratio) // self.img_scale_num * self.img_scale_num)
106
+
107
+ data = {
108
+ 'task_type': data_item['task_type'],
109
+ 'instruction': data_item['instruction'],
110
+ 'input_images_path': input_images_path,
111
+ 'input_images': input_images,
112
+ 'target_img_size': target_img_size,
113
+ }
114
+ return data
115
+
116
+ def __getitem__(self, index):
117
+ data_item = self.data[index]
118
+ return self.process_item(data_item)
119
+
120
+ def __len__(self):
121
+ return len(self.data)
122
+
123
+ class OmniGen2Collator():
124
+ def __init__(self, tokenizer, max_token_len):
125
+ self.tokenizer = tokenizer
126
+ self.max_token_len = max_token_len
127
+
128
+ def __call__(self, batch):
129
+ task_type = [data['task_type'] for data in batch]
130
+ instruction = [data['instruction'] for data in batch]
131
+ input_images_path = [data['input_images_path'] for data in batch]
132
+ input_images = [data['input_images'] for data in batch]
133
+ output_image = [data['output_image'] for data in batch]
134
+ output_image_path = [data['output_image_path'] for data in batch]
135
+
136
+ text_inputs = self.tokenizer(
137
+ instruction,
138
+ padding="longest",
139
+ max_length=self.max_token_len,
140
+ truncation=True,
141
+ return_tensors="pt",
142
+ )
143
+
144
+ data = {
145
+ "task_type": task_type,
146
+ "text_ids": text_inputs.input_ids,
147
+ "text_mask": text_inputs.attention_mask,
148
+ "input_images": input_images,
149
+ "input_images_path": input_images_path,
150
+ "output_image": output_image,
151
+ "output_image_path": output_image_path,
152
+ }
153
+ return data
omnigen2/dataset/omnigen2_train_dataset.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+
3
+ import os
4
+ import random
5
+ import yaml
6
+ import glob
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torchvision import transforms
11
+
12
+ from datasets import load_dataset, concatenate_datasets
13
+
14
+ from ..pipelines.omnigen2.pipeline_omnigen2 import OmniGen2ImageProcessor
15
+
16
+ class OmniGen2TrainDataset(torch.utils.data.Dataset):
17
+ SYSTEM_PROMPT = "You are a helpful assistant that generates high-quality images based on user instructions."
18
+ SYSTEM_PROMPT_DROP = "You are a helpful assistant that generates images."
19
+
20
+ def __init__(
21
+ self,
22
+ config_path: str,
23
+ tokenizer,
24
+ use_chat_template: bool,
25
+ max_input_pixels: Optional[Union[int, List[int]]] = None,
26
+ max_output_pixels: Optional[int] = None,
27
+ max_side_length: Optional[int] = None,
28
+ img_scale_num: int = 16,
29
+ prompt_dropout_prob: float = 0.0,
30
+ ref_img_dropout_prob: float = 0.0,
31
+ ):
32
+ self.max_input_pixels = max_input_pixels
33
+ self.max_output_pixels = max_output_pixels
34
+
35
+ self.max_side_length = max_side_length
36
+ self.img_scale_num = img_scale_num
37
+ self.prompt_dropout_prob = prompt_dropout_prob
38
+ self.ref_img_dropout_prob = ref_img_dropout_prob
39
+
40
+ with open(config_path, "r") as f:
41
+ self.config = yaml.load(f, Loader=yaml.FullLoader)
42
+
43
+ self.use_chat_template = use_chat_template
44
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=img_scale_num, do_resize=True)
45
+
46
+ data = self._collect_annotations(self.config)
47
+
48
+ self.data = data
49
+ self.tokenizer = tokenizer
50
+
51
+ def _collect_annotations(self, config):
52
+ total_samples = 0
53
+ total_ratio = 0
54
+ json_datasets = []
55
+ for data in config['data']:
56
+ data_path, data_type = data['path'], data.get("type", "default")
57
+ if os.path.isdir(data_path):
58
+ jsonl_files = list(glob.glob(os.path.join(data_path, "**/*.jsonl"), recursive=True)) + list(glob.glob(os.path.join(data_path, "**/*.json"), recursive=True))
59
+ json_dataset = load_dataset('json', data_files=jsonl_files, cache_dir=None)['train']
60
+ else:
61
+ data_ext = os.path.splitext(data_path)[-1]
62
+ if data_ext in [".json", ".jsonl"]:
63
+ json_dataset = load_dataset('json', data_files=data_path, cache_dir=None)['train']
64
+ elif data_ext in [".yml", ".yaml"]:
65
+ with open(data_path, "r") as f:
66
+ sub_config = yaml.load(f, Loader=yaml.FullLoader)
67
+ json_dataset = self._collect_annotations(sub_config)
68
+ else:
69
+ raise NotImplementedError(
70
+ f'Unknown data file extension: "{data_ext}". '
71
+ f"Currently, .json, .jsonl .yml .yaml are supported. "
72
+ "If you are using a supported format, please set the file extension so that the proper parsing "
73
+ "routine can be called."
74
+ )
75
+ total_ratio += data['ratio']
76
+ total_samples += len(json_dataset)
77
+ json_datasets.append(json_dataset)
78
+
79
+ for json_dataset in json_datasets:
80
+ target_size = int(len(json_dataset) * data['ratio'] / total_ratio) # normalize the ratio
81
+ if target_size <= len(json_dataset):
82
+ # Random selection without replacement
83
+ indices = random.sample(range(len(json_dataset)), target_size)
84
+ else:
85
+ # Oversample with replacement
86
+ indices = random.choices(range(len(json_dataset)), k=target_size)
87
+ json_dataset = json_dataset.select(indices)
88
+
89
+ json_dataset = concatenate_datasets(json_datasets)
90
+ return json_dataset
91
+
92
+ def clean_data_item(self, data_item):
93
+ task_type = data_item['task_type']
94
+ prefixs = ["The image portrays ", "The image depicts ", "The image captures ", "The image highlights ", "The image shows ", "这张图片展示了"]
95
+ if "text_to_image" in task_type or "t2i" in task_type:
96
+ if random.random() < 0.5:
97
+ for p in prefixs:
98
+ if p in data_item['instruction']:
99
+ data_item['instruction'] = data_item['instruction'].replace(p, "")
100
+ break
101
+ return data_item
102
+
103
+ def apply_chat_template(self, instruction, system_prompt):
104
+ if self.use_chat_template:
105
+ prompt = [
106
+ {
107
+ "role": "system",
108
+ "content": system_prompt,
109
+ },
110
+ {"role": "user", "content": instruction},
111
+ ]
112
+ instruction = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
113
+ return instruction
114
+
115
+ def process_item(self, data_item):
116
+ assert data_item['instruction'] is not None
117
+ data_item = self.clean_data_item(data_item)
118
+
119
+ drop_prompt = random.random() < self.prompt_dropout_prob
120
+ drop_ref_img = drop_prompt and random.random() < self.ref_img_dropout_prob
121
+
122
+ if drop_prompt:
123
+ instruction = self.apply_chat_template("", self.SYSTEM_PROMPT_DROP)
124
+ else:
125
+ instruction = self.apply_chat_template(data_item['instruction'], self.SYSTEM_PROMPT)
126
+
127
+ if not drop_ref_img and 'input_images' in data_item and data_item['input_images'] is not None:
128
+ input_images_path = data_item['input_images']
129
+ input_images = []
130
+
131
+ max_input_pixels = self.max_input_pixels[len(input_images_path) - 1] if isinstance(self.max_input_pixels, list) else self.max_input_pixels
132
+
133
+ for input_image_path in input_images_path:
134
+ input_image = Image.open(input_image_path).convert("RGB")
135
+ input_image = self.image_processor.preprocess(input_image, max_pixels=max_input_pixels, max_side_length=self.max_side_length)
136
+ input_images.append(input_image)
137
+ else:
138
+ input_images_path, input_images = None, None
139
+
140
+ output_image_path = data_item['output_image']
141
+ output_image = Image.open(output_image_path).convert("RGB")
142
+ output_image = self.image_processor.preprocess(output_image, max_pixels=self.max_output_pixels, max_side_length=self.max_side_length)
143
+
144
+ data = {
145
+ 'task_type': data_item['task_type'],
146
+ 'instruction': instruction,
147
+ 'input_images_path': input_images_path,
148
+ 'input_images': input_images,
149
+ 'output_image': output_image,
150
+ 'output_image_path': output_image_path,
151
+ }
152
+ return data
153
+
154
+ def __getitem__(self, index):
155
+ max_retries = 12
156
+
157
+ current_index = index
158
+ for attempt in range(max_retries):
159
+ try:
160
+ data_item = self.data[current_index]
161
+ return self.process_item(data_item)
162
+ except Exception as e:
163
+ if attempt == max_retries - 1:
164
+ raise e
165
+ else:
166
+ # Try a different index for the next attempt
167
+ current_index = random.randint(0, len(self.data) - 1)
168
+ continue
169
+
170
+ def __len__(self):
171
+ return len(self.data)
172
+
173
+ class OmniGen2Collator():
174
+ def __init__(self, tokenizer, max_token_len):
175
+ self.tokenizer = tokenizer
176
+ self.max_token_len = max_token_len
177
+
178
+ def __call__(self, batch):
179
+ task_type = [data['task_type'] for data in batch]
180
+ instruction = [data['instruction'] for data in batch]
181
+ input_images_path = [data['input_images_path'] for data in batch]
182
+ input_images = [data['input_images'] for data in batch]
183
+ output_image = [data['output_image'] for data in batch]
184
+ output_image_path = [data['output_image_path'] for data in batch]
185
+
186
+ text_inputs = self.tokenizer(
187
+ instruction,
188
+ padding="longest",
189
+ max_length=self.max_token_len,
190
+ truncation=True,
191
+ return_tensors="pt",
192
+ )
193
+
194
+ data = {
195
+ "task_type": task_type,
196
+ "text_ids": text_inputs.input_ids,
197
+ "text_mask": text_inputs.attention_mask,
198
+ "input_images": input_images,
199
+ "input_images_path": input_images_path,
200
+ "output_image": output_image,
201
+ "output_image_path": output_image_path,
202
+ }
203
+ return data
omnigen2/models/__init__.py ADDED
File without changes
omnigen2/models/attention_processor.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniGen2 Attention Processor Module
3
+
4
+ Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ import warnings
20
+ import math
21
+ from typing import Optional, Tuple, Dict, Any
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from einops import repeat
26
+
27
+ from ..utils.import_utils import is_flash_attn_available
28
+
29
+ if is_flash_attn_available():
30
+ from flash_attn import flash_attn_varlen_func
31
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
32
+ else:
33
+ warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
34
+
35
+
36
+ from diffusers.models.attention_processor import Attention
37
+ from .embeddings import apply_rotary_emb
38
+
39
+
40
+ class OmniGen2AttnProcessorFlash2Varlen:
41
+ """
42
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
43
+
44
+ This processor implements:
45
+ - Flash attention with variable length sequences
46
+ - Rotary position embeddings (RoPE)
47
+ - Query-Key normalization
48
+ - Proportional attention scaling
49
+
50
+ Args:
51
+ None
52
+ """
53
+
54
+ def __init__(self) -> None:
55
+ """Initialize the attention processor."""
56
+ if not is_flash_attn_available():
57
+ raise ImportError(
58
+ "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. "
59
+ "Please install flash_attn."
60
+ )
61
+
62
+ def _upad_input(
63
+ self,
64
+ query_layer: torch.Tensor,
65
+ key_layer: torch.Tensor,
66
+ value_layer: torch.Tensor,
67
+ attention_mask: torch.Tensor,
68
+ query_length: int,
69
+ num_heads: int,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
71
+ """
72
+ Unpad the input tensors for flash attention.
73
+
74
+ Args:
75
+ query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
76
+ key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
77
+ value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
78
+ attention_mask: Attention mask tensor of shape (batch_size, seq_len)
79
+ query_length: Length of the query sequence
80
+ num_heads: Number of attention heads
81
+
82
+ Returns:
83
+ Tuple containing:
84
+ - Unpadded query tensor
85
+ - Unpadded key tensor
86
+ - Unpadded value tensor
87
+ - Query indices
88
+ - Tuple of cumulative sequence lengths for query and key
89
+ - Tuple of maximum sequence lengths for query and key
90
+ """
91
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
92
+ """Helper function to get unpadding data from attention mask."""
93
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
94
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
95
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
96
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
97
+ return indices, cu_seqlens, max_seqlen_in_batch
98
+
99
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
100
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
101
+
102
+ # Unpad key and value layers
103
+ key_layer = index_first_axis(
104
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
105
+ indices_k,
106
+ )
107
+ value_layer = index_first_axis(
108
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
109
+ indices_k,
110
+ )
111
+
112
+ # Handle different query length cases
113
+ if query_length == kv_seq_len:
114
+ query_layer = index_first_axis(
115
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
116
+ indices_k,
117
+ )
118
+ cu_seqlens_q = cu_seqlens_k
119
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
120
+ indices_q = indices_k
121
+ elif query_length == 1:
122
+ max_seqlen_in_batch_q = 1
123
+ cu_seqlens_q = torch.arange(
124
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
125
+ )
126
+ indices_q = cu_seqlens_q[:-1]
127
+ query_layer = query_layer.squeeze(1)
128
+ else:
129
+ attention_mask = attention_mask[:, -query_length:]
130
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
131
+
132
+ return (
133
+ query_layer,
134
+ key_layer,
135
+ value_layer,
136
+ indices_q,
137
+ (cu_seqlens_q, cu_seqlens_k),
138
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
139
+ )
140
+
141
+ def __call__(
142
+ self,
143
+ attn: Attention,
144
+ hidden_states: torch.Tensor,
145
+ encoder_hidden_states: torch.Tensor,
146
+ attention_mask: Optional[torch.Tensor] = None,
147
+ image_rotary_emb: Optional[torch.Tensor] = None,
148
+ base_sequence_length: Optional[int] = None,
149
+ ) -> torch.Tensor:
150
+ """
151
+ Process attention computation with flash attention.
152
+
153
+ Args:
154
+ attn: Attention module
155
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
156
+ encoder_hidden_states: Encoder hidden states tensor
157
+ attention_mask: Optional attention mask tensor
158
+ image_rotary_emb: Optional rotary embeddings for image tokens
159
+ base_sequence_length: Optional base sequence length for proportional attention
160
+
161
+ Returns:
162
+ torch.Tensor: Processed hidden states after attention computation
163
+ """
164
+ batch_size, sequence_length, _ = hidden_states.shape
165
+
166
+ # Get Query-Key-Value Pair
167
+ query = attn.to_q(hidden_states)
168
+ key = attn.to_k(encoder_hidden_states)
169
+ value = attn.to_v(encoder_hidden_states)
170
+
171
+ query_dim = query.shape[-1]
172
+ inner_dim = key.shape[-1]
173
+ head_dim = query_dim // attn.heads
174
+ dtype = query.dtype
175
+
176
+ # Get key-value heads
177
+ kv_heads = inner_dim // head_dim
178
+
179
+ # Reshape tensors for attention computation
180
+ query = query.view(batch_size, -1, attn.heads, head_dim)
181
+ key = key.view(batch_size, -1, kv_heads, head_dim)
182
+ value = value.view(batch_size, -1, kv_heads, head_dim)
183
+
184
+ # Apply Query-Key normalization
185
+ if attn.norm_q is not None:
186
+ query = attn.norm_q(query)
187
+ if attn.norm_k is not None:
188
+ key = attn.norm_k(key)
189
+
190
+ # Apply Rotary Position Embeddings
191
+ if image_rotary_emb is not None:
192
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
193
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
194
+
195
+ query, key = query.to(dtype), key.to(dtype)
196
+
197
+ # Calculate attention scale
198
+ if base_sequence_length is not None:
199
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
200
+ else:
201
+ softmax_scale = attn.scale
202
+
203
+ # Unpad input for flash attention
204
+ (
205
+ query_states,
206
+ key_states,
207
+ value_states,
208
+ indices_q,
209
+ cu_seq_lens,
210
+ max_seq_lens,
211
+ ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
212
+
213
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
214
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
215
+
216
+ # Handle different number of heads
217
+ if kv_heads < attn.heads:
218
+ key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
219
+ value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
220
+
221
+ # Apply flash attention
222
+ attn_output_unpad = flash_attn_varlen_func(
223
+ query_states,
224
+ key_states,
225
+ value_states,
226
+ cu_seqlens_q=cu_seqlens_q,
227
+ cu_seqlens_k=cu_seqlens_k,
228
+ max_seqlen_q=max_seqlen_in_batch_q,
229
+ max_seqlen_k=max_seqlen_in_batch_k,
230
+ dropout_p=0.0,
231
+ causal=False,
232
+ softmax_scale=softmax_scale,
233
+ )
234
+
235
+ # Pad output and apply final transformations
236
+ hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
237
+ hidden_states = hidden_states.flatten(-2)
238
+ hidden_states = hidden_states.type_as(query)
239
+
240
+ # Apply output projection
241
+ hidden_states = attn.to_out[0](hidden_states)
242
+ hidden_states = attn.to_out[1](hidden_states)
243
+
244
+ return hidden_states
245
+
246
+
247
+ class OmniGen2AttnProcessor:
248
+ """
249
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
250
+
251
+ This processor is optimized for PyTorch 2.0 and implements:
252
+ - Flash attention with variable length sequences
253
+ - Rotary position embeddings (RoPE)
254
+ - Query-Key normalization
255
+ - Proportional attention scaling
256
+
257
+ Args:
258
+ None
259
+
260
+ Raises:
261
+ ImportError: If PyTorch version is less than 2.0
262
+ """
263
+
264
+ def __init__(self) -> None:
265
+ """Initialize the attention processor."""
266
+ if not hasattr(F, "scaled_dot_product_attention"):
267
+ raise ImportError(
268
+ "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
269
+ "Please upgrade PyTorch to version 2.0 or later."
270
+ )
271
+
272
+ def __call__(
273
+ self,
274
+ attn: Attention,
275
+ hidden_states: torch.Tensor,
276
+ encoder_hidden_states: torch.Tensor,
277
+ attention_mask: Optional[torch.Tensor] = None,
278
+ image_rotary_emb: Optional[torch.Tensor] = None,
279
+ base_sequence_length: Optional[int] = None,
280
+ ) -> torch.Tensor:
281
+ """
282
+ Process attention computation with flash attention.
283
+
284
+ Args:
285
+ attn: Attention module
286
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
287
+ encoder_hidden_states: Encoder hidden states tensor
288
+ attention_mask: Optional attention mask tensor
289
+ image_rotary_emb: Optional rotary embeddings for image tokens
290
+ base_sequence_length: Optional base sequence length for proportional attention
291
+
292
+ Returns:
293
+ torch.Tensor: Processed hidden states after attention computation
294
+ """
295
+ batch_size, sequence_length, _ = hidden_states.shape
296
+
297
+ # Get Query-Key-Value Pair
298
+ query = attn.to_q(hidden_states)
299
+ key = attn.to_k(encoder_hidden_states)
300
+ value = attn.to_v(encoder_hidden_states)
301
+
302
+ query_dim = query.shape[-1]
303
+ inner_dim = key.shape[-1]
304
+ head_dim = query_dim // attn.heads
305
+ dtype = query.dtype
306
+
307
+ # Get key-value heads
308
+ kv_heads = inner_dim // head_dim
309
+
310
+ # Reshape tensors for attention computation
311
+ query = query.view(batch_size, -1, attn.heads, head_dim)
312
+ key = key.view(batch_size, -1, kv_heads, head_dim)
313
+ value = value.view(batch_size, -1, kv_heads, head_dim)
314
+
315
+ # Apply Query-Key normalization
316
+ if attn.norm_q is not None:
317
+ query = attn.norm_q(query)
318
+ if attn.norm_k is not None:
319
+ key = attn.norm_k(key)
320
+
321
+ # Apply Rotary Position Embeddings
322
+ if image_rotary_emb is not None:
323
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
324
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
325
+
326
+ query, key = query.to(dtype), key.to(dtype)
327
+
328
+ # Calculate attention scale
329
+ if base_sequence_length is not None:
330
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
331
+ else:
332
+ softmax_scale = attn.scale
333
+
334
+ # scaled_dot_product_attention expects attention_mask shape to be
335
+ # (batch, heads, source_length, target_length)
336
+ if attention_mask is not None:
337
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
338
+
339
+ query = query.transpose(1, 2)
340
+ key = key.transpose(1, 2)
341
+ value = value.transpose(1, 2)
342
+
343
+ # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
344
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
345
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
346
+
347
+ hidden_states = F.scaled_dot_product_attention(
348
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
349
+ )
350
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
351
+ hidden_states = hidden_states.type_as(query)
352
+
353
+ # Apply output projection
354
+ hidden_states = attn.to_out[0](hidden_states)
355
+ hidden_states = attn.to_out[1](hidden_states)
356
+
357
+ return hidden_states
omnigen2/models/embeddings.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+
20
+ from diffusers.models.activations import get_activation
21
+
22
+
23
+ class TimestepEmbedding(nn.Module):
24
+ def __init__(
25
+ self,
26
+ in_channels: int,
27
+ time_embed_dim: int,
28
+ act_fn: str = "silu",
29
+ out_dim: int = None,
30
+ post_act_fn: Optional[str] = None,
31
+ cond_proj_dim=None,
32
+ sample_proj_bias=True,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
37
+
38
+ if cond_proj_dim is not None:
39
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
40
+ else:
41
+ self.cond_proj = None
42
+
43
+ self.act = get_activation(act_fn)
44
+
45
+ if out_dim is not None:
46
+ time_embed_dim_out = out_dim
47
+ else:
48
+ time_embed_dim_out = time_embed_dim
49
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
50
+
51
+ if post_act_fn is None:
52
+ self.post_act = None
53
+ else:
54
+ self.post_act = get_activation(post_act_fn)
55
+
56
+ self.initialize_weights()
57
+
58
+ def initialize_weights(self):
59
+ nn.init.normal_(self.linear_1.weight, std=0.02)
60
+ nn.init.zeros_(self.linear_1.bias)
61
+ nn.init.normal_(self.linear_2.weight, std=0.02)
62
+ nn.init.zeros_(self.linear_2.bias)
63
+
64
+ def forward(self, sample, condition=None):
65
+ if condition is not None:
66
+ sample = sample + self.cond_proj(condition)
67
+ sample = self.linear_1(sample)
68
+
69
+ if self.act is not None:
70
+ sample = self.act(sample)
71
+
72
+ sample = self.linear_2(sample)
73
+
74
+ if self.post_act is not None:
75
+ sample = self.post_act(sample)
76
+ return sample
77
+
78
+
79
+ def apply_rotary_emb(
80
+ x: torch.Tensor,
81
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
82
+ use_real: bool = True,
83
+ use_real_unbind_dim: int = -1,
84
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ """
86
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
87
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
88
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
89
+ tensors contain rotary embeddings and are returned as real tensors.
90
+
91
+ Args:
92
+ x (`torch.Tensor`):
93
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
94
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
95
+
96
+ Returns:
97
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
98
+ """
99
+ if use_real:
100
+ cos, sin = freqs_cis # [S, D]
101
+ cos = cos[None, None]
102
+ sin = sin[None, None]
103
+ cos, sin = cos.to(x.device), sin.to(x.device)
104
+
105
+ if use_real_unbind_dim == -1:
106
+ # Used for flux, cogvideox, hunyuan-dit
107
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
108
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
109
+ elif use_real_unbind_dim == -2:
110
+ # Used for Stable Audio, OmniGen and CogView4
111
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
112
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
113
+ else:
114
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
115
+
116
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
117
+
118
+ return out
119
+ else:
120
+ # used for lumina
121
+ # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
122
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
123
+ freqs_cis = freqs_cis.unsqueeze(2)
124
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
125
+
126
+ return x_out.type_as(x)
omnigen2/models/transformers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .transformer_omnigen2 import OmniGen2Transformer2DModel
2
+
3
+ __all__ = ["OmniGen2Transformer2DModel"]
omnigen2/models/transformers/block_lumina2.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warnings
17
+ from typing import Optional, Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from diffusers.models.embeddings import Timesteps
23
+ from ..embeddings import TimestepEmbedding
24
+
25
+ from ...utils.import_utils import is_flash_attn_available, is_triton_available
26
+
27
+ if is_triton_available():
28
+ from ...ops.triton.layer_norm import RMSNorm
29
+ else:
30
+ from torch.nn import RMSNorm
31
+ warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
32
+
33
+ if is_flash_attn_available():
34
+ from flash_attn.ops.activations import swiglu
35
+ else:
36
+ from .components import swiglu
37
+ warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance")
38
+
39
+ # try:
40
+ # from flash_attn.ops.activations import swiglu as fused_swiglu
41
+ # FUSEDSWIGLU_AVALIBLE = True
42
+ # except ImportError:
43
+
44
+ # FUSEDSWIGLU_AVALIBLE = False
45
+ # warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
46
+
47
+ class LuminaRMSNormZero(nn.Module):
48
+ """
49
+ Norm layer adaptive RMS normalization zero.
50
+
51
+ Parameters:
52
+ embedding_dim (`int`): The size of each embedding vector.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ embedding_dim: int,
58
+ norm_eps: float,
59
+ norm_elementwise_affine: bool,
60
+ ):
61
+ super().__init__()
62
+ self.silu = nn.SiLU()
63
+ self.linear = nn.Linear(
64
+ min(embedding_dim, 1024),
65
+ 4 * embedding_dim,
66
+ bias=True,
67
+ )
68
+
69
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
70
+
71
+ def forward(
72
+ self,
73
+ x: torch.Tensor,
74
+ emb: Optional[torch.Tensor] = None,
75
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
76
+ emb = self.linear(self.silu(emb))
77
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
78
+ x = self.norm(x) * (1 + scale_msa[:, None])
79
+ return x, gate_msa, scale_mlp, gate_mlp
80
+
81
+
82
+ class LuminaLayerNormContinuous(nn.Module):
83
+ def __init__(
84
+ self,
85
+ embedding_dim: int,
86
+ conditioning_embedding_dim: int,
87
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
88
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
89
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
90
+ # However, this is how it was implemented in the original code, and it's rather likely you should
91
+ # set `elementwise_affine` to False.
92
+ elementwise_affine=True,
93
+ eps=1e-5,
94
+ bias=True,
95
+ norm_type="layer_norm",
96
+ out_dim: Optional[int] = None,
97
+ ):
98
+ super().__init__()
99
+
100
+ # AdaLN
101
+ self.silu = nn.SiLU()
102
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
103
+
104
+ if norm_type == "layer_norm":
105
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
106
+ elif norm_type == "rms_norm":
107
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
108
+ else:
109
+ raise ValueError(f"unknown norm_type {norm_type}")
110
+
111
+ self.linear_2 = None
112
+ if out_dim is not None:
113
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
114
+
115
+ def forward(
116
+ self,
117
+ x: torch.Tensor,
118
+ conditioning_embedding: torch.Tensor,
119
+ ) -> torch.Tensor:
120
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
121
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
122
+ scale = emb
123
+ x = self.norm(x) * (1 + scale)[:, None, :]
124
+
125
+ if self.linear_2 is not None:
126
+ x = self.linear_2(x)
127
+
128
+ return x
129
+
130
+
131
+ class LuminaFeedForward(nn.Module):
132
+ r"""
133
+ A feed-forward layer.
134
+
135
+ Parameters:
136
+ hidden_size (`int`):
137
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
138
+ hidden representations.
139
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
140
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
141
+ of this value.
142
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
143
+ dimension. Defaults to None.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ inner_dim: int,
150
+ multiple_of: Optional[int] = 256,
151
+ ffn_dim_multiplier: Optional[float] = None,
152
+ ):
153
+ super().__init__()
154
+ self.swiglu = swiglu
155
+
156
+ # custom hidden_size factor multiplier
157
+ if ffn_dim_multiplier is not None:
158
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
159
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
160
+
161
+ self.linear_1 = nn.Linear(
162
+ dim,
163
+ inner_dim,
164
+ bias=False,
165
+ )
166
+ self.linear_2 = nn.Linear(
167
+ inner_dim,
168
+ dim,
169
+ bias=False,
170
+ )
171
+ self.linear_3 = nn.Linear(
172
+ dim,
173
+ inner_dim,
174
+ bias=False,
175
+ )
176
+
177
+ def forward(self, x):
178
+ h1, h2 = self.linear_1(x), self.linear_3(x)
179
+ return self.linear_2(self.swiglu(h1, h2))
180
+
181
+
182
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
183
+ def __init__(
184
+ self,
185
+ hidden_size: int = 4096,
186
+ text_feat_dim: int = 2048,
187
+ frequency_embedding_size: int = 256,
188
+ norm_eps: float = 1e-5,
189
+ timestep_scale: float = 1.0,
190
+ ) -> None:
191
+ super().__init__()
192
+
193
+ self.time_proj = Timesteps(
194
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
195
+ )
196
+
197
+ self.timestep_embedder = TimestepEmbedding(
198
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
199
+ )
200
+
201
+ self.caption_embedder = nn.Sequential(
202
+ RMSNorm(text_feat_dim, eps=norm_eps),
203
+ nn.Linear(text_feat_dim, hidden_size, bias=True),
204
+ )
205
+
206
+ self._initialize_weights()
207
+
208
+ def _initialize_weights(self):
209
+ nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
210
+ nn.init.zeros_(self.caption_embedder[1].bias)
211
+
212
+ def forward(
213
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
216
+ time_embed = self.timestep_embedder(timestep_proj)
217
+ caption_embed = self.caption_embedder(text_hidden_states)
218
+ return time_embed, caption_embed
omnigen2/models/transformers/components.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+ def swiglu(x, y):
4
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
omnigen2/models/transformers/repo.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from einops import repeat
7
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
8
+
9
+ class OmniGen2RotaryPosEmbed(nn.Module):
10
+ def __init__(self, theta: int,
11
+ axes_dim: Tuple[int, int, int],
12
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
13
+ patch_size: int = 2):
14
+ super().__init__()
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+ self.axes_lens = axes_lens
18
+ self.patch_size = patch_size
19
+
20
+ @staticmethod
21
+ def get_freqs_cis(axes_dim: Tuple[int, int, int],
22
+ axes_lens: Tuple[int, int, int],
23
+ theta: int) -> List[torch.Tensor]:
24
+ freqs_cis = []
25
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
26
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
27
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
28
+ freqs_cis.append(emb)
29
+ return freqs_cis
30
+
31
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
32
+ device = ids.device
33
+ if ids.device.type == "mps":
34
+ ids = ids.to("cpu")
35
+
36
+ result = []
37
+ for i in range(len(self.axes_dim)):
38
+ freqs = freqs_cis[i].to(ids.device)
39
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
40
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
41
+ return torch.cat(result, dim=-1).to(device)
42
+
43
+ def forward(
44
+ self,
45
+ freqs_cis,
46
+ attention_mask,
47
+ l_effective_ref_img_len,
48
+ l_effective_img_len,
49
+ ref_img_sizes,
50
+ img_sizes,
51
+ device
52
+ ):
53
+ batch_size = len(attention_mask)
54
+ p = self.patch_size
55
+
56
+ encoder_seq_len = attention_mask.shape[1]
57
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
58
+
59
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
60
+
61
+ max_seq_len = max(seq_lengths)
62
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
63
+ max_img_len = max(l_effective_img_len)
64
+
65
+ # Create position IDs
66
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
67
+
68
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
69
+ # add text position ids
70
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
71
+
72
+ pe_shift = cap_seq_len
73
+ pe_shift_len = cap_seq_len
74
+
75
+ if ref_img_sizes[i] is not None:
76
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
77
+ H, W = ref_img_size
78
+ ref_H_tokens, ref_W_tokens = H // p, W // p
79
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
80
+ # add image position ids
81
+
82
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
83
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
84
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
85
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
86
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
87
+
88
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
89
+ pe_shift_len += ref_img_len
90
+
91
+ H, W = img_sizes[i]
92
+ H_tokens, W_tokens = H // p, W // p
93
+ assert H_tokens * W_tokens == l_effective_img_len[i]
94
+
95
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
96
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
97
+
98
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
99
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
100
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
101
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
102
+
103
+ # Get combined rotary embeddings
104
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
105
+
106
+ # create separate rotary embeddings for captions and images
107
+ cap_freqs_cis = torch.zeros(
108
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
109
+ )
110
+ ref_img_freqs_cis = torch.zeros(
111
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
112
+ )
113
+ img_freqs_cis = torch.zeros(
114
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
115
+ )
116
+
117
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
118
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
119
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
120
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
121
+
122
+ return (
123
+ cap_freqs_cis,
124
+ ref_img_freqs_cis,
125
+ img_freqs_cis,
126
+ freqs_cis,
127
+ l_effective_cap_len,
128
+ seq_lengths,
129
+ )
omnigen2/models/transformers/transformer_omnigen2.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import itertools
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from einops import rearrange
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.loaders import PeftAdapterMixin
14
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
15
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
16
+ from diffusers.models.attention_processor import Attention
17
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+
20
+ from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor
21
+ from .repo import OmniGen2RotaryPosEmbed
22
+ from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
23
+
24
+ from ...utils.import_utils import is_triton_available, is_flash_attn_available
25
+ from ...utils.teacache_util import TeaCacheParams
26
+
27
+ if is_triton_available():
28
+ from ...ops.triton.layer_norm import RMSNorm
29
+ else:
30
+ from torch.nn import RMSNorm
31
+
32
+ from ...taylorseer_utils import derivative_approximation, taylor_formula, taylor_cache_init
33
+ from ...cache_functions import cache_init, cal_type
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ class OmniGen2TransformerBlock(nn.Module):
38
+ """
39
+ Transformer block for OmniGen2 model.
40
+
41
+ This block implements a transformer layer with:
42
+ - Multi-head attention with flash attention
43
+ - Feed-forward network with SwiGLU activation
44
+ - RMS normalization
45
+ - Optional modulation for conditional generation
46
+
47
+ Args:
48
+ dim: Dimension of the input and output tensors
49
+ num_attention_heads: Number of attention heads
50
+ num_kv_heads: Number of key-value heads
51
+ multiple_of: Multiple of which the hidden dimension should be
52
+ ffn_dim_multiplier: Multiplier for the feed-forward network dimension
53
+ norm_eps: Epsilon value for normalization layers
54
+ modulation: Whether to use modulation for conditional generation
55
+ use_fused_rms_norm: Whether to use fused RMS normalization
56
+ use_fused_swiglu: Whether to use fused SwiGLU activation
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ dim: int,
62
+ num_attention_heads: int,
63
+ num_kv_heads: int,
64
+ multiple_of: int,
65
+ ffn_dim_multiplier: float,
66
+ norm_eps: float,
67
+ modulation: bool = True,
68
+ ) -> None:
69
+ """Initialize the transformer block."""
70
+ super().__init__()
71
+ self.head_dim = dim // num_attention_heads
72
+ self.modulation = modulation
73
+
74
+ try:
75
+ processor = OmniGen2AttnProcessorFlash2Varlen()
76
+ except ImportError:
77
+ processor = OmniGen2AttnProcessor()
78
+
79
+ # Initialize attention layer
80
+ self.attn = Attention(
81
+ query_dim=dim,
82
+ cross_attention_dim=None,
83
+ dim_head=dim // num_attention_heads,
84
+ qk_norm="rms_norm",
85
+ heads=num_attention_heads,
86
+ kv_heads=num_kv_heads,
87
+ eps=1e-5,
88
+ bias=False,
89
+ out_bias=False,
90
+ processor=processor,
91
+ )
92
+
93
+ # Initialize feed-forward network
94
+ self.feed_forward = LuminaFeedForward(
95
+ dim=dim,
96
+ inner_dim=4 * dim,
97
+ multiple_of=multiple_of,
98
+ ffn_dim_multiplier=ffn_dim_multiplier
99
+ )
100
+
101
+ # Initialize normalization layers
102
+ if modulation:
103
+ self.norm1 = LuminaRMSNormZero(
104
+ embedding_dim=dim,
105
+ norm_eps=norm_eps,
106
+ norm_elementwise_affine=True
107
+ )
108
+ else:
109
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
110
+
111
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
112
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
113
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
114
+
115
+ self.initialize_weights()
116
+
117
+ def initialize_weights(self) -> None:
118
+ """
119
+ Initialize the weights of the transformer block.
120
+
121
+ Uses Xavier uniform initialization for linear layers and zero initialization for biases.
122
+ """
123
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
124
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
125
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
126
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
127
+
128
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
129
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
130
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
131
+
132
+ if self.modulation:
133
+ nn.init.zeros_(self.norm1.linear.weight)
134
+ nn.init.zeros_(self.norm1.linear.bias)
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states: torch.Tensor,
139
+ attention_mask: torch.Tensor,
140
+ image_rotary_emb: torch.Tensor,
141
+ temb: Optional[torch.Tensor] = None,
142
+ ) -> torch.Tensor:
143
+ """
144
+ Forward pass of the transformer block.
145
+
146
+ Args:
147
+ hidden_states: Input hidden states tensor
148
+ attention_mask: Attention mask tensor
149
+ image_rotary_emb: Rotary embeddings for image tokens
150
+ temb: Optional timestep embedding tensor
151
+
152
+ Returns:
153
+ torch.Tensor: Output hidden states after transformer block processing
154
+ """
155
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
156
+ if enable_taylorseer:
157
+ if self.modulation:
158
+ if temb is None:
159
+ raise ValueError("temb must be provided when modulation is enabled")
160
+
161
+ if self.current['type'] == 'full':
162
+ self.current['module'] = 'total'
163
+ taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
164
+
165
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
166
+ attn_output = self.attn(
167
+ hidden_states=norm_hidden_states,
168
+ encoder_hidden_states=norm_hidden_states,
169
+ attention_mask=attention_mask,
170
+ image_rotary_emb=image_rotary_emb,
171
+ )
172
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
173
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
174
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
175
+
176
+ derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)
177
+
178
+ elif self.current['type'] == 'Taylor':
179
+ self.current['module'] = 'total'
180
+ hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
181
+ else:
182
+ norm_hidden_states = self.norm1(hidden_states)
183
+ attn_output = self.attn(
184
+ hidden_states=norm_hidden_states,
185
+ encoder_hidden_states=norm_hidden_states,
186
+ attention_mask=attention_mask,
187
+ image_rotary_emb=image_rotary_emb,
188
+ )
189
+ hidden_states = hidden_states + self.norm2(attn_output)
190
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
191
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
192
+ else:
193
+ if self.modulation:
194
+ if temb is None:
195
+ raise ValueError("temb must be provided when modulation is enabled")
196
+
197
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
198
+ attn_output = self.attn(
199
+ hidden_states=norm_hidden_states,
200
+ encoder_hidden_states=norm_hidden_states,
201
+ attention_mask=attention_mask,
202
+ image_rotary_emb=image_rotary_emb,
203
+ )
204
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
205
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
206
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
207
+ else:
208
+ norm_hidden_states = self.norm1(hidden_states)
209
+ attn_output = self.attn(
210
+ hidden_states=norm_hidden_states,
211
+ encoder_hidden_states=norm_hidden_states,
212
+ attention_mask=attention_mask,
213
+ image_rotary_emb=image_rotary_emb,
214
+ )
215
+ hidden_states = hidden_states + self.norm2(attn_output)
216
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
217
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
218
+
219
+ return hidden_states
220
+
221
+
222
+ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
223
+ """
224
+ OmniGen2 Transformer 2D Model.
225
+
226
+ A transformer-based diffusion model for image generation with:
227
+ - Patch-based image processing
228
+ - Rotary position embeddings
229
+ - Multi-head attention
230
+ - Conditional generation support
231
+
232
+ Args:
233
+ patch_size: Size of image patches
234
+ in_channels: Number of input channels
235
+ out_channels: Number of output channels (defaults to in_channels)
236
+ hidden_size: Size of hidden layers
237
+ num_layers: Number of transformer layers
238
+ num_refiner_layers: Number of refiner layers
239
+ num_attention_heads: Number of attention heads
240
+ num_kv_heads: Number of key-value heads
241
+ multiple_of: Multiple of which the hidden dimension should be
242
+ ffn_dim_multiplier: Multiplier for feed-forward network dimension
243
+ norm_eps: Epsilon value for normalization layers
244
+ axes_dim_rope: Dimensions for rotary position embeddings
245
+ axes_lens: Lengths for rotary position embeddings
246
+ text_feat_dim: Dimension of text features
247
+ timestep_scale: Scale factor for timestep embeddings
248
+ use_fused_rms_norm: Whether to use fused RMS normalization
249
+ use_fused_swiglu: Whether to use fused SwiGLU activation
250
+ """
251
+
252
+ _supports_gradient_checkpointing = True
253
+ _no_split_modules = ["Omnigen2TransformerBlock"]
254
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
255
+
256
+ @register_to_config
257
+ def __init__(
258
+ self,
259
+ patch_size: int = 2,
260
+ in_channels: int = 16,
261
+ out_channels: Optional[int] = None,
262
+ hidden_size: int = 2304,
263
+ num_layers: int = 26,
264
+ num_refiner_layers: int = 2,
265
+ num_attention_heads: int = 24,
266
+ num_kv_heads: int = 8,
267
+ multiple_of: int = 256,
268
+ ffn_dim_multiplier: Optional[float] = None,
269
+ norm_eps: float = 1e-5,
270
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
271
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
272
+ text_feat_dim: int = 1024,
273
+ timestep_scale: float = 1.0
274
+ ) -> None:
275
+ """Initialize the OmniGen2 transformer model."""
276
+ super().__init__()
277
+
278
+ # Validate configuration
279
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
280
+ raise ValueError(
281
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
282
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
283
+ )
284
+
285
+ self.out_channels = out_channels or in_channels
286
+
287
+ # Initialize embeddings
288
+ self.rope_embedder = OmniGen2RotaryPosEmbed(
289
+ theta=10000,
290
+ axes_dim=axes_dim_rope,
291
+ axes_lens=axes_lens,
292
+ patch_size=patch_size,
293
+ )
294
+
295
+ self.x_embedder = nn.Linear(
296
+ in_features=patch_size * patch_size * in_channels,
297
+ out_features=hidden_size,
298
+ )
299
+
300
+ self.ref_image_patch_embedder = nn.Linear(
301
+ in_features=patch_size * patch_size * in_channels,
302
+ out_features=hidden_size,
303
+ )
304
+
305
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
306
+ hidden_size=hidden_size,
307
+ text_feat_dim=text_feat_dim,
308
+ norm_eps=norm_eps,
309
+ timestep_scale=timestep_scale
310
+ )
311
+
312
+ # Initialize transformer blocks
313
+ self.noise_refiner = nn.ModuleList([
314
+ OmniGen2TransformerBlock(
315
+ hidden_size,
316
+ num_attention_heads,
317
+ num_kv_heads,
318
+ multiple_of,
319
+ ffn_dim_multiplier,
320
+ norm_eps,
321
+ modulation=True
322
+ )
323
+ for _ in range(num_refiner_layers)
324
+ ])
325
+
326
+ self.ref_image_refiner = nn.ModuleList([
327
+ OmniGen2TransformerBlock(
328
+ hidden_size,
329
+ num_attention_heads,
330
+ num_kv_heads,
331
+ multiple_of,
332
+ ffn_dim_multiplier,
333
+ norm_eps,
334
+ modulation=True
335
+ )
336
+ for _ in range(num_refiner_layers)
337
+ ])
338
+
339
+ self.context_refiner = nn.ModuleList(
340
+ [
341
+ OmniGen2TransformerBlock(
342
+ hidden_size,
343
+ num_attention_heads,
344
+ num_kv_heads,
345
+ multiple_of,
346
+ ffn_dim_multiplier,
347
+ norm_eps,
348
+ modulation=False
349
+ )
350
+ for _ in range(num_refiner_layers)
351
+ ]
352
+ )
353
+
354
+ # 3. Transformer blocks
355
+ self.layers = nn.ModuleList(
356
+ [
357
+ OmniGen2TransformerBlock(
358
+ hidden_size,
359
+ num_attention_heads,
360
+ num_kv_heads,
361
+ multiple_of,
362
+ ffn_dim_multiplier,
363
+ norm_eps,
364
+ modulation=True
365
+ )
366
+ for _ in range(num_layers)
367
+ ]
368
+ )
369
+
370
+ # 4. Output norm & projection
371
+ self.norm_out = LuminaLayerNormContinuous(
372
+ embedding_dim=hidden_size,
373
+ conditioning_embedding_dim=min(hidden_size, 1024),
374
+ elementwise_affine=False,
375
+ eps=1e-6,
376
+ bias=True,
377
+ out_dim=patch_size * patch_size * self.out_channels
378
+ )
379
+
380
+ # Add learnable embeddings to distinguish different images
381
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
382
+
383
+ self.gradient_checkpointing = False
384
+
385
+ self.initialize_weights()
386
+
387
+ # TeaCache settings
388
+ self.enable_teacache = False
389
+ self.teacache_rel_l1_thresh = 0.05
390
+ self.teacache_params = TeaCacheParams()
391
+
392
+ coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
393
+ self.rescale_func = np.poly1d(coefficients)
394
+
395
+ def initialize_weights(self) -> None:
396
+ """
397
+ Initialize the weights of the model.
398
+
399
+ Uses Xavier uniform initialization for linear layers.
400
+ """
401
+ nn.init.xavier_uniform_(self.x_embedder.weight)
402
+ nn.init.constant_(self.x_embedder.bias, 0.0)
403
+
404
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
405
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
406
+
407
+ nn.init.zeros_(self.norm_out.linear_1.weight)
408
+ nn.init.zeros_(self.norm_out.linear_1.bias)
409
+ nn.init.zeros_(self.norm_out.linear_2.weight)
410
+ nn.init.zeros_(self.norm_out.linear_2.bias)
411
+
412
+ nn.init.normal_(self.image_index_embedding, std=0.02)
413
+
414
+ def img_patch_embed_and_refine(
415
+ self,
416
+ hidden_states,
417
+ ref_image_hidden_states,
418
+ padded_img_mask,
419
+ padded_ref_img_mask,
420
+ noise_rotary_emb,
421
+ ref_img_rotary_emb,
422
+ l_effective_ref_img_len,
423
+ l_effective_img_len,
424
+ temb
425
+ ):
426
+ batch_size = len(hidden_states)
427
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
428
+
429
+ hidden_states = self.x_embedder(hidden_states)
430
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
431
+
432
+ for i in range(batch_size):
433
+ shift = 0
434
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
435
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
436
+ shift += ref_img_len
437
+
438
+ for layer in self.noise_refiner:
439
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
440
+
441
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
442
+ num_ref_images = len(flat_l_effective_ref_img_len)
443
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
444
+
445
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
446
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
447
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
448
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
449
+
450
+ # sequence of ref imgs to batch
451
+ idx = 0
452
+ for i in range(batch_size):
453
+ shift = 0
454
+ for ref_img_len in l_effective_ref_img_len[i]:
455
+ batch_ref_img_mask[idx, :ref_img_len] = True
456
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
457
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
458
+ batch_temb[idx] = temb[i]
459
+ shift += ref_img_len
460
+ idx += 1
461
+
462
+ # refine ref imgs separately
463
+ for layer in self.ref_image_refiner:
464
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
465
+
466
+ # batch of ref imgs to sequence
467
+ idx = 0
468
+ for i in range(batch_size):
469
+ shift = 0
470
+ for ref_img_len in l_effective_ref_img_len[i]:
471
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
472
+ shift += ref_img_len
473
+ idx += 1
474
+
475
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
476
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
477
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
478
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
479
+
480
+ return combined_img_hidden_states
481
+
482
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
483
+ batch_size = len(hidden_states)
484
+ p = self.config.patch_size
485
+ device = hidden_states[0].device
486
+
487
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
488
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
489
+
490
+ if ref_image_hidden_states is not None:
491
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
492
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
493
+ else:
494
+ ref_img_sizes = [None for _ in range(batch_size)]
495
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
496
+
497
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
498
+ max_img_len = max(l_effective_img_len)
499
+
500
+ # ref image patch embeddings
501
+ flat_ref_img_hidden_states = []
502
+ for i in range(batch_size):
503
+ if ref_img_sizes[i] is not None:
504
+ imgs = []
505
+ for ref_img in ref_image_hidden_states[i]:
506
+ C, H, W = ref_img.size()
507
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
508
+ imgs.append(ref_img)
509
+
510
+ img = torch.cat(imgs, dim=0)
511
+ flat_ref_img_hidden_states.append(img)
512
+ else:
513
+ flat_ref_img_hidden_states.append(None)
514
+
515
+ # image patch embeddings
516
+ flat_hidden_states = []
517
+ for i in range(batch_size):
518
+ img = hidden_states[i]
519
+ C, H, W = img.size()
520
+
521
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
522
+ flat_hidden_states.append(img)
523
+
524
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
525
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
526
+ for i in range(batch_size):
527
+ if ref_img_sizes[i] is not None:
528
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
529
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
530
+
531
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
532
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
533
+ for i in range(batch_size):
534
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
535
+ padded_img_mask[i, :l_effective_img_len[i]] = True
536
+
537
+ return (
538
+ padded_hidden_states,
539
+ padded_ref_img_hidden_states,
540
+ padded_img_mask,
541
+ padded_ref_img_mask,
542
+ l_effective_ref_img_len,
543
+ l_effective_img_len,
544
+ ref_img_sizes,
545
+ img_sizes,
546
+ )
547
+
548
+ def forward(
549
+ self,
550
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
551
+ timestep: torch.Tensor,
552
+ text_hidden_states: torch.Tensor,
553
+ freqs_cis: torch.Tensor,
554
+ text_attention_mask: torch.Tensor,
555
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
556
+ attention_kwargs: Optional[Dict[str, Any]] = None,
557
+ return_dict: bool = False,
558
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
559
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
560
+ if enable_taylorseer:
561
+ cal_type(self.cache_dic, self.current)
562
+
563
+ if attention_kwargs is not None:
564
+ attention_kwargs = attention_kwargs.copy()
565
+ lora_scale = attention_kwargs.pop("scale", 1.0)
566
+ else:
567
+ lora_scale = 1.0
568
+
569
+ if USE_PEFT_BACKEND:
570
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
571
+ scale_lora_layers(self, lora_scale)
572
+ else:
573
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
574
+ logger.warning(
575
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
576
+ )
577
+
578
+ # 1. Condition, positional & patch embedding
579
+ batch_size = len(hidden_states)
580
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
581
+
582
+ if is_hidden_states_tensor:
583
+ assert hidden_states.ndim == 4
584
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
585
+
586
+ device = hidden_states[0].device
587
+
588
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
589
+
590
+ (
591
+ hidden_states,
592
+ ref_image_hidden_states,
593
+ img_mask,
594
+ ref_img_mask,
595
+ l_effective_ref_img_len,
596
+ l_effective_img_len,
597
+ ref_img_sizes,
598
+ img_sizes,
599
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
600
+
601
+ (
602
+ context_rotary_emb,
603
+ ref_img_rotary_emb,
604
+ noise_rotary_emb,
605
+ rotary_emb,
606
+ encoder_seq_lengths,
607
+ seq_lengths,
608
+ ) = self.rope_embedder(
609
+ freqs_cis,
610
+ text_attention_mask,
611
+ l_effective_ref_img_len,
612
+ l_effective_img_len,
613
+ ref_img_sizes,
614
+ img_sizes,
615
+ device,
616
+ )
617
+
618
+ # 2. Context refinement
619
+ for layer in self.context_refiner:
620
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
621
+
622
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
623
+ hidden_states,
624
+ ref_image_hidden_states,
625
+ img_mask,
626
+ ref_img_mask,
627
+ noise_rotary_emb,
628
+ ref_img_rotary_emb,
629
+ l_effective_ref_img_len,
630
+ l_effective_img_len,
631
+ temb,
632
+ )
633
+
634
+ # 3. Joint Transformer blocks
635
+ max_seq_len = max(seq_lengths)
636
+
637
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
638
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
639
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
640
+ attention_mask[i, :seq_len] = True
641
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
642
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
643
+
644
+ hidden_states = joint_hidden_states
645
+
646
+ if self.enable_teacache:
647
+ teacache_hidden_states = hidden_states.clone()
648
+ teacache_temb = temb.clone()
649
+ modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
650
+ if self.teacache_params.is_first_or_last_step:
651
+ should_calc = True
652
+ self.teacache_params.accumulated_rel_l1_distance = 0
653
+ else:
654
+ self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
655
+ ((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \
656
+ / self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
657
+ )
658
+ if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
659
+ should_calc = False
660
+ else:
661
+ should_calc = True
662
+ self.teacache_params.accumulated_rel_l1_distance = 0
663
+ self.teacache_params.previous_modulated_inp = modulated_inp
664
+
665
+ if self.enable_teacache:
666
+ if not should_calc:
667
+ hidden_states += self.teacache_params.previous_residual
668
+ else:
669
+ ori_hidden_states = hidden_states.clone()
670
+ for layer_idx, layer in enumerate(self.layers):
671
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
672
+ hidden_states = self._gradient_checkpointing_func(
673
+ layer, hidden_states, attention_mask, rotary_emb, temb
674
+ )
675
+ else:
676
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
677
+ self.teacache_params.previous_residual = hidden_states - ori_hidden_states
678
+ else:
679
+ if enable_taylorseer:
680
+ self.current['stream'] = 'layers_stream'
681
+
682
+ for layer_idx, layer in enumerate(self.layers):
683
+ if enable_taylorseer:
684
+ layer.current = self.current
685
+ layer.cache_dic = self.cache_dic
686
+ layer.enable_taylorseer = True
687
+ self.current['layer'] = layer_idx
688
+
689
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
690
+ hidden_states = self._gradient_checkpointing_func(
691
+ layer, hidden_states, attention_mask, rotary_emb, temb
692
+ )
693
+ else:
694
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
695
+
696
+ # 4. Output norm & projection
697
+ hidden_states = self.norm_out(hidden_states, temb)
698
+
699
+ p = self.config.patch_size
700
+ output = []
701
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
702
+ height, width = img_size
703
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
704
+ if is_hidden_states_tensor:
705
+ output = torch.stack(output, dim=0)
706
+
707
+ if USE_PEFT_BACKEND:
708
+ # remove `lora_scale` from each PEFT layer
709
+ unscale_lora_layers(self, lora_scale)
710
+
711
+ if enable_taylorseer:
712
+ self.current['step'] += 1
713
+
714
+ if not return_dict:
715
+ return output
716
+ return Transformer2DModelOutput(sample=output)
omnigen2/ops/triton/__init__.py ADDED
File without changes
omnigen2/ops/triton/layer_norm.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ import triton
15
+ import triton.language as tl
16
+
17
+
18
+ from typing import Callable
19
+
20
+
21
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
22
+ def decorator(*args, **kwargs):
23
+ if cuda_amp_deprecated:
24
+ kwargs["device_type"] = "cuda"
25
+ return dec(*args, **kwargs)
26
+ return decorator
27
+
28
+
29
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
30
+ deprecated = True
31
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
32
+ else:
33
+ deprecated = False
34
+ from torch.cuda.amp import custom_fwd, custom_bwd
35
+
36
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
37
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
38
+
39
+
40
+ def triton_autotune_configs():
41
+ # Return configs with a valid warp count for the current device
42
+ configs=[]
43
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
44
+ max_threads_per_block=1024
45
+ # Default to warp size 32 if not defined by device
46
+ warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
47
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
48
+ warp_count=1
49
+ while warp_count*warp_size <= max_threads_per_block:
50
+ configs.append(triton.Config({}, num_warps=warp_count))
51
+ warp_count*=2
52
+ return configs
53
+
54
+ def layer_norm_ref(
55
+ x,
56
+ weight,
57
+ bias,
58
+ residual=None,
59
+ x1=None,
60
+ weight1=None,
61
+ bias1=None,
62
+ eps=1e-6,
63
+ dropout_p=0.0,
64
+ rowscale=None,
65
+ prenorm=False,
66
+ zero_centered_weight=False,
67
+ dropout_mask=None,
68
+ dropout_mask1=None,
69
+ upcast=False,
70
+ ):
71
+ dtype = x.dtype
72
+ if upcast:
73
+ x = x.float()
74
+ weight = weight.float()
75
+ bias = bias.float() if bias is not None else None
76
+ residual = residual.float() if residual is not None else residual
77
+ x1 = x1.float() if x1 is not None else None
78
+ weight1 = weight1.float() if weight1 is not None else None
79
+ bias1 = bias1.float() if bias1 is not None else None
80
+ if zero_centered_weight:
81
+ weight = weight + 1.0
82
+ if weight1 is not None:
83
+ weight1 = weight1 + 1.0
84
+ if x1 is not None:
85
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
86
+ if rowscale is not None:
87
+ x = x * rowscale[..., None]
88
+ if dropout_p > 0.0:
89
+ if dropout_mask is not None:
90
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
91
+ else:
92
+ x = F.dropout(x, p=dropout_p)
93
+ if x1 is not None:
94
+ if dropout_mask1 is not None:
95
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
96
+ else:
97
+ x1 = F.dropout(x1, p=dropout_p)
98
+ if x1 is not None:
99
+ x = x + x1
100
+ if residual is not None:
101
+ x = (x + residual).to(x.dtype)
102
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
103
+ dtype
104
+ )
105
+ if weight1 is None:
106
+ return out if not prenorm else (out, x)
107
+ else:
108
+ out1 = F.layer_norm(
109
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
110
+ ).to(dtype)
111
+ return (out, out1) if not prenorm else (out, out1, x)
112
+
113
+
114
+ def rms_norm_ref(
115
+ x,
116
+ weight,
117
+ bias,
118
+ residual=None,
119
+ x1=None,
120
+ weight1=None,
121
+ bias1=None,
122
+ eps=1e-6,
123
+ dropout_p=0.0,
124
+ rowscale=None,
125
+ prenorm=False,
126
+ zero_centered_weight=False,
127
+ dropout_mask=None,
128
+ dropout_mask1=None,
129
+ upcast=False,
130
+ ):
131
+ dtype = x.dtype
132
+ if upcast:
133
+ x = x.float()
134
+ weight = weight.float()
135
+ bias = bias.float() if bias is not None else None
136
+ residual = residual.float() if residual is not None else residual
137
+ x1 = x1.float() if x1 is not None else None
138
+ weight1 = weight1.float() if weight1 is not None else None
139
+ bias1 = bias1.float() if bias1 is not None else None
140
+ if zero_centered_weight:
141
+ weight = weight + 1.0
142
+ if weight1 is not None:
143
+ weight1 = weight1 + 1.0
144
+ if x1 is not None:
145
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
146
+ if rowscale is not None:
147
+ x = x * rowscale[..., None]
148
+ if dropout_p > 0.0:
149
+ if dropout_mask is not None:
150
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
151
+ else:
152
+ x = F.dropout(x, p=dropout_p)
153
+ if x1 is not None:
154
+ if dropout_mask1 is not None:
155
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
156
+ else:
157
+ x1 = F.dropout(x1, p=dropout_p)
158
+ if x1 is not None:
159
+ x = x + x1
160
+ if residual is not None:
161
+ x = (x + residual).to(x.dtype)
162
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
163
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
164
+ if weight1 is None:
165
+ return out if not prenorm else (out, x)
166
+ else:
167
+ out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
168
+ dtype
169
+ )
170
+ return (out, out1) if not prenorm else (out, out1, x)
171
+
172
+
173
+ @triton.autotune(
174
+ configs=triton_autotune_configs(),
175
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
176
+ )
177
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
178
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
179
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
180
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
181
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
182
+ @triton.jit
183
+ def _layer_norm_fwd_1pass_kernel(
184
+ X, # pointer to the input
185
+ Y, # pointer to the output
186
+ W, # pointer to the weights
187
+ B, # pointer to the biases
188
+ RESIDUAL, # pointer to the residual
189
+ X1,
190
+ W1,
191
+ B1,
192
+ Y1,
193
+ RESIDUAL_OUT, # pointer to the residual
194
+ ROWSCALE,
195
+ SEEDS, # Dropout seeds for each row
196
+ DROPOUT_MASK,
197
+ Mean, # pointer to the mean
198
+ Rstd, # pointer to the 1/std
199
+ stride_x_row, # how much to increase the pointer when moving by 1 row
200
+ stride_y_row,
201
+ stride_res_row,
202
+ stride_res_out_row,
203
+ stride_x1_row,
204
+ stride_y1_row,
205
+ M, # number of rows in X
206
+ N, # number of columns in X
207
+ eps, # epsilon to avoid division by zero
208
+ dropout_p, # Dropout probability
209
+ zero_centered_weight, # If true, add 1.0 to the weight
210
+ IS_RMS_NORM: tl.constexpr,
211
+ BLOCK_N: tl.constexpr,
212
+ HAS_RESIDUAL: tl.constexpr,
213
+ STORE_RESIDUAL_OUT: tl.constexpr,
214
+ HAS_BIAS: tl.constexpr,
215
+ HAS_DROPOUT: tl.constexpr,
216
+ STORE_DROPOUT_MASK: tl.constexpr,
217
+ HAS_ROWSCALE: tl.constexpr,
218
+ HAS_X1: tl.constexpr,
219
+ HAS_W1: tl.constexpr,
220
+ HAS_B1: tl.constexpr,
221
+ ):
222
+ # Map the program id to the row of X and Y it should compute.
223
+ row = tl.program_id(0)
224
+ X += row * stride_x_row
225
+ Y += row * stride_y_row
226
+ if HAS_RESIDUAL:
227
+ RESIDUAL += row * stride_res_row
228
+ if STORE_RESIDUAL_OUT:
229
+ RESIDUAL_OUT += row * stride_res_out_row
230
+ if HAS_X1:
231
+ X1 += row * stride_x1_row
232
+ if HAS_W1:
233
+ Y1 += row * stride_y1_row
234
+ # Compute mean and variance
235
+ cols = tl.arange(0, BLOCK_N)
236
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
237
+ if HAS_ROWSCALE:
238
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
239
+ x *= rowscale
240
+ if HAS_DROPOUT:
241
+ # Compute dropout mask
242
+ # 7 rounds is good enough, and reduces register pressure
243
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
244
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
245
+ if STORE_DROPOUT_MASK:
246
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
247
+ if HAS_X1:
248
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
249
+ if HAS_ROWSCALE:
250
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
251
+ x1 *= rowscale
252
+ if HAS_DROPOUT:
253
+ # Compute dropout mask
254
+ # 7 rounds is good enough, and reduces register pressure
255
+ keep_mask = (
256
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
257
+ )
258
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
259
+ if STORE_DROPOUT_MASK:
260
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
261
+ x += x1
262
+ if HAS_RESIDUAL:
263
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
264
+ x += residual
265
+ if STORE_RESIDUAL_OUT:
266
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
267
+ if not IS_RMS_NORM:
268
+ mean = tl.sum(x, axis=0) / N
269
+ tl.store(Mean + row, mean)
270
+ xbar = tl.where(cols < N, x - mean, 0.0)
271
+ var = tl.sum(xbar * xbar, axis=0) / N
272
+ else:
273
+ xbar = tl.where(cols < N, x, 0.0)
274
+ var = tl.sum(xbar * xbar, axis=0) / N
275
+ rstd = 1 / tl.sqrt(var + eps)
276
+ tl.store(Rstd + row, rstd)
277
+ # Normalize and apply linear transformation
278
+ mask = cols < N
279
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
280
+ if zero_centered_weight:
281
+ w += 1.0
282
+ if HAS_BIAS:
283
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
284
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
285
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
286
+ # Write output
287
+ tl.store(Y + cols, y, mask=mask)
288
+ if HAS_W1:
289
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
290
+ if zero_centered_weight:
291
+ w1 += 1.0
292
+ if HAS_B1:
293
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
294
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
295
+ tl.store(Y1 + cols, y1, mask=mask)
296
+
297
+
298
+ def _layer_norm_fwd(
299
+ x,
300
+ weight,
301
+ bias,
302
+ eps,
303
+ residual=None,
304
+ x1=None,
305
+ weight1=None,
306
+ bias1=None,
307
+ dropout_p=0.0,
308
+ rowscale=None,
309
+ out_dtype=None,
310
+ residual_dtype=None,
311
+ zero_centered_weight=False,
312
+ is_rms_norm=False,
313
+ return_dropout_mask=False,
314
+ out=None,
315
+ residual_out=None
316
+ ):
317
+ if residual is not None:
318
+ residual_dtype = residual.dtype
319
+ M, N = x.shape
320
+ assert x.stride(-1) == 1
321
+ if residual is not None:
322
+ assert residual.stride(-1) == 1
323
+ assert residual.shape == (M, N)
324
+ assert weight.shape == (N,)
325
+ assert weight.stride(-1) == 1
326
+ if bias is not None:
327
+ assert bias.stride(-1) == 1
328
+ assert bias.shape == (N,)
329
+ if x1 is not None:
330
+ assert x1.shape == x.shape
331
+ assert rowscale is None
332
+ assert x1.stride(-1) == 1
333
+ if weight1 is not None:
334
+ assert weight1.shape == (N,)
335
+ assert weight1.stride(-1) == 1
336
+ if bias1 is not None:
337
+ assert bias1.shape == (N,)
338
+ assert bias1.stride(-1) == 1
339
+ if rowscale is not None:
340
+ assert rowscale.is_contiguous()
341
+ assert rowscale.shape == (M,)
342
+ # allocate output
343
+ if out is None:
344
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
345
+ else:
346
+ assert out.shape == x.shape
347
+ assert out.stride(-1) == 1
348
+ if weight1 is not None:
349
+ y1 = torch.empty_like(out)
350
+ assert y1.stride(-1) == 1
351
+ else:
352
+ y1 = None
353
+ if (
354
+ residual is not None
355
+ or (residual_dtype is not None and residual_dtype != x.dtype)
356
+ or dropout_p > 0.0
357
+ or rowscale is not None
358
+ or x1 is not None
359
+ ):
360
+ if residual_out is None:
361
+ residual_out = torch.empty(
362
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
363
+ )
364
+ else:
365
+ assert residual_out.shape == x.shape
366
+ assert residual_out.stride(-1) == 1
367
+ else:
368
+ residual_out = None
369
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
370
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
+ if dropout_p > 0.0:
372
+ seeds = torch.randint(
373
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
+ )
375
+ else:
376
+ seeds = None
377
+ if return_dropout_mask and dropout_p > 0.0:
378
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
379
+ else:
380
+ dropout_mask = None
381
+ # Less than 64KB per feature: enqueue fused kernel
382
+ MAX_FUSED_SIZE = 65536 // x.element_size()
383
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
384
+ if N > BLOCK_N:
385
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
386
+ with torch.cuda.device(x.device.index):
387
+ _layer_norm_fwd_1pass_kernel[(M,)](
388
+ x,
389
+ out,
390
+ weight,
391
+ bias,
392
+ residual,
393
+ x1,
394
+ weight1,
395
+ bias1,
396
+ y1,
397
+ residual_out,
398
+ rowscale,
399
+ seeds,
400
+ dropout_mask,
401
+ mean,
402
+ rstd,
403
+ x.stride(0),
404
+ out.stride(0),
405
+ residual.stride(0) if residual is not None else 0,
406
+ residual_out.stride(0) if residual_out is not None else 0,
407
+ x1.stride(0) if x1 is not None else 0,
408
+ y1.stride(0) if y1 is not None else 0,
409
+ M,
410
+ N,
411
+ eps,
412
+ dropout_p,
413
+ zero_centered_weight,
414
+ is_rms_norm,
415
+ BLOCK_N,
416
+ residual is not None,
417
+ residual_out is not None,
418
+ bias is not None,
419
+ dropout_p > 0.0,
420
+ dropout_mask is not None,
421
+ rowscale is not None,
422
+ )
423
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
424
+ if dropout_mask is not None and x1 is not None:
425
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
426
+ else:
427
+ dropout_mask1 = None
428
+ return (
429
+ out,
430
+ y1,
431
+ mean,
432
+ rstd,
433
+ residual_out if residual_out is not None else x,
434
+ seeds,
435
+ dropout_mask,
436
+ dropout_mask1,
437
+ )
438
+
439
+
440
+ @triton.autotune(
441
+ configs=triton_autotune_configs(),
442
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
443
+ )
444
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
445
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
446
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
447
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
448
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
449
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
450
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
451
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
452
+ @triton.jit
453
+ def _layer_norm_bwd_kernel(
454
+ X, # pointer to the input
455
+ W, # pointer to the weights
456
+ B, # pointer to the biases
457
+ Y, # pointer to the output to be recomputed
458
+ DY, # pointer to the output gradient
459
+ DX, # pointer to the input gradient
460
+ DW, # pointer to the partial sum of weights gradient
461
+ DB, # pointer to the partial sum of biases gradient
462
+ DRESIDUAL,
463
+ W1,
464
+ DY1,
465
+ DX1,
466
+ DW1,
467
+ DB1,
468
+ DRESIDUAL_IN,
469
+ ROWSCALE,
470
+ SEEDS,
471
+ Mean, # pointer to the mean
472
+ Rstd, # pointer to the 1/std
473
+ stride_x_row, # how much to increase the pointer when moving by 1 row
474
+ stride_y_row,
475
+ stride_dy_row,
476
+ stride_dx_row,
477
+ stride_dres_row,
478
+ stride_dy1_row,
479
+ stride_dx1_row,
480
+ stride_dres_in_row,
481
+ M, # number of rows in X
482
+ N, # number of columns in X
483
+ eps, # epsilon to avoid division by zero
484
+ dropout_p,
485
+ zero_centered_weight,
486
+ rows_per_program,
487
+ IS_RMS_NORM: tl.constexpr,
488
+ BLOCK_N: tl.constexpr,
489
+ HAS_DRESIDUAL: tl.constexpr,
490
+ STORE_DRESIDUAL: tl.constexpr,
491
+ HAS_BIAS: tl.constexpr,
492
+ HAS_DROPOUT: tl.constexpr,
493
+ HAS_ROWSCALE: tl.constexpr,
494
+ HAS_DY1: tl.constexpr,
495
+ HAS_DX1: tl.constexpr,
496
+ HAS_B1: tl.constexpr,
497
+ RECOMPUTE_OUTPUT: tl.constexpr,
498
+ ):
499
+ # Map the program id to the elements of X, DX, and DY it should compute.
500
+ row_block_id = tl.program_id(0)
501
+ row_start = row_block_id * rows_per_program
502
+ # Do not early exit if row_start >= M, because we need to write DW and DB
503
+ cols = tl.arange(0, BLOCK_N)
504
+ mask = cols < N
505
+ X += row_start * stride_x_row
506
+ if HAS_DRESIDUAL:
507
+ DRESIDUAL += row_start * stride_dres_row
508
+ if STORE_DRESIDUAL:
509
+ DRESIDUAL_IN += row_start * stride_dres_in_row
510
+ DY += row_start * stride_dy_row
511
+ DX += row_start * stride_dx_row
512
+ if HAS_DY1:
513
+ DY1 += row_start * stride_dy1_row
514
+ if HAS_DX1:
515
+ DX1 += row_start * stride_dx1_row
516
+ if RECOMPUTE_OUTPUT:
517
+ Y += row_start * stride_y_row
518
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
519
+ if zero_centered_weight:
520
+ w += 1.0
521
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
522
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
523
+ if HAS_DY1:
524
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
525
+ if zero_centered_weight:
526
+ w1 += 1.0
527
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
528
+ if HAS_BIAS:
529
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
530
+ if HAS_DY1:
531
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
532
+ if HAS_B1:
533
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
534
+ row_end = min((row_block_id + 1) * rows_per_program, M)
535
+ for row in range(row_start, row_end):
536
+ # Load data to SRAM
537
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
538
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
539
+ if HAS_DY1:
540
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
541
+ if not IS_RMS_NORM:
542
+ mean = tl.load(Mean + row)
543
+ rstd = tl.load(Rstd + row)
544
+ # Compute dx
545
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
546
+ xhat = tl.where(mask, xhat, 0.0)
547
+ if RECOMPUTE_OUTPUT:
548
+ y = xhat * w + b if HAS_BIAS else xhat * w
549
+ tl.store(Y + cols, y, mask=mask)
550
+ wdy = w * dy
551
+ dw += dy * xhat
552
+ if HAS_BIAS:
553
+ db += dy
554
+ if HAS_DY1:
555
+ wdy += w1 * dy1
556
+ dw1 += dy1 * xhat
557
+ if HAS_B1:
558
+ db1 += dy1
559
+ if not IS_RMS_NORM:
560
+ c1 = tl.sum(xhat * wdy, axis=0) / N
561
+ c2 = tl.sum(wdy, axis=0) / N
562
+ dx = (wdy - (xhat * c1 + c2)) * rstd
563
+ else:
564
+ c1 = tl.sum(xhat * wdy, axis=0) / N
565
+ dx = (wdy - xhat * c1) * rstd
566
+ if HAS_DRESIDUAL:
567
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
568
+ dx += dres
569
+ # Write dx
570
+ if STORE_DRESIDUAL:
571
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
572
+ if HAS_DX1:
573
+ if HAS_DROPOUT:
574
+ keep_mask = (
575
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
576
+ )
577
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
578
+ else:
579
+ dx1 = dx
580
+ tl.store(DX1 + cols, dx1, mask=mask)
581
+ if HAS_DROPOUT:
582
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
583
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
584
+ if HAS_ROWSCALE:
585
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
586
+ dx *= rowscale
587
+ tl.store(DX + cols, dx, mask=mask)
588
+
589
+ X += stride_x_row
590
+ if HAS_DRESIDUAL:
591
+ DRESIDUAL += stride_dres_row
592
+ if STORE_DRESIDUAL:
593
+ DRESIDUAL_IN += stride_dres_in_row
594
+ if RECOMPUTE_OUTPUT:
595
+ Y += stride_y_row
596
+ DY += stride_dy_row
597
+ DX += stride_dx_row
598
+ if HAS_DY1:
599
+ DY1 += stride_dy1_row
600
+ if HAS_DX1:
601
+ DX1 += stride_dx1_row
602
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
603
+ if HAS_BIAS:
604
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
605
+ if HAS_DY1:
606
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
607
+ if HAS_B1:
608
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
609
+
610
+
611
+ def _layer_norm_bwd(
612
+ dy,
613
+ x,
614
+ weight,
615
+ bias,
616
+ eps,
617
+ mean,
618
+ rstd,
619
+ dresidual=None,
620
+ dy1=None,
621
+ weight1=None,
622
+ bias1=None,
623
+ seeds=None,
624
+ dropout_p=0.0,
625
+ rowscale=None,
626
+ has_residual=False,
627
+ has_x1=False,
628
+ zero_centered_weight=False,
629
+ is_rms_norm=False,
630
+ x_dtype=None,
631
+ recompute_output=False,
632
+ ):
633
+ M, N = x.shape
634
+ assert x.stride(-1) == 1
635
+ assert dy.stride(-1) == 1
636
+ assert dy.shape == (M, N)
637
+ if dresidual is not None:
638
+ assert dresidual.stride(-1) == 1
639
+ assert dresidual.shape == (M, N)
640
+ assert weight.shape == (N,)
641
+ assert weight.stride(-1) == 1
642
+ if bias is not None:
643
+ assert bias.stride(-1) == 1
644
+ assert bias.shape == (N,)
645
+ if dy1 is not None:
646
+ assert weight1 is not None
647
+ assert dy1.shape == dy.shape
648
+ assert dy1.stride(-1) == 1
649
+ if weight1 is not None:
650
+ assert weight1.shape == (N,)
651
+ assert weight1.stride(-1) == 1
652
+ if bias1 is not None:
653
+ assert bias1.shape == (N,)
654
+ assert bias1.stride(-1) == 1
655
+ if seeds is not None:
656
+ assert seeds.is_contiguous()
657
+ assert seeds.shape == (M if not has_x1 else M * 2,)
658
+ if rowscale is not None:
659
+ assert rowscale.is_contiguous()
660
+ assert rowscale.shape == (M,)
661
+ # allocate output
662
+ dx = (
663
+ torch.empty_like(x)
664
+ if x_dtype is None
665
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
666
+ )
667
+ dresidual_in = (
668
+ torch.empty_like(x)
669
+ if has_residual
670
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
671
+ else None
672
+ )
673
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
674
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
675
+ if recompute_output:
676
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
677
+
678
+ # Less than 64KB per feature: enqueue fused kernel
679
+ MAX_FUSED_SIZE = 65536 // x.element_size()
680
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
681
+ if N > BLOCK_N:
682
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
683
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
684
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
685
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
686
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
687
+ _db = (
688
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
689
+ if bias is not None
690
+ else None
691
+ )
692
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
693
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
694
+ rows_per_program = math.ceil(M / sm_count)
695
+ grid = (sm_count,)
696
+ with torch.cuda.device(x.device.index):
697
+ _layer_norm_bwd_kernel[grid](
698
+ x,
699
+ weight,
700
+ bias,
701
+ y,
702
+ dy,
703
+ dx,
704
+ _dw,
705
+ _db,
706
+ dresidual,
707
+ weight1,
708
+ dy1,
709
+ dx1,
710
+ _dw1,
711
+ _db1,
712
+ dresidual_in,
713
+ rowscale,
714
+ seeds,
715
+ mean,
716
+ rstd,
717
+ x.stride(0),
718
+ 0 if not recompute_output else y.stride(0),
719
+ dy.stride(0),
720
+ dx.stride(0),
721
+ dresidual.stride(0) if dresidual is not None else 0,
722
+ dy1.stride(0) if dy1 is not None else 0,
723
+ dx1.stride(0) if dx1 is not None else 0,
724
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
725
+ M,
726
+ N,
727
+ eps,
728
+ dropout_p,
729
+ zero_centered_weight,
730
+ rows_per_program,
731
+ is_rms_norm,
732
+ BLOCK_N,
733
+ dresidual is not None,
734
+ dresidual_in is not None,
735
+ bias is not None,
736
+ dropout_p > 0.0,
737
+ )
738
+ dw = _dw.sum(0).to(weight.dtype)
739
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
740
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
741
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
742
+ # Don't need to compute dresidual_in separately in this case
743
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
744
+ dresidual_in = dx
745
+ if has_x1 and dropout_p == 0.0:
746
+ dx1 = dx
747
+ return (
748
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
749
+ if not recompute_output
750
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
751
+ )
752
+
753
+
754
+ class LayerNormFn(torch.autograd.Function):
755
+ @staticmethod
756
+ def forward(
757
+ ctx,
758
+ x,
759
+ weight,
760
+ bias,
761
+ residual=None,
762
+ x1=None,
763
+ weight1=None,
764
+ bias1=None,
765
+ eps=1e-6,
766
+ dropout_p=0.0,
767
+ rowscale=None,
768
+ prenorm=False,
769
+ residual_in_fp32=False,
770
+ zero_centered_weight=False,
771
+ is_rms_norm=False,
772
+ return_dropout_mask=False,
773
+ out=None,
774
+ residual_out=None
775
+ ):
776
+ x_shape_og = x.shape
777
+ # Check for zero sequence length
778
+ if x.numel() == 0:
779
+ ctx.zero_seq_length = True
780
+ # Only save minimal required tensors for backward
781
+ # ctx.save_for_backward(weight, bias, weight1, bias1)
782
+ ctx.x_shape_og = x_shape_og
783
+ ctx.weight_shape = weight.shape
784
+ ctx.weight_dtype = weight.dtype
785
+ ctx.weight_device = weight.device
786
+
787
+ ctx.has_bias = bias is not None
788
+ ctx.bias_shape = bias.shape if bias is not None else None
789
+ ctx.bias_dtype = bias.dtype if bias is not None else None
790
+ ctx.bias_device = bias.device if bias is not None else None
791
+
792
+ ctx.has_weight1 = weight1 is not None
793
+ ctx.weight1_shape = weight1.shape if weight1 is not None else None
794
+ ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
795
+ ctx.weight1_device = weight1.device if weight1 is not None else None
796
+
797
+ ctx.has_bias1 = bias1 is not None
798
+ ctx.bias1_shape = bias1.shape if bias1 is not None else None
799
+ ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
800
+ ctx.bias1_device = bias1.device if bias1 is not None else None
801
+
802
+ ctx.has_residual = residual is not None
803
+ ctx.has_x1 = x1 is not None
804
+ ctx.dropout_p = dropout_p
805
+
806
+ # Handle output tensors with correct dtype
807
+ y = x # Preserve input tensor properties
808
+ y1 = torch.empty_like(x) if x1 is not None else None
809
+
810
+ # Only create residual_out if prenorm is True
811
+ residual_out = torch.empty(x.shape,
812
+ dtype=torch.float32 if residual_in_fp32 else x.dtype,
813
+ device=x.device) if prenorm else None
814
+
815
+ # Handle dropout masks
816
+ dropout_mask = None
817
+ dropout_mask1 = None
818
+ if return_dropout_mask:
819
+ dropout_mask = torch.empty_like(x, dtype=torch.uint8)
820
+ if x1 is not None:
821
+ dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
822
+
823
+ # Return based on configuration
824
+ if not return_dropout_mask:
825
+ if weight1 is None:
826
+ return y if not prenorm else (y, residual_out)
827
+ else:
828
+ return (y, y1) if not prenorm else (y, y1, residual_out)
829
+ else:
830
+ if weight1 is None:
831
+ return ((y, dropout_mask, dropout_mask1) if not prenorm
832
+ else (y, residual_out, dropout_mask, dropout_mask1))
833
+ else:
834
+ return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
835
+ else (y, y1, residual_out, dropout_mask, dropout_mask1))
836
+
837
+ ctx.zero_seq_length = False
838
+ # reshape input data into 2D tensor
839
+ x = x.reshape(-1, x.shape[-1])
840
+ if x.stride(-1) != 1:
841
+ x = x.contiguous()
842
+ if residual is not None:
843
+ assert residual.shape == x_shape_og
844
+ residual = residual.reshape(-1, residual.shape[-1])
845
+ if residual.stride(-1) != 1:
846
+ residual = residual.contiguous()
847
+ if x1 is not None:
848
+ assert x1.shape == x_shape_og
849
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
850
+ x1 = x1.reshape(-1, x1.shape[-1])
851
+ if x1.stride(-1) != 1:
852
+ x1 = x1.contiguous()
853
+ weight = weight.contiguous()
854
+ if bias is not None:
855
+ bias = bias.contiguous()
856
+ if weight1 is not None:
857
+ weight1 = weight1.contiguous()
858
+ if bias1 is not None:
859
+ bias1 = bias1.contiguous()
860
+ if rowscale is not None:
861
+ rowscale = rowscale.reshape(-1).contiguous()
862
+ residual_dtype = (
863
+ residual.dtype
864
+ if residual is not None
865
+ else (torch.float32 if residual_in_fp32 else None)
866
+ )
867
+ if out is not None:
868
+ out = out.reshape(-1, out.shape[-1])
869
+ if residual_out is not None:
870
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
871
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
872
+ x,
873
+ weight,
874
+ bias,
875
+ eps,
876
+ residual,
877
+ x1,
878
+ weight1,
879
+ bias1,
880
+ dropout_p=dropout_p,
881
+ rowscale=rowscale,
882
+ residual_dtype=residual_dtype,
883
+ zero_centered_weight=zero_centered_weight,
884
+ is_rms_norm=is_rms_norm,
885
+ return_dropout_mask=return_dropout_mask,
886
+ out=out,
887
+ residual_out=residual_out
888
+ )
889
+ ctx.save_for_backward(
890
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
891
+ )
892
+ ctx.x_shape_og = x_shape_og
893
+ ctx.eps = eps
894
+ ctx.dropout_p = dropout_p
895
+ ctx.is_rms_norm = is_rms_norm
896
+ ctx.has_residual = residual is not None
897
+ ctx.has_x1 = x1 is not None
898
+ ctx.prenorm = prenorm
899
+ ctx.x_dtype = x.dtype
900
+ ctx.zero_centered_weight = zero_centered_weight
901
+ y = y.reshape(x_shape_og)
902
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
903
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
904
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
905
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
906
+ if not return_dropout_mask:
907
+ if weight1 is None:
908
+ return y if not prenorm else (y, residual_out)
909
+ else:
910
+ return (y, y1) if not prenorm else (y, y1, residual_out)
911
+ else:
912
+ if weight1 is None:
913
+ return (
914
+ (y, dropout_mask, dropout_mask1)
915
+ if not prenorm
916
+ else (y, residual_out, dropout_mask, dropout_mask1)
917
+ )
918
+ else:
919
+ return (
920
+ (y, y1, dropout_mask, dropout_mask1)
921
+ if not prenorm
922
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
923
+ )
924
+
925
+ @staticmethod
926
+ def backward(ctx, dy, *args):
927
+ if ctx.zero_seq_length:
928
+ return (
929
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
930
+ torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
931
+ torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
932
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
933
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
934
+ torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
935
+ torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
936
+ None,
937
+ None,
938
+ None,
939
+ None,
940
+ None,
941
+ None,
942
+ None,
943
+ None,
944
+ None,
945
+ None,
946
+ )
947
+
948
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
949
+ dy = dy.reshape(-1, dy.shape[-1])
950
+ if dy.stride(-1) != 1:
951
+ dy = dy.contiguous()
952
+ assert dy.shape == x.shape
953
+ if weight1 is not None:
954
+ dy1, args = args[0], args[1:]
955
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
956
+ if dy1.stride(-1) != 1:
957
+ dy1 = dy1.contiguous()
958
+ assert dy1.shape == x.shape
959
+ else:
960
+ dy1 = None
961
+ if ctx.prenorm:
962
+ dresidual = args[0]
963
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
964
+ if dresidual.stride(-1) != 1:
965
+ dresidual = dresidual.contiguous()
966
+ assert dresidual.shape == x.shape
967
+ else:
968
+ dresidual = None
969
+
970
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
971
+ dy,
972
+ x,
973
+ weight,
974
+ bias,
975
+ ctx.eps,
976
+ mean,
977
+ rstd,
978
+ dresidual,
979
+ dy1,
980
+ weight1,
981
+ bias1,
982
+ seeds,
983
+ ctx.dropout_p,
984
+ rowscale,
985
+ ctx.has_residual,
986
+ ctx.has_x1,
987
+ ctx.zero_centered_weight,
988
+ ctx.is_rms_norm,
989
+ x_dtype=ctx.x_dtype,
990
+ )
991
+ return (
992
+ dx.reshape(ctx.x_shape_og),
993
+ dw,
994
+ db,
995
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
996
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
997
+ dw1,
998
+ db1,
999
+ None,
1000
+ None,
1001
+ None,
1002
+ None,
1003
+ None,
1004
+ None,
1005
+ None,
1006
+ None,
1007
+ None,
1008
+ None,
1009
+ )
1010
+
1011
+
1012
+ def layer_norm_fn(
1013
+ x,
1014
+ weight,
1015
+ bias,
1016
+ residual=None,
1017
+ x1=None,
1018
+ weight1=None,
1019
+ bias1=None,
1020
+ eps=1e-6,
1021
+ dropout_p=0.0,
1022
+ rowscale=None,
1023
+ prenorm=False,
1024
+ residual_in_fp32=False,
1025
+ zero_centered_weight=False,
1026
+ is_rms_norm=False,
1027
+ return_dropout_mask=False,
1028
+ out=None,
1029
+ residual_out=None
1030
+ ):
1031
+ return LayerNormFn.apply(
1032
+ x,
1033
+ weight,
1034
+ bias,
1035
+ residual,
1036
+ x1,
1037
+ weight1,
1038
+ bias1,
1039
+ eps,
1040
+ dropout_p,
1041
+ rowscale,
1042
+ prenorm,
1043
+ residual_in_fp32,
1044
+ zero_centered_weight,
1045
+ is_rms_norm,
1046
+ return_dropout_mask,
1047
+ out,
1048
+ residual_out
1049
+ )
1050
+
1051
+
1052
+ def rms_norm_fn(
1053
+ x,
1054
+ weight,
1055
+ bias,
1056
+ residual=None,
1057
+ x1=None,
1058
+ weight1=None,
1059
+ bias1=None,
1060
+ eps=1e-6,
1061
+ dropout_p=0.0,
1062
+ rowscale=None,
1063
+ prenorm=False,
1064
+ residual_in_fp32=False,
1065
+ zero_centered_weight=False,
1066
+ return_dropout_mask=False,
1067
+ out=None,
1068
+ residual_out=None
1069
+ ):
1070
+ return LayerNormFn.apply(
1071
+ x,
1072
+ weight,
1073
+ bias,
1074
+ residual,
1075
+ x1,
1076
+ weight1,
1077
+ bias1,
1078
+ eps,
1079
+ dropout_p,
1080
+ rowscale,
1081
+ prenorm,
1082
+ residual_in_fp32,
1083
+ zero_centered_weight,
1084
+ True,
1085
+ return_dropout_mask,
1086
+ out,
1087
+ residual_out
1088
+ )
1089
+
1090
+
1091
+ class RMSNorm(torch.nn.Module):
1092
+
1093
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
1094
+ device=None, dtype=None):
1095
+ factory_kwargs = {"device": device, "dtype": dtype}
1096
+ super().__init__()
1097
+ self.eps = eps
1098
+ if dropout_p > 0.0:
1099
+ self.drop = torch.nn.Dropout(dropout_p)
1100
+ else:
1101
+ self.drop = None
1102
+ self.zero_centered_weight = zero_centered_weight
1103
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1104
+ self.register_parameter("bias", None)
1105
+ self.reset_parameters()
1106
+
1107
+ def reset_parameters(self):
1108
+ if not self.zero_centered_weight:
1109
+ torch.nn.init.ones_(self.weight)
1110
+ else:
1111
+ torch.nn.init.zeros_(self.weight)
1112
+
1113
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1114
+ return rms_norm_fn(
1115
+ x,
1116
+ self.weight,
1117
+ self.bias,
1118
+ residual=residual,
1119
+ eps=self.eps,
1120
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1121
+ prenorm=prenorm,
1122
+ residual_in_fp32=residual_in_fp32,
1123
+ zero_centered_weight=self.zero_centered_weight,
1124
+ )
1125
+
1126
+
1127
+ class LayerNormLinearFn(torch.autograd.Function):
1128
+ @staticmethod
1129
+ @custom_fwd
1130
+ def forward(
1131
+ ctx,
1132
+ x,
1133
+ norm_weight,
1134
+ norm_bias,
1135
+ linear_weight,
1136
+ linear_bias,
1137
+ residual=None,
1138
+ eps=1e-6,
1139
+ prenorm=False,
1140
+ residual_in_fp32=False,
1141
+ is_rms_norm=False,
1142
+ ):
1143
+ x_shape_og = x.shape
1144
+ # reshape input data into 2D tensor
1145
+ x = x.reshape(-1, x.shape[-1])
1146
+ if x.stride(-1) != 1:
1147
+ x = x.contiguous()
1148
+ if residual is not None:
1149
+ assert residual.shape == x_shape_og
1150
+ residual = residual.reshape(-1, residual.shape[-1])
1151
+ if residual.stride(-1) != 1:
1152
+ residual = residual.contiguous()
1153
+ norm_weight = norm_weight.contiguous()
1154
+ if norm_bias is not None:
1155
+ norm_bias = norm_bias.contiguous()
1156
+ residual_dtype = (
1157
+ residual.dtype
1158
+ if residual is not None
1159
+ else (torch.float32 if residual_in_fp32 else None)
1160
+ )
1161
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1162
+ x,
1163
+ norm_weight,
1164
+ norm_bias,
1165
+ eps,
1166
+ residual,
1167
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
1168
+ residual_dtype=residual_dtype,
1169
+ is_rms_norm=is_rms_norm,
1170
+ )
1171
+ y = y.reshape(x_shape_og)
1172
+ dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
1173
+ linear_weight = linear_weight.to(dtype)
1174
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1175
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1176
+ # We don't store y, will be recomputed in the backward pass to save memory
1177
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
1178
+ ctx.x_shape_og = x_shape_og
1179
+ ctx.eps = eps
1180
+ ctx.is_rms_norm = is_rms_norm
1181
+ ctx.has_residual = residual is not None
1182
+ ctx.prenorm = prenorm
1183
+ ctx.x_dtype = x.dtype
1184
+ ctx.linear_bias_is_none = linear_bias is None
1185
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1186
+
1187
+ @staticmethod
1188
+ @custom_bwd
1189
+ def backward(ctx, dout, *args):
1190
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1191
+ dout = dout.reshape(-1, dout.shape[-1])
1192
+ dy = F.linear(dout, linear_weight.t())
1193
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1194
+ if dy.stride(-1) != 1:
1195
+ dy = dy.contiguous()
1196
+ assert dy.shape == x.shape
1197
+ if ctx.prenorm:
1198
+ dresidual = args[0]
1199
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1200
+ if dresidual.stride(-1) != 1:
1201
+ dresidual = dresidual.contiguous()
1202
+ assert dresidual.shape == x.shape
1203
+ else:
1204
+ dresidual = None
1205
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1206
+ dy,
1207
+ x,
1208
+ norm_weight,
1209
+ norm_bias,
1210
+ ctx.eps,
1211
+ mean,
1212
+ rstd,
1213
+ dresidual=dresidual,
1214
+ has_residual=ctx.has_residual,
1215
+ is_rms_norm=ctx.is_rms_norm,
1216
+ x_dtype=ctx.x_dtype,
1217
+ recompute_output=True,
1218
+ )
1219
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1220
+ return (
1221
+ dx.reshape(ctx.x_shape_og),
1222
+ dnorm_weight,
1223
+ dnorm_bias,
1224
+ dlinear_weight,
1225
+ dlinear_bias,
1226
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1227
+ None,
1228
+ None,
1229
+ None,
1230
+ None,
1231
+ )
1232
+
1233
+
1234
+ def layer_norm_linear_fn(
1235
+ x,
1236
+ norm_weight,
1237
+ norm_bias,
1238
+ linear_weight,
1239
+ linear_bias,
1240
+ residual=None,
1241
+ eps=1e-6,
1242
+ prenorm=False,
1243
+ residual_in_fp32=False,
1244
+ is_rms_norm=False,
1245
+ ):
1246
+ return LayerNormLinearFn.apply(
1247
+ x,
1248
+ norm_weight,
1249
+ norm_bias,
1250
+ linear_weight,
1251
+ linear_bias,
1252
+ residual,
1253
+ eps,
1254
+ prenorm,
1255
+ residual_in_fp32,
1256
+ is_rms_norm,
1257
+ )
omnigen2/optim/__init__.py ADDED
File without changes
omnigen2/optim/scheduler/__init__.py ADDED
File without changes
omnigen2/optim/scheduler/cosine_lr.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Cosine Scheduler
2
+
3
+ Cosine LR schedule with warmup, cycle/restarts, noise, k-decay.
4
+
5
+ Hacked together by / Copyright 2021 Ross Wightman
6
+ """
7
+ import logging
8
+ import math
9
+ import torch
10
+ from typing import List
11
+
12
+ from .scheduler import Scheduler
13
+
14
+
15
+ _logger = logging.getLogger(__name__)
16
+
17
+
18
+ class CosineLRScheduler(Scheduler):
19
+ """
20
+ Cosine decay with restarts.
21
+ This is described in the paper https://arxiv.org/abs/1608.03983.
22
+
23
+ Inspiration from
24
+ https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
25
+
26
+ k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ optimizer: torch.optim.Optimizer,
32
+ t_initial: int,
33
+ lr_min: float = 0.,
34
+ cycle_mul: float = 1.,
35
+ cycle_decay: float = 1.,
36
+ cycle_limit: int = 1,
37
+ warmup_t=0,
38
+ warmup_lr_init=0,
39
+ warmup_prefix=False,
40
+ t_in_epochs=True,
41
+ noise_range_t=None,
42
+ noise_pct=0.67,
43
+ noise_std=1.0,
44
+ noise_seed=42,
45
+ k_decay=1.0,
46
+ initialize=True,
47
+ ) -> None:
48
+ super().__init__(
49
+ optimizer,
50
+ param_group_field="lr",
51
+ t_in_epochs=t_in_epochs,
52
+ noise_range_t=noise_range_t,
53
+ noise_pct=noise_pct,
54
+ noise_std=noise_std,
55
+ noise_seed=noise_seed,
56
+ initialize=initialize,
57
+ )
58
+
59
+ assert t_initial > 0
60
+ assert lr_min >= 0
61
+ if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
62
+ _logger.warning(
63
+ "Cosine annealing scheduler will have no effect on the learning "
64
+ "rate since t_initial = t_mul = eta_mul = 1.")
65
+ self.t_initial = t_initial
66
+ self.lr_min = lr_min
67
+ self.cycle_mul = cycle_mul
68
+ self.cycle_decay = cycle_decay
69
+ self.cycle_limit = cycle_limit
70
+ self.warmup_t = warmup_t
71
+ self.warmup_lr_init = warmup_lr_init
72
+ self.warmup_prefix = warmup_prefix
73
+ self.k_decay = k_decay
74
+ if self.warmup_t:
75
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
76
+ super().update_groups(self.warmup_lr_init)
77
+ else:
78
+ self.warmup_steps = [1 for _ in self.base_values]
79
+
80
+ self._step_count = 0 # no use
81
+
82
+ def _get_lr(self, t: int) -> List[float]:
83
+
84
+ if t < self.warmup_t:
85
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
86
+ else:
87
+ if self.warmup_prefix:
88
+ t = t - self.warmup_t
89
+
90
+ if self.cycle_mul != 1:
91
+ i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
92
+ t_i = self.cycle_mul ** i * self.t_initial
93
+ t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
94
+ else:
95
+ i = t // self.t_initial
96
+ t_i = self.t_initial
97
+ t_curr = t - (self.t_initial * i)
98
+
99
+ gamma = self.cycle_decay ** i
100
+ lr_max_values = [v * gamma for v in self.base_values]
101
+ k = self.k_decay
102
+
103
+ if i < self.cycle_limit:
104
+ lrs = [
105
+ self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
106
+ for lr_max in lr_max_values
107
+ ]
108
+ else:
109
+ lrs = [self.lr_min for _ in self.base_values]
110
+
111
+ return lrs
112
+
113
+ def get_cycle_length(self, cycles=0):
114
+ cycles = max(1, cycles or self.cycle_limit)
115
+ if self.cycle_mul == 1.0:
116
+ return self.t_initial * cycles
117
+ else:
118
+ return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
omnigen2/optim/scheduler/scheduler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from abc import ABC
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+
7
+
8
+ class Scheduler(ABC):
9
+ """ Parameter Scheduler Base Class
10
+ A scheduler base class that can be used to schedule any optimizer parameter groups.
11
+
12
+ Unlike the builtin PyTorch schedulers, this is intended to be consistently called
13
+ * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
14
+ * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
15
+
16
+ The schedulers built on this should try to remain as stateless as possible (for simplicity).
17
+
18
+ This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
19
+ and -1 values for special behaviour. All epoch and update counts must be tracked in the training
20
+ code and explicitly passed in to the schedulers on the corresponding step or step_update call.
21
+
22
+ Based on ideas from:
23
+ * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
24
+ * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ optimizer: torch.optim.Optimizer,
30
+ param_group_field: str,
31
+ t_in_epochs: bool = True,
32
+ noise_range_t=None,
33
+ noise_type='normal',
34
+ noise_pct=0.67,
35
+ noise_std=1.0,
36
+ noise_seed=None,
37
+ initialize: bool = True,
38
+ ) -> None:
39
+ self.optimizer = optimizer
40
+ self.param_group_field = param_group_field
41
+ self._initial_param_group_field = f"initial_{param_group_field}"
42
+ if initialize:
43
+ for i, group in enumerate(self.optimizer.param_groups):
44
+ if param_group_field not in group:
45
+ raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
46
+ group.setdefault(self._initial_param_group_field, group[param_group_field])
47
+ else:
48
+ for i, group in enumerate(self.optimizer.param_groups):
49
+ if self._initial_param_group_field not in group:
50
+ raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
51
+ self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
52
+ self.metric = None # any point to having this for all?
53
+ self.t_in_epochs = t_in_epochs
54
+ self.noise_range_t = noise_range_t
55
+ self.noise_pct = noise_pct
56
+ self.noise_type = noise_type
57
+ self.noise_std = noise_std
58
+ self.noise_seed = noise_seed if noise_seed is not None else 42
59
+ self.update_groups(self.base_values)
60
+
61
+ def state_dict(self) -> Dict[str, Any]:
62
+ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
63
+
64
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
65
+ self.__dict__.update(state_dict)
66
+
67
+ def get_last_lr(self):
68
+ """ Return last computed learning rate by current scheduler.
69
+ """
70
+ return self._last_lr
71
+
72
+ @abc.abstractmethod
73
+ def _get_lr(self, t: int) -> List[float]:
74
+ pass
75
+
76
+ def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
77
+ return self._get_lr(t)
78
+
79
+ def step(self, epoch: int, metric: float = None) -> None:
80
+ self.metric = metric
81
+ values = self._get_values(epoch, on_epoch=True)
82
+ if values is not None:
83
+ values = self._add_noise(values, epoch)
84
+ self.update_groups(values)
85
+
86
+ # def step_update(self, num_updates: int, metric: float = None):
87
+ # self.metric = metric
88
+ # values = self._get_values(num_updates, on_epoch=False)
89
+ # if values is not None:
90
+ # values = self._add_noise(values, num_updates)
91
+ # self.update_groups(values)
92
+
93
+ def update_groups(self, values):
94
+ if not isinstance(values, (list, tuple)):
95
+ values = [values] * len(self.optimizer.param_groups)
96
+ for param_group, value in zip(self.optimizer.param_groups, values):
97
+ if 'lr_scale' in param_group:
98
+ param_group[self.param_group_field] = value * param_group['lr_scale']
99
+ else:
100
+ param_group[self.param_group_field] = value
101
+
102
+ self._last_lr = [group[self.param_group_field] for group in self.optimizer.param_groups]
103
+
104
+ def _add_noise(self, lrs, t):
105
+ if self._is_apply_noise(t):
106
+ noise = self._calculate_noise(t)
107
+ lrs = [v + v * noise for v in lrs]
108
+ return lrs
109
+
110
+ def _is_apply_noise(self, t) -> bool:
111
+ """Return True if scheduler in noise range."""
112
+ apply_noise = False
113
+ if self.noise_range_t is not None:
114
+ if isinstance(self.noise_range_t, (list, tuple)):
115
+ apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
116
+ else:
117
+ apply_noise = t >= self.noise_range_t
118
+ return apply_noise
119
+
120
+ def _calculate_noise(self, t) -> float:
121
+ g = torch.Generator()
122
+ g.manual_seed(self.noise_seed + t)
123
+ if self.noise_type == 'normal':
124
+ while True:
125
+ # resample if noise out of percent limit, brute force but shouldn't spin much
126
+ noise = torch.randn(1, generator=g).item()
127
+ if abs(noise) < self.noise_pct:
128
+ return noise
129
+ else:
130
+ noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
131
+ return noise
omnigen2/optim/scheduler/step_lr.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Step Scheduler
2
+
3
+ Basic step LR schedule with warmup, noise.
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import math
8
+ import torch
9
+ from typing import List
10
+
11
+
12
+ from .scheduler import Scheduler
13
+
14
+
15
+ class StepLRScheduler(Scheduler):
16
+ """
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ optimizer: torch.optim.Optimizer,
22
+ decay_t: float,
23
+ decay_rate: float = 1.,
24
+ warmup_t=0,
25
+ warmup_lr_init=0,
26
+ warmup_prefix=True,
27
+ t_in_epochs=True,
28
+ noise_range_t=None,
29
+ noise_pct=0.67,
30
+ noise_std=1.0,
31
+ noise_seed=42,
32
+ initialize=True,
33
+ ) -> None:
34
+ super().__init__(
35
+ optimizer,
36
+ param_group_field="lr",
37
+ t_in_epochs=t_in_epochs,
38
+ noise_range_t=noise_range_t,
39
+ noise_pct=noise_pct,
40
+ noise_std=noise_std,
41
+ noise_seed=noise_seed,
42
+ initialize=initialize,
43
+ )
44
+
45
+ self.decay_t = decay_t
46
+ self.decay_rate = decay_rate
47
+ self.warmup_t = warmup_t
48
+ self.warmup_lr_init = warmup_lr_init
49
+ self.warmup_prefix = warmup_prefix
50
+ if self.warmup_t:
51
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
52
+ super().update_groups(self.warmup_lr_init)
53
+ else:
54
+ self.warmup_steps = [1 for _ in self.base_values]
55
+
56
+ def _get_lr(self, t: int) -> List[float]:
57
+ if t < self.warmup_t:
58
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
59
+ else:
60
+ if self.warmup_prefix:
61
+ t = t - self.warmup_t
62
+ lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
63
+ return lrs
omnigen2/pipelines/__init__.py ADDED
File without changes
omnigen2/pipelines/image_processor.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+
23
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
24
+ from diffusers.configuration_utils import register_to_config
25
+
26
+ class OmniGen2ImageProcessor(VaeImageProcessor):
27
+ """
28
+ Image processor for PixArt image resize and crop.
29
+
30
+ Args:
31
+ do_resize (`bool`, *optional*, defaults to `True`):
32
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
33
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
34
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
35
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
36
+ resample (`str`, *optional*, defaults to `lanczos`):
37
+ Resampling filter to use when resizing the image.
38
+ do_normalize (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the image to [-1,1].
40
+ do_binarize (`bool`, *optional*, defaults to `False`):
41
+ Whether to binarize the image to 0/1.
42
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
43
+ Whether to convert the images to RGB format.
44
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
45
+ Whether to convert the images to grayscale format.
46
+ """
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ do_resize: bool = True,
52
+ vae_scale_factor: int = 16,
53
+ resample: str = "lanczos",
54
+ max_pixels: Optional[int] = None,
55
+ max_side_length: Optional[int] = None,
56
+ do_normalize: bool = True,
57
+ do_binarize: bool = False,
58
+ do_convert_grayscale: bool = False,
59
+ ):
60
+ super().__init__(
61
+ do_resize=do_resize,
62
+ vae_scale_factor=vae_scale_factor,
63
+ resample=resample,
64
+ do_normalize=do_normalize,
65
+ do_binarize=do_binarize,
66
+ do_convert_grayscale=do_convert_grayscale,
67
+ )
68
+
69
+ self.max_pixels = max_pixels
70
+ self.max_side_length = max_side_length
71
+
72
+ def get_new_height_width(
73
+ self,
74
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
75
+ height: Optional[int] = None,
76
+ width: Optional[int] = None,
77
+ max_pixels: Optional[int] = None,
78
+ max_side_length: Optional[int] = None,
79
+ ) -> Tuple[int, int]:
80
+ r"""
81
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
82
+
83
+ Args:
84
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
85
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
86
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
87
+ tensor, it should have shape `[batch, channels, height, width]`.
88
+ height (`Optional[int]`, *optional*, defaults to `None`):
89
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
90
+ width (`Optional[int]`, *optional*, defaults to `None`):
91
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
92
+
93
+ Returns:
94
+ `Tuple[int, int]`:
95
+ A tuple containing the height and width, both resized to the nearest integer multiple of
96
+ `vae_scale_factor`.
97
+ """
98
+
99
+ if height is None:
100
+ if isinstance(image, PIL.Image.Image):
101
+ height = image.height
102
+ elif isinstance(image, torch.Tensor):
103
+ height = image.shape[2]
104
+ else:
105
+ height = image.shape[1]
106
+
107
+ if width is None:
108
+ if isinstance(image, PIL.Image.Image):
109
+ width = image.width
110
+ elif isinstance(image, torch.Tensor):
111
+ width = image.shape[3]
112
+ else:
113
+ width = image.shape[2]
114
+
115
+ if max_side_length is None:
116
+ max_side_length = self.max_side_length
117
+
118
+ if max_pixels is None:
119
+ max_pixels = self.max_pixels
120
+
121
+ ratio = 1.0
122
+ if max_side_length is not None:
123
+ if height > width:
124
+ max_side_length_ratio = max_side_length / height
125
+ else:
126
+ max_side_length_ratio = max_side_length / width
127
+
128
+ cur_pixels = height * width
129
+ max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
130
+ ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
131
+
132
+ new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
133
+ return new_height, new_width
134
+
135
+ def preprocess(
136
+ self,
137
+ image: PipelineImageInput,
138
+ height: Optional[int] = None,
139
+ width: Optional[int] = None,
140
+ max_pixels: Optional[int] = None,
141
+ max_side_length: Optional[int] = None,
142
+ resize_mode: str = "default", # "default", "fill", "crop"
143
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
144
+ ) -> torch.Tensor:
145
+ """
146
+ Preprocess the image input.
147
+
148
+ Args:
149
+ image (`PipelineImageInput`):
150
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
151
+ supported formats.
152
+ height (`int`, *optional*):
153
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
154
+ height.
155
+ width (`int`, *optional*):
156
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
157
+ resize_mode (`str`, *optional*, defaults to `default`):
158
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
159
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
160
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
161
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
162
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
163
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
164
+ supported for PIL image input.
165
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
166
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
167
+
168
+ Returns:
169
+ `torch.Tensor`:
170
+ The preprocessed image.
171
+ """
172
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
173
+
174
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
175
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
176
+ if isinstance(image, torch.Tensor):
177
+ # if image is a pytorch tensor could have 2 possible shapes:
178
+ # 1. batch x height x width: we should insert the channel dimension at position 1
179
+ # 2. channel x height x width: we should insert batch dimension at position 0,
180
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
181
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
182
+ image = image.unsqueeze(1)
183
+ else:
184
+ # if it is a numpy array, it could have 2 possible shapes:
185
+ # 1. batch x height x width: insert channel dimension on last position
186
+ # 2. height x width x channel: insert batch dimension on first position
187
+ if image.shape[-1] == 1:
188
+ image = np.expand_dims(image, axis=0)
189
+ else:
190
+ image = np.expand_dims(image, axis=-1)
191
+
192
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
193
+ warnings.warn(
194
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
195
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
196
+ FutureWarning,
197
+ )
198
+ image = np.concatenate(image, axis=0)
199
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
200
+ warnings.warn(
201
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
202
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
203
+ FutureWarning,
204
+ )
205
+ image = torch.cat(image, axis=0)
206
+
207
+ if not is_valid_image_imagelist(image):
208
+ raise ValueError(
209
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
210
+ )
211
+ if not isinstance(image, list):
212
+ image = [image]
213
+
214
+ if isinstance(image[0], PIL.Image.Image):
215
+ if crops_coords is not None:
216
+ image = [i.crop(crops_coords) for i in image]
217
+ if self.config.do_resize:
218
+ height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
219
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
220
+ if self.config.do_convert_rgb:
221
+ image = [self.convert_to_rgb(i) for i in image]
222
+ elif self.config.do_convert_grayscale:
223
+ image = [self.convert_to_grayscale(i) for i in image]
224
+ image = self.pil_to_numpy(image) # to np
225
+ image = self.numpy_to_pt(image) # to pt
226
+
227
+ elif isinstance(image[0], np.ndarray):
228
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
229
+
230
+ image = self.numpy_to_pt(image)
231
+
232
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
233
+ if self.config.do_resize:
234
+ image = self.resize(image, height, width)
235
+
236
+ elif isinstance(image[0], torch.Tensor):
237
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
238
+
239
+ if self.config.do_convert_grayscale and image.ndim == 3:
240
+ image = image.unsqueeze(1)
241
+
242
+ channel = image.shape[1]
243
+ # don't need any preprocess if the image is latents
244
+ if channel == self.config.vae_latent_channels:
245
+ return image
246
+
247
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
248
+ if self.config.do_resize:
249
+ image = self.resize(image, height, width)
250
+
251
+ # expected range [0,1], normalize to [-1,1]
252
+ do_normalize = self.config.do_normalize
253
+ if do_normalize and image.min() < 0:
254
+ warnings.warn(
255
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
256
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
257
+ FutureWarning,
258
+ )
259
+ do_normalize = False
260
+ if do_normalize:
261
+ image = self.normalize(image)
262
+
263
+ if self.config.do_binarize:
264
+ image = self.binarize(image)
265
+
266
+ return image
omnigen2/pipelines/lora_pipeline.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from huggingface_hub.utils import validate_hf_hub_args
20
+
21
+ from diffusers.utils import (
22
+ USE_PEFT_BACKEND,
23
+ is_peft_available,
24
+ is_peft_version,
25
+ is_torch_version,
26
+ is_transformers_available,
27
+ is_transformers_version,
28
+ logging,
29
+ )
30
+ from diffusers.loaders.lora_base import ( # noqa
31
+ LoraBaseMixin,
32
+ _fetch_state_dict,
33
+ )
34
+ from diffusers.loaders.lora_conversion_utils import (
35
+ _convert_non_diffusers_lumina2_lora_to_diffusers,
36
+ )
37
+
38
+
39
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
40
+ if is_torch_version(">=", "1.9.0"):
41
+ if (
42
+ is_peft_available()
43
+ and is_peft_version(">=", "0.13.1")
44
+ and is_transformers_available()
45
+ and is_transformers_version(">", "4.45.2")
46
+ ):
47
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ TRANSFORMER_NAME = "transformer"
53
+
54
+ class OmniGen2LoraLoaderMixin(LoraBaseMixin):
55
+ r"""
56
+ Load LoRA layers into [`OmniGen2Transformer2DModel`]. Specific to [`OmniGen2Pipeline`].
57
+ """
58
+
59
+ _lora_loadable_modules = ["transformer"]
60
+ transformer_name = TRANSFORMER_NAME
61
+
62
+ @classmethod
63
+ @validate_hf_hub_args
64
+ def lora_state_dict(
65
+ cls,
66
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
67
+ **kwargs,
68
+ ):
69
+ r"""
70
+ Return state dict for lora weights and the network alphas.
71
+
72
+ <Tip warning={true}>
73
+
74
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
75
+
76
+ This function is experimental and might change in the future.
77
+
78
+ </Tip>
79
+
80
+ Parameters:
81
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
82
+ Can be either:
83
+
84
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
85
+ the Hub.
86
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
87
+ with [`ModelMixin.save_pretrained`].
88
+ - A [torch state
89
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
90
+
91
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
92
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
93
+ is not used.
94
+ force_download (`bool`, *optional*, defaults to `False`):
95
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
96
+ cached versions if they exist.
97
+
98
+ proxies (`Dict[str, str]`, *optional*):
99
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
100
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
101
+ local_files_only (`bool`, *optional*, defaults to `False`):
102
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
103
+ won't be downloaded from the Hub.
104
+ token (`str` or *bool*, *optional*):
105
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
106
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
107
+ revision (`str`, *optional*, defaults to `"main"`):
108
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
109
+ allowed by Git.
110
+ subfolder (`str`, *optional*, defaults to `""`):
111
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
112
+
113
+ """
114
+ # Load the main state dict first which has the LoRA layers for either of
115
+ # transformer and text encoder or both.
116
+ cache_dir = kwargs.pop("cache_dir", None)
117
+ force_download = kwargs.pop("force_download", False)
118
+ proxies = kwargs.pop("proxies", None)
119
+ local_files_only = kwargs.pop("local_files_only", None)
120
+ token = kwargs.pop("token", None)
121
+ revision = kwargs.pop("revision", None)
122
+ subfolder = kwargs.pop("subfolder", None)
123
+ weight_name = kwargs.pop("weight_name", None)
124
+ use_safetensors = kwargs.pop("use_safetensors", None)
125
+
126
+ allow_pickle = False
127
+ if use_safetensors is None:
128
+ use_safetensors = True
129
+ allow_pickle = True
130
+
131
+ user_agent = {
132
+ "file_type": "attn_procs_weights",
133
+ "framework": "pytorch",
134
+ }
135
+
136
+ state_dict = _fetch_state_dict(
137
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
138
+ weight_name=weight_name,
139
+ use_safetensors=use_safetensors,
140
+ local_files_only=local_files_only,
141
+ cache_dir=cache_dir,
142
+ force_download=force_download,
143
+ proxies=proxies,
144
+ token=token,
145
+ revision=revision,
146
+ subfolder=subfolder,
147
+ user_agent=user_agent,
148
+ allow_pickle=allow_pickle,
149
+ )
150
+
151
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
152
+ if is_dora_scale_present:
153
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
154
+ logger.warning(warn_msg)
155
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
156
+
157
+ # conversion.
158
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
159
+ if non_diffusers:
160
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
161
+
162
+ return state_dict
163
+
164
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
165
+ def load_lora_weights(
166
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
167
+ ):
168
+ """
169
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
170
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
171
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
172
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
173
+ dict is loaded into `self.transformer`.
174
+
175
+ Parameters:
176
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
177
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
178
+ adapter_name (`str`, *optional*):
179
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
180
+ `default_{i}` where i is the total number of adapters being loaded.
181
+ low_cpu_mem_usage (`bool`, *optional*):
182
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
183
+ weights.
184
+ kwargs (`dict`, *optional*):
185
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
186
+ """
187
+ if not USE_PEFT_BACKEND:
188
+ raise ValueError("PEFT backend is required for this method.")
189
+
190
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
191
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
192
+ raise ValueError(
193
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
194
+ )
195
+
196
+ # if a dict is passed, copy it instead of modifying it inplace
197
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
198
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
199
+
200
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
201
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
202
+
203
+ is_correct_format = all("lora" in key for key in state_dict.keys())
204
+ if not is_correct_format:
205
+ raise ValueError("Invalid LoRA checkpoint.")
206
+
207
+ self.load_lora_into_transformer(
208
+ state_dict,
209
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
210
+ adapter_name=adapter_name,
211
+ _pipeline=self,
212
+ low_cpu_mem_usage=low_cpu_mem_usage,
213
+ )
214
+
215
+ @classmethod
216
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
217
+ def load_lora_into_transformer(
218
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
219
+ ):
220
+ """
221
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
222
+
223
+ Parameters:
224
+ state_dict (`dict`):
225
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
226
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
227
+ encoder lora layers.
228
+ transformer (`Lumina2Transformer2DModel`):
229
+ The Transformer model to load the LoRA layers into.
230
+ adapter_name (`str`, *optional*):
231
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
232
+ `default_{i}` where i is the total number of adapters being loaded.
233
+ low_cpu_mem_usage (`bool`, *optional*):
234
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
235
+ weights.
236
+ hotswap : (`bool`, *optional*)
237
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
238
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
239
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
240
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
241
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
242
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
243
+
244
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
245
+ to call an additional method before loading the adapter:
246
+
247
+ ```py
248
+ pipeline = ... # load diffusers pipeline
249
+ max_rank = ... # the highest rank among all LoRAs that you want to load
250
+ # call *before* compiling and loading the LoRA adapter
251
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
252
+ pipeline.load_lora_weights(file_name)
253
+ # optionally compile the model now
254
+ ```
255
+
256
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
257
+ limitations to this technique, which are documented here:
258
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
259
+ """
260
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
261
+ raise ValueError(
262
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
263
+ )
264
+
265
+ # Load the layers corresponding to transformer.
266
+ logger.info(f"Loading {cls.transformer_name}.")
267
+ transformer.load_lora_adapter(
268
+ state_dict,
269
+ network_alphas=None,
270
+ adapter_name=adapter_name,
271
+ _pipeline=_pipeline,
272
+ low_cpu_mem_usage=low_cpu_mem_usage,
273
+ hotswap=hotswap,
274
+ )
275
+
276
+ @classmethod
277
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
278
+ def save_lora_weights(
279
+ cls,
280
+ save_directory: Union[str, os.PathLike],
281
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
282
+ is_main_process: bool = True,
283
+ weight_name: str = None,
284
+ save_function: Callable = None,
285
+ safe_serialization: bool = True,
286
+ ):
287
+ r"""
288
+ Save the LoRA parameters corresponding to the UNet and text encoder.
289
+
290
+ Arguments:
291
+ save_directory (`str` or `os.PathLike`):
292
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
293
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
294
+ State dict of the LoRA layers corresponding to the `transformer`.
295
+ is_main_process (`bool`, *optional*, defaults to `True`):
296
+ Whether the process calling this is the main process or not. Useful during distributed training and you
297
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
298
+ process to avoid race conditions.
299
+ save_function (`Callable`):
300
+ The function to use to save the state dictionary. Useful during distributed training when you need to
301
+ replace `torch.save` with another method. Can be configured with the environment variable
302
+ `DIFFUSERS_SAVE_MODE`.
303
+ safe_serialization (`bool`, *optional*, defaults to `True`):
304
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
305
+ """
306
+ state_dict = {}
307
+
308
+ if not transformer_lora_layers:
309
+ raise ValueError("You must pass `transformer_lora_layers`.")
310
+
311
+ if transformer_lora_layers:
312
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
313
+
314
+ # Save the model
315
+ cls.write_lora_layers(
316
+ state_dict=state_dict,
317
+ save_directory=save_directory,
318
+ is_main_process=is_main_process,
319
+ weight_name=weight_name,
320
+ save_function=save_function,
321
+ safe_serialization=safe_serialization,
322
+ )
323
+
324
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
325
+ def fuse_lora(
326
+ self,
327
+ components: List[str] = ["transformer"],
328
+ lora_scale: float = 1.0,
329
+ safe_fusing: bool = False,
330
+ adapter_names: Optional[List[str]] = None,
331
+ **kwargs,
332
+ ):
333
+ r"""
334
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
335
+
336
+ <Tip warning={true}>
337
+
338
+ This is an experimental API.
339
+
340
+ </Tip>
341
+
342
+ Args:
343
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
344
+ lora_scale (`float`, defaults to 1.0):
345
+ Controls how much to influence the outputs with the LoRA parameters.
346
+ safe_fusing (`bool`, defaults to `False`):
347
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
348
+ adapter_names (`List[str]`, *optional*):
349
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
350
+
351
+ Example:
352
+
353
+ ```py
354
+ from diffusers import DiffusionPipeline
355
+ import torch
356
+
357
+ pipeline = DiffusionPipeline.from_pretrained(
358
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
359
+ ).to("cuda")
360
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
361
+ pipeline.fuse_lora(lora_scale=0.7)
362
+ ```
363
+ """
364
+ super().fuse_lora(
365
+ components=components,
366
+ lora_scale=lora_scale,
367
+ safe_fusing=safe_fusing,
368
+ adapter_names=adapter_names,
369
+ **kwargs,
370
+ )
371
+
372
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
373
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
374
+ r"""
375
+ Reverses the effect of
376
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
377
+
378
+ <Tip warning={true}>
379
+
380
+ This is an experimental API.
381
+
382
+ </Tip>
383
+
384
+ Args:
385
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
386
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
387
+ """
388
+ super().unfuse_lora(components=components, **kwargs)
omnigen2/pipelines/omnigen2/pipeline_omnigen2.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniGen2 Diffusion Pipeline
3
+
4
+ Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import math
23
+
24
+ from PIL import Image
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ from transformers import Qwen2_5_VLForConditionalGeneration
30
+
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import OmniGen2Transformer2DModel
33
+ from ...models.transformers.repo import OmniGen2RotaryPosEmbed
34
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
+ from diffusers.utils import (
36
+ is_torch_xla_available,
37
+ logging,
38
+ )
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
41
+
42
+ from dataclasses import dataclass
43
+
44
+ import PIL.Image
45
+
46
+ from diffusers.utils import BaseOutput
47
+
48
+ from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
49
+
50
+ from omnigen2.utils.teacache_util import TeaCacheParams
51
+
52
+ from ..lora_pipeline import OmniGen2LoraLoaderMixin
53
+
54
+
55
+ if is_torch_xla_available():
56
+ import torch_xla.core.xla_model as xm
57
+
58
+ XLA_AVAILABLE = True
59
+ else:
60
+ XLA_AVAILABLE = False
61
+
62
+ from ...cache_functions import cache_init
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+ @dataclass
67
+ class FMPipelineOutput(BaseOutput):
68
+ """
69
+ Output class for OmniGen2 pipeline.
70
+
71
+ Args:
72
+ images (Union[List[PIL.Image.Image], np.ndarray]):
73
+ List of denoised PIL images of length `batch_size` or numpy array of shape
74
+ `(batch_size, height, width, num_channels)`. Contains the generated images.
75
+ """
76
+ images: Union[List[PIL.Image.Image], np.ndarray]
77
+
78
+
79
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
80
+ def retrieve_timesteps(
81
+ scheduler,
82
+ num_inference_steps: Optional[int] = None,
83
+ device: Optional[Union[str, torch.device]] = None,
84
+ timesteps: Optional[List[int]] = None,
85
+ **kwargs,
86
+ ):
87
+ """
88
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
89
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
90
+
91
+ Args:
92
+ scheduler (`SchedulerMixin`):
93
+ The scheduler to get timesteps from.
94
+ num_inference_steps (`int`):
95
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
96
+ must be `None`.
97
+ device (`str` or `torch.device`, *optional*):
98
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
99
+ timesteps (`List[int]`, *optional*):
100
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
101
+ `num_inference_steps` and `sigmas` must be `None`.
102
+ sigmas (`List[float]`, *optional*):
103
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
104
+ `num_inference_steps` and `timesteps` must be `None`.
105
+
106
+ Returns:
107
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
108
+ second element is the number of inference steps.
109
+ """
110
+ if timesteps is not None:
111
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
112
+ if not accepts_timesteps:
113
+ raise ValueError(
114
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
115
+ f" timestep schedules. Please check whether you are using the correct scheduler."
116
+ )
117
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
118
+ timesteps = scheduler.timesteps
119
+ num_inference_steps = len(timesteps)
120
+ else:
121
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ return timesteps, num_inference_steps
124
+
125
+
126
+ class OmniGen2Pipeline(DiffusionPipeline, OmniGen2LoraLoaderMixin):
127
+ """
128
+ Pipeline for text-to-image generation using OmniGen2.
129
+
130
+ This pipeline implements a text-to-image generation model that uses:
131
+ - Qwen2.5-VL for text encoding
132
+ - A custom transformer architecture for image generation
133
+ - VAE for image encoding/decoding
134
+ - FlowMatchEulerDiscreteScheduler for noise scheduling
135
+
136
+ Args:
137
+ transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
138
+ vae (AutoencoderKL): The VAE model for image encoding/decoding.
139
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
140
+ text_encoder (Qwen2_5_VLModel): The text encoder model.
141
+ tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
142
+ """
143
+
144
+ model_cpu_offload_seq = "mllm->transformer->vae"
145
+
146
+ def __init__(
147
+ self,
148
+ transformer: OmniGen2Transformer2DModel,
149
+ vae: AutoencoderKL,
150
+ scheduler: FlowMatchEulerDiscreteScheduler,
151
+ mllm: Qwen2_5_VLForConditionalGeneration,
152
+ processor,
153
+ ) -> None:
154
+ """
155
+ Initialize the OmniGen2 pipeline.
156
+
157
+ Args:
158
+ transformer: The transformer model for image generation.
159
+ vae: The VAE model for image encoding/decoding.
160
+ scheduler: The scheduler for noise scheduling.
161
+ text_encoder: The text encoder model.
162
+ tokenizer: The tokenizer for text processing.
163
+ """
164
+ super().__init__()
165
+
166
+ self.register_modules(
167
+ transformer=transformer,
168
+ vae=vae,
169
+ scheduler=scheduler,
170
+ mllm=mllm,
171
+ processor=processor
172
+ )
173
+ self.vae_scale_factor = (
174
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
175
+ )
176
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
177
+ self.default_sample_size = 128
178
+
179
+ def prepare_latents(
180
+ self,
181
+ batch_size: int,
182
+ num_channels_latents: int,
183
+ height: int,
184
+ width: int,
185
+ dtype: torch.dtype,
186
+ device: torch.device,
187
+ generator: Optional[torch.Generator],
188
+ latents: Optional[torch.FloatTensor] = None,
189
+ ) -> torch.FloatTensor:
190
+ """
191
+ Prepare the initial latents for the diffusion process.
192
+
193
+ Args:
194
+ batch_size: The number of images to generate.
195
+ num_channels_latents: The number of channels in the latent space.
196
+ height: The height of the generated image.
197
+ width: The width of the generated image.
198
+ dtype: The data type of the latents.
199
+ device: The device to place the latents on.
200
+ generator: The random number generator to use.
201
+ latents: Optional pre-computed latents to use instead of random initialization.
202
+
203
+ Returns:
204
+ torch.FloatTensor: The prepared latents tensor.
205
+ """
206
+ height = int(height) // self.vae_scale_factor
207
+ width = int(width) // self.vae_scale_factor
208
+
209
+ shape = (batch_size, num_channels_latents, height, width)
210
+
211
+ if latents is None:
212
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
213
+ else:
214
+ latents = latents.to(device)
215
+ return latents
216
+
217
+ def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
218
+ """
219
+ Encode an image into the VAE latent space.
220
+
221
+ Args:
222
+ img: The input image tensor to encode.
223
+
224
+ Returns:
225
+ torch.FloatTensor: The encoded latent representation.
226
+ """
227
+ z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
228
+ if self.vae.config.shift_factor is not None:
229
+ z0 = z0 - self.vae.config.shift_factor
230
+ if self.vae.config.scaling_factor is not None:
231
+ z0 = z0 * self.vae.config.scaling_factor
232
+ z0 = z0.to(dtype=self.vae.dtype)
233
+ return z0
234
+
235
+ def prepare_image(
236
+ self,
237
+ images: Union[List[PIL.Image.Image], PIL.Image.Image],
238
+ batch_size: int,
239
+ num_images_per_prompt: int,
240
+ max_pixels: int,
241
+ max_side_length: int,
242
+ device: torch.device,
243
+ dtype: torch.dtype,
244
+ ) -> List[Optional[torch.FloatTensor]]:
245
+ """
246
+ Prepare input images for processing by encoding them into the VAE latent space.
247
+
248
+ Args:
249
+ images: Single image or list of images to process.
250
+ batch_size: The number of images to generate per prompt.
251
+ num_images_per_prompt: The number of images to generate for each prompt.
252
+ device: The device to place the encoded latents on.
253
+ dtype: The data type of the encoded latents.
254
+
255
+ Returns:
256
+ List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
257
+ """
258
+ if batch_size == 1:
259
+ images = [images]
260
+ latents = []
261
+ for i, img in enumerate(images):
262
+ if img is not None and len(img) > 0:
263
+ ref_latents = []
264
+ for j, img_j in enumerate(img):
265
+ img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
266
+ ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
267
+ else:
268
+ ref_latents = None
269
+ for _ in range(num_images_per_prompt):
270
+ latents.append(ref_latents)
271
+
272
+ return latents
273
+
274
+ def _get_qwen2_prompt_embeds(
275
+ self,
276
+ prompt: Union[str, List[str]],
277
+ device: Optional[torch.device] = None,
278
+ max_sequence_length: int = 256,
279
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
280
+ """
281
+ Get prompt embeddings from the Qwen2 text encoder.
282
+
283
+ Args:
284
+ prompt: The prompt or list of prompts to encode.
285
+ device: The device to place the embeddings on. If None, uses the pipeline's device.
286
+ max_sequence_length: Maximum sequence length for tokenization.
287
+
288
+ Returns:
289
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
290
+ - The prompt embeddings tensor
291
+ - The attention mask tensor
292
+
293
+ Raises:
294
+ Warning: If the input text is truncated due to sequence length limitations.
295
+ """
296
+ device = device or self._execution_device
297
+ prompt = [prompt] if isinstance(prompt, str) else prompt
298
+ # text_inputs = self.processor.tokenizer(
299
+ # prompt,
300
+ # padding="max_length",
301
+ # max_length=max_sequence_length,
302
+ # truncation=True,
303
+ # return_tensors="pt",
304
+ # )
305
+ text_inputs = self.processor.tokenizer(
306
+ prompt,
307
+ padding="longest",
308
+ max_length=max_sequence_length,
309
+ truncation=True,
310
+ return_tensors="pt",
311
+ )
312
+
313
+ text_input_ids = text_inputs.input_ids.to(device)
314
+ untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
315
+
316
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
317
+ removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
318
+ logger.warning(
319
+ "The following part of your input was truncated because Gemma can only handle sequences up to"
320
+ f" {max_sequence_length} tokens: {removed_text}"
321
+ )
322
+
323
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
324
+ prompt_embeds = self.mllm(
325
+ text_input_ids,
326
+ attention_mask=prompt_attention_mask,
327
+ output_hidden_states=True,
328
+ ).hidden_states[-1]
329
+
330
+ if self.mllm is not None:
331
+ dtype = self.mllm.dtype
332
+ elif self.transformer is not None:
333
+ dtype = self.transformer.dtype
334
+ else:
335
+ dtype = None
336
+
337
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
338
+
339
+ return prompt_embeds, prompt_attention_mask
340
+
341
+ def _apply_chat_template(self, prompt: str):
342
+ prompt = [
343
+ {
344
+ "role": "system",
345
+ "content": "You are a helpful assistant that generates high-quality images based on user instructions.",
346
+ },
347
+ {"role": "user", "content": prompt},
348
+ ]
349
+ prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
350
+ return prompt
351
+
352
+ def encode_prompt(
353
+ self,
354
+ prompt: Union[str, List[str]],
355
+ do_classifier_free_guidance: bool = True,
356
+ negative_prompt: Optional[Union[str, List[str]]] = None,
357
+ num_images_per_prompt: int = 1,
358
+ device: Optional[torch.device] = None,
359
+ prompt_embeds: Optional[torch.Tensor] = None,
360
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
361
+ prompt_attention_mask: Optional[torch.Tensor] = None,
362
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
363
+ max_sequence_length: int = 256,
364
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
365
+ r"""
366
+ Encodes the prompt into text encoder hidden states.
367
+
368
+ Args:
369
+ prompt (`str` or `List[str]`, *optional*):
370
+ prompt to be encoded
371
+ negative_prompt (`str` or `List[str]`, *optional*):
372
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
373
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
374
+ Lumina-T2I, this should be "".
375
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
376
+ whether to use classifier free guidance or not
377
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
378
+ number of images that should be generated per prompt
379
+ device: (`torch.device`, *optional*):
380
+ torch device to place the resulting embeddings on
381
+ prompt_embeds (`torch.Tensor`, *optional*):
382
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
383
+ provided, text embeddings will be generated from `prompt` input argument.
384
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
385
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
386
+ max_sequence_length (`int`, defaults to `256`):
387
+ Maximum sequence length to use for the prompt.
388
+ """
389
+ device = device or self._execution_device
390
+
391
+ prompt = [prompt] if isinstance(prompt, str) else prompt
392
+ prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
393
+
394
+ if prompt is not None:
395
+ batch_size = len(prompt)
396
+ else:
397
+ batch_size = prompt_embeds.shape[0]
398
+ if prompt_embeds is None:
399
+ prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
400
+ prompt=prompt,
401
+ device=device,
402
+ max_sequence_length=max_sequence_length
403
+ )
404
+
405
+ batch_size, seq_len, _ = prompt_embeds.shape
406
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
407
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
408
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
409
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
410
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
411
+
412
+ # Get negative embeddings for classifier free guidance
413
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
414
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
415
+
416
+ # Normalize str to list
417
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
418
+ negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
419
+
420
+ if prompt is not None and type(prompt) is not type(negative_prompt):
421
+ raise TypeError(
422
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
423
+ f" {type(prompt)}."
424
+ )
425
+ elif isinstance(negative_prompt, str):
426
+ negative_prompt = [negative_prompt]
427
+ elif batch_size != len(negative_prompt):
428
+ raise ValueError(
429
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
430
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
431
+ " the batch size of `prompt`."
432
+ )
433
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
434
+ prompt=negative_prompt,
435
+ device=device,
436
+ max_sequence_length=max_sequence_length,
437
+ )
438
+
439
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
440
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
441
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
442
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
443
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
444
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
445
+ batch_size * num_images_per_prompt, -1
446
+ )
447
+
448
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
449
+
450
+ @property
451
+ def num_timesteps(self):
452
+ return self._num_timesteps
453
+
454
+ @property
455
+ def text_guidance_scale(self):
456
+ return self._text_guidance_scale
457
+
458
+ @property
459
+ def image_guidance_scale(self):
460
+ return self._image_guidance_scale
461
+
462
+ @property
463
+ def cfg_range(self):
464
+ return self._cfg_range
465
+
466
+ @torch.no_grad()
467
+ def __call__(
468
+ self,
469
+ prompt: Optional[Union[str, List[str]]] = None,
470
+ negative_prompt: Optional[Union[str, List[str]]] = None,
471
+ prompt_embeds: Optional[torch.FloatTensor] = None,
472
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
473
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
474
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
475
+ max_sequence_length: Optional[int] = None,
476
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
477
+ input_images: Optional[List[PIL.Image.Image]] = None,
478
+ num_images_per_prompt: int = 1,
479
+ height: Optional[int] = None,
480
+ width: Optional[int] = None,
481
+ max_pixels: int = 1024 * 1024,
482
+ max_input_image_side_length: int = 1024,
483
+ align_res: bool = True,
484
+ num_inference_steps: int = 28,
485
+ text_guidance_scale: float = 4.0,
486
+ image_guidance_scale: float = 1.0,
487
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
488
+ attention_kwargs: Optional[Dict[str, Any]] = None,
489
+ timesteps: List[int] = None,
490
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
491
+ latents: Optional[torch.FloatTensor] = None,
492
+ output_type: Optional[str] = "pil",
493
+ return_dict: bool = True,
494
+ verbose: bool = False,
495
+ step_func=None,
496
+ ):
497
+
498
+ height = height or self.default_sample_size * self.vae_scale_factor
499
+ width = width or self.default_sample_size * self.vae_scale_factor
500
+
501
+ self._text_guidance_scale = text_guidance_scale
502
+ self._image_guidance_scale = image_guidance_scale
503
+ self._cfg_range = cfg_range
504
+ self._attention_kwargs = attention_kwargs
505
+
506
+ # 2. Define call parameters
507
+ if prompt is not None and isinstance(prompt, str):
508
+ batch_size = 1
509
+ elif prompt is not None and isinstance(prompt, list):
510
+ batch_size = len(prompt)
511
+ else:
512
+ batch_size = prompt_embeds.shape[0]
513
+
514
+ device = self._execution_device
515
+
516
+ # 3. Encode input prompt
517
+ (
518
+ prompt_embeds,
519
+ prompt_attention_mask,
520
+ negative_prompt_embeds,
521
+ negative_prompt_attention_mask,
522
+ ) = self.encode_prompt(
523
+ prompt,
524
+ self.text_guidance_scale > 1.0,
525
+ negative_prompt=negative_prompt,
526
+ num_images_per_prompt=num_images_per_prompt,
527
+ device=device,
528
+ prompt_embeds=prompt_embeds,
529
+ negative_prompt_embeds=negative_prompt_embeds,
530
+ prompt_attention_mask=prompt_attention_mask,
531
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
532
+ max_sequence_length=max_sequence_length,
533
+ )
534
+
535
+ dtype = self.vae.dtype
536
+ # 3. Prepare control image
537
+ ref_latents = self.prepare_image(
538
+ images=input_images,
539
+ batch_size=batch_size,
540
+ num_images_per_prompt=num_images_per_prompt,
541
+ max_pixels=max_pixels,
542
+ max_side_length=max_input_image_side_length,
543
+ device=device,
544
+ dtype=dtype,
545
+ )
546
+
547
+ if input_images is None:
548
+ input_images = []
549
+
550
+ if len(input_images) == 1 and align_res:
551
+ width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
552
+ ori_width, ori_height = width, height
553
+ else:
554
+ ori_width, ori_height = width, height
555
+
556
+ cur_pixels = height * width
557
+ ratio = (max_pixels / cur_pixels) ** 0.5
558
+ ratio = min(ratio, 1.0)
559
+
560
+ height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
561
+
562
+ if len(input_images) == 0:
563
+ self._image_guidance_scale = 1
564
+
565
+ # 4. Prepare latents.
566
+ latent_channels = self.transformer.config.in_channels
567
+ latents = self.prepare_latents(
568
+ batch_size * num_images_per_prompt,
569
+ latent_channels,
570
+ height,
571
+ width,
572
+ prompt_embeds.dtype,
573
+ device,
574
+ generator,
575
+ latents,
576
+ )
577
+
578
+ freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
579
+ self.transformer.config.axes_dim_rope,
580
+ self.transformer.config.axes_lens,
581
+ theta=10000,
582
+ )
583
+
584
+ image = self.processing(
585
+ latents=latents,
586
+ ref_latents=ref_latents,
587
+ prompt_embeds=prompt_embeds,
588
+ freqs_cis=freqs_cis,
589
+ negative_prompt_embeds=negative_prompt_embeds,
590
+ prompt_attention_mask=prompt_attention_mask,
591
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
592
+ num_inference_steps=num_inference_steps,
593
+ timesteps=timesteps,
594
+ device=device,
595
+ dtype=dtype,
596
+ verbose=verbose,
597
+ step_func=step_func,
598
+ )
599
+
600
+ image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
601
+
602
+ image = self.image_processor.postprocess(image, output_type=output_type)
603
+
604
+ # Offload all models
605
+ self.maybe_free_model_hooks()
606
+
607
+ if not return_dict:
608
+ return image
609
+ else:
610
+ return FMPipelineOutput(images=image)
611
+
612
+ def processing(
613
+ self,
614
+ latents,
615
+ ref_latents,
616
+ prompt_embeds,
617
+ freqs_cis,
618
+ negative_prompt_embeds,
619
+ prompt_attention_mask,
620
+ negative_prompt_attention_mask,
621
+ num_inference_steps,
622
+ timesteps,
623
+ device,
624
+ dtype,
625
+ verbose,
626
+ step_func=None
627
+ ):
628
+ batch_size = latents.shape[0]
629
+
630
+ timesteps, num_inference_steps = retrieve_timesteps(
631
+ self.scheduler,
632
+ num_inference_steps,
633
+ device,
634
+ timesteps,
635
+ num_tokens=latents.shape[-2] * latents.shape[-1]
636
+ )
637
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
638
+ self._num_timesteps = len(timesteps)
639
+
640
+ enable_taylorseer = getattr(self, "enable_taylorseer", False)
641
+ if enable_taylorseer:
642
+ model_pred_cache_dic, model_pred_current = cache_init(self, num_inference_steps)
643
+ model_pred_ref_cache_dic, model_pred_ref_current = cache_init(self, num_inference_steps)
644
+ model_pred_uncond_cache_dic, model_pred_uncond_current = cache_init(self, num_inference_steps)
645
+ self.transformer.enable_taylorseer = True
646
+ elif self.transformer.enable_teacache:
647
+ # Use different TeaCacheParams for different conditions
648
+ teacache_params = TeaCacheParams()
649
+ teacache_params_uncond = TeaCacheParams()
650
+ teacache_params_ref = TeaCacheParams()
651
+
652
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
653
+ for i, t in enumerate(timesteps):
654
+ if enable_taylorseer:
655
+ self.transformer.cache_dic = model_pred_cache_dic
656
+ self.transformer.current = model_pred_current
657
+ elif self.transformer.enable_teacache:
658
+ teacache_params.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
659
+ self.transformer.teacache_params = teacache_params
660
+
661
+ model_pred = self.predict(
662
+ t=t,
663
+ latents=latents,
664
+ prompt_embeds=prompt_embeds,
665
+ freqs_cis=freqs_cis,
666
+ prompt_attention_mask=prompt_attention_mask,
667
+ ref_image_hidden_states=ref_latents,
668
+ )
669
+ text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
670
+ image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
671
+
672
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
673
+ if enable_taylorseer:
674
+ self.transformer.cache_dic = model_pred_ref_cache_dic
675
+ self.transformer.current = model_pred_ref_current
676
+ elif self.transformer.enable_teacache:
677
+ teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
678
+ self.transformer.teacache_params = teacache_params_ref
679
+
680
+ model_pred_ref = self.predict(
681
+ t=t,
682
+ latents=latents,
683
+ prompt_embeds=negative_prompt_embeds,
684
+ freqs_cis=freqs_cis,
685
+ prompt_attention_mask=negative_prompt_attention_mask,
686
+ ref_image_hidden_states=ref_latents,
687
+ )
688
+
689
+ if enable_taylorseer:
690
+ self.transformer.cache_dic = model_pred_uncond_cache_dic
691
+ self.transformer.current = model_pred_uncond_current
692
+ elif self.transformer.enable_teacache:
693
+ teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
694
+ self.transformer.teacache_params = teacache_params_uncond
695
+
696
+ model_pred_uncond = self.predict(
697
+ t=t,
698
+ latents=latents,
699
+ prompt_embeds=negative_prompt_embeds,
700
+ freqs_cis=freqs_cis,
701
+ prompt_attention_mask=negative_prompt_attention_mask,
702
+ ref_image_hidden_states=None,
703
+ )
704
+
705
+ model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
706
+ text_guidance_scale * (model_pred - model_pred_ref)
707
+ elif text_guidance_scale > 1.0:
708
+ if enable_taylorseer:
709
+ self.transformer.cache_dic = model_pred_uncond_cache_dic
710
+ self.transformer.current = model_pred_uncond_current
711
+ elif self.transformer.enable_teacache:
712
+ teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
713
+ self.transformer.teacache_params = teacache_params_uncond
714
+
715
+ model_pred_uncond = self.predict(
716
+ t=t,
717
+ latents=latents,
718
+ prompt_embeds=negative_prompt_embeds,
719
+ freqs_cis=freqs_cis,
720
+ prompt_attention_mask=negative_prompt_attention_mask,
721
+ ref_image_hidden_states=None,
722
+ )
723
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
724
+
725
+ latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
726
+
727
+ latents = latents.to(dtype=dtype)
728
+
729
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
730
+ progress_bar.update()
731
+
732
+ if step_func is not None:
733
+ step_func(i, self._num_timesteps)
734
+
735
+ if enable_taylorseer:
736
+ del model_pred_cache_dic, model_pred_ref_cache_dic, model_pred_uncond_cache_dic
737
+ del model_pred_current, model_pred_ref_current, model_pred_uncond_current
738
+
739
+ latents = latents.to(dtype=dtype)
740
+ if self.vae.config.scaling_factor is not None:
741
+ latents = latents / self.vae.config.scaling_factor
742
+ if self.vae.config.shift_factor is not None:
743
+ latents = latents + self.vae.config.shift_factor
744
+ image = self.vae.decode(latents, return_dict=False)[0]
745
+
746
+ return image
747
+
748
+ def predict(
749
+ self,
750
+ t,
751
+ latents,
752
+ prompt_embeds,
753
+ freqs_cis,
754
+ prompt_attention_mask,
755
+ ref_image_hidden_states,
756
+ ):
757
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
758
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
759
+
760
+ batch_size, num_channels_latents, height, width = latents.shape
761
+
762
+ optional_kwargs = {}
763
+ if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
764
+ optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
765
+
766
+ model_pred = self.transformer(
767
+ latents,
768
+ timestep,
769
+ prompt_embeds,
770
+ freqs_cis,
771
+ prompt_attention_mask,
772
+ **optional_kwargs
773
+ )
774
+ return model_pred
omnigen2/pipelines/omnigen2/pipeline_omnigen2_chat.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniGen2 Diffusion Pipeline
3
+
4
+ Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import math
23
+
24
+ from PIL import Image
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ from transformers import Qwen2_5_VLForConditionalGeneration
30
+
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import OmniGen2Transformer2DModel
33
+ from ...models.transformers.repo import OmniGen2RotaryPosEmbed
34
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
+ from diffusers.utils import (
36
+ is_torch_xla_available,
37
+ logging,
38
+ )
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
41
+
42
+ from dataclasses import dataclass
43
+
44
+ import PIL.Image
45
+
46
+ from diffusers.utils import BaseOutput
47
+
48
+ from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ @dataclass
61
+ class OmniGen2PipelineOutput(BaseOutput):
62
+ """
63
+ Output class for OmniGen2 pipeline.
64
+
65
+ Args:
66
+ images (Union[List[PIL.Image.Image], np.ndarray]):
67
+ List of denoised PIL images of length `batch_size` or numpy array of shape
68
+ `(batch_size, height, width, num_channels)`. Contains the generated images.
69
+ """
70
+ text: str
71
+ images: Union[List[PIL.Image.Image], np.ndarray]
72
+
73
+
74
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
75
+ def retrieve_timesteps(
76
+ scheduler,
77
+ num_inference_steps: Optional[int] = None,
78
+ device: Optional[Union[str, torch.device]] = None,
79
+ timesteps: Optional[List[int]] = None,
80
+ **kwargs,
81
+ ):
82
+ """
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None:
106
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
107
+ if not accepts_timesteps:
108
+ raise ValueError(
109
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
110
+ f" timestep schedules. Please check whether you are using the correct scheduler."
111
+ )
112
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
113
+ timesteps = scheduler.timesteps
114
+ num_inference_steps = len(timesteps)
115
+ else:
116
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
117
+ timesteps = scheduler.timesteps
118
+ return timesteps, num_inference_steps
119
+
120
+
121
+ class OmniGen2ChatPipeline(DiffusionPipeline):
122
+ """
123
+ Pipeline for text-to-image generation using OmniGen2.
124
+
125
+ This pipeline implements a text-to-image generation model that uses:
126
+ - Qwen2.5-VL for text encoding
127
+ - A custom transformer architecture for image generation
128
+ - VAE for image encoding/decoding
129
+ - FlowMatchEulerDiscreteScheduler for noise scheduling
130
+
131
+ Args:
132
+ transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
133
+ vae (AutoencoderKL): The VAE model for image encoding/decoding.
134
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
135
+ text_encoder (Qwen2_5_VLModel): The text encoder model.
136
+ tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
137
+ """
138
+
139
+ model_cpu_offload_seq = "mllm->transformer->vae"
140
+ def __init__(
141
+ self,
142
+ transformer: OmniGen2Transformer2DModel,
143
+ vae: AutoencoderKL,
144
+ scheduler: FlowMatchEulerDiscreteScheduler,
145
+ mllm: Qwen2_5_VLForConditionalGeneration,
146
+ processor,
147
+ ) -> None:
148
+ """
149
+ Initialize the OmniGen2 pipeline.
150
+
151
+ Args:
152
+ transformer: The transformer model for image generation.
153
+ vae: The VAE model for image encoding/decoding.
154
+ scheduler: The scheduler for noise scheduling.
155
+ text_encoder: The text encoder model.
156
+ tokenizer: The tokenizer for text processing.
157
+ """
158
+ super().__init__()
159
+
160
+ self.register_modules(
161
+ transformer=transformer,
162
+ vae=vae,
163
+ scheduler=scheduler,
164
+ mllm=mllm,
165
+ processor=processor
166
+ )
167
+ self.vae_scale_factor = (
168
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
169
+ )
170
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
171
+ self.default_sample_size = 128
172
+
173
+ def prepare_latents(
174
+ self,
175
+ batch_size: int,
176
+ num_channels_latents: int,
177
+ height: int,
178
+ width: int,
179
+ dtype: torch.dtype,
180
+ device: torch.device,
181
+ generator: Optional[torch.Generator],
182
+ latents: Optional[torch.FloatTensor] = None,
183
+ ) -> torch.FloatTensor:
184
+ """
185
+ Prepare the initial latents for the diffusion process.
186
+
187
+ Args:
188
+ batch_size: The number of images to generate.
189
+ num_channels_latents: The number of channels in the latent space.
190
+ height: The height of the generated image.
191
+ width: The width of the generated image.
192
+ dtype: The data type of the latents.
193
+ device: The device to place the latents on.
194
+ generator: The random number generator to use.
195
+ latents: Optional pre-computed latents to use instead of random initialization.
196
+
197
+ Returns:
198
+ torch.FloatTensor: The prepared latents tensor.
199
+ """
200
+ height = int(height) // self.vae_scale_factor
201
+ width = int(width) // self.vae_scale_factor
202
+
203
+ shape = (batch_size, num_channels_latents, height, width)
204
+
205
+ if latents is None:
206
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
207
+ else:
208
+ latents = latents.to(device)
209
+ return latents
210
+
211
+ def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
212
+ """
213
+ Encode an image into the VAE latent space.
214
+
215
+ Args:
216
+ img: The input image tensor to encode.
217
+
218
+ Returns:
219
+ torch.FloatTensor: The encoded latent representation.
220
+ """
221
+ z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
222
+ if self.vae.config.shift_factor is not None:
223
+ z0 = z0 - self.vae.config.shift_factor
224
+ if self.vae.config.scaling_factor is not None:
225
+ z0 = z0 * self.vae.config.scaling_factor
226
+ z0 = z0.to(dtype=self.vae.dtype)
227
+ return z0
228
+
229
+ def prepare_image(
230
+ self,
231
+ images: Union[List[PIL.Image.Image], PIL.Image.Image],
232
+ batch_size: int,
233
+ num_images_per_prompt: int,
234
+ max_pixels: int,
235
+ max_side_length: int,
236
+ device: torch.device,
237
+ dtype: torch.dtype,
238
+ ) -> List[Optional[torch.FloatTensor]]:
239
+ """
240
+ Prepare input images for processing by encoding them into the VAE latent space.
241
+
242
+ Args:
243
+ images: Single image or list of images to process.
244
+ batch_size: The number of images to generate per prompt.
245
+ num_images_per_prompt: The number of images to generate for each prompt.
246
+ device: The device to place the encoded latents on.
247
+ dtype: The data type of the encoded latents.
248
+
249
+ Returns:
250
+ List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
251
+ """
252
+ if batch_size == 1:
253
+ images = [images]
254
+ latents = []
255
+ for i, img in enumerate(images):
256
+ if img is not None and len(img) > 0:
257
+ ref_latents = []
258
+ for j, img_j in enumerate(img):
259
+ img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
260
+ ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
261
+ else:
262
+ ref_latents = None
263
+ for _ in range(num_images_per_prompt):
264
+ latents.append(ref_latents)
265
+
266
+ return latents
267
+
268
+ def _apply_chat_template(self, prompt: str, images: List = None):
269
+ if images is not None:
270
+ prompt = "".join(
271
+ [
272
+ f"<img{i}>: <|vision_start|><|image_pad|><|vision_end|>"
273
+ for i in range(1, len(images) + 1)
274
+ ]
275
+ ) + prompt
276
+ prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
277
+ return prompt
278
+
279
+ def _get_qwen2_prompt_embeds(
280
+ self,
281
+ prompt: Union[str, List[str]],
282
+ input_images = None,
283
+ device: Optional[torch.device] = None,
284
+ use_only_text_hidden_states: bool = True,
285
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
286
+ """
287
+ Get prompt embeddings from the Qwen2 text encoder.
288
+
289
+ Args:
290
+ prompt: The prompt or list of prompts to encode.
291
+ device: The device to place the embeddings on. If None, uses the pipeline's device.
292
+
293
+ Returns:
294
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
295
+ - The prompt embeddings tensor
296
+ - The attention mask tensor
297
+
298
+ Raises:
299
+ Warning: If the input text is truncated due to sequence length limitations.
300
+ """
301
+ device = device or self._execution_device
302
+ prompt = [prompt] if isinstance(prompt, str) else prompt
303
+
304
+ inputs = self.processor(
305
+ text=prompt,
306
+ images=input_images,
307
+ videos=None,
308
+ padding=True,
309
+ return_tensors="pt",
310
+ )
311
+ inputs = inputs.to(device)
312
+
313
+ prompt_embeds = self.mllm(
314
+ **inputs,
315
+ output_hidden_states=True,
316
+ ).hidden_states[-1]
317
+
318
+ text_input_ids = inputs.input_ids
319
+ text_mask = inputs.attention_mask
320
+ if use_only_text_hidden_states:
321
+ mask = text_input_ids != self.mllm.config.image_token_id
322
+ mask = mask & text_mask
323
+ mask = mask.bool()
324
+
325
+ text_l = mask.sum(dim=-1)
326
+ max_l = text_l.max()
327
+ text_batch_size = prompt_embeds.size(0)
328
+ new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype)
329
+ new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device)
330
+ for i in range(text_batch_size):
331
+ new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]]
332
+ new_text_mask[i, :text_l[i]] = 1
333
+
334
+ prompt_embeds = new_prompt_embeds
335
+ text_mask = new_text_mask
336
+
337
+ prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device)
338
+ return prompt_embeds, text_mask
339
+
340
+
341
+ def encode_prompt(
342
+ self,
343
+ prompt: Union[str, List[str]],
344
+ input_images: Optional[Union[str, List[str]]] = None,
345
+ do_classifier_free_guidance: bool = True,
346
+ negative_prompt: Optional[Union[str, List[str]]] = None,
347
+ num_images_per_prompt: int = 1,
348
+ device: Optional[torch.device] = None,
349
+ prompt_embeds: Optional[torch.Tensor] = None,
350
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
351
+ prompt_attention_mask: Optional[torch.Tensor] = None,
352
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
353
+ max_sequence_length: int = 256,
354
+ use_text_encoder_penultimate_layer_feats: bool = False
355
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
356
+ r"""
357
+ Encodes the prompt into text encoder hidden states.
358
+
359
+ Args:
360
+ prompt (`str` or `List[str]`, *optional*):
361
+ prompt to be encoded
362
+ negative_prompt (`str` or `List[str]`, *optional*):
363
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
364
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
365
+ Lumina-T2I, this should be "".
366
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
367
+ whether to use classifier free guidance or not
368
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
369
+ number of images that should be generated per prompt
370
+ device: (`torch.device`, *optional*):
371
+ torch device to place the resulting embeddings on
372
+ prompt_embeds (`torch.Tensor`, *optional*):
373
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
374
+ provided, text embeddings will be generated from `prompt` input argument.
375
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
376
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
377
+ max_sequence_length (`int`, defaults to `256`):
378
+ Maximum sequence length to use for the prompt.
379
+ """
380
+ device = device or self._execution_device
381
+
382
+ prompt = [prompt] if isinstance(prompt, str) else prompt
383
+
384
+ if prompt is not None:
385
+ batch_size = len(prompt)
386
+ else:
387
+ batch_size = prompt_embeds.shape[0]
388
+ if prompt_embeds is None:
389
+ prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
390
+ prompt=prompt,
391
+ input_images=input_images,
392
+ device=device,
393
+ )
394
+
395
+ batch_size, seq_len, _ = prompt_embeds.shape
396
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
397
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
398
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
399
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
400
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
401
+
402
+ # Get negative embeddings for classifier free guidance
403
+ negative_prompt_embeds, negative_prompt_attention_mask = None, None
404
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
405
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
406
+
407
+ # Normalize str to list
408
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
409
+ negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
410
+
411
+ if prompt is not None and type(prompt) is not type(negative_prompt):
412
+ raise TypeError(
413
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
414
+ f" {type(prompt)}."
415
+ )
416
+ elif isinstance(negative_prompt, str):
417
+ negative_prompt = [negative_prompt]
418
+ elif batch_size != len(negative_prompt):
419
+ raise ValueError(
420
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
421
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
422
+ " the batch size of `prompt`."
423
+ )
424
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
425
+ prompt=negative_prompt,
426
+ device=device,
427
+ )
428
+
429
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
430
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
431
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
432
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
433
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
434
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
435
+ batch_size * num_images_per_prompt, -1
436
+ )
437
+
438
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
439
+
440
+ @property
441
+ def num_timesteps(self):
442
+ return self._num_timesteps
443
+
444
+ @property
445
+ def text_guidance_scale(self):
446
+ return self._text_guidance_scale
447
+
448
+ @property
449
+ def image_guidance_scale(self):
450
+ return self._image_guidance_scale
451
+
452
+ @property
453
+ def cfg_range(self):
454
+ return self._cfg_range
455
+
456
+ def prepare_inputs_for_text_generation(self, prompts, input_images, device):
457
+ if isinstance(prompts, str):
458
+ prompts = [prompts]
459
+
460
+ ori_padding_side = self.processor.tokenizer.padding_side
461
+ self.processor.tokenizer.padding_side = "left"
462
+ inputs = self.processor(
463
+ text=prompts,
464
+ images=input_images,
465
+ videos=None,
466
+ padding=True,
467
+ return_tensors="pt",
468
+ ).to(device)
469
+ self.processor.tokenizer.padding_side = ori_padding_side
470
+ return inputs
471
+
472
+ def generate_text(self, prompt, input_images):
473
+ inputs = self.prepare_inputs_for_text_generation(
474
+ prompt, input_images, self.mllm.device
475
+ )
476
+ generated_ids = self.mllm.generate(
477
+ **inputs,
478
+ tokenizer=self.processor.tokenizer,
479
+ max_new_tokens=256,
480
+ stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"],
481
+ ) # stop_words=[151643, 151645, 151665]
482
+ generated_ids_trimmed = [
483
+ out_ids[len(in_ids) :]
484
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
485
+ ]
486
+ output_texts = self.processor.batch_decode(
487
+ generated_ids_trimmed,
488
+ # skip_special_tokens=True,
489
+ skip_special_tokens=False,
490
+ clean_up_tokenization_spaces=False,
491
+ )
492
+ return output_texts
493
+
494
+ def generate_image(
495
+ self,
496
+ prompt: Optional[Union[str, List[str]]] = None,
497
+ negative_prompt: Optional[Union[str, List[str]]] = None,
498
+ prompt_embeds: Optional[torch.FloatTensor] = None,
499
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
500
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
501
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
502
+ use_text_encoder_penultimate_layer_feats: bool = False,
503
+ max_sequence_length: Optional[int] = None,
504
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
505
+ input_images: Optional[List[PIL.Image.Image]] = None,
506
+ num_images_per_prompt: int = 1,
507
+ height: Optional[int] = None,
508
+ width: Optional[int] = None,
509
+ max_pixels: int = 1024 * 1024,
510
+ max_input_image_side_length: int = 1024,
511
+ align_res: bool = True,
512
+ num_inference_steps: int = 28,
513
+ text_guidance_scale: float = 4.0,
514
+ image_guidance_scale: float = 1.0,
515
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
516
+ attention_kwargs: Optional[Dict[str, Any]] = None,
517
+ timesteps: List[int] = None,
518
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
519
+ latents: Optional[torch.FloatTensor] = None,
520
+ output_type: Optional[str] = "pil",
521
+ return_dict: bool = True,
522
+ verbose: bool = False,
523
+ step_func=None,
524
+ ):
525
+ height = height or self.default_sample_size * self.vae_scale_factor
526
+ width = width or self.default_sample_size * self.vae_scale_factor
527
+
528
+ self._text_guidance_scale = text_guidance_scale
529
+ self._image_guidance_scale = image_guidance_scale
530
+ self._cfg_range = cfg_range
531
+ self._attention_kwargs = attention_kwargs
532
+
533
+ # 2. Define call parameters
534
+ if prompt is not None and isinstance(prompt, str):
535
+ batch_size = 1
536
+ elif prompt is not None and isinstance(prompt, list):
537
+ batch_size = len(prompt)
538
+ else:
539
+ batch_size = prompt_embeds.shape[0]
540
+
541
+ device = self._execution_device
542
+
543
+ # 3. Encode input promptb
544
+ (
545
+ prompt_embeds,
546
+ prompt_attention_mask,
547
+ negative_prompt_embeds,
548
+ negative_prompt_attention_mask,
549
+ ) = self.encode_prompt(
550
+ prompt,
551
+ input_images,
552
+ self.text_guidance_scale > 1.0,
553
+ negative_prompt=negative_prompt,
554
+ num_images_per_prompt=num_images_per_prompt,
555
+ device=device,
556
+ prompt_embeds=prompt_embeds,
557
+ negative_prompt_embeds=negative_prompt_embeds,
558
+ prompt_attention_mask=prompt_attention_mask,
559
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
560
+ max_sequence_length=max_sequence_length,
561
+ use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats
562
+ )
563
+
564
+ dtype = self.vae.dtype
565
+ # 3. Prepare control image
566
+ ref_latents = self.prepare_image(
567
+ images=input_images,
568
+ batch_size=batch_size,
569
+ num_images_per_prompt=num_images_per_prompt,
570
+ max_pixels=max_pixels,
571
+ max_side_length=max_input_image_side_length,
572
+ device=device,
573
+ dtype=dtype,
574
+ )
575
+
576
+ if input_images is None:
577
+ input_images = []
578
+
579
+ if len(input_images) == 1 and align_res:
580
+ width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
581
+ ori_width, ori_height = width, height
582
+ else:
583
+ ori_width, ori_height = width, height
584
+
585
+ cur_pixels = height * width
586
+ ratio = (max_pixels / cur_pixels) ** 0.5
587
+ ratio = min(ratio, 1.0)
588
+
589
+ height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
590
+
591
+ if len(input_images) == 0:
592
+ self._image_guidance_scale = 1
593
+
594
+ # 4. Prepare latents.
595
+ latent_channels = self.transformer.config.in_channels
596
+ latents = self.prepare_latents(
597
+ batch_size * num_images_per_prompt,
598
+ latent_channels,
599
+ height,
600
+ width,
601
+ prompt_embeds.dtype,
602
+ device,
603
+ generator,
604
+ latents,
605
+ )
606
+
607
+ freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
608
+ self.transformer.config.axes_dim_rope,
609
+ self.transformer.config.axes_lens,
610
+ theta=10000,
611
+ )
612
+
613
+ image = self.processing(
614
+ latents=latents,
615
+ ref_latents=ref_latents,
616
+ prompt_embeds=prompt_embeds,
617
+ freqs_cis=freqs_cis,
618
+ negative_prompt_embeds=negative_prompt_embeds,
619
+ prompt_attention_mask=prompt_attention_mask,
620
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
621
+ num_inference_steps=num_inference_steps,
622
+ timesteps=timesteps,
623
+ device=device,
624
+ dtype=dtype,
625
+ verbose=verbose,
626
+ step_func=step_func,
627
+ )
628
+
629
+ image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
630
+
631
+ image = self.image_processor.postprocess(image, output_type=output_type)
632
+
633
+ # Offload all models
634
+ self.maybe_free_model_hooks()
635
+ return image
636
+
637
+ @torch.no_grad()
638
+ def __call__(
639
+ self,
640
+ prompt: Optional[Union[str, List[str]]] = None,
641
+ negative_prompt: Optional[Union[str, List[str]]] = None,
642
+ prompt_embeds: Optional[torch.FloatTensor] = None,
643
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
644
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
645
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
646
+ use_text_encoder_penultimate_layer_feats: bool = False,
647
+ max_sequence_length: Optional[int] = None,
648
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
649
+ input_images: Optional[List[PIL.Image.Image]] = None,
650
+ num_images_per_prompt: int = 1,
651
+ height: Optional[int] = 1024,
652
+ width: Optional[int] = 1024,
653
+ max_pixels: Optional[int] = 1024 * 1024,
654
+ max_input_image_side_length: int = 1024,
655
+ align_res: bool = True,
656
+ num_inference_steps: int = 28,
657
+ text_guidance_scale: float = 4.0,
658
+ image_guidance_scale: float = 1.0,
659
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
660
+ attention_kwargs: Optional[Dict[str, Any]] = None,
661
+ timesteps: List[int] = None,
662
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
663
+ latents: Optional[torch.FloatTensor] = None,
664
+ output_type: Optional[str] = "pil",
665
+ return_dict: bool = True,
666
+ verbose: bool = False,
667
+ step_func=None,
668
+ ):
669
+ assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn"
670
+
671
+ # input_images = self.preprocess_images(input_images, max_input_image_size)
672
+ prompt = self._apply_chat_template(prompt, input_images)
673
+ generated_text = self.generate_text(prompt, input_images)[0]
674
+
675
+ images = None
676
+ if generated_text.startswith("<|img|>"):
677
+ #TODO: reuse the hidden state when generate text instead of re-generating
678
+ prompt = prompt + generated_text.split("<|img|>")[0]
679
+ images = self.generate_image(
680
+ prompt=prompt,
681
+ negative_prompt=negative_prompt,
682
+ use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats,
683
+ max_sequence_length=max_sequence_length,
684
+ input_images=input_images,
685
+ num_images_per_prompt=num_images_per_prompt,
686
+ height=height,
687
+ width=width,
688
+ max_pixels=max_pixels,
689
+ max_input_image_side_length=max_input_image_side_length,
690
+ align_res=align_res,
691
+ num_inference_steps=num_inference_steps,
692
+ text_guidance_scale=text_guidance_scale,
693
+ image_guidance_scale=image_guidance_scale,
694
+ cfg_range=cfg_range,
695
+ timesteps=timesteps,
696
+ generator=generator,
697
+ latents=latents,
698
+ return_dict=False,
699
+ verbose=verbose,
700
+ step_func=step_func,
701
+ )
702
+
703
+ generated_text = generated_text.replace("<|im_end|>", "")
704
+ if not return_dict:
705
+ return generated_text, images
706
+ else:
707
+ return OmniGen2PipelineOutput(text=generated_text, images=images)
708
+
709
+ def processing(
710
+ self,
711
+ latents,
712
+ ref_latents,
713
+ prompt_embeds,
714
+ freqs_cis,
715
+ negative_prompt_embeds,
716
+ prompt_attention_mask,
717
+ negative_prompt_attention_mask,
718
+ num_inference_steps,
719
+ timesteps,
720
+ device,
721
+ dtype,
722
+ verbose,
723
+ step_func=None
724
+ ):
725
+ batch_size = latents.shape[0]
726
+
727
+ timesteps, num_inference_steps = retrieve_timesteps(
728
+ self.scheduler,
729
+ num_inference_steps,
730
+ device,
731
+ timesteps,
732
+ num_tokens=latents.shape[-2] * latents.shape[-1]
733
+ )
734
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
735
+ self._num_timesteps = len(timesteps)
736
+
737
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
738
+ for i, t in enumerate(timesteps):
739
+ model_pred = self.predict(
740
+ t=t,
741
+ latents=latents,
742
+ prompt_embeds=prompt_embeds,
743
+ freqs_cis=freqs_cis,
744
+ prompt_attention_mask=prompt_attention_mask,
745
+ ref_image_hidden_states=ref_latents,
746
+ )
747
+
748
+ text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
749
+ image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
750
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
751
+ model_pred_ref = self.predict(
752
+ t=t,
753
+ latents=latents,
754
+ prompt_embeds=negative_prompt_embeds,
755
+ freqs_cis=freqs_cis,
756
+ prompt_attention_mask=negative_prompt_attention_mask,
757
+ ref_image_hidden_states=ref_latents,
758
+ )
759
+
760
+ if image_guidance_scale != 1:
761
+ model_pred_uncond = self.predict(
762
+ t=t,
763
+ latents=latents,
764
+ prompt_embeds=negative_prompt_embeds,
765
+ freqs_cis=freqs_cis,
766
+ prompt_attention_mask=negative_prompt_attention_mask,
767
+ ref_image_hidden_states=None,
768
+ )
769
+ else:
770
+ model_pred_uncond = torch.zeros_like(model_pred)
771
+
772
+ model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
773
+ text_guidance_scale * (model_pred - model_pred_ref)
774
+ elif text_guidance_scale > 1.0:
775
+ model_pred_uncond = self.predict(
776
+ t=t,
777
+ latents=latents,
778
+ prompt_embeds=negative_prompt_embeds,
779
+ freqs_cis=freqs_cis,
780
+ prompt_attention_mask=negative_prompt_attention_mask,
781
+ ref_image_hidden_states=None,
782
+ )
783
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
784
+
785
+ latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
786
+
787
+ latents = latents.to(dtype=dtype)
788
+
789
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
790
+ progress_bar.update()
791
+
792
+ if step_func is not None:
793
+ step_func(i, self._num_timesteps)
794
+
795
+ latents = latents.to(dtype=dtype)
796
+ if self.vae.config.scaling_factor is not None:
797
+ latents = latents / self.vae.config.scaling_factor
798
+ if self.vae.config.shift_factor is not None:
799
+ latents = latents + self.vae.config.shift_factor
800
+ image = self.vae.decode(latents, return_dict=False)[0]
801
+
802
+ return image
803
+
804
+ def predict(
805
+ self,
806
+ t,
807
+ latents,
808
+ prompt_embeds,
809
+ freqs_cis,
810
+ prompt_attention_mask,
811
+ ref_image_hidden_states,
812
+ ):
813
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
814
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
815
+
816
+ batch_size, num_channels_latents, height, width = latents.shape
817
+
818
+ optional_kwargs = {}
819
+ if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
820
+ optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
821
+
822
+ model_pred = self.transformer(
823
+ latents,
824
+ timestep,
825
+ prompt_embeds,
826
+ freqs_cis,
827
+ prompt_attention_mask,
828
+ **optional_kwargs
829
+ )
830
+ return model_pred
omnigen2/pipelines/pipeline_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
5
+ """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
6
+ :param pipeline:
7
+ :param prompt:
8
+ :param negative_prompt:
9
+ :param device:
10
+ :return:
11
+ """
12
+ max_length = pipeline.tokenizer.model_max_length
13
+
14
+ # simple way to determine length of tokens
15
+ # count_prompt = len(prompt.split(" "))
16
+ # count_negative_prompt = len(negative_prompt.split(" "))
17
+
18
+ # create the tensor based on which prompt is longer
19
+ # if count_prompt >= count_negative_prompt:
20
+ input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
21
+ # input_ids = pipeline.tokenizer(prompt, padding="max_length",
22
+ # max_length=pipeline.tokenizer.model_max_length,
23
+ # truncation=True,
24
+ # return_tensors="pt",).input_ids.to(device)
25
+ shape_max_length = input_ids.shape[-1]
26
+
27
+ if negative_prompt is not None:
28
+ negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
29
+ max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
30
+
31
+ # else:
32
+ # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
33
+ # shape_max_length = negative_ids.shape[-1]
34
+ # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
35
+ # max_length=shape_max_length).input_ids.to(device)
36
+
37
+ concat_embeds = []
38
+ neg_embeds = []
39
+ for i in range(0, shape_max_length, max_length):
40
+ if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
41
+ attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
42
+ else:
43
+ attention_mask = None
44
+ concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
45
+ attention_mask=attention_mask)[0])
46
+
47
+ if negative_prompt is not None:
48
+ if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
49
+ attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
50
+ else:
51
+ attention_mask = None
52
+ neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
53
+ attention_mask=attention_mask)[0])
54
+
55
+ concat_embeds = torch.cat(concat_embeds, dim=1)
56
+
57
+ if negative_prompt is not None:
58
+ neg_embeds = torch.cat(neg_embeds, dim=1)
59
+ else:
60
+ neg_embeds = None
61
+
62
+ return concat_embeds, neg_embeds
omnigen2/schedulers/__init__.py ADDED
File without changes
omnigen2/schedulers/scheduling_dpmsolver_multistep.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import deprecate, is_scipy_available
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
+
28
+
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
34
+ def betas_for_alpha_bar(
35
+ num_diffusion_timesteps,
36
+ max_beta=0.999,
37
+ alpha_transform_type="cosine",
38
+ ):
39
+ """
40
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
41
+ (1-beta) over time from t = [0,1].
42
+
43
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
44
+ to that part of the diffusion process.
45
+
46
+
47
+ Args:
48
+ num_diffusion_timesteps (`int`): the number of betas to produce.
49
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
50
+ prevent singularities.
51
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
52
+ Choose from `cosine` or `exp`
53
+
54
+ Returns:
55
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
56
+ """
57
+ if alpha_transform_type == "cosine":
58
+
59
+ def alpha_bar_fn(t):
60
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
61
+
62
+ elif alpha_transform_type == "exp":
63
+
64
+ def alpha_bar_fn(t):
65
+ return math.exp(t * -12.0)
66
+
67
+ else:
68
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
69
+
70
+ betas = []
71
+ for i in range(num_diffusion_timesteps):
72
+ t1 = i / num_diffusion_timesteps
73
+ t2 = (i + 1) / num_diffusion_timesteps
74
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
+ return torch.tensor(betas, dtype=torch.float32)
76
+
77
+
78
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
79
+ def rescale_zero_terminal_snr(betas):
80
+ """
81
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
82
+
83
+
84
+ Args:
85
+ betas (`torch.Tensor`):
86
+ the betas that the scheduler is being initialized with.
87
+
88
+ Returns:
89
+ `torch.Tensor`: rescaled betas with zero terminal SNR
90
+ """
91
+ # Convert betas to alphas_bar_sqrt
92
+ alphas = 1.0 - betas
93
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
94
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
95
+
96
+ # Store old values.
97
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
98
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
99
+
100
+ # Shift so the last timestep is zero.
101
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
102
+
103
+ # Scale so the first timestep is back to the old value.
104
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
105
+
106
+ # Convert alphas_bar_sqrt to betas
107
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
108
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
109
+ alphas = torch.cat([alphas_bar[0:1], alphas])
110
+ betas = 1 - alphas
111
+
112
+ return betas
113
+
114
+
115
+ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
116
+ """
117
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
118
+
119
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
120
+ methods the library implements for all schedulers such as loading and saving.
121
+
122
+ Args:
123
+ num_train_timesteps (`int`, defaults to 1000):
124
+ The number of diffusion steps to train the model.
125
+ beta_start (`float`, defaults to 0.0001):
126
+ The starting `beta` value of inference.
127
+ beta_end (`float`, defaults to 0.02):
128
+ The final `beta` value.
129
+ beta_schedule (`str`, defaults to `"linear"`):
130
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
131
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
132
+ trained_betas (`np.ndarray`, *optional*):
133
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
134
+ solver_order (`int`, defaults to 2):
135
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
136
+ sampling, and `solver_order=3` for unconditional sampling.
137
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
138
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
139
+ `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
140
+ Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
141
+ thresholding (`bool`, defaults to `False`):
142
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
143
+ as Stable Diffusion.
144
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
145
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
146
+ sample_max_value (`float`, defaults to 1.0):
147
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
148
+ `algorithm_type="dpmsolver++"`.
149
+ algorithm_type (`str`, defaults to `dpmsolver++`):
150
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
151
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
152
+ paper, and the `dpmsolver++` type implements the algorithms in the
153
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
154
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
155
+ solver_type (`str`, defaults to `midpoint`):
156
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
157
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
158
+ lower_order_final (`bool`, defaults to `True`):
159
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
160
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
161
+ euler_at_final (`bool`, defaults to `False`):
162
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
163
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
164
+ steps, but sometimes may result in blurring.
165
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
166
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
167
+ the sigmas are determined according to a sequence of noise levels {σi}.
168
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
169
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
173
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
174
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
175
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
176
+ `lambda(t)`.
177
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
178
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
179
+ flow_shift (`float`, *optional*, defaults to 1.0):
180
+ The shift value for the timestep schedule for flow matching.
181
+ final_sigmas_type (`str`, defaults to `"zero"`):
182
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
183
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
184
+ lambda_min_clipped (`float`, defaults to `-inf`):
185
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
186
+ cosine (`squaredcos_cap_v2`) noise schedule.
187
+ variance_type (`str`, *optional*):
188
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
189
+ contains the predicted Gaussian variance.
190
+ timestep_spacing (`str`, defaults to `"linspace"`):
191
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
192
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
193
+ steps_offset (`int`, defaults to 0):
194
+ An offset added to the inference steps, as required by some model families.
195
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
196
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
197
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
198
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
199
+ """
200
+
201
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
202
+ order = 1
203
+
204
+ @register_to_config
205
+ def __init__(
206
+ self,
207
+ num_train_timesteps: int = 1000,
208
+ beta_start: float = 0.0001,
209
+ beta_end: float = 0.02,
210
+ beta_schedule: str = "linear",
211
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
212
+ solver_order: int = 2,
213
+ prediction_type: str = "epsilon",
214
+ thresholding: bool = False,
215
+ dynamic_thresholding_ratio: float = 0.995,
216
+ sample_max_value: float = 1.0,
217
+ algorithm_type: str = "dpmsolver++",
218
+ solver_type: str = "midpoint",
219
+ lower_order_final: bool = True,
220
+ euler_at_final: bool = False,
221
+ final_sigmas_type: str = 'zero',
222
+ dynamic_time_shift: bool = True
223
+ ):
224
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
225
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
226
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
227
+
228
+ if trained_betas is not None:
229
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
230
+ elif beta_schedule == "linear":
231
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
232
+ elif beta_schedule == "scaled_linear":
233
+ # this schedule is very specific to the latent diffusion model.
234
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
235
+ elif beta_schedule == "squaredcos_cap_v2":
236
+ # Glide cosine schedule
237
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
238
+ else:
239
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
240
+ self.alphas = 1.0 - self.betas
241
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
242
+
243
+ # Currently we only support VP-type noise schedule
244
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
245
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
246
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
247
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
248
+
249
+ # standard deviation of the initial noise distribution
250
+ self.init_noise_sigma = 1.0
251
+
252
+ # settings for DPM-Solver
253
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
254
+ if algorithm_type == "deis":
255
+ self.register_to_config(algorithm_type="dpmsolver++")
256
+ else:
257
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
258
+
259
+ if solver_type not in ["midpoint", "heun"]:
260
+ if solver_type in ["logrho", "bh1", "bh2"]:
261
+ self.register_to_config(solver_type="midpoint")
262
+ else:
263
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
264
+
265
+ # if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
266
+ # raise ValueError(
267
+ # f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
268
+ # )
269
+
270
+ # setable values
271
+ self.num_inference_steps = None
272
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
273
+ self.timesteps = torch.from_numpy(timesteps)
274
+ self.model_outputs = [None] * solver_order
275
+ self.lower_order_nums = 0
276
+ self._step_index = None
277
+ self._begin_index = None
278
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
279
+
280
+ @property
281
+ def step_index(self):
282
+ """
283
+ The index counter for current timestep. It will increase 1 after each scheduler step.
284
+ """
285
+ return self._step_index
286
+
287
+ @property
288
+ def begin_index(self):
289
+ """
290
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
291
+ """
292
+ return self._begin_index
293
+
294
+ def set_begin_index(self, begin_index: int = 0):
295
+ """
296
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
297
+
298
+ Args:
299
+ begin_index (`int`):
300
+ The begin index for the scheduler.
301
+ """
302
+ self._begin_index = begin_index
303
+
304
+ def set_timesteps(
305
+ self,
306
+ num_inference_steps: int = None,
307
+ device: Union[str, torch.device] = None,
308
+ timesteps: Optional[List[int]] = None,
309
+ num_tokens: Optional[int] = None
310
+ ):
311
+ if timesteps is None:
312
+ self.num_inference_steps = num_inference_steps
313
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
314
+ if self.config.dynamic_time_shift and num_tokens is not None:
315
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
316
+ timesteps = timesteps / (m - m * timesteps + timesteps)
317
+
318
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
319
+ sigmas = torch.cat([1 - timesteps, torch.zeros(1, device=timesteps.device)])
320
+
321
+ self.sigmas = sigmas
322
+ self.timesteps = timesteps
323
+
324
+ self.num_inference_steps = len(timesteps)
325
+
326
+ self.model_outputs = [
327
+ None,
328
+ ] * self.config.solver_order
329
+ self.lower_order_nums = 0
330
+
331
+ # add an index counter for schedulers that allow duplicated timesteps
332
+ self._step_index = None
333
+ self._begin_index = None
334
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
335
+
336
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
337
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
338
+ """
339
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
340
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
341
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
342
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
343
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
344
+
345
+ https://arxiv.org/abs/2205.11487
346
+ """
347
+ dtype = sample.dtype
348
+ batch_size, channels, *remaining_dims = sample.shape
349
+
350
+ if dtype not in (torch.float32, torch.float64):
351
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
352
+
353
+ # Flatten sample for doing quantile calculation along each image
354
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
355
+
356
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
357
+
358
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
359
+ s = torch.clamp(
360
+ s, min=1, max=self.config.sample_max_value
361
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
362
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
363
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
364
+
365
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
366
+ sample = sample.to(dtype)
367
+
368
+ return sample
369
+
370
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
371
+ def _sigma_to_t(self, sigma, log_sigmas):
372
+ # get log sigma
373
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
374
+
375
+ # get distribution
376
+ dists = log_sigma - log_sigmas[:, np.newaxis]
377
+
378
+ # get sigmas range
379
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
380
+ high_idx = low_idx + 1
381
+
382
+ low = log_sigmas[low_idx]
383
+ high = log_sigmas[high_idx]
384
+
385
+ # interpolate sigmas
386
+ w = (low - log_sigma) / (low - high)
387
+ w = np.clip(w, 0, 1)
388
+
389
+ # transform interpolation to time range
390
+ t = (1 - w) * low_idx + w * high_idx
391
+ t = t.reshape(sigma.shape)
392
+ return t
393
+
394
+ def _sigma_to_alpha_sigma_t(self, sigma):
395
+ alpha_t = 1 - sigma
396
+ sigma_t = sigma
397
+
398
+ return alpha_t, sigma_t
399
+
400
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
401
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
402
+ """Constructs the noise schedule of Karras et al. (2022)."""
403
+
404
+ # Hack to make sure that other schedulers which copy this function don't break
405
+ # TODO: Add this logic to the other schedulers
406
+ if hasattr(self.config, "sigma_min"):
407
+ sigma_min = self.config.sigma_min
408
+ else:
409
+ sigma_min = None
410
+
411
+ if hasattr(self.config, "sigma_max"):
412
+ sigma_max = self.config.sigma_max
413
+ else:
414
+ sigma_max = None
415
+
416
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
417
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
418
+
419
+ rho = 7.0 # 7.0 is the value used in the paper
420
+ ramp = np.linspace(0, 1, num_inference_steps)
421
+ min_inv_rho = sigma_min ** (1 / rho)
422
+ max_inv_rho = sigma_max ** (1 / rho)
423
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
424
+ return sigmas
425
+
426
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
427
+ """Constructs the noise schedule of Lu et al. (2022)."""
428
+
429
+ lambda_min: float = in_lambdas[-1].item()
430
+ lambda_max: float = in_lambdas[0].item()
431
+
432
+ rho = 1.0 # 1.0 is the value used in the paper
433
+ ramp = np.linspace(0, 1, num_inference_steps)
434
+ min_inv_rho = lambda_min ** (1 / rho)
435
+ max_inv_rho = lambda_max ** (1 / rho)
436
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
437
+ return lambdas
438
+
439
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
440
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
441
+ """Constructs an exponential noise schedule."""
442
+
443
+ # Hack to make sure that other schedulers which copy this function don't break
444
+ # TODO: Add this logic to the other schedulers
445
+ if hasattr(self.config, "sigma_min"):
446
+ sigma_min = self.config.sigma_min
447
+ else:
448
+ sigma_min = None
449
+
450
+ if hasattr(self.config, "sigma_max"):
451
+ sigma_max = self.config.sigma_max
452
+ else:
453
+ sigma_max = None
454
+
455
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
456
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
457
+
458
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
459
+ return sigmas
460
+
461
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
462
+ def _convert_to_beta(
463
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
464
+ ) -> torch.Tensor:
465
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
466
+
467
+ # Hack to make sure that other schedulers which copy this function don't break
468
+ # TODO: Add this logic to the other schedulers
469
+ if hasattr(self.config, "sigma_min"):
470
+ sigma_min = self.config.sigma_min
471
+ else:
472
+ sigma_min = None
473
+
474
+ if hasattr(self.config, "sigma_max"):
475
+ sigma_max = self.config.sigma_max
476
+ else:
477
+ sigma_max = None
478
+
479
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
480
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
481
+
482
+ sigmas = np.array(
483
+ [
484
+ sigma_min + (ppf * (sigma_max - sigma_min))
485
+ for ppf in [
486
+ scipy.stats.beta.ppf(timestep, alpha, beta)
487
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
488
+ ]
489
+ ]
490
+ )
491
+ return sigmas
492
+
493
+ def convert_model_output(
494
+ self,
495
+ model_output: torch.Tensor,
496
+ *args,
497
+ sample: torch.Tensor = None,
498
+ **kwargs,
499
+ ) -> torch.Tensor:
500
+ """
501
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
502
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
503
+ integral of the data prediction model.
504
+
505
+ <Tip>
506
+
507
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
508
+ prediction and data prediction models.
509
+
510
+ </Tip>
511
+
512
+ Args:
513
+ model_output (`torch.Tensor`):
514
+ The direct output from the learned diffusion model.
515
+ sample (`torch.Tensor`):
516
+ A current instance of a sample created by the diffusion process.
517
+
518
+ Returns:
519
+ `torch.Tensor`:
520
+ The converted model output.
521
+ """
522
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
523
+ if sample is None:
524
+ if len(args) > 1:
525
+ sample = args[1]
526
+ else:
527
+ raise ValueError("missing `sample` as a required keyward argument")
528
+ if timestep is not None:
529
+ deprecate(
530
+ "timesteps",
531
+ "1.0.0",
532
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
533
+ )
534
+
535
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
536
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
537
+ if self.config.prediction_type == "epsilon":
538
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
539
+ if self.config.variance_type in ["learned", "learned_range"]:
540
+ model_output = model_output[:, :3]
541
+ sigma = self.sigmas[self.step_index]
542
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
543
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
544
+ elif self.config.prediction_type == "sample":
545
+ x0_pred = model_output
546
+ elif self.config.prediction_type == "v_prediction":
547
+ sigma = self.sigmas[self.step_index]
548
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
549
+ x0_pred = alpha_t * sample - sigma_t * model_output
550
+ elif self.config.prediction_type == "flow_prediction":
551
+ sigma_t = self.sigmas[self.step_index]
552
+ x0_pred = sample + sigma_t * model_output
553
+ else:
554
+ raise ValueError(
555
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
556
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
557
+ )
558
+
559
+ if self.config.thresholding:
560
+ x0_pred = self._threshold_sample(x0_pred)
561
+
562
+ return x0_pred
563
+
564
+ # DPM-Solver needs to solve an integral of the noise prediction model.
565
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
566
+ if self.config.prediction_type == "epsilon":
567
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
568
+ if self.config.variance_type in ["learned", "learned_range"]:
569
+ epsilon = model_output[:, :3]
570
+ else:
571
+ epsilon = model_output
572
+ elif self.config.prediction_type == "sample":
573
+ sigma = self.sigmas[self.step_index]
574
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
575
+ epsilon = (sample - alpha_t * model_output) / sigma_t
576
+ elif self.config.prediction_type == "v_prediction":
577
+ sigma = self.sigmas[self.step_index]
578
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
579
+ epsilon = alpha_t * model_output + sigma_t * sample
580
+ else:
581
+ raise ValueError(
582
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
583
+ " `v_prediction` for the DPMSolverMultistepScheduler."
584
+ )
585
+
586
+ if self.config.thresholding:
587
+ sigma = self.sigmas[self.step_index]
588
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
589
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
590
+ x0_pred = self._threshold_sample(x0_pred)
591
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
592
+
593
+ return epsilon
594
+
595
+ def dpm_solver_first_order_update(
596
+ self,
597
+ model_output: torch.Tensor,
598
+ *args,
599
+ sample: torch.Tensor = None,
600
+ noise: Optional[torch.Tensor] = None,
601
+ **kwargs,
602
+ ) -> torch.Tensor:
603
+ """
604
+ One step for the first-order DPMSolver (equivalent to DDIM).
605
+
606
+ Args:
607
+ model_output (`torch.Tensor`):
608
+ The direct output from the learned diffusion model.
609
+ sample (`torch.Tensor`):
610
+ A current instance of a sample created by the diffusion process.
611
+
612
+ Returns:
613
+ `torch.Tensor`:
614
+ The sample tensor at the previous timestep.
615
+ """
616
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
617
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
618
+ if sample is None:
619
+ if len(args) > 2:
620
+ sample = args[2]
621
+ else:
622
+ raise ValueError(" missing `sample` as a required keyward argument")
623
+ if timestep is not None:
624
+ deprecate(
625
+ "timesteps",
626
+ "1.0.0",
627
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
628
+ )
629
+
630
+ if prev_timestep is not None:
631
+ deprecate(
632
+ "prev_timestep",
633
+ "1.0.0",
634
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
635
+ )
636
+
637
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
638
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
639
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
640
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
641
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
642
+
643
+ h = lambda_t - lambda_s
644
+ if self.config.algorithm_type == "dpmsolver++":
645
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
646
+ elif self.config.algorithm_type == "dpmsolver":
647
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
648
+ elif self.config.algorithm_type == "sde-dpmsolver++":
649
+ assert noise is not None
650
+ x_t = (
651
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
652
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
653
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
654
+ )
655
+ elif self.config.algorithm_type == "sde-dpmsolver":
656
+ assert noise is not None
657
+ x_t = (
658
+ (alpha_t / alpha_s) * sample
659
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
660
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
661
+ )
662
+ return x_t
663
+
664
+ def multistep_dpm_solver_second_order_update(
665
+ self,
666
+ model_output_list: List[torch.Tensor],
667
+ *args,
668
+ sample: torch.Tensor = None,
669
+ noise: Optional[torch.Tensor] = None,
670
+ **kwargs,
671
+ ) -> torch.Tensor:
672
+ """
673
+ One step for the second-order multistep DPMSolver.
674
+
675
+ Args:
676
+ model_output_list (`List[torch.Tensor]`):
677
+ The direct outputs from learned diffusion model at current and latter timesteps.
678
+ sample (`torch.Tensor`):
679
+ A current instance of a sample created by the diffusion process.
680
+
681
+ Returns:
682
+ `torch.Tensor`:
683
+ The sample tensor at the previous timestep.
684
+ """
685
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
686
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
687
+ if sample is None:
688
+ if len(args) > 2:
689
+ sample = args[2]
690
+ else:
691
+ raise ValueError(" missing `sample` as a required keyward argument")
692
+ if timestep_list is not None:
693
+ deprecate(
694
+ "timestep_list",
695
+ "1.0.0",
696
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
697
+ )
698
+
699
+ if prev_timestep is not None:
700
+ deprecate(
701
+ "prev_timestep",
702
+ "1.0.0",
703
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
704
+ )
705
+
706
+ sigma_t, sigma_s0, sigma_s1 = (
707
+ self.sigmas[self.step_index + 1],
708
+ self.sigmas[self.step_index],
709
+ self.sigmas[self.step_index - 1],
710
+ )
711
+
712
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
713
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
714
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
715
+
716
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
717
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
718
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
719
+
720
+ m0, m1 = model_output_list[-1], model_output_list[-2]
721
+
722
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
723
+ r0 = h_0 / h
724
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
725
+ if self.config.algorithm_type == "dpmsolver++":
726
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
727
+ if self.config.solver_type == "midpoint":
728
+ x_t = (
729
+ (sigma_t / sigma_s0) * sample
730
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
731
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
732
+ )
733
+ elif self.config.solver_type == "heun":
734
+ x_t = (
735
+ (sigma_t / sigma_s0) * sample
736
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
737
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
738
+ )
739
+ elif self.config.algorithm_type == "dpmsolver":
740
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
741
+ if self.config.solver_type == "midpoint":
742
+ x_t = (
743
+ (alpha_t / alpha_s0) * sample
744
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
745
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
746
+ )
747
+ elif self.config.solver_type == "heun":
748
+ x_t = (
749
+ (alpha_t / alpha_s0) * sample
750
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
751
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
752
+ )
753
+ elif self.config.algorithm_type == "sde-dpmsolver++":
754
+ assert noise is not None
755
+ if self.config.solver_type == "midpoint":
756
+ x_t = (
757
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
758
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
759
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
760
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
761
+ )
762
+ elif self.config.solver_type == "heun":
763
+ x_t = (
764
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
765
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
766
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
767
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
768
+ )
769
+ elif self.config.algorithm_type == "sde-dpmsolver":
770
+ assert noise is not None
771
+ if self.config.solver_type == "midpoint":
772
+ x_t = (
773
+ (alpha_t / alpha_s0) * sample
774
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
775
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
776
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
777
+ )
778
+ elif self.config.solver_type == "heun":
779
+ x_t = (
780
+ (alpha_t / alpha_s0) * sample
781
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
782
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
783
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
784
+ )
785
+ return x_t
786
+
787
+ def multistep_dpm_solver_third_order_update(
788
+ self,
789
+ model_output_list: List[torch.Tensor],
790
+ *args,
791
+ sample: torch.Tensor = None,
792
+ noise: Optional[torch.Tensor] = None,
793
+ **kwargs,
794
+ ) -> torch.Tensor:
795
+ """
796
+ One step for the third-order multistep DPMSolver.
797
+
798
+ Args:
799
+ model_output_list (`List[torch.Tensor]`):
800
+ The direct outputs from learned diffusion model at current and latter timesteps.
801
+ sample (`torch.Tensor`):
802
+ A current instance of a sample created by diffusion process.
803
+
804
+ Returns:
805
+ `torch.Tensor`:
806
+ The sample tensor at the previous timestep.
807
+ """
808
+
809
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
810
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
811
+ if sample is None:
812
+ if len(args) > 2:
813
+ sample = args[2]
814
+ else:
815
+ raise ValueError(" missing`sample` as a required keyward argument")
816
+ if timestep_list is not None:
817
+ deprecate(
818
+ "timestep_list",
819
+ "1.0.0",
820
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
821
+ )
822
+
823
+ if prev_timestep is not None:
824
+ deprecate(
825
+ "prev_timestep",
826
+ "1.0.0",
827
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
828
+ )
829
+
830
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
831
+ self.sigmas[self.step_index + 1],
832
+ self.sigmas[self.step_index],
833
+ self.sigmas[self.step_index - 1],
834
+ self.sigmas[self.step_index - 2],
835
+ )
836
+
837
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
838
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
839
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
840
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
841
+
842
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
843
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
844
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
845
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
846
+
847
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
848
+
849
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
850
+ r0, r1 = h_0 / h, h_1 / h
851
+ D0 = m0
852
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
853
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
854
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
855
+ if self.config.algorithm_type == "dpmsolver++":
856
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
857
+ x_t = (
858
+ (sigma_t / sigma_s0) * sample
859
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
860
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
861
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
862
+ )
863
+ elif self.config.algorithm_type == "dpmsolver":
864
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
865
+ x_t = (
866
+ (alpha_t / alpha_s0) * sample
867
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
868
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
869
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
870
+ )
871
+ elif self.config.algorithm_type == "sde-dpmsolver++":
872
+ assert noise is not None
873
+ x_t = (
874
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
875
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
876
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
877
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
878
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
879
+ )
880
+ return x_t
881
+
882
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
883
+ if schedule_timesteps is None:
884
+ schedule_timesteps = self.timesteps
885
+
886
+ index_candidates = (schedule_timesteps == timestep).nonzero()
887
+
888
+ if len(index_candidates) == 0:
889
+ step_index = len(self.timesteps) - 1
890
+ # The sigma index that is taken for the **very** first `step`
891
+ # is always the second index (or the last index if there is only 1)
892
+ # This way we can ensure we don't accidentally skip a sigma in
893
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
894
+ elif len(index_candidates) > 1:
895
+ step_index = index_candidates[1].item()
896
+ else:
897
+ step_index = index_candidates[0].item()
898
+
899
+ return step_index
900
+
901
+ def _init_step_index(self, timestep):
902
+ """
903
+ Initialize the step_index counter for the scheduler.
904
+ """
905
+
906
+ if self.begin_index is None:
907
+ if isinstance(timestep, torch.Tensor):
908
+ timestep = timestep.to(self.timesteps.device)
909
+ self._step_index = self.index_for_timestep(timestep)
910
+ else:
911
+ self._step_index = self._begin_index
912
+
913
+ def step(
914
+ self,
915
+ model_output: torch.Tensor,
916
+ timestep: Union[int, torch.Tensor],
917
+ sample: torch.Tensor,
918
+ generator=None,
919
+ variance_noise: Optional[torch.Tensor] = None,
920
+ return_dict: bool = True,
921
+ ) -> Union[SchedulerOutput, Tuple]:
922
+ """
923
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
924
+ the multistep DPMSolver.
925
+
926
+ Args:
927
+ model_output (`torch.Tensor`):
928
+ The direct output from learned diffusion model.
929
+ timestep (`int`):
930
+ The current discrete timestep in the diffusion chain.
931
+ sample (`torch.Tensor`):
932
+ A current instance of a sample created by the diffusion process.
933
+ generator (`torch.Generator`, *optional*):
934
+ A random number generator.
935
+ variance_noise (`torch.Tensor`):
936
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
937
+ itself. Useful for methods such as [`LEdits++`].
938
+ return_dict (`bool`):
939
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
940
+
941
+ Returns:
942
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
943
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
944
+ tuple is returned where the first element is the sample tensor.
945
+
946
+ """
947
+ if self.num_inference_steps is None:
948
+ raise ValueError(
949
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
950
+ )
951
+
952
+ if self.step_index is None:
953
+ self._init_step_index(timestep)
954
+
955
+ # Improve numerical stability for small number of steps
956
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
957
+ self.config.euler_at_final
958
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
959
+ or self.config.final_sigmas_type == "zero"
960
+ )
961
+ lower_order_second = (
962
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
963
+ )
964
+
965
+ model_output = self.convert_model_output(model_output, sample=sample)
966
+ for i in range(self.config.solver_order - 1):
967
+ self.model_outputs[i] = self.model_outputs[i + 1]
968
+ self.model_outputs[-1] = model_output
969
+
970
+ # Upcast to avoid precision issues when computing prev_sample
971
+ sample = sample.to(torch.float32)
972
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
973
+ noise = randn_tensor(
974
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
975
+ )
976
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
977
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
978
+ else:
979
+ noise = None
980
+
981
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
982
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
983
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
984
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
985
+ else:
986
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
987
+
988
+ if self.lower_order_nums < self.config.solver_order:
989
+ self.lower_order_nums += 1
990
+
991
+ # Cast sample back to expected dtype
992
+ prev_sample = prev_sample.to(model_output.dtype)
993
+
994
+ # upon completion increase step index by one
995
+ self._step_index += 1
996
+
997
+ if not return_dict:
998
+ return (prev_sample,)
999
+
1000
+ return SchedulerOutput(prev_sample=prev_sample)
1001
+
1002
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1003
+ """
1004
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
1005
+ current timestep.
1006
+
1007
+ Args:
1008
+ sample (`torch.Tensor`):
1009
+ The input sample.
1010
+
1011
+ Returns:
1012
+ `torch.Tensor`:
1013
+ A scaled input sample.
1014
+ """
1015
+ return sample
1016
+
1017
+ def add_noise(
1018
+ self,
1019
+ original_samples: torch.Tensor,
1020
+ noise: torch.Tensor,
1021
+ timesteps: torch.IntTensor,
1022
+ ) -> torch.Tensor:
1023
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
1024
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
1025
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
1026
+ # mps does not support float64
1027
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
1028
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
1029
+ else:
1030
+ schedule_timesteps = self.timesteps.to(original_samples.device)
1031
+ timesteps = timesteps.to(original_samples.device)
1032
+
1033
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1034
+ if self.begin_index is None:
1035
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1036
+ elif self.step_index is not None:
1037
+ # add_noise is called after first denoising step (for inpainting)
1038
+ step_indices = [self.step_index] * timesteps.shape[0]
1039
+ else:
1040
+ # add noise is called before first denoising step to create initial latent(img2img)
1041
+ step_indices = [self.begin_index] * timesteps.shape[0]
1042
+
1043
+ sigma = sigmas[step_indices].flatten()
1044
+ while len(sigma.shape) < len(original_samples.shape):
1045
+ sigma = sigma.unsqueeze(-1)
1046
+
1047
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
1048
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
1049
+ return noisy_samples
1050
+
1051
+ def __len__(self):
1052
+ return self.config.num_train_timesteps
omnigen2/schedulers/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ dynamic_time_shift: bool = True
69
+ ):
70
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
71
+
72
+ self.timesteps = timesteps
73
+
74
+ self._step_index = None
75
+ self._begin_index = None
76
+
77
+ @property
78
+ def step_index(self):
79
+ """
80
+ The index counter for current timestep. It will increase 1 after each scheduler step.
81
+ """
82
+ return self._step_index
83
+
84
+ @property
85
+ def begin_index(self):
86
+ """
87
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
88
+ """
89
+ return self._begin_index
90
+
91
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
92
+ def set_begin_index(self, begin_index: int = 0):
93
+ """
94
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
95
+
96
+ Args:
97
+ begin_index (`int`):
98
+ The begin index for the scheduler.
99
+ """
100
+ self._begin_index = begin_index
101
+
102
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
103
+ if schedule_timesteps is None:
104
+ schedule_timesteps = self._timesteps
105
+
106
+ indices = (schedule_timesteps == timestep).nonzero()
107
+
108
+ # The sigma index that is taken for the **very** first `step`
109
+ # is always the second index (or the last index if there is only 1)
110
+ # This way we can ensure we don't accidentally skip a sigma in
111
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
112
+ pos = 1 if len(indices) > 1 else 0
113
+
114
+ return indices[pos].item()
115
+
116
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
117
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
118
+
119
+ def set_timesteps(
120
+ self,
121
+ num_inference_steps: int = None,
122
+ device: Union[str, torch.device] = None,
123
+ timesteps: Optional[List[float]] = None,
124
+ num_tokens: Optional[int] = None
125
+ ):
126
+ """
127
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
128
+
129
+ Args:
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model.
132
+ device (`str` or `torch.device`, *optional*):
133
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
134
+ """
135
+
136
+ if timesteps is None:
137
+ self.num_inference_steps = num_inference_steps
138
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
139
+ if self.config.dynamic_time_shift and num_tokens is not None:
140
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
141
+ timesteps = timesteps / (m - m * timesteps + timesteps)
142
+
143
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
144
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
145
+
146
+ self.timesteps = timesteps
147
+ self._timesteps = _timesteps
148
+ self._step_index = None
149
+ self._begin_index = None
150
+
151
+ def _init_step_index(self, timestep):
152
+ if self.begin_index is None:
153
+ if isinstance(timestep, torch.Tensor):
154
+ timestep = timestep.to(self.timesteps.device)
155
+ self._step_index = self.index_for_timestep(timestep)
156
+ else:
157
+ self._step_index = self._begin_index
158
+
159
+ def step(
160
+ self,
161
+ model_output: torch.FloatTensor,
162
+ timestep: Union[float, torch.FloatTensor],
163
+ sample: torch.FloatTensor,
164
+ generator: Optional[torch.Generator] = None,
165
+ return_dict: bool = True,
166
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
167
+ """
168
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
169
+ process from the learned model outputs (most often the predicted noise).
170
+
171
+ Args:
172
+ model_output (`torch.FloatTensor`):
173
+ The direct output from learned diffusion model.
174
+ timestep (`float`):
175
+ The current discrete timestep in the diffusion chain.
176
+ sample (`torch.FloatTensor`):
177
+ A current instance of a sample created by the diffusion process.
178
+ s_churn (`float`):
179
+ s_tmin (`float`):
180
+ s_tmax (`float`):
181
+ s_noise (`float`, defaults to 1.0):
182
+ Scaling factor for noise added to the sample.
183
+ generator (`torch.Generator`, *optional*):
184
+ A random number generator.
185
+ return_dict (`bool`):
186
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
187
+ tuple.
188
+
189
+ Returns:
190
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
191
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
192
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
193
+ """
194
+
195
+ if (
196
+ isinstance(timestep, int)
197
+ or isinstance(timestep, torch.IntTensor)
198
+ or isinstance(timestep, torch.LongTensor)
199
+ ):
200
+ raise ValueError(
201
+ (
202
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204
+ " one of the `scheduler.timesteps` as a timestep."
205
+ ),
206
+ )
207
+
208
+ if self.step_index is None:
209
+ self._init_step_index(timestep)
210
+ # Upcast to avoid precision issues when computing prev_sample
211
+ sample = sample.to(torch.float32)
212
+ t = self._timesteps[self.step_index]
213
+ t_next = self._timesteps[self.step_index + 1]
214
+
215
+ prev_sample = sample + (t_next - t) * model_output
216
+
217
+ # Cast sample back to model compatible dtype
218
+ prev_sample = prev_sample.to(model_output.dtype)
219
+
220
+ # upon completion increase step index by one
221
+ self._step_index += 1
222
+
223
+ if not return_dict:
224
+ return (prev_sample,)
225
+
226
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
227
+
228
+ def __len__(self):
229
+ return self.config.num_train_timesteps
omnigen2/taylorseer_utils/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/taylorseer_utils/__init__.py
2
+
3
+ from typing import Dict
4
+ import torch
5
+ import math
6
+
7
+ def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
8
+ """
9
+ Compute derivative approximation.
10
+
11
+ :param cache_dic: Cache dictionary
12
+ :param current: Information of the current step
13
+ """
14
+ difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2]
15
+
16
+ updated_taylor_factors = {}
17
+ updated_taylor_factors[0] = feature
18
+
19
+ for i in range(cache_dic['max_order']):
20
+ if (cache_dic['cache'][-1][current['stream']][current['layer']][current['module']].get(i, None) is not None) and (current['step'] > cache_dic['first_enhance'] - 2):
21
+ updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i]) / difference_distance
22
+ else:
23
+ break
24
+
25
+ cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = updated_taylor_factors
26
+
27
+ def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
28
+ """
29
+ Compute Taylor expansion error.
30
+
31
+ :param cache_dic: Cache dictionary
32
+ :param current: Information of the current step
33
+ """
34
+ x = current['step'] - current['activated_steps'][-1]
35
+ #x = current['t'] - current['activated_times'][-1]
36
+ output = 0
37
+
38
+ for i in range(len(cache_dic['cache'][-1][current['stream']][current['layer']][current['module']])):
39
+ output += (1 / math.factorial(i)) * cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i] * (x ** i)
40
+
41
+ return output
42
+
43
+ def taylor_cache_init(cache_dic: Dict, current: Dict):
44
+ """
45
+ Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache.
46
+
47
+ :param cache_dic: Cache dictionary
48
+ :param current: Information of the current step
49
+ """
50
+ if (current['step'] == 0) and (cache_dic['taylor_cache']):
51
+ cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = {}
omnigen2/training_utils.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import copy
3
+ import gc
4
+ import math
5
+ import random
6
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from diffusers.models import UNet2DConditionModel
12
+ from diffusers.schedulers import SchedulerMixin
13
+ from diffusers.utils import (
14
+ convert_state_dict_to_diffusers,
15
+ convert_state_dict_to_peft,
16
+ deprecate,
17
+ is_peft_available,
18
+ is_torch_npu_available,
19
+ is_torchvision_available,
20
+ is_transformers_available,
21
+ )
22
+
23
+
24
+ if is_transformers_available():
25
+ import transformers
26
+
27
+ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
28
+ import deepspeed
29
+
30
+ if is_peft_available():
31
+ from peft import set_peft_model_state_dict
32
+
33
+ if is_torchvision_available():
34
+ from torchvision import transforms
35
+
36
+ if is_torch_npu_available():
37
+ import torch_npu # noqa: F401
38
+
39
+
40
+ def set_seed(seed: int):
41
+ """
42
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
43
+
44
+ Args:
45
+ seed (`int`): The seed to set.
46
+
47
+ Returns:
48
+ `None`
49
+ """
50
+ random.seed(seed)
51
+ np.random.seed(seed)
52
+ torch.manual_seed(seed)
53
+ if is_torch_npu_available():
54
+ torch.npu.manual_seed_all(seed)
55
+ else:
56
+ torch.cuda.manual_seed_all(seed)
57
+ # ^^ safe to call this function even if cuda is not available
58
+
59
+
60
+ def compute_snr(noise_scheduler, timesteps):
61
+ """
62
+ Computes SNR as per
63
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
64
+ for the given timesteps using the provided noise scheduler.
65
+
66
+ Args:
67
+ noise_scheduler (`NoiseScheduler`):
68
+ An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
69
+ the SNR values.
70
+ timesteps (`torch.Tensor`):
71
+ A tensor of timesteps for which the SNR is computed.
72
+
73
+ Returns:
74
+ `torch.Tensor`: A tensor containing the computed SNR values for each timestep.
75
+ """
76
+ alphas_cumprod = noise_scheduler.alphas_cumprod
77
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
78
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
79
+
80
+ # Expand the tensors.
81
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
82
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
83
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
84
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
85
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
86
+
87
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
88
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
89
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
90
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
91
+
92
+ # Compute SNR.
93
+ snr = (alpha / sigma) ** 2
94
+ return snr
95
+
96
+
97
+ def resolve_interpolation_mode(interpolation_type: str):
98
+ """
99
+ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
100
+ full list of supported enums is documented at
101
+ https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
102
+
103
+ Args:
104
+ interpolation_type (`str`):
105
+ A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
106
+ `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
107
+ in torchvision.
108
+
109
+ Returns:
110
+ `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
111
+ transform.
112
+ """
113
+ if not is_torchvision_available():
114
+ raise ImportError(
115
+ "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
116
+ )
117
+
118
+ if interpolation_type == "bilinear":
119
+ interpolation_mode = transforms.InterpolationMode.BILINEAR
120
+ elif interpolation_type == "bicubic":
121
+ interpolation_mode = transforms.InterpolationMode.BICUBIC
122
+ elif interpolation_type == "box":
123
+ interpolation_mode = transforms.InterpolationMode.BOX
124
+ elif interpolation_type == "nearest":
125
+ interpolation_mode = transforms.InterpolationMode.NEAREST
126
+ elif interpolation_type == "nearest_exact":
127
+ interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
128
+ elif interpolation_type == "hamming":
129
+ interpolation_mode = transforms.InterpolationMode.HAMMING
130
+ elif interpolation_type == "lanczos":
131
+ interpolation_mode = transforms.InterpolationMode.LANCZOS
132
+ else:
133
+ raise ValueError(
134
+ f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
135
+ f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
136
+ )
137
+
138
+ return interpolation_mode
139
+
140
+
141
+ def compute_dream_and_update_latents(
142
+ unet: UNet2DConditionModel,
143
+ noise_scheduler: SchedulerMixin,
144
+ timesteps: torch.Tensor,
145
+ noise: torch.Tensor,
146
+ noisy_latents: torch.Tensor,
147
+ target: torch.Tensor,
148
+ encoder_hidden_states: torch.Tensor,
149
+ dream_detail_preservation: float = 1.0,
150
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
151
+ """
152
+ Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
153
+ DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
154
+ forward step without gradients.
155
+
156
+ Args:
157
+ `unet`: The state unet to use to make a prediction.
158
+ `noise_scheduler`: The noise scheduler used to add noise for the given timestep.
159
+ `timesteps`: The timesteps for the noise_scheduler to user.
160
+ `noise`: A tensor of noise in the shape of noisy_latents.
161
+ `noisy_latents`: Previously noise latents from the training loop.
162
+ `target`: The ground-truth tensor to predict after eps is removed.
163
+ `encoder_hidden_states`: Text embeddings from the text model.
164
+ `dream_detail_preservation`: A float value that indicates detail preservation level.
165
+ See reference.
166
+
167
+ Returns:
168
+ `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
169
+ """
170
+ alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
171
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
172
+
173
+ # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
174
+ dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
175
+
176
+ pred = None
177
+ with torch.no_grad():
178
+ pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
179
+
180
+ _noisy_latents, _target = (None, None)
181
+ if noise_scheduler.config.prediction_type == "epsilon":
182
+ predicted_noise = pred
183
+ delta_noise = (noise - predicted_noise).detach()
184
+ delta_noise.mul_(dream_lambda)
185
+ _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
186
+ _target = target.add(delta_noise)
187
+ elif noise_scheduler.config.prediction_type == "v_prediction":
188
+ raise NotImplementedError("DREAM has not been implemented for v-prediction")
189
+ else:
190
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
191
+
192
+ return _noisy_latents, _target
193
+
194
+
195
+ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
196
+ r"""
197
+ Returns:
198
+ A state dict containing just the LoRA parameters.
199
+ """
200
+ lora_state_dict = {}
201
+
202
+ for name, module in unet.named_modules():
203
+ if hasattr(module, "set_lora_layer"):
204
+ lora_layer = getattr(module, "lora_layer")
205
+ if lora_layer is not None:
206
+ current_lora_layer_sd = lora_layer.state_dict()
207
+ for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
208
+ # The matrix name can either be "down" or "up".
209
+ lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
210
+
211
+ return lora_state_dict
212
+
213
+
214
+ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
215
+ """
216
+ Casts the training parameters of the model to the specified data type.
217
+
218
+ Args:
219
+ model: The PyTorch model whose parameters will be cast.
220
+ dtype: The data type to which the model parameters will be cast.
221
+ """
222
+ if not isinstance(model, list):
223
+ model = [model]
224
+ for m in model:
225
+ for param in m.parameters():
226
+ # only upcast trainable parameters into fp32
227
+ if param.requires_grad:
228
+ param.data = param.to(dtype)
229
+
230
+
231
+ def _set_state_dict_into_text_encoder(
232
+ lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
233
+ ):
234
+ """
235
+ Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
236
+
237
+ Args:
238
+ lora_state_dict: The state dictionary to be set.
239
+ prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
240
+ text_encoder: Where the `lora_state_dict` is to be set.
241
+ """
242
+
243
+ text_encoder_state_dict = {
244
+ f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
245
+ }
246
+ text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
247
+ set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
248
+
249
+
250
+ def compute_density_for_timestep_sampling(
251
+ weighting_scheme: str,
252
+ batch_size: int,
253
+ logit_mean: float = None,
254
+ logit_std: float = None,
255
+ mode_scale: float = None,
256
+ device: Union[torch.device, str] = "cpu",
257
+ generator: Optional[torch.Generator] = None,
258
+ ):
259
+ """
260
+ Compute the density for sampling the timesteps when doing SD3 training.
261
+
262
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
263
+
264
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
265
+ """
266
+ if weighting_scheme == "logit_normal":
267
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
268
+ u = torch.nn.functional.sigmoid(u)
269
+ elif weighting_scheme == "mode":
270
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
271
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
272
+ else:
273
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
274
+ return u
275
+
276
+
277
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
278
+ """
279
+ Computes loss weighting scheme for SD3 training.
280
+
281
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
282
+
283
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
284
+ """
285
+ if weighting_scheme == "sigma_sqrt":
286
+ weighting = (sigmas**-2.0).float()
287
+ elif weighting_scheme == "cosmap":
288
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
289
+ weighting = 2 / (math.pi * bot)
290
+ else:
291
+ weighting = torch.ones_like(sigmas)
292
+ return weighting
293
+
294
+
295
+ def free_memory():
296
+ """
297
+ Runs garbage collection. Then clears the cache of the available accelerator.
298
+ """
299
+ gc.collect()
300
+
301
+ if torch.cuda.is_available():
302
+ torch.cuda.empty_cache()
303
+ elif torch.backends.mps.is_available():
304
+ torch.mps.empty_cache()
305
+ elif is_torch_npu_available():
306
+ torch_npu.npu.empty_cache()
307
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
308
+ torch.xpu.empty_cache()
309
+
310
+
311
+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
312
+ class EMAModel:
313
+ """
314
+ Exponential Moving Average of models weights
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ parameters: Iterable[torch.nn.Parameter],
320
+ decay: float = 0.9999,
321
+ min_decay: float = 0.0,
322
+ update_after_step: int = 0,
323
+ use_ema_warmup: bool = False,
324
+ inv_gamma: Union[float, int] = 1.0,
325
+ power: Union[float, int] = 2 / 3,
326
+ foreach: bool = False,
327
+ model_cls: Optional[Any] = None,
328
+ model_config: Dict[str, Any] = None,
329
+ **kwargs,
330
+ ):
331
+ """
332
+ Args:
333
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
334
+ decay (float): The decay factor for the exponential moving average.
335
+ min_decay (float): The minimum decay factor for the exponential moving average.
336
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
337
+ use_ema_warmup (bool): Whether to use EMA warmup.
338
+ inv_gamma (float):
339
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
340
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
341
+ foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
342
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
343
+ weights will be stored on CPU.
344
+
345
+ @crowsonkb's notes on EMA Warmup:
346
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
347
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
348
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
349
+ at 215.4k steps).
350
+ """
351
+
352
+ if isinstance(parameters, torch.nn.Module):
353
+ deprecation_message = (
354
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
355
+ "Please pass the parameters of the module instead."
356
+ )
357
+ deprecate(
358
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
359
+ "1.0.0",
360
+ deprecation_message,
361
+ standard_warn=False,
362
+ )
363
+ parameters = parameters.parameters()
364
+
365
+ # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
366
+ use_ema_warmup = True
367
+
368
+ if kwargs.get("max_value", None) is not None:
369
+ deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
370
+ deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
371
+ decay = kwargs["max_value"]
372
+
373
+ if kwargs.get("min_value", None) is not None:
374
+ deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
375
+ deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
376
+ min_decay = kwargs["min_value"]
377
+
378
+ parameters = list(parameters)
379
+ self.shadow_params = [p.clone().detach() for p in parameters]
380
+
381
+ if kwargs.get("device", None) is not None:
382
+ deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
383
+ deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
384
+ self.to(device=kwargs["device"])
385
+
386
+ self.temp_stored_params = None
387
+
388
+ self.decay = decay
389
+ self.min_decay = min_decay
390
+ self.update_after_step = update_after_step
391
+ self.use_ema_warmup = use_ema_warmup
392
+ self.inv_gamma = inv_gamma
393
+ self.power = power
394
+ self.optimization_step = 0
395
+ self.cur_decay_value = None # set in `step()`
396
+ self.foreach = foreach
397
+
398
+ self.model_cls = model_cls
399
+ self.model_config = model_config
400
+
401
+ @classmethod
402
+ def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
403
+ _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
404
+ model = model_cls.from_pretrained(path)
405
+
406
+ ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
407
+
408
+ ema_model.load_state_dict(ema_kwargs)
409
+ return ema_model
410
+
411
+ def save_pretrained(self, path):
412
+ if self.model_cls is None:
413
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
414
+
415
+ if self.model_config is None:
416
+ raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
417
+
418
+ model = self.model_cls.from_config(self.model_config)
419
+ state_dict = self.state_dict()
420
+ state_dict.pop("shadow_params", None)
421
+
422
+ model.register_to_config(**state_dict)
423
+ self.copy_to(model.parameters())
424
+ model.save_pretrained(path)
425
+
426
+ def get_decay(self, optimization_step: int) -> float:
427
+ """
428
+ Compute the decay factor for the exponential moving average.
429
+ """
430
+ step = max(0, optimization_step - self.update_after_step - 1)
431
+
432
+ if step <= 0:
433
+ return 0.0
434
+
435
+ if self.use_ema_warmup:
436
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
437
+ else:
438
+ cur_decay_value = (1 + step) / (10 + step)
439
+
440
+ cur_decay_value = min(cur_decay_value, self.decay)
441
+ # make sure decay is not smaller than min_decay
442
+ cur_decay_value = max(cur_decay_value, self.min_decay)
443
+ return cur_decay_value
444
+
445
+ @torch.no_grad()
446
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
447
+ if isinstance(parameters, torch.nn.Module):
448
+ deprecation_message = (
449
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
450
+ "Please pass the parameters of the module instead."
451
+ )
452
+ deprecate(
453
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
454
+ "1.0.0",
455
+ deprecation_message,
456
+ standard_warn=False,
457
+ )
458
+ parameters = parameters.parameters()
459
+
460
+ parameters = list(parameters)
461
+
462
+ self.optimization_step += 1
463
+
464
+ # Compute the decay factor for the exponential moving average.
465
+ decay = self.get_decay(self.optimization_step)
466
+ self.cur_decay_value = decay
467
+ one_minus_decay = 1 - decay
468
+
469
+ context_manager = contextlib.nullcontext()
470
+
471
+ if self.foreach:
472
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
473
+ context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
474
+
475
+ with context_manager:
476
+ params_grad = [param for param in parameters if param.requires_grad]
477
+ s_params_grad = [
478
+ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
479
+ ]
480
+
481
+ if len(params_grad) < len(parameters):
482
+ torch._foreach_copy_(
483
+ [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
484
+ [param for param in parameters if not param.requires_grad],
485
+ non_blocking=True,
486
+ )
487
+
488
+ torch._foreach_sub_(
489
+ s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
490
+ )
491
+
492
+ else:
493
+ for s_param, param in zip(self.shadow_params, parameters):
494
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
495
+ context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
496
+
497
+ with context_manager:
498
+ if param.requires_grad:
499
+ # print(f"{s_param.shape=} {param.shape=}")
500
+ s_param.sub_(one_minus_decay * (s_param - param))
501
+ else:
502
+ s_param.copy_(param)
503
+
504
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
505
+ """
506
+ Copy current averaged parameters into given collection of parameters.
507
+
508
+ Args:
509
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
510
+ updated with the stored moving averages. If `None`, the parameters with which this
511
+ `ExponentialMovingAverage` was initialized will be used.
512
+ """
513
+ parameters = list(parameters)
514
+ if self.foreach:
515
+ torch._foreach_copy_(
516
+ [param.data for param in parameters],
517
+ [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
518
+ )
519
+ else:
520
+ for s_param, param in zip(self.shadow_params, parameters):
521
+ param.data.copy_(s_param.to(param.device).data)
522
+
523
+ def pin_memory(self) -> None:
524
+ r"""
525
+ Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
526
+ offloading EMA params to the host.
527
+ """
528
+
529
+ self.shadow_params = [p.pin_memory() for p in self.shadow_params]
530
+
531
+ def to(self, device=None, dtype=None, non_blocking=False) -> None:
532
+ r"""
533
+ Move internal buffers of the ExponentialMovingAverage to `device`.
534
+
535
+ Args:
536
+ device: like `device` argument to `torch.Tensor.to`
537
+ """
538
+ # .to() on the tensors handles None correctly
539
+ self.shadow_params = [
540
+ p.to(device=device, dtype=dtype, non_blocking=non_blocking)
541
+ if p.is_floating_point()
542
+ else p.to(device=device, non_blocking=non_blocking)
543
+ for p in self.shadow_params
544
+ ]
545
+
546
+ def state_dict(self) -> dict:
547
+ r"""
548
+ Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
549
+ checkpointing to save the ema state dict.
550
+ """
551
+ # Following PyTorch conventions, references to tensors are returned:
552
+ # "returns a reference to the state and not its copy!" -
553
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
554
+ return {
555
+ "decay": self.decay,
556
+ "min_decay": self.min_decay,
557
+ "optimization_step": self.optimization_step,
558
+ "update_after_step": self.update_after_step,
559
+ "use_ema_warmup": self.use_ema_warmup,
560
+ "inv_gamma": self.inv_gamma,
561
+ "power": self.power,
562
+ "shadow_params": self.shadow_params,
563
+ }
564
+
565
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
566
+ r"""
567
+ Saves the current parameters for restoring later.
568
+
569
+ Args:
570
+ parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
571
+ """
572
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
573
+
574
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
575
+ r"""
576
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
577
+ without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
578
+ validation (or model saving), use this to restore the former parameters.
579
+
580
+ Args:
581
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
582
+ updated with the stored parameters. If `None`, the parameters with which this
583
+ `ExponentialMovingAverage` was initialized will be used.
584
+ """
585
+
586
+ if self.temp_stored_params is None:
587
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
588
+ if self.foreach:
589
+ torch._foreach_copy_(
590
+ [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
591
+ )
592
+ else:
593
+ for c_param, param in zip(self.temp_stored_params, parameters):
594
+ param.data.copy_(c_param.data)
595
+
596
+ # Better memory-wise.
597
+ self.temp_stored_params = None
598
+
599
+ def load_state_dict(self, state_dict: dict) -> None:
600
+ r"""
601
+ Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
602
+ ema state dict.
603
+
604
+ Args:
605
+ state_dict (dict): EMA state. Should be an object returned
606
+ from a call to :meth:`state_dict`.
607
+ """
608
+ # deepcopy, to be consistent with module API
609
+ state_dict = copy.deepcopy(state_dict)
610
+
611
+ self.decay = state_dict.get("decay", self.decay)
612
+ if self.decay < 0.0 or self.decay > 1.0:
613
+ raise ValueError("Decay must be between 0 and 1")
614
+
615
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
616
+ if not isinstance(self.min_decay, float):
617
+ raise ValueError("Invalid min_decay")
618
+
619
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
620
+ if not isinstance(self.optimization_step, int):
621
+ raise ValueError("Invalid optimization_step")
622
+
623
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
624
+ if not isinstance(self.update_after_step, int):
625
+ raise ValueError("Invalid update_after_step")
626
+
627
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
628
+ if not isinstance(self.use_ema_warmup, bool):
629
+ raise ValueError("Invalid use_ema_warmup")
630
+
631
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
632
+ if not isinstance(self.inv_gamma, (float, int)):
633
+ raise ValueError("Invalid inv_gamma")
634
+
635
+ self.power = state_dict.get("power", self.power)
636
+ if not isinstance(self.power, (float, int)):
637
+ raise ValueError("Invalid power")
638
+
639
+ shadow_params = state_dict.get("shadow_params", None)
640
+ if shadow_params is not None:
641
+ self.shadow_params = shadow_params
642
+ if not isinstance(self.shadow_params, list):
643
+ raise ValueError("shadow_params must be a list")
644
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
645
+ raise ValueError("shadow_params must all be Tensors")
omnigen2/transport/__init__.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transport import ModelType, PathType, Sampler, Transport, WeightType
2
+
3
+
4
+ def create_transport(
5
+ path_type="Linear",
6
+ prediction="velocity",
7
+ loss_weight=None,
8
+ train_eps=None,
9
+ sample_eps=None,
10
+ snr_type="uniform",
11
+ do_shift=True,
12
+ seq_len=1024, # corresponding to 512x512
13
+ dynamic_time_shift: bool = False,
14
+ time_shift_version: str = "v1",
15
+ ):
16
+ """function for creating Transport object
17
+ **Note**: model prediction defaults to velocity
18
+ Args:
19
+ - path_type: type of path to use; default to linear
20
+ - learn_score: set model prediction to score
21
+ - learn_noise: set model prediction to noise
22
+ - velocity_weighted: weight loss by velocity weight
23
+ - likelihood_weighted: weight loss by likelihood weight
24
+ - train_eps: small epsilon for avoiding instability during training
25
+ - sample_eps: small epsilon for avoiding instability during sampling
26
+ """
27
+
28
+ if prediction == "noise":
29
+ model_type = ModelType.NOISE
30
+ elif prediction == "score":
31
+ model_type = ModelType.SCORE
32
+ else:
33
+ model_type = ModelType.VELOCITY
34
+
35
+ if loss_weight == "velocity":
36
+ loss_type = WeightType.VELOCITY
37
+ elif loss_weight == "likelihood":
38
+ loss_type = WeightType.LIKELIHOOD
39
+ else:
40
+ loss_type = WeightType.NONE
41
+
42
+ path_choice = {
43
+ "Linear": PathType.LINEAR,
44
+ "GVP": PathType.GVP,
45
+ "VP": PathType.VP,
46
+ }
47
+
48
+ path_type = path_choice[path_type]
49
+
50
+ if path_type in [PathType.VP]:
51
+ train_eps = 1e-5 if train_eps is None else train_eps
52
+ sample_eps = 1e-3 if train_eps is None else sample_eps
53
+ elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
54
+ train_eps = 1e-3 if train_eps is None else train_eps
55
+ sample_eps = 1e-3 if train_eps is None else sample_eps
56
+ else: # velocity & [GVP, LINEAR] is stable everywhere
57
+ train_eps = 0
58
+ sample_eps = 0
59
+
60
+ # create flow state
61
+ state = Transport(
62
+ model_type=model_type,
63
+ path_type=path_type,
64
+ loss_type=loss_type,
65
+ train_eps=train_eps,
66
+ sample_eps=sample_eps,
67
+ snr_type=snr_type,
68
+ do_shift=do_shift,
69
+ seq_len=seq_len,
70
+ dynamic_time_shift=dynamic_time_shift,
71
+ time_shift_version=time_shift_version,
72
+ )
73
+
74
+ return state
omnigen2/transport/dpm_solver.py ADDED
@@ -0,0 +1,1386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
18
+ import os
19
+
20
+ import torch
21
+ from tqdm import tqdm
22
+
23
+
24
+ class NoiseScheduleFlow:
25
+ def __init__(
26
+ self,
27
+ schedule="discrete_flow",
28
+ ):
29
+ """Create a wrapper class for the forward SDE (EDM type)."""
30
+ self.T = 1
31
+ self.t0 = 0.001
32
+ self.schedule = schedule # ['continuous', 'discrete_flow']
33
+ self.total_N = 1000
34
+
35
+ def marginal_log_mean_coeff(self, t):
36
+ """
37
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
38
+ """
39
+ return torch.log(self.marginal_alpha(t))
40
+
41
+ def marginal_alpha(self, t):
42
+ """
43
+ Compute alpha_t of a given continuous-time label t in [0, T].
44
+ """
45
+ return 1 - t
46
+
47
+ @staticmethod
48
+ def marginal_std(t):
49
+ """
50
+ Compute sigma_t of a given continuous-time label t in [0, T].
51
+ """
52
+ return t
53
+
54
+ def marginal_lambda(self, t):
55
+ """
56
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
57
+ """
58
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
59
+ log_std = torch.log(self.marginal_std(t))
60
+ return log_mean_coeff - log_std
61
+
62
+ @staticmethod
63
+ def inverse_lambda(lamb):
64
+ """
65
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
66
+ """
67
+ return torch.exp(-lamb)
68
+
69
+
70
+ def model_wrapper(
71
+ model,
72
+ noise_schedule,
73
+ model_type="noise",
74
+ model_kwargs={},
75
+ guidance_type="uncond",
76
+ condition=None,
77
+ unconditional_condition=None,
78
+ guidance_scale=1.0,
79
+ interval_guidance=[0, 1.0],
80
+ classifier_fn=None,
81
+ classifier_kwargs={},
82
+ ):
83
+ """Create a wrapper function for the noise prediction model.
84
+
85
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
86
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
87
+
88
+ We support four types of the diffusion model by setting `model_type`:
89
+
90
+ 1. "noise": noise prediction model. (Trained by predicting noise).
91
+
92
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
93
+
94
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
95
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
96
+
97
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
98
+ arXiv preprint arXiv:2202.00512 (2022).
99
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
100
+ arXiv preprint arXiv:2210.02303 (2022).
101
+
102
+ 4. "score": marginal score function. (Trained by denoising score matching).
103
+ Note that the score function and the noise prediction model follows a simple relationship:
104
+ ```
105
+ noise(x_t, t) = -sigma_t * score(x_t, t)
106
+ ```
107
+
108
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
109
+ 1. "uncond": unconditional sampling by DPMs.
110
+ The input `model` has the following format:
111
+ ``
112
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
113
+ ``
114
+
115
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
116
+ The input `model` has the following format:
117
+ ``
118
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
119
+ ``
120
+
121
+ The input `classifier_fn` has the following format:
122
+ ``
123
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
124
+ ``
125
+
126
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
127
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
128
+
129
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
130
+ The input `model` has the following format:
131
+ ``
132
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
133
+ ``
134
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
135
+
136
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
137
+ arXiv preprint arXiv:2207.12598 (2022).
138
+
139
+
140
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
141
+ or continuous-time labels (i.e. epsilon to T).
142
+
143
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
144
+ ``
145
+ def model_fn(x, t_continuous) -> noise:
146
+ t_input = get_model_input_time(t_continuous)
147
+ return noise_pred(model, x, t_input, **model_kwargs)
148
+ ``
149
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
150
+
151
+ ===============================================================
152
+
153
+ Args:
154
+ model: A diffusion model with the corresponding format described above.
155
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
156
+ model_type: A `str`. The parameterization type of the diffusion model.
157
+ "noise" or "x_start" or "v" or "score".
158
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
159
+ guidance_type: A `str`. The type of the guidance for sampling.
160
+ "uncond" or "classifier" or "classifier-free".
161
+ condition: A pytorch tensor. The condition for the guided sampling.
162
+ Only used for "classifier" or "classifier-free" guidance type.
163
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
164
+ Only used for "classifier-free" guidance type.
165
+ guidance_scale: A `float`. The scale for the guided sampling.
166
+ classifier_fn: A classifier function. Only used for the classifier guidance.
167
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
168
+ Returns:
169
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
170
+ """
171
+
172
+ def get_model_input_time(t_continuous):
173
+ """
174
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
175
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
176
+ For continuous-time DPMs, we just use `t_continuous`.
177
+ """
178
+ if noise_schedule.schedule == "discrete":
179
+ return (t_continuous - 1.0 / noise_schedule.total_N) * noise_schedule.total_N
180
+ elif noise_schedule.schedule == "discrete_flow":
181
+ return t_continuous * noise_schedule.total_N
182
+ else:
183
+ return t_continuous
184
+
185
+ def noise_pred_fn(x, t_continuous, cond=None):
186
+ t_input = get_model_input_time(t_continuous)
187
+ if cond is None:
188
+ output = model(x, t_input, **model_kwargs)
189
+ else:
190
+ output = model(x, t_input, cond, **model_kwargs)
191
+ if model_type == "noise":
192
+ return output
193
+ elif model_type == "x_start":
194
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
195
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
196
+ elif model_type == "v":
197
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
198
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
199
+ elif model_type == "score":
200
+ sigma_t = noise_schedule.marginal_std(t_continuous)
201
+ return -expand_dims(sigma_t, x.dim()) * output
202
+ elif model_type == "flow":
203
+ _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
204
+ try:
205
+ noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output + x
206
+ except:
207
+ noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0] + x
208
+ return noise
209
+
210
+ def cond_grad_fn(x, t_input):
211
+ """
212
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
213
+ """
214
+ with torch.enable_grad():
215
+ x_in = x.detach().requires_grad_(True)
216
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
217
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
218
+
219
+ def model_fn(x, t_continuous):
220
+ """
221
+ The noise predicition model function that is used for DPM-Solver.
222
+ """
223
+ guidance_tp = guidance_type
224
+ if guidance_tp == "uncond":
225
+ return noise_pred_fn(x, t_continuous)
226
+ elif guidance_tp == "classifier":
227
+ assert classifier_fn is not None
228
+ t_input = get_model_input_time(t_continuous)
229
+ cond_grad = cond_grad_fn(x, t_input)
230
+ sigma_t = noise_schedule.marginal_std(t_continuous)
231
+ noise = noise_pred_fn(x, t_continuous)
232
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
233
+ elif guidance_tp == "classifier-free":
234
+ if (
235
+ guidance_scale == 1.0
236
+ or unconditional_condition is None
237
+ or not (interval_guidance[0] < t_continuous[0] < interval_guidance[1])
238
+ ):
239
+ return noise_pred_fn(x, t_continuous, cond=condition)
240
+ else:
241
+ x_in = torch.cat([x] * 2)
242
+ t_in = torch.cat([t_continuous] * 2)
243
+ c_in = torch.cat([unconditional_condition, condition])
244
+ try:
245
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
246
+ except:
247
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
248
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
249
+
250
+ assert model_type in ["noise", "x_start", "v", "score", "flow"]
251
+ assert guidance_type in [
252
+ "uncond",
253
+ "classifier",
254
+ "classifier-free",
255
+ ]
256
+ return model_fn
257
+
258
+
259
+ class DPM_Solver:
260
+ def __init__(
261
+ self,
262
+ model_fn,
263
+ noise_schedule,
264
+ algorithm_type="dpmsolver++",
265
+ correcting_x0_fn=None,
266
+ correcting_xt_fn=None,
267
+ thresholding_max_val=1.0,
268
+ dynamic_thresholding_ratio=0.995,
269
+ ):
270
+ """Construct a DPM-Solver.
271
+
272
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
273
+
274
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
275
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
276
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
277
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
278
+ DPMs (such as stable-diffusion).
279
+
280
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
281
+ both x0 and xt.
282
+
283
+ Args:
284
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
285
+ ``
286
+ def model_fn(x, t_continuous):
287
+ return noise
288
+ ``
289
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
290
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
291
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
292
+ correcting_x0_fn: A `str` or a function with the following format:
293
+ ```
294
+ def correcting_x0_fn(x0, t):
295
+ x0_new = ...
296
+ return x0_new
297
+ ```
298
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
299
+ ```
300
+ x0_pred = data_pred_model(xt, t)
301
+ if correcting_x0_fn is not None:
302
+ x0_pred = correcting_x0_fn(x0_pred, t)
303
+ xt_1 = update(x0_pred, xt, t)
304
+ ```
305
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
306
+ correcting_xt_fn: A function with the following format:
307
+ ```
308
+ def correcting_xt_fn(xt, t, step):
309
+ x_new = ...
310
+ return x_new
311
+ ```
312
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
313
+ ```
314
+ xt = ...
315
+ xt = correcting_xt_fn(xt, t, step)
316
+ ```
317
+ thresholding_max_val: A `float`. The max value for thresholding.
318
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
319
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
320
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
321
+
322
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
323
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
324
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
325
+ """
326
+ self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
327
+ self.noise_schedule = noise_schedule
328
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
329
+ self.algorithm_type = algorithm_type
330
+ if correcting_x0_fn == "dynamic_thresholding":
331
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
332
+ else:
333
+ self.correcting_x0_fn = correcting_x0_fn
334
+ self.correcting_xt_fn = correcting_xt_fn
335
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
336
+ self.thresholding_max_val = thresholding_max_val
337
+ self.register_progress_bar()
338
+
339
+ def register_progress_bar(self, progress_fn=None):
340
+ """
341
+ Register a progress bar callback function
342
+
343
+ Args:
344
+ progress_fn: Callback function that takes current step and total steps as parameters
345
+ """
346
+ self.progress_fn = progress_fn if progress_fn is not None else lambda step, total: None
347
+
348
+ def update_progress(self, step, total_steps):
349
+ """
350
+ Update sampling progress
351
+
352
+ Args:
353
+ step: Current step number
354
+ total_steps: Total number of steps
355
+ """
356
+ if hasattr(self, "progress_fn"):
357
+ try:
358
+ self.progress_fn(step / total_steps, desc=f"Generating {step}/{total_steps}")
359
+ except:
360
+ self.progress_fn(step, total_steps)
361
+
362
+ else:
363
+ # If no progress_fn registered, use default empty function
364
+ pass
365
+
366
+ def dynamic_thresholding_fn(self, x0, t):
367
+ """
368
+ The dynamic thresholding method.
369
+ """
370
+ dims = x0.dim()
371
+ p = self.dynamic_thresholding_ratio
372
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
373
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
374
+ x0 = torch.clamp(x0, -s, s) / s
375
+ return x0
376
+
377
+ def noise_prediction_fn(self, x, t):
378
+ """
379
+ Return the noise prediction model.
380
+ """
381
+ return self.model(x, t)
382
+
383
+ def data_prediction_fn(self, x, t):
384
+ """
385
+ Return the data prediction model (with corrector).
386
+ """
387
+ noise = self.noise_prediction_fn(x, t)
388
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
389
+ x0 = (x - sigma_t * noise) / alpha_t
390
+ if self.correcting_x0_fn is not None:
391
+ x0 = self.correcting_x0_fn(x0, t)
392
+ return x0
393
+
394
+ def model_fn(self, x, t):
395
+ """
396
+ Convert the model to the noise prediction model or the data prediction model.
397
+ """
398
+ if self.algorithm_type == "dpmsolver++":
399
+ return self.data_prediction_fn(x, t)
400
+ else:
401
+ return self.noise_prediction_fn(x, t)
402
+
403
+ def get_time_steps(self, skip_type, t_T, t_0, N, device, shift=1.0):
404
+ """Compute the intermediate time steps for sampling.
405
+
406
+ Args:
407
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
408
+ - 'logSNR': uniform logSNR for the time steps.
409
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
410
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
411
+ t_T: A `float`. The starting time of the sampling (default is T).
412
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
413
+ N: A `int`. The total number of the spacing of the time steps.
414
+ device: A torch device.
415
+ Returns:
416
+ A pytorch tensor of the time steps, with the shape (N + 1,).
417
+ """
418
+ if skip_type == "logSNR":
419
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
420
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
421
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
422
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
423
+ elif skip_type == "time_uniform":
424
+ return torch.linspace(t_T, t_0, N + 1).to(device)
425
+ elif skip_type == "time_quadratic":
426
+ t_order = 2
427
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
428
+ return t
429
+ elif skip_type == "time_uniform_flow":
430
+ betas = torch.linspace(t_T, t_0, N + 1).to(device)
431
+ sigmas = 1.0 - betas
432
+ sigmas = (shift * sigmas / (1 + (shift - 1) * sigmas)).flip(dims=[0])
433
+ return sigmas
434
+ else:
435
+ raise ValueError(
436
+ f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
437
+ )
438
+
439
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
440
+ """
441
+ Get the order of each step for sampling by the singlestep DPM-Solver.
442
+
443
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
444
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
445
+ - If order == 1:
446
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
447
+ - If order == 2:
448
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
449
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
450
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
451
+ - If order == 3:
452
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
453
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
454
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
455
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
456
+
457
+ ============================================
458
+ Args:
459
+ order: A `int`. The max order for the solver (2 or 3).
460
+ steps: A `int`. The total number of function evaluations (NFE).
461
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
462
+ - 'logSNR': uniform logSNR for the time steps.
463
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
464
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
465
+ t_T: A `float`. The starting time of the sampling (default is T).
466
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
467
+ device: A torch device.
468
+ Returns:
469
+ orders: A list of the solver order of each step.
470
+ """
471
+ if order == 3:
472
+ K = steps // 3 + 1
473
+ if steps % 3 == 0:
474
+ orders = [3,] * (
475
+ K - 2
476
+ ) + [2, 1]
477
+ elif steps % 3 == 1:
478
+ orders = [3,] * (
479
+ K - 1
480
+ ) + [1]
481
+ else:
482
+ orders = [3,] * (
483
+ K - 1
484
+ ) + [2]
485
+ elif order == 2:
486
+ if steps % 2 == 0:
487
+ K = steps // 2
488
+ orders = [
489
+ 2,
490
+ ] * K
491
+ else:
492
+ K = steps // 2 + 1
493
+ orders = [2,] * (
494
+ K - 1
495
+ ) + [1]
496
+ elif order == 1:
497
+ K = 1
498
+ orders = [
499
+ 1,
500
+ ] * steps
501
+ else:
502
+ raise ValueError("'order' must be '1' or '2' or '3'.")
503
+ if skip_type == "logSNR":
504
+ # To reproduce the results in DPM-Solver paper
505
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
506
+ else:
507
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
508
+ torch.cumsum(
509
+ torch.tensor(
510
+ [
511
+ 0,
512
+ ]
513
+ + orders
514
+ ),
515
+ 0,
516
+ ).to(device)
517
+ ]
518
+ return timesteps_outer, orders
519
+
520
+ def denoise_to_zero_fn(self, x, s):
521
+ """
522
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
523
+ """
524
+ return self.data_prediction_fn(x, s)
525
+
526
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
527
+ """
528
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
529
+
530
+ Args:
531
+ x: A pytorch tensor. The initial value at time `s`.
532
+ s: A pytorch tensor. The starting time, with the shape (1,).
533
+ t: A pytorch tensor. The ending time, with the shape (1,).
534
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
535
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
536
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
537
+ Returns:
538
+ x_t: A pytorch tensor. The approximated solution at time `t`.
539
+ """
540
+ ns = self.noise_schedule
541
+ dims = x.dim()
542
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
543
+ h = lambda_t - lambda_s
544
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
545
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
546
+ alpha_t = torch.exp(log_alpha_t)
547
+
548
+ if self.algorithm_type == "dpmsolver++":
549
+ phi_1 = torch.expm1(-h)
550
+ if model_s is None:
551
+ model_s = self.model_fn(x, s)
552
+ x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
553
+ if return_intermediate:
554
+ return x_t, {"model_s": model_s}
555
+ else:
556
+ return x_t
557
+ else:
558
+ phi_1 = torch.expm1(h)
559
+ if model_s is None:
560
+ model_s = self.model_fn(x, s)
561
+ x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
562
+ if return_intermediate:
563
+ return x_t, {"model_s": model_s}
564
+ else:
565
+ return x_t
566
+
567
+ def singlestep_dpm_solver_second_update(
568
+ self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"
569
+ ):
570
+ """
571
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
572
+
573
+ Args:
574
+ x: A pytorch tensor. The initial value at time `s`.
575
+ s: A pytorch tensor. The starting time, with the shape (1,).
576
+ t: A pytorch tensor. The ending time, with the shape (1,).
577
+ r1: A `float`. The hyperparameter of the second-order solver.
578
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
579
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
580
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
581
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
582
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
583
+ Returns:
584
+ x_t: A pytorch tensor. The approximated solution at time `t`.
585
+ """
586
+ if solver_type not in ["dpmsolver", "taylor"]:
587
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
588
+ if r1 is None:
589
+ r1 = 0.5
590
+ ns = self.noise_schedule
591
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
592
+ h = lambda_t - lambda_s
593
+ lambda_s1 = lambda_s + r1 * h
594
+ s1 = ns.inverse_lambda(lambda_s1)
595
+ log_alpha_s, log_alpha_s1, log_alpha_t = (
596
+ ns.marginal_log_mean_coeff(s),
597
+ ns.marginal_log_mean_coeff(s1),
598
+ ns.marginal_log_mean_coeff(t),
599
+ )
600
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
601
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
602
+
603
+ if self.algorithm_type == "dpmsolver++":
604
+ phi_11 = torch.expm1(-r1 * h)
605
+ phi_1 = torch.expm1(-h)
606
+
607
+ if model_s is None:
608
+ model_s = self.model_fn(x, s)
609
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
610
+ model_s1 = self.model_fn(x_s1, s1)
611
+ if solver_type == "dpmsolver":
612
+ x_t = (
613
+ (sigma_t / sigma_s) * x
614
+ - (alpha_t * phi_1) * model_s
615
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
616
+ )
617
+ elif solver_type == "taylor":
618
+ x_t = (
619
+ (sigma_t / sigma_s) * x
620
+ - (alpha_t * phi_1) * model_s
621
+ + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
622
+ )
623
+ else:
624
+ phi_11 = torch.expm1(r1 * h)
625
+ phi_1 = torch.expm1(h)
626
+
627
+ if model_s is None:
628
+ model_s = self.model_fn(x, s)
629
+ x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
630
+ model_s1 = self.model_fn(x_s1, s1)
631
+ if solver_type == "dpmsolver":
632
+ x_t = (
633
+ torch.exp(log_alpha_t - log_alpha_s) * x
634
+ - (sigma_t * phi_1) * model_s
635
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
636
+ )
637
+ elif solver_type == "taylor":
638
+ x_t = (
639
+ torch.exp(log_alpha_t - log_alpha_s) * x
640
+ - (sigma_t * phi_1) * model_s
641
+ - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
642
+ )
643
+ if return_intermediate:
644
+ return x_t, {"model_s": model_s, "model_s1": model_s1}
645
+ else:
646
+ return x_t
647
+
648
+ def singlestep_dpm_solver_third_update(
649
+ self,
650
+ x,
651
+ s,
652
+ t,
653
+ r1=1.0 / 3.0,
654
+ r2=2.0 / 3.0,
655
+ model_s=None,
656
+ model_s1=None,
657
+ return_intermediate=False,
658
+ solver_type="dpmsolver",
659
+ ):
660
+ """
661
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
662
+
663
+ Args:
664
+ x: A pytorch tensor. The initial value at time `s`.
665
+ s: A pytorch tensor. The starting time, with the shape (1,).
666
+ t: A pytorch tensor. The ending time, with the shape (1,).
667
+ r1: A `float`. The hyperparameter of the third-order solver.
668
+ r2: A `float`. The hyperparameter of the third-order solver.
669
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
670
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
671
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
672
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
673
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
674
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
675
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
676
+ Returns:
677
+ x_t: A pytorch tensor. The approximated solution at time `t`.
678
+ """
679
+ if solver_type not in ["dpmsolver", "taylor"]:
680
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
681
+ if r1 is None:
682
+ r1 = 1.0 / 3.0
683
+ if r2 is None:
684
+ r2 = 2.0 / 3.0
685
+ ns = self.noise_schedule
686
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
687
+ h = lambda_t - lambda_s
688
+ lambda_s1 = lambda_s + r1 * h
689
+ lambda_s2 = lambda_s + r2 * h
690
+ s1 = ns.inverse_lambda(lambda_s1)
691
+ s2 = ns.inverse_lambda(lambda_s2)
692
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
693
+ ns.marginal_log_mean_coeff(s),
694
+ ns.marginal_log_mean_coeff(s1),
695
+ ns.marginal_log_mean_coeff(s2),
696
+ ns.marginal_log_mean_coeff(t),
697
+ )
698
+ sigma_s, sigma_s1, sigma_s2, sigma_t = (
699
+ ns.marginal_std(s),
700
+ ns.marginal_std(s1),
701
+ ns.marginal_std(s2),
702
+ ns.marginal_std(t),
703
+ )
704
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
705
+
706
+ if self.algorithm_type == "dpmsolver++":
707
+ phi_11 = torch.expm1(-r1 * h)
708
+ phi_12 = torch.expm1(-r2 * h)
709
+ phi_1 = torch.expm1(-h)
710
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
711
+ phi_2 = phi_1 / h + 1.0
712
+ phi_3 = phi_2 / h - 0.5
713
+
714
+ if model_s is None:
715
+ model_s = self.model_fn(x, s)
716
+ if model_s1 is None:
717
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
718
+ model_s1 = self.model_fn(x_s1, s1)
719
+ x_s2 = (
720
+ (sigma_s2 / sigma_s) * x
721
+ - (alpha_s2 * phi_12) * model_s
722
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
723
+ )
724
+ model_s2 = self.model_fn(x_s2, s2)
725
+ if solver_type == "dpmsolver":
726
+ x_t = (
727
+ (sigma_t / sigma_s) * x
728
+ - (alpha_t * phi_1) * model_s
729
+ + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
730
+ )
731
+ elif solver_type == "taylor":
732
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
733
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
734
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
735
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
736
+ x_t = (
737
+ (sigma_t / sigma_s) * x
738
+ - (alpha_t * phi_1) * model_s
739
+ + (alpha_t * phi_2) * D1
740
+ - (alpha_t * phi_3) * D2
741
+ )
742
+ else:
743
+ phi_11 = torch.expm1(r1 * h)
744
+ phi_12 = torch.expm1(r2 * h)
745
+ phi_1 = torch.expm1(h)
746
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
747
+ phi_2 = phi_1 / h - 1.0
748
+ phi_3 = phi_2 / h - 0.5
749
+
750
+ if model_s is None:
751
+ model_s = self.model_fn(x, s)
752
+ if model_s1 is None:
753
+ x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
754
+ model_s1 = self.model_fn(x_s1, s1)
755
+ x_s2 = (
756
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
757
+ - (sigma_s2 * phi_12) * model_s
758
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
759
+ )
760
+ model_s2 = self.model_fn(x_s2, s2)
761
+ if solver_type == "dpmsolver":
762
+ x_t = (
763
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
764
+ - (sigma_t * phi_1) * model_s
765
+ - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
766
+ )
767
+ elif solver_type == "taylor":
768
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
769
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
770
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
771
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
772
+ x_t = (
773
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
774
+ - (sigma_t * phi_1) * model_s
775
+ - (sigma_t * phi_2) * D1
776
+ - (sigma_t * phi_3) * D2
777
+ )
778
+
779
+ if return_intermediate:
780
+ return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
781
+ else:
782
+ return x_t
783
+
784
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
785
+ """
786
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
787
+
788
+ Args:
789
+ x: A pytorch tensor. The initial value at time `s`.
790
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
791
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
792
+ t: A pytorch tensor. The ending time, with the shape (1,).
793
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
794
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
795
+ Returns:
796
+ x_t: A pytorch tensor. The approximated solution at time `t`.
797
+ """
798
+ if solver_type not in ["dpmsolver", "taylor"]:
799
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
800
+ ns = self.noise_schedule
801
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
802
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
803
+ lambda_prev_1, lambda_prev_0, lambda_t = (
804
+ ns.marginal_lambda(t_prev_1),
805
+ ns.marginal_lambda(t_prev_0),
806
+ ns.marginal_lambda(t),
807
+ )
808
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
809
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
810
+ alpha_t = torch.exp(log_alpha_t)
811
+
812
+ h_0 = lambda_prev_0 - lambda_prev_1
813
+ h = lambda_t - lambda_prev_0
814
+ r0 = h_0 / h
815
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
816
+ if self.algorithm_type == "dpmsolver++":
817
+ phi_1 = torch.expm1(-h)
818
+ if solver_type == "dpmsolver":
819
+ x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
820
+ elif solver_type == "taylor":
821
+ x_t = (
822
+ (sigma_t / sigma_prev_0) * x
823
+ - (alpha_t * phi_1) * model_prev_0
824
+ + (alpha_t * (phi_1 / h + 1.0)) * D1_0
825
+ )
826
+ else:
827
+ phi_1 = torch.expm1(h)
828
+ if solver_type == "dpmsolver":
829
+ x_t = (
830
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
831
+ - (sigma_t * phi_1) * model_prev_0
832
+ - 0.5 * (sigma_t * phi_1) * D1_0
833
+ )
834
+ elif solver_type == "taylor":
835
+ x_t = (
836
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
837
+ - (sigma_t * phi_1) * model_prev_0
838
+ - (sigma_t * (phi_1 / h - 1.0)) * D1_0
839
+ )
840
+ return x_t
841
+
842
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
843
+ """
844
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
845
+
846
+ Args:
847
+ x: A pytorch tensor. The initial value at time `s`.
848
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
849
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
850
+ t: A pytorch tensor. The ending time, with the shape (1,).
851
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
852
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
853
+ Returns:
854
+ x_t: A pytorch tensor. The approximated solution at time `t`.
855
+ """
856
+ ns = self.noise_schedule
857
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
858
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
859
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
860
+ ns.marginal_lambda(t_prev_2),
861
+ ns.marginal_lambda(t_prev_1),
862
+ ns.marginal_lambda(t_prev_0),
863
+ ns.marginal_lambda(t),
864
+ )
865
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
866
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
867
+ alpha_t = torch.exp(log_alpha_t)
868
+
869
+ h_1 = lambda_prev_1 - lambda_prev_2
870
+ h_0 = lambda_prev_0 - lambda_prev_1
871
+ h = lambda_t - lambda_prev_0
872
+ r0, r1 = h_0 / h, h_1 / h
873
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
874
+ D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
875
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
876
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
877
+ if self.algorithm_type == "dpmsolver++":
878
+ phi_1 = torch.expm1(-h)
879
+ phi_2 = phi_1 / h + 1.0
880
+ phi_3 = phi_2 / h - 0.5
881
+ x_t = (
882
+ (sigma_t / sigma_prev_0) * x
883
+ - (alpha_t * phi_1) * model_prev_0
884
+ + (alpha_t * phi_2) * D1
885
+ - (alpha_t * phi_3) * D2
886
+ )
887
+ else:
888
+ phi_1 = torch.expm1(h)
889
+ phi_2 = phi_1 / h - 1.0
890
+ phi_3 = phi_2 / h - 0.5
891
+ x_t = (
892
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
893
+ - (sigma_t * phi_1) * model_prev_0
894
+ - (sigma_t * phi_2) * D1
895
+ - (sigma_t * phi_3) * D2
896
+ )
897
+ return x_t
898
+
899
+ def singlestep_dpm_solver_update(
900
+ self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None
901
+ ):
902
+ """
903
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
904
+
905
+ Args:
906
+ x: A pytorch tensor. The initial value at time `s`.
907
+ s: A pytorch tensor. The starting time, with the shape (1,).
908
+ t: A pytorch tensor. The ending time, with the shape (1,).
909
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
910
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
911
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
912
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
913
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
914
+ r2: A `float`. The hyperparameter of the third-order solver.
915
+ Returns:
916
+ x_t: A pytorch tensor. The approximated solution at time `t`.
917
+ """
918
+ if order == 1:
919
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
920
+ elif order == 2:
921
+ return self.singlestep_dpm_solver_second_update(
922
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
923
+ )
924
+ elif order == 3:
925
+ return self.singlestep_dpm_solver_third_update(
926
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
927
+ )
928
+ else:
929
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
930
+
931
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
932
+ """
933
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
934
+
935
+ Args:
936
+ x: A pytorch tensor. The initial value at time `s`.
937
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
938
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
939
+ t: A pytorch tensor. The ending time, with the shape (1,).
940
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
941
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
942
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
943
+ Returns:
944
+ x_t: A pytorch tensor. The approximated solution at time `t`.
945
+ """
946
+ if order == 1:
947
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
948
+ elif order == 2:
949
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
950
+ elif order == 3:
951
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
952
+ else:
953
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
954
+
955
+ def dpm_solver_adaptive(
956
+ self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"
957
+ ):
958
+ """
959
+ The adaptive step size solver based on singlestep DPM-Solver.
960
+
961
+ Args:
962
+ x: A pytorch tensor. The initial value at time `t_T`.
963
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
964
+ t_T: A `float`. The starting time of the sampling (default is T).
965
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
966
+ h_init: A `float`. The initial step size (for logSNR).
967
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
968
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
969
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
970
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
971
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
972
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
973
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
974
+ Returns:
975
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
976
+
977
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
978
+ """
979
+ ns = self.noise_schedule
980
+ s = t_T * torch.ones((1,)).to(x)
981
+ lambda_s = ns.marginal_lambda(s)
982
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
983
+ h = h_init * torch.ones_like(s).to(x)
984
+ x_prev = x
985
+ nfe = 0
986
+ if order == 2:
987
+ r1 = 0.5
988
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
989
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
990
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs
991
+ )
992
+ elif order == 3:
993
+ r1, r2 = 1.0 / 3.0, 2.0 / 3.0
994
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
995
+ x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
996
+ )
997
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
998
+ x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
999
+ )
1000
+ else:
1001
+ raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
1002
+ while torch.abs(s - t_0).mean() > t_err:
1003
+ t = ns.inverse_lambda(lambda_s + h)
1004
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1005
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1006
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1007
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1008
+ E = norm_fn((x_higher - x_lower) / delta).max()
1009
+ if torch.all(E <= 1.0):
1010
+ x = x_higher
1011
+ s = t
1012
+ x_prev = x_lower
1013
+ lambda_s = ns.marginal_lambda(s)
1014
+ h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
1015
+ nfe += order
1016
+ print("adaptive solver nfe", nfe)
1017
+ return x
1018
+
1019
+ def add_noise(self, x, t, noise=None):
1020
+ """
1021
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1022
+
1023
+ Args:
1024
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1025
+ t: A `torch.Tensor` with shape `(t_size,)`.
1026
+ Returns:
1027
+ xt with shape `(t_size, batch_size, *shape)`.
1028
+ """
1029
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1030
+ if noise is None:
1031
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1032
+ x = x.reshape((-1, *x.shape))
1033
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1034
+ if t.shape[0] == 1:
1035
+ return xt.squeeze(0)
1036
+ else:
1037
+ return xt
1038
+
1039
+ def inverse(
1040
+ self,
1041
+ x,
1042
+ steps=20,
1043
+ t_start=None,
1044
+ t_end=None,
1045
+ order=2,
1046
+ skip_type="time_uniform",
1047
+ method="multistep",
1048
+ lower_order_final=True,
1049
+ denoise_to_zero=False,
1050
+ solver_type="dpmsolver",
1051
+ atol=0.0078,
1052
+ rtol=0.05,
1053
+ return_intermediate=False,
1054
+ ):
1055
+ """
1056
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1057
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1058
+ """
1059
+ t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
1060
+ t_T = self.noise_schedule.T if t_end is None else t_end
1061
+ assert (
1062
+ t_0 > 0 and t_T > 0
1063
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1064
+ return self.sample(
1065
+ x,
1066
+ steps=steps,
1067
+ t_start=t_0,
1068
+ t_end=t_T,
1069
+ order=order,
1070
+ skip_type=skip_type,
1071
+ method=method,
1072
+ lower_order_final=lower_order_final,
1073
+ denoise_to_zero=denoise_to_zero,
1074
+ solver_type=solver_type,
1075
+ atol=atol,
1076
+ rtol=rtol,
1077
+ return_intermediate=return_intermediate,
1078
+ )
1079
+
1080
+ def sample(
1081
+ self,
1082
+ x,
1083
+ steps=20,
1084
+ t_start=None,
1085
+ t_end=None,
1086
+ order=2,
1087
+ skip_type="time_uniform",
1088
+ method="multistep",
1089
+ lower_order_final=True,
1090
+ denoise_to_zero=False,
1091
+ solver_type="dpmsolver",
1092
+ atol=0.0078,
1093
+ rtol=0.05,
1094
+ return_intermediate=False,
1095
+ flow_shift=1.0,
1096
+ ):
1097
+ """
1098
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1099
+
1100
+ =====================================================
1101
+
1102
+ We support the following algorithms for both noise prediction model and data prediction model:
1103
+ - 'singlestep':
1104
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1105
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1106
+ The total number of function evaluations (NFE) == `steps`.
1107
+ Given a fixed NFE == `steps`, the sampling procedure is:
1108
+ - If `order` == 1:
1109
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1110
+ - If `order` == 2:
1111
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1112
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1113
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1114
+ - If `order` == 3:
1115
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1116
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1117
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1118
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1119
+ - 'multistep':
1120
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1121
+ We initialize the first `order` values by lower order multistep solvers.
1122
+ Given a fixed NFE == `steps`, the sampling procedure is:
1123
+ Denote K = steps.
1124
+ - If `order` == 1:
1125
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1126
+ - If `order` == 2:
1127
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1128
+ - If `order` == 3:
1129
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1130
+ - 'singlestep_fixed':
1131
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1132
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1133
+ - 'adaptive':
1134
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1135
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1136
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1137
+ (NFE) and the sample quality.
1138
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1139
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1140
+
1141
+ =====================================================
1142
+
1143
+ Some advices for choosing the algorithm:
1144
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1145
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1146
+ e.g., DPM-Solver:
1147
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1148
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1149
+ skip_type='time_uniform', method='singlestep')
1150
+ e.g., DPM-Solver++:
1151
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1152
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1153
+ skip_type='time_uniform', method='singlestep')
1154
+ - For **guided sampling with large guidance scale** by DPMs:
1155
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1156
+ e.g.
1157
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1158
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1159
+ skip_type='time_uniform', method='multistep')
1160
+
1161
+ We support three types of `skip_type`:
1162
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1163
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1164
+ - 'time_quadratic': quadratic time for the time steps.
1165
+
1166
+ =====================================================
1167
+ Args:
1168
+ x: A pytorch tensor. The initial value at time `t_start`
1169
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1170
+ steps: A `int`. The total number of function evaluations (NFE).
1171
+ t_start: A `float`. The starting time of the sampling.
1172
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1173
+ t_end: A `float`. The ending time of the sampling.
1174
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1175
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1176
+ For discrete-time DPMs:
1177
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1178
+ For continuous-time DPMs:
1179
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1180
+ order: A `int`. The order of DPM-Solver.
1181
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1182
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1183
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1184
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1185
+
1186
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1187
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1188
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1189
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1190
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1191
+ it for high-resolutional images.
1192
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1193
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1194
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1195
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1196
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1197
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1198
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1199
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1200
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1201
+ Returns:
1202
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1203
+
1204
+ """
1205
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
1206
+ t_T = self.noise_schedule.T if t_start is None else t_start
1207
+ assert (
1208
+ t_0 > 0 and t_T > 0
1209
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1210
+ if return_intermediate:
1211
+ assert method in [
1212
+ "multistep",
1213
+ "singlestep",
1214
+ "singlestep_fixed",
1215
+ ], "Cannot use adaptive solver when saving intermediate values"
1216
+ if self.correcting_xt_fn is not None:
1217
+ assert method in [
1218
+ "multistep",
1219
+ "singlestep",
1220
+ "singlestep_fixed",
1221
+ ], "Cannot use adaptive solver when correcting_xt_fn is not None"
1222
+ device = x.device
1223
+ intermediates = []
1224
+ with torch.no_grad():
1225
+ if method == "adaptive":
1226
+ x = self.dpm_solver_adaptive(
1227
+ x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
1228
+ )
1229
+ elif method == "multistep":
1230
+ assert steps >= order
1231
+ timesteps = self.get_time_steps(
1232
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device, shift=flow_shift
1233
+ )
1234
+ assert timesteps.shape[0] - 1 == steps
1235
+ # Init the initial values.
1236
+ step = 0
1237
+ t = timesteps[step]
1238
+ t_prev_list = [t]
1239
+ model_prev_list = [self.model_fn(x, t)]
1240
+ if self.correcting_xt_fn is not None:
1241
+ x = self.correcting_xt_fn(x, t, step)
1242
+ if return_intermediate:
1243
+ intermediates.append(x)
1244
+ self.update_progress(step + 1, len(timesteps))
1245
+ # Init the first `order` values by lower order multistep DPM-Solver.
1246
+ for step in range(1, order):
1247
+ t = timesteps[step]
1248
+ x = self.multistep_dpm_solver_update(
1249
+ x, model_prev_list, t_prev_list, t, step, solver_type=solver_type
1250
+ )
1251
+ if self.correcting_xt_fn is not None:
1252
+ x = self.correcting_xt_fn(x, t, step)
1253
+ if return_intermediate:
1254
+ intermediates.append(x)
1255
+ t_prev_list.append(t)
1256
+ model_prev_list.append(self.model_fn(x, t))
1257
+ # update progress bar
1258
+ self.update_progress(step + 1, len(timesteps))
1259
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1260
+ for step in tqdm(range(order, steps + 1), disable=os.getenv("DPM_TQDM", "False") == "True"):
1261
+ t = timesteps[step]
1262
+ # We only use lower order for steps < 10
1263
+ # if lower_order_final and steps < 10:
1264
+ if lower_order_final: # recommended by Shuchen Xue
1265
+ step_order = min(order, steps + 1 - step)
1266
+ else:
1267
+ step_order = order
1268
+ x = self.multistep_dpm_solver_update(
1269
+ x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type
1270
+ )
1271
+ if self.correcting_xt_fn is not None:
1272
+ x = self.correcting_xt_fn(x, t, step)
1273
+ if return_intermediate:
1274
+ intermediates.append(x)
1275
+ for i in range(order - 1):
1276
+ t_prev_list[i] = t_prev_list[i + 1]
1277
+ model_prev_list[i] = model_prev_list[i + 1]
1278
+ t_prev_list[-1] = t
1279
+ # We do not need to evaluate the final model value.
1280
+ if step < steps:
1281
+ model_prev_list[-1] = self.model_fn(x, t)
1282
+ # update progress bar
1283
+ self.update_progress(step + 1, len(timesteps))
1284
+ elif method in ["singlestep", "singlestep_fixed"]:
1285
+ if method == "singlestep":
1286
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
1287
+ steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device
1288
+ )
1289
+ elif method == "singlestep_fixed":
1290
+ K = steps // order
1291
+ orders = [
1292
+ order,
1293
+ ] * K
1294
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1295
+ for step, order in enumerate(orders):
1296
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1297
+ timesteps_inner = self.get_time_steps(
1298
+ skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device
1299
+ )
1300
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1301
+ h = lambda_inner[-1] - lambda_inner[0]
1302
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1303
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1304
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1305
+ if self.correcting_xt_fn is not None:
1306
+ x = self.correcting_xt_fn(x, t, step)
1307
+ if return_intermediate:
1308
+ intermediates.append(x)
1309
+ self.update_progress(step + 1, len(timesteps_outer))
1310
+ else:
1311
+ raise ValueError(f"Got wrong method {method}")
1312
+ if denoise_to_zero:
1313
+ t = torch.ones((1,)).to(device) * t_0
1314
+ x = self.denoise_to_zero_fn(x, t)
1315
+ if self.correcting_xt_fn is not None:
1316
+ x = self.correcting_xt_fn(x, t, step + 1)
1317
+ if return_intermediate:
1318
+ intermediates.append(x)
1319
+ if return_intermediate:
1320
+ return x, intermediates
1321
+ else:
1322
+ return x
1323
+
1324
+
1325
+ #############################################################
1326
+ # other utility functions
1327
+ #############################################################
1328
+
1329
+
1330
+ def interpolate_fn(x, xp, yp):
1331
+ """
1332
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1333
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1334
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1335
+
1336
+ Args:
1337
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1338
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1339
+ yp: PyTorch tensor with shape [C, K].
1340
+ Returns:
1341
+ The function values f(x), with shape [N, C].
1342
+ """
1343
+ N, K = x.shape[0], xp.shape[1]
1344
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1345
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1346
+ x_idx = torch.argmin(x_indices, dim=2)
1347
+ cand_start_idx = x_idx - 1
1348
+ start_idx = torch.where(
1349
+ torch.eq(x_idx, 0),
1350
+ torch.tensor(1, device=x.device),
1351
+ torch.where(
1352
+ torch.eq(x_idx, K),
1353
+ torch.tensor(K - 2, device=x.device),
1354
+ cand_start_idx,
1355
+ ),
1356
+ )
1357
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1358
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1359
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1360
+ start_idx2 = torch.where(
1361
+ torch.eq(x_idx, 0),
1362
+ torch.tensor(0, device=x.device),
1363
+ torch.where(
1364
+ torch.eq(x_idx, K),
1365
+ torch.tensor(K - 2, device=x.device),
1366
+ cand_start_idx,
1367
+ ),
1368
+ )
1369
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1370
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1371
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1372
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1373
+ return cand
1374
+
1375
+
1376
+ def expand_dims(v, dims):
1377
+ """
1378
+ Expand the tensor `v` to the dim `dims`.
1379
+
1380
+ Args:
1381
+ `v`: a PyTorch tensor with shape [N].
1382
+ `dim`: a `int`.
1383
+ Returns:
1384
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1385
+ """
1386
+ return v[(...,) + (None,) * (dims - 1)]
omnigen2/transport/integrators.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ from torchdiffeq import odeint
3
+ from .utils import time_shift, get_lin_function
4
+
5
+ class sde:
6
+ """SDE solver class"""
7
+
8
+ def __init__(
9
+ self,
10
+ drift,
11
+ diffusion,
12
+ *,
13
+ t0,
14
+ t1,
15
+ num_steps,
16
+ sampler_type,
17
+ ):
18
+ assert t0 < t1, "SDE sampler has to be in forward time"
19
+
20
+ self.num_timesteps = num_steps
21
+ self.t = th.linspace(t0, t1, num_steps)
22
+ self.dt = self.t[1] - self.t[0]
23
+ self.drift = drift
24
+ self.diffusion = diffusion
25
+ self.sampler_type = sampler_type
26
+
27
+ def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
28
+ w_cur = th.randn(x.size()).to(x)
29
+ t = th.ones(x.size(0)).to(x) * t
30
+ dw = w_cur * th.sqrt(self.dt)
31
+ drift = self.drift(x, t, model, **model_kwargs)
32
+ diffusion = self.diffusion(x, t)
33
+ mean_x = x + drift * self.dt
34
+ x = mean_x + th.sqrt(2 * diffusion) * dw
35
+ return x, mean_x
36
+
37
+ def __Heun_step(self, x, _, t, model, **model_kwargs):
38
+ w_cur = th.randn(x.size()).to(x)
39
+ dw = w_cur * th.sqrt(self.dt)
40
+ t_cur = th.ones(x.size(0)).to(x) * t
41
+ diffusion = self.diffusion(x, t_cur)
42
+ xhat = x + th.sqrt(2 * diffusion) * dw
43
+ K1 = self.drift(xhat, t_cur, model, **model_kwargs)
44
+ xp = xhat + self.dt * K1
45
+ K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
46
+ return (
47
+ xhat + 0.5 * self.dt * (K1 + K2),
48
+ xhat,
49
+ ) # at last time point we do not perform the heun step
50
+
51
+ def __forward_fn(self):
52
+ """TODO: generalize here by adding all private functions ending with steps to it"""
53
+ sampler_dict = {
54
+ "Euler": self.__Euler_Maruyama_step,
55
+ "Heun": self.__Heun_step,
56
+ }
57
+
58
+ try:
59
+ sampler = sampler_dict[self.sampler_type]
60
+ except:
61
+ raise NotImplementedError("Smapler type not implemented.")
62
+
63
+ return sampler
64
+
65
+ def sample(self, init, model, **model_kwargs):
66
+ """forward loop of sde"""
67
+ x = init
68
+ mean_x = init
69
+ samples = []
70
+ sampler = self.__forward_fn()
71
+ for ti in self.t[:-1]:
72
+ with th.no_grad():
73
+ x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
74
+ samples.append(x)
75
+
76
+ return samples
77
+
78
+
79
+ class ode:
80
+ """ODE solver class"""
81
+
82
+ def __init__(
83
+ self,
84
+ drift,
85
+ *,
86
+ t0,
87
+ t1,
88
+ sampler_type,
89
+ num_steps,
90
+ atol,
91
+ rtol,
92
+ do_shift=False,
93
+ time_shifting_factor=None,
94
+ ):
95
+ assert t0 < t1, "ODE sampler has to be in forward time"
96
+
97
+ self.drift = drift
98
+ self.do_shift = do_shift
99
+ self.t = th.linspace(t0, t1, num_steps)
100
+ if time_shifting_factor:
101
+ self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
102
+ self.atol = atol
103
+ self.rtol = rtol
104
+ self.sampler_type = sampler_type
105
+
106
+ def sample(self, x, model, **model_kwargs):
107
+ x = x.float()
108
+ device = x[0].device if isinstance(x, tuple) else x.device
109
+
110
+ def _fn(t, x):
111
+ t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
112
+ model_output = self.drift(x, t, model, **model_kwargs).float()
113
+ return model_output
114
+
115
+ t = self.t.to(device)
116
+ if self.do_shift:
117
+ mu = get_lin_function(y1=0.5, y2=1.15)(x.shape[1])
118
+ t = time_shift(mu, 1.0, t)
119
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
120
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
121
+ samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
122
+ return samples
omnigen2/transport/path.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+
5
+ def expand_t_like_x(t, x):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * len(x[0].size())
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+
16
+ #################### Coupling Plans ####################
17
+
18
+
19
+ class ICPlan:
20
+ """Linear Coupling Plan"""
21
+
22
+ def __init__(self, sigma=0.0):
23
+ self.sigma = sigma
24
+
25
+ def compute_alpha_t(self, t):
26
+ """Compute the data coefficient along the path"""
27
+ return t, 1
28
+
29
+ def compute_sigma_t(self, t):
30
+ """Compute the noise coefficient along the path"""
31
+ return 1 - t, -1
32
+
33
+ def compute_d_alpha_alpha_ratio_t(self, t):
34
+ """Compute the ratio between d_alpha and alpha"""
35
+ return 1 / t
36
+
37
+ def compute_drift(self, x, t):
38
+ """We always output sde according to score parametrization;"""
39
+ t = expand_t_like_x(t, x)
40
+ alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
41
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
42
+ drift = alpha_ratio * x
43
+ diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t
44
+
45
+ return -drift, diffusion
46
+
47
+ def compute_diffusion(self, x, t, form="constant", norm=1.0):
48
+ """Compute the diffusion term of the SDE
49
+ Args:
50
+ x: [batch_dim, ...], data point
51
+ t: [batch_dim,], time vector
52
+ form: str, form of the diffusion term
53
+ norm: float, norm of the diffusion term
54
+ """
55
+ t = expand_t_like_x(t, x)
56
+ choices = {
57
+ "constant": norm,
58
+ "SBDM": norm * self.compute_drift(x, t)[1],
59
+ "sigma": norm * self.compute_sigma_t(t)[0],
60
+ "linear": norm * (1 - t),
61
+ "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
62
+ "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
63
+ }
64
+
65
+ try:
66
+ diffusion = choices[form]
67
+ except KeyError:
68
+ raise NotImplementedError(f"Diffusion form {form} not implemented")
69
+
70
+ return diffusion
71
+
72
+ def get_score_from_velocity(self, velocity, x, t):
73
+ """Wrapper function: transfrom velocity prediction model to score
74
+ Args:
75
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
76
+ x: [batch_dim, ...] shaped tensor; x_t data point
77
+ t: [batch_dim,] time tensor
78
+ """
79
+ t = expand_t_like_x(t, x)
80
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
81
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
82
+ mean = x
83
+ reverse_alpha_ratio = alpha_t / d_alpha_t
84
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
85
+ score = (reverse_alpha_ratio * velocity - mean) / var
86
+ return score
87
+
88
+ def get_noise_from_velocity(self, velocity, x, t):
89
+ """Wrapper function: transfrom velocity prediction model to denoiser
90
+ Args:
91
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
92
+ x: [batch_dim, ...] shaped tensor; x_t data point
93
+ t: [batch_dim,] time tensor
94
+ """
95
+ t = expand_t_like_x(t, x)
96
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
97
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
98
+ mean = x
99
+ reverse_alpha_ratio = alpha_t / d_alpha_t
100
+ var = reverse_alpha_ratio * d_sigma_t - sigma_t
101
+ noise = (reverse_alpha_ratio * velocity - mean) / var
102
+ return noise
103
+
104
+ def get_velocity_from_score(self, score, x, t):
105
+ """Wrapper function: transfrom score prediction model to velocity
106
+ Args:
107
+ score: [batch_dim, ...] shaped tensor; score model output
108
+ x: [batch_dim, ...] shaped tensor; x_t data point
109
+ t: [batch_dim,] time tensor
110
+ """
111
+ t = expand_t_like_x(t, x)
112
+ drift, var = self.compute_drift(x, t)
113
+ velocity = var * score - drift
114
+ return velocity
115
+
116
+ def compute_mu_t(self, t, x0, x1):
117
+ """Compute the mean of time-dependent density p_t"""
118
+ t = expand_t_like_x(t, x1)
119
+ alpha_t, _ = self.compute_alpha_t(t)
120
+ sigma_t, _ = self.compute_sigma_t(t)
121
+ if isinstance(x1, (list, tuple)):
122
+ return [alpha_t[i] * x1[i] + sigma_t[i] * x0[i] for i in range(len(x1))]
123
+ else:
124
+ return alpha_t * x1 + sigma_t * x0
125
+
126
+ def compute_xt(self, t, x0, x1):
127
+ """Sample xt from time-dependent density p_t; rng is required"""
128
+ xt = self.compute_mu_t(t, x0, x1)
129
+ return xt
130
+
131
+ def compute_ut(self, t, x0, x1, xt):
132
+ """Compute the vector field corresponding to p_t"""
133
+ t = expand_t_like_x(t, x1)
134
+ _, d_alpha_t = self.compute_alpha_t(t)
135
+ _, d_sigma_t = self.compute_sigma_t(t)
136
+ if isinstance(x1, (list, tuple)):
137
+ return [d_alpha_t * x1[i] + d_sigma_t * x0[i] for i in range(len(x1))]
138
+ else:
139
+ return d_alpha_t * x1 + d_sigma_t * x0
140
+
141
+ def plan(self, t, x0, x1):
142
+ xt = self.compute_xt(t, x0, x1)
143
+ ut = self.compute_ut(t, x0, x1, xt)
144
+ return t, xt, ut
145
+
146
+
147
+ class VPCPlan(ICPlan):
148
+ """class for VP path flow matching"""
149
+
150
+ def __init__(self, sigma_min=0.1, sigma_max=20.0):
151
+ self.sigma_min = sigma_min
152
+ self.sigma_max = sigma_max
153
+ self.log_mean_coeff = (
154
+ lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
155
+ )
156
+ self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
157
+
158
+ def compute_alpha_t(self, t):
159
+ """Compute coefficient of x1"""
160
+ alpha_t = self.log_mean_coeff(t)
161
+ alpha_t = th.exp(alpha_t)
162
+ d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
163
+ return alpha_t, d_alpha_t
164
+
165
+ def compute_sigma_t(self, t):
166
+ """Compute coefficient of x0"""
167
+ p_sigma_t = 2 * self.log_mean_coeff(t)
168
+ sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
169
+ d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
170
+ return sigma_t, d_sigma_t
171
+
172
+ def compute_d_alpha_alpha_ratio_t(self, t):
173
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
174
+ return self.d_log_mean_coeff(t)
175
+
176
+ def compute_drift(self, x, t):
177
+ """Compute the drift term of the SDE"""
178
+ t = expand_t_like_x(t, x)
179
+ beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
180
+ return -0.5 * beta_t * x, beta_t / 2
181
+
182
+
183
+ class GVPCPlan(ICPlan):
184
+ def __init__(self, sigma=0.0):
185
+ super().__init__(sigma)
186
+
187
+ def compute_alpha_t(self, t):
188
+ """Compute coefficient of x1"""
189
+ alpha_t = th.sin(t * np.pi / 2)
190
+ d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
191
+ return alpha_t, d_alpha_t
192
+
193
+ def compute_sigma_t(self, t):
194
+ """Compute coefficient of x0"""
195
+ sigma_t = th.cos(t * np.pi / 2)
196
+ d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
197
+ return sigma_t, d_sigma_t
198
+
199
+ def compute_d_alpha_alpha_ratio_t(self, t):
200
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
201
+ return np.pi / (2 * th.tan(t * np.pi / 2))
omnigen2/transport/transport.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ from typing import Callable, Optional
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import random
8
+
9
+ from . import path
10
+ from .integrators import ode, sde
11
+ from .utils import mean_flat, expand_dims
12
+ from .dpm_solver import NoiseScheduleFlow, model_wrapper, DPM_Solver
13
+
14
+
15
+ class ModelType(enum.Enum):
16
+ """
17
+ Which type of output the model predicts.
18
+ """
19
+
20
+ NOISE = enum.auto() # the model predicts epsilon
21
+ SCORE = enum.auto() # the model predicts \nabla \log p(x)
22
+ VELOCITY = enum.auto() # the model predicts v(x)
23
+
24
+
25
+ class PathType(enum.Enum):
26
+ """
27
+ Which type of path to use.
28
+ """
29
+
30
+ LINEAR = enum.auto()
31
+ GVP = enum.auto()
32
+ VP = enum.auto()
33
+
34
+
35
+ class WeightType(enum.Enum):
36
+ """
37
+ Which type of weighting to use.
38
+ """
39
+
40
+ NONE = enum.auto()
41
+ VELOCITY = enum.auto()
42
+ LIKELIHOOD = enum.auto()
43
+
44
+
45
+ class Transport:
46
+ def __init__(self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type, do_shift, seq_len,
47
+ dynamic_time_shift: bool = False,
48
+ time_shift_version: str = "v1"):
49
+ path_options = {
50
+ PathType.LINEAR: path.ICPlan,
51
+ PathType.GVP: path.GVPCPlan,
52
+ PathType.VP: path.VPCPlan,
53
+ }
54
+
55
+ self.loss_type = loss_type
56
+ self.model_type = model_type
57
+ self.path_sampler = path_options[path_type]()
58
+ self.train_eps = train_eps
59
+ self.sample_eps = sample_eps
60
+
61
+ self.snr_type = snr_type
62
+ self.do_shift = do_shift
63
+ self.seq_len = seq_len
64
+ self.dynamic_time_shift = dynamic_time_shift
65
+ self.time_shift_version = time_shift_version
66
+ def prior_logp(self, z):
67
+ """
68
+ Standard multivariate normal prior
69
+ Assume z is batched
70
+ """
71
+ shape = th.tensor(z.size())
72
+ N = th.prod(shape[1:])
73
+ _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
74
+ return th.vmap(_fn)(z)
75
+
76
+ def check_interval(
77
+ self,
78
+ train_eps,
79
+ sample_eps,
80
+ *,
81
+ diffusion_form="SBDM",
82
+ sde=False,
83
+ reverse=False,
84
+ eval=False,
85
+ last_step_size=0.0,
86
+ ):
87
+ t0 = 0
88
+ t1 = 1
89
+ eps = train_eps if not eval else sample_eps
90
+ if type(self.path_sampler) in [path.VPCPlan]:
91
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
92
+
93
+ elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
94
+ self.model_type != ModelType.VELOCITY or sde
95
+ ): # avoid numerical issue by taking a first semi-implicit step
96
+ t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
97
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
98
+
99
+ if reverse:
100
+ t0, t1 = 1 - t0, 1 - t1
101
+
102
+ return t0, t1
103
+
104
+ def sample(self, x1, process_index, num_processes):
105
+ """Sampling x0 & t based on shape of x1 (if needed)
106
+ Args:
107
+ x1 - data point; [batch, *dim]
108
+ """
109
+ if isinstance(x1, (list, tuple)):
110
+ x0 = [th.randn_like(img_start) for img_start in x1]
111
+ else:
112
+ x0 = th.randn_like(x1)
113
+ t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
114
+
115
+ if self.snr_type.startswith("uniform"):
116
+ assert t0 == 0.0 and t1 == 1.0, "not implemented."
117
+ if "_" in self.snr_type:
118
+ _, t0, t1 = self.snr_type.split("_")
119
+ t0, t1 = float(t0), float(t1)
120
+ t = th.rand((len(x1),)) * (t1 - t0) + t0
121
+ if self.snr_type == "stratified_uniform":
122
+ batch_size = len(x1)
123
+ n = batch_size * num_processes
124
+ offsets = th.arange(process_index, n, num_processes)
125
+ u = th.rand(size=(batch_size,))
126
+ t = ((offsets + u) / n)
127
+ elif self.snr_type == "lognorm":
128
+ u = th.normal(mean=0.0, std=1.0, size=(len(x1),))
129
+ t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
130
+ elif self.snr_type == "zero":
131
+ t = th.rand((len(x1),))
132
+ for _ in range(len(x1)):
133
+ if random.random() < 1.0:
134
+ t[_] = 0.0
135
+ # print(t)
136
+ else:
137
+ raise NotImplementedError("Not implemented snr_type %s" % self.snr_type)
138
+
139
+ if self.do_shift:
140
+ if self.dynamic_time_shift:
141
+ if self.time_shift_version == "v1":
142
+ base_shift: float = 0.5
143
+ max_shift: float = 1.15
144
+ lin_func = self.get_lin_function(y1=base_shift, y2=max_shift)
145
+
146
+ mu = th.tensor([lin_func((_x1.shape[-2] // 2) * (_x1.shape[-1] // 2)) for _x1 in x1], dtype=t.dtype, device=t.device).view_as(t)
147
+ t = self.time_shift(mu, 1.0, t)
148
+ elif self.time_shift_version == "v2":
149
+ tokens = th.tensor([(_x1.shape[-2] // 2) * (_x1.shape[-1] // 2) for _x1 in x1], dtype=t.dtype, device=t.device).view_as(t)
150
+ t = self.time_shift_v2(tokens, t)
151
+ else:
152
+ if self.time_shift_version == "v1":
153
+ base_shift: float = 0.5
154
+ max_shift: float = 1.15
155
+ mu = self.get_lin_function(y1=base_shift, y2=max_shift)(self.seq_len)
156
+ t = self.time_shift(mu, 1.0, t)
157
+ elif self.time_shift_version == "v2":
158
+ tokens = th.tensor([self.seq_len] * len(x1), dtype=t.dtype, device=t.device).view_as(t)
159
+ t = self.time_shift_v2(tokens, t)
160
+ t = t.to(x1[0])
161
+ return t, x0, x1
162
+
163
+ def time_shift(self, mu: float, sigma: float, t: th.Tensor):
164
+ # the following implementation was original for t=0: clean / t=1: noise
165
+ # Since we adopt the reverse, the 1-t operations are needed
166
+ t = 1 - t
167
+ if isinstance(mu, th.Tensor):
168
+ t = th.exp(mu) / (th.exp(mu) + (1 / t - 1) ** sigma)
169
+ else:
170
+ t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
171
+ t = 1 - t
172
+ return t
173
+
174
+ def time_shift_v2(self, tokens: th.Tensor, t: th.Tensor):
175
+ # t = th.exp(mu) / (th.exp(mu) + (1 / t - 1) ** sigma)
176
+ m = th.sqrt(tokens) / 20
177
+ t = t / (m - m * t + t)
178
+ return t
179
+
180
+ def get_lin_function(
181
+ self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
182
+ ) -> Callable[[float], float]:
183
+ m = (y2 - y1) / (x2 - x1)
184
+ b = y1 - m * x1
185
+ return lambda x: m * x + b
186
+
187
+ def training_losses(
188
+ self,
189
+ model,
190
+ x1,
191
+ model_kwargs=None,
192
+ process_index: Optional[int] = None,
193
+ num_processes: Optional[int] = None,
194
+ reduction: str = 'mean',
195
+ ):
196
+ """Loss for training the score model
197
+ Args:
198
+ - model: backbone model; could be score, noise, or velocity
199
+ - x1: datapoint
200
+ - model_kwargs: additional arguments for the model
201
+ """
202
+
203
+ terms = {}
204
+
205
+ if model_kwargs is None:
206
+ model_kwargs = {}
207
+ t, x0, x1 = self.sample(x1, process_index, num_processes)
208
+ t, xt, ut = self.path_sampler.plan(t, x0, x1)
209
+
210
+ terms = {}
211
+ terms['t'] = t
212
+ terms['xt'] = xt
213
+
214
+ if "cond" in model_kwargs:
215
+ conds = model_kwargs.pop("cond")
216
+ xt = [th.cat([x, cond], dim=0) if cond is not None else x for x, cond in zip(xt, conds)]
217
+ model_output = model(xt, t, **model_kwargs)
218
+ B = len(x0)
219
+
220
+ terms['pred'] = model_output
221
+ if self.model_type == ModelType.VELOCITY:
222
+ if isinstance(x1, (list, tuple)):
223
+ assert len(model_output) == len(ut) == len(x1)
224
+ for i in range(B):
225
+ assert (
226
+ model_output[i].shape == ut[i].shape == x1[i].shape
227
+ ), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}"
228
+ terms["task_loss"] = th.stack(
229
+ [th.nn.functional.mse_loss(ut[i].float(), model_output[i].float(), reduction=reduction) for i in range(B)],
230
+ dim=0,
231
+ )
232
+ else:
233
+ terms["task_loss"] = mean_flat(((model_output - ut) ** 2))
234
+ else:
235
+ raise NotImplementedError
236
+
237
+ terms["loss"] = terms["task_loss"]
238
+ terms["t"] = t
239
+ return terms
240
+
241
+ def get_drift(self):
242
+ """member function for obtaining the drift of the probability flow ODE"""
243
+
244
+ def score_ode(x, t, model, **model_kwargs):
245
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
246
+ model_output = model(x, t, **model_kwargs)
247
+ return -drift_mean + drift_var * model_output # by change of variable
248
+
249
+ def noise_ode(x, t, model, **model_kwargs):
250
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
251
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
252
+ model_output = model(x, t, **model_kwargs)
253
+ score = model_output / -sigma_t
254
+ return -drift_mean + drift_var * score
255
+
256
+ def velocity_ode(x, t, model, **model_kwargs):
257
+ model_output = model(x, t, **model_kwargs)
258
+ return model_output
259
+
260
+ if self.model_type == ModelType.NOISE:
261
+ drift_fn = noise_ode
262
+ elif self.model_type == ModelType.SCORE:
263
+ drift_fn = score_ode
264
+ else:
265
+ drift_fn = velocity_ode
266
+
267
+ def body_fn(x, t, model, **model_kwargs):
268
+ model_output = drift_fn(x, t, model, **model_kwargs)
269
+ assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
270
+ return model_output
271
+
272
+ return body_fn
273
+
274
+ def get_score(
275
+ self,
276
+ ):
277
+ """member function for obtaining score of
278
+ x_t = alpha_t * x + sigma_t * eps"""
279
+ if self.model_type == ModelType.NOISE:
280
+ score_fn = (
281
+ lambda x, t, model, **kwargs: model(x, t, **kwargs)
282
+ / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
283
+ )
284
+ elif self.model_type == ModelType.SCORE:
285
+ score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
286
+ elif self.model_type == ModelType.VELOCITY:
287
+ score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
288
+ model(x, t, **kwargs), x, t
289
+ )
290
+ else:
291
+ raise NotImplementedError()
292
+
293
+ return score_fn
294
+
295
+
296
+ class Sampler:
297
+ """Sampler class for the transport model"""
298
+
299
+ def __init__(
300
+ self,
301
+ transport,
302
+ ):
303
+ """Constructor for a general sampler; supporting different sampling methods
304
+ Args:
305
+ - transport: an tranport object specify model prediction & interpolant type
306
+ """
307
+
308
+ self.transport = transport
309
+ self.drift = self.transport.get_drift()
310
+ self.score = self.transport.get_score()
311
+
312
+ def __get_sde_diffusion_and_drift(
313
+ self,
314
+ *,
315
+ diffusion_form="SBDM",
316
+ diffusion_norm=1.0,
317
+ ):
318
+ def diffusion_fn(x, t):
319
+ diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
320
+ return diffusion
321
+
322
+ sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(
323
+ x, t, model, **kwargs
324
+ )
325
+
326
+ sde_diffusion = diffusion_fn
327
+
328
+ return sde_drift, sde_diffusion
329
+
330
+ def __get_last_step(
331
+ self,
332
+ sde_drift,
333
+ *,
334
+ last_step,
335
+ last_step_size,
336
+ ):
337
+ """Get the last step function of the SDE solver"""
338
+
339
+ if last_step is None:
340
+ last_step_fn = lambda x, t, model, **model_kwargs: x
341
+ elif last_step == "Mean":
342
+ last_step_fn = (
343
+ lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size
344
+ )
345
+ elif last_step == "Tweedie":
346
+ alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
347
+ sigma = self.transport.path_sampler.compute_sigma_t
348
+ last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][
349
+ 0
350
+ ] * self.score(x, t, model, **model_kwargs)
351
+ elif last_step == "Euler":
352
+ last_step_fn = (
353
+ lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size
354
+ )
355
+ else:
356
+ raise NotImplementedError()
357
+
358
+ return last_step_fn
359
+
360
+ def sample_sde(
361
+ self,
362
+ *,
363
+ sampling_method="Euler",
364
+ diffusion_form="SBDM",
365
+ diffusion_norm=1.0,
366
+ last_step="Mean",
367
+ last_step_size=0.04,
368
+ num_steps=250,
369
+ ):
370
+ """returns a sampling function with given SDE settings
371
+ Args:
372
+ - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
373
+ - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
374
+ - diffusion_norm: function magnitude of diffusion coefficient; default to 1
375
+ - last_step: type of the last step; default to identity
376
+ - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
377
+ - num_steps: total integration step of SDE
378
+ """
379
+
380
+ if last_step is None:
381
+ last_step_size = 0.0
382
+
383
+ sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
384
+ diffusion_form=diffusion_form,
385
+ diffusion_norm=diffusion_norm,
386
+ )
387
+
388
+ t0, t1 = self.transport.check_interval(
389
+ self.transport.train_eps,
390
+ self.transport.sample_eps,
391
+ diffusion_form=diffusion_form,
392
+ sde=True,
393
+ eval=True,
394
+ reverse=False,
395
+ last_step_size=last_step_size,
396
+ )
397
+
398
+ _sde = sde(
399
+ sde_drift,
400
+ sde_diffusion,
401
+ t0=t0,
402
+ t1=t1,
403
+ num_steps=num_steps,
404
+ sampler_type=sampling_method,
405
+ )
406
+
407
+ last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
408
+
409
+ def _sample(init, model, **model_kwargs):
410
+ xs = _sde.sample(init, model, **model_kwargs)
411
+ ts = th.ones(init.size(0), device=init.device) * t1
412
+ x = last_step_fn(xs[-1], ts, model, **model_kwargs)
413
+ xs.append(x)
414
+
415
+ assert len(xs) == num_steps, "Samples does not match the number of steps"
416
+
417
+ return xs
418
+
419
+ return _sample
420
+
421
+ def sample_dpm(
422
+ self,
423
+ model,
424
+ model_kwargs=None,
425
+ ):
426
+
427
+ noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
428
+
429
+ def noise_pred_fn(x, t_continuous):
430
+ output = model(x, 1 - t_continuous, **model_kwargs)
431
+ _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
432
+ try:
433
+ noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output
434
+ except:
435
+ noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0]
436
+ return noise
437
+
438
+ return DPM_Solver(noise_pred_fn, noise_schedule, algorithm_type="dpmsolver++").sample
439
+
440
+
441
+ def sample_ode(
442
+ self,
443
+ *,
444
+ sampling_method="dopri5",
445
+ num_steps=50,
446
+ atol=1e-6,
447
+ rtol=1e-3,
448
+ reverse=False,
449
+ do_shift=False,
450
+ time_shifting_factor=None,
451
+ ):
452
+ """returns a sampling function with given ODE settings
453
+ Args:
454
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
455
+ - num_steps:
456
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
457
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
458
+ - atol: absolute error tolerance for the solver
459
+ - rtol: relative error tolerance for the solver
460
+ """
461
+
462
+ # for flux
463
+ drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs)
464
+
465
+ t0, t1 = self.transport.check_interval(
466
+ self.transport.train_eps,
467
+ self.transport.sample_eps,
468
+ sde=False,
469
+ eval=True,
470
+ reverse=reverse,
471
+ last_step_size=0.0,
472
+ )
473
+
474
+ _ode = ode(
475
+ drift=drift,
476
+ t0=t0,
477
+ t1=t1,
478
+ sampler_type=sampling_method,
479
+ num_steps=num_steps,
480
+ atol=atol,
481
+ rtol=rtol,
482
+ do_shift=do_shift,
483
+ time_shifting_factor=time_shifting_factor,
484
+ )
485
+
486
+ return _ode.sample
487
+
488
+ def sample_ode_likelihood(
489
+ self,
490
+ *,
491
+ sampling_method="dopri5",
492
+ num_steps=50,
493
+ atol=1e-6,
494
+ rtol=1e-3,
495
+ ):
496
+ """returns a sampling function for calculating likelihood with given ODE settings
497
+ Args:
498
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
499
+ - num_steps:
500
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
501
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
502
+ - atol: absolute error tolerance for the solver
503
+ - rtol: relative error tolerance for the solver
504
+ """
505
+
506
+ def _likelihood_drift(x, t, model, **model_kwargs):
507
+ x, _ = x
508
+ eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
509
+ t = th.ones_like(t) * (1 - t)
510
+ with th.enable_grad():
511
+ x.requires_grad = True
512
+ grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
513
+ logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
514
+ drift = self.drift(x, t, model, **model_kwargs)
515
+ return (-drift, logp_grad)
516
+
517
+ t0, t1 = self.transport.check_interval(
518
+ self.transport.train_eps,
519
+ self.transport.sample_eps,
520
+ sde=False,
521
+ eval=True,
522
+ reverse=False,
523
+ last_step_size=0.0,
524
+ )
525
+
526
+ _ode = ode(
527
+ drift=_likelihood_drift,
528
+ t0=t0,
529
+ t1=t1,
530
+ sampler_type=sampling_method,
531
+ num_steps=num_steps,
532
+ atol=atol,
533
+ rtol=rtol,
534
+ )
535
+
536
+ def _sample_fn(x, model, **model_kwargs):
537
+ init_logp = th.zeros(x.size(0)).to(x)
538
+ input = (x, init_logp)
539
+ drift, delta_logp = _ode.sample(input, model, **model_kwargs)
540
+ drift, delta_logp = drift[-1], delta_logp[-1]
541
+ prior_logp = self.transport.prior_logp(drift)
542
+ logp = prior_logp - delta_logp
543
+ return logp, drift
544
+
545
+ return _sample_fn
omnigen2/transport/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import math
3
+
4
+ class EasyDict:
5
+ def __init__(self, sub_dict):
6
+ for k, v in sub_dict.items():
7
+ setattr(self, k, v)
8
+
9
+ def __getitem__(self, key):
10
+ return getattr(self, key)
11
+
12
+
13
+ def mean_flat(x):
14
+ """
15
+ Take the mean over all non-batch dimensions.
16
+ """
17
+ return th.mean(x, dim=list(range(1, len(x.size()))))
18
+
19
+
20
+ def log_state(state):
21
+ result = []
22
+
23
+ sorted_state = dict(sorted(state.items()))
24
+ for key, value in sorted_state.items():
25
+ # Check if the value is an instance of a class
26
+ if "<object" in str(value) or "object at" in str(value):
27
+ result.append(f"{key}: [{value.__class__.__name__}]")
28
+ else:
29
+ result.append(f"{key}: {value}")
30
+
31
+ return "\n".join(result)
32
+
33
+ def time_shift(mu: float, sigma: float, t: th.Tensor):
34
+ # the following implementation was original for t=0: clean / t=1: noise
35
+ # Since we adopt the reverse, the 1-t operations are needed
36
+ t = 1 - t
37
+ t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
38
+ t = 1 - t
39
+ return t
40
+
41
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
42
+ m = (y2 - y1) / (x2 - x1)
43
+ b = y1 - m * x1
44
+ return lambda x: m * x + b
45
+
46
+ def expand_dims(v, dims):
47
+ """
48
+ Expand the tensor `v` to the dim `dims`.
49
+
50
+ Args:
51
+ `v`: a PyTorch tensor with shape [N].
52
+ `dim`: a `int`.
53
+ Returns:
54
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
55
+ """
56
+ return v[(...,) + (None,) * (dims - 1)]
omnigen2/utils/__init__.py ADDED
File without changes
omnigen2/utils/img_util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from torchvision.transforms.functional import to_pil_image
7
+
8
+ def resize_image(image, max_pixels, img_scale_num):
9
+ width, height = image.size
10
+ cur_pixels = height * width
11
+ ratio = (max_pixels / cur_pixels) ** 0.5
12
+ ratio = min(ratio, 1.0) # do not upscale input image
13
+
14
+ new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num
15
+
16
+ image = image.resize((new_width, new_height), resample=Image.BICUBIC)
17
+ return image
18
+
19
+ def create_collage(images: List[torch.Tensor]) -> Image.Image:
20
+ """Create a horizontal collage from a list of images."""
21
+ max_height = max(img.shape[-2] for img in images)
22
+ total_width = sum(img.shape[-1] for img in images)
23
+ canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
24
+
25
+ current_x = 0
26
+ for img in images:
27
+ h, w = img.shape[-2:]
28
+ canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
29
+ current_x += w
30
+
31
+ return to_pil_image(canvas)
omnigen2/utils/import_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Import utilities: Utilities related to imports and our lazy inits.
16
+ """
17
+
18
+ import importlib.util
19
+ import sys
20
+
21
+ # The package importlib_metadata is in a different place, depending on the python version.
22
+ if sys.version_info < (3, 8):
23
+ import importlib_metadata
24
+ else:
25
+ import importlib.metadata as importlib_metadata
26
+
27
+ def _is_package_available(pkg_name: str):
28
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
29
+ pkg_version = "N/A"
30
+
31
+ if pkg_exists:
32
+ try:
33
+ pkg_version = importlib_metadata.version(pkg_name)
34
+ except (ImportError, importlib_metadata.PackageNotFoundError):
35
+ pkg_exists = False
36
+
37
+ return pkg_exists, pkg_version
38
+
39
+ _triton_available, _triton_version = _is_package_available("triton")
40
+ _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
41
+
42
+ def is_triton_available():
43
+ return _triton_available
44
+
45
+ def is_flash_attn_available():
46
+ return _flash_attn_available
omnigen2/utils/logging_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ class TqdmToLogger(object):
4
+ """File-like object to redirect tqdm output to a logger."""
5
+ def __init__(self, logger, level=logging.INFO):
6
+ self.logger = logger
7
+ self.level = level
8
+
9
+ def write(self, buf):
10
+ for line in buf.rstrip().splitlines():
11
+ self.logger.log(self.level, line)
12
+
13
+ def flush(self):
14
+ for handler in self.logger.logger.handlers:
15
+ handler.flush()
omnigen2/utils/reproducibility.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+
4
+ import torch
5
+
6
+ from diffusers.utils import is_torch_npu_available
7
+
8
+
9
+ def worker_init_fn(worker_id, num_processes, num_workers, process_index, seed, same_seed_per_epoch=False):
10
+ if same_seed_per_epoch:
11
+ worker_seed = seed + num_processes + num_workers * process_index + worker_id
12
+ else:
13
+ worker_seed = torch.initial_seed()
14
+
15
+ random.seed(worker_seed)
16
+ np.random.seed(worker_seed % 2**32)
17
+ torch.manual_seed(worker_seed)
18
+
19
+ if is_torch_npu_available():
20
+ torch.npu.manual_seed_all(seed)
21
+ else:
22
+ torch.cuda.manual_seed_all(seed)
omnigen2/utils/teacache_util.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility for TeaCache
3
+
4
+ Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Optional
21
+
22
+ import torch
23
+
24
+ @dataclass
25
+ class TeaCacheParams:
26
+ """
27
+ TeaCache parameters for `OmniGen2Transformer2DModel`
28
+ See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding
29
+
30
+ Args:
31
+ previous_residual (Optional[torch.Tensor]):
32
+ The tensor difference between the output and the input of the transformer layers from the previous timestep.
33
+ previous_modulated_inp (Optional[torch.Tensor]):
34
+ The modulated input from the previous timestep used to indicate the change of the transformer layer's output.
35
+ accumulated_rel_l1_distance (float):
36
+ The accumulated relative L1 distance.
37
+ is_first_or_last_step (bool):
38
+ Whether the current timestep is the first or last step.
39
+ """
40
+ previous_residual: Optional[torch.Tensor] = None
41
+ previous_modulated_inp: Optional[torch.Tensor] = None
42
+ accumulated_rel_l1_distance: float = 0
43
+ is_first_or_last_step: bool = False