Spaces:
Running
on
Zero
Running
on
Zero
fix: omnigen2
Browse files- omnigen2/__init__.py +0 -0
- omnigen2/cache_functions/__init__.py +3 -0
- omnigen2/cache_functions/cache_init.py +38 -0
- omnigen2/cache_functions/cal_type.py +41 -0
- omnigen2/cache_functions/force_scheduler.py +19 -0
- omnigen2/dataset/__init__.py +0 -0
- omnigen2/dataset/omnigen2_test_dataset.py +153 -0
- omnigen2/dataset/omnigen2_train_dataset.py +203 -0
- omnigen2/models/__init__.py +0 -0
- omnigen2/models/attention_processor.py +357 -0
- omnigen2/models/embeddings.py +126 -0
- omnigen2/models/transformers/__init__.py +3 -0
- omnigen2/models/transformers/block_lumina2.py +218 -0
- omnigen2/models/transformers/components.py +4 -0
- omnigen2/models/transformers/repo.py +129 -0
- omnigen2/models/transformers/transformer_omnigen2.py +716 -0
- omnigen2/ops/triton/__init__.py +0 -0
- omnigen2/ops/triton/layer_norm.py +1257 -0
- omnigen2/optim/__init__.py +0 -0
- omnigen2/optim/scheduler/__init__.py +0 -0
- omnigen2/optim/scheduler/cosine_lr.py +118 -0
- omnigen2/optim/scheduler/scheduler.py +131 -0
- omnigen2/optim/scheduler/step_lr.py +63 -0
- omnigen2/pipelines/__init__.py +0 -0
- omnigen2/pipelines/image_processor.py +266 -0
- omnigen2/pipelines/lora_pipeline.py +388 -0
- omnigen2/pipelines/omnigen2/pipeline_omnigen2.py +774 -0
- omnigen2/pipelines/omnigen2/pipeline_omnigen2_chat.py +830 -0
- omnigen2/pipelines/pipeline_utils.py +62 -0
- omnigen2/schedulers/__init__.py +0 -0
- omnigen2/schedulers/scheduling_dpmsolver_multistep.py +1052 -0
- omnigen2/schedulers/scheduling_flow_match_euler_discrete.py +229 -0
- omnigen2/taylorseer_utils/__init__.py +51 -0
- omnigen2/training_utils.py +645 -0
- omnigen2/transport/__init__.py +74 -0
- omnigen2/transport/dpm_solver.py +1386 -0
- omnigen2/transport/integrators.py +122 -0
- omnigen2/transport/path.py +201 -0
- omnigen2/transport/transport.py +545 -0
- omnigen2/transport/utils.py +56 -0
- omnigen2/utils/__init__.py +0 -0
- omnigen2/utils/img_util.py +31 -0
- omnigen2/utils/import_utils.py +46 -0
- omnigen2/utils/logging_utils.py +15 -0
- omnigen2/utils/reproducibility.py +22 -0
- omnigen2/utils/teacache_util.py +43 -0
omnigen2/__init__.py
ADDED
|
File without changes
|
omnigen2/cache_functions/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cache_init import cache_init
|
| 2 |
+
from .cal_type import cal_type
|
| 3 |
+
from .force_scheduler import force_scheduler
|
omnigen2/cache_functions/cache_init.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cache_init.py
|
| 2 |
+
|
| 3 |
+
# Type hinting would cause circular import, self should be `OmniGen2Pipeline`
|
| 4 |
+
def cache_init(self, num_steps: int):
|
| 5 |
+
'''
|
| 6 |
+
Initialization for cache.
|
| 7 |
+
'''
|
| 8 |
+
cache_dic = {}
|
| 9 |
+
cache = {}
|
| 10 |
+
cache_index = {}
|
| 11 |
+
cache[-1]={}
|
| 12 |
+
cache_index[-1]={}
|
| 13 |
+
cache_index['layer_index']={}
|
| 14 |
+
cache[-1]['layers_stream']={}
|
| 15 |
+
cache_dic['cache_counter'] = 0
|
| 16 |
+
|
| 17 |
+
for j in range(len(self.transformer.layers)):
|
| 18 |
+
cache[-1]['layers_stream'][j] = {}
|
| 19 |
+
cache_index[-1][j] = {}
|
| 20 |
+
|
| 21 |
+
cache_dic['Delta-DiT'] = False
|
| 22 |
+
cache_dic['cache_type'] = 'random'
|
| 23 |
+
cache_dic['cache_index'] = cache_index
|
| 24 |
+
cache_dic['cache'] = cache
|
| 25 |
+
cache_dic['fresh_ratio_schedule'] = 'ToCa'
|
| 26 |
+
cache_dic['fresh_ratio'] = 0.0
|
| 27 |
+
cache_dic['fresh_threshold'] = 3
|
| 28 |
+
cache_dic['soft_fresh_weight'] = 0.0
|
| 29 |
+
cache_dic['taylor_cache'] = True
|
| 30 |
+
cache_dic['max_order'] = 4
|
| 31 |
+
cache_dic['first_enhance'] = 5
|
| 32 |
+
|
| 33 |
+
current = {}
|
| 34 |
+
current['activated_steps'] = [0]
|
| 35 |
+
current['step'] = 0
|
| 36 |
+
current['num_steps'] = num_steps
|
| 37 |
+
|
| 38 |
+
return cache_dic, current
|
omnigen2/cache_functions/cal_type.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cal_type.py
|
| 2 |
+
|
| 3 |
+
from .force_scheduler import force_scheduler
|
| 4 |
+
|
| 5 |
+
def cal_type(cache_dic, current):
|
| 6 |
+
'''
|
| 7 |
+
Determine calculation type for this step
|
| 8 |
+
'''
|
| 9 |
+
if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']):
|
| 10 |
+
# FORA:Uniform
|
| 11 |
+
first_step = (current['step'] == 0)
|
| 12 |
+
else:
|
| 13 |
+
# ToCa: First enhanced
|
| 14 |
+
first_step = (current['step'] < cache_dic['first_enhance'])
|
| 15 |
+
|
| 16 |
+
if not first_step:
|
| 17 |
+
fresh_interval = cache_dic['cal_threshold']
|
| 18 |
+
else:
|
| 19 |
+
fresh_interval = cache_dic['fresh_threshold']
|
| 20 |
+
|
| 21 |
+
if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ):
|
| 22 |
+
current['type'] = 'full'
|
| 23 |
+
cache_dic['cache_counter'] = 0
|
| 24 |
+
current['activated_steps'].append(current['step'])
|
| 25 |
+
force_scheduler(cache_dic, current)
|
| 26 |
+
|
| 27 |
+
elif (cache_dic['taylor_cache']):
|
| 28 |
+
cache_dic['cache_counter'] += 1
|
| 29 |
+
current['type'] = 'Taylor'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
elif (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
|
| 33 |
+
cache_dic['cache_counter'] += 1
|
| 34 |
+
current['type'] = 'ToCa'
|
| 35 |
+
# 'cache_noise' 'ToCa' 'FORA'
|
| 36 |
+
elif cache_dic['Delta-DiT']:
|
| 37 |
+
cache_dic['cache_counter'] += 1
|
| 38 |
+
current['type'] = 'Delta-Cache'
|
| 39 |
+
else:
|
| 40 |
+
cache_dic['cache_counter'] += 1
|
| 41 |
+
current['type'] = 'ToCa'
|
omnigen2/cache_functions/force_scheduler.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/force_scheduler.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def force_scheduler(cache_dic, current):
|
| 6 |
+
if cache_dic['fresh_ratio'] == 0:
|
| 7 |
+
# FORA
|
| 8 |
+
linear_step_weight = 0.0
|
| 9 |
+
else:
|
| 10 |
+
# TokenCache
|
| 11 |
+
linear_step_weight = 0.0
|
| 12 |
+
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps'])
|
| 13 |
+
threshold = torch.round(cache_dic['fresh_threshold'] / step_factor)
|
| 14 |
+
|
| 15 |
+
# no force constrain for sensitive steps, cause the performance is good enough.
|
| 16 |
+
# you may have a try.
|
| 17 |
+
|
| 18 |
+
cache_dic['cal_threshold'] = threshold
|
| 19 |
+
#return threshold
|
omnigen2/dataset/__init__.py
ADDED
|
File without changes
|
omnigen2/dataset/omnigen2_test_dataset.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import yaml
|
| 6 |
+
import glob
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from datasets import load_dataset, concatenate_datasets
|
| 12 |
+
|
| 13 |
+
from ..pipelines.omnigen2.pipeline_omnigen2 import OmniGen2ImageProcessor
|
| 14 |
+
|
| 15 |
+
class OmniGen2TestDataset(torch.utils.data.Dataset):
|
| 16 |
+
SYSTEM_PROMPT = "You are a helpful assistant that generates high-quality images based on user instructions."
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
config_path: str,
|
| 21 |
+
tokenizer,
|
| 22 |
+
use_chat_template: bool,
|
| 23 |
+
max_pixels: Optional[int] = None,
|
| 24 |
+
max_side_length: Optional[int] = None,
|
| 25 |
+
img_scale_num: int = 16,
|
| 26 |
+
align_res: bool = True
|
| 27 |
+
):
|
| 28 |
+
|
| 29 |
+
self.max_pixels = max_pixels
|
| 30 |
+
self.max_side_length = max_side_length
|
| 31 |
+
self.img_scale_num = img_scale_num
|
| 32 |
+
self.align_res = align_res
|
| 33 |
+
|
| 34 |
+
with open(config_path, "r") as f:
|
| 35 |
+
self.config = yaml.load(f, Loader=yaml.FullLoader)
|
| 36 |
+
|
| 37 |
+
self.use_chat_template = use_chat_template
|
| 38 |
+
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=img_scale_num, do_resize=True)
|
| 39 |
+
|
| 40 |
+
data = self._collect_annotations(self.config)
|
| 41 |
+
|
| 42 |
+
self.data = data
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
|
| 45 |
+
def _collect_annotations(self, config):
|
| 46 |
+
json_datasets = []
|
| 47 |
+
for data in config['data']:
|
| 48 |
+
data_path, data_type = data['path'], data.get("type", "default")
|
| 49 |
+
if os.path.isdir(data_path):
|
| 50 |
+
jsonl_files = list(glob.glob(os.path.join(data_path, "**/*.jsonl"), recursive=True)) + list(glob.glob(os.path.join(data_path, "**/*.json"), recursive=True))
|
| 51 |
+
json_dataset = load_dataset('json', data_files=jsonl_files, cache_dir=None)['train']
|
| 52 |
+
else:
|
| 53 |
+
data_ext = os.path.splitext(data_path)[-1]
|
| 54 |
+
if data_ext in [".json", ".jsonl"]:
|
| 55 |
+
json_dataset = load_dataset('json', data_files=data_path, cache_dir=None)['train']
|
| 56 |
+
elif data_ext in [".yml", ".yaml"]:
|
| 57 |
+
with open(data_path, "r") as f:
|
| 58 |
+
sub_config = yaml.load(f, Loader=yaml.FullLoader)
|
| 59 |
+
json_dataset = self._collect_annotations(sub_config)
|
| 60 |
+
else:
|
| 61 |
+
raise NotImplementedError(
|
| 62 |
+
f'Unknown data file extension: "{data_ext}". '
|
| 63 |
+
f"Currently, .json, .jsonl .yml .yaml are supported. "
|
| 64 |
+
"If you are using a supported format, please set the file extension so that the proper parsing "
|
| 65 |
+
"routine can be called."
|
| 66 |
+
)
|
| 67 |
+
json_datasets.append(json_dataset)
|
| 68 |
+
|
| 69 |
+
json_dataset = concatenate_datasets(json_datasets)
|
| 70 |
+
return json_dataset
|
| 71 |
+
|
| 72 |
+
def apply_chat_template(self, instruction, system_prompt):
|
| 73 |
+
if self.use_chat_template:
|
| 74 |
+
prompt = [
|
| 75 |
+
{
|
| 76 |
+
"role": "system",
|
| 77 |
+
"content": system_prompt,
|
| 78 |
+
},
|
| 79 |
+
{"role": "user", "content": instruction},
|
| 80 |
+
]
|
| 81 |
+
instruction = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
| 82 |
+
return instruction
|
| 83 |
+
|
| 84 |
+
def process_item(self, data_item):
|
| 85 |
+
assert data_item['instruction'] is not None
|
| 86 |
+
if 'input_images' in data_item and data_item['input_images'] is not None:
|
| 87 |
+
input_images_path = data_item['input_images']
|
| 88 |
+
input_images = []
|
| 89 |
+
|
| 90 |
+
for input_image_path in input_images_path:
|
| 91 |
+
input_image = Image.open(input_image_path).convert("RGB")
|
| 92 |
+
input_images.append(input_image)
|
| 93 |
+
else:
|
| 94 |
+
input_images_path, input_images = None, None
|
| 95 |
+
|
| 96 |
+
if input_images is not None and len(input_images) == 1 and self.align_res:
|
| 97 |
+
target_img_size = (input_images[0].width, input_images[0].height)
|
| 98 |
+
else:
|
| 99 |
+
target_img_size = data_item["target_img_size"]
|
| 100 |
+
|
| 101 |
+
w, h = target_img_size
|
| 102 |
+
cur_pixels = w * h
|
| 103 |
+
ratio = min(1, (self.max_pixels / cur_pixels) ** 0.5)
|
| 104 |
+
|
| 105 |
+
target_img_size = (int(w * ratio) // self.img_scale_num * self.img_scale_num, int(h * ratio) // self.img_scale_num * self.img_scale_num)
|
| 106 |
+
|
| 107 |
+
data = {
|
| 108 |
+
'task_type': data_item['task_type'],
|
| 109 |
+
'instruction': data_item['instruction'],
|
| 110 |
+
'input_images_path': input_images_path,
|
| 111 |
+
'input_images': input_images,
|
| 112 |
+
'target_img_size': target_img_size,
|
| 113 |
+
}
|
| 114 |
+
return data
|
| 115 |
+
|
| 116 |
+
def __getitem__(self, index):
|
| 117 |
+
data_item = self.data[index]
|
| 118 |
+
return self.process_item(data_item)
|
| 119 |
+
|
| 120 |
+
def __len__(self):
|
| 121 |
+
return len(self.data)
|
| 122 |
+
|
| 123 |
+
class OmniGen2Collator():
|
| 124 |
+
def __init__(self, tokenizer, max_token_len):
|
| 125 |
+
self.tokenizer = tokenizer
|
| 126 |
+
self.max_token_len = max_token_len
|
| 127 |
+
|
| 128 |
+
def __call__(self, batch):
|
| 129 |
+
task_type = [data['task_type'] for data in batch]
|
| 130 |
+
instruction = [data['instruction'] for data in batch]
|
| 131 |
+
input_images_path = [data['input_images_path'] for data in batch]
|
| 132 |
+
input_images = [data['input_images'] for data in batch]
|
| 133 |
+
output_image = [data['output_image'] for data in batch]
|
| 134 |
+
output_image_path = [data['output_image_path'] for data in batch]
|
| 135 |
+
|
| 136 |
+
text_inputs = self.tokenizer(
|
| 137 |
+
instruction,
|
| 138 |
+
padding="longest",
|
| 139 |
+
max_length=self.max_token_len,
|
| 140 |
+
truncation=True,
|
| 141 |
+
return_tensors="pt",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
data = {
|
| 145 |
+
"task_type": task_type,
|
| 146 |
+
"text_ids": text_inputs.input_ids,
|
| 147 |
+
"text_mask": text_inputs.attention_mask,
|
| 148 |
+
"input_images": input_images,
|
| 149 |
+
"input_images_path": input_images_path,
|
| 150 |
+
"output_image": output_image,
|
| 151 |
+
"output_image_path": output_image_path,
|
| 152 |
+
}
|
| 153 |
+
return data
|
omnigen2/dataset/omnigen2_train_dataset.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union, List
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import yaml
|
| 6 |
+
import glob
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
|
| 12 |
+
from datasets import load_dataset, concatenate_datasets
|
| 13 |
+
|
| 14 |
+
from ..pipelines.omnigen2.pipeline_omnigen2 import OmniGen2ImageProcessor
|
| 15 |
+
|
| 16 |
+
class OmniGen2TrainDataset(torch.utils.data.Dataset):
|
| 17 |
+
SYSTEM_PROMPT = "You are a helpful assistant that generates high-quality images based on user instructions."
|
| 18 |
+
SYSTEM_PROMPT_DROP = "You are a helpful assistant that generates images."
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
config_path: str,
|
| 23 |
+
tokenizer,
|
| 24 |
+
use_chat_template: bool,
|
| 25 |
+
max_input_pixels: Optional[Union[int, List[int]]] = None,
|
| 26 |
+
max_output_pixels: Optional[int] = None,
|
| 27 |
+
max_side_length: Optional[int] = None,
|
| 28 |
+
img_scale_num: int = 16,
|
| 29 |
+
prompt_dropout_prob: float = 0.0,
|
| 30 |
+
ref_img_dropout_prob: float = 0.0,
|
| 31 |
+
):
|
| 32 |
+
self.max_input_pixels = max_input_pixels
|
| 33 |
+
self.max_output_pixels = max_output_pixels
|
| 34 |
+
|
| 35 |
+
self.max_side_length = max_side_length
|
| 36 |
+
self.img_scale_num = img_scale_num
|
| 37 |
+
self.prompt_dropout_prob = prompt_dropout_prob
|
| 38 |
+
self.ref_img_dropout_prob = ref_img_dropout_prob
|
| 39 |
+
|
| 40 |
+
with open(config_path, "r") as f:
|
| 41 |
+
self.config = yaml.load(f, Loader=yaml.FullLoader)
|
| 42 |
+
|
| 43 |
+
self.use_chat_template = use_chat_template
|
| 44 |
+
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=img_scale_num, do_resize=True)
|
| 45 |
+
|
| 46 |
+
data = self._collect_annotations(self.config)
|
| 47 |
+
|
| 48 |
+
self.data = data
|
| 49 |
+
self.tokenizer = tokenizer
|
| 50 |
+
|
| 51 |
+
def _collect_annotations(self, config):
|
| 52 |
+
total_samples = 0
|
| 53 |
+
total_ratio = 0
|
| 54 |
+
json_datasets = []
|
| 55 |
+
for data in config['data']:
|
| 56 |
+
data_path, data_type = data['path'], data.get("type", "default")
|
| 57 |
+
if os.path.isdir(data_path):
|
| 58 |
+
jsonl_files = list(glob.glob(os.path.join(data_path, "**/*.jsonl"), recursive=True)) + list(glob.glob(os.path.join(data_path, "**/*.json"), recursive=True))
|
| 59 |
+
json_dataset = load_dataset('json', data_files=jsonl_files, cache_dir=None)['train']
|
| 60 |
+
else:
|
| 61 |
+
data_ext = os.path.splitext(data_path)[-1]
|
| 62 |
+
if data_ext in [".json", ".jsonl"]:
|
| 63 |
+
json_dataset = load_dataset('json', data_files=data_path, cache_dir=None)['train']
|
| 64 |
+
elif data_ext in [".yml", ".yaml"]:
|
| 65 |
+
with open(data_path, "r") as f:
|
| 66 |
+
sub_config = yaml.load(f, Loader=yaml.FullLoader)
|
| 67 |
+
json_dataset = self._collect_annotations(sub_config)
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError(
|
| 70 |
+
f'Unknown data file extension: "{data_ext}". '
|
| 71 |
+
f"Currently, .json, .jsonl .yml .yaml are supported. "
|
| 72 |
+
"If you are using a supported format, please set the file extension so that the proper parsing "
|
| 73 |
+
"routine can be called."
|
| 74 |
+
)
|
| 75 |
+
total_ratio += data['ratio']
|
| 76 |
+
total_samples += len(json_dataset)
|
| 77 |
+
json_datasets.append(json_dataset)
|
| 78 |
+
|
| 79 |
+
for json_dataset in json_datasets:
|
| 80 |
+
target_size = int(len(json_dataset) * data['ratio'] / total_ratio) # normalize the ratio
|
| 81 |
+
if target_size <= len(json_dataset):
|
| 82 |
+
# Random selection without replacement
|
| 83 |
+
indices = random.sample(range(len(json_dataset)), target_size)
|
| 84 |
+
else:
|
| 85 |
+
# Oversample with replacement
|
| 86 |
+
indices = random.choices(range(len(json_dataset)), k=target_size)
|
| 87 |
+
json_dataset = json_dataset.select(indices)
|
| 88 |
+
|
| 89 |
+
json_dataset = concatenate_datasets(json_datasets)
|
| 90 |
+
return json_dataset
|
| 91 |
+
|
| 92 |
+
def clean_data_item(self, data_item):
|
| 93 |
+
task_type = data_item['task_type']
|
| 94 |
+
prefixs = ["The image portrays ", "The image depicts ", "The image captures ", "The image highlights ", "The image shows ", "这张图片展示了"]
|
| 95 |
+
if "text_to_image" in task_type or "t2i" in task_type:
|
| 96 |
+
if random.random() < 0.5:
|
| 97 |
+
for p in prefixs:
|
| 98 |
+
if p in data_item['instruction']:
|
| 99 |
+
data_item['instruction'] = data_item['instruction'].replace(p, "")
|
| 100 |
+
break
|
| 101 |
+
return data_item
|
| 102 |
+
|
| 103 |
+
def apply_chat_template(self, instruction, system_prompt):
|
| 104 |
+
if self.use_chat_template:
|
| 105 |
+
prompt = [
|
| 106 |
+
{
|
| 107 |
+
"role": "system",
|
| 108 |
+
"content": system_prompt,
|
| 109 |
+
},
|
| 110 |
+
{"role": "user", "content": instruction},
|
| 111 |
+
]
|
| 112 |
+
instruction = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
| 113 |
+
return instruction
|
| 114 |
+
|
| 115 |
+
def process_item(self, data_item):
|
| 116 |
+
assert data_item['instruction'] is not None
|
| 117 |
+
data_item = self.clean_data_item(data_item)
|
| 118 |
+
|
| 119 |
+
drop_prompt = random.random() < self.prompt_dropout_prob
|
| 120 |
+
drop_ref_img = drop_prompt and random.random() < self.ref_img_dropout_prob
|
| 121 |
+
|
| 122 |
+
if drop_prompt:
|
| 123 |
+
instruction = self.apply_chat_template("", self.SYSTEM_PROMPT_DROP)
|
| 124 |
+
else:
|
| 125 |
+
instruction = self.apply_chat_template(data_item['instruction'], self.SYSTEM_PROMPT)
|
| 126 |
+
|
| 127 |
+
if not drop_ref_img and 'input_images' in data_item and data_item['input_images'] is not None:
|
| 128 |
+
input_images_path = data_item['input_images']
|
| 129 |
+
input_images = []
|
| 130 |
+
|
| 131 |
+
max_input_pixels = self.max_input_pixels[len(input_images_path) - 1] if isinstance(self.max_input_pixels, list) else self.max_input_pixels
|
| 132 |
+
|
| 133 |
+
for input_image_path in input_images_path:
|
| 134 |
+
input_image = Image.open(input_image_path).convert("RGB")
|
| 135 |
+
input_image = self.image_processor.preprocess(input_image, max_pixels=max_input_pixels, max_side_length=self.max_side_length)
|
| 136 |
+
input_images.append(input_image)
|
| 137 |
+
else:
|
| 138 |
+
input_images_path, input_images = None, None
|
| 139 |
+
|
| 140 |
+
output_image_path = data_item['output_image']
|
| 141 |
+
output_image = Image.open(output_image_path).convert("RGB")
|
| 142 |
+
output_image = self.image_processor.preprocess(output_image, max_pixels=self.max_output_pixels, max_side_length=self.max_side_length)
|
| 143 |
+
|
| 144 |
+
data = {
|
| 145 |
+
'task_type': data_item['task_type'],
|
| 146 |
+
'instruction': instruction,
|
| 147 |
+
'input_images_path': input_images_path,
|
| 148 |
+
'input_images': input_images,
|
| 149 |
+
'output_image': output_image,
|
| 150 |
+
'output_image_path': output_image_path,
|
| 151 |
+
}
|
| 152 |
+
return data
|
| 153 |
+
|
| 154 |
+
def __getitem__(self, index):
|
| 155 |
+
max_retries = 12
|
| 156 |
+
|
| 157 |
+
current_index = index
|
| 158 |
+
for attempt in range(max_retries):
|
| 159 |
+
try:
|
| 160 |
+
data_item = self.data[current_index]
|
| 161 |
+
return self.process_item(data_item)
|
| 162 |
+
except Exception as e:
|
| 163 |
+
if attempt == max_retries - 1:
|
| 164 |
+
raise e
|
| 165 |
+
else:
|
| 166 |
+
# Try a different index for the next attempt
|
| 167 |
+
current_index = random.randint(0, len(self.data) - 1)
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
def __len__(self):
|
| 171 |
+
return len(self.data)
|
| 172 |
+
|
| 173 |
+
class OmniGen2Collator():
|
| 174 |
+
def __init__(self, tokenizer, max_token_len):
|
| 175 |
+
self.tokenizer = tokenizer
|
| 176 |
+
self.max_token_len = max_token_len
|
| 177 |
+
|
| 178 |
+
def __call__(self, batch):
|
| 179 |
+
task_type = [data['task_type'] for data in batch]
|
| 180 |
+
instruction = [data['instruction'] for data in batch]
|
| 181 |
+
input_images_path = [data['input_images_path'] for data in batch]
|
| 182 |
+
input_images = [data['input_images'] for data in batch]
|
| 183 |
+
output_image = [data['output_image'] for data in batch]
|
| 184 |
+
output_image_path = [data['output_image_path'] for data in batch]
|
| 185 |
+
|
| 186 |
+
text_inputs = self.tokenizer(
|
| 187 |
+
instruction,
|
| 188 |
+
padding="longest",
|
| 189 |
+
max_length=self.max_token_len,
|
| 190 |
+
truncation=True,
|
| 191 |
+
return_tensors="pt",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
data = {
|
| 195 |
+
"task_type": task_type,
|
| 196 |
+
"text_ids": text_inputs.input_ids,
|
| 197 |
+
"text_mask": text_inputs.attention_mask,
|
| 198 |
+
"input_images": input_images,
|
| 199 |
+
"input_images_path": input_images_path,
|
| 200 |
+
"output_image": output_image,
|
| 201 |
+
"output_image_path": output_image_path,
|
| 202 |
+
}
|
| 203 |
+
return data
|
omnigen2/models/__init__.py
ADDED
|
File without changes
|
omnigen2/models/attention_processor.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OmniGen2 Attention Processor Module
|
| 3 |
+
|
| 4 |
+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
| 5 |
+
|
| 6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
you may not use this file except in compliance with the License.
|
| 8 |
+
You may obtain a copy of the License at
|
| 9 |
+
|
| 10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
|
| 12 |
+
Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
See the License for the specific language governing permissions and
|
| 16 |
+
limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
import math
|
| 21 |
+
from typing import Optional, Tuple, Dict, Any
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from einops import repeat
|
| 26 |
+
|
| 27 |
+
from ..utils.import_utils import is_flash_attn_available
|
| 28 |
+
|
| 29 |
+
if is_flash_attn_available():
|
| 30 |
+
from flash_attn import flash_attn_varlen_func
|
| 31 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
| 32 |
+
else:
|
| 33 |
+
warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
from diffusers.models.attention_processor import Attention
|
| 37 |
+
from .embeddings import apply_rotary_emb
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class OmniGen2AttnProcessorFlash2Varlen:
|
| 41 |
+
"""
|
| 42 |
+
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
|
| 43 |
+
|
| 44 |
+
This processor implements:
|
| 45 |
+
- Flash attention with variable length sequences
|
| 46 |
+
- Rotary position embeddings (RoPE)
|
| 47 |
+
- Query-Key normalization
|
| 48 |
+
- Proportional attention scaling
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
None
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self) -> None:
|
| 55 |
+
"""Initialize the attention processor."""
|
| 56 |
+
if not is_flash_attn_available():
|
| 57 |
+
raise ImportError(
|
| 58 |
+
"OmniGen2AttnProcessorFlash2Varlen requires flash_attn. "
|
| 59 |
+
"Please install flash_attn."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def _upad_input(
|
| 63 |
+
self,
|
| 64 |
+
query_layer: torch.Tensor,
|
| 65 |
+
key_layer: torch.Tensor,
|
| 66 |
+
value_layer: torch.Tensor,
|
| 67 |
+
attention_mask: torch.Tensor,
|
| 68 |
+
query_length: int,
|
| 69 |
+
num_heads: int,
|
| 70 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 71 |
+
"""
|
| 72 |
+
Unpad the input tensors for flash attention.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
|
| 76 |
+
key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
| 77 |
+
value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
| 78 |
+
attention_mask: Attention mask tensor of shape (batch_size, seq_len)
|
| 79 |
+
query_length: Length of the query sequence
|
| 80 |
+
num_heads: Number of attention heads
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Tuple containing:
|
| 84 |
+
- Unpadded query tensor
|
| 85 |
+
- Unpadded key tensor
|
| 86 |
+
- Unpadded value tensor
|
| 87 |
+
- Query indices
|
| 88 |
+
- Tuple of cumulative sequence lengths for query and key
|
| 89 |
+
- Tuple of maximum sequence lengths for query and key
|
| 90 |
+
"""
|
| 91 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 92 |
+
"""Helper function to get unpadding data from attention mask."""
|
| 93 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 94 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 95 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 96 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 97 |
+
return indices, cu_seqlens, max_seqlen_in_batch
|
| 98 |
+
|
| 99 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 100 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 101 |
+
|
| 102 |
+
# Unpad key and value layers
|
| 103 |
+
key_layer = index_first_axis(
|
| 104 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 105 |
+
indices_k,
|
| 106 |
+
)
|
| 107 |
+
value_layer = index_first_axis(
|
| 108 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 109 |
+
indices_k,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Handle different query length cases
|
| 113 |
+
if query_length == kv_seq_len:
|
| 114 |
+
query_layer = index_first_axis(
|
| 115 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
|
| 116 |
+
indices_k,
|
| 117 |
+
)
|
| 118 |
+
cu_seqlens_q = cu_seqlens_k
|
| 119 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 120 |
+
indices_q = indices_k
|
| 121 |
+
elif query_length == 1:
|
| 122 |
+
max_seqlen_in_batch_q = 1
|
| 123 |
+
cu_seqlens_q = torch.arange(
|
| 124 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 125 |
+
)
|
| 126 |
+
indices_q = cu_seqlens_q[:-1]
|
| 127 |
+
query_layer = query_layer.squeeze(1)
|
| 128 |
+
else:
|
| 129 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 130 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
| 131 |
+
|
| 132 |
+
return (
|
| 133 |
+
query_layer,
|
| 134 |
+
key_layer,
|
| 135 |
+
value_layer,
|
| 136 |
+
indices_q,
|
| 137 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 138 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def __call__(
|
| 142 |
+
self,
|
| 143 |
+
attn: Attention,
|
| 144 |
+
hidden_states: torch.Tensor,
|
| 145 |
+
encoder_hidden_states: torch.Tensor,
|
| 146 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 147 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 148 |
+
base_sequence_length: Optional[int] = None,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Process attention computation with flash attention.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
attn: Attention module
|
| 155 |
+
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
|
| 156 |
+
encoder_hidden_states: Encoder hidden states tensor
|
| 157 |
+
attention_mask: Optional attention mask tensor
|
| 158 |
+
image_rotary_emb: Optional rotary embeddings for image tokens
|
| 159 |
+
base_sequence_length: Optional base sequence length for proportional attention
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
torch.Tensor: Processed hidden states after attention computation
|
| 163 |
+
"""
|
| 164 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 165 |
+
|
| 166 |
+
# Get Query-Key-Value Pair
|
| 167 |
+
query = attn.to_q(hidden_states)
|
| 168 |
+
key = attn.to_k(encoder_hidden_states)
|
| 169 |
+
value = attn.to_v(encoder_hidden_states)
|
| 170 |
+
|
| 171 |
+
query_dim = query.shape[-1]
|
| 172 |
+
inner_dim = key.shape[-1]
|
| 173 |
+
head_dim = query_dim // attn.heads
|
| 174 |
+
dtype = query.dtype
|
| 175 |
+
|
| 176 |
+
# Get key-value heads
|
| 177 |
+
kv_heads = inner_dim // head_dim
|
| 178 |
+
|
| 179 |
+
# Reshape tensors for attention computation
|
| 180 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 181 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 182 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 183 |
+
|
| 184 |
+
# Apply Query-Key normalization
|
| 185 |
+
if attn.norm_q is not None:
|
| 186 |
+
query = attn.norm_q(query)
|
| 187 |
+
if attn.norm_k is not None:
|
| 188 |
+
key = attn.norm_k(key)
|
| 189 |
+
|
| 190 |
+
# Apply Rotary Position Embeddings
|
| 191 |
+
if image_rotary_emb is not None:
|
| 192 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 193 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 194 |
+
|
| 195 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 196 |
+
|
| 197 |
+
# Calculate attention scale
|
| 198 |
+
if base_sequence_length is not None:
|
| 199 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 200 |
+
else:
|
| 201 |
+
softmax_scale = attn.scale
|
| 202 |
+
|
| 203 |
+
# Unpad input for flash attention
|
| 204 |
+
(
|
| 205 |
+
query_states,
|
| 206 |
+
key_states,
|
| 207 |
+
value_states,
|
| 208 |
+
indices_q,
|
| 209 |
+
cu_seq_lens,
|
| 210 |
+
max_seq_lens,
|
| 211 |
+
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
|
| 212 |
+
|
| 213 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 214 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 215 |
+
|
| 216 |
+
# Handle different number of heads
|
| 217 |
+
if kv_heads < attn.heads:
|
| 218 |
+
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
| 219 |
+
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
| 220 |
+
|
| 221 |
+
# Apply flash attention
|
| 222 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 223 |
+
query_states,
|
| 224 |
+
key_states,
|
| 225 |
+
value_states,
|
| 226 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 227 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 228 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 229 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 230 |
+
dropout_p=0.0,
|
| 231 |
+
causal=False,
|
| 232 |
+
softmax_scale=softmax_scale,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Pad output and apply final transformations
|
| 236 |
+
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
|
| 237 |
+
hidden_states = hidden_states.flatten(-2)
|
| 238 |
+
hidden_states = hidden_states.type_as(query)
|
| 239 |
+
|
| 240 |
+
# Apply output projection
|
| 241 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 242 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 243 |
+
|
| 244 |
+
return hidden_states
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class OmniGen2AttnProcessor:
|
| 248 |
+
"""
|
| 249 |
+
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
|
| 250 |
+
|
| 251 |
+
This processor is optimized for PyTorch 2.0 and implements:
|
| 252 |
+
- Flash attention with variable length sequences
|
| 253 |
+
- Rotary position embeddings (RoPE)
|
| 254 |
+
- Query-Key normalization
|
| 255 |
+
- Proportional attention scaling
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
None
|
| 259 |
+
|
| 260 |
+
Raises:
|
| 261 |
+
ImportError: If PyTorch version is less than 2.0
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self) -> None:
|
| 265 |
+
"""Initialize the attention processor."""
|
| 266 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 267 |
+
raise ImportError(
|
| 268 |
+
"OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
|
| 269 |
+
"Please upgrade PyTorch to version 2.0 or later."
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def __call__(
|
| 273 |
+
self,
|
| 274 |
+
attn: Attention,
|
| 275 |
+
hidden_states: torch.Tensor,
|
| 276 |
+
encoder_hidden_states: torch.Tensor,
|
| 277 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 278 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 279 |
+
base_sequence_length: Optional[int] = None,
|
| 280 |
+
) -> torch.Tensor:
|
| 281 |
+
"""
|
| 282 |
+
Process attention computation with flash attention.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
attn: Attention module
|
| 286 |
+
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
|
| 287 |
+
encoder_hidden_states: Encoder hidden states tensor
|
| 288 |
+
attention_mask: Optional attention mask tensor
|
| 289 |
+
image_rotary_emb: Optional rotary embeddings for image tokens
|
| 290 |
+
base_sequence_length: Optional base sequence length for proportional attention
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
torch.Tensor: Processed hidden states after attention computation
|
| 294 |
+
"""
|
| 295 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 296 |
+
|
| 297 |
+
# Get Query-Key-Value Pair
|
| 298 |
+
query = attn.to_q(hidden_states)
|
| 299 |
+
key = attn.to_k(encoder_hidden_states)
|
| 300 |
+
value = attn.to_v(encoder_hidden_states)
|
| 301 |
+
|
| 302 |
+
query_dim = query.shape[-1]
|
| 303 |
+
inner_dim = key.shape[-1]
|
| 304 |
+
head_dim = query_dim // attn.heads
|
| 305 |
+
dtype = query.dtype
|
| 306 |
+
|
| 307 |
+
# Get key-value heads
|
| 308 |
+
kv_heads = inner_dim // head_dim
|
| 309 |
+
|
| 310 |
+
# Reshape tensors for attention computation
|
| 311 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 312 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 313 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 314 |
+
|
| 315 |
+
# Apply Query-Key normalization
|
| 316 |
+
if attn.norm_q is not None:
|
| 317 |
+
query = attn.norm_q(query)
|
| 318 |
+
if attn.norm_k is not None:
|
| 319 |
+
key = attn.norm_k(key)
|
| 320 |
+
|
| 321 |
+
# Apply Rotary Position Embeddings
|
| 322 |
+
if image_rotary_emb is not None:
|
| 323 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 324 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 325 |
+
|
| 326 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 327 |
+
|
| 328 |
+
# Calculate attention scale
|
| 329 |
+
if base_sequence_length is not None:
|
| 330 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 331 |
+
else:
|
| 332 |
+
softmax_scale = attn.scale
|
| 333 |
+
|
| 334 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 335 |
+
# (batch, heads, source_length, target_length)
|
| 336 |
+
if attention_mask is not None:
|
| 337 |
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
| 338 |
+
|
| 339 |
+
query = query.transpose(1, 2)
|
| 340 |
+
key = key.transpose(1, 2)
|
| 341 |
+
value = value.transpose(1, 2)
|
| 342 |
+
|
| 343 |
+
# explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
|
| 344 |
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
| 345 |
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
| 346 |
+
|
| 347 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 348 |
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
| 349 |
+
)
|
| 350 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 351 |
+
hidden_states = hidden_states.type_as(query)
|
| 352 |
+
|
| 353 |
+
# Apply output projection
|
| 354 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 355 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 356 |
+
|
| 357 |
+
return hidden_states
|
omnigen2/models/embeddings.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import List, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from diffusers.models.activations import get_activation
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TimestepEmbedding(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
in_channels: int,
|
| 27 |
+
time_embed_dim: int,
|
| 28 |
+
act_fn: str = "silu",
|
| 29 |
+
out_dim: int = None,
|
| 30 |
+
post_act_fn: Optional[str] = None,
|
| 31 |
+
cond_proj_dim=None,
|
| 32 |
+
sample_proj_bias=True,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
| 37 |
+
|
| 38 |
+
if cond_proj_dim is not None:
|
| 39 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 40 |
+
else:
|
| 41 |
+
self.cond_proj = None
|
| 42 |
+
|
| 43 |
+
self.act = get_activation(act_fn)
|
| 44 |
+
|
| 45 |
+
if out_dim is not None:
|
| 46 |
+
time_embed_dim_out = out_dim
|
| 47 |
+
else:
|
| 48 |
+
time_embed_dim_out = time_embed_dim
|
| 49 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
| 50 |
+
|
| 51 |
+
if post_act_fn is None:
|
| 52 |
+
self.post_act = None
|
| 53 |
+
else:
|
| 54 |
+
self.post_act = get_activation(post_act_fn)
|
| 55 |
+
|
| 56 |
+
self.initialize_weights()
|
| 57 |
+
|
| 58 |
+
def initialize_weights(self):
|
| 59 |
+
nn.init.normal_(self.linear_1.weight, std=0.02)
|
| 60 |
+
nn.init.zeros_(self.linear_1.bias)
|
| 61 |
+
nn.init.normal_(self.linear_2.weight, std=0.02)
|
| 62 |
+
nn.init.zeros_(self.linear_2.bias)
|
| 63 |
+
|
| 64 |
+
def forward(self, sample, condition=None):
|
| 65 |
+
if condition is not None:
|
| 66 |
+
sample = sample + self.cond_proj(condition)
|
| 67 |
+
sample = self.linear_1(sample)
|
| 68 |
+
|
| 69 |
+
if self.act is not None:
|
| 70 |
+
sample = self.act(sample)
|
| 71 |
+
|
| 72 |
+
sample = self.linear_2(sample)
|
| 73 |
+
|
| 74 |
+
if self.post_act is not None:
|
| 75 |
+
sample = self.post_act(sample)
|
| 76 |
+
return sample
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def apply_rotary_emb(
|
| 80 |
+
x: torch.Tensor,
|
| 81 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 82 |
+
use_real: bool = True,
|
| 83 |
+
use_real_unbind_dim: int = -1,
|
| 84 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 85 |
+
"""
|
| 86 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 87 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 88 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 89 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
x (`torch.Tensor`):
|
| 93 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 94 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 98 |
+
"""
|
| 99 |
+
if use_real:
|
| 100 |
+
cos, sin = freqs_cis # [S, D]
|
| 101 |
+
cos = cos[None, None]
|
| 102 |
+
sin = sin[None, None]
|
| 103 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 104 |
+
|
| 105 |
+
if use_real_unbind_dim == -1:
|
| 106 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 107 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 108 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 109 |
+
elif use_real_unbind_dim == -2:
|
| 110 |
+
# Used for Stable Audio, OmniGen and CogView4
|
| 111 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
| 112 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 115 |
+
|
| 116 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 117 |
+
|
| 118 |
+
return out
|
| 119 |
+
else:
|
| 120 |
+
# used for lumina
|
| 121 |
+
# x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 122 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
|
| 123 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 124 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 125 |
+
|
| 126 |
+
return x_out.type_as(x)
|
omnigen2/models/transformers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transformer_omnigen2 import OmniGen2Transformer2DModel
|
| 2 |
+
|
| 3 |
+
__all__ = ["OmniGen2Transformer2DModel"]
|
omnigen2/models/transformers/block_lumina2.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
from typing import Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from diffusers.models.embeddings import Timesteps
|
| 23 |
+
from ..embeddings import TimestepEmbedding
|
| 24 |
+
|
| 25 |
+
from ...utils.import_utils import is_flash_attn_available, is_triton_available
|
| 26 |
+
|
| 27 |
+
if is_triton_available():
|
| 28 |
+
from ...ops.triton.layer_norm import RMSNorm
|
| 29 |
+
else:
|
| 30 |
+
from torch.nn import RMSNorm
|
| 31 |
+
warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
|
| 32 |
+
|
| 33 |
+
if is_flash_attn_available():
|
| 34 |
+
from flash_attn.ops.activations import swiglu
|
| 35 |
+
else:
|
| 36 |
+
from .components import swiglu
|
| 37 |
+
warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance")
|
| 38 |
+
|
| 39 |
+
# try:
|
| 40 |
+
# from flash_attn.ops.activations import swiglu as fused_swiglu
|
| 41 |
+
# FUSEDSWIGLU_AVALIBLE = True
|
| 42 |
+
# except ImportError:
|
| 43 |
+
|
| 44 |
+
# FUSEDSWIGLU_AVALIBLE = False
|
| 45 |
+
# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
| 46 |
+
|
| 47 |
+
class LuminaRMSNormZero(nn.Module):
|
| 48 |
+
"""
|
| 49 |
+
Norm layer adaptive RMS normalization zero.
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
embedding_dim (`int`): The size of each embedding vector.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
embedding_dim: int,
|
| 58 |
+
norm_eps: float,
|
| 59 |
+
norm_elementwise_affine: bool,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.silu = nn.SiLU()
|
| 63 |
+
self.linear = nn.Linear(
|
| 64 |
+
min(embedding_dim, 1024),
|
| 65 |
+
4 * embedding_dim,
|
| 66 |
+
bias=True,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self,
|
| 73 |
+
x: torch.Tensor,
|
| 74 |
+
emb: Optional[torch.Tensor] = None,
|
| 75 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 76 |
+
emb = self.linear(self.silu(emb))
|
| 77 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
| 78 |
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
| 79 |
+
return x, gate_msa, scale_mlp, gate_mlp
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LuminaLayerNormContinuous(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
embedding_dim: int,
|
| 86 |
+
conditioning_embedding_dim: int,
|
| 87 |
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
| 88 |
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
| 89 |
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
| 90 |
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
| 91 |
+
# set `elementwise_affine` to False.
|
| 92 |
+
elementwise_affine=True,
|
| 93 |
+
eps=1e-5,
|
| 94 |
+
bias=True,
|
| 95 |
+
norm_type="layer_norm",
|
| 96 |
+
out_dim: Optional[int] = None,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
|
| 100 |
+
# AdaLN
|
| 101 |
+
self.silu = nn.SiLU()
|
| 102 |
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
| 103 |
+
|
| 104 |
+
if norm_type == "layer_norm":
|
| 105 |
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
| 106 |
+
elif norm_type == "rms_norm":
|
| 107 |
+
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
| 110 |
+
|
| 111 |
+
self.linear_2 = None
|
| 112 |
+
if out_dim is not None:
|
| 113 |
+
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
| 114 |
+
|
| 115 |
+
def forward(
|
| 116 |
+
self,
|
| 117 |
+
x: torch.Tensor,
|
| 118 |
+
conditioning_embedding: torch.Tensor,
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
| 121 |
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
| 122 |
+
scale = emb
|
| 123 |
+
x = self.norm(x) * (1 + scale)[:, None, :]
|
| 124 |
+
|
| 125 |
+
if self.linear_2 is not None:
|
| 126 |
+
x = self.linear_2(x)
|
| 127 |
+
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class LuminaFeedForward(nn.Module):
|
| 132 |
+
r"""
|
| 133 |
+
A feed-forward layer.
|
| 134 |
+
|
| 135 |
+
Parameters:
|
| 136 |
+
hidden_size (`int`):
|
| 137 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
| 138 |
+
hidden representations.
|
| 139 |
+
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
| 140 |
+
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
| 141 |
+
of this value.
|
| 142 |
+
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
| 143 |
+
dimension. Defaults to None.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
dim: int,
|
| 149 |
+
inner_dim: int,
|
| 150 |
+
multiple_of: Optional[int] = 256,
|
| 151 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 152 |
+
):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.swiglu = swiglu
|
| 155 |
+
|
| 156 |
+
# custom hidden_size factor multiplier
|
| 157 |
+
if ffn_dim_multiplier is not None:
|
| 158 |
+
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
| 159 |
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
| 160 |
+
|
| 161 |
+
self.linear_1 = nn.Linear(
|
| 162 |
+
dim,
|
| 163 |
+
inner_dim,
|
| 164 |
+
bias=False,
|
| 165 |
+
)
|
| 166 |
+
self.linear_2 = nn.Linear(
|
| 167 |
+
inner_dim,
|
| 168 |
+
dim,
|
| 169 |
+
bias=False,
|
| 170 |
+
)
|
| 171 |
+
self.linear_3 = nn.Linear(
|
| 172 |
+
dim,
|
| 173 |
+
inner_dim,
|
| 174 |
+
bias=False,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
h1, h2 = self.linear_1(x), self.linear_3(x)
|
| 179 |
+
return self.linear_2(self.swiglu(h1, h2))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
hidden_size: int = 4096,
|
| 186 |
+
text_feat_dim: int = 2048,
|
| 187 |
+
frequency_embedding_size: int = 256,
|
| 188 |
+
norm_eps: float = 1e-5,
|
| 189 |
+
timestep_scale: float = 1.0,
|
| 190 |
+
) -> None:
|
| 191 |
+
super().__init__()
|
| 192 |
+
|
| 193 |
+
self.time_proj = Timesteps(
|
| 194 |
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 198 |
+
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.caption_embedder = nn.Sequential(
|
| 202 |
+
RMSNorm(text_feat_dim, eps=norm_eps),
|
| 203 |
+
nn.Linear(text_feat_dim, hidden_size, bias=True),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
self._initialize_weights()
|
| 207 |
+
|
| 208 |
+
def _initialize_weights(self):
|
| 209 |
+
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
|
| 210 |
+
nn.init.zeros_(self.caption_embedder[1].bias)
|
| 211 |
+
|
| 212 |
+
def forward(
|
| 213 |
+
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
| 214 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 215 |
+
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
| 216 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
| 217 |
+
caption_embed = self.caption_embedder(text_hidden_states)
|
| 218 |
+
return time_embed, caption_embed
|
omnigen2/models/transformers/components.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
|
| 3 |
+
def swiglu(x, y):
|
| 4 |
+
return F.silu(x.float(), inplace=False).to(x.dtype) * y
|
omnigen2/models/transformers/repo.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from einops import repeat
|
| 7 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 8 |
+
|
| 9 |
+
class OmniGen2RotaryPosEmbed(nn.Module):
|
| 10 |
+
def __init__(self, theta: int,
|
| 11 |
+
axes_dim: Tuple[int, int, int],
|
| 12 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 13 |
+
patch_size: int = 2):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.theta = theta
|
| 16 |
+
self.axes_dim = axes_dim
|
| 17 |
+
self.axes_lens = axes_lens
|
| 18 |
+
self.patch_size = patch_size
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
def get_freqs_cis(axes_dim: Tuple[int, int, int],
|
| 22 |
+
axes_lens: Tuple[int, int, int],
|
| 23 |
+
theta: int) -> List[torch.Tensor]:
|
| 24 |
+
freqs_cis = []
|
| 25 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 26 |
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
| 27 |
+
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
|
| 28 |
+
freqs_cis.append(emb)
|
| 29 |
+
return freqs_cis
|
| 30 |
+
|
| 31 |
+
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
device = ids.device
|
| 33 |
+
if ids.device.type == "mps":
|
| 34 |
+
ids = ids.to("cpu")
|
| 35 |
+
|
| 36 |
+
result = []
|
| 37 |
+
for i in range(len(self.axes_dim)):
|
| 38 |
+
freqs = freqs_cis[i].to(ids.device)
|
| 39 |
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
| 40 |
+
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
| 41 |
+
return torch.cat(result, dim=-1).to(device)
|
| 42 |
+
|
| 43 |
+
def forward(
|
| 44 |
+
self,
|
| 45 |
+
freqs_cis,
|
| 46 |
+
attention_mask,
|
| 47 |
+
l_effective_ref_img_len,
|
| 48 |
+
l_effective_img_len,
|
| 49 |
+
ref_img_sizes,
|
| 50 |
+
img_sizes,
|
| 51 |
+
device
|
| 52 |
+
):
|
| 53 |
+
batch_size = len(attention_mask)
|
| 54 |
+
p = self.patch_size
|
| 55 |
+
|
| 56 |
+
encoder_seq_len = attention_mask.shape[1]
|
| 57 |
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
| 58 |
+
|
| 59 |
+
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
| 60 |
+
|
| 61 |
+
max_seq_len = max(seq_lengths)
|
| 62 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 63 |
+
max_img_len = max(l_effective_img_len)
|
| 64 |
+
|
| 65 |
+
# Create position IDs
|
| 66 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 67 |
+
|
| 68 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 69 |
+
# add text position ids
|
| 70 |
+
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
| 71 |
+
|
| 72 |
+
pe_shift = cap_seq_len
|
| 73 |
+
pe_shift_len = cap_seq_len
|
| 74 |
+
|
| 75 |
+
if ref_img_sizes[i] is not None:
|
| 76 |
+
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
| 77 |
+
H, W = ref_img_size
|
| 78 |
+
ref_H_tokens, ref_W_tokens = H // p, W // p
|
| 79 |
+
assert ref_H_tokens * ref_W_tokens == ref_img_len
|
| 80 |
+
# add image position ids
|
| 81 |
+
|
| 82 |
+
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
| 83 |
+
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
| 84 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
| 85 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
| 86 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
| 87 |
+
|
| 88 |
+
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
| 89 |
+
pe_shift_len += ref_img_len
|
| 90 |
+
|
| 91 |
+
H, W = img_sizes[i]
|
| 92 |
+
H_tokens, W_tokens = H // p, W // p
|
| 93 |
+
assert H_tokens * W_tokens == l_effective_img_len[i]
|
| 94 |
+
|
| 95 |
+
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
| 96 |
+
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
| 97 |
+
|
| 98 |
+
assert pe_shift_len + l_effective_img_len[i] == seq_len
|
| 99 |
+
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
| 100 |
+
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
| 101 |
+
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
| 102 |
+
|
| 103 |
+
# Get combined rotary embeddings
|
| 104 |
+
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
|
| 105 |
+
|
| 106 |
+
# create separate rotary embeddings for captions and images
|
| 107 |
+
cap_freqs_cis = torch.zeros(
|
| 108 |
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 109 |
+
)
|
| 110 |
+
ref_img_freqs_cis = torch.zeros(
|
| 111 |
+
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 112 |
+
)
|
| 113 |
+
img_freqs_cis = torch.zeros(
|
| 114 |
+
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
| 118 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
| 119 |
+
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
| 120 |
+
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
| 121 |
+
|
| 122 |
+
return (
|
| 123 |
+
cap_freqs_cis,
|
| 124 |
+
ref_img_freqs_cis,
|
| 125 |
+
img_freqs_cis,
|
| 126 |
+
freqs_cis,
|
| 127 |
+
l_effective_cap_len,
|
| 128 |
+
seq_lengths,
|
| 129 |
+
)
|
omnigen2/models/transformers/transformer_omnigen2.py
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import itertools
|
| 3 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 13 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 14 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 15 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 16 |
+
from diffusers.models.attention_processor import Attention
|
| 17 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
|
| 20 |
+
from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor
|
| 21 |
+
from .repo import OmniGen2RotaryPosEmbed
|
| 22 |
+
from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
|
| 23 |
+
|
| 24 |
+
from ...utils.import_utils import is_triton_available, is_flash_attn_available
|
| 25 |
+
from ...utils.teacache_util import TeaCacheParams
|
| 26 |
+
|
| 27 |
+
if is_triton_available():
|
| 28 |
+
from ...ops.triton.layer_norm import RMSNorm
|
| 29 |
+
else:
|
| 30 |
+
from torch.nn import RMSNorm
|
| 31 |
+
|
| 32 |
+
from ...taylorseer_utils import derivative_approximation, taylor_formula, taylor_cache_init
|
| 33 |
+
from ...cache_functions import cache_init, cal_type
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
class OmniGen2TransformerBlock(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Transformer block for OmniGen2 model.
|
| 40 |
+
|
| 41 |
+
This block implements a transformer layer with:
|
| 42 |
+
- Multi-head attention with flash attention
|
| 43 |
+
- Feed-forward network with SwiGLU activation
|
| 44 |
+
- RMS normalization
|
| 45 |
+
- Optional modulation for conditional generation
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
dim: Dimension of the input and output tensors
|
| 49 |
+
num_attention_heads: Number of attention heads
|
| 50 |
+
num_kv_heads: Number of key-value heads
|
| 51 |
+
multiple_of: Multiple of which the hidden dimension should be
|
| 52 |
+
ffn_dim_multiplier: Multiplier for the feed-forward network dimension
|
| 53 |
+
norm_eps: Epsilon value for normalization layers
|
| 54 |
+
modulation: Whether to use modulation for conditional generation
|
| 55 |
+
use_fused_rms_norm: Whether to use fused RMS normalization
|
| 56 |
+
use_fused_swiglu: Whether to use fused SwiGLU activation
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
dim: int,
|
| 62 |
+
num_attention_heads: int,
|
| 63 |
+
num_kv_heads: int,
|
| 64 |
+
multiple_of: int,
|
| 65 |
+
ffn_dim_multiplier: float,
|
| 66 |
+
norm_eps: float,
|
| 67 |
+
modulation: bool = True,
|
| 68 |
+
) -> None:
|
| 69 |
+
"""Initialize the transformer block."""
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.head_dim = dim // num_attention_heads
|
| 72 |
+
self.modulation = modulation
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
processor = OmniGen2AttnProcessorFlash2Varlen()
|
| 76 |
+
except ImportError:
|
| 77 |
+
processor = OmniGen2AttnProcessor()
|
| 78 |
+
|
| 79 |
+
# Initialize attention layer
|
| 80 |
+
self.attn = Attention(
|
| 81 |
+
query_dim=dim,
|
| 82 |
+
cross_attention_dim=None,
|
| 83 |
+
dim_head=dim // num_attention_heads,
|
| 84 |
+
qk_norm="rms_norm",
|
| 85 |
+
heads=num_attention_heads,
|
| 86 |
+
kv_heads=num_kv_heads,
|
| 87 |
+
eps=1e-5,
|
| 88 |
+
bias=False,
|
| 89 |
+
out_bias=False,
|
| 90 |
+
processor=processor,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Initialize feed-forward network
|
| 94 |
+
self.feed_forward = LuminaFeedForward(
|
| 95 |
+
dim=dim,
|
| 96 |
+
inner_dim=4 * dim,
|
| 97 |
+
multiple_of=multiple_of,
|
| 98 |
+
ffn_dim_multiplier=ffn_dim_multiplier
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Initialize normalization layers
|
| 102 |
+
if modulation:
|
| 103 |
+
self.norm1 = LuminaRMSNormZero(
|
| 104 |
+
embedding_dim=dim,
|
| 105 |
+
norm_eps=norm_eps,
|
| 106 |
+
norm_elementwise_affine=True
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
| 110 |
+
|
| 111 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 112 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
| 113 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 114 |
+
|
| 115 |
+
self.initialize_weights()
|
| 116 |
+
|
| 117 |
+
def initialize_weights(self) -> None:
|
| 118 |
+
"""
|
| 119 |
+
Initialize the weights of the transformer block.
|
| 120 |
+
|
| 121 |
+
Uses Xavier uniform initialization for linear layers and zero initialization for biases.
|
| 122 |
+
"""
|
| 123 |
+
nn.init.xavier_uniform_(self.attn.to_q.weight)
|
| 124 |
+
nn.init.xavier_uniform_(self.attn.to_k.weight)
|
| 125 |
+
nn.init.xavier_uniform_(self.attn.to_v.weight)
|
| 126 |
+
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
|
| 127 |
+
|
| 128 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
|
| 129 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
|
| 130 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
|
| 131 |
+
|
| 132 |
+
if self.modulation:
|
| 133 |
+
nn.init.zeros_(self.norm1.linear.weight)
|
| 134 |
+
nn.init.zeros_(self.norm1.linear.bias)
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
hidden_states: torch.Tensor,
|
| 139 |
+
attention_mask: torch.Tensor,
|
| 140 |
+
image_rotary_emb: torch.Tensor,
|
| 141 |
+
temb: Optional[torch.Tensor] = None,
|
| 142 |
+
) -> torch.Tensor:
|
| 143 |
+
"""
|
| 144 |
+
Forward pass of the transformer block.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
hidden_states: Input hidden states tensor
|
| 148 |
+
attention_mask: Attention mask tensor
|
| 149 |
+
image_rotary_emb: Rotary embeddings for image tokens
|
| 150 |
+
temb: Optional timestep embedding tensor
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
torch.Tensor: Output hidden states after transformer block processing
|
| 154 |
+
"""
|
| 155 |
+
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
|
| 156 |
+
if enable_taylorseer:
|
| 157 |
+
if self.modulation:
|
| 158 |
+
if temb is None:
|
| 159 |
+
raise ValueError("temb must be provided when modulation is enabled")
|
| 160 |
+
|
| 161 |
+
if self.current['type'] == 'full':
|
| 162 |
+
self.current['module'] = 'total'
|
| 163 |
+
taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
|
| 164 |
+
|
| 165 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 166 |
+
attn_output = self.attn(
|
| 167 |
+
hidden_states=norm_hidden_states,
|
| 168 |
+
encoder_hidden_states=norm_hidden_states,
|
| 169 |
+
attention_mask=attention_mask,
|
| 170 |
+
image_rotary_emb=image_rotary_emb,
|
| 171 |
+
)
|
| 172 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 173 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 174 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 175 |
+
|
| 176 |
+
derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)
|
| 177 |
+
|
| 178 |
+
elif self.current['type'] == 'Taylor':
|
| 179 |
+
self.current['module'] = 'total'
|
| 180 |
+
hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
|
| 181 |
+
else:
|
| 182 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 183 |
+
attn_output = self.attn(
|
| 184 |
+
hidden_states=norm_hidden_states,
|
| 185 |
+
encoder_hidden_states=norm_hidden_states,
|
| 186 |
+
attention_mask=attention_mask,
|
| 187 |
+
image_rotary_emb=image_rotary_emb,
|
| 188 |
+
)
|
| 189 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 190 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 191 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 192 |
+
else:
|
| 193 |
+
if self.modulation:
|
| 194 |
+
if temb is None:
|
| 195 |
+
raise ValueError("temb must be provided when modulation is enabled")
|
| 196 |
+
|
| 197 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 198 |
+
attn_output = self.attn(
|
| 199 |
+
hidden_states=norm_hidden_states,
|
| 200 |
+
encoder_hidden_states=norm_hidden_states,
|
| 201 |
+
attention_mask=attention_mask,
|
| 202 |
+
image_rotary_emb=image_rotary_emb,
|
| 203 |
+
)
|
| 204 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 205 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 206 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 207 |
+
else:
|
| 208 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 209 |
+
attn_output = self.attn(
|
| 210 |
+
hidden_states=norm_hidden_states,
|
| 211 |
+
encoder_hidden_states=norm_hidden_states,
|
| 212 |
+
attention_mask=attention_mask,
|
| 213 |
+
image_rotary_emb=image_rotary_emb,
|
| 214 |
+
)
|
| 215 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 216 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 217 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 218 |
+
|
| 219 |
+
return hidden_states
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 223 |
+
"""
|
| 224 |
+
OmniGen2 Transformer 2D Model.
|
| 225 |
+
|
| 226 |
+
A transformer-based diffusion model for image generation with:
|
| 227 |
+
- Patch-based image processing
|
| 228 |
+
- Rotary position embeddings
|
| 229 |
+
- Multi-head attention
|
| 230 |
+
- Conditional generation support
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
patch_size: Size of image patches
|
| 234 |
+
in_channels: Number of input channels
|
| 235 |
+
out_channels: Number of output channels (defaults to in_channels)
|
| 236 |
+
hidden_size: Size of hidden layers
|
| 237 |
+
num_layers: Number of transformer layers
|
| 238 |
+
num_refiner_layers: Number of refiner layers
|
| 239 |
+
num_attention_heads: Number of attention heads
|
| 240 |
+
num_kv_heads: Number of key-value heads
|
| 241 |
+
multiple_of: Multiple of which the hidden dimension should be
|
| 242 |
+
ffn_dim_multiplier: Multiplier for feed-forward network dimension
|
| 243 |
+
norm_eps: Epsilon value for normalization layers
|
| 244 |
+
axes_dim_rope: Dimensions for rotary position embeddings
|
| 245 |
+
axes_lens: Lengths for rotary position embeddings
|
| 246 |
+
text_feat_dim: Dimension of text features
|
| 247 |
+
timestep_scale: Scale factor for timestep embeddings
|
| 248 |
+
use_fused_rms_norm: Whether to use fused RMS normalization
|
| 249 |
+
use_fused_swiglu: Whether to use fused SwiGLU activation
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
_supports_gradient_checkpointing = True
|
| 253 |
+
_no_split_modules = ["Omnigen2TransformerBlock"]
|
| 254 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
| 255 |
+
|
| 256 |
+
@register_to_config
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
patch_size: int = 2,
|
| 260 |
+
in_channels: int = 16,
|
| 261 |
+
out_channels: Optional[int] = None,
|
| 262 |
+
hidden_size: int = 2304,
|
| 263 |
+
num_layers: int = 26,
|
| 264 |
+
num_refiner_layers: int = 2,
|
| 265 |
+
num_attention_heads: int = 24,
|
| 266 |
+
num_kv_heads: int = 8,
|
| 267 |
+
multiple_of: int = 256,
|
| 268 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 269 |
+
norm_eps: float = 1e-5,
|
| 270 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
| 271 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 272 |
+
text_feat_dim: int = 1024,
|
| 273 |
+
timestep_scale: float = 1.0
|
| 274 |
+
) -> None:
|
| 275 |
+
"""Initialize the OmniGen2 transformer model."""
|
| 276 |
+
super().__init__()
|
| 277 |
+
|
| 278 |
+
# Validate configuration
|
| 279 |
+
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
|
| 280 |
+
raise ValueError(
|
| 281 |
+
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
|
| 282 |
+
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.out_channels = out_channels or in_channels
|
| 286 |
+
|
| 287 |
+
# Initialize embeddings
|
| 288 |
+
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
| 289 |
+
theta=10000,
|
| 290 |
+
axes_dim=axes_dim_rope,
|
| 291 |
+
axes_lens=axes_lens,
|
| 292 |
+
patch_size=patch_size,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self.x_embedder = nn.Linear(
|
| 296 |
+
in_features=patch_size * patch_size * in_channels,
|
| 297 |
+
out_features=hidden_size,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.ref_image_patch_embedder = nn.Linear(
|
| 301 |
+
in_features=patch_size * patch_size * in_channels,
|
| 302 |
+
out_features=hidden_size,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
| 306 |
+
hidden_size=hidden_size,
|
| 307 |
+
text_feat_dim=text_feat_dim,
|
| 308 |
+
norm_eps=norm_eps,
|
| 309 |
+
timestep_scale=timestep_scale
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Initialize transformer blocks
|
| 313 |
+
self.noise_refiner = nn.ModuleList([
|
| 314 |
+
OmniGen2TransformerBlock(
|
| 315 |
+
hidden_size,
|
| 316 |
+
num_attention_heads,
|
| 317 |
+
num_kv_heads,
|
| 318 |
+
multiple_of,
|
| 319 |
+
ffn_dim_multiplier,
|
| 320 |
+
norm_eps,
|
| 321 |
+
modulation=True
|
| 322 |
+
)
|
| 323 |
+
for _ in range(num_refiner_layers)
|
| 324 |
+
])
|
| 325 |
+
|
| 326 |
+
self.ref_image_refiner = nn.ModuleList([
|
| 327 |
+
OmniGen2TransformerBlock(
|
| 328 |
+
hidden_size,
|
| 329 |
+
num_attention_heads,
|
| 330 |
+
num_kv_heads,
|
| 331 |
+
multiple_of,
|
| 332 |
+
ffn_dim_multiplier,
|
| 333 |
+
norm_eps,
|
| 334 |
+
modulation=True
|
| 335 |
+
)
|
| 336 |
+
for _ in range(num_refiner_layers)
|
| 337 |
+
])
|
| 338 |
+
|
| 339 |
+
self.context_refiner = nn.ModuleList(
|
| 340 |
+
[
|
| 341 |
+
OmniGen2TransformerBlock(
|
| 342 |
+
hidden_size,
|
| 343 |
+
num_attention_heads,
|
| 344 |
+
num_kv_heads,
|
| 345 |
+
multiple_of,
|
| 346 |
+
ffn_dim_multiplier,
|
| 347 |
+
norm_eps,
|
| 348 |
+
modulation=False
|
| 349 |
+
)
|
| 350 |
+
for _ in range(num_refiner_layers)
|
| 351 |
+
]
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# 3. Transformer blocks
|
| 355 |
+
self.layers = nn.ModuleList(
|
| 356 |
+
[
|
| 357 |
+
OmniGen2TransformerBlock(
|
| 358 |
+
hidden_size,
|
| 359 |
+
num_attention_heads,
|
| 360 |
+
num_kv_heads,
|
| 361 |
+
multiple_of,
|
| 362 |
+
ffn_dim_multiplier,
|
| 363 |
+
norm_eps,
|
| 364 |
+
modulation=True
|
| 365 |
+
)
|
| 366 |
+
for _ in range(num_layers)
|
| 367 |
+
]
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# 4. Output norm & projection
|
| 371 |
+
self.norm_out = LuminaLayerNormContinuous(
|
| 372 |
+
embedding_dim=hidden_size,
|
| 373 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
| 374 |
+
elementwise_affine=False,
|
| 375 |
+
eps=1e-6,
|
| 376 |
+
bias=True,
|
| 377 |
+
out_dim=patch_size * patch_size * self.out_channels
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Add learnable embeddings to distinguish different images
|
| 381 |
+
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
|
| 382 |
+
|
| 383 |
+
self.gradient_checkpointing = False
|
| 384 |
+
|
| 385 |
+
self.initialize_weights()
|
| 386 |
+
|
| 387 |
+
# TeaCache settings
|
| 388 |
+
self.enable_teacache = False
|
| 389 |
+
self.teacache_rel_l1_thresh = 0.05
|
| 390 |
+
self.teacache_params = TeaCacheParams()
|
| 391 |
+
|
| 392 |
+
coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
|
| 393 |
+
self.rescale_func = np.poly1d(coefficients)
|
| 394 |
+
|
| 395 |
+
def initialize_weights(self) -> None:
|
| 396 |
+
"""
|
| 397 |
+
Initialize the weights of the model.
|
| 398 |
+
|
| 399 |
+
Uses Xavier uniform initialization for linear layers.
|
| 400 |
+
"""
|
| 401 |
+
nn.init.xavier_uniform_(self.x_embedder.weight)
|
| 402 |
+
nn.init.constant_(self.x_embedder.bias, 0.0)
|
| 403 |
+
|
| 404 |
+
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
|
| 405 |
+
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
|
| 406 |
+
|
| 407 |
+
nn.init.zeros_(self.norm_out.linear_1.weight)
|
| 408 |
+
nn.init.zeros_(self.norm_out.linear_1.bias)
|
| 409 |
+
nn.init.zeros_(self.norm_out.linear_2.weight)
|
| 410 |
+
nn.init.zeros_(self.norm_out.linear_2.bias)
|
| 411 |
+
|
| 412 |
+
nn.init.normal_(self.image_index_embedding, std=0.02)
|
| 413 |
+
|
| 414 |
+
def img_patch_embed_and_refine(
|
| 415 |
+
self,
|
| 416 |
+
hidden_states,
|
| 417 |
+
ref_image_hidden_states,
|
| 418 |
+
padded_img_mask,
|
| 419 |
+
padded_ref_img_mask,
|
| 420 |
+
noise_rotary_emb,
|
| 421 |
+
ref_img_rotary_emb,
|
| 422 |
+
l_effective_ref_img_len,
|
| 423 |
+
l_effective_img_len,
|
| 424 |
+
temb
|
| 425 |
+
):
|
| 426 |
+
batch_size = len(hidden_states)
|
| 427 |
+
max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
|
| 428 |
+
|
| 429 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 430 |
+
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
| 431 |
+
|
| 432 |
+
for i in range(batch_size):
|
| 433 |
+
shift = 0
|
| 434 |
+
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
| 435 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
|
| 436 |
+
shift += ref_img_len
|
| 437 |
+
|
| 438 |
+
for layer in self.noise_refiner:
|
| 439 |
+
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
| 440 |
+
|
| 441 |
+
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
|
| 442 |
+
num_ref_images = len(flat_l_effective_ref_img_len)
|
| 443 |
+
max_ref_img_len = max(flat_l_effective_ref_img_len)
|
| 444 |
+
|
| 445 |
+
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
|
| 446 |
+
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
|
| 447 |
+
batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
|
| 448 |
+
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
|
| 449 |
+
|
| 450 |
+
# sequence of ref imgs to batch
|
| 451 |
+
idx = 0
|
| 452 |
+
for i in range(batch_size):
|
| 453 |
+
shift = 0
|
| 454 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 455 |
+
batch_ref_img_mask[idx, :ref_img_len] = True
|
| 456 |
+
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
|
| 457 |
+
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
|
| 458 |
+
batch_temb[idx] = temb[i]
|
| 459 |
+
shift += ref_img_len
|
| 460 |
+
idx += 1
|
| 461 |
+
|
| 462 |
+
# refine ref imgs separately
|
| 463 |
+
for layer in self.ref_image_refiner:
|
| 464 |
+
batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
|
| 465 |
+
|
| 466 |
+
# batch of ref imgs to sequence
|
| 467 |
+
idx = 0
|
| 468 |
+
for i in range(batch_size):
|
| 469 |
+
shift = 0
|
| 470 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 471 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
|
| 472 |
+
shift += ref_img_len
|
| 473 |
+
idx += 1
|
| 474 |
+
|
| 475 |
+
combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
|
| 476 |
+
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
|
| 477 |
+
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
|
| 478 |
+
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
|
| 479 |
+
|
| 480 |
+
return combined_img_hidden_states
|
| 481 |
+
|
| 482 |
+
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
| 483 |
+
batch_size = len(hidden_states)
|
| 484 |
+
p = self.config.patch_size
|
| 485 |
+
device = hidden_states[0].device
|
| 486 |
+
|
| 487 |
+
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
| 488 |
+
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
| 489 |
+
|
| 490 |
+
if ref_image_hidden_states is not None:
|
| 491 |
+
ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
|
| 492 |
+
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
| 493 |
+
else:
|
| 494 |
+
ref_img_sizes = [None for _ in range(batch_size)]
|
| 495 |
+
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
| 496 |
+
|
| 497 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 498 |
+
max_img_len = max(l_effective_img_len)
|
| 499 |
+
|
| 500 |
+
# ref image patch embeddings
|
| 501 |
+
flat_ref_img_hidden_states = []
|
| 502 |
+
for i in range(batch_size):
|
| 503 |
+
if ref_img_sizes[i] is not None:
|
| 504 |
+
imgs = []
|
| 505 |
+
for ref_img in ref_image_hidden_states[i]:
|
| 506 |
+
C, H, W = ref_img.size()
|
| 507 |
+
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 508 |
+
imgs.append(ref_img)
|
| 509 |
+
|
| 510 |
+
img = torch.cat(imgs, dim=0)
|
| 511 |
+
flat_ref_img_hidden_states.append(img)
|
| 512 |
+
else:
|
| 513 |
+
flat_ref_img_hidden_states.append(None)
|
| 514 |
+
|
| 515 |
+
# image patch embeddings
|
| 516 |
+
flat_hidden_states = []
|
| 517 |
+
for i in range(batch_size):
|
| 518 |
+
img = hidden_states[i]
|
| 519 |
+
C, H, W = img.size()
|
| 520 |
+
|
| 521 |
+
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 522 |
+
flat_hidden_states.append(img)
|
| 523 |
+
|
| 524 |
+
padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
| 525 |
+
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
|
| 526 |
+
for i in range(batch_size):
|
| 527 |
+
if ref_img_sizes[i] is not None:
|
| 528 |
+
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
|
| 529 |
+
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
|
| 530 |
+
|
| 531 |
+
padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
| 532 |
+
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
| 533 |
+
for i in range(batch_size):
|
| 534 |
+
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
|
| 535 |
+
padded_img_mask[i, :l_effective_img_len[i]] = True
|
| 536 |
+
|
| 537 |
+
return (
|
| 538 |
+
padded_hidden_states,
|
| 539 |
+
padded_ref_img_hidden_states,
|
| 540 |
+
padded_img_mask,
|
| 541 |
+
padded_ref_img_mask,
|
| 542 |
+
l_effective_ref_img_len,
|
| 543 |
+
l_effective_img_len,
|
| 544 |
+
ref_img_sizes,
|
| 545 |
+
img_sizes,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
def forward(
|
| 549 |
+
self,
|
| 550 |
+
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
|
| 551 |
+
timestep: torch.Tensor,
|
| 552 |
+
text_hidden_states: torch.Tensor,
|
| 553 |
+
freqs_cis: torch.Tensor,
|
| 554 |
+
text_attention_mask: torch.Tensor,
|
| 555 |
+
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
|
| 556 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 557 |
+
return_dict: bool = False,
|
| 558 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 559 |
+
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
|
| 560 |
+
if enable_taylorseer:
|
| 561 |
+
cal_type(self.cache_dic, self.current)
|
| 562 |
+
|
| 563 |
+
if attention_kwargs is not None:
|
| 564 |
+
attention_kwargs = attention_kwargs.copy()
|
| 565 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 566 |
+
else:
|
| 567 |
+
lora_scale = 1.0
|
| 568 |
+
|
| 569 |
+
if USE_PEFT_BACKEND:
|
| 570 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 571 |
+
scale_lora_layers(self, lora_scale)
|
| 572 |
+
else:
|
| 573 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 574 |
+
logger.warning(
|
| 575 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# 1. Condition, positional & patch embedding
|
| 579 |
+
batch_size = len(hidden_states)
|
| 580 |
+
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
|
| 581 |
+
|
| 582 |
+
if is_hidden_states_tensor:
|
| 583 |
+
assert hidden_states.ndim == 4
|
| 584 |
+
hidden_states = [_hidden_states for _hidden_states in hidden_states]
|
| 585 |
+
|
| 586 |
+
device = hidden_states[0].device
|
| 587 |
+
|
| 588 |
+
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
| 589 |
+
|
| 590 |
+
(
|
| 591 |
+
hidden_states,
|
| 592 |
+
ref_image_hidden_states,
|
| 593 |
+
img_mask,
|
| 594 |
+
ref_img_mask,
|
| 595 |
+
l_effective_ref_img_len,
|
| 596 |
+
l_effective_img_len,
|
| 597 |
+
ref_img_sizes,
|
| 598 |
+
img_sizes,
|
| 599 |
+
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
| 600 |
+
|
| 601 |
+
(
|
| 602 |
+
context_rotary_emb,
|
| 603 |
+
ref_img_rotary_emb,
|
| 604 |
+
noise_rotary_emb,
|
| 605 |
+
rotary_emb,
|
| 606 |
+
encoder_seq_lengths,
|
| 607 |
+
seq_lengths,
|
| 608 |
+
) = self.rope_embedder(
|
| 609 |
+
freqs_cis,
|
| 610 |
+
text_attention_mask,
|
| 611 |
+
l_effective_ref_img_len,
|
| 612 |
+
l_effective_img_len,
|
| 613 |
+
ref_img_sizes,
|
| 614 |
+
img_sizes,
|
| 615 |
+
device,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# 2. Context refinement
|
| 619 |
+
for layer in self.context_refiner:
|
| 620 |
+
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
| 621 |
+
|
| 622 |
+
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
| 623 |
+
hidden_states,
|
| 624 |
+
ref_image_hidden_states,
|
| 625 |
+
img_mask,
|
| 626 |
+
ref_img_mask,
|
| 627 |
+
noise_rotary_emb,
|
| 628 |
+
ref_img_rotary_emb,
|
| 629 |
+
l_effective_ref_img_len,
|
| 630 |
+
l_effective_img_len,
|
| 631 |
+
temb,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# 3. Joint Transformer blocks
|
| 635 |
+
max_seq_len = max(seq_lengths)
|
| 636 |
+
|
| 637 |
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
| 638 |
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
| 639 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 640 |
+
attention_mask[i, :seq_len] = True
|
| 641 |
+
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
|
| 642 |
+
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
|
| 643 |
+
|
| 644 |
+
hidden_states = joint_hidden_states
|
| 645 |
+
|
| 646 |
+
if self.enable_teacache:
|
| 647 |
+
teacache_hidden_states = hidden_states.clone()
|
| 648 |
+
teacache_temb = temb.clone()
|
| 649 |
+
modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
|
| 650 |
+
if self.teacache_params.is_first_or_last_step:
|
| 651 |
+
should_calc = True
|
| 652 |
+
self.teacache_params.accumulated_rel_l1_distance = 0
|
| 653 |
+
else:
|
| 654 |
+
self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
|
| 655 |
+
((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \
|
| 656 |
+
/ self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
|
| 657 |
+
)
|
| 658 |
+
if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
|
| 659 |
+
should_calc = False
|
| 660 |
+
else:
|
| 661 |
+
should_calc = True
|
| 662 |
+
self.teacache_params.accumulated_rel_l1_distance = 0
|
| 663 |
+
self.teacache_params.previous_modulated_inp = modulated_inp
|
| 664 |
+
|
| 665 |
+
if self.enable_teacache:
|
| 666 |
+
if not should_calc:
|
| 667 |
+
hidden_states += self.teacache_params.previous_residual
|
| 668 |
+
else:
|
| 669 |
+
ori_hidden_states = hidden_states.clone()
|
| 670 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 671 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 672 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 673 |
+
layer, hidden_states, attention_mask, rotary_emb, temb
|
| 674 |
+
)
|
| 675 |
+
else:
|
| 676 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
| 677 |
+
self.teacache_params.previous_residual = hidden_states - ori_hidden_states
|
| 678 |
+
else:
|
| 679 |
+
if enable_taylorseer:
|
| 680 |
+
self.current['stream'] = 'layers_stream'
|
| 681 |
+
|
| 682 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 683 |
+
if enable_taylorseer:
|
| 684 |
+
layer.current = self.current
|
| 685 |
+
layer.cache_dic = self.cache_dic
|
| 686 |
+
layer.enable_taylorseer = True
|
| 687 |
+
self.current['layer'] = layer_idx
|
| 688 |
+
|
| 689 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 690 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 691 |
+
layer, hidden_states, attention_mask, rotary_emb, temb
|
| 692 |
+
)
|
| 693 |
+
else:
|
| 694 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
| 695 |
+
|
| 696 |
+
# 4. Output norm & projection
|
| 697 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 698 |
+
|
| 699 |
+
p = self.config.patch_size
|
| 700 |
+
output = []
|
| 701 |
+
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
|
| 702 |
+
height, width = img_size
|
| 703 |
+
output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
|
| 704 |
+
if is_hidden_states_tensor:
|
| 705 |
+
output = torch.stack(output, dim=0)
|
| 706 |
+
|
| 707 |
+
if USE_PEFT_BACKEND:
|
| 708 |
+
# remove `lora_scale` from each PEFT layer
|
| 709 |
+
unscale_lora_layers(self, lora_scale)
|
| 710 |
+
|
| 711 |
+
if enable_taylorseer:
|
| 712 |
+
self.current['step'] += 1
|
| 713 |
+
|
| 714 |
+
if not return_dict:
|
| 715 |
+
return output
|
| 716 |
+
return Transformer2DModelOutput(sample=output)
|
omnigen2/ops/triton/__init__.py
ADDED
|
File without changes
|
omnigen2/ops/triton/layer_norm.py
ADDED
|
@@ -0,0 +1,1257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
# Implement dropout + residual + layer_norm / rms_norm.
|
| 3 |
+
|
| 4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
import triton
|
| 15 |
+
import triton.language as tl
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from typing import Callable
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
|
| 22 |
+
def decorator(*args, **kwargs):
|
| 23 |
+
if cuda_amp_deprecated:
|
| 24 |
+
kwargs["device_type"] = "cuda"
|
| 25 |
+
return dec(*args, **kwargs)
|
| 26 |
+
return decorator
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
|
| 30 |
+
deprecated = True
|
| 31 |
+
from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
|
| 32 |
+
else:
|
| 33 |
+
deprecated = False
|
| 34 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
| 35 |
+
|
| 36 |
+
custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
|
| 37 |
+
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def triton_autotune_configs():
|
| 41 |
+
# Return configs with a valid warp count for the current device
|
| 42 |
+
configs=[]
|
| 43 |
+
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
|
| 44 |
+
max_threads_per_block=1024
|
| 45 |
+
# Default to warp size 32 if not defined by device
|
| 46 |
+
warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
|
| 47 |
+
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
|
| 48 |
+
warp_count=1
|
| 49 |
+
while warp_count*warp_size <= max_threads_per_block:
|
| 50 |
+
configs.append(triton.Config({}, num_warps=warp_count))
|
| 51 |
+
warp_count*=2
|
| 52 |
+
return configs
|
| 53 |
+
|
| 54 |
+
def layer_norm_ref(
|
| 55 |
+
x,
|
| 56 |
+
weight,
|
| 57 |
+
bias,
|
| 58 |
+
residual=None,
|
| 59 |
+
x1=None,
|
| 60 |
+
weight1=None,
|
| 61 |
+
bias1=None,
|
| 62 |
+
eps=1e-6,
|
| 63 |
+
dropout_p=0.0,
|
| 64 |
+
rowscale=None,
|
| 65 |
+
prenorm=False,
|
| 66 |
+
zero_centered_weight=False,
|
| 67 |
+
dropout_mask=None,
|
| 68 |
+
dropout_mask1=None,
|
| 69 |
+
upcast=False,
|
| 70 |
+
):
|
| 71 |
+
dtype = x.dtype
|
| 72 |
+
if upcast:
|
| 73 |
+
x = x.float()
|
| 74 |
+
weight = weight.float()
|
| 75 |
+
bias = bias.float() if bias is not None else None
|
| 76 |
+
residual = residual.float() if residual is not None else residual
|
| 77 |
+
x1 = x1.float() if x1 is not None else None
|
| 78 |
+
weight1 = weight1.float() if weight1 is not None else None
|
| 79 |
+
bias1 = bias1.float() if bias1 is not None else None
|
| 80 |
+
if zero_centered_weight:
|
| 81 |
+
weight = weight + 1.0
|
| 82 |
+
if weight1 is not None:
|
| 83 |
+
weight1 = weight1 + 1.0
|
| 84 |
+
if x1 is not None:
|
| 85 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 86 |
+
if rowscale is not None:
|
| 87 |
+
x = x * rowscale[..., None]
|
| 88 |
+
if dropout_p > 0.0:
|
| 89 |
+
if dropout_mask is not None:
|
| 90 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 91 |
+
else:
|
| 92 |
+
x = F.dropout(x, p=dropout_p)
|
| 93 |
+
if x1 is not None:
|
| 94 |
+
if dropout_mask1 is not None:
|
| 95 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 96 |
+
else:
|
| 97 |
+
x1 = F.dropout(x1, p=dropout_p)
|
| 98 |
+
if x1 is not None:
|
| 99 |
+
x = x + x1
|
| 100 |
+
if residual is not None:
|
| 101 |
+
x = (x + residual).to(x.dtype)
|
| 102 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
| 103 |
+
dtype
|
| 104 |
+
)
|
| 105 |
+
if weight1 is None:
|
| 106 |
+
return out if not prenorm else (out, x)
|
| 107 |
+
else:
|
| 108 |
+
out1 = F.layer_norm(
|
| 109 |
+
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
| 110 |
+
).to(dtype)
|
| 111 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def rms_norm_ref(
|
| 115 |
+
x,
|
| 116 |
+
weight,
|
| 117 |
+
bias,
|
| 118 |
+
residual=None,
|
| 119 |
+
x1=None,
|
| 120 |
+
weight1=None,
|
| 121 |
+
bias1=None,
|
| 122 |
+
eps=1e-6,
|
| 123 |
+
dropout_p=0.0,
|
| 124 |
+
rowscale=None,
|
| 125 |
+
prenorm=False,
|
| 126 |
+
zero_centered_weight=False,
|
| 127 |
+
dropout_mask=None,
|
| 128 |
+
dropout_mask1=None,
|
| 129 |
+
upcast=False,
|
| 130 |
+
):
|
| 131 |
+
dtype = x.dtype
|
| 132 |
+
if upcast:
|
| 133 |
+
x = x.float()
|
| 134 |
+
weight = weight.float()
|
| 135 |
+
bias = bias.float() if bias is not None else None
|
| 136 |
+
residual = residual.float() if residual is not None else residual
|
| 137 |
+
x1 = x1.float() if x1 is not None else None
|
| 138 |
+
weight1 = weight1.float() if weight1 is not None else None
|
| 139 |
+
bias1 = bias1.float() if bias1 is not None else None
|
| 140 |
+
if zero_centered_weight:
|
| 141 |
+
weight = weight + 1.0
|
| 142 |
+
if weight1 is not None:
|
| 143 |
+
weight1 = weight1 + 1.0
|
| 144 |
+
if x1 is not None:
|
| 145 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 146 |
+
if rowscale is not None:
|
| 147 |
+
x = x * rowscale[..., None]
|
| 148 |
+
if dropout_p > 0.0:
|
| 149 |
+
if dropout_mask is not None:
|
| 150 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
| 151 |
+
else:
|
| 152 |
+
x = F.dropout(x, p=dropout_p)
|
| 153 |
+
if x1 is not None:
|
| 154 |
+
if dropout_mask1 is not None:
|
| 155 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
| 156 |
+
else:
|
| 157 |
+
x1 = F.dropout(x1, p=dropout_p)
|
| 158 |
+
if x1 is not None:
|
| 159 |
+
x = x + x1
|
| 160 |
+
if residual is not None:
|
| 161 |
+
x = (x + residual).to(x.dtype)
|
| 162 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 163 |
+
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
|
| 164 |
+
if weight1 is None:
|
| 165 |
+
return out if not prenorm else (out, x)
|
| 166 |
+
else:
|
| 167 |
+
out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
|
| 168 |
+
dtype
|
| 169 |
+
)
|
| 170 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@triton.autotune(
|
| 174 |
+
configs=triton_autotune_configs(),
|
| 175 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 176 |
+
)
|
| 177 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 178 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 179 |
+
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
| 180 |
+
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
| 181 |
+
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
| 182 |
+
@triton.jit
|
| 183 |
+
def _layer_norm_fwd_1pass_kernel(
|
| 184 |
+
X, # pointer to the input
|
| 185 |
+
Y, # pointer to the output
|
| 186 |
+
W, # pointer to the weights
|
| 187 |
+
B, # pointer to the biases
|
| 188 |
+
RESIDUAL, # pointer to the residual
|
| 189 |
+
X1,
|
| 190 |
+
W1,
|
| 191 |
+
B1,
|
| 192 |
+
Y1,
|
| 193 |
+
RESIDUAL_OUT, # pointer to the residual
|
| 194 |
+
ROWSCALE,
|
| 195 |
+
SEEDS, # Dropout seeds for each row
|
| 196 |
+
DROPOUT_MASK,
|
| 197 |
+
Mean, # pointer to the mean
|
| 198 |
+
Rstd, # pointer to the 1/std
|
| 199 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 200 |
+
stride_y_row,
|
| 201 |
+
stride_res_row,
|
| 202 |
+
stride_res_out_row,
|
| 203 |
+
stride_x1_row,
|
| 204 |
+
stride_y1_row,
|
| 205 |
+
M, # number of rows in X
|
| 206 |
+
N, # number of columns in X
|
| 207 |
+
eps, # epsilon to avoid division by zero
|
| 208 |
+
dropout_p, # Dropout probability
|
| 209 |
+
zero_centered_weight, # If true, add 1.0 to the weight
|
| 210 |
+
IS_RMS_NORM: tl.constexpr,
|
| 211 |
+
BLOCK_N: tl.constexpr,
|
| 212 |
+
HAS_RESIDUAL: tl.constexpr,
|
| 213 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 214 |
+
HAS_BIAS: tl.constexpr,
|
| 215 |
+
HAS_DROPOUT: tl.constexpr,
|
| 216 |
+
STORE_DROPOUT_MASK: tl.constexpr,
|
| 217 |
+
HAS_ROWSCALE: tl.constexpr,
|
| 218 |
+
HAS_X1: tl.constexpr,
|
| 219 |
+
HAS_W1: tl.constexpr,
|
| 220 |
+
HAS_B1: tl.constexpr,
|
| 221 |
+
):
|
| 222 |
+
# Map the program id to the row of X and Y it should compute.
|
| 223 |
+
row = tl.program_id(0)
|
| 224 |
+
X += row * stride_x_row
|
| 225 |
+
Y += row * stride_y_row
|
| 226 |
+
if HAS_RESIDUAL:
|
| 227 |
+
RESIDUAL += row * stride_res_row
|
| 228 |
+
if STORE_RESIDUAL_OUT:
|
| 229 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
| 230 |
+
if HAS_X1:
|
| 231 |
+
X1 += row * stride_x1_row
|
| 232 |
+
if HAS_W1:
|
| 233 |
+
Y1 += row * stride_y1_row
|
| 234 |
+
# Compute mean and variance
|
| 235 |
+
cols = tl.arange(0, BLOCK_N)
|
| 236 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 237 |
+
if HAS_ROWSCALE:
|
| 238 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 239 |
+
x *= rowscale
|
| 240 |
+
if HAS_DROPOUT:
|
| 241 |
+
# Compute dropout mask
|
| 242 |
+
# 7 rounds is good enough, and reduces register pressure
|
| 243 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 244 |
+
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
| 245 |
+
if STORE_DROPOUT_MASK:
|
| 246 |
+
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
| 247 |
+
if HAS_X1:
|
| 248 |
+
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 249 |
+
if HAS_ROWSCALE:
|
| 250 |
+
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
| 251 |
+
x1 *= rowscale
|
| 252 |
+
if HAS_DROPOUT:
|
| 253 |
+
# Compute dropout mask
|
| 254 |
+
# 7 rounds is good enough, and reduces register pressure
|
| 255 |
+
keep_mask = (
|
| 256 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 257 |
+
)
|
| 258 |
+
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
| 259 |
+
if STORE_DROPOUT_MASK:
|
| 260 |
+
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
| 261 |
+
x += x1
|
| 262 |
+
if HAS_RESIDUAL:
|
| 263 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 264 |
+
x += residual
|
| 265 |
+
if STORE_RESIDUAL_OUT:
|
| 266 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 267 |
+
if not IS_RMS_NORM:
|
| 268 |
+
mean = tl.sum(x, axis=0) / N
|
| 269 |
+
tl.store(Mean + row, mean)
|
| 270 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 271 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 272 |
+
else:
|
| 273 |
+
xbar = tl.where(cols < N, x, 0.0)
|
| 274 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 275 |
+
rstd = 1 / tl.sqrt(var + eps)
|
| 276 |
+
tl.store(Rstd + row, rstd)
|
| 277 |
+
# Normalize and apply linear transformation
|
| 278 |
+
mask = cols < N
|
| 279 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 280 |
+
if zero_centered_weight:
|
| 281 |
+
w += 1.0
|
| 282 |
+
if HAS_BIAS:
|
| 283 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 284 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 285 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
| 286 |
+
# Write output
|
| 287 |
+
tl.store(Y + cols, y, mask=mask)
|
| 288 |
+
if HAS_W1:
|
| 289 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 290 |
+
if zero_centered_weight:
|
| 291 |
+
w1 += 1.0
|
| 292 |
+
if HAS_B1:
|
| 293 |
+
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
| 294 |
+
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
| 295 |
+
tl.store(Y1 + cols, y1, mask=mask)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _layer_norm_fwd(
|
| 299 |
+
x,
|
| 300 |
+
weight,
|
| 301 |
+
bias,
|
| 302 |
+
eps,
|
| 303 |
+
residual=None,
|
| 304 |
+
x1=None,
|
| 305 |
+
weight1=None,
|
| 306 |
+
bias1=None,
|
| 307 |
+
dropout_p=0.0,
|
| 308 |
+
rowscale=None,
|
| 309 |
+
out_dtype=None,
|
| 310 |
+
residual_dtype=None,
|
| 311 |
+
zero_centered_weight=False,
|
| 312 |
+
is_rms_norm=False,
|
| 313 |
+
return_dropout_mask=False,
|
| 314 |
+
out=None,
|
| 315 |
+
residual_out=None
|
| 316 |
+
):
|
| 317 |
+
if residual is not None:
|
| 318 |
+
residual_dtype = residual.dtype
|
| 319 |
+
M, N = x.shape
|
| 320 |
+
assert x.stride(-1) == 1
|
| 321 |
+
if residual is not None:
|
| 322 |
+
assert residual.stride(-1) == 1
|
| 323 |
+
assert residual.shape == (M, N)
|
| 324 |
+
assert weight.shape == (N,)
|
| 325 |
+
assert weight.stride(-1) == 1
|
| 326 |
+
if bias is not None:
|
| 327 |
+
assert bias.stride(-1) == 1
|
| 328 |
+
assert bias.shape == (N,)
|
| 329 |
+
if x1 is not None:
|
| 330 |
+
assert x1.shape == x.shape
|
| 331 |
+
assert rowscale is None
|
| 332 |
+
assert x1.stride(-1) == 1
|
| 333 |
+
if weight1 is not None:
|
| 334 |
+
assert weight1.shape == (N,)
|
| 335 |
+
assert weight1.stride(-1) == 1
|
| 336 |
+
if bias1 is not None:
|
| 337 |
+
assert bias1.shape == (N,)
|
| 338 |
+
assert bias1.stride(-1) == 1
|
| 339 |
+
if rowscale is not None:
|
| 340 |
+
assert rowscale.is_contiguous()
|
| 341 |
+
assert rowscale.shape == (M,)
|
| 342 |
+
# allocate output
|
| 343 |
+
if out is None:
|
| 344 |
+
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 345 |
+
else:
|
| 346 |
+
assert out.shape == x.shape
|
| 347 |
+
assert out.stride(-1) == 1
|
| 348 |
+
if weight1 is not None:
|
| 349 |
+
y1 = torch.empty_like(out)
|
| 350 |
+
assert y1.stride(-1) == 1
|
| 351 |
+
else:
|
| 352 |
+
y1 = None
|
| 353 |
+
if (
|
| 354 |
+
residual is not None
|
| 355 |
+
or (residual_dtype is not None and residual_dtype != x.dtype)
|
| 356 |
+
or dropout_p > 0.0
|
| 357 |
+
or rowscale is not None
|
| 358 |
+
or x1 is not None
|
| 359 |
+
):
|
| 360 |
+
if residual_out is None:
|
| 361 |
+
residual_out = torch.empty(
|
| 362 |
+
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
| 363 |
+
)
|
| 364 |
+
else:
|
| 365 |
+
assert residual_out.shape == x.shape
|
| 366 |
+
assert residual_out.stride(-1) == 1
|
| 367 |
+
else:
|
| 368 |
+
residual_out = None
|
| 369 |
+
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
| 370 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
| 371 |
+
if dropout_p > 0.0:
|
| 372 |
+
seeds = torch.randint(
|
| 373 |
+
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
seeds = None
|
| 377 |
+
if return_dropout_mask and dropout_p > 0.0:
|
| 378 |
+
dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
|
| 379 |
+
else:
|
| 380 |
+
dropout_mask = None
|
| 381 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 382 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 383 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 384 |
+
if N > BLOCK_N:
|
| 385 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 386 |
+
with torch.cuda.device(x.device.index):
|
| 387 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 388 |
+
x,
|
| 389 |
+
out,
|
| 390 |
+
weight,
|
| 391 |
+
bias,
|
| 392 |
+
residual,
|
| 393 |
+
x1,
|
| 394 |
+
weight1,
|
| 395 |
+
bias1,
|
| 396 |
+
y1,
|
| 397 |
+
residual_out,
|
| 398 |
+
rowscale,
|
| 399 |
+
seeds,
|
| 400 |
+
dropout_mask,
|
| 401 |
+
mean,
|
| 402 |
+
rstd,
|
| 403 |
+
x.stride(0),
|
| 404 |
+
out.stride(0),
|
| 405 |
+
residual.stride(0) if residual is not None else 0,
|
| 406 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
| 407 |
+
x1.stride(0) if x1 is not None else 0,
|
| 408 |
+
y1.stride(0) if y1 is not None else 0,
|
| 409 |
+
M,
|
| 410 |
+
N,
|
| 411 |
+
eps,
|
| 412 |
+
dropout_p,
|
| 413 |
+
zero_centered_weight,
|
| 414 |
+
is_rms_norm,
|
| 415 |
+
BLOCK_N,
|
| 416 |
+
residual is not None,
|
| 417 |
+
residual_out is not None,
|
| 418 |
+
bias is not None,
|
| 419 |
+
dropout_p > 0.0,
|
| 420 |
+
dropout_mask is not None,
|
| 421 |
+
rowscale is not None,
|
| 422 |
+
)
|
| 423 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
| 424 |
+
if dropout_mask is not None and x1 is not None:
|
| 425 |
+
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
| 426 |
+
else:
|
| 427 |
+
dropout_mask1 = None
|
| 428 |
+
return (
|
| 429 |
+
out,
|
| 430 |
+
y1,
|
| 431 |
+
mean,
|
| 432 |
+
rstd,
|
| 433 |
+
residual_out if residual_out is not None else x,
|
| 434 |
+
seeds,
|
| 435 |
+
dropout_mask,
|
| 436 |
+
dropout_mask1,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
@triton.autotune(
|
| 441 |
+
configs=triton_autotune_configs(),
|
| 442 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
|
| 443 |
+
)
|
| 444 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 445 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 446 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 447 |
+
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
| 448 |
+
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
| 449 |
+
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
| 450 |
+
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
| 451 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 452 |
+
@triton.jit
|
| 453 |
+
def _layer_norm_bwd_kernel(
|
| 454 |
+
X, # pointer to the input
|
| 455 |
+
W, # pointer to the weights
|
| 456 |
+
B, # pointer to the biases
|
| 457 |
+
Y, # pointer to the output to be recomputed
|
| 458 |
+
DY, # pointer to the output gradient
|
| 459 |
+
DX, # pointer to the input gradient
|
| 460 |
+
DW, # pointer to the partial sum of weights gradient
|
| 461 |
+
DB, # pointer to the partial sum of biases gradient
|
| 462 |
+
DRESIDUAL,
|
| 463 |
+
W1,
|
| 464 |
+
DY1,
|
| 465 |
+
DX1,
|
| 466 |
+
DW1,
|
| 467 |
+
DB1,
|
| 468 |
+
DRESIDUAL_IN,
|
| 469 |
+
ROWSCALE,
|
| 470 |
+
SEEDS,
|
| 471 |
+
Mean, # pointer to the mean
|
| 472 |
+
Rstd, # pointer to the 1/std
|
| 473 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 474 |
+
stride_y_row,
|
| 475 |
+
stride_dy_row,
|
| 476 |
+
stride_dx_row,
|
| 477 |
+
stride_dres_row,
|
| 478 |
+
stride_dy1_row,
|
| 479 |
+
stride_dx1_row,
|
| 480 |
+
stride_dres_in_row,
|
| 481 |
+
M, # number of rows in X
|
| 482 |
+
N, # number of columns in X
|
| 483 |
+
eps, # epsilon to avoid division by zero
|
| 484 |
+
dropout_p,
|
| 485 |
+
zero_centered_weight,
|
| 486 |
+
rows_per_program,
|
| 487 |
+
IS_RMS_NORM: tl.constexpr,
|
| 488 |
+
BLOCK_N: tl.constexpr,
|
| 489 |
+
HAS_DRESIDUAL: tl.constexpr,
|
| 490 |
+
STORE_DRESIDUAL: tl.constexpr,
|
| 491 |
+
HAS_BIAS: tl.constexpr,
|
| 492 |
+
HAS_DROPOUT: tl.constexpr,
|
| 493 |
+
HAS_ROWSCALE: tl.constexpr,
|
| 494 |
+
HAS_DY1: tl.constexpr,
|
| 495 |
+
HAS_DX1: tl.constexpr,
|
| 496 |
+
HAS_B1: tl.constexpr,
|
| 497 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 498 |
+
):
|
| 499 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 500 |
+
row_block_id = tl.program_id(0)
|
| 501 |
+
row_start = row_block_id * rows_per_program
|
| 502 |
+
# Do not early exit if row_start >= M, because we need to write DW and DB
|
| 503 |
+
cols = tl.arange(0, BLOCK_N)
|
| 504 |
+
mask = cols < N
|
| 505 |
+
X += row_start * stride_x_row
|
| 506 |
+
if HAS_DRESIDUAL:
|
| 507 |
+
DRESIDUAL += row_start * stride_dres_row
|
| 508 |
+
if STORE_DRESIDUAL:
|
| 509 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 510 |
+
DY += row_start * stride_dy_row
|
| 511 |
+
DX += row_start * stride_dx_row
|
| 512 |
+
if HAS_DY1:
|
| 513 |
+
DY1 += row_start * stride_dy1_row
|
| 514 |
+
if HAS_DX1:
|
| 515 |
+
DX1 += row_start * stride_dx1_row
|
| 516 |
+
if RECOMPUTE_OUTPUT:
|
| 517 |
+
Y += row_start * stride_y_row
|
| 518 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 519 |
+
if zero_centered_weight:
|
| 520 |
+
w += 1.0
|
| 521 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 522 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 523 |
+
if HAS_DY1:
|
| 524 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
| 525 |
+
if zero_centered_weight:
|
| 526 |
+
w1 += 1.0
|
| 527 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 528 |
+
if HAS_BIAS:
|
| 529 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 530 |
+
if HAS_DY1:
|
| 531 |
+
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 532 |
+
if HAS_B1:
|
| 533 |
+
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 534 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 535 |
+
for row in range(row_start, row_end):
|
| 536 |
+
# Load data to SRAM
|
| 537 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 538 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 539 |
+
if HAS_DY1:
|
| 540 |
+
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
| 541 |
+
if not IS_RMS_NORM:
|
| 542 |
+
mean = tl.load(Mean + row)
|
| 543 |
+
rstd = tl.load(Rstd + row)
|
| 544 |
+
# Compute dx
|
| 545 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 546 |
+
xhat = tl.where(mask, xhat, 0.0)
|
| 547 |
+
if RECOMPUTE_OUTPUT:
|
| 548 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
| 549 |
+
tl.store(Y + cols, y, mask=mask)
|
| 550 |
+
wdy = w * dy
|
| 551 |
+
dw += dy * xhat
|
| 552 |
+
if HAS_BIAS:
|
| 553 |
+
db += dy
|
| 554 |
+
if HAS_DY1:
|
| 555 |
+
wdy += w1 * dy1
|
| 556 |
+
dw1 += dy1 * xhat
|
| 557 |
+
if HAS_B1:
|
| 558 |
+
db1 += dy1
|
| 559 |
+
if not IS_RMS_NORM:
|
| 560 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 561 |
+
c2 = tl.sum(wdy, axis=0) / N
|
| 562 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 563 |
+
else:
|
| 564 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 565 |
+
dx = (wdy - xhat * c1) * rstd
|
| 566 |
+
if HAS_DRESIDUAL:
|
| 567 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 568 |
+
dx += dres
|
| 569 |
+
# Write dx
|
| 570 |
+
if STORE_DRESIDUAL:
|
| 571 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 572 |
+
if HAS_DX1:
|
| 573 |
+
if HAS_DROPOUT:
|
| 574 |
+
keep_mask = (
|
| 575 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 576 |
+
)
|
| 577 |
+
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 578 |
+
else:
|
| 579 |
+
dx1 = dx
|
| 580 |
+
tl.store(DX1 + cols, dx1, mask=mask)
|
| 581 |
+
if HAS_DROPOUT:
|
| 582 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
| 583 |
+
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
| 584 |
+
if HAS_ROWSCALE:
|
| 585 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
| 586 |
+
dx *= rowscale
|
| 587 |
+
tl.store(DX + cols, dx, mask=mask)
|
| 588 |
+
|
| 589 |
+
X += stride_x_row
|
| 590 |
+
if HAS_DRESIDUAL:
|
| 591 |
+
DRESIDUAL += stride_dres_row
|
| 592 |
+
if STORE_DRESIDUAL:
|
| 593 |
+
DRESIDUAL_IN += stride_dres_in_row
|
| 594 |
+
if RECOMPUTE_OUTPUT:
|
| 595 |
+
Y += stride_y_row
|
| 596 |
+
DY += stride_dy_row
|
| 597 |
+
DX += stride_dx_row
|
| 598 |
+
if HAS_DY1:
|
| 599 |
+
DY1 += stride_dy1_row
|
| 600 |
+
if HAS_DX1:
|
| 601 |
+
DX1 += stride_dx1_row
|
| 602 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 603 |
+
if HAS_BIAS:
|
| 604 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 605 |
+
if HAS_DY1:
|
| 606 |
+
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
| 607 |
+
if HAS_B1:
|
| 608 |
+
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def _layer_norm_bwd(
|
| 612 |
+
dy,
|
| 613 |
+
x,
|
| 614 |
+
weight,
|
| 615 |
+
bias,
|
| 616 |
+
eps,
|
| 617 |
+
mean,
|
| 618 |
+
rstd,
|
| 619 |
+
dresidual=None,
|
| 620 |
+
dy1=None,
|
| 621 |
+
weight1=None,
|
| 622 |
+
bias1=None,
|
| 623 |
+
seeds=None,
|
| 624 |
+
dropout_p=0.0,
|
| 625 |
+
rowscale=None,
|
| 626 |
+
has_residual=False,
|
| 627 |
+
has_x1=False,
|
| 628 |
+
zero_centered_weight=False,
|
| 629 |
+
is_rms_norm=False,
|
| 630 |
+
x_dtype=None,
|
| 631 |
+
recompute_output=False,
|
| 632 |
+
):
|
| 633 |
+
M, N = x.shape
|
| 634 |
+
assert x.stride(-1) == 1
|
| 635 |
+
assert dy.stride(-1) == 1
|
| 636 |
+
assert dy.shape == (M, N)
|
| 637 |
+
if dresidual is not None:
|
| 638 |
+
assert dresidual.stride(-1) == 1
|
| 639 |
+
assert dresidual.shape == (M, N)
|
| 640 |
+
assert weight.shape == (N,)
|
| 641 |
+
assert weight.stride(-1) == 1
|
| 642 |
+
if bias is not None:
|
| 643 |
+
assert bias.stride(-1) == 1
|
| 644 |
+
assert bias.shape == (N,)
|
| 645 |
+
if dy1 is not None:
|
| 646 |
+
assert weight1 is not None
|
| 647 |
+
assert dy1.shape == dy.shape
|
| 648 |
+
assert dy1.stride(-1) == 1
|
| 649 |
+
if weight1 is not None:
|
| 650 |
+
assert weight1.shape == (N,)
|
| 651 |
+
assert weight1.stride(-1) == 1
|
| 652 |
+
if bias1 is not None:
|
| 653 |
+
assert bias1.shape == (N,)
|
| 654 |
+
assert bias1.stride(-1) == 1
|
| 655 |
+
if seeds is not None:
|
| 656 |
+
assert seeds.is_contiguous()
|
| 657 |
+
assert seeds.shape == (M if not has_x1 else M * 2,)
|
| 658 |
+
if rowscale is not None:
|
| 659 |
+
assert rowscale.is_contiguous()
|
| 660 |
+
assert rowscale.shape == (M,)
|
| 661 |
+
# allocate output
|
| 662 |
+
dx = (
|
| 663 |
+
torch.empty_like(x)
|
| 664 |
+
if x_dtype is None
|
| 665 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 666 |
+
)
|
| 667 |
+
dresidual_in = (
|
| 668 |
+
torch.empty_like(x)
|
| 669 |
+
if has_residual
|
| 670 |
+
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
| 671 |
+
else None
|
| 672 |
+
)
|
| 673 |
+
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
| 674 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
| 675 |
+
if recompute_output:
|
| 676 |
+
assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
|
| 677 |
+
|
| 678 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 679 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 680 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 681 |
+
if N > BLOCK_N:
|
| 682 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 683 |
+
# Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
|
| 684 |
+
# latency of the gmem reads/writes, but will increase the time of summing up dw / db.
|
| 685 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
|
| 686 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 687 |
+
_db = (
|
| 688 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 689 |
+
if bias is not None
|
| 690 |
+
else None
|
| 691 |
+
)
|
| 692 |
+
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
| 693 |
+
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
| 694 |
+
rows_per_program = math.ceil(M / sm_count)
|
| 695 |
+
grid = (sm_count,)
|
| 696 |
+
with torch.cuda.device(x.device.index):
|
| 697 |
+
_layer_norm_bwd_kernel[grid](
|
| 698 |
+
x,
|
| 699 |
+
weight,
|
| 700 |
+
bias,
|
| 701 |
+
y,
|
| 702 |
+
dy,
|
| 703 |
+
dx,
|
| 704 |
+
_dw,
|
| 705 |
+
_db,
|
| 706 |
+
dresidual,
|
| 707 |
+
weight1,
|
| 708 |
+
dy1,
|
| 709 |
+
dx1,
|
| 710 |
+
_dw1,
|
| 711 |
+
_db1,
|
| 712 |
+
dresidual_in,
|
| 713 |
+
rowscale,
|
| 714 |
+
seeds,
|
| 715 |
+
mean,
|
| 716 |
+
rstd,
|
| 717 |
+
x.stride(0),
|
| 718 |
+
0 if not recompute_output else y.stride(0),
|
| 719 |
+
dy.stride(0),
|
| 720 |
+
dx.stride(0),
|
| 721 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
| 722 |
+
dy1.stride(0) if dy1 is not None else 0,
|
| 723 |
+
dx1.stride(0) if dx1 is not None else 0,
|
| 724 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 725 |
+
M,
|
| 726 |
+
N,
|
| 727 |
+
eps,
|
| 728 |
+
dropout_p,
|
| 729 |
+
zero_centered_weight,
|
| 730 |
+
rows_per_program,
|
| 731 |
+
is_rms_norm,
|
| 732 |
+
BLOCK_N,
|
| 733 |
+
dresidual is not None,
|
| 734 |
+
dresidual_in is not None,
|
| 735 |
+
bias is not None,
|
| 736 |
+
dropout_p > 0.0,
|
| 737 |
+
)
|
| 738 |
+
dw = _dw.sum(0).to(weight.dtype)
|
| 739 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 740 |
+
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
| 741 |
+
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
| 742 |
+
# Don't need to compute dresidual_in separately in this case
|
| 743 |
+
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
| 744 |
+
dresidual_in = dx
|
| 745 |
+
if has_x1 and dropout_p == 0.0:
|
| 746 |
+
dx1 = dx
|
| 747 |
+
return (
|
| 748 |
+
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
| 749 |
+
if not recompute_output
|
| 750 |
+
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
class LayerNormFn(torch.autograd.Function):
|
| 755 |
+
@staticmethod
|
| 756 |
+
def forward(
|
| 757 |
+
ctx,
|
| 758 |
+
x,
|
| 759 |
+
weight,
|
| 760 |
+
bias,
|
| 761 |
+
residual=None,
|
| 762 |
+
x1=None,
|
| 763 |
+
weight1=None,
|
| 764 |
+
bias1=None,
|
| 765 |
+
eps=1e-6,
|
| 766 |
+
dropout_p=0.0,
|
| 767 |
+
rowscale=None,
|
| 768 |
+
prenorm=False,
|
| 769 |
+
residual_in_fp32=False,
|
| 770 |
+
zero_centered_weight=False,
|
| 771 |
+
is_rms_norm=False,
|
| 772 |
+
return_dropout_mask=False,
|
| 773 |
+
out=None,
|
| 774 |
+
residual_out=None
|
| 775 |
+
):
|
| 776 |
+
x_shape_og = x.shape
|
| 777 |
+
# Check for zero sequence length
|
| 778 |
+
if x.numel() == 0:
|
| 779 |
+
ctx.zero_seq_length = True
|
| 780 |
+
# Only save minimal required tensors for backward
|
| 781 |
+
# ctx.save_for_backward(weight, bias, weight1, bias1)
|
| 782 |
+
ctx.x_shape_og = x_shape_og
|
| 783 |
+
ctx.weight_shape = weight.shape
|
| 784 |
+
ctx.weight_dtype = weight.dtype
|
| 785 |
+
ctx.weight_device = weight.device
|
| 786 |
+
|
| 787 |
+
ctx.has_bias = bias is not None
|
| 788 |
+
ctx.bias_shape = bias.shape if bias is not None else None
|
| 789 |
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
| 790 |
+
ctx.bias_device = bias.device if bias is not None else None
|
| 791 |
+
|
| 792 |
+
ctx.has_weight1 = weight1 is not None
|
| 793 |
+
ctx.weight1_shape = weight1.shape if weight1 is not None else None
|
| 794 |
+
ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
|
| 795 |
+
ctx.weight1_device = weight1.device if weight1 is not None else None
|
| 796 |
+
|
| 797 |
+
ctx.has_bias1 = bias1 is not None
|
| 798 |
+
ctx.bias1_shape = bias1.shape if bias1 is not None else None
|
| 799 |
+
ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
|
| 800 |
+
ctx.bias1_device = bias1.device if bias1 is not None else None
|
| 801 |
+
|
| 802 |
+
ctx.has_residual = residual is not None
|
| 803 |
+
ctx.has_x1 = x1 is not None
|
| 804 |
+
ctx.dropout_p = dropout_p
|
| 805 |
+
|
| 806 |
+
# Handle output tensors with correct dtype
|
| 807 |
+
y = x # Preserve input tensor properties
|
| 808 |
+
y1 = torch.empty_like(x) if x1 is not None else None
|
| 809 |
+
|
| 810 |
+
# Only create residual_out if prenorm is True
|
| 811 |
+
residual_out = torch.empty(x.shape,
|
| 812 |
+
dtype=torch.float32 if residual_in_fp32 else x.dtype,
|
| 813 |
+
device=x.device) if prenorm else None
|
| 814 |
+
|
| 815 |
+
# Handle dropout masks
|
| 816 |
+
dropout_mask = None
|
| 817 |
+
dropout_mask1 = None
|
| 818 |
+
if return_dropout_mask:
|
| 819 |
+
dropout_mask = torch.empty_like(x, dtype=torch.uint8)
|
| 820 |
+
if x1 is not None:
|
| 821 |
+
dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
|
| 822 |
+
|
| 823 |
+
# Return based on configuration
|
| 824 |
+
if not return_dropout_mask:
|
| 825 |
+
if weight1 is None:
|
| 826 |
+
return y if not prenorm else (y, residual_out)
|
| 827 |
+
else:
|
| 828 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
| 829 |
+
else:
|
| 830 |
+
if weight1 is None:
|
| 831 |
+
return ((y, dropout_mask, dropout_mask1) if not prenorm
|
| 832 |
+
else (y, residual_out, dropout_mask, dropout_mask1))
|
| 833 |
+
else:
|
| 834 |
+
return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
|
| 835 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1))
|
| 836 |
+
|
| 837 |
+
ctx.zero_seq_length = False
|
| 838 |
+
# reshape input data into 2D tensor
|
| 839 |
+
x = x.reshape(-1, x.shape[-1])
|
| 840 |
+
if x.stride(-1) != 1:
|
| 841 |
+
x = x.contiguous()
|
| 842 |
+
if residual is not None:
|
| 843 |
+
assert residual.shape == x_shape_og
|
| 844 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 845 |
+
if residual.stride(-1) != 1:
|
| 846 |
+
residual = residual.contiguous()
|
| 847 |
+
if x1 is not None:
|
| 848 |
+
assert x1.shape == x_shape_og
|
| 849 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
| 850 |
+
x1 = x1.reshape(-1, x1.shape[-1])
|
| 851 |
+
if x1.stride(-1) != 1:
|
| 852 |
+
x1 = x1.contiguous()
|
| 853 |
+
weight = weight.contiguous()
|
| 854 |
+
if bias is not None:
|
| 855 |
+
bias = bias.contiguous()
|
| 856 |
+
if weight1 is not None:
|
| 857 |
+
weight1 = weight1.contiguous()
|
| 858 |
+
if bias1 is not None:
|
| 859 |
+
bias1 = bias1.contiguous()
|
| 860 |
+
if rowscale is not None:
|
| 861 |
+
rowscale = rowscale.reshape(-1).contiguous()
|
| 862 |
+
residual_dtype = (
|
| 863 |
+
residual.dtype
|
| 864 |
+
if residual is not None
|
| 865 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 866 |
+
)
|
| 867 |
+
if out is not None:
|
| 868 |
+
out = out.reshape(-1, out.shape[-1])
|
| 869 |
+
if residual_out is not None:
|
| 870 |
+
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
| 871 |
+
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
| 872 |
+
x,
|
| 873 |
+
weight,
|
| 874 |
+
bias,
|
| 875 |
+
eps,
|
| 876 |
+
residual,
|
| 877 |
+
x1,
|
| 878 |
+
weight1,
|
| 879 |
+
bias1,
|
| 880 |
+
dropout_p=dropout_p,
|
| 881 |
+
rowscale=rowscale,
|
| 882 |
+
residual_dtype=residual_dtype,
|
| 883 |
+
zero_centered_weight=zero_centered_weight,
|
| 884 |
+
is_rms_norm=is_rms_norm,
|
| 885 |
+
return_dropout_mask=return_dropout_mask,
|
| 886 |
+
out=out,
|
| 887 |
+
residual_out=residual_out
|
| 888 |
+
)
|
| 889 |
+
ctx.save_for_backward(
|
| 890 |
+
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
| 891 |
+
)
|
| 892 |
+
ctx.x_shape_og = x_shape_og
|
| 893 |
+
ctx.eps = eps
|
| 894 |
+
ctx.dropout_p = dropout_p
|
| 895 |
+
ctx.is_rms_norm = is_rms_norm
|
| 896 |
+
ctx.has_residual = residual is not None
|
| 897 |
+
ctx.has_x1 = x1 is not None
|
| 898 |
+
ctx.prenorm = prenorm
|
| 899 |
+
ctx.x_dtype = x.dtype
|
| 900 |
+
ctx.zero_centered_weight = zero_centered_weight
|
| 901 |
+
y = y.reshape(x_shape_og)
|
| 902 |
+
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
| 903 |
+
residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
|
| 904 |
+
dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
| 905 |
+
dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
| 906 |
+
if not return_dropout_mask:
|
| 907 |
+
if weight1 is None:
|
| 908 |
+
return y if not prenorm else (y, residual_out)
|
| 909 |
+
else:
|
| 910 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
| 911 |
+
else:
|
| 912 |
+
if weight1 is None:
|
| 913 |
+
return (
|
| 914 |
+
(y, dropout_mask, dropout_mask1)
|
| 915 |
+
if not prenorm
|
| 916 |
+
else (y, residual_out, dropout_mask, dropout_mask1)
|
| 917 |
+
)
|
| 918 |
+
else:
|
| 919 |
+
return (
|
| 920 |
+
(y, y1, dropout_mask, dropout_mask1)
|
| 921 |
+
if not prenorm
|
| 922 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
@staticmethod
|
| 926 |
+
def backward(ctx, dy, *args):
|
| 927 |
+
if ctx.zero_seq_length:
|
| 928 |
+
return (
|
| 929 |
+
torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
|
| 930 |
+
torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
|
| 931 |
+
torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
|
| 932 |
+
torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
|
| 933 |
+
torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
|
| 934 |
+
torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
|
| 935 |
+
torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
|
| 936 |
+
None,
|
| 937 |
+
None,
|
| 938 |
+
None,
|
| 939 |
+
None,
|
| 940 |
+
None,
|
| 941 |
+
None,
|
| 942 |
+
None,
|
| 943 |
+
None,
|
| 944 |
+
None,
|
| 945 |
+
None,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
| 949 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
| 950 |
+
if dy.stride(-1) != 1:
|
| 951 |
+
dy = dy.contiguous()
|
| 952 |
+
assert dy.shape == x.shape
|
| 953 |
+
if weight1 is not None:
|
| 954 |
+
dy1, args = args[0], args[1:]
|
| 955 |
+
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
| 956 |
+
if dy1.stride(-1) != 1:
|
| 957 |
+
dy1 = dy1.contiguous()
|
| 958 |
+
assert dy1.shape == x.shape
|
| 959 |
+
else:
|
| 960 |
+
dy1 = None
|
| 961 |
+
if ctx.prenorm:
|
| 962 |
+
dresidual = args[0]
|
| 963 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 964 |
+
if dresidual.stride(-1) != 1:
|
| 965 |
+
dresidual = dresidual.contiguous()
|
| 966 |
+
assert dresidual.shape == x.shape
|
| 967 |
+
else:
|
| 968 |
+
dresidual = None
|
| 969 |
+
|
| 970 |
+
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
| 971 |
+
dy,
|
| 972 |
+
x,
|
| 973 |
+
weight,
|
| 974 |
+
bias,
|
| 975 |
+
ctx.eps,
|
| 976 |
+
mean,
|
| 977 |
+
rstd,
|
| 978 |
+
dresidual,
|
| 979 |
+
dy1,
|
| 980 |
+
weight1,
|
| 981 |
+
bias1,
|
| 982 |
+
seeds,
|
| 983 |
+
ctx.dropout_p,
|
| 984 |
+
rowscale,
|
| 985 |
+
ctx.has_residual,
|
| 986 |
+
ctx.has_x1,
|
| 987 |
+
ctx.zero_centered_weight,
|
| 988 |
+
ctx.is_rms_norm,
|
| 989 |
+
x_dtype=ctx.x_dtype,
|
| 990 |
+
)
|
| 991 |
+
return (
|
| 992 |
+
dx.reshape(ctx.x_shape_og),
|
| 993 |
+
dw,
|
| 994 |
+
db,
|
| 995 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 996 |
+
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
| 997 |
+
dw1,
|
| 998 |
+
db1,
|
| 999 |
+
None,
|
| 1000 |
+
None,
|
| 1001 |
+
None,
|
| 1002 |
+
None,
|
| 1003 |
+
None,
|
| 1004 |
+
None,
|
| 1005 |
+
None,
|
| 1006 |
+
None,
|
| 1007 |
+
None,
|
| 1008 |
+
None,
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def layer_norm_fn(
|
| 1013 |
+
x,
|
| 1014 |
+
weight,
|
| 1015 |
+
bias,
|
| 1016 |
+
residual=None,
|
| 1017 |
+
x1=None,
|
| 1018 |
+
weight1=None,
|
| 1019 |
+
bias1=None,
|
| 1020 |
+
eps=1e-6,
|
| 1021 |
+
dropout_p=0.0,
|
| 1022 |
+
rowscale=None,
|
| 1023 |
+
prenorm=False,
|
| 1024 |
+
residual_in_fp32=False,
|
| 1025 |
+
zero_centered_weight=False,
|
| 1026 |
+
is_rms_norm=False,
|
| 1027 |
+
return_dropout_mask=False,
|
| 1028 |
+
out=None,
|
| 1029 |
+
residual_out=None
|
| 1030 |
+
):
|
| 1031 |
+
return LayerNormFn.apply(
|
| 1032 |
+
x,
|
| 1033 |
+
weight,
|
| 1034 |
+
bias,
|
| 1035 |
+
residual,
|
| 1036 |
+
x1,
|
| 1037 |
+
weight1,
|
| 1038 |
+
bias1,
|
| 1039 |
+
eps,
|
| 1040 |
+
dropout_p,
|
| 1041 |
+
rowscale,
|
| 1042 |
+
prenorm,
|
| 1043 |
+
residual_in_fp32,
|
| 1044 |
+
zero_centered_weight,
|
| 1045 |
+
is_rms_norm,
|
| 1046 |
+
return_dropout_mask,
|
| 1047 |
+
out,
|
| 1048 |
+
residual_out
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def rms_norm_fn(
|
| 1053 |
+
x,
|
| 1054 |
+
weight,
|
| 1055 |
+
bias,
|
| 1056 |
+
residual=None,
|
| 1057 |
+
x1=None,
|
| 1058 |
+
weight1=None,
|
| 1059 |
+
bias1=None,
|
| 1060 |
+
eps=1e-6,
|
| 1061 |
+
dropout_p=0.0,
|
| 1062 |
+
rowscale=None,
|
| 1063 |
+
prenorm=False,
|
| 1064 |
+
residual_in_fp32=False,
|
| 1065 |
+
zero_centered_weight=False,
|
| 1066 |
+
return_dropout_mask=False,
|
| 1067 |
+
out=None,
|
| 1068 |
+
residual_out=None
|
| 1069 |
+
):
|
| 1070 |
+
return LayerNormFn.apply(
|
| 1071 |
+
x,
|
| 1072 |
+
weight,
|
| 1073 |
+
bias,
|
| 1074 |
+
residual,
|
| 1075 |
+
x1,
|
| 1076 |
+
weight1,
|
| 1077 |
+
bias1,
|
| 1078 |
+
eps,
|
| 1079 |
+
dropout_p,
|
| 1080 |
+
rowscale,
|
| 1081 |
+
prenorm,
|
| 1082 |
+
residual_in_fp32,
|
| 1083 |
+
zero_centered_weight,
|
| 1084 |
+
True,
|
| 1085 |
+
return_dropout_mask,
|
| 1086 |
+
out,
|
| 1087 |
+
residual_out
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
class RMSNorm(torch.nn.Module):
|
| 1092 |
+
|
| 1093 |
+
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
|
| 1094 |
+
device=None, dtype=None):
|
| 1095 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 1096 |
+
super().__init__()
|
| 1097 |
+
self.eps = eps
|
| 1098 |
+
if dropout_p > 0.0:
|
| 1099 |
+
self.drop = torch.nn.Dropout(dropout_p)
|
| 1100 |
+
else:
|
| 1101 |
+
self.drop = None
|
| 1102 |
+
self.zero_centered_weight = zero_centered_weight
|
| 1103 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 1104 |
+
self.register_parameter("bias", None)
|
| 1105 |
+
self.reset_parameters()
|
| 1106 |
+
|
| 1107 |
+
def reset_parameters(self):
|
| 1108 |
+
if not self.zero_centered_weight:
|
| 1109 |
+
torch.nn.init.ones_(self.weight)
|
| 1110 |
+
else:
|
| 1111 |
+
torch.nn.init.zeros_(self.weight)
|
| 1112 |
+
|
| 1113 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
| 1114 |
+
return rms_norm_fn(
|
| 1115 |
+
x,
|
| 1116 |
+
self.weight,
|
| 1117 |
+
self.bias,
|
| 1118 |
+
residual=residual,
|
| 1119 |
+
eps=self.eps,
|
| 1120 |
+
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
| 1121 |
+
prenorm=prenorm,
|
| 1122 |
+
residual_in_fp32=residual_in_fp32,
|
| 1123 |
+
zero_centered_weight=self.zero_centered_weight,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
| 1128 |
+
@staticmethod
|
| 1129 |
+
@custom_fwd
|
| 1130 |
+
def forward(
|
| 1131 |
+
ctx,
|
| 1132 |
+
x,
|
| 1133 |
+
norm_weight,
|
| 1134 |
+
norm_bias,
|
| 1135 |
+
linear_weight,
|
| 1136 |
+
linear_bias,
|
| 1137 |
+
residual=None,
|
| 1138 |
+
eps=1e-6,
|
| 1139 |
+
prenorm=False,
|
| 1140 |
+
residual_in_fp32=False,
|
| 1141 |
+
is_rms_norm=False,
|
| 1142 |
+
):
|
| 1143 |
+
x_shape_og = x.shape
|
| 1144 |
+
# reshape input data into 2D tensor
|
| 1145 |
+
x = x.reshape(-1, x.shape[-1])
|
| 1146 |
+
if x.stride(-1) != 1:
|
| 1147 |
+
x = x.contiguous()
|
| 1148 |
+
if residual is not None:
|
| 1149 |
+
assert residual.shape == x_shape_og
|
| 1150 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 1151 |
+
if residual.stride(-1) != 1:
|
| 1152 |
+
residual = residual.contiguous()
|
| 1153 |
+
norm_weight = norm_weight.contiguous()
|
| 1154 |
+
if norm_bias is not None:
|
| 1155 |
+
norm_bias = norm_bias.contiguous()
|
| 1156 |
+
residual_dtype = (
|
| 1157 |
+
residual.dtype
|
| 1158 |
+
if residual is not None
|
| 1159 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 1160 |
+
)
|
| 1161 |
+
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
| 1162 |
+
x,
|
| 1163 |
+
norm_weight,
|
| 1164 |
+
norm_bias,
|
| 1165 |
+
eps,
|
| 1166 |
+
residual,
|
| 1167 |
+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
|
| 1168 |
+
residual_dtype=residual_dtype,
|
| 1169 |
+
is_rms_norm=is_rms_norm,
|
| 1170 |
+
)
|
| 1171 |
+
y = y.reshape(x_shape_og)
|
| 1172 |
+
dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
|
| 1173 |
+
linear_weight = linear_weight.to(dtype)
|
| 1174 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 1175 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 1176 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
| 1177 |
+
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
| 1178 |
+
ctx.x_shape_og = x_shape_og
|
| 1179 |
+
ctx.eps = eps
|
| 1180 |
+
ctx.is_rms_norm = is_rms_norm
|
| 1181 |
+
ctx.has_residual = residual is not None
|
| 1182 |
+
ctx.prenorm = prenorm
|
| 1183 |
+
ctx.x_dtype = x.dtype
|
| 1184 |
+
ctx.linear_bias_is_none = linear_bias is None
|
| 1185 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 1186 |
+
|
| 1187 |
+
@staticmethod
|
| 1188 |
+
@custom_bwd
|
| 1189 |
+
def backward(ctx, dout, *args):
|
| 1190 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 1191 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
| 1192 |
+
dy = F.linear(dout, linear_weight.t())
|
| 1193 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 1194 |
+
if dy.stride(-1) != 1:
|
| 1195 |
+
dy = dy.contiguous()
|
| 1196 |
+
assert dy.shape == x.shape
|
| 1197 |
+
if ctx.prenorm:
|
| 1198 |
+
dresidual = args[0]
|
| 1199 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 1200 |
+
if dresidual.stride(-1) != 1:
|
| 1201 |
+
dresidual = dresidual.contiguous()
|
| 1202 |
+
assert dresidual.shape == x.shape
|
| 1203 |
+
else:
|
| 1204 |
+
dresidual = None
|
| 1205 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
| 1206 |
+
dy,
|
| 1207 |
+
x,
|
| 1208 |
+
norm_weight,
|
| 1209 |
+
norm_bias,
|
| 1210 |
+
ctx.eps,
|
| 1211 |
+
mean,
|
| 1212 |
+
rstd,
|
| 1213 |
+
dresidual=dresidual,
|
| 1214 |
+
has_residual=ctx.has_residual,
|
| 1215 |
+
is_rms_norm=ctx.is_rms_norm,
|
| 1216 |
+
x_dtype=ctx.x_dtype,
|
| 1217 |
+
recompute_output=True,
|
| 1218 |
+
)
|
| 1219 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 1220 |
+
return (
|
| 1221 |
+
dx.reshape(ctx.x_shape_og),
|
| 1222 |
+
dnorm_weight,
|
| 1223 |
+
dnorm_bias,
|
| 1224 |
+
dlinear_weight,
|
| 1225 |
+
dlinear_bias,
|
| 1226 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 1227 |
+
None,
|
| 1228 |
+
None,
|
| 1229 |
+
None,
|
| 1230 |
+
None,
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
def layer_norm_linear_fn(
|
| 1235 |
+
x,
|
| 1236 |
+
norm_weight,
|
| 1237 |
+
norm_bias,
|
| 1238 |
+
linear_weight,
|
| 1239 |
+
linear_bias,
|
| 1240 |
+
residual=None,
|
| 1241 |
+
eps=1e-6,
|
| 1242 |
+
prenorm=False,
|
| 1243 |
+
residual_in_fp32=False,
|
| 1244 |
+
is_rms_norm=False,
|
| 1245 |
+
):
|
| 1246 |
+
return LayerNormLinearFn.apply(
|
| 1247 |
+
x,
|
| 1248 |
+
norm_weight,
|
| 1249 |
+
norm_bias,
|
| 1250 |
+
linear_weight,
|
| 1251 |
+
linear_bias,
|
| 1252 |
+
residual,
|
| 1253 |
+
eps,
|
| 1254 |
+
prenorm,
|
| 1255 |
+
residual_in_fp32,
|
| 1256 |
+
is_rms_norm,
|
| 1257 |
+
)
|
omnigen2/optim/__init__.py
ADDED
|
File without changes
|
omnigen2/optim/scheduler/__init__.py
ADDED
|
File without changes
|
omnigen2/optim/scheduler/cosine_lr.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Cosine Scheduler
|
| 2 |
+
|
| 3 |
+
Cosine LR schedule with warmup, cycle/restarts, noise, k-decay.
|
| 4 |
+
|
| 5 |
+
Hacked together by / Copyright 2021 Ross Wightman
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
from typing import List
|
| 11 |
+
|
| 12 |
+
from .scheduler import Scheduler
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CosineLRScheduler(Scheduler):
|
| 19 |
+
"""
|
| 20 |
+
Cosine decay with restarts.
|
| 21 |
+
This is described in the paper https://arxiv.org/abs/1608.03983.
|
| 22 |
+
|
| 23 |
+
Inspiration from
|
| 24 |
+
https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
|
| 25 |
+
|
| 26 |
+
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
optimizer: torch.optim.Optimizer,
|
| 32 |
+
t_initial: int,
|
| 33 |
+
lr_min: float = 0.,
|
| 34 |
+
cycle_mul: float = 1.,
|
| 35 |
+
cycle_decay: float = 1.,
|
| 36 |
+
cycle_limit: int = 1,
|
| 37 |
+
warmup_t=0,
|
| 38 |
+
warmup_lr_init=0,
|
| 39 |
+
warmup_prefix=False,
|
| 40 |
+
t_in_epochs=True,
|
| 41 |
+
noise_range_t=None,
|
| 42 |
+
noise_pct=0.67,
|
| 43 |
+
noise_std=1.0,
|
| 44 |
+
noise_seed=42,
|
| 45 |
+
k_decay=1.0,
|
| 46 |
+
initialize=True,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__(
|
| 49 |
+
optimizer,
|
| 50 |
+
param_group_field="lr",
|
| 51 |
+
t_in_epochs=t_in_epochs,
|
| 52 |
+
noise_range_t=noise_range_t,
|
| 53 |
+
noise_pct=noise_pct,
|
| 54 |
+
noise_std=noise_std,
|
| 55 |
+
noise_seed=noise_seed,
|
| 56 |
+
initialize=initialize,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
assert t_initial > 0
|
| 60 |
+
assert lr_min >= 0
|
| 61 |
+
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
|
| 62 |
+
_logger.warning(
|
| 63 |
+
"Cosine annealing scheduler will have no effect on the learning "
|
| 64 |
+
"rate since t_initial = t_mul = eta_mul = 1.")
|
| 65 |
+
self.t_initial = t_initial
|
| 66 |
+
self.lr_min = lr_min
|
| 67 |
+
self.cycle_mul = cycle_mul
|
| 68 |
+
self.cycle_decay = cycle_decay
|
| 69 |
+
self.cycle_limit = cycle_limit
|
| 70 |
+
self.warmup_t = warmup_t
|
| 71 |
+
self.warmup_lr_init = warmup_lr_init
|
| 72 |
+
self.warmup_prefix = warmup_prefix
|
| 73 |
+
self.k_decay = k_decay
|
| 74 |
+
if self.warmup_t:
|
| 75 |
+
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
| 76 |
+
super().update_groups(self.warmup_lr_init)
|
| 77 |
+
else:
|
| 78 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
| 79 |
+
|
| 80 |
+
self._step_count = 0 # no use
|
| 81 |
+
|
| 82 |
+
def _get_lr(self, t: int) -> List[float]:
|
| 83 |
+
|
| 84 |
+
if t < self.warmup_t:
|
| 85 |
+
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
| 86 |
+
else:
|
| 87 |
+
if self.warmup_prefix:
|
| 88 |
+
t = t - self.warmup_t
|
| 89 |
+
|
| 90 |
+
if self.cycle_mul != 1:
|
| 91 |
+
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
|
| 92 |
+
t_i = self.cycle_mul ** i * self.t_initial
|
| 93 |
+
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
|
| 94 |
+
else:
|
| 95 |
+
i = t // self.t_initial
|
| 96 |
+
t_i = self.t_initial
|
| 97 |
+
t_curr = t - (self.t_initial * i)
|
| 98 |
+
|
| 99 |
+
gamma = self.cycle_decay ** i
|
| 100 |
+
lr_max_values = [v * gamma for v in self.base_values]
|
| 101 |
+
k = self.k_decay
|
| 102 |
+
|
| 103 |
+
if i < self.cycle_limit:
|
| 104 |
+
lrs = [
|
| 105 |
+
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
|
| 106 |
+
for lr_max in lr_max_values
|
| 107 |
+
]
|
| 108 |
+
else:
|
| 109 |
+
lrs = [self.lr_min for _ in self.base_values]
|
| 110 |
+
|
| 111 |
+
return lrs
|
| 112 |
+
|
| 113 |
+
def get_cycle_length(self, cycles=0):
|
| 114 |
+
cycles = max(1, cycles or self.cycle_limit)
|
| 115 |
+
if self.cycle_mul == 1.0:
|
| 116 |
+
return self.t_initial * cycles
|
| 117 |
+
else:
|
| 118 |
+
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
omnigen2/optim/scheduler/scheduler.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Scheduler(ABC):
|
| 9 |
+
""" Parameter Scheduler Base Class
|
| 10 |
+
A scheduler base class that can be used to schedule any optimizer parameter groups.
|
| 11 |
+
|
| 12 |
+
Unlike the builtin PyTorch schedulers, this is intended to be consistently called
|
| 13 |
+
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
|
| 14 |
+
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value
|
| 15 |
+
|
| 16 |
+
The schedulers built on this should try to remain as stateless as possible (for simplicity).
|
| 17 |
+
|
| 18 |
+
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
|
| 19 |
+
and -1 values for special behaviour. All epoch and update counts must be tracked in the training
|
| 20 |
+
code and explicitly passed in to the schedulers on the corresponding step or step_update call.
|
| 21 |
+
|
| 22 |
+
Based on ideas from:
|
| 23 |
+
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
|
| 24 |
+
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
optimizer: torch.optim.Optimizer,
|
| 30 |
+
param_group_field: str,
|
| 31 |
+
t_in_epochs: bool = True,
|
| 32 |
+
noise_range_t=None,
|
| 33 |
+
noise_type='normal',
|
| 34 |
+
noise_pct=0.67,
|
| 35 |
+
noise_std=1.0,
|
| 36 |
+
noise_seed=None,
|
| 37 |
+
initialize: bool = True,
|
| 38 |
+
) -> None:
|
| 39 |
+
self.optimizer = optimizer
|
| 40 |
+
self.param_group_field = param_group_field
|
| 41 |
+
self._initial_param_group_field = f"initial_{param_group_field}"
|
| 42 |
+
if initialize:
|
| 43 |
+
for i, group in enumerate(self.optimizer.param_groups):
|
| 44 |
+
if param_group_field not in group:
|
| 45 |
+
raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
|
| 46 |
+
group.setdefault(self._initial_param_group_field, group[param_group_field])
|
| 47 |
+
else:
|
| 48 |
+
for i, group in enumerate(self.optimizer.param_groups):
|
| 49 |
+
if self._initial_param_group_field not in group:
|
| 50 |
+
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
|
| 51 |
+
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
|
| 52 |
+
self.metric = None # any point to having this for all?
|
| 53 |
+
self.t_in_epochs = t_in_epochs
|
| 54 |
+
self.noise_range_t = noise_range_t
|
| 55 |
+
self.noise_pct = noise_pct
|
| 56 |
+
self.noise_type = noise_type
|
| 57 |
+
self.noise_std = noise_std
|
| 58 |
+
self.noise_seed = noise_seed if noise_seed is not None else 42
|
| 59 |
+
self.update_groups(self.base_values)
|
| 60 |
+
|
| 61 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 62 |
+
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
| 63 |
+
|
| 64 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 65 |
+
self.__dict__.update(state_dict)
|
| 66 |
+
|
| 67 |
+
def get_last_lr(self):
|
| 68 |
+
""" Return last computed learning rate by current scheduler.
|
| 69 |
+
"""
|
| 70 |
+
return self._last_lr
|
| 71 |
+
|
| 72 |
+
@abc.abstractmethod
|
| 73 |
+
def _get_lr(self, t: int) -> List[float]:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
|
| 77 |
+
return self._get_lr(t)
|
| 78 |
+
|
| 79 |
+
def step(self, epoch: int, metric: float = None) -> None:
|
| 80 |
+
self.metric = metric
|
| 81 |
+
values = self._get_values(epoch, on_epoch=True)
|
| 82 |
+
if values is not None:
|
| 83 |
+
values = self._add_noise(values, epoch)
|
| 84 |
+
self.update_groups(values)
|
| 85 |
+
|
| 86 |
+
# def step_update(self, num_updates: int, metric: float = None):
|
| 87 |
+
# self.metric = metric
|
| 88 |
+
# values = self._get_values(num_updates, on_epoch=False)
|
| 89 |
+
# if values is not None:
|
| 90 |
+
# values = self._add_noise(values, num_updates)
|
| 91 |
+
# self.update_groups(values)
|
| 92 |
+
|
| 93 |
+
def update_groups(self, values):
|
| 94 |
+
if not isinstance(values, (list, tuple)):
|
| 95 |
+
values = [values] * len(self.optimizer.param_groups)
|
| 96 |
+
for param_group, value in zip(self.optimizer.param_groups, values):
|
| 97 |
+
if 'lr_scale' in param_group:
|
| 98 |
+
param_group[self.param_group_field] = value * param_group['lr_scale']
|
| 99 |
+
else:
|
| 100 |
+
param_group[self.param_group_field] = value
|
| 101 |
+
|
| 102 |
+
self._last_lr = [group[self.param_group_field] for group in self.optimizer.param_groups]
|
| 103 |
+
|
| 104 |
+
def _add_noise(self, lrs, t):
|
| 105 |
+
if self._is_apply_noise(t):
|
| 106 |
+
noise = self._calculate_noise(t)
|
| 107 |
+
lrs = [v + v * noise for v in lrs]
|
| 108 |
+
return lrs
|
| 109 |
+
|
| 110 |
+
def _is_apply_noise(self, t) -> bool:
|
| 111 |
+
"""Return True if scheduler in noise range."""
|
| 112 |
+
apply_noise = False
|
| 113 |
+
if self.noise_range_t is not None:
|
| 114 |
+
if isinstance(self.noise_range_t, (list, tuple)):
|
| 115 |
+
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
|
| 116 |
+
else:
|
| 117 |
+
apply_noise = t >= self.noise_range_t
|
| 118 |
+
return apply_noise
|
| 119 |
+
|
| 120 |
+
def _calculate_noise(self, t) -> float:
|
| 121 |
+
g = torch.Generator()
|
| 122 |
+
g.manual_seed(self.noise_seed + t)
|
| 123 |
+
if self.noise_type == 'normal':
|
| 124 |
+
while True:
|
| 125 |
+
# resample if noise out of percent limit, brute force but shouldn't spin much
|
| 126 |
+
noise = torch.randn(1, generator=g).item()
|
| 127 |
+
if abs(noise) < self.noise_pct:
|
| 128 |
+
return noise
|
| 129 |
+
else:
|
| 130 |
+
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
| 131 |
+
return noise
|
omnigen2/optim/scheduler/step_lr.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Step Scheduler
|
| 2 |
+
|
| 3 |
+
Basic step LR schedule with warmup, noise.
|
| 4 |
+
|
| 5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 6 |
+
"""
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from .scheduler import Scheduler
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class StepLRScheduler(Scheduler):
|
| 16 |
+
"""
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
optimizer: torch.optim.Optimizer,
|
| 22 |
+
decay_t: float,
|
| 23 |
+
decay_rate: float = 1.,
|
| 24 |
+
warmup_t=0,
|
| 25 |
+
warmup_lr_init=0,
|
| 26 |
+
warmup_prefix=True,
|
| 27 |
+
t_in_epochs=True,
|
| 28 |
+
noise_range_t=None,
|
| 29 |
+
noise_pct=0.67,
|
| 30 |
+
noise_std=1.0,
|
| 31 |
+
noise_seed=42,
|
| 32 |
+
initialize=True,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__(
|
| 35 |
+
optimizer,
|
| 36 |
+
param_group_field="lr",
|
| 37 |
+
t_in_epochs=t_in_epochs,
|
| 38 |
+
noise_range_t=noise_range_t,
|
| 39 |
+
noise_pct=noise_pct,
|
| 40 |
+
noise_std=noise_std,
|
| 41 |
+
noise_seed=noise_seed,
|
| 42 |
+
initialize=initialize,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.decay_t = decay_t
|
| 46 |
+
self.decay_rate = decay_rate
|
| 47 |
+
self.warmup_t = warmup_t
|
| 48 |
+
self.warmup_lr_init = warmup_lr_init
|
| 49 |
+
self.warmup_prefix = warmup_prefix
|
| 50 |
+
if self.warmup_t:
|
| 51 |
+
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
| 52 |
+
super().update_groups(self.warmup_lr_init)
|
| 53 |
+
else:
|
| 54 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
| 55 |
+
|
| 56 |
+
def _get_lr(self, t: int) -> List[float]:
|
| 57 |
+
if t < self.warmup_t:
|
| 58 |
+
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
| 59 |
+
else:
|
| 60 |
+
if self.warmup_prefix:
|
| 61 |
+
t = t - self.warmup_t
|
| 62 |
+
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
|
| 63 |
+
return lrs
|
omnigen2/pipelines/__init__.py
ADDED
|
File without changes
|
omnigen2/pipelines/image_processor.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import warnings
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import PIL.Image
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
|
| 24 |
+
from diffusers.configuration_utils import register_to_config
|
| 25 |
+
|
| 26 |
+
class OmniGen2ImageProcessor(VaeImageProcessor):
|
| 27 |
+
"""
|
| 28 |
+
Image processor for PixArt image resize and crop.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 32 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
| 33 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
| 34 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
| 35 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
| 36 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
| 37 |
+
Resampling filter to use when resizing the image.
|
| 38 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 39 |
+
Whether to normalize the image to [-1,1].
|
| 40 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
| 41 |
+
Whether to binarize the image to 0/1.
|
| 42 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
| 43 |
+
Whether to convert the images to RGB format.
|
| 44 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
| 45 |
+
Whether to convert the images to grayscale format.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
@register_to_config
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
do_resize: bool = True,
|
| 52 |
+
vae_scale_factor: int = 16,
|
| 53 |
+
resample: str = "lanczos",
|
| 54 |
+
max_pixels: Optional[int] = None,
|
| 55 |
+
max_side_length: Optional[int] = None,
|
| 56 |
+
do_normalize: bool = True,
|
| 57 |
+
do_binarize: bool = False,
|
| 58 |
+
do_convert_grayscale: bool = False,
|
| 59 |
+
):
|
| 60 |
+
super().__init__(
|
| 61 |
+
do_resize=do_resize,
|
| 62 |
+
vae_scale_factor=vae_scale_factor,
|
| 63 |
+
resample=resample,
|
| 64 |
+
do_normalize=do_normalize,
|
| 65 |
+
do_binarize=do_binarize,
|
| 66 |
+
do_convert_grayscale=do_convert_grayscale,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.max_pixels = max_pixels
|
| 70 |
+
self.max_side_length = max_side_length
|
| 71 |
+
|
| 72 |
+
def get_new_height_width(
|
| 73 |
+
self,
|
| 74 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
| 75 |
+
height: Optional[int] = None,
|
| 76 |
+
width: Optional[int] = None,
|
| 77 |
+
max_pixels: Optional[int] = None,
|
| 78 |
+
max_side_length: Optional[int] = None,
|
| 79 |
+
) -> Tuple[int, int]:
|
| 80 |
+
r"""
|
| 81 |
+
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
| 85 |
+
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
|
| 86 |
+
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
|
| 87 |
+
tensor, it should have shape `[batch, channels, height, width]`.
|
| 88 |
+
height (`Optional[int]`, *optional*, defaults to `None`):
|
| 89 |
+
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
|
| 90 |
+
width (`Optional[int]`, *optional*, defaults to `None`):
|
| 91 |
+
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
`Tuple[int, int]`:
|
| 95 |
+
A tuple containing the height and width, both resized to the nearest integer multiple of
|
| 96 |
+
`vae_scale_factor`.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
if height is None:
|
| 100 |
+
if isinstance(image, PIL.Image.Image):
|
| 101 |
+
height = image.height
|
| 102 |
+
elif isinstance(image, torch.Tensor):
|
| 103 |
+
height = image.shape[2]
|
| 104 |
+
else:
|
| 105 |
+
height = image.shape[1]
|
| 106 |
+
|
| 107 |
+
if width is None:
|
| 108 |
+
if isinstance(image, PIL.Image.Image):
|
| 109 |
+
width = image.width
|
| 110 |
+
elif isinstance(image, torch.Tensor):
|
| 111 |
+
width = image.shape[3]
|
| 112 |
+
else:
|
| 113 |
+
width = image.shape[2]
|
| 114 |
+
|
| 115 |
+
if max_side_length is None:
|
| 116 |
+
max_side_length = self.max_side_length
|
| 117 |
+
|
| 118 |
+
if max_pixels is None:
|
| 119 |
+
max_pixels = self.max_pixels
|
| 120 |
+
|
| 121 |
+
ratio = 1.0
|
| 122 |
+
if max_side_length is not None:
|
| 123 |
+
if height > width:
|
| 124 |
+
max_side_length_ratio = max_side_length / height
|
| 125 |
+
else:
|
| 126 |
+
max_side_length_ratio = max_side_length / width
|
| 127 |
+
|
| 128 |
+
cur_pixels = height * width
|
| 129 |
+
max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
|
| 130 |
+
ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
|
| 131 |
+
|
| 132 |
+
new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
|
| 133 |
+
return new_height, new_width
|
| 134 |
+
|
| 135 |
+
def preprocess(
|
| 136 |
+
self,
|
| 137 |
+
image: PipelineImageInput,
|
| 138 |
+
height: Optional[int] = None,
|
| 139 |
+
width: Optional[int] = None,
|
| 140 |
+
max_pixels: Optional[int] = None,
|
| 141 |
+
max_side_length: Optional[int] = None,
|
| 142 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
| 143 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Preprocess the image input.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
image (`PipelineImageInput`):
|
| 150 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
| 151 |
+
supported formats.
|
| 152 |
+
height (`int`, *optional*):
|
| 153 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
| 154 |
+
height.
|
| 155 |
+
width (`int`, *optional*):
|
| 156 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
| 157 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
| 158 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
| 159 |
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
| 160 |
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
| 161 |
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
| 162 |
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
| 163 |
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
| 164 |
+
supported for PIL image input.
|
| 165 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
| 166 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
`torch.Tensor`:
|
| 170 |
+
The preprocessed image.
|
| 171 |
+
"""
|
| 172 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
| 173 |
+
|
| 174 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
| 175 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
| 176 |
+
if isinstance(image, torch.Tensor):
|
| 177 |
+
# if image is a pytorch tensor could have 2 possible shapes:
|
| 178 |
+
# 1. batch x height x width: we should insert the channel dimension at position 1
|
| 179 |
+
# 2. channel x height x width: we should insert batch dimension at position 0,
|
| 180 |
+
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
| 181 |
+
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
| 182 |
+
image = image.unsqueeze(1)
|
| 183 |
+
else:
|
| 184 |
+
# if it is a numpy array, it could have 2 possible shapes:
|
| 185 |
+
# 1. batch x height x width: insert channel dimension on last position
|
| 186 |
+
# 2. height x width x channel: insert batch dimension on first position
|
| 187 |
+
if image.shape[-1] == 1:
|
| 188 |
+
image = np.expand_dims(image, axis=0)
|
| 189 |
+
else:
|
| 190 |
+
image = np.expand_dims(image, axis=-1)
|
| 191 |
+
|
| 192 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
| 193 |
+
warnings.warn(
|
| 194 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
| 195 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
| 196 |
+
FutureWarning,
|
| 197 |
+
)
|
| 198 |
+
image = np.concatenate(image, axis=0)
|
| 199 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
| 200 |
+
warnings.warn(
|
| 201 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
| 202 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
| 203 |
+
FutureWarning,
|
| 204 |
+
)
|
| 205 |
+
image = torch.cat(image, axis=0)
|
| 206 |
+
|
| 207 |
+
if not is_valid_image_imagelist(image):
|
| 208 |
+
raise ValueError(
|
| 209 |
+
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
| 210 |
+
)
|
| 211 |
+
if not isinstance(image, list):
|
| 212 |
+
image = [image]
|
| 213 |
+
|
| 214 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 215 |
+
if crops_coords is not None:
|
| 216 |
+
image = [i.crop(crops_coords) for i in image]
|
| 217 |
+
if self.config.do_resize:
|
| 218 |
+
height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
|
| 219 |
+
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
| 220 |
+
if self.config.do_convert_rgb:
|
| 221 |
+
image = [self.convert_to_rgb(i) for i in image]
|
| 222 |
+
elif self.config.do_convert_grayscale:
|
| 223 |
+
image = [self.convert_to_grayscale(i) for i in image]
|
| 224 |
+
image = self.pil_to_numpy(image) # to np
|
| 225 |
+
image = self.numpy_to_pt(image) # to pt
|
| 226 |
+
|
| 227 |
+
elif isinstance(image[0], np.ndarray):
|
| 228 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
| 229 |
+
|
| 230 |
+
image = self.numpy_to_pt(image)
|
| 231 |
+
|
| 232 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
| 233 |
+
if self.config.do_resize:
|
| 234 |
+
image = self.resize(image, height, width)
|
| 235 |
+
|
| 236 |
+
elif isinstance(image[0], torch.Tensor):
|
| 237 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
| 238 |
+
|
| 239 |
+
if self.config.do_convert_grayscale and image.ndim == 3:
|
| 240 |
+
image = image.unsqueeze(1)
|
| 241 |
+
|
| 242 |
+
channel = image.shape[1]
|
| 243 |
+
# don't need any preprocess if the image is latents
|
| 244 |
+
if channel == self.config.vae_latent_channels:
|
| 245 |
+
return image
|
| 246 |
+
|
| 247 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
| 248 |
+
if self.config.do_resize:
|
| 249 |
+
image = self.resize(image, height, width)
|
| 250 |
+
|
| 251 |
+
# expected range [0,1], normalize to [-1,1]
|
| 252 |
+
do_normalize = self.config.do_normalize
|
| 253 |
+
if do_normalize and image.min() < 0:
|
| 254 |
+
warnings.warn(
|
| 255 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
| 256 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
| 257 |
+
FutureWarning,
|
| 258 |
+
)
|
| 259 |
+
do_normalize = False
|
| 260 |
+
if do_normalize:
|
| 261 |
+
image = self.normalize(image)
|
| 262 |
+
|
| 263 |
+
if self.config.do_binarize:
|
| 264 |
+
image = self.binarize(image)
|
| 265 |
+
|
| 266 |
+
return image
|
omnigen2/pipelines/lora_pipeline.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
| 20 |
+
|
| 21 |
+
from diffusers.utils import (
|
| 22 |
+
USE_PEFT_BACKEND,
|
| 23 |
+
is_peft_available,
|
| 24 |
+
is_peft_version,
|
| 25 |
+
is_torch_version,
|
| 26 |
+
is_transformers_available,
|
| 27 |
+
is_transformers_version,
|
| 28 |
+
logging,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.loaders.lora_base import ( # noqa
|
| 31 |
+
LoraBaseMixin,
|
| 32 |
+
_fetch_state_dict,
|
| 33 |
+
)
|
| 34 |
+
from diffusers.loaders.lora_conversion_utils import (
|
| 35 |
+
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
|
| 40 |
+
if is_torch_version(">=", "1.9.0"):
|
| 41 |
+
if (
|
| 42 |
+
is_peft_available()
|
| 43 |
+
and is_peft_version(">=", "0.13.1")
|
| 44 |
+
and is_transformers_available()
|
| 45 |
+
and is_transformers_version(">", "4.45.2")
|
| 46 |
+
):
|
| 47 |
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
TRANSFORMER_NAME = "transformer"
|
| 53 |
+
|
| 54 |
+
class OmniGen2LoraLoaderMixin(LoraBaseMixin):
|
| 55 |
+
r"""
|
| 56 |
+
Load LoRA layers into [`OmniGen2Transformer2DModel`]. Specific to [`OmniGen2Pipeline`].
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
_lora_loadable_modules = ["transformer"]
|
| 60 |
+
transformer_name = TRANSFORMER_NAME
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
@validate_hf_hub_args
|
| 64 |
+
def lora_state_dict(
|
| 65 |
+
cls,
|
| 66 |
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
| 67 |
+
**kwargs,
|
| 68 |
+
):
|
| 69 |
+
r"""
|
| 70 |
+
Return state dict for lora weights and the network alphas.
|
| 71 |
+
|
| 72 |
+
<Tip warning={true}>
|
| 73 |
+
|
| 74 |
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
| 75 |
+
|
| 76 |
+
This function is experimental and might change in the future.
|
| 77 |
+
|
| 78 |
+
</Tip>
|
| 79 |
+
|
| 80 |
+
Parameters:
|
| 81 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
| 82 |
+
Can be either:
|
| 83 |
+
|
| 84 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 85 |
+
the Hub.
|
| 86 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
| 87 |
+
with [`ModelMixin.save_pretrained`].
|
| 88 |
+
- A [torch state
|
| 89 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
| 90 |
+
|
| 91 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 92 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 93 |
+
is not used.
|
| 94 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 95 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 96 |
+
cached versions if they exist.
|
| 97 |
+
|
| 98 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 99 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 100 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 101 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 102 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 103 |
+
won't be downloaded from the Hub.
|
| 104 |
+
token (`str` or *bool*, *optional*):
|
| 105 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 106 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 107 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 108 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 109 |
+
allowed by Git.
|
| 110 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 111 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
# Load the main state dict first which has the LoRA layers for either of
|
| 115 |
+
# transformer and text encoder or both.
|
| 116 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 117 |
+
force_download = kwargs.pop("force_download", False)
|
| 118 |
+
proxies = kwargs.pop("proxies", None)
|
| 119 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
| 120 |
+
token = kwargs.pop("token", None)
|
| 121 |
+
revision = kwargs.pop("revision", None)
|
| 122 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 123 |
+
weight_name = kwargs.pop("weight_name", None)
|
| 124 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
| 125 |
+
|
| 126 |
+
allow_pickle = False
|
| 127 |
+
if use_safetensors is None:
|
| 128 |
+
use_safetensors = True
|
| 129 |
+
allow_pickle = True
|
| 130 |
+
|
| 131 |
+
user_agent = {
|
| 132 |
+
"file_type": "attn_procs_weights",
|
| 133 |
+
"framework": "pytorch",
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
state_dict = _fetch_state_dict(
|
| 137 |
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
| 138 |
+
weight_name=weight_name,
|
| 139 |
+
use_safetensors=use_safetensors,
|
| 140 |
+
local_files_only=local_files_only,
|
| 141 |
+
cache_dir=cache_dir,
|
| 142 |
+
force_download=force_download,
|
| 143 |
+
proxies=proxies,
|
| 144 |
+
token=token,
|
| 145 |
+
revision=revision,
|
| 146 |
+
subfolder=subfolder,
|
| 147 |
+
user_agent=user_agent,
|
| 148 |
+
allow_pickle=allow_pickle,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
| 152 |
+
if is_dora_scale_present:
|
| 153 |
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
| 154 |
+
logger.warning(warn_msg)
|
| 155 |
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
| 156 |
+
|
| 157 |
+
# conversion.
|
| 158 |
+
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
|
| 159 |
+
if non_diffusers:
|
| 160 |
+
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
|
| 161 |
+
|
| 162 |
+
return state_dict
|
| 163 |
+
|
| 164 |
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
| 165 |
+
def load_lora_weights(
|
| 166 |
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
| 167 |
+
):
|
| 168 |
+
"""
|
| 169 |
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
| 170 |
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
| 171 |
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
| 172 |
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
| 173 |
+
dict is loaded into `self.transformer`.
|
| 174 |
+
|
| 175 |
+
Parameters:
|
| 176 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
| 177 |
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
| 178 |
+
adapter_name (`str`, *optional*):
|
| 179 |
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
| 180 |
+
`default_{i}` where i is the total number of adapters being loaded.
|
| 181 |
+
low_cpu_mem_usage (`bool`, *optional*):
|
| 182 |
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
| 183 |
+
weights.
|
| 184 |
+
kwargs (`dict`, *optional*):
|
| 185 |
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
| 186 |
+
"""
|
| 187 |
+
if not USE_PEFT_BACKEND:
|
| 188 |
+
raise ValueError("PEFT backend is required for this method.")
|
| 189 |
+
|
| 190 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
| 191 |
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
| 192 |
+
raise ValueError(
|
| 193 |
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# if a dict is passed, copy it instead of modifying it inplace
|
| 197 |
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
| 198 |
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
| 199 |
+
|
| 200 |
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
| 201 |
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
| 202 |
+
|
| 203 |
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
| 204 |
+
if not is_correct_format:
|
| 205 |
+
raise ValueError("Invalid LoRA checkpoint.")
|
| 206 |
+
|
| 207 |
+
self.load_lora_into_transformer(
|
| 208 |
+
state_dict,
|
| 209 |
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
| 210 |
+
adapter_name=adapter_name,
|
| 211 |
+
_pipeline=self,
|
| 212 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
@classmethod
|
| 216 |
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
|
| 217 |
+
def load_lora_into_transformer(
|
| 218 |
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
| 219 |
+
):
|
| 220 |
+
"""
|
| 221 |
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
| 222 |
+
|
| 223 |
+
Parameters:
|
| 224 |
+
state_dict (`dict`):
|
| 225 |
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
| 226 |
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
| 227 |
+
encoder lora layers.
|
| 228 |
+
transformer (`Lumina2Transformer2DModel`):
|
| 229 |
+
The Transformer model to load the LoRA layers into.
|
| 230 |
+
adapter_name (`str`, *optional*):
|
| 231 |
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
| 232 |
+
`default_{i}` where i is the total number of adapters being loaded.
|
| 233 |
+
low_cpu_mem_usage (`bool`, *optional*):
|
| 234 |
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
| 235 |
+
weights.
|
| 236 |
+
hotswap : (`bool`, *optional*)
|
| 237 |
+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
| 238 |
+
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
| 239 |
+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
| 240 |
+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
| 241 |
+
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
| 242 |
+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
| 243 |
+
|
| 244 |
+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
| 245 |
+
to call an additional method before loading the adapter:
|
| 246 |
+
|
| 247 |
+
```py
|
| 248 |
+
pipeline = ... # load diffusers pipeline
|
| 249 |
+
max_rank = ... # the highest rank among all LoRAs that you want to load
|
| 250 |
+
# call *before* compiling and loading the LoRA adapter
|
| 251 |
+
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
| 252 |
+
pipeline.load_lora_weights(file_name)
|
| 253 |
+
# optionally compile the model now
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
| 257 |
+
limitations to this technique, which are documented here:
|
| 258 |
+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
| 259 |
+
"""
|
| 260 |
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
| 261 |
+
raise ValueError(
|
| 262 |
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Load the layers corresponding to transformer.
|
| 266 |
+
logger.info(f"Loading {cls.transformer_name}.")
|
| 267 |
+
transformer.load_lora_adapter(
|
| 268 |
+
state_dict,
|
| 269 |
+
network_alphas=None,
|
| 270 |
+
adapter_name=adapter_name,
|
| 271 |
+
_pipeline=_pipeline,
|
| 272 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
| 273 |
+
hotswap=hotswap,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
@classmethod
|
| 277 |
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
| 278 |
+
def save_lora_weights(
|
| 279 |
+
cls,
|
| 280 |
+
save_directory: Union[str, os.PathLike],
|
| 281 |
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
| 282 |
+
is_main_process: bool = True,
|
| 283 |
+
weight_name: str = None,
|
| 284 |
+
save_function: Callable = None,
|
| 285 |
+
safe_serialization: bool = True,
|
| 286 |
+
):
|
| 287 |
+
r"""
|
| 288 |
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
| 289 |
+
|
| 290 |
+
Arguments:
|
| 291 |
+
save_directory (`str` or `os.PathLike`):
|
| 292 |
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
| 293 |
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
| 294 |
+
State dict of the LoRA layers corresponding to the `transformer`.
|
| 295 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
| 296 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
| 297 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
| 298 |
+
process to avoid race conditions.
|
| 299 |
+
save_function (`Callable`):
|
| 300 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
| 301 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
| 302 |
+
`DIFFUSERS_SAVE_MODE`.
|
| 303 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
| 304 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
| 305 |
+
"""
|
| 306 |
+
state_dict = {}
|
| 307 |
+
|
| 308 |
+
if not transformer_lora_layers:
|
| 309 |
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
| 310 |
+
|
| 311 |
+
if transformer_lora_layers:
|
| 312 |
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
| 313 |
+
|
| 314 |
+
# Save the model
|
| 315 |
+
cls.write_lora_layers(
|
| 316 |
+
state_dict=state_dict,
|
| 317 |
+
save_directory=save_directory,
|
| 318 |
+
is_main_process=is_main_process,
|
| 319 |
+
weight_name=weight_name,
|
| 320 |
+
save_function=save_function,
|
| 321 |
+
safe_serialization=safe_serialization,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
| 325 |
+
def fuse_lora(
|
| 326 |
+
self,
|
| 327 |
+
components: List[str] = ["transformer"],
|
| 328 |
+
lora_scale: float = 1.0,
|
| 329 |
+
safe_fusing: bool = False,
|
| 330 |
+
adapter_names: Optional[List[str]] = None,
|
| 331 |
+
**kwargs,
|
| 332 |
+
):
|
| 333 |
+
r"""
|
| 334 |
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
| 335 |
+
|
| 336 |
+
<Tip warning={true}>
|
| 337 |
+
|
| 338 |
+
This is an experimental API.
|
| 339 |
+
|
| 340 |
+
</Tip>
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
| 344 |
+
lora_scale (`float`, defaults to 1.0):
|
| 345 |
+
Controls how much to influence the outputs with the LoRA parameters.
|
| 346 |
+
safe_fusing (`bool`, defaults to `False`):
|
| 347 |
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
| 348 |
+
adapter_names (`List[str]`, *optional*):
|
| 349 |
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
| 350 |
+
|
| 351 |
+
Example:
|
| 352 |
+
|
| 353 |
+
```py
|
| 354 |
+
from diffusers import DiffusionPipeline
|
| 355 |
+
import torch
|
| 356 |
+
|
| 357 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 358 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
| 359 |
+
).to("cuda")
|
| 360 |
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
| 361 |
+
pipeline.fuse_lora(lora_scale=0.7)
|
| 362 |
+
```
|
| 363 |
+
"""
|
| 364 |
+
super().fuse_lora(
|
| 365 |
+
components=components,
|
| 366 |
+
lora_scale=lora_scale,
|
| 367 |
+
safe_fusing=safe_fusing,
|
| 368 |
+
adapter_names=adapter_names,
|
| 369 |
+
**kwargs,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
| 373 |
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
| 374 |
+
r"""
|
| 375 |
+
Reverses the effect of
|
| 376 |
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
| 377 |
+
|
| 378 |
+
<Tip warning={true}>
|
| 379 |
+
|
| 380 |
+
This is an experimental API.
|
| 381 |
+
|
| 382 |
+
</Tip>
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
| 386 |
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
| 387 |
+
"""
|
| 388 |
+
super().unfuse_lora(components=components, **kwargs)
|
omnigen2/pipelines/omnigen2/pipeline_omnigen2.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OmniGen2 Diffusion Pipeline
|
| 3 |
+
|
| 4 |
+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
| 5 |
+
|
| 6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
you may not use this file except in compliance with the License.
|
| 8 |
+
You may obtain a copy of the License at
|
| 9 |
+
|
| 10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
|
| 12 |
+
Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
See the License for the specific language governing permissions and
|
| 16 |
+
limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import inspect
|
| 20 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
|
| 24 |
+
from PIL import Image
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
| 30 |
+
|
| 31 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 32 |
+
from ...models.transformers import OmniGen2Transformer2DModel
|
| 33 |
+
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
|
| 34 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 35 |
+
from diffusers.utils import (
|
| 36 |
+
is_torch_xla_available,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 40 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 41 |
+
|
| 42 |
+
from dataclasses import dataclass
|
| 43 |
+
|
| 44 |
+
import PIL.Image
|
| 45 |
+
|
| 46 |
+
from diffusers.utils import BaseOutput
|
| 47 |
+
|
| 48 |
+
from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
|
| 49 |
+
|
| 50 |
+
from omnigen2.utils.teacache_util import TeaCacheParams
|
| 51 |
+
|
| 52 |
+
from ..lora_pipeline import OmniGen2LoraLoaderMixin
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if is_torch_xla_available():
|
| 56 |
+
import torch_xla.core.xla_model as xm
|
| 57 |
+
|
| 58 |
+
XLA_AVAILABLE = True
|
| 59 |
+
else:
|
| 60 |
+
XLA_AVAILABLE = False
|
| 61 |
+
|
| 62 |
+
from ...cache_functions import cache_init
|
| 63 |
+
|
| 64 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class FMPipelineOutput(BaseOutput):
|
| 68 |
+
"""
|
| 69 |
+
Output class for OmniGen2 pipeline.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
images (Union[List[PIL.Image.Image], np.ndarray]):
|
| 73 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape
|
| 74 |
+
`(batch_size, height, width, num_channels)`. Contains the generated images.
|
| 75 |
+
"""
|
| 76 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 80 |
+
def retrieve_timesteps(
|
| 81 |
+
scheduler,
|
| 82 |
+
num_inference_steps: Optional[int] = None,
|
| 83 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 84 |
+
timesteps: Optional[List[int]] = None,
|
| 85 |
+
**kwargs,
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 89 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
scheduler (`SchedulerMixin`):
|
| 93 |
+
The scheduler to get timesteps from.
|
| 94 |
+
num_inference_steps (`int`):
|
| 95 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 96 |
+
must be `None`.
|
| 97 |
+
device (`str` or `torch.device`, *optional*):
|
| 98 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 99 |
+
timesteps (`List[int]`, *optional*):
|
| 100 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 101 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 102 |
+
sigmas (`List[float]`, *optional*):
|
| 103 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 104 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 108 |
+
second element is the number of inference steps.
|
| 109 |
+
"""
|
| 110 |
+
if timesteps is not None:
|
| 111 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 112 |
+
if not accepts_timesteps:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 115 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 116 |
+
)
|
| 117 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 118 |
+
timesteps = scheduler.timesteps
|
| 119 |
+
num_inference_steps = len(timesteps)
|
| 120 |
+
else:
|
| 121 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 122 |
+
timesteps = scheduler.timesteps
|
| 123 |
+
return timesteps, num_inference_steps
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class OmniGen2Pipeline(DiffusionPipeline, OmniGen2LoraLoaderMixin):
|
| 127 |
+
"""
|
| 128 |
+
Pipeline for text-to-image generation using OmniGen2.
|
| 129 |
+
|
| 130 |
+
This pipeline implements a text-to-image generation model that uses:
|
| 131 |
+
- Qwen2.5-VL for text encoding
|
| 132 |
+
- A custom transformer architecture for image generation
|
| 133 |
+
- VAE for image encoding/decoding
|
| 134 |
+
- FlowMatchEulerDiscreteScheduler for noise scheduling
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
|
| 138 |
+
vae (AutoencoderKL): The VAE model for image encoding/decoding.
|
| 139 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
|
| 140 |
+
text_encoder (Qwen2_5_VLModel): The text encoder model.
|
| 141 |
+
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
model_cpu_offload_seq = "mllm->transformer->vae"
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
transformer: OmniGen2Transformer2DModel,
|
| 149 |
+
vae: AutoencoderKL,
|
| 150 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 151 |
+
mllm: Qwen2_5_VLForConditionalGeneration,
|
| 152 |
+
processor,
|
| 153 |
+
) -> None:
|
| 154 |
+
"""
|
| 155 |
+
Initialize the OmniGen2 pipeline.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
transformer: The transformer model for image generation.
|
| 159 |
+
vae: The VAE model for image encoding/decoding.
|
| 160 |
+
scheduler: The scheduler for noise scheduling.
|
| 161 |
+
text_encoder: The text encoder model.
|
| 162 |
+
tokenizer: The tokenizer for text processing.
|
| 163 |
+
"""
|
| 164 |
+
super().__init__()
|
| 165 |
+
|
| 166 |
+
self.register_modules(
|
| 167 |
+
transformer=transformer,
|
| 168 |
+
vae=vae,
|
| 169 |
+
scheduler=scheduler,
|
| 170 |
+
mllm=mllm,
|
| 171 |
+
processor=processor
|
| 172 |
+
)
|
| 173 |
+
self.vae_scale_factor = (
|
| 174 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 175 |
+
)
|
| 176 |
+
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
|
| 177 |
+
self.default_sample_size = 128
|
| 178 |
+
|
| 179 |
+
def prepare_latents(
|
| 180 |
+
self,
|
| 181 |
+
batch_size: int,
|
| 182 |
+
num_channels_latents: int,
|
| 183 |
+
height: int,
|
| 184 |
+
width: int,
|
| 185 |
+
dtype: torch.dtype,
|
| 186 |
+
device: torch.device,
|
| 187 |
+
generator: Optional[torch.Generator],
|
| 188 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 189 |
+
) -> torch.FloatTensor:
|
| 190 |
+
"""
|
| 191 |
+
Prepare the initial latents for the diffusion process.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
batch_size: The number of images to generate.
|
| 195 |
+
num_channels_latents: The number of channels in the latent space.
|
| 196 |
+
height: The height of the generated image.
|
| 197 |
+
width: The width of the generated image.
|
| 198 |
+
dtype: The data type of the latents.
|
| 199 |
+
device: The device to place the latents on.
|
| 200 |
+
generator: The random number generator to use.
|
| 201 |
+
latents: Optional pre-computed latents to use instead of random initialization.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
torch.FloatTensor: The prepared latents tensor.
|
| 205 |
+
"""
|
| 206 |
+
height = int(height) // self.vae_scale_factor
|
| 207 |
+
width = int(width) // self.vae_scale_factor
|
| 208 |
+
|
| 209 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 210 |
+
|
| 211 |
+
if latents is None:
|
| 212 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 213 |
+
else:
|
| 214 |
+
latents = latents.to(device)
|
| 215 |
+
return latents
|
| 216 |
+
|
| 217 |
+
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
|
| 218 |
+
"""
|
| 219 |
+
Encode an image into the VAE latent space.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
img: The input image tensor to encode.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
torch.FloatTensor: The encoded latent representation.
|
| 226 |
+
"""
|
| 227 |
+
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
|
| 228 |
+
if self.vae.config.shift_factor is not None:
|
| 229 |
+
z0 = z0 - self.vae.config.shift_factor
|
| 230 |
+
if self.vae.config.scaling_factor is not None:
|
| 231 |
+
z0 = z0 * self.vae.config.scaling_factor
|
| 232 |
+
z0 = z0.to(dtype=self.vae.dtype)
|
| 233 |
+
return z0
|
| 234 |
+
|
| 235 |
+
def prepare_image(
|
| 236 |
+
self,
|
| 237 |
+
images: Union[List[PIL.Image.Image], PIL.Image.Image],
|
| 238 |
+
batch_size: int,
|
| 239 |
+
num_images_per_prompt: int,
|
| 240 |
+
max_pixels: int,
|
| 241 |
+
max_side_length: int,
|
| 242 |
+
device: torch.device,
|
| 243 |
+
dtype: torch.dtype,
|
| 244 |
+
) -> List[Optional[torch.FloatTensor]]:
|
| 245 |
+
"""
|
| 246 |
+
Prepare input images for processing by encoding them into the VAE latent space.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
images: Single image or list of images to process.
|
| 250 |
+
batch_size: The number of images to generate per prompt.
|
| 251 |
+
num_images_per_prompt: The number of images to generate for each prompt.
|
| 252 |
+
device: The device to place the encoded latents on.
|
| 253 |
+
dtype: The data type of the encoded latents.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
|
| 257 |
+
"""
|
| 258 |
+
if batch_size == 1:
|
| 259 |
+
images = [images]
|
| 260 |
+
latents = []
|
| 261 |
+
for i, img in enumerate(images):
|
| 262 |
+
if img is not None and len(img) > 0:
|
| 263 |
+
ref_latents = []
|
| 264 |
+
for j, img_j in enumerate(img):
|
| 265 |
+
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
|
| 266 |
+
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
|
| 267 |
+
else:
|
| 268 |
+
ref_latents = None
|
| 269 |
+
for _ in range(num_images_per_prompt):
|
| 270 |
+
latents.append(ref_latents)
|
| 271 |
+
|
| 272 |
+
return latents
|
| 273 |
+
|
| 274 |
+
def _get_qwen2_prompt_embeds(
|
| 275 |
+
self,
|
| 276 |
+
prompt: Union[str, List[str]],
|
| 277 |
+
device: Optional[torch.device] = None,
|
| 278 |
+
max_sequence_length: int = 256,
|
| 279 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 280 |
+
"""
|
| 281 |
+
Get prompt embeddings from the Qwen2 text encoder.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
prompt: The prompt or list of prompts to encode.
|
| 285 |
+
device: The device to place the embeddings on. If None, uses the pipeline's device.
|
| 286 |
+
max_sequence_length: Maximum sequence length for tokenization.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
| 290 |
+
- The prompt embeddings tensor
|
| 291 |
+
- The attention mask tensor
|
| 292 |
+
|
| 293 |
+
Raises:
|
| 294 |
+
Warning: If the input text is truncated due to sequence length limitations.
|
| 295 |
+
"""
|
| 296 |
+
device = device or self._execution_device
|
| 297 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 298 |
+
# text_inputs = self.processor.tokenizer(
|
| 299 |
+
# prompt,
|
| 300 |
+
# padding="max_length",
|
| 301 |
+
# max_length=max_sequence_length,
|
| 302 |
+
# truncation=True,
|
| 303 |
+
# return_tensors="pt",
|
| 304 |
+
# )
|
| 305 |
+
text_inputs = self.processor.tokenizer(
|
| 306 |
+
prompt,
|
| 307 |
+
padding="longest",
|
| 308 |
+
max_length=max_sequence_length,
|
| 309 |
+
truncation=True,
|
| 310 |
+
return_tensors="pt",
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 314 |
+
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
| 315 |
+
|
| 316 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 317 |
+
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 318 |
+
logger.warning(
|
| 319 |
+
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
| 320 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
| 324 |
+
prompt_embeds = self.mllm(
|
| 325 |
+
text_input_ids,
|
| 326 |
+
attention_mask=prompt_attention_mask,
|
| 327 |
+
output_hidden_states=True,
|
| 328 |
+
).hidden_states[-1]
|
| 329 |
+
|
| 330 |
+
if self.mllm is not None:
|
| 331 |
+
dtype = self.mllm.dtype
|
| 332 |
+
elif self.transformer is not None:
|
| 333 |
+
dtype = self.transformer.dtype
|
| 334 |
+
else:
|
| 335 |
+
dtype = None
|
| 336 |
+
|
| 337 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 338 |
+
|
| 339 |
+
return prompt_embeds, prompt_attention_mask
|
| 340 |
+
|
| 341 |
+
def _apply_chat_template(self, prompt: str):
|
| 342 |
+
prompt = [
|
| 343 |
+
{
|
| 344 |
+
"role": "system",
|
| 345 |
+
"content": "You are a helpful assistant that generates high-quality images based on user instructions.",
|
| 346 |
+
},
|
| 347 |
+
{"role": "user", "content": prompt},
|
| 348 |
+
]
|
| 349 |
+
prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
| 350 |
+
return prompt
|
| 351 |
+
|
| 352 |
+
def encode_prompt(
|
| 353 |
+
self,
|
| 354 |
+
prompt: Union[str, List[str]],
|
| 355 |
+
do_classifier_free_guidance: bool = True,
|
| 356 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 357 |
+
num_images_per_prompt: int = 1,
|
| 358 |
+
device: Optional[torch.device] = None,
|
| 359 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 360 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 361 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 362 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 363 |
+
max_sequence_length: int = 256,
|
| 364 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 365 |
+
r"""
|
| 366 |
+
Encodes the prompt into text encoder hidden states.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 370 |
+
prompt to be encoded
|
| 371 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 372 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
| 373 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
| 374 |
+
Lumina-T2I, this should be "".
|
| 375 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 376 |
+
whether to use classifier free guidance or not
|
| 377 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 378 |
+
number of images that should be generated per prompt
|
| 379 |
+
device: (`torch.device`, *optional*):
|
| 380 |
+
torch device to place the resulting embeddings on
|
| 381 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 382 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 383 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 384 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 385 |
+
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
| 386 |
+
max_sequence_length (`int`, defaults to `256`):
|
| 387 |
+
Maximum sequence length to use for the prompt.
|
| 388 |
+
"""
|
| 389 |
+
device = device or self._execution_device
|
| 390 |
+
|
| 391 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 392 |
+
prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
|
| 393 |
+
|
| 394 |
+
if prompt is not None:
|
| 395 |
+
batch_size = len(prompt)
|
| 396 |
+
else:
|
| 397 |
+
batch_size = prompt_embeds.shape[0]
|
| 398 |
+
if prompt_embeds is None:
|
| 399 |
+
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
| 400 |
+
prompt=prompt,
|
| 401 |
+
device=device,
|
| 402 |
+
max_sequence_length=max_sequence_length
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
batch_size, seq_len, _ = prompt_embeds.shape
|
| 406 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 407 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 408 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 409 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
| 410 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
| 411 |
+
|
| 412 |
+
# Get negative embeddings for classifier free guidance
|
| 413 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 414 |
+
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
| 415 |
+
|
| 416 |
+
# Normalize str to list
|
| 417 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 418 |
+
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
|
| 419 |
+
|
| 420 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 421 |
+
raise TypeError(
|
| 422 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 423 |
+
f" {type(prompt)}."
|
| 424 |
+
)
|
| 425 |
+
elif isinstance(negative_prompt, str):
|
| 426 |
+
negative_prompt = [negative_prompt]
|
| 427 |
+
elif batch_size != len(negative_prompt):
|
| 428 |
+
raise ValueError(
|
| 429 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 430 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 431 |
+
" the batch size of `prompt`."
|
| 432 |
+
)
|
| 433 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
| 434 |
+
prompt=negative_prompt,
|
| 435 |
+
device=device,
|
| 436 |
+
max_sequence_length=max_sequence_length,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
| 440 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 441 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 442 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 443 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
| 444 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
| 445 |
+
batch_size * num_images_per_prompt, -1
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
| 449 |
+
|
| 450 |
+
@property
|
| 451 |
+
def num_timesteps(self):
|
| 452 |
+
return self._num_timesteps
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def text_guidance_scale(self):
|
| 456 |
+
return self._text_guidance_scale
|
| 457 |
+
|
| 458 |
+
@property
|
| 459 |
+
def image_guidance_scale(self):
|
| 460 |
+
return self._image_guidance_scale
|
| 461 |
+
|
| 462 |
+
@property
|
| 463 |
+
def cfg_range(self):
|
| 464 |
+
return self._cfg_range
|
| 465 |
+
|
| 466 |
+
@torch.no_grad()
|
| 467 |
+
def __call__(
|
| 468 |
+
self,
|
| 469 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 470 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 471 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 472 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 473 |
+
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
| 474 |
+
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
| 475 |
+
max_sequence_length: Optional[int] = None,
|
| 476 |
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
| 477 |
+
input_images: Optional[List[PIL.Image.Image]] = None,
|
| 478 |
+
num_images_per_prompt: int = 1,
|
| 479 |
+
height: Optional[int] = None,
|
| 480 |
+
width: Optional[int] = None,
|
| 481 |
+
max_pixels: int = 1024 * 1024,
|
| 482 |
+
max_input_image_side_length: int = 1024,
|
| 483 |
+
align_res: bool = True,
|
| 484 |
+
num_inference_steps: int = 28,
|
| 485 |
+
text_guidance_scale: float = 4.0,
|
| 486 |
+
image_guidance_scale: float = 1.0,
|
| 487 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
| 488 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 489 |
+
timesteps: List[int] = None,
|
| 490 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 491 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 492 |
+
output_type: Optional[str] = "pil",
|
| 493 |
+
return_dict: bool = True,
|
| 494 |
+
verbose: bool = False,
|
| 495 |
+
step_func=None,
|
| 496 |
+
):
|
| 497 |
+
|
| 498 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 499 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 500 |
+
|
| 501 |
+
self._text_guidance_scale = text_guidance_scale
|
| 502 |
+
self._image_guidance_scale = image_guidance_scale
|
| 503 |
+
self._cfg_range = cfg_range
|
| 504 |
+
self._attention_kwargs = attention_kwargs
|
| 505 |
+
|
| 506 |
+
# 2. Define call parameters
|
| 507 |
+
if prompt is not None and isinstance(prompt, str):
|
| 508 |
+
batch_size = 1
|
| 509 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 510 |
+
batch_size = len(prompt)
|
| 511 |
+
else:
|
| 512 |
+
batch_size = prompt_embeds.shape[0]
|
| 513 |
+
|
| 514 |
+
device = self._execution_device
|
| 515 |
+
|
| 516 |
+
# 3. Encode input prompt
|
| 517 |
+
(
|
| 518 |
+
prompt_embeds,
|
| 519 |
+
prompt_attention_mask,
|
| 520 |
+
negative_prompt_embeds,
|
| 521 |
+
negative_prompt_attention_mask,
|
| 522 |
+
) = self.encode_prompt(
|
| 523 |
+
prompt,
|
| 524 |
+
self.text_guidance_scale > 1.0,
|
| 525 |
+
negative_prompt=negative_prompt,
|
| 526 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 527 |
+
device=device,
|
| 528 |
+
prompt_embeds=prompt_embeds,
|
| 529 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 530 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 531 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 532 |
+
max_sequence_length=max_sequence_length,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
dtype = self.vae.dtype
|
| 536 |
+
# 3. Prepare control image
|
| 537 |
+
ref_latents = self.prepare_image(
|
| 538 |
+
images=input_images,
|
| 539 |
+
batch_size=batch_size,
|
| 540 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 541 |
+
max_pixels=max_pixels,
|
| 542 |
+
max_side_length=max_input_image_side_length,
|
| 543 |
+
device=device,
|
| 544 |
+
dtype=dtype,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if input_images is None:
|
| 548 |
+
input_images = []
|
| 549 |
+
|
| 550 |
+
if len(input_images) == 1 and align_res:
|
| 551 |
+
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
|
| 552 |
+
ori_width, ori_height = width, height
|
| 553 |
+
else:
|
| 554 |
+
ori_width, ori_height = width, height
|
| 555 |
+
|
| 556 |
+
cur_pixels = height * width
|
| 557 |
+
ratio = (max_pixels / cur_pixels) ** 0.5
|
| 558 |
+
ratio = min(ratio, 1.0)
|
| 559 |
+
|
| 560 |
+
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
|
| 561 |
+
|
| 562 |
+
if len(input_images) == 0:
|
| 563 |
+
self._image_guidance_scale = 1
|
| 564 |
+
|
| 565 |
+
# 4. Prepare latents.
|
| 566 |
+
latent_channels = self.transformer.config.in_channels
|
| 567 |
+
latents = self.prepare_latents(
|
| 568 |
+
batch_size * num_images_per_prompt,
|
| 569 |
+
latent_channels,
|
| 570 |
+
height,
|
| 571 |
+
width,
|
| 572 |
+
prompt_embeds.dtype,
|
| 573 |
+
device,
|
| 574 |
+
generator,
|
| 575 |
+
latents,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
| 579 |
+
self.transformer.config.axes_dim_rope,
|
| 580 |
+
self.transformer.config.axes_lens,
|
| 581 |
+
theta=10000,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
image = self.processing(
|
| 585 |
+
latents=latents,
|
| 586 |
+
ref_latents=ref_latents,
|
| 587 |
+
prompt_embeds=prompt_embeds,
|
| 588 |
+
freqs_cis=freqs_cis,
|
| 589 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 590 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 591 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 592 |
+
num_inference_steps=num_inference_steps,
|
| 593 |
+
timesteps=timesteps,
|
| 594 |
+
device=device,
|
| 595 |
+
dtype=dtype,
|
| 596 |
+
verbose=verbose,
|
| 597 |
+
step_func=step_func,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
|
| 601 |
+
|
| 602 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 603 |
+
|
| 604 |
+
# Offload all models
|
| 605 |
+
self.maybe_free_model_hooks()
|
| 606 |
+
|
| 607 |
+
if not return_dict:
|
| 608 |
+
return image
|
| 609 |
+
else:
|
| 610 |
+
return FMPipelineOutput(images=image)
|
| 611 |
+
|
| 612 |
+
def processing(
|
| 613 |
+
self,
|
| 614 |
+
latents,
|
| 615 |
+
ref_latents,
|
| 616 |
+
prompt_embeds,
|
| 617 |
+
freqs_cis,
|
| 618 |
+
negative_prompt_embeds,
|
| 619 |
+
prompt_attention_mask,
|
| 620 |
+
negative_prompt_attention_mask,
|
| 621 |
+
num_inference_steps,
|
| 622 |
+
timesteps,
|
| 623 |
+
device,
|
| 624 |
+
dtype,
|
| 625 |
+
verbose,
|
| 626 |
+
step_func=None
|
| 627 |
+
):
|
| 628 |
+
batch_size = latents.shape[0]
|
| 629 |
+
|
| 630 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 631 |
+
self.scheduler,
|
| 632 |
+
num_inference_steps,
|
| 633 |
+
device,
|
| 634 |
+
timesteps,
|
| 635 |
+
num_tokens=latents.shape[-2] * latents.shape[-1]
|
| 636 |
+
)
|
| 637 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 638 |
+
self._num_timesteps = len(timesteps)
|
| 639 |
+
|
| 640 |
+
enable_taylorseer = getattr(self, "enable_taylorseer", False)
|
| 641 |
+
if enable_taylorseer:
|
| 642 |
+
model_pred_cache_dic, model_pred_current = cache_init(self, num_inference_steps)
|
| 643 |
+
model_pred_ref_cache_dic, model_pred_ref_current = cache_init(self, num_inference_steps)
|
| 644 |
+
model_pred_uncond_cache_dic, model_pred_uncond_current = cache_init(self, num_inference_steps)
|
| 645 |
+
self.transformer.enable_taylorseer = True
|
| 646 |
+
elif self.transformer.enable_teacache:
|
| 647 |
+
# Use different TeaCacheParams for different conditions
|
| 648 |
+
teacache_params = TeaCacheParams()
|
| 649 |
+
teacache_params_uncond = TeaCacheParams()
|
| 650 |
+
teacache_params_ref = TeaCacheParams()
|
| 651 |
+
|
| 652 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 653 |
+
for i, t in enumerate(timesteps):
|
| 654 |
+
if enable_taylorseer:
|
| 655 |
+
self.transformer.cache_dic = model_pred_cache_dic
|
| 656 |
+
self.transformer.current = model_pred_current
|
| 657 |
+
elif self.transformer.enable_teacache:
|
| 658 |
+
teacache_params.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
|
| 659 |
+
self.transformer.teacache_params = teacache_params
|
| 660 |
+
|
| 661 |
+
model_pred = self.predict(
|
| 662 |
+
t=t,
|
| 663 |
+
latents=latents,
|
| 664 |
+
prompt_embeds=prompt_embeds,
|
| 665 |
+
freqs_cis=freqs_cis,
|
| 666 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 667 |
+
ref_image_hidden_states=ref_latents,
|
| 668 |
+
)
|
| 669 |
+
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
| 670 |
+
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
| 671 |
+
|
| 672 |
+
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
| 673 |
+
if enable_taylorseer:
|
| 674 |
+
self.transformer.cache_dic = model_pred_ref_cache_dic
|
| 675 |
+
self.transformer.current = model_pred_ref_current
|
| 676 |
+
elif self.transformer.enable_teacache:
|
| 677 |
+
teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
|
| 678 |
+
self.transformer.teacache_params = teacache_params_ref
|
| 679 |
+
|
| 680 |
+
model_pred_ref = self.predict(
|
| 681 |
+
t=t,
|
| 682 |
+
latents=latents,
|
| 683 |
+
prompt_embeds=negative_prompt_embeds,
|
| 684 |
+
freqs_cis=freqs_cis,
|
| 685 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 686 |
+
ref_image_hidden_states=ref_latents,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
if enable_taylorseer:
|
| 690 |
+
self.transformer.cache_dic = model_pred_uncond_cache_dic
|
| 691 |
+
self.transformer.current = model_pred_uncond_current
|
| 692 |
+
elif self.transformer.enable_teacache:
|
| 693 |
+
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
|
| 694 |
+
self.transformer.teacache_params = teacache_params_uncond
|
| 695 |
+
|
| 696 |
+
model_pred_uncond = self.predict(
|
| 697 |
+
t=t,
|
| 698 |
+
latents=latents,
|
| 699 |
+
prompt_embeds=negative_prompt_embeds,
|
| 700 |
+
freqs_cis=freqs_cis,
|
| 701 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 702 |
+
ref_image_hidden_states=None,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
|
| 706 |
+
text_guidance_scale * (model_pred - model_pred_ref)
|
| 707 |
+
elif text_guidance_scale > 1.0:
|
| 708 |
+
if enable_taylorseer:
|
| 709 |
+
self.transformer.cache_dic = model_pred_uncond_cache_dic
|
| 710 |
+
self.transformer.current = model_pred_uncond_current
|
| 711 |
+
elif self.transformer.enable_teacache:
|
| 712 |
+
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
|
| 713 |
+
self.transformer.teacache_params = teacache_params_uncond
|
| 714 |
+
|
| 715 |
+
model_pred_uncond = self.predict(
|
| 716 |
+
t=t,
|
| 717 |
+
latents=latents,
|
| 718 |
+
prompt_embeds=negative_prompt_embeds,
|
| 719 |
+
freqs_cis=freqs_cis,
|
| 720 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 721 |
+
ref_image_hidden_states=None,
|
| 722 |
+
)
|
| 723 |
+
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
| 724 |
+
|
| 725 |
+
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
| 726 |
+
|
| 727 |
+
latents = latents.to(dtype=dtype)
|
| 728 |
+
|
| 729 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 730 |
+
progress_bar.update()
|
| 731 |
+
|
| 732 |
+
if step_func is not None:
|
| 733 |
+
step_func(i, self._num_timesteps)
|
| 734 |
+
|
| 735 |
+
if enable_taylorseer:
|
| 736 |
+
del model_pred_cache_dic, model_pred_ref_cache_dic, model_pred_uncond_cache_dic
|
| 737 |
+
del model_pred_current, model_pred_ref_current, model_pred_uncond_current
|
| 738 |
+
|
| 739 |
+
latents = latents.to(dtype=dtype)
|
| 740 |
+
if self.vae.config.scaling_factor is not None:
|
| 741 |
+
latents = latents / self.vae.config.scaling_factor
|
| 742 |
+
if self.vae.config.shift_factor is not None:
|
| 743 |
+
latents = latents + self.vae.config.shift_factor
|
| 744 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 745 |
+
|
| 746 |
+
return image
|
| 747 |
+
|
| 748 |
+
def predict(
|
| 749 |
+
self,
|
| 750 |
+
t,
|
| 751 |
+
latents,
|
| 752 |
+
prompt_embeds,
|
| 753 |
+
freqs_cis,
|
| 754 |
+
prompt_attention_mask,
|
| 755 |
+
ref_image_hidden_states,
|
| 756 |
+
):
|
| 757 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 758 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 759 |
+
|
| 760 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
| 761 |
+
|
| 762 |
+
optional_kwargs = {}
|
| 763 |
+
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
|
| 764 |
+
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
| 765 |
+
|
| 766 |
+
model_pred = self.transformer(
|
| 767 |
+
latents,
|
| 768 |
+
timestep,
|
| 769 |
+
prompt_embeds,
|
| 770 |
+
freqs_cis,
|
| 771 |
+
prompt_attention_mask,
|
| 772 |
+
**optional_kwargs
|
| 773 |
+
)
|
| 774 |
+
return model_pred
|
omnigen2/pipelines/omnigen2/pipeline_omnigen2_chat.py
ADDED
|
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OmniGen2 Diffusion Pipeline
|
| 3 |
+
|
| 4 |
+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
| 5 |
+
|
| 6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
you may not use this file except in compliance with the License.
|
| 8 |
+
You may obtain a copy of the License at
|
| 9 |
+
|
| 10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
|
| 12 |
+
Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
See the License for the specific language governing permissions and
|
| 16 |
+
limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import inspect
|
| 20 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
|
| 24 |
+
from PIL import Image
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
| 30 |
+
|
| 31 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 32 |
+
from ...models.transformers import OmniGen2Transformer2DModel
|
| 33 |
+
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
|
| 34 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 35 |
+
from diffusers.utils import (
|
| 36 |
+
is_torch_xla_available,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 40 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 41 |
+
|
| 42 |
+
from dataclasses import dataclass
|
| 43 |
+
|
| 44 |
+
import PIL.Image
|
| 45 |
+
|
| 46 |
+
from diffusers.utils import BaseOutput
|
| 47 |
+
|
| 48 |
+
from omnigen2.pipelines.image_processor import OmniGen2ImageProcessor
|
| 49 |
+
|
| 50 |
+
if is_torch_xla_available():
|
| 51 |
+
import torch_xla.core.xla_model as xm
|
| 52 |
+
|
| 53 |
+
XLA_AVAILABLE = True
|
| 54 |
+
else:
|
| 55 |
+
XLA_AVAILABLE = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class OmniGen2PipelineOutput(BaseOutput):
|
| 62 |
+
"""
|
| 63 |
+
Output class for OmniGen2 pipeline.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
images (Union[List[PIL.Image.Image], np.ndarray]):
|
| 67 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape
|
| 68 |
+
`(batch_size, height, width, num_channels)`. Contains the generated images.
|
| 69 |
+
"""
|
| 70 |
+
text: str
|
| 71 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 75 |
+
def retrieve_timesteps(
|
| 76 |
+
scheduler,
|
| 77 |
+
num_inference_steps: Optional[int] = None,
|
| 78 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 79 |
+
timesteps: Optional[List[int]] = None,
|
| 80 |
+
**kwargs,
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 84 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
scheduler (`SchedulerMixin`):
|
| 88 |
+
The scheduler to get timesteps from.
|
| 89 |
+
num_inference_steps (`int`):
|
| 90 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 91 |
+
must be `None`.
|
| 92 |
+
device (`str` or `torch.device`, *optional*):
|
| 93 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 94 |
+
timesteps (`List[int]`, *optional*):
|
| 95 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 96 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 97 |
+
sigmas (`List[float]`, *optional*):
|
| 98 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 99 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 103 |
+
second element is the number of inference steps.
|
| 104 |
+
"""
|
| 105 |
+
if timesteps is not None:
|
| 106 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 107 |
+
if not accepts_timesteps:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 110 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 111 |
+
)
|
| 112 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 113 |
+
timesteps = scheduler.timesteps
|
| 114 |
+
num_inference_steps = len(timesteps)
|
| 115 |
+
else:
|
| 116 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 117 |
+
timesteps = scheduler.timesteps
|
| 118 |
+
return timesteps, num_inference_steps
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class OmniGen2ChatPipeline(DiffusionPipeline):
|
| 122 |
+
"""
|
| 123 |
+
Pipeline for text-to-image generation using OmniGen2.
|
| 124 |
+
|
| 125 |
+
This pipeline implements a text-to-image generation model that uses:
|
| 126 |
+
- Qwen2.5-VL for text encoding
|
| 127 |
+
- A custom transformer architecture for image generation
|
| 128 |
+
- VAE for image encoding/decoding
|
| 129 |
+
- FlowMatchEulerDiscreteScheduler for noise scheduling
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
|
| 133 |
+
vae (AutoencoderKL): The VAE model for image encoding/decoding.
|
| 134 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
|
| 135 |
+
text_encoder (Qwen2_5_VLModel): The text encoder model.
|
| 136 |
+
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
model_cpu_offload_seq = "mllm->transformer->vae"
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
transformer: OmniGen2Transformer2DModel,
|
| 143 |
+
vae: AutoencoderKL,
|
| 144 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 145 |
+
mllm: Qwen2_5_VLForConditionalGeneration,
|
| 146 |
+
processor,
|
| 147 |
+
) -> None:
|
| 148 |
+
"""
|
| 149 |
+
Initialize the OmniGen2 pipeline.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
transformer: The transformer model for image generation.
|
| 153 |
+
vae: The VAE model for image encoding/decoding.
|
| 154 |
+
scheduler: The scheduler for noise scheduling.
|
| 155 |
+
text_encoder: The text encoder model.
|
| 156 |
+
tokenizer: The tokenizer for text processing.
|
| 157 |
+
"""
|
| 158 |
+
super().__init__()
|
| 159 |
+
|
| 160 |
+
self.register_modules(
|
| 161 |
+
transformer=transformer,
|
| 162 |
+
vae=vae,
|
| 163 |
+
scheduler=scheduler,
|
| 164 |
+
mllm=mllm,
|
| 165 |
+
processor=processor
|
| 166 |
+
)
|
| 167 |
+
self.vae_scale_factor = (
|
| 168 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 169 |
+
)
|
| 170 |
+
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
|
| 171 |
+
self.default_sample_size = 128
|
| 172 |
+
|
| 173 |
+
def prepare_latents(
|
| 174 |
+
self,
|
| 175 |
+
batch_size: int,
|
| 176 |
+
num_channels_latents: int,
|
| 177 |
+
height: int,
|
| 178 |
+
width: int,
|
| 179 |
+
dtype: torch.dtype,
|
| 180 |
+
device: torch.device,
|
| 181 |
+
generator: Optional[torch.Generator],
|
| 182 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 183 |
+
) -> torch.FloatTensor:
|
| 184 |
+
"""
|
| 185 |
+
Prepare the initial latents for the diffusion process.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
batch_size: The number of images to generate.
|
| 189 |
+
num_channels_latents: The number of channels in the latent space.
|
| 190 |
+
height: The height of the generated image.
|
| 191 |
+
width: The width of the generated image.
|
| 192 |
+
dtype: The data type of the latents.
|
| 193 |
+
device: The device to place the latents on.
|
| 194 |
+
generator: The random number generator to use.
|
| 195 |
+
latents: Optional pre-computed latents to use instead of random initialization.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
torch.FloatTensor: The prepared latents tensor.
|
| 199 |
+
"""
|
| 200 |
+
height = int(height) // self.vae_scale_factor
|
| 201 |
+
width = int(width) // self.vae_scale_factor
|
| 202 |
+
|
| 203 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 204 |
+
|
| 205 |
+
if latents is None:
|
| 206 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 207 |
+
else:
|
| 208 |
+
latents = latents.to(device)
|
| 209 |
+
return latents
|
| 210 |
+
|
| 211 |
+
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
|
| 212 |
+
"""
|
| 213 |
+
Encode an image into the VAE latent space.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
img: The input image tensor to encode.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
torch.FloatTensor: The encoded latent representation.
|
| 220 |
+
"""
|
| 221 |
+
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
|
| 222 |
+
if self.vae.config.shift_factor is not None:
|
| 223 |
+
z0 = z0 - self.vae.config.shift_factor
|
| 224 |
+
if self.vae.config.scaling_factor is not None:
|
| 225 |
+
z0 = z0 * self.vae.config.scaling_factor
|
| 226 |
+
z0 = z0.to(dtype=self.vae.dtype)
|
| 227 |
+
return z0
|
| 228 |
+
|
| 229 |
+
def prepare_image(
|
| 230 |
+
self,
|
| 231 |
+
images: Union[List[PIL.Image.Image], PIL.Image.Image],
|
| 232 |
+
batch_size: int,
|
| 233 |
+
num_images_per_prompt: int,
|
| 234 |
+
max_pixels: int,
|
| 235 |
+
max_side_length: int,
|
| 236 |
+
device: torch.device,
|
| 237 |
+
dtype: torch.dtype,
|
| 238 |
+
) -> List[Optional[torch.FloatTensor]]:
|
| 239 |
+
"""
|
| 240 |
+
Prepare input images for processing by encoding them into the VAE latent space.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
images: Single image or list of images to process.
|
| 244 |
+
batch_size: The number of images to generate per prompt.
|
| 245 |
+
num_images_per_prompt: The number of images to generate for each prompt.
|
| 246 |
+
device: The device to place the encoded latents on.
|
| 247 |
+
dtype: The data type of the encoded latents.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
|
| 251 |
+
"""
|
| 252 |
+
if batch_size == 1:
|
| 253 |
+
images = [images]
|
| 254 |
+
latents = []
|
| 255 |
+
for i, img in enumerate(images):
|
| 256 |
+
if img is not None and len(img) > 0:
|
| 257 |
+
ref_latents = []
|
| 258 |
+
for j, img_j in enumerate(img):
|
| 259 |
+
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
|
| 260 |
+
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
|
| 261 |
+
else:
|
| 262 |
+
ref_latents = None
|
| 263 |
+
for _ in range(num_images_per_prompt):
|
| 264 |
+
latents.append(ref_latents)
|
| 265 |
+
|
| 266 |
+
return latents
|
| 267 |
+
|
| 268 |
+
def _apply_chat_template(self, prompt: str, images: List = None):
|
| 269 |
+
if images is not None:
|
| 270 |
+
prompt = "".join(
|
| 271 |
+
[
|
| 272 |
+
f"<img{i}>: <|vision_start|><|image_pad|><|vision_end|>"
|
| 273 |
+
for i in range(1, len(images) + 1)
|
| 274 |
+
]
|
| 275 |
+
) + prompt
|
| 276 |
+
prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
| 277 |
+
return prompt
|
| 278 |
+
|
| 279 |
+
def _get_qwen2_prompt_embeds(
|
| 280 |
+
self,
|
| 281 |
+
prompt: Union[str, List[str]],
|
| 282 |
+
input_images = None,
|
| 283 |
+
device: Optional[torch.device] = None,
|
| 284 |
+
use_only_text_hidden_states: bool = True,
|
| 285 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 286 |
+
"""
|
| 287 |
+
Get prompt embeddings from the Qwen2 text encoder.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
prompt: The prompt or list of prompts to encode.
|
| 291 |
+
device: The device to place the embeddings on. If None, uses the pipeline's device.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
| 295 |
+
- The prompt embeddings tensor
|
| 296 |
+
- The attention mask tensor
|
| 297 |
+
|
| 298 |
+
Raises:
|
| 299 |
+
Warning: If the input text is truncated due to sequence length limitations.
|
| 300 |
+
"""
|
| 301 |
+
device = device or self._execution_device
|
| 302 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 303 |
+
|
| 304 |
+
inputs = self.processor(
|
| 305 |
+
text=prompt,
|
| 306 |
+
images=input_images,
|
| 307 |
+
videos=None,
|
| 308 |
+
padding=True,
|
| 309 |
+
return_tensors="pt",
|
| 310 |
+
)
|
| 311 |
+
inputs = inputs.to(device)
|
| 312 |
+
|
| 313 |
+
prompt_embeds = self.mllm(
|
| 314 |
+
**inputs,
|
| 315 |
+
output_hidden_states=True,
|
| 316 |
+
).hidden_states[-1]
|
| 317 |
+
|
| 318 |
+
text_input_ids = inputs.input_ids
|
| 319 |
+
text_mask = inputs.attention_mask
|
| 320 |
+
if use_only_text_hidden_states:
|
| 321 |
+
mask = text_input_ids != self.mllm.config.image_token_id
|
| 322 |
+
mask = mask & text_mask
|
| 323 |
+
mask = mask.bool()
|
| 324 |
+
|
| 325 |
+
text_l = mask.sum(dim=-1)
|
| 326 |
+
max_l = text_l.max()
|
| 327 |
+
text_batch_size = prompt_embeds.size(0)
|
| 328 |
+
new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
| 329 |
+
new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device)
|
| 330 |
+
for i in range(text_batch_size):
|
| 331 |
+
new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]]
|
| 332 |
+
new_text_mask[i, :text_l[i]] = 1
|
| 333 |
+
|
| 334 |
+
prompt_embeds = new_prompt_embeds
|
| 335 |
+
text_mask = new_text_mask
|
| 336 |
+
|
| 337 |
+
prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device)
|
| 338 |
+
return prompt_embeds, text_mask
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def encode_prompt(
|
| 342 |
+
self,
|
| 343 |
+
prompt: Union[str, List[str]],
|
| 344 |
+
input_images: Optional[Union[str, List[str]]] = None,
|
| 345 |
+
do_classifier_free_guidance: bool = True,
|
| 346 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 347 |
+
num_images_per_prompt: int = 1,
|
| 348 |
+
device: Optional[torch.device] = None,
|
| 349 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 350 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 351 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 352 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 353 |
+
max_sequence_length: int = 256,
|
| 354 |
+
use_text_encoder_penultimate_layer_feats: bool = False
|
| 355 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 356 |
+
r"""
|
| 357 |
+
Encodes the prompt into text encoder hidden states.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 361 |
+
prompt to be encoded
|
| 362 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 363 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
| 364 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
| 365 |
+
Lumina-T2I, this should be "".
|
| 366 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 367 |
+
whether to use classifier free guidance or not
|
| 368 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 369 |
+
number of images that should be generated per prompt
|
| 370 |
+
device: (`torch.device`, *optional*):
|
| 371 |
+
torch device to place the resulting embeddings on
|
| 372 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 373 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 374 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 375 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 376 |
+
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
| 377 |
+
max_sequence_length (`int`, defaults to `256`):
|
| 378 |
+
Maximum sequence length to use for the prompt.
|
| 379 |
+
"""
|
| 380 |
+
device = device or self._execution_device
|
| 381 |
+
|
| 382 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 383 |
+
|
| 384 |
+
if prompt is not None:
|
| 385 |
+
batch_size = len(prompt)
|
| 386 |
+
else:
|
| 387 |
+
batch_size = prompt_embeds.shape[0]
|
| 388 |
+
if prompt_embeds is None:
|
| 389 |
+
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
| 390 |
+
prompt=prompt,
|
| 391 |
+
input_images=input_images,
|
| 392 |
+
device=device,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
batch_size, seq_len, _ = prompt_embeds.shape
|
| 396 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 397 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 398 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 399 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
| 400 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
| 401 |
+
|
| 402 |
+
# Get negative embeddings for classifier free guidance
|
| 403 |
+
negative_prompt_embeds, negative_prompt_attention_mask = None, None
|
| 404 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 405 |
+
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
| 406 |
+
|
| 407 |
+
# Normalize str to list
|
| 408 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 409 |
+
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
|
| 410 |
+
|
| 411 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 412 |
+
raise TypeError(
|
| 413 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 414 |
+
f" {type(prompt)}."
|
| 415 |
+
)
|
| 416 |
+
elif isinstance(negative_prompt, str):
|
| 417 |
+
negative_prompt = [negative_prompt]
|
| 418 |
+
elif batch_size != len(negative_prompt):
|
| 419 |
+
raise ValueError(
|
| 420 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 421 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 422 |
+
" the batch size of `prompt`."
|
| 423 |
+
)
|
| 424 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
| 425 |
+
prompt=negative_prompt,
|
| 426 |
+
device=device,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
| 430 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 431 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 432 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 433 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
| 434 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
| 435 |
+
batch_size * num_images_per_prompt, -1
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
| 439 |
+
|
| 440 |
+
@property
|
| 441 |
+
def num_timesteps(self):
|
| 442 |
+
return self._num_timesteps
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
def text_guidance_scale(self):
|
| 446 |
+
return self._text_guidance_scale
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def image_guidance_scale(self):
|
| 450 |
+
return self._image_guidance_scale
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def cfg_range(self):
|
| 454 |
+
return self._cfg_range
|
| 455 |
+
|
| 456 |
+
def prepare_inputs_for_text_generation(self, prompts, input_images, device):
|
| 457 |
+
if isinstance(prompts, str):
|
| 458 |
+
prompts = [prompts]
|
| 459 |
+
|
| 460 |
+
ori_padding_side = self.processor.tokenizer.padding_side
|
| 461 |
+
self.processor.tokenizer.padding_side = "left"
|
| 462 |
+
inputs = self.processor(
|
| 463 |
+
text=prompts,
|
| 464 |
+
images=input_images,
|
| 465 |
+
videos=None,
|
| 466 |
+
padding=True,
|
| 467 |
+
return_tensors="pt",
|
| 468 |
+
).to(device)
|
| 469 |
+
self.processor.tokenizer.padding_side = ori_padding_side
|
| 470 |
+
return inputs
|
| 471 |
+
|
| 472 |
+
def generate_text(self, prompt, input_images):
|
| 473 |
+
inputs = self.prepare_inputs_for_text_generation(
|
| 474 |
+
prompt, input_images, self.mllm.device
|
| 475 |
+
)
|
| 476 |
+
generated_ids = self.mllm.generate(
|
| 477 |
+
**inputs,
|
| 478 |
+
tokenizer=self.processor.tokenizer,
|
| 479 |
+
max_new_tokens=256,
|
| 480 |
+
stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"],
|
| 481 |
+
) # stop_words=[151643, 151645, 151665]
|
| 482 |
+
generated_ids_trimmed = [
|
| 483 |
+
out_ids[len(in_ids) :]
|
| 484 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 485 |
+
]
|
| 486 |
+
output_texts = self.processor.batch_decode(
|
| 487 |
+
generated_ids_trimmed,
|
| 488 |
+
# skip_special_tokens=True,
|
| 489 |
+
skip_special_tokens=False,
|
| 490 |
+
clean_up_tokenization_spaces=False,
|
| 491 |
+
)
|
| 492 |
+
return output_texts
|
| 493 |
+
|
| 494 |
+
def generate_image(
|
| 495 |
+
self,
|
| 496 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 497 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 498 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 499 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 500 |
+
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
| 501 |
+
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
| 502 |
+
use_text_encoder_penultimate_layer_feats: bool = False,
|
| 503 |
+
max_sequence_length: Optional[int] = None,
|
| 504 |
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
| 505 |
+
input_images: Optional[List[PIL.Image.Image]] = None,
|
| 506 |
+
num_images_per_prompt: int = 1,
|
| 507 |
+
height: Optional[int] = None,
|
| 508 |
+
width: Optional[int] = None,
|
| 509 |
+
max_pixels: int = 1024 * 1024,
|
| 510 |
+
max_input_image_side_length: int = 1024,
|
| 511 |
+
align_res: bool = True,
|
| 512 |
+
num_inference_steps: int = 28,
|
| 513 |
+
text_guidance_scale: float = 4.0,
|
| 514 |
+
image_guidance_scale: float = 1.0,
|
| 515 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
| 516 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 517 |
+
timesteps: List[int] = None,
|
| 518 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 519 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 520 |
+
output_type: Optional[str] = "pil",
|
| 521 |
+
return_dict: bool = True,
|
| 522 |
+
verbose: bool = False,
|
| 523 |
+
step_func=None,
|
| 524 |
+
):
|
| 525 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 526 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 527 |
+
|
| 528 |
+
self._text_guidance_scale = text_guidance_scale
|
| 529 |
+
self._image_guidance_scale = image_guidance_scale
|
| 530 |
+
self._cfg_range = cfg_range
|
| 531 |
+
self._attention_kwargs = attention_kwargs
|
| 532 |
+
|
| 533 |
+
# 2. Define call parameters
|
| 534 |
+
if prompt is not None and isinstance(prompt, str):
|
| 535 |
+
batch_size = 1
|
| 536 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 537 |
+
batch_size = len(prompt)
|
| 538 |
+
else:
|
| 539 |
+
batch_size = prompt_embeds.shape[0]
|
| 540 |
+
|
| 541 |
+
device = self._execution_device
|
| 542 |
+
|
| 543 |
+
# 3. Encode input promptb
|
| 544 |
+
(
|
| 545 |
+
prompt_embeds,
|
| 546 |
+
prompt_attention_mask,
|
| 547 |
+
negative_prompt_embeds,
|
| 548 |
+
negative_prompt_attention_mask,
|
| 549 |
+
) = self.encode_prompt(
|
| 550 |
+
prompt,
|
| 551 |
+
input_images,
|
| 552 |
+
self.text_guidance_scale > 1.0,
|
| 553 |
+
negative_prompt=negative_prompt,
|
| 554 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 555 |
+
device=device,
|
| 556 |
+
prompt_embeds=prompt_embeds,
|
| 557 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 558 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 559 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 560 |
+
max_sequence_length=max_sequence_length,
|
| 561 |
+
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
dtype = self.vae.dtype
|
| 565 |
+
# 3. Prepare control image
|
| 566 |
+
ref_latents = self.prepare_image(
|
| 567 |
+
images=input_images,
|
| 568 |
+
batch_size=batch_size,
|
| 569 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 570 |
+
max_pixels=max_pixels,
|
| 571 |
+
max_side_length=max_input_image_side_length,
|
| 572 |
+
device=device,
|
| 573 |
+
dtype=dtype,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
if input_images is None:
|
| 577 |
+
input_images = []
|
| 578 |
+
|
| 579 |
+
if len(input_images) == 1 and align_res:
|
| 580 |
+
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
|
| 581 |
+
ori_width, ori_height = width, height
|
| 582 |
+
else:
|
| 583 |
+
ori_width, ori_height = width, height
|
| 584 |
+
|
| 585 |
+
cur_pixels = height * width
|
| 586 |
+
ratio = (max_pixels / cur_pixels) ** 0.5
|
| 587 |
+
ratio = min(ratio, 1.0)
|
| 588 |
+
|
| 589 |
+
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
|
| 590 |
+
|
| 591 |
+
if len(input_images) == 0:
|
| 592 |
+
self._image_guidance_scale = 1
|
| 593 |
+
|
| 594 |
+
# 4. Prepare latents.
|
| 595 |
+
latent_channels = self.transformer.config.in_channels
|
| 596 |
+
latents = self.prepare_latents(
|
| 597 |
+
batch_size * num_images_per_prompt,
|
| 598 |
+
latent_channels,
|
| 599 |
+
height,
|
| 600 |
+
width,
|
| 601 |
+
prompt_embeds.dtype,
|
| 602 |
+
device,
|
| 603 |
+
generator,
|
| 604 |
+
latents,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
| 608 |
+
self.transformer.config.axes_dim_rope,
|
| 609 |
+
self.transformer.config.axes_lens,
|
| 610 |
+
theta=10000,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
image = self.processing(
|
| 614 |
+
latents=latents,
|
| 615 |
+
ref_latents=ref_latents,
|
| 616 |
+
prompt_embeds=prompt_embeds,
|
| 617 |
+
freqs_cis=freqs_cis,
|
| 618 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 619 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 620 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 621 |
+
num_inference_steps=num_inference_steps,
|
| 622 |
+
timesteps=timesteps,
|
| 623 |
+
device=device,
|
| 624 |
+
dtype=dtype,
|
| 625 |
+
verbose=verbose,
|
| 626 |
+
step_func=step_func,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
|
| 630 |
+
|
| 631 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 632 |
+
|
| 633 |
+
# Offload all models
|
| 634 |
+
self.maybe_free_model_hooks()
|
| 635 |
+
return image
|
| 636 |
+
|
| 637 |
+
@torch.no_grad()
|
| 638 |
+
def __call__(
|
| 639 |
+
self,
|
| 640 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 641 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 642 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 643 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 644 |
+
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
| 645 |
+
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
| 646 |
+
use_text_encoder_penultimate_layer_feats: bool = False,
|
| 647 |
+
max_sequence_length: Optional[int] = None,
|
| 648 |
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
| 649 |
+
input_images: Optional[List[PIL.Image.Image]] = None,
|
| 650 |
+
num_images_per_prompt: int = 1,
|
| 651 |
+
height: Optional[int] = 1024,
|
| 652 |
+
width: Optional[int] = 1024,
|
| 653 |
+
max_pixels: Optional[int] = 1024 * 1024,
|
| 654 |
+
max_input_image_side_length: int = 1024,
|
| 655 |
+
align_res: bool = True,
|
| 656 |
+
num_inference_steps: int = 28,
|
| 657 |
+
text_guidance_scale: float = 4.0,
|
| 658 |
+
image_guidance_scale: float = 1.0,
|
| 659 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
| 660 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 661 |
+
timesteps: List[int] = None,
|
| 662 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 663 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 664 |
+
output_type: Optional[str] = "pil",
|
| 665 |
+
return_dict: bool = True,
|
| 666 |
+
verbose: bool = False,
|
| 667 |
+
step_func=None,
|
| 668 |
+
):
|
| 669 |
+
assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn"
|
| 670 |
+
|
| 671 |
+
# input_images = self.preprocess_images(input_images, max_input_image_size)
|
| 672 |
+
prompt = self._apply_chat_template(prompt, input_images)
|
| 673 |
+
generated_text = self.generate_text(prompt, input_images)[0]
|
| 674 |
+
|
| 675 |
+
images = None
|
| 676 |
+
if generated_text.startswith("<|img|>"):
|
| 677 |
+
#TODO: reuse the hidden state when generate text instead of re-generating
|
| 678 |
+
prompt = prompt + generated_text.split("<|img|>")[0]
|
| 679 |
+
images = self.generate_image(
|
| 680 |
+
prompt=prompt,
|
| 681 |
+
negative_prompt=negative_prompt,
|
| 682 |
+
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats,
|
| 683 |
+
max_sequence_length=max_sequence_length,
|
| 684 |
+
input_images=input_images,
|
| 685 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 686 |
+
height=height,
|
| 687 |
+
width=width,
|
| 688 |
+
max_pixels=max_pixels,
|
| 689 |
+
max_input_image_side_length=max_input_image_side_length,
|
| 690 |
+
align_res=align_res,
|
| 691 |
+
num_inference_steps=num_inference_steps,
|
| 692 |
+
text_guidance_scale=text_guidance_scale,
|
| 693 |
+
image_guidance_scale=image_guidance_scale,
|
| 694 |
+
cfg_range=cfg_range,
|
| 695 |
+
timesteps=timesteps,
|
| 696 |
+
generator=generator,
|
| 697 |
+
latents=latents,
|
| 698 |
+
return_dict=False,
|
| 699 |
+
verbose=verbose,
|
| 700 |
+
step_func=step_func,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
generated_text = generated_text.replace("<|im_end|>", "")
|
| 704 |
+
if not return_dict:
|
| 705 |
+
return generated_text, images
|
| 706 |
+
else:
|
| 707 |
+
return OmniGen2PipelineOutput(text=generated_text, images=images)
|
| 708 |
+
|
| 709 |
+
def processing(
|
| 710 |
+
self,
|
| 711 |
+
latents,
|
| 712 |
+
ref_latents,
|
| 713 |
+
prompt_embeds,
|
| 714 |
+
freqs_cis,
|
| 715 |
+
negative_prompt_embeds,
|
| 716 |
+
prompt_attention_mask,
|
| 717 |
+
negative_prompt_attention_mask,
|
| 718 |
+
num_inference_steps,
|
| 719 |
+
timesteps,
|
| 720 |
+
device,
|
| 721 |
+
dtype,
|
| 722 |
+
verbose,
|
| 723 |
+
step_func=None
|
| 724 |
+
):
|
| 725 |
+
batch_size = latents.shape[0]
|
| 726 |
+
|
| 727 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 728 |
+
self.scheduler,
|
| 729 |
+
num_inference_steps,
|
| 730 |
+
device,
|
| 731 |
+
timesteps,
|
| 732 |
+
num_tokens=latents.shape[-2] * latents.shape[-1]
|
| 733 |
+
)
|
| 734 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 735 |
+
self._num_timesteps = len(timesteps)
|
| 736 |
+
|
| 737 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 738 |
+
for i, t in enumerate(timesteps):
|
| 739 |
+
model_pred = self.predict(
|
| 740 |
+
t=t,
|
| 741 |
+
latents=latents,
|
| 742 |
+
prompt_embeds=prompt_embeds,
|
| 743 |
+
freqs_cis=freqs_cis,
|
| 744 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 745 |
+
ref_image_hidden_states=ref_latents,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
| 749 |
+
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
| 750 |
+
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
| 751 |
+
model_pred_ref = self.predict(
|
| 752 |
+
t=t,
|
| 753 |
+
latents=latents,
|
| 754 |
+
prompt_embeds=negative_prompt_embeds,
|
| 755 |
+
freqs_cis=freqs_cis,
|
| 756 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 757 |
+
ref_image_hidden_states=ref_latents,
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
if image_guidance_scale != 1:
|
| 761 |
+
model_pred_uncond = self.predict(
|
| 762 |
+
t=t,
|
| 763 |
+
latents=latents,
|
| 764 |
+
prompt_embeds=negative_prompt_embeds,
|
| 765 |
+
freqs_cis=freqs_cis,
|
| 766 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 767 |
+
ref_image_hidden_states=None,
|
| 768 |
+
)
|
| 769 |
+
else:
|
| 770 |
+
model_pred_uncond = torch.zeros_like(model_pred)
|
| 771 |
+
|
| 772 |
+
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
|
| 773 |
+
text_guidance_scale * (model_pred - model_pred_ref)
|
| 774 |
+
elif text_guidance_scale > 1.0:
|
| 775 |
+
model_pred_uncond = self.predict(
|
| 776 |
+
t=t,
|
| 777 |
+
latents=latents,
|
| 778 |
+
prompt_embeds=negative_prompt_embeds,
|
| 779 |
+
freqs_cis=freqs_cis,
|
| 780 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 781 |
+
ref_image_hidden_states=None,
|
| 782 |
+
)
|
| 783 |
+
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
| 784 |
+
|
| 785 |
+
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
| 786 |
+
|
| 787 |
+
latents = latents.to(dtype=dtype)
|
| 788 |
+
|
| 789 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 790 |
+
progress_bar.update()
|
| 791 |
+
|
| 792 |
+
if step_func is not None:
|
| 793 |
+
step_func(i, self._num_timesteps)
|
| 794 |
+
|
| 795 |
+
latents = latents.to(dtype=dtype)
|
| 796 |
+
if self.vae.config.scaling_factor is not None:
|
| 797 |
+
latents = latents / self.vae.config.scaling_factor
|
| 798 |
+
if self.vae.config.shift_factor is not None:
|
| 799 |
+
latents = latents + self.vae.config.shift_factor
|
| 800 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 801 |
+
|
| 802 |
+
return image
|
| 803 |
+
|
| 804 |
+
def predict(
|
| 805 |
+
self,
|
| 806 |
+
t,
|
| 807 |
+
latents,
|
| 808 |
+
prompt_embeds,
|
| 809 |
+
freqs_cis,
|
| 810 |
+
prompt_attention_mask,
|
| 811 |
+
ref_image_hidden_states,
|
| 812 |
+
):
|
| 813 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 814 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 815 |
+
|
| 816 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
| 817 |
+
|
| 818 |
+
optional_kwargs = {}
|
| 819 |
+
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
|
| 820 |
+
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
| 821 |
+
|
| 822 |
+
model_pred = self.transformer(
|
| 823 |
+
latents,
|
| 824 |
+
timestep,
|
| 825 |
+
prompt_embeds,
|
| 826 |
+
freqs_cis,
|
| 827 |
+
prompt_attention_mask,
|
| 828 |
+
**optional_kwargs
|
| 829 |
+
)
|
| 830 |
+
return model_pred
|
omnigen2/pipelines/pipeline_utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
|
| 5 |
+
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
|
| 6 |
+
:param pipeline:
|
| 7 |
+
:param prompt:
|
| 8 |
+
:param negative_prompt:
|
| 9 |
+
:param device:
|
| 10 |
+
:return:
|
| 11 |
+
"""
|
| 12 |
+
max_length = pipeline.tokenizer.model_max_length
|
| 13 |
+
|
| 14 |
+
# simple way to determine length of tokens
|
| 15 |
+
# count_prompt = len(prompt.split(" "))
|
| 16 |
+
# count_negative_prompt = len(negative_prompt.split(" "))
|
| 17 |
+
|
| 18 |
+
# create the tensor based on which prompt is longer
|
| 19 |
+
# if count_prompt >= count_negative_prompt:
|
| 20 |
+
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
|
| 21 |
+
# input_ids = pipeline.tokenizer(prompt, padding="max_length",
|
| 22 |
+
# max_length=pipeline.tokenizer.model_max_length,
|
| 23 |
+
# truncation=True,
|
| 24 |
+
# return_tensors="pt",).input_ids.to(device)
|
| 25 |
+
shape_max_length = input_ids.shape[-1]
|
| 26 |
+
|
| 27 |
+
if negative_prompt is not None:
|
| 28 |
+
negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
|
| 29 |
+
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
|
| 30 |
+
|
| 31 |
+
# else:
|
| 32 |
+
# negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
|
| 33 |
+
# shape_max_length = negative_ids.shape[-1]
|
| 34 |
+
# input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
|
| 35 |
+
# max_length=shape_max_length).input_ids.to(device)
|
| 36 |
+
|
| 37 |
+
concat_embeds = []
|
| 38 |
+
neg_embeds = []
|
| 39 |
+
for i in range(0, shape_max_length, max_length):
|
| 40 |
+
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
| 41 |
+
attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
|
| 42 |
+
else:
|
| 43 |
+
attention_mask = None
|
| 44 |
+
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
|
| 45 |
+
attention_mask=attention_mask)[0])
|
| 46 |
+
|
| 47 |
+
if negative_prompt is not None:
|
| 48 |
+
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
| 49 |
+
attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
|
| 50 |
+
else:
|
| 51 |
+
attention_mask = None
|
| 52 |
+
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
|
| 53 |
+
attention_mask=attention_mask)[0])
|
| 54 |
+
|
| 55 |
+
concat_embeds = torch.cat(concat_embeds, dim=1)
|
| 56 |
+
|
| 57 |
+
if negative_prompt is not None:
|
| 58 |
+
neg_embeds = torch.cat(neg_embeds, dim=1)
|
| 59 |
+
else:
|
| 60 |
+
neg_embeds = None
|
| 61 |
+
|
| 62 |
+
return concat_embeds, neg_embeds
|
omnigen2/schedulers/__init__.py
ADDED
|
File without changes
|
omnigen2/schedulers/scheduling_dpmsolver_multistep.py
ADDED
|
@@ -0,0 +1,1052 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.utils import deprecate, is_scipy_available
|
| 25 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_scipy_available():
|
| 30 |
+
import scipy.stats
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
| 34 |
+
def betas_for_alpha_bar(
|
| 35 |
+
num_diffusion_timesteps,
|
| 36 |
+
max_beta=0.999,
|
| 37 |
+
alpha_transform_type="cosine",
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 41 |
+
(1-beta) over time from t = [0,1].
|
| 42 |
+
|
| 43 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 44 |
+
to that part of the diffusion process.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 49 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 50 |
+
prevent singularities.
|
| 51 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 52 |
+
Choose from `cosine` or `exp`
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 56 |
+
"""
|
| 57 |
+
if alpha_transform_type == "cosine":
|
| 58 |
+
|
| 59 |
+
def alpha_bar_fn(t):
|
| 60 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 61 |
+
|
| 62 |
+
elif alpha_transform_type == "exp":
|
| 63 |
+
|
| 64 |
+
def alpha_bar_fn(t):
|
| 65 |
+
return math.exp(t * -12.0)
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 69 |
+
|
| 70 |
+
betas = []
|
| 71 |
+
for i in range(num_diffusion_timesteps):
|
| 72 |
+
t1 = i / num_diffusion_timesteps
|
| 73 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 74 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 75 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
| 79 |
+
def rescale_zero_terminal_snr(betas):
|
| 80 |
+
"""
|
| 81 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
betas (`torch.Tensor`):
|
| 86 |
+
the betas that the scheduler is being initialized with.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 90 |
+
"""
|
| 91 |
+
# Convert betas to alphas_bar_sqrt
|
| 92 |
+
alphas = 1.0 - betas
|
| 93 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 94 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 95 |
+
|
| 96 |
+
# Store old values.
|
| 97 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 98 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 99 |
+
|
| 100 |
+
# Shift so the last timestep is zero.
|
| 101 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 102 |
+
|
| 103 |
+
# Scale so the first timestep is back to the old value.
|
| 104 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 105 |
+
|
| 106 |
+
# Convert alphas_bar_sqrt to betas
|
| 107 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 108 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 109 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 110 |
+
betas = 1 - alphas
|
| 111 |
+
|
| 112 |
+
return betas
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 116 |
+
"""
|
| 117 |
+
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
| 118 |
+
|
| 119 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 120 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 124 |
+
The number of diffusion steps to train the model.
|
| 125 |
+
beta_start (`float`, defaults to 0.0001):
|
| 126 |
+
The starting `beta` value of inference.
|
| 127 |
+
beta_end (`float`, defaults to 0.02):
|
| 128 |
+
The final `beta` value.
|
| 129 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
| 130 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 131 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
| 132 |
+
trained_betas (`np.ndarray`, *optional*):
|
| 133 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 134 |
+
solver_order (`int`, defaults to 2):
|
| 135 |
+
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
| 136 |
+
sampling, and `solver_order=3` for unconditional sampling.
|
| 137 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 138 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 139 |
+
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
| 140 |
+
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
| 141 |
+
thresholding (`bool`, defaults to `False`):
|
| 142 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 143 |
+
as Stable Diffusion.
|
| 144 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 145 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 146 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 147 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
| 148 |
+
`algorithm_type="dpmsolver++"`.
|
| 149 |
+
algorithm_type (`str`, defaults to `dpmsolver++`):
|
| 150 |
+
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
| 151 |
+
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
| 152 |
+
paper, and the `dpmsolver++` type implements the algorithms in the
|
| 153 |
+
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
| 154 |
+
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
| 155 |
+
solver_type (`str`, defaults to `midpoint`):
|
| 156 |
+
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
| 157 |
+
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
| 158 |
+
lower_order_final (`bool`, defaults to `True`):
|
| 159 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 160 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 161 |
+
euler_at_final (`bool`, defaults to `False`):
|
| 162 |
+
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
| 163 |
+
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
| 164 |
+
steps, but sometimes may result in blurring.
|
| 165 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 166 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 167 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 168 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
| 169 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
| 170 |
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
| 171 |
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
| 172 |
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
| 173 |
+
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
|
| 174 |
+
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
| 175 |
+
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
| 176 |
+
`lambda(t)`.
|
| 177 |
+
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
| 178 |
+
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
| 179 |
+
flow_shift (`float`, *optional*, defaults to 1.0):
|
| 180 |
+
The shift value for the timestep schedule for flow matching.
|
| 181 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 182 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 183 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 184 |
+
lambda_min_clipped (`float`, defaults to `-inf`):
|
| 185 |
+
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
| 186 |
+
cosine (`squaredcos_cap_v2`) noise schedule.
|
| 187 |
+
variance_type (`str`, *optional*):
|
| 188 |
+
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
| 189 |
+
contains the predicted Gaussian variance.
|
| 190 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 191 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 192 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 193 |
+
steps_offset (`int`, defaults to 0):
|
| 194 |
+
An offset added to the inference steps, as required by some model families.
|
| 195 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 196 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 197 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 198 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 202 |
+
order = 1
|
| 203 |
+
|
| 204 |
+
@register_to_config
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
num_train_timesteps: int = 1000,
|
| 208 |
+
beta_start: float = 0.0001,
|
| 209 |
+
beta_end: float = 0.02,
|
| 210 |
+
beta_schedule: str = "linear",
|
| 211 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 212 |
+
solver_order: int = 2,
|
| 213 |
+
prediction_type: str = "epsilon",
|
| 214 |
+
thresholding: bool = False,
|
| 215 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 216 |
+
sample_max_value: float = 1.0,
|
| 217 |
+
algorithm_type: str = "dpmsolver++",
|
| 218 |
+
solver_type: str = "midpoint",
|
| 219 |
+
lower_order_final: bool = True,
|
| 220 |
+
euler_at_final: bool = False,
|
| 221 |
+
final_sigmas_type: str = 'zero',
|
| 222 |
+
dynamic_time_shift: bool = True
|
| 223 |
+
):
|
| 224 |
+
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 225 |
+
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
| 226 |
+
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
| 227 |
+
|
| 228 |
+
if trained_betas is not None:
|
| 229 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 230 |
+
elif beta_schedule == "linear":
|
| 231 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 232 |
+
elif beta_schedule == "scaled_linear":
|
| 233 |
+
# this schedule is very specific to the latent diffusion model.
|
| 234 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
| 235 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
| 236 |
+
# Glide cosine schedule
|
| 237 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
| 238 |
+
else:
|
| 239 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 240 |
+
self.alphas = 1.0 - self.betas
|
| 241 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 242 |
+
|
| 243 |
+
# Currently we only support VP-type noise schedule
|
| 244 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
| 245 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
| 246 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
| 247 |
+
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
| 248 |
+
|
| 249 |
+
# standard deviation of the initial noise distribution
|
| 250 |
+
self.init_noise_sigma = 1.0
|
| 251 |
+
|
| 252 |
+
# settings for DPM-Solver
|
| 253 |
+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
| 254 |
+
if algorithm_type == "deis":
|
| 255 |
+
self.register_to_config(algorithm_type="dpmsolver++")
|
| 256 |
+
else:
|
| 257 |
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
| 258 |
+
|
| 259 |
+
if solver_type not in ["midpoint", "heun"]:
|
| 260 |
+
if solver_type in ["logrho", "bh1", "bh2"]:
|
| 261 |
+
self.register_to_config(solver_type="midpoint")
|
| 262 |
+
else:
|
| 263 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 264 |
+
|
| 265 |
+
# if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
| 266 |
+
# raise ValueError(
|
| 267 |
+
# f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
| 268 |
+
# )
|
| 269 |
+
|
| 270 |
+
# setable values
|
| 271 |
+
self.num_inference_steps = None
|
| 272 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 273 |
+
self.timesteps = torch.from_numpy(timesteps)
|
| 274 |
+
self.model_outputs = [None] * solver_order
|
| 275 |
+
self.lower_order_nums = 0
|
| 276 |
+
self._step_index = None
|
| 277 |
+
self._begin_index = None
|
| 278 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def step_index(self):
|
| 282 |
+
"""
|
| 283 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 284 |
+
"""
|
| 285 |
+
return self._step_index
|
| 286 |
+
|
| 287 |
+
@property
|
| 288 |
+
def begin_index(self):
|
| 289 |
+
"""
|
| 290 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 291 |
+
"""
|
| 292 |
+
return self._begin_index
|
| 293 |
+
|
| 294 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 295 |
+
"""
|
| 296 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
begin_index (`int`):
|
| 300 |
+
The begin index for the scheduler.
|
| 301 |
+
"""
|
| 302 |
+
self._begin_index = begin_index
|
| 303 |
+
|
| 304 |
+
def set_timesteps(
|
| 305 |
+
self,
|
| 306 |
+
num_inference_steps: int = None,
|
| 307 |
+
device: Union[str, torch.device] = None,
|
| 308 |
+
timesteps: Optional[List[int]] = None,
|
| 309 |
+
num_tokens: Optional[int] = None
|
| 310 |
+
):
|
| 311 |
+
if timesteps is None:
|
| 312 |
+
self.num_inference_steps = num_inference_steps
|
| 313 |
+
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
|
| 314 |
+
if self.config.dynamic_time_shift and num_tokens is not None:
|
| 315 |
+
m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
|
| 316 |
+
timesteps = timesteps / (m - m * timesteps + timesteps)
|
| 317 |
+
|
| 318 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
| 319 |
+
sigmas = torch.cat([1 - timesteps, torch.zeros(1, device=timesteps.device)])
|
| 320 |
+
|
| 321 |
+
self.sigmas = sigmas
|
| 322 |
+
self.timesteps = timesteps
|
| 323 |
+
|
| 324 |
+
self.num_inference_steps = len(timesteps)
|
| 325 |
+
|
| 326 |
+
self.model_outputs = [
|
| 327 |
+
None,
|
| 328 |
+
] * self.config.solver_order
|
| 329 |
+
self.lower_order_nums = 0
|
| 330 |
+
|
| 331 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
| 332 |
+
self._step_index = None
|
| 333 |
+
self._begin_index = None
|
| 334 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 335 |
+
|
| 336 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 337 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 338 |
+
"""
|
| 339 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 340 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 341 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 342 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 343 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 344 |
+
|
| 345 |
+
https://arxiv.org/abs/2205.11487
|
| 346 |
+
"""
|
| 347 |
+
dtype = sample.dtype
|
| 348 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 349 |
+
|
| 350 |
+
if dtype not in (torch.float32, torch.float64):
|
| 351 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 352 |
+
|
| 353 |
+
# Flatten sample for doing quantile calculation along each image
|
| 354 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 355 |
+
|
| 356 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 357 |
+
|
| 358 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 359 |
+
s = torch.clamp(
|
| 360 |
+
s, min=1, max=self.config.sample_max_value
|
| 361 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 362 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 363 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 364 |
+
|
| 365 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 366 |
+
sample = sample.to(dtype)
|
| 367 |
+
|
| 368 |
+
return sample
|
| 369 |
+
|
| 370 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
| 371 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
| 372 |
+
# get log sigma
|
| 373 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
| 374 |
+
|
| 375 |
+
# get distribution
|
| 376 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
| 377 |
+
|
| 378 |
+
# get sigmas range
|
| 379 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
| 380 |
+
high_idx = low_idx + 1
|
| 381 |
+
|
| 382 |
+
low = log_sigmas[low_idx]
|
| 383 |
+
high = log_sigmas[high_idx]
|
| 384 |
+
|
| 385 |
+
# interpolate sigmas
|
| 386 |
+
w = (low - log_sigma) / (low - high)
|
| 387 |
+
w = np.clip(w, 0, 1)
|
| 388 |
+
|
| 389 |
+
# transform interpolation to time range
|
| 390 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 391 |
+
t = t.reshape(sigma.shape)
|
| 392 |
+
return t
|
| 393 |
+
|
| 394 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 395 |
+
alpha_t = 1 - sigma
|
| 396 |
+
sigma_t = sigma
|
| 397 |
+
|
| 398 |
+
return alpha_t, sigma_t
|
| 399 |
+
|
| 400 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
| 401 |
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 402 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 403 |
+
|
| 404 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 405 |
+
# TODO: Add this logic to the other schedulers
|
| 406 |
+
if hasattr(self.config, "sigma_min"):
|
| 407 |
+
sigma_min = self.config.sigma_min
|
| 408 |
+
else:
|
| 409 |
+
sigma_min = None
|
| 410 |
+
|
| 411 |
+
if hasattr(self.config, "sigma_max"):
|
| 412 |
+
sigma_max = self.config.sigma_max
|
| 413 |
+
else:
|
| 414 |
+
sigma_max = None
|
| 415 |
+
|
| 416 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 417 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 418 |
+
|
| 419 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
| 420 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
| 421 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 422 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 423 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 424 |
+
return sigmas
|
| 425 |
+
|
| 426 |
+
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 427 |
+
"""Constructs the noise schedule of Lu et al. (2022)."""
|
| 428 |
+
|
| 429 |
+
lambda_min: float = in_lambdas[-1].item()
|
| 430 |
+
lambda_max: float = in_lambdas[0].item()
|
| 431 |
+
|
| 432 |
+
rho = 1.0 # 1.0 is the value used in the paper
|
| 433 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
| 434 |
+
min_inv_rho = lambda_min ** (1 / rho)
|
| 435 |
+
max_inv_rho = lambda_max ** (1 / rho)
|
| 436 |
+
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 437 |
+
return lambdas
|
| 438 |
+
|
| 439 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
| 440 |
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
| 441 |
+
"""Constructs an exponential noise schedule."""
|
| 442 |
+
|
| 443 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 444 |
+
# TODO: Add this logic to the other schedulers
|
| 445 |
+
if hasattr(self.config, "sigma_min"):
|
| 446 |
+
sigma_min = self.config.sigma_min
|
| 447 |
+
else:
|
| 448 |
+
sigma_min = None
|
| 449 |
+
|
| 450 |
+
if hasattr(self.config, "sigma_max"):
|
| 451 |
+
sigma_max = self.config.sigma_max
|
| 452 |
+
else:
|
| 453 |
+
sigma_max = None
|
| 454 |
+
|
| 455 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 456 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 457 |
+
|
| 458 |
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
| 459 |
+
return sigmas
|
| 460 |
+
|
| 461 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
| 462 |
+
def _convert_to_beta(
|
| 463 |
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
| 464 |
+
) -> torch.Tensor:
|
| 465 |
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
| 466 |
+
|
| 467 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 468 |
+
# TODO: Add this logic to the other schedulers
|
| 469 |
+
if hasattr(self.config, "sigma_min"):
|
| 470 |
+
sigma_min = self.config.sigma_min
|
| 471 |
+
else:
|
| 472 |
+
sigma_min = None
|
| 473 |
+
|
| 474 |
+
if hasattr(self.config, "sigma_max"):
|
| 475 |
+
sigma_max = self.config.sigma_max
|
| 476 |
+
else:
|
| 477 |
+
sigma_max = None
|
| 478 |
+
|
| 479 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 480 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 481 |
+
|
| 482 |
+
sigmas = np.array(
|
| 483 |
+
[
|
| 484 |
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
| 485 |
+
for ppf in [
|
| 486 |
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
| 487 |
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
| 488 |
+
]
|
| 489 |
+
]
|
| 490 |
+
)
|
| 491 |
+
return sigmas
|
| 492 |
+
|
| 493 |
+
def convert_model_output(
|
| 494 |
+
self,
|
| 495 |
+
model_output: torch.Tensor,
|
| 496 |
+
*args,
|
| 497 |
+
sample: torch.Tensor = None,
|
| 498 |
+
**kwargs,
|
| 499 |
+
) -> torch.Tensor:
|
| 500 |
+
"""
|
| 501 |
+
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
| 502 |
+
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
| 503 |
+
integral of the data prediction model.
|
| 504 |
+
|
| 505 |
+
<Tip>
|
| 506 |
+
|
| 507 |
+
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
| 508 |
+
prediction and data prediction models.
|
| 509 |
+
|
| 510 |
+
</Tip>
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
model_output (`torch.Tensor`):
|
| 514 |
+
The direct output from the learned diffusion model.
|
| 515 |
+
sample (`torch.Tensor`):
|
| 516 |
+
A current instance of a sample created by the diffusion process.
|
| 517 |
+
|
| 518 |
+
Returns:
|
| 519 |
+
`torch.Tensor`:
|
| 520 |
+
The converted model output.
|
| 521 |
+
"""
|
| 522 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 523 |
+
if sample is None:
|
| 524 |
+
if len(args) > 1:
|
| 525 |
+
sample = args[1]
|
| 526 |
+
else:
|
| 527 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
| 528 |
+
if timestep is not None:
|
| 529 |
+
deprecate(
|
| 530 |
+
"timesteps",
|
| 531 |
+
"1.0.0",
|
| 532 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
| 536 |
+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
| 537 |
+
if self.config.prediction_type == "epsilon":
|
| 538 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
| 539 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
| 540 |
+
model_output = model_output[:, :3]
|
| 541 |
+
sigma = self.sigmas[self.step_index]
|
| 542 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 543 |
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
| 544 |
+
elif self.config.prediction_type == "sample":
|
| 545 |
+
x0_pred = model_output
|
| 546 |
+
elif self.config.prediction_type == "v_prediction":
|
| 547 |
+
sigma = self.sigmas[self.step_index]
|
| 548 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 549 |
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
| 550 |
+
elif self.config.prediction_type == "flow_prediction":
|
| 551 |
+
sigma_t = self.sigmas[self.step_index]
|
| 552 |
+
x0_pred = sample + sigma_t * model_output
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError(
|
| 555 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
| 556 |
+
"`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
if self.config.thresholding:
|
| 560 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 561 |
+
|
| 562 |
+
return x0_pred
|
| 563 |
+
|
| 564 |
+
# DPM-Solver needs to solve an integral of the noise prediction model.
|
| 565 |
+
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 566 |
+
if self.config.prediction_type == "epsilon":
|
| 567 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
| 568 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
| 569 |
+
epsilon = model_output[:, :3]
|
| 570 |
+
else:
|
| 571 |
+
epsilon = model_output
|
| 572 |
+
elif self.config.prediction_type == "sample":
|
| 573 |
+
sigma = self.sigmas[self.step_index]
|
| 574 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 575 |
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
| 576 |
+
elif self.config.prediction_type == "v_prediction":
|
| 577 |
+
sigma = self.sigmas[self.step_index]
|
| 578 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 579 |
+
epsilon = alpha_t * model_output + sigma_t * sample
|
| 580 |
+
else:
|
| 581 |
+
raise ValueError(
|
| 582 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 583 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
if self.config.thresholding:
|
| 587 |
+
sigma = self.sigmas[self.step_index]
|
| 588 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 589 |
+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
| 590 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 591 |
+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
| 592 |
+
|
| 593 |
+
return epsilon
|
| 594 |
+
|
| 595 |
+
def dpm_solver_first_order_update(
|
| 596 |
+
self,
|
| 597 |
+
model_output: torch.Tensor,
|
| 598 |
+
*args,
|
| 599 |
+
sample: torch.Tensor = None,
|
| 600 |
+
noise: Optional[torch.Tensor] = None,
|
| 601 |
+
**kwargs,
|
| 602 |
+
) -> torch.Tensor:
|
| 603 |
+
"""
|
| 604 |
+
One step for the first-order DPMSolver (equivalent to DDIM).
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
model_output (`torch.Tensor`):
|
| 608 |
+
The direct output from the learned diffusion model.
|
| 609 |
+
sample (`torch.Tensor`):
|
| 610 |
+
A current instance of a sample created by the diffusion process.
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
`torch.Tensor`:
|
| 614 |
+
The sample tensor at the previous timestep.
|
| 615 |
+
"""
|
| 616 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 617 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 618 |
+
if sample is None:
|
| 619 |
+
if len(args) > 2:
|
| 620 |
+
sample = args[2]
|
| 621 |
+
else:
|
| 622 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
| 623 |
+
if timestep is not None:
|
| 624 |
+
deprecate(
|
| 625 |
+
"timesteps",
|
| 626 |
+
"1.0.0",
|
| 627 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
if prev_timestep is not None:
|
| 631 |
+
deprecate(
|
| 632 |
+
"prev_timestep",
|
| 633 |
+
"1.0.0",
|
| 634 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
| 638 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 639 |
+
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
| 640 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 641 |
+
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
| 642 |
+
|
| 643 |
+
h = lambda_t - lambda_s
|
| 644 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 645 |
+
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
| 646 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 647 |
+
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
| 648 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 649 |
+
assert noise is not None
|
| 650 |
+
x_t = (
|
| 651 |
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
| 652 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
| 653 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 654 |
+
)
|
| 655 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 656 |
+
assert noise is not None
|
| 657 |
+
x_t = (
|
| 658 |
+
(alpha_t / alpha_s) * sample
|
| 659 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
| 660 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 661 |
+
)
|
| 662 |
+
return x_t
|
| 663 |
+
|
| 664 |
+
def multistep_dpm_solver_second_order_update(
|
| 665 |
+
self,
|
| 666 |
+
model_output_list: List[torch.Tensor],
|
| 667 |
+
*args,
|
| 668 |
+
sample: torch.Tensor = None,
|
| 669 |
+
noise: Optional[torch.Tensor] = None,
|
| 670 |
+
**kwargs,
|
| 671 |
+
) -> torch.Tensor:
|
| 672 |
+
"""
|
| 673 |
+
One step for the second-order multistep DPMSolver.
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
model_output_list (`List[torch.Tensor]`):
|
| 677 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 678 |
+
sample (`torch.Tensor`):
|
| 679 |
+
A current instance of a sample created by the diffusion process.
|
| 680 |
+
|
| 681 |
+
Returns:
|
| 682 |
+
`torch.Tensor`:
|
| 683 |
+
The sample tensor at the previous timestep.
|
| 684 |
+
"""
|
| 685 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
| 686 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 687 |
+
if sample is None:
|
| 688 |
+
if len(args) > 2:
|
| 689 |
+
sample = args[2]
|
| 690 |
+
else:
|
| 691 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
| 692 |
+
if timestep_list is not None:
|
| 693 |
+
deprecate(
|
| 694 |
+
"timestep_list",
|
| 695 |
+
"1.0.0",
|
| 696 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
if prev_timestep is not None:
|
| 700 |
+
deprecate(
|
| 701 |
+
"prev_timestep",
|
| 702 |
+
"1.0.0",
|
| 703 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
sigma_t, sigma_s0, sigma_s1 = (
|
| 707 |
+
self.sigmas[self.step_index + 1],
|
| 708 |
+
self.sigmas[self.step_index],
|
| 709 |
+
self.sigmas[self.step_index - 1],
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 713 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 714 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 715 |
+
|
| 716 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 717 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 718 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 719 |
+
|
| 720 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
| 721 |
+
|
| 722 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
| 723 |
+
r0 = h_0 / h
|
| 724 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
| 725 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 726 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
| 727 |
+
if self.config.solver_type == "midpoint":
|
| 728 |
+
x_t = (
|
| 729 |
+
(sigma_t / sigma_s0) * sample
|
| 730 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 731 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
| 732 |
+
)
|
| 733 |
+
elif self.config.solver_type == "heun":
|
| 734 |
+
x_t = (
|
| 735 |
+
(sigma_t / sigma_s0) * sample
|
| 736 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 737 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 738 |
+
)
|
| 739 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 740 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 741 |
+
if self.config.solver_type == "midpoint":
|
| 742 |
+
x_t = (
|
| 743 |
+
(alpha_t / alpha_s0) * sample
|
| 744 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 745 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
| 746 |
+
)
|
| 747 |
+
elif self.config.solver_type == "heun":
|
| 748 |
+
x_t = (
|
| 749 |
+
(alpha_t / alpha_s0) * sample
|
| 750 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 751 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 752 |
+
)
|
| 753 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 754 |
+
assert noise is not None
|
| 755 |
+
if self.config.solver_type == "midpoint":
|
| 756 |
+
x_t = (
|
| 757 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 758 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 759 |
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
| 760 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 761 |
+
)
|
| 762 |
+
elif self.config.solver_type == "heun":
|
| 763 |
+
x_t = (
|
| 764 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 765 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 766 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
| 767 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 768 |
+
)
|
| 769 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 770 |
+
assert noise is not None
|
| 771 |
+
if self.config.solver_type == "midpoint":
|
| 772 |
+
x_t = (
|
| 773 |
+
(alpha_t / alpha_s0) * sample
|
| 774 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 775 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
| 776 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 777 |
+
)
|
| 778 |
+
elif self.config.solver_type == "heun":
|
| 779 |
+
x_t = (
|
| 780 |
+
(alpha_t / alpha_s0) * sample
|
| 781 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 782 |
+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 783 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 784 |
+
)
|
| 785 |
+
return x_t
|
| 786 |
+
|
| 787 |
+
def multistep_dpm_solver_third_order_update(
|
| 788 |
+
self,
|
| 789 |
+
model_output_list: List[torch.Tensor],
|
| 790 |
+
*args,
|
| 791 |
+
sample: torch.Tensor = None,
|
| 792 |
+
noise: Optional[torch.Tensor] = None,
|
| 793 |
+
**kwargs,
|
| 794 |
+
) -> torch.Tensor:
|
| 795 |
+
"""
|
| 796 |
+
One step for the third-order multistep DPMSolver.
|
| 797 |
+
|
| 798 |
+
Args:
|
| 799 |
+
model_output_list (`List[torch.Tensor]`):
|
| 800 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 801 |
+
sample (`torch.Tensor`):
|
| 802 |
+
A current instance of a sample created by diffusion process.
|
| 803 |
+
|
| 804 |
+
Returns:
|
| 805 |
+
`torch.Tensor`:
|
| 806 |
+
The sample tensor at the previous timestep.
|
| 807 |
+
"""
|
| 808 |
+
|
| 809 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
| 810 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 811 |
+
if sample is None:
|
| 812 |
+
if len(args) > 2:
|
| 813 |
+
sample = args[2]
|
| 814 |
+
else:
|
| 815 |
+
raise ValueError(" missing`sample` as a required keyward argument")
|
| 816 |
+
if timestep_list is not None:
|
| 817 |
+
deprecate(
|
| 818 |
+
"timestep_list",
|
| 819 |
+
"1.0.0",
|
| 820 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
if prev_timestep is not None:
|
| 824 |
+
deprecate(
|
| 825 |
+
"prev_timestep",
|
| 826 |
+
"1.0.0",
|
| 827 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
| 831 |
+
self.sigmas[self.step_index + 1],
|
| 832 |
+
self.sigmas[self.step_index],
|
| 833 |
+
self.sigmas[self.step_index - 1],
|
| 834 |
+
self.sigmas[self.step_index - 2],
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 838 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 839 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 840 |
+
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
| 841 |
+
|
| 842 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 843 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 844 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 845 |
+
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
| 846 |
+
|
| 847 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
| 848 |
+
|
| 849 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
| 850 |
+
r0, r1 = h_0 / h, h_1 / h
|
| 851 |
+
D0 = m0
|
| 852 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
| 853 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 854 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 855 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 856 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 857 |
+
x_t = (
|
| 858 |
+
(sigma_t / sigma_s0) * sample
|
| 859 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 860 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 861 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
| 862 |
+
)
|
| 863 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 864 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 865 |
+
x_t = (
|
| 866 |
+
(alpha_t / alpha_s0) * sample
|
| 867 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 868 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 869 |
+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
| 870 |
+
)
|
| 871 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 872 |
+
assert noise is not None
|
| 873 |
+
x_t = (
|
| 874 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 875 |
+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
|
| 876 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
| 877 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
|
| 878 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 879 |
+
)
|
| 880 |
+
return x_t
|
| 881 |
+
|
| 882 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 883 |
+
if schedule_timesteps is None:
|
| 884 |
+
schedule_timesteps = self.timesteps
|
| 885 |
+
|
| 886 |
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
| 887 |
+
|
| 888 |
+
if len(index_candidates) == 0:
|
| 889 |
+
step_index = len(self.timesteps) - 1
|
| 890 |
+
# The sigma index that is taken for the **very** first `step`
|
| 891 |
+
# is always the second index (or the last index if there is only 1)
|
| 892 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 893 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 894 |
+
elif len(index_candidates) > 1:
|
| 895 |
+
step_index = index_candidates[1].item()
|
| 896 |
+
else:
|
| 897 |
+
step_index = index_candidates[0].item()
|
| 898 |
+
|
| 899 |
+
return step_index
|
| 900 |
+
|
| 901 |
+
def _init_step_index(self, timestep):
|
| 902 |
+
"""
|
| 903 |
+
Initialize the step_index counter for the scheduler.
|
| 904 |
+
"""
|
| 905 |
+
|
| 906 |
+
if self.begin_index is None:
|
| 907 |
+
if isinstance(timestep, torch.Tensor):
|
| 908 |
+
timestep = timestep.to(self.timesteps.device)
|
| 909 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 910 |
+
else:
|
| 911 |
+
self._step_index = self._begin_index
|
| 912 |
+
|
| 913 |
+
def step(
|
| 914 |
+
self,
|
| 915 |
+
model_output: torch.Tensor,
|
| 916 |
+
timestep: Union[int, torch.Tensor],
|
| 917 |
+
sample: torch.Tensor,
|
| 918 |
+
generator=None,
|
| 919 |
+
variance_noise: Optional[torch.Tensor] = None,
|
| 920 |
+
return_dict: bool = True,
|
| 921 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 922 |
+
"""
|
| 923 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 924 |
+
the multistep DPMSolver.
|
| 925 |
+
|
| 926 |
+
Args:
|
| 927 |
+
model_output (`torch.Tensor`):
|
| 928 |
+
The direct output from learned diffusion model.
|
| 929 |
+
timestep (`int`):
|
| 930 |
+
The current discrete timestep in the diffusion chain.
|
| 931 |
+
sample (`torch.Tensor`):
|
| 932 |
+
A current instance of a sample created by the diffusion process.
|
| 933 |
+
generator (`torch.Generator`, *optional*):
|
| 934 |
+
A random number generator.
|
| 935 |
+
variance_noise (`torch.Tensor`):
|
| 936 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
| 937 |
+
itself. Useful for methods such as [`LEdits++`].
|
| 938 |
+
return_dict (`bool`):
|
| 939 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 940 |
+
|
| 941 |
+
Returns:
|
| 942 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 943 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 944 |
+
tuple is returned where the first element is the sample tensor.
|
| 945 |
+
|
| 946 |
+
"""
|
| 947 |
+
if self.num_inference_steps is None:
|
| 948 |
+
raise ValueError(
|
| 949 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
if self.step_index is None:
|
| 953 |
+
self._init_step_index(timestep)
|
| 954 |
+
|
| 955 |
+
# Improve numerical stability for small number of steps
|
| 956 |
+
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
| 957 |
+
self.config.euler_at_final
|
| 958 |
+
or (self.config.lower_order_final and len(self.timesteps) < 15)
|
| 959 |
+
or self.config.final_sigmas_type == "zero"
|
| 960 |
+
)
|
| 961 |
+
lower_order_second = (
|
| 962 |
+
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
model_output = self.convert_model_output(model_output, sample=sample)
|
| 966 |
+
for i in range(self.config.solver_order - 1):
|
| 967 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 968 |
+
self.model_outputs[-1] = model_output
|
| 969 |
+
|
| 970 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 971 |
+
sample = sample.to(torch.float32)
|
| 972 |
+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
| 973 |
+
noise = randn_tensor(
|
| 974 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
| 975 |
+
)
|
| 976 |
+
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
| 977 |
+
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
| 978 |
+
else:
|
| 979 |
+
noise = None
|
| 980 |
+
|
| 981 |
+
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
| 982 |
+
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
| 983 |
+
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
| 984 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
| 985 |
+
else:
|
| 986 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
|
| 987 |
+
|
| 988 |
+
if self.lower_order_nums < self.config.solver_order:
|
| 989 |
+
self.lower_order_nums += 1
|
| 990 |
+
|
| 991 |
+
# Cast sample back to expected dtype
|
| 992 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 993 |
+
|
| 994 |
+
# upon completion increase step index by one
|
| 995 |
+
self._step_index += 1
|
| 996 |
+
|
| 997 |
+
if not return_dict:
|
| 998 |
+
return (prev_sample,)
|
| 999 |
+
|
| 1000 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 1001 |
+
|
| 1002 |
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 1003 |
+
"""
|
| 1004 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 1005 |
+
current timestep.
|
| 1006 |
+
|
| 1007 |
+
Args:
|
| 1008 |
+
sample (`torch.Tensor`):
|
| 1009 |
+
The input sample.
|
| 1010 |
+
|
| 1011 |
+
Returns:
|
| 1012 |
+
`torch.Tensor`:
|
| 1013 |
+
A scaled input sample.
|
| 1014 |
+
"""
|
| 1015 |
+
return sample
|
| 1016 |
+
|
| 1017 |
+
def add_noise(
|
| 1018 |
+
self,
|
| 1019 |
+
original_samples: torch.Tensor,
|
| 1020 |
+
noise: torch.Tensor,
|
| 1021 |
+
timesteps: torch.IntTensor,
|
| 1022 |
+
) -> torch.Tensor:
|
| 1023 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 1024 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1025 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 1026 |
+
# mps does not support float64
|
| 1027 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 1028 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 1029 |
+
else:
|
| 1030 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 1031 |
+
timesteps = timesteps.to(original_samples.device)
|
| 1032 |
+
|
| 1033 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
| 1034 |
+
if self.begin_index is None:
|
| 1035 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 1036 |
+
elif self.step_index is not None:
|
| 1037 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 1038 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 1039 |
+
else:
|
| 1040 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 1041 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 1042 |
+
|
| 1043 |
+
sigma = sigmas[step_indices].flatten()
|
| 1044 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 1045 |
+
sigma = sigma.unsqueeze(-1)
|
| 1046 |
+
|
| 1047 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 1048 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 1049 |
+
return noisy_samples
|
| 1050 |
+
|
| 1051 |
+
def __len__(self):
|
| 1052 |
+
return self.config.num_train_timesteps
|
omnigen2/schedulers/scheduling_flow_match_euler_discrete.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.utils import BaseOutput, logging
|
| 24 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
| 32 |
+
"""
|
| 33 |
+
Output class for the scheduler's `step` function output.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 38 |
+
denoising loop.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
prev_sample: torch.FloatTensor
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 45 |
+
"""
|
| 46 |
+
Euler scheduler.
|
| 47 |
+
|
| 48 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 49 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 53 |
+
The number of diffusion steps to train the model.
|
| 54 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 55 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 56 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 57 |
+
shift (`float`, defaults to 1.0):
|
| 58 |
+
The shift value for the timestep schedule.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
_compatibles = []
|
| 62 |
+
order = 1
|
| 63 |
+
|
| 64 |
+
@register_to_config
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
num_train_timesteps: int = 1000,
|
| 68 |
+
dynamic_time_shift: bool = True
|
| 69 |
+
):
|
| 70 |
+
timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
|
| 71 |
+
|
| 72 |
+
self.timesteps = timesteps
|
| 73 |
+
|
| 74 |
+
self._step_index = None
|
| 75 |
+
self._begin_index = None
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def step_index(self):
|
| 79 |
+
"""
|
| 80 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 81 |
+
"""
|
| 82 |
+
return self._step_index
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def begin_index(self):
|
| 86 |
+
"""
|
| 87 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 88 |
+
"""
|
| 89 |
+
return self._begin_index
|
| 90 |
+
|
| 91 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 92 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 93 |
+
"""
|
| 94 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
begin_index (`int`):
|
| 98 |
+
The begin index for the scheduler.
|
| 99 |
+
"""
|
| 100 |
+
self._begin_index = begin_index
|
| 101 |
+
|
| 102 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 103 |
+
if schedule_timesteps is None:
|
| 104 |
+
schedule_timesteps = self._timesteps
|
| 105 |
+
|
| 106 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 107 |
+
|
| 108 |
+
# The sigma index that is taken for the **very** first `step`
|
| 109 |
+
# is always the second index (or the last index if there is only 1)
|
| 110 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 111 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 112 |
+
pos = 1 if len(indices) > 1 else 0
|
| 113 |
+
|
| 114 |
+
return indices[pos].item()
|
| 115 |
+
|
| 116 |
+
# def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 117 |
+
# return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 118 |
+
|
| 119 |
+
def set_timesteps(
|
| 120 |
+
self,
|
| 121 |
+
num_inference_steps: int = None,
|
| 122 |
+
device: Union[str, torch.device] = None,
|
| 123 |
+
timesteps: Optional[List[float]] = None,
|
| 124 |
+
num_tokens: Optional[int] = None
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
num_inference_steps (`int`):
|
| 131 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 132 |
+
device (`str` or `torch.device`, *optional*):
|
| 133 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
if timesteps is None:
|
| 137 |
+
self.num_inference_steps = num_inference_steps
|
| 138 |
+
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
|
| 139 |
+
if self.config.dynamic_time_shift and num_tokens is not None:
|
| 140 |
+
m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
|
| 141 |
+
timesteps = timesteps / (m - m * timesteps + timesteps)
|
| 142 |
+
|
| 143 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
| 144 |
+
_timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
|
| 145 |
+
|
| 146 |
+
self.timesteps = timesteps
|
| 147 |
+
self._timesteps = _timesteps
|
| 148 |
+
self._step_index = None
|
| 149 |
+
self._begin_index = None
|
| 150 |
+
|
| 151 |
+
def _init_step_index(self, timestep):
|
| 152 |
+
if self.begin_index is None:
|
| 153 |
+
if isinstance(timestep, torch.Tensor):
|
| 154 |
+
timestep = timestep.to(self.timesteps.device)
|
| 155 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 156 |
+
else:
|
| 157 |
+
self._step_index = self._begin_index
|
| 158 |
+
|
| 159 |
+
def step(
|
| 160 |
+
self,
|
| 161 |
+
model_output: torch.FloatTensor,
|
| 162 |
+
timestep: Union[float, torch.FloatTensor],
|
| 163 |
+
sample: torch.FloatTensor,
|
| 164 |
+
generator: Optional[torch.Generator] = None,
|
| 165 |
+
return_dict: bool = True,
|
| 166 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 167 |
+
"""
|
| 168 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 169 |
+
process from the learned model outputs (most often the predicted noise).
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
model_output (`torch.FloatTensor`):
|
| 173 |
+
The direct output from learned diffusion model.
|
| 174 |
+
timestep (`float`):
|
| 175 |
+
The current discrete timestep in the diffusion chain.
|
| 176 |
+
sample (`torch.FloatTensor`):
|
| 177 |
+
A current instance of a sample created by the diffusion process.
|
| 178 |
+
s_churn (`float`):
|
| 179 |
+
s_tmin (`float`):
|
| 180 |
+
s_tmax (`float`):
|
| 181 |
+
s_noise (`float`, defaults to 1.0):
|
| 182 |
+
Scaling factor for noise added to the sample.
|
| 183 |
+
generator (`torch.Generator`, *optional*):
|
| 184 |
+
A random number generator.
|
| 185 |
+
return_dict (`bool`):
|
| 186 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 187 |
+
tuple.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 191 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 192 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
if (
|
| 196 |
+
isinstance(timestep, int)
|
| 197 |
+
or isinstance(timestep, torch.IntTensor)
|
| 198 |
+
or isinstance(timestep, torch.LongTensor)
|
| 199 |
+
):
|
| 200 |
+
raise ValueError(
|
| 201 |
+
(
|
| 202 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 203 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 204 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 205 |
+
),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if self.step_index is None:
|
| 209 |
+
self._init_step_index(timestep)
|
| 210 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 211 |
+
sample = sample.to(torch.float32)
|
| 212 |
+
t = self._timesteps[self.step_index]
|
| 213 |
+
t_next = self._timesteps[self.step_index + 1]
|
| 214 |
+
|
| 215 |
+
prev_sample = sample + (t_next - t) * model_output
|
| 216 |
+
|
| 217 |
+
# Cast sample back to model compatible dtype
|
| 218 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 219 |
+
|
| 220 |
+
# upon completion increase step index by one
|
| 221 |
+
self._step_index += 1
|
| 222 |
+
|
| 223 |
+
if not return_dict:
|
| 224 |
+
return (prev_sample,)
|
| 225 |
+
|
| 226 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 227 |
+
|
| 228 |
+
def __len__(self):
|
| 229 |
+
return self.config.num_train_timesteps
|
omnigen2/taylorseer_utils/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/taylorseer_utils/__init__.py
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
|
| 8 |
+
"""
|
| 9 |
+
Compute derivative approximation.
|
| 10 |
+
|
| 11 |
+
:param cache_dic: Cache dictionary
|
| 12 |
+
:param current: Information of the current step
|
| 13 |
+
"""
|
| 14 |
+
difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2]
|
| 15 |
+
|
| 16 |
+
updated_taylor_factors = {}
|
| 17 |
+
updated_taylor_factors[0] = feature
|
| 18 |
+
|
| 19 |
+
for i in range(cache_dic['max_order']):
|
| 20 |
+
if (cache_dic['cache'][-1][current['stream']][current['layer']][current['module']].get(i, None) is not None) and (current['step'] > cache_dic['first_enhance'] - 2):
|
| 21 |
+
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i]) / difference_distance
|
| 22 |
+
else:
|
| 23 |
+
break
|
| 24 |
+
|
| 25 |
+
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = updated_taylor_factors
|
| 26 |
+
|
| 27 |
+
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Compute Taylor expansion error.
|
| 30 |
+
|
| 31 |
+
:param cache_dic: Cache dictionary
|
| 32 |
+
:param current: Information of the current step
|
| 33 |
+
"""
|
| 34 |
+
x = current['step'] - current['activated_steps'][-1]
|
| 35 |
+
#x = current['t'] - current['activated_times'][-1]
|
| 36 |
+
output = 0
|
| 37 |
+
|
| 38 |
+
for i in range(len(cache_dic['cache'][-1][current['stream']][current['layer']][current['module']])):
|
| 39 |
+
output += (1 / math.factorial(i)) * cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i] * (x ** i)
|
| 40 |
+
|
| 41 |
+
return output
|
| 42 |
+
|
| 43 |
+
def taylor_cache_init(cache_dic: Dict, current: Dict):
|
| 44 |
+
"""
|
| 45 |
+
Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache.
|
| 46 |
+
|
| 47 |
+
:param cache_dic: Cache dictionary
|
| 48 |
+
:param current: Information of the current step
|
| 49 |
+
"""
|
| 50 |
+
if (current['step'] == 0) and (cache_dic['taylor_cache']):
|
| 51 |
+
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = {}
|
omnigen2/training_utils.py
ADDED
|
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import copy
|
| 3 |
+
import gc
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from diffusers.models import UNet2DConditionModel
|
| 12 |
+
from diffusers.schedulers import SchedulerMixin
|
| 13 |
+
from diffusers.utils import (
|
| 14 |
+
convert_state_dict_to_diffusers,
|
| 15 |
+
convert_state_dict_to_peft,
|
| 16 |
+
deprecate,
|
| 17 |
+
is_peft_available,
|
| 18 |
+
is_torch_npu_available,
|
| 19 |
+
is_torchvision_available,
|
| 20 |
+
is_transformers_available,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if is_transformers_available():
|
| 25 |
+
import transformers
|
| 26 |
+
|
| 27 |
+
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 28 |
+
import deepspeed
|
| 29 |
+
|
| 30 |
+
if is_peft_available():
|
| 31 |
+
from peft import set_peft_model_state_dict
|
| 32 |
+
|
| 33 |
+
if is_torchvision_available():
|
| 34 |
+
from torchvision import transforms
|
| 35 |
+
|
| 36 |
+
if is_torch_npu_available():
|
| 37 |
+
import torch_npu # noqa: F401
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def set_seed(seed: int):
|
| 41 |
+
"""
|
| 42 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
seed (`int`): The seed to set.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
`None`
|
| 49 |
+
"""
|
| 50 |
+
random.seed(seed)
|
| 51 |
+
np.random.seed(seed)
|
| 52 |
+
torch.manual_seed(seed)
|
| 53 |
+
if is_torch_npu_available():
|
| 54 |
+
torch.npu.manual_seed_all(seed)
|
| 55 |
+
else:
|
| 56 |
+
torch.cuda.manual_seed_all(seed)
|
| 57 |
+
# ^^ safe to call this function even if cuda is not available
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def compute_snr(noise_scheduler, timesteps):
|
| 61 |
+
"""
|
| 62 |
+
Computes SNR as per
|
| 63 |
+
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
| 64 |
+
for the given timesteps using the provided noise scheduler.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
noise_scheduler (`NoiseScheduler`):
|
| 68 |
+
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
| 69 |
+
the SNR values.
|
| 70 |
+
timesteps (`torch.Tensor`):
|
| 71 |
+
A tensor of timesteps for which the SNR is computed.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
| 75 |
+
"""
|
| 76 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 77 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 78 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 79 |
+
|
| 80 |
+
# Expand the tensors.
|
| 81 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
| 82 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 83 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 84 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 85 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 86 |
+
|
| 87 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 88 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 89 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 90 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 91 |
+
|
| 92 |
+
# Compute SNR.
|
| 93 |
+
snr = (alpha / sigma) ** 2
|
| 94 |
+
return snr
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def resolve_interpolation_mode(interpolation_type: str):
|
| 98 |
+
"""
|
| 99 |
+
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
| 100 |
+
full list of supported enums is documented at
|
| 101 |
+
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
interpolation_type (`str`):
|
| 105 |
+
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
|
| 106 |
+
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
|
| 107 |
+
in torchvision.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
| 111 |
+
transform.
|
| 112 |
+
"""
|
| 113 |
+
if not is_torchvision_available():
|
| 114 |
+
raise ImportError(
|
| 115 |
+
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if interpolation_type == "bilinear":
|
| 119 |
+
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
| 120 |
+
elif interpolation_type == "bicubic":
|
| 121 |
+
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
| 122 |
+
elif interpolation_type == "box":
|
| 123 |
+
interpolation_mode = transforms.InterpolationMode.BOX
|
| 124 |
+
elif interpolation_type == "nearest":
|
| 125 |
+
interpolation_mode = transforms.InterpolationMode.NEAREST
|
| 126 |
+
elif interpolation_type == "nearest_exact":
|
| 127 |
+
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
|
| 128 |
+
elif interpolation_type == "hamming":
|
| 129 |
+
interpolation_mode = transforms.InterpolationMode.HAMMING
|
| 130 |
+
elif interpolation_type == "lanczos":
|
| 131 |
+
interpolation_mode = transforms.InterpolationMode.LANCZOS
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(
|
| 134 |
+
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
|
| 135 |
+
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return interpolation_mode
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def compute_dream_and_update_latents(
|
| 142 |
+
unet: UNet2DConditionModel,
|
| 143 |
+
noise_scheduler: SchedulerMixin,
|
| 144 |
+
timesteps: torch.Tensor,
|
| 145 |
+
noise: torch.Tensor,
|
| 146 |
+
noisy_latents: torch.Tensor,
|
| 147 |
+
target: torch.Tensor,
|
| 148 |
+
encoder_hidden_states: torch.Tensor,
|
| 149 |
+
dream_detail_preservation: float = 1.0,
|
| 150 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 151 |
+
"""
|
| 152 |
+
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
|
| 153 |
+
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
|
| 154 |
+
forward step without gradients.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
`unet`: The state unet to use to make a prediction.
|
| 158 |
+
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
|
| 159 |
+
`timesteps`: The timesteps for the noise_scheduler to user.
|
| 160 |
+
`noise`: A tensor of noise in the shape of noisy_latents.
|
| 161 |
+
`noisy_latents`: Previously noise latents from the training loop.
|
| 162 |
+
`target`: The ground-truth tensor to predict after eps is removed.
|
| 163 |
+
`encoder_hidden_states`: Text embeddings from the text model.
|
| 164 |
+
`dream_detail_preservation`: A float value that indicates detail preservation level.
|
| 165 |
+
See reference.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
|
| 169 |
+
"""
|
| 170 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
|
| 171 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 172 |
+
|
| 173 |
+
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
|
| 174 |
+
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
|
| 175 |
+
|
| 176 |
+
pred = None
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 179 |
+
|
| 180 |
+
_noisy_latents, _target = (None, None)
|
| 181 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 182 |
+
predicted_noise = pred
|
| 183 |
+
delta_noise = (noise - predicted_noise).detach()
|
| 184 |
+
delta_noise.mul_(dream_lambda)
|
| 185 |
+
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
| 186 |
+
_target = target.add(delta_noise)
|
| 187 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 188 |
+
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 191 |
+
|
| 192 |
+
return _noisy_latents, _target
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
| 196 |
+
r"""
|
| 197 |
+
Returns:
|
| 198 |
+
A state dict containing just the LoRA parameters.
|
| 199 |
+
"""
|
| 200 |
+
lora_state_dict = {}
|
| 201 |
+
|
| 202 |
+
for name, module in unet.named_modules():
|
| 203 |
+
if hasattr(module, "set_lora_layer"):
|
| 204 |
+
lora_layer = getattr(module, "lora_layer")
|
| 205 |
+
if lora_layer is not None:
|
| 206 |
+
current_lora_layer_sd = lora_layer.state_dict()
|
| 207 |
+
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
|
| 208 |
+
# The matrix name can either be "down" or "up".
|
| 209 |
+
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
| 210 |
+
|
| 211 |
+
return lora_state_dict
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
| 215 |
+
"""
|
| 216 |
+
Casts the training parameters of the model to the specified data type.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
model: The PyTorch model whose parameters will be cast.
|
| 220 |
+
dtype: The data type to which the model parameters will be cast.
|
| 221 |
+
"""
|
| 222 |
+
if not isinstance(model, list):
|
| 223 |
+
model = [model]
|
| 224 |
+
for m in model:
|
| 225 |
+
for param in m.parameters():
|
| 226 |
+
# only upcast trainable parameters into fp32
|
| 227 |
+
if param.requires_grad:
|
| 228 |
+
param.data = param.to(dtype)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _set_state_dict_into_text_encoder(
|
| 232 |
+
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
lora_state_dict: The state dictionary to be set.
|
| 239 |
+
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
|
| 240 |
+
text_encoder: Where the `lora_state_dict` is to be set.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
text_encoder_state_dict = {
|
| 244 |
+
f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
|
| 245 |
+
}
|
| 246 |
+
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
|
| 247 |
+
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def compute_density_for_timestep_sampling(
|
| 251 |
+
weighting_scheme: str,
|
| 252 |
+
batch_size: int,
|
| 253 |
+
logit_mean: float = None,
|
| 254 |
+
logit_std: float = None,
|
| 255 |
+
mode_scale: float = None,
|
| 256 |
+
device: Union[torch.device, str] = "cpu",
|
| 257 |
+
generator: Optional[torch.Generator] = None,
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Compute the density for sampling the timesteps when doing SD3 training.
|
| 261 |
+
|
| 262 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
| 263 |
+
|
| 264 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
| 265 |
+
"""
|
| 266 |
+
if weighting_scheme == "logit_normal":
|
| 267 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
|
| 268 |
+
u = torch.nn.functional.sigmoid(u)
|
| 269 |
+
elif weighting_scheme == "mode":
|
| 270 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
| 271 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
| 272 |
+
else:
|
| 273 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
| 274 |
+
return u
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
| 278 |
+
"""
|
| 279 |
+
Computes loss weighting scheme for SD3 training.
|
| 280 |
+
|
| 281 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
| 282 |
+
|
| 283 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
| 284 |
+
"""
|
| 285 |
+
if weighting_scheme == "sigma_sqrt":
|
| 286 |
+
weighting = (sigmas**-2.0).float()
|
| 287 |
+
elif weighting_scheme == "cosmap":
|
| 288 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
| 289 |
+
weighting = 2 / (math.pi * bot)
|
| 290 |
+
else:
|
| 291 |
+
weighting = torch.ones_like(sigmas)
|
| 292 |
+
return weighting
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def free_memory():
|
| 296 |
+
"""
|
| 297 |
+
Runs garbage collection. Then clears the cache of the available accelerator.
|
| 298 |
+
"""
|
| 299 |
+
gc.collect()
|
| 300 |
+
|
| 301 |
+
if torch.cuda.is_available():
|
| 302 |
+
torch.cuda.empty_cache()
|
| 303 |
+
elif torch.backends.mps.is_available():
|
| 304 |
+
torch.mps.empty_cache()
|
| 305 |
+
elif is_torch_npu_available():
|
| 306 |
+
torch_npu.npu.empty_cache()
|
| 307 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 308 |
+
torch.xpu.empty_cache()
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
| 312 |
+
class EMAModel:
|
| 313 |
+
"""
|
| 314 |
+
Exponential Moving Average of models weights
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
parameters: Iterable[torch.nn.Parameter],
|
| 320 |
+
decay: float = 0.9999,
|
| 321 |
+
min_decay: float = 0.0,
|
| 322 |
+
update_after_step: int = 0,
|
| 323 |
+
use_ema_warmup: bool = False,
|
| 324 |
+
inv_gamma: Union[float, int] = 1.0,
|
| 325 |
+
power: Union[float, int] = 2 / 3,
|
| 326 |
+
foreach: bool = False,
|
| 327 |
+
model_cls: Optional[Any] = None,
|
| 328 |
+
model_config: Dict[str, Any] = None,
|
| 329 |
+
**kwargs,
|
| 330 |
+
):
|
| 331 |
+
"""
|
| 332 |
+
Args:
|
| 333 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
| 334 |
+
decay (float): The decay factor for the exponential moving average.
|
| 335 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
| 336 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
| 337 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
| 338 |
+
inv_gamma (float):
|
| 339 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
| 340 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
| 341 |
+
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
|
| 342 |
+
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
| 343 |
+
weights will be stored on CPU.
|
| 344 |
+
|
| 345 |
+
@crowsonkb's notes on EMA Warmup:
|
| 346 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
| 347 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
| 348 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
| 349 |
+
at 215.4k steps).
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
if isinstance(parameters, torch.nn.Module):
|
| 353 |
+
deprecation_message = (
|
| 354 |
+
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
|
| 355 |
+
"Please pass the parameters of the module instead."
|
| 356 |
+
)
|
| 357 |
+
deprecate(
|
| 358 |
+
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
|
| 359 |
+
"1.0.0",
|
| 360 |
+
deprecation_message,
|
| 361 |
+
standard_warn=False,
|
| 362 |
+
)
|
| 363 |
+
parameters = parameters.parameters()
|
| 364 |
+
|
| 365 |
+
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
|
| 366 |
+
use_ema_warmup = True
|
| 367 |
+
|
| 368 |
+
if kwargs.get("max_value", None) is not None:
|
| 369 |
+
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
|
| 370 |
+
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
|
| 371 |
+
decay = kwargs["max_value"]
|
| 372 |
+
|
| 373 |
+
if kwargs.get("min_value", None) is not None:
|
| 374 |
+
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
|
| 375 |
+
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
|
| 376 |
+
min_decay = kwargs["min_value"]
|
| 377 |
+
|
| 378 |
+
parameters = list(parameters)
|
| 379 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
| 380 |
+
|
| 381 |
+
if kwargs.get("device", None) is not None:
|
| 382 |
+
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
|
| 383 |
+
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
|
| 384 |
+
self.to(device=kwargs["device"])
|
| 385 |
+
|
| 386 |
+
self.temp_stored_params = None
|
| 387 |
+
|
| 388 |
+
self.decay = decay
|
| 389 |
+
self.min_decay = min_decay
|
| 390 |
+
self.update_after_step = update_after_step
|
| 391 |
+
self.use_ema_warmup = use_ema_warmup
|
| 392 |
+
self.inv_gamma = inv_gamma
|
| 393 |
+
self.power = power
|
| 394 |
+
self.optimization_step = 0
|
| 395 |
+
self.cur_decay_value = None # set in `step()`
|
| 396 |
+
self.foreach = foreach
|
| 397 |
+
|
| 398 |
+
self.model_cls = model_cls
|
| 399 |
+
self.model_config = model_config
|
| 400 |
+
|
| 401 |
+
@classmethod
|
| 402 |
+
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
| 403 |
+
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
|
| 404 |
+
model = model_cls.from_pretrained(path)
|
| 405 |
+
|
| 406 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|
| 407 |
+
|
| 408 |
+
ema_model.load_state_dict(ema_kwargs)
|
| 409 |
+
return ema_model
|
| 410 |
+
|
| 411 |
+
def save_pretrained(self, path):
|
| 412 |
+
if self.model_cls is None:
|
| 413 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
| 414 |
+
|
| 415 |
+
if self.model_config is None:
|
| 416 |
+
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
| 417 |
+
|
| 418 |
+
model = self.model_cls.from_config(self.model_config)
|
| 419 |
+
state_dict = self.state_dict()
|
| 420 |
+
state_dict.pop("shadow_params", None)
|
| 421 |
+
|
| 422 |
+
model.register_to_config(**state_dict)
|
| 423 |
+
self.copy_to(model.parameters())
|
| 424 |
+
model.save_pretrained(path)
|
| 425 |
+
|
| 426 |
+
def get_decay(self, optimization_step: int) -> float:
|
| 427 |
+
"""
|
| 428 |
+
Compute the decay factor for the exponential moving average.
|
| 429 |
+
"""
|
| 430 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
| 431 |
+
|
| 432 |
+
if step <= 0:
|
| 433 |
+
return 0.0
|
| 434 |
+
|
| 435 |
+
if self.use_ema_warmup:
|
| 436 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
| 437 |
+
else:
|
| 438 |
+
cur_decay_value = (1 + step) / (10 + step)
|
| 439 |
+
|
| 440 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
| 441 |
+
# make sure decay is not smaller than min_decay
|
| 442 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
| 443 |
+
return cur_decay_value
|
| 444 |
+
|
| 445 |
+
@torch.no_grad()
|
| 446 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
| 447 |
+
if isinstance(parameters, torch.nn.Module):
|
| 448 |
+
deprecation_message = (
|
| 449 |
+
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
|
| 450 |
+
"Please pass the parameters of the module instead."
|
| 451 |
+
)
|
| 452 |
+
deprecate(
|
| 453 |
+
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
|
| 454 |
+
"1.0.0",
|
| 455 |
+
deprecation_message,
|
| 456 |
+
standard_warn=False,
|
| 457 |
+
)
|
| 458 |
+
parameters = parameters.parameters()
|
| 459 |
+
|
| 460 |
+
parameters = list(parameters)
|
| 461 |
+
|
| 462 |
+
self.optimization_step += 1
|
| 463 |
+
|
| 464 |
+
# Compute the decay factor for the exponential moving average.
|
| 465 |
+
decay = self.get_decay(self.optimization_step)
|
| 466 |
+
self.cur_decay_value = decay
|
| 467 |
+
one_minus_decay = 1 - decay
|
| 468 |
+
|
| 469 |
+
context_manager = contextlib.nullcontext()
|
| 470 |
+
|
| 471 |
+
if self.foreach:
|
| 472 |
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 473 |
+
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
|
| 474 |
+
|
| 475 |
+
with context_manager:
|
| 476 |
+
params_grad = [param for param in parameters if param.requires_grad]
|
| 477 |
+
s_params_grad = [
|
| 478 |
+
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
|
| 479 |
+
]
|
| 480 |
+
|
| 481 |
+
if len(params_grad) < len(parameters):
|
| 482 |
+
torch._foreach_copy_(
|
| 483 |
+
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
|
| 484 |
+
[param for param in parameters if not param.requires_grad],
|
| 485 |
+
non_blocking=True,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
torch._foreach_sub_(
|
| 489 |
+
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
else:
|
| 493 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 494 |
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 495 |
+
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
| 496 |
+
|
| 497 |
+
with context_manager:
|
| 498 |
+
if param.requires_grad:
|
| 499 |
+
# print(f"{s_param.shape=} {param.shape=}")
|
| 500 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
| 501 |
+
else:
|
| 502 |
+
s_param.copy_(param)
|
| 503 |
+
|
| 504 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 505 |
+
"""
|
| 506 |
+
Copy current averaged parameters into given collection of parameters.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 510 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
| 511 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 512 |
+
"""
|
| 513 |
+
parameters = list(parameters)
|
| 514 |
+
if self.foreach:
|
| 515 |
+
torch._foreach_copy_(
|
| 516 |
+
[param.data for param in parameters],
|
| 517 |
+
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 521 |
+
param.data.copy_(s_param.to(param.device).data)
|
| 522 |
+
|
| 523 |
+
def pin_memory(self) -> None:
|
| 524 |
+
r"""
|
| 525 |
+
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
|
| 526 |
+
offloading EMA params to the host.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
|
| 530 |
+
|
| 531 |
+
def to(self, device=None, dtype=None, non_blocking=False) -> None:
|
| 532 |
+
r"""
|
| 533 |
+
Move internal buffers of the ExponentialMovingAverage to `device`.
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
device: like `device` argument to `torch.Tensor.to`
|
| 537 |
+
"""
|
| 538 |
+
# .to() on the tensors handles None correctly
|
| 539 |
+
self.shadow_params = [
|
| 540 |
+
p.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
| 541 |
+
if p.is_floating_point()
|
| 542 |
+
else p.to(device=device, non_blocking=non_blocking)
|
| 543 |
+
for p in self.shadow_params
|
| 544 |
+
]
|
| 545 |
+
|
| 546 |
+
def state_dict(self) -> dict:
|
| 547 |
+
r"""
|
| 548 |
+
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
| 549 |
+
checkpointing to save the ema state dict.
|
| 550 |
+
"""
|
| 551 |
+
# Following PyTorch conventions, references to tensors are returned:
|
| 552 |
+
# "returns a reference to the state and not its copy!" -
|
| 553 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
| 554 |
+
return {
|
| 555 |
+
"decay": self.decay,
|
| 556 |
+
"min_decay": self.min_decay,
|
| 557 |
+
"optimization_step": self.optimization_step,
|
| 558 |
+
"update_after_step": self.update_after_step,
|
| 559 |
+
"use_ema_warmup": self.use_ema_warmup,
|
| 560 |
+
"inv_gamma": self.inv_gamma,
|
| 561 |
+
"power": self.power,
|
| 562 |
+
"shadow_params": self.shadow_params,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 566 |
+
r"""
|
| 567 |
+
Saves the current parameters for restoring later.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
|
| 571 |
+
"""
|
| 572 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
| 573 |
+
|
| 574 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 575 |
+
r"""
|
| 576 |
+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
|
| 577 |
+
without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
| 578 |
+
validation (or model saving), use this to restore the former parameters.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 582 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
| 583 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
if self.temp_stored_params is None:
|
| 587 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
| 588 |
+
if self.foreach:
|
| 589 |
+
torch._foreach_copy_(
|
| 590 |
+
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
|
| 591 |
+
)
|
| 592 |
+
else:
|
| 593 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
| 594 |
+
param.data.copy_(c_param.data)
|
| 595 |
+
|
| 596 |
+
# Better memory-wise.
|
| 597 |
+
self.temp_stored_params = None
|
| 598 |
+
|
| 599 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 600 |
+
r"""
|
| 601 |
+
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
| 602 |
+
ema state dict.
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
state_dict (dict): EMA state. Should be an object returned
|
| 606 |
+
from a call to :meth:`state_dict`.
|
| 607 |
+
"""
|
| 608 |
+
# deepcopy, to be consistent with module API
|
| 609 |
+
state_dict = copy.deepcopy(state_dict)
|
| 610 |
+
|
| 611 |
+
self.decay = state_dict.get("decay", self.decay)
|
| 612 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
| 613 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 614 |
+
|
| 615 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
| 616 |
+
if not isinstance(self.min_decay, float):
|
| 617 |
+
raise ValueError("Invalid min_decay")
|
| 618 |
+
|
| 619 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
| 620 |
+
if not isinstance(self.optimization_step, int):
|
| 621 |
+
raise ValueError("Invalid optimization_step")
|
| 622 |
+
|
| 623 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
| 624 |
+
if not isinstance(self.update_after_step, int):
|
| 625 |
+
raise ValueError("Invalid update_after_step")
|
| 626 |
+
|
| 627 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
| 628 |
+
if not isinstance(self.use_ema_warmup, bool):
|
| 629 |
+
raise ValueError("Invalid use_ema_warmup")
|
| 630 |
+
|
| 631 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
| 632 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
| 633 |
+
raise ValueError("Invalid inv_gamma")
|
| 634 |
+
|
| 635 |
+
self.power = state_dict.get("power", self.power)
|
| 636 |
+
if not isinstance(self.power, (float, int)):
|
| 637 |
+
raise ValueError("Invalid power")
|
| 638 |
+
|
| 639 |
+
shadow_params = state_dict.get("shadow_params", None)
|
| 640 |
+
if shadow_params is not None:
|
| 641 |
+
self.shadow_params = shadow_params
|
| 642 |
+
if not isinstance(self.shadow_params, list):
|
| 643 |
+
raise ValueError("shadow_params must be a list")
|
| 644 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
| 645 |
+
raise ValueError("shadow_params must all be Tensors")
|
omnigen2/transport/__init__.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transport import ModelType, PathType, Sampler, Transport, WeightType
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def create_transport(
|
| 5 |
+
path_type="Linear",
|
| 6 |
+
prediction="velocity",
|
| 7 |
+
loss_weight=None,
|
| 8 |
+
train_eps=None,
|
| 9 |
+
sample_eps=None,
|
| 10 |
+
snr_type="uniform",
|
| 11 |
+
do_shift=True,
|
| 12 |
+
seq_len=1024, # corresponding to 512x512
|
| 13 |
+
dynamic_time_shift: bool = False,
|
| 14 |
+
time_shift_version: str = "v1",
|
| 15 |
+
):
|
| 16 |
+
"""function for creating Transport object
|
| 17 |
+
**Note**: model prediction defaults to velocity
|
| 18 |
+
Args:
|
| 19 |
+
- path_type: type of path to use; default to linear
|
| 20 |
+
- learn_score: set model prediction to score
|
| 21 |
+
- learn_noise: set model prediction to noise
|
| 22 |
+
- velocity_weighted: weight loss by velocity weight
|
| 23 |
+
- likelihood_weighted: weight loss by likelihood weight
|
| 24 |
+
- train_eps: small epsilon for avoiding instability during training
|
| 25 |
+
- sample_eps: small epsilon for avoiding instability during sampling
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
if prediction == "noise":
|
| 29 |
+
model_type = ModelType.NOISE
|
| 30 |
+
elif prediction == "score":
|
| 31 |
+
model_type = ModelType.SCORE
|
| 32 |
+
else:
|
| 33 |
+
model_type = ModelType.VELOCITY
|
| 34 |
+
|
| 35 |
+
if loss_weight == "velocity":
|
| 36 |
+
loss_type = WeightType.VELOCITY
|
| 37 |
+
elif loss_weight == "likelihood":
|
| 38 |
+
loss_type = WeightType.LIKELIHOOD
|
| 39 |
+
else:
|
| 40 |
+
loss_type = WeightType.NONE
|
| 41 |
+
|
| 42 |
+
path_choice = {
|
| 43 |
+
"Linear": PathType.LINEAR,
|
| 44 |
+
"GVP": PathType.GVP,
|
| 45 |
+
"VP": PathType.VP,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
path_type = path_choice[path_type]
|
| 49 |
+
|
| 50 |
+
if path_type in [PathType.VP]:
|
| 51 |
+
train_eps = 1e-5 if train_eps is None else train_eps
|
| 52 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
| 53 |
+
elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
|
| 54 |
+
train_eps = 1e-3 if train_eps is None else train_eps
|
| 55 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
| 56 |
+
else: # velocity & [GVP, LINEAR] is stable everywhere
|
| 57 |
+
train_eps = 0
|
| 58 |
+
sample_eps = 0
|
| 59 |
+
|
| 60 |
+
# create flow state
|
| 61 |
+
state = Transport(
|
| 62 |
+
model_type=model_type,
|
| 63 |
+
path_type=path_type,
|
| 64 |
+
loss_type=loss_type,
|
| 65 |
+
train_eps=train_eps,
|
| 66 |
+
sample_eps=sample_eps,
|
| 67 |
+
snr_type=snr_type,
|
| 68 |
+
do_shift=do_shift,
|
| 69 |
+
seq_len=seq_len,
|
| 70 |
+
dynamic_time_shift=dynamic_time_shift,
|
| 71 |
+
time_shift_version=time_shift_version,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return state
|
omnigen2/transport/dpm_solver.py
ADDED
|
@@ -0,0 +1,1386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NoiseScheduleFlow:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
schedule="discrete_flow",
|
| 28 |
+
):
|
| 29 |
+
"""Create a wrapper class for the forward SDE (EDM type)."""
|
| 30 |
+
self.T = 1
|
| 31 |
+
self.t0 = 0.001
|
| 32 |
+
self.schedule = schedule # ['continuous', 'discrete_flow']
|
| 33 |
+
self.total_N = 1000
|
| 34 |
+
|
| 35 |
+
def marginal_log_mean_coeff(self, t):
|
| 36 |
+
"""
|
| 37 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 38 |
+
"""
|
| 39 |
+
return torch.log(self.marginal_alpha(t))
|
| 40 |
+
|
| 41 |
+
def marginal_alpha(self, t):
|
| 42 |
+
"""
|
| 43 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 44 |
+
"""
|
| 45 |
+
return 1 - t
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def marginal_std(t):
|
| 49 |
+
"""
|
| 50 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 51 |
+
"""
|
| 52 |
+
return t
|
| 53 |
+
|
| 54 |
+
def marginal_lambda(self, t):
|
| 55 |
+
"""
|
| 56 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 57 |
+
"""
|
| 58 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 59 |
+
log_std = torch.log(self.marginal_std(t))
|
| 60 |
+
return log_mean_coeff - log_std
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def inverse_lambda(lamb):
|
| 64 |
+
"""
|
| 65 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 66 |
+
"""
|
| 67 |
+
return torch.exp(-lamb)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def model_wrapper(
|
| 71 |
+
model,
|
| 72 |
+
noise_schedule,
|
| 73 |
+
model_type="noise",
|
| 74 |
+
model_kwargs={},
|
| 75 |
+
guidance_type="uncond",
|
| 76 |
+
condition=None,
|
| 77 |
+
unconditional_condition=None,
|
| 78 |
+
guidance_scale=1.0,
|
| 79 |
+
interval_guidance=[0, 1.0],
|
| 80 |
+
classifier_fn=None,
|
| 81 |
+
classifier_kwargs={},
|
| 82 |
+
):
|
| 83 |
+
"""Create a wrapper function for the noise prediction model.
|
| 84 |
+
|
| 85 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
| 86 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
| 87 |
+
|
| 88 |
+
We support four types of the diffusion model by setting `model_type`:
|
| 89 |
+
|
| 90 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
| 91 |
+
|
| 92 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
| 93 |
+
|
| 94 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
| 95 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
| 96 |
+
|
| 97 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
| 98 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
| 99 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
| 100 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
| 101 |
+
|
| 102 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
| 103 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
| 104 |
+
```
|
| 105 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
| 109 |
+
1. "uncond": unconditional sampling by DPMs.
|
| 110 |
+
The input `model` has the following format:
|
| 111 |
+
``
|
| 112 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 113 |
+
``
|
| 114 |
+
|
| 115 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
| 116 |
+
The input `model` has the following format:
|
| 117 |
+
``
|
| 118 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 119 |
+
``
|
| 120 |
+
|
| 121 |
+
The input `classifier_fn` has the following format:
|
| 122 |
+
``
|
| 123 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
| 124 |
+
``
|
| 125 |
+
|
| 126 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
| 127 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
| 128 |
+
|
| 129 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
| 130 |
+
The input `model` has the following format:
|
| 131 |
+
``
|
| 132 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
| 133 |
+
``
|
| 134 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
| 135 |
+
|
| 136 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
| 137 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
| 141 |
+
or continuous-time labels (i.e. epsilon to T).
|
| 142 |
+
|
| 143 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
| 144 |
+
``
|
| 145 |
+
def model_fn(x, t_continuous) -> noise:
|
| 146 |
+
t_input = get_model_input_time(t_continuous)
|
| 147 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
| 148 |
+
``
|
| 149 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
| 150 |
+
|
| 151 |
+
===============================================================
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
model: A diffusion model with the corresponding format described above.
|
| 155 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 156 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
| 157 |
+
"noise" or "x_start" or "v" or "score".
|
| 158 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
| 159 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
| 160 |
+
"uncond" or "classifier" or "classifier-free".
|
| 161 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
| 162 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
| 163 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
| 164 |
+
Only used for "classifier-free" guidance type.
|
| 165 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
| 166 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
| 167 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
| 168 |
+
Returns:
|
| 169 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def get_model_input_time(t_continuous):
|
| 173 |
+
"""
|
| 174 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 175 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 176 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 177 |
+
"""
|
| 178 |
+
if noise_schedule.schedule == "discrete":
|
| 179 |
+
return (t_continuous - 1.0 / noise_schedule.total_N) * noise_schedule.total_N
|
| 180 |
+
elif noise_schedule.schedule == "discrete_flow":
|
| 181 |
+
return t_continuous * noise_schedule.total_N
|
| 182 |
+
else:
|
| 183 |
+
return t_continuous
|
| 184 |
+
|
| 185 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 186 |
+
t_input = get_model_input_time(t_continuous)
|
| 187 |
+
if cond is None:
|
| 188 |
+
output = model(x, t_input, **model_kwargs)
|
| 189 |
+
else:
|
| 190 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 191 |
+
if model_type == "noise":
|
| 192 |
+
return output
|
| 193 |
+
elif model_type == "x_start":
|
| 194 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 195 |
+
return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
|
| 196 |
+
elif model_type == "v":
|
| 197 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 198 |
+
return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
|
| 199 |
+
elif model_type == "score":
|
| 200 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 201 |
+
return -expand_dims(sigma_t, x.dim()) * output
|
| 202 |
+
elif model_type == "flow":
|
| 203 |
+
_, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 204 |
+
try:
|
| 205 |
+
noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output + x
|
| 206 |
+
except:
|
| 207 |
+
noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0] + x
|
| 208 |
+
return noise
|
| 209 |
+
|
| 210 |
+
def cond_grad_fn(x, t_input):
|
| 211 |
+
"""
|
| 212 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 213 |
+
"""
|
| 214 |
+
with torch.enable_grad():
|
| 215 |
+
x_in = x.detach().requires_grad_(True)
|
| 216 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 217 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 218 |
+
|
| 219 |
+
def model_fn(x, t_continuous):
|
| 220 |
+
"""
|
| 221 |
+
The noise predicition model function that is used for DPM-Solver.
|
| 222 |
+
"""
|
| 223 |
+
guidance_tp = guidance_type
|
| 224 |
+
if guidance_tp == "uncond":
|
| 225 |
+
return noise_pred_fn(x, t_continuous)
|
| 226 |
+
elif guidance_tp == "classifier":
|
| 227 |
+
assert classifier_fn is not None
|
| 228 |
+
t_input = get_model_input_time(t_continuous)
|
| 229 |
+
cond_grad = cond_grad_fn(x, t_input)
|
| 230 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 231 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 232 |
+
return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
|
| 233 |
+
elif guidance_tp == "classifier-free":
|
| 234 |
+
if (
|
| 235 |
+
guidance_scale == 1.0
|
| 236 |
+
or unconditional_condition is None
|
| 237 |
+
or not (interval_guidance[0] < t_continuous[0] < interval_guidance[1])
|
| 238 |
+
):
|
| 239 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 240 |
+
else:
|
| 241 |
+
x_in = torch.cat([x] * 2)
|
| 242 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 243 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 244 |
+
try:
|
| 245 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 246 |
+
except:
|
| 247 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
|
| 248 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 249 |
+
|
| 250 |
+
assert model_type in ["noise", "x_start", "v", "score", "flow"]
|
| 251 |
+
assert guidance_type in [
|
| 252 |
+
"uncond",
|
| 253 |
+
"classifier",
|
| 254 |
+
"classifier-free",
|
| 255 |
+
]
|
| 256 |
+
return model_fn
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class DPM_Solver:
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
model_fn,
|
| 263 |
+
noise_schedule,
|
| 264 |
+
algorithm_type="dpmsolver++",
|
| 265 |
+
correcting_x0_fn=None,
|
| 266 |
+
correcting_xt_fn=None,
|
| 267 |
+
thresholding_max_val=1.0,
|
| 268 |
+
dynamic_thresholding_ratio=0.995,
|
| 269 |
+
):
|
| 270 |
+
"""Construct a DPM-Solver.
|
| 271 |
+
|
| 272 |
+
We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
|
| 273 |
+
|
| 274 |
+
We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
|
| 275 |
+
can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
|
| 276 |
+
dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
|
| 277 |
+
DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
|
| 278 |
+
DPMs (such as stable-diffusion).
|
| 279 |
+
|
| 280 |
+
To support advanced algorithms in image-to-image applications, we also support corrector functions for
|
| 281 |
+
both x0 and xt.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
|
| 285 |
+
``
|
| 286 |
+
def model_fn(x, t_continuous):
|
| 287 |
+
return noise
|
| 288 |
+
``
|
| 289 |
+
The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
|
| 290 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 291 |
+
algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
|
| 292 |
+
correcting_x0_fn: A `str` or a function with the following format:
|
| 293 |
+
```
|
| 294 |
+
def correcting_x0_fn(x0, t):
|
| 295 |
+
x0_new = ...
|
| 296 |
+
return x0_new
|
| 297 |
+
```
|
| 298 |
+
This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
|
| 299 |
+
```
|
| 300 |
+
x0_pred = data_pred_model(xt, t)
|
| 301 |
+
if correcting_x0_fn is not None:
|
| 302 |
+
x0_pred = correcting_x0_fn(x0_pred, t)
|
| 303 |
+
xt_1 = update(x0_pred, xt, t)
|
| 304 |
+
```
|
| 305 |
+
If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
|
| 306 |
+
correcting_xt_fn: A function with the following format:
|
| 307 |
+
```
|
| 308 |
+
def correcting_xt_fn(xt, t, step):
|
| 309 |
+
x_new = ...
|
| 310 |
+
return x_new
|
| 311 |
+
```
|
| 312 |
+
This function is to correct the intermediate samples xt at each sampling step. e.g.,
|
| 313 |
+
```
|
| 314 |
+
xt = ...
|
| 315 |
+
xt = correcting_xt_fn(xt, t, step)
|
| 316 |
+
```
|
| 317 |
+
thresholding_max_val: A `float`. The max value for thresholding.
|
| 318 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
| 319 |
+
dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
|
| 320 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
| 321 |
+
|
| 322 |
+
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
|
| 323 |
+
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
|
| 324 |
+
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
|
| 325 |
+
"""
|
| 326 |
+
self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
|
| 327 |
+
self.noise_schedule = noise_schedule
|
| 328 |
+
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
|
| 329 |
+
self.algorithm_type = algorithm_type
|
| 330 |
+
if correcting_x0_fn == "dynamic_thresholding":
|
| 331 |
+
self.correcting_x0_fn = self.dynamic_thresholding_fn
|
| 332 |
+
else:
|
| 333 |
+
self.correcting_x0_fn = correcting_x0_fn
|
| 334 |
+
self.correcting_xt_fn = correcting_xt_fn
|
| 335 |
+
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
| 336 |
+
self.thresholding_max_val = thresholding_max_val
|
| 337 |
+
self.register_progress_bar()
|
| 338 |
+
|
| 339 |
+
def register_progress_bar(self, progress_fn=None):
|
| 340 |
+
"""
|
| 341 |
+
Register a progress bar callback function
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
progress_fn: Callback function that takes current step and total steps as parameters
|
| 345 |
+
"""
|
| 346 |
+
self.progress_fn = progress_fn if progress_fn is not None else lambda step, total: None
|
| 347 |
+
|
| 348 |
+
def update_progress(self, step, total_steps):
|
| 349 |
+
"""
|
| 350 |
+
Update sampling progress
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
step: Current step number
|
| 354 |
+
total_steps: Total number of steps
|
| 355 |
+
"""
|
| 356 |
+
if hasattr(self, "progress_fn"):
|
| 357 |
+
try:
|
| 358 |
+
self.progress_fn(step / total_steps, desc=f"Generating {step}/{total_steps}")
|
| 359 |
+
except:
|
| 360 |
+
self.progress_fn(step, total_steps)
|
| 361 |
+
|
| 362 |
+
else:
|
| 363 |
+
# If no progress_fn registered, use default empty function
|
| 364 |
+
pass
|
| 365 |
+
|
| 366 |
+
def dynamic_thresholding_fn(self, x0, t):
|
| 367 |
+
"""
|
| 368 |
+
The dynamic thresholding method.
|
| 369 |
+
"""
|
| 370 |
+
dims = x0.dim()
|
| 371 |
+
p = self.dynamic_thresholding_ratio
|
| 372 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 373 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
| 374 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 375 |
+
return x0
|
| 376 |
+
|
| 377 |
+
def noise_prediction_fn(self, x, t):
|
| 378 |
+
"""
|
| 379 |
+
Return the noise prediction model.
|
| 380 |
+
"""
|
| 381 |
+
return self.model(x, t)
|
| 382 |
+
|
| 383 |
+
def data_prediction_fn(self, x, t):
|
| 384 |
+
"""
|
| 385 |
+
Return the data prediction model (with corrector).
|
| 386 |
+
"""
|
| 387 |
+
noise = self.noise_prediction_fn(x, t)
|
| 388 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
| 389 |
+
x0 = (x - sigma_t * noise) / alpha_t
|
| 390 |
+
if self.correcting_x0_fn is not None:
|
| 391 |
+
x0 = self.correcting_x0_fn(x0, t)
|
| 392 |
+
return x0
|
| 393 |
+
|
| 394 |
+
def model_fn(self, x, t):
|
| 395 |
+
"""
|
| 396 |
+
Convert the model to the noise prediction model or the data prediction model.
|
| 397 |
+
"""
|
| 398 |
+
if self.algorithm_type == "dpmsolver++":
|
| 399 |
+
return self.data_prediction_fn(x, t)
|
| 400 |
+
else:
|
| 401 |
+
return self.noise_prediction_fn(x, t)
|
| 402 |
+
|
| 403 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device, shift=1.0):
|
| 404 |
+
"""Compute the intermediate time steps for sampling.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 408 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 409 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 410 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 411 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 412 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 413 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 414 |
+
device: A torch device.
|
| 415 |
+
Returns:
|
| 416 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 417 |
+
"""
|
| 418 |
+
if skip_type == "logSNR":
|
| 419 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 420 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 421 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 422 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 423 |
+
elif skip_type == "time_uniform":
|
| 424 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 425 |
+
elif skip_type == "time_quadratic":
|
| 426 |
+
t_order = 2
|
| 427 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 428 |
+
return t
|
| 429 |
+
elif skip_type == "time_uniform_flow":
|
| 430 |
+
betas = torch.linspace(t_T, t_0, N + 1).to(device)
|
| 431 |
+
sigmas = 1.0 - betas
|
| 432 |
+
sigmas = (shift * sigmas / (1 + (shift - 1) * sigmas)).flip(dims=[0])
|
| 433 |
+
return sigmas
|
| 434 |
+
else:
|
| 435 |
+
raise ValueError(
|
| 436 |
+
f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
| 440 |
+
"""
|
| 441 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
| 442 |
+
|
| 443 |
+
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
|
| 444 |
+
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
|
| 445 |
+
- If order == 1:
|
| 446 |
+
We take `steps` of DPM-Solver-1 (i.e. DDIM).
|
| 447 |
+
- If order == 2:
|
| 448 |
+
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
|
| 449 |
+
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
|
| 450 |
+
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 451 |
+
- If order == 3:
|
| 452 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
| 453 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 454 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
| 455 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
| 456 |
+
|
| 457 |
+
============================================
|
| 458 |
+
Args:
|
| 459 |
+
order: A `int`. The max order for the solver (2 or 3).
|
| 460 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
| 461 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 462 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 463 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 464 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 465 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 466 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 467 |
+
device: A torch device.
|
| 468 |
+
Returns:
|
| 469 |
+
orders: A list of the solver order of each step.
|
| 470 |
+
"""
|
| 471 |
+
if order == 3:
|
| 472 |
+
K = steps // 3 + 1
|
| 473 |
+
if steps % 3 == 0:
|
| 474 |
+
orders = [3,] * (
|
| 475 |
+
K - 2
|
| 476 |
+
) + [2, 1]
|
| 477 |
+
elif steps % 3 == 1:
|
| 478 |
+
orders = [3,] * (
|
| 479 |
+
K - 1
|
| 480 |
+
) + [1]
|
| 481 |
+
else:
|
| 482 |
+
orders = [3,] * (
|
| 483 |
+
K - 1
|
| 484 |
+
) + [2]
|
| 485 |
+
elif order == 2:
|
| 486 |
+
if steps % 2 == 0:
|
| 487 |
+
K = steps // 2
|
| 488 |
+
orders = [
|
| 489 |
+
2,
|
| 490 |
+
] * K
|
| 491 |
+
else:
|
| 492 |
+
K = steps // 2 + 1
|
| 493 |
+
orders = [2,] * (
|
| 494 |
+
K - 1
|
| 495 |
+
) + [1]
|
| 496 |
+
elif order == 1:
|
| 497 |
+
K = 1
|
| 498 |
+
orders = [
|
| 499 |
+
1,
|
| 500 |
+
] * steps
|
| 501 |
+
else:
|
| 502 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
| 503 |
+
if skip_type == "logSNR":
|
| 504 |
+
# To reproduce the results in DPM-Solver paper
|
| 505 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
| 506 |
+
else:
|
| 507 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
| 508 |
+
torch.cumsum(
|
| 509 |
+
torch.tensor(
|
| 510 |
+
[
|
| 511 |
+
0,
|
| 512 |
+
]
|
| 513 |
+
+ orders
|
| 514 |
+
),
|
| 515 |
+
0,
|
| 516 |
+
).to(device)
|
| 517 |
+
]
|
| 518 |
+
return timesteps_outer, orders
|
| 519 |
+
|
| 520 |
+
def denoise_to_zero_fn(self, x, s):
|
| 521 |
+
"""
|
| 522 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
| 523 |
+
"""
|
| 524 |
+
return self.data_prediction_fn(x, s)
|
| 525 |
+
|
| 526 |
+
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
|
| 527 |
+
"""
|
| 528 |
+
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 532 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
| 533 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
| 534 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
| 535 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
| 536 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`.
|
| 537 |
+
Returns:
|
| 538 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 539 |
+
"""
|
| 540 |
+
ns = self.noise_schedule
|
| 541 |
+
dims = x.dim()
|
| 542 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
| 543 |
+
h = lambda_t - lambda_s
|
| 544 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
| 545 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
| 546 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 547 |
+
|
| 548 |
+
if self.algorithm_type == "dpmsolver++":
|
| 549 |
+
phi_1 = torch.expm1(-h)
|
| 550 |
+
if model_s is None:
|
| 551 |
+
model_s = self.model_fn(x, s)
|
| 552 |
+
x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
|
| 553 |
+
if return_intermediate:
|
| 554 |
+
return x_t, {"model_s": model_s}
|
| 555 |
+
else:
|
| 556 |
+
return x_t
|
| 557 |
+
else:
|
| 558 |
+
phi_1 = torch.expm1(h)
|
| 559 |
+
if model_s is None:
|
| 560 |
+
model_s = self.model_fn(x, s)
|
| 561 |
+
x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
|
| 562 |
+
if return_intermediate:
|
| 563 |
+
return x_t, {"model_s": model_s}
|
| 564 |
+
else:
|
| 565 |
+
return x_t
|
| 566 |
+
|
| 567 |
+
def singlestep_dpm_solver_second_update(
|
| 568 |
+
self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"
|
| 569 |
+
):
|
| 570 |
+
"""
|
| 571 |
+
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 575 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
| 576 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
| 577 |
+
r1: A `float`. The hyperparameter of the second-order solver.
|
| 578 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
| 579 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
| 580 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
|
| 581 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
| 582 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
| 583 |
+
Returns:
|
| 584 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 585 |
+
"""
|
| 586 |
+
if solver_type not in ["dpmsolver", "taylor"]:
|
| 587 |
+
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
|
| 588 |
+
if r1 is None:
|
| 589 |
+
r1 = 0.5
|
| 590 |
+
ns = self.noise_schedule
|
| 591 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
| 592 |
+
h = lambda_t - lambda_s
|
| 593 |
+
lambda_s1 = lambda_s + r1 * h
|
| 594 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
| 595 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = (
|
| 596 |
+
ns.marginal_log_mean_coeff(s),
|
| 597 |
+
ns.marginal_log_mean_coeff(s1),
|
| 598 |
+
ns.marginal_log_mean_coeff(t),
|
| 599 |
+
)
|
| 600 |
+
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
| 601 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
| 602 |
+
|
| 603 |
+
if self.algorithm_type == "dpmsolver++":
|
| 604 |
+
phi_11 = torch.expm1(-r1 * h)
|
| 605 |
+
phi_1 = torch.expm1(-h)
|
| 606 |
+
|
| 607 |
+
if model_s is None:
|
| 608 |
+
model_s = self.model_fn(x, s)
|
| 609 |
+
x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
|
| 610 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 611 |
+
if solver_type == "dpmsolver":
|
| 612 |
+
x_t = (
|
| 613 |
+
(sigma_t / sigma_s) * x
|
| 614 |
+
- (alpha_t * phi_1) * model_s
|
| 615 |
+
- (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
|
| 616 |
+
)
|
| 617 |
+
elif solver_type == "taylor":
|
| 618 |
+
x_t = (
|
| 619 |
+
(sigma_t / sigma_s) * x
|
| 620 |
+
- (alpha_t * phi_1) * model_s
|
| 621 |
+
+ (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
|
| 622 |
+
)
|
| 623 |
+
else:
|
| 624 |
+
phi_11 = torch.expm1(r1 * h)
|
| 625 |
+
phi_1 = torch.expm1(h)
|
| 626 |
+
|
| 627 |
+
if model_s is None:
|
| 628 |
+
model_s = self.model_fn(x, s)
|
| 629 |
+
x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
|
| 630 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 631 |
+
if solver_type == "dpmsolver":
|
| 632 |
+
x_t = (
|
| 633 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
| 634 |
+
- (sigma_t * phi_1) * model_s
|
| 635 |
+
- (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
|
| 636 |
+
)
|
| 637 |
+
elif solver_type == "taylor":
|
| 638 |
+
x_t = (
|
| 639 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
| 640 |
+
- (sigma_t * phi_1) * model_s
|
| 641 |
+
- (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
|
| 642 |
+
)
|
| 643 |
+
if return_intermediate:
|
| 644 |
+
return x_t, {"model_s": model_s, "model_s1": model_s1}
|
| 645 |
+
else:
|
| 646 |
+
return x_t
|
| 647 |
+
|
| 648 |
+
def singlestep_dpm_solver_third_update(
|
| 649 |
+
self,
|
| 650 |
+
x,
|
| 651 |
+
s,
|
| 652 |
+
t,
|
| 653 |
+
r1=1.0 / 3.0,
|
| 654 |
+
r2=2.0 / 3.0,
|
| 655 |
+
model_s=None,
|
| 656 |
+
model_s1=None,
|
| 657 |
+
return_intermediate=False,
|
| 658 |
+
solver_type="dpmsolver",
|
| 659 |
+
):
|
| 660 |
+
"""
|
| 661 |
+
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 665 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
| 666 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
| 667 |
+
r1: A `float`. The hyperparameter of the third-order solver.
|
| 668 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
| 669 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
| 670 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
| 671 |
+
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
|
| 672 |
+
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
|
| 673 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
| 674 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
| 675 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
| 676 |
+
Returns:
|
| 677 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 678 |
+
"""
|
| 679 |
+
if solver_type not in ["dpmsolver", "taylor"]:
|
| 680 |
+
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
|
| 681 |
+
if r1 is None:
|
| 682 |
+
r1 = 1.0 / 3.0
|
| 683 |
+
if r2 is None:
|
| 684 |
+
r2 = 2.0 / 3.0
|
| 685 |
+
ns = self.noise_schedule
|
| 686 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
| 687 |
+
h = lambda_t - lambda_s
|
| 688 |
+
lambda_s1 = lambda_s + r1 * h
|
| 689 |
+
lambda_s2 = lambda_s + r2 * h
|
| 690 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
| 691 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
| 692 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
|
| 693 |
+
ns.marginal_log_mean_coeff(s),
|
| 694 |
+
ns.marginal_log_mean_coeff(s1),
|
| 695 |
+
ns.marginal_log_mean_coeff(s2),
|
| 696 |
+
ns.marginal_log_mean_coeff(t),
|
| 697 |
+
)
|
| 698 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = (
|
| 699 |
+
ns.marginal_std(s),
|
| 700 |
+
ns.marginal_std(s1),
|
| 701 |
+
ns.marginal_std(s2),
|
| 702 |
+
ns.marginal_std(t),
|
| 703 |
+
)
|
| 704 |
+
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
| 705 |
+
|
| 706 |
+
if self.algorithm_type == "dpmsolver++":
|
| 707 |
+
phi_11 = torch.expm1(-r1 * h)
|
| 708 |
+
phi_12 = torch.expm1(-r2 * h)
|
| 709 |
+
phi_1 = torch.expm1(-h)
|
| 710 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
|
| 711 |
+
phi_2 = phi_1 / h + 1.0
|
| 712 |
+
phi_3 = phi_2 / h - 0.5
|
| 713 |
+
|
| 714 |
+
if model_s is None:
|
| 715 |
+
model_s = self.model_fn(x, s)
|
| 716 |
+
if model_s1 is None:
|
| 717 |
+
x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
|
| 718 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 719 |
+
x_s2 = (
|
| 720 |
+
(sigma_s2 / sigma_s) * x
|
| 721 |
+
- (alpha_s2 * phi_12) * model_s
|
| 722 |
+
+ r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
|
| 723 |
+
)
|
| 724 |
+
model_s2 = self.model_fn(x_s2, s2)
|
| 725 |
+
if solver_type == "dpmsolver":
|
| 726 |
+
x_t = (
|
| 727 |
+
(sigma_t / sigma_s) * x
|
| 728 |
+
- (alpha_t * phi_1) * model_s
|
| 729 |
+
+ (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
|
| 730 |
+
)
|
| 731 |
+
elif solver_type == "taylor":
|
| 732 |
+
D1_0 = (1.0 / r1) * (model_s1 - model_s)
|
| 733 |
+
D1_1 = (1.0 / r2) * (model_s2 - model_s)
|
| 734 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
| 735 |
+
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
|
| 736 |
+
x_t = (
|
| 737 |
+
(sigma_t / sigma_s) * x
|
| 738 |
+
- (alpha_t * phi_1) * model_s
|
| 739 |
+
+ (alpha_t * phi_2) * D1
|
| 740 |
+
- (alpha_t * phi_3) * D2
|
| 741 |
+
)
|
| 742 |
+
else:
|
| 743 |
+
phi_11 = torch.expm1(r1 * h)
|
| 744 |
+
phi_12 = torch.expm1(r2 * h)
|
| 745 |
+
phi_1 = torch.expm1(h)
|
| 746 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
|
| 747 |
+
phi_2 = phi_1 / h - 1.0
|
| 748 |
+
phi_3 = phi_2 / h - 0.5
|
| 749 |
+
|
| 750 |
+
if model_s is None:
|
| 751 |
+
model_s = self.model_fn(x, s)
|
| 752 |
+
if model_s1 is None:
|
| 753 |
+
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
|
| 754 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 755 |
+
x_s2 = (
|
| 756 |
+
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
|
| 757 |
+
- (sigma_s2 * phi_12) * model_s
|
| 758 |
+
- r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
|
| 759 |
+
)
|
| 760 |
+
model_s2 = self.model_fn(x_s2, s2)
|
| 761 |
+
if solver_type == "dpmsolver":
|
| 762 |
+
x_t = (
|
| 763 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
| 764 |
+
- (sigma_t * phi_1) * model_s
|
| 765 |
+
- (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
|
| 766 |
+
)
|
| 767 |
+
elif solver_type == "taylor":
|
| 768 |
+
D1_0 = (1.0 / r1) * (model_s1 - model_s)
|
| 769 |
+
D1_1 = (1.0 / r2) * (model_s2 - model_s)
|
| 770 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
| 771 |
+
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
|
| 772 |
+
x_t = (
|
| 773 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
| 774 |
+
- (sigma_t * phi_1) * model_s
|
| 775 |
+
- (sigma_t * phi_2) * D1
|
| 776 |
+
- (sigma_t * phi_3) * D2
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
if return_intermediate:
|
| 780 |
+
return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
|
| 781 |
+
else:
|
| 782 |
+
return x_t
|
| 783 |
+
|
| 784 |
+
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
|
| 785 |
+
"""
|
| 786 |
+
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
| 787 |
+
|
| 788 |
+
Args:
|
| 789 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 790 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 791 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
| 792 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
| 793 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
| 794 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
| 795 |
+
Returns:
|
| 796 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 797 |
+
"""
|
| 798 |
+
if solver_type not in ["dpmsolver", "taylor"]:
|
| 799 |
+
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
|
| 800 |
+
ns = self.noise_schedule
|
| 801 |
+
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
|
| 802 |
+
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
|
| 803 |
+
lambda_prev_1, lambda_prev_0, lambda_t = (
|
| 804 |
+
ns.marginal_lambda(t_prev_1),
|
| 805 |
+
ns.marginal_lambda(t_prev_0),
|
| 806 |
+
ns.marginal_lambda(t),
|
| 807 |
+
)
|
| 808 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 809 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 810 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 811 |
+
|
| 812 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
| 813 |
+
h = lambda_t - lambda_prev_0
|
| 814 |
+
r0 = h_0 / h
|
| 815 |
+
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
|
| 816 |
+
if self.algorithm_type == "dpmsolver++":
|
| 817 |
+
phi_1 = torch.expm1(-h)
|
| 818 |
+
if solver_type == "dpmsolver":
|
| 819 |
+
x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
|
| 820 |
+
elif solver_type == "taylor":
|
| 821 |
+
x_t = (
|
| 822 |
+
(sigma_t / sigma_prev_0) * x
|
| 823 |
+
- (alpha_t * phi_1) * model_prev_0
|
| 824 |
+
+ (alpha_t * (phi_1 / h + 1.0)) * D1_0
|
| 825 |
+
)
|
| 826 |
+
else:
|
| 827 |
+
phi_1 = torch.expm1(h)
|
| 828 |
+
if solver_type == "dpmsolver":
|
| 829 |
+
x_t = (
|
| 830 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
| 831 |
+
- (sigma_t * phi_1) * model_prev_0
|
| 832 |
+
- 0.5 * (sigma_t * phi_1) * D1_0
|
| 833 |
+
)
|
| 834 |
+
elif solver_type == "taylor":
|
| 835 |
+
x_t = (
|
| 836 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
| 837 |
+
- (sigma_t * phi_1) * model_prev_0
|
| 838 |
+
- (sigma_t * (phi_1 / h - 1.0)) * D1_0
|
| 839 |
+
)
|
| 840 |
+
return x_t
|
| 841 |
+
|
| 842 |
+
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
|
| 843 |
+
"""
|
| 844 |
+
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
| 845 |
+
|
| 846 |
+
Args:
|
| 847 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 848 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 849 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
| 850 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
| 851 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
| 852 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
| 853 |
+
Returns:
|
| 854 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 855 |
+
"""
|
| 856 |
+
ns = self.noise_schedule
|
| 857 |
+
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
| 858 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
| 859 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
|
| 860 |
+
ns.marginal_lambda(t_prev_2),
|
| 861 |
+
ns.marginal_lambda(t_prev_1),
|
| 862 |
+
ns.marginal_lambda(t_prev_0),
|
| 863 |
+
ns.marginal_lambda(t),
|
| 864 |
+
)
|
| 865 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 866 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 867 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 868 |
+
|
| 869 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
| 870 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
| 871 |
+
h = lambda_t - lambda_prev_0
|
| 872 |
+
r0, r1 = h_0 / h, h_1 / h
|
| 873 |
+
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
|
| 874 |
+
D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
|
| 875 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 876 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 877 |
+
if self.algorithm_type == "dpmsolver++":
|
| 878 |
+
phi_1 = torch.expm1(-h)
|
| 879 |
+
phi_2 = phi_1 / h + 1.0
|
| 880 |
+
phi_3 = phi_2 / h - 0.5
|
| 881 |
+
x_t = (
|
| 882 |
+
(sigma_t / sigma_prev_0) * x
|
| 883 |
+
- (alpha_t * phi_1) * model_prev_0
|
| 884 |
+
+ (alpha_t * phi_2) * D1
|
| 885 |
+
- (alpha_t * phi_3) * D2
|
| 886 |
+
)
|
| 887 |
+
else:
|
| 888 |
+
phi_1 = torch.expm1(h)
|
| 889 |
+
phi_2 = phi_1 / h - 1.0
|
| 890 |
+
phi_3 = phi_2 / h - 0.5
|
| 891 |
+
x_t = (
|
| 892 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
| 893 |
+
- (sigma_t * phi_1) * model_prev_0
|
| 894 |
+
- (sigma_t * phi_2) * D1
|
| 895 |
+
- (sigma_t * phi_3) * D2
|
| 896 |
+
)
|
| 897 |
+
return x_t
|
| 898 |
+
|
| 899 |
+
def singlestep_dpm_solver_update(
|
| 900 |
+
self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None
|
| 901 |
+
):
|
| 902 |
+
"""
|
| 903 |
+
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
| 904 |
+
|
| 905 |
+
Args:
|
| 906 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 907 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
| 908 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
| 909 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
| 910 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
| 911 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
| 912 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
| 913 |
+
r1: A `float`. The hyperparameter of the second-order or third-order solver.
|
| 914 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
| 915 |
+
Returns:
|
| 916 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 917 |
+
"""
|
| 918 |
+
if order == 1:
|
| 919 |
+
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
| 920 |
+
elif order == 2:
|
| 921 |
+
return self.singlestep_dpm_solver_second_update(
|
| 922 |
+
x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
|
| 923 |
+
)
|
| 924 |
+
elif order == 3:
|
| 925 |
+
return self.singlestep_dpm_solver_third_update(
|
| 926 |
+
x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
|
| 927 |
+
)
|
| 928 |
+
else:
|
| 929 |
+
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
|
| 930 |
+
|
| 931 |
+
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
|
| 932 |
+
"""
|
| 933 |
+
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
| 934 |
+
|
| 935 |
+
Args:
|
| 936 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 937 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 938 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
| 939 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
| 940 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
| 941 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
| 942 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
| 943 |
+
Returns:
|
| 944 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 945 |
+
"""
|
| 946 |
+
if order == 1:
|
| 947 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
|
| 948 |
+
elif order == 2:
|
| 949 |
+
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
| 950 |
+
elif order == 3:
|
| 951 |
+
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
| 952 |
+
else:
|
| 953 |
+
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
|
| 954 |
+
|
| 955 |
+
def dpm_solver_adaptive(
|
| 956 |
+
self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"
|
| 957 |
+
):
|
| 958 |
+
"""
|
| 959 |
+
The adaptive step size solver based on singlestep DPM-Solver.
|
| 960 |
+
|
| 961 |
+
Args:
|
| 962 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
| 963 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
| 964 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 965 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 966 |
+
h_init: A `float`. The initial step size (for logSNR).
|
| 967 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
| 968 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
| 969 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
| 970 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
| 971 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
| 972 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
| 973 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
| 974 |
+
Returns:
|
| 975 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
| 976 |
+
|
| 977 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
| 978 |
+
"""
|
| 979 |
+
ns = self.noise_schedule
|
| 980 |
+
s = t_T * torch.ones((1,)).to(x)
|
| 981 |
+
lambda_s = ns.marginal_lambda(s)
|
| 982 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
| 983 |
+
h = h_init * torch.ones_like(s).to(x)
|
| 984 |
+
x_prev = x
|
| 985 |
+
nfe = 0
|
| 986 |
+
if order == 2:
|
| 987 |
+
r1 = 0.5
|
| 988 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
| 989 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
|
| 990 |
+
x, s, t, r1=r1, solver_type=solver_type, **kwargs
|
| 991 |
+
)
|
| 992 |
+
elif order == 3:
|
| 993 |
+
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
|
| 994 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
|
| 995 |
+
x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
|
| 996 |
+
)
|
| 997 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
|
| 998 |
+
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
|
| 999 |
+
)
|
| 1000 |
+
else:
|
| 1001 |
+
raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
|
| 1002 |
+
while torch.abs(s - t_0).mean() > t_err:
|
| 1003 |
+
t = ns.inverse_lambda(lambda_s + h)
|
| 1004 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
| 1005 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
| 1006 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
| 1007 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
| 1008 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
| 1009 |
+
if torch.all(E <= 1.0):
|
| 1010 |
+
x = x_higher
|
| 1011 |
+
s = t
|
| 1012 |
+
x_prev = x_lower
|
| 1013 |
+
lambda_s = ns.marginal_lambda(s)
|
| 1014 |
+
h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
|
| 1015 |
+
nfe += order
|
| 1016 |
+
print("adaptive solver nfe", nfe)
|
| 1017 |
+
return x
|
| 1018 |
+
|
| 1019 |
+
def add_noise(self, x, t, noise=None):
|
| 1020 |
+
"""
|
| 1021 |
+
Compute the noised input xt = alpha_t * x + sigma_t * noise.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
x: A `torch.Tensor` with shape `(batch_size, *shape)`.
|
| 1025 |
+
t: A `torch.Tensor` with shape `(t_size,)`.
|
| 1026 |
+
Returns:
|
| 1027 |
+
xt with shape `(t_size, batch_size, *shape)`.
|
| 1028 |
+
"""
|
| 1029 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
| 1030 |
+
if noise is None:
|
| 1031 |
+
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
|
| 1032 |
+
x = x.reshape((-1, *x.shape))
|
| 1033 |
+
xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
|
| 1034 |
+
if t.shape[0] == 1:
|
| 1035 |
+
return xt.squeeze(0)
|
| 1036 |
+
else:
|
| 1037 |
+
return xt
|
| 1038 |
+
|
| 1039 |
+
def inverse(
|
| 1040 |
+
self,
|
| 1041 |
+
x,
|
| 1042 |
+
steps=20,
|
| 1043 |
+
t_start=None,
|
| 1044 |
+
t_end=None,
|
| 1045 |
+
order=2,
|
| 1046 |
+
skip_type="time_uniform",
|
| 1047 |
+
method="multistep",
|
| 1048 |
+
lower_order_final=True,
|
| 1049 |
+
denoise_to_zero=False,
|
| 1050 |
+
solver_type="dpmsolver",
|
| 1051 |
+
atol=0.0078,
|
| 1052 |
+
rtol=0.05,
|
| 1053 |
+
return_intermediate=False,
|
| 1054 |
+
):
|
| 1055 |
+
"""
|
| 1056 |
+
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
|
| 1057 |
+
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
|
| 1058 |
+
"""
|
| 1059 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
|
| 1060 |
+
t_T = self.noise_schedule.T if t_end is None else t_end
|
| 1061 |
+
assert (
|
| 1062 |
+
t_0 > 0 and t_T > 0
|
| 1063 |
+
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
| 1064 |
+
return self.sample(
|
| 1065 |
+
x,
|
| 1066 |
+
steps=steps,
|
| 1067 |
+
t_start=t_0,
|
| 1068 |
+
t_end=t_T,
|
| 1069 |
+
order=order,
|
| 1070 |
+
skip_type=skip_type,
|
| 1071 |
+
method=method,
|
| 1072 |
+
lower_order_final=lower_order_final,
|
| 1073 |
+
denoise_to_zero=denoise_to_zero,
|
| 1074 |
+
solver_type=solver_type,
|
| 1075 |
+
atol=atol,
|
| 1076 |
+
rtol=rtol,
|
| 1077 |
+
return_intermediate=return_intermediate,
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
def sample(
|
| 1081 |
+
self,
|
| 1082 |
+
x,
|
| 1083 |
+
steps=20,
|
| 1084 |
+
t_start=None,
|
| 1085 |
+
t_end=None,
|
| 1086 |
+
order=2,
|
| 1087 |
+
skip_type="time_uniform",
|
| 1088 |
+
method="multistep",
|
| 1089 |
+
lower_order_final=True,
|
| 1090 |
+
denoise_to_zero=False,
|
| 1091 |
+
solver_type="dpmsolver",
|
| 1092 |
+
atol=0.0078,
|
| 1093 |
+
rtol=0.05,
|
| 1094 |
+
return_intermediate=False,
|
| 1095 |
+
flow_shift=1.0,
|
| 1096 |
+
):
|
| 1097 |
+
"""
|
| 1098 |
+
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
| 1099 |
+
|
| 1100 |
+
=====================================================
|
| 1101 |
+
|
| 1102 |
+
We support the following algorithms for both noise prediction model and data prediction model:
|
| 1103 |
+
- 'singlestep':
|
| 1104 |
+
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
|
| 1105 |
+
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
|
| 1106 |
+
The total number of function evaluations (NFE) == `steps`.
|
| 1107 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
| 1108 |
+
- If `order` == 1:
|
| 1109 |
+
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
|
| 1110 |
+
- If `order` == 2:
|
| 1111 |
+
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
|
| 1112 |
+
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
|
| 1113 |
+
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 1114 |
+
- If `order` == 3:
|
| 1115 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
| 1116 |
+
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 1117 |
+
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
|
| 1118 |
+
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
|
| 1119 |
+
- 'multistep':
|
| 1120 |
+
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
|
| 1121 |
+
We initialize the first `order` values by lower order multistep solvers.
|
| 1122 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
| 1123 |
+
Denote K = steps.
|
| 1124 |
+
- If `order` == 1:
|
| 1125 |
+
- We use K steps of DPM-Solver-1 (i.e. DDIM).
|
| 1126 |
+
- If `order` == 2:
|
| 1127 |
+
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
|
| 1128 |
+
- If `order` == 3:
|
| 1129 |
+
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
|
| 1130 |
+
- 'singlestep_fixed':
|
| 1131 |
+
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
|
| 1132 |
+
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
| 1133 |
+
- 'adaptive':
|
| 1134 |
+
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
|
| 1135 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
| 1136 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
| 1137 |
+
(NFE) and the sample quality.
|
| 1138 |
+
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
|
| 1139 |
+
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
|
| 1140 |
+
|
| 1141 |
+
=====================================================
|
| 1142 |
+
|
| 1143 |
+
Some advices for choosing the algorithm:
|
| 1144 |
+
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
|
| 1145 |
+
Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
|
| 1146 |
+
e.g., DPM-Solver:
|
| 1147 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
|
| 1148 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
| 1149 |
+
skip_type='time_uniform', method='singlestep')
|
| 1150 |
+
e.g., DPM-Solver++:
|
| 1151 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
| 1152 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
| 1153 |
+
skip_type='time_uniform', method='singlestep')
|
| 1154 |
+
- For **guided sampling with large guidance scale** by DPMs:
|
| 1155 |
+
Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
|
| 1156 |
+
e.g.
|
| 1157 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
| 1158 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
|
| 1159 |
+
skip_type='time_uniform', method='multistep')
|
| 1160 |
+
|
| 1161 |
+
We support three types of `skip_type`:
|
| 1162 |
+
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
|
| 1163 |
+
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
|
| 1164 |
+
- 'time_quadratic': quadratic time for the time steps.
|
| 1165 |
+
|
| 1166 |
+
=====================================================
|
| 1167 |
+
Args:
|
| 1168 |
+
x: A pytorch tensor. The initial value at time `t_start`
|
| 1169 |
+
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
|
| 1170 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
| 1171 |
+
t_start: A `float`. The starting time of the sampling.
|
| 1172 |
+
If `T` is None, we use self.noise_schedule.T (default is 1.0).
|
| 1173 |
+
t_end: A `float`. The ending time of the sampling.
|
| 1174 |
+
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
|
| 1175 |
+
e.g. if total_N == 1000, we have `t_end` == 1e-3.
|
| 1176 |
+
For discrete-time DPMs:
|
| 1177 |
+
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
|
| 1178 |
+
For continuous-time DPMs:
|
| 1179 |
+
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
|
| 1180 |
+
order: A `int`. The order of DPM-Solver.
|
| 1181 |
+
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
| 1182 |
+
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
| 1183 |
+
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
|
| 1184 |
+
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
|
| 1185 |
+
|
| 1186 |
+
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
|
| 1187 |
+
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
|
| 1188 |
+
for diffusion models sampling by diffusion SDEs for low-resolutional images
|
| 1189 |
+
(such as CIFAR-10). However, we observed that such trick does not matter for
|
| 1190 |
+
high-resolutional images. As it needs an additional NFE, we do not recommend
|
| 1191 |
+
it for high-resolutional images.
|
| 1192 |
+
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
|
| 1193 |
+
Only valid for `method=multistep` and `steps < 15`. We empirically find that
|
| 1194 |
+
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
|
| 1195 |
+
(especially for steps <= 10). So we recommend to set it to be `True`.
|
| 1196 |
+
solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
|
| 1197 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
| 1198 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
| 1199 |
+
return_intermediate: A `bool`. Whether to save the xt at each step.
|
| 1200 |
+
When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
|
| 1201 |
+
Returns:
|
| 1202 |
+
x_end: A pytorch tensor. The approximated solution at time `t_end`.
|
| 1203 |
+
|
| 1204 |
+
"""
|
| 1205 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 1206 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 1207 |
+
assert (
|
| 1208 |
+
t_0 > 0 and t_T > 0
|
| 1209 |
+
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
| 1210 |
+
if return_intermediate:
|
| 1211 |
+
assert method in [
|
| 1212 |
+
"multistep",
|
| 1213 |
+
"singlestep",
|
| 1214 |
+
"singlestep_fixed",
|
| 1215 |
+
], "Cannot use adaptive solver when saving intermediate values"
|
| 1216 |
+
if self.correcting_xt_fn is not None:
|
| 1217 |
+
assert method in [
|
| 1218 |
+
"multistep",
|
| 1219 |
+
"singlestep",
|
| 1220 |
+
"singlestep_fixed",
|
| 1221 |
+
], "Cannot use adaptive solver when correcting_xt_fn is not None"
|
| 1222 |
+
device = x.device
|
| 1223 |
+
intermediates = []
|
| 1224 |
+
with torch.no_grad():
|
| 1225 |
+
if method == "adaptive":
|
| 1226 |
+
x = self.dpm_solver_adaptive(
|
| 1227 |
+
x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
|
| 1228 |
+
)
|
| 1229 |
+
elif method == "multistep":
|
| 1230 |
+
assert steps >= order
|
| 1231 |
+
timesteps = self.get_time_steps(
|
| 1232 |
+
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device, shift=flow_shift
|
| 1233 |
+
)
|
| 1234 |
+
assert timesteps.shape[0] - 1 == steps
|
| 1235 |
+
# Init the initial values.
|
| 1236 |
+
step = 0
|
| 1237 |
+
t = timesteps[step]
|
| 1238 |
+
t_prev_list = [t]
|
| 1239 |
+
model_prev_list = [self.model_fn(x, t)]
|
| 1240 |
+
if self.correcting_xt_fn is not None:
|
| 1241 |
+
x = self.correcting_xt_fn(x, t, step)
|
| 1242 |
+
if return_intermediate:
|
| 1243 |
+
intermediates.append(x)
|
| 1244 |
+
self.update_progress(step + 1, len(timesteps))
|
| 1245 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
| 1246 |
+
for step in range(1, order):
|
| 1247 |
+
t = timesteps[step]
|
| 1248 |
+
x = self.multistep_dpm_solver_update(
|
| 1249 |
+
x, model_prev_list, t_prev_list, t, step, solver_type=solver_type
|
| 1250 |
+
)
|
| 1251 |
+
if self.correcting_xt_fn is not None:
|
| 1252 |
+
x = self.correcting_xt_fn(x, t, step)
|
| 1253 |
+
if return_intermediate:
|
| 1254 |
+
intermediates.append(x)
|
| 1255 |
+
t_prev_list.append(t)
|
| 1256 |
+
model_prev_list.append(self.model_fn(x, t))
|
| 1257 |
+
# update progress bar
|
| 1258 |
+
self.update_progress(step + 1, len(timesteps))
|
| 1259 |
+
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
| 1260 |
+
for step in tqdm(range(order, steps + 1), disable=os.getenv("DPM_TQDM", "False") == "True"):
|
| 1261 |
+
t = timesteps[step]
|
| 1262 |
+
# We only use lower order for steps < 10
|
| 1263 |
+
# if lower_order_final and steps < 10:
|
| 1264 |
+
if lower_order_final: # recommended by Shuchen Xue
|
| 1265 |
+
step_order = min(order, steps + 1 - step)
|
| 1266 |
+
else:
|
| 1267 |
+
step_order = order
|
| 1268 |
+
x = self.multistep_dpm_solver_update(
|
| 1269 |
+
x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type
|
| 1270 |
+
)
|
| 1271 |
+
if self.correcting_xt_fn is not None:
|
| 1272 |
+
x = self.correcting_xt_fn(x, t, step)
|
| 1273 |
+
if return_intermediate:
|
| 1274 |
+
intermediates.append(x)
|
| 1275 |
+
for i in range(order - 1):
|
| 1276 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
| 1277 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
| 1278 |
+
t_prev_list[-1] = t
|
| 1279 |
+
# We do not need to evaluate the final model value.
|
| 1280 |
+
if step < steps:
|
| 1281 |
+
model_prev_list[-1] = self.model_fn(x, t)
|
| 1282 |
+
# update progress bar
|
| 1283 |
+
self.update_progress(step + 1, len(timesteps))
|
| 1284 |
+
elif method in ["singlestep", "singlestep_fixed"]:
|
| 1285 |
+
if method == "singlestep":
|
| 1286 |
+
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
|
| 1287 |
+
steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device
|
| 1288 |
+
)
|
| 1289 |
+
elif method == "singlestep_fixed":
|
| 1290 |
+
K = steps // order
|
| 1291 |
+
orders = [
|
| 1292 |
+
order,
|
| 1293 |
+
] * K
|
| 1294 |
+
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
| 1295 |
+
for step, order in enumerate(orders):
|
| 1296 |
+
s, t = timesteps_outer[step], timesteps_outer[step + 1]
|
| 1297 |
+
timesteps_inner = self.get_time_steps(
|
| 1298 |
+
skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device
|
| 1299 |
+
)
|
| 1300 |
+
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
| 1301 |
+
h = lambda_inner[-1] - lambda_inner[0]
|
| 1302 |
+
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
| 1303 |
+
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
| 1304 |
+
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
|
| 1305 |
+
if self.correcting_xt_fn is not None:
|
| 1306 |
+
x = self.correcting_xt_fn(x, t, step)
|
| 1307 |
+
if return_intermediate:
|
| 1308 |
+
intermediates.append(x)
|
| 1309 |
+
self.update_progress(step + 1, len(timesteps_outer))
|
| 1310 |
+
else:
|
| 1311 |
+
raise ValueError(f"Got wrong method {method}")
|
| 1312 |
+
if denoise_to_zero:
|
| 1313 |
+
t = torch.ones((1,)).to(device) * t_0
|
| 1314 |
+
x = self.denoise_to_zero_fn(x, t)
|
| 1315 |
+
if self.correcting_xt_fn is not None:
|
| 1316 |
+
x = self.correcting_xt_fn(x, t, step + 1)
|
| 1317 |
+
if return_intermediate:
|
| 1318 |
+
intermediates.append(x)
|
| 1319 |
+
if return_intermediate:
|
| 1320 |
+
return x, intermediates
|
| 1321 |
+
else:
|
| 1322 |
+
return x
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
+
#############################################################
|
| 1326 |
+
# other utility functions
|
| 1327 |
+
#############################################################
|
| 1328 |
+
|
| 1329 |
+
|
| 1330 |
+
def interpolate_fn(x, xp, yp):
|
| 1331 |
+
"""
|
| 1332 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 1333 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 1334 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 1335 |
+
|
| 1336 |
+
Args:
|
| 1337 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 1338 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 1339 |
+
yp: PyTorch tensor with shape [C, K].
|
| 1340 |
+
Returns:
|
| 1341 |
+
The function values f(x), with shape [N, C].
|
| 1342 |
+
"""
|
| 1343 |
+
N, K = x.shape[0], xp.shape[1]
|
| 1344 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 1345 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 1346 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 1347 |
+
cand_start_idx = x_idx - 1
|
| 1348 |
+
start_idx = torch.where(
|
| 1349 |
+
torch.eq(x_idx, 0),
|
| 1350 |
+
torch.tensor(1, device=x.device),
|
| 1351 |
+
torch.where(
|
| 1352 |
+
torch.eq(x_idx, K),
|
| 1353 |
+
torch.tensor(K - 2, device=x.device),
|
| 1354 |
+
cand_start_idx,
|
| 1355 |
+
),
|
| 1356 |
+
)
|
| 1357 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 1358 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 1359 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 1360 |
+
start_idx2 = torch.where(
|
| 1361 |
+
torch.eq(x_idx, 0),
|
| 1362 |
+
torch.tensor(0, device=x.device),
|
| 1363 |
+
torch.where(
|
| 1364 |
+
torch.eq(x_idx, K),
|
| 1365 |
+
torch.tensor(K - 2, device=x.device),
|
| 1366 |
+
cand_start_idx,
|
| 1367 |
+
),
|
| 1368 |
+
)
|
| 1369 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 1370 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 1371 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 1372 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 1373 |
+
return cand
|
| 1374 |
+
|
| 1375 |
+
|
| 1376 |
+
def expand_dims(v, dims):
|
| 1377 |
+
"""
|
| 1378 |
+
Expand the tensor `v` to the dim `dims`.
|
| 1379 |
+
|
| 1380 |
+
Args:
|
| 1381 |
+
`v`: a PyTorch tensor with shape [N].
|
| 1382 |
+
`dim`: a `int`.
|
| 1383 |
+
Returns:
|
| 1384 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 1385 |
+
"""
|
| 1386 |
+
return v[(...,) + (None,) * (dims - 1)]
|
omnigen2/transport/integrators.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as th
|
| 2 |
+
from torchdiffeq import odeint
|
| 3 |
+
from .utils import time_shift, get_lin_function
|
| 4 |
+
|
| 5 |
+
class sde:
|
| 6 |
+
"""SDE solver class"""
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
drift,
|
| 11 |
+
diffusion,
|
| 12 |
+
*,
|
| 13 |
+
t0,
|
| 14 |
+
t1,
|
| 15 |
+
num_steps,
|
| 16 |
+
sampler_type,
|
| 17 |
+
):
|
| 18 |
+
assert t0 < t1, "SDE sampler has to be in forward time"
|
| 19 |
+
|
| 20 |
+
self.num_timesteps = num_steps
|
| 21 |
+
self.t = th.linspace(t0, t1, num_steps)
|
| 22 |
+
self.dt = self.t[1] - self.t[0]
|
| 23 |
+
self.drift = drift
|
| 24 |
+
self.diffusion = diffusion
|
| 25 |
+
self.sampler_type = sampler_type
|
| 26 |
+
|
| 27 |
+
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
| 28 |
+
w_cur = th.randn(x.size()).to(x)
|
| 29 |
+
t = th.ones(x.size(0)).to(x) * t
|
| 30 |
+
dw = w_cur * th.sqrt(self.dt)
|
| 31 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
| 32 |
+
diffusion = self.diffusion(x, t)
|
| 33 |
+
mean_x = x + drift * self.dt
|
| 34 |
+
x = mean_x + th.sqrt(2 * diffusion) * dw
|
| 35 |
+
return x, mean_x
|
| 36 |
+
|
| 37 |
+
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
| 38 |
+
w_cur = th.randn(x.size()).to(x)
|
| 39 |
+
dw = w_cur * th.sqrt(self.dt)
|
| 40 |
+
t_cur = th.ones(x.size(0)).to(x) * t
|
| 41 |
+
diffusion = self.diffusion(x, t_cur)
|
| 42 |
+
xhat = x + th.sqrt(2 * diffusion) * dw
|
| 43 |
+
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
| 44 |
+
xp = xhat + self.dt * K1
|
| 45 |
+
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
| 46 |
+
return (
|
| 47 |
+
xhat + 0.5 * self.dt * (K1 + K2),
|
| 48 |
+
xhat,
|
| 49 |
+
) # at last time point we do not perform the heun step
|
| 50 |
+
|
| 51 |
+
def __forward_fn(self):
|
| 52 |
+
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
| 53 |
+
sampler_dict = {
|
| 54 |
+
"Euler": self.__Euler_Maruyama_step,
|
| 55 |
+
"Heun": self.__Heun_step,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
sampler = sampler_dict[self.sampler_type]
|
| 60 |
+
except:
|
| 61 |
+
raise NotImplementedError("Smapler type not implemented.")
|
| 62 |
+
|
| 63 |
+
return sampler
|
| 64 |
+
|
| 65 |
+
def sample(self, init, model, **model_kwargs):
|
| 66 |
+
"""forward loop of sde"""
|
| 67 |
+
x = init
|
| 68 |
+
mean_x = init
|
| 69 |
+
samples = []
|
| 70 |
+
sampler = self.__forward_fn()
|
| 71 |
+
for ti in self.t[:-1]:
|
| 72 |
+
with th.no_grad():
|
| 73 |
+
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 74 |
+
samples.append(x)
|
| 75 |
+
|
| 76 |
+
return samples
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ode:
|
| 80 |
+
"""ODE solver class"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
drift,
|
| 85 |
+
*,
|
| 86 |
+
t0,
|
| 87 |
+
t1,
|
| 88 |
+
sampler_type,
|
| 89 |
+
num_steps,
|
| 90 |
+
atol,
|
| 91 |
+
rtol,
|
| 92 |
+
do_shift=False,
|
| 93 |
+
time_shifting_factor=None,
|
| 94 |
+
):
|
| 95 |
+
assert t0 < t1, "ODE sampler has to be in forward time"
|
| 96 |
+
|
| 97 |
+
self.drift = drift
|
| 98 |
+
self.do_shift = do_shift
|
| 99 |
+
self.t = th.linspace(t0, t1, num_steps)
|
| 100 |
+
if time_shifting_factor:
|
| 101 |
+
self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
|
| 102 |
+
self.atol = atol
|
| 103 |
+
self.rtol = rtol
|
| 104 |
+
self.sampler_type = sampler_type
|
| 105 |
+
|
| 106 |
+
def sample(self, x, model, **model_kwargs):
|
| 107 |
+
x = x.float()
|
| 108 |
+
device = x[0].device if isinstance(x, tuple) else x.device
|
| 109 |
+
|
| 110 |
+
def _fn(t, x):
|
| 111 |
+
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
| 112 |
+
model_output = self.drift(x, t, model, **model_kwargs).float()
|
| 113 |
+
return model_output
|
| 114 |
+
|
| 115 |
+
t = self.t.to(device)
|
| 116 |
+
if self.do_shift:
|
| 117 |
+
mu = get_lin_function(y1=0.5, y2=1.15)(x.shape[1])
|
| 118 |
+
t = time_shift(mu, 1.0, t)
|
| 119 |
+
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
| 120 |
+
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
| 121 |
+
samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
|
| 122 |
+
return samples
|
omnigen2/transport/path.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch as th
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def expand_t_like_x(t, x):
|
| 6 |
+
"""Function to reshape time t to broadcastable dimension of x
|
| 7 |
+
Args:
|
| 8 |
+
t: [batch_dim,], time vector
|
| 9 |
+
x: [batch_dim,...], data point
|
| 10 |
+
"""
|
| 11 |
+
dims = [1] * len(x[0].size())
|
| 12 |
+
t = t.view(t.size(0), *dims)
|
| 13 |
+
return t
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#################### Coupling Plans ####################
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ICPlan:
|
| 20 |
+
"""Linear Coupling Plan"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, sigma=0.0):
|
| 23 |
+
self.sigma = sigma
|
| 24 |
+
|
| 25 |
+
def compute_alpha_t(self, t):
|
| 26 |
+
"""Compute the data coefficient along the path"""
|
| 27 |
+
return t, 1
|
| 28 |
+
|
| 29 |
+
def compute_sigma_t(self, t):
|
| 30 |
+
"""Compute the noise coefficient along the path"""
|
| 31 |
+
return 1 - t, -1
|
| 32 |
+
|
| 33 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 34 |
+
"""Compute the ratio between d_alpha and alpha"""
|
| 35 |
+
return 1 / t
|
| 36 |
+
|
| 37 |
+
def compute_drift(self, x, t):
|
| 38 |
+
"""We always output sde according to score parametrization;"""
|
| 39 |
+
t = expand_t_like_x(t, x)
|
| 40 |
+
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
| 41 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 42 |
+
drift = alpha_ratio * x
|
| 43 |
+
diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t
|
| 44 |
+
|
| 45 |
+
return -drift, diffusion
|
| 46 |
+
|
| 47 |
+
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
| 48 |
+
"""Compute the diffusion term of the SDE
|
| 49 |
+
Args:
|
| 50 |
+
x: [batch_dim, ...], data point
|
| 51 |
+
t: [batch_dim,], time vector
|
| 52 |
+
form: str, form of the diffusion term
|
| 53 |
+
norm: float, norm of the diffusion term
|
| 54 |
+
"""
|
| 55 |
+
t = expand_t_like_x(t, x)
|
| 56 |
+
choices = {
|
| 57 |
+
"constant": norm,
|
| 58 |
+
"SBDM": norm * self.compute_drift(x, t)[1],
|
| 59 |
+
"sigma": norm * self.compute_sigma_t(t)[0],
|
| 60 |
+
"linear": norm * (1 - t),
|
| 61 |
+
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
| 62 |
+
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
diffusion = choices[form]
|
| 67 |
+
except KeyError:
|
| 68 |
+
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
| 69 |
+
|
| 70 |
+
return diffusion
|
| 71 |
+
|
| 72 |
+
def get_score_from_velocity(self, velocity, x, t):
|
| 73 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
| 74 |
+
Args:
|
| 75 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 76 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 77 |
+
t: [batch_dim,] time tensor
|
| 78 |
+
"""
|
| 79 |
+
t = expand_t_like_x(t, x)
|
| 80 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 81 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 82 |
+
mean = x
|
| 83 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 84 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 85 |
+
score = (reverse_alpha_ratio * velocity - mean) / var
|
| 86 |
+
return score
|
| 87 |
+
|
| 88 |
+
def get_noise_from_velocity(self, velocity, x, t):
|
| 89 |
+
"""Wrapper function: transfrom velocity prediction model to denoiser
|
| 90 |
+
Args:
|
| 91 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 92 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 93 |
+
t: [batch_dim,] time tensor
|
| 94 |
+
"""
|
| 95 |
+
t = expand_t_like_x(t, x)
|
| 96 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 97 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 98 |
+
mean = x
|
| 99 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 100 |
+
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
| 101 |
+
noise = (reverse_alpha_ratio * velocity - mean) / var
|
| 102 |
+
return noise
|
| 103 |
+
|
| 104 |
+
def get_velocity_from_score(self, score, x, t):
|
| 105 |
+
"""Wrapper function: transfrom score prediction model to velocity
|
| 106 |
+
Args:
|
| 107 |
+
score: [batch_dim, ...] shaped tensor; score model output
|
| 108 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 109 |
+
t: [batch_dim,] time tensor
|
| 110 |
+
"""
|
| 111 |
+
t = expand_t_like_x(t, x)
|
| 112 |
+
drift, var = self.compute_drift(x, t)
|
| 113 |
+
velocity = var * score - drift
|
| 114 |
+
return velocity
|
| 115 |
+
|
| 116 |
+
def compute_mu_t(self, t, x0, x1):
|
| 117 |
+
"""Compute the mean of time-dependent density p_t"""
|
| 118 |
+
t = expand_t_like_x(t, x1)
|
| 119 |
+
alpha_t, _ = self.compute_alpha_t(t)
|
| 120 |
+
sigma_t, _ = self.compute_sigma_t(t)
|
| 121 |
+
if isinstance(x1, (list, tuple)):
|
| 122 |
+
return [alpha_t[i] * x1[i] + sigma_t[i] * x0[i] for i in range(len(x1))]
|
| 123 |
+
else:
|
| 124 |
+
return alpha_t * x1 + sigma_t * x0
|
| 125 |
+
|
| 126 |
+
def compute_xt(self, t, x0, x1):
|
| 127 |
+
"""Sample xt from time-dependent density p_t; rng is required"""
|
| 128 |
+
xt = self.compute_mu_t(t, x0, x1)
|
| 129 |
+
return xt
|
| 130 |
+
|
| 131 |
+
def compute_ut(self, t, x0, x1, xt):
|
| 132 |
+
"""Compute the vector field corresponding to p_t"""
|
| 133 |
+
t = expand_t_like_x(t, x1)
|
| 134 |
+
_, d_alpha_t = self.compute_alpha_t(t)
|
| 135 |
+
_, d_sigma_t = self.compute_sigma_t(t)
|
| 136 |
+
if isinstance(x1, (list, tuple)):
|
| 137 |
+
return [d_alpha_t * x1[i] + d_sigma_t * x0[i] for i in range(len(x1))]
|
| 138 |
+
else:
|
| 139 |
+
return d_alpha_t * x1 + d_sigma_t * x0
|
| 140 |
+
|
| 141 |
+
def plan(self, t, x0, x1):
|
| 142 |
+
xt = self.compute_xt(t, x0, x1)
|
| 143 |
+
ut = self.compute_ut(t, x0, x1, xt)
|
| 144 |
+
return t, xt, ut
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class VPCPlan(ICPlan):
|
| 148 |
+
"""class for VP path flow matching"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
| 151 |
+
self.sigma_min = sigma_min
|
| 152 |
+
self.sigma_max = sigma_max
|
| 153 |
+
self.log_mean_coeff = (
|
| 154 |
+
lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
| 155 |
+
)
|
| 156 |
+
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
| 157 |
+
|
| 158 |
+
def compute_alpha_t(self, t):
|
| 159 |
+
"""Compute coefficient of x1"""
|
| 160 |
+
alpha_t = self.log_mean_coeff(t)
|
| 161 |
+
alpha_t = th.exp(alpha_t)
|
| 162 |
+
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
| 163 |
+
return alpha_t, d_alpha_t
|
| 164 |
+
|
| 165 |
+
def compute_sigma_t(self, t):
|
| 166 |
+
"""Compute coefficient of x0"""
|
| 167 |
+
p_sigma_t = 2 * self.log_mean_coeff(t)
|
| 168 |
+
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
| 169 |
+
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
| 170 |
+
return sigma_t, d_sigma_t
|
| 171 |
+
|
| 172 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 173 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 174 |
+
return self.d_log_mean_coeff(t)
|
| 175 |
+
|
| 176 |
+
def compute_drift(self, x, t):
|
| 177 |
+
"""Compute the drift term of the SDE"""
|
| 178 |
+
t = expand_t_like_x(t, x)
|
| 179 |
+
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
| 180 |
+
return -0.5 * beta_t * x, beta_t / 2
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class GVPCPlan(ICPlan):
|
| 184 |
+
def __init__(self, sigma=0.0):
|
| 185 |
+
super().__init__(sigma)
|
| 186 |
+
|
| 187 |
+
def compute_alpha_t(self, t):
|
| 188 |
+
"""Compute coefficient of x1"""
|
| 189 |
+
alpha_t = th.sin(t * np.pi / 2)
|
| 190 |
+
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
| 191 |
+
return alpha_t, d_alpha_t
|
| 192 |
+
|
| 193 |
+
def compute_sigma_t(self, t):
|
| 194 |
+
"""Compute coefficient of x0"""
|
| 195 |
+
sigma_t = th.cos(t * np.pi / 2)
|
| 196 |
+
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
| 197 |
+
return sigma_t, d_sigma_t
|
| 198 |
+
|
| 199 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 200 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 201 |
+
return np.pi / (2 * th.tan(t * np.pi / 2))
|
omnigen2/transport/transport.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
from . import path
|
| 10 |
+
from .integrators import ode, sde
|
| 11 |
+
from .utils import mean_flat, expand_dims
|
| 12 |
+
from .dpm_solver import NoiseScheduleFlow, model_wrapper, DPM_Solver
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ModelType(enum.Enum):
|
| 16 |
+
"""
|
| 17 |
+
Which type of output the model predicts.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
NOISE = enum.auto() # the model predicts epsilon
|
| 21 |
+
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
| 22 |
+
VELOCITY = enum.auto() # the model predicts v(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PathType(enum.Enum):
|
| 26 |
+
"""
|
| 27 |
+
Which type of path to use.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
LINEAR = enum.auto()
|
| 31 |
+
GVP = enum.auto()
|
| 32 |
+
VP = enum.auto()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class WeightType(enum.Enum):
|
| 36 |
+
"""
|
| 37 |
+
Which type of weighting to use.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
NONE = enum.auto()
|
| 41 |
+
VELOCITY = enum.auto()
|
| 42 |
+
LIKELIHOOD = enum.auto()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Transport:
|
| 46 |
+
def __init__(self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type, do_shift, seq_len,
|
| 47 |
+
dynamic_time_shift: bool = False,
|
| 48 |
+
time_shift_version: str = "v1"):
|
| 49 |
+
path_options = {
|
| 50 |
+
PathType.LINEAR: path.ICPlan,
|
| 51 |
+
PathType.GVP: path.GVPCPlan,
|
| 52 |
+
PathType.VP: path.VPCPlan,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
self.loss_type = loss_type
|
| 56 |
+
self.model_type = model_type
|
| 57 |
+
self.path_sampler = path_options[path_type]()
|
| 58 |
+
self.train_eps = train_eps
|
| 59 |
+
self.sample_eps = sample_eps
|
| 60 |
+
|
| 61 |
+
self.snr_type = snr_type
|
| 62 |
+
self.do_shift = do_shift
|
| 63 |
+
self.seq_len = seq_len
|
| 64 |
+
self.dynamic_time_shift = dynamic_time_shift
|
| 65 |
+
self.time_shift_version = time_shift_version
|
| 66 |
+
def prior_logp(self, z):
|
| 67 |
+
"""
|
| 68 |
+
Standard multivariate normal prior
|
| 69 |
+
Assume z is batched
|
| 70 |
+
"""
|
| 71 |
+
shape = th.tensor(z.size())
|
| 72 |
+
N = th.prod(shape[1:])
|
| 73 |
+
_fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
|
| 74 |
+
return th.vmap(_fn)(z)
|
| 75 |
+
|
| 76 |
+
def check_interval(
|
| 77 |
+
self,
|
| 78 |
+
train_eps,
|
| 79 |
+
sample_eps,
|
| 80 |
+
*,
|
| 81 |
+
diffusion_form="SBDM",
|
| 82 |
+
sde=False,
|
| 83 |
+
reverse=False,
|
| 84 |
+
eval=False,
|
| 85 |
+
last_step_size=0.0,
|
| 86 |
+
):
|
| 87 |
+
t0 = 0
|
| 88 |
+
t1 = 1
|
| 89 |
+
eps = train_eps if not eval else sample_eps
|
| 90 |
+
if type(self.path_sampler) in [path.VPCPlan]:
|
| 91 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 92 |
+
|
| 93 |
+
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
|
| 94 |
+
self.model_type != ModelType.VELOCITY or sde
|
| 95 |
+
): # avoid numerical issue by taking a first semi-implicit step
|
| 96 |
+
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
|
| 97 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 98 |
+
|
| 99 |
+
if reverse:
|
| 100 |
+
t0, t1 = 1 - t0, 1 - t1
|
| 101 |
+
|
| 102 |
+
return t0, t1
|
| 103 |
+
|
| 104 |
+
def sample(self, x1, process_index, num_processes):
|
| 105 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
| 106 |
+
Args:
|
| 107 |
+
x1 - data point; [batch, *dim]
|
| 108 |
+
"""
|
| 109 |
+
if isinstance(x1, (list, tuple)):
|
| 110 |
+
x0 = [th.randn_like(img_start) for img_start in x1]
|
| 111 |
+
else:
|
| 112 |
+
x0 = th.randn_like(x1)
|
| 113 |
+
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
| 114 |
+
|
| 115 |
+
if self.snr_type.startswith("uniform"):
|
| 116 |
+
assert t0 == 0.0 and t1 == 1.0, "not implemented."
|
| 117 |
+
if "_" in self.snr_type:
|
| 118 |
+
_, t0, t1 = self.snr_type.split("_")
|
| 119 |
+
t0, t1 = float(t0), float(t1)
|
| 120 |
+
t = th.rand((len(x1),)) * (t1 - t0) + t0
|
| 121 |
+
if self.snr_type == "stratified_uniform":
|
| 122 |
+
batch_size = len(x1)
|
| 123 |
+
n = batch_size * num_processes
|
| 124 |
+
offsets = th.arange(process_index, n, num_processes)
|
| 125 |
+
u = th.rand(size=(batch_size,))
|
| 126 |
+
t = ((offsets + u) / n)
|
| 127 |
+
elif self.snr_type == "lognorm":
|
| 128 |
+
u = th.normal(mean=0.0, std=1.0, size=(len(x1),))
|
| 129 |
+
t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
|
| 130 |
+
elif self.snr_type == "zero":
|
| 131 |
+
t = th.rand((len(x1),))
|
| 132 |
+
for _ in range(len(x1)):
|
| 133 |
+
if random.random() < 1.0:
|
| 134 |
+
t[_] = 0.0
|
| 135 |
+
# print(t)
|
| 136 |
+
else:
|
| 137 |
+
raise NotImplementedError("Not implemented snr_type %s" % self.snr_type)
|
| 138 |
+
|
| 139 |
+
if self.do_shift:
|
| 140 |
+
if self.dynamic_time_shift:
|
| 141 |
+
if self.time_shift_version == "v1":
|
| 142 |
+
base_shift: float = 0.5
|
| 143 |
+
max_shift: float = 1.15
|
| 144 |
+
lin_func = self.get_lin_function(y1=base_shift, y2=max_shift)
|
| 145 |
+
|
| 146 |
+
mu = th.tensor([lin_func((_x1.shape[-2] // 2) * (_x1.shape[-1] // 2)) for _x1 in x1], dtype=t.dtype, device=t.device).view_as(t)
|
| 147 |
+
t = self.time_shift(mu, 1.0, t)
|
| 148 |
+
elif self.time_shift_version == "v2":
|
| 149 |
+
tokens = th.tensor([(_x1.shape[-2] // 2) * (_x1.shape[-1] // 2) for _x1 in x1], dtype=t.dtype, device=t.device).view_as(t)
|
| 150 |
+
t = self.time_shift_v2(tokens, t)
|
| 151 |
+
else:
|
| 152 |
+
if self.time_shift_version == "v1":
|
| 153 |
+
base_shift: float = 0.5
|
| 154 |
+
max_shift: float = 1.15
|
| 155 |
+
mu = self.get_lin_function(y1=base_shift, y2=max_shift)(self.seq_len)
|
| 156 |
+
t = self.time_shift(mu, 1.0, t)
|
| 157 |
+
elif self.time_shift_version == "v2":
|
| 158 |
+
tokens = th.tensor([self.seq_len] * len(x1), dtype=t.dtype, device=t.device).view_as(t)
|
| 159 |
+
t = self.time_shift_v2(tokens, t)
|
| 160 |
+
t = t.to(x1[0])
|
| 161 |
+
return t, x0, x1
|
| 162 |
+
|
| 163 |
+
def time_shift(self, mu: float, sigma: float, t: th.Tensor):
|
| 164 |
+
# the following implementation was original for t=0: clean / t=1: noise
|
| 165 |
+
# Since we adopt the reverse, the 1-t operations are needed
|
| 166 |
+
t = 1 - t
|
| 167 |
+
if isinstance(mu, th.Tensor):
|
| 168 |
+
t = th.exp(mu) / (th.exp(mu) + (1 / t - 1) ** sigma)
|
| 169 |
+
else:
|
| 170 |
+
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 171 |
+
t = 1 - t
|
| 172 |
+
return t
|
| 173 |
+
|
| 174 |
+
def time_shift_v2(self, tokens: th.Tensor, t: th.Tensor):
|
| 175 |
+
# t = th.exp(mu) / (th.exp(mu) + (1 / t - 1) ** sigma)
|
| 176 |
+
m = th.sqrt(tokens) / 20
|
| 177 |
+
t = t / (m - m * t + t)
|
| 178 |
+
return t
|
| 179 |
+
|
| 180 |
+
def get_lin_function(
|
| 181 |
+
self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 182 |
+
) -> Callable[[float], float]:
|
| 183 |
+
m = (y2 - y1) / (x2 - x1)
|
| 184 |
+
b = y1 - m * x1
|
| 185 |
+
return lambda x: m * x + b
|
| 186 |
+
|
| 187 |
+
def training_losses(
|
| 188 |
+
self,
|
| 189 |
+
model,
|
| 190 |
+
x1,
|
| 191 |
+
model_kwargs=None,
|
| 192 |
+
process_index: Optional[int] = None,
|
| 193 |
+
num_processes: Optional[int] = None,
|
| 194 |
+
reduction: str = 'mean',
|
| 195 |
+
):
|
| 196 |
+
"""Loss for training the score model
|
| 197 |
+
Args:
|
| 198 |
+
- model: backbone model; could be score, noise, or velocity
|
| 199 |
+
- x1: datapoint
|
| 200 |
+
- model_kwargs: additional arguments for the model
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
terms = {}
|
| 204 |
+
|
| 205 |
+
if model_kwargs is None:
|
| 206 |
+
model_kwargs = {}
|
| 207 |
+
t, x0, x1 = self.sample(x1, process_index, num_processes)
|
| 208 |
+
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
| 209 |
+
|
| 210 |
+
terms = {}
|
| 211 |
+
terms['t'] = t
|
| 212 |
+
terms['xt'] = xt
|
| 213 |
+
|
| 214 |
+
if "cond" in model_kwargs:
|
| 215 |
+
conds = model_kwargs.pop("cond")
|
| 216 |
+
xt = [th.cat([x, cond], dim=0) if cond is not None else x for x, cond in zip(xt, conds)]
|
| 217 |
+
model_output = model(xt, t, **model_kwargs)
|
| 218 |
+
B = len(x0)
|
| 219 |
+
|
| 220 |
+
terms['pred'] = model_output
|
| 221 |
+
if self.model_type == ModelType.VELOCITY:
|
| 222 |
+
if isinstance(x1, (list, tuple)):
|
| 223 |
+
assert len(model_output) == len(ut) == len(x1)
|
| 224 |
+
for i in range(B):
|
| 225 |
+
assert (
|
| 226 |
+
model_output[i].shape == ut[i].shape == x1[i].shape
|
| 227 |
+
), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}"
|
| 228 |
+
terms["task_loss"] = th.stack(
|
| 229 |
+
[th.nn.functional.mse_loss(ut[i].float(), model_output[i].float(), reduction=reduction) for i in range(B)],
|
| 230 |
+
dim=0,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
terms["task_loss"] = mean_flat(((model_output - ut) ** 2))
|
| 234 |
+
else:
|
| 235 |
+
raise NotImplementedError
|
| 236 |
+
|
| 237 |
+
terms["loss"] = terms["task_loss"]
|
| 238 |
+
terms["t"] = t
|
| 239 |
+
return terms
|
| 240 |
+
|
| 241 |
+
def get_drift(self):
|
| 242 |
+
"""member function for obtaining the drift of the probability flow ODE"""
|
| 243 |
+
|
| 244 |
+
def score_ode(x, t, model, **model_kwargs):
|
| 245 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 246 |
+
model_output = model(x, t, **model_kwargs)
|
| 247 |
+
return -drift_mean + drift_var * model_output # by change of variable
|
| 248 |
+
|
| 249 |
+
def noise_ode(x, t, model, **model_kwargs):
|
| 250 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 251 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
| 252 |
+
model_output = model(x, t, **model_kwargs)
|
| 253 |
+
score = model_output / -sigma_t
|
| 254 |
+
return -drift_mean + drift_var * score
|
| 255 |
+
|
| 256 |
+
def velocity_ode(x, t, model, **model_kwargs):
|
| 257 |
+
model_output = model(x, t, **model_kwargs)
|
| 258 |
+
return model_output
|
| 259 |
+
|
| 260 |
+
if self.model_type == ModelType.NOISE:
|
| 261 |
+
drift_fn = noise_ode
|
| 262 |
+
elif self.model_type == ModelType.SCORE:
|
| 263 |
+
drift_fn = score_ode
|
| 264 |
+
else:
|
| 265 |
+
drift_fn = velocity_ode
|
| 266 |
+
|
| 267 |
+
def body_fn(x, t, model, **model_kwargs):
|
| 268 |
+
model_output = drift_fn(x, t, model, **model_kwargs)
|
| 269 |
+
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
| 270 |
+
return model_output
|
| 271 |
+
|
| 272 |
+
return body_fn
|
| 273 |
+
|
| 274 |
+
def get_score(
|
| 275 |
+
self,
|
| 276 |
+
):
|
| 277 |
+
"""member function for obtaining score of
|
| 278 |
+
x_t = alpha_t * x + sigma_t * eps"""
|
| 279 |
+
if self.model_type == ModelType.NOISE:
|
| 280 |
+
score_fn = (
|
| 281 |
+
lambda x, t, model, **kwargs: model(x, t, **kwargs)
|
| 282 |
+
/ -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
| 283 |
+
)
|
| 284 |
+
elif self.model_type == ModelType.SCORE:
|
| 285 |
+
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
| 286 |
+
elif self.model_type == ModelType.VELOCITY:
|
| 287 |
+
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
|
| 288 |
+
model(x, t, **kwargs), x, t
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
raise NotImplementedError()
|
| 292 |
+
|
| 293 |
+
return score_fn
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class Sampler:
|
| 297 |
+
"""Sampler class for the transport model"""
|
| 298 |
+
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
transport,
|
| 302 |
+
):
|
| 303 |
+
"""Constructor for a general sampler; supporting different sampling methods
|
| 304 |
+
Args:
|
| 305 |
+
- transport: an tranport object specify model prediction & interpolant type
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
self.transport = transport
|
| 309 |
+
self.drift = self.transport.get_drift()
|
| 310 |
+
self.score = self.transport.get_score()
|
| 311 |
+
|
| 312 |
+
def __get_sde_diffusion_and_drift(
|
| 313 |
+
self,
|
| 314 |
+
*,
|
| 315 |
+
diffusion_form="SBDM",
|
| 316 |
+
diffusion_norm=1.0,
|
| 317 |
+
):
|
| 318 |
+
def diffusion_fn(x, t):
|
| 319 |
+
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
|
| 320 |
+
return diffusion
|
| 321 |
+
|
| 322 |
+
sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(
|
| 323 |
+
x, t, model, **kwargs
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
sde_diffusion = diffusion_fn
|
| 327 |
+
|
| 328 |
+
return sde_drift, sde_diffusion
|
| 329 |
+
|
| 330 |
+
def __get_last_step(
|
| 331 |
+
self,
|
| 332 |
+
sde_drift,
|
| 333 |
+
*,
|
| 334 |
+
last_step,
|
| 335 |
+
last_step_size,
|
| 336 |
+
):
|
| 337 |
+
"""Get the last step function of the SDE solver"""
|
| 338 |
+
|
| 339 |
+
if last_step is None:
|
| 340 |
+
last_step_fn = lambda x, t, model, **model_kwargs: x
|
| 341 |
+
elif last_step == "Mean":
|
| 342 |
+
last_step_fn = (
|
| 343 |
+
lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size
|
| 344 |
+
)
|
| 345 |
+
elif last_step == "Tweedie":
|
| 346 |
+
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
|
| 347 |
+
sigma = self.transport.path_sampler.compute_sigma_t
|
| 348 |
+
last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][
|
| 349 |
+
0
|
| 350 |
+
] * self.score(x, t, model, **model_kwargs)
|
| 351 |
+
elif last_step == "Euler":
|
| 352 |
+
last_step_fn = (
|
| 353 |
+
lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
raise NotImplementedError()
|
| 357 |
+
|
| 358 |
+
return last_step_fn
|
| 359 |
+
|
| 360 |
+
def sample_sde(
|
| 361 |
+
self,
|
| 362 |
+
*,
|
| 363 |
+
sampling_method="Euler",
|
| 364 |
+
diffusion_form="SBDM",
|
| 365 |
+
diffusion_norm=1.0,
|
| 366 |
+
last_step="Mean",
|
| 367 |
+
last_step_size=0.04,
|
| 368 |
+
num_steps=250,
|
| 369 |
+
):
|
| 370 |
+
"""returns a sampling function with given SDE settings
|
| 371 |
+
Args:
|
| 372 |
+
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
| 373 |
+
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
| 374 |
+
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
| 375 |
+
- last_step: type of the last step; default to identity
|
| 376 |
+
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
| 377 |
+
- num_steps: total integration step of SDE
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
if last_step is None:
|
| 381 |
+
last_step_size = 0.0
|
| 382 |
+
|
| 383 |
+
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
| 384 |
+
diffusion_form=diffusion_form,
|
| 385 |
+
diffusion_norm=diffusion_norm,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
t0, t1 = self.transport.check_interval(
|
| 389 |
+
self.transport.train_eps,
|
| 390 |
+
self.transport.sample_eps,
|
| 391 |
+
diffusion_form=diffusion_form,
|
| 392 |
+
sde=True,
|
| 393 |
+
eval=True,
|
| 394 |
+
reverse=False,
|
| 395 |
+
last_step_size=last_step_size,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
_sde = sde(
|
| 399 |
+
sde_drift,
|
| 400 |
+
sde_diffusion,
|
| 401 |
+
t0=t0,
|
| 402 |
+
t1=t1,
|
| 403 |
+
num_steps=num_steps,
|
| 404 |
+
sampler_type=sampling_method,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
|
| 408 |
+
|
| 409 |
+
def _sample(init, model, **model_kwargs):
|
| 410 |
+
xs = _sde.sample(init, model, **model_kwargs)
|
| 411 |
+
ts = th.ones(init.size(0), device=init.device) * t1
|
| 412 |
+
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
| 413 |
+
xs.append(x)
|
| 414 |
+
|
| 415 |
+
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
| 416 |
+
|
| 417 |
+
return xs
|
| 418 |
+
|
| 419 |
+
return _sample
|
| 420 |
+
|
| 421 |
+
def sample_dpm(
|
| 422 |
+
self,
|
| 423 |
+
model,
|
| 424 |
+
model_kwargs=None,
|
| 425 |
+
):
|
| 426 |
+
|
| 427 |
+
noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
|
| 428 |
+
|
| 429 |
+
def noise_pred_fn(x, t_continuous):
|
| 430 |
+
output = model(x, 1 - t_continuous, **model_kwargs)
|
| 431 |
+
_, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 432 |
+
try:
|
| 433 |
+
noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output
|
| 434 |
+
except:
|
| 435 |
+
noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0]
|
| 436 |
+
return noise
|
| 437 |
+
|
| 438 |
+
return DPM_Solver(noise_pred_fn, noise_schedule, algorithm_type="dpmsolver++").sample
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def sample_ode(
|
| 442 |
+
self,
|
| 443 |
+
*,
|
| 444 |
+
sampling_method="dopri5",
|
| 445 |
+
num_steps=50,
|
| 446 |
+
atol=1e-6,
|
| 447 |
+
rtol=1e-3,
|
| 448 |
+
reverse=False,
|
| 449 |
+
do_shift=False,
|
| 450 |
+
time_shifting_factor=None,
|
| 451 |
+
):
|
| 452 |
+
"""returns a sampling function with given ODE settings
|
| 453 |
+
Args:
|
| 454 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 455 |
+
- num_steps:
|
| 456 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 457 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 458 |
+
- atol: absolute error tolerance for the solver
|
| 459 |
+
- rtol: relative error tolerance for the solver
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
# for flux
|
| 463 |
+
drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs)
|
| 464 |
+
|
| 465 |
+
t0, t1 = self.transport.check_interval(
|
| 466 |
+
self.transport.train_eps,
|
| 467 |
+
self.transport.sample_eps,
|
| 468 |
+
sde=False,
|
| 469 |
+
eval=True,
|
| 470 |
+
reverse=reverse,
|
| 471 |
+
last_step_size=0.0,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
_ode = ode(
|
| 475 |
+
drift=drift,
|
| 476 |
+
t0=t0,
|
| 477 |
+
t1=t1,
|
| 478 |
+
sampler_type=sampling_method,
|
| 479 |
+
num_steps=num_steps,
|
| 480 |
+
atol=atol,
|
| 481 |
+
rtol=rtol,
|
| 482 |
+
do_shift=do_shift,
|
| 483 |
+
time_shifting_factor=time_shifting_factor,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
return _ode.sample
|
| 487 |
+
|
| 488 |
+
def sample_ode_likelihood(
|
| 489 |
+
self,
|
| 490 |
+
*,
|
| 491 |
+
sampling_method="dopri5",
|
| 492 |
+
num_steps=50,
|
| 493 |
+
atol=1e-6,
|
| 494 |
+
rtol=1e-3,
|
| 495 |
+
):
|
| 496 |
+
"""returns a sampling function for calculating likelihood with given ODE settings
|
| 497 |
+
Args:
|
| 498 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 499 |
+
- num_steps:
|
| 500 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 501 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 502 |
+
- atol: absolute error tolerance for the solver
|
| 503 |
+
- rtol: relative error tolerance for the solver
|
| 504 |
+
"""
|
| 505 |
+
|
| 506 |
+
def _likelihood_drift(x, t, model, **model_kwargs):
|
| 507 |
+
x, _ = x
|
| 508 |
+
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
| 509 |
+
t = th.ones_like(t) * (1 - t)
|
| 510 |
+
with th.enable_grad():
|
| 511 |
+
x.requires_grad = True
|
| 512 |
+
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
|
| 513 |
+
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
| 514 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
| 515 |
+
return (-drift, logp_grad)
|
| 516 |
+
|
| 517 |
+
t0, t1 = self.transport.check_interval(
|
| 518 |
+
self.transport.train_eps,
|
| 519 |
+
self.transport.sample_eps,
|
| 520 |
+
sde=False,
|
| 521 |
+
eval=True,
|
| 522 |
+
reverse=False,
|
| 523 |
+
last_step_size=0.0,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
_ode = ode(
|
| 527 |
+
drift=_likelihood_drift,
|
| 528 |
+
t0=t0,
|
| 529 |
+
t1=t1,
|
| 530 |
+
sampler_type=sampling_method,
|
| 531 |
+
num_steps=num_steps,
|
| 532 |
+
atol=atol,
|
| 533 |
+
rtol=rtol,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def _sample_fn(x, model, **model_kwargs):
|
| 537 |
+
init_logp = th.zeros(x.size(0)).to(x)
|
| 538 |
+
input = (x, init_logp)
|
| 539 |
+
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
| 540 |
+
drift, delta_logp = drift[-1], delta_logp[-1]
|
| 541 |
+
prior_logp = self.transport.prior_logp(drift)
|
| 542 |
+
logp = prior_logp - delta_logp
|
| 543 |
+
return logp, drift
|
| 544 |
+
|
| 545 |
+
return _sample_fn
|
omnigen2/transport/utils.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as th
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
class EasyDict:
|
| 5 |
+
def __init__(self, sub_dict):
|
| 6 |
+
for k, v in sub_dict.items():
|
| 7 |
+
setattr(self, k, v)
|
| 8 |
+
|
| 9 |
+
def __getitem__(self, key):
|
| 10 |
+
return getattr(self, key)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def mean_flat(x):
|
| 14 |
+
"""
|
| 15 |
+
Take the mean over all non-batch dimensions.
|
| 16 |
+
"""
|
| 17 |
+
return th.mean(x, dim=list(range(1, len(x.size()))))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def log_state(state):
|
| 21 |
+
result = []
|
| 22 |
+
|
| 23 |
+
sorted_state = dict(sorted(state.items()))
|
| 24 |
+
for key, value in sorted_state.items():
|
| 25 |
+
# Check if the value is an instance of a class
|
| 26 |
+
if "<object" in str(value) or "object at" in str(value):
|
| 27 |
+
result.append(f"{key}: [{value.__class__.__name__}]")
|
| 28 |
+
else:
|
| 29 |
+
result.append(f"{key}: {value}")
|
| 30 |
+
|
| 31 |
+
return "\n".join(result)
|
| 32 |
+
|
| 33 |
+
def time_shift(mu: float, sigma: float, t: th.Tensor):
|
| 34 |
+
# the following implementation was original for t=0: clean / t=1: noise
|
| 35 |
+
# Since we adopt the reverse, the 1-t operations are needed
|
| 36 |
+
t = 1 - t
|
| 37 |
+
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 38 |
+
t = 1 - t
|
| 39 |
+
return t
|
| 40 |
+
|
| 41 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
|
| 42 |
+
m = (y2 - y1) / (x2 - x1)
|
| 43 |
+
b = y1 - m * x1
|
| 44 |
+
return lambda x: m * x + b
|
| 45 |
+
|
| 46 |
+
def expand_dims(v, dims):
|
| 47 |
+
"""
|
| 48 |
+
Expand the tensor `v` to the dim `dims`.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
`v`: a PyTorch tensor with shape [N].
|
| 52 |
+
`dim`: a `int`.
|
| 53 |
+
Returns:
|
| 54 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 55 |
+
"""
|
| 56 |
+
return v[(...,) + (None,) * (dims - 1)]
|
omnigen2/utils/__init__.py
ADDED
|
File without changes
|
omnigen2/utils/img_util.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision.transforms.functional import to_pil_image
|
| 7 |
+
|
| 8 |
+
def resize_image(image, max_pixels, img_scale_num):
|
| 9 |
+
width, height = image.size
|
| 10 |
+
cur_pixels = height * width
|
| 11 |
+
ratio = (max_pixels / cur_pixels) ** 0.5
|
| 12 |
+
ratio = min(ratio, 1.0) # do not upscale input image
|
| 13 |
+
|
| 14 |
+
new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num
|
| 15 |
+
|
| 16 |
+
image = image.resize((new_width, new_height), resample=Image.BICUBIC)
|
| 17 |
+
return image
|
| 18 |
+
|
| 19 |
+
def create_collage(images: List[torch.Tensor]) -> Image.Image:
|
| 20 |
+
"""Create a horizontal collage from a list of images."""
|
| 21 |
+
max_height = max(img.shape[-2] for img in images)
|
| 22 |
+
total_width = sum(img.shape[-1] for img in images)
|
| 23 |
+
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
|
| 24 |
+
|
| 25 |
+
current_x = 0
|
| 26 |
+
for img in images:
|
| 27 |
+
h, w = img.shape[-2:]
|
| 28 |
+
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
|
| 29 |
+
current_x += w
|
| 30 |
+
|
| 31 |
+
return to_pil_image(canvas)
|
omnigen2/utils/import_utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Import utilities: Utilities related to imports and our lazy inits.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import importlib.util
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
# The package importlib_metadata is in a different place, depending on the python version.
|
| 22 |
+
if sys.version_info < (3, 8):
|
| 23 |
+
import importlib_metadata
|
| 24 |
+
else:
|
| 25 |
+
import importlib.metadata as importlib_metadata
|
| 26 |
+
|
| 27 |
+
def _is_package_available(pkg_name: str):
|
| 28 |
+
pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
| 29 |
+
pkg_version = "N/A"
|
| 30 |
+
|
| 31 |
+
if pkg_exists:
|
| 32 |
+
try:
|
| 33 |
+
pkg_version = importlib_metadata.version(pkg_name)
|
| 34 |
+
except (ImportError, importlib_metadata.PackageNotFoundError):
|
| 35 |
+
pkg_exists = False
|
| 36 |
+
|
| 37 |
+
return pkg_exists, pkg_version
|
| 38 |
+
|
| 39 |
+
_triton_available, _triton_version = _is_package_available("triton")
|
| 40 |
+
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
| 41 |
+
|
| 42 |
+
def is_triton_available():
|
| 43 |
+
return _triton_available
|
| 44 |
+
|
| 45 |
+
def is_flash_attn_available():
|
| 46 |
+
return _flash_attn_available
|
omnigen2/utils/logging_utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
class TqdmToLogger(object):
|
| 4 |
+
"""File-like object to redirect tqdm output to a logger."""
|
| 5 |
+
def __init__(self, logger, level=logging.INFO):
|
| 6 |
+
self.logger = logger
|
| 7 |
+
self.level = level
|
| 8 |
+
|
| 9 |
+
def write(self, buf):
|
| 10 |
+
for line in buf.rstrip().splitlines():
|
| 11 |
+
self.logger.log(self.level, line)
|
| 12 |
+
|
| 13 |
+
def flush(self):
|
| 14 |
+
for handler in self.logger.logger.handlers:
|
| 15 |
+
handler.flush()
|
omnigen2/utils/reproducibility.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers.utils import is_torch_npu_available
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def worker_init_fn(worker_id, num_processes, num_workers, process_index, seed, same_seed_per_epoch=False):
|
| 10 |
+
if same_seed_per_epoch:
|
| 11 |
+
worker_seed = seed + num_processes + num_workers * process_index + worker_id
|
| 12 |
+
else:
|
| 13 |
+
worker_seed = torch.initial_seed()
|
| 14 |
+
|
| 15 |
+
random.seed(worker_seed)
|
| 16 |
+
np.random.seed(worker_seed % 2**32)
|
| 17 |
+
torch.manual_seed(worker_seed)
|
| 18 |
+
|
| 19 |
+
if is_torch_npu_available():
|
| 20 |
+
torch.npu.manual_seed_all(seed)
|
| 21 |
+
else:
|
| 22 |
+
torch.cuda.manual_seed_all(seed)
|
omnigen2/utils/teacache_util.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility for TeaCache
|
| 3 |
+
|
| 4 |
+
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
| 5 |
+
|
| 6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
you may not use this file except in compliance with the License.
|
| 8 |
+
You may obtain a copy of the License at
|
| 9 |
+
|
| 10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
|
| 12 |
+
Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
See the License for the specific language governing permissions and
|
| 16 |
+
limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TeaCacheParams:
|
| 26 |
+
"""
|
| 27 |
+
TeaCache parameters for `OmniGen2Transformer2DModel`
|
| 28 |
+
See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
previous_residual (Optional[torch.Tensor]):
|
| 32 |
+
The tensor difference between the output and the input of the transformer layers from the previous timestep.
|
| 33 |
+
previous_modulated_inp (Optional[torch.Tensor]):
|
| 34 |
+
The modulated input from the previous timestep used to indicate the change of the transformer layer's output.
|
| 35 |
+
accumulated_rel_l1_distance (float):
|
| 36 |
+
The accumulated relative L1 distance.
|
| 37 |
+
is_first_or_last_step (bool):
|
| 38 |
+
Whether the current timestep is the first or last step.
|
| 39 |
+
"""
|
| 40 |
+
previous_residual: Optional[torch.Tensor] = None
|
| 41 |
+
previous_modulated_inp: Optional[torch.Tensor] = None
|
| 42 |
+
accumulated_rel_l1_distance: float = 0
|
| 43 |
+
is_first_or_last_step: bool = False
|