Spaces:
Running
Running
fix config bug
Browse files- config/train_smollm3_h100_lightweight.py +142 -92
- scripts/training/train.py +20 -3
- src/train.py +6 -2
- test_config.py +53 -0
config/train_smollm3_h100_lightweight.py
CHANGED
|
@@ -3,112 +3,162 @@ SmolLM3 H100 Lightweight Training Configuration
|
|
| 3 |
Optimized for rapid training on H100 with 80K Hermes-FR samples
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
from config.train_smollm3 import SmolLM3Config
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# Mixed precision - Full precision for H100
|
| 37 |
-
fp16=True
|
| 38 |
-
bf16=False
|
| 39 |
|
| 40 |
-
# Logging and saving -
|
| 41 |
-
save_steps=200
|
| 42 |
-
eval_steps=50
|
| 43 |
-
logging_steps=5
|
| 44 |
-
save_total_limit=2
|
| 45 |
|
| 46 |
# Evaluation
|
| 47 |
-
eval_strategy="steps"
|
| 48 |
-
metric_for_best_model="eval_loss"
|
| 49 |
-
greater_is_better=False
|
| 50 |
-
load_best_model_at_end=True
|
| 51 |
-
|
| 52 |
-
#
|
| 53 |
-
dataset_name="legmlai/openhermes-fr"
|
| 54 |
-
dataset_split="train"
|
| 55 |
-
input_field="prompt"
|
| 56 |
-
target_field="completion"
|
| 57 |
-
filter_bad_entries=False
|
| 58 |
-
bad_entry_field="bad_entry"
|
| 59 |
-
sample_size=80000
|
| 60 |
-
sample_seed=42
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Chat template configuration
|
| 63 |
-
use_chat_template=True
|
| 64 |
-
chat_template_kwargs=
|
| 65 |
-
"enable_thinking": False,
|
| 66 |
-
"add_generation_prompt": True,
|
| 67 |
-
"no_think_system_message": True
|
| 68 |
-
},
|
| 69 |
|
| 70 |
# Trackio monitoring configuration
|
| 71 |
-
enable_tracking=True
|
| 72 |
-
trackio_url
|
| 73 |
-
trackio_token=None
|
| 74 |
-
log_artifacts=True
|
| 75 |
-
log_metrics=True
|
| 76 |
-
log_config=True
|
| 77 |
-
experiment_name
|
| 78 |
|
| 79 |
# HF Datasets configuration
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
# H100-specific optimizations
|
| 83 |
-
dataloader_num_workers=4
|
| 84 |
-
dataloader_pin_memory=True
|
| 85 |
-
|
| 86 |
|
| 87 |
# Memory optimizations for rapid training
|
| 88 |
-
max_grad_norm=1.0
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
#
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
Optimized for rapid training on H100 with 80K Hermes-FR samples
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional
|
| 9 |
from config.train_smollm3 import SmolLM3Config
|
| 10 |
|
| 11 |
+
@dataclass
|
| 12 |
+
class SmolLM3ConfigH100Lightweight(SmolLM3Config):
|
| 13 |
+
"""Configuration for SmolLM3 fine-tuning on OpenHermes-FR dataset - H100 Lightweight"""
|
| 14 |
+
|
| 15 |
+
# Model configuration - optimized for H100
|
| 16 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
| 17 |
+
max_seq_length: int = 8192 # Increased for better context understanding
|
| 18 |
+
use_flash_attention: bool = True
|
| 19 |
+
use_gradient_checkpointing: bool = True # Enabled for memory efficiency
|
| 20 |
+
|
| 21 |
+
# Training configuration - H100 optimized for rapid training
|
| 22 |
+
batch_size: int = 16 # Larger batch size for H100
|
| 23 |
+
gradient_accumulation_steps: int = 4 # Reduced for faster updates
|
| 24 |
+
learning_rate: float = 8e-6 # Slightly higher for rapid convergence
|
| 25 |
+
weight_decay: float = 0.01
|
| 26 |
+
warmup_steps: int = 50 # Reduced warmup for rapid training
|
| 27 |
+
max_iters: int = None # Will be calculated based on epochs
|
| 28 |
+
eval_interval: int = 50 # More frequent evaluation
|
| 29 |
+
log_interval: int = 5 # More frequent logging
|
| 30 |
+
save_interval: int = 200 # More frequent saving
|
| 31 |
+
|
| 32 |
+
# Optimizer configuration - optimized for rapid training
|
| 33 |
+
optimizer: str = "adamw_torch"
|
| 34 |
+
beta1: float = 0.9
|
| 35 |
+
beta2: float = 0.95
|
| 36 |
+
eps: float = 1e-8
|
| 37 |
+
|
| 38 |
+
# Scheduler configuration - faster learning
|
| 39 |
+
scheduler: str = "cosine"
|
| 40 |
+
min_lr: float = 2e-6 # Higher minimum LR
|
| 41 |
|
| 42 |
# Mixed precision - Full precision for H100
|
| 43 |
+
fp16: bool = True
|
| 44 |
+
bf16: bool = False
|
| 45 |
|
| 46 |
+
# Logging and saving - more frequent for rapid training
|
| 47 |
+
save_steps: int = 200
|
| 48 |
+
eval_steps: int = 50
|
| 49 |
+
logging_steps: int = 5
|
| 50 |
+
save_total_limit: Optional[int] = 2 # Keep fewer checkpoints
|
| 51 |
|
| 52 |
# Evaluation
|
| 53 |
+
eval_strategy: str = "steps"
|
| 54 |
+
metric_for_best_model: str = "eval_loss"
|
| 55 |
+
greater_is_better: bool = False
|
| 56 |
+
load_best_model_at_end: bool = True
|
| 57 |
+
|
| 58 |
+
# OpenHermes-FR Dataset configuration with sampling
|
| 59 |
+
dataset_name: str = "legmlai/openhermes-fr"
|
| 60 |
+
dataset_split: str = "train"
|
| 61 |
+
input_field: str = "prompt"
|
| 62 |
+
target_field: str = "completion"
|
| 63 |
+
filter_bad_entries: bool = False
|
| 64 |
+
bad_entry_field: str = "bad_entry"
|
| 65 |
+
sample_size: int = 80000 # 80K samples for lightweight training
|
| 66 |
+
sample_seed: int = 42 # For reproducibility
|
| 67 |
+
|
| 68 |
+
# Data configuration (not used for HF datasets but kept for compatibility)
|
| 69 |
+
data_dir: str = "my_dataset"
|
| 70 |
+
train_file: str = "train.json"
|
| 71 |
+
validation_file: Optional[str] = "validation.json"
|
| 72 |
+
test_file: Optional[str] = None
|
| 73 |
|
| 74 |
# Chat template configuration
|
| 75 |
+
use_chat_template: bool = True
|
| 76 |
+
chat_template_kwargs: dict = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# Trackio monitoring configuration
|
| 79 |
+
enable_tracking: bool = True
|
| 80 |
+
trackio_url: Optional[str] = None
|
| 81 |
+
trackio_token: Optional[str] = None
|
| 82 |
+
log_artifacts: bool = True
|
| 83 |
+
log_metrics: bool = True
|
| 84 |
+
log_config: bool = True
|
| 85 |
+
experiment_name: Optional[str] = None
|
| 86 |
|
| 87 |
# HF Datasets configuration
|
| 88 |
+
hf_token: Optional[str] = None
|
| 89 |
+
dataset_repo: Optional[str] = None
|
| 90 |
|
| 91 |
# H100-specific optimizations
|
| 92 |
+
dataloader_num_workers: int = 4 # Optimized for H100
|
| 93 |
+
dataloader_pin_memory: bool = True
|
| 94 |
+
dataloader_prefetch_factor: int = 2
|
| 95 |
|
| 96 |
# Memory optimizations for rapid training
|
| 97 |
+
max_grad_norm: float = 1.0
|
| 98 |
+
group_by_length: bool = True # Group similar length sequences
|
| 99 |
+
|
| 100 |
+
# Training duration calculations
|
| 101 |
+
# With 80k datapoints and effective batch size of 64:
|
| 102 |
+
# Steps per epoch = 80,000 / 64 = 1,250 steps
|
| 103 |
+
# For 1 epoch: 1,250 steps
|
| 104 |
+
# For 2 epochs: 2,500 steps
|
| 105 |
+
|
| 106 |
+
def __post_init__(self):
|
| 107 |
+
if self.chat_template_kwargs is None:
|
| 108 |
+
self.chat_template_kwargs = {
|
| 109 |
+
"enable_thinking": False,
|
| 110 |
+
"add_generation_prompt": True,
|
| 111 |
+
"no_think_system_message": True
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Validate configuration
|
| 115 |
+
if self.fp16 and self.bf16:
|
| 116 |
+
raise ValueError("Cannot use both fp16 and bf16")
|
| 117 |
+
|
| 118 |
+
if self.max_seq_length > 131072: # 128k limit
|
| 119 |
+
raise ValueError("max_seq_length cannot exceed 131072")
|
| 120 |
+
|
| 121 |
+
# Calculate training statistics
|
| 122 |
+
effective_batch_size = self.batch_size * self.gradient_accumulation_steps
|
| 123 |
+
steps_per_epoch = self.sample_size // effective_batch_size # For 80k dataset
|
| 124 |
+
epochs_for_max_iters = self.max_iters / steps_per_epoch if self.max_iters else 1
|
| 125 |
+
|
| 126 |
+
print(f"=== H100 Lightweight Training Configuration ===")
|
| 127 |
+
print(f"Effective batch size: {effective_batch_size}")
|
| 128 |
+
print(f"Steps per epoch: ~{steps_per_epoch}")
|
| 129 |
+
print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
|
| 130 |
+
print(f"Total training steps: {self.max_iters or 'auto'}")
|
| 131 |
+
print(f"Learning rate: {self.learning_rate}")
|
| 132 |
+
print(f"Mixed precision: {'fp16' if self.fp16 else 'bf16'}")
|
| 133 |
+
print(f"Max sequence length: {self.max_seq_length}")
|
| 134 |
+
print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
|
| 135 |
+
print(f"Dataset sample size: {self.sample_size}")
|
| 136 |
+
print("=" * 50)
|
| 137 |
+
|
| 138 |
+
# Set default experiment name if not provided
|
| 139 |
+
if self.experiment_name is None:
|
| 140 |
+
self.experiment_name = "smollm3_h100_lightweight"
|
| 141 |
+
|
| 142 |
+
def get_config(config_path: str) -> SmolLM3ConfigH100Lightweight:
|
| 143 |
+
"""Load configuration from file or return default"""
|
| 144 |
+
if os.path.exists(config_path):
|
| 145 |
+
# Load from file if it exists
|
| 146 |
+
import importlib.util
|
| 147 |
+
spec = importlib.util.spec_from_file_location("config_module", config_path)
|
| 148 |
+
config_module = importlib.util.module_from_spec(spec)
|
| 149 |
+
spec.loader.exec_module(config_module)
|
| 150 |
+
|
| 151 |
+
if hasattr(config_module, 'config'):
|
| 152 |
+
return config_module.config
|
| 153 |
+
else:
|
| 154 |
+
# Try to find a config class
|
| 155 |
+
for attr_name in dir(config_module):
|
| 156 |
+
attr = getattr(config_module, attr_name)
|
| 157 |
+
if isinstance(attr, SmolLM3ConfigH100Lightweight):
|
| 158 |
+
return attr
|
| 159 |
+
|
| 160 |
+
# Return default configuration
|
| 161 |
+
return SmolLM3ConfigH100Lightweight()
|
| 162 |
+
|
| 163 |
+
# Default configuration instance
|
| 164 |
+
config = SmolLM3ConfigH100Lightweight()
|
scripts/training/train.py
CHANGED
|
@@ -53,6 +53,12 @@ def main():
|
|
| 53 |
type=str,
|
| 54 |
help="Trackio token for authentication"
|
| 55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
args = parser.parse_args()
|
| 58 |
|
|
@@ -65,13 +71,13 @@ def main():
|
|
| 65 |
# Import all available configurations
|
| 66 |
from config.train_smollm3_openhermes_fr_a100_large import get_config as get_large_config
|
| 67 |
from config.train_smollm3_openhermes_fr_a100_multiple_passes import get_config as get_multiple_passes_config
|
| 68 |
-
from config.train_smollm3_h100_lightweight import
|
| 69 |
|
| 70 |
# Map config files to their respective functions
|
| 71 |
config_map = {
|
| 72 |
"config/train_smollm3_openhermes_fr_a100_large.py": get_large_config,
|
| 73 |
"config/train_smollm3_openhermes_fr_a100_multiple_passes.py": get_multiple_passes_config,
|
| 74 |
-
"config/train_smollm3_h100_lightweight.py":
|
| 75 |
}
|
| 76 |
|
| 77 |
if args.config in config_map:
|
|
@@ -116,7 +122,15 @@ def main():
|
|
| 116 |
print(f"Max iterations: {config.max_iters}")
|
| 117 |
print(f"Max sequence length: {config.max_seq_length}")
|
| 118 |
print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
if config.trackio_url:
|
| 121 |
print(f"Trackio URL: {config.trackio_url}")
|
| 122 |
if config.trackio_token:
|
|
@@ -151,6 +165,9 @@ def main():
|
|
| 151 |
if args.experiment_name:
|
| 152 |
train_args.extend(["--experiment_name", args.experiment_name])
|
| 153 |
|
|
|
|
|
|
|
|
|
|
| 154 |
# Override sys.argv for the training script
|
| 155 |
original_argv = sys.argv
|
| 156 |
sys.argv = ["train.py"] + train_args
|
|
|
|
| 53 |
type=str,
|
| 54 |
help="Trackio token for authentication"
|
| 55 |
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--dataset-dir",
|
| 58 |
+
type=str,
|
| 59 |
+
default="my_dataset",
|
| 60 |
+
help="Dataset directory path"
|
| 61 |
+
)
|
| 62 |
|
| 63 |
args = parser.parse_args()
|
| 64 |
|
|
|
|
| 71 |
# Import all available configurations
|
| 72 |
from config.train_smollm3_openhermes_fr_a100_large import get_config as get_large_config
|
| 73 |
from config.train_smollm3_openhermes_fr_a100_multiple_passes import get_config as get_multiple_passes_config
|
| 74 |
+
from config.train_smollm3_h100_lightweight import get_config as get_h100_lightweight_config
|
| 75 |
|
| 76 |
# Map config files to their respective functions
|
| 77 |
config_map = {
|
| 78 |
"config/train_smollm3_openhermes_fr_a100_large.py": get_large_config,
|
| 79 |
"config/train_smollm3_openhermes_fr_a100_multiple_passes.py": get_multiple_passes_config,
|
| 80 |
+
"config/train_smollm3_h100_lightweight.py": get_h100_lightweight_config,
|
| 81 |
}
|
| 82 |
|
| 83 |
if args.config in config_map:
|
|
|
|
| 122 |
print(f"Max iterations: {config.max_iters}")
|
| 123 |
print(f"Max sequence length: {config.max_seq_length}")
|
| 124 |
print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
|
| 125 |
+
if hasattr(config, 'dataset_name') and config.dataset_name:
|
| 126 |
+
print(f"Dataset: {config.dataset_name}")
|
| 127 |
+
if hasattr(config, 'sample_size') and config.sample_size:
|
| 128 |
+
print(f"Sample size: {config.sample_size}")
|
| 129 |
+
else:
|
| 130 |
+
print(f"Dataset directory: {config.data_dir}")
|
| 131 |
+
print(f"Training file: {config.train_file}")
|
| 132 |
+
if config.validation_file:
|
| 133 |
+
print(f"Validation file: {config.validation_file}")
|
| 134 |
if config.trackio_url:
|
| 135 |
print(f"Trackio URL: {config.trackio_url}")
|
| 136 |
if config.trackio_token:
|
|
|
|
| 165 |
if args.experiment_name:
|
| 166 |
train_args.extend(["--experiment_name", args.experiment_name])
|
| 167 |
|
| 168 |
+
# Add dataset directory argument
|
| 169 |
+
train_args.extend(["--dataset_dir", args.dataset_dir])
|
| 170 |
+
|
| 171 |
# Override sys.argv for the training script
|
| 172 |
original_argv = sys.argv
|
| 173 |
sys.argv = ["train.py"] + train_args
|
src/train.py
CHANGED
|
@@ -174,13 +174,17 @@ def main():
|
|
| 174 |
)
|
| 175 |
|
| 176 |
# Determine dataset path
|
|
|
|
| 177 |
if hasattr(config, 'dataset_name') and config.dataset_name:
|
| 178 |
# Use Hugging Face dataset
|
| 179 |
dataset_path = config.dataset_name
|
| 180 |
logger.info(f"Using Hugging Face dataset: {dataset_path}")
|
| 181 |
else:
|
| 182 |
-
# Use local dataset
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
| 184 |
logger.info(f"Using local dataset: {dataset_path}")
|
| 185 |
|
| 186 |
# Load dataset with filtering options and sampling
|
|
|
|
| 174 |
)
|
| 175 |
|
| 176 |
# Determine dataset path
|
| 177 |
+
# Check if using Hugging Face dataset or local dataset
|
| 178 |
if hasattr(config, 'dataset_name') and config.dataset_name:
|
| 179 |
# Use Hugging Face dataset
|
| 180 |
dataset_path = config.dataset_name
|
| 181 |
logger.info(f"Using Hugging Face dataset: {dataset_path}")
|
| 182 |
else:
|
| 183 |
+
# Use local dataset from config or command line argument
|
| 184 |
+
if args.dataset_dir:
|
| 185 |
+
dataset_path = os.path.join('/input', args.dataset_dir)
|
| 186 |
+
else:
|
| 187 |
+
dataset_path = os.path.join('/input', config.data_dir)
|
| 188 |
logger.info(f"Using local dataset: {dataset_path}")
|
| 189 |
|
| 190 |
# Load dataset with filtering options and sampling
|
test_config.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script to verify H100 lightweight configuration loads correctly
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Add project root to path
|
| 10 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
sys.path.insert(0, project_root)
|
| 12 |
+
|
| 13 |
+
def test_h100_lightweight_config():
|
| 14 |
+
"""Test the H100 lightweight configuration"""
|
| 15 |
+
try:
|
| 16 |
+
from config.train_smollm3_h100_lightweight import config
|
| 17 |
+
|
| 18 |
+
print("✅ H100 Lightweight configuration loaded successfully!")
|
| 19 |
+
print(f"Model: {config.model_name}")
|
| 20 |
+
print(f"Dataset: {config.dataset_name}")
|
| 21 |
+
print(f"Sample size: {config.sample_size}")
|
| 22 |
+
print(f"Batch size: {config.batch_size}")
|
| 23 |
+
print(f"Learning rate: {config.learning_rate}")
|
| 24 |
+
print(f"Max sequence length: {config.max_seq_length}")
|
| 25 |
+
|
| 26 |
+
return True
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"❌ Error loading H100 lightweight configuration: {e}")
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
def test_training_script_import():
|
| 32 |
+
"""Test that the training script can import the configuration"""
|
| 33 |
+
try:
|
| 34 |
+
from scripts.training.train import main
|
| 35 |
+
print("✅ Training script imports successfully!")
|
| 36 |
+
return True
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"❌ Error importing training script: {e}")
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
print("Testing H100 Lightweight Configuration...")
|
| 43 |
+
print("=" * 50)
|
| 44 |
+
|
| 45 |
+
success = True
|
| 46 |
+
success &= test_h100_lightweight_config()
|
| 47 |
+
success &= test_training_script_import()
|
| 48 |
+
|
| 49 |
+
if success:
|
| 50 |
+
print("\n🎉 All tests passed! Configuration is ready for training.")
|
| 51 |
+
else:
|
| 52 |
+
print("\n❌ Some tests failed. Please check the configuration.")
|
| 53 |
+
sys.exit(1)
|