Improve model card: Add metadata, paper/code links, abstract, and usage examples
Browse filesThis PR significantly enhances the model card for `Llama3.2-Mamba2-3B-dpo` by:
* Adding the `pipeline_tag: text-generation`, making the model discoverable on the Hugging Face Hub (e.g., at https://huggingface.co/models?pipeline_tag=text-generation).
* Specifying `library_name: transformers`, which correctly indicates its compatibility with the Transformers library's `AutoTokenizer` and its underlying architecture.
* Including a direct link to the paper: [The Mamba in the Llama: Distilling and Accelerating Hybrid Models](https://huggingface.co/papers/2408.15237).
* Adding a link to the official GitHub repository: https://github.com/jxiw/MambaInLlama.
* Integrating the paper's abstract for quick context.
* Providing a clear Python usage example for direct inference with this specific model.
* Updating the performance evaluation section with more comprehensive tables and visuals from the original GitHub repository.
These additions will greatly improve the model's visibility, usability, and documentation for the community.
@@ -1,30 +1,217 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
3 |
---
|
4 |
|
5 |
-
|
6 |
|
7 |
-
|
8 |
-
|---------------|---------------------------------------------------------------------------------|-----------------------------------|-----------------------------------|-----------------------------------|-----------------------------------|
|
9 |
-
| Initialization Model | N/A | Llama-3.2-3B-Instruct | Llama-3.2-3B-Instruct | Llama-3.2-3B-Instruct | Llama-3.2-3B-Instruct |
|
10 |
-
| Teacher Model | N/A | Llama-3.1-70B-Instruct | Llama-3.1-70B-Instruct | Llama-3.1-70B-Instruct | Llama-3.1-70B-Instruct |
|
11 |
-
| arc_challenge | 0.459 | 0.4838 | 0.5265 | 0.4667 | 0.541 |
|
12 |
-
| arc_easy | 0.7407 | 0.7765 | 0.7997 | 0.7668 | 0.8026 |
|
13 |
-
| hellaswag | 0.7043 | 0.7037 | 0.7256 | 0.6913 | 0.7445 |
|
14 |
-
| mmlu | 0.6043 | 0.5448 | 0.5509 | 0.5312 | 0.5247 |
|
15 |
-
| openbookqa | 0.36 | 0.394 | 0.416 | 0.388 | 0.424 |
|
16 |
-
| piqa | 0.7568 | 0.7731 | 0.7731 | 0.7601 | 0.7769 |
|
17 |
-
| pubmedqa | 0.696 | 0.664 | 0.7 | 0.638 | 0.654 |
|
18 |
-
| race | 0.4067 | 0.4029 | 0.4364 | 0.3981 | 0.4344 |
|
19 |
-
| winogrande | 0.6748 | 0.6732 | 0.674 | 0.6606 | 0.6732 |
|
20 |
-
| truthfulqa | 0.3801 | 0.4202 | 0.4853 | 0.3478 | 0.5028 |
|
21 |
|
|
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
```
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
}
|
30 |
```
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
pipeline_tag: text-generation
|
4 |
+
library_name: transformers
|
5 |
---
|
6 |
|
7 |
+
# The Mamba in the Llama: Distilling and Accelerating Hybrid Models
|
8 |
|
9 |
+
This repository contains the code and released models for the distillation approach described in our paper [The Mamba in the Llama: Distilling and Accelerating Hybrid Models](https://huggingface.co/papers/2408.15237).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
Code: https://github.com/jxiw/MambaInLlama
|
12 |
|
13 |
+
<div style="display: flex; justify-content: space-between;">
|
14 |
+
<img src="https://raw.githubusercontent.com/jxiw/MambaInLlama/main/assets/mambainllama.png" alt="MambaInLlama" style="height:200px; width:auto; margin-right: 10px;">
|
15 |
+
<img src="https://raw.githubusercontent.com/jxiw/MambaInLlama/main/assets/mambainllama2.png" alt="MambaInLlama" style="height:200px; width:auto; margin-left: 10px;">
|
16 |
+
</div>
|
17 |
+
|
18 |
+
## Abstract
|
19 |
+
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best 8B scale instruction-tuned linear RNN model. We also find that the distilled model has natural length extrapolation, showing almost perfect accuracy in the needle-in-a-haystack test at 20x the distillation length.
|
20 |
+
|
21 |
+
## Approach
|
22 |
+
|
23 |
+
1. **Stepwise layer alignment** (Optional). Replace the attention layers by Mamba2, one by one in a stepwise manner. **MLP layers are frozen in this stage**. See [here](https://github.com/jxiw/MambaInLlama/blob/main/train_mamba2/train_hybrid.py)
|
24 |
+
2. **End to end distillation** (Most important). Minimize KL divergence loss between the student and teacher models. You can consider to use a larger teacher model to get better results. (**the is a end to end training, and MLP layers are not frozen in this stage**). See [here](https://github.com/jxiw/MambaInLlama/blob/main/train_mamba2/train_distill.py).
|
25 |
+
3. **Instruction tuning** (Optional). For simplicity, we use DPO for this process.
|
26 |
+
|
27 |
+
**We freeze the MLP layers in the first stage because we want to produce a model similar to the initialization model. However, in the end-to-end training/distillation, we only focus on the KL loss, so training all parameters (not freezing the MLP layers) will give better results.**
|
28 |
+
|
29 |
+
## Usage
|
30 |
+
|
31 |
+
### Environment
|
32 |
+
We provide an [environment file](https://github.com/jxiw/MambaInLlama/blob/main/environment.yml) that lists the specific Python package versions used in our experiments. To ensure the best reproducibility, we suggest using these same package versions. Nonetheless, you may also use alternative versions and still be able to run the program. The alignment-handbook version that we use is [here](https://github.com/huggingface/alignment-handbook/tree/606d2e954fd17999af40e6fb4f712055ca11b2f0). The following script is to install `mamba-ssm==2.2.2` and cuda-11.8.0.
|
33 |
+
|
34 |
+
```bash
|
35 |
+
# CUDA>=11.6 needed for `mamba-ssm` and `causal-conv1d`.
|
36 |
+
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
|
37 |
+
# Install PyTorch (with CUDA 11.8) before everything else. those assume you are using cu118
|
38 |
+
pip install torch==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
|
39 |
+
|
40 |
+
pip install causal-conv1d==1.4.0
|
41 |
+
pip install flash-attn==2.6.3
|
42 |
+
|
43 |
+
# make sure you use this alignment version
|
44 |
+
git clone https://github.com/huggingface/alignment-handbook.git
|
45 |
+
cd alignment-handbook/
|
46 |
+
git checkout 606d2e9
|
47 |
+
|
48 |
+
git clone https://github.com/huggingface/transformers.git --branch v4.43.1
|
49 |
+
|
50 |
+
# check your version matches those
|
51 |
+
# deepspeed==0.12.2
|
52 |
+
# torch==2.1.1+cu118
|
53 |
+
# transformers==4.43.1
|
54 |
+
# trl==0.8.6
|
55 |
+
# accelerate==0.33.0
|
56 |
+
# peft==0.12.0
|
57 |
+
# huggingface-hub==0.24.5
|
58 |
```
|
59 |
+
|
60 |
+
If you install mamba-ssm using `pip install mamba-ssm==2.2.2`, you will need to manually change `CONDA_ENV_PATH/site-packages/mamba_ssm/modules/mha.py` to [this version](https://github.com/state-spaces/mamba/blob/014c094d11f780a27330657faabecaaded7a31db/mamba_ssm/modules/mha.py) to support GQA, since GQA is used in Llama3. The **mamba-ssm** used in my experiment is from this [commit](https://github.com/state-spaces/mamba/tree/49ddf8321e4987650e8dc8dc44caa44b892f207a).
|
61 |
+
|
62 |
+
Alternatively, you can build mamba-ssm from source, but ensure the commit is after [this one](https://github.com/state-spaces/mamba/commit/014c094d11f780a27330657faabecaaded7a31db), which fixes the GQA bugs in generations.
|
63 |
+
|
64 |
+
### Generation Example
|
65 |
+
|
66 |
+
To use this model (`Llama3.2-Mamba2-3B-dpo`) for text generation, you can leverage the `mamba2_inference.hybrid_wrapper` with `transformers.AutoTokenizer`.
|
67 |
+
|
68 |
+
```python
|
69 |
+
import torch
|
70 |
+
from transformers import AutoTokenizer
|
71 |
+
# You might need to install 'mamba_inference' or 'mamba2_inference' from the original GitHub repository.
|
72 |
+
# Refer to the official repository for installation instructions:
|
73 |
+
# https://github.com/jxiw/MambaInLlama
|
74 |
+
from mamba2_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper
|
75 |
+
|
76 |
+
pretrained_model_name = "JunxiongWang/Llama3.2-Mamba2-3B-dpo" # This model
|
77 |
+
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
|
78 |
+
model.eval()
|
79 |
+
|
80 |
+
messages = [[
|
81 |
+
{
|
82 |
+
"role": "user",
|
83 |
+
"content": "Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?",
|
84 |
+
},
|
85 |
+
]]
|
86 |
+
|
87 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
88 |
+
formatted_prompts = [
|
89 |
+
tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
|
90 |
+
]
|
91 |
+
|
92 |
+
prompts = [
|
93 |
+
tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
|
94 |
+
for formatted_prompt in formatted_prompts
|
95 |
+
]
|
96 |
+
batch_prompts = torch.cat(prompts, dim=0).cuda()
|
97 |
+
|
98 |
+
outputs = model.generate(
|
99 |
+
input_ids=batch_prompts,
|
100 |
+
max_length=1000,
|
101 |
+
cg=True, # Set to True for speculative decoding (if supported by model config)
|
102 |
+
return_dict_in_generate=True,
|
103 |
+
output_scores=True,
|
104 |
+
enable_timing=True,
|
105 |
+
top_k=1,
|
106 |
+
eos_token_id=tokenizer.eos_token_id
|
107 |
+
)
|
108 |
+
|
109 |
+
generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
|
110 |
+
print(generated_text[0])
|
111 |
+
```
|
112 |
+
|
113 |
+
## Evaluation
|
114 |
+
|
115 |
+
Please follow the instructions [here](https://github.com/jxiw/MambaInLlama/blob/main/benchmark/README.md). Our evaluation includes: a. Standard tasks in [LM Eval](https://github.com/jxiw/MambaInLlama/tree/main/benchmark/llm_eval), b. [Chat Benchmarks](https://github.com/jxiw/MambaInLlama/tree/main/benchmark/alpaca_eval) and [here](https://github.com/jxiw/MambaInLlama/tree/main/benchmark/mt_bench), c. Reasoning tasks [Math and Code Reasoning Benchmarks](https://huggingface.co/spaces/allenai/ZeroEval), and d. Long-range tasks, [NeedleInAHaystack](https://github.com/jxiw/MambaInLlama/blob/main/benchmark/needle/README.md). Our goal is to provide a thorough evaluation and study.
|
116 |
+
|
117 |
+
### Released Models
|
118 |
+
|
119 |
+
### Hybrid Mamba Math Reasoning models
|
120 |
+
|
121 |
+
| **Model** | **AIME 2025** | **AIME 2024** | **MATH 500** | **AMC 2023** | **OlympiadBench** |
|
122 |
+
|-----------------------------------|---------------|---------------|--------------|--------------|-------------------|
|
123 |
+
| Qwen2.5-Math-7B-Instruct (Transformer) | – | 13.3 | 79.8 | 50.6 | 40.7 |
|
124 |
+
| rStar-Math-7B (Transformer) | – | 26.7 | 78.4 | 47.5 | 47.1 |
|
125 |
+
| Eurus-2-7B-PRIME (Transformer) | – | 26.7 | 79.2 | 57.8 | 42.1 |
|
126 |
+
| Qwen2.5-7B-SimpleRL (Transformer) | – | 26.7 | 82.4 | 62.5 | 43.3 |
|
127 |
+
| DeepSeek-R1-Distill-Qwen-1.5B (Transformer) | 23.0 | 28.8 | 82.8 | 62.9 | 43.3 |
|
128 |
+
| [**M1-3B (Mamba Hybrid Models)**](https://huggingface.co/togethercomputer/M1-3B) | 23.5 | 28.5 | 84.0 | 62.8 | 47.3 |
|
129 |
+
|
130 |
+
To reproduce the results, please check this and refer to [M1](https://github.com/jxiw/M1).
|
131 |
+
|
132 |
+
Check [here](https://huggingface.co/collections/JunxiongWang/mambainllama-math-reasoning-67c151eb6ea48bd56b35f434) for reasoning models distilled from llama 1B and llama 3B.
|
133 |
+
|
134 |
+
### Hybrid Mamba (8B) distilled from Llama3.1 8B
|
135 |
+
|
136 |
+
Check [this](https://github.com/jxiw/MambaInLlama/blob/main/llama3.1_8B/README.md) for more details.
|
137 |
+
|
138 |
+
Models are available [here](https://huggingface.co/collections/JunxiongWang/mambainllama-distill-6737cbebfd1af6c3bd75a06c).
|
139 |
+
|
140 |
+
| Model | MMLU (0-shot) <br> | AlpacaEval <br> (Win against GPT-4) | MT-Bench <br> (scored by GPT-4) | GSM8K (0-shot) | [CRUX](https://huggingface.co/spaces/allenai/ZeroEval) (0-shot) |
|
141 |
+
|---------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|
|
142 |
+
[Llama3.1-Mamba2-8B-distill](https://huggingface.co/JunxiongWang/Llama3.1-Mamba2-8B-distill) | 61.01 | 21.22 | 7.5 | 40.65 | 32.50 |
|
143 |
+
[Llama3.1-Mamba-8B-distill](https://huggingface.co/JunxiongWang/Llama3.1-Mamba-8B-distill) | 62.13 | 21.55 | 7.7 | 67.15 | 34.12 |
|
144 |
+
|
145 |
+
These models are trained without using SFT + DPO. We find that with DPO, you can achieve significantly higher scores on the common sense task in LM evaluation benchmark or AlpacaEval. However, it may result in lower scores on reasoning benchmarks and long-range tasks, such as 'needle in a haystack'. Therefore, it is unclear whether this actually makes the model better.
|
146 |
+
|
147 |
+
### Hybrid Mamba (3B) distilled from Llama3.2 3B
|
148 |
+
|
149 |
+
Check [this](https://github.com/jxiw/MambaInLlama/blob/main/llama3.2_3B/README.md) for more details.
|
150 |
+
|
151 |
+
Models are available [here](https://huggingface.co/collections/JunxiongWang/mambainllama-distill-6737cbebfd1af6c3bd75a06c).
|
152 |
+
|
153 |
+
| Model | MMLU (0-shot) <br> | AlpacaEval <br> (Win against GPT-4) | MT-Bench <br> (scored by GPT-4) | GSM8K (0-shot) | [CRUX](https://huggingface.co/spaces/allenai/ZeroEval) (0-shot) |
|
154 |
+
|---------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|
|
155 |
+
[Llama3.2-Mamba2-3B-distill](https://huggingface.co/JunxiongWang/Llama3.2-Mamba2-3B-distill) | 53.12 | 14.34 | 6.7 | 49.37 | 23.38 |
|
156 |
+
[Llama3.2-Mamba-3B-distill](https://huggingface.co/JunxiongWang/Llama3.2-Mamba-3B-distill) | 54.50 | 15.54 | 7.2 | 62.93 | 25.75 |
|
157 |
+
|
158 |
+
Needle In A Haystack. The distillation training length is 2k.
|
159 |
+
|
160 |
+
<img src="https://raw.githubusercontent.com/jxiw/MambaInLlama/main/benchmark/needle/img/needle.png" alt="needle">
|
161 |
+
|
162 |
+
### Hybrid Mamba distilled from Llama3
|
163 |
+
|
164 |
+
| Teacher Model | Hybrid Mamba Model - DPO |Hybrid Mamba2 Model - DPO |
|
165 |
+
|---------------|---------------------------|---------------------------|
|
166 |
+
| Meta-Llama-3-8B-Instruct | [Mamba (1/2 attention)](https://huggingface.co/JunxiongWang/MambaInLlama_0_50) | [Mamba2 (1/2 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_0_50) |
|
167 |
+
| | [Mamba (1/4 attention)](https://huggingface.co/JunxiongWang/MambaInLlama_0_75) | [Mamba2 (1/4 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_0_75) |
|
168 |
+
| | [Mamba (1/8 attention)](https://huggingface.co/JunxiongWang/MambaInLlama_0_875) | [Mamba2 (1/8 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_0_875) |
|
169 |
+
| | | [Mamba2 (0 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_1) |
|
170 |
+
|
171 |
+
|
172 |
+
| Model | MMLU <br> (5 shots) | AlpacaEval <br> (LC win against GPT-4) | MT-Bench <br> (scored by GPT-4) |
|
173 |
+
|-------|----------------|-----------------------------------|----------------------------|
|
174 |
+
| [Mamba (1/2 attention)](https://huggingface.co/JunxiongWang/MambaInLlama_0_50) | 59.26 | 29.61 | 7.35 |
|
175 |
+
| [Mamba2 (1/2 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_0_50) | 56.67 | 25.00 | 7.32 |
|
176 |
+
| [Mamba (1/4 attention)](https://huggingface.co/JunxiongWang/MambaInLlama_0_75) | 52.68 | 25.85 | 6.86 |
|
177 |
+
| [Mamba2 (1/4 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_0_75) | 53.94 | 20.25 | 6.74 |
|
178 |
+
| [Mamba (1/8 attention)](https://huggingface.co/JunxiongWang/MambaInLlama_0_875) | 49.20 | 20.76 | 6.46 |
|
179 |
+
| [Mamba2 (1/8 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_0_875) | 50.85 | 20.25 | 6.48 |
|
180 |
+
| [Mamba2 (0 attention)](https://huggingface.co/JunxiongWang/Mamba2InLlama_1) | 43.19 | 14.49 | 5.64 |
|
181 |
+
|
182 |
+
For reproduction, please follow the instructions [here](https://github.com/jxiw/MambaInLlama/blob/main/mamba_llama/README.md).
|
183 |
+
|
184 |
+
### Hybrid Mamba distilled from Zephyr
|
185 |
+
|
186 |
+
| Teacher Model | Hybrid Mamba Model - SFT | Hybrid Mamba Model - DPO | Hybrid Mamba Model - DPO |
|
187 |
+
|---------------|---------------------------------------------------|--------------------------------------------------|--------------------------------------------------|
|
188 |
+
| Zephyr | [Mamba (1/2 attention)](https://huggingface.co/JunxiongWang/mamba_0_5_sft) | [Mamba (1/2 attention)](https://huggingface.co/JunxiongWang/mamba_0_5_dpo_ep1) | [Mamba (1/2 attention)](https://huggingface.co/JunxiongWang/mamba_0_5_dpo_ep3) |
|
189 |
+
| | [Mamba (1/4 attention)](https://huggingface.co/JunxiongWang/mamba_0_75_sft) | [Mamba (1/4 attention)](https://huggingface.co/JunxiongWang/mamba_0_75_dpo_ep1) | [Mamba (1/4 attention)](https://huggingface.co/JunxiongWang/mamba_0_75_dpo_ep3) |
|
190 |
+
| | [Mamba (1/8 attention)](https://huggingface.co/JunxiongWang/mamba_0_875_sft) | [Mamba (1/8 attention)](https://huggingface.co/JunxiongWang/mamba_0_875_dpo_ep1) | [Mamba (1/8 attention)](https://huggingface.co/JunxiongWang/mamba_0_875_dpo_ep3) |
|
191 |
+
|
192 |
+
|
193 |
+
| Model | MMLU <br> (5 shots) | AlpacaEval <br> (LC win against GPT-4) | MT-Bench <br> (scored by GPT-4) |
|
194 |
+
|-------|---------------------|-----------------------------------|----------------------------|
|
195 |
+
| [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) | 61.44 | 13.20 | 7.34 |
|
196 |
+
| [Mamba DPO 1 (1/2 attention)](https://huggingface.co/JunxiongWang/mamba_0_5_dpo_ep1) | 55.23 | 20.66 | 7.12 |
|
197 |
+
| [Mamba DPO 3 (1/2 attention)](https://huggingface.co/JunxiongWang/mamba_0_5_dpo_ep3) | 55.38 | 17.48 | 7.31 |
|
198 |
+
| [Mamba DPO 1 (1/4 attention)](https://huggingface.co/JunxiongWang/mamba_0_75_dpo_ep1) | 50.94 | 17.16 | 7.03 |
|
199 |
+
| [Mamba DPO 3 (1/4 attention)](https://huggingface.co/JunxiongWang/mamba_0_75_dpo_ep3) | 51.19 | 13.89 | 6.58 |
|
200 |
+
| [Mamba DPO 1 (1/8 attention)](https://huggingface.co/JunxiongWang/mamba_0_875_dpo_ep1) | 48.35 | 15.32 | 6.39 |
|
201 |
+
| [Mamba DPO 3 (1/8 attention)](https://huggingface.co/JunxiongWang/mamba_0_875_dpo_ep3) | 48.44 | 12.67 | 6.37 |
|
202 |
+
|
203 |
+
For reproduction, please follow the instructions [here](https://github.com/jxiw/MambaInLlama/blob/main/mamba_zephyr/README.md).
|
204 |
+
|
205 |
+
## Citation
|
206 |
+
If you use this codebase, or otherwise found our work valuable, please cite:
|
207 |
+
|
208 |
+
```bibtex
|
209 |
+
@inproceedings{
|
210 |
+
junxiongdaniele2024mambainllama,
|
211 |
+
title={The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
|
212 |
+
author={Junxiong Wang and Daniele Paliotta and Avner May and Alexander M Rush and Tri Dao},
|
213 |
+
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
|
214 |
+
year={2024},
|
215 |
+
url={https://openreview.net/forum?id=uAzhODjALU}
|
216 |
}
|
217 |
```
|