test / app.py
goku6045's picture
Update app.py
b06c01f verified
raw
history blame
17.6 kB
"""
This module is used to launch Axolotl with user defined configurations.
"""
import gradio as gr
import yaml
from config import config
example_yml = """
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
"""
def yml_config(yml_config):
"""
This function saves as a yaml file from user text.
"""
yml_config = yaml.safe_load(yml_config)
with open("config.yml", "w", encoding="utf-8") as file:
yaml.dump(yml_config, file)
# print(yml_config)
return yaml.dump(yml_config)
with gr.Blocks(title="Axolotl Launcher") as demo:
gr.Markdown("""
# Axolotl Launcher
Fill out the required fields below to create a training run.
""")
with gr.Tab("Base Model & Tokenizer"):
with gr.Column():
with gr.Row():
base_model = gr.Textbox(label="Base Model")
base_model_ignore_patterns = gr.Textbox(
label="Base Model Ignore Patterns")
base_model_config = gr.Textbox(label="Base Model Config")
model_revision = gr.Textbox(label="Model Revision")
with gr.Row():
tokenizer_config = gr.Textbox(label="Tokenizer Config")
model_type = gr.Textbox(label="Model Type")
tokenizer_type = gr.Textbox(label="Tokenizer Type")
with gr.Row():
trust_remote_code = gr.Checkbox(label="Trust Remote Code", value=False)
tokenizer_use_fast = gr.Checkbox(label="Use Fast Tokenizer",
value=True)
tokenizer_legacy = gr.Checkbox(label="Use Legacy Tokenizer",
value=False)
resize_token_embeddings_to_32x = gr.Checkbox(
label="Resize Token Embeddings to 32x", value=False)
with gr.Accordion("Adv. Config", open=False):
with gr.Accordion("Model Derivation & Configuration Overrides", open=False):
with gr.Column():
is_falcon_derived_model = gr.Checkbox(
label="Is Falcon Derived Model", value=False)
is_llama_derived_model = gr.Checkbox(label="Is Llama Derived Model",
value=False)
is_mistral_derived_model = gr.Checkbox(
label="Is Mistral Derived Model", value=False)
is_qwen_derived_model = gr.Checkbox(label="Is Qwen Derived Model",
value=False)
model_config = gr.TextArea(label="Model Config Overrides",
placeholder="YAML or JSON format")
bnb_config_kwargs = gr.TextArea(label="BnB Config KWArgs",
placeholder="YAML or JSON format")
with gr.Accordion("Quantization & Precision", open=False):
with gr.Column():
with gr.Row():
gptq = gr.Checkbox(label="GPTQ", value=False)
gptq_groupsize = gr.Number(label="GPTQ Groupsize", value=128)
gptq_model_v1 = gr.Checkbox(label="GPTQ Model V1", value=False)
load_in_8bit = gr.Checkbox(label="Load in 8-bit", value=False)
load_in_4bit = gr.Checkbox(label="Load in 4-bit", value=False)
with gr.Row():
bf16 = gr.Checkbox(label="BF16", value=False)
fp16 = gr.Checkbox(label="FP16", value=False)
tf32 = gr.Checkbox(label="TF32", value=False)
bfloat16 = gr.Checkbox(label="BFloat16", value=False)
float16 = gr.Checkbox(label="Float16", value=False)
with gr.Accordion("GPU & LoRA Settings", open=False):
gpu_memory_limit = gr.Textbox(label="GPU Memory Limit")
lora_on_cpu = gr.Checkbox(label="LoRA on CPU", value=False)
datasets = gr.TextArea(label="Datasets",
placeholder="YAML or JSON format for datasets")
test_datasets = gr.TextArea(
label="Test Datasets",
placeholder="YAML or JSON format for test datasets")
rl = gr.Textbox(label="RL")
chat_template = gr.Textbox(label="Chat Template")
default_system_message = gr.Textbox(label="Default System Message")
dataset_prepared_path = gr.Textbox(label="Dataset Prepared Path")
push_dataset_to_hub = gr.Textbox(label="Push Dataset to Hub")
dataset_processes = gr.Number(label="Dataset Processes", value=1)
dataset_keep_in_memory = gr.Checkbox(label="Dataset Keep in Memory",
value=False)
with gr.Row():
hub_model_id = gr.Textbox(label="Hub Model ID")
hub_strategy = gr.Textbox(label="Hub Strategy")
hf_use_auth_token = gr.Checkbox(label="HF Use Auth Token",
value=False)
with gr.Row():
val_set_size = gr.Number(label="Validation Set Size",
value=0.04,
step=0.01)
dataset_shard_num = gr.Number(label="Dataset Shard Num")
dataset_shard_idx = gr.Number(label="Dataset Shard Index")
with gr.Accordion("Training & Evaluation", open=False):
with gr.Row():
sequence_len = gr.Number(label="Sequence Length", value=2048)
pad_to_sequence_len = gr.Checkbox(label="Pad to Sequence Length",
value=False)
with gr.Row():
sample_packing = gr.Checkbox(label="Sample Packing", value=False)
eval_sample_packing = gr.Checkbox(label="Eval Sample Packing",
value=False)
sample_packing_eff_est = gr.Number(label="Sample Packing Eff Est")
with gr.Row():
total_num_tokens = gr.Number(label="Total Num Tokens")
device_map = gr.Textbox(label="Device Map")
max_memory = gr.Textbox(label="Max Memory")
adapter = gr.Textbox(label="Adapter")
with gr.Column():
lora_model_dir = gr.Textbox(label="LoRA Model Dir")
lora_r = gr.Number(label="LoRA R", value=8)
lora_alpha = gr.Number(label="LoRA Alpha", value=16)
lora_dropout = gr.Number(label="LoRA Dropout", value=0.05, step=0.01)
lora_target_modules = gr.TextArea(label="LoRA Target Modules")
lora_target_linear = gr.Checkbox(label="LoRA Target Linear",
value=False)
lora_modules_to_save = gr.TextArea(label="LoRA Modules to Save")
lora_fan_in_fan_out = gr.Checkbox(label="LoRA Fan In Fan Out",
value=False)
peft = gr.Textbox(label="PEFT")
with gr.Row():
relora_steps = gr.Number(label="ReLoRA Steps")
relora_warmup_steps = gr.Number(label="ReLoRA Warmup Steps")
relora_anneal_steps = gr.Number(label="ReLoRA Anneal Steps")
relora_prune_ratio = gr.Number(label="ReLoRA Prune Ratio")
relora_cpu_offload = gr.Checkbox(label="ReLoRA CPU Offload",
value=False)
with gr.Row():
wandb_mode = gr.Textbox(label="WandB Mode")
wandb_project = gr.Textbox(label="WandB Project")
wandb_entity = gr.Textbox(label="WandB Entity")
wandb_watch = gr.Checkbox(label="WandB Watch", value=False)
wandb_name = gr.Textbox(label="WandB Name")
wandb_run_id = gr.Textbox(label="WandB Run ID")
wandb_log_model = gr.Checkbox(label="WandB Log Model", value=False)
with gr.Column():
mlflow_tracking_uri = gr.Textbox(label="MLFlow Tracking URI")
mlflow_experiment_name = gr.Textbox(label="MLFlow Experiment Name")
output_dir = gr.Textbox(label="Output Dir")
torch_compile = gr.Checkbox(label="Torch Compile", value=False)
torch_compile_backend = gr.Textbox(label="Torch Compile Backend")
gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps", value=1)
micro_batch_size = gr.Number(label="Micro Batch Size", value=2)
eval_batch_size = gr.Number(label="Eval Batch Size", value=2)
num_epochs = gr.Number(label="Number of Epochs", value=4)
warmup_steps = gr.Number(label="Warmup Steps", value=100)
warmup_ratio = gr.Number(label="Warmup Ratio")
learning_rate = gr.Number(label="Learning Rate",
value=0.00003,
step=1e-5)
lr_quadratic_warmup = gr.Checkbox(label="LR Quadratic Warmup",
value=False)
logging_steps = gr.Number(label="Logging Steps", value=1)
eval_steps = gr.Textbox(label="Eval Steps")
evals_per_epoch = gr.Number(label="Evals per Epoch", value=4)
save_strategy = gr.Textbox(label="Save Strategy")
save_steps = gr.Textbox(label="Save Steps")
saves_per_epoch = gr.Number(label="Saves per Epoch", value=1)
save_total_limit = gr.Number(label="Save Total Limit")
max_steps = gr.Number(label="Max Steps")
eval_table_size = gr.Number(label="Eval Table Size")
eval_max_new_tokens = gr.Number(label="Eval Max New Tokens",
value=128)
eval_causal_lm_metrics = gr.TextArea(label="Eval Causal LM Metrics")
loss_watchdog_threshold = gr.Number(label="Loss Watchdog Threshold")
loss_watchdog_patience = gr.Number(label="Loss Watchdog Patience",
value=3)
save_safetensors = gr.Checkbox(label="Save SafeTensors", value=False)
train_on_inputs = gr.Checkbox(label="Train on Inputs", value=False)
group_by_length = gr.Checkbox(label="Group by Length", value=False)
gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing",
value=False)
early_stopping_patience = gr.Number(label="Early Stopping Patience",
value=3)
lr_scheduler = gr.Textbox(label="LR Scheduler")
lr_scheduler_kwargs = gr.TextArea(label="LR Scheduler KWArgs")
cosine_min_lr_ratio = gr.Number(label="Cosine Min LR Ratio")
cosine_constant_lr_ratio = gr.Number(
label="Cosine Constant LR Ratio")
lr_div_factor = gr.Number(label="LR Div Factor")
log_sweep_min_lr = gr.Number(label="Log Sweep Min LR")
log_sweep_max_lr = gr.Number(label="Log Sweep Max LR")
optimizer = gr.Textbox(label="Optimizer")
weight_decay = gr.Number(label="Weight Decay", value=0.0, step=0.01)
adam_beta1 = gr.Number(label="Adam Beta1", value=0.9, step=0.01)
adam_beta2 = gr.Number(label="Adam Beta2", value=0.999, step=0.001)
adam_epsilon = gr.Number(label="Adam Epsilon", value=1e-8, step=1e-9)
max_grad_norm = gr.Number(label="Max Grad Norm")
neftune_noise_alpha = gr.Number(label="NEFTune Noise Alpha")
flash_optimum = gr.Checkbox(label="Flash Optimum", value=False)
xformers_attention = gr.Checkbox(label="XFormers Attention",
value=False)
flash_attention = gr.Checkbox(label="Flash Attention", value=False)
flash_attn_cross_entropy = gr.Checkbox(
label="Flash Attn Cross Entropy", value=False)
flash_attn_rms_norm = gr.Checkbox(label="Flash Attn RMS Norm",
value=False)
flash_attn_fuse_qkv = gr.Checkbox(label="Flash Attn Fuse QKV",
value=False)
flash_attn_fuse_mlp = gr.Checkbox(label="Flash Attn Fuse MLP",
value=False)
sdp_attention = gr.Checkbox(label="SDP Attention", value=False)
s2_attention = gr.Checkbox(label="S2 Attention", value=False)
resume_from_checkpoint = gr.Textbox(label="Resume From Checkpoint")
auto_resume_from_checkpoints = gr.Checkbox(
label="Auto Resume From Checkpoints", value=False)
local_rank = gr.Number(label="Local Rank")
special_tokens = gr.TextArea(label="Special Tokens")
tokens = gr.TextArea(label="Tokens")
fsdp = gr.Checkbox(label="FSDP", value=False)
fsdp_config = gr.TextArea(label="FSDP Config")
deepspeed = gr.Textbox(label="Deepspeed")
ddp_timeout = gr.Number(label="DDP Timeout")
ddp_bucket_cap_mb = gr.Number(label="DDP Bucket Cap MB")
ddp_broadcast_buffers = gr.Checkbox(label="DDP Broadcast Buffers",
value=False)
torchdistx_path = gr.Textbox(label="TorchDistX Path")
pretraining_dataset = gr.Textbox(label="Pretraining Dataset")
debug = gr.Checkbox(label="Debug", value=False)
seed = gr.Number(label="Seed", value=42)
strict = gr.Checkbox(label="Strict", value=False)
submit_button = gr.Button("Launch Configuration")
output_area = gr.TextArea(label="Configuration Output")
submit_button.click(
config,
inputs=[
base_model, base_model_ignore_patterns, base_model_config,
model_revision, tokenizer_config, model_type, tokenizer_type,
trust_remote_code, tokenizer_use_fast, tokenizer_legacy,
resize_token_embeddings_to_32x, is_falcon_derived_model,
is_llama_derived_model, is_mistral_derived_model,
is_qwen_derived_model, model_config, bnb_config_kwargs, gptq,
gptq_groupsize, gptq_model_v1, load_in_8bit, load_in_4bit, bf16,
fp16, tf32, bfloat16, float16, gpu_memory_limit, lora_on_cpu,
datasets, test_datasets, rl, chat_template, default_system_message,
dataset_prepared_path, push_dataset_to_hub, dataset_processes,
dataset_keep_in_memory, hub_model_id, hub_strategy,
hf_use_auth_token, val_set_size, dataset_shard_num,
dataset_shard_idx, sequence_len, pad_to_sequence_len, sample_packing,
eval_sample_packing, sample_packing_eff_est, total_num_tokens,
device_map, max_memory, adapter, lora_model_dir, lora_r, lora_alpha,
lora_dropout, lora_target_modules, lora_target_linear,
lora_modules_to_save, lora_fan_in_fan_out, peft, relora_steps,
relora_warmup_steps, relora_anneal_steps, relora_prune_ratio,
relora_cpu_offload, wandb_mode, wandb_project, wandb_entity,
wandb_watch, wandb_name, wandb_run_id, wandb_log_model,
mlflow_tracking_uri, mlflow_experiment_name, output_dir,
torch_compile, torch_compile_backend, gradient_accumulation_steps,
micro_batch_size, eval_batch_size, num_epochs, warmup_steps,
warmup_ratio, learning_rate, lr_quadratic_warmup, logging_steps,
eval_steps, evals_per_epoch, save_strategy, save_steps,
saves_per_epoch, save_total_limit, max_steps, eval_table_size,
eval_max_new_tokens, eval_causal_lm_metrics, loss_watchdog_threshold,
loss_watchdog_patience, save_safetensors, train_on_inputs,
group_by_length, gradient_checkpointing, early_stopping_patience,
lr_scheduler, lr_scheduler_kwargs, cosine_min_lr_ratio,
cosine_constant_lr_ratio, lr_div_factor, log_sweep_min_lr,
log_sweep_max_lr, optimizer, weight_decay, adam_beta1, adam_beta2,
adam_epsilon, max_grad_norm, neftune_noise_alpha, flash_optimum,
xformers_attention, flash_attention, flash_attn_cross_entropy,
flash_attn_rms_norm, flash_attn_fuse_qkv, flash_attn_fuse_mlp,
sdp_attention, s2_attention, resume_from_checkpoint,
auto_resume_from_checkpoints, local_rank, special_tokens, tokens,
fsdp, fsdp_config, deepspeed, ddp_timeout, ddp_bucket_cap_mb,
ddp_broadcast_buffers, torchdistx_path, pretraining_dataset, debug,
seed, strict
],
outputs=output_area)
"""
This section is used to create a configuration file from user text.
"""
with gr.Tab(label="YML text"):
yml_config_text = gr.TextArea(label='YML Config',
lines=50,
value=example_yml)
create_config = gr.Button("Create config")
output = gr.TextArea(label="Generated config")
create_config.click(
yml_config,
inputs=[yml_config_text],
outputs=output,
)
demo.launch(share=True)