tyfeld
commited on
Commit
·
ea359a8
1
Parent(s):
2f15a78
initial
Browse files- app.py +871 -0
- models/__init__.py +3 -0
- models/common_modules.py +357 -0
- models/configuration_llada.py +463 -0
- models/logging.py +338 -0
- models/lr_schedulers.py +302 -0
- models/misc.py +53 -0
- models/modeling_llada.py +1500 -0
- models/modeling_magvitv2.py +440 -0
- models/modeling_mmada.py +668 -0
- models/modeling_utils.py +1207 -0
- models/sampling.py +118 -0
- models/training_utils.py +455 -0
- training/__init__.py +1 -0
- training/prompting_utils.py +475 -0
app.py
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from models import MAGVITv2, get_mask_schedule, MMadaModelLM
|
| 8 |
+
from training.prompting_utils import UniversalPrompting
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
def image_transform(image, resolution=256, normalize=True):
|
| 12 |
+
image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
|
| 13 |
+
image = transforms.CenterCrop((resolution, resolution))(image)
|
| 14 |
+
image = transforms.ToTensor()(image)
|
| 15 |
+
if normalize:
|
| 16 |
+
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
|
| 17 |
+
return image
|
| 18 |
+
|
| 19 |
+
def add_gumbel_noise(logits, temperature):
|
| 20 |
+
"""
|
| 21 |
+
Adds Gumbel noise to logits for stochastic sampling.
|
| 22 |
+
Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
|
| 23 |
+
This version is more numerically stable than a version involving exp() and division.
|
| 24 |
+
"""
|
| 25 |
+
if abs(temperature) < 1e-9: # Effectively zero temperature
|
| 26 |
+
return logits
|
| 27 |
+
# Ensure logits are float64 for precision with noise, as suggested by user context
|
| 28 |
+
logits = logits.to(torch.float64)
|
| 29 |
+
# Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1)
|
| 30 |
+
# Add small epsilon for numerical stability inside logs
|
| 31 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 32 |
+
standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
|
| 33 |
+
return logits + temperature * standard_gumbel_noise
|
| 34 |
+
|
| 35 |
+
def get_num_transfer_tokens(mask_index, steps):
|
| 36 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
| 37 |
+
# Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0)
|
| 38 |
+
steps = max(1, int(steps)) # Ensure steps is a positive integer
|
| 39 |
+
base = mask_num // steps
|
| 40 |
+
remainder = mask_num % steps
|
| 41 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
|
| 42 |
+
for i in range(mask_num.size(0)): # Iterate over batch
|
| 43 |
+
if remainder[i] > 0 : # Ensure remainder is positive before indexing
|
| 44 |
+
num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
|
| 45 |
+
return num_transfer_tokens
|
| 46 |
+
|
| 47 |
+
MODEL = None
|
| 48 |
+
TOKENIZER = None
|
| 49 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 50 |
+
MASK_ID = None
|
| 51 |
+
uni_prompting = None
|
| 52 |
+
VQ_MODEL = MAGVITv2().from_pretrained("/data_storage/shared/pretrained_models/models--showlab--magvitv2").to(DEVICE)
|
| 53 |
+
|
| 54 |
+
DEFAULT_MODEL_PATH = "/data_storage/lbw/MMaDA/mmada-training-stage3-llada-instruct-512-cot-uni/checkpoint-210000/unwrapped_model" # Default
|
| 55 |
+
CURRENT_MODEL_PATH = None
|
| 56 |
+
|
| 57 |
+
MODEL_CHOICES = [
|
| 58 |
+
"MMaDA-8B-Base",
|
| 59 |
+
"MMaDA-8B-MixCoT (coming soon)",
|
| 60 |
+
"MMaDA-8B-Max (coming soon)"
|
| 61 |
+
]
|
| 62 |
+
MODEL_ACTUAL_PATHS = {
|
| 63 |
+
"MMaDA-8B-Base": DEFAULT_MODEL_PATH,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def clear_outputs_action():
|
| 67 |
+
return None, None
|
| 68 |
+
|
| 69 |
+
def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
|
| 70 |
+
global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
|
| 71 |
+
|
| 72 |
+
if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load:
|
| 73 |
+
return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}"
|
| 74 |
+
|
| 75 |
+
CURRENT_MODEL_PATH = model_path_to_load
|
| 76 |
+
|
| 77 |
+
status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
|
| 78 |
+
try:
|
| 79 |
+
TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
|
| 80 |
+
status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
|
| 81 |
+
|
| 82 |
+
MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
|
| 83 |
+
status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
|
| 84 |
+
|
| 85 |
+
uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
|
| 86 |
+
|
| 87 |
+
if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
|
| 88 |
+
MASK_ID = TOKENIZER.mask_token_id
|
| 89 |
+
status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
|
| 90 |
+
else:
|
| 91 |
+
MASK_ID = 126336
|
| 92 |
+
status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
|
| 93 |
+
|
| 94 |
+
if TOKENIZER.pad_token_id is None:
|
| 95 |
+
if TOKENIZER.eos_token_id is not None:
|
| 96 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
| 97 |
+
TOKENIZER.pad_token = TOKENIZER.eos_token
|
| 98 |
+
status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
|
| 99 |
+
else:
|
| 100 |
+
status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
|
| 101 |
+
|
| 102 |
+
if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
|
| 103 |
+
status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
|
| 104 |
+
|
| 105 |
+
TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
|
| 106 |
+
|
| 107 |
+
return " ".join(status_msg_parts)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
MODEL = None
|
| 110 |
+
TOKENIZER = None
|
| 111 |
+
MASK_ID = None
|
| 112 |
+
CURRENT_MODEL_PATH = None
|
| 113 |
+
return f"Error loading model '{model_display_name_for_status}': {str(e)}"
|
| 114 |
+
|
| 115 |
+
def handle_model_selection_change(selected_model_name_ui):
|
| 116 |
+
if "coming soon" in selected_model_name_ui.lower():
|
| 117 |
+
global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
|
| 118 |
+
MODEL = None
|
| 119 |
+
TOKENIZER = None
|
| 120 |
+
MASK_ID = None
|
| 121 |
+
CURRENT_MODEL_PATH = None
|
| 122 |
+
return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
|
| 123 |
+
|
| 124 |
+
actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
|
| 125 |
+
if not actual_path:
|
| 126 |
+
return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
|
| 127 |
+
|
| 128 |
+
return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
|
| 132 |
+
if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
|
| 133 |
+
return [("Error in sequence data for visualization.", "ERROR")]
|
| 134 |
+
# only answer part
|
| 135 |
+
current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
|
| 136 |
+
seq_ids = current_x_ids_batch[0].tolist()
|
| 137 |
+
eos_token_id = tk.eos_token_id # Get EOS token ID
|
| 138 |
+
|
| 139 |
+
# Stage 1: Build initial list of tuples with (token_str, label, token_id_int)
|
| 140 |
+
# This helps in identifying EOS tokens later without re-checking the type.
|
| 141 |
+
intermediate_tuples = []
|
| 142 |
+
for j, token_id_int in enumerate(seq_ids):
|
| 143 |
+
try:
|
| 144 |
+
token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
| 145 |
+
except Exception: # Handle cases where a token ID might be problematic (e.g. with mock)
|
| 146 |
+
token_str = f"[ID:{token_id_int}]"
|
| 147 |
+
|
| 148 |
+
label = "ERROR"
|
| 149 |
+
if token_id_int == current_mask_id:
|
| 150 |
+
token_str = "[MASK]"
|
| 151 |
+
label = "MASK"
|
| 152 |
+
else:
|
| 153 |
+
label = "GEN"
|
| 154 |
+
intermediate_tuples.append((token_str, label, token_id_int))
|
| 155 |
+
|
| 156 |
+
return intermediate_tuples
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
|
| 160 |
+
global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
|
| 161 |
+
|
| 162 |
+
if MODEL is None or TOKENIZER is None or MASK_ID is None:
|
| 163 |
+
yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
|
| 164 |
+
return
|
| 165 |
+
steps = int(steps)
|
| 166 |
+
guidance_scale = float(guidance_scale)
|
| 167 |
+
|
| 168 |
+
image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
|
| 169 |
+
prompt_text = [prompt_text]
|
| 170 |
+
input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
|
| 171 |
+
|
| 172 |
+
if guidance_scale > 0:
|
| 173 |
+
uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
|
| 174 |
+
else:
|
| 175 |
+
uncond_input_ids, uncond_attention_mask = None, None
|
| 176 |
+
|
| 177 |
+
mask_schedule = get_mask_schedule(mask_schedule)
|
| 178 |
+
blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
|
| 179 |
+
yield blank_image, "Starting generation..."
|
| 180 |
+
for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
|
| 181 |
+
input_ids = input_ids,
|
| 182 |
+
uncond_input_ids = uncond_input_ids,
|
| 183 |
+
attention_mask = attention_mask,
|
| 184 |
+
uncond_attention_mask = uncond_attention_mask,
|
| 185 |
+
temperature=1.0,
|
| 186 |
+
timesteps = steps,
|
| 187 |
+
guidance_scale = guidance_scale,
|
| 188 |
+
noise_schedule = mask_schedule,
|
| 189 |
+
noise_type = "mask",
|
| 190 |
+
seq_len = 1024,
|
| 191 |
+
vq_model = VQ_MODEL,
|
| 192 |
+
uni_prompting=uni_prompting):
|
| 193 |
+
yield image_step, status_msg_step
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
|
| 200 |
+
cfg_scale, remasking_strategy, thinking_mode_lm):
|
| 201 |
+
global MODEL, TOKENIZER, MASK_ID, DEVICE
|
| 202 |
+
print(f"thinking_mode_lm: {thinking_mode_lm}")
|
| 203 |
+
if MODEL is None or TOKENIZER is None or MASK_ID is None:
|
| 204 |
+
yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
steps = int(steps)
|
| 208 |
+
gen_length = int(gen_length)
|
| 209 |
+
block_length = int(block_length)
|
| 210 |
+
|
| 211 |
+
if thinking_mode_lm:
|
| 212 |
+
prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
m = [{"role": "user", "content": prompt_text}]
|
| 216 |
+
processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
|
| 217 |
+
except Exception as e:
|
| 218 |
+
yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
|
| 219 |
+
processed_prompt_text = prompt_text
|
| 220 |
+
try:
|
| 221 |
+
if TOKENIZER.pad_token_id is None:
|
| 222 |
+
if TOKENIZER.eos_token_id is not None:
|
| 223 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
| 224 |
+
else: # Should have been caught by load_model, but double check
|
| 225 |
+
yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
|
| 229 |
+
raw_prompt_attention_mask = None
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
batch_size = input_ids.shape[0]
|
| 238 |
+
prompt_len = input_ids.shape[1]
|
| 239 |
+
|
| 240 |
+
x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
|
| 241 |
+
x[:, :prompt_len] = input_ids.clone()
|
| 242 |
+
|
| 243 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
|
| 244 |
+
|
| 245 |
+
if gen_length == 0:
|
| 246 |
+
final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
|
| 247 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
if block_length <= 0 or gen_length % block_length != 0 :
|
| 251 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
| 252 |
+
f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
|
| 253 |
+
return
|
| 254 |
+
num_blocks = gen_length // block_length
|
| 255 |
+
|
| 256 |
+
if steps <=0 or steps % num_blocks != 0:
|
| 257 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
| 258 |
+
f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
|
| 259 |
+
return
|
| 260 |
+
steps_per_block = steps // num_blocks
|
| 261 |
+
|
| 262 |
+
for num_block_iter in range(num_blocks):
|
| 263 |
+
current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
|
| 264 |
+
current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
|
| 265 |
+
|
| 266 |
+
block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
|
| 267 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
|
| 268 |
+
(x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
|
| 269 |
+
|
| 270 |
+
num_transfer_tokens_for_this_block = get_num_transfer_tokens(
|
| 271 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
|
| 272 |
+
steps_per_block
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
for i_step_in_block in range(steps_per_block):
|
| 276 |
+
mask_index_global = (x == MASK_ID)
|
| 277 |
+
|
| 278 |
+
if cfg_scale > 0.:
|
| 279 |
+
un_x = x.clone()
|
| 280 |
+
# For unconditional pass, mask out the original prompt tokens that are not padding
|
| 281 |
+
# raw_prompt_attention_mask is (B, prompt_len)
|
| 282 |
+
prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
|
| 283 |
+
un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
|
| 284 |
+
|
| 285 |
+
x_cfg_input = torch.cat([x, un_x], dim=0)
|
| 286 |
+
# Pass attention_mask for CFG if model expects it, covering both parts
|
| 287 |
+
# For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
|
| 288 |
+
model_output = MODEL(x_cfg_input)
|
| 289 |
+
logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
|
| 290 |
+
logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
|
| 291 |
+
else:
|
| 292 |
+
# Not passing explicit attention_mask here; relies on model's internal handling.
|
| 293 |
+
model_output = MODEL(x)
|
| 294 |
+
logits = model_output.logits
|
| 295 |
+
|
| 296 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 297 |
+
x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
|
| 298 |
+
|
| 299 |
+
if remasking_strategy == 'low_confidence':
|
| 300 |
+
probs = F.softmax(logits.to(torch.float64), dim=-1)
|
| 301 |
+
x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
|
| 302 |
+
elif remasking_strategy == 'random':
|
| 303 |
+
x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
|
| 304 |
+
else:
|
| 305 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
|
| 306 |
+
return
|
| 307 |
+
|
| 308 |
+
confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
|
| 309 |
+
candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
|
| 310 |
+
confidence_for_selection = torch.where(
|
| 311 |
+
candidate_positions_for_unmasking,
|
| 312 |
+
x0_probs,
|
| 313 |
+
-torch.inf
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
|
| 317 |
+
|
| 318 |
+
transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
|
| 319 |
+
num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
|
| 320 |
+
|
| 321 |
+
for j_batch_idx in range(batch_size):
|
| 322 |
+
k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
|
| 323 |
+
candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
|
| 324 |
+
|
| 325 |
+
if k_val > 0:
|
| 326 |
+
# Ensure confidence_for_selection[j_batch_idx] is 1D for topk
|
| 327 |
+
conf_slice = confidence_for_selection[j_batch_idx]
|
| 328 |
+
if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
|
| 329 |
+
|
| 330 |
+
# Check if there are enough valid (non -inf) confidences
|
| 331 |
+
valid_conf_count = (conf_slice > -torch.inf).sum().item()
|
| 332 |
+
actual_k = min(k_val, valid_conf_count)
|
| 333 |
+
|
| 334 |
+
if actual_k > 0:
|
| 335 |
+
_, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
|
| 336 |
+
transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
|
| 337 |
+
|
| 338 |
+
x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
|
| 339 |
+
|
| 340 |
+
current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
|
| 341 |
+
total_overall_steps = num_blocks * steps_per_block
|
| 342 |
+
status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
|
| 343 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
|
| 344 |
+
|
| 345 |
+
final_generated_ids = x[:, prompt_len:]
|
| 346 |
+
final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
|
| 347 |
+
|
| 348 |
+
final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
|
| 349 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
|
| 350 |
+
|
| 351 |
+
@torch.no_grad()
|
| 352 |
+
def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
|
| 353 |
+
cfg_scale, remasking_strategy, thinking_mode_mmu):
|
| 354 |
+
global MODEL, TOKENIZER, MASK_ID, DEVICE
|
| 355 |
+
|
| 356 |
+
if MODEL is None or TOKENIZER is None or MASK_ID is None:
|
| 357 |
+
yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
|
| 358 |
+
return
|
| 359 |
+
|
| 360 |
+
steps = int(steps)
|
| 361 |
+
gen_length = int(gen_length)
|
| 362 |
+
block_length = int(block_length)
|
| 363 |
+
|
| 364 |
+
if thinking_mode_mmu:
|
| 365 |
+
prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
|
| 366 |
+
|
| 367 |
+
try:
|
| 368 |
+
m = [{"role": "user", "content": prompt_text}]
|
| 369 |
+
processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
|
| 370 |
+
except Exception as e:
|
| 371 |
+
yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
|
| 372 |
+
processed_prompt_text = prompt_text
|
| 373 |
+
|
| 374 |
+
image_vq_ids_tensor = None
|
| 375 |
+
if uploaded_image_pil is not None:
|
| 376 |
+
try:
|
| 377 |
+
|
| 378 |
+
image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
|
| 379 |
+
image = image.unsqueeze(0)
|
| 380 |
+
image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
|
| 381 |
+
except Exception as e:
|
| 382 |
+
yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
|
| 383 |
+
return
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
try:
|
| 387 |
+
if TOKENIZER.pad_token_id is None:
|
| 388 |
+
if TOKENIZER.eos_token_id is not None:
|
| 389 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
| 390 |
+
else:
|
| 391 |
+
yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
|
| 392 |
+
return
|
| 393 |
+
|
| 394 |
+
input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
|
| 395 |
+
raw_prompt_attention_mask = None
|
| 396 |
+
if image_vq_ids_tensor is not None:
|
| 397 |
+
if image_vq_ids_tensor.ndim == 1:
|
| 398 |
+
image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
|
| 399 |
+
|
| 400 |
+
input_ids = torch.cat([
|
| 401 |
+
(torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
|
| 402 |
+
(torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
|
| 403 |
+
image_vq_ids_tensor,
|
| 404 |
+
(torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
|
| 405 |
+
input_ids
|
| 406 |
+
], dim=1).long()
|
| 407 |
+
|
| 408 |
+
else:
|
| 409 |
+
input_ids = input_ids
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
except Exception as e:
|
| 413 |
+
yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
|
| 414 |
+
return
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
batch_size = input_ids.shape[0]
|
| 419 |
+
prompt_len = input_ids.shape[1]
|
| 420 |
+
|
| 421 |
+
x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
|
| 422 |
+
x[:, :prompt_len] = input_ids.clone()
|
| 423 |
+
|
| 424 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
|
| 425 |
+
|
| 426 |
+
if gen_length == 0:
|
| 427 |
+
final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
|
| 428 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
|
| 429 |
+
return
|
| 430 |
+
|
| 431 |
+
if block_length <= 0 or gen_length % block_length != 0 :
|
| 432 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
| 433 |
+
f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
|
| 434 |
+
return
|
| 435 |
+
num_blocks = gen_length // block_length
|
| 436 |
+
|
| 437 |
+
if steps <=0 or steps % num_blocks != 0:
|
| 438 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
| 439 |
+
f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
|
| 440 |
+
return
|
| 441 |
+
steps_per_block = steps // num_blocks
|
| 442 |
+
|
| 443 |
+
for num_block_iter in range(num_blocks):
|
| 444 |
+
current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
|
| 445 |
+
current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
|
| 446 |
+
|
| 447 |
+
block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
|
| 448 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
|
| 449 |
+
(x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
|
| 450 |
+
|
| 451 |
+
num_transfer_tokens_for_this_block = get_num_transfer_tokens(
|
| 452 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
|
| 453 |
+
steps_per_block
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
for i_step_in_block in range(steps_per_block):
|
| 457 |
+
mask_index_global = (x == MASK_ID)
|
| 458 |
+
|
| 459 |
+
if cfg_scale > 0.:
|
| 460 |
+
un_x = x.clone()
|
| 461 |
+
# For unconditional pass, mask out the original prompt tokens that are not padding
|
| 462 |
+
# raw_prompt_attention_mask is (B, prompt_len)
|
| 463 |
+
prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
|
| 464 |
+
un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
|
| 465 |
+
|
| 466 |
+
x_cfg_input = torch.cat([x, un_x], dim=0)
|
| 467 |
+
# Pass attention_mask for CFG if model expects it, covering both parts
|
| 468 |
+
# For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
|
| 469 |
+
model_output = MODEL(x_cfg_input)
|
| 470 |
+
logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
|
| 471 |
+
logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
|
| 472 |
+
else:
|
| 473 |
+
# Not passing explicit attention_mask here; relies on model's internal handling.
|
| 474 |
+
model_output = MODEL(x)
|
| 475 |
+
logits = model_output.logits
|
| 476 |
+
|
| 477 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 478 |
+
x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
|
| 479 |
+
|
| 480 |
+
if remasking_strategy == 'low_confidence':
|
| 481 |
+
probs = F.softmax(logits.to(torch.float64), dim=-1)
|
| 482 |
+
x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
|
| 483 |
+
elif remasking_strategy == 'random':
|
| 484 |
+
x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
|
| 485 |
+
else:
|
| 486 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
|
| 487 |
+
return
|
| 488 |
+
|
| 489 |
+
confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
|
| 490 |
+
candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
|
| 491 |
+
confidence_for_selection = torch.where(
|
| 492 |
+
candidate_positions_for_unmasking,
|
| 493 |
+
x0_probs,
|
| 494 |
+
-torch.inf
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
|
| 498 |
+
|
| 499 |
+
transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
|
| 500 |
+
num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
|
| 501 |
+
|
| 502 |
+
for j_batch_idx in range(batch_size):
|
| 503 |
+
k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
|
| 504 |
+
candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
|
| 505 |
+
|
| 506 |
+
if k_val > 0:
|
| 507 |
+
# Ensure confidence_for_selection[j_batch_idx] is 1D for topk
|
| 508 |
+
conf_slice = confidence_for_selection[j_batch_idx]
|
| 509 |
+
if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
|
| 510 |
+
|
| 511 |
+
# Check if there are enough valid (non -inf) confidences
|
| 512 |
+
valid_conf_count = (conf_slice > -torch.inf).sum().item()
|
| 513 |
+
actual_k = min(k_val, valid_conf_count)
|
| 514 |
+
|
| 515 |
+
if actual_k > 0:
|
| 516 |
+
_, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
|
| 517 |
+
transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
|
| 518 |
+
|
| 519 |
+
x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
|
| 520 |
+
|
| 521 |
+
current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
|
| 522 |
+
total_overall_steps = num_blocks * steps_per_block
|
| 523 |
+
status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
|
| 524 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
|
| 525 |
+
|
| 526 |
+
final_generated_ids = x[:, prompt_len:]
|
| 527 |
+
final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
|
| 528 |
+
|
| 529 |
+
final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
|
| 530 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
css_styles = """
|
| 534 |
+
.gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
|
| 535 |
+
.gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
|
| 536 |
+
.gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
|
| 537 |
+
|
| 538 |
+
.highlighted-text span{
|
| 539 |
+
padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
footer{display:none !important}
|
| 543 |
+
|
| 544 |
+
#live-update-scrollable-box {
|
| 545 |
+
max-height: 800px; /* 您可以根据需要调整这个最大高度,例如 '300px', '50vh' 等 */
|
| 546 |
+
overflow-y: auto !important; /* 当内容超出 max-height 时显示垂直滚动条 */
|
| 547 |
+
display: block; /* 确保元素是块级元素,以便 max-height 生效 */
|
| 548 |
+
|
| 549 |
+
}
|
| 550 |
+
#think_btn {
|
| 551 |
+
background-color: #f3f4f6 !important;
|
| 552 |
+
border: 1px solid #d0d0d0 !important;
|
| 553 |
+
color: #111827 !important;
|
| 554 |
+
font-size: 16px !important;
|
| 555 |
+
font-weight: bold !important;
|
| 556 |
+
}
|
| 557 |
+
#think_btn:hover {
|
| 558 |
+
background-color: #e0e0e0 !important;
|
| 559 |
+
border: 1px solid #c0c0c0 !important;
|
| 560 |
+
color: #222 !important;
|
| 561 |
+
}
|
| 562 |
+
#think_btn:active {
|
| 563 |
+
background-color: #2563eb !important;
|
| 564 |
+
border: 1px solid #b0b0b0 !important;
|
| 565 |
+
color: white !important;
|
| 566 |
+
}
|
| 567 |
+
"""
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# thinking_mode_t2i = gr.State(False)
|
| 571 |
+
def toggle_thinking_mode_lm(current_thinking_mode):
|
| 572 |
+
# print(f"current_thinking_mode: {current_thinking_mode}")
|
| 573 |
+
new_state = not current_thinking_mode
|
| 574 |
+
new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
|
| 575 |
+
return new_state, gr.update(value=new_label)
|
| 576 |
+
|
| 577 |
+
def toggle_thinking_mode_mmu(current_thinking_mode):
|
| 578 |
+
new_state = not current_thinking_mode
|
| 579 |
+
new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
|
| 580 |
+
return new_state, gr.update(value=new_label)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
color_map_config = {
|
| 584 |
+
"MASK": "lightgrey",
|
| 585 |
+
"GEN": "#DCABFA",
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
theme = gr.themes.Ocean(
|
| 589 |
+
primary_hue="fuchsia",
|
| 590 |
+
)
|
| 591 |
+
with gr.Blocks(css=css_styles, theme=theme) as demo:
|
| 592 |
+
# with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
|
| 593 |
+
# with gr.Blocks() as demo:
|
| 594 |
+
thinking_mode_lm = gr.State(False)
|
| 595 |
+
thinking_mode_mmu = gr.State(False)
|
| 596 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>MMaDA </h1>")
|
| 597 |
+
gr.Markdown("Interactively explore the step-by-step generation process of a diffusion language model. "
|
| 598 |
+
"The model begins with a fully masked sequence (except for the prompt) and progressively refines it by unmasking tokens.")
|
| 599 |
+
gr.Markdown("### Select Model")
|
| 600 |
+
with gr.Row():
|
| 601 |
+
model_select_radio = gr.Radio(
|
| 602 |
+
label="Select Text Generation Model",
|
| 603 |
+
choices=MODEL_CHOICES,
|
| 604 |
+
value=MODEL_CHOICES[0]
|
| 605 |
+
)
|
| 606 |
+
model_load_status_box = gr.Textbox(
|
| 607 |
+
label="Model Load Status",
|
| 608 |
+
interactive=False,
|
| 609 |
+
lines=3,
|
| 610 |
+
max_lines=5
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
gr.Markdown("## Part 1. Text Generation")
|
| 614 |
+
with gr.Row():
|
| 615 |
+
with gr.Column(scale=2):
|
| 616 |
+
prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
|
| 617 |
+
think_button_lm = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
|
| 618 |
+
with gr.Accordion("Generation Parameters", open=True):
|
| 619 |
+
with gr.Row():
|
| 620 |
+
gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
|
| 621 |
+
steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
| 622 |
+
with gr.Row():
|
| 623 |
+
block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
|
| 624 |
+
remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
|
| 625 |
+
with gr.Row():
|
| 626 |
+
cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
|
| 627 |
+
temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
with gr.Row():
|
| 631 |
+
run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
|
| 632 |
+
clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
|
| 633 |
+
|
| 634 |
+
with gr.Column(scale=3):
|
| 635 |
+
# gr.Markdown("## Live Generation Process")
|
| 636 |
+
output_visualization_box_lm = gr.HighlightedText(
|
| 637 |
+
label="Live Generation Process",
|
| 638 |
+
show_legend=True,
|
| 639 |
+
color_map=color_map_config,
|
| 640 |
+
combine_adjacent=False,
|
| 641 |
+
interactive=False,
|
| 642 |
+
elem_id="live-update-scrollable-box",
|
| 643 |
+
)
|
| 644 |
+
# gr.Markdown("## Final Generated Text")
|
| 645 |
+
output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
gr.Examples(
|
| 650 |
+
examples=[
|
| 651 |
+
["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
|
| 652 |
+
["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
|
| 653 |
+
],
|
| 654 |
+
inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
|
| 655 |
+
outputs=[output_visualization_box_lm, output_final_text_box_lm],
|
| 656 |
+
fn=generate_viz_wrapper_lm,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
gr.Markdown("---")
|
| 660 |
+
gr.Markdown("## Part 2. Multimodal Understanding")
|
| 661 |
+
with gr.Row():
|
| 662 |
+
with gr.Column(scale=2):
|
| 663 |
+
prompt_input_box_mmu = gr.Textbox(
|
| 664 |
+
label="Enter your prompt:",
|
| 665 |
+
lines=3,
|
| 666 |
+
value="Please describe this image in detail."
|
| 667 |
+
)
|
| 668 |
+
think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
|
| 669 |
+
with gr.Accordion("Generation Parameters", open=True):
|
| 670 |
+
with gr.Row():
|
| 671 |
+
gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
|
| 672 |
+
steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
| 673 |
+
with gr.Row():
|
| 674 |
+
block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
|
| 675 |
+
remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
|
| 676 |
+
with gr.Row():
|
| 677 |
+
cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
|
| 678 |
+
temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
|
| 679 |
+
|
| 680 |
+
with gr.Row():
|
| 681 |
+
image_upload_box = gr.Image(type="pil", label="Upload Image")
|
| 682 |
+
|
| 683 |
+
with gr.Row():
|
| 684 |
+
run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
|
| 685 |
+
clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
|
| 686 |
+
|
| 687 |
+
with gr.Column(scale=3):
|
| 688 |
+
gr.Markdown("## Live Generation Process")
|
| 689 |
+
output_visualization_box_mmu = gr.HighlightedText(
|
| 690 |
+
label="Token Sequence (Live Update)",
|
| 691 |
+
show_legend=True,
|
| 692 |
+
color_map=color_map_config,
|
| 693 |
+
combine_adjacent=False,
|
| 694 |
+
interactive=False,
|
| 695 |
+
elem_id="live-update-scrollable-box",
|
| 696 |
+
)
|
| 697 |
+
gr.Markdown("## Final Generated Text")
|
| 698 |
+
output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
gr.Examples(
|
| 702 |
+
examples=[
|
| 703 |
+
[
|
| 704 |
+
"mmu_validation_2/sunflower.jpg",
|
| 705 |
+
"Please describe this image in detail.",
|
| 706 |
+
256,
|
| 707 |
+
512,
|
| 708 |
+
128,
|
| 709 |
+
1,
|
| 710 |
+
0,
|
| 711 |
+
"low_confidence"
|
| 712 |
+
],
|
| 713 |
+
[
|
| 714 |
+
"mmu_validation_2/woman.jpg",
|
| 715 |
+
"Please describe this image in detail.",
|
| 716 |
+
256,
|
| 717 |
+
512,
|
| 718 |
+
128,
|
| 719 |
+
1,
|
| 720 |
+
0,
|
| 721 |
+
"low_confidence"
|
| 722 |
+
]
|
| 723 |
+
],
|
| 724 |
+
inputs=[
|
| 725 |
+
image_upload_box,
|
| 726 |
+
prompt_input_box_mmu,
|
| 727 |
+
steps_slider_mmu,
|
| 728 |
+
gen_length_slider_mmu,
|
| 729 |
+
block_length_slider_mmu,
|
| 730 |
+
temperature_slider_mmu,
|
| 731 |
+
cfg_scale_slider_mmu,
|
| 732 |
+
remasking_dropdown_mmu
|
| 733 |
+
],
|
| 734 |
+
outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
|
| 735 |
+
fn=generate_viz_wrapper,
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
gr.Markdown("---")
|
| 739 |
+
gr.Markdown("## Part 3. Text-to-Image Generation")
|
| 740 |
+
with gr.Row():
|
| 741 |
+
with gr.Column(scale=2):
|
| 742 |
+
prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
|
| 743 |
+
|
| 744 |
+
with gr.Accordion("Generation Parameters", open=True):
|
| 745 |
+
with gr.Row():
|
| 746 |
+
steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
| 747 |
+
guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale", info="Classifier-Free Guidance. 0 disables it.")
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
with gr.Row():
|
| 751 |
+
scheduler_radio_t2i = gr.Radio(
|
| 752 |
+
choices=["cosine", "sigmoid", "linear"],
|
| 753 |
+
value="cosine",
|
| 754 |
+
label="Scheduler",
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
with gr.Row():
|
| 758 |
+
run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
|
| 759 |
+
clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
with gr.Column(scale=3):
|
| 763 |
+
# gr.Markdown("## Live Generation Process")
|
| 764 |
+
output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
|
| 765 |
+
output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
|
| 766 |
+
|
| 767 |
+
gr.Examples(
|
| 768 |
+
examples=[
|
| 769 |
+
["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
|
| 770 |
+
["A beautiful sunset over a calm ocean, with a few clouds in the sky.", 15, 3.5, "cosine"]
|
| 771 |
+
],
|
| 772 |
+
inputs=[prompt_input_box_t2i, steps_slider_t2i, guidance_scale_slider_t2i, scheduler_radio_t2i],
|
| 773 |
+
outputs=[output_image_t2i, output_status_t2i],
|
| 774 |
+
fn=generate_viz_wrapper_t2i,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
run_button_ui_t2i.click(
|
| 778 |
+
fn=generate_viz_wrapper_t2i,
|
| 779 |
+
inputs=[
|
| 780 |
+
prompt_input_box_t2i,
|
| 781 |
+
steps_slider_t2i,
|
| 782 |
+
guidance_scale_slider_t2i,
|
| 783 |
+
scheduler_radio_t2i
|
| 784 |
+
],
|
| 785 |
+
outputs=[output_image_t2i, output_status_t2i]
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
clear_button_ui_t2i.click(
|
| 789 |
+
fn=lambda: (None, ""),
|
| 790 |
+
inputs=None,
|
| 791 |
+
outputs=[output_image_t2i, output_status_t2i],
|
| 792 |
+
queue=False
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
think_button_lm.click(
|
| 796 |
+
fn=toggle_thinking_mode_lm,
|
| 797 |
+
inputs=[thinking_mode_lm],
|
| 798 |
+
outputs=[thinking_mode_lm, think_button_lm]
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
think_button_mmu.click(
|
| 802 |
+
fn=toggle_thinking_mode_mmu,
|
| 803 |
+
inputs=[thinking_mode_mmu],
|
| 804 |
+
outputs=[thinking_mode_mmu, think_button_mmu]
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
def initialize_default_model():
|
| 810 |
+
default_model = "MMaDA-8B-Base"
|
| 811 |
+
result = handle_model_selection_change(default_model)
|
| 812 |
+
return default_model, result
|
| 813 |
+
|
| 814 |
+
demo.load(
|
| 815 |
+
fn=initialize_default_model,
|
| 816 |
+
inputs=None,
|
| 817 |
+
outputs=[model_select_radio, model_load_status_box],
|
| 818 |
+
queue=True
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
def clear_outputs():
|
| 822 |
+
return None, None, None # Clear image, visualization, and final text
|
| 823 |
+
|
| 824 |
+
clear_button_ui_lm.click(
|
| 825 |
+
fn=clear_outputs,
|
| 826 |
+
inputs=None,
|
| 827 |
+
outputs=[image_upload_box, output_visualization_box_lm, output_final_text_box_lm],
|
| 828 |
+
queue=False
|
| 829 |
+
)
|
| 830 |
+
clear_button_ui_mmu.click(
|
| 831 |
+
fn=clear_outputs,
|
| 832 |
+
inputs=None,
|
| 833 |
+
outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu],
|
| 834 |
+
queue=False
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
run_button_ui_lm.click(
|
| 838 |
+
fn=generate_viz_wrapper_lm,
|
| 839 |
+
inputs=[
|
| 840 |
+
prompt_input_box_lm,
|
| 841 |
+
steps_slider_lm,
|
| 842 |
+
gen_length_slider_lm,
|
| 843 |
+
block_length_slider_lm,
|
| 844 |
+
temperature_slider_lm,
|
| 845 |
+
cfg_scale_slider_lm,
|
| 846 |
+
remasking_dropdown_lm,
|
| 847 |
+
thinking_mode_lm
|
| 848 |
+
],
|
| 849 |
+
outputs=[output_visualization_box_lm, output_final_text_box_lm]
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
run_button_ui_mmu.click(
|
| 853 |
+
fn=generate_viz_wrapper,
|
| 854 |
+
inputs=[
|
| 855 |
+
image_upload_box,
|
| 856 |
+
prompt_input_box_mmu,
|
| 857 |
+
steps_slider_mmu,
|
| 858 |
+
gen_length_slider_mmu,
|
| 859 |
+
block_length_slider_mmu,
|
| 860 |
+
temperature_slider_mmu,
|
| 861 |
+
cfg_scale_slider_mmu,
|
| 862 |
+
remasking_dropdown_mmu,
|
| 863 |
+
thinking_mode_mmu
|
| 864 |
+
],
|
| 865 |
+
outputs=[output_visualization_box_mmu, output_final_text_box_mmu]
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
if __name__ == "__main__":
|
| 870 |
+
print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
|
| 871 |
+
demo.launch(share=True)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2
|
| 2 |
+
from .sampling import *
|
| 3 |
+
from .modeling_mmada import MMadaModelLM, MMadaConfig
|
models/common_modules.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
from einops.layers.torch import Rearrange
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def nonlinearity(x):
|
| 17 |
+
# swish
|
| 18 |
+
return x * torch.sigmoid(x)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def Normalize(in_channels):
|
| 22 |
+
return torch.nn.GroupNorm(
|
| 23 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Upsample(nn.Module):
|
| 28 |
+
def __init__(self, in_channels, with_conv):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.with_conv = with_conv
|
| 31 |
+
if self.with_conv:
|
| 32 |
+
self.conv = torch.nn.Conv2d(
|
| 33 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 38 |
+
if self.with_conv:
|
| 39 |
+
x = self.conv(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DepthToSpaceUpsample(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
in_channels,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
conv = nn.Conv2d(in_channels, in_channels * 4, 1)
|
| 50 |
+
|
| 51 |
+
self.net = nn.Sequential(
|
| 52 |
+
conv,
|
| 53 |
+
nn.SiLU(),
|
| 54 |
+
Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.init_conv_(conv)
|
| 58 |
+
|
| 59 |
+
def init_conv_(self, conv):
|
| 60 |
+
o, i, h, w = conv.weight.shape
|
| 61 |
+
conv_weight = torch.empty(o // 4, i, h, w)
|
| 62 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 63 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
| 64 |
+
|
| 65 |
+
conv.weight.data.copy_(conv_weight)
|
| 66 |
+
nn.init.zeros_(conv.bias.data)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
out = self.net(x)
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Downsample(nn.Module):
|
| 74 |
+
def __init__(self, in_channels, with_conv):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.with_conv = with_conv
|
| 77 |
+
if self.with_conv:
|
| 78 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 79 |
+
self.conv = torch.nn.Conv2d(
|
| 80 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
if self.with_conv:
|
| 85 |
+
pad = (0, 1, 0, 1)
|
| 86 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 87 |
+
x = self.conv(x)
|
| 88 |
+
else:
|
| 89 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def unpack_time(t, batch):
|
| 94 |
+
_, c, w, h = t.size()
|
| 95 |
+
out = torch.reshape(t, [batch, -1, c, w, h])
|
| 96 |
+
out = rearrange(out, "b t c h w -> b c t h w")
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def pack_time(t):
|
| 101 |
+
out = rearrange(t, "b c t h w -> b t c h w")
|
| 102 |
+
_, _, c, w, h = out.size()
|
| 103 |
+
return torch.reshape(out, [-1, c, w, h])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TimeDownsample2x(nn.Module):
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
dim,
|
| 110 |
+
dim_out=None,
|
| 111 |
+
kernel_size=3,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
if dim_out is None:
|
| 115 |
+
dim_out = dim
|
| 116 |
+
self.time_causal_padding = (kernel_size - 1, 0)
|
| 117 |
+
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
| 121 |
+
b, h, w, c, t = x.size()
|
| 122 |
+
x = torch.reshape(x, [-1, c, t])
|
| 123 |
+
|
| 124 |
+
x = F.pad(x, self.time_causal_padding)
|
| 125 |
+
out = self.conv(x)
|
| 126 |
+
|
| 127 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
| 128 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
| 129 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
| 130 |
+
return out
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TimeUpsample2x(nn.Module):
|
| 134 |
+
def __init__(self, dim, dim_out=None):
|
| 135 |
+
super().__init__()
|
| 136 |
+
if dim_out is None:
|
| 137 |
+
dim_out = dim
|
| 138 |
+
conv = nn.Conv1d(dim, dim_out * 2, 1)
|
| 139 |
+
|
| 140 |
+
self.net = nn.Sequential(
|
| 141 |
+
nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.init_conv_(conv)
|
| 145 |
+
|
| 146 |
+
def init_conv_(self, conv):
|
| 147 |
+
o, i, t = conv.weight.shape
|
| 148 |
+
conv_weight = torch.empty(o // 2, i, t)
|
| 149 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 150 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
|
| 151 |
+
|
| 152 |
+
conv.weight.data.copy_(conv_weight)
|
| 153 |
+
nn.init.zeros_(conv.bias.data)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
| 157 |
+
b, h, w, c, t = x.size()
|
| 158 |
+
x = torch.reshape(x, [-1, c, t])
|
| 159 |
+
|
| 160 |
+
out = self.net(x)
|
| 161 |
+
out = out[:, :, 1:].contiguous()
|
| 162 |
+
|
| 163 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
| 164 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class AttnBlock(nn.Module):
|
| 169 |
+
def __init__(self, in_channels):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.in_channels = in_channels
|
| 172 |
+
|
| 173 |
+
self.norm = Normalize(in_channels)
|
| 174 |
+
self.q = torch.nn.Conv2d(
|
| 175 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 176 |
+
)
|
| 177 |
+
self.k = torch.nn.Conv2d(
|
| 178 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 179 |
+
)
|
| 180 |
+
self.v = torch.nn.Conv2d(
|
| 181 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 182 |
+
)
|
| 183 |
+
self.proj_out = torch.nn.Conv2d(
|
| 184 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
h_ = x
|
| 189 |
+
h_ = self.norm(h_)
|
| 190 |
+
q = self.q(h_)
|
| 191 |
+
k = self.k(h_)
|
| 192 |
+
v = self.v(h_)
|
| 193 |
+
|
| 194 |
+
# compute attention
|
| 195 |
+
b, c, h, w = q.shape
|
| 196 |
+
q = q.reshape(b, c, h * w)
|
| 197 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
| 198 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
| 199 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 200 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 201 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 202 |
+
|
| 203 |
+
# attend to values
|
| 204 |
+
v = v.reshape(b, c, h * w)
|
| 205 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
| 206 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 207 |
+
h_ = h_.reshape(b, c, h, w)
|
| 208 |
+
|
| 209 |
+
h_ = self.proj_out(h_)
|
| 210 |
+
|
| 211 |
+
return x + h_
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class TimeAttention(AttnBlock):
|
| 215 |
+
def forward(self, x, *args, **kwargs):
|
| 216 |
+
x = rearrange(x, "b c t h w -> b h w t c")
|
| 217 |
+
b, h, w, t, c = x.size()
|
| 218 |
+
x = torch.reshape(x, (-1, t, c))
|
| 219 |
+
|
| 220 |
+
x = super().forward(x, *args, **kwargs)
|
| 221 |
+
|
| 222 |
+
x = torch.reshape(x, [b, h, w, t, c])
|
| 223 |
+
return rearrange(x, "b h w t c -> b c t h w")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class Residual(nn.Module):
|
| 227 |
+
def __init__(self, fn: nn.Module):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.fn = fn
|
| 230 |
+
|
| 231 |
+
def forward(self, x, **kwargs):
|
| 232 |
+
return self.fn(x, **kwargs) + x
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def cast_tuple(t, length=1):
|
| 236 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CausalConv3d(nn.Module):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
chan_in,
|
| 243 |
+
chan_out,
|
| 244 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 245 |
+
pad_mode="constant",
|
| 246 |
+
**kwargs
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
| 250 |
+
|
| 251 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
| 252 |
+
|
| 253 |
+
dilation = kwargs.pop("dilation", 1)
|
| 254 |
+
stride = kwargs.pop("stride", 1)
|
| 255 |
+
|
| 256 |
+
self.pad_mode = pad_mode
|
| 257 |
+
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
| 258 |
+
height_pad = height_kernel_size // 2
|
| 259 |
+
width_pad = width_kernel_size // 2
|
| 260 |
+
|
| 261 |
+
self.time_pad = time_pad
|
| 262 |
+
self.time_causal_padding = (
|
| 263 |
+
width_pad,
|
| 264 |
+
width_pad,
|
| 265 |
+
height_pad,
|
| 266 |
+
height_pad,
|
| 267 |
+
time_pad,
|
| 268 |
+
0,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
stride = (stride, 1, 1)
|
| 272 |
+
dilation = (dilation, 1, 1)
|
| 273 |
+
self.conv = nn.Conv3d(
|
| 274 |
+
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
| 279 |
+
|
| 280 |
+
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
|
| 281 |
+
return self.conv(x)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def ResnetBlockCausal3D(
|
| 285 |
+
dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"
|
| 286 |
+
):
|
| 287 |
+
net = nn.Sequential(
|
| 288 |
+
Normalize(dim),
|
| 289 |
+
nn.SiLU(),
|
| 290 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
| 291 |
+
Normalize(dim),
|
| 292 |
+
nn.SiLU(),
|
| 293 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
| 294 |
+
)
|
| 295 |
+
return Residual(net)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class ResnetBlock(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
*,
|
| 302 |
+
in_channels,
|
| 303 |
+
out_channels=None,
|
| 304 |
+
conv_shortcut=False,
|
| 305 |
+
dropout,
|
| 306 |
+
temb_channels=512
|
| 307 |
+
):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.in_channels = in_channels
|
| 310 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 311 |
+
self.out_channels = out_channels
|
| 312 |
+
self.use_conv_shortcut = conv_shortcut
|
| 313 |
+
|
| 314 |
+
self.norm1 = Normalize(in_channels)
|
| 315 |
+
self.conv1 = torch.nn.Conv2d(
|
| 316 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 317 |
+
)
|
| 318 |
+
if temb_channels > 0:
|
| 319 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 320 |
+
else:
|
| 321 |
+
self.temb_proj = None
|
| 322 |
+
self.norm2 = Normalize(out_channels)
|
| 323 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 324 |
+
self.conv2 = torch.nn.Conv2d(
|
| 325 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 326 |
+
)
|
| 327 |
+
if self.in_channels != self.out_channels:
|
| 328 |
+
if self.use_conv_shortcut:
|
| 329 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
| 330 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
| 334 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def forward(self, x, temb):
|
| 338 |
+
h = x
|
| 339 |
+
h = self.norm1(h)
|
| 340 |
+
h = nonlinearity(h)
|
| 341 |
+
h = self.conv1(h)
|
| 342 |
+
|
| 343 |
+
if temb is not None:
|
| 344 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 345 |
+
|
| 346 |
+
h = self.norm2(h)
|
| 347 |
+
h = nonlinearity(h)
|
| 348 |
+
h = self.dropout(h)
|
| 349 |
+
h = self.conv2(h)
|
| 350 |
+
|
| 351 |
+
if self.in_channels != self.out_channels:
|
| 352 |
+
if self.use_conv_shortcut:
|
| 353 |
+
x = self.conv_shortcut(x)
|
| 354 |
+
else:
|
| 355 |
+
x = self.nin_shortcut(x)
|
| 356 |
+
|
| 357 |
+
return x + h
|
models/configuration_llada.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaDA configuration
|
| 3 |
+
"""
|
| 4 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from os import PathLike
|
| 8 |
+
from typing import Union
|
| 9 |
+
from dataclasses import asdict, dataclass, field
|
| 10 |
+
from glob import glob
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Dict,
|
| 15 |
+
Iterable,
|
| 16 |
+
List,
|
| 17 |
+
Optional,
|
| 18 |
+
Tuple,
|
| 19 |
+
Type,
|
| 20 |
+
TypeVar,
|
| 21 |
+
Union,
|
| 22 |
+
cast,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"ActivationType",
|
| 28 |
+
"ActivationCheckpointingStrategy",
|
| 29 |
+
"BlockType",
|
| 30 |
+
"LayerNormType",
|
| 31 |
+
"InitFnType",
|
| 32 |
+
"ModelConfig",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
PathOrStr = Union[str, PathLike]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class StrEnum(str, Enum):
|
| 39 |
+
"""
|
| 40 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
| 41 |
+
We include this here for compatibility with older version of Python.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return self.value
|
| 46 |
+
|
| 47 |
+
def __repr__(self) -> str:
|
| 48 |
+
return f"'{str(self)}'"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LayerNormType(StrEnum):
|
| 52 |
+
default = "default"
|
| 53 |
+
"""
|
| 54 |
+
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
low_precision = "low_precision"
|
| 58 |
+
"""
|
| 59 |
+
A low-precision version of the default LayerNorm.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
rms = "rms"
|
| 63 |
+
"""
|
| 64 |
+
An RMSNorm implementation. When using ``torch.compile`` this is
|
| 65 |
+
probably the fastest implementation.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
gemma_rms = "gemma_rms"
|
| 69 |
+
"""
|
| 70 |
+
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
| 71 |
+
probably the fastest implementation.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
amd_compatible = "amd_compatible"
|
| 75 |
+
"""
|
| 76 |
+
LayerNorm implemented manually to work around an issue with ROCm.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ActivationType(StrEnum):
|
| 81 |
+
gelu = "gelu"
|
| 82 |
+
relu = "relu"
|
| 83 |
+
silu = "silu"
|
| 84 |
+
swiglu = "swiglu"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BlockType(StrEnum):
|
| 88 |
+
sequential = "sequential"
|
| 89 |
+
parallel = "parallel"
|
| 90 |
+
|
| 91 |
+
llama = "llama"
|
| 92 |
+
"""
|
| 93 |
+
A block similar to the sequential block with slightly different
|
| 94 |
+
implementations of operations like attention to imitate the behavior of Llama.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class InitFnType(StrEnum):
|
| 99 |
+
mitchell = "mitchell"
|
| 100 |
+
"""
|
| 101 |
+
The strategy suggested to us by Mitchell Wortsman from UW.
|
| 102 |
+
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
| 103 |
+
on the size of the weights as well as the depth of the layer.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
normal = "normal"
|
| 107 |
+
"""
|
| 108 |
+
All weights are initialized from the same normal distribution.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
kaiming_normal = "kaiming_normal"
|
| 112 |
+
"""
|
| 113 |
+
All weights are initialized with the Kaiming method from a normal distribution.
|
| 114 |
+
Note this currently won't work with FSDP.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
fan_in = "fan_in"
|
| 118 |
+
"""
|
| 119 |
+
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
| 120 |
+
is the input dimensionality of the kernel.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
full_megatron = "full_megatron"
|
| 124 |
+
"""
|
| 125 |
+
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class ModelConfig():
|
| 131 |
+
"""
|
| 132 |
+
LLaDA (model) configuration.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
| 136 |
+
|
| 137 |
+
d_model: int = 768
|
| 138 |
+
"""
|
| 139 |
+
The hidden size of the model.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
n_heads: int = 12
|
| 143 |
+
"""
|
| 144 |
+
The number of self-attention heads.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
n_kv_heads: Optional[int] = None
|
| 148 |
+
"""
|
| 149 |
+
The number of heads to use for keys and values. Defaults to `n_heads`.
|
| 150 |
+
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
| 151 |
+
Set this to 1 for multi-query attention.
|
| 152 |
+
Set it to some in-between value for Llama2-style grouped query attention.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
n_layers: int = 12
|
| 156 |
+
"""
|
| 157 |
+
The number of layers/blocks.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
mlp_ratio: int = 4
|
| 161 |
+
"""
|
| 162 |
+
The ratio of the inner MLP dimensionality to ``d_model``.
|
| 163 |
+
This is only used when ``mlp_hidden_size`` is not set.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
mlp_hidden_size: Optional[int] = None
|
| 167 |
+
"""
|
| 168 |
+
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
activation_type: ActivationType = ActivationType.swiglu
|
| 172 |
+
"""
|
| 173 |
+
The activation function to use within the MLP layers.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
block_type: BlockType = BlockType.sequential
|
| 177 |
+
"""
|
| 178 |
+
The transformer block implementation.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
block_group_size: int = 1
|
| 182 |
+
"""
|
| 183 |
+
The number of blocks to group together into a single parent block.
|
| 184 |
+
This has no affect on the number of parameters in the model and is only used to wrap groups
|
| 185 |
+
of blocks together with a single FSDP wrapper during training.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
alibi: bool = False
|
| 189 |
+
"""
|
| 190 |
+
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
alibi_bias_max: float = 8.0
|
| 194 |
+
"""
|
| 195 |
+
Maximum absolute value of ALiBi bias.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
rope: bool = False
|
| 199 |
+
"""
|
| 200 |
+
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
rope_full_precision: bool = True
|
| 204 |
+
"""
|
| 205 |
+
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
| 206 |
+
apply RoPE at the precision of the input.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
flash_attention: bool = False
|
| 210 |
+
"""
|
| 211 |
+
If ``True``, use ``FlashAttention``.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
attention_dropout: float = 0.1
|
| 215 |
+
"""
|
| 216 |
+
The dropout probability within the attention modules.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
multi_query_attention: Optional[bool] = None
|
| 220 |
+
"""
|
| 221 |
+
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
| 222 |
+
and is more efficient during inference.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
attention_layer_norm: bool = False
|
| 226 |
+
"""
|
| 227 |
+
Apply layer norm to the keys and queries within the attention mechanism.
|
| 228 |
+
This can help stabilize training.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
residual_dropout: float = 0.1
|
| 232 |
+
"""
|
| 233 |
+
The dropout probability for the MLP and attention output within each block.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
embedding_dropout: float = 0.1
|
| 237 |
+
"""
|
| 238 |
+
The dropout probability for embeddings.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
input_emb_norm: bool = False
|
| 242 |
+
"""
|
| 243 |
+
An input hidden_states norm implementation by gemmma.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
layer_norm_type: LayerNormType = LayerNormType.default
|
| 247 |
+
"""
|
| 248 |
+
The layernorm implementation to use.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
layer_norm_with_affine: bool = True
|
| 252 |
+
"""
|
| 253 |
+
Whether to include bias and weight parameters for the layer norms.
|
| 254 |
+
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
| 255 |
+
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
| 256 |
+
to ``False``.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
rms_norm_eps: float = 1e-05
|
| 260 |
+
"""
|
| 261 |
+
The rms layernorm eps param.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
attention_layer_norm_with_affine: bool = True
|
| 265 |
+
"""
|
| 266 |
+
Toggle affine transform for the QK norms.
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
max_sequence_length: int = 1024
|
| 270 |
+
"""
|
| 271 |
+
The maximum input sequence length supported by the model.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
rope_theta: float = 10000.0
|
| 275 |
+
"""
|
| 276 |
+
The rope base param.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
include_qkv_bias: Optional[bool] = False
|
| 280 |
+
"""
|
| 281 |
+
Whether or not to include bias parameters in qkv linear layers.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
include_bias: bool = False
|
| 285 |
+
"""
|
| 286 |
+
Whether or not to include bias parameters in linear layers.
|
| 287 |
+
In PaLM, they got rid of all bias terms because they found that large
|
| 288 |
+
models tend to have near 0 bias terms anyway.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
bias_for_layer_norm: Optional[bool] = None
|
| 292 |
+
"""
|
| 293 |
+
Whether or not to include bias parameters in layer norm.
|
| 294 |
+
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
| 295 |
+
layer norm.
|
| 296 |
+
When this is None (the default), it inherits the setting from include_bias.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
scale_logits: bool = False
|
| 300 |
+
"""
|
| 301 |
+
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
vocab_size: int = 50257
|
| 305 |
+
"""
|
| 306 |
+
Vocabulary size of the model.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
embedding_size: Optional[int] = 50304
|
| 310 |
+
"""
|
| 311 |
+
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
| 312 |
+
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
| 313 |
+
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
| 314 |
+
substantially.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
weight_tying: bool = True
|
| 318 |
+
"""
|
| 319 |
+
Whether to tie output linear weights to the input embedding.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
eos_token_id: int = 50256
|
| 323 |
+
"""
|
| 324 |
+
The ID of the end-of-sentence special token.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
pad_token_id: int = 50256
|
| 328 |
+
"""
|
| 329 |
+
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
mask_token_id: Optional[int] = 50256
|
| 333 |
+
"""
|
| 334 |
+
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
init_device: Optional[str] = None
|
| 338 |
+
"""
|
| 339 |
+
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
init_fn: InitFnType = InitFnType.normal
|
| 343 |
+
"""
|
| 344 |
+
The weight initialization strategy.
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
init_std: float = 0.02
|
| 348 |
+
"""
|
| 349 |
+
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 350 |
+
as "normal".
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
init_cutoff_factor: Optional[float] = None
|
| 354 |
+
"""
|
| 355 |
+
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
| 356 |
+
as "normal". Setting this to None means values are not cutoff.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
precision: Optional[str] = None
|
| 360 |
+
"""
|
| 361 |
+
Precision used to train/evaluate with. You shouldn't set this directly.
|
| 362 |
+
See :data:`TrainConfig.precision` instead.
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def effective_n_kv_heads(self) -> int:
|
| 367 |
+
if self.n_kv_heads is None:
|
| 368 |
+
if self.multi_query_attention is True:
|
| 369 |
+
return 1
|
| 370 |
+
else:
|
| 371 |
+
return self.n_heads
|
| 372 |
+
else:
|
| 373 |
+
if self.multi_query_attention is None:
|
| 374 |
+
return self.n_kv_heads
|
| 375 |
+
if self.multi_query_attention:
|
| 376 |
+
n_kv_heads_should_be = 1
|
| 377 |
+
else:
|
| 378 |
+
n_kv_heads_should_be = self.n_heads
|
| 379 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
| 380 |
+
return n_kv_heads_should_be
|
| 381 |
+
else:
|
| 382 |
+
raise Exception(
|
| 383 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
class ActivationCheckpointingStrategy(StrEnum):
|
| 387 |
+
whole_layer = "whole_layer"
|
| 388 |
+
"""
|
| 389 |
+
Checkpoint every transformer layer.
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
one_in_two = "one_in_two"
|
| 393 |
+
"""
|
| 394 |
+
Checkpoint one in two transformer layers.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
one_in_three = "one_in_three"
|
| 398 |
+
"""
|
| 399 |
+
Checkpoint one in three transformer layers.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
one_in_four = "one_in_four"
|
| 403 |
+
"""
|
| 404 |
+
Checkpoint one in four transformer layers.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
two_in_three = "two_in_three"
|
| 408 |
+
"""
|
| 409 |
+
Checkpoint two out of every three transformer layers.
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
three_in_four = "three_in_four"
|
| 413 |
+
"""
|
| 414 |
+
Checkpoint three out of four of every transformer layers.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
four_in_five = "four_in_five"
|
| 418 |
+
"""
|
| 419 |
+
Checkpoint four out of five of every transformer layers.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
nine_in_ten = "nine_in_ten"
|
| 423 |
+
"""
|
| 424 |
+
Checkpoint nine out of ten of every transformer layers.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
fine_grained = "fine_grained"
|
| 428 |
+
"""
|
| 429 |
+
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class LLaDAConfig(PretrainedConfig):
|
| 434 |
+
model_type = "llada"
|
| 435 |
+
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
| 436 |
+
|
| 437 |
+
def __init__(self, use_cache: bool = False, **kwargs):
|
| 438 |
+
model_config = ModelConfig()
|
| 439 |
+
all_kwargs = model_config.__dict__
|
| 440 |
+
all_kwargs.update(kwargs)
|
| 441 |
+
all_kwargs.update({"use_cache": use_cache})
|
| 442 |
+
all_kwargs.update(
|
| 443 |
+
{
|
| 444 |
+
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
| 445 |
+
}
|
| 446 |
+
)
|
| 447 |
+
super().__init__(**all_kwargs)
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def num_attention_heads(self):
|
| 451 |
+
return self.n_heads
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def num_hidden_layers(self):
|
| 455 |
+
return self.n_layers
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def hidden_size(self):
|
| 459 |
+
return self.d_model
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
|
| 463 |
+
AutoConfig.register("llada", LLaDAConfig)
|
models/logging.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Optuna, Hugging Face
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" Logging utilities."""
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import threading
|
| 21 |
+
from logging import CRITICAL # NOQA
|
| 22 |
+
from logging import DEBUG # NOQA
|
| 23 |
+
from logging import ERROR # NOQA
|
| 24 |
+
from logging import FATAL # NOQA
|
| 25 |
+
from logging import INFO # NOQA
|
| 26 |
+
from logging import NOTSET # NOQA
|
| 27 |
+
from logging import WARN # NOQA
|
| 28 |
+
from logging import WARNING # NOQA
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
from tqdm import auto as tqdm_lib
|
| 32 |
+
|
| 33 |
+
_lock = threading.Lock()
|
| 34 |
+
_default_handler: Optional[logging.Handler] = None
|
| 35 |
+
|
| 36 |
+
log_levels = {
|
| 37 |
+
"debug": logging.DEBUG,
|
| 38 |
+
"info": logging.INFO,
|
| 39 |
+
"warning": logging.WARNING,
|
| 40 |
+
"error": logging.ERROR,
|
| 41 |
+
"critical": logging.CRITICAL,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
_default_log_level = logging.WARNING
|
| 45 |
+
|
| 46 |
+
_tqdm_active = True
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _get_default_logging_level():
|
| 50 |
+
"""
|
| 51 |
+
If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
| 52 |
+
not - fall back to `_default_log_level`
|
| 53 |
+
"""
|
| 54 |
+
env_level_str = os.getenv("muse_VERBOSITY", None)
|
| 55 |
+
if env_level_str:
|
| 56 |
+
if env_level_str in log_levels:
|
| 57 |
+
return log_levels[env_level_str]
|
| 58 |
+
else:
|
| 59 |
+
logging.getLogger().warning(
|
| 60 |
+
f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
|
| 61 |
+
)
|
| 62 |
+
return _default_log_level
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_library_name() -> str:
|
| 66 |
+
return __name__.split(".")[0]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _get_library_root_logger() -> logging.Logger:
|
| 70 |
+
return logging.getLogger(_get_library_name())
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _configure_library_root_logger() -> None:
|
| 74 |
+
global _default_handler
|
| 75 |
+
|
| 76 |
+
with _lock:
|
| 77 |
+
if _default_handler:
|
| 78 |
+
# This library has already configured the library root logger.
|
| 79 |
+
return
|
| 80 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
| 81 |
+
_default_handler.flush = sys.stderr.flush
|
| 82 |
+
|
| 83 |
+
# Apply our default configuration to the library root logger.
|
| 84 |
+
library_root_logger = _get_library_root_logger()
|
| 85 |
+
library_root_logger.addHandler(_default_handler)
|
| 86 |
+
library_root_logger.setLevel(_get_default_logging_level())
|
| 87 |
+
library_root_logger.propagate = False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _reset_library_root_logger() -> None:
|
| 91 |
+
global _default_handler
|
| 92 |
+
|
| 93 |
+
with _lock:
|
| 94 |
+
if not _default_handler:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
library_root_logger = _get_library_root_logger()
|
| 98 |
+
library_root_logger.removeHandler(_default_handler)
|
| 99 |
+
library_root_logger.setLevel(logging.NOTSET)
|
| 100 |
+
_default_handler = None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_log_levels_dict():
|
| 104 |
+
return log_levels
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
| 108 |
+
"""
|
| 109 |
+
Return a logger with the specified name.
|
| 110 |
+
|
| 111 |
+
This function is not supposed to be directly accessed unless you are writing a custom muse module.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
if name is None:
|
| 115 |
+
name = _get_library_name()
|
| 116 |
+
|
| 117 |
+
_configure_library_root_logger()
|
| 118 |
+
return logging.getLogger(name)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_verbosity() -> int:
|
| 122 |
+
"""
|
| 123 |
+
Return the current level for the 🤗 muse' root logger as an int.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
`int`: The logging level.
|
| 127 |
+
|
| 128 |
+
<Tip>
|
| 129 |
+
|
| 130 |
+
🤗 muse has following logging levels:
|
| 131 |
+
|
| 132 |
+
- 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
|
| 133 |
+
- 40: `muse.logging.ERROR`
|
| 134 |
+
- 30: `muse.logging.WARNING` or `muse.logging.WARN`
|
| 135 |
+
- 20: `muse.logging.INFO`
|
| 136 |
+
- 10: `muse.logging.DEBUG`
|
| 137 |
+
|
| 138 |
+
</Tip>"""
|
| 139 |
+
|
| 140 |
+
_configure_library_root_logger()
|
| 141 |
+
return _get_library_root_logger().getEffectiveLevel()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def set_verbosity(verbosity: int) -> None:
|
| 145 |
+
"""
|
| 146 |
+
Set the verbosity level for the 🤗 muse' root logger.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
verbosity (`int`):
|
| 150 |
+
Logging level, e.g., one of:
|
| 151 |
+
|
| 152 |
+
- `muse.logging.CRITICAL` or `muse.logging.FATAL`
|
| 153 |
+
- `muse.logging.ERROR`
|
| 154 |
+
- `muse.logging.WARNING` or `muse.logging.WARN`
|
| 155 |
+
- `muse.logging.INFO`
|
| 156 |
+
- `muse.logging.DEBUG`
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
_configure_library_root_logger()
|
| 160 |
+
_get_library_root_logger().setLevel(verbosity)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def set_verbosity_info():
|
| 164 |
+
"""Set the verbosity to the `INFO` level."""
|
| 165 |
+
return set_verbosity(INFO)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def set_verbosity_warning():
|
| 169 |
+
"""Set the verbosity to the `WARNING` level."""
|
| 170 |
+
return set_verbosity(WARNING)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def set_verbosity_debug():
|
| 174 |
+
"""Set the verbosity to the `DEBUG` level."""
|
| 175 |
+
return set_verbosity(DEBUG)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def set_verbosity_error():
|
| 179 |
+
"""Set the verbosity to the `ERROR` level."""
|
| 180 |
+
return set_verbosity(ERROR)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def disable_default_handler() -> None:
|
| 184 |
+
"""Disable the default handler of the HuggingFace muse' root logger."""
|
| 185 |
+
|
| 186 |
+
_configure_library_root_logger()
|
| 187 |
+
|
| 188 |
+
assert _default_handler is not None
|
| 189 |
+
_get_library_root_logger().removeHandler(_default_handler)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def enable_default_handler() -> None:
|
| 193 |
+
"""Enable the default handler of the HuggingFace muse' root logger."""
|
| 194 |
+
|
| 195 |
+
_configure_library_root_logger()
|
| 196 |
+
|
| 197 |
+
assert _default_handler is not None
|
| 198 |
+
_get_library_root_logger().addHandler(_default_handler)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def add_handler(handler: logging.Handler) -> None:
|
| 202 |
+
"""adds a handler to the HuggingFace muse' root logger."""
|
| 203 |
+
|
| 204 |
+
_configure_library_root_logger()
|
| 205 |
+
|
| 206 |
+
assert handler is not None
|
| 207 |
+
_get_library_root_logger().addHandler(handler)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def remove_handler(handler: logging.Handler) -> None:
|
| 211 |
+
"""removes given handler from the HuggingFace muse' root logger."""
|
| 212 |
+
|
| 213 |
+
_configure_library_root_logger()
|
| 214 |
+
|
| 215 |
+
assert handler is not None and handler not in _get_library_root_logger().handlers
|
| 216 |
+
_get_library_root_logger().removeHandler(handler)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def disable_propagation() -> None:
|
| 220 |
+
"""
|
| 221 |
+
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
_configure_library_root_logger()
|
| 225 |
+
_get_library_root_logger().propagate = False
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def enable_propagation() -> None:
|
| 229 |
+
"""
|
| 230 |
+
Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
|
| 231 |
+
double logging if the root logger has been configured.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
_configure_library_root_logger()
|
| 235 |
+
_get_library_root_logger().propagate = True
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def enable_explicit_format() -> None:
|
| 239 |
+
"""
|
| 240 |
+
Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
|
| 241 |
+
```
|
| 242 |
+
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
| 243 |
+
```
|
| 244 |
+
All handlers currently bound to the root logger are affected by this method.
|
| 245 |
+
"""
|
| 246 |
+
handlers = _get_library_root_logger().handlers
|
| 247 |
+
|
| 248 |
+
for handler in handlers:
|
| 249 |
+
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
| 250 |
+
handler.setFormatter(formatter)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def reset_format() -> None:
|
| 254 |
+
"""
|
| 255 |
+
Resets the formatting for HuggingFace muse' loggers.
|
| 256 |
+
|
| 257 |
+
All handlers currently bound to the root logger are affected by this method.
|
| 258 |
+
"""
|
| 259 |
+
handlers = _get_library_root_logger().handlers
|
| 260 |
+
|
| 261 |
+
for handler in handlers:
|
| 262 |
+
handler.setFormatter(None)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def warning_advice(self, *args, **kwargs):
|
| 266 |
+
"""
|
| 267 |
+
This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
|
| 268 |
+
warning will not be printed
|
| 269 |
+
"""
|
| 270 |
+
no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
|
| 271 |
+
if no_advisory_warnings:
|
| 272 |
+
return
|
| 273 |
+
self.warning(*args, **kwargs)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
logging.Logger.warning_advice = warning_advice
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class EmptyTqdm:
|
| 280 |
+
"""Dummy tqdm which doesn't do anything."""
|
| 281 |
+
|
| 282 |
+
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
| 283 |
+
self._iterator = args[0] if args else None
|
| 284 |
+
|
| 285 |
+
def __iter__(self):
|
| 286 |
+
return iter(self._iterator)
|
| 287 |
+
|
| 288 |
+
def __getattr__(self, _):
|
| 289 |
+
"""Return empty function."""
|
| 290 |
+
|
| 291 |
+
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
return empty_fn
|
| 295 |
+
|
| 296 |
+
def __enter__(self):
|
| 297 |
+
return self
|
| 298 |
+
|
| 299 |
+
def __exit__(self, type_, value, traceback):
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class _tqdm_cls:
|
| 304 |
+
def __call__(self, *args, **kwargs):
|
| 305 |
+
if _tqdm_active:
|
| 306 |
+
return tqdm_lib.tqdm(*args, **kwargs)
|
| 307 |
+
else:
|
| 308 |
+
return EmptyTqdm(*args, **kwargs)
|
| 309 |
+
|
| 310 |
+
def set_lock(self, *args, **kwargs):
|
| 311 |
+
self._lock = None
|
| 312 |
+
if _tqdm_active:
|
| 313 |
+
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
| 314 |
+
|
| 315 |
+
def get_lock(self):
|
| 316 |
+
if _tqdm_active:
|
| 317 |
+
return tqdm_lib.tqdm.get_lock()
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
tqdm = _tqdm_cls()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def is_progress_bar_enabled() -> bool:
|
| 324 |
+
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
| 325 |
+
global _tqdm_active
|
| 326 |
+
return bool(_tqdm_active)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def enable_progress_bar():
|
| 330 |
+
"""Enable tqdm progress bar."""
|
| 331 |
+
global _tqdm_active
|
| 332 |
+
_tqdm_active = True
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def disable_progress_bar():
|
| 336 |
+
"""Disable tqdm progress bar."""
|
| 337 |
+
global _tqdm_active
|
| 338 |
+
_tqdm_active = False
|
models/lr_schedulers.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch optimization for diffusion models."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from enum import Enum
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
from torch.optim import Optimizer
|
| 22 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 23 |
+
|
| 24 |
+
from .logging import get_logger
|
| 25 |
+
|
| 26 |
+
logger = get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SchedulerType(Enum):
|
| 30 |
+
LINEAR = "linear"
|
| 31 |
+
COSINE = "cosine"
|
| 32 |
+
COSINE_WITH_RESTARTS = "cosine_with_restarts"
|
| 33 |
+
POLYNOMIAL = "polynomial"
|
| 34 |
+
CONSTANT = "constant"
|
| 35 |
+
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
| 39 |
+
"""
|
| 40 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 44 |
+
The optimizer for which to schedule the learning rate.
|
| 45 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 46 |
+
The index of the last epoch when resuming training.
|
| 47 |
+
|
| 48 |
+
Return:
|
| 49 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 50 |
+
"""
|
| 51 |
+
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
| 55 |
+
"""
|
| 56 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
| 57 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 61 |
+
The optimizer for which to schedule the learning rate.
|
| 62 |
+
num_warmup_steps (`int`):
|
| 63 |
+
The number of steps for the warmup phase.
|
| 64 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 65 |
+
The index of the last epoch when resuming training.
|
| 66 |
+
|
| 67 |
+
Return:
|
| 68 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def lr_lambda(current_step: int):
|
| 72 |
+
if current_step < num_warmup_steps:
|
| 73 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
| 74 |
+
return 1.0
|
| 75 |
+
|
| 76 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
| 80 |
+
"""
|
| 81 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
| 82 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 86 |
+
The optimizer for which to schedule the learning rate.
|
| 87 |
+
num_warmup_steps (`int`):
|
| 88 |
+
The number of steps for the warmup phase.
|
| 89 |
+
num_training_steps (`int`):
|
| 90 |
+
The total number of training steps.
|
| 91 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 92 |
+
The index of the last epoch when resuming training.
|
| 93 |
+
|
| 94 |
+
Return:
|
| 95 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def lr_lambda(current_step: int):
|
| 99 |
+
if current_step < num_warmup_steps:
|
| 100 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 101 |
+
return max(
|
| 102 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_cosine_schedule_with_warmup(
|
| 109 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, min_lr_scale: float = 0.0
|
| 110 |
+
):
|
| 111 |
+
"""
|
| 112 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 113 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
| 114 |
+
initial lr set in the optimizer.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 118 |
+
The optimizer for which to schedule the learning rate.
|
| 119 |
+
num_warmup_steps (`int`):
|
| 120 |
+
The number of steps for the warmup phase.
|
| 121 |
+
num_training_steps (`int`):
|
| 122 |
+
The total number of training steps.
|
| 123 |
+
num_periods (`float`, *optional*, defaults to 0.5):
|
| 124 |
+
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
|
| 125 |
+
value to 0 following a half-cosine).
|
| 126 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 127 |
+
The index of the last epoch when resuming training.
|
| 128 |
+
|
| 129 |
+
Return:
|
| 130 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
# def lr_lambda(current_step):
|
| 134 |
+
# if current_step < num_warmup_steps:
|
| 135 |
+
# return float(current_step) / float(max(1, num_warmup_steps))
|
| 136 |
+
# progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 137 |
+
# return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 138 |
+
|
| 139 |
+
# return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 140 |
+
|
| 141 |
+
def lr_lambda(current_step):
|
| 142 |
+
if current_step < num_warmup_steps:
|
| 143 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 144 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 145 |
+
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * 2.0 * num_cycles * progress))
|
| 146 |
+
return min_lr_scale + (1.0 - min_lr_scale) * cosine_decay
|
| 147 |
+
|
| 148 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 152 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
| 153 |
+
):
|
| 154 |
+
"""
|
| 155 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 156 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
| 157 |
+
linearly between 0 and the initial lr set in the optimizer.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 161 |
+
The optimizer for which to schedule the learning rate.
|
| 162 |
+
num_warmup_steps (`int`):
|
| 163 |
+
The number of steps for the warmup phase.
|
| 164 |
+
num_training_steps (`int`):
|
| 165 |
+
The total number of training steps.
|
| 166 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
| 167 |
+
The number of hard restarts to use.
|
| 168 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 169 |
+
The index of the last epoch when resuming training.
|
| 170 |
+
|
| 171 |
+
Return:
|
| 172 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def lr_lambda(current_step):
|
| 176 |
+
if current_step < num_warmup_steps:
|
| 177 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 178 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 179 |
+
if progress >= 1.0:
|
| 180 |
+
return 0.0
|
| 181 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
| 182 |
+
|
| 183 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_polynomial_decay_schedule_with_warmup(
|
| 187 |
+
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
| 188 |
+
):
|
| 189 |
+
"""
|
| 190 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
| 191 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
| 192 |
+
initial lr set in the optimizer.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 196 |
+
The optimizer for which to schedule the learning rate.
|
| 197 |
+
num_warmup_steps (`int`):
|
| 198 |
+
The number of steps for the warmup phase.
|
| 199 |
+
num_training_steps (`int`):
|
| 200 |
+
The total number of training steps.
|
| 201 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
| 202 |
+
The end LR.
|
| 203 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 204 |
+
Power factor.
|
| 205 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 206 |
+
The index of the last epoch when resuming training.
|
| 207 |
+
|
| 208 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
| 209 |
+
implementation at
|
| 210 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
| 211 |
+
|
| 212 |
+
Return:
|
| 213 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 214 |
+
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
lr_init = optimizer.defaults["lr"]
|
| 218 |
+
if not (lr_init > lr_end):
|
| 219 |
+
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
|
| 220 |
+
|
| 221 |
+
def lr_lambda(current_step: int):
|
| 222 |
+
if current_step < num_warmup_steps:
|
| 223 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 224 |
+
elif current_step > num_training_steps:
|
| 225 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
| 226 |
+
else:
|
| 227 |
+
lr_range = lr_init - lr_end
|
| 228 |
+
decay_steps = num_training_steps - num_warmup_steps
|
| 229 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
| 230 |
+
decay = lr_range * pct_remaining**power + lr_end
|
| 231 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
| 232 |
+
|
| 233 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
| 237 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
| 238 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
| 239 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 240 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
| 241 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
| 242 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def get_scheduler(
|
| 247 |
+
name: Union[str, SchedulerType],
|
| 248 |
+
optimizer: Optimizer,
|
| 249 |
+
num_warmup_steps: Optional[int] = None,
|
| 250 |
+
num_training_steps: Optional[int] = None,
|
| 251 |
+
num_cycles: int = 1,
|
| 252 |
+
power: float = 1.0,
|
| 253 |
+
min_lr_scale: float = 0.0
|
| 254 |
+
):
|
| 255 |
+
"""
|
| 256 |
+
Unified API to get any scheduler from its name.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
name (`str` or `SchedulerType`):
|
| 260 |
+
The name of the scheduler to use.
|
| 261 |
+
optimizer (`torch.optim.Optimizer`):
|
| 262 |
+
The optimizer that will be used during training.
|
| 263 |
+
num_warmup_steps (`int`, *optional*):
|
| 264 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
| 265 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 266 |
+
num_training_steps (`int``, *optional*):
|
| 267 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
| 268 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 269 |
+
num_cycles (`int`, *optional*):
|
| 270 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
| 271 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 272 |
+
Power factor. See `POLYNOMIAL` scheduler
|
| 273 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 274 |
+
The index of the last epoch when resuming training.
|
| 275 |
+
"""
|
| 276 |
+
name = SchedulerType(name)
|
| 277 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
| 278 |
+
if name == SchedulerType.CONSTANT:
|
| 279 |
+
return schedule_func(optimizer)
|
| 280 |
+
|
| 281 |
+
# All other schedulers require `num_warmup_steps`
|
| 282 |
+
if num_warmup_steps is None:
|
| 283 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
| 284 |
+
|
| 285 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
| 286 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
| 287 |
+
|
| 288 |
+
# All other schedulers require `num_training_steps`
|
| 289 |
+
if num_training_steps is None:
|
| 290 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
| 291 |
+
|
| 292 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
| 293 |
+
return schedule_func(
|
| 294 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, min_lr_scale=min_lr_scale
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if name == SchedulerType.POLYNOMIAL:
|
| 298 |
+
return schedule_func(
|
| 299 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
models/misc.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
import torch
|
| 3 |
+
from typing import (
|
| 4 |
+
Any,
|
| 5 |
+
Callable,
|
| 6 |
+
Dict,
|
| 7 |
+
Iterable,
|
| 8 |
+
List,
|
| 9 |
+
NamedTuple,
|
| 10 |
+
NewType,
|
| 11 |
+
Optional,
|
| 12 |
+
Sized,
|
| 13 |
+
Tuple,
|
| 14 |
+
Type,
|
| 15 |
+
TypeVar,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
try:
|
| 19 |
+
from typing import Literal
|
| 20 |
+
except ImportError:
|
| 21 |
+
from typing_extensions import Literal
|
| 22 |
+
|
| 23 |
+
# Tensor dtype
|
| 24 |
+
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
| 25 |
+
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
| 26 |
+
|
| 27 |
+
# Config type
|
| 28 |
+
from omegaconf import DictConfig
|
| 29 |
+
|
| 30 |
+
# PyTorch Tensor type
|
| 31 |
+
from torch import Tensor
|
| 32 |
+
|
| 33 |
+
# Runtime type checking decorator
|
| 34 |
+
from typeguard import typechecked as typechecker
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def broadcast(tensor, src=0):
|
| 38 |
+
if not _distributed_available():
|
| 39 |
+
return tensor
|
| 40 |
+
else:
|
| 41 |
+
torch.distributed.broadcast(tensor, src=src)
|
| 42 |
+
return tensor
|
| 43 |
+
|
| 44 |
+
def _distributed_available():
|
| 45 |
+
return torch.distributed.is_available() and torch.distributed.is_initialized()
|
| 46 |
+
|
| 47 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
| 48 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
| 49 |
+
if '--local-rank' in cfg:
|
| 50 |
+
del cfg['--local-rank']
|
| 51 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
| 52 |
+
scfg = OmegaConf.structured(fields(**cfg))
|
| 53 |
+
return scfg
|
models/modeling_llada.py
ADDED
|
@@ -0,0 +1,1500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import sys
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import (
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
Iterable,
|
| 13 |
+
List,
|
| 14 |
+
NamedTuple,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
cast,
|
| 20 |
+
)
|
| 21 |
+
from dataclasses import fields
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.backends.cuda
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from torch import einsum
|
| 29 |
+
from transformers import PreTrainedModel
|
| 30 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 31 |
+
from transformers.models.auto import AutoModel
|
| 32 |
+
from transformers.cache_utils import Cache
|
| 33 |
+
|
| 34 |
+
from .configuration_llada import (
|
| 35 |
+
LLaDAConfig,
|
| 36 |
+
StrEnum,
|
| 37 |
+
InitFnType,
|
| 38 |
+
ActivationType,
|
| 39 |
+
BlockType,
|
| 40 |
+
LayerNormType,
|
| 41 |
+
ModelConfig,
|
| 42 |
+
ActivationCheckpointingStrategy,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if sys.version_info.minor > 8:
|
| 46 |
+
from collections.abc import MutableMapping
|
| 47 |
+
elif sys.version_info.minor == 8:
|
| 48 |
+
from typing import MutableMapping
|
| 49 |
+
else:
|
| 50 |
+
raise SystemExit("This script supports Python 3.8 or higher")
|
| 51 |
+
|
| 52 |
+
__all__ = [
|
| 53 |
+
"LayerNormBase",
|
| 54 |
+
"LayerNorm",
|
| 55 |
+
"RMSLayerNorm",
|
| 56 |
+
"GemmaRMSLayerNorm",
|
| 57 |
+
"RotaryEmbedding",
|
| 58 |
+
"Activation",
|
| 59 |
+
"GELU",
|
| 60 |
+
"ReLU",
|
| 61 |
+
"SwiGLU",
|
| 62 |
+
"LLaDABlock",
|
| 63 |
+
"LLaDASequentialBlock",
|
| 64 |
+
"LLaDAModel",
|
| 65 |
+
"LLaDAOutput",
|
| 66 |
+
"LLaDAGenerateOutput",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
log = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ModuleType(StrEnum):
|
| 74 |
+
in_module = "in"
|
| 75 |
+
out_module = "out"
|
| 76 |
+
emb = "emb"
|
| 77 |
+
final_out = "final_out"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def init_weights(
|
| 81 |
+
config: ModelConfig,
|
| 82 |
+
module: Union[nn.Linear, nn.Embedding],
|
| 83 |
+
d: Optional[int] = None,
|
| 84 |
+
layer_id: Optional[int] = None,
|
| 85 |
+
std_factor: float = 1.0,
|
| 86 |
+
type_of_module: Optional[ModuleType] = None,
|
| 87 |
+
) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Initialize weights of a linear or embedding module.
|
| 90 |
+
|
| 91 |
+
:param config: The model config.
|
| 92 |
+
:param module: The linear or embedding submodule to initialize.
|
| 93 |
+
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
|
| 94 |
+
for fused layers.
|
| 95 |
+
:param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
|
| 96 |
+
``1 / sqrt(2 * (layer_id + 1))``.
|
| 97 |
+
"""
|
| 98 |
+
d = d if d is not None else config.d_model
|
| 99 |
+
if config.init_fn == InitFnType.normal:
|
| 100 |
+
std = config.init_std * std_factor
|
| 101 |
+
if config.init_cutoff_factor is not None:
|
| 102 |
+
cutoff_value = config.init_cutoff_factor * std
|
| 103 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
|
| 104 |
+
else:
|
| 105 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 106 |
+
elif config.init_fn == InitFnType.mitchell:
|
| 107 |
+
std = std_factor / math.sqrt(d)
|
| 108 |
+
if layer_id is not None:
|
| 109 |
+
std = std / math.sqrt(2 * (layer_id + 1))
|
| 110 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
| 111 |
+
elif config.init_fn == InitFnType.kaiming_normal:
|
| 112 |
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
| 113 |
+
elif config.init_fn == InitFnType.fan_in:
|
| 114 |
+
std = std_factor / math.sqrt(d)
|
| 115 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 116 |
+
elif config.init_fn == InitFnType.full_megatron:
|
| 117 |
+
if type_of_module is None:
|
| 118 |
+
raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
|
| 119 |
+
|
| 120 |
+
cutoff_factor = config.init_cutoff_factor
|
| 121 |
+
if cutoff_factor is None:
|
| 122 |
+
cutoff_factor = 3
|
| 123 |
+
|
| 124 |
+
if type_of_module == ModuleType.in_module:
|
| 125 |
+
# for att_proj (same as QKV), ff_proj
|
| 126 |
+
std = config.init_std
|
| 127 |
+
elif type_of_module == ModuleType.out_module:
|
| 128 |
+
# for attn_out, ff_out
|
| 129 |
+
std = config.init_std / math.sqrt(2.0 * config.n_layers)
|
| 130 |
+
elif type_of_module == ModuleType.emb:
|
| 131 |
+
# positional embeddings (wpe)
|
| 132 |
+
# token embeddings (wte)
|
| 133 |
+
std = config.init_std
|
| 134 |
+
elif type_of_module == ModuleType.final_out:
|
| 135 |
+
# final output (ff_out)
|
| 136 |
+
std = config.d_model**-0.5
|
| 137 |
+
else:
|
| 138 |
+
raise RuntimeError(f"Unknown module type '{type_of_module}'")
|
| 139 |
+
nn.init.trunc_normal_(
|
| 140 |
+
module.weight,
|
| 141 |
+
mean=0.0,
|
| 142 |
+
std=std,
|
| 143 |
+
a=-cutoff_factor * std,
|
| 144 |
+
b=cutoff_factor * std,
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
raise NotImplementedError(config.init_fn)
|
| 148 |
+
|
| 149 |
+
if isinstance(module, nn.Linear):
|
| 150 |
+
if module.bias is not None:
|
| 151 |
+
nn.init.zeros_(module.bias)
|
| 152 |
+
|
| 153 |
+
if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
module.weight.div_(math.sqrt(2 * config.n_layers))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
| 159 |
+
"""
|
| 160 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
| 161 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
| 162 |
+
"""
|
| 163 |
+
if check_neg_inf:
|
| 164 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
| 165 |
+
if check_pos_inf:
|
| 166 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def activation_checkpoint_function(cfg: ModelConfig):
|
| 170 |
+
preserve_rng_state = (
|
| 171 |
+
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
|
| 172 |
+
)
|
| 173 |
+
from torch.utils.checkpoint import checkpoint
|
| 174 |
+
|
| 175 |
+
return partial(
|
| 176 |
+
checkpoint,
|
| 177 |
+
preserve_rng_state=preserve_rng_state,
|
| 178 |
+
use_reentrant=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
| 183 |
+
"""
|
| 184 |
+
Cache for attention biases and other things that would normally be stored as buffers.
|
| 185 |
+
We avoid using buffers because we've run into various issues doing so with FSDP.
|
| 186 |
+
In general it appears the way FSDP handles buffers is not well-defined.
|
| 187 |
+
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
|
| 188 |
+
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
|
| 189 |
+
NaNs when they're synchronized due to casting or some other issue.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _non_meta_init_device(config: ModelConfig) -> torch.device:
|
| 194 |
+
if config.init_device is not None and config.init_device != "meta":
|
| 195 |
+
return torch.device(config.init_device)
|
| 196 |
+
else:
|
| 197 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Dropout(nn.Dropout):
|
| 201 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 202 |
+
if self.p == 0.0:
|
| 203 |
+
return input
|
| 204 |
+
else:
|
| 205 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class LayerNormBase(nn.Module):
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
config: ModelConfig,
|
| 212 |
+
*,
|
| 213 |
+
size: Optional[int] = None,
|
| 214 |
+
elementwise_affine: Optional[bool] = True,
|
| 215 |
+
eps: float = 1e-05,
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.config = config
|
| 219 |
+
self.eps = eps
|
| 220 |
+
self.normalized_shape = (size or config.d_model,)
|
| 221 |
+
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
| 222 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
|
| 223 |
+
use_bias = self.config.bias_for_layer_norm
|
| 224 |
+
if use_bias is None:
|
| 225 |
+
use_bias = self.config.include_bias
|
| 226 |
+
if use_bias:
|
| 227 |
+
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
|
| 228 |
+
else:
|
| 229 |
+
self.register_parameter("bias", None)
|
| 230 |
+
else:
|
| 231 |
+
self.register_parameter("bias", None)
|
| 232 |
+
self.register_parameter("weight", None)
|
| 233 |
+
|
| 234 |
+
@abstractmethod
|
| 235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 236 |
+
raise NotImplementedError
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
|
| 240 |
+
if config.layer_norm_type == LayerNormType.default:
|
| 241 |
+
return LayerNorm(config, size=size, low_precision=False, **kwargs)
|
| 242 |
+
elif config.layer_norm_type == LayerNormType.low_precision:
|
| 243 |
+
return LayerNorm(config, size=size, low_precision=True, **kwargs)
|
| 244 |
+
elif config.layer_norm_type == LayerNormType.rms:
|
| 245 |
+
return RMSLayerNorm(config, size=size, **kwargs)
|
| 246 |
+
elif config.layer_norm_type == LayerNormType.gemma_rms:
|
| 247 |
+
return GemmaRMSLayerNorm(config, size=size, **kwargs)
|
| 248 |
+
else:
|
| 249 |
+
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
|
| 250 |
+
|
| 251 |
+
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
| 252 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 253 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 254 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 255 |
+
if tensor.device.type == "cuda" and torch.is_autocast_enabled():
|
| 256 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
|
| 257 |
+
elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 258 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
| 259 |
+
else:
|
| 260 |
+
return tensor
|
| 261 |
+
|
| 262 |
+
def reset_parameters(self):
|
| 263 |
+
if self.weight is not None:
|
| 264 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
| 265 |
+
if self.bias is not None:
|
| 266 |
+
torch.nn.init.zeros_(self.bias) # type: ignore
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class LayerNorm(LayerNormBase):
|
| 270 |
+
"""
|
| 271 |
+
The default :class:`LayerNorm` implementation which can optionally run in low precision.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
config: ModelConfig,
|
| 277 |
+
size: Optional[int] = None,
|
| 278 |
+
low_precision: bool = False,
|
| 279 |
+
elementwise_affine: Optional[bool] = None,
|
| 280 |
+
eps: float = 1e-05,
|
| 281 |
+
):
|
| 282 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
| 283 |
+
self.low_precision = low_precision
|
| 284 |
+
|
| 285 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 286 |
+
if self.low_precision:
|
| 287 |
+
module_device = x.device
|
| 288 |
+
downcast_x = self._cast_if_autocast_enabled(x)
|
| 289 |
+
downcast_weight = (
|
| 290 |
+
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 291 |
+
)
|
| 292 |
+
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 293 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 294 |
+
return F.layer_norm(
|
| 295 |
+
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class RMSLayerNorm(LayerNormBase):
|
| 302 |
+
"""
|
| 303 |
+
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
config: ModelConfig,
|
| 309 |
+
size: Optional[int] = None,
|
| 310 |
+
elementwise_affine: Optional[bool] = None,
|
| 311 |
+
eps: float = 1e-5,
|
| 312 |
+
):
|
| 313 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 314 |
+
|
| 315 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 316 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 317 |
+
og_dtype = x.dtype
|
| 318 |
+
x = x.to(torch.float32)
|
| 319 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 320 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 321 |
+
x = x.to(og_dtype)
|
| 322 |
+
|
| 323 |
+
if self.weight is not None:
|
| 324 |
+
if self.bias is not None:
|
| 325 |
+
return self.weight * x + self.bias
|
| 326 |
+
else:
|
| 327 |
+
return self.weight * x
|
| 328 |
+
else:
|
| 329 |
+
return x
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class GemmaRMSLayerNorm(LayerNormBase):
|
| 333 |
+
"""
|
| 334 |
+
Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
config: ModelConfig,
|
| 340 |
+
size: Optional[int] = None,
|
| 341 |
+
elementwise_affine: Optional[bool] = None,
|
| 342 |
+
eps: float = 1e-5,
|
| 343 |
+
):
|
| 344 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
| 345 |
+
|
| 346 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 347 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 348 |
+
og_dtype = x.dtype
|
| 349 |
+
x = x.to(torch.float32)
|
| 350 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 351 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 352 |
+
x = x.to(og_dtype)
|
| 353 |
+
|
| 354 |
+
if self.weight is not None:
|
| 355 |
+
if self.bias is not None:
|
| 356 |
+
return x * (1 + self.weight) + self.bias
|
| 357 |
+
else:
|
| 358 |
+
return x * (1 + self.weight)
|
| 359 |
+
else:
|
| 360 |
+
return x
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class RotaryEmbedding(nn.Module):
|
| 364 |
+
"""
|
| 365 |
+
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
def __init__(self, config: ModelConfig, cache: BufferCache):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.config = config
|
| 371 |
+
self.__cache = cache
|
| 372 |
+
# Warm up cache.
|
| 373 |
+
self.rope_theta = config.rope_theta
|
| 374 |
+
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
|
| 375 |
+
|
| 376 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 377 |
+
if (
|
| 378 |
+
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
|
| 379 |
+
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
|
| 380 |
+
and pos_sin.shape[-2] >= seq_len
|
| 381 |
+
and pos_cos.shape[-2] >= seq_len
|
| 382 |
+
):
|
| 383 |
+
if pos_sin.device != device:
|
| 384 |
+
pos_sin = pos_sin.to(device)
|
| 385 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 386 |
+
if pos_cos.device != device:
|
| 387 |
+
pos_cos = pos_cos.to(device)
|
| 388 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 389 |
+
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
| 390 |
+
|
| 391 |
+
with torch.autocast(device.type, enabled=False):
|
| 392 |
+
dim = self.config.d_model // self.config.n_heads
|
| 393 |
+
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
| 394 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
| 395 |
+
freqs = einsum("i , j -> i j", seq, inv_freq)
|
| 396 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
| 397 |
+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
|
| 398 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 399 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 400 |
+
return pos_sin, pos_cos
|
| 401 |
+
|
| 402 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 403 |
+
B, nh, T, hs = x.size()
|
| 404 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
| 405 |
+
x1, x2 = x.unbind(dim=-2)
|
| 406 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 407 |
+
|
| 408 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 409 |
+
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
| 410 |
+
|
| 411 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 412 |
+
if self.config.rope_full_precision:
|
| 413 |
+
q_, k_ = q.float(), k.float()
|
| 414 |
+
else:
|
| 415 |
+
q_, k_ = q, k
|
| 416 |
+
|
| 417 |
+
with torch.autocast(q.device.type, enabled=False):
|
| 418 |
+
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
| 419 |
+
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
|
| 420 |
+
pos_sin = pos_sin.type_as(q_)
|
| 421 |
+
pos_cos = pos_cos.type_as(q_)
|
| 422 |
+
q_ = self.apply_rotary_pos_emb(
|
| 423 |
+
pos_sin[:, :, key_len - query_len : key_len, :],
|
| 424 |
+
pos_cos[:, :, key_len - query_len : key_len, :],
|
| 425 |
+
q_,
|
| 426 |
+
)
|
| 427 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
| 428 |
+
return q_.type_as(q), k_.type_as(k)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class Activation(nn.Module):
|
| 432 |
+
def __init__(self, config: ModelConfig):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.config = config
|
| 435 |
+
|
| 436 |
+
@abstractmethod
|
| 437 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 438 |
+
raise NotImplementedError
|
| 439 |
+
|
| 440 |
+
@property
|
| 441 |
+
@abstractmethod
|
| 442 |
+
def output_multiplier(self) -> float:
|
| 443 |
+
raise NotImplementedError
|
| 444 |
+
|
| 445 |
+
@classmethod
|
| 446 |
+
def build(cls, config: ModelConfig) -> Activation:
|
| 447 |
+
if config.activation_type == ActivationType.gelu:
|
| 448 |
+
return cast(Activation, GELU(approximate="none"))
|
| 449 |
+
elif config.activation_type == ActivationType.relu:
|
| 450 |
+
return cast(Activation, ReLU(inplace=False))
|
| 451 |
+
elif config.activation_type == ActivationType.silu:
|
| 452 |
+
return cast(Activation, SiLU(inplace=False))
|
| 453 |
+
elif config.activation_type == ActivationType.swiglu:
|
| 454 |
+
return SwiGLU(config)
|
| 455 |
+
else:
|
| 456 |
+
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class GELU(nn.GELU):
|
| 460 |
+
@property
|
| 461 |
+
def output_multiplier(self) -> float:
|
| 462 |
+
return 1.0
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class ReLU(nn.ReLU):
|
| 466 |
+
@property
|
| 467 |
+
def output_multiplier(self) -> float:
|
| 468 |
+
return 1.0
|
| 469 |
+
|
| 470 |
+
class SiLU(nn.SiLU):
|
| 471 |
+
@property
|
| 472 |
+
def output_multiplier(self) -> float:
|
| 473 |
+
return 1.0
|
| 474 |
+
|
| 475 |
+
class SwiGLU(Activation):
|
| 476 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 477 |
+
x, gate = x.chunk(2, dim=-1)
|
| 478 |
+
return F.silu(gate) * x
|
| 479 |
+
|
| 480 |
+
@property
|
| 481 |
+
def output_multiplier(self) -> float:
|
| 482 |
+
return 0.5
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
|
| 486 |
+
att_bias = torch.triu(
|
| 487 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
|
| 488 |
+
diagonal=1,
|
| 489 |
+
)
|
| 490 |
+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
|
| 491 |
+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 495 |
+
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
|
| 496 |
+
if causal_bias.device != device:
|
| 497 |
+
causal_bias = causal_bias.to(device)
|
| 498 |
+
cache["causal_attention_bias"] = causal_bias
|
| 499 |
+
return causal_bias
|
| 500 |
+
with torch.autocast(device.type, enabled=False):
|
| 501 |
+
causal_bias = causal_attention_bias(seq_len, device)
|
| 502 |
+
cache["causal_attention_bias"] = causal_bias
|
| 503 |
+
return causal_bias
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
|
| 507 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
|
| 508 |
+
|
| 509 |
+
# shape: (1, 1, seq_len, seq_len)
|
| 510 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
|
| 511 |
+
alibi_bias.abs_().mul_(-1)
|
| 512 |
+
|
| 513 |
+
# shape: (n_heads,)
|
| 514 |
+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
|
| 515 |
+
m.mul_(config.alibi_bias_max / config.n_heads)
|
| 516 |
+
|
| 517 |
+
# shape: (1, n_heads, seq_len, seq_len)
|
| 518 |
+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class LLaDABlock(nn.Module):
|
| 522 |
+
"""
|
| 523 |
+
A base class for transformer block implementations.
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 527 |
+
super().__init__()
|
| 528 |
+
self.layer_id = layer_id
|
| 529 |
+
self.config = config
|
| 530 |
+
self.hidden_size = (
|
| 531 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
| 532 |
+
)
|
| 533 |
+
self.__cache = cache
|
| 534 |
+
assert config.d_model % config.n_heads == 0
|
| 535 |
+
|
| 536 |
+
self._activation_checkpoint_fn = None
|
| 537 |
+
|
| 538 |
+
# Dropout.
|
| 539 |
+
self.dropout = Dropout(config.residual_dropout)
|
| 540 |
+
|
| 541 |
+
# Layer norms.
|
| 542 |
+
self.k_norm: Optional[LayerNormBase] = None
|
| 543 |
+
self.q_norm: Optional[LayerNormBase] = None
|
| 544 |
+
if config.attention_layer_norm:
|
| 545 |
+
self.k_norm = LayerNormBase.build(
|
| 546 |
+
config,
|
| 547 |
+
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
|
| 548 |
+
elementwise_affine=config.attention_layer_norm_with_affine,
|
| 549 |
+
)
|
| 550 |
+
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
|
| 551 |
+
|
| 552 |
+
# Activation function.
|
| 553 |
+
self.act = Activation.build(config)
|
| 554 |
+
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
| 555 |
+
|
| 556 |
+
# Attention output projection.
|
| 557 |
+
self.attn_out = nn.Linear(
|
| 558 |
+
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Feed-forward output projection.
|
| 562 |
+
self.ff_out = nn.Linear(
|
| 563 |
+
int(self.act.output_multiplier * self.hidden_size),
|
| 564 |
+
config.d_model,
|
| 565 |
+
bias=config.include_bias,
|
| 566 |
+
device=config.init_device,
|
| 567 |
+
)
|
| 568 |
+
self.ff_out._is_residual = True # type: ignore
|
| 569 |
+
|
| 570 |
+
# Rotary embeddings.
|
| 571 |
+
if self.config.rope:
|
| 572 |
+
self.rotary_emb = RotaryEmbedding(config, self.__cache)
|
| 573 |
+
|
| 574 |
+
self.flash_attn_func = None
|
| 575 |
+
if config.flash_attention:
|
| 576 |
+
try:
|
| 577 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 578 |
+
|
| 579 |
+
self.flash_attn_func = flash_attn_func
|
| 580 |
+
except ModuleNotFoundError:
|
| 581 |
+
pass
|
| 582 |
+
|
| 583 |
+
def reset_parameters(self):
|
| 584 |
+
if self.k_norm is not None:
|
| 585 |
+
self.k_norm.reset_parameters()
|
| 586 |
+
if self.q_norm is not None:
|
| 587 |
+
self.q_norm.reset_parameters()
|
| 588 |
+
init_weights(
|
| 589 |
+
self.config,
|
| 590 |
+
self.attn_out,
|
| 591 |
+
d=self.config.d_model,
|
| 592 |
+
layer_id=self.layer_id,
|
| 593 |
+
type_of_module=ModuleType.out_module,
|
| 594 |
+
)
|
| 595 |
+
init_weights(
|
| 596 |
+
self.config,
|
| 597 |
+
self.ff_out,
|
| 598 |
+
d=self.ff_out.in_features,
|
| 599 |
+
layer_id=self.layer_id,
|
| 600 |
+
type_of_module=ModuleType.out_module,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 604 |
+
if strategy == ActivationCheckpointingStrategy.fine_grained:
|
| 605 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 606 |
+
else:
|
| 607 |
+
self._activation_checkpoint_fn = None
|
| 608 |
+
|
| 609 |
+
@classmethod
|
| 610 |
+
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
|
| 611 |
+
target_dtype = input_dtype
|
| 612 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 613 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 614 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 615 |
+
if bias.device.type == "cuda" and torch.is_autocast_enabled():
|
| 616 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 617 |
+
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 618 |
+
target_dtype = torch.get_autocast_cpu_dtype()
|
| 619 |
+
if bias.dtype != target_dtype:
|
| 620 |
+
bias = bias.to(target_dtype)
|
| 621 |
+
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
|
| 622 |
+
return bias
|
| 623 |
+
|
| 624 |
+
def _scaled_dot_product_attention(
|
| 625 |
+
self,
|
| 626 |
+
q: torch.Tensor,
|
| 627 |
+
k: torch.Tensor,
|
| 628 |
+
v: torch.Tensor,
|
| 629 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 630 |
+
dropout_p: float = 0.0,
|
| 631 |
+
is_causal: bool = False,
|
| 632 |
+
) -> torch.Tensor:
|
| 633 |
+
"""
|
| 634 |
+
Computes scaled dot product attention on query, key and value tensors, using an optional
|
| 635 |
+
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
|
| 636 |
+
"""
|
| 637 |
+
if self.flash_attn_func is not None and attn_mask is None:
|
| 638 |
+
r = self.flash_attn_func(
|
| 639 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
|
| 640 |
+
)
|
| 641 |
+
return r.transpose(1, 2)
|
| 642 |
+
else:
|
| 643 |
+
# torch's sdpa doesn't support GQA, so we're doing this
|
| 644 |
+
assert k.size(1) == v.size(1)
|
| 645 |
+
num_kv_heads = k.size(1)
|
| 646 |
+
num_q_heads = q.size(1)
|
| 647 |
+
if num_q_heads != num_kv_heads:
|
| 648 |
+
assert num_q_heads % num_kv_heads == 0
|
| 649 |
+
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 650 |
+
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 651 |
+
|
| 652 |
+
# Modify: MDM set causal to False, and with no attn_mask.
|
| 653 |
+
return F.scaled_dot_product_attention(
|
| 654 |
+
q,
|
| 655 |
+
k,
|
| 656 |
+
v,
|
| 657 |
+
attn_mask=None,
|
| 658 |
+
dropout_p=dropout_p,
|
| 659 |
+
is_causal=False,
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
def attention(
|
| 663 |
+
self,
|
| 664 |
+
q: torch.Tensor,
|
| 665 |
+
k: torch.Tensor,
|
| 666 |
+
v: torch.Tensor,
|
| 667 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 668 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 669 |
+
use_cache: bool = False,
|
| 670 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 671 |
+
B, T, C = q.size() # batch size, sequence length, d_model
|
| 672 |
+
dtype = k.dtype
|
| 673 |
+
|
| 674 |
+
# Optionally apply layer norm to keys and queries.
|
| 675 |
+
if self.q_norm is not None and self.k_norm is not None:
|
| 676 |
+
q = self.q_norm(q).to(dtype=dtype)
|
| 677 |
+
k = self.k_norm(k).to(dtype=dtype)
|
| 678 |
+
|
| 679 |
+
# Move head forward to be next to the batch dim.
|
| 680 |
+
# shape: (B, nh, T, hs)
|
| 681 |
+
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
|
| 682 |
+
# shape: (B, n_kv_h, T, hs)
|
| 683 |
+
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 684 |
+
# shape: (B, n_kv_h, T, hs)
|
| 685 |
+
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 686 |
+
|
| 687 |
+
if layer_past is not None:
|
| 688 |
+
past_key, past_value = layer_past
|
| 689 |
+
k = torch.cat((past_key, k), dim=-2)
|
| 690 |
+
v = torch.cat((past_value, v), dim=-2)
|
| 691 |
+
|
| 692 |
+
present = (k, v) if use_cache else None
|
| 693 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
| 694 |
+
|
| 695 |
+
if self.config.rope:
|
| 696 |
+
# Apply rotary embeddings.
|
| 697 |
+
q, k = self.rotary_emb(q, k)
|
| 698 |
+
|
| 699 |
+
if attention_bias is not None:
|
| 700 |
+
# Resize and cast attention bias.
|
| 701 |
+
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
|
| 702 |
+
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
|
| 703 |
+
# as down-casting the attention bias to the autocast precision will result in -infs, which will
|
| 704 |
+
# cause the SDP attn function to produce NaNs.
|
| 705 |
+
attention_bias = self._cast_attn_bias(
|
| 706 |
+
attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# Get the attention scores.
|
| 710 |
+
# shape: (B, nh, T, hs)
|
| 711 |
+
att = self._scaled_dot_product_attention(
|
| 712 |
+
q,
|
| 713 |
+
k,
|
| 714 |
+
v,
|
| 715 |
+
attn_mask=None,
|
| 716 |
+
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 717 |
+
is_causal=False,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
# Re-assemble all head outputs side-by-side.
|
| 721 |
+
att = att.transpose(1, 2).contiguous().view(B, T, C)
|
| 722 |
+
|
| 723 |
+
# Apply output projection.
|
| 724 |
+
return self.attn_out(att), present
|
| 725 |
+
|
| 726 |
+
@abstractmethod
|
| 727 |
+
def forward(
|
| 728 |
+
self,
|
| 729 |
+
x: torch.Tensor,
|
| 730 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 731 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 732 |
+
use_cache: bool = False,
|
| 733 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 734 |
+
raise NotImplementedError
|
| 735 |
+
|
| 736 |
+
@classmethod
|
| 737 |
+
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
|
| 738 |
+
if config.block_type == BlockType.sequential:
|
| 739 |
+
return LLaDASequentialBlock(layer_id, config, cache)
|
| 740 |
+
elif config.block_type == BlockType.llama:
|
| 741 |
+
return LLaDALlamaBlock(layer_id, config, cache)
|
| 742 |
+
else:
|
| 743 |
+
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
class LLaDASequentialBlock(LLaDABlock):
|
| 747 |
+
"""
|
| 748 |
+
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 749 |
+
(plus another skip connection).
|
| 750 |
+
"""
|
| 751 |
+
|
| 752 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 753 |
+
super().__init__(layer_id, config, cache)
|
| 754 |
+
# Layer norms.
|
| 755 |
+
self.attn_norm = LayerNorm.build(config)
|
| 756 |
+
self.ff_norm = LayerNorm.build(config)
|
| 757 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 758 |
+
head_dim = config.d_model // config.n_heads
|
| 759 |
+
self.fused_dims = (
|
| 760 |
+
config.d_model,
|
| 761 |
+
config.effective_n_kv_heads * head_dim,
|
| 762 |
+
config.effective_n_kv_heads * head_dim,
|
| 763 |
+
)
|
| 764 |
+
self.att_proj = nn.Linear(
|
| 765 |
+
config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 766 |
+
)
|
| 767 |
+
# Feed-forward input projection.
|
| 768 |
+
self.ff_proj = nn.Linear(
|
| 769 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
def reset_parameters(self):
|
| 773 |
+
super().reset_parameters()
|
| 774 |
+
self.attn_norm.reset_parameters()
|
| 775 |
+
self.ff_norm.reset_parameters()
|
| 776 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 777 |
+
init_weights(
|
| 778 |
+
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 779 |
+
)
|
| 780 |
+
init_weights(
|
| 781 |
+
self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
def forward(
|
| 785 |
+
self,
|
| 786 |
+
x: torch.Tensor,
|
| 787 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 788 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 789 |
+
use_cache: bool = False,
|
| 790 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 791 |
+
# Get query, key, value projections.
|
| 792 |
+
# shape:
|
| 793 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 794 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 795 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 796 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 797 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 798 |
+
if self._activation_checkpoint_fn is not None:
|
| 799 |
+
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
|
| 800 |
+
self.fused_dims, dim=-1
|
| 801 |
+
)
|
| 802 |
+
else:
|
| 803 |
+
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
|
| 804 |
+
|
| 805 |
+
# Get attention scores.
|
| 806 |
+
if self._activation_checkpoint_fn is not None:
|
| 807 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 808 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 809 |
+
)
|
| 810 |
+
else:
|
| 811 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 812 |
+
|
| 813 |
+
# Add attention scores.
|
| 814 |
+
# shape: (B, T, C)
|
| 815 |
+
x = x + self.dropout(att)
|
| 816 |
+
|
| 817 |
+
# Add feed-forward projection.
|
| 818 |
+
# shape: (batch_size, seq_len, d_model)
|
| 819 |
+
og_x = x
|
| 820 |
+
if self._activation_checkpoint_fn is not None:
|
| 821 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 822 |
+
else:
|
| 823 |
+
x = self.ff_norm(x)
|
| 824 |
+
x = self.ff_proj(x)
|
| 825 |
+
if self._activation_checkpoint_fn is not None:
|
| 826 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 827 |
+
else:
|
| 828 |
+
x = self.act(x)
|
| 829 |
+
x = self.ff_out(x)
|
| 830 |
+
x = self.dropout(x)
|
| 831 |
+
x = og_x + x
|
| 832 |
+
|
| 833 |
+
return x, cache
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
class LLaDALlamaBlock(LLaDABlock):
|
| 837 |
+
"""
|
| 838 |
+
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 839 |
+
(plus another skip connection). This block is similar to `LLaDASequentialBlock`
|
| 840 |
+
but some operations have slightly different implementations to imitate the
|
| 841 |
+
behavior of Llama.
|
| 842 |
+
"""
|
| 843 |
+
|
| 844 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
| 845 |
+
super().__init__(layer_id, config, cache)
|
| 846 |
+
# Layer norms.
|
| 847 |
+
self.attn_norm = LayerNorm.build(config)
|
| 848 |
+
self.ff_norm = LayerNorm.build(config)
|
| 849 |
+
self.__cache = cache
|
| 850 |
+
|
| 851 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 852 |
+
head_dim = config.d_model // config.n_heads
|
| 853 |
+
q_proj_out_dim = config.d_model
|
| 854 |
+
k_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 855 |
+
v_proj_out_dim = config.effective_n_kv_heads * head_dim
|
| 856 |
+
self.q_proj = nn.Linear(
|
| 857 |
+
config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 858 |
+
)
|
| 859 |
+
self.k_proj = nn.Linear(
|
| 860 |
+
config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 861 |
+
)
|
| 862 |
+
self.v_proj = nn.Linear(
|
| 863 |
+
config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
# Feed-forward input projection.
|
| 867 |
+
self.ff_proj = nn.Linear(
|
| 868 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 869 |
+
)
|
| 870 |
+
# new add
|
| 871 |
+
self.up_proj = nn.Linear(
|
| 872 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
def reset_parameters(self):
|
| 876 |
+
super().reset_parameters()
|
| 877 |
+
self.attn_norm.reset_parameters()
|
| 878 |
+
self.ff_norm.reset_parameters()
|
| 879 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 880 |
+
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
|
| 881 |
+
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
|
| 882 |
+
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
|
| 883 |
+
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
|
| 884 |
+
init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
|
| 885 |
+
|
| 886 |
+
def forward(
|
| 887 |
+
self,
|
| 888 |
+
x: torch.Tensor,
|
| 889 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 890 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 891 |
+
use_cache: bool = False,
|
| 892 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 893 |
+
# Get query, key, value projections.
|
| 894 |
+
# shape:
|
| 895 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 896 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 897 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 898 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 899 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 900 |
+
x_normed = self.attn_norm(x)
|
| 901 |
+
q = self.q_proj(x_normed)
|
| 902 |
+
k = self.k_proj(x_normed)
|
| 903 |
+
v = self.v_proj(x_normed)
|
| 904 |
+
|
| 905 |
+
# Get attention scores.
|
| 906 |
+
if self._activation_checkpoint_fn is not None:
|
| 907 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 908 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 909 |
+
)
|
| 910 |
+
else:
|
| 911 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 912 |
+
|
| 913 |
+
# Add attention scores.
|
| 914 |
+
# shape: (B, T, C)
|
| 915 |
+
x = x + self.dropout(att)
|
| 916 |
+
|
| 917 |
+
# Add feed-forward projection.
|
| 918 |
+
# shape: (batch_size, seq_len, d_model)
|
| 919 |
+
og_x = x
|
| 920 |
+
if self._activation_checkpoint_fn is not None:
|
| 921 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 922 |
+
else:
|
| 923 |
+
x = self.ff_norm(x)
|
| 924 |
+
x, x_up = self.ff_proj(x), self.up_proj(x) # new add
|
| 925 |
+
if self._activation_checkpoint_fn is not None:
|
| 926 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 927 |
+
else:
|
| 928 |
+
x = self.act(x)
|
| 929 |
+
x = x * x_up # new add
|
| 930 |
+
x = self.ff_out(x)
|
| 931 |
+
x = self.dropout(x)
|
| 932 |
+
x = og_x + x
|
| 933 |
+
|
| 934 |
+
return x, cache
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
class LLaDAOutput(NamedTuple):
|
| 938 |
+
logits: torch.FloatTensor
|
| 939 |
+
"""
|
| 940 |
+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
|
| 941 |
+
for the next token *before* normalization via (log) softmax.
|
| 942 |
+
"""
|
| 943 |
+
|
| 944 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
|
| 945 |
+
"""
|
| 946 |
+
Attention keys and values from each block.
|
| 947 |
+
"""
|
| 948 |
+
|
| 949 |
+
hidden_states: Optional[Tuple[torch.Tensor]]
|
| 950 |
+
"""
|
| 951 |
+
Hidden states from each block.
|
| 952 |
+
"""
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
class LLaDAGenerateOutput(NamedTuple):
|
| 956 |
+
token_ids: torch.LongTensor
|
| 957 |
+
"""
|
| 958 |
+
The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
|
| 959 |
+
These do *not* include the original input IDs.
|
| 960 |
+
"""
|
| 961 |
+
|
| 962 |
+
scores: torch.FloatTensor
|
| 963 |
+
"""
|
| 964 |
+
The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
class LLaDABlockGroup(nn.ModuleList):
|
| 969 |
+
def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
|
| 970 |
+
super().__init__(modules)
|
| 971 |
+
self.config = config
|
| 972 |
+
self.layer_offset = layer_offset
|
| 973 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 974 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
| 975 |
+
|
| 976 |
+
def forward(
|
| 977 |
+
self,
|
| 978 |
+
x: torch.Tensor,
|
| 979 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 980 |
+
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 981 |
+
use_cache: bool = False,
|
| 982 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
| 983 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 984 |
+
for block_idx, block in enumerate(self):
|
| 985 |
+
layer_past = None if layers_past is None else layers_past[block_idx]
|
| 986 |
+
block_idx += self.layer_offset
|
| 987 |
+
if (
|
| 988 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 989 |
+
or (
|
| 990 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 991 |
+
and block_idx % 2 == 0
|
| 992 |
+
)
|
| 993 |
+
or (
|
| 994 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 995 |
+
and block_idx % 3 == 0
|
| 996 |
+
)
|
| 997 |
+
or (
|
| 998 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 999 |
+
and block_idx % 4 == 0
|
| 1000 |
+
)
|
| 1001 |
+
):
|
| 1002 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1003 |
+
x, cache = self._activation_checkpoint_fn( # type: ignore
|
| 1004 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1005 |
+
)
|
| 1006 |
+
else:
|
| 1007 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1008 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1009 |
+
if attn_key_values is not None:
|
| 1010 |
+
assert cache is not None
|
| 1011 |
+
attn_key_values.append(cache)
|
| 1012 |
+
return x, attn_key_values
|
| 1013 |
+
|
| 1014 |
+
def reset_parameters(self):
|
| 1015 |
+
for block in self:
|
| 1016 |
+
block.reset_parameters()
|
| 1017 |
+
|
| 1018 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1019 |
+
self.activation_checkpointing_strategy = strategy
|
| 1020 |
+
for block in self:
|
| 1021 |
+
block.set_activation_checkpointing(strategy)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
class LLaDAModel(nn.Module):
|
| 1025 |
+
def __init__(self, config: ModelConfig, init_params: bool = True):
|
| 1026 |
+
super().__init__()
|
| 1027 |
+
self.config = config
|
| 1028 |
+
self.__cache = BufferCache()
|
| 1029 |
+
|
| 1030 |
+
# Validate config.
|
| 1031 |
+
if self.config.alibi and self.config.flash_attention:
|
| 1032 |
+
raise Exception("ALiBi is currently not supported with FlashAttention")
|
| 1033 |
+
|
| 1034 |
+
if self.config.alibi and self.config.rope:
|
| 1035 |
+
raise Exception("ALiBi and RoPE are mutually exclusive")
|
| 1036 |
+
|
| 1037 |
+
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
|
| 1038 |
+
if self.config.embedding_size < self.config.vocab_size:
|
| 1039 |
+
raise Exception("embedding size should be at least as big as vocab size")
|
| 1040 |
+
elif self.config.embedding_size % 128 != 0:
|
| 1041 |
+
import warnings
|
| 1042 |
+
|
| 1043 |
+
warnings.warn(
|
| 1044 |
+
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 1048 |
+
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
| 1049 |
+
|
| 1050 |
+
if not (
|
| 1051 |
+
0 < self.config.block_group_size <= self.config.n_layers
|
| 1052 |
+
and self.config.n_layers % self.config.block_group_size == 0
|
| 1053 |
+
):
|
| 1054 |
+
raise Exception("n layers must be divisible by block group size")
|
| 1055 |
+
|
| 1056 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 1057 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
| 1058 |
+
|
| 1059 |
+
self.transformer = nn.ModuleDict(
|
| 1060 |
+
dict(
|
| 1061 |
+
wte=nn.Embedding(
|
| 1062 |
+
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
| 1063 |
+
),
|
| 1064 |
+
emb_drop=Dropout(config.embedding_dropout),
|
| 1065 |
+
ln_f=LayerNorm.build(config),
|
| 1066 |
+
)
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
| 1070 |
+
if self.config.block_group_size > 1:
|
| 1071 |
+
block_groups = [
|
| 1072 |
+
LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
|
| 1073 |
+
for i in range(0, config.n_layers, config.block_group_size)
|
| 1074 |
+
]
|
| 1075 |
+
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
|
| 1076 |
+
else:
|
| 1077 |
+
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
| 1078 |
+
|
| 1079 |
+
if not (self.config.alibi or self.config.rope):
|
| 1080 |
+
self.transformer.update(
|
| 1081 |
+
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
| 1082 |
+
)
|
| 1083 |
+
if not config.weight_tying:
|
| 1084 |
+
self.transformer.update(
|
| 1085 |
+
{
|
| 1086 |
+
"ff_out": nn.Linear(
|
| 1087 |
+
config.d_model,
|
| 1088 |
+
config.embedding_size or config.vocab_size,
|
| 1089 |
+
bias=config.include_bias,
|
| 1090 |
+
device=config.init_device,
|
| 1091 |
+
)
|
| 1092 |
+
}
|
| 1093 |
+
)
|
| 1094 |
+
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
| 1095 |
+
if init_params and self.config.init_device != "meta":
|
| 1096 |
+
self.reset_parameters()
|
| 1097 |
+
self.__num_fwd_flops: Optional[int] = None
|
| 1098 |
+
|
| 1099 |
+
# Warm up cache.
|
| 1100 |
+
if self.config.alibi:
|
| 1101 |
+
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
|
| 1102 |
+
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
|
| 1103 |
+
|
| 1104 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1105 |
+
self.activation_checkpointing_strategy = strategy
|
| 1106 |
+
if self.config.block_group_size != 1:
|
| 1107 |
+
for block_group in self.transformer.block_groups:
|
| 1108 |
+
block_group.set_activation_checkpointing(strategy)
|
| 1109 |
+
else:
|
| 1110 |
+
for block in self.transformer.blocks:
|
| 1111 |
+
block.set_activation_checkpointing(strategy)
|
| 1112 |
+
|
| 1113 |
+
@property
|
| 1114 |
+
def device(self) -> torch.device:
|
| 1115 |
+
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
| 1116 |
+
if device.type == "meta":
|
| 1117 |
+
return _non_meta_init_device(self.config)
|
| 1118 |
+
else:
|
| 1119 |
+
return device
|
| 1120 |
+
|
| 1121 |
+
def reset_parameters(self):
|
| 1122 |
+
log.info("Initializing model parameters...")
|
| 1123 |
+
# Top-level embeddings / linear layers.
|
| 1124 |
+
init_weights(
|
| 1125 |
+
self.config,
|
| 1126 |
+
self.transformer.wte, # type: ignore
|
| 1127 |
+
std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
|
| 1128 |
+
type_of_module=ModuleType.emb,
|
| 1129 |
+
)
|
| 1130 |
+
if hasattr(self.transformer, "wpe"):
|
| 1131 |
+
init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
|
| 1132 |
+
|
| 1133 |
+
# Top-level layer norm.
|
| 1134 |
+
self.transformer.ln_f.reset_parameters() # type: ignore
|
| 1135 |
+
|
| 1136 |
+
# Output weights.
|
| 1137 |
+
if hasattr(self.transformer, "ff_out"):
|
| 1138 |
+
init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
|
| 1139 |
+
|
| 1140 |
+
# Let the blocks handle themselves.
|
| 1141 |
+
if self.config.block_group_size == 1:
|
| 1142 |
+
for block in self.transformer.blocks:
|
| 1143 |
+
block.reset_parameters()
|
| 1144 |
+
else:
|
| 1145 |
+
for block_group in self.transformer.block_groups:
|
| 1146 |
+
block_group.reset_parameters()
|
| 1147 |
+
|
| 1148 |
+
def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1149 |
+
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
|
| 1150 |
+
-1
|
| 1151 |
+
] >= seq_len:
|
| 1152 |
+
if alibi_bias.device != device:
|
| 1153 |
+
alibi_bias = alibi_bias.to(device)
|
| 1154 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1155 |
+
return alibi_bias
|
| 1156 |
+
with torch.autocast(device.type, enabled=False):
|
| 1157 |
+
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1158 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1159 |
+
return alibi_bias
|
| 1160 |
+
|
| 1161 |
+
def forward(
|
| 1162 |
+
self,
|
| 1163 |
+
input_ids: torch.LongTensor,
|
| 1164 |
+
input_embeddings: Optional[torch.FloatTensor] = None,
|
| 1165 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1166 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1167 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1168 |
+
use_cache: bool = False,
|
| 1169 |
+
last_logits_only: bool = False,
|
| 1170 |
+
output_hidden_states: Optional[bool] = None,
|
| 1171 |
+
) -> LLaDAOutput:
|
| 1172 |
+
"""
|
| 1173 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
| 1174 |
+
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
|
| 1175 |
+
embeddings. When provided, it is treated as the output of the input embedding layer.
|
| 1176 |
+
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
| 1177 |
+
which input IDs are masked. A `1` value in the mask means that
|
| 1178 |
+
the corresponding input ID should *not* be ignored. A `0` means
|
| 1179 |
+
that the corresponding input ID is masked.
|
| 1180 |
+
|
| 1181 |
+
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
| 1182 |
+
library.
|
| 1183 |
+
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
| 1184 |
+
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
| 1185 |
+
to introduce causal or other biases.
|
| 1186 |
+
|
| 1187 |
+
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
| 1188 |
+
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
| 1189 |
+
element in the sequence.
|
| 1190 |
+
|
| 1191 |
+
If the tensor is a float tensor, it will just be added to the attention
|
| 1192 |
+
scores before the softmax.
|
| 1193 |
+
|
| 1194 |
+
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
| 1195 |
+
:param past_key_values: Pre-computed keys and values for each attention block.
|
| 1196 |
+
Can be used to speed up sequential decoding. The `input_ids` which have
|
| 1197 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
| 1198 |
+
:param use_cache: If `True`, return key and value tensors for each block.
|
| 1199 |
+
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
|
| 1200 |
+
This can speed up decoding when you only care about the next token.
|
| 1201 |
+
"""
|
| 1202 |
+
# Add Basic MDM Model config check
|
| 1203 |
+
assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
|
| 1204 |
+
assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
|
| 1205 |
+
assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
|
| 1206 |
+
|
| 1207 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 1208 |
+
|
| 1209 |
+
if past_key_values:
|
| 1210 |
+
assert len(past_key_values) == self.config.n_layers
|
| 1211 |
+
|
| 1212 |
+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
| 1213 |
+
if past_key_values is None:
|
| 1214 |
+
past_length = 0
|
| 1215 |
+
else:
|
| 1216 |
+
past_length = past_key_values[0][0].size(-2)
|
| 1217 |
+
|
| 1218 |
+
# Get embeddings of input.
|
| 1219 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1220 |
+
# print(f"input_ids: {input_ids}, input_ids.shape: {input_ids.shape}")
|
| 1221 |
+
# print(f"transformer wte weight shape: {self.transformer.wte.weight.shape}")
|
| 1222 |
+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
| 1223 |
+
|
| 1224 |
+
# print(f"xshape: {x.shape}")
|
| 1225 |
+
|
| 1226 |
+
if self.config.input_emb_norm:
|
| 1227 |
+
x = x * (self.config.d_model**0.5)
|
| 1228 |
+
|
| 1229 |
+
if not (self.config.alibi or self.config.rope):
|
| 1230 |
+
# Get positional embeddings.
|
| 1231 |
+
# shape: (1, seq_len)
|
| 1232 |
+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
| 1233 |
+
# shape: (1, seq_len, d_model)
|
| 1234 |
+
pos_emb = self.transformer.wpe(pos) # type: ignore
|
| 1235 |
+
x = pos_emb + x
|
| 1236 |
+
|
| 1237 |
+
# Add input + positional embeddings and apply dropout.
|
| 1238 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1239 |
+
x = self.transformer.emb_drop(x) # type: ignore
|
| 1240 |
+
|
| 1241 |
+
# Transform the attention mask into what the blocks expect.
|
| 1242 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1243 |
+
# shape: (batch_size, 1, 1, seq_len)
|
| 1244 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
| 1245 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
| 1246 |
+
else:
|
| 1247 |
+
attention_mask = None
|
| 1248 |
+
|
| 1249 |
+
# Merge attention mask with attention bias.
|
| 1250 |
+
if (
|
| 1251 |
+
attention_bias is not None
|
| 1252 |
+
or attention_mask is not None
|
| 1253 |
+
or self.config.alibi
|
| 1254 |
+
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
| 1255 |
+
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
| 1256 |
+
# scores correctly.
|
| 1257 |
+
or past_key_values is not None
|
| 1258 |
+
):
|
| 1259 |
+
if attention_bias is None and self.config.alibi:
|
| 1260 |
+
# print(f"get_causal_attention_bias")
|
| 1261 |
+
attention_bias = get_causal_attention_bias(
|
| 1262 |
+
self.__cache, past_length + seq_len, x.device
|
| 1263 |
+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1264 |
+
elif attention_bias is None:
|
| 1265 |
+
# print(f"get_causal_attention_bias")
|
| 1266 |
+
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
| 1267 |
+
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1268 |
+
# print(f"attention_bias.dtype in (torch.int8, torch.bool)")
|
| 1269 |
+
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1270 |
+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
| 1271 |
+
|
| 1272 |
+
# Transform to the right shape and data type.
|
| 1273 |
+
mask_len = seq_len
|
| 1274 |
+
if attention_mask is not None:
|
| 1275 |
+
mask_len = attention_mask.shape[-1]
|
| 1276 |
+
elif past_key_values is not None:
|
| 1277 |
+
mask_len = past_key_values[0][0].shape[-2] + seq_len
|
| 1278 |
+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
|
| 1279 |
+
|
| 1280 |
+
# Add in the masking bias.
|
| 1281 |
+
if attention_mask is not None:
|
| 1282 |
+
attention_bias = attention_bias + attention_mask
|
| 1283 |
+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
|
| 1284 |
+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
|
| 1285 |
+
# it can produce NaNs.
|
| 1286 |
+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
|
| 1287 |
+
|
| 1288 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 1289 |
+
|
| 1290 |
+
# decoder layers
|
| 1291 |
+
all_hidden_states = []
|
| 1292 |
+
|
| 1293 |
+
# Apply blocks one-by-one.
|
| 1294 |
+
if self.config.block_group_size == 1:
|
| 1295 |
+
for block_idx, block in enumerate(self.transformer.blocks):
|
| 1296 |
+
if output_hidden_states:
|
| 1297 |
+
# add hidden states
|
| 1298 |
+
all_hidden_states.append(x)
|
| 1299 |
+
|
| 1300 |
+
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
| 1301 |
+
if (
|
| 1302 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 1303 |
+
or (
|
| 1304 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
| 1305 |
+
and block_idx % 2 == 0
|
| 1306 |
+
)
|
| 1307 |
+
or (
|
| 1308 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
| 1309 |
+
and block_idx % 3 == 0
|
| 1310 |
+
)
|
| 1311 |
+
or (
|
| 1312 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
| 1313 |
+
and block_idx % 4 == 0
|
| 1314 |
+
)
|
| 1315 |
+
):
|
| 1316 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1317 |
+
x, cache = self._activation_checkpoint_fn(
|
| 1318 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
| 1319 |
+
)
|
| 1320 |
+
else:
|
| 1321 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1322 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
| 1323 |
+
if attn_key_values is not None:
|
| 1324 |
+
assert cache is not None
|
| 1325 |
+
attn_key_values.append(cache)
|
| 1326 |
+
else:
|
| 1327 |
+
for group_idx, block_group in enumerate(self.transformer.block_groups):
|
| 1328 |
+
if output_hidden_states:
|
| 1329 |
+
# add hidden states
|
| 1330 |
+
all_hidden_states.append(x)
|
| 1331 |
+
|
| 1332 |
+
layers_past = (
|
| 1333 |
+
None
|
| 1334 |
+
if past_key_values is None
|
| 1335 |
+
else past_key_values[
|
| 1336 |
+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
|
| 1337 |
+
]
|
| 1338 |
+
)
|
| 1339 |
+
x, cache = block_group(
|
| 1340 |
+
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
|
| 1341 |
+
)
|
| 1342 |
+
if attn_key_values is not None:
|
| 1343 |
+
assert cache is not None
|
| 1344 |
+
attn_key_values.extend(cache)
|
| 1345 |
+
|
| 1346 |
+
if last_logits_only:
|
| 1347 |
+
# shape: (batch_size, 1, d_model)
|
| 1348 |
+
x = x[:, -1, :].unsqueeze(1)
|
| 1349 |
+
|
| 1350 |
+
# Apply final layer norm.
|
| 1351 |
+
# shape: (batch_size, seq_len or 1, d_model)
|
| 1352 |
+
x = self.transformer.ln_f(x) # type: ignore
|
| 1353 |
+
if output_hidden_states:
|
| 1354 |
+
# add final hidden state post-final-layernorm, following HuggingFace's convention
|
| 1355 |
+
all_hidden_states.append(x)
|
| 1356 |
+
|
| 1357 |
+
# Get logits.
|
| 1358 |
+
# shape: (batch_size, seq_len or 1, vocab_size)
|
| 1359 |
+
if self.config.weight_tying:
|
| 1360 |
+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
| 1361 |
+
else:
|
| 1362 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
| 1363 |
+
if self.config.scale_logits:
|
| 1364 |
+
logits.mul_(1 / math.sqrt(self.config.d_model))
|
| 1365 |
+
|
| 1366 |
+
return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
|
| 1367 |
+
|
| 1368 |
+
|
| 1369 |
+
def create_model_config_from_pretrained_config(config: LLaDAConfig):
|
| 1370 |
+
"""
|
| 1371 |
+
Utility function
|
| 1372 |
+
"""
|
| 1373 |
+
|
| 1374 |
+
kwargs = {}
|
| 1375 |
+
for field in fields(ModelConfig):
|
| 1376 |
+
kwargs[field.name] = getattr(config, field.name)
|
| 1377 |
+
|
| 1378 |
+
model_config = ModelConfig(**kwargs)
|
| 1379 |
+
return model_config
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
class LLaDAModelLM(PreTrainedModel):
|
| 1383 |
+
"""
|
| 1384 |
+
Extremely barebones HF model wrapper.
|
| 1385 |
+
"""
|
| 1386 |
+
|
| 1387 |
+
config_class = LLaDAConfig
|
| 1388 |
+
base_model_prefix = "model"
|
| 1389 |
+
_no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
|
| 1390 |
+
|
| 1391 |
+
def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
|
| 1392 |
+
super().__init__(config)
|
| 1393 |
+
|
| 1394 |
+
if not model:
|
| 1395 |
+
model_config = create_model_config_from_pretrained_config(config)
|
| 1396 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
| 1397 |
+
model_config.init_device = "cpu"
|
| 1398 |
+
self.model = LLaDAModel(model_config, init_params=init_params)
|
| 1399 |
+
else:
|
| 1400 |
+
self.model = model
|
| 1401 |
+
|
| 1402 |
+
def forward(
|
| 1403 |
+
self,
|
| 1404 |
+
input_ids: torch.LongTensor = None,
|
| 1405 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1406 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1407 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1408 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1409 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1410 |
+
use_cache: Optional[bool] = None,
|
| 1411 |
+
output_attentions: Optional[bool] = None,
|
| 1412 |
+
output_hidden_states: Optional[bool] = None,
|
| 1413 |
+
return_dict: Optional[bool] = None,
|
| 1414 |
+
cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
|
| 1415 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1416 |
+
if use_cache is None:
|
| 1417 |
+
use_cache = self.config.use_cache
|
| 1418 |
+
|
| 1419 |
+
if output_attentions:
|
| 1420 |
+
raise ValueError("output_attentions is not yet supported in LLaDA")
|
| 1421 |
+
|
| 1422 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1423 |
+
|
| 1424 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1425 |
+
outputs = self.model.forward(
|
| 1426 |
+
input_ids=input_ids,
|
| 1427 |
+
input_embeddings=inputs_embeds,
|
| 1428 |
+
attention_mask=attention_mask,
|
| 1429 |
+
attention_bias=attention_bias,
|
| 1430 |
+
past_key_values=None,
|
| 1431 |
+
use_cache=False,
|
| 1432 |
+
output_hidden_states=output_hidden_states,
|
| 1433 |
+
)
|
| 1434 |
+
|
| 1435 |
+
logits = outputs.logits
|
| 1436 |
+
hidden_states = outputs.hidden_states
|
| 1437 |
+
|
| 1438 |
+
loss = None
|
| 1439 |
+
if labels is not None:
|
| 1440 |
+
import warnings
|
| 1441 |
+
warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
|
| 1442 |
+
if not return_dict:
|
| 1443 |
+
output = (logits,) + outputs[1:]
|
| 1444 |
+
return (loss,) + output if loss is not None else output
|
| 1445 |
+
|
| 1446 |
+
return CausalLMOutputWithPast(
|
| 1447 |
+
logits=logits,
|
| 1448 |
+
past_key_values=outputs.attn_key_values,
|
| 1449 |
+
hidden_states=hidden_states,
|
| 1450 |
+
)
|
| 1451 |
+
|
| 1452 |
+
def can_generate(self) -> bool:
|
| 1453 |
+
return True
|
| 1454 |
+
|
| 1455 |
+
def prepare_inputs_for_generation(
|
| 1456 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 1457 |
+
):
|
| 1458 |
+
if past_key_values:
|
| 1459 |
+
# This is because we want the model to only process the last generated token.
|
| 1460 |
+
input_ids = input_ids[:, -1:]
|
| 1461 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 1462 |
+
|
| 1463 |
+
model_inputs.update(kwargs)
|
| 1464 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 1465 |
+
return model_inputs
|
| 1466 |
+
|
| 1467 |
+
# TODO: these are required to make the implementation complete.
|
| 1468 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
| 1469 |
+
# pass
|
| 1470 |
+
#
|
| 1471 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
| 1472 |
+
# pass
|
| 1473 |
+
#
|
| 1474 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
| 1475 |
+
# pass
|
| 1476 |
+
|
| 1477 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1478 |
+
return self.model.transformer.wte
|
| 1479 |
+
|
| 1480 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
| 1481 |
+
self.model.transformer.wte = value
|
| 1482 |
+
|
| 1483 |
+
def get_output_embeddings(self):
|
| 1484 |
+
if self.config.weight_tying:
|
| 1485 |
+
return self.model.transformer.wte
|
| 1486 |
+
else:
|
| 1487 |
+
return self.model.transformer.ff_out
|
| 1488 |
+
|
| 1489 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
| 1490 |
+
if self.config.weight_tying:
|
| 1491 |
+
self.model.transformer.wte = value
|
| 1492 |
+
else:
|
| 1493 |
+
self.model.transformer.ff_out = value
|
| 1494 |
+
|
| 1495 |
+
def tie_weights(self):
|
| 1496 |
+
if self.config.weight_tying:
|
| 1497 |
+
self.model.transformer.ff_out = self.model.transformer.wte
|
| 1498 |
+
|
| 1499 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
| 1500 |
+
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
models/modeling_magvitv2.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from .common_modules import *
|
| 6 |
+
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
|
| 7 |
+
from .misc import *
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
class Updateable:
|
| 11 |
+
def do_update_step(
|
| 12 |
+
self, epoch: int, global_step: int, on_load_weights: bool = False
|
| 13 |
+
):
|
| 14 |
+
for attr in self.__dir__():
|
| 15 |
+
if attr.startswith("_"):
|
| 16 |
+
continue
|
| 17 |
+
try:
|
| 18 |
+
module = getattr(self, attr)
|
| 19 |
+
except:
|
| 20 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
| 21 |
+
if isinstance(module, Updateable):
|
| 22 |
+
module.do_update_step(
|
| 23 |
+
epoch, global_step, on_load_weights=on_load_weights
|
| 24 |
+
)
|
| 25 |
+
self.update_step(epoch, global_step, on_load_weights=on_load_weights)
|
| 26 |
+
|
| 27 |
+
def do_update_step_end(self, epoch: int, global_step: int):
|
| 28 |
+
for attr in self.__dir__():
|
| 29 |
+
if attr.startswith("_"):
|
| 30 |
+
continue
|
| 31 |
+
try:
|
| 32 |
+
module = getattr(self, attr)
|
| 33 |
+
except:
|
| 34 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
| 35 |
+
if isinstance(module, Updateable):
|
| 36 |
+
module.do_update_step_end(epoch, global_step)
|
| 37 |
+
self.update_step_end(epoch, global_step)
|
| 38 |
+
|
| 39 |
+
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
| 40 |
+
# override this method to implement custom update logic
|
| 41 |
+
# if on_load_weights is True, you should be careful doing things related to model evaluations,
|
| 42 |
+
# as the models and tensors are not guarenteed to be on the same device
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def update_step_end(self, epoch: int, global_step: int):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
class VQGANEncoder(ModelMixin, ConfigMixin):
|
| 49 |
+
@dataclass
|
| 50 |
+
class Config:
|
| 51 |
+
ch: int = 128
|
| 52 |
+
ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4])
|
| 53 |
+
num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4])
|
| 54 |
+
attn_resolutions: List[int] = field(default_factory=lambda: [5])
|
| 55 |
+
dropout: float = 0.0
|
| 56 |
+
in_ch: int = 3
|
| 57 |
+
out_ch: int = 3
|
| 58 |
+
resolution: int = 256
|
| 59 |
+
z_channels: int = 13
|
| 60 |
+
double_z: bool = False
|
| 61 |
+
|
| 62 |
+
def __init__(self,
|
| 63 |
+
ch: int = 128,
|
| 64 |
+
ch_mult: List[int] = [1, 2, 2, 4, 4],
|
| 65 |
+
num_res_blocks: List[int] = [4, 3, 4, 3, 4],
|
| 66 |
+
attn_resolutions: List[int] = [5],
|
| 67 |
+
dropout: float = 0.0,
|
| 68 |
+
in_ch: int = 3,
|
| 69 |
+
out_ch: int = 3,
|
| 70 |
+
resolution: int = 256,
|
| 71 |
+
z_channels: int = 13,
|
| 72 |
+
double_z: bool = False):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.ch = ch
|
| 75 |
+
self.temb_ch = 0
|
| 76 |
+
self.num_resolutions = len(ch_mult)
|
| 77 |
+
self.num_res_blocks = num_res_blocks
|
| 78 |
+
self.resolution = resolution
|
| 79 |
+
self.in_ch = in_ch
|
| 80 |
+
# downsampling
|
| 81 |
+
self.conv_in = torch.nn.Conv2d(
|
| 82 |
+
self.in_ch, self.ch, kernel_size=3, stride=1, padding=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
curr_res = self.resolution
|
| 86 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 87 |
+
self.down = nn.ModuleList()
|
| 88 |
+
for i_level in range(self.num_resolutions):
|
| 89 |
+
block = nn.ModuleList()
|
| 90 |
+
attn = nn.ModuleList()
|
| 91 |
+
block_in = self.ch * in_ch_mult[i_level]
|
| 92 |
+
block_out = self.ch * ch_mult[i_level]
|
| 93 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 94 |
+
block.append(
|
| 95 |
+
ResnetBlock(
|
| 96 |
+
in_channels=block_in,
|
| 97 |
+
out_channels=block_out,
|
| 98 |
+
temb_channels=self.temb_ch,
|
| 99 |
+
dropout=dropout,
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
block_in = block_out
|
| 103 |
+
if curr_res in attn_resolutions:
|
| 104 |
+
attn.append(AttnBlock(block_in))
|
| 105 |
+
down = nn.Module()
|
| 106 |
+
down.block = block
|
| 107 |
+
down.attn = attn
|
| 108 |
+
if i_level != self.num_resolutions - 1:
|
| 109 |
+
down.downsample = Downsample(block_in, True)
|
| 110 |
+
curr_res = curr_res // 2
|
| 111 |
+
self.down.append(down)
|
| 112 |
+
|
| 113 |
+
# middle
|
| 114 |
+
self.mid = nn.Module()
|
| 115 |
+
self.mid.block_1 = ResnetBlock(
|
| 116 |
+
in_channels=block_in,
|
| 117 |
+
out_channels=block_in,
|
| 118 |
+
temb_channels=self.temb_ch,
|
| 119 |
+
dropout=dropout,
|
| 120 |
+
)
|
| 121 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 122 |
+
self.mid.block_2 = ResnetBlock(
|
| 123 |
+
in_channels=block_in,
|
| 124 |
+
out_channels=block_in,
|
| 125 |
+
temb_channels=self.temb_ch,
|
| 126 |
+
dropout=dropout,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
self.norm_out = Normalize(block_in)
|
| 131 |
+
self.conv_out = torch.nn.Conv2d(
|
| 132 |
+
block_in,
|
| 133 |
+
2 * z_channels if double_z else z_channels,
|
| 134 |
+
kernel_size=3,
|
| 135 |
+
stride=1,
|
| 136 |
+
padding=1,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
| 140 |
+
# for param in self.parameters():
|
| 141 |
+
# broadcast(param, src=0)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
# timestep embedding
|
| 145 |
+
temb = None
|
| 146 |
+
|
| 147 |
+
# downsampling
|
| 148 |
+
hs = [self.conv_in(x)]
|
| 149 |
+
for i_level in range(self.num_resolutions):
|
| 150 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 151 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 152 |
+
if len(self.down[i_level].attn) > 0:
|
| 153 |
+
h = self.down[i_level].attn[i_block](h)
|
| 154 |
+
hs.append(h)
|
| 155 |
+
if i_level != self.num_resolutions - 1:
|
| 156 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 157 |
+
|
| 158 |
+
# middle
|
| 159 |
+
h = hs[-1]
|
| 160 |
+
h = self.mid.block_1(h, temb)
|
| 161 |
+
h = self.mid.attn_1(h)
|
| 162 |
+
h = self.mid.block_2(h, temb)
|
| 163 |
+
|
| 164 |
+
# end
|
| 165 |
+
h = self.norm_out(h)
|
| 166 |
+
h = nonlinearity(h)
|
| 167 |
+
h = self.conv_out(h)
|
| 168 |
+
h = self.quant_conv(h)
|
| 169 |
+
return h
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class LFQuantizer(nn.Module):
|
| 173 |
+
def __init__(self, num_codebook_entry: int = -1,
|
| 174 |
+
codebook_dim: int = 13,
|
| 175 |
+
beta: float = 0.25,
|
| 176 |
+
entropy_multiplier: float = 0.1,
|
| 177 |
+
commit_loss_multiplier: float = 0.1, ):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.codebook_size = 2 ** codebook_dim
|
| 180 |
+
print(
|
| 181 |
+
f"Look-up free quantizer with codebook size: {self.codebook_size}"
|
| 182 |
+
)
|
| 183 |
+
self.e_dim = codebook_dim
|
| 184 |
+
self.beta = beta
|
| 185 |
+
|
| 186 |
+
indices = torch.arange(self.codebook_size)
|
| 187 |
+
|
| 188 |
+
binary = (
|
| 189 |
+
indices.unsqueeze(1)
|
| 190 |
+
>> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long)
|
| 191 |
+
) & 1
|
| 192 |
+
|
| 193 |
+
embedding = binary.float() * 2 - 1
|
| 194 |
+
self.register_buffer("embedding", embedding)
|
| 195 |
+
self.register_buffer(
|
| 196 |
+
"power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1)
|
| 197 |
+
)
|
| 198 |
+
self.commit_loss_multiplier = commit_loss_multiplier
|
| 199 |
+
self.entropy_multiplier = entropy_multiplier
|
| 200 |
+
|
| 201 |
+
def get_indices(self, z_q):
|
| 202 |
+
return (
|
| 203 |
+
(self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float())
|
| 204 |
+
.sum(1, keepdim=True)
|
| 205 |
+
.long()
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def get_codebook_entry(self, indices, shape=None):
|
| 209 |
+
if shape is None:
|
| 210 |
+
h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1]))
|
| 211 |
+
else:
|
| 212 |
+
h, w = shape
|
| 213 |
+
b, _ = indices.shape
|
| 214 |
+
indices = indices.reshape(-1)
|
| 215 |
+
z_q = self.embedding[indices]
|
| 216 |
+
z_q = z_q.view(b, h, w, -1)
|
| 217 |
+
|
| 218 |
+
# reshape back to match original input shape
|
| 219 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 220 |
+
|
| 221 |
+
return z_q
|
| 222 |
+
|
| 223 |
+
def forward(self, z, get_code=False):
|
| 224 |
+
"""
|
| 225 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
| 226 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
| 227 |
+
z (continuous) -> z_q (discrete)
|
| 228 |
+
z.shape = (batch, channel, height, width)
|
| 229 |
+
quantization pipeline:
|
| 230 |
+
1. get encoder input (B,C,H,W)
|
| 231 |
+
2. flatten input to (B*H*W,C)
|
| 232 |
+
"""
|
| 233 |
+
if get_code:
|
| 234 |
+
return self.get_codebook_entry(z)
|
| 235 |
+
|
| 236 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 237 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
| 238 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 239 |
+
ge_zero = (z_flattened > 0).float()
|
| 240 |
+
ones = torch.ones_like(z_flattened)
|
| 241 |
+
z_q = ones * ge_zero + -ones * (1 - ge_zero)
|
| 242 |
+
|
| 243 |
+
# preserve gradients
|
| 244 |
+
z_q = z_flattened + (z_q - z_flattened).detach()
|
| 245 |
+
|
| 246 |
+
# compute entropy loss
|
| 247 |
+
CatDist = torch.distributions.categorical.Categorical
|
| 248 |
+
logit = torch.stack(
|
| 249 |
+
[
|
| 250 |
+
-(z_flattened - torch.ones_like(z_q)).pow(2),
|
| 251 |
+
-(z_flattened - torch.ones_like(z_q) * -1).pow(2),
|
| 252 |
+
],
|
| 253 |
+
dim=-1,
|
| 254 |
+
)
|
| 255 |
+
cat_dist = CatDist(logits=logit)
|
| 256 |
+
entropy = cat_dist.entropy().mean()
|
| 257 |
+
mean_prob = cat_dist.probs.mean(0)
|
| 258 |
+
mean_entropy = CatDist(probs=mean_prob).entropy().mean()
|
| 259 |
+
|
| 260 |
+
# compute loss for embedding
|
| 261 |
+
commit_loss = torch.mean(
|
| 262 |
+
(z_q.detach() - z_flattened) ** 2
|
| 263 |
+
) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2)
|
| 264 |
+
|
| 265 |
+
# reshape back to match original input shape
|
| 266 |
+
z_q = z_q.view(z.shape)
|
| 267 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 268 |
+
|
| 269 |
+
return {
|
| 270 |
+
"z": z_q,
|
| 271 |
+
"quantizer_loss": commit_loss * self.commit_loss_multiplier,
|
| 272 |
+
"entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier,
|
| 273 |
+
"indices": self.get_indices(z_q),
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class VQGANDecoder(ModelMixin, ConfigMixin):
|
| 278 |
+
def __init__(self, ch: int = 128,
|
| 279 |
+
ch_mult: List[int] = [1, 1, 2, 2, 4],
|
| 280 |
+
num_res_blocks: List[int] = [4, 4, 3, 4, 3],
|
| 281 |
+
attn_resolutions: List[int] = [5],
|
| 282 |
+
dropout: float = 0.0,
|
| 283 |
+
in_ch: int = 3,
|
| 284 |
+
out_ch: int = 3,
|
| 285 |
+
resolution: int = 256,
|
| 286 |
+
z_channels: int = 13,
|
| 287 |
+
double_z: bool = False):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.ch = ch
|
| 290 |
+
self.temb_ch = 0
|
| 291 |
+
self.num_resolutions = len(ch_mult)
|
| 292 |
+
self.num_res_blocks = num_res_blocks
|
| 293 |
+
self.resolution = resolution
|
| 294 |
+
self.in_ch = in_ch
|
| 295 |
+
self.give_pre_end = False
|
| 296 |
+
|
| 297 |
+
self.z_channels = z_channels
|
| 298 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 299 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 300 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 301 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
| 302 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 303 |
+
print(
|
| 304 |
+
"Working with z of shape {} = {} dimensions.".format(
|
| 305 |
+
self.z_shape, np.prod(self.z_shape)
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# z to block_in
|
| 310 |
+
self.conv_in = torch.nn.Conv2d(
|
| 311 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# middle
|
| 315 |
+
self.mid = nn.Module()
|
| 316 |
+
self.mid.block_1 = ResnetBlock(
|
| 317 |
+
in_channels=block_in,
|
| 318 |
+
out_channels=block_in,
|
| 319 |
+
temb_channels=self.temb_ch,
|
| 320 |
+
dropout=dropout,
|
| 321 |
+
)
|
| 322 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 323 |
+
self.mid.block_2 = ResnetBlock(
|
| 324 |
+
in_channels=block_in,
|
| 325 |
+
out_channels=block_in,
|
| 326 |
+
temb_channels=self.temb_ch,
|
| 327 |
+
dropout=dropout,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# upsampling
|
| 331 |
+
self.up = nn.ModuleList()
|
| 332 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 333 |
+
block = nn.ModuleList()
|
| 334 |
+
attn = nn.ModuleList()
|
| 335 |
+
block_out = ch * ch_mult[i_level]
|
| 336 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 337 |
+
block.append(
|
| 338 |
+
ResnetBlock(
|
| 339 |
+
in_channels=block_in,
|
| 340 |
+
out_channels=block_out,
|
| 341 |
+
temb_channels=self.temb_ch,
|
| 342 |
+
dropout=dropout,
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
block_in = block_out
|
| 346 |
+
if curr_res in attn_resolutions:
|
| 347 |
+
attn.append(AttnBlock(block_in))
|
| 348 |
+
up = nn.Module()
|
| 349 |
+
up.block = block
|
| 350 |
+
up.attn = attn
|
| 351 |
+
if i_level != 0:
|
| 352 |
+
up.upsample = Upsample(block_in, True)
|
| 353 |
+
curr_res = curr_res * 2
|
| 354 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 355 |
+
|
| 356 |
+
self.norm_out = Normalize(block_in)
|
| 357 |
+
self.conv_out = torch.nn.Conv2d(
|
| 358 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
| 359 |
+
)
|
| 360 |
+
self.post_quant_conv = torch.nn.Conv2d(
|
| 361 |
+
z_channels, z_channels, 1
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def forward(self, z):
|
| 366 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
| 367 |
+
self.last_z_shape = z.shape
|
| 368 |
+
# timestep embedding
|
| 369 |
+
temb = None
|
| 370 |
+
output = dict()
|
| 371 |
+
z = self.post_quant_conv(z)
|
| 372 |
+
|
| 373 |
+
# z to block_in
|
| 374 |
+
h = self.conv_in(z)
|
| 375 |
+
|
| 376 |
+
# middle
|
| 377 |
+
h = self.mid.block_1(h, temb)
|
| 378 |
+
h = self.mid.attn_1(h)
|
| 379 |
+
h = self.mid.block_2(h, temb)
|
| 380 |
+
|
| 381 |
+
# upsampling
|
| 382 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 383 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 384 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 385 |
+
if len(self.up[i_level].attn) > 0:
|
| 386 |
+
h = self.up[i_level].attn[i_block](h)
|
| 387 |
+
if i_level != 0:
|
| 388 |
+
h = self.up[i_level].upsample(h)
|
| 389 |
+
|
| 390 |
+
# end
|
| 391 |
+
output["output"] = h
|
| 392 |
+
if self.give_pre_end:
|
| 393 |
+
return output
|
| 394 |
+
|
| 395 |
+
h = self.norm_out(h)
|
| 396 |
+
h = nonlinearity(h)
|
| 397 |
+
h = self.conv_out(h)
|
| 398 |
+
output["output"] = h
|
| 399 |
+
return output
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class MAGVITv2(ModelMixin, ConfigMixin):
|
| 403 |
+
@register_to_config
|
| 404 |
+
def __init__(
|
| 405 |
+
self,
|
| 406 |
+
):
|
| 407 |
+
super().__init__()
|
| 408 |
+
|
| 409 |
+
self.encoder = VQGANEncoder()
|
| 410 |
+
self.decoder = VQGANDecoder()
|
| 411 |
+
self.quantize = LFQuantizer()
|
| 412 |
+
|
| 413 |
+
def forward(self, pixel_values, return_loss=False):
|
| 414 |
+
pass
|
| 415 |
+
|
| 416 |
+
def encode(self, pixel_values, return_loss=False):
|
| 417 |
+
hidden_states = self.encoder(pixel_values)
|
| 418 |
+
quantized_states = self.quantize(hidden_states)['z']
|
| 419 |
+
codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1)
|
| 420 |
+
output = (quantized_states, codebook_indices)
|
| 421 |
+
return output
|
| 422 |
+
|
| 423 |
+
def get_code(self, pixel_values):
|
| 424 |
+
hidden_states = self.encoder(pixel_values)
|
| 425 |
+
codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1)
|
| 426 |
+
|
| 427 |
+
return codebook_indices
|
| 428 |
+
|
| 429 |
+
def decode_code(self, codebook_indices, shape=None):
|
| 430 |
+
z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape)
|
| 431 |
+
|
| 432 |
+
reconstructed_pixel_values = self.decoder(z_q)["output"]
|
| 433 |
+
return reconstructed_pixel_values
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == '__main__':
|
| 437 |
+
encoder = VQGANEncoder()
|
| 438 |
+
import ipdb
|
| 439 |
+
ipdb.set_trace()
|
| 440 |
+
print()
|
models/modeling_mmada.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import sys
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import (
|
| 10 |
+
Callable,
|
| 11 |
+
Dict,
|
| 12 |
+
Iterable,
|
| 13 |
+
List,
|
| 14 |
+
NamedTuple,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
cast,
|
| 20 |
+
)
|
| 21 |
+
from dataclasses import fields
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.backends.cuda
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from torch import einsum
|
| 29 |
+
from transformers import PreTrainedModel
|
| 30 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 31 |
+
from transformers.models.auto import AutoModel, AutoConfig, AutoModelForCausalLM
|
| 32 |
+
from transformers.cache_utils import Cache
|
| 33 |
+
from PIL import Image
|
| 34 |
+
from .configuration_llada import (
|
| 35 |
+
LLaDAConfig,
|
| 36 |
+
StrEnum,
|
| 37 |
+
InitFnType,
|
| 38 |
+
ActivationType,
|
| 39 |
+
BlockType,
|
| 40 |
+
LayerNormType,
|
| 41 |
+
ModelConfig,
|
| 42 |
+
ActivationCheckpointingStrategy,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
from .modeling_llada import LLaDAModelLM
|
| 46 |
+
from .sampling import cosine_schedule, mask_by_random_topk
|
| 47 |
+
from transformers import PretrainedConfig
|
| 48 |
+
|
| 49 |
+
def add_gumbel_noise(logits, temperature):
|
| 50 |
+
'''
|
| 51 |
+
The Gumbel max is a method for sampling categorical distributions.
|
| 52 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
| 53 |
+
Thus, we use float64.
|
| 54 |
+
'''
|
| 55 |
+
if temperature == 0:
|
| 56 |
+
return logits
|
| 57 |
+
logits = logits.to(torch.float64)
|
| 58 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 59 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
| 60 |
+
return logits.exp() / gumbel_noise
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_num_transfer_tokens(mask_index, steps):
|
| 64 |
+
'''
|
| 65 |
+
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
| 66 |
+
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
| 67 |
+
the expected number of tokens transitioned at each step should be consistent.
|
| 68 |
+
|
| 69 |
+
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
| 70 |
+
'''
|
| 71 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
| 72 |
+
|
| 73 |
+
base = mask_num // steps
|
| 74 |
+
remainder = mask_num % steps
|
| 75 |
+
|
| 76 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
| 77 |
+
|
| 78 |
+
for i in range(mask_num.size(0)):
|
| 79 |
+
num_transfer_tokens[i, :remainder[i]] += 1
|
| 80 |
+
|
| 81 |
+
return num_transfer_tokens
|
| 82 |
+
|
| 83 |
+
class MMadaConfig(PretrainedConfig):
|
| 84 |
+
model_type = "mmada"
|
| 85 |
+
|
| 86 |
+
def __init__(self, **kwargs):
|
| 87 |
+
super().__init__(**kwargs)
|
| 88 |
+
|
| 89 |
+
allowed_keys = [
|
| 90 |
+
"vocab_size",
|
| 91 |
+
"llm_vocab_size",
|
| 92 |
+
"llm_model_path",
|
| 93 |
+
"codebook_size",
|
| 94 |
+
"num_vq_tokens",
|
| 95 |
+
"num_new_special_tokens",
|
| 96 |
+
"gradient_checkpointing",
|
| 97 |
+
"new_vocab_size",
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
for key in allowed_keys:
|
| 101 |
+
if key in kwargs:
|
| 102 |
+
setattr(self, key, kwargs[key])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class MMadaModelLM(LLaDAModelLM):
|
| 107 |
+
config_class = MMadaConfig
|
| 108 |
+
base_model_prefix = "model"
|
| 109 |
+
def __init__(self, config: MMadaConfig, *args, **kwargs):
|
| 110 |
+
print(f"Initializing MMadaModelLM with config: {config}")
|
| 111 |
+
super().__init__(config, *args, **kwargs)
|
| 112 |
+
|
| 113 |
+
# # resize token embeddings
|
| 114 |
+
# print(f"Resizing token embeddings to {config.new_vocab_size}")
|
| 115 |
+
# self.resize_token_embeddings(config.new_vocab_size)
|
| 116 |
+
|
| 117 |
+
@torch.no_grad()
|
| 118 |
+
def t2i_generate(
|
| 119 |
+
self,
|
| 120 |
+
input_ids: torch.LongTensor = None,
|
| 121 |
+
uncond_input_ids: torch.LongTensor = None,
|
| 122 |
+
attention_mask=None,
|
| 123 |
+
uncond_attention_mask=None,
|
| 124 |
+
temperature=1.0,
|
| 125 |
+
timesteps=18, # ideal number of steps is 18 in maskgit paper
|
| 126 |
+
guidance_scale=0,
|
| 127 |
+
noise_schedule=cosine_schedule,
|
| 128 |
+
generator: torch.Generator = None,
|
| 129 |
+
config=None,
|
| 130 |
+
seq_len=1024,
|
| 131 |
+
mask_token_id = 126336,
|
| 132 |
+
resolution = 512,
|
| 133 |
+
codebook_size = 8192,
|
| 134 |
+
**kwargs,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Generate 1:1 similar to the original MaskGit repo
|
| 138 |
+
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
# begin with all image token ids masked
|
| 142 |
+
# 计算有多少个mask token
|
| 143 |
+
mask_count = (input_ids == mask_token_id).sum().item()
|
| 144 |
+
num_vq_tokens = seq_len
|
| 145 |
+
num_new_special_tokens = 0
|
| 146 |
+
uni_prompting = kwargs.get("uni_prompting", None)
|
| 147 |
+
# print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}")
|
| 148 |
+
input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
|
| 149 |
+
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens)
|
| 150 |
+
|
| 151 |
+
# for classifier-free guidance
|
| 152 |
+
if uncond_input_ids is not None:
|
| 153 |
+
uncond_prefix = uncond_input_ids[:, :resolution + 1]
|
| 154 |
+
|
| 155 |
+
for step in range(timesteps):
|
| 156 |
+
if uncond_input_ids is not None and guidance_scale > 0:
|
| 157 |
+
uncond_input_ids = torch.cat(
|
| 158 |
+
[uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
|
| 159 |
+
model_input = torch.cat([input_ids, uncond_input_ids])
|
| 160 |
+
attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
|
| 161 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
| 162 |
+
logits = self(model_input, attention_bias=attention_bias).logits
|
| 163 |
+
# print(f"logits.shape: {logits.shape}")
|
| 164 |
+
cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
|
| 165 |
+
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
| 166 |
+
# it seems that muse has a different cfg setting
|
| 167 |
+
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
|
| 168 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
| 169 |
+
else:
|
| 170 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
| 171 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
| 172 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
| 173 |
+
|
| 174 |
+
# logits: 1, 1024, 8192
|
| 175 |
+
# print(f"logits.shape: {logits.shape}")
|
| 176 |
+
probs = logits.softmax(dim=-1)
|
| 177 |
+
sampled = probs.reshape(-1, logits.size(-1))
|
| 178 |
+
# print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}")
|
| 179 |
+
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024
|
| 180 |
+
|
| 181 |
+
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
|
| 182 |
+
# print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}")
|
| 183 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
|
| 184 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 185 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 186 |
+
ratio = 1.0 * (step + 1) / timesteps
|
| 187 |
+
mask_ratio = noise_schedule(torch.tensor(ratio))
|
| 188 |
+
# Computes the probabilities of each selected tokens.
|
| 189 |
+
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
|
| 190 |
+
selected_probs = selected_probs.squeeze(-1)
|
| 191 |
+
|
| 192 |
+
# Ignores the tokens given in the input by overwriting their confidence.
|
| 193 |
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
| 194 |
+
# Gets mask lens for each sample in the batch according to the mask ratio.
|
| 195 |
+
mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
|
| 196 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 197 |
+
# one and for the next iteration
|
| 198 |
+
mask_len = torch.max(
|
| 199 |
+
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
| 200 |
+
)
|
| 201 |
+
# print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}")
|
| 202 |
+
# Adds noise for randomness
|
| 203 |
+
temperature = temperature * (1.0 - ratio)
|
| 204 |
+
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
|
| 205 |
+
# Masks tokens with lower confidence.
|
| 206 |
+
input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
|
| 207 |
+
sampled_ids + len(uni_prompting.text_tokenizer)
|
| 208 |
+
+ num_new_special_tokens)
|
| 209 |
+
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
|
| 210 |
+
|
| 211 |
+
return sampled_ids
|
| 212 |
+
|
| 213 |
+
def forward_process(
|
| 214 |
+
self,
|
| 215 |
+
input_ids,
|
| 216 |
+
labels,
|
| 217 |
+
batch_size_t2i=0,
|
| 218 |
+
batch_size_lm=0,
|
| 219 |
+
batch_size_mmu=0,
|
| 220 |
+
max_seq_length=128,
|
| 221 |
+
p_mask_lm=None,
|
| 222 |
+
p_mask_mmu=None,
|
| 223 |
+
answer_lengths=None,
|
| 224 |
+
t2i_masks=None,
|
| 225 |
+
answer_lengths_lm=None
|
| 226 |
+
):
|
| 227 |
+
# attention bias, True for batch_size, 1, seq_len, seq_len
|
| 228 |
+
attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
|
| 229 |
+
attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
|
| 230 |
+
attention_bias[:batch_size_t2i] = attention_bias_t2i
|
| 231 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
| 232 |
+
# logits = self(input_ids).logits
|
| 233 |
+
self.output_size = logits.shape[-1]
|
| 234 |
+
|
| 235 |
+
# print(f"logits shape: {logits.shape}") B, 359, vocab_size
|
| 236 |
+
|
| 237 |
+
if batch_size_t2i == 0:
|
| 238 |
+
loss_t2i = torch.tensor(0.0, device=input_ids.device)
|
| 239 |
+
else:
|
| 240 |
+
# t2i loss
|
| 241 |
+
loss_t2i = F.cross_entropy(
|
| 242 |
+
logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
|
| 243 |
+
labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# llada loss
|
| 247 |
+
masked_indices = input_ids == self.config.mask_token_id
|
| 248 |
+
masked_indices_lm = masked_indices[batch_size_t2i:batch_size_t2i + batch_size_lm]
|
| 249 |
+
# 新增调试代码:统计每行mask数量
|
| 250 |
+
# if masked_indices_lm.numel() > 0:
|
| 251 |
+
# mask_counts = torch.sum(masked_indices_lm, dim=1)
|
| 252 |
+
# logging.info(f"[LM mask nums]: {mask_counts.cpu()}.")
|
| 253 |
+
# else:
|
| 254 |
+
# logging.info("[LM mask nums] no LM sample.")
|
| 255 |
+
masked_indices_mmu = masked_indices[-batch_size_mmu:]
|
| 256 |
+
p_mask_lm = p_mask_lm.to(masked_indices_lm.device)
|
| 257 |
+
p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device)
|
| 258 |
+
answer_lengths = answer_lengths.to(masked_indices_mmu.device)
|
| 259 |
+
loss_lm = F.cross_entropy(
|
| 260 |
+
logits[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1, self.output_size),
|
| 261 |
+
labels[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none'
|
| 262 |
+
)/p_mask_lm[masked_indices_lm]
|
| 263 |
+
# print(f"logits lm shape: {logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape}")
|
| 264 |
+
loss_lm = loss_lm.sum() / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0] * logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[1])
|
| 265 |
+
|
| 266 |
+
# llm loss
|
| 267 |
+
answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device)
|
| 268 |
+
loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0])
|
| 269 |
+
|
| 270 |
+
loss_mmu = F.cross_entropy(
|
| 271 |
+
logits[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1, self.output_size),
|
| 272 |
+
labels[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none'
|
| 273 |
+
)/p_mask_mmu[masked_indices_mmu]
|
| 274 |
+
loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[-batch_size_mmu:].shape[0])
|
| 275 |
+
|
| 276 |
+
return logits, loss_t2i, loss_lm, loss_mmu
|
| 277 |
+
|
| 278 |
+
def forward_process_with_r2i(
|
| 279 |
+
self,
|
| 280 |
+
input_ids,
|
| 281 |
+
labels,
|
| 282 |
+
t2i_masks=None,
|
| 283 |
+
max_seq_length=128,
|
| 284 |
+
batch_size_t2i=0,
|
| 285 |
+
batch_size_lm=0,
|
| 286 |
+
batch_size_mmu=0,
|
| 287 |
+
batch_size_r2i=0,
|
| 288 |
+
p_mask_lm=None,
|
| 289 |
+
p_mask_mmu=None,
|
| 290 |
+
p_mask_r2i=None,
|
| 291 |
+
answer_lengths=None,
|
| 292 |
+
answer_lengths_lm=None,
|
| 293 |
+
answer_lengths_r2i=None,
|
| 294 |
+
):
|
| 295 |
+
# attention bias, True for batch_size, 1, seq_len, seq_len
|
| 296 |
+
attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
|
| 297 |
+
attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
|
| 298 |
+
attention_bias[:batch_size_t2i] = attention_bias_t2i
|
| 299 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
| 300 |
+
# logits = self(input_ids).logits
|
| 301 |
+
self.output_size = logits.shape[-1]
|
| 302 |
+
|
| 303 |
+
# print(f"logits shape: {logits.shape}") B, 359, vocab_size
|
| 304 |
+
|
| 305 |
+
if batch_size_t2i == 0:
|
| 306 |
+
loss_t2i = torch.tensor(0.0, device=input_ids.device)
|
| 307 |
+
else:
|
| 308 |
+
# t2i loss
|
| 309 |
+
loss_t2i = F.cross_entropy(
|
| 310 |
+
logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
|
| 311 |
+
labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# llada loss
|
| 315 |
+
|
| 316 |
+
start_lm = batch_size_t2i
|
| 317 |
+
end_lm = start_lm + batch_size_lm
|
| 318 |
+
start_mmu = end_lm
|
| 319 |
+
end_mmu = start_mmu + batch_size_mmu
|
| 320 |
+
start_r2i = end_mmu
|
| 321 |
+
end_r2i = start_r2i + batch_size_r2i
|
| 322 |
+
|
| 323 |
+
masked_indices = input_ids == self.config.mask_token_id
|
| 324 |
+
masked_indices_lm = masked_indices[start_lm:end_lm]
|
| 325 |
+
masked_indices_mmu = masked_indices[start_mmu:end_mmu]
|
| 326 |
+
masked_indices_r2i = masked_indices[start_r2i:end_r2i]
|
| 327 |
+
|
| 328 |
+
p_mask_lm = p_mask_lm.to(masked_indices_lm.device)
|
| 329 |
+
p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device)
|
| 330 |
+
p_mask_r2i = p_mask_r2i.to(masked_indices_r2i.device)
|
| 331 |
+
|
| 332 |
+
answer_lengths = answer_lengths.to(masked_indices_mmu.device)
|
| 333 |
+
answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device)
|
| 334 |
+
answer_lengths_r2i = answer_lengths_r2i.to(masked_indices_r2i.device)
|
| 335 |
+
|
| 336 |
+
loss_lm = F.cross_entropy(
|
| 337 |
+
logits[start_lm:end_lm][masked_indices_lm].contiguous().view(-1, self.output_size),
|
| 338 |
+
labels[start_lm:end_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none'
|
| 339 |
+
)/p_mask_lm[masked_indices_lm]
|
| 340 |
+
# print(f"logits lm shape: {logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape}")
|
| 341 |
+
loss_lm = loss_lm.sum() / (logits[start_lm:end_lm].shape[0] * logits[start_lm:end_lm].shape[1])
|
| 342 |
+
loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[start_lm:end_lm].shape[0])
|
| 343 |
+
|
| 344 |
+
loss_mmu = F.cross_entropy(
|
| 345 |
+
logits[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1, self.output_size),
|
| 346 |
+
labels[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none'
|
| 347 |
+
)/p_mask_mmu[masked_indices_mmu]
|
| 348 |
+
loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[start_mmu:end_mmu].shape[0])
|
| 349 |
+
|
| 350 |
+
loss_r2i = F.cross_entropy(
|
| 351 |
+
logits[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1, self.output_size),
|
| 352 |
+
labels[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1), ignore_index=-100, reduction='none'
|
| 353 |
+
)/p_mask_r2i[masked_indices_r2i]
|
| 354 |
+
loss_r2i = torch.sum(loss_r2i/answer_lengths_r2i[masked_indices_r2i]) / (logits[start_r2i:end_r2i].shape[0])
|
| 355 |
+
|
| 356 |
+
return logits, loss_t2i, loss_lm, loss_mmu, loss_r2i
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def forward_t2i(
|
| 360 |
+
self,
|
| 361 |
+
input_ids,
|
| 362 |
+
labels,
|
| 363 |
+
batch_size_t2i=0,
|
| 364 |
+
max_seq_length=128,
|
| 365 |
+
t2i_masks=None
|
| 366 |
+
):
|
| 367 |
+
# attention bias, True for batch_size, 1, seq_len, seq_len
|
| 368 |
+
attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
|
| 369 |
+
attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
|
| 370 |
+
attention_bias[:batch_size_t2i] = attention_bias_t2i
|
| 371 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
| 372 |
+
# logits = self(input_ids).logits
|
| 373 |
+
self.output_size = logits.shape[-1]
|
| 374 |
+
|
| 375 |
+
# print(f"logits shape: {logits.shape}") B, 359, vocab_size
|
| 376 |
+
|
| 377 |
+
loss_t2i = F.cross_entropy(
|
| 378 |
+
logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
|
| 379 |
+
labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
return loss_t2i
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@torch.no_grad()
|
| 389 |
+
def mmu_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
|
| 390 |
+
"""
|
| 391 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
| 392 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
| 393 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 397 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
| 398 |
+
# print(f"attention_bias: {attention_bias}")
|
| 399 |
+
else:
|
| 400 |
+
attention_bias = None
|
| 401 |
+
try:
|
| 402 |
+
device = idx.device
|
| 403 |
+
except:
|
| 404 |
+
device = input_embeddings.device
|
| 405 |
+
|
| 406 |
+
result = []
|
| 407 |
+
batch_size = idx.shape[0]
|
| 408 |
+
x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device)
|
| 409 |
+
x[:, :idx.shape[1]] = idx.clone()
|
| 410 |
+
prompt_index = (x != mask_id)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
assert max_new_tokens % block_length == 0
|
| 414 |
+
num_blocks = max_new_tokens // block_length
|
| 415 |
+
|
| 416 |
+
assert steps % num_blocks == 0
|
| 417 |
+
steps = steps // num_blocks
|
| 418 |
+
|
| 419 |
+
# print(f"num_blocks: {num_blocks}, steps: {steps}")
|
| 420 |
+
# num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps)
|
| 421 |
+
for num_block in range(num_blocks):
|
| 422 |
+
block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
| 423 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
| 424 |
+
# num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps)
|
| 425 |
+
# print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}")
|
| 426 |
+
for i in range(steps):
|
| 427 |
+
mask_index = (x == mask_id)
|
| 428 |
+
if cfg_scale > 0.0:
|
| 429 |
+
un_x = x.clone()
|
| 430 |
+
un_x[prompt_index] = mask_id
|
| 431 |
+
x_ = torch.cat([x, un_x], dim=0)
|
| 432 |
+
logits = self(x_).logits
|
| 433 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 434 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
| 435 |
+
else:
|
| 436 |
+
logits = self(x, attention_bias=attention_bias).logits
|
| 437 |
+
|
| 438 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 439 |
+
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
| 440 |
+
if remasking == 'low_confidence':
|
| 441 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 442 |
+
x0_p = torch.squeeze(
|
| 443 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 444 |
+
elif remasking == 'random':
|
| 445 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
| 446 |
+
else:
|
| 447 |
+
raise NotImplementedError(remasking)
|
| 448 |
+
|
| 449 |
+
x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
| 450 |
+
|
| 451 |
+
x0 = torch.where(mask_index, x0, x)
|
| 452 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 453 |
+
|
| 454 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 455 |
+
for j in range(confidence.shape[0]):
|
| 456 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
| 457 |
+
transfer_index[j, select_index] = True
|
| 458 |
+
x[transfer_index] = x0[transfer_index]
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# logits = logits[:, -1, :] / temperature
|
| 462 |
+
# # optionally crop the logits to only the top k options
|
| 463 |
+
# if top_k is not None:
|
| 464 |
+
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 465 |
+
# logits[logits < v[:, [-1]]] = -float('Inf')
|
| 466 |
+
# # apply softmax to convert logits to (normalized) probabilities
|
| 467 |
+
# probs = F.softmax(logits, dim=-1)
|
| 468 |
+
# # sample from the distribution
|
| 469 |
+
# idx_next = torch.multinomial(probs, num_samples=1)
|
| 470 |
+
# result.append(idx_next[0][0])
|
| 471 |
+
# # append sampled index to the running sequence and continue
|
| 472 |
+
# if self.config.w_clip_vit:
|
| 473 |
+
# idx_next_embeddings = self.mmada.model.embed_tokens(idx_next)
|
| 474 |
+
# input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
|
| 475 |
+
# else:
|
| 476 |
+
# idx = torch.cat((idx, idx_next), dim=1)
|
| 477 |
+
|
| 478 |
+
# if eot_token is not None and idx_next.cpu() == eot_token:
|
| 479 |
+
# break
|
| 480 |
+
|
| 481 |
+
return x
|
| 482 |
+
|
| 483 |
+
@torch.no_grad()
|
| 484 |
+
def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
|
| 485 |
+
"""
|
| 486 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
| 487 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
| 488 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
| 489 |
+
"""
|
| 490 |
+
|
| 491 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 492 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
| 493 |
+
# print(f"attention_bias: {attention_bias}")
|
| 494 |
+
else:
|
| 495 |
+
attention_bias = None
|
| 496 |
+
try:
|
| 497 |
+
device = idx.device
|
| 498 |
+
except:
|
| 499 |
+
device = input_embeddings.device
|
| 500 |
+
|
| 501 |
+
result = []
|
| 502 |
+
batch_size = idx.shape[0]
|
| 503 |
+
x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device)
|
| 504 |
+
x[:, :idx.shape[1]] = idx.clone()
|
| 505 |
+
prompt_index = (x != mask_id)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
assert max_new_tokens % block_length == 0
|
| 509 |
+
num_blocks = max_new_tokens // block_length
|
| 510 |
+
|
| 511 |
+
assert steps % num_blocks == 0
|
| 512 |
+
steps = steps // num_blocks
|
| 513 |
+
|
| 514 |
+
for num_block in range(num_blocks):
|
| 515 |
+
block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
| 516 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
| 517 |
+
for i in range(steps):
|
| 518 |
+
mask_index = (x == mask_id)
|
| 519 |
+
if cfg_scale > 0.0:
|
| 520 |
+
un_x = x.clone()
|
| 521 |
+
un_x[prompt_index] = mask_id
|
| 522 |
+
x_ = torch.cat([x, un_x], dim=0)
|
| 523 |
+
logits = self(x_).logits
|
| 524 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 525 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
| 526 |
+
else:
|
| 527 |
+
logits = self(x, attention_bias=attention_bias).logits
|
| 528 |
+
|
| 529 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 530 |
+
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
| 531 |
+
if remasking == 'low_confidence':
|
| 532 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 533 |
+
x0_p = torch.squeeze(
|
| 534 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 535 |
+
elif remasking == 'random':
|
| 536 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
| 537 |
+
else:
|
| 538 |
+
raise NotImplementedError(remasking)
|
| 539 |
+
|
| 540 |
+
x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
| 541 |
+
|
| 542 |
+
x0 = torch.where(mask_index, x0, x)
|
| 543 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 544 |
+
|
| 545 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 546 |
+
for j in range(confidence.shape[0]):
|
| 547 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
| 548 |
+
transfer_index[j, select_index] = True
|
| 549 |
+
x[transfer_index] = x0[transfer_index]
|
| 550 |
+
if eot_token is not None:
|
| 551 |
+
last_token_index_in_current_block = idx.shape[1] + (num_block + 1) * block_length - 1
|
| 552 |
+
if last_token_index_in_current_block < x.shape[1]:
|
| 553 |
+
tokens_at_block_end = x[:, last_token_index_in_current_block]
|
| 554 |
+
if torch.all(tokens_at_block_end == eot_token):
|
| 555 |
+
break
|
| 556 |
+
return x
|
| 557 |
+
|
| 558 |
+
@torch.no_grad()
|
| 559 |
+
def t2i_generate_decoding_stepwise(
|
| 560 |
+
self,
|
| 561 |
+
input_ids: torch.LongTensor = None,
|
| 562 |
+
uncond_input_ids: torch.LongTensor = None,
|
| 563 |
+
attention_mask=None,
|
| 564 |
+
uncond_attention_mask=None,
|
| 565 |
+
temperature=1.0,
|
| 566 |
+
timesteps=18, # ideal number of steps is 18 in maskgit paper
|
| 567 |
+
guidance_scale=0,
|
| 568 |
+
noise_schedule=cosine_schedule,
|
| 569 |
+
generator: torch.Generator = None,
|
| 570 |
+
config=None,
|
| 571 |
+
seq_len=1024,
|
| 572 |
+
mask_token_id = 126336,
|
| 573 |
+
resolution = 512,
|
| 574 |
+
codebook_size = 8192,
|
| 575 |
+
vq_model = None,
|
| 576 |
+
**kwargs,
|
| 577 |
+
):
|
| 578 |
+
"""
|
| 579 |
+
Generate 1:1 similar to the original MaskGit repo
|
| 580 |
+
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
|
| 581 |
+
"""
|
| 582 |
+
|
| 583 |
+
# begin with all image token ids masked
|
| 584 |
+
# 计算有多少个mask token
|
| 585 |
+
mask_count = (input_ids == mask_token_id).sum().item()
|
| 586 |
+
num_vq_tokens = seq_len
|
| 587 |
+
num_new_special_tokens = 0
|
| 588 |
+
uni_prompting = kwargs.get("uni_prompting", None)
|
| 589 |
+
# print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}")
|
| 590 |
+
input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
|
| 591 |
+
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens)
|
| 592 |
+
|
| 593 |
+
# for classifier-free guidance
|
| 594 |
+
if uncond_input_ids is not None:
|
| 595 |
+
uncond_prefix = uncond_input_ids[:, :resolution + 1]
|
| 596 |
+
|
| 597 |
+
for step in range(timesteps):
|
| 598 |
+
if uncond_input_ids is not None and guidance_scale > 0:
|
| 599 |
+
uncond_input_ids = torch.cat(
|
| 600 |
+
[uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
|
| 601 |
+
model_input = torch.cat([input_ids, uncond_input_ids])
|
| 602 |
+
attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
|
| 603 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
| 604 |
+
logits = self(model_input, attention_bias=attention_bias).logits
|
| 605 |
+
# print(f"logits.shape: {logits.shape}")
|
| 606 |
+
cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
|
| 607 |
+
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
| 608 |
+
# it seems that muse has a different cfg setting
|
| 609 |
+
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
|
| 610 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
| 611 |
+
else:
|
| 612 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
| 613 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
| 614 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
| 615 |
+
|
| 616 |
+
# logits: 1, 1024, 8192
|
| 617 |
+
# print(f"logits.shape: {logits.shape}")
|
| 618 |
+
probs = logits.softmax(dim=-1)
|
| 619 |
+
sampled = probs.reshape(-1, logits.size(-1))
|
| 620 |
+
# print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}")
|
| 621 |
+
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024
|
| 622 |
+
|
| 623 |
+
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
|
| 624 |
+
# print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}")
|
| 625 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
|
| 626 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 627 |
+
current_image_vq_indices = sampled_ids.clone()
|
| 628 |
+
# print(f"current_image_vq_indices: {current_image_vq_indices}")
|
| 629 |
+
current_image_vq_indices = torch.clamp(current_image_vq_indices, 0, 8192 - 1)
|
| 630 |
+
current_image = vq_model.decode_code(current_image_vq_indices)
|
| 631 |
+
images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0)
|
| 632 |
+
images *= 255.0
|
| 633 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
| 634 |
+
pil_images = Image.fromarray(images[0])
|
| 635 |
+
yield pil_images, f"Step {step + 1}/{timesteps}"
|
| 636 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 637 |
+
ratio = 1.0 * (step + 1) / timesteps
|
| 638 |
+
mask_ratio = noise_schedule(torch.tensor(ratio))
|
| 639 |
+
# Computes the probabilities of each selected tokens.
|
| 640 |
+
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
|
| 641 |
+
selected_probs = selected_probs.squeeze(-1)
|
| 642 |
+
|
| 643 |
+
# Ignores the tokens given in the input by overwriting their confidence.
|
| 644 |
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
| 645 |
+
# Gets mask lens for each sample in the batch according to the mask ratio.
|
| 646 |
+
mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
|
| 647 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 648 |
+
# one and for the next iteration
|
| 649 |
+
mask_len = torch.max(
|
| 650 |
+
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
| 651 |
+
)
|
| 652 |
+
# print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}")
|
| 653 |
+
# Adds noise for randomness
|
| 654 |
+
temperature = temperature * (1.0 - ratio)
|
| 655 |
+
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
|
| 656 |
+
# Masks tokens with lower confidence.
|
| 657 |
+
input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
|
| 658 |
+
sampled_ids + len(uni_prompting.text_tokenizer)
|
| 659 |
+
+ num_new_special_tokens)
|
| 660 |
+
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
return sampled_ids
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
AutoConfig.register("mmada", MMadaConfig)
|
| 667 |
+
AutoModelForCausalLM.register(MMadaConfig, MMadaModelLM)
|
| 668 |
+
AutoModel.register(MMadaConfig, MMadaModelLM)
|
models/modeling_utils.py
ADDED
|
@@ -0,0 +1,1207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
import itertools
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from functools import partial
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import safetensors
|
| 28 |
+
import torch
|
| 29 |
+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
| 30 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
| 31 |
+
from torch import Tensor, nn
|
| 32 |
+
|
| 33 |
+
from diffusers import __version__
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
FLAX_WEIGHTS_NAME,
|
| 36 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
| 37 |
+
WEIGHTS_INDEX_NAME,
|
| 38 |
+
_add_variant,
|
| 39 |
+
_get_checkpoint_shard_files,
|
| 40 |
+
_get_model_file,
|
| 41 |
+
deprecate,
|
| 42 |
+
is_accelerate_available,
|
| 43 |
+
is_torch_version,
|
| 44 |
+
logging,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
CONFIG_NAME = "config.json"
|
| 48 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 49 |
+
SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
|
| 50 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
| 51 |
+
|
| 52 |
+
from diffusers.utils.hub_utils import (
|
| 53 |
+
PushToHubMixin,
|
| 54 |
+
load_or_create_model_card,
|
| 55 |
+
populate_model_card,
|
| 56 |
+
)
|
| 57 |
+
from diffusers.models.model_loading_utils import (
|
| 58 |
+
_determine_device_map,
|
| 59 |
+
_fetch_index_file,
|
| 60 |
+
_load_state_dict_into_model,
|
| 61 |
+
load_model_dict_into_meta,
|
| 62 |
+
load_state_dict,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 66 |
+
|
| 67 |
+
logger = logging.get_logger(__name__)
|
| 68 |
+
|
| 69 |
+
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if is_torch_version(">=", "1.9.0"):
|
| 73 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
| 74 |
+
else:
|
| 75 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if is_accelerate_available():
|
| 79 |
+
import accelerate
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
| 83 |
+
try:
|
| 84 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
| 85 |
+
return next(parameters_and_buffers).device
|
| 86 |
+
except StopIteration:
|
| 87 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 88 |
+
|
| 89 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 90 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 91 |
+
return tuples
|
| 92 |
+
|
| 93 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 94 |
+
first_tuple = next(gen)
|
| 95 |
+
return first_tuple[1].device
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
| 99 |
+
try:
|
| 100 |
+
params = tuple(parameter.parameters())
|
| 101 |
+
if len(params) > 0:
|
| 102 |
+
return params[0].dtype
|
| 103 |
+
|
| 104 |
+
buffers = tuple(parameter.buffers())
|
| 105 |
+
if len(buffers) > 0:
|
| 106 |
+
return buffers[0].dtype
|
| 107 |
+
|
| 108 |
+
except StopIteration:
|
| 109 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 110 |
+
|
| 111 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 112 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 113 |
+
return tuples
|
| 114 |
+
|
| 115 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 116 |
+
first_tuple = next(gen)
|
| 117 |
+
return first_tuple[1].dtype
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
| 121 |
+
r"""
|
| 122 |
+
Base class for all models.
|
| 123 |
+
|
| 124 |
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
| 125 |
+
saving models.
|
| 126 |
+
|
| 127 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
config_name = CONFIG_NAME
|
| 131 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
| 132 |
+
_supports_gradient_checkpointing = False
|
| 133 |
+
_keys_to_ignore_on_load_unexpected = None
|
| 134 |
+
_no_split_modules = None
|
| 135 |
+
|
| 136 |
+
def __init__(self):
|
| 137 |
+
super().__init__()
|
| 138 |
+
|
| 139 |
+
def __getattr__(self, name: str) -> Any:
|
| 140 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
| 141 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
| 142 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
| 143 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
| 147 |
+
is_attribute = name in self.__dict__
|
| 148 |
+
|
| 149 |
+
if is_in_config and not is_attribute:
|
| 150 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
| 151 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
| 152 |
+
return self._internal_dict[name]
|
| 153 |
+
|
| 154 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
| 155 |
+
return super().__getattr__(name)
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def is_gradient_checkpointing(self) -> bool:
|
| 159 |
+
"""
|
| 160 |
+
Whether gradient checkpointing is activated for this model or not.
|
| 161 |
+
"""
|
| 162 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
| 163 |
+
|
| 164 |
+
def enable_gradient_checkpointing(self) -> None:
|
| 165 |
+
"""
|
| 166 |
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
| 167 |
+
*checkpoint activations* in other frameworks).
|
| 168 |
+
"""
|
| 169 |
+
if not self._supports_gradient_checkpointing:
|
| 170 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
| 171 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
| 172 |
+
|
| 173 |
+
def disable_gradient_checkpointing(self) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
| 176 |
+
*checkpoint activations* in other frameworks).
|
| 177 |
+
"""
|
| 178 |
+
if self._supports_gradient_checkpointing:
|
| 179 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
| 180 |
+
|
| 181 |
+
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
| 182 |
+
r"""
|
| 183 |
+
Set the switch for the npu flash attention.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
| 187 |
+
if hasattr(module, "set_use_npu_flash_attention"):
|
| 188 |
+
module.set_use_npu_flash_attention(valid)
|
| 189 |
+
|
| 190 |
+
for child in module.children():
|
| 191 |
+
fn_recursive_set_npu_flash_attention(child)
|
| 192 |
+
|
| 193 |
+
for module in self.children():
|
| 194 |
+
if isinstance(module, torch.nn.Module):
|
| 195 |
+
fn_recursive_set_npu_flash_attention(module)
|
| 196 |
+
|
| 197 |
+
def enable_npu_flash_attention(self) -> None:
|
| 198 |
+
r"""
|
| 199 |
+
Enable npu flash attention from torch_npu
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
self.set_use_npu_flash_attention(True)
|
| 203 |
+
|
| 204 |
+
def disable_npu_flash_attention(self) -> None:
|
| 205 |
+
r"""
|
| 206 |
+
disable npu flash attention from torch_npu
|
| 207 |
+
|
| 208 |
+
"""
|
| 209 |
+
self.set_use_npu_flash_attention(False)
|
| 210 |
+
|
| 211 |
+
def set_use_memory_efficient_attention_xformers(
|
| 212 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
| 213 |
+
) -> None:
|
| 214 |
+
# Recursively walk through all the children.
|
| 215 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
| 216 |
+
# gets the message
|
| 217 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
| 218 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
| 219 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
| 220 |
+
|
| 221 |
+
for child in module.children():
|
| 222 |
+
fn_recursive_set_mem_eff(child)
|
| 223 |
+
|
| 224 |
+
for module in self.children():
|
| 225 |
+
if isinstance(module, torch.nn.Module):
|
| 226 |
+
fn_recursive_set_mem_eff(module)
|
| 227 |
+
|
| 228 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
|
| 229 |
+
r"""
|
| 230 |
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
| 231 |
+
|
| 232 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
| 233 |
+
inference. Speed up during training is not guaranteed.
|
| 234 |
+
|
| 235 |
+
<Tip warning={true}>
|
| 236 |
+
|
| 237 |
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
| 238 |
+
precedent.
|
| 239 |
+
|
| 240 |
+
</Tip>
|
| 241 |
+
|
| 242 |
+
Parameters:
|
| 243 |
+
attention_op (`Callable`, *optional*):
|
| 244 |
+
Override the default `None` operator for use as `op` argument to the
|
| 245 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
| 246 |
+
function of xFormers.
|
| 247 |
+
|
| 248 |
+
Examples:
|
| 249 |
+
|
| 250 |
+
```py
|
| 251 |
+
>>> import torch
|
| 252 |
+
>>> from diffusers import UNet2DConditionModel
|
| 253 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
| 254 |
+
|
| 255 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
| 256 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
| 257 |
+
... )
|
| 258 |
+
>>> model = model.to("cuda")
|
| 259 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
| 260 |
+
```
|
| 261 |
+
"""
|
| 262 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
| 263 |
+
|
| 264 |
+
def disable_xformers_memory_efficient_attention(self) -> None:
|
| 265 |
+
r"""
|
| 266 |
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
| 267 |
+
"""
|
| 268 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
| 269 |
+
|
| 270 |
+
def save_pretrained(
|
| 271 |
+
self,
|
| 272 |
+
save_directory: Union[str, os.PathLike],
|
| 273 |
+
is_main_process: bool = True,
|
| 274 |
+
save_function: Optional[Callable] = None,
|
| 275 |
+
safe_serialization: bool = True,
|
| 276 |
+
variant: Optional[str] = None,
|
| 277 |
+
max_shard_size: Union[int, str] = "10GB",
|
| 278 |
+
push_to_hub: bool = False,
|
| 279 |
+
**kwargs,
|
| 280 |
+
):
|
| 281 |
+
"""
|
| 282 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
| 283 |
+
[`~models.ModelMixin.from_pretrained`] class method.
|
| 284 |
+
|
| 285 |
+
Arguments:
|
| 286 |
+
save_directory (`str` or `os.PathLike`):
|
| 287 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
| 288 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
| 289 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
| 290 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
| 291 |
+
process to avoid race conditions.
|
| 292 |
+
save_function (`Callable`):
|
| 293 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
| 294 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
| 295 |
+
`DIFFUSERS_SAVE_MODE`.
|
| 296 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
| 297 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
| 298 |
+
variant (`str`, *optional*):
|
| 299 |
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
| 300 |
+
max_shard_size (`int` or `str`, defaults to `"10GB"`):
|
| 301 |
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
| 302 |
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
| 303 |
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
| 304 |
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
| 305 |
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
| 306 |
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
| 307 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 308 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
| 309 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 310 |
+
namespace).
|
| 311 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 312 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 313 |
+
"""
|
| 314 |
+
if os.path.isfile(save_directory):
|
| 315 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 316 |
+
return
|
| 317 |
+
|
| 318 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
| 319 |
+
weights_name = _add_variant(weights_name, variant)
|
| 320 |
+
weight_name_split = weights_name.split(".")
|
| 321 |
+
if len(weight_name_split) in [2, 3]:
|
| 322 |
+
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError(f"Invalid {weights_name} provided.")
|
| 325 |
+
|
| 326 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 327 |
+
|
| 328 |
+
if push_to_hub:
|
| 329 |
+
commit_message = kwargs.pop("commit_message", None)
|
| 330 |
+
private = kwargs.pop("private", False)
|
| 331 |
+
create_pr = kwargs.pop("create_pr", False)
|
| 332 |
+
token = kwargs.pop("token", None)
|
| 333 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
| 334 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
| 335 |
+
|
| 336 |
+
# Only save the model itself if we are using distributed training
|
| 337 |
+
model_to_save = self
|
| 338 |
+
|
| 339 |
+
# Attach architecture to the config
|
| 340 |
+
# Save the config
|
| 341 |
+
if is_main_process:
|
| 342 |
+
model_to_save.save_config(save_directory)
|
| 343 |
+
|
| 344 |
+
# Save the model
|
| 345 |
+
state_dict = model_to_save.state_dict()
|
| 346 |
+
|
| 347 |
+
# Save the model
|
| 348 |
+
state_dict_split = split_torch_state_dict_into_shards(
|
| 349 |
+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Clean the folder from a previous save
|
| 353 |
+
if is_main_process:
|
| 354 |
+
for filename in os.listdir(save_directory):
|
| 355 |
+
if filename in state_dict_split.filename_to_tensors.keys():
|
| 356 |
+
continue
|
| 357 |
+
full_filename = os.path.join(save_directory, filename)
|
| 358 |
+
if not os.path.isfile(full_filename):
|
| 359 |
+
continue
|
| 360 |
+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
| 361 |
+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
| 362 |
+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
| 363 |
+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
| 364 |
+
if (
|
| 365 |
+
filename.startswith(weights_without_ext)
|
| 366 |
+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
| 367 |
+
):
|
| 368 |
+
os.remove(full_filename)
|
| 369 |
+
|
| 370 |
+
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
| 371 |
+
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
| 372 |
+
filepath = os.path.join(save_directory, filename)
|
| 373 |
+
if safe_serialization:
|
| 374 |
+
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
| 375 |
+
# joyfulness), but for now this enough.
|
| 376 |
+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
| 377 |
+
else:
|
| 378 |
+
torch.save(shard, filepath)
|
| 379 |
+
|
| 380 |
+
if state_dict_split.is_sharded:
|
| 381 |
+
index = {
|
| 382 |
+
"metadata": state_dict_split.metadata,
|
| 383 |
+
"weight_map": state_dict_split.tensor_to_filename,
|
| 384 |
+
}
|
| 385 |
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
| 386 |
+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
| 387 |
+
# Save the index as well
|
| 388 |
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
| 389 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
| 390 |
+
f.write(content)
|
| 391 |
+
logger.info(
|
| 392 |
+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
| 393 |
+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
| 394 |
+
f"index located at {save_index_file}."
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
path_to_weights = os.path.join(save_directory, weights_name)
|
| 398 |
+
logger.info(f"Model weights saved in {path_to_weights}")
|
| 399 |
+
|
| 400 |
+
if push_to_hub:
|
| 401 |
+
# Create a new empty model card and eventually tag it
|
| 402 |
+
model_card = load_or_create_model_card(repo_id, token=token)
|
| 403 |
+
model_card = populate_model_card(model_card)
|
| 404 |
+
model_card.save(Path(save_directory, "README.md").as_posix())
|
| 405 |
+
|
| 406 |
+
self._upload_folder(
|
| 407 |
+
save_directory,
|
| 408 |
+
repo_id,
|
| 409 |
+
token=token,
|
| 410 |
+
commit_message=commit_message,
|
| 411 |
+
create_pr=create_pr,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
@classmethod
|
| 415 |
+
@validate_hf_hub_args
|
| 416 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
| 417 |
+
r"""
|
| 418 |
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
| 419 |
+
|
| 420 |
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
| 421 |
+
train the model, set it back in training mode with `model.train()`.
|
| 422 |
+
|
| 423 |
+
Parameters:
|
| 424 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 425 |
+
Can be either:
|
| 426 |
+
|
| 427 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 428 |
+
the Hub.
|
| 429 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
| 430 |
+
with [`~ModelMixin.save_pretrained`].
|
| 431 |
+
|
| 432 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 433 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 434 |
+
is not used.
|
| 435 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
| 436 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
| 437 |
+
dtype is automatically derived from the model's weights.
|
| 438 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 439 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 440 |
+
cached versions if they exist.
|
| 441 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 442 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 443 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 444 |
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
| 445 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 446 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 447 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 448 |
+
won't be downloaded from the Hub.
|
| 449 |
+
token (`str` or *bool*, *optional*):
|
| 450 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 451 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 452 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 453 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 454 |
+
allowed by Git.
|
| 455 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
| 456 |
+
Load the model weights from a Flax checkpoint save file.
|
| 457 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 458 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 459 |
+
mirror (`str`, *optional*):
|
| 460 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
| 461 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
| 462 |
+
information.
|
| 463 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 464 |
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
| 465 |
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
| 466 |
+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
| 467 |
+
|
| 468 |
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
| 469 |
+
more information about each option see [designing a device
|
| 470 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
| 471 |
+
max_memory (`Dict`, *optional*):
|
| 472 |
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
| 473 |
+
each GPU and the available CPU RAM if unset.
|
| 474 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
| 475 |
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
| 476 |
+
offload_state_dict (`bool`, *optional*):
|
| 477 |
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
| 478 |
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
| 479 |
+
when there is some disk offload.
|
| 480 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
| 481 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
| 482 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
| 483 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
| 484 |
+
argument to `True` will raise an error.
|
| 485 |
+
variant (`str`, *optional*):
|
| 486 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
| 487 |
+
loading `from_flax`.
|
| 488 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
| 489 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
| 490 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
| 491 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
| 492 |
+
|
| 493 |
+
<Tip>
|
| 494 |
+
|
| 495 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
| 496 |
+
`huggingface-cli login`. You can also activate the special
|
| 497 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
| 498 |
+
firewalled environment.
|
| 499 |
+
|
| 500 |
+
</Tip>
|
| 501 |
+
|
| 502 |
+
Example:
|
| 503 |
+
|
| 504 |
+
```py
|
| 505 |
+
from diffusers import UNet2DConditionModel
|
| 506 |
+
|
| 507 |
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
| 508 |
+
```
|
| 509 |
+
|
| 510 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
| 511 |
+
|
| 512 |
+
```bash
|
| 513 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
| 514 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
| 515 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
| 516 |
+
```
|
| 517 |
+
"""
|
| 518 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 519 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
| 520 |
+
force_download = kwargs.pop("force_download", False)
|
| 521 |
+
from_flax = kwargs.pop("from_flax", False)
|
| 522 |
+
proxies = kwargs.pop("proxies", None)
|
| 523 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
| 524 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
| 525 |
+
token = kwargs.pop("token", None)
|
| 526 |
+
revision = kwargs.pop("revision", None)
|
| 527 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
| 528 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 529 |
+
device_map = kwargs.pop("device_map", None)
|
| 530 |
+
max_memory = kwargs.pop("max_memory", None)
|
| 531 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
| 532 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
| 533 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
| 534 |
+
variant = kwargs.pop("variant", None)
|
| 535 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
| 536 |
+
|
| 537 |
+
allow_pickle = False
|
| 538 |
+
if use_safetensors is None:
|
| 539 |
+
use_safetensors = True
|
| 540 |
+
allow_pickle = True
|
| 541 |
+
|
| 542 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
| 543 |
+
low_cpu_mem_usage = False
|
| 544 |
+
logger.warning(
|
| 545 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
| 546 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
| 547 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
| 548 |
+
" install accelerate\n```\n."
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if device_map is not None and not is_accelerate_available():
|
| 552 |
+
raise NotImplementedError(
|
| 553 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
| 554 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# Check if we can handle device_map and dispatching the weights
|
| 558 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
| 559 |
+
raise NotImplementedError(
|
| 560 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
| 561 |
+
" `device_map=None`."
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
| 565 |
+
raise NotImplementedError(
|
| 566 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
| 567 |
+
" `low_cpu_mem_usage=False`."
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
| 571 |
+
raise ValueError(
|
| 572 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
| 573 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# change device_map into a map if we passed an int, a str or a torch.device
|
| 577 |
+
if isinstance(device_map, torch.device):
|
| 578 |
+
device_map = {"": device_map}
|
| 579 |
+
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
| 580 |
+
try:
|
| 581 |
+
device_map = {"": torch.device(device_map)}
|
| 582 |
+
except RuntimeError:
|
| 583 |
+
raise ValueError(
|
| 584 |
+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
| 585 |
+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
| 586 |
+
)
|
| 587 |
+
elif isinstance(device_map, int):
|
| 588 |
+
if device_map < 0:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
| 591 |
+
)
|
| 592 |
+
else:
|
| 593 |
+
device_map = {"": device_map}
|
| 594 |
+
|
| 595 |
+
if device_map is not None:
|
| 596 |
+
if low_cpu_mem_usage is None:
|
| 597 |
+
low_cpu_mem_usage = True
|
| 598 |
+
elif not low_cpu_mem_usage:
|
| 599 |
+
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
| 600 |
+
|
| 601 |
+
if low_cpu_mem_usage:
|
| 602 |
+
if device_map is not None and not is_torch_version(">=", "1.10"):
|
| 603 |
+
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
| 604 |
+
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
| 605 |
+
|
| 606 |
+
# Load config if we don't provide a configuration
|
| 607 |
+
config_path = pretrained_model_name_or_path
|
| 608 |
+
|
| 609 |
+
user_agent = {
|
| 610 |
+
"diffusers": __version__,
|
| 611 |
+
"file_type": "model",
|
| 612 |
+
"framework": "pytorch",
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
# load config
|
| 616 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
| 617 |
+
config_path,
|
| 618 |
+
cache_dir=cache_dir,
|
| 619 |
+
return_unused_kwargs=True,
|
| 620 |
+
return_commit_hash=True,
|
| 621 |
+
force_download=force_download,
|
| 622 |
+
proxies=proxies,
|
| 623 |
+
local_files_only=local_files_only,
|
| 624 |
+
token=token,
|
| 625 |
+
revision=revision,
|
| 626 |
+
subfolder=subfolder,
|
| 627 |
+
user_agent=user_agent,
|
| 628 |
+
**kwargs,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Determine if we're loading from a directory of sharded checkpoints.
|
| 632 |
+
is_sharded = False
|
| 633 |
+
index_file = None
|
| 634 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
| 635 |
+
index_file = _fetch_index_file(
|
| 636 |
+
is_local=is_local,
|
| 637 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 638 |
+
subfolder=subfolder or "",
|
| 639 |
+
use_safetensors=use_safetensors,
|
| 640 |
+
cache_dir=cache_dir,
|
| 641 |
+
variant=variant,
|
| 642 |
+
force_download=force_download,
|
| 643 |
+
proxies=proxies,
|
| 644 |
+
local_files_only=local_files_only,
|
| 645 |
+
token=token,
|
| 646 |
+
revision=revision,
|
| 647 |
+
user_agent=user_agent,
|
| 648 |
+
commit_hash=commit_hash,
|
| 649 |
+
)
|
| 650 |
+
if index_file is not None and index_file.is_file():
|
| 651 |
+
is_sharded = True
|
| 652 |
+
|
| 653 |
+
if is_sharded and from_flax:
|
| 654 |
+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
| 655 |
+
|
| 656 |
+
# load model
|
| 657 |
+
model_file = None
|
| 658 |
+
if from_flax:
|
| 659 |
+
model_file = _get_model_file(
|
| 660 |
+
pretrained_model_name_or_path,
|
| 661 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
| 662 |
+
cache_dir=cache_dir,
|
| 663 |
+
force_download=force_download,
|
| 664 |
+
proxies=proxies,
|
| 665 |
+
local_files_only=local_files_only,
|
| 666 |
+
token=token,
|
| 667 |
+
revision=revision,
|
| 668 |
+
subfolder=subfolder,
|
| 669 |
+
user_agent=user_agent,
|
| 670 |
+
commit_hash=commit_hash,
|
| 671 |
+
)
|
| 672 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 673 |
+
|
| 674 |
+
# Convert the weights
|
| 675 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
| 676 |
+
|
| 677 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
| 678 |
+
else:
|
| 679 |
+
if is_sharded:
|
| 680 |
+
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
| 681 |
+
pretrained_model_name_or_path,
|
| 682 |
+
index_file,
|
| 683 |
+
cache_dir=cache_dir,
|
| 684 |
+
proxies=proxies,
|
| 685 |
+
local_files_only=local_files_only,
|
| 686 |
+
token=token,
|
| 687 |
+
user_agent=user_agent,
|
| 688 |
+
revision=revision,
|
| 689 |
+
subfolder=subfolder or "",
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
elif use_safetensors and not is_sharded:
|
| 693 |
+
try:
|
| 694 |
+
model_file = _get_model_file(
|
| 695 |
+
pretrained_model_name_or_path,
|
| 696 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
| 697 |
+
cache_dir=cache_dir,
|
| 698 |
+
force_download=force_download,
|
| 699 |
+
proxies=proxies,
|
| 700 |
+
local_files_only=local_files_only,
|
| 701 |
+
token=token,
|
| 702 |
+
revision=revision,
|
| 703 |
+
subfolder=subfolder,
|
| 704 |
+
user_agent=user_agent,
|
| 705 |
+
commit_hash=commit_hash,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
except IOError as e:
|
| 709 |
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
| 710 |
+
if not allow_pickle:
|
| 711 |
+
raise
|
| 712 |
+
logger.warning(
|
| 713 |
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
if model_file is None and not is_sharded:
|
| 717 |
+
model_file = _get_model_file(
|
| 718 |
+
pretrained_model_name_or_path,
|
| 719 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
| 720 |
+
cache_dir=cache_dir,
|
| 721 |
+
force_download=force_download,
|
| 722 |
+
proxies=proxies,
|
| 723 |
+
local_files_only=local_files_only,
|
| 724 |
+
token=token,
|
| 725 |
+
revision=revision,
|
| 726 |
+
subfolder=subfolder,
|
| 727 |
+
user_agent=user_agent,
|
| 728 |
+
commit_hash=commit_hash,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
if low_cpu_mem_usage:
|
| 732 |
+
# Instantiate model with empty weights
|
| 733 |
+
with accelerate.init_empty_weights():
|
| 734 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 735 |
+
|
| 736 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
| 737 |
+
if device_map is None and not is_sharded:
|
| 738 |
+
param_device = "cpu"
|
| 739 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
| 740 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 741 |
+
# move the params from meta device to cpu
|
| 742 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 743 |
+
if len(missing_keys) > 0:
|
| 744 |
+
raise ValueError(
|
| 745 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
| 746 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 747 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 748 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 752 |
+
model,
|
| 753 |
+
state_dict,
|
| 754 |
+
device=param_device,
|
| 755 |
+
dtype=torch_dtype,
|
| 756 |
+
model_name_or_path=pretrained_model_name_or_path,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 760 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 761 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 762 |
+
|
| 763 |
+
if len(unexpected_keys) > 0:
|
| 764 |
+
logger.warning(
|
| 765 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
else: # else let accelerate handle loading and dispatching.
|
| 769 |
+
# Load weights and dispatch according to the device_map
|
| 770 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
| 771 |
+
force_hook = True
|
| 772 |
+
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
| 773 |
+
if device_map is None and is_sharded:
|
| 774 |
+
# we load the parameters on the cpu
|
| 775 |
+
device_map = {"": "cpu"}
|
| 776 |
+
force_hook = False
|
| 777 |
+
try:
|
| 778 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 779 |
+
model,
|
| 780 |
+
model_file if not is_sharded else index_file,
|
| 781 |
+
device_map,
|
| 782 |
+
max_memory=max_memory,
|
| 783 |
+
offload_folder=offload_folder,
|
| 784 |
+
offload_state_dict=offload_state_dict,
|
| 785 |
+
dtype=torch_dtype,
|
| 786 |
+
force_hooks=force_hook,
|
| 787 |
+
strict=True,
|
| 788 |
+
)
|
| 789 |
+
except AttributeError as e:
|
| 790 |
+
# When using accelerate loading, we do not have the ability to load the state
|
| 791 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
| 792 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
| 793 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
| 794 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
| 795 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
| 796 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
| 797 |
+
# the weights so we don't have to do this again.
|
| 798 |
+
|
| 799 |
+
if "'Attention' object has no attribute" in str(e):
|
| 800 |
+
logger.warning(
|
| 801 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
| 802 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
| 803 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
| 804 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
| 805 |
+
" please also re-upload it or open a PR on the original repository."
|
| 806 |
+
)
|
| 807 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
| 808 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 809 |
+
model,
|
| 810 |
+
model_file if not is_sharded else index_file,
|
| 811 |
+
device_map,
|
| 812 |
+
max_memory=max_memory,
|
| 813 |
+
offload_folder=offload_folder,
|
| 814 |
+
offload_state_dict=offload_state_dict,
|
| 815 |
+
dtype=torch_dtype,
|
| 816 |
+
force_hooks=force_hook,
|
| 817 |
+
strict=True,
|
| 818 |
+
)
|
| 819 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
| 820 |
+
else:
|
| 821 |
+
raise e
|
| 822 |
+
|
| 823 |
+
loading_info = {
|
| 824 |
+
"missing_keys": [],
|
| 825 |
+
"unexpected_keys": [],
|
| 826 |
+
"mismatched_keys": [],
|
| 827 |
+
"error_msgs": [],
|
| 828 |
+
}
|
| 829 |
+
else:
|
| 830 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 831 |
+
|
| 832 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
| 833 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 834 |
+
|
| 835 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
| 836 |
+
model,
|
| 837 |
+
state_dict,
|
| 838 |
+
model_file,
|
| 839 |
+
pretrained_model_name_or_path,
|
| 840 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
loading_info = {
|
| 844 |
+
"missing_keys": missing_keys,
|
| 845 |
+
"unexpected_keys": unexpected_keys,
|
| 846 |
+
"mismatched_keys": mismatched_keys,
|
| 847 |
+
"error_msgs": error_msgs,
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
| 851 |
+
raise ValueError(
|
| 852 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
| 853 |
+
)
|
| 854 |
+
elif torch_dtype is not None:
|
| 855 |
+
model = model.to(torch_dtype)
|
| 856 |
+
|
| 857 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
| 858 |
+
|
| 859 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
| 860 |
+
model.eval()
|
| 861 |
+
if output_loading_info:
|
| 862 |
+
return model, loading_info
|
| 863 |
+
|
| 864 |
+
return model
|
| 865 |
+
|
| 866 |
+
@classmethod
|
| 867 |
+
def _load_pretrained_model(
|
| 868 |
+
cls,
|
| 869 |
+
model,
|
| 870 |
+
state_dict: OrderedDict,
|
| 871 |
+
resolved_archive_file,
|
| 872 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 873 |
+
ignore_mismatched_sizes: bool = False,
|
| 874 |
+
):
|
| 875 |
+
# Retrieve missing & unexpected_keys
|
| 876 |
+
model_state_dict = model.state_dict()
|
| 877 |
+
loaded_keys = list(state_dict.keys())
|
| 878 |
+
|
| 879 |
+
expected_keys = list(model_state_dict.keys())
|
| 880 |
+
|
| 881 |
+
original_loaded_keys = loaded_keys
|
| 882 |
+
|
| 883 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
| 884 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
| 885 |
+
|
| 886 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
| 887 |
+
model_to_load = model
|
| 888 |
+
|
| 889 |
+
def _find_mismatched_keys(
|
| 890 |
+
state_dict,
|
| 891 |
+
model_state_dict,
|
| 892 |
+
loaded_keys,
|
| 893 |
+
ignore_mismatched_sizes,
|
| 894 |
+
):
|
| 895 |
+
mismatched_keys = []
|
| 896 |
+
if ignore_mismatched_sizes:
|
| 897 |
+
for checkpoint_key in loaded_keys:
|
| 898 |
+
model_key = checkpoint_key
|
| 899 |
+
|
| 900 |
+
if (
|
| 901 |
+
model_key in model_state_dict
|
| 902 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
| 903 |
+
):
|
| 904 |
+
mismatched_keys.append(
|
| 905 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
| 906 |
+
)
|
| 907 |
+
del state_dict[checkpoint_key]
|
| 908 |
+
return mismatched_keys
|
| 909 |
+
|
| 910 |
+
if state_dict is not None:
|
| 911 |
+
# Whole checkpoint
|
| 912 |
+
mismatched_keys = _find_mismatched_keys(
|
| 913 |
+
state_dict,
|
| 914 |
+
model_state_dict,
|
| 915 |
+
original_loaded_keys,
|
| 916 |
+
ignore_mismatched_sizes,
|
| 917 |
+
)
|
| 918 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
| 919 |
+
|
| 920 |
+
if len(error_msgs) > 0:
|
| 921 |
+
error_msg = "\n\t".join(error_msgs)
|
| 922 |
+
if "size mismatch" in error_msg:
|
| 923 |
+
error_msg += (
|
| 924 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
| 925 |
+
)
|
| 926 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
| 927 |
+
|
| 928 |
+
if len(unexpected_keys) > 0:
|
| 929 |
+
logger.warning(
|
| 930 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
| 931 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
| 932 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
| 933 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
| 934 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
| 935 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
| 936 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
| 937 |
+
" BertForSequenceClassification model)."
|
| 938 |
+
)
|
| 939 |
+
else:
|
| 940 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
| 941 |
+
if len(missing_keys) > 0:
|
| 942 |
+
logger.warning(
|
| 943 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
| 944 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
| 945 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
| 946 |
+
)
|
| 947 |
+
elif len(mismatched_keys) == 0:
|
| 948 |
+
logger.info(
|
| 949 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
| 950 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
| 951 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
| 952 |
+
" without further training."
|
| 953 |
+
)
|
| 954 |
+
if len(mismatched_keys) > 0:
|
| 955 |
+
mismatched_warning = "\n".join(
|
| 956 |
+
[
|
| 957 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
| 958 |
+
for key, shape1, shape2 in mismatched_keys
|
| 959 |
+
]
|
| 960 |
+
)
|
| 961 |
+
logger.warning(
|
| 962 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
| 963 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
| 964 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
| 965 |
+
" able to use it for predictions and inference."
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
| 969 |
+
|
| 970 |
+
@classmethod
|
| 971 |
+
def _get_signature_keys(cls, obj):
|
| 972 |
+
parameters = inspect.signature(obj.__init__).parameters
|
| 973 |
+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
| 974 |
+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
| 975 |
+
expected_modules = set(required_parameters.keys()) - {"self"}
|
| 976 |
+
|
| 977 |
+
return expected_modules, optional_parameters
|
| 978 |
+
|
| 979 |
+
# Adapted from `transformers` modeling_utils.py
|
| 980 |
+
def _get_no_split_modules(self, device_map: str):
|
| 981 |
+
"""
|
| 982 |
+
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
| 983 |
+
get the underlying `_no_split_modules`.
|
| 984 |
+
|
| 985 |
+
Args:
|
| 986 |
+
device_map (`str`):
|
| 987 |
+
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
`List[str]`: List of modules that should not be split
|
| 991 |
+
"""
|
| 992 |
+
_no_split_modules = set()
|
| 993 |
+
modules_to_check = [self]
|
| 994 |
+
while len(modules_to_check) > 0:
|
| 995 |
+
module = modules_to_check.pop(-1)
|
| 996 |
+
# if the module does not appear in _no_split_modules, we also check the children
|
| 997 |
+
if module.__class__.__name__ not in _no_split_modules:
|
| 998 |
+
if isinstance(module, ModelMixin):
|
| 999 |
+
if module._no_split_modules is None:
|
| 1000 |
+
raise ValueError(
|
| 1001 |
+
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
| 1002 |
+
"class needs to implement the `_no_split_modules` attribute."
|
| 1003 |
+
)
|
| 1004 |
+
else:
|
| 1005 |
+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
| 1006 |
+
modules_to_check += list(module.children())
|
| 1007 |
+
return list(_no_split_modules)
|
| 1008 |
+
|
| 1009 |
+
@property
|
| 1010 |
+
def device(self) -> torch.device:
|
| 1011 |
+
"""
|
| 1012 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
| 1013 |
+
device).
|
| 1014 |
+
"""
|
| 1015 |
+
return get_parameter_device(self)
|
| 1016 |
+
|
| 1017 |
+
@property
|
| 1018 |
+
def dtype(self) -> torch.dtype:
|
| 1019 |
+
"""
|
| 1020 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
| 1021 |
+
"""
|
| 1022 |
+
return get_parameter_dtype(self)
|
| 1023 |
+
|
| 1024 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
| 1025 |
+
"""
|
| 1026 |
+
Get number of (trainable or non-embedding) parameters in the module.
|
| 1027 |
+
|
| 1028 |
+
Args:
|
| 1029 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
| 1030 |
+
Whether or not to return only the number of trainable parameters.
|
| 1031 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
| 1032 |
+
Whether or not to return only the number of non-embedding parameters.
|
| 1033 |
+
|
| 1034 |
+
Returns:
|
| 1035 |
+
`int`: The number of parameters.
|
| 1036 |
+
|
| 1037 |
+
Example:
|
| 1038 |
+
|
| 1039 |
+
```py
|
| 1040 |
+
from diffusers import UNet2DConditionModel
|
| 1041 |
+
|
| 1042 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
| 1043 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
| 1044 |
+
unet.num_parameters(only_trainable=True)
|
| 1045 |
+
859520964
|
| 1046 |
+
```
|
| 1047 |
+
"""
|
| 1048 |
+
|
| 1049 |
+
if exclude_embeddings:
|
| 1050 |
+
embedding_param_names = [
|
| 1051 |
+
f"{name}.weight"
|
| 1052 |
+
for name, module_type in self.named_modules()
|
| 1053 |
+
if isinstance(module_type, torch.nn.Embedding)
|
| 1054 |
+
]
|
| 1055 |
+
non_embedding_parameters = [
|
| 1056 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
| 1057 |
+
]
|
| 1058 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
| 1059 |
+
else:
|
| 1060 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
| 1061 |
+
|
| 1062 |
+
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
| 1063 |
+
deprecated_attention_block_paths = []
|
| 1064 |
+
|
| 1065 |
+
def recursive_find_attn_block(name, module):
|
| 1066 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
| 1067 |
+
deprecated_attention_block_paths.append(name)
|
| 1068 |
+
|
| 1069 |
+
for sub_name, sub_module in module.named_children():
|
| 1070 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
| 1071 |
+
recursive_find_attn_block(sub_name, sub_module)
|
| 1072 |
+
|
| 1073 |
+
recursive_find_attn_block("", self)
|
| 1074 |
+
|
| 1075 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
| 1076 |
+
# because it is possible we are loading from a state dict that was already
|
| 1077 |
+
# converted
|
| 1078 |
+
|
| 1079 |
+
for path in deprecated_attention_block_paths:
|
| 1080 |
+
# group_norm path stays the same
|
| 1081 |
+
|
| 1082 |
+
# query -> to_q
|
| 1083 |
+
if f"{path}.query.weight" in state_dict:
|
| 1084 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
| 1085 |
+
if f"{path}.query.bias" in state_dict:
|
| 1086 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
| 1087 |
+
|
| 1088 |
+
# key -> to_k
|
| 1089 |
+
if f"{path}.key.weight" in state_dict:
|
| 1090 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
| 1091 |
+
if f"{path}.key.bias" in state_dict:
|
| 1092 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
| 1093 |
+
|
| 1094 |
+
# value -> to_v
|
| 1095 |
+
if f"{path}.value.weight" in state_dict:
|
| 1096 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
| 1097 |
+
if f"{path}.value.bias" in state_dict:
|
| 1098 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
| 1099 |
+
|
| 1100 |
+
# proj_attn -> to_out.0
|
| 1101 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
| 1102 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
| 1103 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
| 1104 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
| 1105 |
+
|
| 1106 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
| 1107 |
+
deprecated_attention_block_modules = []
|
| 1108 |
+
|
| 1109 |
+
def recursive_find_attn_block(module):
|
| 1110 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
| 1111 |
+
deprecated_attention_block_modules.append(module)
|
| 1112 |
+
|
| 1113 |
+
for sub_module in module.children():
|
| 1114 |
+
recursive_find_attn_block(sub_module)
|
| 1115 |
+
|
| 1116 |
+
recursive_find_attn_block(self)
|
| 1117 |
+
|
| 1118 |
+
for module in deprecated_attention_block_modules:
|
| 1119 |
+
module.query = module.to_q
|
| 1120 |
+
module.key = module.to_k
|
| 1121 |
+
module.value = module.to_v
|
| 1122 |
+
module.proj_attn = module.to_out[0]
|
| 1123 |
+
|
| 1124 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
| 1125 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
| 1126 |
+
# making an incorrect assumption that this model should be converted when
|
| 1127 |
+
# it really shouldn't be.
|
| 1128 |
+
del module.to_q
|
| 1129 |
+
del module.to_k
|
| 1130 |
+
del module.to_v
|
| 1131 |
+
del module.to_out
|
| 1132 |
+
|
| 1133 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
| 1134 |
+
deprecated_attention_block_modules = []
|
| 1135 |
+
|
| 1136 |
+
def recursive_find_attn_block(module) -> None:
|
| 1137 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
| 1138 |
+
deprecated_attention_block_modules.append(module)
|
| 1139 |
+
|
| 1140 |
+
for sub_module in module.children():
|
| 1141 |
+
recursive_find_attn_block(sub_module)
|
| 1142 |
+
|
| 1143 |
+
recursive_find_attn_block(self)
|
| 1144 |
+
|
| 1145 |
+
for module in deprecated_attention_block_modules:
|
| 1146 |
+
module.to_q = module.query
|
| 1147 |
+
module.to_k = module.key
|
| 1148 |
+
module.to_v = module.value
|
| 1149 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
| 1150 |
+
|
| 1151 |
+
del module.query
|
| 1152 |
+
del module.key
|
| 1153 |
+
del module.value
|
| 1154 |
+
del module.proj_attn
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
class LegacyModelMixin(ModelMixin):
|
| 1158 |
+
r"""
|
| 1159 |
+
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
| 1160 |
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
| 1161 |
+
"""
|
| 1162 |
+
|
| 1163 |
+
@classmethod
|
| 1164 |
+
@validate_hf_hub_args
|
| 1165 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
| 1166 |
+
# To prevent dependency import problem.
|
| 1167 |
+
from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config
|
| 1168 |
+
|
| 1169 |
+
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
| 1170 |
+
kwargs_copy = kwargs.copy()
|
| 1171 |
+
|
| 1172 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 1173 |
+
force_download = kwargs.pop("force_download", False)
|
| 1174 |
+
proxies = kwargs.pop("proxies", None)
|
| 1175 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
| 1176 |
+
token = kwargs.pop("token", None)
|
| 1177 |
+
revision = kwargs.pop("revision", None)
|
| 1178 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 1179 |
+
|
| 1180 |
+
# Load config if we don't provide a configuration
|
| 1181 |
+
config_path = pretrained_model_name_or_path
|
| 1182 |
+
|
| 1183 |
+
user_agent = {
|
| 1184 |
+
"diffusers": __version__,
|
| 1185 |
+
"file_type": "model",
|
| 1186 |
+
"framework": "pytorch",
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
# load config
|
| 1190 |
+
config, _, _ = cls.load_config(
|
| 1191 |
+
config_path,
|
| 1192 |
+
cache_dir=cache_dir,
|
| 1193 |
+
return_unused_kwargs=True,
|
| 1194 |
+
return_commit_hash=True,
|
| 1195 |
+
force_download=force_download,
|
| 1196 |
+
proxies=proxies,
|
| 1197 |
+
local_files_only=local_files_only,
|
| 1198 |
+
token=token,
|
| 1199 |
+
revision=revision,
|
| 1200 |
+
subfolder=subfolder,
|
| 1201 |
+
user_agent=user_agent,
|
| 1202 |
+
**kwargs,
|
| 1203 |
+
)
|
| 1204 |
+
# resolve remapping
|
| 1205 |
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
| 1206 |
+
|
| 1207 |
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
models/sampling.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/lucidrains/muse-maskgit-pytorch
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def log(t, eps=1e-20):
|
| 11 |
+
return torch.log(t.clamp(min=eps))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def gumbel_noise(t, generator=None):
|
| 15 |
+
noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
|
| 16 |
+
return -log(-log(noise))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
|
| 20 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def top_k(logits, thres=0.9):
|
| 24 |
+
k = math.ceil((1 - thres) * logits.shape[-1])
|
| 25 |
+
val, ind = logits.topk(k, dim=-1)
|
| 26 |
+
probs = torch.full_like(logits, float("-inf"))
|
| 27 |
+
probs.scatter_(2, ind, val)
|
| 28 |
+
return probs
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
| 32 |
+
confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator)
|
| 33 |
+
sorted_confidence = torch.sort(confidence, dim=-1).values
|
| 34 |
+
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
| 35 |
+
masking = confidence < cut_off
|
| 36 |
+
return masking
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def cosine_schedule(t):
|
| 40 |
+
return torch.cos(t * math.pi * 0.5)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def linear_schedule(t):
|
| 44 |
+
mask_ratio = 1 - t
|
| 45 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
| 46 |
+
return mask_ratio
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def pow(t, method):
|
| 50 |
+
exponent = float(method.replace("pow", ""))
|
| 51 |
+
mask_ratio = 1.0 - t**exponent
|
| 52 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
| 53 |
+
return mask_ratio
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6):
|
| 57 |
+
for item in [t, start, end, tau]:
|
| 58 |
+
item = torch.tensor(item) if not torch.is_tensor(item) else item
|
| 59 |
+
|
| 60 |
+
# A gamma function based on sigmoid function.
|
| 61 |
+
v_start = torch.sigmoid(torch.tensor(start / tau))
|
| 62 |
+
v_end = torch.sigmoid(torch.tensor(end / tau))
|
| 63 |
+
output = torch.sigmoid((t * (end - start) + start) / tau)
|
| 64 |
+
output = (v_end - output) / (v_end - v_start)
|
| 65 |
+
return torch.clip(output, clip_min, 1.0)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_mask_schedule(method, **schedule_kwargs):
|
| 69 |
+
if method == "cosine":
|
| 70 |
+
return cosine_schedule
|
| 71 |
+
elif method == "linear":
|
| 72 |
+
return linear_schedule
|
| 73 |
+
elif "pow" in method:
|
| 74 |
+
return partial(pow, method=method)
|
| 75 |
+
elif method == "sigmoid":
|
| 76 |
+
return partial(sigmoid_schedule, **schedule_kwargs)
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError("Unknown schedule method: {}".format(method))
|
| 79 |
+
|
| 80 |
+
def top_k_top_p_filtering(
|
| 81 |
+
logits: torch.Tensor,
|
| 82 |
+
top_k: int = 0,
|
| 83 |
+
top_p: float = 1.0,
|
| 84 |
+
filter_value: float = -float("Inf"),
|
| 85 |
+
min_tokens_to_keep: int = 1,
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
| 88 |
+
Args:
|
| 89 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
| 90 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
| 91 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
| 92 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
| 93 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
| 94 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
| 95 |
+
"""
|
| 96 |
+
if top_k > 0:
|
| 97 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
| 98 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 99 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 100 |
+
logits[indices_to_remove] = filter_value
|
| 101 |
+
|
| 102 |
+
if top_p < 1.0:
|
| 103 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 104 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 105 |
+
|
| 106 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
| 107 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 108 |
+
if min_tokens_to_keep > 1:
|
| 109 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
| 110 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
| 111 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 112 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 113 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 114 |
+
|
| 115 |
+
# scatter sorted tensors to original indexing
|
| 116 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 117 |
+
logits[indices_to_remove] = filter_value
|
| 118 |
+
return logits
|
models/training_utils.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
from typing import Any, Dict, Iterable, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def enable_full_determinism(seed: int):
|
| 28 |
+
"""
|
| 29 |
+
Helper function for reproducible behavior during distributed training. See
|
| 30 |
+
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
|
| 31 |
+
"""
|
| 32 |
+
# set seed first
|
| 33 |
+
set_seed(seed)
|
| 34 |
+
|
| 35 |
+
# Enable PyTorch deterministic mode. This potentially requires either the environment
|
| 36 |
+
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
|
| 37 |
+
# depending on the CUDA version, so we set them both here
|
| 38 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 39 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
| 40 |
+
torch.use_deterministic_algorithms(True)
|
| 41 |
+
|
| 42 |
+
# Enable CUDNN deterministic mode
|
| 43 |
+
torch.backends.cudnn.deterministic = True
|
| 44 |
+
torch.backends.cudnn.benchmark = False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def set_seed(seed: int):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
| 51 |
+
seed (`int`): The seed to set.
|
| 52 |
+
"""
|
| 53 |
+
random.seed(seed)
|
| 54 |
+
np.random.seed(seed)
|
| 55 |
+
torch.manual_seed(seed)
|
| 56 |
+
torch.cuda.manual_seed_all(seed)
|
| 57 |
+
# ^^ safe to call this function even if cuda is not available
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
| 61 |
+
class EMA:
|
| 62 |
+
"""
|
| 63 |
+
Exponential Moving Average of models weights
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
parameters: Iterable[torch.nn.Parameter],
|
| 69 |
+
decay: float = 0.9999,
|
| 70 |
+
min_decay: float = 0.0,
|
| 71 |
+
update_after_step: int = 0,
|
| 72 |
+
use_ema_warmup: bool = False,
|
| 73 |
+
inv_gamma: Union[float, int] = 1.0,
|
| 74 |
+
power: Union[float, int] = 2 / 3,
|
| 75 |
+
model_cls: Optional[Any] = None,
|
| 76 |
+
model_config: Dict[str, Any] = None,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
| 82 |
+
decay (float): The decay factor for the exponential moving average.
|
| 83 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
| 84 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
| 85 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
| 86 |
+
inv_gamma (float):
|
| 87 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
| 88 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
| 89 |
+
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
| 90 |
+
weights will be stored on CPU.
|
| 91 |
+
|
| 92 |
+
@crowsonkb's notes on EMA Warmup:
|
| 93 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
| 94 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
| 95 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
| 96 |
+
at 215.4k steps).
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
parameters = list(parameters)
|
| 100 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
| 101 |
+
|
| 102 |
+
self.temp_stored_params = None
|
| 103 |
+
|
| 104 |
+
self.decay = decay
|
| 105 |
+
self.min_decay = min_decay
|
| 106 |
+
self.update_after_step = update_after_step
|
| 107 |
+
self.use_ema_warmup = use_ema_warmup
|
| 108 |
+
self.inv_gamma = inv_gamma
|
| 109 |
+
self.power = power
|
| 110 |
+
self.optimization_step = 0
|
| 111 |
+
self.cur_decay_value = None # set in `step()`
|
| 112 |
+
|
| 113 |
+
self.model_cls = model_cls
|
| 114 |
+
self.model_config = model_config
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def from_pretrained(cls, path, model_cls) -> "EMA":
|
| 118 |
+
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
|
| 119 |
+
model = model_cls.from_pretrained(path)
|
| 120 |
+
|
| 121 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
|
| 122 |
+
|
| 123 |
+
ema_model.load_state_dict(ema_kwargs)
|
| 124 |
+
return ema_model
|
| 125 |
+
|
| 126 |
+
def save_pretrained(self, path):
|
| 127 |
+
if self.model_cls is None:
|
| 128 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
| 129 |
+
|
| 130 |
+
if self.model_config is None:
|
| 131 |
+
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
| 132 |
+
|
| 133 |
+
model = self.model_cls.from_config(self.model_config)
|
| 134 |
+
state_dict = self.state_dict()
|
| 135 |
+
state_dict.pop("shadow_params", None)
|
| 136 |
+
|
| 137 |
+
model.register_to_config(**state_dict)
|
| 138 |
+
self.copy_to(model.parameters())
|
| 139 |
+
model.save_pretrained(path)
|
| 140 |
+
|
| 141 |
+
def get_decay(self, optimization_step: int) -> float:
|
| 142 |
+
"""
|
| 143 |
+
Compute the decay factor for the exponential moving average.
|
| 144 |
+
"""
|
| 145 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
| 146 |
+
|
| 147 |
+
if step <= 0:
|
| 148 |
+
return 0.0
|
| 149 |
+
|
| 150 |
+
if self.use_ema_warmup:
|
| 151 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
| 152 |
+
else:
|
| 153 |
+
cur_decay_value = (1 + step) / (10 + step)
|
| 154 |
+
|
| 155 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
| 156 |
+
# make sure decay is not smaller than min_decay
|
| 157 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
| 158 |
+
return cur_decay_value
|
| 159 |
+
|
| 160 |
+
@torch.no_grad()
|
| 161 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
| 162 |
+
parameters = list(parameters)
|
| 163 |
+
|
| 164 |
+
self.optimization_step += 1
|
| 165 |
+
|
| 166 |
+
# Compute the decay factor for the exponential moving average.
|
| 167 |
+
decay = self.get_decay(self.optimization_step)
|
| 168 |
+
self.cur_decay_value = decay
|
| 169 |
+
one_minus_decay = 1 - decay
|
| 170 |
+
|
| 171 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 172 |
+
if param.requires_grad:
|
| 173 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
| 174 |
+
else:
|
| 175 |
+
s_param.copy_(param)
|
| 176 |
+
|
| 177 |
+
torch.cuda.empty_cache()
|
| 178 |
+
|
| 179 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 180 |
+
"""
|
| 181 |
+
Copy current averaged parameters into given collection of parameters.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 185 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
| 186 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 187 |
+
"""
|
| 188 |
+
parameters = list(parameters)
|
| 189 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 190 |
+
param.data.copy_(s_param.to(param.device).data)
|
| 191 |
+
|
| 192 |
+
def to(self, device=None, dtype=None) -> None:
|
| 193 |
+
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
device: like `device` argument to `torch.Tensor.to`
|
| 197 |
+
"""
|
| 198 |
+
# .to() on the tensors handles None correctly
|
| 199 |
+
self.shadow_params = [
|
| 200 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
| 201 |
+
for p in self.shadow_params
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
def state_dict(self) -> dict:
|
| 205 |
+
r"""
|
| 206 |
+
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
| 207 |
+
checkpointing to save the ema state dict.
|
| 208 |
+
"""
|
| 209 |
+
# Following PyTorch conventions, references to tensors are returned:
|
| 210 |
+
# "returns a reference to the state and not its copy!" -
|
| 211 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
| 212 |
+
return {
|
| 213 |
+
"decay": self.decay,
|
| 214 |
+
"min_decay": self.min_decay,
|
| 215 |
+
"optimization_step": self.optimization_step,
|
| 216 |
+
"update_after_step": self.update_after_step,
|
| 217 |
+
"use_ema_warmup": self.use_ema_warmup,
|
| 218 |
+
"inv_gamma": self.inv_gamma,
|
| 219 |
+
"power": self.power,
|
| 220 |
+
"shadow_params": self.shadow_params,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 224 |
+
r"""
|
| 225 |
+
Args:
|
| 226 |
+
Save the current parameters for restoring later.
|
| 227 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 228 |
+
temporarily stored.
|
| 229 |
+
"""
|
| 230 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
| 231 |
+
|
| 232 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 233 |
+
r"""
|
| 234 |
+
Args:
|
| 235 |
+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
|
| 236 |
+
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
| 237 |
+
validation (or model saving), use this to restore the former parameters.
|
| 238 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 239 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
| 240 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 241 |
+
"""
|
| 242 |
+
if self.temp_stored_params is None:
|
| 243 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
| 244 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
| 245 |
+
param.data.copy_(c_param.data)
|
| 246 |
+
|
| 247 |
+
# Better memory-wise.
|
| 248 |
+
self.temp_stored_params = None
|
| 249 |
+
|
| 250 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 251 |
+
r"""
|
| 252 |
+
Args:
|
| 253 |
+
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
| 254 |
+
ema state dict.
|
| 255 |
+
state_dict (dict): EMA state. Should be an object returned
|
| 256 |
+
from a call to :meth:`state_dict`.
|
| 257 |
+
"""
|
| 258 |
+
# deepcopy, to be consistent with module API
|
| 259 |
+
state_dict = copy.deepcopy(state_dict)
|
| 260 |
+
|
| 261 |
+
self.decay = state_dict.get("decay", self.decay)
|
| 262 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
| 263 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 264 |
+
|
| 265 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
| 266 |
+
if not isinstance(self.min_decay, float):
|
| 267 |
+
raise ValueError("Invalid min_decay")
|
| 268 |
+
|
| 269 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
| 270 |
+
if not isinstance(self.optimization_step, int):
|
| 271 |
+
raise ValueError("Invalid optimization_step")
|
| 272 |
+
|
| 273 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
| 274 |
+
if not isinstance(self.update_after_step, int):
|
| 275 |
+
raise ValueError("Invalid update_after_step")
|
| 276 |
+
|
| 277 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
| 278 |
+
if not isinstance(self.use_ema_warmup, bool):
|
| 279 |
+
raise ValueError("Invalid use_ema_warmup")
|
| 280 |
+
|
| 281 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
| 282 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
| 283 |
+
raise ValueError("Invalid inv_gamma")
|
| 284 |
+
|
| 285 |
+
self.power = state_dict.get("power", self.power)
|
| 286 |
+
if not isinstance(self.power, (float, int)):
|
| 287 |
+
raise ValueError("Invalid power")
|
| 288 |
+
|
| 289 |
+
shadow_params = state_dict.get("shadow_params", None)
|
| 290 |
+
if shadow_params is not None:
|
| 291 |
+
self.shadow_params = shadow_params
|
| 292 |
+
if not isinstance(self.shadow_params, list):
|
| 293 |
+
raise ValueError("shadow_params must be a list")
|
| 294 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
| 295 |
+
raise ValueError("shadow_params must all be Tensors")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# calculates entropy over each pixel distribution
|
| 299 |
+
def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
|
| 300 |
+
# only calculated entropy over image tokens that were masked in the original image
|
| 301 |
+
masked_tokens = input_ids == mask_id
|
| 302 |
+
num_masked_pixels = masked_tokens.sum(-1)
|
| 303 |
+
|
| 304 |
+
probs = F.softmax(logits, dim=-1)
|
| 305 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 306 |
+
|
| 307 |
+
entropy_per_pixel = -((probs * log_probs).sum(-1))
|
| 308 |
+
|
| 309 |
+
# the predictions for non-masked aren't used, so set their entropies to zero
|
| 310 |
+
entropy_per_pixel[~masked_tokens] = 0
|
| 311 |
+
|
| 312 |
+
entropy_per_image_numerator = entropy_per_pixel.sum(-1)
|
| 313 |
+
entropy_per_image = entropy_per_image_numerator / num_masked_pixels
|
| 314 |
+
|
| 315 |
+
total_buckets = 10
|
| 316 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
| 317 |
+
|
| 318 |
+
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
|
| 319 |
+
|
| 320 |
+
return entropy_by_masked_bucket
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# calculates entropy over the averaged distribution of pixels for the whole image
|
| 324 |
+
def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
|
| 325 |
+
# only calculated entropy over image tokens that were masked in the original image
|
| 326 |
+
masked_tokens = input_ids == mask_id
|
| 327 |
+
num_masked_pixels = masked_tokens.sum(-1, keepdim=True)
|
| 328 |
+
|
| 329 |
+
pixel_probs = F.softmax(logits, dim=-1)
|
| 330 |
+
pixel_probs[~masked_tokens] = 0
|
| 331 |
+
image_probs_numerator = pixel_probs.sum(-2)
|
| 332 |
+
image_probs = image_probs_numerator / num_masked_pixels
|
| 333 |
+
|
| 334 |
+
image_log_probs = image_probs.log()
|
| 335 |
+
|
| 336 |
+
entropy_per_image = -((image_probs * image_log_probs).sum(-1))
|
| 337 |
+
|
| 338 |
+
total_buckets = 10
|
| 339 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
| 340 |
+
|
| 341 |
+
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
|
| 342 |
+
|
| 343 |
+
return entropy_by_masked_bucket
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing):
|
| 347 |
+
cross_entropy_per_image = F.cross_entropy(
|
| 348 |
+
logits.view(-1, output_size),
|
| 349 |
+
labels.view(-1),
|
| 350 |
+
ignore_index=-100,
|
| 351 |
+
label_smoothing=label_smoothing,
|
| 352 |
+
reduction="none",
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
total_buckets = 10
|
| 356 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
| 357 |
+
|
| 358 |
+
cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets)
|
| 359 |
+
|
| 360 |
+
return cross_entropy_by_percent_masked_bucket
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id):
|
| 364 |
+
probs = F.softmax(logits, dim=-1)
|
| 365 |
+
|
| 366 |
+
total_buckets = 10
|
| 367 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
| 368 |
+
|
| 369 |
+
data = []
|
| 370 |
+
|
| 371 |
+
for bucket_idx in range(total_buckets):
|
| 372 |
+
indices_for_bucket = masked_buckets[masked_buckets == bucket_idx]
|
| 373 |
+
|
| 374 |
+
# It's ok if none were noised in the range of this bucket. This
|
| 375 |
+
# function will be called for a later training step where it's likely
|
| 376 |
+
# there will be an element noised in the range.
|
| 377 |
+
if indices_for_bucket.shape[0] == 0:
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
index_for_bucket = indices_for_bucket[0]
|
| 381 |
+
|
| 382 |
+
image_probs = probs[index_for_bucket]
|
| 383 |
+
|
| 384 |
+
# find the index of a masked pixel for the image
|
| 385 |
+
input_ids_for_image = input_ids[index_for_bucket]
|
| 386 |
+
masked_pixels_probs = image_probs[input_ids_for_image == mask_id]
|
| 387 |
+
|
| 388 |
+
masked_pixel_probs = masked_pixels_probs[0]
|
| 389 |
+
|
| 390 |
+
masked_pixel_probs = masked_pixel_probs.cpu().numpy()
|
| 391 |
+
|
| 392 |
+
for masked_pixel_prob in masked_pixel_probs:
|
| 393 |
+
data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob})
|
| 394 |
+
|
| 395 |
+
df = pd.DataFrame(data)
|
| 396 |
+
|
| 397 |
+
return df
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def average_by_buckets(values, masked_buckets, total_buckets):
|
| 401 |
+
unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True)
|
| 402 |
+
|
| 403 |
+
numerator = torch.zeros(total_buckets, device=values.device)
|
| 404 |
+
|
| 405 |
+
numerator.scatter_add_(0, masked_buckets, values)
|
| 406 |
+
|
| 407 |
+
# default value is one because the buckets for which there aren't
|
| 408 |
+
# any values will have a numerator of zero. So we just need to not divide
|
| 409 |
+
# by zero.
|
| 410 |
+
denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long)
|
| 411 |
+
denominator[unique_buckets] = bucket_counts
|
| 412 |
+
|
| 413 |
+
averaged_by_buckets = numerator / denominator
|
| 414 |
+
|
| 415 |
+
return averaged_by_buckets
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10):
|
| 419 |
+
assert total_buckets == 10
|
| 420 |
+
|
| 421 |
+
masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1]
|
| 422 |
+
|
| 423 |
+
# we do not formally use timesteps to noise images. Instead, we mask a percent
|
| 424 |
+
# of the pixels. We don't want to log entropy for every mask percent between 0 and 1,
|
| 425 |
+
# and we also want to track how the entropy evolves over time w/in a range of mask
|
| 426 |
+
# percents that should have similar entropy. So we bucket the masked percents into a
|
| 427 |
+
# fixed number of buckets
|
| 428 |
+
|
| 429 |
+
# we could generalize this later if needed but for now, let's just assume a fixed
|
| 430 |
+
# number of 10 buckets.
|
| 431 |
+
|
| 432 |
+
# How this maps to a bucket index:
|
| 433 |
+
# (mask) * bucket_index +
|
| 434 |
+
# (mask_1) * bucket_index_1
|
| 435 |
+
#
|
| 436 |
+
# -> Where the mask is true will be set to the expected bucket index,
|
| 437 |
+
# where the mask is false will be set to 0.
|
| 438 |
+
#
|
| 439 |
+
# Given the probabilities are between 0 and 1, each masked_percent will get mapped
|
| 440 |
+
# to a timestep by one and only one of the masks.
|
| 441 |
+
|
| 442 |
+
masked_buckets = (
|
| 443 |
+
((0 < masked_percent) & (masked_percent <= 0.1)) * 0
|
| 444 |
+
+ ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1
|
| 445 |
+
+ ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2
|
| 446 |
+
+ ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3
|
| 447 |
+
+ ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4
|
| 448 |
+
+ ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5
|
| 449 |
+
+ ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6
|
| 450 |
+
+ ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7
|
| 451 |
+
+ ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8
|
| 452 |
+
+ ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
return masked_buckets
|
training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# from .mmada_grpo_trainer import DiffusionGRPOTrainer
|
training/prompting_utils.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 MMaDA team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
reserved_token_mapping = {
|
| 18 |
+
'<|soi|>': 126084,
|
| 19 |
+
'<|eoi|>': 126085,
|
| 20 |
+
'<|sov|>': 126086,
|
| 21 |
+
'<|eov|>': 126087,
|
| 22 |
+
'<|t2i|>': 126088,
|
| 23 |
+
'<|mmu|>': 126089,
|
| 24 |
+
'<|t2v|>': 126090,
|
| 25 |
+
'<|v2v|>': 126091,
|
| 26 |
+
'<|lvg|>': 126092,
|
| 27 |
+
'[iPAD]': 126093,
|
| 28 |
+
'<|r2i|>': 126094,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
class UniversalPrompting():
|
| 34 |
+
def __init__(self, text_tokenizer,
|
| 35 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
| 36 |
+
max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=False):
|
| 37 |
+
"""
|
| 38 |
+
:param text_tokenizer: original text tokenizer
|
| 39 |
+
"""
|
| 40 |
+
if not use_reserved_token:
|
| 41 |
+
self.text_tokenizer = text_tokenizer
|
| 42 |
+
self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 43 |
+
self.text_tokenizer.add_tokens(list(special_tokens))
|
| 44 |
+
self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
|
| 45 |
+
special_tokens}
|
| 46 |
+
self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
|
| 47 |
+
self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
|
| 48 |
+
self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
|
| 49 |
+
else:
|
| 50 |
+
self.text_tokenizer = text_tokenizer
|
| 51 |
+
self.sptids_dict = {}
|
| 52 |
+
for token, token_id in reserved_token_mapping.items():
|
| 53 |
+
self.sptids_dict[token] = torch.tensor([token_id])
|
| 54 |
+
self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
|
| 55 |
+
self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
|
| 56 |
+
end_header_tokens = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>'])
|
| 57 |
+
if end_header_tokens and len(end_header_tokens) > 0 and end_header_tokens[0]:
|
| 58 |
+
self.sptids_dict['<|end_header_id|>'] = torch.tensor(end_header_tokens)
|
| 59 |
+
self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>']))
|
| 60 |
+
self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>']))
|
| 61 |
+
else:
|
| 62 |
+
special_tokens_dict = {
|
| 63 |
+
'additional_special_tokens': [
|
| 64 |
+
'<|start_header_id|>',
|
| 65 |
+
'<|end_header_id|>',
|
| 66 |
+
'<|eot_id|>'
|
| 67 |
+
]
|
| 68 |
+
}
|
| 69 |
+
num_added = self.text_tokenizer.add_special_tokens(special_tokens_dict)
|
| 70 |
+
new_token_id = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>'])
|
| 71 |
+
self.sptids_dict['<|end_header_id|>'] = torch.tensor(new_token_id)
|
| 72 |
+
self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>']))
|
| 73 |
+
self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>']))
|
| 74 |
+
# plus 1 because at this time we add a task token before
|
| 75 |
+
print(f"self.sptids_dict: {self.sptids_dict}")
|
| 76 |
+
self.max_text_len = max_text_len + 1
|
| 77 |
+
self.pad_id = reserved_token_mapping['[iPAD]']
|
| 78 |
+
self.ignore_id = ignore_id
|
| 79 |
+
self.cond_dropout_prob = cond_dropout_prob
|
| 80 |
+
|
| 81 |
+
def t2i_prompt(self, text_ids, image_ids, labels):
|
| 82 |
+
|
| 83 |
+
device = image_ids.device
|
| 84 |
+
sequence_ids = []
|
| 85 |
+
attention_masks = []
|
| 86 |
+
label_ids = []
|
| 87 |
+
probs = torch.rand(len(text_ids))
|
| 88 |
+
for i in range(len(text_ids)):
|
| 89 |
+
|
| 90 |
+
if len(text_ids[i]) == 0:
|
| 91 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 92 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 93 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 94 |
+
|
| 95 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 96 |
+
|
| 97 |
+
# randomly dropout text condition
|
| 98 |
+
if probs[i] < self.cond_dropout_prob:
|
| 99 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
|
| 100 |
+
|
| 101 |
+
if self.max_text_len >= len(temp_ids):
|
| 102 |
+
old_len = len(temp_ids)
|
| 103 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 104 |
+
temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2)
|
| 105 |
+
else:
|
| 106 |
+
# should add the eos token
|
| 107 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 108 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens
|
| 109 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 110 |
+
temp_label_ids = torch.cat([
|
| 111 |
+
# should we predict text tokens when doing image reconstruction?
|
| 112 |
+
torch.tensor(temp_ids).to(device),
|
| 113 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 114 |
+
labels[i],
|
| 115 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 116 |
+
], dim=0)
|
| 117 |
+
|
| 118 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
| 119 |
+
|
| 120 |
+
temp_ids = torch.cat([
|
| 121 |
+
torch.tensor(temp_ids).to(device),
|
| 122 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 123 |
+
image_ids[i],
|
| 124 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 125 |
+
], dim=0)
|
| 126 |
+
|
| 127 |
+
# sequence_ids: [pad]...[pad] <|t2i|> <bos> text_1 ... text_n <eos> <|soi|> image_1 ... image_m <|eoi|>
|
| 128 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 129 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 130 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 131 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
| 132 |
+
|
| 133 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 134 |
+
|
| 135 |
+
def t2i_gen_prompt(self, text_ids, image_ids):
|
| 136 |
+
|
| 137 |
+
device = image_ids.device
|
| 138 |
+
sequence_ids = []
|
| 139 |
+
attention_masks = []
|
| 140 |
+
for i in range(len(text_ids)):
|
| 141 |
+
if len(text_ids[i]) == 0:
|
| 142 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 143 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 144 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 145 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
| 146 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 147 |
+
if self.max_text_len >= len(temp_ids):
|
| 148 |
+
old_len = len(temp_ids)
|
| 149 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 150 |
+
temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2)
|
| 151 |
+
else:
|
| 152 |
+
# should add the eos token
|
| 153 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 154 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens
|
| 155 |
+
|
| 156 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 157 |
+
temp_ids = torch.cat([
|
| 158 |
+
torch.tensor(temp_ids).to(device),
|
| 159 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 160 |
+
image_ids[i],
|
| 161 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 162 |
+
], dim=0)
|
| 163 |
+
|
| 164 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 165 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 166 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 167 |
+
|
| 168 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
|
| 169 |
+
|
| 170 |
+
# language modeling
|
| 171 |
+
def lm_prompt(self, text_ids, max_seq_len):
|
| 172 |
+
sequence_ids = []
|
| 173 |
+
attention_masks = []
|
| 174 |
+
label_ids = []
|
| 175 |
+
for i in range(len(text_ids)):
|
| 176 |
+
if len(text_ids[i]) == 0:
|
| 177 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 178 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 179 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 180 |
+
|
| 181 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 182 |
+
|
| 183 |
+
if max_seq_len >= len(temp_ids):
|
| 184 |
+
temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
| 185 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
| 186 |
+
temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids))
|
| 187 |
+
else:
|
| 188 |
+
# In language modeling, we only process text tokens. We do not add the eos token if the text length
|
| 189 |
+
# exceeds the max sequence length
|
| 190 |
+
temp_labels_ids = temp_ids[:max_seq_len]
|
| 191 |
+
temp_ids = temp_ids[:max_seq_len]
|
| 192 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
| 193 |
+
|
| 194 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 195 |
+
temp_ids = torch.tensor(temp_ids)
|
| 196 |
+
temp_masks = torch.tensor(temp_masks)
|
| 197 |
+
temp_labels_ids = torch.tensor(temp_labels_ids)
|
| 198 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 199 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 200 |
+
label_ids.append(temp_labels_ids.unsqueeze(0))
|
| 201 |
+
|
| 202 |
+
# input_ids, masks, labels
|
| 203 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 204 |
+
|
| 205 |
+
# language modeling
|
| 206 |
+
def lm_chat_prompt(self, text_ids, max_seq_len):
|
| 207 |
+
sequence_ids = []
|
| 208 |
+
prompt_masks = []
|
| 209 |
+
label_ids = []
|
| 210 |
+
|
| 211 |
+
for i in range(len(text_ids)):
|
| 212 |
+
if len(text_ids[i]) == 0:
|
| 213 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 214 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 215 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 216 |
+
|
| 217 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 218 |
+
|
| 219 |
+
if max_seq_len >= len(temp_ids):
|
| 220 |
+
temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
| 221 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
| 222 |
+
else:
|
| 223 |
+
# In language modeling, we only process text tokens. We do not add the eos token if the text length
|
| 224 |
+
# exceeds the max sequence length
|
| 225 |
+
temp_labels_ids = temp_ids[:max_seq_len]
|
| 226 |
+
temp_ids = temp_ids[:max_seq_len]
|
| 227 |
+
|
| 228 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
| 229 |
+
end_header_pos = -1
|
| 230 |
+
for pos in range(len(temp_ids) - 1, -1, -1): # 尝试从文本序列中寻找<|end_header_id|>
|
| 231 |
+
if temp_ids[pos] == end_header_id:
|
| 232 |
+
end_header_pos = pos
|
| 233 |
+
break
|
| 234 |
+
if end_header_pos != -1:
|
| 235 |
+
prompt_length = end_header_pos + 1
|
| 236 |
+
else:
|
| 237 |
+
prompt_length = 0
|
| 238 |
+
temp_masks = [1] * prompt_length + [0] * (len(temp_ids) - prompt_length)
|
| 239 |
+
|
| 240 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 241 |
+
temp_ids = torch.tensor(temp_ids)
|
| 242 |
+
temp_masks = torch.tensor(temp_masks)
|
| 243 |
+
temp_labels_ids = torch.tensor(temp_labels_ids)
|
| 244 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 245 |
+
prompt_masks.append(temp_masks.unsqueeze(0))
|
| 246 |
+
label_ids.append(temp_labels_ids.unsqueeze(0))
|
| 247 |
+
|
| 248 |
+
# input_ids, masks, labels
|
| 249 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 250 |
+
|
| 251 |
+
def mmu_prompt(self, image_ids, text_ids):
|
| 252 |
+
device = image_ids.device
|
| 253 |
+
sequence_ids = []
|
| 254 |
+
prompt_masks = []
|
| 255 |
+
label_ids = []
|
| 256 |
+
max_text_len = self.max_text_len - 1
|
| 257 |
+
for i in range(len(text_ids)):
|
| 258 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
| 259 |
+
# for empty list []
|
| 260 |
+
|
| 261 |
+
if len(text_ids[i]) == 0:
|
| 262 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 263 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 264 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 265 |
+
|
| 266 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 267 |
+
|
| 268 |
+
if max_text_len >= len(temp_ids):
|
| 269 |
+
# minus 1 because task token was prepended to the former image tokens
|
| 270 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_text_len - len(temp_ids))
|
| 271 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids))
|
| 272 |
+
else:
|
| 273 |
+
# should add the eos token
|
| 274 |
+
temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 275 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
| 276 |
+
|
| 277 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 278 |
+
temp_label_ids = torch.cat([
|
| 279 |
+
torch.tensor([self.ignore_id]).to(device),
|
| 280 |
+
torch.tensor([self.ignore_id]).to(device),
|
| 281 |
+
torch.ones_like(image_ids[i]) * self.ignore_id,
|
| 282 |
+
torch.tensor([self.ignore_id]).to(device),
|
| 283 |
+
torch.tensor(temp_ids).to(device),
|
| 284 |
+
], dim=0)
|
| 285 |
+
|
| 286 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
| 287 |
+
|
| 288 |
+
return_temp_ids = torch.cat([
|
| 289 |
+
self.sptids_dict['<|mmu|>'].to(device), # task token
|
| 290 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 291 |
+
image_ids[i],
|
| 292 |
+
self.sptids_dict['<|eoi|>'].to(device),
|
| 293 |
+
torch.tensor(temp_ids).to(device),
|
| 294 |
+
], dim=0)
|
| 295 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
| 296 |
+
end_header_pos = -1
|
| 297 |
+
for pos in range(len(temp_ids) - 1, -1, -1):
|
| 298 |
+
if temp_ids[pos] == end_header_id:
|
| 299 |
+
end_header_pos = pos
|
| 300 |
+
break
|
| 301 |
+
if end_header_pos != -1:
|
| 302 |
+
prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1
|
| 303 |
+
else:
|
| 304 |
+
prompt_length = len(return_temp_ids) - len(temp_ids)
|
| 305 |
+
predict_length = len(return_temp_ids) - prompt_length
|
| 306 |
+
prompt_mask = [1] * prompt_length + [0] * predict_length
|
| 307 |
+
prompt_mask = torch.tensor(prompt_mask).to(device)
|
| 308 |
+
sequence_ids.append(return_temp_ids.unsqueeze(0))
|
| 309 |
+
prompt_masks.append(prompt_mask.unsqueeze(0))
|
| 310 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
| 311 |
+
|
| 312 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 313 |
+
|
| 314 |
+
def mmu_gen_prompt(self, image_ids, text_ids):
|
| 315 |
+
device = image_ids.device
|
| 316 |
+
sequence_ids = []
|
| 317 |
+
prompt_masks = []
|
| 318 |
+
max_text_len = self.max_text_len - 1
|
| 319 |
+
for i in range(len(text_ids)):
|
| 320 |
+
|
| 321 |
+
if len(text_ids[i]) == 0:
|
| 322 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 323 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 324 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 325 |
+
|
| 326 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 327 |
+
|
| 328 |
+
if max_text_len >= len(temp_ids):
|
| 329 |
+
# minus 1 because task token was prepended to the former image tokens
|
| 330 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_text_len - len(temp_ids))
|
| 331 |
+
else:
|
| 332 |
+
# should add the eos token
|
| 333 |
+
temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 334 |
+
|
| 335 |
+
# print(f"mmu temp_ids: {temp_ids}")
|
| 336 |
+
return_temp_ids = torch.cat([
|
| 337 |
+
self.sptids_dict['<|mmu|>'].to(device), # task token
|
| 338 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 339 |
+
image_ids[i],
|
| 340 |
+
self.sptids_dict['<|eoi|>'].to(device),
|
| 341 |
+
torch.tensor(temp_ids).to(device),
|
| 342 |
+
], dim=0)
|
| 343 |
+
|
| 344 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
| 345 |
+
end_header_pos = -1
|
| 346 |
+
for pos in range(len(temp_ids) - 1, -1, -1):
|
| 347 |
+
if temp_ids[pos] == end_header_id:
|
| 348 |
+
end_header_pos = pos
|
| 349 |
+
break
|
| 350 |
+
if end_header_pos != -1:
|
| 351 |
+
prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1
|
| 352 |
+
else:
|
| 353 |
+
prompt_length = len(return_temp_ids) - len(temp_ids)
|
| 354 |
+
predict_length = len(temp_ids) - prompt_length
|
| 355 |
+
print(f"prompt_length: {prompt_length}, predict_length: {predict_length}, all length: {len(return_temp_ids)}, {return_temp_ids[-predict_length:]}")
|
| 356 |
+
prompt_mask = [1] * prompt_length + [0] * predict_length
|
| 357 |
+
prompt_mask = torch.tensor(prompt_mask).to(device)
|
| 358 |
+
sequence_ids.append(return_temp_ids.unsqueeze(0))
|
| 359 |
+
prompt_masks.append(prompt_mask.unsqueeze(0))
|
| 360 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0)
|
| 361 |
+
|
| 362 |
+
def r2i_prompt(self, image_ids, text_ids):
|
| 363 |
+
device = image_ids.device
|
| 364 |
+
sequence_ids = []
|
| 365 |
+
prompt_masks = []
|
| 366 |
+
label_ids = []
|
| 367 |
+
r2i_id = int(self.sptids_dict['<|r2i|>'])
|
| 368 |
+
soi_id = int(self.sptids_dict['<|soi|>'])
|
| 369 |
+
eoi_id = int(self.sptids_dict['<|eoi|>'])
|
| 370 |
+
max_text_len = self.max_text_len - 1 # 512,include BOS text EOS
|
| 371 |
+
for i in range(len(text_ids)):
|
| 372 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
| 373 |
+
# for empty list []
|
| 374 |
+
if len(text_ids[i]) == 0:
|
| 375 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 376 |
+
elif text_ids[i][0]!= self.text_tokenizer.bos_token_id:
|
| 377 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 378 |
+
text_ids_with_bos_eos = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 379 |
+
if max_text_len >= len(text_ids_with_bos_eos):
|
| 380 |
+
# minus 1 because task token was prepended to the former image tokens
|
| 381 |
+
text_ids_full_len = text_ids_with_bos_eos + [self.text_tokenizer.eos_token_id] * (max_text_len - len(text_ids_with_bos_eos))
|
| 382 |
+
else:
|
| 383 |
+
# should add the eos token
|
| 384 |
+
text_ids_full_len = text_ids_with_bos_eos[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 385 |
+
|
| 386 |
+
sequence_ids.append(torch.cat([
|
| 387 |
+
torch.tensor([r2i_id]).to(device), # task token
|
| 388 |
+
torch.tensor(text_ids_full_len).to(device),
|
| 389 |
+
torch.tensor([soi_id]).to(device),
|
| 390 |
+
image_ids[i],
|
| 391 |
+
torch.tensor([eoi_id]).to(device),
|
| 392 |
+
], dim=0).unsqueeze(0))
|
| 393 |
+
|
| 394 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
| 395 |
+
end_header_pos = -1
|
| 396 |
+
for pos in range(len(text_ids_full_len) - 1, -1, -1):
|
| 397 |
+
if text_ids_full_len[pos] == end_header_id:
|
| 398 |
+
end_header_pos = pos
|
| 399 |
+
break
|
| 400 |
+
prompt_mask = torch.zeros(sequence_ids[i].size(1)).to(device)
|
| 401 |
+
prompt_mask[0] = 1 # task_id
|
| 402 |
+
if end_header_pos != -1:
|
| 403 |
+
prompt_mask[1:end_header_pos+2] = 1
|
| 404 |
+
else:
|
| 405 |
+
prompt_mask[1:len(text_ids_full_len)+1] = 1
|
| 406 |
+
prompt_mask[len(text_ids_full_len)+1] = 1
|
| 407 |
+
prompt_mask[len(text_ids_full_len)+2+len(image_ids[i])] = 1
|
| 408 |
+
prompt_masks.append(prompt_mask.unsqueeze(0))
|
| 409 |
+
|
| 410 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(sequence_ids, dim=0)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def mask_prompt(self):
|
| 415 |
+
pass
|
| 416 |
+
|
| 417 |
+
def __call__(self, input, task, padding=True, config=None):
|
| 418 |
+
"""
|
| 419 |
+
input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor).
|
| 420 |
+
task (str) : a flag indicates the current task.
|
| 421 |
+
"""
|
| 422 |
+
if task == "t2i":
|
| 423 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 424 |
+
image_ids = input[1] # (B, #tokens)
|
| 425 |
+
sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2])
|
| 426 |
+
|
| 427 |
+
elif task == "t2v":
|
| 428 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 429 |
+
image_ids = input[1] # (B, #tokens)
|
| 430 |
+
sequence_ids_with_masks = self.t2v_prompt(text_ids, image_ids, input[2])
|
| 431 |
+
|
| 432 |
+
elif task == "t2i_plus_lm":
|
| 433 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 434 |
+
image_ids = input[1] # (B, #tokens)
|
| 435 |
+
sequence_ids_with_masks = self.t2i_prompt(text_ids[:config.training.batch_size], image_ids,
|
| 436 |
+
input[2])
|
| 437 |
+
sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3])
|
| 438 |
+
return sequence_ids_with_masks, sequence_ids_with_masks_lm
|
| 439 |
+
|
| 440 |
+
elif task == "t2i_gen":
|
| 441 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 442 |
+
image_ids = input[1] # (B, #tokens)
|
| 443 |
+
sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids)
|
| 444 |
+
|
| 445 |
+
elif task == "t2v_gen":
|
| 446 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 447 |
+
image_ids = input[1] # (B, #tokens)
|
| 448 |
+
sequence_ids_with_masks = self.t2v_gen_prompt(text_ids, image_ids)
|
| 449 |
+
|
| 450 |
+
elif task == "lm":
|
| 451 |
+
text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
|
| 452 |
+
sequence_ids_with_masks = self.lm_prompt(text_ids, input[1])
|
| 453 |
+
|
| 454 |
+
elif task == "lm_chat":
|
| 455 |
+
text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
|
| 456 |
+
sequence_ids_with_masks = self.lm_chat_prompt(text_ids, input[1])
|
| 457 |
+
|
| 458 |
+
elif task == "mmu":
|
| 459 |
+
image_ids = input[0]
|
| 460 |
+
text_ids = self.text_tokenizer(input[1])['input_ids']
|
| 461 |
+
sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids)
|
| 462 |
+
|
| 463 |
+
elif task == "r2i":
|
| 464 |
+
image_ids = input[0]
|
| 465 |
+
text_ids = self.text_tokenizer(input[1])['input_ids']
|
| 466 |
+
sequence_ids_with_masks = self.r2i_prompt(image_ids, text_ids)
|
| 467 |
+
|
| 468 |
+
else:
|
| 469 |
+
raise NotImplementedError
|
| 470 |
+
|
| 471 |
+
return sequence_ids_with_masks
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
if __name__ == '__main__':
|
| 475 |
+
pass
|