nimz12 commited on
Commit
85a0876
1 Parent(s): 0ad8e2d

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +68 -3
  2. config.json +48 -0
  3. configuration_rene.py +103 -0
  4. model.safetensors +3 -0
  5. rene.py +435 -0
README.md CHANGED
@@ -1,3 +1,68 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ datasets:
6
+ - allenai/dolma
7
+ tags:
8
+ - rene
9
+ - mamba
10
+ - cartesia
11
+ ---
12
+
13
+ # Model Card for Rene
14
+
15
+ Rene is a 1.3 billion-parameter language model trained by [Cartesia](https://cartesia.ai).
16
+ Rene has a hybrid architecture based on [Mamba-2](https://arxiv.org/abs/2405.21060), with feedforward and sliding window attention layers interspersed.
17
+ It uses the [allenai/OLMo-1B-hf](https://huggingface.co/allenai/OLMo-1B-hf) tokenizer.
18
+ Rene was pretrained on 1.5 trillion tokens of the [Dolma-1.7](https://huggingface.co/datasets/allenai/dolma) dataset.
19
+ For more details, see our [blog post](https://cartesia.ai/blog/on-device).
20
+
21
+ ## Usage
22
+ ### Installation
23
+ The Rene model depends on the `cartesia-pytorch` package, which can be installed with `pip` as follows:
24
+ ```shell
25
+ pip install --no-binary :all: cartesia-pytorch
26
+ ```
27
+
28
+ ### Generation example
29
+ ```python
30
+ from cartesia_pytorch import ReneLMHeadModel
31
+ from transformers import AutoTokenizer
32
+
33
+ model = ReneLMHeadModel.from_pretrained("cartesia-ai/Rene-v0.1-1.3b-pytorch").half().cuda()
34
+ tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
35
+ in_message = ["Rene Descartes was"]
36
+ inputs = tokenizer(in_message, return_tensors="pt")
37
+ outputs = model.generate(inputs.input_ids.cuda(), max_length=50, top_k=100, top_p=0.99)
38
+ out_message = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
39
+ print(out_message)
40
+ # Example output: "Rene Descartes was a French mathematician, philosopher, and scientist. Descartes is famously credited for creating the Cartesian coordinate system: a 3 dimensional representation of points, vectors, and directions. This work is, for the most part" ...
41
+ ```
42
+
43
+ ### Evaluation example
44
+ You can use our `cartesia_lm_eval` wrapper around the [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) to evaluate our model on standard text benchmarks. Example command (clone this repo and run the below from within the `cartesia-pytorch` directory):
45
+ ```shell
46
+ python -m evals.cartesia_lm_eval --model rene_ssm --model_args pretrained=cartesia-ai/Rene-v0.1-1.3b-pytorch,trust_remote_code=True --trust_remote_code --tasks copa,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --cache_requests true --batch_size auto:4 --output_path outputs/rene_evals/
47
+ ```
48
+ ## Results on common benchmarks
49
+ | Model | Params (B) | Train Tokens | COPA | HellaSwag | MMLU (5-shot) | PIQA | ARC-e | ARC-c | WinoGrande | OpenBookQA | Average |
50
+ |------------------------------------------------|------------|--------------|------|-----------|---------------|------|-------|-------|------------|------------|---------|
51
+ | allenai/OLMo-1B-hf | 1.2 | 3.0 | 82.0 | 62.9 | 26.2 | 75.1 | 57.4 | 31.1 | 60.0 | 36.2 | 53.9 |
52
+ | apple/OpenELM-1\_1B | 1.1 | 1.5 | 81.0 | 64.8 | 27.1 | 75.6 | 55.4 | 32.3 | 61.9 | 36.2 | 54.3 |
53
+ | state-spaces/mamba2-1.3b | 1.3 | 0.3 | 82.0 | 60.0 | 25.8 | 73.7 | 64.2 | 33.3 | 61.0 | 37.8 | 54.7 |
54
+ | microsoft/phi-1\_5 | 1.4 | 0.15 | 79.0 | 62.6 | 42.5 | 75.5 | 73.2 | 48.0 | 72.8 | 48.0 | 62.7 |
55
+ | Qwen/Qwen2-1.5B | 1.5 | 7.0 | 80.0 | 65.4 | 56.0 | 75.5 | 60.4 | 35.0 | 65.8 | 36.4 | 59.3 |
56
+ | RWKV/rwkv-6-world-1b6 | 1.6 | 1.1 | 84.0 | 58.3 | 25.9 | 73.5 | 56.7 | 34.1 | 60.0 | 37.4 | 53.7 |
57
+ | stabilityai/stablelm-2-1\_6b | 1.6 | 4.0 | 86.0 | 69.0 | 38.1 | 76.7 | 68.1 | 38.9 | 63.6 | 38.8 | 59.9 |
58
+ | HuggingFaceTB/SmolLM-1.7B | 1.7 | 1.0 | 76.0 | 65.8 | 29.9 | 76.1 | 73.5 | 46.4 | 60.9 | 42.0 | 58.8 |
59
+ | h2oai/h2o-danube2-1.8b-base | 1.8 | 3.0 | 82.0 | 72.4 | 39.9 | 77.3 | 69.0 | 39.9 | 63.9 | 41.4 | 60.7 |
60
+ | google/recurrentgemma-2b | 2.7 | 2.0 | 62.0 | 61.8 | 32.3 | 68.8 | 46.4 | 29.9 | 57.1 | 29.0 | 48.4 |
61
+ | cognitivecomputations/TinyDolphin-2.8.1-1.1b | 1.1 | | 71.0 | 59.9 | 25.7 | 73.1 | 55.8 | 33.0 | 59.7 | 36.6 | 51.9 |
62
+ | cartesia-ai/Rene-v0.1-1.3b-pytorch (OUR MODEL) | 1.3 | 1.5 | 82.0 | 69.4 | 32.6 | 77.5 | 61.7 | 34.4 | 62.9 | 39.2 | 57.5 |
63
+
64
+ ## Bias, Risks, and Limitations
65
+ Rene is a pretrained base model which has not undergone any alignment or instruction tuning, and therefore does not have any moderation or safety guarantees. Users should implement appropriate guardrails and moderation mechanisms based on their particular needs in order to ensure responsible and ethical usage.
66
+
67
+ ## About Cartesia
68
+ At [Cartesia](https://cartesia.ai/), we're building real-time multimodal intelligence for every device.
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_cfg": {
3
+ "causal": true,
4
+ "head_dim": 64,
5
+ "num_heads": 48,
6
+ "out_proj_bias": true,
7
+ "qkv_proj_bias": true,
8
+ "sliding_window_length": 2048
9
+ },
10
+ "attn_layer_idx": [
11
+ 6,
12
+ 18,
13
+ 30,
14
+ 42
15
+ ],
16
+ "d_model": 2048,
17
+ "eos_token_id": 50279,
18
+ "mlp_cfg": {},
19
+ "mlp_layer_idx": [
20
+ 2,
21
+ 5,
22
+ 8,
23
+ 11,
24
+ 14,
25
+ 17,
26
+ 20,
27
+ 23,
28
+ 26,
29
+ 29,
30
+ 32,
31
+ 35,
32
+ 38,
33
+ 41,
34
+ 44,
35
+ 47
36
+ ],
37
+ "model_type": "rene",
38
+ "n_layer": 48,
39
+ "pad_token_id": 1,
40
+ "pad_vocab_size_multiple": 16,
41
+ "residual_in_fp32": true,
42
+ "rms_norm": true,
43
+ "ssm_cfg": {
44
+ "norm_before_gate": true
45
+ },
46
+ "tie_word_embeddings": true,
47
+ "vocab_size": 50280
48
+ }
configuration_rene.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class ReneConfig(PretrainedConfig):
7
+ r"""Configuration class for the Rene model.
8
+
9
+ This is the configuration class to store the configuration of a [`ReneLMHeadModel`].
10
+ It is used to instantiate a Rene model according to the specified arguments,
11
+ defining the model architecture. Instantiating a configuration with the defaults will yield
12
+ a similar configuration to that of the Rene-v0.1-1.3b-pytorch model.
13
+ [cartesia-ai/Rene-v0.1-1.3b-pytorch](https://huggingface.co/cartesia-ai/Rene-v0.1-1.3b-pytorch)
14
+
15
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
16
+ documentation from [`PretrainedConfig`] for more information.
17
+
18
+ Args:
19
+ d_model (`int`, *optional*, defaults to 2048):
20
+ Dimension of the hidden representations.
21
+ n_layer (`int`, *optional*, defaults to 48):
22
+ Number of architecture blocks.
23
+ vocab_size (`int`, *optional*, defaults to 50280):
24
+ Vocabulary size of the Rene model. Defines the number of different tokens that can be represented by the
25
+ `inputs_ids` passed when calling [`ReneModel`].
26
+ ssm_cfg (`dict`, *optional*):
27
+ Configuration parameters for the SSM layers.
28
+ attn_layer_idx (`List[int]`, *optional*):
29
+ Indices of the architecture blocks that should have attention layers.
30
+ attn_cfg (`dict`, *optional*):
31
+ Configuration parameters for the attention layers.
32
+ mlp_layer_idx (`List[int]`, *optional*):
33
+ Indices of the architecture blocks that should have MLP layers.
34
+ mlp_cfg (`dict`, *optional*):
35
+ Configuration parameters for the MLP layers.
36
+ rms_norm (`bool`, *optional*, defaults to `True`):
37
+ Whether to use RMSNorm (instead of LayerNorm).
38
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
39
+ Whether to keep residual values in fp32.
40
+ pad_vocab_size_multiple (`int`, *optional*, defaults to 16):
41
+ Pad the vocabulary size up to the next multiple of this value.
42
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
43
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
44
+ model has a output word embedding layer.
45
+ pad_token_id (`int`, *optional*, defaults to 1):
46
+ The id of the padding token.
47
+ bos_token_id (`int`, *optional*):
48
+ The id of the "beginning-of-sequence" token.
49
+ eos_token_id (`int`, *optional*, defaults to 50279):
50
+ The id of the "end-of-sequence" token.
51
+ """
52
+
53
+ model_type = "rene"
54
+
55
+ def __init__(
56
+ self,
57
+ d_model: int = 2048,
58
+ n_layer: int = 48,
59
+ vocab_size: int = 50280,
60
+ ssm_cfg: Optional[Dict] = None,
61
+ attn_layer_idx: Optional[List] = None,
62
+ attn_cfg: Optional[Dict] = None,
63
+ mlp_layer_idx: Optional[List] = None,
64
+ mlp_cfg: Optional[Dict] = None,
65
+ rms_norm: bool = True,
66
+ residual_in_fp32: bool = True,
67
+ pad_vocab_size_multiple: int = 16,
68
+ tie_word_embeddings: bool = True,
69
+ pad_token_id=1,
70
+ bos_token_id=None,
71
+ eos_token_id=50279,
72
+ **kwargs,
73
+ ):
74
+ if ssm_cfg is None:
75
+ ssm_cfg = {}
76
+ if attn_layer_idx is None:
77
+ attn_layer_idx = []
78
+ if attn_cfg is None:
79
+ attn_cfg = {}
80
+ if mlp_layer_idx is None:
81
+ mlp_layer_idx = []
82
+ if mlp_cfg is None:
83
+ mlp_cfg = {}
84
+
85
+ self.d_model = d_model
86
+ self.n_layer = n_layer
87
+ self.vocab_size = vocab_size
88
+ self.ssm_cfg = ssm_cfg
89
+ self.attn_layer_idx = attn_layer_idx
90
+ self.attn_cfg = attn_cfg
91
+ self.mlp_layer_idx = mlp_layer_idx
92
+ self.mlp_cfg = mlp_cfg
93
+ self.rms_norm = rms_norm
94
+ self.residual_in_fp32 = residual_in_fp32
95
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
96
+ self.tie_word_embeddings = tie_word_embeddings
97
+ super().__init__(
98
+ bos_token_id=bos_token_id,
99
+ eos_token_id=eos_token_id,
100
+ pad_token_id=pad_token_id,
101
+ tie_word_embeddings=tie_word_embeddings,
102
+ **kwargs,
103
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a62c98beb82cd70e4ff866b3cd479f836f17676a76b82a337a1dde2126673de
3
+ size 2866628624
rene.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from flash_attn import flash_attn_with_kvcache
8
+ from mamba_ssm.models.mixer_seq_simple import _init_weights
9
+ from mamba_ssm.modules.mamba2 import Mamba2
10
+ from mamba_ssm.modules.mha import _update_kv_cache
11
+ from mamba_ssm.utils.generation import GenerationMixin as MambaGenerationMixin
12
+ from transformers.modeling_outputs import CausalLMOutput
13
+ from transformers.modeling_utils import PreTrainedModel
14
+
15
+ from .configuration_rene import ReneConfig
16
+
17
+
18
+ class ReneMLP(nn.Module):
19
+ """One-hidden-layer network with GELU activation.
20
+
21
+ Args:
22
+ d_input: Block input dimension.
23
+ d_output: Block output dimension.
24
+ expand: Block expansion factor.
25
+ bias: Use biases in linear layers.
26
+ """
27
+
28
+ def __init__(self, d_input, d_output=None, expand=3, bias=True, device=None, dtype=None):
29
+ super().__init__()
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ self.d_input = d_input
32
+ self.d_output = d_input if d_output is None else d_output
33
+ self.d_inner = int(round(expand * d_input))
34
+ self.in_proj = nn.Linear(self.d_input, self.d_inner, bias=bias, **factory_kwargs)
35
+ self.activation = nn.GELU()
36
+ self.out_proj = nn.Linear(self.d_inner, self.d_input, bias=bias, **factory_kwargs)
37
+
38
+ def forward(self, x, inference_params=None):
39
+ """Forward pass through the MLP module."""
40
+ y = self.in_proj(x)
41
+ y = self.activation(y)
42
+ y = self.out_proj(y)
43
+ return y
44
+
45
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
46
+ """Allocate inference cache for ReneMLP. (There is nothing to cache for this module)."""
47
+ return None
48
+
49
+
50
+ class ReneMHA(nn.Module):
51
+ """Multi-head self-attention. Adapted from mamba_ssm MHA class."""
52
+
53
+ def __init__(
54
+ self,
55
+ embed_dim,
56
+ num_heads,
57
+ num_heads_kv=None,
58
+ head_dim=None, # If None, use embed_dim // num_heads
59
+ qkv_proj_bias=True,
60
+ out_proj_bias=True,
61
+ softmax_scale=None,
62
+ causal=True,
63
+ sliding_window_length=None, # If None, infinite context
64
+ layer_idx=None,
65
+ device=None,
66
+ dtype=None,
67
+ ) -> None:
68
+ """
69
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
70
+ return_residual: whether to return the input x along with the output. This is for
71
+ performance reason: for post-norm architecture, returning the input allows us
72
+ to fuse the backward of nn.Linear with the residual connection.
73
+ """
74
+ super().__init__()
75
+ factory_kwargs = {"device": device, "dtype": dtype}
76
+ self.embed_dim = embed_dim
77
+ self.layer_idx = layer_idx
78
+ self.softmax_scale = softmax_scale
79
+ self.causal = causal
80
+ assert self.causal, "Rene does not yet support non-causal modeling"
81
+
82
+ self.num_heads = num_heads
83
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
84
+ assert (
85
+ self.num_heads % self.num_heads_kv == 0
86
+ ), "num_heads must be divisible by num_heads_kv"
87
+ if head_dim is None:
88
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
89
+ self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
90
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
+ out_dim = self.head_dim * self.num_heads
92
+
93
+ self.sliding_window_length = sliding_window_length
94
+ if self.sliding_window_length is None:
95
+ self.window_size = (-1, -1)
96
+ else:
97
+ self.window_size = (self.sliding_window_length - 1, 0) # for flash_attn
98
+
99
+ self.in_proj = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
100
+ self.out_proj = nn.Linear(out_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
101
+
102
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
103
+ """Allocate inference cache for the multi-head self-attention module."""
104
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
105
+ device = self.out_proj.weight.device
106
+ kv_cache = torch.empty(
107
+ batch_size,
108
+ max_seqlen,
109
+ 2,
110
+ self.num_heads_kv,
111
+ self.head_dim,
112
+ dtype=dtype,
113
+ device=device,
114
+ )
115
+ return kv_cache, None
116
+
117
+ def _pytorch_attn(self, q, kv):
118
+ k, v = kv.unbind(dim=-3)
119
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
120
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
121
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
122
+ L, S = q.size(-2), k.size(-2)
123
+ if S > self.sliding_window_length:
124
+ attn_mask = (
125
+ torch.ones(L, S, dtype=torch.bool)
126
+ .tril(diagonal=0)
127
+ .triu(-self.window_size[0])
128
+ .to(device=q.device)
129
+ )
130
+ # Since we pass in an attn_mask explicitly, we need to pass is_causal=False to
131
+ # `scaled_dot_product_attention` (even though the attn_mask itself is in fact causal).
132
+ is_causal_arg = False
133
+ else:
134
+ # The previous branch would also handle this case correctly, but it is more efficient
135
+ # to omit the attn_mask when we don't need it.
136
+ attn_mask = None
137
+ is_causal_arg = True
138
+ return F.scaled_dot_product_attention(
139
+ q, k, v, attn_mask=attn_mask, is_causal=is_causal_arg, scale=self.softmax_scale
140
+ ).transpose(1, 2)
141
+
142
+ def _update_kv_cache(self, kv, inference_params):
143
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)."""
144
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
145
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
146
+
147
+ def _update_kvcache_attention(self, q, kv, inference_params):
148
+ """Write kv to inference_params, then compute attention."""
149
+ if inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None:
150
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
151
+ kv = self._update_kv_cache(kv, inference_params)
152
+ return self._pytorch_attn(q, kv)
153
+ else:
154
+ batch = q.shape[0]
155
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
156
+ kv_cache = kv_cache[:batch]
157
+ cache_seqlens = (
158
+ inference_params.lengths_per_sample[:batch]
159
+ if inference_params.lengths_per_sample is not None
160
+ else inference_params.seqlen_offset
161
+ )
162
+ return flash_attn_with_kvcache(
163
+ q,
164
+ kv_cache[:, :, 0],
165
+ kv_cache[:, :, 1],
166
+ kv[:, :, 0],
167
+ kv[:, :, 1],
168
+ cache_seqlens=cache_seqlens,
169
+ softmax_scale=self.softmax_scale,
170
+ causal=self.causal,
171
+ window_size=self.window_size,
172
+ )
173
+
174
+ def forward(self, x, inference_params=None):
175
+ """Forward pass through the multi-head self-attention module."""
176
+ if (
177
+ inference_params is not None
178
+ and self.layer_idx not in inference_params.key_value_memory_dict
179
+ ):
180
+ inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
181
+ x.shape[0], inference_params.max_seqlen, dtype=x.dtype
182
+ )
183
+ qkv = self.in_proj(x)
184
+ q, kv = qkv.split(
185
+ [self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1
186
+ )
187
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
188
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
189
+ if inference_params is None:
190
+ context = self._pytorch_attn(q, kv)
191
+ else:
192
+ context = self._update_kvcache_attention(q, kv, inference_params)
193
+ context = rearrange(context, "... h d -> ... (h d)")
194
+ out = self.out_proj(context)
195
+ return out
196
+
197
+
198
+ class Block(nn.Module):
199
+ """Simple residual block with normalization that wraps an inner "mixer" module."""
200
+
201
+ def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, residual_in_fp32=False):
202
+ """
203
+ dim: The dimension of the input data.
204
+ mixer_cls: The class of the mixer module.
205
+ norm_cls: The class of the normalization module.
206
+ residual_in_fp32: Whether to keep residuals in fp32.
207
+ """
208
+ super().__init__()
209
+ self.residual_in_fp32 = residual_in_fp32
210
+ self.norm = norm_cls(dim)
211
+ self.mixer = mixer_cls(dim)
212
+
213
+ def forward(self, x, inference_params=None, **mixer_kwargs):
214
+ """Forward pass through the block."""
215
+ y = self.norm(x.to(dtype=self.norm.weight.dtype))
216
+ y = self.mixer(y, inference_params=inference_params, **mixer_kwargs)
217
+
218
+ residual = x
219
+ if self.residual_in_fp32:
220
+ residual = residual.to(torch.float32)
221
+ y = y + residual
222
+ y = y.to(dtype=x.dtype)
223
+
224
+ return y
225
+
226
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
227
+ """Allocate inference cache for the mixer module."""
228
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
229
+
230
+
231
+ def _create_block(
232
+ d_model,
233
+ norm_cls,
234
+ ssm_cfg=None,
235
+ attn_layer_idx=None,
236
+ attn_cfg=None,
237
+ mlp_layer_idx=None,
238
+ mlp_cfg=None,
239
+ residual_in_fp32=False,
240
+ layer_idx=None,
241
+ device=None,
242
+ dtype=None,
243
+ ):
244
+ factory_kwargs = {"device": device, "dtype": dtype}
245
+ if ssm_cfg is None:
246
+ ssm_cfg = {}
247
+ if attn_layer_idx is None:
248
+ attn_layer_idx = []
249
+ if attn_cfg is None:
250
+ attn_cfg = {}
251
+ if mlp_layer_idx is None:
252
+ mlp_layer_idx = []
253
+ if mlp_cfg is None:
254
+ mlp_cfg = {}
255
+ if layer_idx in attn_layer_idx:
256
+ mixer_cls = partial(ReneMHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
257
+ elif layer_idx in mlp_layer_idx:
258
+ mixer_cls = partial(ReneMLP, **mlp_cfg, **factory_kwargs)
259
+ else:
260
+ mixer_cls = partial(Mamba2, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
261
+ return Block(d_model, mixer_cls, norm_cls=norm_cls, residual_in_fp32=residual_in_fp32)
262
+
263
+
264
+ class MixerModel(nn.Module):
265
+ """Adapted from mamba_ssm.models.mixer_seq_simple.MixerModel."""
266
+
267
+ def __init__(
268
+ self,
269
+ d_model: int,
270
+ n_layer: int,
271
+ vocab_size: int,
272
+ ssm_cfg=None,
273
+ attn_layer_idx=None,
274
+ attn_cfg=None,
275
+ mlp_layer_idx=None,
276
+ mlp_cfg=None,
277
+ norm_epsilon: float = 1e-5,
278
+ rms_norm: bool = False,
279
+ initializer_cfg=None,
280
+ residual_in_fp32=False,
281
+ device=None,
282
+ dtype=None,
283
+ ) -> None:
284
+ super().__init__()
285
+ factory_kwargs = {"device": device, "dtype": dtype}
286
+ self.residual_in_fp32 = residual_in_fp32
287
+
288
+ if rms_norm:
289
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm as norm_cls_base
290
+ else:
291
+ norm_cls_base = nn.LayerNorm
292
+ norm_cls = partial(norm_cls_base, eps=norm_epsilon, **factory_kwargs)
293
+
294
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
295
+
296
+ self.layers = nn.ModuleList(
297
+ [
298
+ _create_block(
299
+ d_model,
300
+ norm_cls=norm_cls,
301
+ ssm_cfg=ssm_cfg,
302
+ attn_layer_idx=attn_layer_idx,
303
+ attn_cfg=attn_cfg,
304
+ mlp_layer_idx=mlp_layer_idx,
305
+ mlp_cfg=mlp_cfg,
306
+ residual_in_fp32=residual_in_fp32,
307
+ layer_idx=i,
308
+ **factory_kwargs,
309
+ )
310
+ for i in range(n_layer)
311
+ ]
312
+ )
313
+
314
+ self.norm_f = norm_cls(d_model)
315
+
316
+ self.apply(
317
+ partial(
318
+ _init_weights,
319
+ n_layer=n_layer,
320
+ **(initializer_cfg if initializer_cfg is not None else {}),
321
+ n_residuals_per_layer=1,
322
+ )
323
+ )
324
+
325
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
326
+ """Allocate inference cache for all layers."""
327
+ return {
328
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
329
+ for i, layer in enumerate(self.layers)
330
+ }
331
+
332
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
333
+ """Forward pass through the model."""
334
+ hidden_states = self.embedding(input_ids)
335
+ for layer in self.layers:
336
+ hidden_states = layer(hidden_states, inference_params=inference_params, **mixer_kwargs)
337
+ hidden_states = self.norm_f(hidden_states.to(dtype=self.norm_f.weight.dtype))
338
+ return hidden_states
339
+
340
+
341
+ class ReneLMHeadModel(PreTrainedModel, MambaGenerationMixin):
342
+ """
343
+ Rene language model architecture.
344
+ Based on mamba_ssm.models.mixer_seq_simple.MambaLMHeadModel, with several adaptations.
345
+ """
346
+
347
+ config_class = ReneConfig
348
+ base_model_prefix = "backbone"
349
+ _no_split_modules = ["Block", "Mamba2"]
350
+ supports_gradient_checkpointing = True
351
+ _is_stateful = True
352
+ _tied_weights_keys = ["lm_head.weight"]
353
+
354
+ def __init__(
355
+ self,
356
+ config: ReneConfig,
357
+ initializer_cfg=None,
358
+ device=None,
359
+ dtype=None,
360
+ ) -> None:
361
+ super().__init__(config)
362
+ d_model = config.d_model
363
+ n_layer = config.n_layer
364
+ vocab_size = config.vocab_size
365
+ ssm_cfg = config.ssm_cfg
366
+ attn_layer_idx = config.attn_layer_idx
367
+ attn_cfg = config.attn_cfg
368
+ mlp_layer_idx = config.mlp_layer_idx
369
+ mlp_cfg = config.mlp_cfg
370
+ rms_norm = config.rms_norm
371
+ residual_in_fp32 = config.residual_in_fp32
372
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
373
+ factory_kwargs = {"device": device, "dtype": dtype}
374
+
375
+ if set(attn_layer_idx).intersection(mlp_layer_idx):
376
+ raise ValueError(f"Conflicting {attn_layer_idx=} and {mlp_layer_idx=}")
377
+
378
+ if vocab_size % pad_vocab_size_multiple != 0:
379
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
380
+
381
+ self.backbone = MixerModel(
382
+ d_model=d_model,
383
+ n_layer=n_layer,
384
+ vocab_size=vocab_size,
385
+ ssm_cfg=ssm_cfg,
386
+ attn_layer_idx=attn_layer_idx,
387
+ attn_cfg=attn_cfg,
388
+ mlp_layer_idx=mlp_layer_idx,
389
+ mlp_cfg=mlp_cfg,
390
+ rms_norm=rms_norm,
391
+ initializer_cfg=initializer_cfg,
392
+ residual_in_fp32=residual_in_fp32,
393
+ **factory_kwargs,
394
+ )
395
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
396
+
397
+ # Initialize weights
398
+ self.apply(
399
+ partial(
400
+ _init_weights,
401
+ n_layer=n_layer,
402
+ **(initializer_cfg if initializer_cfg is not None else {}),
403
+ )
404
+ )
405
+ self.tie_weights()
406
+
407
+ def tie_weights(self):
408
+ """Tie embeddings and softmax layer weights if specified by config."""
409
+ if self.config.tie_word_embeddings:
410
+ self.lm_head.weight = self.backbone.embedding.weight
411
+
412
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
413
+ """Allocate inference cache."""
414
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
415
+
416
+ def forward(
417
+ self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs
418
+ ):
419
+ """
420
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
421
+ num_last_tokens: if > 0, only return the logits for the last n tokens.
422
+ """
423
+ hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
424
+ if num_last_tokens > 0:
425
+ hidden_states = hidden_states[:, -num_last_tokens:]
426
+ lm_logits = self.lm_head(hidden_states)
427
+
428
+ return CausalLMOutput(logits=lm_logits)
429
+
430
+ def generate(self, *args, **kwargs):
431
+ """
432
+ Calls the custom `generate` method from `mamba_ssm.utils.generation.GenerationMixin`.
433
+ Refer to that method for argument names and defaults.
434
+ """
435
+ return MambaGenerationMixin.generate(self, *args, **kwargs)