File size: 5,630 Bytes
5fe83da
 
 
 
 
 
 
 
 
 
 
d60ab6c
 
 
5fe83da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb64084
 
0de9de2
bb64084
 
 
 
 
0de9de2
bb64084
 
 
5fe83da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb64084
 
 
 
 
 
5fe83da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb64084
 
 
 
5fe83da
 
 
 
 
 
 
 
 
 
829d8f4
5fe83da
829d8f4
 
5fe83da
 
 
829d8f4
 
 
 
 
 
 
 
 
5fe83da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python3
"""
Script to run A100 large-scale experiments on OpenHermes-FR dataset
Supports multiple configurations for different training scenarios
"""

import argparse
import os
import sys
from pathlib import Path

# Set CUDA memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def main():
    parser = argparse.ArgumentParser(description="Run A100 large-scale experiments")
    parser.add_argument(
        "--config", 
        type=str, 
        default="config/train_smollm3_openhermes_fr_a100_large.py",
        help="Configuration file to use"
    )
    parser.add_argument(
        "--experiment-name",
        type=str,
        help="Custom experiment name for tracking"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./outputs",
        help="Output directory for checkpoints and logs"
    )
    parser.add_argument(
        "--resume",
        type=str,
        help="Resume training from checkpoint"
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print configuration without starting training"
    )
    parser.add_argument(
        "--trackio-url",
        "--trackio_url",
        type=str,
        help="Trackio URL for experiment tracking"
    )
    parser.add_argument(
        "--trackio-token",
        "--trackio_token",
        type=str,
        help="Trackio token for authentication"
    )
    
    args = parser.parse_args()
    
    # Add the current directory to Python path
    sys.path.insert(0, str(Path(__file__).parent))
    
    # Import the configuration
    try:
        from config.train_smollm3_openhermes_fr_a100_large import get_config as get_large_config
        from config.train_smollm3_openhermes_fr_a100_multiple_passes import get_config as get_multiple_passes_config
        
        # Map config files to their respective functions
        config_map = {
            "config/train_smollm3_openhermes_fr_a100_large.py": get_large_config,
            "config/train_smollm3_openhermes_fr_a100_multiple_passes.py": get_multiple_passes_config,
        }
        
        if args.config in config_map:
            config = config_map[args.config](args.config)
        else:
            # Try to load from the specified config file
            config = get_large_config(args.config)
            
    except ImportError as e:
        print(f"Error importing configuration: {e}")
        print("Available configurations:")
        print("  - config/train_smollm3_openhermes_fr_a100_large.py (Large batch, 1.3 passes)")
        print("  - config/train_smollm3_openhermes_fr_a100_multiple_passes.py (Multiple passes, 4 epochs)")
        return 1
    
    # Override experiment name if provided
    if args.experiment_name:
        config.experiment_name = args.experiment_name
    
    # Override Trackio settings if provided
    if args.trackio_url:
        config.trackio_url = args.trackio_url
    if args.trackio_token:
        config.trackio_token = args.trackio_token
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Print configuration summary
    print(f"\n{'='*60}")
    print(f"EXPERIMENT CONFIGURATION")
    print(f"{'='*60}")
    print(f"Config file: {args.config}")
    print(f"Experiment name: {config.experiment_name}")
    print(f"Output directory: {args.output_dir}")
    print(f"Model: {config.model_name}")
    print(f"Batch size: {config.batch_size}")
    print(f"Gradient accumulation: {config.gradient_accumulation_steps}")
    print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
    print(f"Learning rate: {config.learning_rate}")
    print(f"Max iterations: {config.max_iters}")
    print(f"Max sequence length: {config.max_seq_length}")
    print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
    print(f"Dataset: {config.dataset_name}")
    if config.trackio_url:
        print(f"Trackio URL: {config.trackio_url}")
    if config.trackio_token:
        print(f"Trackio Token: {'*' * len(config.trackio_token)}")
    print(f"{'='*60}\n")
    
    if args.dry_run:
        print("DRY RUN - Configuration printed above. Use without --dry-run to start training.")
        return 0
    
    # Import and run training
    try:
        from train import main as train_main
        
        # Set up training arguments - config is positional, not --config
        train_args = [
            args.config,  # Config file as positional argument
            "--out_dir", args.output_dir,
        ]
        
        if args.resume:
            train_args.extend(["--init_from", "resume"])
        
        # Add Trackio arguments if provided
        if args.trackio_url:
            train_args.extend(["--trackio_url", args.trackio_url])
        if args.trackio_token:
            train_args.extend(["--trackio_token", args.trackio_token])
        if args.experiment_name:
            train_args.extend(["--experiment_name", args.experiment_name])
        
        # Override sys.argv for the training script
        original_argv = sys.argv
        sys.argv = ["train.py"] + train_args
        
        # Run training
        train_main()
        
        # Restore original argv
        sys.argv = original_argv
        
    except ImportError as e:
        print(f"Error importing training module: {e}")
        print("Make sure train.py is available in the current directory.")
        return 1
    except Exception as e:
        print(f"Error during training: {e}")
        return 1
    
    return 0

if __name__ == "__main__":
    exit(main())