Spaces:
Running
Running
improves interface flow
Browse files- interface.py +181 -233
interface.py
CHANGED
|
@@ -829,10 +829,91 @@ joinus = """
|
|
| 829 |
"""
|
| 830 |
|
| 831 |
|
| 832 |
-
def on_family_change(family: str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 833 |
confs = list(get_config_map(family).keys())
|
| 834 |
exp, repo_short, desc, space = ui_defaults(family)
|
| 835 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
|
| 837 |
|
| 838 |
def start_pipeline(
|
|
@@ -932,243 +1013,110 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
| 932 |
)
|
| 933 |
gr.Markdown(joinus)
|
| 934 |
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
model_family = gr.Dropdown(choices=MODEL_FAMILIES, value="SmolLM3", label="Model family")
|
| 964 |
-
trainer_type = gr.Radio(choices=TRAINER_CHOICES, value="SFT", label="Trainer type")
|
| 965 |
-
monitoring_mode = gr.Dropdown(choices=MONITORING_CHOICES, value="both", label="Monitoring mode")
|
| 966 |
-
|
| 967 |
-
config_choice = gr.Dropdown(choices=list(get_config_map("SmolLM3").keys()), value="Basic Training", label="Training configuration")
|
| 968 |
-
|
| 969 |
-
exp_default, repo_default, desc_default, trackio_space_default = ui_defaults("SmolLM3")
|
| 970 |
-
with gr.Row():
|
| 971 |
-
experiment_name = gr.Textbox(value=exp_default, label="Experiment name")
|
| 972 |
-
repo_short = gr.Textbox(value=repo_default, label="Model repo (short name)")
|
| 973 |
-
|
| 974 |
-
with gr.Row():
|
| 975 |
-
author_name = gr.Textbox(value=os.environ.get("HF_USERNAME", ""), label="Author name")
|
| 976 |
-
model_description = gr.Textbox(value=desc_default, label="Model description")
|
| 977 |
-
|
| 978 |
-
with gr.Row():
|
| 979 |
-
trackio_space_name = gr.Textbox(value=trackio_space_default, label="Trackio Space name (used when monitoring != none)")
|
| 980 |
-
deploy_trackio_space = gr.Checkbox(value=True, label="Deploy Trackio Space")
|
| 981 |
-
create_dataset_repo = gr.Checkbox(value=True, label="Create/ensure HF Dataset repo")
|
| 982 |
-
|
| 983 |
-
with gr.Row():
|
| 984 |
-
push_to_hub = gr.Checkbox(value=True, label="Push model to Hugging Face Hub")
|
| 985 |
-
switch_to_read_after = gr.Checkbox(value=True, label="Switch Space token to READ after training")
|
| 986 |
-
|
| 987 |
-
gr.Markdown("### Medical SFT (GPT-OSS o1)")
|
| 988 |
-
gr.Markdown("Configure GPT-OSS Medical o1 SFT (FreedomIntelligence/medical-o1-reasoning-SFT)")
|
| 989 |
-
med_dataset_config = gr.Dropdown(choices=["en", "en_mix", "zh", "zh_mix"], value="en", label="Dataset config")
|
| 990 |
-
med_system = gr.Textbox(value="You are GPT-Tonic, a large language model trained by TonicAI.", label="System message", lines=2)
|
| 991 |
-
med_developer = gr.Textbox(value="You are are GPT-Tonic, an intelligent assistant that always answers health-related queries scientifically.", label="Developer message", lines=3)
|
| 992 |
-
with gr.Row():
|
| 993 |
-
med_epochs = gr.Number(value=2.0, precision=2, label="Epochs")
|
| 994 |
-
med_bs = gr.Number(value=4, precision=0, label="Batch size")
|
| 995 |
-
med_gas = gr.Number(value=4, precision=0, label="Grad accumulation")
|
| 996 |
-
med_lr = gr.Number(value=2e-4, precision=6, label="Learning rate")
|
| 997 |
-
med_msl = gr.Number(value=2048, precision=0, label="Max seq length")
|
| 998 |
-
med_generate = gr.Button("Generate Medical Config")
|
| 999 |
-
med_status = gr.Textbox(label="Generated config path", interactive=False)
|
| 1000 |
-
|
| 1001 |
-
logs = gr.Textbox(value="", label="Logs", lines=20)
|
| 1002 |
-
start_btn = gr.Button("Run Pipeline")
|
| 1003 |
-
|
| 1004 |
-
with gr.Tab("Advanced Config"):
|
| 1005 |
-
with gr.Accordion("GPT-OSS Scheduler Overrides", open=False):
|
| 1006 |
-
scheduler_override = gr.Dropdown(choices=[c for c in SCHEDULER_CHOICES if c is not None], value=None, allow_custom_value=True, label="Scheduler override")
|
| 1007 |
-
min_lr = gr.Number(value=None, precision=6, label="min_lr (when cosine_with_min_lr)")
|
| 1008 |
-
min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (when cosine_with_min_lr)")
|
| 1009 |
-
|
| 1010 |
-
gr.Markdown("### GPT-OSS Custom Dataset")
|
| 1011 |
-
with gr.Row():
|
| 1012 |
-
cds_dataset = gr.Textbox(value="legmlai/openhermes-fr", label="Dataset name")
|
| 1013 |
-
cds_split = gr.Textbox(value="train", label="Split")
|
| 1014 |
-
cds_format = gr.Dropdown(choices=["openhermes_fr", "messages", "text", "medical_o1_sft", "custom", "preference"], value="openhermes_fr", label="Format")
|
| 1015 |
-
with gr.Row():
|
| 1016 |
-
cds_input = gr.Textbox(value="prompt", label="Input field")
|
| 1017 |
-
cds_target = gr.Textbox(value="accepted_completion", label="Target field (optional, blank for None)")
|
| 1018 |
-
with gr.Row():
|
| 1019 |
-
cds_sys = gr.Textbox(value="", label="System message (optional)")
|
| 1020 |
-
cds_dev = gr.Textbox(value="", label="Developer message (optional)")
|
| 1021 |
-
with gr.Row():
|
| 1022 |
-
cds_identity = gr.Textbox(value="You are GPT-Tonic, a large language model trained by TonicAI.", label="Model identity (chat_template_kwargs.model_identity)")
|
| 1023 |
-
with gr.Row():
|
| 1024 |
-
cds_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
|
| 1025 |
-
cds_min_len = gr.Number(value=10, precision=0, label="Min length")
|
| 1026 |
-
cds_max_len = gr.Number(value=None, precision=0, label="Max length (optional)")
|
| 1027 |
-
gr.Markdown("#### Training Hyperparameters")
|
| 1028 |
-
with gr.Row():
|
| 1029 |
-
cds_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
|
| 1030 |
-
cds_bs = gr.Number(value=4, precision=0, label="Batch size")
|
| 1031 |
-
cds_gas = gr.Number(value=4, precision=0, label="Grad accumulation")
|
| 1032 |
-
cds_lr = gr.Number(value=2e-4, precision=6, label="Learning rate")
|
| 1033 |
-
cds_minlr = gr.Number(value=2e-5, precision=6, label="Min LR")
|
| 1034 |
-
with gr.Row():
|
| 1035 |
-
cds_wd = gr.Number(value=0.01, precision=6, label="Weight decay")
|
| 1036 |
-
cds_warm = gr.Number(value=0.03, precision=6, label="Warmup ratio")
|
| 1037 |
-
cds_msl = gr.Number(value=2048, precision=0, label="Max seq length")
|
| 1038 |
-
gr.Markdown("#### LoRA / Precision / Quantization / Perf")
|
| 1039 |
-
with gr.Row():
|
| 1040 |
-
cds_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
|
| 1041 |
-
cds_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
|
| 1042 |
-
cds_lora_dropout = gr.Number(value=0.05, precision=4, label="LoRA dropout")
|
| 1043 |
-
with gr.Row():
|
| 1044 |
-
cds_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
|
| 1045 |
-
cds_workers = gr.Number(value=4, precision=0, label="Data workers")
|
| 1046 |
-
cds_quant = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
|
| 1047 |
-
with gr.Row():
|
| 1048 |
-
cds_mgn = gr.Number(value=1.0, precision=4, label="Max grad norm")
|
| 1049 |
-
cds_log_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
| 1050 |
-
cds_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
|
| 1051 |
-
cds_save_steps = gr.Number(value=500, precision=0, label="Save steps")
|
| 1052 |
-
cds_generate = gr.Button("Generate GPT-OSS Custom Config")
|
| 1053 |
-
cds_status = gr.Textbox(label="Generated config path", interactive=False)
|
| 1054 |
-
|
| 1055 |
-
gr.Markdown("### SmolLM3 Custom Configuration")
|
| 1056 |
-
with gr.Row():
|
| 1057 |
-
sm_model = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
|
| 1058 |
-
sm_dataset = gr.Textbox(value="legmlai/openhermes-fr", label="Dataset (optional; leave blank for local)")
|
| 1059 |
-
with gr.Row():
|
| 1060 |
-
sm_msl = gr.Number(value=4096, precision=0, label="Max seq length")
|
| 1061 |
-
sm_bs = gr.Number(value=2, precision=0, label="Batch size")
|
| 1062 |
-
sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
|
| 1063 |
-
sm_lr = gr.Number(value=5e-6, precision=8, label="Learning rate")
|
| 1064 |
-
with gr.Row():
|
| 1065 |
-
sm_save = gr.Number(value=500, precision=0, label="Save steps")
|
| 1066 |
-
sm_eval = gr.Number(value=100, precision=0, label="Eval steps")
|
| 1067 |
-
sm_log = gr.Number(value=10, precision=0, label="Logging steps")
|
| 1068 |
-
with gr.Row():
|
| 1069 |
-
sm_filter = gr.Checkbox(value=False, label="Filter bad entries")
|
| 1070 |
-
sm_in = gr.Textbox(value="prompt", label="Input field")
|
| 1071 |
-
sm_out = gr.Textbox(value="accepted_completion", label="Target field")
|
| 1072 |
-
with gr.Row():
|
| 1073 |
-
sm_sample = gr.Number(value=None, precision=0, label="Sample size (optional)")
|
| 1074 |
-
sm_seed = gr.Number(value=42, precision=0, label="Sample seed")
|
| 1075 |
-
sm_trainer = gr.Dropdown(choices=["SFT", "DPO"], value="SFT", label="Trainer type")
|
| 1076 |
-
sm_generate = gr.Button("Generate SmolLM3 Custom Config")
|
| 1077 |
-
sm_status = gr.Textbox(label="Generated config path", interactive=False)
|
| 1078 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1079 |
logs = gr.Textbox(value="", label="Logs", lines=20)
|
| 1080 |
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
)
|
| 1099 |
-
),
|
| 1100 |
-
inputs=[med_dataset_config, med_system, med_developer, med_epochs, med_bs, med_gas, med_lr, med_msl],
|
| 1101 |
-
outputs=[med_status],
|
| 1102 |
)
|
| 1103 |
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
input_field=ifld,
|
| 1111 |
-
target_field=(tfld or None),
|
| 1112 |
-
system_message=sm,
|
| 1113 |
-
developer_message=dm,
|
| 1114 |
-
model_identity=ident,
|
| 1115 |
-
max_samples=(int(ms) if ms is not None else None),
|
| 1116 |
-
min_length=int(minl or 10),
|
| 1117 |
-
max_length=(int(maxl) if maxl is not None else None),
|
| 1118 |
-
num_train_epochs=float(ep or 1.0),
|
| 1119 |
-
batch_size=int(bs or 4),
|
| 1120 |
-
gradient_accumulation_steps=int(gas or 4),
|
| 1121 |
-
learning_rate=float(lr or 2e-4),
|
| 1122 |
-
min_lr=float(minlr or 2e-5),
|
| 1123 |
-
weight_decay=float(wd or 0.01),
|
| 1124 |
-
warmup_ratio=float(warm or 0.03),
|
| 1125 |
-
max_seq_length=int(msl or 2048),
|
| 1126 |
-
lora_r=int(lr_),
|
| 1127 |
-
lora_alpha=int(la),
|
| 1128 |
-
lora_dropout=float(ld),
|
| 1129 |
-
mixed_precision=prec,
|
| 1130 |
-
num_workers=int(nw or 4),
|
| 1131 |
-
quantization_type=q,
|
| 1132 |
-
max_grad_norm=float(mgn or 1.0),
|
| 1133 |
-
logging_steps=int(logst or 10),
|
| 1134 |
-
eval_steps=int(evst or 100),
|
| 1135 |
-
save_steps=int(savst or 500),
|
| 1136 |
-
)
|
| 1137 |
-
),
|
| 1138 |
-
inputs=[
|
| 1139 |
-
cds_dataset, cds_split, cds_format, cds_input, cds_target, cds_sys, cds_dev, cds_identity,
|
| 1140 |
-
cds_max_samples, cds_min_len, cds_max_len, cds_epochs, cds_bs, cds_gas, cds_lr, cds_minlr, cds_wd,
|
| 1141 |
-
cds_warm, cds_msl, cds_lora_r, cds_lora_alpha, cds_lora_dropout, cds_precision, cds_workers, cds_quant,
|
| 1142 |
-
cds_mgn, cds_log_steps, cds_eval_steps, cds_save_steps
|
| 1143 |
-
],
|
| 1144 |
-
outputs=[cds_status],
|
| 1145 |
)
|
| 1146 |
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
dataset_name=(dn or None),
|
| 1152 |
-
max_seq_length=int(msl or 4096),
|
| 1153 |
-
batch_size=int(bs or 2),
|
| 1154 |
-
gradient_accumulation_steps=int(gas or 8),
|
| 1155 |
-
learning_rate=float(lr or 5e-6),
|
| 1156 |
-
save_steps=int(sst or 500),
|
| 1157 |
-
eval_steps=int(est or 100),
|
| 1158 |
-
logging_steps=int(lst or 10),
|
| 1159 |
-
filter_bad_entries=bool(fbe),
|
| 1160 |
-
input_field=ifld,
|
| 1161 |
-
target_field=tfld,
|
| 1162 |
-
sample_size=(int(ss) if ss is not None else None),
|
| 1163 |
-
sample_seed=int(seed or 42),
|
| 1164 |
-
trainer_type=tt,
|
| 1165 |
-
)
|
| 1166 |
-
),
|
| 1167 |
-
inputs=[
|
| 1168 |
-
sm_model, sm_dataset, sm_msl, sm_bs, sm_gas, sm_lr, sm_save, sm_eval, sm_log,
|
| 1169 |
-
sm_filter, sm_in, sm_out, sm_sample, sm_seed, sm_trainer,
|
| 1170 |
-
],
|
| 1171 |
-
outputs=[sm_status],
|
| 1172 |
)
|
| 1173 |
|
| 1174 |
start_btn.click(
|
|
@@ -1199,6 +1147,6 @@ if __name__ == "__main__":
|
|
| 1199 |
# Optional: allow setting server parameters via env
|
| 1200 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
| 1201 |
server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
|
| 1202 |
-
demo.queue().launch(server_name=server_name, server_port=server_port)
|
| 1203 |
|
| 1204 |
|
|
|
|
| 829 |
"""
|
| 830 |
|
| 831 |
|
| 832 |
+
def on_family_change(family: str):
|
| 833 |
+
"""Update UI when the model family changes.
|
| 834 |
+
|
| 835 |
+
- Refresh available prebuilt configuration choices
|
| 836 |
+
- Reset defaults (experiment name, repo short, description, space name)
|
| 837 |
+
- Reveal the next step (trainer type)
|
| 838 |
+
"""
|
| 839 |
confs = list(get_config_map(family).keys())
|
| 840 |
exp, repo_short, desc, space = ui_defaults(family)
|
| 841 |
+
|
| 842 |
+
# Initial dataset information placeholder until a specific config is chosen
|
| 843 |
+
training_md = (
|
| 844 |
+
f"Select a training configuration for {family} to see details (dataset, batch size, etc.)."
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# Update objects:
|
| 848 |
+
return (
|
| 849 |
+
gr.update(choices=confs, value=(confs[0] if confs else None)),
|
| 850 |
+
exp,
|
| 851 |
+
repo_short,
|
| 852 |
+
desc,
|
| 853 |
+
space,
|
| 854 |
+
training_md,
|
| 855 |
+
gr.update(choices=[], value=None),
|
| 856 |
+
gr.update(visible=True), # show step 2 (trainer)
|
| 857 |
+
gr.update(visible=False), # hide step 3 until trainer selected
|
| 858 |
+
gr.update(visible=False), # hide step 4 until monitoring selected
|
| 859 |
+
gr.update(visible=(family == "GPT-OSS")), # advanced (scheduler) visibility
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def on_config_change(family: str, config_choice: str):
|
| 864 |
+
"""When a prebuilt configuration is selected, update dataset info and helpful details."""
|
| 865 |
+
if not config_choice:
|
| 866 |
+
return (
|
| 867 |
+
"",
|
| 868 |
+
gr.update(choices=[], value=None),
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
conf_map = get_config_map(family)
|
| 872 |
+
cfg_path = PROJECT_ROOT / conf_map[config_choice]["config_file"]
|
| 873 |
+
cfg_obj = import_config_object(cfg_path)
|
| 874 |
+
|
| 875 |
+
dataset_name = getattr(cfg_obj, "dataset_name", None) if cfg_obj else None
|
| 876 |
+
batch_size = getattr(cfg_obj, "batch_size", None) if cfg_obj else None
|
| 877 |
+
learning_rate = getattr(cfg_obj, "learning_rate", None) if cfg_obj else None
|
| 878 |
+
max_seq_length = getattr(cfg_obj, "max_seq_length", None) if cfg_obj else None
|
| 879 |
+
base_model = conf_map[config_choice]["default_model"]
|
| 880 |
+
|
| 881 |
+
md_lines = [
|
| 882 |
+
f"**Configuration**: {config_choice}",
|
| 883 |
+
f"**Base model**: {base_model}",
|
| 884 |
+
]
|
| 885 |
+
if dataset_name:
|
| 886 |
+
md_lines.append(f"**Dataset**: `{dataset_name}`")
|
| 887 |
+
if batch_size is not None:
|
| 888 |
+
md_lines.append(f"**Batch size**: {batch_size}")
|
| 889 |
+
if learning_rate is not None:
|
| 890 |
+
md_lines.append(f"**Learning rate**: {learning_rate}")
|
| 891 |
+
if max_seq_length is not None:
|
| 892 |
+
md_lines.append(f"**Max seq length**: {max_seq_length}")
|
| 893 |
+
|
| 894 |
+
training_md = "\n".join(md_lines)
|
| 895 |
+
|
| 896 |
+
# dataset selection (allow custom but prefill with the config's dataset if any)
|
| 897 |
+
ds_choices = [dataset_name] if dataset_name else []
|
| 898 |
+
|
| 899 |
+
return training_md, gr.update(choices=ds_choices, value=(dataset_name or None))
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def on_trainer_selected(_: str):
|
| 903 |
+
"""Reveal monitoring step once trainer type is chosen."""
|
| 904 |
+
return gr.update(visible=True)
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def on_monitoring_change(mode: str):
|
| 908 |
+
"""Reveal configuration/details step and adjust Trackio-related visibility by mode."""
|
| 909 |
+
show_trackio = mode in ("both", "trackio")
|
| 910 |
+
show_dataset_repo = mode != "none"
|
| 911 |
+
return (
|
| 912 |
+
gr.update(visible=True),
|
| 913 |
+
gr.update(visible=show_trackio), # trackio space name
|
| 914 |
+
gr.update(visible=show_trackio), # deploy trackio space
|
| 915 |
+
gr.update(visible=show_dataset_repo), # create dataset repo
|
| 916 |
+
)
|
| 917 |
|
| 918 |
|
| 919 |
def start_pipeline(
|
|
|
|
| 1013 |
)
|
| 1014 |
gr.Markdown(joinus)
|
| 1015 |
|
| 1016 |
+
# --- Progressive interface --------------------------------------------------------
|
| 1017 |
+
gr.Markdown("### Configure your run in simple steps")
|
| 1018 |
+
|
| 1019 |
+
# Step 1: Model family
|
| 1020 |
+
with gr.Group():
|
| 1021 |
+
model_family = gr.Dropdown(choices=MODEL_FAMILIES, value="SmolLM3", label="1) Model family")
|
| 1022 |
+
|
| 1023 |
+
# Step 2: Trainer (revealed after family)
|
| 1024 |
+
step2_group = gr.Group(visible=False)
|
| 1025 |
+
with step2_group:
|
| 1026 |
+
trainer_type = gr.Radio(choices=TRAINER_CHOICES, value="SFT", label="2) Trainer type")
|
| 1027 |
+
|
| 1028 |
+
# Step 3: Monitoring (revealed after trainer)
|
| 1029 |
+
step3_group = gr.Group(visible=False)
|
| 1030 |
+
with step3_group:
|
| 1031 |
+
monitoring_mode = gr.Dropdown(choices=MONITORING_CHOICES, value="dataset", label="3) Monitoring mode")
|
| 1032 |
+
|
| 1033 |
+
# Step 4: Config & details (revealed after monitoring)
|
| 1034 |
+
step4_group = gr.Group(visible=False)
|
| 1035 |
+
with step4_group:
|
| 1036 |
+
# Defaults based on initial family selection
|
| 1037 |
+
exp_default, repo_default, desc_default, trackio_space_default = ui_defaults("SmolLM3")
|
| 1038 |
+
|
| 1039 |
+
config_choice = gr.Dropdown(
|
| 1040 |
+
choices=list(get_config_map("SmolLM3").keys()),
|
| 1041 |
+
value="Basic Training",
|
| 1042 |
+
label="4) Training configuration",
|
| 1043 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1044 |
|
| 1045 |
+
with gr.Tabs():
|
| 1046 |
+
with gr.Tab("Overview"):
|
| 1047 |
+
training_info = gr.Markdown("Select a training configuration to see details.")
|
| 1048 |
+
dataset_choice = gr.Dropdown(
|
| 1049 |
+
choices=[],
|
| 1050 |
+
value=None,
|
| 1051 |
+
allow_custom_value=True,
|
| 1052 |
+
label="Dataset (from config; optional)",
|
| 1053 |
+
)
|
| 1054 |
+
with gr.Row():
|
| 1055 |
+
experiment_name = gr.Textbox(value=exp_default, label="Experiment name")
|
| 1056 |
+
repo_short = gr.Textbox(value=repo_default, label="Model repo (short name)")
|
| 1057 |
+
with gr.Row():
|
| 1058 |
+
author_name = gr.Textbox(value=os.environ.get("HF_USERNAME", ""), label="Author name")
|
| 1059 |
+
model_description = gr.Textbox(value=desc_default, label="Model description")
|
| 1060 |
+
trackio_space_name = gr.Textbox(
|
| 1061 |
+
value=trackio_space_default,
|
| 1062 |
+
label="Trackio Space name (used when monitoring != none)",
|
| 1063 |
+
visible=False,
|
| 1064 |
+
)
|
| 1065 |
+
deploy_trackio_space = gr.Checkbox(value=True, label="Deploy Trackio Space", visible=False)
|
| 1066 |
+
create_dataset_repo = gr.Checkbox(value=True, label="Create/ensure HF Dataset repo", visible=True)
|
| 1067 |
+
with gr.Row():
|
| 1068 |
+
push_to_hub = gr.Checkbox(value=True, label="Push model to Hugging Face Hub")
|
| 1069 |
+
switch_to_read_after = gr.Checkbox(value=True, label="Switch Space token to READ after training")
|
| 1070 |
+
|
| 1071 |
+
with gr.Tab("Advanced"):
|
| 1072 |
+
# GPT-OSS specific scheduler overrides
|
| 1073 |
+
advanced_scheduler_group = gr.Group(visible=False)
|
| 1074 |
+
with advanced_scheduler_group:
|
| 1075 |
+
scheduler_override = gr.Dropdown(
|
| 1076 |
+
choices=[c for c in SCHEDULER_CHOICES if c is not None],
|
| 1077 |
+
value=None,
|
| 1078 |
+
allow_custom_value=True,
|
| 1079 |
+
label="Scheduler override",
|
| 1080 |
+
)
|
| 1081 |
+
with gr.Row():
|
| 1082 |
+
min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
|
| 1083 |
+
min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
|
| 1084 |
+
|
| 1085 |
+
# Final action & logs
|
| 1086 |
+
start_btn = gr.Button("Run Pipeline", variant="primary")
|
| 1087 |
logs = gr.Textbox(value="", label="Logs", lines=20)
|
| 1088 |
|
| 1089 |
+
# --- Events ---------------------------------------------------------------------
|
| 1090 |
+
model_family.change(
|
| 1091 |
+
on_family_change,
|
| 1092 |
+
inputs=model_family,
|
| 1093 |
+
outputs=[
|
| 1094 |
+
config_choice,
|
| 1095 |
+
experiment_name,
|
| 1096 |
+
repo_short,
|
| 1097 |
+
model_description,
|
| 1098 |
+
trackio_space_name,
|
| 1099 |
+
training_info,
|
| 1100 |
+
dataset_choice,
|
| 1101 |
+
step2_group,
|
| 1102 |
+
step3_group,
|
| 1103 |
+
step4_group,
|
| 1104 |
+
advanced_scheduler_group,
|
| 1105 |
+
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1106 |
)
|
| 1107 |
|
| 1108 |
+
trainer_type.change(on_trainer_selected, inputs=trainer_type, outputs=step3_group)
|
| 1109 |
+
|
| 1110 |
+
monitoring_mode.change(
|
| 1111 |
+
on_monitoring_change,
|
| 1112 |
+
inputs=monitoring_mode,
|
| 1113 |
+
outputs=[step4_group, trackio_space_name, deploy_trackio_space, create_dataset_repo],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
)
|
| 1115 |
|
| 1116 |
+
config_choice.change(
|
| 1117 |
+
on_config_change,
|
| 1118 |
+
inputs=[model_family, config_choice],
|
| 1119 |
+
outputs=[training_info, dataset_choice],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1120 |
)
|
| 1121 |
|
| 1122 |
start_btn.click(
|
|
|
|
| 1147 |
# Optional: allow setting server parameters via env
|
| 1148 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
| 1149 |
server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
|
| 1150 |
+
demo.queue().launch(server_name=server_name, server_port=server_port, mcp_server=True)
|
| 1151 |
|
| 1152 |
|