|
""" |
|
This module is used to launch Axolotl with user defined configurations. |
|
""" |
|
|
|
import gradio as gr |
|
import yaml |
|
from config import config |
|
|
|
example_yml = """ |
|
base_model: NousResearch/Llama-2-7b-hf |
|
model_type: LlamaForCausalLM |
|
tokenizer_type: LlamaTokenizer |
|
|
|
load_in_8bit: false |
|
load_in_4bit: true |
|
strict: false |
|
|
|
datasets: |
|
- path: mhenrichsen/alpaca_2k_test |
|
type: alpaca |
|
dataset_prepared_path: |
|
val_set_size: 0.05 |
|
output_dir: ./qlora-out |
|
|
|
adapter: qlora |
|
lora_model_dir: |
|
|
|
sequence_len: 4096 |
|
sample_packing: true |
|
pad_to_sequence_len: true |
|
|
|
lora_r: 32 |
|
lora_alpha: 16 |
|
lora_dropout: 0.05 |
|
lora_target_modules: |
|
lora_target_linear: true |
|
lora_fan_in_fan_out: |
|
|
|
wandb_project: |
|
wandb_entity: |
|
wandb_watch: |
|
wandb_name: |
|
wandb_log_model: |
|
|
|
gradient_accumulation_steps: 4 |
|
micro_batch_size: 2 |
|
num_epochs: 4 |
|
optimizer: paged_adamw_32bit |
|
lr_scheduler: cosine |
|
learning_rate: 0.0002 |
|
|
|
train_on_inputs: false |
|
group_by_length: false |
|
bf16: auto |
|
fp16: |
|
tf32: false |
|
|
|
gradient_checkpointing: true |
|
early_stopping_patience: |
|
resume_from_checkpoint: |
|
local_rank: |
|
logging_steps: 1 |
|
xformers_attention: |
|
flash_attention: true |
|
|
|
warmup_steps: 10 |
|
evals_per_epoch: 4 |
|
eval_table_size: |
|
saves_per_epoch: 1 |
|
debug: |
|
deepspeed: |
|
weight_decay: 0.0 |
|
fsdp: |
|
fsdp_config: |
|
special_tokens: |
|
""" |
|
|
|
|
|
def yml_config(yml_config): |
|
""" |
|
This function saves as a yaml file from user text. |
|
""" |
|
yml_config = yaml.safe_load(yml_config) |
|
with open("config.yml", "w", encoding="utf-8") as file: |
|
yaml.dump(yml_config, file) |
|
|
|
return yaml.dump(yml_config) |
|
|
|
|
|
with gr.Blocks(title="Axolotl Launcher") as demo: |
|
gr.Markdown(""" |
|
# Axolotl Launcher |
|
Fill out the required fields below to create a training run. |
|
""") |
|
with gr.Tab("Base Model & Tokenizer"): |
|
with gr.Column(): |
|
with gr.Row(): |
|
base_model = gr.Textbox(label="Base Model") |
|
base_model_ignore_patterns = gr.Textbox( |
|
label="Base Model Ignore Patterns") |
|
base_model_config = gr.Textbox(label="Base Model Config") |
|
model_revision = gr.Textbox(label="Model Revision") |
|
with gr.Row(): |
|
tokenizer_config = gr.Textbox(label="Tokenizer Config") |
|
model_type = gr.Textbox(label="Model Type") |
|
tokenizer_type = gr.Textbox(label="Tokenizer Type") |
|
with gr.Row(): |
|
trust_remote_code = gr.Checkbox(label="Trust Remote Code", value=False) |
|
tokenizer_use_fast = gr.Checkbox(label="Use Fast Tokenizer", |
|
value=True) |
|
tokenizer_legacy = gr.Checkbox(label="Use Legacy Tokenizer", |
|
value=False) |
|
resize_token_embeddings_to_32x = gr.Checkbox( |
|
label="Resize Token Embeddings to 32x", value=False) |
|
with gr.Accordion("Adv. Config", open=False): |
|
with gr.Accordion("Model Derivation & Configuration Overrides", open=False): |
|
with gr.Column(): |
|
is_falcon_derived_model = gr.Checkbox( |
|
label="Is Falcon Derived Model", value=False) |
|
is_llama_derived_model = gr.Checkbox(label="Is Llama Derived Model", |
|
value=False) |
|
is_mistral_derived_model = gr.Checkbox( |
|
label="Is Mistral Derived Model", value=False) |
|
is_qwen_derived_model = gr.Checkbox(label="Is Qwen Derived Model", |
|
value=False) |
|
model_config = gr.TextArea(label="Model Config Overrides", |
|
placeholder="YAML or JSON format") |
|
bnb_config_kwargs = gr.TextArea(label="BnB Config KWArgs", |
|
placeholder="YAML or JSON format") |
|
|
|
with gr.Accordion("Quantization & Precision", open=False): |
|
with gr.Column(): |
|
with gr.Row(): |
|
gptq = gr.Checkbox(label="GPTQ", value=False) |
|
gptq_groupsize = gr.Number(label="GPTQ Groupsize", value=128) |
|
gptq_model_v1 = gr.Checkbox(label="GPTQ Model V1", value=False) |
|
load_in_8bit = gr.Checkbox(label="Load in 8-bit", value=False) |
|
load_in_4bit = gr.Checkbox(label="Load in 4-bit", value=False) |
|
with gr.Row(): |
|
bf16 = gr.Checkbox(label="BF16", value=False) |
|
fp16 = gr.Checkbox(label="FP16", value=False) |
|
tf32 = gr.Checkbox(label="TF32", value=False) |
|
bfloat16 = gr.Checkbox(label="BFloat16", value=False) |
|
float16 = gr.Checkbox(label="Float16", value=False) |
|
|
|
with gr.Accordion("GPU & LoRA Settings", open=False): |
|
gpu_memory_limit = gr.Textbox(label="GPU Memory Limit") |
|
lora_on_cpu = gr.Checkbox(label="LoRA on CPU", value=False) |
|
datasets = gr.TextArea(label="Datasets", |
|
placeholder="YAML or JSON format for datasets") |
|
test_datasets = gr.TextArea( |
|
label="Test Datasets", |
|
placeholder="YAML or JSON format for test datasets") |
|
rl = gr.Textbox(label="RL") |
|
chat_template = gr.Textbox(label="Chat Template") |
|
default_system_message = gr.Textbox(label="Default System Message") |
|
dataset_prepared_path = gr.Textbox(label="Dataset Prepared Path") |
|
push_dataset_to_hub = gr.Textbox(label="Push Dataset to Hub") |
|
dataset_processes = gr.Number(label="Dataset Processes", value=1) |
|
dataset_keep_in_memory = gr.Checkbox(label="Dataset Keep in Memory", |
|
value=False) |
|
with gr.Row(): |
|
hub_model_id = gr.Textbox(label="Hub Model ID") |
|
hub_strategy = gr.Textbox(label="Hub Strategy") |
|
hf_use_auth_token = gr.Checkbox(label="HF Use Auth Token", |
|
value=False) |
|
with gr.Row(): |
|
val_set_size = gr.Number(label="Validation Set Size", |
|
value=0.04, |
|
step=0.01) |
|
dataset_shard_num = gr.Number(label="Dataset Shard Num") |
|
dataset_shard_idx = gr.Number(label="Dataset Shard Index") |
|
|
|
with gr.Accordion("Training & Evaluation", open=False): |
|
with gr.Row(): |
|
sequence_len = gr.Number(label="Sequence Length", value=2048) |
|
pad_to_sequence_len = gr.Checkbox(label="Pad to Sequence Length", |
|
value=False) |
|
with gr.Row(): |
|
sample_packing = gr.Checkbox(label="Sample Packing", value=False) |
|
eval_sample_packing = gr.Checkbox(label="Eval Sample Packing", |
|
value=False) |
|
sample_packing_eff_est = gr.Number(label="Sample Packing Eff Est") |
|
with gr.Row(): |
|
total_num_tokens = gr.Number(label="Total Num Tokens") |
|
device_map = gr.Textbox(label="Device Map") |
|
max_memory = gr.Textbox(label="Max Memory") |
|
adapter = gr.Textbox(label="Adapter") |
|
with gr.Column(): |
|
lora_model_dir = gr.Textbox(label="LoRA Model Dir") |
|
lora_r = gr.Number(label="LoRA R", value=8) |
|
lora_alpha = gr.Number(label="LoRA Alpha", value=16) |
|
lora_dropout = gr.Number(label="LoRA Dropout", value=0.05, step=0.01) |
|
lora_target_modules = gr.TextArea(label="LoRA Target Modules") |
|
lora_target_linear = gr.Checkbox(label="LoRA Target Linear", |
|
value=False) |
|
lora_modules_to_save = gr.TextArea(label="LoRA Modules to Save") |
|
lora_fan_in_fan_out = gr.Checkbox(label="LoRA Fan In Fan Out", |
|
value=False) |
|
peft = gr.Textbox(label="PEFT") |
|
with gr.Row(): |
|
relora_steps = gr.Number(label="ReLoRA Steps") |
|
relora_warmup_steps = gr.Number(label="ReLoRA Warmup Steps") |
|
relora_anneal_steps = gr.Number(label="ReLoRA Anneal Steps") |
|
relora_prune_ratio = gr.Number(label="ReLoRA Prune Ratio") |
|
relora_cpu_offload = gr.Checkbox(label="ReLoRA CPU Offload", |
|
value=False) |
|
with gr.Row(): |
|
wandb_mode = gr.Textbox(label="WandB Mode") |
|
wandb_project = gr.Textbox(label="WandB Project") |
|
wandb_entity = gr.Textbox(label="WandB Entity") |
|
wandb_watch = gr.Checkbox(label="WandB Watch", value=False) |
|
wandb_name = gr.Textbox(label="WandB Name") |
|
wandb_run_id = gr.Textbox(label="WandB Run ID") |
|
wandb_log_model = gr.Checkbox(label="WandB Log Model", value=False) |
|
with gr.Column(): |
|
mlflow_tracking_uri = gr.Textbox(label="MLFlow Tracking URI") |
|
mlflow_experiment_name = gr.Textbox(label="MLFlow Experiment Name") |
|
output_dir = gr.Textbox(label="Output Dir") |
|
torch_compile = gr.Checkbox(label="Torch Compile", value=False) |
|
torch_compile_backend = gr.Textbox(label="Torch Compile Backend") |
|
gradient_accumulation_steps = gr.Number( |
|
label="Gradient Accumulation Steps", value=1) |
|
micro_batch_size = gr.Number(label="Micro Batch Size", value=2) |
|
eval_batch_size = gr.Number(label="Eval Batch Size", value=2) |
|
num_epochs = gr.Number(label="Number of Epochs", value=4) |
|
warmup_steps = gr.Number(label="Warmup Steps", value=100) |
|
warmup_ratio = gr.Number(label="Warmup Ratio") |
|
learning_rate = gr.Number(label="Learning Rate", |
|
value=0.00003, |
|
step=1e-5) |
|
lr_quadratic_warmup = gr.Checkbox(label="LR Quadratic Warmup", |
|
value=False) |
|
logging_steps = gr.Number(label="Logging Steps", value=1) |
|
eval_steps = gr.Textbox(label="Eval Steps") |
|
evals_per_epoch = gr.Number(label="Evals per Epoch", value=4) |
|
save_strategy = gr.Textbox(label="Save Strategy") |
|
save_steps = gr.Textbox(label="Save Steps") |
|
saves_per_epoch = gr.Number(label="Saves per Epoch", value=1) |
|
save_total_limit = gr.Number(label="Save Total Limit") |
|
max_steps = gr.Number(label="Max Steps") |
|
eval_table_size = gr.Number(label="Eval Table Size") |
|
eval_max_new_tokens = gr.Number(label="Eval Max New Tokens", |
|
value=128) |
|
eval_causal_lm_metrics = gr.TextArea(label="Eval Causal LM Metrics") |
|
loss_watchdog_threshold = gr.Number(label="Loss Watchdog Threshold") |
|
loss_watchdog_patience = gr.Number(label="Loss Watchdog Patience", |
|
value=3) |
|
save_safetensors = gr.Checkbox(label="Save SafeTensors", value=False) |
|
train_on_inputs = gr.Checkbox(label="Train on Inputs", value=False) |
|
group_by_length = gr.Checkbox(label="Group by Length", value=False) |
|
gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", |
|
value=False) |
|
early_stopping_patience = gr.Number(label="Early Stopping Patience", |
|
value=3) |
|
lr_scheduler = gr.Textbox(label="LR Scheduler") |
|
lr_scheduler_kwargs = gr.TextArea(label="LR Scheduler KWArgs") |
|
cosine_min_lr_ratio = gr.Number(label="Cosine Min LR Ratio") |
|
cosine_constant_lr_ratio = gr.Number( |
|
label="Cosine Constant LR Ratio") |
|
lr_div_factor = gr.Number(label="LR Div Factor") |
|
log_sweep_min_lr = gr.Number(label="Log Sweep Min LR") |
|
log_sweep_max_lr = gr.Number(label="Log Sweep Max LR") |
|
optimizer = gr.Textbox(label="Optimizer") |
|
weight_decay = gr.Number(label="Weight Decay", value=0.0, step=0.01) |
|
adam_beta1 = gr.Number(label="Adam Beta1", value=0.9, step=0.01) |
|
adam_beta2 = gr.Number(label="Adam Beta2", value=0.999, step=0.001) |
|
adam_epsilon = gr.Number(label="Adam Epsilon", value=1e-8, step=1e-9) |
|
max_grad_norm = gr.Number(label="Max Grad Norm") |
|
neftune_noise_alpha = gr.Number(label="NEFTune Noise Alpha") |
|
flash_optimum = gr.Checkbox(label="Flash Optimum", value=False) |
|
xformers_attention = gr.Checkbox(label="XFormers Attention", |
|
value=False) |
|
flash_attention = gr.Checkbox(label="Flash Attention", value=False) |
|
flash_attn_cross_entropy = gr.Checkbox( |
|
label="Flash Attn Cross Entropy", value=False) |
|
flash_attn_rms_norm = gr.Checkbox(label="Flash Attn RMS Norm", |
|
value=False) |
|
flash_attn_fuse_qkv = gr.Checkbox(label="Flash Attn Fuse QKV", |
|
value=False) |
|
flash_attn_fuse_mlp = gr.Checkbox(label="Flash Attn Fuse MLP", |
|
value=False) |
|
sdp_attention = gr.Checkbox(label="SDP Attention", value=False) |
|
s2_attention = gr.Checkbox(label="S2 Attention", value=False) |
|
resume_from_checkpoint = gr.Textbox(label="Resume From Checkpoint") |
|
auto_resume_from_checkpoints = gr.Checkbox( |
|
label="Auto Resume From Checkpoints", value=False) |
|
local_rank = gr.Number(label="Local Rank") |
|
special_tokens = gr.TextArea(label="Special Tokens") |
|
tokens = gr.TextArea(label="Tokens") |
|
fsdp = gr.Checkbox(label="FSDP", value=False) |
|
fsdp_config = gr.TextArea(label="FSDP Config") |
|
deepspeed = gr.Textbox(label="Deepspeed") |
|
ddp_timeout = gr.Number(label="DDP Timeout") |
|
ddp_bucket_cap_mb = gr.Number(label="DDP Bucket Cap MB") |
|
ddp_broadcast_buffers = gr.Checkbox(label="DDP Broadcast Buffers", |
|
value=False) |
|
torchdistx_path = gr.Textbox(label="TorchDistX Path") |
|
pretraining_dataset = gr.Textbox(label="Pretraining Dataset") |
|
debug = gr.Checkbox(label="Debug", value=False) |
|
seed = gr.Number(label="Seed", value=42) |
|
strict = gr.Checkbox(label="Strict", value=False) |
|
|
|
submit_button = gr.Button("Launch Configuration") |
|
output_area = gr.TextArea(label="Configuration Output") |
|
|
|
submit_button.click( |
|
config, |
|
inputs=[ |
|
base_model, base_model_ignore_patterns, base_model_config, |
|
model_revision, tokenizer_config, model_type, tokenizer_type, |
|
trust_remote_code, tokenizer_use_fast, tokenizer_legacy, |
|
resize_token_embeddings_to_32x, is_falcon_derived_model, |
|
is_llama_derived_model, is_mistral_derived_model, |
|
is_qwen_derived_model, model_config, bnb_config_kwargs, gptq, |
|
gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16, |
|
fp16, tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu, |
|
datasets, test_datasets, rl, chat_template, default_system_message, |
|
dataset_prepared_path, push_dataset_to_hub, dataset_processes, |
|
dataset_keep_in_memory, hub_model_id, hub_strategy, |
|
hf_use_auth_token, val_set_size, dataset_shard_num, |
|
dataset_shard_idx, sequence_len, pad_to_sequence_len, sample_packing, |
|
eval_sample_packing, sample_packing_eff_est, total_num_tokens, |
|
device_map, max_memory, adapter, lora_model_dir, lora_r, lora_alpha, |
|
lora_dropout, lora_target_modules, lora_target_linear, |
|
lora_modules_to_save, lora_fan_in_fan_out, peft, relora_steps, |
|
relora_warmup_steps, relora_anneal_steps, relora_prune_ratio, |
|
relora_cpu_offload, wandb_mode, wandb_project, wandb_entity, |
|
wandb_watch, wandb_name, wandb_run_id, wandb_log_model, |
|
mlflow_tracking_uri, mlflow_experiment_name, output_dir, |
|
torch_compile, torch_compile_backend, gradient_accumulation_steps, |
|
micro_batch_size, eval_batch_size, num_epochs, warmup_steps, |
|
warmup_ratio, learning_rate, lr_quadratic_warmup, logging_steps, |
|
eval_steps, evals_per_epoch, save_strategy, save_steps, |
|
saves_per_epoch, save_total_limit, max_steps, eval_table_size, |
|
eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold, |
|
loss_watchdog_patience, save_safetensors, train_on_inputs, |
|
group_by_length, gradient_checkpointing, early_stopping_patience, |
|
lr_scheduler, lr_scheduler_kwargs, cosine_min_lr_ratio, |
|
cosine_constant_lr_ratio, lr_div_factor, log_sweep_min_lr, |
|
log_sweep_max_lr, optimizer, weight_decay, adam_beta1, adam_beta2, |
|
adam_epsilon, max_grad_norm, neftune_noise_alpha, flash_optimum, |
|
xformers_attention, flash_attention, flash_attn_cross_entropy, |
|
flash_attn_rms_norm, flash_attn_fuse_qkv, flash_attn_fuse_mlp, |
|
sdp_attention, s2_attention, resume_from_checkpoint, |
|
auto_resume_from_checkpoints, local_rank, special_tokens, tokens, |
|
fsdp, fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb, |
|
ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug, |
|
seed, strict |
|
], |
|
outputs=output_area) |
|
""" |
|
This section is used to create a configuration file from user text. |
|
""" |
|
with gr.Tab(label="YML text"): |
|
yml_config_text = gr.TextArea(label='YML Config', |
|
lines=50, |
|
value=example_yml) |
|
create_config = gr.Button("Create config") |
|
output = gr.TextArea(label="Generated config") |
|
create_config.click( |
|
yml_config, |
|
inputs=[yml_config_text], |
|
outputs=output, |
|
) |
|
|
|
demo.launch(share=True) |
|
|