|
|
|
|
|
|
|
|
|
import math |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from torch.nn import functional as F |
|
from einops import rearrange |
|
from typing import Dict, List, Optional |
|
from dataclasses import dataclass |
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
def find_multiple(n: int, k: int): |
|
if n % k == 0: |
|
return n |
|
return n + k - (n % k) |
|
|
|
|
|
def batch_seq_shuffle(x, orders=None): |
|
assert x.ndim >= 2, "The input should contain at least two dimensions, batch and length" |
|
bs, seq_len = x.shape[:2] |
|
|
|
if orders is None: |
|
orders = torch.rand(bs, seq_len, device=x.device).argsort(dim=1) |
|
|
|
orders_expand = orders.view(*orders.shape, *(1,) * (x.ndim - orders.ndim)) |
|
shuffled_data = torch.gather(x, 1, orders_expand.expand(*x.shape)) |
|
|
|
return shuffled_data, orders |
|
|
|
|
|
|
|
class ModelArgs(PretrainedConfig): |
|
def __init__( |
|
self, |
|
dim: int = 4096, |
|
n_layer: int = 32, |
|
n_head: int = 32, |
|
multiple_of: int = 256, |
|
ffn_dim_multiplier: Optional[float] = None, |
|
rope_base: float = 10000, |
|
norm_eps: float = 1e-5, |
|
initializer_range: float = 0.02, |
|
token_dropout_p: float = 0.1, |
|
attn_dropout_p: float = 0.0, |
|
resid_dropout_p: float = 0.1, |
|
ffn_dropout_p: float = 0.1, |
|
drop_path_rate: float = 0.0, |
|
num_classes: int = 1000, |
|
class_dropout_prob: float = 0.1, |
|
model_type: str = 'c2i', |
|
vocab_size: int = 16384, |
|
cls_token_num: int = 1, |
|
block_size: int = 256, |
|
): |
|
self.dim = dim |
|
self.n_layer = n_layer |
|
self.n_head = n_head |
|
self.multiple_of = multiple_of |
|
self.ffn_dim_multiplier = ffn_dim_multiplier |
|
self.rope_base = rope_base |
|
self.norm_eps = norm_eps |
|
self.initializer_range = initializer_range |
|
|
|
self.token_dropout_p = token_dropout_p |
|
self.attn_dropout_p = attn_dropout_p |
|
self.resid_dropout_p = resid_dropout_p |
|
self.ffn_dropout_p = ffn_dropout_p |
|
self.drop_path_rate = drop_path_rate |
|
|
|
self.num_classes = num_classes |
|
self.class_dropout_prob = class_dropout_prob |
|
self.model_type = model_type |
|
self.vocab_size = vocab_size |
|
self.cls_token_num = cls_token_num |
|
self.block_size = block_size |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def _norm(self, x): |
|
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
output = self._norm(x.float()).type_as(x) |
|
return output * self.weight |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
hidden_dim = 4 * config.dim |
|
hidden_dim = int(2 * hidden_dim / 3) |
|
|
|
if config.ffn_dim_multiplier is not None: |
|
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) |
|
hidden_dim = find_multiple(hidden_dim, config.multiple_of) |
|
|
|
self.w1 = nn.Linear(config.dim, hidden_dim, bias=False) |
|
self.w3 = nn.Linear(config.dim, hidden_dim, bias=False) |
|
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False) |
|
self.ffn_dropout = nn.Dropout(config.ffn_dropout_p) |
|
|
|
def forward(self, x): |
|
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
assert config.dim % config.n_head == 0 |
|
self.dim = config.dim |
|
self.n_head = config.n_head |
|
self.head_dim = config.dim // config.n_head |
|
|
|
self.to_q = nn.Linear(config.dim, config.dim, bias=False) |
|
self.to_k = nn.Linear(config.dim, config.dim, bias=False) |
|
self.to_v = nn.Linear(config.dim, config.dim, bias=False) |
|
|
|
self.proj = nn.Linear(config.dim, config.dim, bias=False) |
|
|
|
self.attn_drop = config.attn_dropout_p |
|
self.proj_drop = nn.Dropout(config.resid_dropout_p) |
|
|
|
self.kv_cache = False |
|
self.k_cache = None |
|
self.v_cache = None |
|
|
|
def reset_kv_cache(self): |
|
self.k_cache = None |
|
self.v_cache = None |
|
|
|
def update_kv_cache(self, k: torch.Tensor, v: torch.Tensor): |
|
if self.k_cache is None and self.v_cache is None: |
|
k_cache = k |
|
v_cache = v |
|
else: |
|
k_cache = torch.cat([self.k_cache, k], dim=-2) |
|
v_cache = torch.cat([self.v_cache, v], dim=-2) |
|
|
|
self.k_cache = k_cache |
|
self.v_cache = v_cache |
|
|
|
return k_cache, v_cache |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
freqs_cis: torch.Tensor = None |
|
): |
|
|
|
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=self.n_head), (q, k, v)) |
|
|
|
q = apply_rotary_emb(q, freqs_cis) |
|
k = apply_rotary_emb(k, freqs_cis) |
|
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) |
|
|
|
if self.kv_cache: |
|
k, v = self.update_kv_cache(k, v) |
|
|
|
output = F.scaled_dot_product_attention( |
|
q, k, v, |
|
attn_mask=None, |
|
is_causal=True if self.training else False, |
|
dropout_p=self.attn_drop if self.training else 0 |
|
) |
|
output = rearrange(output, 'b h n d -> b n (h d)').contiguous() |
|
output = self.proj_drop(self.proj(output)) |
|
return output |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
assert config.dim % config.n_head == 0 |
|
self.dim = config.dim |
|
self.n_head = config.n_head |
|
self.head_dim = config.dim // config.n_head |
|
|
|
self.to_q = nn.Linear(config.dim, config.dim, bias=False) |
|
|
|
self.proj = nn.Linear(config.dim, config.dim, bias=False) |
|
|
|
self.attn_drop = config.attn_dropout_p |
|
self.proj_drop = nn.Dropout(config.resid_dropout_p) |
|
|
|
self.kv_cache = False |
|
self.k_cache = None |
|
self.v_cache = None |
|
|
|
def reset_kv_cache(self): |
|
self.k_cache = None |
|
self.v_cache = None |
|
|
|
def update_kv_cache(self, k: torch.Tensor, v: torch.Tensor): |
|
if self.k_cache is None and self.v_cache is None: |
|
k_cache = k |
|
v_cache = v |
|
else: |
|
k_cache = torch.cat([self.k_cache, k], dim=-2) |
|
v_cache = torch.cat([self.v_cache, v], dim=-2) |
|
|
|
self.k_cache = k_cache |
|
self.v_cache = v_cache |
|
|
|
return k_cache, v_cache |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
freqs_cis: torch.Tensor = None |
|
): |
|
q = self.to_q(x) |
|
q = rearrange(q, 'b n (h d) -> b n h d', h=self.n_head) |
|
|
|
|
|
q = apply_rotary_emb(q, freqs_cis[:, -q.shape[1]:, ...]) |
|
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) |
|
|
|
if self.kv_cache: |
|
k, v = self.update_kv_cache(k, v) |
|
|
|
output = F.scaled_dot_product_attention( |
|
q, k, v, |
|
attn_mask=None, |
|
is_causal=True if self.training else False, |
|
dropout_p=self.attn_drop if self.training else 0 |
|
) |
|
output = rearrange(output, 'b h n d -> b n (h d)').contiguous() |
|
output = self.proj_drop(self.proj(output)) |
|
return output |
|
|
|
|
|
class SelfDecoder(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
self.attn = Attention(config) |
|
self.ffn = FeedForward(config) |
|
|
|
self.attn_norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
freqs_cis: torch.Tensor = None |
|
): |
|
h = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis[:, :x.shape[1], ...]) |
|
out = h + self.ffn(self.ffn_norm(h)) |
|
|
|
return out |
|
|
|
|
|
class CrossDecoder(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
self.attn = CrossAttention(config) |
|
self.ffn = FeedForward(config) |
|
|
|
self.attn_norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
freqs_cis: torch.Tensor = None |
|
): |
|
h = x + self.attn(x=self.attn_norm(x), k=k, v=v, freqs_cis=freqs_cis) |
|
out = h + self.ffn(self.ffn_norm(h)) |
|
|
|
return out |
|
|
|
|
|
class Decoder_Decoder(nn.Module): |
|
def __init__(self, config: ModelArgs, n_layer): |
|
super().__init__() |
|
self.config = config |
|
self.self_dec = nn.ModuleList([SelfDecoder(config) for _ in range(n_layer//2)]) |
|
self.cross_dec = nn.ModuleList([CrossDecoder(config) for _ in range(n_layer//2)]) |
|
|
|
self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
self.to_k = nn.Linear(config.dim, config.dim, bias=False) |
|
self.to_v = nn.Linear(config.dim, config.dim, bias=False) |
|
|
|
self.kv_cache = False |
|
self.k_cache = None |
|
self.v_cache = None |
|
|
|
def reset_kv_cache(self): |
|
self.k_cache = None |
|
self.v_cache = None |
|
|
|
def update_kv_cache(self, k: torch.Tensor, v: torch.Tensor, head_first=False): |
|
t_dim = 2 if head_first else 1 |
|
|
|
if self.k_cache is None and self.v_cache is None: |
|
k_cache = k |
|
v_cache = v |
|
else: |
|
k_cache = torch.cat([self.k_cache, k], dim=t_dim) |
|
v_cache = torch.cat([self.v_cache, v], dim=t_dim) |
|
|
|
self.k_cache = k_cache |
|
self.v_cache = v_cache |
|
|
|
return k_cache, v_cache |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
q: torch.Tensor, |
|
freqs_cis: torch.Tensor = None |
|
): |
|
for layer in self.self_dec: |
|
x = layer(x=x, freqs_cis=freqs_cis) |
|
|
|
x_norm = self.norm(x) |
|
k = self.to_k(x_norm) |
|
v = self.to_v(x_norm) |
|
|
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=self.config.n_head), (k, v)) |
|
k = apply_rotary_emb(k, freqs_cis[:, :k.shape[1], ...]) |
|
|
|
if self.kv_cache: |
|
k, v = self.update_kv_cache(k, v) |
|
|
|
for layer in self.cross_dec: |
|
q = layer(x=q, k=k, v=v, freqs_cis=freqs_cis) |
|
|
|
return q |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, config: ModelArgs): |
|
super().__init__() |
|
self.config = config |
|
self.image_seq_len = config.block_size |
|
|
|
""" |
|
ref: https://github.com/bytedance/1d-tokenizer/blob/main/modeling/rar.py |
|
Token space: |
|
[0, vocab_size - 1] : those are the learned quantized image tokens |
|
[vocab_size] : the mask token id |
|
[vocab_size + 1, vocab_size + num_classes] : the imagenet class tokens |
|
[vocab_size + num_classes + 1] : the class drop label |
|
[vocab_size + num_classes + 2] : the drop token for scg |
|
""" |
|
self.embeddings = nn.Embedding(config.vocab_size + 1 + config.num_classes + 1 + 1, config.dim) |
|
self.embed_drop = nn.Dropout(config.token_dropout_p) |
|
|
|
self.mask_token_id = config.vocab_size |
|
self.none_conds_id = config.vocab_size + config.num_classes + 1 |
|
self.none_token_id = config.vocab_size + config.num_classes + 2 |
|
|
|
|
|
self.layers = Decoder_Decoder(config, config.n_layer) |
|
|
|
|
|
self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
|
self.head = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
|
|
|
grid_size = int(self.image_seq_len ** 0.5) |
|
self.freqs_cis = precompute_freqs_cis_2d(grid_size, config.dim // config.n_head, config.rope_base, config.cls_token_num) |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
nn.init.constant_(self.head.weight, 0) |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
|
def setup_kv_cache(self, enable=True): |
|
for block in self.layers.self_dec: |
|
block.attn.kv_cache = enable |
|
block.attn.reset_kv_cache() |
|
|
|
self.layers.kv_cache = enable |
|
self.layers.reset_kv_cache() |
|
|
|
def preprocess_condition(self, condition, cond_drop_prob=0.0): |
|
|
|
drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob |
|
condition = condition + self.config.vocab_size + 1 |
|
condition[drop_label_mask] = self.none_conds_id |
|
|
|
if condition.ndim == 1: |
|
condition = condition.unsqueeze(-1) |
|
|
|
return condition |
|
|
|
def forward_shared(self, input_ids, freqs_cis, num_query=None): |
|
embedds = self.embeddings(input_ids) |
|
|
|
x = self.embed_drop(embedds) |
|
num_query = input_ids.shape[-1] if num_query == None else num_query |
|
queries = self.embeddings(torch.full((input_ids.shape[0], num_query), self.mask_token_id, device=input_ids.device)) |
|
|
|
x = self.layers(x, queries, freqs_cis=freqs_cis) |
|
logits = self.head(self.norm(x)).float() |
|
|
|
return logits |
|
|
|
def forward(self, input_ids, condition, targets=None, debug=False): |
|
|
|
condition = self.preprocess_condition(condition, cond_drop_prob=self.config.class_dropout_prob) |
|
|
|
|
|
shuffled_ids, orders = batch_seq_shuffle(input_ids) |
|
|
|
|
|
freqs_cis = self.freqs_cis.unsqueeze(0).repeat(input_ids.shape[0], 1, 1, 1).to(input_ids.device) |
|
fixed_freqs_cis = freqs_cis[:, :1, ...] |
|
shuffled_freqs_cis = batch_seq_shuffle(freqs_cis[:, 1:, ...], orders)[0] |
|
freqs_cis = torch.cat([fixed_freqs_cis, shuffled_freqs_cis], dim=1) |
|
|
|
|
|
logits = self.forward_shared(torch.cat([condition, shuffled_ids[:, :-1]], dim=-1), freqs_cis) |
|
|
|
loss = None |
|
if targets is not None: |
|
targets = batch_seq_shuffle(targets, orders)[0] |
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
|
|
return logits, loss |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
condition, |
|
guidance_scale=4.0, |
|
cfg_schedule='linear', |
|
sample_schedule='arccos', |
|
temperature=1.0, |
|
top_k=0, |
|
top_p=1, |
|
seq_len=256, |
|
num_iter=64, |
|
): |
|
device = condition.device |
|
num_samples = condition.shape[0] |
|
freqs_cis_ = self.freqs_cis.unsqueeze(0).to(device) |
|
|
|
|
|
condition = self.preprocess_condition(condition, cond_drop_prob=0.0) |
|
|
|
|
|
orders = torch.rand(256, device=device).argsort(dim=0) + 1 |
|
|
|
last_pos = 0 |
|
last_range = range(0, 1) |
|
sequences = [] |
|
|
|
self.setup_kv_cache(enable=True) |
|
for step in range(num_iter): |
|
if sample_schedule == 'arccos': |
|
mask_ratio = np.arccos(1. * (step + 1) / num_iter) / (math.pi * 0.5) |
|
elif sample_schedule == 'cosine': |
|
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) |
|
else: |
|
raise NotImplementedError |
|
|
|
mask_len = int(seq_len * mask_ratio) |
|
mask_len = max(1, min(seq_len - last_pos - 1, mask_len)) |
|
|
|
num_pred = seq_len - last_pos - mask_len |
|
if step == num_iter - 1: |
|
num_pred = seq_len - last_pos |
|
|
|
next_range = orders[range(last_pos, last_pos + num_pred)] |
|
last_pos += num_pred |
|
|
|
if cfg_schedule == 'linear': |
|
cfg_scale = 1.0 + (guidance_scale - 1.0) * last_pos / seq_len |
|
elif cfg_schedule == 'constant': |
|
cfg_scale = guidance_scale |
|
else: |
|
raise NotImplementedError |
|
|
|
""" |
|
1. Since the cached key has already had rotary embedding applied, |
|
we only need to input the current position's frequencies for key. |
|
2. We need the next position's frequencies for query to achieve target-aware guidance. |
|
""" |
|
freqs_cis = torch.cat([ |
|
freqs_cis_[:, last_range, ...], |
|
freqs_cis_[:, next_range, ...]], dim=1 |
|
) |
|
if guidance_scale != 0: |
|
if step == 0: |
|
input_ids = torch.cat([condition, torch.full_like(condition, self.none_conds_id)], dim=0) |
|
else: |
|
input_ids = torch.cat([sequences[-1], sequences[-1]], dim=0) |
|
|
|
logits = self.forward_shared(input_ids, freqs_cis, num_pred) |
|
cond_logits, uncond_logits = logits[:num_samples], logits[num_samples:] |
|
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
logits = logits[:, -num_pred:] / max(temperature, 1e-5) |
|
|
|
if top_k > 0 or top_p < 1.0: |
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
sampled = torch.multinomial(probs.flatten(0, 1), num_samples=1) |
|
sequences.append(sampled.reshape(num_samples, -1)) |
|
|
|
last_range = next_range |
|
|
|
self.setup_kv_cache(enable=False) |
|
|
|
sequences = torch.cat(sequences, dim=-1) |
|
return sequences[:, orders.argsort(dim=0)] |
|
|
|
|
|
|
|
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120): |
|
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) |
|
t = torch.arange(seq_len, device=freqs.device) |
|
freqs = torch.outer(t, freqs) |
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) |
|
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) |
|
return cond_cache |
|
|
|
|
|
def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120): |
|
|
|
half_dim = n_elem // 2 |
|
freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim)) |
|
t = torch.arange(grid_size, device=freqs.device) |
|
freqs = torch.outer(t, freqs) |
|
freqs_grid = torch.concat([ |
|
freqs[:, None, :].expand(-1, grid_size, -1), |
|
freqs[None, :, :].expand(grid_size, -1, -1), |
|
], dim=-1) |
|
cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) |
|
cache = cache_grid.flatten(0, 1) |
|
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) |
|
return cond_cache |
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor): |
|
|
|
|
|
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
|
freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) |
|
x_out2 = torch.stack([ |
|
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], |
|
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], |
|
], dim=-1) |
|
x_out2 = x_out2.flatten(3) |
|
return x_out2.type_as(x) |
|
|
|
|
|
def top_k_top_p_filtering( |
|
logits, |
|
top_k: int = 0, |
|
top_p: float = 1.0, |
|
filter_value: float = -float("Inf"), |
|
min_tokens_to_keep: int = 1, |
|
): |
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
Args: |
|
logits: logits distribution shape (batch size, vocabulary size) |
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
Make sure we keep at least min_tokens_to_keep per batch example in the output |
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
""" |
|
if top_k > 0: |
|
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
if min_tokens_to_keep > 1: |
|
|
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
logits[indices_to_remove] = filter_value |
|
return logits |
|
|
|
|
|
def ARPG_XXL(**kwargs): |
|
return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) |
|
|
|
def ARPG_XL(**kwargs): |
|
return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) |
|
|
|
def ARPG_L(**kwargs): |
|
return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) |
|
|
|
|
|
ARPG_models = {'ARPG-L': ARPG_L, 'ARPG-XL': ARPG_XL, 'ARPG-XXL': ARPG_XXL} |