Update LoRA fine-tune example - more target_modules, lower LR, bf16 (#49)
Browse files- Update LoRA fine-tune example - more target_modules, lower LR, bf16 (a69ca0f303d6079e51f4d323a81e2ec76484fc92)
Co-authored-by: Michael Gokhman <[email protected]>
README.md
CHANGED
|
@@ -96,31 +96,40 @@ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
|
|
| 96 |
</details>
|
| 97 |
|
| 98 |
### Fine-tuning example
|
| 99 |
-
Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library:
|
| 100 |
|
| 101 |
```python
|
|
|
|
| 102 |
from datasets import load_dataset
|
| 103 |
-
from trl import SFTTrainer
|
| 104 |
from peft import LoraConfig
|
| 105 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 106 |
|
| 107 |
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
|
| 108 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
dataset = load_dataset("Abirate/english_quotes", split="train")
|
| 111 |
-
training_args =
|
| 112 |
output_dir="./results",
|
| 113 |
-
num_train_epochs=
|
| 114 |
per_device_train_batch_size=4,
|
| 115 |
logging_dir='./logs',
|
| 116 |
logging_steps=10,
|
| 117 |
-
learning_rate=
|
| 118 |
-
|
| 119 |
-
lora_config = LoraConfig(
|
| 120 |
-
r=8,
|
| 121 |
-
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
|
| 122 |
-
task_type="CAUSAL_LM",
|
| 123 |
-
bias="none"
|
| 124 |
)
|
| 125 |
trainer = SFTTrainer(
|
| 126 |
model=model,
|
|
@@ -128,9 +137,7 @@ trainer = SFTTrainer(
|
|
| 128 |
args=training_args,
|
| 129 |
peft_config=lora_config,
|
| 130 |
train_dataset=dataset,
|
| 131 |
-
dataset_text_field="quote",
|
| 132 |
)
|
| 133 |
-
|
| 134 |
trainer.train()
|
| 135 |
```
|
| 136 |
|
|
|
|
| 96 |
</details>
|
| 97 |
|
| 98 |
### Fine-tuning example
|
| 99 |
+
Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library (requires ~120GB GPU RAM, in example 2xA100 80GB):
|
| 100 |
|
| 101 |
```python
|
| 102 |
+
import torch
|
| 103 |
from datasets import load_dataset
|
| 104 |
+
from trl import SFTTrainer, SFTConfig
|
| 105 |
from peft import LoraConfig
|
| 106 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 107 |
|
| 108 |
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
|
| 109 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 110 |
+
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
|
| 111 |
+
|
| 112 |
+
lora_config = LoraConfig(
|
| 113 |
+
r=8,
|
| 114 |
+
target_modules=[
|
| 115 |
+
"embed_tokens",
|
| 116 |
+
"x_proj", "in_proj", "out_proj", # mamba
|
| 117 |
+
"gate_proj", "up_proj", "down_proj", # mlp
|
| 118 |
+
"q_proj", "k_proj", "v_proj" # attention
|
| 119 |
+
],
|
| 120 |
+
task_type="CAUSAL_LM",
|
| 121 |
+
bias="none"
|
| 122 |
+
)
|
| 123 |
|
| 124 |
dataset = load_dataset("Abirate/english_quotes", split="train")
|
| 125 |
+
training_args = SFTConfig(
|
| 126 |
output_dir="./results",
|
| 127 |
+
num_train_epochs=2,
|
| 128 |
per_device_train_batch_size=4,
|
| 129 |
logging_dir='./logs',
|
| 130 |
logging_steps=10,
|
| 131 |
+
learning_rate=1e-5,
|
| 132 |
+
dataset_text_field="quote",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
)
|
| 134 |
trainer = SFTTrainer(
|
| 135 |
model=model,
|
|
|
|
| 137 |
args=training_args,
|
| 138 |
peft_config=lora_config,
|
| 139 |
train_dataset=dataset,
|
|
|
|
| 140 |
)
|
|
|
|
| 141 |
trainer.train()
|
| 142 |
```
|
| 143 |
|