sanchit-gandhi commited on
Commit
501a06b
·
verified ·
1 Parent(s): e8d9a1c

Saving train state of step 5000

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
accelerate_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
checkpoint-5000-epoch-0/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-Instruct-v0.2-first-6",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": null,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.40.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
checkpoint-5000-epoch-0/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "transformers_version": "4.40.0.dev0"
7
+ }
checkpoint-5000-epoch-0/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4102a8dd687cfc3fc54bc2848c8130965b95de98d42ca5743ec537b152804c41
3
+ size 4987196936
checkpoint-5000-epoch-0/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67e551f279a6ae86f905511f92037cf65e1a3839cff4ef09118d32bdb378a460
3
+ size 1296089984
checkpoint-5000-epoch-0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f19ba61aba16eea08933f227bda11b6d10a0c71a55f891350d93902232328cc
3
+ size 6283286904
checkpoint-5000-epoch-0/model.safetensors.index.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 6283280384
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.5.input_layernorm.weight": "model-00002-of-00002.safetensors",
54
+ "model.layers.5.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
55
+ "model.layers.5.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
56
+ "model.layers.5.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
57
+ "model.layers.5.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
58
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.5.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
60
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
62
+ "model.norm.weight": "model-00002-of-00002.safetensors"
63
+ }
64
+ }
checkpoint-5000-epoch-0/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65173ac419081e25f0d5f93ed77393cf05f5158325ee154a5cbb3e14b47ece07
3
+ size 4450837792
checkpoint-5000-epoch-0/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4729e5dc2d01a6c24756ee0a1599314320a69ec290e9d66593a274370b055ec2
3
+ size 874185360
checkpoint-5000-epoch-0/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aca503cc09e63ca033e29a437a20cc580a9c1db27fef2174e533f58ba275879
3
+ size 16100
checkpoint-5000-epoch-0/random_states_1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31831c2134536b1e81ba1e763e72b2ff98a14a83774fcfb30d153a66dca7879c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a628258539b4090ce50e9faf5fda4d613f523ca957f3e837c02d316e4b20122
3
+ size 16100
checkpoint-5000-epoch-0/random_states_3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d594aa54f68e8eb41c3deb9753bf43474028f44edb92db1930ebdf967f708a7c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ca4240374ff4b93ad0537aca2f28bfc293153a29ee8069cf09d088ca30fee7
3
+ size 16100
checkpoint-5000-epoch-0/random_states_5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d6f3577977e8c32eac49b1c5136c6718fcd9c66051b703ba6e305cca03a8fb0
3
+ size 16100
checkpoint-5000-epoch-0/random_states_6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0ef1d86e60e6cedda41454cd08e0b3652ab6a6eb017b4eed0d6b84866ed7d46
3
+ size 16100
checkpoint-5000-epoch-0/random_states_7.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d860c07ef8d57c8162394106fcd87c34e7924d859b28b4b292e9e792a96af2
3
+ size 16100
checkpoint-5000-epoch-0/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c25f7255aa53945ccffbdb6904da689924024cb2e693a6c6739ade9fae0454a2
3
+ size 1064
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-Instruct-v0.2-first-6",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": null,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.40.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
config_mistral.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model arguments
2
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2-first-6
3
+ teacher_model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
4
+ dtype: bfloat16
5
+ load_teacher_in_4bit: true
6
+ optim: adamw_bnb_8bit
7
+ freeze_embeddings: true
8
+ freeze_n_layers: 4
9
+
10
+ # Data arguments
11
+ train_dataset_name: HuggingFaceTB/cosmopedia
12
+ train_dataset_config_name:
13
+ - auto_math_text
14
+ - khanacademy
15
+ - openstax
16
+ - stanford
17
+ - stories
18
+ - web_samples_v1
19
+ - web_samples_v2
20
+ - wikihow
21
+ train_split_name: train[1000:]
22
+ eval_split_name: train[:1000]
23
+ prompt_column_name: prompt
24
+ eval_prompt_column_name: prompt
25
+ max_steps: 200000
26
+ max_train_samples: 20000000
27
+
28
+ # Training arguments
29
+ do_train: true
30
+ do_eval: true
31
+ per_device_eval_batch_size: 4
32
+ per_device_train_batch_size: 4
33
+ gradient_accumulation_steps: 2
34
+ max_label_length: 4096
35
+ learning_rate: 0.0001
36
+ warmup_steps: 500
37
+ gradient_checkpointing: false
38
+ dataloader_num_workers: 4
39
+ preprocessing_num_workers: 32
40
+ ddp_timeout: 7200
41
+ save_strategy: steps
42
+ save_steps: 5000
43
+ evaluation_strategy: steps
44
+ eval_steps: 5000
45
+ logging_steps: 25
46
+ output_router_logits: true
47
+ report_to: all
48
+ output_dir: ./
49
+ overwrite_output_dir: false
50
+ save_total_limit: 1
51
+ wandb_project: distil-mistral
52
+ push_to_hub: true
53
+
distil-mistral/1714987090.311693/events.out.tfevents.1714987090.ip-26-0-168-30.3307834.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b801b4b147bcc14041bc6a6cc51d1eb92578d4a12cccaa892ae78f78026117f
3
+ size 1168
distil-mistral/1714987090.3164885/hparams.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 32
4
+ gradient_accumulation_steps: 2
5
+ learning_rate: 0.0001
6
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
7
+ - linear
8
+ max_steps: 200000
9
+ mixed_precision: bf16
10
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2-first-6
11
+ num_train_epochs: 3.0
12
+ per_device_train_batch_size: 4
13
+ teacher_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
14
+ temperature: 2.0
15
+ warmup_steps: 500
16
+ weight_decay: 0.0
distil-mistral/events.out.tfevents.1714987080.ip-26-0-168-30.3307834.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03cc9f820c2e7bdae6c0beb0fc7a1c28264167e027508a79ea0f258c721e2928
3
+ size 62058
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "transformers_version": "4.40.0.dev0"
7
+ }
run_distillation.py ADDED
@@ -0,0 +1,1594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training langauge models for conditional language modelling tasks via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import shutil
26
+ import sys
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from functools import partial
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Union
32
+
33
+ import datasets
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ Dataset,
42
+ DatasetDict,
43
+ IterableDataset,
44
+ IterableDatasetDict,
45
+ concatenate_datasets,
46
+ interleave_datasets,
47
+ load_dataset,
48
+ )
49
+ from huggingface_hub import create_repo, get_full_repo_name, upload_folder
50
+ from peft import LoraConfig, get_peft_model
51
+ from torch.utils.data import DataLoader
52
+ from tqdm import tqdm
53
+ from transformers import (
54
+ AutoConfig,
55
+ AutoModelForCausalLM,
56
+ AutoTokenizer,
57
+ BatchEncoding,
58
+ BitsAndBytesConfig,
59
+ HfArgumentParser,
60
+ PreTrainedTokenizerBase,
61
+ Seq2SeqTrainingArguments,
62
+ get_scheduler,
63
+ set_seed, is_bitsandbytes_available,
64
+ )
65
+ from transformers.training_args import OptimizerNames
66
+ from transformers.utils import check_min_version
67
+ from transformers.utils.versions import require_version
68
+
69
+
70
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
71
+ check_min_version("4.34.0.dev0")
72
+
73
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
74
+
75
+ logger = get_logger(__name__)
76
+
77
+
78
+ @dataclass
79
+ class ModelArguments:
80
+ """
81
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
82
+ """
83
+
84
+ model_name_or_path: str = field(
85
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
86
+ )
87
+ teacher_model_name_or_path: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
90
+ )
91
+ config_name: Optional[str] = field(
92
+ default=None,
93
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
94
+ )
95
+ tokenizer_name: Optional[str] = field(
96
+ default=None,
97
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
98
+ )
99
+ cache_dir: Optional[str] = field(
100
+ default=None,
101
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
102
+ )
103
+ use_fast_tokenizer: bool = field(
104
+ default=True,
105
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106
+ )
107
+ model_revision: str = field(
108
+ default="main",
109
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
110
+ )
111
+ subfolder: str = field(
112
+ default="",
113
+ metadata={
114
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
115
+ "specify the folder name here."
116
+ },
117
+ )
118
+ token: str = field(
119
+ default=None,
120
+ metadata={
121
+ "help": (
122
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
123
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
124
+ )
125
+ },
126
+ )
127
+ attn_implementation: Optional[str] = field(
128
+ default=None,
129
+ metadata={
130
+ "help": (
131
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
132
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
133
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
134
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
135
+ )
136
+ },
137
+ )
138
+ load_teacher_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the teacher model."})
139
+ load_teacher_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the teacher model."})
140
+ load_student_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the student model."})
141
+ load_student_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the student model."})
142
+ bnb_4bit_quant_type: Optional[str] = field(
143
+ default="nf4", metadata={"help": "Quantization type if the teacher is quantized (fp4 or nf4)"}
144
+ )
145
+ use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether or not to use nested quantization."})
146
+ lora_r: Optional[int] = field(
147
+ default=16,
148
+ metadata={"help": "LoRA R value."},
149
+ )
150
+ lora_alpha: Optional[int] = field(
151
+ default=32,
152
+ metadata={"help": "LoRA alpha."},
153
+ )
154
+ lora_dropout: Optional[float] = field(
155
+ default=0.05,
156
+ metadata={"help": "LoRA dropout."},
157
+ )
158
+ lora_target_modules: Optional[List[str]] = field(
159
+ default=None,
160
+ metadata={"help": "LoRA target modules."},
161
+ )
162
+ lora_modules_to_save: Optional[List[str]] = field(
163
+ default=None,
164
+ metadata={"help": "Model layers to unfreeze & train"},
165
+ )
166
+ instruction_model: Optional[bool] = field(
167
+ default=None,
168
+ metadata={"help": "Whether or not the pre-trained model is instruction tuned"},
169
+ )
170
+
171
+
172
+ @dataclass
173
+ class DataTrainingArguments:
174
+ """
175
+ Arguments pertaining to what data we are going to input our model for training and eval.
176
+ """
177
+
178
+ train_dataset_name: List[str] = field(
179
+ default=None,
180
+ metadata={
181
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
182
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
183
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
184
+ },
185
+ )
186
+ train_dataset_config_name: Optional[List[str]] = field(
187
+ default=None,
188
+ metadata={
189
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
190
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
191
+ "match the order of the datasets."
192
+ },
193
+ )
194
+ train_dataset_samples: Optional[List[str]] = field(
195
+ default=None,
196
+ metadata={
197
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
198
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
199
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
200
+ "sample from every dataset is used once per epoch."
201
+ },
202
+ )
203
+ eval_dataset_name: Optional[List[str]] = field(
204
+ default=None,
205
+ metadata={
206
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
207
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
208
+ "ids by a '+' symbol."
209
+ },
210
+ )
211
+ eval_dataset_config_name: Optional[List[str]] = field(
212
+ default=None,
213
+ metadata={
214
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
215
+ "training dataset config name if unspecified."
216
+ },
217
+ )
218
+ dataset_cache_dir: Optional[str] = field(
219
+ default=None,
220
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
221
+ )
222
+ overwrite_cache: bool = field(
223
+ default=False,
224
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
225
+ )
226
+ preprocessing_num_workers: Optional[int] = field(
227
+ default=None,
228
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
229
+ )
230
+ max_train_samples: Optional[int] = field(
231
+ default=None,
232
+ metadata={
233
+ "help": (
234
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
235
+ )
236
+ },
237
+ )
238
+ max_eval_samples: Optional[int] = field(
239
+ default=None,
240
+ metadata={
241
+ "help": (
242
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
243
+ )
244
+ },
245
+ )
246
+ text_column_name: Optional[List[str]] = field(
247
+ default=None,
248
+ metadata={"help": "The name of the dataset column containing the generated text data in the training set."},
249
+ )
250
+ prompt_column_name: Optional[List[str]] = field(
251
+ default=None,
252
+ metadata={"help": "The name of the dataset column containing the prompt data. Defaults to 'prompt'"},
253
+ )
254
+ eval_text_column_name: Optional[List[str]] = field(
255
+ default=None,
256
+ metadata={"help": "The name of the dataset column containing the generated text data in the evaluation set."},
257
+ )
258
+ eval_prompt_column_name: Optional[List[str]] = field(
259
+ default=None,
260
+ metadata={"help": "The name of the dataset column containing the prompt data in the evaluation set."},
261
+ )
262
+ max_label_length: Optional[int] = field(
263
+ default=4096,
264
+ metadata={"help": "Truncate target labels that are longer `max_label_length` tokens."},
265
+ )
266
+ pad_target_to_multiple_of: Optional[int] = field(
267
+ default=None,
268
+ metadata={
269
+ "help": (
270
+ "If set will pad the target sequence to a multiple of the provided value. This is important to "
271
+ "avoid triggering recompilations when using torch compile. If unspecified, will default to padding "
272
+ "the targets to max length."
273
+ )
274
+ },
275
+ )
276
+ preprocessing_only: bool = field(
277
+ default=False,
278
+ metadata={
279
+ "help": (
280
+ "Whether to only do data preprocessing and skip training. This is especially useful when data "
281
+ "preprocessing errors out in distributed training due to timeout. In this case, one should run the "
282
+ "preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets "
283
+ "can consequently be loaded in distributed training"
284
+ )
285
+ },
286
+ )
287
+ train_split_name: Optional[List[str]] = field(
288
+ default=lambda: ["train"],
289
+ metadata={
290
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
291
+ },
292
+ )
293
+ eval_split_name: Optional[List[str]] = field(
294
+ default=lambda: ["validation"],
295
+ metadata={
296
+ "help": (
297
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
298
+ )
299
+ },
300
+ )
301
+ streaming: bool = field(
302
+ default=False,
303
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
304
+ )
305
+ wandb_project: str = field(
306
+ default="distil-mistral",
307
+ metadata={"help": "The name of the wandb project."},
308
+ )
309
+
310
+
311
+ @dataclass
312
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
313
+ freeze_embeddings: Optional[bool] = field(
314
+ default=False, metadata={"help": "Whether to freeze the input and output embeddings of the student model."}
315
+ )
316
+ freeze_n_layers: Optional[int] = field(
317
+ default=None, metadata={"help": "Freeze the first n layers of the student model."}
318
+ )
319
+ temperature: Optional[float] = field(
320
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
321
+ )
322
+ kl_weight: Optional[float] = field(
323
+ default=1.0,
324
+ metadata={
325
+ "help": (
326
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
327
+ "computed between the teacher-student hidden states and attentions."
328
+ )
329
+ },
330
+ )
331
+ output_router_logits: Optional[bool] = field(
332
+ default=False,
333
+ metadata={
334
+ "help": "Whether or not to return the router logits in the forward pass. Enabling this will "
335
+ "also configure the model to compute the auxiliary loss."
336
+ },
337
+ )
338
+ dtype: Optional[str] = field(
339
+ default="float32",
340
+ metadata={
341
+ "help": (
342
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
343
+ "`float16` or `bfloat16` (both half-precision)."
344
+ )
345
+ },
346
+ )
347
+ completions_only: Optional[bool] = field(
348
+ default=False,
349
+ metadata={
350
+ "help": "Whether to train only on the target completions, or the prompt + completions."
351
+ },
352
+ )
353
+
354
+
355
+ @dataclass
356
+ class DataCollatorCausalLMWithPadding:
357
+ """
358
+ Data collator that will dynamically pad the inputs received.
359
+ Args:
360
+ tokenizer ([`PreTrainedTokenizer`])
361
+ The tokenizer used for tokenizing the data.
362
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
363
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
364
+ See above for details.
365
+ max_target_length (:obj:`int`, `optional`):
366
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
367
+ completions_only (:obj:`bool`, `optional`):
368
+ Whether to train on the assistant responses (completions) only, or the combination of prompt + responses.
369
+ """
370
+
371
+ tokenizer: PreTrainedTokenizerBase
372
+ target_padding: Union[bool, str] = "max_length"
373
+ max_target_length: Optional[int] = None
374
+ completions_only: Optional[bool] = False
375
+
376
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> BatchEncoding:
377
+ # dataloader returns a list of features which we convert to a dict
378
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
379
+ label_lengths = [len(feature["labels"]) for feature in features]
380
+ prompt_lengths = [feature["prompt_length"] for feature in features]
381
+
382
+ batch = self.tokenizer.pad(
383
+ label_features,
384
+ max_length=self.max_target_length,
385
+ padding=self.target_padding,
386
+ return_tensors="pt",
387
+ )
388
+
389
+ labels_mask = batch["attention_mask"]
390
+
391
+ if self.completions_only:
392
+ # don't include prompts in loss calculation
393
+ for idx in range(len(prompt_lengths)):
394
+ padding_length = labels_mask.shape[1] - label_lengths[idx]
395
+ labels_mask[idx, padding_length : padding_length + prompt_lengths[idx]] = 0
396
+
397
+ # replace padding with -100 to ignore loss correctly
398
+ labels = batch["input_ids"].masked_fill(labels_mask.ne(1), -100)
399
+
400
+ batch["labels"] = labels
401
+
402
+ return batch
403
+
404
+
405
+ def log_metric(
406
+ accelerator,
407
+ metrics: Dict,
408
+ train_time: float,
409
+ step: int,
410
+ epoch: int,
411
+ learning_rate: float = None,
412
+ prefix: str = "train",
413
+ ):
414
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
415
+ log_metrics = {}
416
+ for k, v in metrics.items():
417
+ log_metrics[f"{prefix}/{k}"] = v
418
+ log_metrics[f"{prefix}/time"] = train_time
419
+ log_metrics[f"{prefix}/epoch"] = epoch
420
+ if learning_rate is not None:
421
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
422
+ accelerator.log(log_metrics, step=step)
423
+
424
+
425
+ def log_pred(
426
+ accelerator,
427
+ pred_str: List[str],
428
+ label_str: List[str],
429
+ step: int,
430
+ epoch: int,
431
+ evaluation_strategy: str,
432
+ prefix: str = "eval",
433
+ num_lines: int = 200000,
434
+ ):
435
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
436
+ if accelerator.is_main_process:
437
+ wandb_tracker = accelerator.get_tracker("wandb")
438
+ # pretty name for current step: step 50000 -> step 50k
439
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
440
+ prefix_pretty = prefix.replace("/", "-")
441
+
442
+ if evaluation_strategy == "epoch":
443
+ table_name = f"predictions/{prefix_pretty}-epoch-{epoch}"
444
+ else:
445
+ table_name = f"predictions/{prefix_pretty}-step-{cur_step_pretty}"
446
+
447
+ # convert str data to a wandb compatible format
448
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
449
+ # log as a table with the appropriate headers
450
+ wandb_tracker.log_table(
451
+ table_name=table_name,
452
+ columns=["Target", "Pred"],
453
+ data=str_data[:num_lines],
454
+ step=step,
455
+ )
456
+
457
+
458
+ def convert_dataset_str_to_list(
459
+ dataset_names,
460
+ dataset_config_names,
461
+ splits=None,
462
+ text_column_names=None,
463
+ prompt_column_names=None,
464
+ dataset_samples=None,
465
+ default_split="train",
466
+ ) -> List[Dict]:
467
+ """
468
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
469
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
470
+ function returns a list of dictionaries, one for each dataset.
471
+ """
472
+ if isinstance(dataset_names, str):
473
+ dataset_names = [dataset_names]
474
+ splits = [splits] if splits else None
475
+ text_column_names = [text_column_names] if text_column_names else None
476
+ prompt_column_names = [prompt_column_names] if prompt_column_names else None
477
+ if isinstance(dataset_config_names, str):
478
+ dataset_config_names = [dataset_config_names]
479
+
480
+ if len(dataset_names) == 1 and len(dataset_config_names) > 1:
481
+ dataset_names = len(dataset_config_names) * dataset_names
482
+
483
+ if isinstance(splits, list) and len(splits) == 1 and len(dataset_config_names) > 1:
484
+ splits = len(dataset_config_names) * splits
485
+
486
+ if isinstance(text_column_names, list) and len(text_column_names) == 1 and len(dataset_config_names) > 1:
487
+ text_column_names = len(dataset_config_names) * text_column_names
488
+
489
+ if isinstance(prompt_column_names, list) and len(prompt_column_names) == 1 and len(dataset_config_names) > 1:
490
+ prompt_column_names = len(dataset_config_names) * prompt_column_names
491
+
492
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
493
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
494
+ raise ValueError(
495
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
496
+ f" {len(dataset_config_names)} configs."
497
+ )
498
+
499
+ if splits is not None and len(splits) != len(dataset_names):
500
+ raise ValueError(
501
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
502
+ )
503
+
504
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
505
+ raise ValueError(
506
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
507
+ f" {len(text_column_names)} text column names."
508
+ )
509
+
510
+ if prompt_column_names is not None and len(prompt_column_names) != len(dataset_names):
511
+ raise ValueError(
512
+ f"Ensure one prompt column name is passed for each dataset, got {len(dataset_names)} datasets and"
513
+ f" {len(prompt_column_names)} prompt column names."
514
+ )
515
+
516
+ if dataset_samples is not None:
517
+ if len(dataset_samples) != len(dataset_names):
518
+ raise ValueError(
519
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
520
+ f"{len(dataset_samples)} samples."
521
+ )
522
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
523
+ else:
524
+ dataset_samples = [None] * len(dataset_names)
525
+
526
+ dataset_config_names = (
527
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
528
+ )
529
+ text_column_names = (
530
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
531
+ )
532
+ prompt_column_names = (
533
+ prompt_column_names if prompt_column_names is not None else [None for _ in range(len(dataset_names))]
534
+ )
535
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
536
+
537
+ dataset_names_dict = []
538
+ for i, ds_name in enumerate(dataset_names):
539
+ dataset_names_dict.append(
540
+ {
541
+ "name": ds_name,
542
+ "config": dataset_config_names[i],
543
+ "split": splits[i],
544
+ "text_column_name": text_column_names[i],
545
+ "prompt_column_name": prompt_column_names[i],
546
+ "samples": dataset_samples[i],
547
+ }
548
+ )
549
+ return dataset_names_dict
550
+
551
+
552
+ def load_multiple_datasets(
553
+ dataset_names: Union[List, str],
554
+ dataset_config_names: Union[List, str],
555
+ splits: Optional[Union[List, str]] = None,
556
+ text_column_names: Optional[List] = None,
557
+ prompt_column_names: Optional[List] = None,
558
+ stopping_strategy: Optional[str] = "first_exhausted",
559
+ dataset_samples: Optional[Union[List, np.array]] = None,
560
+ streaming: Optional[bool] = False,
561
+ seed: Optional[int] = None,
562
+ accelerator: Optional[Accelerator] = None,
563
+ **kwargs,
564
+ ) -> Union[Dataset, IterableDataset]:
565
+ dataset_names_dict = convert_dataset_str_to_list(
566
+ dataset_names, dataset_config_names, splits, text_column_names, prompt_column_names, dataset_samples
567
+ )
568
+
569
+ if dataset_samples is not None:
570
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
571
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
572
+ else:
573
+ probabilities = None
574
+
575
+ all_datasets = []
576
+ # iterate over the datasets we want to interleave
577
+ for dataset_dict in tqdm(
578
+ dataset_names_dict,
579
+ desc="Combining datasets...",
580
+ disable=not accelerator.is_main_process,
581
+ ):
582
+ dataset = load_dataset(
583
+ dataset_dict["name"],
584
+ dataset_dict["config"],
585
+ split=dataset_dict["split"],
586
+ streaming=streaming,
587
+ **kwargs,
588
+ )
589
+
590
+ columns_to_keep = {"text"}
591
+ dataset_features = dataset.features.keys()
592
+
593
+ if dataset_dict["text_column_name"] not in dataset_features:
594
+ raise ValueError(
595
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
596
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
597
+ f" correct text column - one of {', '.join(dataset_features)}."
598
+ )
599
+
600
+ # blanket renaming of all transcription columns to text
601
+ if dataset_dict["text_column_name"] != "text":
602
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
603
+
604
+ # blanket renaming of all prompt columns to prompt
605
+ if dataset_dict["prompt_column_name"] is not None:
606
+ if dataset_dict["prompt_column_name"] not in dataset_features:
607
+ raise ValueError(
608
+ f"Prompt column name {dataset_dict['prompt_column_name']} not found in dataset"
609
+ f" '{dataset_dict['name']}'. Make sure to set `--prompt_column_name` to the"
610
+ f" correct prompt column - one of {', '.join(dataset_features)}."
611
+ )
612
+ elif dataset_dict["prompt_column_name"] != "prompt":
613
+ dataset = dataset.rename_column(dataset_dict["prompt_column_name"], "prompt")
614
+ columns_to_keep.add("prompt")
615
+
616
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
617
+ all_datasets.append(dataset)
618
+
619
+ if len(all_datasets) == 1:
620
+ # we have a single dataset so just return it as is
621
+ return all_datasets[0]
622
+
623
+ if streaming:
624
+ interleaved_dataset = interleave_datasets(
625
+ all_datasets,
626
+ stopping_strategy=stopping_strategy,
627
+ probabilities=probabilities,
628
+ seed=seed,
629
+ )
630
+ else:
631
+ interleaved_dataset = concatenate_datasets(all_datasets)
632
+
633
+ # shuffle mixed dataset prior to potentially truncating it
634
+ interleaved_dataset = interleaved_dataset.shuffle(seed)
635
+ return interleaved_dataset
636
+
637
+
638
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
639
+ """Helper function to sort saved checkpoints from oldest to newest."""
640
+ ordering_and_checkpoint_path = []
641
+
642
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
643
+
644
+ for path in glob_checkpoints:
645
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
646
+ if regex_match is not None and regex_match.groups() is not None:
647
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
648
+
649
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
650
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
651
+ return checkpoints_sorted
652
+
653
+
654
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> Union[List, None]:
655
+ """Helper function to delete old checkpoints."""
656
+ if save_total_limit is None or save_total_limit <= 0:
657
+ return
658
+ # Check if we should delete older checkpoint(s)
659
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
660
+ if len(checkpoints_sorted) <= save_total_limit:
661
+ return
662
+
663
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
664
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
665
+ for checkpoint in checkpoints_to_be_deleted:
666
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
667
+ shutil.rmtree(checkpoint, ignore_errors=True)
668
+ checkpoints_to_be_deleted = [f"*{Path(checkpoint).absolute().name}*" for checkpoint in checkpoints_to_be_deleted]
669
+ return checkpoints_to_be_deleted
670
+
671
+
672
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
673
+
674
+
675
+ def get_last_checkpoint(folder):
676
+ content = os.listdir(folder)
677
+ checkpoints = [
678
+ path
679
+ for path in content
680
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
681
+ ]
682
+ if len(checkpoints) == 0:
683
+ return
684
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
685
+
686
+
687
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
688
+ """
689
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
690
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
691
+ (e.g. if the module is frozen).
692
+ """
693
+ result = []
694
+ for name, child in model.named_children():
695
+ result += [
696
+ f"{name}.{n}"
697
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
698
+ if not (
699
+ isinstance(child, tuple(forbidden_layer_types))
700
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
701
+ )
702
+ ]
703
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
704
+ result += list(model._parameters.keys())
705
+ return result
706
+
707
+
708
+ def get_quantization_config(
709
+ model_args: ModelArguments, torch_dtype: torch.dtype
710
+ ) -> tuple[BitsAndBytesConfig | None, BitsAndBytesConfig | None]:
711
+ if model_args.load_teacher_in_4bit:
712
+ quantization_config_teacher = BitsAndBytesConfig(
713
+ load_in_4bit=True,
714
+ bnb_4bit_compute_dtype=torch_dtype,
715
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
716
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
717
+ )
718
+ elif model_args.load_teacher_in_8bit:
719
+ quantization_config_teacher = BitsAndBytesConfig(load_in_8bit=True)
720
+ else:
721
+ quantization_config_teacher = None
722
+
723
+ if model_args.load_student_in_4bit:
724
+ quantization_config_student = BitsAndBytesConfig(
725
+ load_in_4bit=True,
726
+ bnb_4bit_compute_dtype=torch_dtype,
727
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
728
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
729
+ )
730
+ elif model_args.load_student_in_8bit:
731
+ quantization_config_student = BitsAndBytesConfig(load_in_8bit=True)
732
+ else:
733
+ quantization_config_student = None
734
+
735
+ return quantization_config_teacher, quantization_config_student
736
+
737
+
738
+ def main():
739
+ # 1. Parse input arguments
740
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
741
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
742
+
743
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
744
+ # If we pass only one argument to the script and it's the path to a json file,
745
+ # let's parse it to get our arguments.
746
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
747
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
748
+ # If we pass only one argument to the script and it's the path to a yaml file,
749
+ # let's parse it to get our arguments.
750
+ model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))
751
+ else:
752
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
753
+
754
+ # 2. Initialize the accelerator
755
+ # We will let the accelerator handle device placement for us in this example
756
+ # We simply have to specify the training precision and any trackers being used
757
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
758
+ # it to accelerate format
759
+ if training_args.dtype == "float16":
760
+ mixed_precision = "fp16"
761
+ teacher_dtype = torch.float16
762
+ elif training_args.dtype == "bfloat16":
763
+ mixed_precision = "bf16"
764
+ teacher_dtype = torch.bfloat16
765
+ else:
766
+ mixed_precision = "no"
767
+ teacher_dtype = torch.float32
768
+
769
+ accelerator = Accelerator(
770
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
771
+ mixed_precision=mixed_precision,
772
+ log_with=training_args.report_to,
773
+ project_dir=training_args.output_dir,
774
+ )
775
+
776
+ accelerator.init_trackers(
777
+ project_name=data_args.wandb_project,
778
+ config={
779
+ "learning_rate": training_args.learning_rate,
780
+ "model_name_or_path": model_args.model_name_or_path,
781
+ "teacher_name_or_path": model_args.teacher_model_name_or_path,
782
+ "num_train_epochs": training_args.num_train_epochs,
783
+ "max_steps": training_args.max_steps,
784
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
785
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
786
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
787
+ "mixed_precision": mixed_precision,
788
+ "lr_scheduler_type": training_args.lr_scheduler_type,
789
+ "warmup_steps": training_args.warmup_steps,
790
+ "weight_decay": training_args.weight_decay,
791
+ "adam_beta1": training_args.adam_beta1,
792
+ "adam_beta2": training_args.adam_beta2,
793
+ "temperature": training_args.temperature,
794
+ },
795
+ )
796
+
797
+ # 3. Set-up basic logging
798
+ # Create one log on every process with the configuration for debugging
799
+ logging.basicConfig(
800
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
801
+ datefmt="%m/%d/%Y %H:%M:%S",
802
+ level=logging.INFO,
803
+ )
804
+ # Log a small summary on each proces
805
+ logger.warning(
806
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
807
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
808
+ )
809
+
810
+ # Set the verbosity to info of the Transformers logger (on main process only)
811
+ if accelerator.is_local_main_process:
812
+ datasets.utils.logging.set_verbosity_warning()
813
+ transformers.utils.logging.set_verbosity_info()
814
+ else:
815
+ datasets.utils.logging.set_verbosity_error()
816
+ transformers.utils.logging.set_verbosity_error()
817
+ logger.info("Training/evaluation parameters %s", training_args)
818
+
819
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
820
+ last_checkpoint = None
821
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
822
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
823
+ if last_checkpoint is None and len(sorted_checkpoints(training_args.output_dir)) > 0:
824
+ raise ValueError(
825
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
826
+ "Use --overwrite_output_dir to overcome."
827
+ )
828
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
829
+ logger.info(
830
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
831
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
832
+ )
833
+
834
+ # 5. Handle the repository creation
835
+ if accelerator.is_main_process:
836
+ if training_args.output_dir is not None:
837
+ os.makedirs(training_args.output_dir, exist_ok=True)
838
+ if training_args.push_to_hub:
839
+ if training_args.hub_model_id is None:
840
+ repo_name = get_full_repo_name(
841
+ Path(training_args.output_dir).absolute().name,
842
+ token=training_args.hub_token,
843
+ )
844
+ else:
845
+ repo_name = training_args.hub_model_id
846
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
847
+
848
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
849
+ if "wandb" not in gitignore:
850
+ gitignore.write("wandb\n")
851
+ accelerator.wait_for_everyone()
852
+
853
+ # 6. Load dataset - either streaming or non-streaming (offline)
854
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
855
+
856
+ # set seed for determinism
857
+ set_seed(training_args.seed)
858
+
859
+ if training_args.do_train:
860
+ raw_datasets["train"] = load_multiple_datasets(
861
+ data_args.train_dataset_name,
862
+ data_args.train_dataset_config_name,
863
+ splits=data_args.train_split_name,
864
+ text_column_names=data_args.text_column_name,
865
+ prompt_column_names=data_args.prompt_column_name,
866
+ streaming=data_args.streaming,
867
+ dataset_samples=data_args.train_dataset_samples,
868
+ seed=training_args.seed,
869
+ accelerator=accelerator,
870
+ cache_dir=data_args.dataset_cache_dir,
871
+ token=model_args.token,
872
+ num_proc=data_args.preprocessing_num_workers,
873
+ )
874
+ raw_datasets_train_features = set(raw_datasets["train"].features.keys())
875
+
876
+ if training_args.do_eval:
877
+ dataset_names_dict = convert_dataset_str_to_list(
878
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
879
+ (
880
+ data_args.eval_dataset_config_name
881
+ if data_args.eval_dataset_config_name
882
+ else data_args.train_dataset_config_name
883
+ ),
884
+ splits=data_args.eval_split_name,
885
+ text_column_names=data_args.eval_text_column_name,
886
+ prompt_column_names=data_args.eval_prompt_column_name,
887
+ )
888
+ all_eval_splits = []
889
+ if len(dataset_names_dict) == 1:
890
+ # load a single eval set
891
+ dataset_dict = dataset_names_dict[0]
892
+ all_eval_splits.append("eval")
893
+ raw_datasets["eval"] = load_dataset(
894
+ dataset_dict["name"],
895
+ dataset_dict["config"],
896
+ split=dataset_dict["split"],
897
+ cache_dir=data_args.dataset_cache_dir,
898
+ token=model_args.token,
899
+ streaming=data_args.streaming,
900
+ )
901
+ if dataset_dict["text_column_name"] != "text":
902
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
903
+ if dataset_dict["prompt_column_name"] and dataset_dict["prompt_column_name"] != "prompt":
904
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_prompt_column_name, "prompt")
905
+ else:
906
+ # load multiple eval sets
907
+ for dataset_dict in dataset_names_dict:
908
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['config'].replace('.', '-')}"
909
+ all_eval_splits.append(pretty_name)
910
+ raw_datasets[pretty_name] = load_dataset(
911
+ dataset_dict["name"],
912
+ dataset_dict["config"],
913
+ split=dataset_dict["split"],
914
+ cache_dir=data_args.dataset_cache_dir,
915
+ token=model_args.token,
916
+ streaming=data_args.streaming,
917
+ )
918
+ # make column names consistent (text, prompt)
919
+ columns_to_keep = {"text"}
920
+ if dataset_dict["text_column_name"] != "text":
921
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
922
+ dataset_dict["text_column_name"], "text"
923
+ )
924
+ if dataset_dict["prompt_column_name"]:
925
+ if dataset_dict["prompt_column_name"] != "prompt":
926
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
927
+ dataset_dict["prompt_column_name"], "prompt"
928
+ )
929
+ columns_to_keep.add("prompt")
930
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
931
+ set(raw_datasets[pretty_name].features.keys()) - columns_to_keep
932
+ )
933
+
934
+ if not training_args.do_train and not training_args.do_eval:
935
+ raise ValueError(
936
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
937
+ )
938
+
939
+ # 7. Load pretrained model, tokenizer, and feature extractor
940
+ config = AutoConfig.from_pretrained(
941
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
942
+ cache_dir=model_args.cache_dir,
943
+ revision=model_args.model_revision,
944
+ token=model_args.token,
945
+ )
946
+ if training_args.output_router_logits:
947
+ config.output_router_logits = True
948
+
949
+ tokenizer = AutoTokenizer.from_pretrained(
950
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
951
+ cache_dir=model_args.cache_dir,
952
+ use_fast=model_args.use_fast_tokenizer,
953
+ revision=model_args.model_revision,
954
+ token=model_args.token,
955
+ )
956
+ if tokenizer.pad_token_id is None:
957
+ tokenizer.pad_token = tokenizer.eos_token
958
+
959
+ quantization_config_teacher, quantization_config_student = get_quantization_config(
960
+ model_args, torch_dtype=teacher_dtype
961
+ )
962
+
963
+ if model_args.teacher_model_name_or_path:
964
+ # The teacher model can safely be cast to the dtype of training since we don't
965
+ # update the params
966
+ teacher_model = AutoModelForCausalLM.from_pretrained(
967
+ model_args.teacher_model_name_or_path,
968
+ cache_dir=model_args.cache_dir,
969
+ token=model_args.token,
970
+ low_cpu_mem_usage=True,
971
+ torch_dtype=teacher_dtype,
972
+ attn_implementation=model_args.attn_implementation,
973
+ quantization_config=quantization_config_teacher,
974
+ ).eval()
975
+ else:
976
+ teacher_model = None
977
+
978
+ student_model = AutoModelForCausalLM.from_pretrained(
979
+ model_args.model_name_or_path,
980
+ config=config,
981
+ cache_dir=model_args.cache_dir,
982
+ revision=model_args.model_revision,
983
+ subfolder=model_args.subfolder,
984
+ token=model_args.token,
985
+ low_cpu_mem_usage=True,
986
+ attn_implementation=model_args.attn_implementation,
987
+ quantization_config=quantization_config_student,
988
+ )
989
+
990
+ if quantization_config_student is not None:
991
+ lora_config = LoraConfig(
992
+ r=model_args.lora_r,
993
+ lora_alpha=model_args.lora_alpha,
994
+ target_modules=model_args.lora_target_modules,
995
+ lora_dropout=model_args.lora_dropout,
996
+ bias="none",
997
+ task_type="CAUSAL_LM",
998
+ )
999
+ student_model = get_peft_model(student_model, lora_config)
1000
+
1001
+ if student_model.generation_config.bos_token_id is None or (teacher_model and teacher_model.generation_config.bos_token_id is None):
1002
+ student_error = f"Make sure that `generation_config.bos_token_id` is correctly defined. Got {student_model.generation_config.bos_token_id} for the student."
1003
+ teacher_error = f"Got {teacher_model.generation_config.bos_token_id} for the teacher." if teacher_model else None
1004
+ raise ValueError(student_error + teacher_error)
1005
+
1006
+ def set_trainable_parameters(module, requires_grad=False):
1007
+ for param in module.parameters():
1008
+ param.requires_grad = requires_grad
1009
+ module._requires_grad = requires_grad
1010
+
1011
+ forbidden_module = []
1012
+ # freeze student embeddings if necessary
1013
+ if training_args.freeze_embeddings:
1014
+ set_trainable_parameters(student_model.get_output_embeddings(), requires_grad=False)
1015
+ set_trainable_parameters(student_model.get_input_embeddings(), requires_grad=False)
1016
+ forbidden_module.extend([student_model.get_output_embeddings(), student_model.get_input_embeddings()])
1017
+
1018
+ if training_args.freeze_n_layers:
1019
+ for i in range(int(training_args.freeze_n_layers)):
1020
+ set_trainable_parameters(student_model.model.layers[i], requires_grad=False)
1021
+ forbidden_module.extend([student_model.model.layers[i]])
1022
+
1023
+ # enable gradient checkpointing if necessary
1024
+ if training_args.gradient_checkpointing:
1025
+ if training_args.freeze_embeddings or training_args.freeze_n_layers:
1026
+ raise ValueError(
1027
+ "Gradient checkpointing is not compatible with `--freeze_embeddings` or `--freeze_n_layers`. "
1028
+ "Either un-freeze these layers, or set `--gradient_checkpointing=False`."
1029
+ )
1030
+ student_model.gradient_checkpointing_enable()
1031
+
1032
+ student_model.generation_config.max_length = data_args.max_label_length
1033
+
1034
+ # 8. Save all pre-processed tokenizers/config/generation configs
1035
+ if accelerator.is_main_process:
1036
+ tokenizer.save_pretrained(training_args.output_dir)
1037
+ # save the config and generation config as well
1038
+ config.save_pretrained(training_args.output_dir)
1039
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1040
+
1041
+ accelerator.wait_for_everyone()
1042
+
1043
+
1044
+ # 10. Preprocessing the datasets: we need to combine the prompt and generations and tokenize the targets.
1045
+ # 10.1: Define the pre-processing constants
1046
+ max_label_length = (
1047
+ data_args.max_label_length if data_args.max_label_length is not None else config.max_length
1048
+ )
1049
+ num_workers = data_args.preprocessing_num_workers
1050
+ dataloader_num_workers = training_args.dataloader_num_workers
1051
+ prefetch_factor = training_args.dataloader_prefetch_factor
1052
+ eos_token_id = tokenizer.eos_token_id
1053
+ if model_args.instruction_model is not None:
1054
+ instruction_model = model_args.instruction_model
1055
+ else:
1056
+ instruction_model = "instruct" in model_args.model_name_or_path.lower()
1057
+ if instruction_model and "prompt" not in raw_datasets_train_features:
1058
+ raise ValueError(
1059
+ "Distilling an instruction model, but `--prompt_column_name` is set to None. "
1060
+ "Ensure `--prompt_column_name` is set according to the dataset features."
1061
+ )
1062
+
1063
+ # 10.2: filter based on maximum number of training/evaluation samples
1064
+ if training_args.do_train and data_args.max_train_samples is not None:
1065
+ raw_datasets["train"] = (
1066
+ raw_datasets["train"].take(data_args.max_train_samples)
1067
+ if data_args.streaming
1068
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1069
+ )
1070
+
1071
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1072
+ for eval_split in all_eval_splits:
1073
+ raw_datasets[eval_split] = (
1074
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1075
+ if data_args.streaming
1076
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1077
+ )
1078
+
1079
+ # 10.3: pre-process training/evaluation datasets
1080
+ def prepare_dataset(example):
1081
+ prompt = example.get("prompt")
1082
+ target_text = prompt + example["text"] if prompt is not None else example["text"]
1083
+ example["labels"] = tokenizer(target_text).input_ids
1084
+ if example["labels"][-1] != eos_token_id:
1085
+ example["labels"] += [eos_token_id]
1086
+ example["prompt_length"] = len(tokenizer(prompt).input_ids) if prompt else 0
1087
+ return example
1088
+
1089
+ def prepare_instruction_dataset(example):
1090
+ messages = [
1091
+ {"role": "user", "content": example["prompt"]},
1092
+ {"role": "assistant", "content": example["text"]},
1093
+ ]
1094
+ example["labels"] = tokenizer.apply_chat_template(messages)
1095
+ if example["labels"][-1] != eos_token_id:
1096
+ example["labels"] = example["labels"][:-1]
1097
+
1098
+ example["prompt_length"] = len(tokenizer.apply_chat_template([messages[0]]))
1099
+ return example
1100
+
1101
+ prepare_dataset = prepare_instruction_dataset if instruction_model else prepare_dataset
1102
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1103
+ if training_args.do_train:
1104
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1105
+ # we can use `num_workers` (which is much faster)
1106
+ # We gate the pre-processing function accordingly
1107
+ map_fn_train = partial(
1108
+ raw_datasets["train"].map,
1109
+ function=prepare_dataset,
1110
+ remove_columns=raw_datasets_train_features,
1111
+ )
1112
+ with accelerator.main_process_first():
1113
+ vectorized_datasets["train"] = (
1114
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1115
+ if not data_args.streaming
1116
+ else map_fn_train()
1117
+ )
1118
+ if training_args.do_eval:
1119
+ for eval_split in all_eval_splits:
1120
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1121
+ map_fn_eval = partial(
1122
+ raw_datasets[eval_split].map, function=prepare_dataset, remove_columns=raw_datasets_eval_features
1123
+ )
1124
+ with accelerator.main_process_first():
1125
+ vectorized_datasets[eval_split] = (
1126
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1127
+ if not data_args.streaming
1128
+ else map_fn_eval()
1129
+ )
1130
+
1131
+ # 10.4: Filter training data with labels longer than `max_label_length`
1132
+ def is_labels_in_length_range(labels):
1133
+ return 0 < len(labels) <= max_label_length
1134
+
1135
+ filter_by_labels_fn = partial(
1136
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1137
+ )
1138
+ with accelerator.main_process_first():
1139
+ vectorized_datasets = (
1140
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1141
+ if not data_args.streaming
1142
+ else filter_by_labels_fn()
1143
+ )
1144
+
1145
+ # Pre-processing complete!
1146
+ # For large datasets it is advised to run the preprocessing on a
1147
+ # single machine first with `--preprocessing_only` since there will mostly likely
1148
+ # be a timeout when running the script in distributed mode.
1149
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1150
+ # cached dataset
1151
+ if data_args.preprocessing_only:
1152
+ if data_args.streaming:
1153
+ raise ValueError(
1154
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1155
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1156
+ "on the fly with streaming mode."
1157
+ )
1158
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1159
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1160
+ return
1161
+
1162
+ # 11. Define Evaluation Metrics
1163
+ def compute_metrics(preds, labels):
1164
+ # TODO(SG): better metrics for performance?
1165
+ # replace padded labels by the padding token
1166
+ for idx in range(len(labels)):
1167
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1168
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
1169
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1170
+ return pred_str, label_str
1171
+
1172
+ # 12. Define Training Schedule
1173
+ # 12.1: Store some constants
1174
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1175
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1176
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1177
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1178
+ temperature = training_args.temperature
1179
+
1180
+ # 12.2: Set max training steps
1181
+ if not data_args.streaming and training_args.max_steps < 0:
1182
+ num_epochs = int(training_args.num_train_epochs)
1183
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1184
+ total_train_steps = steps_per_epoch * num_epochs
1185
+ elif training_args.max_steps > 0:
1186
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1187
+ total_train_steps = int(training_args.max_steps)
1188
+ if not data_args.streaming:
1189
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1190
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1191
+ else:
1192
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1193
+ num_epochs = sys.maxsize
1194
+ steps_per_epoch = total_train_steps
1195
+ else:
1196
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1197
+
1198
+ # 12.3: Set evaluation steps
1199
+ if training_args.evaluation_strategy == "epoch":
1200
+ eval_steps = steps_per_epoch
1201
+ elif training_args.eval_steps is None:
1202
+ logger.info(
1203
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1204
+ )
1205
+ eval_steps = steps_per_epoch
1206
+ else:
1207
+ eval_steps = training_args.eval_steps
1208
+
1209
+ # 12.4: Set save steps
1210
+ if training_args.save_strategy == "epoch":
1211
+ save_steps = steps_per_epoch
1212
+ elif training_args.save_strategy == "steps":
1213
+ save_steps = training_args.save_steps
1214
+ else:
1215
+ save_steps = sys.maxsize
1216
+
1217
+ # 13. Define optimizer, LR scheduler, collator
1218
+ decay_parameters = get_parameter_names(
1219
+ student_model,
1220
+ [nn.LayerNorm],
1221
+ forbidden_module,
1222
+ )
1223
+
1224
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1225
+ optimizer_grouped_parameters = [
1226
+ {
1227
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1228
+ "weight_decay": training_args.weight_decay,
1229
+ },
1230
+ {
1231
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1232
+ "weight_decay": 0.0,
1233
+ },
1234
+ ]
1235
+ if training_args.optim == OptimizerNames.ADAMW_TORCH:
1236
+ optim_cls = torch.optim.AdamW
1237
+ elif training_args.optim == OptimizerNames.ADAMW_BNB:
1238
+ if not is_bitsandbytes_available():
1239
+ raise ValueError(
1240
+ "bitsandbytes package required for Adam8bit. Install via: `pip install --upgrade bitsandbytes`"
1241
+ )
1242
+ import bitsandbytes as bnb
1243
+
1244
+ optim_cls = bnb.optim.Adam8bit
1245
+ else:
1246
+ raise ValueError(
1247
+ f"Got invalid `--optim` {training_args.optim}, should be one of `['adam_torch', 'adamw_bnb_8bit']`."
1248
+ )
1249
+
1250
+ optimizer = optim_cls(
1251
+ params = optimizer_grouped_parameters,
1252
+ lr = training_args.learning_rate,
1253
+ betas = (training_args.adam_beta1, training_args.adam_beta2),
1254
+ eps = training_args.adam_epsilon,
1255
+ )
1256
+
1257
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1258
+ lr_scheduler = get_scheduler(
1259
+ name=training_args.lr_scheduler_type,
1260
+ optimizer=optimizer,
1261
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1262
+ num_training_steps=total_train_steps * accelerator.num_processes,
1263
+ )
1264
+
1265
+ data_collator = DataCollatorCausalLMWithPadding(
1266
+ tokenizer=tokenizer,
1267
+ target_padding="max_length",
1268
+ max_target_length=max_label_length,
1269
+ completions_only=training_args.completions_only,
1270
+ )
1271
+
1272
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1273
+ # so that we can still access the configs
1274
+ num_beams = (
1275
+ training_args.generation_num_beams
1276
+ if training_args.generation_num_beams is not None
1277
+ else getattr(student_model.generation_config, "num_beams", 1)
1278
+ )
1279
+
1280
+ # 15. Prepare everything with accelerate
1281
+ student_model, optimizer, lr_scheduler = accelerator.prepare(student_model, optimizer, lr_scheduler)
1282
+ teacher_model = accelerator.prepare(teacher_model) if teacher_model else None
1283
+
1284
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1285
+ kl_loss = nn.KLDivLoss(reduction="none")
1286
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1287
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1288
+ padding_mask = labels >= 0
1289
+ padding_mask = padding_mask.unsqueeze(-1)
1290
+ divergence = divergence * padding_mask
1291
+ # take the average over the mini-batch
1292
+ divergence = divergence.sum() / padding_mask.sum()
1293
+ return divergence
1294
+
1295
+ # Define gradient update step fn
1296
+ def train_step(batch):
1297
+ student_model.train()
1298
+ student_outputs = student_model(**batch)
1299
+
1300
+ # CE (data) loss
1301
+ ce_loss = student_outputs.loss
1302
+ metrics = {"ce_loss": ce_loss}
1303
+
1304
+ if teacher_model:
1305
+ with torch.no_grad():
1306
+ teacher_outputs = teacher_model(**batch)
1307
+ # rescale distribution by temperature to ensure gradients scale correctly
1308
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1309
+ # log softmax of student predictions for numerical stability
1310
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1311
+ # KL-divergence loss (scaled by temperature)
1312
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature ** 2
1313
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1314
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1315
+ metrics["kl_loss"] = kl_loss
1316
+ else:
1317
+ loss = ce_loss
1318
+
1319
+ metrics["loss"] = loss
1320
+ return loss, metrics
1321
+
1322
+ # Define eval fn
1323
+ @torch.no_grad()
1324
+ def eval_step(batch):
1325
+ student_model.eval()
1326
+
1327
+ # CE (data) loss
1328
+ student_outputs = student_model(**batch)
1329
+ ce_loss = student_outputs.loss
1330
+ metrics = {"ce_loss": ce_loss}
1331
+
1332
+ if teacher_model:
1333
+ teacher_outputs = teacher_model(**batch)
1334
+ # log softmax / softmax for numerical stability
1335
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1336
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1337
+ # temperature is always 1 for eval
1338
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1339
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1340
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1341
+ metrics["kl_loss"] = kl_loss
1342
+ else:
1343
+ loss = ce_loss
1344
+
1345
+ metrics["loss"] = loss
1346
+ return metrics
1347
+
1348
+ def generate_step(batch):
1349
+ output_ids = accelerator.unwrap_model(student_model).generate(
1350
+ **batch, max_length=max_label_length, num_beams=num_beams
1351
+ )
1352
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1353
+ return output_ids
1354
+
1355
+ logger.info("***** Running training *****")
1356
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1357
+ if not data_args.streaming:
1358
+ logger.info(f" Num epochs = {num_epochs}")
1359
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1360
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1361
+ logger.info(
1362
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1363
+ )
1364
+ logger.info(f" Total optimization steps = {total_train_steps}")
1365
+
1366
+ # ======================== Training ================================
1367
+ train_time = 0
1368
+ train_start = time.time()
1369
+ steps_trained_progress_bar = tqdm(
1370
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1371
+ )
1372
+ continue_training = True
1373
+ epochs_trained = 0
1374
+ cur_step = 0
1375
+
1376
+ checkpoint = None
1377
+ if training_args.resume_from_checkpoint is not None:
1378
+ checkpoint = training_args.resume_from_checkpoint
1379
+ elif last_checkpoint is not None:
1380
+ checkpoint = last_checkpoint
1381
+
1382
+ if checkpoint is not None:
1383
+ accelerator.load_state(checkpoint)
1384
+ # Find num steps and epoch from saved state string pattern
1385
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1386
+ match = re.search(pattern, checkpoint)
1387
+ cur_step = int(match.group(1))
1388
+ epochs_trained = int(match.group(2))
1389
+
1390
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1391
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1392
+ logger.info(f" Continuing training from global step {cur_step}")
1393
+
1394
+ steps_trained_progress_bar.update(cur_step)
1395
+
1396
+ for epoch in range(0, epochs_trained):
1397
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1398
+
1399
+ if not data_args.streaming and training_args.max_steps < 0:
1400
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1401
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1402
+ else:
1403
+ # Currently we don't know how many steps we've taken in the current epoch
1404
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1405
+ # This is "good enough" for our purposes but not fully correct
1406
+ resume_step = None
1407
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1408
+ else:
1409
+ resume_step = None
1410
+
1411
+ for epoch in range(epochs_trained, num_epochs):
1412
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1413
+ train_dataloader = DataLoader(
1414
+ vectorized_datasets["train"],
1415
+ collate_fn=data_collator,
1416
+ batch_size=per_device_train_batch_size,
1417
+ num_workers=dataloader_num_workers,
1418
+ prefetch_factor=prefetch_factor,
1419
+ pin_memory=training_args.dataloader_pin_memory,
1420
+ )
1421
+ train_dataloader = accelerator.prepare(train_dataloader)
1422
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1423
+ train_dataloader.dataset.set_epoch(epoch)
1424
+
1425
+ if resume_step is not None:
1426
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1427
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1428
+ resume_step = None
1429
+
1430
+ for batch in train_dataloader:
1431
+ with accelerator.accumulate(student_model):
1432
+ loss, train_metric = train_step(batch)
1433
+ accelerator.backward(loss)
1434
+ if accelerator.sync_gradients:
1435
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1436
+ optimizer.step()
1437
+ lr_scheduler.step()
1438
+ optimizer.zero_grad()
1439
+
1440
+ # Check if the accelerator has performed an optimization step behind the scenes
1441
+ if accelerator.sync_gradients:
1442
+ steps_trained_progress_bar.update(1)
1443
+ cur_step += 1
1444
+
1445
+ if cur_step % training_args.logging_steps == 0:
1446
+ steps_trained_progress_bar.write(
1447
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1448
+ f" {train_metric['loss']}, Learning Rate:"
1449
+ f" {lr_scheduler.get_last_lr()[0]})"
1450
+ )
1451
+ log_metric(
1452
+ accelerator,
1453
+ metrics=train_metric,
1454
+ learning_rate=lr_scheduler.get_last_lr()[0],
1455
+ train_time=train_time + time.time() - train_start,
1456
+ step=cur_step,
1457
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1458
+ prefix="train",
1459
+ )
1460
+
1461
+ # save checkpoint and weights after each save_steps and at the end of training
1462
+ if (cur_step % save_steps == 0) or cur_step == total_train_steps:
1463
+ accelerator.wait_for_everyone()
1464
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1465
+ accelerator.save_state(output_dir=intermediate_dir)
1466
+ unwrapped_model = accelerator.unwrap_model(student_model)
1467
+ unwrapped_model.save_pretrained(
1468
+ intermediate_dir,
1469
+ is_main_process=accelerator.is_main_process,
1470
+ save_function=accelerator.save,
1471
+ )
1472
+ if accelerator.is_main_process:
1473
+ checkpoint_to_be_deleted = rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1474
+ if training_args.push_to_hub:
1475
+ upload_folder(
1476
+ folder_path=training_args.output_dir,
1477
+ repo_id=repo_name,
1478
+ repo_type="model",
1479
+ commit_message=f"Saving train state of step {cur_step}",
1480
+ delete_patterns=checkpoint_to_be_deleted,
1481
+ )
1482
+
1483
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1484
+ train_time += time.time() - train_start
1485
+ student_model.eval()
1486
+ # ======================== Evaluating ==============================
1487
+ for eval_split in all_eval_splits:
1488
+ eval_metrics = []
1489
+ eval_preds = []
1490
+ eval_labels = []
1491
+ eval_start = time.time()
1492
+
1493
+ validation_dataloader = DataLoader(
1494
+ vectorized_datasets[eval_split],
1495
+ collate_fn=data_collator,
1496
+ batch_size=per_device_eval_batch_size,
1497
+ drop_last=False,
1498
+ num_workers=dataloader_num_workers,
1499
+ prefetch_factor=prefetch_factor,
1500
+ pin_memory=training_args.dataloader_pin_memory,
1501
+ )
1502
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1503
+
1504
+ for batch in tqdm(
1505
+ validation_dataloader,
1506
+ desc=f"Evaluating {eval_split}...",
1507
+ position=2,
1508
+ disable=not accelerator.is_local_main_process,
1509
+ ):
1510
+ # Model forward
1511
+ eval_metric = eval_step(batch)
1512
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1513
+ eval_metrics.append(eval_metric)
1514
+
1515
+ # generation
1516
+ if training_args.predict_with_generate:
1517
+ generated_ids = generate_step(batch)
1518
+ # Gather all predictions and targets
1519
+ generated_ids, labels = accelerator.gather_for_metrics(
1520
+ (generated_ids, batch["labels"])
1521
+ )
1522
+ eval_preds.extend(generated_ids)
1523
+ eval_labels.extend(labels)
1524
+
1525
+ eval_time = time.time() - eval_start
1526
+ stack = torch.stack if accelerator.num_processes == 1 else torch.concatenate
1527
+ # normalize eval metrics
1528
+ eval_metrics = {
1529
+ key: torch.mean(stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1530
+ }
1531
+ try:
1532
+ eval_metrics["perplexity"] = math.exp(eval_metrics["ce_loss"])
1533
+ except OverflowError:
1534
+ eval_metrics["perplexity"] = float("inf")
1535
+
1536
+ if training_args.predict_with_generate:
1537
+ pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1538
+ log_pred(
1539
+ accelerator,
1540
+ pred_str,
1541
+ label_str,
1542
+ step=cur_step,
1543
+ epoch=epoch,
1544
+ evaluation_strategy=training_args.evaluation_strategy,
1545
+ prefix=eval_split,
1546
+ )
1547
+
1548
+ # Print metrics and update progress bar
1549
+ logger_desc = " ".join([f"Eval {key}: {value} |" for key, value in eval_metrics.items()])
1550
+ steps_trained_progress_bar.write(
1551
+ f"Eval results for step ({cur_step} / {total_train_steps} | {logger_desc}"
1552
+ )
1553
+
1554
+ log_metric(
1555
+ accelerator,
1556
+ metrics=eval_metrics,
1557
+ train_time=eval_time,
1558
+ step=cur_step,
1559
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1560
+ prefix=eval_split,
1561
+ )
1562
+
1563
+ # flush the train metrics
1564
+ train_start = time.time()
1565
+
1566
+ # break condition
1567
+ if cur_step == total_train_steps:
1568
+ accelerator.wait_for_everyone()
1569
+ # un-wrap student model for save
1570
+ student_model = accelerator.unwrap_model(student_model)
1571
+ student_model.save_pretrained(
1572
+ training_args.output_dir,
1573
+ is_main_process=accelerator.is_main_process,
1574
+ save_function=accelerator.save,
1575
+ )
1576
+ if training_args.push_to_hub and accelerator.is_main_process:
1577
+ upload_folder(
1578
+ folder_path=training_args.output_dir,
1579
+ repo_id=repo_name,
1580
+ repo_type="model",
1581
+ commit_message=f"Saving final weights of step {cur_step}",
1582
+ )
1583
+ continue_training = False
1584
+ break
1585
+
1586
+ if not continue_training:
1587
+ break
1588
+
1589
+ accelerator.end_training()
1590
+
1591
+
1592
+ if __name__ == "__main__":
1593
+ main()
1594
+
slurm_job.slurm ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=distil-mistral
3
+ #SBATCH --nodes=1
4
+ # set 24h for job wall time limit
5
+ #SBATCH --time=48:00:00
6
+ #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
7
+ #SBATCH --cpus-per-task=32
8
+ #SBATCH --gres=gpu:8
9
+ #SBATCH --exclusive
10
+ #SBATCH --partition=hopper-prod
11
+ #SBATCH --output=/fsx/sanchit/logs/%x-%j.out
12
+
13
+ set -x -e
14
+
15
+ # START EDIT
16
+ source ~/.bashrc
17
+ source /fsx/sanchit/miniconda3/bin/activate venv
18
+
19
+ LOG_PATH="/fsx/sanchit/logs/main_log.txt"
20
+ SAVE_DIR="/fsx/sanchit"
21
+ # END EDIT
22
+
23
+ echo "START TIME: $(date)"
24
+
25
+ GPUS_PER_NODE=8
26
+ NNODES=$SLURM_NNODES
27
+
28
+ # so processes know who to talk to
29
+ MASTER_ADDR=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1`
30
+
31
+ # From https://i.hsfzxjy.site/2021-03-10-obtain-a-random-unused-tcp-port-with-bash/
32
+ function unused_port() {
33
+ N=${1:-1}
34
+ comm -23 \
35
+ <(seq "1025" "65535" | sort) \
36
+ <(ss -Htan |
37
+ awk '{print $4}' |
38
+ cut -d':' -f2 |
39
+ sort -u) |
40
+ shuf |
41
+ head -n "$N"
42
+ }
43
+ MASTER_PORT=$(unused_port)
44
+
45
+ # export TORCH_CPP_LOG_LEVEL=INFO
46
+ # export TORCH_DISTRIBUTED_DEBUG=DETAIL
47
+
48
+ export LAUNCHER="python -u -m accelerate.commands.launch --config_file ./accelerate_config.yaml"
49
+
50
+ export PROGRAM="./run_distillation.py ./config_mistral.yaml"
51
+ export CMD="$LAUNCHER $PROGRAM"
52
+ echo $CMD
53
+
54
+ SRUN_ARGS=" \
55
+ --wait=60 \
56
+ --kill-on-bad-exit=1 \
57
+ "
58
+
59
+ # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
60
+ clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
61
+
62
+
63
+ # srun error handling:
64
+ # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
65
+ # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
66
+
67
+ # SRUN_ARGS=" \
68
+ # --wait=60 \
69
+ # --kill-on-bad-exit=1 \
70
+ # "
71
+ #
72
+ # # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
73
+ # clear; srun $SRUN_ARGS --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
74
+
75
+ echo "END TIME: $(date)"
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "legacy": true,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": "</s>",
38
+ "sp_model_kwargs": {},
39
+ "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }