Feat(wandb): Refactor to be more flexible (#767)
Browse files* Feat: Update to handle wandb env better
* chore: rename wandb_run_id to wandb_name
* feat: add new recommendation and update config
* fix: indent and pop disabled env if project passed
* feat: test env set for wandb and recommendation
* feat: update to use wandb_name and allow id
* chore: add info to readme
- README.md +3 -2
- examples/cerebras/btlm-ft.yml +1 -1
- examples/cerebras/qlora.yml +1 -1
- examples/code-llama/13b/lora.yml +1 -1
- examples/code-llama/13b/qlora.yml +1 -1
- examples/code-llama/34b/lora.yml +1 -1
- examples/code-llama/34b/qlora.yml +1 -1
- examples/code-llama/7b/lora.yml +1 -1
- examples/code-llama/7b/qlora.yml +1 -1
- examples/falcon/config-7b-lora.yml +1 -1
- examples/falcon/config-7b-qlora.yml +1 -1
- examples/falcon/config-7b.yml +1 -1
- examples/gptj/qlora.yml +1 -1
- examples/jeopardy-bot/config.yml +1 -1
- examples/llama-2/fft_optimized.yml +1 -1
- examples/llama-2/gptq-lora.yml +1 -1
- examples/llama-2/lora.yml +1 -1
- examples/llama-2/qlora.yml +1 -1
- examples/llama-2/relora.yml +1 -1
- examples/llama-2/tiny-llama.yml +1 -1
- examples/mistral/config.yml +1 -1
- examples/mistral/qlora.yml +1 -1
- examples/mpt-7b/config.yml +1 -1
- examples/openllama-3b/config.yml +1 -1
- examples/openllama-3b/lora.yml +1 -1
- examples/openllama-3b/qlora.yml +1 -1
- examples/phi/phi-ft.yml +1 -1
- examples/phi/phi-qlora.yml +1 -1
- examples/pythia-12b/config.yml +1 -1
- examples/pythia/lora.yml +1 -1
- examples/qwen/lora.yml +1 -1
- examples/qwen/qlora.yml +1 -1
- examples/redpajama/config-3b.yml +1 -1
- examples/replit-3b/config-lora.yml +1 -1
- examples/xgen-7b/xgen-7b-8k-qlora.yml +1 -1
- src/axolotl/core/trainer_builder.py +1 -1
- src/axolotl/utils/config.py +7 -0
- src/axolotl/utils/wandb_.py +13 -13
- tests/test_validation.py +82 -0
README.md
CHANGED
|
@@ -659,7 +659,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
|
| 659 |
wandb_project: # Your wandb project name
|
| 660 |
wandb_entity: # A wandb Team name if using a Team
|
| 661 |
wandb_watch:
|
| 662 |
-
|
|
|
|
| 663 |
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
| 664 |
|
| 665 |
# Where to save the full-finetuned model to
|
|
@@ -955,7 +956,7 @@ wandb_mode:
|
|
| 955 |
wandb_project:
|
| 956 |
wandb_entity:
|
| 957 |
wandb_watch:
|
| 958 |
-
|
| 959 |
wandb_log_model:
|
| 960 |
```
|
| 961 |
|
|
|
|
| 659 |
wandb_project: # Your wandb project name
|
| 660 |
wandb_entity: # A wandb Team name if using a Team
|
| 661 |
wandb_watch:
|
| 662 |
+
wandb_name: # Set the name of your wandb run
|
| 663 |
+
wandb_run_id: # Set the ID of your wandb run
|
| 664 |
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
| 665 |
|
| 666 |
# Where to save the full-finetuned model to
|
|
|
|
| 956 |
wandb_project:
|
| 957 |
wandb_entity:
|
| 958 |
wandb_watch:
|
| 959 |
+
wandb_name:
|
| 960 |
wandb_log_model:
|
| 961 |
```
|
| 962 |
|
examples/cerebras/btlm-ft.yml
CHANGED
|
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
|
| 35 |
wandb_project:
|
| 36 |
wandb_entity:
|
| 37 |
wandb_watch:
|
| 38 |
-
|
| 39 |
wandb_log_model:
|
| 40 |
|
| 41 |
output_dir: btlm-out
|
|
|
|
| 35 |
wandb_project:
|
| 36 |
wandb_entity:
|
| 37 |
wandb_watch:
|
| 38 |
+
wandb_name:
|
| 39 |
wandb_log_model:
|
| 40 |
|
| 41 |
output_dir: btlm-out
|
examples/cerebras/qlora.yml
CHANGED
|
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
|
| 24 |
wandb_project:
|
| 25 |
wandb_entity:
|
| 26 |
wandb_watch:
|
| 27 |
-
|
| 28 |
wandb_log_model:
|
| 29 |
output_dir: ./qlora-out
|
| 30 |
batch_size: 4
|
|
|
|
| 24 |
wandb_project:
|
| 25 |
wandb_entity:
|
| 26 |
wandb_watch:
|
| 27 |
+
wandb_name:
|
| 28 |
wandb_log_model:
|
| 29 |
output_dir: ./qlora-out
|
| 30 |
batch_size: 4
|
examples/code-llama/13b/lora.yml
CHANGED
|
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
-
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
|
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
+
wandb_name:
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
examples/code-llama/13b/qlora.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
examples/code-llama/34b/lora.yml
CHANGED
|
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
-
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
|
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
+
wandb_name:
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
examples/code-llama/34b/qlora.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
examples/code-llama/7b/lora.yml
CHANGED
|
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
-
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
|
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
+
wandb_name:
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
examples/code-llama/7b/qlora.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
examples/falcon/config-7b-lora.yml
CHANGED
|
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|
| 26 |
wandb_project:
|
| 27 |
wandb_entity:
|
| 28 |
wandb_watch:
|
| 29 |
-
|
| 30 |
wandb_log_model:
|
| 31 |
output_dir: ./falcon-7b
|
| 32 |
batch_size: 2
|
|
|
|
| 26 |
wandb_project:
|
| 27 |
wandb_entity:
|
| 28 |
wandb_watch:
|
| 29 |
+
wandb_name:
|
| 30 |
wandb_log_model:
|
| 31 |
output_dir: ./falcon-7b
|
| 32 |
batch_size: 2
|
examples/falcon/config-7b-qlora.yml
CHANGED
|
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
|
| 40 |
wandb_project:
|
| 41 |
wandb_entity:
|
| 42 |
wandb_watch:
|
| 43 |
-
|
| 44 |
wandb_log_model:
|
| 45 |
output_dir: ./qlora-out
|
| 46 |
|
|
|
|
| 40 |
wandb_project:
|
| 41 |
wandb_entity:
|
| 42 |
wandb_watch:
|
| 43 |
+
wandb_name:
|
| 44 |
wandb_log_model:
|
| 45 |
output_dir: ./qlora-out
|
| 46 |
|
examples/falcon/config-7b.yml
CHANGED
|
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|
| 26 |
wandb_project:
|
| 27 |
wandb_entity:
|
| 28 |
wandb_watch:
|
| 29 |
-
|
| 30 |
wandb_log_model:
|
| 31 |
output_dir: ./falcon-7b
|
| 32 |
batch_size: 2
|
|
|
|
| 26 |
wandb_project:
|
| 27 |
wandb_entity:
|
| 28 |
wandb_watch:
|
| 29 |
+
wandb_name:
|
| 30 |
wandb_log_model:
|
| 31 |
output_dir: ./falcon-7b
|
| 32 |
batch_size: 2
|
examples/gptj/qlora.yml
CHANGED
|
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|
| 21 |
wandb_project:
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
-
|
| 25 |
wandb_log_model:
|
| 26 |
output_dir: ./qlora-out
|
| 27 |
gradient_accumulation_steps: 2
|
|
|
|
| 21 |
wandb_project:
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
+
wandb_name:
|
| 25 |
wandb_log_model:
|
| 26 |
output_dir: ./qlora-out
|
| 27 |
gradient_accumulation_steps: 2
|
examples/jeopardy-bot/config.yml
CHANGED
|
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
|
| 19 |
wandb_project:
|
| 20 |
wandb_entity:
|
| 21 |
wandb_watch:
|
| 22 |
-
|
| 23 |
wandb_log_model:
|
| 24 |
output_dir: ./jeopardy-bot-7b
|
| 25 |
gradient_accumulation_steps: 1
|
|
|
|
| 19 |
wandb_project:
|
| 20 |
wandb_entity:
|
| 21 |
wandb_watch:
|
| 22 |
+
wandb_name:
|
| 23 |
wandb_log_model:
|
| 24 |
output_dir: ./jeopardy-bot-7b
|
| 25 |
gradient_accumulation_steps: 1
|
examples/llama-2/fft_optimized.yml
CHANGED
|
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
-
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 1
|
|
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
+
wandb_name:
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 1
|
examples/llama-2/gptq-lora.yml
CHANGED
|
@@ -32,7 +32,7 @@ lora_target_linear:
|
|
| 32 |
lora_fan_in_fan_out:
|
| 33 |
wandb_project:
|
| 34 |
wandb_watch:
|
| 35 |
-
|
| 36 |
wandb_log_model:
|
| 37 |
output_dir: ./model-out
|
| 38 |
gradient_accumulation_steps: 1
|
|
|
|
| 32 |
lora_fan_in_fan_out:
|
| 33 |
wandb_project:
|
| 34 |
wandb_watch:
|
| 35 |
+
wandb_name:
|
| 36 |
wandb_log_model:
|
| 37 |
output_dir: ./model-out
|
| 38 |
gradient_accumulation_steps: 1
|
examples/llama-2/lora.yml
CHANGED
|
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
-
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
|
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
+
wandb_name:
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
examples/llama-2/qlora.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
examples/llama-2/relora.yml
CHANGED
|
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
|
| 35 |
wandb_project:
|
| 36 |
wandb_entity:
|
| 37 |
wandb_watch:
|
| 38 |
-
|
| 39 |
wandb_log_model:
|
| 40 |
|
| 41 |
gradient_accumulation_steps: 4
|
|
|
|
| 35 |
wandb_project:
|
| 36 |
wandb_entity:
|
| 37 |
wandb_watch:
|
| 38 |
+
wandb_name:
|
| 39 |
wandb_log_model:
|
| 40 |
|
| 41 |
gradient_accumulation_steps: 4
|
examples/llama-2/tiny-llama.yml
CHANGED
|
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
-
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
|
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
+
wandb_name:
|
| 33 |
wandb_log_model:
|
| 34 |
|
| 35 |
gradient_accumulation_steps: 4
|
examples/mistral/config.yml
CHANGED
|
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
|
| 21 |
wandb_project:
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
-
|
| 25 |
wandb_log_model:
|
| 26 |
|
| 27 |
gradient_accumulation_steps: 4
|
|
|
|
| 21 |
wandb_project:
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
+
wandb_name:
|
| 25 |
wandb_log_model:
|
| 26 |
|
| 27 |
gradient_accumulation_steps: 4
|
examples/mistral/qlora.yml
CHANGED
|
@@ -38,7 +38,7 @@ lora_target_modules:
|
|
| 38 |
wandb_project:
|
| 39 |
wandb_entity:
|
| 40 |
wandb_watch:
|
| 41 |
-
|
| 42 |
wandb_log_model:
|
| 43 |
|
| 44 |
gradient_accumulation_steps: 4
|
|
|
|
| 38 |
wandb_project:
|
| 39 |
wandb_entity:
|
| 40 |
wandb_watch:
|
| 41 |
+
wandb_name:
|
| 42 |
wandb_log_model:
|
| 43 |
|
| 44 |
gradient_accumulation_steps: 4
|
examples/mpt-7b/config.yml
CHANGED
|
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
|
| 21 |
wandb_project: mpt-alpaca-7b
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
-
|
| 25 |
wandb_log_model:
|
| 26 |
output_dir: ./mpt-alpaca-7b
|
| 27 |
gradient_accumulation_steps: 1
|
|
|
|
| 21 |
wandb_project: mpt-alpaca-7b
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
+
wandb_name:
|
| 25 |
wandb_log_model:
|
| 26 |
output_dir: ./mpt-alpaca-7b
|
| 27 |
gradient_accumulation_steps: 1
|
examples/openllama-3b/config.yml
CHANGED
|
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|
| 23 |
wandb_project:
|
| 24 |
wandb_entity:
|
| 25 |
wandb_watch:
|
| 26 |
-
|
| 27 |
wandb_log_model:
|
| 28 |
output_dir: ./openllama-out
|
| 29 |
gradient_accumulation_steps: 1
|
|
|
|
| 23 |
wandb_project:
|
| 24 |
wandb_entity:
|
| 25 |
wandb_watch:
|
| 26 |
+
wandb_name:
|
| 27 |
wandb_log_model:
|
| 28 |
output_dir: ./openllama-out
|
| 29 |
gradient_accumulation_steps: 1
|
examples/openllama-3b/lora.yml
CHANGED
|
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
-
|
| 33 |
wandb_log_model:
|
| 34 |
output_dir: ./lora-out
|
| 35 |
gradient_accumulation_steps: 1
|
|
|
|
| 29 |
wandb_project:
|
| 30 |
wandb_entity:
|
| 31 |
wandb_watch:
|
| 32 |
+
wandb_name:
|
| 33 |
wandb_log_model:
|
| 34 |
output_dir: ./lora-out
|
| 35 |
gradient_accumulation_steps: 1
|
examples/openllama-3b/qlora.yml
CHANGED
|
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|
| 23 |
wandb_project:
|
| 24 |
wandb_entity:
|
| 25 |
wandb_watch:
|
| 26 |
-
|
| 27 |
wandb_log_model:
|
| 28 |
output_dir: ./qlora-out
|
| 29 |
gradient_accumulation_steps: 1
|
|
|
|
| 23 |
wandb_project:
|
| 24 |
wandb_entity:
|
| 25 |
wandb_watch:
|
| 26 |
+
wandb_name:
|
| 27 |
wandb_log_model:
|
| 28 |
output_dir: ./qlora-out
|
| 29 |
gradient_accumulation_steps: 1
|
examples/phi/phi-ft.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 1
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 1
|
examples/phi/phi-qlora.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 1
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 1
|
examples/pythia-12b/config.yml
CHANGED
|
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|
| 24 |
wandb_project:
|
| 25 |
wandb_entity:
|
| 26 |
wandb_watch:
|
| 27 |
-
|
| 28 |
wandb_log_model:
|
| 29 |
output_dir: ./pythia-12b
|
| 30 |
gradient_accumulation_steps: 1
|
|
|
|
| 24 |
wandb_project:
|
| 25 |
wandb_entity:
|
| 26 |
wandb_watch:
|
| 27 |
+
wandb_name:
|
| 28 |
wandb_log_model:
|
| 29 |
output_dir: ./pythia-12b
|
| 30 |
gradient_accumulation_steps: 1
|
examples/pythia/lora.yml
CHANGED
|
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|
| 18 |
wandb_project:
|
| 19 |
wandb_entity:
|
| 20 |
wandb_watch:
|
| 21 |
-
|
| 22 |
wandb_log_model:
|
| 23 |
output_dir: ./lora-alpaca-pythia
|
| 24 |
gradient_accumulation_steps: 1
|
|
|
|
| 18 |
wandb_project:
|
| 19 |
wandb_entity:
|
| 20 |
wandb_watch:
|
| 21 |
+
wandb_name:
|
| 22 |
wandb_log_model:
|
| 23 |
output_dir: ./lora-alpaca-pythia
|
| 24 |
gradient_accumulation_steps: 1
|
examples/qwen/lora.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
examples/qwen/qlora.yml
CHANGED
|
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
-
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
|
|
|
| 31 |
wandb_project:
|
| 32 |
wandb_entity:
|
| 33 |
wandb_watch:
|
| 34 |
+
wandb_name:
|
| 35 |
wandb_log_model:
|
| 36 |
|
| 37 |
gradient_accumulation_steps: 4
|
examples/redpajama/config-3b.yml
CHANGED
|
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
|
| 22 |
wandb_project: redpajama-alpaca-3b
|
| 23 |
wandb_entity:
|
| 24 |
wandb_watch:
|
| 25 |
-
|
| 26 |
wandb_log_model:
|
| 27 |
output_dir: ./redpajama-alpaca-3b
|
| 28 |
batch_size: 4
|
|
|
|
| 22 |
wandb_project: redpajama-alpaca-3b
|
| 23 |
wandb_entity:
|
| 24 |
wandb_watch:
|
| 25 |
+
wandb_name:
|
| 26 |
wandb_log_model:
|
| 27 |
output_dir: ./redpajama-alpaca-3b
|
| 28 |
batch_size: 4
|
examples/replit-3b/config-lora.yml
CHANGED
|
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|
| 21 |
wandb_project: lora-replit
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
-
|
| 25 |
wandb_log_model:
|
| 26 |
output_dir: ./lora-replit
|
| 27 |
batch_size: 8
|
|
|
|
| 21 |
wandb_project: lora-replit
|
| 22 |
wandb_entity:
|
| 23 |
wandb_watch:
|
| 24 |
+
wandb_name:
|
| 25 |
wandb_log_model:
|
| 26 |
output_dir: ./lora-replit
|
| 27 |
batch_size: 8
|
examples/xgen-7b/xgen-7b-8k-qlora.yml
CHANGED
|
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
|
| 38 |
wandb_project:
|
| 39 |
wandb_entity:
|
| 40 |
wandb_watch:
|
| 41 |
-
|
| 42 |
wandb_log_model:
|
| 43 |
output_dir: ./qlora-out
|
| 44 |
|
|
|
|
| 38 |
wandb_project:
|
| 39 |
wandb_entity:
|
| 40 |
wandb_watch:
|
| 41 |
+
wandb_name:
|
| 42 |
wandb_log_model:
|
| 43 |
output_dir: ./qlora-out
|
| 44 |
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -647,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 647 |
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
| 648 |
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
| 649 |
training_arguments_kwargs["run_name"] = (
|
| 650 |
-
self.cfg.
|
| 651 |
)
|
| 652 |
training_arguments_kwargs["optim"] = (
|
| 653 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
|
|
|
| 647 |
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
| 648 |
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
| 649 |
training_arguments_kwargs["run_name"] = (
|
| 650 |
+
self.cfg.wandb_name if self.cfg.use_wandb else None
|
| 651 |
)
|
| 652 |
training_arguments_kwargs["optim"] = (
|
| 653 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
src/axolotl/utils/config.py
CHANGED
|
@@ -397,6 +397,13 @@ def validate_config(cfg):
|
|
| 397 |
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
| 398 |
)
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
# TODO
|
| 401 |
# MPT 7b
|
| 402 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 397 |
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
| 398 |
)
|
| 399 |
|
| 400 |
+
if cfg.wandb_run_id and not cfg.wandb_name:
|
| 401 |
+
cfg.wandb_name = cfg.wandb_run_id
|
| 402 |
+
|
| 403 |
+
LOG.warning(
|
| 404 |
+
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
# TODO
|
| 408 |
# MPT 7b
|
| 409 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/wandb_.py
CHANGED
|
@@ -2,20 +2,20 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
cfg.use_wandb = True
|
| 12 |
-
|
| 13 |
-
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
| 14 |
-
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
| 15 |
-
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
| 16 |
-
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
| 17 |
-
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
| 18 |
-
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
| 19 |
-
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
| 20 |
else:
|
| 21 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
from axolotl.utils.dict import DictDefault
|
| 6 |
|
| 7 |
+
|
| 8 |
+
def setup_wandb_env_vars(cfg: DictDefault):
|
| 9 |
+
for key in cfg.keys():
|
| 10 |
+
if key.startswith("wandb_"):
|
| 11 |
+
value = cfg.get(key, "")
|
| 12 |
+
|
| 13 |
+
if value and isinstance(value, str) and len(value) > 0:
|
| 14 |
+
os.environ[key.upper()] = value
|
| 15 |
+
|
| 16 |
+
# Enable wandb if project name is present
|
| 17 |
+
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
| 18 |
cfg.use_wandb = True
|
| 19 |
+
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
else:
|
| 21 |
os.environ["WANDB_DISABLED"] = "true"
|
tests/test_validation.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""Module for testing the validation module"""
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 4 |
import unittest
|
| 5 |
from typing import Optional
|
| 6 |
|
|
@@ -8,6 +9,7 @@ import pytest
|
|
| 8 |
|
| 9 |
from axolotl.utils.config import validate_config
|
| 10 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class ValidationTest(unittest.TestCase):
|
|
@@ -679,3 +681,83 @@ class ValidationTest(unittest.TestCase):
|
|
| 679 |
)
|
| 680 |
|
| 681 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Module for testing the validation module"""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
+
import os
|
| 5 |
import unittest
|
| 6 |
from typing import Optional
|
| 7 |
|
|
|
|
| 9 |
|
| 10 |
from axolotl.utils.config import validate_config
|
| 11 |
from axolotl.utils.dict import DictDefault
|
| 12 |
+
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
| 13 |
|
| 14 |
|
| 15 |
class ValidationTest(unittest.TestCase):
|
|
|
|
| 681 |
)
|
| 682 |
|
| 683 |
validate_config(cfg)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class ValidationWandbTest(ValidationTest):
|
| 687 |
+
"""
|
| 688 |
+
Validation test for wandb
|
| 689 |
+
"""
|
| 690 |
+
|
| 691 |
+
def test_wandb_set_run_id_to_name(self):
|
| 692 |
+
cfg = DictDefault(
|
| 693 |
+
{
|
| 694 |
+
"wandb_run_id": "foo",
|
| 695 |
+
}
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
with self._caplog.at_level(logging.WARNING):
|
| 699 |
+
validate_config(cfg)
|
| 700 |
+
assert any(
|
| 701 |
+
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
| 702 |
+
in record.message
|
| 703 |
+
for record in self._caplog.records
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
| 707 |
+
|
| 708 |
+
cfg = DictDefault(
|
| 709 |
+
{
|
| 710 |
+
"wandb_name": "foo",
|
| 711 |
+
}
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
validate_config(cfg)
|
| 715 |
+
|
| 716 |
+
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
| 717 |
+
|
| 718 |
+
def test_wandb_sets_env(self):
|
| 719 |
+
cfg = DictDefault(
|
| 720 |
+
{
|
| 721 |
+
"wandb_project": "foo",
|
| 722 |
+
"wandb_name": "bar",
|
| 723 |
+
"wandb_run_id": "bat",
|
| 724 |
+
"wandb_entity": "baz",
|
| 725 |
+
"wandb_mode": "online",
|
| 726 |
+
"wandb_watch": "false",
|
| 727 |
+
"wandb_log_model": "checkpoint",
|
| 728 |
+
}
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
validate_config(cfg)
|
| 732 |
+
|
| 733 |
+
setup_wandb_env_vars(cfg)
|
| 734 |
+
|
| 735 |
+
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
| 736 |
+
assert os.environ.get("WANDB_NAME", "") == "bar"
|
| 737 |
+
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
| 738 |
+
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
| 739 |
+
assert os.environ.get("WANDB_MODE", "") == "online"
|
| 740 |
+
assert os.environ.get("WANDB_WATCH", "") == "false"
|
| 741 |
+
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
| 742 |
+
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
| 743 |
+
|
| 744 |
+
def test_wandb_set_disabled(self):
|
| 745 |
+
cfg = DictDefault({})
|
| 746 |
+
|
| 747 |
+
validate_config(cfg)
|
| 748 |
+
|
| 749 |
+
setup_wandb_env_vars(cfg)
|
| 750 |
+
|
| 751 |
+
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
| 752 |
+
|
| 753 |
+
cfg = DictDefault(
|
| 754 |
+
{
|
| 755 |
+
"wandb_project": "foo",
|
| 756 |
+
}
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
validate_config(cfg)
|
| 760 |
+
|
| 761 |
+
setup_wandb_env_vars(cfg)
|
| 762 |
+
|
| 763 |
+
assert os.environ.get("WANDB_DISABLED", "") != "true"
|