ARPG / arpg.py
hp-l33
Add pipeline
d09f0be
# Modified from:
# LlamaGen: https://github.com/FoundationVision/LlamaGen/
# YOCO: https://github.com/microsoft/unilm/tree/master/YOCO
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
from typing import Dict, List, Optional
from dataclasses import dataclass
from transformers.configuration_utils import PretrainedConfig
def find_multiple(n: int, k: int):
if n % k == 0:
return n
return n + k - (n % k)
def batch_seq_shuffle(x, orders=None):
assert x.ndim >= 2, "The input should contain at least two dimensions, batch and length"
bs, seq_len = x.shape[:2]
if orders is None:
orders = torch.rand(bs, seq_len, device=x.device).argsort(dim=1)
orders_expand = orders.view(*orders.shape, *(1,) * (x.ndim - orders.ndim))
shuffled_data = torch.gather(x, 1, orders_expand.expand(*x.shape))
return shuffled_data, orders
# @dataclass
class ModelArgs(PretrainedConfig):
def __init__(
self,
dim: int = 4096,
n_layer: int = 32,
n_head: int = 32,
multiple_of: int = 256, # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None,
rope_base: float = 10000,
norm_eps: float = 1e-5,
initializer_range: float = 0.02,
token_dropout_p: float = 0.1,
attn_dropout_p: float = 0.0,
resid_dropout_p: float = 0.1,
ffn_dropout_p: float = 0.1,
drop_path_rate: float = 0.0,
num_classes: int = 1000,
class_dropout_prob: float = 0.1,
model_type: str = 'c2i',
vocab_size: int = 16384,
cls_token_num: int = 1,
block_size: int = 256,
):
self.dim = dim
self.n_layer = n_layer
self.n_head = n_head
self.multiple_of = multiple_of
self.ffn_dim_multiplier = ffn_dim_multiplier
self.rope_base = rope_base
self.norm_eps = norm_eps
self.initializer_range = initializer_range
self.token_dropout_p = token_dropout_p
self.attn_dropout_p = attn_dropout_p
self.resid_dropout_p = resid_dropout_p
self.ffn_dropout_p = ffn_dropout_p
self.drop_path_rate = drop_path_rate
self.num_classes = num_classes
self.class_dropout_prob = class_dropout_prob
self.model_type = model_type
self.vocab_size = vocab_size
self.cls_token_num = cls_token_num
self.block_size = block_size
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if config.ffn_dim_multiplier is not None:
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
hidden_dim = find_multiple(hidden_dim, config.multiple_of)
self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
def forward(self, x):
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
self.dim = config.dim
self.n_head = config.n_head
self.head_dim = config.dim // config.n_head
self.to_q = nn.Linear(config.dim, config.dim, bias=False)
self.to_k = nn.Linear(config.dim, config.dim, bias=False)
self.to_v = nn.Linear(config.dim, config.dim, bias=False)
self.proj = nn.Linear(config.dim, config.dim, bias=False)
self.attn_drop = config.attn_dropout_p
self.proj_drop = nn.Dropout(config.resid_dropout_p)
self.kv_cache = False
self.k_cache = None
self.v_cache = None
def reset_kv_cache(self):
self.k_cache = None
self.v_cache = None
def update_kv_cache(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache is None and self.v_cache is None:
k_cache = k
v_cache = v
else:
k_cache = torch.cat([self.k_cache, k], dim=-2)
v_cache = torch.cat([self.v_cache, v], dim=-2)
self.k_cache = k_cache
self.v_cache = v_cache
return k_cache, v_cache
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor = None
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=self.n_head), (q, k, v))
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache:
k, v = self.update_kv_cache(k, v)
output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
is_causal=True if self.training else False,
dropout_p=self.attn_drop if self.training else 0
)
output = rearrange(output, 'b h n d -> b n (h d)').contiguous()
output = self.proj_drop(self.proj(output))
return output
class CrossAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
self.dim = config.dim
self.n_head = config.n_head
self.head_dim = config.dim // config.n_head
self.to_q = nn.Linear(config.dim, config.dim, bias=False)
self.proj = nn.Linear(config.dim, config.dim, bias=False)
self.attn_drop = config.attn_dropout_p
self.proj_drop = nn.Dropout(config.resid_dropout_p)
self.kv_cache = False
self.k_cache = None
self.v_cache = None
def reset_kv_cache(self):
self.k_cache = None
self.v_cache = None
def update_kv_cache(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache is None and self.v_cache is None:
k_cache = k
v_cache = v
else:
k_cache = torch.cat([self.k_cache, k], dim=-2)
v_cache = torch.cat([self.v_cache, v], dim=-2)
self.k_cache = k_cache
self.v_cache = v_cache
return k_cache, v_cache
def forward(
self,
x: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
freqs_cis: torch.Tensor = None
):
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b n h d', h=self.n_head)
# target-aware
q = apply_rotary_emb(q, freqs_cis[:, -q.shape[1]:, ...])
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache:
k, v = self.update_kv_cache(k, v)
output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
is_causal=True if self.training else False,
dropout_p=self.attn_drop if self.training else 0
)
output = rearrange(output, 'b h n d -> b n (h d)').contiguous()
output = self.proj_drop(self.proj(output))
return output
class SelfDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.attn = Attention(config)
self.ffn = FeedForward(config)
self.attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor = None
):
h = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis[:, :x.shape[1], ...])
out = h + self.ffn(self.ffn_norm(h))
return out
class CrossDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.attn = CrossAttention(config)
self.ffn = FeedForward(config)
self.attn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
def forward(
self,
x: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
freqs_cis: torch.Tensor = None
):
h = x + self.attn(x=self.attn_norm(x), k=k, v=v, freqs_cis=freqs_cis)
out = h + self.ffn(self.ffn_norm(h))
return out
class Decoder_Decoder(nn.Module):
def __init__(self, config: ModelArgs, n_layer):
super().__init__()
self.config = config
self.self_dec = nn.ModuleList([SelfDecoder(config) for _ in range(n_layer//2)])
self.cross_dec = nn.ModuleList([CrossDecoder(config) for _ in range(n_layer//2)])
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.to_k = nn.Linear(config.dim, config.dim, bias=False)
self.to_v = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = False
self.k_cache = None
self.v_cache = None
def reset_kv_cache(self):
self.k_cache = None
self.v_cache = None
def update_kv_cache(self, k: torch.Tensor, v: torch.Tensor, head_first=False):
t_dim = 2 if head_first else 1
if self.k_cache is None and self.v_cache is None:
k_cache = k
v_cache = v
else:
k_cache = torch.cat([self.k_cache, k], dim=t_dim)
v_cache = torch.cat([self.v_cache, v], dim=t_dim)
self.k_cache = k_cache
self.v_cache = v_cache
return k_cache, v_cache
def forward(
self,
x: torch.Tensor,
q: torch.Tensor,
freqs_cis: torch.Tensor = None
):
for layer in self.self_dec:
x = layer(x=x, freqs_cis=freqs_cis)
x_norm = self.norm(x)
k = self.to_k(x_norm)
v = self.to_v(x_norm)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=self.config.n_head), (k, v))
k = apply_rotary_emb(k, freqs_cis[:, :k.shape[1], ...])
if self.kv_cache:
k, v = self.update_kv_cache(k, v)
for layer in self.cross_dec:
q = layer(x=q, k=k, v=v, freqs_cis=freqs_cis)
return q
class Transformer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.image_seq_len = config.block_size
"""
ref: https://github.com/bytedance/1d-tokenizer/blob/main/modeling/rar.py
Token space:
[0, vocab_size - 1] : those are the learned quantized image tokens
[vocab_size] : the mask token id
[vocab_size + 1, vocab_size + num_classes] : the imagenet class tokens
[vocab_size + num_classes + 1] : the class drop label
[vocab_size + num_classes + 2] : the drop token for scg
"""
self.embeddings = nn.Embedding(config.vocab_size + 1 + config.num_classes + 1 + 1, config.dim)
self.embed_drop = nn.Dropout(config.token_dropout_p)
self.mask_token_id = config.vocab_size
self.none_conds_id = config.vocab_size + config.num_classes + 1
self.none_token_id = config.vocab_size + config.num_classes + 2
# 2-pass decoder
self.layers = Decoder_Decoder(config, config.n_layer)
# output layer
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.head = nn.Linear(config.dim, config.vocab_size, bias=False)
# 2d rotary pos embedding
grid_size = int(self.image_seq_len ** 0.5)
self.freqs_cis = precompute_freqs_cis_2d(grid_size, config.dim // config.n_head, config.rope_base, config.cls_token_num)
self.initialize_weights()
def initialize_weights(self):
# Initialize nn.Linear and nn.Embedding
self.apply(self._init_weights)
# Zero-out output layers:
nn.init.constant_(self.head.weight, 0)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
def setup_kv_cache(self, enable=True):
for block in self.layers.self_dec:
block.attn.kv_cache = enable
block.attn.reset_kv_cache()
self.layers.kv_cache = enable
self.layers.reset_kv_cache()
def preprocess_condition(self, condition, cond_drop_prob=0.0):
# Set class condition to None condition
drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
condition = condition + self.config.vocab_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999]
condition[drop_label_mask] = self.none_conds_id
if condition.ndim == 1:
condition = condition.unsqueeze(-1)
return condition
def forward_shared(self, input_ids, freqs_cis, num_query=None):
embedds = self.embeddings(input_ids)
x = self.embed_drop(embedds)
num_query = input_ids.shape[-1] if num_query == None else num_query
queries = self.embeddings(torch.full((input_ids.shape[0], num_query), self.mask_token_id, device=input_ids.device))
x = self.layers(x, queries, freqs_cis=freqs_cis)
logits = self.head(self.norm(x)).float()
return logits
def forward(self, input_ids, condition, targets=None, debug=False):
# shift class id and dropout for classifier-free guidance
condition = self.preprocess_condition(condition, cond_drop_prob=self.config.class_dropout_prob)
# shuffle input
shuffled_ids, orders = batch_seq_shuffle(input_ids)
# shuffle RoPE
freqs_cis = self.freqs_cis.unsqueeze(0).repeat(input_ids.shape[0], 1, 1, 1).to(input_ids.device)
fixed_freqs_cis = freqs_cis[:, :1, ...]
shuffled_freqs_cis = batch_seq_shuffle(freqs_cis[:, 1:, ...], orders)[0]
freqs_cis = torch.cat([fixed_freqs_cis, shuffled_freqs_cis], dim=1)
# teacher-forcing input
logits = self.forward_shared(torch.cat([condition, shuffled_ids[:, :-1]], dim=-1), freqs_cis)
loss = None
if targets is not None:
targets = batch_seq_shuffle(targets, orders)[0]
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.inference_mode()
def generate(
self,
condition,
guidance_scale=4.0,
cfg_schedule='linear',
sample_schedule='arccos',
temperature=1.0,
top_k=0,
top_p=1,
seq_len=256,
num_iter=64,
):
device = condition.device
num_samples = condition.shape[0]
freqs_cis_ = self.freqs_cis.unsqueeze(0).to(device)
# shift condition id
condition = self.preprocess_condition(condition, cond_drop_prob=0.0)
# generate a random order
orders = torch.rand(256, device=device).argsort(dim=0) + 1
last_pos = 0
last_range = range(0, 1) # for class token, hardcode
sequences = []
self.setup_kv_cache(enable=True)
for step in range(num_iter):
if sample_schedule == 'arccos':
mask_ratio = np.arccos(1. * (step + 1) / num_iter) / (math.pi * 0.5)
elif sample_schedule == 'cosine':
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
else:
raise NotImplementedError
mask_len = int(seq_len * mask_ratio)
mask_len = max(1, min(seq_len - last_pos - 1, mask_len))
num_pred = seq_len - last_pos - mask_len
if step == num_iter - 1:
num_pred = seq_len - last_pos
next_range = orders[range(last_pos, last_pos + num_pred)]
last_pos += num_pred
if cfg_schedule == 'linear':
cfg_scale = 1.0 + (guidance_scale - 1.0) * last_pos / seq_len
elif cfg_schedule == 'constant':
cfg_scale = guidance_scale
else:
raise NotImplementedError
"""
1. Since the cached key has already had rotary embedding applied,
we only need to input the current position's frequencies for key.
2. We need the next position's frequencies for query to achieve target-aware guidance.
"""
freqs_cis = torch.cat([
freqs_cis_[:, last_range, ...],
freqs_cis_[:, next_range, ...]], dim=1
)
if guidance_scale != 0:
if step == 0:
input_ids = torch.cat([condition, torch.full_like(condition, self.none_conds_id)], dim=0)
else:
input_ids = torch.cat([sequences[-1], sequences[-1]], dim=0)
logits = self.forward_shared(input_ids, freqs_cis, num_pred)
cond_logits, uncond_logits = logits[:num_samples], logits[num_samples:]
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
else:
raise NotImplementedError
# keep the logits of last n-tokens
logits = logits[:, -num_pred:] / max(temperature, 1e-5)
if top_k > 0 or top_p < 1.0:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
sampled = torch.multinomial(probs.flatten(0, 1), num_samples=1)
sequences.append(sampled.reshape(num_samples, -1))
last_range = next_range
self.setup_kv_cache(enable=False)
sequences = torch.cat(sequences, dim=-1)
return sequences[:, orders.argsort(dim=0)]
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
return cond_cache
def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
# split the dimension into half, one for x and one for y
half_dim = n_elem // 2
freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
t = torch.arange(grid_size, device=freqs.device)
freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
freqs_grid = torch.concat([
freqs[:, None, :].expand(-1, grid_size, -1),
freqs[None, :, :].expand(grid_size, -1, -1),
], dim=-1) # (grid_size, grid_size, head_dim // 2)
cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
cache = cache_grid.flatten(0, 1)
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
return cond_cache
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
# x: (bs, seq_len, n_head, head_dim)
# freqs_cis (seq_len, head_dim // 2, 2)
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
x_out2 = torch.stack([
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
], dim=-1)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
def top_k_top_p_filtering(
logits,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def ARPG_XXL(**kwargs):
return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs))
def ARPG_XL(**kwargs):
return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs))
def ARPG_L(**kwargs):
return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs))
ARPG_models = {'ARPG-L': ARPG_L, 'ARPG-XL': ARPG_XL, 'ARPG-XXL': ARPG_XXL}