EthanReid
Initial import of code and 4-bit weights
1c87faa
raw
history blame
4.69 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
from bitblas.cache import OperatorCache
from .layers import layer_norm, mlp, Linear
from .rope import apply_rotary_emb, precompute_freqs_cis
from .config import TextConfig
def text_encoder(input_ids: torch.Tensor, w: nn.Module):
return F.embedding(input_ids, w.wte)
def attn(
x: torch.Tensor,
w: nn.Module,
freqs_cis: torch.Tensor,
kv_cache: nn.Module,
attn_mask: torch.Tensor,
n_heads: int,
n_kv_heads: int,
position_ids: torch.Tensor,
):
bsz, q_len, d_model = x.shape
head_dim = d_model // n_heads
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
q_dim = n_heads * head_dim
kv_dim = n_kv_heads * head_dim
q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
k = (
qkv_out[..., q_dim : q_dim + kv_dim]
.view(bsz, q_len, n_kv_heads, head_dim)
.transpose(1, 2)
)
v = (
qkv_out[..., q_dim + kv_dim :]
.view(bsz, q_len, n_kv_heads, head_dim)
.transpose(1, 2)
)
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
if kv_cache is not None:
k, v = kv_cache.update(position_ids, k, v)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
)
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
out = w.proj(out)
return out
def text_decoder(
x: torch.Tensor,
w: nn.Module,
attn_mask: torch.Tensor,
position_ids: torch.Tensor,
config: TextConfig,
):
for i, block in enumerate(w.blocks):
l_in = layer_norm(x, block.ln)
l_attn = attn(
l_in,
block.attn,
freqs_cis=w.freqs_cis,
kv_cache=block.kv_cache,
attn_mask=attn_mask,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
position_ids=position_ids,
)
l_mlp = mlp(l_in, block.mlp)
x = x + l_attn + l_mlp
return x
def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
hidden_BC = hidden_BTC[:, -1, :]
hidden_BC = layer_norm(hidden_BC, w.post_ln)
logits = w.lm_head(hidden_BC)
return logits
def build_text_model(
config: TextConfig,
linear_dtype: torch.dtype = torch.float16,
layernorm_dtype: torch.dtype = torch.float16,
) -> nn.Module:
# note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm
print(
"Initializing quantized backend. This only has to run once, but may take a few minutes."
)
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
group_size = None
if linear_dtype == torch.int8:
group_size = config.group_size
def create_linear(in_features, out_features, dtype=linear_dtype):
# factory function for creating Linear layers so we dont have to pass everything again and again
return Linear(
in_features=in_features,
out_features=out_features,
dtype=dtype,
group_size=group_size,
)
text = nn.ModuleDict(
{
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
"attn": nn.ModuleDict(
{
"qkv": create_linear(config.dim, qkv_dim),
"proj": create_linear(config.dim, config.dim),
}
),
"mlp": nn.ModuleDict(
{
"fc1": create_linear(config.dim, config.ff_dim),
"fc2": create_linear(config.ff_dim, config.dim),
}
),
}
)
for _ in range(config.n_layers)
]
),
"post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
"lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype),
}
)
text.wte = nn.Parameter(
torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)
)
text.register_buffer(
"freqs_cis",
precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
persistent=False,
)
return text