AIR-hl commited on
Commit
b4caec0
·
verified ·
1 Parent(s): e11adc4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -3
README.md CHANGED
@@ -27,8 +27,6 @@ tags:
27
 
28
  ## Training Details
29
 
30
- **Cutome training codes**
31
-
32
  ### Training Hyperparameters
33
  `attn_implementation`: flash_attention_2 \
34
  `bf16`: True \
@@ -39,10 +37,70 @@ tags:
39
  `torch_dtype`: bfloat16 \
40
  `num_train_epochs`: 1 \
41
  `max_seq_length`: 2048 \
42
- `warmup_ratio`: 0.1 \
43
 
44
  ### Results
45
 
46
  `init_train_loss`: 1.421 \
47
  `final_train_loss`: 1.192 \
48
  `eval_loss`: 1.2003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ## Training Details
29
 
 
 
30
  ### Training Hyperparameters
31
  `attn_implementation`: flash_attention_2 \
32
  `bf16`: True \
 
37
  `torch_dtype`: bfloat16 \
38
  `num_train_epochs`: 1 \
39
  `max_seq_length`: 2048 \
40
+ `warmup_ratio`: 0.1
41
 
42
  ### Results
43
 
44
  `init_train_loss`: 1.421 \
45
  `final_train_loss`: 1.192 \
46
  `eval_loss`: 1.2003
47
+
48
+ ### Training script
49
+
50
+ ```python
51
+ import multiprocessing
52
+
53
+ from datasets import load_dataset
54
+ from tqdm.rich import tqdm
55
+ from transformers import AutoTokenizer, AutoModelForCausalLM
56
+ from trl import (
57
+ ModelConfig,
58
+ SFTTrainer,
59
+ get_peft_config,
60
+ get_quantization_config,
61
+ get_kbit_device_map,
62
+ SFTConfig,
63
+ ScriptArguments
64
+ )
65
+ from trl.commands.cli_utils import TrlParser
66
+
67
+ tqdm.pandas()
68
+
69
+ if __name__ == "__main__":
70
+ parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
71
+ args, training_args, model_config = parser.parse_args_and_config()
72
+
73
+ quantization_config = get_quantization_config(model_config)
74
+ model_kwargs = dict(
75
+ revision=model_config.model_revision,
76
+ trust_remote_code=model_config.trust_remote_code,
77
+ attn_implementation=model_config.attn_implementation,
78
+ torch_dtype=model_config.torch_dtype,
79
+ use_cache=False if training_args.gradient_checkpointing else True,
80
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
81
+ quantization_config=quantization_config,
82
+ )
83
+
84
+ model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,
85
+ **model_kwargs)
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
88
+ )
89
+ tokenizer.pad_token = tokenizer.eos_token
90
+
91
+ train_dataset = load_dataset(args.dataset_name,
92
+ split=args.dataset_train_split,
93
+ num_proc=multiprocessing.cpu_count())
94
+
95
+ trainer = SFTTrainer(
96
+ model=model,
97
+ args=training_args,
98
+ train_dataset=train_dataset,
99
+ processing_class=tokenizer,
100
+ peft_config=get_peft_config(model_config),
101
+ )
102
+
103
+ trainer.train()
104
+
105
+ trainer.save_model(training_args.output_dir)
106
+ ```