File size: 10,044 Bytes
0776c31 a7db71f 0776c31 a7db71f 083ee17 0776c31 6640b34 aea70a2 6640b34 279671a 6640b34 279671a 6640b34 279671a cd0f0e0 6640b34 cd0f0e0 279671a d5966db 279671a d5966db 6640b34 ef9dc41 a7db71f 8ee14c3 a7db71f ce13f3f a7db71f ce13f3f a7db71f 083ee17 a7db71f 083ee17 a7db71f cd0f0e0 6640b34 cd0f0e0 a7db71f 6640b34 a7db71f a6b0673 a7db71f ce13f3f a7db71f cd0f0e0 6640b34 cd0f0e0 6640b34 a7db71f cd0f0e0 6640b34 cd0f0e0 6640b34 a7db71f 6640b34 a7db71f cd0f0e0 6640b34 cd0f0e0 6640b34 a7db71f 6640b34 a7db71f c16867f a7db71f c16867f a7db71f c16867f a7db71f cd0f0e0 6640b34 cd0f0e0 a7db71f 6640b34 083ee17 c16867f a7db71f c16867f a7db71f c16867f a7db71f c16867f a7db71f 083ee17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
---
library_name: transformers
license: apache-2.0
tags:
- jamba
- mamba
- moe
base_model: ai21labs/Jamba-v0.1
---
# Disclaimer and Requirements
This model is a clone of [**ai21labs/Jamba-v0.1**](https://huggingface.co/ai21labs/Jamba-v0.1) compressed using ZipNN. Compressed losslessly to 67% its original size, ZipNN saved ~35GB in storage and potentially ~1PB in data transfer **monthly**.
### Requirement
In order to use the model, ZipNN is necessary:
```bash
pip install zipnn
```
### Use This Model
```python
# Use a pipeline as a high-level helper
from transformers import pipeline
from zipnn import zipnn_hf
zipnn_hf()
pipe = pipeline("text-generation", model="royleibov/Jamba-v0.1-ZipNN-Compressed")
```
```python
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
from zipnn import zipnn_hf
zipnn_hf()
tokenizer = AutoTokenizer.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed")
model = AutoModelForCausalLM.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed")
```
### ZipNN
ZipNN also allows you to seemlessly save local disk space in your cache after the model is downloaded.
To compress the cached model, simply run:
```bash
python zipnn_compress_path.py safetensors --model royleibov/Jamba-v0.1-ZipNN-Compressed --hf_cache
```
The model will be decompressed automatically and safely as long as `zipnn_hf()` is added at the top of the file like in the [example above](#use-this-model).
To decompress manualy, simply run:
```bash
python zipnn_decompress_path.py --model royleibov/Jamba-v0.1-ZipNN-Compressed --hf_cache
```
# Model Card for Jamba
Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It delivers throughput gains over traditional Transformer-based models, while outperforming or matching the leading models of its size class on most common benchmarks.
Jamba is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations.
This model card is for the base version of Jamba. It’s a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and a total of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU.
For full details of this model please read the [white paper](https://arxiv.org/abs/2403.19887) and the [release blog post](https://www.ai21.com/blog/announcing-jamba).
## Model Details
- **Developed by:** [AI21](https://www.ai21.com)
- **Model type:** Joint Attention and Mamba (Jamba)
- **License:** Apache 2.0
- **Context length:** 256K
- **Knowledge cutoff date:** March 5, 2024
## Usage
### Presequities
In order to use Jamba, it is recommended you use `transformers` version 4.40.0 or higher (version 4.39.0 or higher is required):
```bash
pip install transformers>=4.40.0
```
In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`:
```bash
pip install mamba-ssm causal-conv1d>=1.2.0
```
You also have to have the model on a CUDA device.
You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model.
### Run the model
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from zipnn import zipnn_hf
zipnn_hf()
model = AutoModelForCausalLM.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed")
tokenizer = AutoTokenizer.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
```
Please note that if you're using `transformers<4.40.0`, `trust_remote_code=True` is required for running the new Jamba architecture.
<details>
<summary><strong>Loading the model in half precision</strong></summary>
The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify `torch_dtype`:
```python
from transformers import AutoModelForCausalLM
import torch
from zipnn import zipnn_hf
zipnn_hf()
model = AutoModelForCausalLM.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed",
torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
```
When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
```python
from transformers import AutoModelForCausalLM
from zipnn import zipnn_hf
zipnn_hf()
import torch
model = AutoModelForCausalLM.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
```
</details>
<details><summary><strong>Load the model in 8-bit</strong></summary>
**Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU.** You can easily quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization:
```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from zipnn import zipnn_hf
zipnn_hf()
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
```
</details>
### Fine-tuning example
Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library (requires ~120GB GPU RAM, in example 2xA100 80GB):
```python
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from zipnn import zipnn_hf
zipnn_hf()
tokenizer = AutoTokenizer.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed")
model = AutoModelForCausalLM.from_pretrained("royleibov/Jamba-v0.1-ZipNN-Compressed",
device_map='auto', torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba
"gate_proj", "up_proj", "down_proj", # mlp
"q_proj", "k_proj", "v_proj" # attention
],
task_type="CAUSAL_LM",
bias="none"
)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=1e-5,
dataset_text_field="quote",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
```
## Results on common benchmarks
| Benchmark | Score |
|--------------|:-----:|
| HellaSwag | 87.1% |
| Arc Challenge | 64.4% |
| WinoGrande | 82.5% |
| PIQA | 83.2% |
| MMLU | 67.4% |
| BBH | 45.4% |
| TruthfulQA | 46.4% |
| GSM8K (CoT) | 59.9% |
It's crucial that the 'BOS' token is added to all prompts, which might not be enabled by default in all eval frameworks.
## Notice
Jamba is a pretrained base model and did not undergo any alignment for instruct/chat interactions.
As a base model, Jamba is intended for use as a foundation layer for fine tuning, training, and developing custom solutions. Jamba does not have safety moderation mechanisms and guardrails should be added for responsible and safe use.
## About AI21
AI21 builds reliable, practical, and scalable AI solutions for the enterprise.
Jamba is the first in AI21’s new family of models, and the Instruct version of Jamba is coming soon to the [AI21 platform](https://www.ai21.com/studio). |