Phi2 rewrite (#1058)
Browse files* restore to current phi modeling code from phi-2
* enable gradient checkpointing
* don't cast everything to float32 all the time
* gradient checkpointing for phi2 ParallelBlock module too
* fix enabling flash attn for phi2
* add comment about import
* fix phi2 example
* fix model type check for tokenizer
* revert float32 -> bf16 casting changes
* support fused dense flash attn
* fix the repo for flash-attn
* add package name for subdir pkg
* fix the data collator when not using sample packing
* install packaging for pytests in ci
* also fix setup to not install flash attn fused dense subdir if not extras
* split out the fused-dense-lib in extra requires
* don't train w group_by_length for phi
* update integration test to use phi2
* set max steps and save steps for phi e2e tests
* try to workaround ssave issue in ci
* skip phi2 e2e test for now
- examples/phi/phi2-ft.yml +73 -0
- requirements.txt +1 -0
- setup.py +4 -0
- src/axolotl/core/trainer_builder.py +9 -1
- src/axolotl/models/phi/modeling_phi.py +115 -86
- src/axolotl/utils/models.py +11 -1
- tests/e2e/test_phi.py +15 -9
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_model: microsoft/phi-2
|
| 2 |
+
model_type: AutoModelForCausalLM
|
| 3 |
+
tokenizer_type: AutoTokenizer
|
| 4 |
+
trust_remote_code: true
|
| 5 |
+
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
strict: false
|
| 9 |
+
|
| 10 |
+
datasets:
|
| 11 |
+
- path: garage-bAInd/Open-Platypus
|
| 12 |
+
type: alpaca
|
| 13 |
+
|
| 14 |
+
dataset_prepared_path:
|
| 15 |
+
val_set_size: 0.05
|
| 16 |
+
output_dir: ./phi-sft-out
|
| 17 |
+
|
| 18 |
+
sequence_len: 2048
|
| 19 |
+
sample_packing: false # currently unsupported
|
| 20 |
+
pad_to_sequence_len:
|
| 21 |
+
|
| 22 |
+
adapter:
|
| 23 |
+
lora_model_dir:
|
| 24 |
+
lora_r: 16
|
| 25 |
+
lora_alpha: 32
|
| 26 |
+
lora_dropout: 0.1
|
| 27 |
+
lora_target_linear: true
|
| 28 |
+
lora_fan_in_fan_out:
|
| 29 |
+
lora_modules_to_save:
|
| 30 |
+
- embd
|
| 31 |
+
- lm_head
|
| 32 |
+
|
| 33 |
+
wandb_project:
|
| 34 |
+
wandb_entity:
|
| 35 |
+
wandb_watch:
|
| 36 |
+
wandb_name:
|
| 37 |
+
wandb_log_model:
|
| 38 |
+
|
| 39 |
+
gradient_accumulation_steps: 1
|
| 40 |
+
micro_batch_size: 1
|
| 41 |
+
num_epochs: 4
|
| 42 |
+
optimizer: paged_adamw_8bit
|
| 43 |
+
adam_beta2: 0.95
|
| 44 |
+
adam_epsilon: 0.00001
|
| 45 |
+
max_grad_norm: 1.0
|
| 46 |
+
lr_scheduler: cosine
|
| 47 |
+
learning_rate: 1e-5
|
| 48 |
+
|
| 49 |
+
train_on_inputs: false
|
| 50 |
+
group_by_length: false
|
| 51 |
+
bf16: true
|
| 52 |
+
fp16: false
|
| 53 |
+
tf32: true
|
| 54 |
+
|
| 55 |
+
gradient_checkpointing: true
|
| 56 |
+
early_stopping_patience:
|
| 57 |
+
resume_from_checkpoint:
|
| 58 |
+
local_rank:
|
| 59 |
+
logging_steps: 1
|
| 60 |
+
xformers_attention:
|
| 61 |
+
flash_attention: true
|
| 62 |
+
|
| 63 |
+
warmup_steps: 100
|
| 64 |
+
evals_per_epoch: 4
|
| 65 |
+
saves_per_epoch: 1
|
| 66 |
+
debug:
|
| 67 |
+
deepspeed:
|
| 68 |
+
weight_decay: 0.1
|
| 69 |
+
fsdp:
|
| 70 |
+
fsdp_config:
|
| 71 |
+
resize_token_embeddings_to_32x: true
|
| 72 |
+
special_tokens:
|
| 73 |
+
pad_token: "<|endoftext|>"
|
|
@@ -12,6 +12,7 @@ fire
|
|
| 12 |
PyYAML>=6.0
|
| 13 |
datasets>=2.15.0
|
| 14 |
flash-attn==2.3.3
|
|
|
|
| 15 |
sentencepiece
|
| 16 |
wandb
|
| 17 |
einops
|
|
|
|
| 12 |
PyYAML>=6.0
|
| 13 |
datasets>=2.15.0
|
| 14 |
flash-attn==2.3.3
|
| 15 |
+
fused-dense-lib @ git+https://github.com/Dao-AILab/[email protected]#subdirectory=csrc/fused_dense_lib
|
| 16 |
sentencepiece
|
| 17 |
wandb
|
| 18 |
einops
|
|
@@ -17,6 +17,7 @@ def parse_requirements():
|
|
| 17 |
_dependency_links.append(url)
|
| 18 |
elif (
|
| 19 |
"flash-attn" not in line
|
|
|
|
| 20 |
and "deepspeed" not in line
|
| 21 |
and line
|
| 22 |
and line[0] != "#"
|
|
@@ -51,6 +52,9 @@ setup(
|
|
| 51 |
"flash-attn": [
|
| 52 |
"flash-attn==2.3.3",
|
| 53 |
],
|
|
|
|
|
|
|
|
|
|
| 54 |
"deepspeed": [
|
| 55 |
"deepspeed",
|
| 56 |
],
|
|
|
|
| 17 |
_dependency_links.append(url)
|
| 18 |
elif (
|
| 19 |
"flash-attn" not in line
|
| 20 |
+
and "flash-attention" not in line
|
| 21 |
and "deepspeed" not in line
|
| 22 |
and line
|
| 23 |
and line[0] != "#"
|
|
|
|
| 52 |
"flash-attn": [
|
| 53 |
"flash-attn==2.3.3",
|
| 54 |
],
|
| 55 |
+
"fused-dense-lib": [
|
| 56 |
+
"fused-dense-lib @ git+https://github.com/Dao-AILab/[email protected]#subdirectory=csrc/fused_dense_lib",
|
| 57 |
+
],
|
| 58 |
"deepspeed": [
|
| 59 |
"deepspeed",
|
| 60 |
],
|
|
@@ -34,6 +34,7 @@ from axolotl.utils.callbacks import (
|
|
| 34 |
)
|
| 35 |
from axolotl.utils.collators import (
|
| 36 |
BatchSamplerDataCollatorForSeq2Seq,
|
|
|
|
| 37 |
MambaDataCollator,
|
| 38 |
)
|
| 39 |
from axolotl.utils.samplers import MultipackBatchSampler
|
|
@@ -843,7 +844,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 843 |
if self.cfg.model_config_type == "mamba":
|
| 844 |
return MambaDataCollator(tokenizer=self.tokenizer)
|
| 845 |
|
| 846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
self.tokenizer,
|
| 848 |
return_tensors="pt",
|
| 849 |
**kwargs,
|
|
|
|
| 34 |
)
|
| 35 |
from axolotl.utils.collators import (
|
| 36 |
BatchSamplerDataCollatorForSeq2Seq,
|
| 37 |
+
DataCollatorForSeq2Seq,
|
| 38 |
MambaDataCollator,
|
| 39 |
)
|
| 40 |
from axolotl.utils.samplers import MultipackBatchSampler
|
|
|
|
| 844 |
if self.cfg.model_config_type == "mamba":
|
| 845 |
return MambaDataCollator(tokenizer=self.tokenizer)
|
| 846 |
|
| 847 |
+
if training_args.sample_packing:
|
| 848 |
+
return BatchSamplerDataCollatorForSeq2Seq(
|
| 849 |
+
self.tokenizer,
|
| 850 |
+
return_tensors="pt",
|
| 851 |
+
**kwargs,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
return DataCollatorForSeq2Seq(
|
| 855 |
self.tokenizer,
|
| 856 |
return_tensors="pt",
|
| 857 |
**kwargs,
|
|
@@ -9,27 +9,32 @@ from __future__ import annotations
|
|
| 9 |
|
| 10 |
import math
|
| 11 |
from dataclasses import dataclass, field
|
| 12 |
-
from typing import Any, Dict, Optional, Tuple, Union
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
| 16 |
from einops import rearrange, repeat
|
|
|
|
| 17 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 18 |
from transformers.activations import ACT2FN
|
| 19 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 20 |
|
| 21 |
-
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
| 22 |
from .configuration_phi import PhiConfig
|
| 23 |
|
| 24 |
try:
|
| 25 |
from flash_attn.bert_padding import pad_input, unpad_input
|
| 26 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 27 |
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 28 |
-
|
| 29 |
-
except: # noqa: E722
|
| 30 |
pad_input, unpad_input = None, None
|
| 31 |
FlashRotaryEmbedding = None
|
| 32 |
FlashSelfAttention, FlashCrossAttention = None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
FusedDense = None
|
| 34 |
|
| 35 |
|
|
@@ -224,7 +229,9 @@ class RotaryEmbedding(nn.Module):
|
|
| 224 |
|
| 225 |
# Initialize cached attributes since ONNX can't rely on dynamic initialization
|
| 226 |
self._update_cos_sin_cache(
|
| 227 |
-
max_position_embeddings,
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
|
| 230 |
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
|
@@ -281,34 +288,32 @@ class RotaryEmbedding(nn.Module):
|
|
| 281 |
seqlen_offset: int = 0,
|
| 282 |
**kwargs,
|
| 283 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 284 |
-
seq_start = seqlen_offset
|
| 285 |
-
seq_end = seq_start + qkv.shape[1]
|
| 286 |
-
|
| 287 |
if (
|
| 288 |
-
self.
|
|
|
|
| 289 |
or self._cos_cached.dtype != qkv.dtype
|
| 290 |
or (self.training and self._cos_cached.is_inference())
|
| 291 |
):
|
| 292 |
self._update_cos_sin_cache(
|
| 293 |
-
|
| 294 |
)
|
| 295 |
|
| 296 |
if kv is None:
|
| 297 |
return _apply_rotary_emb_qkv(
|
| 298 |
qkv,
|
| 299 |
-
self._cos_cached[
|
| 300 |
-
self._sin_cached[
|
| 301 |
)
|
| 302 |
else:
|
| 303 |
q = _apply_rotary_emb(
|
| 304 |
qkv,
|
| 305 |
-
self._cos_cached[
|
| 306 |
-
self._sin_cached[
|
| 307 |
)
|
| 308 |
kv = _apply_rotary_emb_kv(
|
| 309 |
kv,
|
| 310 |
-
self._cos_cached[
|
| 311 |
-
self._sin_cached[
|
| 312 |
)
|
| 313 |
|
| 314 |
return q, kv
|
|
@@ -511,7 +516,7 @@ def _update_kv_cache(
|
|
| 511 |
num_heads, head_dim = kv.shape[-2:]
|
| 512 |
|
| 513 |
if layer_idx not in inference_params.key_value_memory_dict:
|
| 514 |
-
|
| 515 |
inference_params.max_batch_size,
|
| 516 |
inference_params.max_seqlen,
|
| 517 |
2,
|
|
@@ -520,9 +525,6 @@ def _update_kv_cache(
|
|
| 520 |
dtype=kv.dtype,
|
| 521 |
device=kv.device,
|
| 522 |
)
|
| 523 |
-
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
| 524 |
-
else:
|
| 525 |
-
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
| 526 |
|
| 527 |
batch_start = inference_params.batch_size_offset
|
| 528 |
batch_end = batch_start + kv.shape[0]
|
|
@@ -530,8 +532,19 @@ def _update_kv_cache(
|
|
| 530 |
sequence_start = inference_params.seqlen_offset
|
| 531 |
sequence_end = sequence_start + kv.shape[1]
|
| 532 |
|
| 533 |
-
|
| 534 |
-
kv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
|
| 536 |
return kv
|
| 537 |
|
|
@@ -624,13 +637,10 @@ class MHA(nn.Module):
|
|
| 624 |
self.layer_idx = layer_idx
|
| 625 |
self.return_residual = return_residual
|
| 626 |
self.checkpointing = checkpointing
|
|
|
|
| 627 |
|
| 628 |
def _forward_self_attn(
|
| 629 |
-
self,
|
| 630 |
-
x: torch.FloatTensor,
|
| 631 |
-
key_padding_mask: Optional[torch.BoolTensor],
|
| 632 |
-
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 633 |
-
max_seqlen: Optional[int] = None,
|
| 634 |
) -> torch.FloatTensor:
|
| 635 |
qkv = self.Wqkv(x)
|
| 636 |
qkv = rearrange(
|
|
@@ -643,20 +653,21 @@ class MHA(nn.Module):
|
|
| 643 |
if self.flash_attn:
|
| 644 |
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 645 |
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
and cu_seqlens is None
|
| 649 |
-
and max_seqlen is None
|
| 650 |
-
):
|
| 651 |
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
|
| 652 |
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
| 653 |
qkv, indices, cu_seqlens, max_seqlen = unpad_input(
|
| 654 |
qkv, key_padding_mask
|
| 655 |
)
|
| 656 |
|
| 657 |
-
if self.checkpointing:
|
| 658 |
-
attn_output =
|
| 659 |
-
self.inner_attn,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
)
|
| 661 |
else:
|
| 662 |
attn_output = self.inner_attn(
|
|
@@ -670,9 +681,12 @@ class MHA(nn.Module):
|
|
| 670 |
else attn_output
|
| 671 |
)
|
| 672 |
|
| 673 |
-
if self.checkpointing:
|
| 674 |
-
return
|
| 675 |
-
self.inner_attn,
|
|
|
|
|
|
|
|
|
|
| 676 |
)
|
| 677 |
|
| 678 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
|
@@ -725,8 +739,8 @@ class MHA(nn.Module):
|
|
| 725 |
q, key_padding_mask
|
| 726 |
)
|
| 727 |
|
| 728 |
-
if self.checkpointing:
|
| 729 |
-
attn_output =
|
| 730 |
self.inner_cross_attn,
|
| 731 |
q,
|
| 732 |
kv,
|
|
@@ -735,6 +749,7 @@ class MHA(nn.Module):
|
|
| 735 |
max_seqlen=max_seqlen_q,
|
| 736 |
cu_seqlens_k=cu_seqlens_k,
|
| 737 |
max_seqlen_k=max_seqlen_k,
|
|
|
|
| 738 |
)
|
| 739 |
else:
|
| 740 |
attn_output = self.inner_cross_attn(
|
|
@@ -753,13 +768,14 @@ class MHA(nn.Module):
|
|
| 753 |
else attn_output
|
| 754 |
)
|
| 755 |
|
| 756 |
-
if self.checkpointing:
|
| 757 |
-
return
|
| 758 |
self.inner_cross_attn,
|
| 759 |
q,
|
| 760 |
kv,
|
| 761 |
key_padding_mask=key_padding_mask,
|
| 762 |
causal=causal,
|
|
|
|
| 763 |
)
|
| 764 |
|
| 765 |
return self.inner_cross_attn(
|
|
@@ -771,11 +787,8 @@ class MHA(nn.Module):
|
|
| 771 |
x: torch.FloatTensor,
|
| 772 |
past_key_values: Optional[InferenceParams] = None,
|
| 773 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 774 |
-
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 775 |
-
max_seqlen: Optional[int] = None,
|
| 776 |
**kwargs,
|
| 777 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 778 |
-
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
|
| 779 |
if attention_mask is not None:
|
| 780 |
attention_mask = attention_mask.bool()
|
| 781 |
else:
|
|
@@ -785,18 +798,12 @@ class MHA(nn.Module):
|
|
| 785 |
if self.n_head == self.n_head_kv:
|
| 786 |
if past_key_values is None:
|
| 787 |
# If `past_key_values` are not supplied, we run self-attention
|
| 788 |
-
attn_output = self._forward_self_attn(
|
| 789 |
-
x, attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
| 790 |
-
)
|
| 791 |
else:
|
| 792 |
# If `past_key_values` are supplied, it means that we might have cached values and
|
| 793 |
# could take advantage of cross-attention
|
| 794 |
attn_output = self._forward_cross_attn(
|
| 795 |
-
x,
|
| 796 |
-
past_key_values,
|
| 797 |
-
attention_mask,
|
| 798 |
-
cu_seqlens=cu_seqlens,
|
| 799 |
-
max_seqlen=max_seqlen,
|
| 800 |
)
|
| 801 |
# MQA / GQA
|
| 802 |
else:
|
|
@@ -830,6 +837,8 @@ class ParallelBlock(nn.Module):
|
|
| 830 |
|
| 831 |
self.mixer = MHA(config, layer_idx=block_idx)
|
| 832 |
self.mlp = MLP(config)
|
|
|
|
|
|
|
| 833 |
|
| 834 |
def forward(
|
| 835 |
self,
|
|
@@ -838,23 +847,52 @@ class ParallelBlock(nn.Module):
|
|
| 838 |
attention_mask: Optional[torch.BoolTensor] = None,
|
| 839 |
**kwargs,
|
| 840 |
) -> torch.FloatTensor:
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
|
|
|
| 845 |
hidden_states,
|
| 846 |
-
past_key_values
|
| 847 |
-
attention_mask
|
| 848 |
-
)
|
| 849 |
-
|
| 850 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
|
| 852 |
-
|
| 853 |
-
|
| 854 |
|
| 855 |
-
|
| 856 |
|
| 857 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
|
| 860 |
class CausalLMHead(nn.Module):
|
|
@@ -911,7 +949,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 911 |
|
| 912 |
config_class = PhiConfig
|
| 913 |
base_model_prefix = "transformer"
|
| 914 |
-
supports_gradient_checkpointing =
|
| 915 |
_no_split_modules = ["ParallelBlock"]
|
| 916 |
|
| 917 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
@@ -931,6 +969,14 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 931 |
module.bias.data.zero_()
|
| 932 |
module.weight.data.fill_(1.0)
|
| 933 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 934 |
def prepare_inputs_for_generation(
|
| 935 |
self,
|
| 936 |
input_ids: torch.LongTensor,
|
|
@@ -951,7 +997,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 951 |
)
|
| 952 |
else:
|
| 953 |
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
| 954 |
-
past_key_values.seqlen_offset =
|
| 955 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 956 |
|
| 957 |
return {
|
|
@@ -988,8 +1034,6 @@ class PhiModel(PhiPreTrainedModel):
|
|
| 988 |
input_ids: torch.LongTensor,
|
| 989 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 990 |
attention_mask: Optional[torch.BoolTensor] = None,
|
| 991 |
-
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 992 |
-
max_seqlen: Optional[int] = None,
|
| 993 |
) -> torch.FloatTensor:
|
| 994 |
hidden_states = self.embd(input_ids)
|
| 995 |
|
|
@@ -998,8 +1042,6 @@ class PhiModel(PhiPreTrainedModel):
|
|
| 998 |
hidden_states,
|
| 999 |
past_key_values=past_key_values,
|
| 1000 |
attention_mask=attention_mask,
|
| 1001 |
-
cu_seqlens=cu_seqlens,
|
| 1002 |
-
max_seqlen=max_seqlen,
|
| 1003 |
)
|
| 1004 |
|
| 1005 |
return hidden_states
|
|
@@ -1034,23 +1076,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
| 1034 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 1035 |
attention_mask: Optional[torch.BoolTensor] = None,
|
| 1036 |
labels: Optional[torch.LongTensor] = None,
|
| 1037 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 1038 |
**kwargs,
|
| 1039 |
) -> CausalLMOutputWithPast:
|
| 1040 |
-
cu_seqlens: Optional[torch.LongTensor] = None
|
| 1041 |
-
max_seqlen: Optional[int] = None
|
| 1042 |
-
if position_ids is not None:
|
| 1043 |
-
batch_size, seq_length = input_ids.shape
|
| 1044 |
-
position_ids = position_ids.view(-1, seq_length).long()
|
| 1045 |
-
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
| 1046 |
-
cu_seqlens = cu_seqlens.squeeze()
|
| 1047 |
-
|
| 1048 |
hidden_states = self.transformer(
|
| 1049 |
-
input_ids,
|
| 1050 |
-
past_key_values=past_key_values,
|
| 1051 |
-
attention_mask=attention_mask,
|
| 1052 |
-
cu_seqlens=cu_seqlens,
|
| 1053 |
-
max_seqlen=max_seqlen,
|
| 1054 |
)
|
| 1055 |
lm_logits = self.lm_head(hidden_states)
|
| 1056 |
|
|
|
|
| 9 |
|
| 10 |
import math
|
| 11 |
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
| 16 |
from einops import rearrange, repeat
|
| 17 |
+
from torch.utils.checkpoint import checkpoint
|
| 18 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 19 |
from transformers.activations import ACT2FN
|
| 20 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 21 |
|
|
|
|
| 22 |
from .configuration_phi import PhiConfig
|
| 23 |
|
| 24 |
try:
|
| 25 |
from flash_attn.bert_padding import pad_input, unpad_input
|
| 26 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 27 |
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 28 |
+
except ImportError:
|
|
|
|
| 29 |
pad_input, unpad_input = None, None
|
| 30 |
FlashRotaryEmbedding = None
|
| 31 |
FlashSelfAttention, FlashCrossAttention = None, None
|
| 32 |
+
|
| 33 |
+
# this is in a seperate try/except block since sometimes fused_dense isn't available
|
| 34 |
+
# and it shouldn't completely disable flash attn when it isn't
|
| 35 |
+
try:
|
| 36 |
+
from flash_attn.ops.fused_dense import FusedDense
|
| 37 |
+
except ImportError:
|
| 38 |
FusedDense = None
|
| 39 |
|
| 40 |
|
|
|
|
| 229 |
|
| 230 |
# Initialize cached attributes since ONNX can't rely on dynamic initialization
|
| 231 |
self._update_cos_sin_cache(
|
| 232 |
+
max_position_embeddings,
|
| 233 |
+
device=device,
|
| 234 |
+
dtype=torch.float32,
|
| 235 |
)
|
| 236 |
|
| 237 |
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
|
|
|
| 288 |
seqlen_offset: int = 0,
|
| 289 |
**kwargs,
|
| 290 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
| 291 |
if (
|
| 292 |
+
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
| 293 |
+
or self._cos_cached.device != qkv.device
|
| 294 |
or self._cos_cached.dtype != qkv.dtype
|
| 295 |
or (self.training and self._cos_cached.is_inference())
|
| 296 |
):
|
| 297 |
self._update_cos_sin_cache(
|
| 298 |
+
qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype
|
| 299 |
)
|
| 300 |
|
| 301 |
if kv is None:
|
| 302 |
return _apply_rotary_emb_qkv(
|
| 303 |
qkv,
|
| 304 |
+
self._cos_cached[seqlen_offset:],
|
| 305 |
+
self._sin_cached[seqlen_offset:],
|
| 306 |
)
|
| 307 |
else:
|
| 308 |
q = _apply_rotary_emb(
|
| 309 |
qkv,
|
| 310 |
+
self._cos_cached[seqlen_offset:],
|
| 311 |
+
self._sin_cached[seqlen_offset:],
|
| 312 |
)
|
| 313 |
kv = _apply_rotary_emb_kv(
|
| 314 |
kv,
|
| 315 |
+
self._cos_cached[seqlen_offset:],
|
| 316 |
+
self._sin_cached[seqlen_offset:],
|
| 317 |
)
|
| 318 |
|
| 319 |
return q, kv
|
|
|
|
| 516 |
num_heads, head_dim = kv.shape[-2:]
|
| 517 |
|
| 518 |
if layer_idx not in inference_params.key_value_memory_dict:
|
| 519 |
+
inference_params.key_value_memory_dict[layer_idx] = torch.empty(
|
| 520 |
inference_params.max_batch_size,
|
| 521 |
inference_params.max_seqlen,
|
| 522 |
2,
|
|
|
|
| 525 |
dtype=kv.dtype,
|
| 526 |
device=kv.device,
|
| 527 |
)
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
batch_start = inference_params.batch_size_offset
|
| 530 |
batch_end = batch_start + kv.shape[0]
|
|
|
|
| 532 |
sequence_start = inference_params.seqlen_offset
|
| 533 |
sequence_end = sequence_start + kv.shape[1]
|
| 534 |
|
| 535 |
+
# When the current sequence length is equal to or larger than the maximum sequence length,
|
| 536 |
+
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
| 537 |
+
if sequence_end >= inference_params.max_seqlen:
|
| 538 |
+
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate(
|
| 539 |
+
(inference_params.key_value_memory_dict[layer_idx], kv), dim=1
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
inference_params.key_value_memory_dict[layer_idx][
|
| 543 |
+
batch_start:batch_end, sequence_start:sequence_end, ...
|
| 544 |
+
] = kv
|
| 545 |
+
kv = inference_params.key_value_memory_dict[layer_idx][
|
| 546 |
+
batch_start:batch_end, :sequence_end, ...
|
| 547 |
+
]
|
| 548 |
|
| 549 |
return kv
|
| 550 |
|
|
|
|
| 637 |
self.layer_idx = layer_idx
|
| 638 |
self.return_residual = return_residual
|
| 639 |
self.checkpointing = checkpointing
|
| 640 |
+
self._gradient_checkpointing_func = None
|
| 641 |
|
| 642 |
def _forward_self_attn(
|
| 643 |
+
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
) -> torch.FloatTensor:
|
| 645 |
qkv = self.Wqkv(x)
|
| 646 |
qkv = rearrange(
|
|
|
|
| 653 |
if self.flash_attn:
|
| 654 |
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 655 |
|
| 656 |
+
cu_seqlens, max_seqlen = None, None
|
| 657 |
+
if key_padding_mask is not None:
|
|
|
|
|
|
|
|
|
|
| 658 |
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
|
| 659 |
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
| 660 |
qkv, indices, cu_seqlens, max_seqlen = unpad_input(
|
| 661 |
qkv, key_padding_mask
|
| 662 |
)
|
| 663 |
|
| 664 |
+
if self.checkpointing and self.training:
|
| 665 |
+
attn_output = self._gradient_checkpointing_func(
|
| 666 |
+
self.inner_attn,
|
| 667 |
+
qkv,
|
| 668 |
+
cu_seqlens=cu_seqlens,
|
| 669 |
+
max_seqlen=max_seqlen,
|
| 670 |
+
use_reentrant=False,
|
| 671 |
)
|
| 672 |
else:
|
| 673 |
attn_output = self.inner_attn(
|
|
|
|
| 681 |
else attn_output
|
| 682 |
)
|
| 683 |
|
| 684 |
+
if self.checkpointing and self.training:
|
| 685 |
+
return self._gradient_checkpointing_func(
|
| 686 |
+
self.inner_attn,
|
| 687 |
+
qkv,
|
| 688 |
+
key_padding_mask=key_padding_mask,
|
| 689 |
+
use_reentrant=False,
|
| 690 |
)
|
| 691 |
|
| 692 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
|
|
|
| 739 |
q, key_padding_mask
|
| 740 |
)
|
| 741 |
|
| 742 |
+
if self.checkpointing and self.training:
|
| 743 |
+
attn_output = self._gradient_checkpointing_func(
|
| 744 |
self.inner_cross_attn,
|
| 745 |
q,
|
| 746 |
kv,
|
|
|
|
| 749 |
max_seqlen=max_seqlen_q,
|
| 750 |
cu_seqlens_k=cu_seqlens_k,
|
| 751 |
max_seqlen_k=max_seqlen_k,
|
| 752 |
+
use_reentrant=False,
|
| 753 |
)
|
| 754 |
else:
|
| 755 |
attn_output = self.inner_cross_attn(
|
|
|
|
| 768 |
else attn_output
|
| 769 |
)
|
| 770 |
|
| 771 |
+
if self.checkpointing and self.training:
|
| 772 |
+
return self._gradient_checkpointing_func(
|
| 773 |
self.inner_cross_attn,
|
| 774 |
q,
|
| 775 |
kv,
|
| 776 |
key_padding_mask=key_padding_mask,
|
| 777 |
causal=causal,
|
| 778 |
+
use_reentrant=False,
|
| 779 |
)
|
| 780 |
|
| 781 |
return self.inner_cross_attn(
|
|
|
|
| 787 |
x: torch.FloatTensor,
|
| 788 |
past_key_values: Optional[InferenceParams] = None,
|
| 789 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
|
|
|
|
|
|
| 790 |
**kwargs,
|
| 791 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
|
|
| 792 |
if attention_mask is not None:
|
| 793 |
attention_mask = attention_mask.bool()
|
| 794 |
else:
|
|
|
|
| 798 |
if self.n_head == self.n_head_kv:
|
| 799 |
if past_key_values is None:
|
| 800 |
# If `past_key_values` are not supplied, we run self-attention
|
| 801 |
+
attn_output = self._forward_self_attn(x, attention_mask)
|
|
|
|
|
|
|
| 802 |
else:
|
| 803 |
# If `past_key_values` are supplied, it means that we might have cached values and
|
| 804 |
# could take advantage of cross-attention
|
| 805 |
attn_output = self._forward_cross_attn(
|
| 806 |
+
x, past_key_values, attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
)
|
| 808 |
# MQA / GQA
|
| 809 |
else:
|
|
|
|
| 837 |
|
| 838 |
self.mixer = MHA(config, layer_idx=block_idx)
|
| 839 |
self.mlp = MLP(config)
|
| 840 |
+
self.checkpointing = False
|
| 841 |
+
self._gradient_checkpointing_func = None
|
| 842 |
|
| 843 |
def forward(
|
| 844 |
self,
|
|
|
|
| 847 |
attention_mask: Optional[torch.BoolTensor] = None,
|
| 848 |
**kwargs,
|
| 849 |
) -> torch.FloatTensor:
|
| 850 |
+
def _forward(
|
| 851 |
+
mixer,
|
| 852 |
+
resid_dropout,
|
| 853 |
+
mlp,
|
| 854 |
+
ln,
|
| 855 |
hidden_states,
|
| 856 |
+
past_key_values,
|
| 857 |
+
attention_mask,
|
| 858 |
+
):
|
| 859 |
+
residual = hidden_states
|
| 860 |
+
hidden_states = ln(hidden_states)
|
| 861 |
+
|
| 862 |
+
attn_outputs = mixer(
|
| 863 |
+
hidden_states,
|
| 864 |
+
past_key_values=past_key_values,
|
| 865 |
+
attention_mask=attention_mask,
|
| 866 |
+
)
|
| 867 |
+
if isinstance(attn_outputs, tuple):
|
| 868 |
+
attn_outputs = attn_outputs[0]
|
| 869 |
|
| 870 |
+
attn_outputs = resid_dropout(attn_outputs)
|
| 871 |
+
feed_forward_hidden_states = resid_dropout(mlp(hidden_states))
|
| 872 |
|
| 873 |
+
return attn_outputs + feed_forward_hidden_states + residual
|
| 874 |
|
| 875 |
+
if self.training and self.checkpointing:
|
| 876 |
+
return self._gradient_checkpointing_func(
|
| 877 |
+
_forward,
|
| 878 |
+
self.mixer,
|
| 879 |
+
self.resid_dropout,
|
| 880 |
+
self.mlp,
|
| 881 |
+
self.ln,
|
| 882 |
+
hidden_states,
|
| 883 |
+
past_key_values,
|
| 884 |
+
attention_mask,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
return _forward(
|
| 888 |
+
self.mixer,
|
| 889 |
+
self.resid_dropout,
|
| 890 |
+
self.mlp,
|
| 891 |
+
self.ln,
|
| 892 |
+
hidden_states,
|
| 893 |
+
past_key_values,
|
| 894 |
+
attention_mask,
|
| 895 |
+
)
|
| 896 |
|
| 897 |
|
| 898 |
class CausalLMHead(nn.Module):
|
|
|
|
| 949 |
|
| 950 |
config_class = PhiConfig
|
| 951 |
base_model_prefix = "transformer"
|
| 952 |
+
supports_gradient_checkpointing = True
|
| 953 |
_no_split_modules = ["ParallelBlock"]
|
| 954 |
|
| 955 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
|
|
| 969 |
module.bias.data.zero_()
|
| 970 |
module.weight.data.fill_(1.0)
|
| 971 |
|
| 972 |
+
def _set_gradient_checkpointing(
|
| 973 |
+
self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint
|
| 974 |
+
):
|
| 975 |
+
for module in self.modules():
|
| 976 |
+
if hasattr(module, "checkpointing"):
|
| 977 |
+
module._gradient_checkpointing_func = gradient_checkpointing_func
|
| 978 |
+
module.checkpointing = enable
|
| 979 |
+
|
| 980 |
def prepare_inputs_for_generation(
|
| 981 |
self,
|
| 982 |
input_ids: torch.LongTensor,
|
|
|
|
| 997 |
)
|
| 998 |
else:
|
| 999 |
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
| 1000 |
+
past_key_values.seqlen_offset = input_ids.shape[1] - 1
|
| 1001 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 1002 |
|
| 1003 |
return {
|
|
|
|
| 1034 |
input_ids: torch.LongTensor,
|
| 1035 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 1036 |
attention_mask: Optional[torch.BoolTensor] = None,
|
|
|
|
|
|
|
| 1037 |
) -> torch.FloatTensor:
|
| 1038 |
hidden_states = self.embd(input_ids)
|
| 1039 |
|
|
|
|
| 1042 |
hidden_states,
|
| 1043 |
past_key_values=past_key_values,
|
| 1044 |
attention_mask=attention_mask,
|
|
|
|
|
|
|
| 1045 |
)
|
| 1046 |
|
| 1047 |
return hidden_states
|
|
|
|
| 1076 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 1077 |
attention_mask: Optional[torch.BoolTensor] = None,
|
| 1078 |
labels: Optional[torch.LongTensor] = None,
|
|
|
|
| 1079 |
**kwargs,
|
| 1080 |
) -> CausalLMOutputWithPast:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1081 |
hidden_states = self.transformer(
|
| 1082 |
+
input_ids, past_key_values=past_key_values, attention_mask=attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1083 |
)
|
| 1084 |
lm_logits = self.lm_head(hidden_states)
|
| 1085 |
|
|
@@ -55,6 +55,8 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|
| 55 |
|
| 56 |
def load_model_config(cfg):
|
| 57 |
model_config_name = cfg.base_model_config or cfg.base_model
|
|
|
|
|
|
|
| 58 |
trust_remote_code = cfg.trust_remote_code is True
|
| 59 |
|
| 60 |
try:
|
|
@@ -80,6 +82,7 @@ def load_model_config(cfg):
|
|
| 80 |
|
| 81 |
|
| 82 |
def load_tokenizer(cfg):
|
|
|
|
| 83 |
tokenizer_kwargs = {}
|
| 84 |
use_fast = True # this is the default
|
| 85 |
|
|
@@ -139,6 +142,7 @@ def load_tokenizer(cfg):
|
|
| 139 |
for k, val in cfg.special_tokens.items():
|
| 140 |
# check if new special token is not already in tokenizer and
|
| 141 |
# is adapter training to make sure lora_modules_to_save is set
|
|
|
|
| 142 |
if (
|
| 143 |
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
| 144 |
and cfg.adapter
|
|
@@ -149,6 +153,7 @@ def load_tokenizer(cfg):
|
|
| 149 |
for x in ["embed_tokens", "lm_head"]
|
| 150 |
)
|
| 151 |
)
|
|
|
|
| 152 |
):
|
| 153 |
raise ValueError(
|
| 154 |
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
|
@@ -386,6 +391,10 @@ def load_model(
|
|
| 386 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 387 |
"eager"
|
| 388 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
try:
|
| 391 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
|
@@ -438,11 +447,12 @@ def load_model(
|
|
| 438 |
# device=cfg.device,
|
| 439 |
# )
|
| 440 |
# model.train() # sets to train instead of eval mode
|
| 441 |
-
elif model_type == "PhiForCausalLM":
|
| 442 |
from axolotl.models.phi import PhiForCausalLM
|
| 443 |
|
| 444 |
model = PhiForCausalLM.from_pretrained(
|
| 445 |
base_model,
|
|
|
|
| 446 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 447 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 448 |
**model_kwargs,
|
|
|
|
| 55 |
|
| 56 |
def load_model_config(cfg):
|
| 57 |
model_config_name = cfg.base_model_config or cfg.base_model
|
| 58 |
+
if not model_config_name and cfg.tokenizer_config:
|
| 59 |
+
model_config_name = cfg.tokenizer_config
|
| 60 |
trust_remote_code = cfg.trust_remote_code is True
|
| 61 |
|
| 62 |
try:
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
def load_tokenizer(cfg):
|
| 85 |
+
model_config = load_model_config(cfg)
|
| 86 |
tokenizer_kwargs = {}
|
| 87 |
use_fast = True # this is the default
|
| 88 |
|
|
|
|
| 142 |
for k, val in cfg.special_tokens.items():
|
| 143 |
# check if new special token is not already in tokenizer and
|
| 144 |
# is adapter training to make sure lora_modules_to_save is set
|
| 145 |
+
# pylint: disable=too-many-boolean-expressions
|
| 146 |
if (
|
| 147 |
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
| 148 |
and cfg.adapter
|
|
|
|
| 153 |
for x in ["embed_tokens", "lm_head"]
|
| 154 |
)
|
| 155 |
)
|
| 156 |
+
and (model_config.model_type in ["llama", "mistral", "mixtral"])
|
| 157 |
):
|
| 158 |
raise ValueError(
|
| 159 |
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
|
|
|
| 391 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 392 |
"eager"
|
| 393 |
)
|
| 394 |
+
if model_config.model_type == "phi-msft":
|
| 395 |
+
model_config.flash_attn = True
|
| 396 |
+
model_config.flash_rotary = True
|
| 397 |
+
model_config.fused_dense = True
|
| 398 |
|
| 399 |
try:
|
| 400 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
|
|
|
| 447 |
# device=cfg.device,
|
| 448 |
# )
|
| 449 |
# model.train() # sets to train instead of eval mode
|
| 450 |
+
elif model_type == "PhiForCausalLM" or model_config.model_type == "phi-msft":
|
| 451 |
from axolotl.models.phi import PhiForCausalLM
|
| 452 |
|
| 453 |
model = PhiForCausalLM.from_pretrained(
|
| 454 |
base_model,
|
| 455 |
+
config=model_config,
|
| 456 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 457 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 458 |
**model_kwargs,
|
|
@@ -7,6 +7,8 @@ import os
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
|
|
|
| 10 |
from axolotl.cli import load_datasets
|
| 11 |
from axolotl.common.cli import TrainerCliArgs
|
| 12 |
from axolotl.train import train
|
|
@@ -21,17 +23,18 @@ os.environ["WANDB_DISABLED"] = "true"
|
|
| 21 |
|
| 22 |
class TestPhi(unittest.TestCase):
|
| 23 |
"""
|
| 24 |
-
Test case for
|
| 25 |
"""
|
| 26 |
|
|
|
|
| 27 |
@with_temp_dir
|
| 28 |
-
def
|
| 29 |
# pylint: disable=duplicate-code
|
| 30 |
cfg = DictDefault(
|
| 31 |
{
|
| 32 |
-
"base_model": "microsoft/phi-
|
| 33 |
"trust_remote_code": True,
|
| 34 |
-
"model_type": "
|
| 35 |
"tokenizer_type": "AutoTokenizer",
|
| 36 |
"sequence_len": 512,
|
| 37 |
"sample_packing": False,
|
|
@@ -39,9 +42,6 @@ class TestPhi(unittest.TestCase):
|
|
| 39 |
"adapter": None,
|
| 40 |
"val_set_size": 0.1,
|
| 41 |
"special_tokens": {
|
| 42 |
-
"unk_token": "<|endoftext|>",
|
| 43 |
-
"bos_token": "<|endoftext|>",
|
| 44 |
-
"eos_token": "<|endoftext|>",
|
| 45 |
"pad_token": "<|endoftext|>",
|
| 46 |
},
|
| 47 |
"datasets": [
|
|
@@ -57,9 +57,14 @@ class TestPhi(unittest.TestCase):
|
|
| 57 |
"gradient_accumulation_steps": 1,
|
| 58 |
"output_dir": temp_dir,
|
| 59 |
"learning_rate": 0.00001,
|
| 60 |
-
"optimizer": "
|
| 61 |
"lr_scheduler": "cosine",
|
| 62 |
"bf16": True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
}
|
| 64 |
)
|
| 65 |
normalize_config(cfg)
|
|
@@ -69,12 +74,13 @@ class TestPhi(unittest.TestCase):
|
|
| 69 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 70 |
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
| 71 |
|
|
|
|
| 72 |
@with_temp_dir
|
| 73 |
def test_ft_packed(self, temp_dir):
|
| 74 |
# pylint: disable=duplicate-code
|
| 75 |
cfg = DictDefault(
|
| 76 |
{
|
| 77 |
-
"base_model": "microsoft/phi-
|
| 78 |
"trust_remote_code": True,
|
| 79 |
"model_type": "PhiForCausalLM",
|
| 80 |
"tokenizer_type": "AutoTokenizer",
|
|
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
+
import pytest
|
| 11 |
+
|
| 12 |
from axolotl.cli import load_datasets
|
| 13 |
from axolotl.common.cli import TrainerCliArgs
|
| 14 |
from axolotl.train import train
|
|
|
|
| 23 |
|
| 24 |
class TestPhi(unittest.TestCase):
|
| 25 |
"""
|
| 26 |
+
Test case for Phi2 models
|
| 27 |
"""
|
| 28 |
|
| 29 |
+
@pytest.mark.skip(reason="fixme later")
|
| 30 |
@with_temp_dir
|
| 31 |
+
def test_phi2_ft(self, temp_dir):
|
| 32 |
# pylint: disable=duplicate-code
|
| 33 |
cfg = DictDefault(
|
| 34 |
{
|
| 35 |
+
"base_model": "microsoft/phi-2",
|
| 36 |
"trust_remote_code": True,
|
| 37 |
+
"model_type": "AutoModelForCausalLM",
|
| 38 |
"tokenizer_type": "AutoTokenizer",
|
| 39 |
"sequence_len": 512,
|
| 40 |
"sample_packing": False,
|
|
|
|
| 42 |
"adapter": None,
|
| 43 |
"val_set_size": 0.1,
|
| 44 |
"special_tokens": {
|
|
|
|
|
|
|
|
|
|
| 45 |
"pad_token": "<|endoftext|>",
|
| 46 |
},
|
| 47 |
"datasets": [
|
|
|
|
| 57 |
"gradient_accumulation_steps": 1,
|
| 58 |
"output_dir": temp_dir,
|
| 59 |
"learning_rate": 0.00001,
|
| 60 |
+
"optimizer": "paged_adamw_8bit",
|
| 61 |
"lr_scheduler": "cosine",
|
| 62 |
"bf16": True,
|
| 63 |
+
"flash_attention": True,
|
| 64 |
+
"max_steps": 10,
|
| 65 |
+
"save_steps": 10,
|
| 66 |
+
"eval_steps": 10,
|
| 67 |
+
"save_safetensors": True,
|
| 68 |
}
|
| 69 |
)
|
| 70 |
normalize_config(cfg)
|
|
|
|
| 74 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 75 |
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
| 76 |
|
| 77 |
+
@pytest.mark.skip(reason="multipack no longer supported atm")
|
| 78 |
@with_temp_dir
|
| 79 |
def test_ft_packed(self, temp_dir):
|
| 80 |
# pylint: disable=duplicate-code
|
| 81 |
cfg = DictDefault(
|
| 82 |
{
|
| 83 |
+
"base_model": "microsoft/phi-2",
|
| 84 |
"trust_remote_code": True,
|
| 85 |
"model_type": "PhiForCausalLM",
|
| 86 |
"tokenizer_type": "AutoTokenizer",
|