Spaces:
Running
Running
Commit
·
d8dd7a1
verified
·
0
Parent(s):
first commit
Browse files- .gitignore +98 -0
- README.md +291 -0
- config.py +28 -0
- config/train_smollm3.py +107 -0
- config/train_smollm3_dpo.py +38 -0
- config/train_smollm3_long_context.py +38 -0
- create_sample_dataset.py +41 -0
- data.py +238 -0
- model.py +188 -0
- requirements.txt +35 -0
- test_setup.py +206 -0
- train.py +144 -0
- trainer.py +242 -0
.gitignore
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
MANIFEST
|
| 23 |
+
|
| 24 |
+
# PyTorch
|
| 25 |
+
*.pth
|
| 26 |
+
*.pt
|
| 27 |
+
*.ckpt
|
| 28 |
+
|
| 29 |
+
# Jupyter Notebook
|
| 30 |
+
.ipynb_checkpoints
|
| 31 |
+
|
| 32 |
+
# Environment
|
| 33 |
+
.env
|
| 34 |
+
.venv
|
| 35 |
+
env/
|
| 36 |
+
venv/
|
| 37 |
+
ENV/
|
| 38 |
+
env.bak/
|
| 39 |
+
venv.bak/
|
| 40 |
+
|
| 41 |
+
# IDE
|
| 42 |
+
.vscode/
|
| 43 |
+
.idea/
|
| 44 |
+
*.swp
|
| 45 |
+
*.swo
|
| 46 |
+
*~
|
| 47 |
+
|
| 48 |
+
# OS
|
| 49 |
+
.DS_Store
|
| 50 |
+
.DS_Store?
|
| 51 |
+
._*
|
| 52 |
+
.Spotlight-V100
|
| 53 |
+
.Trashes
|
| 54 |
+
ehthumbs.db
|
| 55 |
+
Thumbs.db
|
| 56 |
+
|
| 57 |
+
# Logs
|
| 58 |
+
*.log
|
| 59 |
+
logs/
|
| 60 |
+
tensorboard_logs/
|
| 61 |
+
|
| 62 |
+
# Model outputs
|
| 63 |
+
output/
|
| 64 |
+
checkpoints/
|
| 65 |
+
models/
|
| 66 |
+
wandb/
|
| 67 |
+
|
| 68 |
+
# Datasets
|
| 69 |
+
data/
|
| 70 |
+
datasets/
|
| 71 |
+
my_dataset/
|
| 72 |
+
test_dataset/
|
| 73 |
+
|
| 74 |
+
# Temporary files
|
| 75 |
+
tmp/
|
| 76 |
+
temp/
|
| 77 |
+
*.tmp
|
| 78 |
+
*.temp
|
| 79 |
+
|
| 80 |
+
# Hugging Face cache
|
| 81 |
+
.cache/
|
| 82 |
+
transformers_cache/
|
| 83 |
+
|
| 84 |
+
# Accelerate
|
| 85 |
+
accelerate_config.yaml
|
| 86 |
+
|
| 87 |
+
# Training outputs
|
| 88 |
+
runs/
|
| 89 |
+
*.json
|
| 90 |
+
!config/*.json
|
| 91 |
+
!*.json.example
|
| 92 |
+
|
| 93 |
+
# Evaluation results
|
| 94 |
+
eval_results/
|
| 95 |
+
test_results/
|
| 96 |
+
|
| 97 |
+
# Documentation
|
| 98 |
+
docs/_build/
|
README.md
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SmolLM3 Fine-tuning for FlexAI Console
|
| 2 |
+
|
| 3 |
+
This repository provides a complete setup for fine-tuning SmolLM3 models using the FlexAI console, following the nanoGPT structure but adapted for modern transformer models.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
SmolLM3 is a 3B-parameter transformer decoder model optimized for efficiency, long-context reasoning, and multilingual support. This setup allows you to fine-tune SmolLM3 for various tasks including:
|
| 8 |
+
|
| 9 |
+
- **Supervised Fine-tuning (SFT)**: Adapt the model for instruction following
|
| 10 |
+
- **Direct Preference Optimization (DPO)**: Improve model alignment
|
| 11 |
+
- **Long-context fine-tuning**: Support for up to 128k tokens
|
| 12 |
+
- **Tool calling**: Fine-tune for function calling capabilities
|
| 13 |
+
|
| 14 |
+
## Quick Start
|
| 15 |
+
|
| 16 |
+
### 1. Repository Setup
|
| 17 |
+
|
| 18 |
+
The repository follows the FlexAI console structure with the following key files:
|
| 19 |
+
|
| 20 |
+
- `train.py`: Main entry point script
|
| 21 |
+
- `config/train_smollm3.py`: Default configuration
|
| 22 |
+
- `model.py`: Model wrapper and loading
|
| 23 |
+
- `data.py`: Dataset handling and preprocessing
|
| 24 |
+
- `trainer.py`: Training loop and trainer setup
|
| 25 |
+
- `requirements.txt`: Dependencies
|
| 26 |
+
|
| 27 |
+
### 2. FlexAI Console Configuration
|
| 28 |
+
|
| 29 |
+
When setting up a Fine Tuning Job in the FlexAI console, use these settings:
|
| 30 |
+
|
| 31 |
+
#### Basic Configuration
|
| 32 |
+
- **Name**: `smollm3-finetune`
|
| 33 |
+
- **Cluster**: Your organization's designated cluster
|
| 34 |
+
- **Checkpoint**: (Optional) Previous training job checkpoint
|
| 35 |
+
- **Node Count**: 1
|
| 36 |
+
- **Accelerator Count**: 1-8 (depending on your needs)
|
| 37 |
+
|
| 38 |
+
#### Repository Settings
|
| 39 |
+
- **Repository URL**: `https://github.com/your-username/flexai-finetune`
|
| 40 |
+
- **Repository Revision**: `main`
|
| 41 |
+
|
| 42 |
+
#### Dataset Configuration
|
| 43 |
+
- **Datasets**: Your dataset (mounted under `/input`)
|
| 44 |
+
- **Mount Directory**: `my_dataset`
|
| 45 |
+
|
| 46 |
+
#### Entry Point
|
| 47 |
+
```
|
| 48 |
+
train.py config/train_smollm3.py --dataset_dir=my_dataset --init_from=resume --out_dir=/input-checkpoint --max_iters=1500
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### 3. Dataset Format
|
| 52 |
+
|
| 53 |
+
The script supports multiple dataset formats:
|
| 54 |
+
|
| 55 |
+
#### Chat Format (Recommended)
|
| 56 |
+
```json
|
| 57 |
+
[
|
| 58 |
+
{
|
| 59 |
+
"messages": [
|
| 60 |
+
{"role": "user", "content": "What is machine learning?"},
|
| 61 |
+
{"role": "assistant", "content": "Machine learning is a subset of AI..."}
|
| 62 |
+
]
|
| 63 |
+
}
|
| 64 |
+
]
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
#### Instruction Format
|
| 68 |
+
```json
|
| 69 |
+
[
|
| 70 |
+
{
|
| 71 |
+
"instruction": "What is machine learning?",
|
| 72 |
+
"output": "Machine learning is a subset of AI..."
|
| 73 |
+
}
|
| 74 |
+
]
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
#### User-Assistant Format
|
| 78 |
+
```json
|
| 79 |
+
[
|
| 80 |
+
{
|
| 81 |
+
"user": "What is machine learning?",
|
| 82 |
+
"assistant": "Machine learning is a subset of AI..."
|
| 83 |
+
}
|
| 84 |
+
]
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### 4. Configuration Options
|
| 88 |
+
|
| 89 |
+
The default configuration in `config/train_smollm3.py` includes:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
@dataclass
|
| 93 |
+
class SmolLM3Config:
|
| 94 |
+
# Model configuration
|
| 95 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
| 96 |
+
max_seq_length: int = 4096
|
| 97 |
+
use_flash_attention: bool = True
|
| 98 |
+
|
| 99 |
+
# Training configuration
|
| 100 |
+
batch_size: int = 4
|
| 101 |
+
gradient_accumulation_steps: int = 4
|
| 102 |
+
learning_rate: float = 2e-5
|
| 103 |
+
max_iters: int = 1000
|
| 104 |
+
|
| 105 |
+
# Mixed precision
|
| 106 |
+
fp16: bool = True
|
| 107 |
+
bf16: bool = False
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### 5. Command Line Arguments
|
| 111 |
+
|
| 112 |
+
The `train.py` script accepts various arguments:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
# Basic usage
|
| 116 |
+
python train.py config/train_smollm3.py
|
| 117 |
+
|
| 118 |
+
# With custom parameters
|
| 119 |
+
python train.py config/train_smollm3.py \
|
| 120 |
+
--dataset_dir=my_dataset \
|
| 121 |
+
--out_dir=/output-checkpoint \
|
| 122 |
+
--init_from=resume \
|
| 123 |
+
--max_iters=1500 \
|
| 124 |
+
--batch_size=8 \
|
| 125 |
+
--learning_rate=1e-5 \
|
| 126 |
+
--max_seq_length=8192
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
## Advanced Usage
|
| 130 |
+
|
| 131 |
+
### 1. Custom Configuration
|
| 132 |
+
|
| 133 |
+
Create a custom configuration file:
|
| 134 |
+
|
| 135 |
+
```python
|
| 136 |
+
# config/my_config.py
|
| 137 |
+
from config.train_smollm3 import SmolLM3Config
|
| 138 |
+
|
| 139 |
+
config = SmolLM3Config(
|
| 140 |
+
model_name="HuggingFaceTB/SmolLM3-3B-Instruct",
|
| 141 |
+
max_seq_length=8192,
|
| 142 |
+
batch_size=2,
|
| 143 |
+
learning_rate=1e-5,
|
| 144 |
+
max_iters=2000,
|
| 145 |
+
use_flash_attention=True,
|
| 146 |
+
fp16=True
|
| 147 |
+
)
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### 2. Long-Context Fine-tuning
|
| 151 |
+
|
| 152 |
+
For long-context tasks (up to 128k tokens):
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
config = SmolLM3Config(
|
| 156 |
+
max_seq_length=131072, # 128k tokens
|
| 157 |
+
model_name="HuggingFaceTB/SmolLM3-3B",
|
| 158 |
+
use_flash_attention=True,
|
| 159 |
+
gradient_checkpointing=True
|
| 160 |
+
)
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### 3. DPO Training
|
| 164 |
+
|
| 165 |
+
For preference optimization, use the DPO trainer:
|
| 166 |
+
|
| 167 |
+
```python
|
| 168 |
+
from trainer import SmolLM3DPOTrainer
|
| 169 |
+
|
| 170 |
+
dpo_trainer = SmolLM3DPOTrainer(
|
| 171 |
+
model=model,
|
| 172 |
+
dataset=dataset,
|
| 173 |
+
config=config,
|
| 174 |
+
output_dir="./dpo-output"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
dpo_trainer.train()
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
### 4. Tool Calling Fine-tuning
|
| 181 |
+
|
| 182 |
+
Include tool calling examples in your dataset:
|
| 183 |
+
|
| 184 |
+
```json
|
| 185 |
+
[
|
| 186 |
+
{
|
| 187 |
+
"messages": [
|
| 188 |
+
{"role": "user", "content": "What's the weather in New York?"},
|
| 189 |
+
{"role": "assistant", "content": "<tool_call>\n<invoke name=\"get_weather\">\n<parameter name=\"location\">New York</parameter>\n</invoke>\n</tool_call>"},
|
| 190 |
+
{"role": "tool", "content": "The weather in New York is 72°F and sunny."},
|
| 191 |
+
{"role": "assistant", "content": "The weather in New York is currently 72°F and sunny."}
|
| 192 |
+
]
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
## Model Variants
|
| 198 |
+
|
| 199 |
+
SmolLM3 comes in several variants:
|
| 200 |
+
|
| 201 |
+
- **SmolLM3-3B-Base**: Base model for general fine-tuning
|
| 202 |
+
- **SmolLM3-3B**: Instruction-tuned model
|
| 203 |
+
- **SmolLM3-3B-Instruct**: Enhanced instruction model
|
| 204 |
+
- **Quantized versions**: Available for deployment
|
| 205 |
+
|
| 206 |
+
## Hardware Requirements
|
| 207 |
+
|
| 208 |
+
### Minimum Requirements
|
| 209 |
+
- **GPU**: 16GB+ VRAM (for 3B model)
|
| 210 |
+
- **RAM**: 32GB+ system memory
|
| 211 |
+
- **Storage**: 50GB+ free space
|
| 212 |
+
|
| 213 |
+
### Recommended
|
| 214 |
+
- **GPU**: A100/H100 or similar
|
| 215 |
+
- **RAM**: 64GB+ system memory
|
| 216 |
+
- **Storage**: 100GB+ SSD
|
| 217 |
+
|
| 218 |
+
## Troubleshooting
|
| 219 |
+
|
| 220 |
+
### Common Issues
|
| 221 |
+
|
| 222 |
+
1. **Out of Memory (OOM)**
|
| 223 |
+
- Reduce `batch_size`
|
| 224 |
+
- Increase `gradient_accumulation_steps`
|
| 225 |
+
- Enable `gradient_checkpointing`
|
| 226 |
+
- Use `fp16` or `bf16`
|
| 227 |
+
|
| 228 |
+
2. **Slow Training**
|
| 229 |
+
- Enable `flash_attention`
|
| 230 |
+
- Use mixed precision (`fp16`/`bf16`)
|
| 231 |
+
- Increase `dataloader_num_workers`
|
| 232 |
+
|
| 233 |
+
3. **Dataset Loading Issues**
|
| 234 |
+
- Check dataset format
|
| 235 |
+
- Ensure proper JSON structure
|
| 236 |
+
- Verify file permissions
|
| 237 |
+
|
| 238 |
+
### Debug Mode
|
| 239 |
+
|
| 240 |
+
Enable debug logging:
|
| 241 |
+
|
| 242 |
+
```python
|
| 243 |
+
import logging
|
| 244 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
## Evaluation
|
| 248 |
+
|
| 249 |
+
After training, evaluate your model:
|
| 250 |
+
|
| 251 |
+
```python
|
| 252 |
+
from transformers import pipeline
|
| 253 |
+
|
| 254 |
+
pipe = pipeline(
|
| 255 |
+
task="text-generation",
|
| 256 |
+
model="./output-checkpoint",
|
| 257 |
+
device=0,
|
| 258 |
+
max_new_tokens=256,
|
| 259 |
+
do_sample=True,
|
| 260 |
+
temperature=0.7
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Test the model
|
| 264 |
+
messages = [{"role": "user", "content": "Explain gravity in simple terms."}]
|
| 265 |
+
outputs = pipe(messages)
|
| 266 |
+
print(outputs[0]["generated_text"][-1]["content"])
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
## Deployment
|
| 270 |
+
|
| 271 |
+
### Using vLLM
|
| 272 |
+
```bash
|
| 273 |
+
vllm serve ./output-checkpoint --enable-auto-tool-choice
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
### Using llama.cpp
|
| 277 |
+
```bash
|
| 278 |
+
# Convert to GGUF format
|
| 279 |
+
python -m llama_cpp.convert_model ./output-checkpoint --outfile model.gguf
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
## Resources
|
| 283 |
+
|
| 284 |
+
- [SmolLM3 Blog Post](https://huggingface.co/blog/smollm3)
|
| 285 |
+
- [Model Repository](https://huggingface.co/HuggingFaceTB/SmolLM3-3B)
|
| 286 |
+
- [GitHub Repository](https://github.com/huggingface/smollm)
|
| 287 |
+
- [SmolTalk Dataset](https://huggingface.co/datasets/HuggingFaceTB/smoltalk)
|
| 288 |
+
|
| 289 |
+
## License
|
| 290 |
+
|
| 291 |
+
This project follows the same license as the SmolLM3 model. Please refer to the Hugging Face model page for licensing information.
|
config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration management for SmolLM3 fine-tuning
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import importlib.util
|
| 7 |
+
from typing import Any
|
| 8 |
+
from config.train_smollm3 import SmolLM3Config, get_config as get_default_config
|
| 9 |
+
|
| 10 |
+
def get_config(config_path: str) -> SmolLM3Config:
|
| 11 |
+
"""Load configuration from file or return default"""
|
| 12 |
+
if os.path.exists(config_path):
|
| 13 |
+
# Load from file if it exists
|
| 14 |
+
spec = importlib.util.spec_from_file_location("config_module", config_path)
|
| 15 |
+
config_module = importlib.util.module_from_spec(spec)
|
| 16 |
+
spec.loader.exec_module(config_module)
|
| 17 |
+
|
| 18 |
+
if hasattr(config_module, 'config'):
|
| 19 |
+
return config_module.config
|
| 20 |
+
else:
|
| 21 |
+
# Try to find a config class
|
| 22 |
+
for attr_name in dir(config_module):
|
| 23 |
+
attr = getattr(config_module, attr_name)
|
| 24 |
+
if isinstance(attr, SmolLM3Config):
|
| 25 |
+
return attr
|
| 26 |
+
|
| 27 |
+
# Return default configuration
|
| 28 |
+
return get_default_config(config_path)
|
config/train_smollm3.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolLM3 Training Configuration
|
| 3 |
+
Based on nanoGPT structure but adapted for SmolLM3
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class SmolLM3Config:
|
| 12 |
+
"""Configuration for SmolLM3 fine-tuning"""
|
| 13 |
+
|
| 14 |
+
# Model configuration
|
| 15 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
| 16 |
+
max_seq_length: int = 4096
|
| 17 |
+
use_flash_attention: bool = True
|
| 18 |
+
use_gradient_checkpointing: bool = True
|
| 19 |
+
|
| 20 |
+
# Training configuration
|
| 21 |
+
batch_size: int = 4
|
| 22 |
+
gradient_accumulation_steps: int = 4
|
| 23 |
+
learning_rate: float = 2e-5
|
| 24 |
+
weight_decay: float = 0.01
|
| 25 |
+
warmup_steps: int = 100
|
| 26 |
+
max_iters: int = 1000
|
| 27 |
+
eval_interval: int = 100
|
| 28 |
+
log_interval: int = 10
|
| 29 |
+
save_interval: int = 500
|
| 30 |
+
|
| 31 |
+
# Optimizer configuration
|
| 32 |
+
optimizer: str = "adamw"
|
| 33 |
+
beta1: float = 0.9
|
| 34 |
+
beta2: float = 0.95
|
| 35 |
+
eps: float = 1e-8
|
| 36 |
+
|
| 37 |
+
# Scheduler configuration
|
| 38 |
+
scheduler: str = "cosine"
|
| 39 |
+
min_lr: float = 1e-6
|
| 40 |
+
|
| 41 |
+
# Mixed precision
|
| 42 |
+
fp16: bool = True
|
| 43 |
+
bf16: bool = False
|
| 44 |
+
|
| 45 |
+
# DDP configuration
|
| 46 |
+
ddp_backend: str = "nccl"
|
| 47 |
+
ddp_find_unused_parameters: bool = False
|
| 48 |
+
|
| 49 |
+
# Logging and saving
|
| 50 |
+
save_steps: int = 500
|
| 51 |
+
eval_steps: int = 100
|
| 52 |
+
logging_steps: int = 10
|
| 53 |
+
save_total_limit: Optional[int] = 3
|
| 54 |
+
|
| 55 |
+
# Evaluation
|
| 56 |
+
eval_strategy: str = "steps"
|
| 57 |
+
metric_for_best_model: str = "eval_loss"
|
| 58 |
+
greater_is_better: bool = False
|
| 59 |
+
load_best_model_at_end: bool = True
|
| 60 |
+
|
| 61 |
+
# Data configuration
|
| 62 |
+
data_dir: str = "my_dataset"
|
| 63 |
+
train_file: str = "train.json"
|
| 64 |
+
validation_file: Optional[str] = None
|
| 65 |
+
test_file: Optional[str] = None
|
| 66 |
+
|
| 67 |
+
# Chat template configuration
|
| 68 |
+
use_chat_template: bool = True
|
| 69 |
+
chat_template_kwargs: dict = None
|
| 70 |
+
|
| 71 |
+
def __post_init__(self):
|
| 72 |
+
if self.chat_template_kwargs is None:
|
| 73 |
+
self.chat_template_kwargs = {
|
| 74 |
+
"enable_thinking": False,
|
| 75 |
+
"add_generation_prompt": True
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Validate configuration
|
| 79 |
+
if self.fp16 and self.bf16:
|
| 80 |
+
raise ValueError("Cannot use both fp16 and bf16")
|
| 81 |
+
|
| 82 |
+
if self.max_seq_length > 131072: # 128k limit
|
| 83 |
+
raise ValueError("max_seq_length cannot exceed 131072")
|
| 84 |
+
|
| 85 |
+
def get_config(config_path: str) -> SmolLM3Config:
|
| 86 |
+
"""Load configuration from file or return default"""
|
| 87 |
+
if os.path.exists(config_path):
|
| 88 |
+
# Load from file if it exists
|
| 89 |
+
import importlib.util
|
| 90 |
+
spec = importlib.util.spec_from_file_location("config_module", config_path)
|
| 91 |
+
config_module = importlib.util.module_from_spec(spec)
|
| 92 |
+
spec.loader.exec_module(config_module)
|
| 93 |
+
|
| 94 |
+
if hasattr(config_module, 'config'):
|
| 95 |
+
return config_module.config
|
| 96 |
+
else:
|
| 97 |
+
# Try to find a config class
|
| 98 |
+
for attr_name in dir(config_module):
|
| 99 |
+
attr = getattr(config_module, attr_name)
|
| 100 |
+
if isinstance(attr, SmolLM3Config):
|
| 101 |
+
return attr
|
| 102 |
+
|
| 103 |
+
# Return default configuration
|
| 104 |
+
return SmolLM3Config()
|
| 105 |
+
|
| 106 |
+
# Default configuration instance
|
| 107 |
+
config = SmolLM3Config()
|
config/train_smollm3_dpo.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolLM3 DPO Training Configuration
|
| 3 |
+
Optimized for Direct Preference Optimization
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from config.train_smollm3 import SmolLM3Config
|
| 7 |
+
|
| 8 |
+
config = SmolLM3Config(
|
| 9 |
+
# Model configuration
|
| 10 |
+
model_name="HuggingFaceTB/SmolLM3-3B-Instruct", # Start from instruction-tuned model
|
| 11 |
+
max_seq_length=4096,
|
| 12 |
+
use_flash_attention=True,
|
| 13 |
+
use_gradient_checkpointing=True,
|
| 14 |
+
|
| 15 |
+
# Training configuration
|
| 16 |
+
batch_size=2, # Smaller batch size for DPO
|
| 17 |
+
gradient_accumulation_steps=4,
|
| 18 |
+
learning_rate=5e-6, # Very low learning rate for DPO
|
| 19 |
+
weight_decay=0.01,
|
| 20 |
+
warmup_steps=100,
|
| 21 |
+
max_iters=1000,
|
| 22 |
+
|
| 23 |
+
# Mixed precision
|
| 24 |
+
fp16=True,
|
| 25 |
+
bf16=False,
|
| 26 |
+
|
| 27 |
+
# Logging and saving
|
| 28 |
+
save_steps=200,
|
| 29 |
+
eval_steps=100,
|
| 30 |
+
logging_steps=20,
|
| 31 |
+
|
| 32 |
+
# Chat template configuration
|
| 33 |
+
use_chat_template=True,
|
| 34 |
+
chat_template_kwargs={
|
| 35 |
+
"enable_thinking": False, # Disable reasoning for preference learning
|
| 36 |
+
"add_generation_prompt": True
|
| 37 |
+
}
|
| 38 |
+
)
|
config/train_smollm3_long_context.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolLM3 Long-Context Training Configuration
|
| 3 |
+
Optimized for long-context tasks (up to 128k tokens)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from config.train_smollm3 import SmolLM3Config
|
| 7 |
+
|
| 8 |
+
config = SmolLM3Config(
|
| 9 |
+
# Model configuration
|
| 10 |
+
model_name="HuggingFaceTB/SmolLM3-3B",
|
| 11 |
+
max_seq_length=131072, # 128k tokens
|
| 12 |
+
use_flash_attention=True,
|
| 13 |
+
use_gradient_checkpointing=True,
|
| 14 |
+
|
| 15 |
+
# Training configuration
|
| 16 |
+
batch_size=1, # Reduced for long sequences
|
| 17 |
+
gradient_accumulation_steps=8, # Increased to maintain effective batch size
|
| 18 |
+
learning_rate=1e-5, # Lower learning rate for stability
|
| 19 |
+
weight_decay=0.01,
|
| 20 |
+
warmup_steps=200,
|
| 21 |
+
max_iters=500,
|
| 22 |
+
|
| 23 |
+
# Mixed precision
|
| 24 |
+
fp16=True,
|
| 25 |
+
bf16=False,
|
| 26 |
+
|
| 27 |
+
# Logging and saving
|
| 28 |
+
save_steps=100,
|
| 29 |
+
eval_steps=50,
|
| 30 |
+
logging_steps=10,
|
| 31 |
+
|
| 32 |
+
# Chat template configuration
|
| 33 |
+
use_chat_template=True,
|
| 34 |
+
chat_template_kwargs={
|
| 35 |
+
"enable_thinking": True, # Enable reasoning mode
|
| 36 |
+
"add_generation_prompt": True
|
| 37 |
+
}
|
| 38 |
+
)
|
create_sample_dataset.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sample Dataset Creation Script
|
| 4 |
+
Creates sample datasets for testing SmolLM3 fine-tuning
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import argparse
|
| 10 |
+
from data import create_sample_dataset
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
parser = argparse.ArgumentParser(description='Create sample dataset for SmolLM3 fine-tuning')
|
| 14 |
+
parser.add_argument('--output_dir', type=str, default='my_dataset',
|
| 15 |
+
help='Output directory for the dataset')
|
| 16 |
+
parser.add_argument('--format', type=str, default='chat',
|
| 17 |
+
choices=['chat', 'instruction', 'user_assistant'],
|
| 18 |
+
help='Dataset format')
|
| 19 |
+
parser.add_argument('--num_samples', type=int, default=100,
|
| 20 |
+
help='Number of samples to create')
|
| 21 |
+
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
# Create sample dataset
|
| 25 |
+
output_path = create_sample_dataset(args.output_dir)
|
| 26 |
+
|
| 27 |
+
print(f"Sample dataset created in: {output_path}")
|
| 28 |
+
print(f"Format: {args.format}")
|
| 29 |
+
print(f"Samples: {args.num_samples}")
|
| 30 |
+
print("\nFiles created:")
|
| 31 |
+
print(f"- {os.path.join(output_path, 'train.json')}")
|
| 32 |
+
print(f"- {os.path.join(output_path, 'validation.json')}")
|
| 33 |
+
|
| 34 |
+
# Show sample data
|
| 35 |
+
with open(os.path.join(output_path, 'train.json'), 'r') as f:
|
| 36 |
+
data = json.load(f)
|
| 37 |
+
print(f"\nSample data:")
|
| 38 |
+
print(json.dumps(data[0], indent=2))
|
| 39 |
+
|
| 40 |
+
if __name__ == '__main__':
|
| 41 |
+
main()
|
data.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolLM3 Dataset Handler
|
| 3 |
+
Handles data loading, preprocessing, and tokenization for SmolLM3 fine-tuning
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
from typing import Dict, List, Optional, Union
|
| 10 |
+
from datasets import Dataset, load_dataset
|
| 11 |
+
from transformers import PreTrainedTokenizer
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class SmolLM3Dataset:
|
| 17 |
+
"""Dataset handler for SmolLM3 fine-tuning"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
data_path: str,
|
| 22 |
+
tokenizer: PreTrainedTokenizer,
|
| 23 |
+
max_seq_length: int = 4096,
|
| 24 |
+
use_chat_template: bool = True,
|
| 25 |
+
chat_template_kwargs: Optional[Dict] = None
|
| 26 |
+
):
|
| 27 |
+
self.data_path = data_path
|
| 28 |
+
self.tokenizer = tokenizer
|
| 29 |
+
self.max_seq_length = max_seq_length
|
| 30 |
+
self.use_chat_template = use_chat_template
|
| 31 |
+
self.chat_template_kwargs = chat_template_kwargs or {}
|
| 32 |
+
|
| 33 |
+
# Load and process dataset
|
| 34 |
+
self.dataset = self._load_dataset()
|
| 35 |
+
self.processed_dataset = self._process_dataset()
|
| 36 |
+
|
| 37 |
+
def _load_dataset(self) -> Dataset:
|
| 38 |
+
"""Load dataset from various formats"""
|
| 39 |
+
logger.info(f"Loading dataset from {self.data_path}")
|
| 40 |
+
|
| 41 |
+
# Check if it's a Hugging Face dataset
|
| 42 |
+
if os.path.isdir(self.data_path):
|
| 43 |
+
# Local directory
|
| 44 |
+
try:
|
| 45 |
+
dataset = load_dataset("json", data_files={
|
| 46 |
+
"train": os.path.join(self.data_path, "train.json"),
|
| 47 |
+
"validation": os.path.join(self.data_path, "validation.json") if os.path.exists(os.path.join(self.data_path, "validation.json")) else None,
|
| 48 |
+
"test": os.path.join(self.data_path, "test.json") if os.path.exists(os.path.join(self.data_path, "test.json")) else None
|
| 49 |
+
})
|
| 50 |
+
logger.info("Loaded dataset from local JSON files")
|
| 51 |
+
return dataset
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.warning(f"Failed to load as JSON dataset: {e}")
|
| 54 |
+
|
| 55 |
+
# Try to load as a single JSON file
|
| 56 |
+
if os.path.isfile(self.data_path) and self.data_path.endswith('.json'):
|
| 57 |
+
try:
|
| 58 |
+
with open(self.data_path, 'r', encoding='utf-8') as f:
|
| 59 |
+
data = json.load(f)
|
| 60 |
+
|
| 61 |
+
# Convert to dataset format
|
| 62 |
+
if isinstance(data, list):
|
| 63 |
+
dataset = Dataset.from_list(data)
|
| 64 |
+
else:
|
| 65 |
+
dataset = Dataset.from_dict(data)
|
| 66 |
+
|
| 67 |
+
logger.info("Loaded dataset from single JSON file")
|
| 68 |
+
return dataset
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"Failed to load JSON file: {e}")
|
| 71 |
+
raise
|
| 72 |
+
|
| 73 |
+
# Try to load as a Hugging Face dataset name
|
| 74 |
+
try:
|
| 75 |
+
dataset = load_dataset(self.data_path)
|
| 76 |
+
logger.info(f"Loaded Hugging Face dataset: {self.data_path}")
|
| 77 |
+
return dataset
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.error(f"Failed to load dataset: {e}")
|
| 80 |
+
raise
|
| 81 |
+
|
| 82 |
+
def _process_dataset(self) -> Dataset:
|
| 83 |
+
"""Process the dataset for training"""
|
| 84 |
+
logger.info("Processing dataset for training")
|
| 85 |
+
|
| 86 |
+
def format_chat_template(example):
|
| 87 |
+
"""Format example using chat template"""
|
| 88 |
+
if self.use_chat_template:
|
| 89 |
+
try:
|
| 90 |
+
# Handle different input formats
|
| 91 |
+
if "messages" in example:
|
| 92 |
+
messages = example["messages"]
|
| 93 |
+
elif "conversations" in example:
|
| 94 |
+
messages = example["conversations"]
|
| 95 |
+
elif "user" in example and "assistant" in example:
|
| 96 |
+
messages = [
|
| 97 |
+
{"role": "user", "content": example["user"]},
|
| 98 |
+
{"role": "assistant", "content": example["assistant"]}
|
| 99 |
+
]
|
| 100 |
+
elif "instruction" in example and "output" in example:
|
| 101 |
+
messages = [
|
| 102 |
+
{"role": "user", "content": example["instruction"]},
|
| 103 |
+
{"role": "assistant", "content": example["output"]}
|
| 104 |
+
]
|
| 105 |
+
elif "prompt" in example and "completion" in example:
|
| 106 |
+
messages = [
|
| 107 |
+
{"role": "user", "content": example["prompt"]},
|
| 108 |
+
{"role": "assistant", "content": example["completion"]}
|
| 109 |
+
]
|
| 110 |
+
else:
|
| 111 |
+
# Fallback: treat as plain text
|
| 112 |
+
return {"text": str(example)}
|
| 113 |
+
|
| 114 |
+
# Apply chat template
|
| 115 |
+
text = self.tokenizer.apply_chat_template(
|
| 116 |
+
messages,
|
| 117 |
+
tokenize=False,
|
| 118 |
+
**self.chat_template_kwargs
|
| 119 |
+
)
|
| 120 |
+
return {"text": text}
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.warning(f"Failed to apply chat template: {e}")
|
| 123 |
+
# Fallback to plain text
|
| 124 |
+
return {"text": str(example)}
|
| 125 |
+
else:
|
| 126 |
+
# Use plain text
|
| 127 |
+
if "text" in example:
|
| 128 |
+
return {"text": example["text"]}
|
| 129 |
+
else:
|
| 130 |
+
return {"text": str(example)}
|
| 131 |
+
|
| 132 |
+
def tokenize_function(examples):
|
| 133 |
+
"""Tokenize the examples"""
|
| 134 |
+
# Tokenize the texts
|
| 135 |
+
tokenized = self.tokenizer(
|
| 136 |
+
examples["text"],
|
| 137 |
+
truncation=True,
|
| 138 |
+
padding=False,
|
| 139 |
+
max_length=self.max_seq_length,
|
| 140 |
+
return_overflowing_tokens=True,
|
| 141 |
+
return_length=True,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Calculate input length
|
| 145 |
+
input_length = [len(x) for x in tokenized["input_ids"]]
|
| 146 |
+
|
| 147 |
+
# Create labels (same as input_ids for causal LM)
|
| 148 |
+
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"input_ids": tokenized["input_ids"],
|
| 152 |
+
"attention_mask": tokenized["attention_mask"],
|
| 153 |
+
"labels": tokenized["labels"],
|
| 154 |
+
"length": input_length,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Process the dataset
|
| 158 |
+
processed_dataset = self.dataset.map(
|
| 159 |
+
format_chat_template,
|
| 160 |
+
remove_columns=self.dataset["train"].column_names,
|
| 161 |
+
desc="Formatting dataset"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Tokenize the dataset
|
| 165 |
+
tokenized_dataset = processed_dataset.map(
|
| 166 |
+
tokenize_function,
|
| 167 |
+
remove_columns=processed_dataset["train"].column_names,
|
| 168 |
+
desc="Tokenizing dataset",
|
| 169 |
+
batched=True,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
logger.info(f"Dataset processed. Train samples: {len(tokenized_dataset['train'])}")
|
| 173 |
+
if "validation" in tokenized_dataset:
|
| 174 |
+
logger.info(f"Validation samples: {len(tokenized_dataset['validation'])}")
|
| 175 |
+
|
| 176 |
+
return tokenized_dataset
|
| 177 |
+
|
| 178 |
+
def get_train_dataset(self) -> Dataset:
|
| 179 |
+
"""Get training dataset"""
|
| 180 |
+
return self.processed_dataset["train"]
|
| 181 |
+
|
| 182 |
+
def get_eval_dataset(self) -> Optional[Dataset]:
|
| 183 |
+
"""Get evaluation dataset if available"""
|
| 184 |
+
if "validation" in self.processed_dataset:
|
| 185 |
+
return self.processed_dataset["validation"]
|
| 186 |
+
elif "test" in self.processed_dataset:
|
| 187 |
+
return self.processed_dataset["test"]
|
| 188 |
+
else:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
def get_data_collator(self):
|
| 192 |
+
"""Get data collator for training"""
|
| 193 |
+
from transformers import DataCollatorForLanguageModeling
|
| 194 |
+
|
| 195 |
+
return DataCollatorForLanguageModeling(
|
| 196 |
+
tokenizer=self.tokenizer,
|
| 197 |
+
mlm=False, # We're doing causal LM, not masked LM
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def create_sample_dataset(output_path: str = "my_dataset"):
|
| 201 |
+
"""Create a sample dataset for testing"""
|
| 202 |
+
os.makedirs(output_path, exist_ok=True)
|
| 203 |
+
|
| 204 |
+
# Sample conversations
|
| 205 |
+
conversations = [
|
| 206 |
+
{
|
| 207 |
+
"messages": [
|
| 208 |
+
{"role": "user", "content": "What is machine learning?"},
|
| 209 |
+
{"role": "assistant", "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed."}
|
| 210 |
+
]
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"messages": [
|
| 214 |
+
{"role": "user", "content": "Explain gravity in simple terms."},
|
| 215 |
+
{"role": "assistant", "content": "Gravity is the force that pulls objects toward each other, like how the Earth pulls things down to the ground."}
|
| 216 |
+
]
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"messages": [
|
| 220 |
+
{"role": "user", "content": "How do I make a cup of coffee?"},
|
| 221 |
+
{"role": "assistant", "content": "To make a cup of coffee: 1) Boil water, 2) Add coffee grounds to a filter, 3) Pour hot water over the grounds, 4) Let it brew for a few minutes, 5) Enjoy!"}
|
| 222 |
+
]
|
| 223 |
+
}
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
# Split into train/validation
|
| 227 |
+
train_data = conversations[:2]
|
| 228 |
+
validation_data = conversations[2:]
|
| 229 |
+
|
| 230 |
+
# Save to files
|
| 231 |
+
with open(os.path.join(output_path, "train.json"), 'w', encoding='utf-8') as f:
|
| 232 |
+
json.dump(train_data, f, indent=2, ensure_ascii=False)
|
| 233 |
+
|
| 234 |
+
with open(os.path.join(output_path, "validation.json"), 'w', encoding='utf-8') as f:
|
| 235 |
+
json.dump(validation_data, f, indent=2, ensure_ascii=False)
|
| 236 |
+
|
| 237 |
+
logger.info(f"Sample dataset created in {output_path}")
|
| 238 |
+
return output_path
|
model.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolLM3 Model Wrapper
|
| 3 |
+
Handles model loading, tokenizer, and training setup
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoModelForCausalLM,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
AutoConfig,
|
| 13 |
+
TrainingArguments,
|
| 14 |
+
Trainer
|
| 15 |
+
)
|
| 16 |
+
from typing import Optional, Dict, Any
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
class SmolLM3Model:
|
| 22 |
+
"""Wrapper for SmolLM3 model and tokenizer"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B",
|
| 27 |
+
max_seq_length: int = 4096,
|
| 28 |
+
config: Optional[Any] = None,
|
| 29 |
+
device_map: Optional[str] = None,
|
| 30 |
+
torch_dtype: Optional[torch.dtype] = None
|
| 31 |
+
):
|
| 32 |
+
self.model_name = model_name
|
| 33 |
+
self.max_seq_length = max_seq_length
|
| 34 |
+
self.config = config
|
| 35 |
+
|
| 36 |
+
# Set device and dtype
|
| 37 |
+
if torch_dtype is None:
|
| 38 |
+
if torch.cuda.is_available():
|
| 39 |
+
self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 40 |
+
else:
|
| 41 |
+
self.torch_dtype = torch.float32
|
| 42 |
+
else:
|
| 43 |
+
self.torch_dtype = torch_dtype
|
| 44 |
+
|
| 45 |
+
if device_map is None:
|
| 46 |
+
self.device_map = "auto" if torch.cuda.is_available() else "cpu"
|
| 47 |
+
else:
|
| 48 |
+
self.device_map = device_map
|
| 49 |
+
|
| 50 |
+
# Load tokenizer and model
|
| 51 |
+
self._load_tokenizer()
|
| 52 |
+
self._load_model()
|
| 53 |
+
|
| 54 |
+
def _load_tokenizer(self):
|
| 55 |
+
"""Load the tokenizer"""
|
| 56 |
+
logger.info(f"Loading tokenizer from {self.model_name}")
|
| 57 |
+
try:
|
| 58 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 59 |
+
self.model_name,
|
| 60 |
+
trust_remote_code=True,
|
| 61 |
+
use_fast=True
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Set pad token if not present
|
| 65 |
+
if self.tokenizer.pad_token is None:
|
| 66 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 67 |
+
|
| 68 |
+
logger.info(f"Tokenizer loaded successfully. Vocab size: {self.tokenizer.vocab_size}")
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Failed to load tokenizer: {e}")
|
| 72 |
+
raise
|
| 73 |
+
|
| 74 |
+
def _load_model(self):
|
| 75 |
+
"""Load the model"""
|
| 76 |
+
logger.info(f"Loading model from {self.model_name}")
|
| 77 |
+
try:
|
| 78 |
+
# Load model configuration
|
| 79 |
+
model_config = AutoConfig.from_pretrained(
|
| 80 |
+
self.model_name,
|
| 81 |
+
trust_remote_code=True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Update configuration if needed
|
| 85 |
+
if hasattr(model_config, 'max_position_embeddings'):
|
| 86 |
+
model_config.max_position_embeddings = self.max_seq_length
|
| 87 |
+
|
| 88 |
+
# Load model
|
| 89 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 90 |
+
self.model_name,
|
| 91 |
+
config=model_config,
|
| 92 |
+
torch_dtype=self.torch_dtype,
|
| 93 |
+
device_map=self.device_map,
|
| 94 |
+
trust_remote_code=True,
|
| 95 |
+
use_flash_attention_2=self.config.use_flash_attention if self.config else True,
|
| 96 |
+
use_cache=False # Disable KV cache for training
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Enable gradient checkpointing if specified
|
| 100 |
+
if self.config and self.config.use_gradient_checkpointing:
|
| 101 |
+
self.model.gradient_checkpointing_enable()
|
| 102 |
+
|
| 103 |
+
logger.info(f"Model loaded successfully. Parameters: {self.model.num_parameters():,}")
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"Failed to load model: {e}")
|
| 107 |
+
raise
|
| 108 |
+
|
| 109 |
+
def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments:
|
| 110 |
+
"""Get training arguments for the Trainer"""
|
| 111 |
+
if self.config is None:
|
| 112 |
+
raise ValueError("Config is required to get training arguments")
|
| 113 |
+
|
| 114 |
+
# Merge config with kwargs
|
| 115 |
+
training_args = {
|
| 116 |
+
"output_dir": output_dir,
|
| 117 |
+
"per_device_train_batch_size": self.config.batch_size,
|
| 118 |
+
"per_device_eval_batch_size": self.config.batch_size,
|
| 119 |
+
"gradient_accumulation_steps": self.config.gradient_accumulation_steps,
|
| 120 |
+
"learning_rate": self.config.learning_rate,
|
| 121 |
+
"weight_decay": self.config.weight_decay,
|
| 122 |
+
"warmup_steps": self.config.warmup_steps,
|
| 123 |
+
"max_steps": self.config.max_iters,
|
| 124 |
+
"save_steps": self.config.save_steps,
|
| 125 |
+
"eval_steps": self.config.eval_steps,
|
| 126 |
+
"logging_steps": self.config.logging_steps,
|
| 127 |
+
"save_total_limit": self.config.save_total_limit,
|
| 128 |
+
"evaluation_strategy": self.config.eval_strategy,
|
| 129 |
+
"metric_for_best_model": self.config.metric_for_best_model,
|
| 130 |
+
"greater_is_better": self.config.greater_is_better,
|
| 131 |
+
"load_best_model_at_end": self.config.load_best_model_at_end,
|
| 132 |
+
"fp16": self.config.fp16,
|
| 133 |
+
"bf16": self.config.bf16,
|
| 134 |
+
"ddp_backend": self.config.ddp_backend,
|
| 135 |
+
"ddp_find_unused_parameters": self.config.ddp_find_unused_parameters,
|
| 136 |
+
"report_to": "none", # Disable external logging
|
| 137 |
+
"remove_unused_columns": False,
|
| 138 |
+
"dataloader_pin_memory": False,
|
| 139 |
+
"group_by_length": True,
|
| 140 |
+
"length_column_name": "length",
|
| 141 |
+
"ignore_data_skip": False,
|
| 142 |
+
"seed": 42,
|
| 143 |
+
"data_seed": 42,
|
| 144 |
+
"dataloader_num_workers": 4,
|
| 145 |
+
"max_grad_norm": 1.0,
|
| 146 |
+
"optim": self.config.optimizer,
|
| 147 |
+
"lr_scheduler_type": self.config.scheduler,
|
| 148 |
+
"warmup_ratio": 0.1,
|
| 149 |
+
"save_strategy": "steps",
|
| 150 |
+
"logging_strategy": "steps",
|
| 151 |
+
"prediction_loss_only": True,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
# Override with kwargs
|
| 155 |
+
training_args.update(kwargs)
|
| 156 |
+
|
| 157 |
+
return TrainingArguments(**training_args)
|
| 158 |
+
|
| 159 |
+
def save_pretrained(self, path: str):
|
| 160 |
+
"""Save model and tokenizer"""
|
| 161 |
+
logger.info(f"Saving model and tokenizer to {path}")
|
| 162 |
+
os.makedirs(path, exist_ok=True)
|
| 163 |
+
|
| 164 |
+
self.model.save_pretrained(path)
|
| 165 |
+
self.tokenizer.save_pretrained(path)
|
| 166 |
+
|
| 167 |
+
# Save configuration
|
| 168 |
+
if self.config:
|
| 169 |
+
import json
|
| 170 |
+
config_dict = {k: v for k, v in self.config.__dict__.items()
|
| 171 |
+
if not k.startswith('_')}
|
| 172 |
+
with open(os.path.join(path, 'training_config.json'), 'w') as f:
|
| 173 |
+
json.dump(config_dict, f, indent=2, default=str)
|
| 174 |
+
|
| 175 |
+
def load_checkpoint(self, checkpoint_path: str):
|
| 176 |
+
"""Load model from checkpoint"""
|
| 177 |
+
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
| 178 |
+
try:
|
| 179 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 180 |
+
checkpoint_path,
|
| 181 |
+
torch_dtype=self.torch_dtype,
|
| 182 |
+
device_map=self.device_map,
|
| 183 |
+
trust_remote_code=True
|
| 184 |
+
)
|
| 185 |
+
logger.info("Checkpoint loaded successfully")
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.error(f"Failed to load checkpoint: {e}")
|
| 188 |
+
raise
|
requirements.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.53.0
|
| 4 |
+
datasets>=2.14.0
|
| 5 |
+
accelerate>=0.20.0
|
| 6 |
+
trl>=0.7.0
|
| 7 |
+
|
| 8 |
+
# Hugging Face ecosystem
|
| 9 |
+
huggingface-hub>=0.16.0
|
| 10 |
+
tokenizers>=0.13.0
|
| 11 |
+
|
| 12 |
+
# Training and optimization
|
| 13 |
+
flash-attn>=2.0.0
|
| 14 |
+
xformers>=0.0.20
|
| 15 |
+
bitsandbytes>=0.41.0
|
| 16 |
+
|
| 17 |
+
# Utilities
|
| 18 |
+
numpy>=1.24.0
|
| 19 |
+
pandas>=2.0.0
|
| 20 |
+
scikit-learn>=1.3.0
|
| 21 |
+
tqdm>=4.65.0
|
| 22 |
+
wandb>=0.15.0
|
| 23 |
+
|
| 24 |
+
# Optional: for evaluation
|
| 25 |
+
lighteval>=0.1.0
|
| 26 |
+
evaluate>=0.4.0
|
| 27 |
+
|
| 28 |
+
# Optional: for deployment
|
| 29 |
+
vllm>=0.2.0
|
| 30 |
+
sentencepiece>=0.1.99
|
| 31 |
+
|
| 32 |
+
# Development
|
| 33 |
+
pytest>=7.0.0
|
| 34 |
+
black>=23.0.0
|
| 35 |
+
isort>=5.12.0
|
test_setup.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test Setup Script
|
| 4 |
+
Verifies that all components are working correctly
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import torch
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Setup logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
def test_imports():
|
| 18 |
+
"""Test that all required modules can be imported"""
|
| 19 |
+
logger.info("Testing imports...")
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import transformers
|
| 23 |
+
logger.info(f"✓ transformers {transformers.__version__}")
|
| 24 |
+
except ImportError as e:
|
| 25 |
+
logger.error(f"✗ transformers: {e}")
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import datasets
|
| 30 |
+
logger.info(f"✓ datasets {datasets.__version__}")
|
| 31 |
+
except ImportError as e:
|
| 32 |
+
logger.error(f"✗ datasets: {e}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import trl
|
| 37 |
+
logger.info(f"✓ trl {trl.__version__}")
|
| 38 |
+
except ImportError as e:
|
| 39 |
+
logger.error(f"✗ trl: {e}")
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import accelerate
|
| 44 |
+
logger.info(f"✓ accelerate {accelerate.__version__}")
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
logger.error(f"✗ accelerate: {e}")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
def test_local_imports():
|
| 52 |
+
"""Test that local modules can be imported"""
|
| 53 |
+
logger.info("Testing local imports...")
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
from config import get_config
|
| 57 |
+
logger.info("✓ config module")
|
| 58 |
+
except ImportError as e:
|
| 59 |
+
logger.error(f"✗ config module: {e}")
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
from model import SmolLM3Model
|
| 64 |
+
logger.info("✓ model module")
|
| 65 |
+
except ImportError as e:
|
| 66 |
+
logger.error(f"✗ model module: {e}")
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
from data import SmolLM3Dataset
|
| 71 |
+
logger.info("✓ data module")
|
| 72 |
+
except ImportError as e:
|
| 73 |
+
logger.error(f"✗ data module: {e}")
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
from trainer import SmolLM3Trainer
|
| 78 |
+
logger.info("✓ trainer module")
|
| 79 |
+
except ImportError as e:
|
| 80 |
+
logger.error(f"✗ trainer module: {e}")
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
def test_config():
|
| 86 |
+
"""Test configuration loading"""
|
| 87 |
+
logger.info("Testing configuration...")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
from config import get_config
|
| 91 |
+
config = get_config("config/train_smollm3.py")
|
| 92 |
+
logger.info(f"✓ Configuration loaded: {config.model_name}")
|
| 93 |
+
return True
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"✗ Configuration loading failed: {e}")
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
def test_dataset_creation():
|
| 99 |
+
"""Test dataset creation"""
|
| 100 |
+
logger.info("Testing dataset creation...")
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
from data import create_sample_dataset
|
| 104 |
+
output_path = create_sample_dataset("test_dataset")
|
| 105 |
+
|
| 106 |
+
# Check if files were created
|
| 107 |
+
train_file = os.path.join(output_path, "train.json")
|
| 108 |
+
val_file = os.path.join(output_path, "validation.json")
|
| 109 |
+
|
| 110 |
+
if os.path.exists(train_file) and os.path.exists(val_file):
|
| 111 |
+
logger.info("✓ Sample dataset created successfully")
|
| 112 |
+
|
| 113 |
+
# Clean up
|
| 114 |
+
import shutil
|
| 115 |
+
shutil.rmtree(output_path)
|
| 116 |
+
return True
|
| 117 |
+
else:
|
| 118 |
+
logger.error("✗ Dataset files not created")
|
| 119 |
+
return False
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"✗ Dataset creation failed: {e}")
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
def test_gpu_availability():
|
| 125 |
+
"""Test GPU availability"""
|
| 126 |
+
logger.info("Testing GPU availability...")
|
| 127 |
+
|
| 128 |
+
if torch.cuda.is_available():
|
| 129 |
+
logger.info(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
| 130 |
+
logger.info(f"✓ CUDA version: {torch.version.cuda}")
|
| 131 |
+
logger.info(f"✓ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 132 |
+
return True
|
| 133 |
+
else:
|
| 134 |
+
logger.warning("⚠ No GPU available, will use CPU")
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
def test_model_loading():
|
| 138 |
+
"""Test model loading (without downloading)"""
|
| 139 |
+
logger.info("Testing model loading...")
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 143 |
+
|
| 144 |
+
# Test tokenizer loading
|
| 145 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 146 |
+
"HuggingFaceTB/SmolLM3-3B",
|
| 147 |
+
trust_remote_code=True,
|
| 148 |
+
use_fast=True
|
| 149 |
+
)
|
| 150 |
+
logger.info(f"✓ Tokenizer loaded, vocab size: {tokenizer.vocab_size}")
|
| 151 |
+
|
| 152 |
+
# Test config loading
|
| 153 |
+
config = AutoConfig.from_pretrained(
|
| 154 |
+
"HuggingFaceTB/SmolLM3-3B",
|
| 155 |
+
trust_remote_code=True
|
| 156 |
+
)
|
| 157 |
+
logger.info(f"✓ Config loaded, model type: {config.model_type}")
|
| 158 |
+
|
| 159 |
+
return True
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.error(f"✗ Model loading test failed: {e}")
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
def main():
|
| 165 |
+
"""Run all tests"""
|
| 166 |
+
logger.info("Starting SmolLM3 setup tests...")
|
| 167 |
+
|
| 168 |
+
tests = [
|
| 169 |
+
("Import Tests", test_imports),
|
| 170 |
+
("Local Import Tests", test_local_imports),
|
| 171 |
+
("Configuration Tests", test_config),
|
| 172 |
+
("Dataset Creation Tests", test_dataset_creation),
|
| 173 |
+
("GPU Availability Tests", test_gpu_availability),
|
| 174 |
+
("Model Loading Tests", test_model_loading),
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
passed = 0
|
| 178 |
+
total = len(tests)
|
| 179 |
+
|
| 180 |
+
for test_name, test_func in tests:
|
| 181 |
+
logger.info(f"\n{'='*50}")
|
| 182 |
+
logger.info(f"Running: {test_name}")
|
| 183 |
+
logger.info('='*50)
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
if test_func():
|
| 187 |
+
passed += 1
|
| 188 |
+
logger.info(f"✓ {test_name} PASSED")
|
| 189 |
+
else:
|
| 190 |
+
logger.error(f"✗ {test_name} FAILED")
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"✗ {test_name} FAILED with exception: {e}")
|
| 193 |
+
|
| 194 |
+
logger.info(f"\n{'='*50}")
|
| 195 |
+
logger.info(f"Test Results: {passed}/{total} tests passed")
|
| 196 |
+
logger.info('='*50)
|
| 197 |
+
|
| 198 |
+
if passed == total:
|
| 199 |
+
logger.info("🎉 All tests passed! Setup is ready for SmolLM3 fine-tuning.")
|
| 200 |
+
return 0
|
| 201 |
+
else:
|
| 202 |
+
logger.error("❌ Some tests failed. Please check the errors above.")
|
| 203 |
+
return 1
|
| 204 |
+
|
| 205 |
+
if __name__ == '__main__':
|
| 206 |
+
sys.exit(main())
|
train.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SmolLM3 Fine-tuning Script for FlexAI Console
|
| 4 |
+
Based on the nanoGPT structure but adapted for SmolLM3 model
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import torch
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Optional, Dict, Any
|
| 15 |
+
|
| 16 |
+
# Add the current directory to the path for imports
|
| 17 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
+
|
| 19 |
+
from config import get_config
|
| 20 |
+
from model import SmolLM3Model
|
| 21 |
+
from data import SmolLM3Dataset
|
| 22 |
+
from trainer import SmolLM3Trainer
|
| 23 |
+
|
| 24 |
+
def setup_logging():
|
| 25 |
+
"""Setup logging configuration"""
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 29 |
+
handlers=[
|
| 30 |
+
logging.StreamHandler(sys.stdout),
|
| 31 |
+
logging.FileHandler('training.log')
|
| 32 |
+
]
|
| 33 |
+
)
|
| 34 |
+
return logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
def parse_args():
|
| 37 |
+
"""Parse command line arguments"""
|
| 38 |
+
parser = argparse.ArgumentParser(description='SmolLM3 Fine-tuning Script')
|
| 39 |
+
|
| 40 |
+
# Configuration file
|
| 41 |
+
parser.add_argument('config', type=str, help='Path to configuration file')
|
| 42 |
+
|
| 43 |
+
# Dataset arguments
|
| 44 |
+
parser.add_argument('--dataset_dir', type=str, default='my_dataset',
|
| 45 |
+
help='Path to dataset directory within /input')
|
| 46 |
+
|
| 47 |
+
# Checkpoint arguments
|
| 48 |
+
parser.add_argument('--out_dir', type=str, default='/output-checkpoint',
|
| 49 |
+
help='Output directory for checkpoints')
|
| 50 |
+
parser.add_argument('--init_from', type=str, default='scratch',
|
| 51 |
+
choices=['scratch', 'resume', 'pretrained'],
|
| 52 |
+
help='Initialization method')
|
| 53 |
+
|
| 54 |
+
# Training arguments
|
| 55 |
+
parser.add_argument('--max_iters', type=int, default=None,
|
| 56 |
+
help='Maximum number of training iterations')
|
| 57 |
+
parser.add_argument('--batch_size', type=int, default=None,
|
| 58 |
+
help='Batch size for training')
|
| 59 |
+
parser.add_argument('--learning_rate', type=float, default=None,
|
| 60 |
+
help='Learning rate')
|
| 61 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=None,
|
| 62 |
+
help='Gradient accumulation steps')
|
| 63 |
+
|
| 64 |
+
# Model arguments
|
| 65 |
+
parser.add_argument('--model_name', type=str,
|
| 66 |
+
default='HuggingFaceTB/SmolLM3-3B',
|
| 67 |
+
help='Model name or path')
|
| 68 |
+
parser.add_argument('--max_seq_length', type=int, default=4096,
|
| 69 |
+
help='Maximum sequence length')
|
| 70 |
+
|
| 71 |
+
# Logging and saving
|
| 72 |
+
parser.add_argument('--save_steps', type=int, default=500,
|
| 73 |
+
help='Save checkpoint every N steps')
|
| 74 |
+
parser.add_argument('--eval_steps', type=int, default=100,
|
| 75 |
+
help='Evaluate every N steps')
|
| 76 |
+
parser.add_argument('--logging_steps', type=int, default=10,
|
| 77 |
+
help='Log every N steps')
|
| 78 |
+
|
| 79 |
+
return parser.parse_args()
|
| 80 |
+
|
| 81 |
+
def main():
|
| 82 |
+
"""Main training function"""
|
| 83 |
+
args = parse_args()
|
| 84 |
+
logger = setup_logging()
|
| 85 |
+
|
| 86 |
+
logger.info("Starting SmolLM3 fine-tuning...")
|
| 87 |
+
logger.info(f"Arguments: {vars(args)}")
|
| 88 |
+
|
| 89 |
+
# Load configuration
|
| 90 |
+
config = get_config(args.config)
|
| 91 |
+
|
| 92 |
+
# Override config with command line arguments
|
| 93 |
+
if args.max_iters is not None:
|
| 94 |
+
config.max_iters = args.max_iters
|
| 95 |
+
if args.batch_size is not None:
|
| 96 |
+
config.batch_size = args.batch_size
|
| 97 |
+
if args.learning_rate is not None:
|
| 98 |
+
config.learning_rate = args.learning_rate
|
| 99 |
+
if args.gradient_accumulation_steps is not None:
|
| 100 |
+
config.gradient_accumulation_steps = args.gradient_accumulation_steps
|
| 101 |
+
|
| 102 |
+
# Setup paths
|
| 103 |
+
dataset_path = os.path.join('/input', args.dataset_dir)
|
| 104 |
+
output_path = args.out_dir
|
| 105 |
+
|
| 106 |
+
# Ensure output directory exists
|
| 107 |
+
os.makedirs(output_path, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
logger.info(f"Dataset path: {dataset_path}")
|
| 110 |
+
logger.info(f"Output path: {output_path}")
|
| 111 |
+
|
| 112 |
+
# Initialize model
|
| 113 |
+
model = SmolLM3Model(
|
| 114 |
+
model_name=args.model_name,
|
| 115 |
+
max_seq_length=args.max_seq_length,
|
| 116 |
+
config=config
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Load dataset
|
| 120 |
+
dataset = SmolLM3Dataset(
|
| 121 |
+
data_path=dataset_path,
|
| 122 |
+
tokenizer=model.tokenizer,
|
| 123 |
+
max_seq_length=args.max_seq_length
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Initialize trainer
|
| 127 |
+
trainer = SmolLM3Trainer(
|
| 128 |
+
model=model,
|
| 129 |
+
dataset=dataset,
|
| 130 |
+
config=config,
|
| 131 |
+
output_dir=output_path,
|
| 132 |
+
init_from=args.init_from
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Start training
|
| 136 |
+
try:
|
| 137 |
+
trainer.train()
|
| 138 |
+
logger.info("Training completed successfully!")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Training failed: {e}")
|
| 141 |
+
raise
|
| 142 |
+
|
| 143 |
+
if __name__ == '__main__':
|
| 144 |
+
main()
|
trainer.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolLM3 Trainer
|
| 3 |
+
Handles the training loop and integrates with Hugging Face Trainer
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Optional, Dict, Any
|
| 10 |
+
from transformers import Trainer, TrainingArguments
|
| 11 |
+
from trl import SFTTrainer
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class SmolLM3Trainer:
|
| 17 |
+
"""Trainer for SmolLM3 fine-tuning"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
model,
|
| 22 |
+
dataset,
|
| 23 |
+
config,
|
| 24 |
+
output_dir: str,
|
| 25 |
+
init_from: str = "scratch",
|
| 26 |
+
use_sft_trainer: bool = True
|
| 27 |
+
):
|
| 28 |
+
self.model = model
|
| 29 |
+
self.dataset = dataset
|
| 30 |
+
self.config = config
|
| 31 |
+
self.output_dir = output_dir
|
| 32 |
+
self.init_from = init_from
|
| 33 |
+
self.use_sft_trainer = use_sft_trainer
|
| 34 |
+
|
| 35 |
+
# Setup trainer
|
| 36 |
+
self.trainer = self._setup_trainer()
|
| 37 |
+
|
| 38 |
+
def _setup_trainer(self):
|
| 39 |
+
"""Setup the trainer"""
|
| 40 |
+
logger.info("Setting up trainer")
|
| 41 |
+
|
| 42 |
+
# Get training arguments
|
| 43 |
+
training_args = self.model.get_training_arguments(
|
| 44 |
+
output_dir=self.output_dir,
|
| 45 |
+
save_steps=self.config.save_steps,
|
| 46 |
+
eval_steps=self.config.eval_steps,
|
| 47 |
+
logging_steps=self.config.logging_steps,
|
| 48 |
+
max_steps=self.config.max_iters,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Get datasets
|
| 52 |
+
train_dataset = self.dataset.get_train_dataset()
|
| 53 |
+
eval_dataset = self.dataset.get_eval_dataset()
|
| 54 |
+
|
| 55 |
+
# Get data collator
|
| 56 |
+
data_collator = self.dataset.get_data_collator()
|
| 57 |
+
|
| 58 |
+
if self.use_sft_trainer:
|
| 59 |
+
# Use SFTTrainer for supervised fine-tuning
|
| 60 |
+
trainer = SFTTrainer(
|
| 61 |
+
model=self.model.model,
|
| 62 |
+
tokenizer=self.model.tokenizer,
|
| 63 |
+
train_dataset=train_dataset,
|
| 64 |
+
eval_dataset=eval_dataset,
|
| 65 |
+
args=training_args,
|
| 66 |
+
data_collator=data_collator,
|
| 67 |
+
dataset_text_field="text",
|
| 68 |
+
max_seq_length=self.config.max_seq_length,
|
| 69 |
+
packing=False, # Disable packing for better control
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
# Use standard Trainer
|
| 73 |
+
trainer = Trainer(
|
| 74 |
+
model=self.model.model,
|
| 75 |
+
tokenizer=self.model.tokenizer,
|
| 76 |
+
args=training_args,
|
| 77 |
+
train_dataset=train_dataset,
|
| 78 |
+
eval_dataset=eval_dataset,
|
| 79 |
+
data_collator=data_collator,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return trainer
|
| 83 |
+
|
| 84 |
+
def load_checkpoint(self, checkpoint_path: str):
|
| 85 |
+
"""Load checkpoint for resuming training"""
|
| 86 |
+
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
| 87 |
+
|
| 88 |
+
if self.init_from == "resume":
|
| 89 |
+
# Load the model from checkpoint
|
| 90 |
+
self.model.load_checkpoint(checkpoint_path)
|
| 91 |
+
|
| 92 |
+
# Update trainer with loaded model
|
| 93 |
+
self.trainer.model = self.model.model
|
| 94 |
+
|
| 95 |
+
logger.info("Checkpoint loaded successfully")
|
| 96 |
+
elif self.init_from == "pretrained":
|
| 97 |
+
# Model is already loaded from pretrained
|
| 98 |
+
logger.info("Using pretrained model")
|
| 99 |
+
else:
|
| 100 |
+
logger.info("Starting from scratch")
|
| 101 |
+
|
| 102 |
+
def train(self):
|
| 103 |
+
"""Start training"""
|
| 104 |
+
logger.info("Starting training")
|
| 105 |
+
|
| 106 |
+
# Load checkpoint if resuming
|
| 107 |
+
if self.init_from == "resume":
|
| 108 |
+
checkpoint_path = "/input-checkpoint"
|
| 109 |
+
if os.path.exists(checkpoint_path):
|
| 110 |
+
self.load_checkpoint(checkpoint_path)
|
| 111 |
+
else:
|
| 112 |
+
logger.warning(f"Checkpoint path {checkpoint_path} not found, starting from scratch")
|
| 113 |
+
|
| 114 |
+
# Start training
|
| 115 |
+
try:
|
| 116 |
+
train_result = self.trainer.train()
|
| 117 |
+
|
| 118 |
+
# Save the final model
|
| 119 |
+
self.trainer.save_model()
|
| 120 |
+
|
| 121 |
+
# Save training results
|
| 122 |
+
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
|
| 123 |
+
json.dump(train_result.metrics, f, indent=2)
|
| 124 |
+
|
| 125 |
+
logger.info("Training completed successfully!")
|
| 126 |
+
logger.info(f"Training metrics: {train_result.metrics}")
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.error(f"Training failed: {e}")
|
| 130 |
+
raise
|
| 131 |
+
|
| 132 |
+
def evaluate(self):
|
| 133 |
+
"""Evaluate the model"""
|
| 134 |
+
logger.info("Starting evaluation")
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
eval_results = self.trainer.evaluate()
|
| 138 |
+
|
| 139 |
+
# Save evaluation results
|
| 140 |
+
with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f:
|
| 141 |
+
json.dump(eval_results, f, indent=2)
|
| 142 |
+
|
| 143 |
+
logger.info(f"Evaluation completed: {eval_results}")
|
| 144 |
+
return eval_results
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"Evaluation failed: {e}")
|
| 148 |
+
raise
|
| 149 |
+
|
| 150 |
+
def save_model(self, path: Optional[str] = None):
|
| 151 |
+
"""Save the trained model"""
|
| 152 |
+
save_path = path or self.output_dir
|
| 153 |
+
logger.info(f"Saving model to {save_path}")
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
self.trainer.save_model(save_path)
|
| 157 |
+
self.model.tokenizer.save_pretrained(save_path)
|
| 158 |
+
|
| 159 |
+
# Save training configuration
|
| 160 |
+
if self.config:
|
| 161 |
+
config_dict = {k: v for k, v in self.config.__dict__.items()
|
| 162 |
+
if not k.startswith('_')}
|
| 163 |
+
with open(os.path.join(save_path, 'training_config.json'), 'w') as f:
|
| 164 |
+
json.dump(config_dict, f, indent=2, default=str)
|
| 165 |
+
|
| 166 |
+
logger.info("Model saved successfully!")
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.error(f"Failed to save model: {e}")
|
| 170 |
+
raise
|
| 171 |
+
|
| 172 |
+
class SmolLM3DPOTrainer:
|
| 173 |
+
"""DPO Trainer for SmolLM3 preference optimization"""
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
model,
|
| 178 |
+
dataset,
|
| 179 |
+
config,
|
| 180 |
+
output_dir: str,
|
| 181 |
+
ref_model=None
|
| 182 |
+
):
|
| 183 |
+
self.model = model
|
| 184 |
+
self.dataset = dataset
|
| 185 |
+
self.config = config
|
| 186 |
+
self.output_dir = output_dir
|
| 187 |
+
self.ref_model = ref_model
|
| 188 |
+
|
| 189 |
+
# Setup DPO trainer
|
| 190 |
+
self.trainer = self._setup_dpo_trainer()
|
| 191 |
+
|
| 192 |
+
def _setup_dpo_trainer(self):
|
| 193 |
+
"""Setup DPO trainer"""
|
| 194 |
+
from trl import DPOTrainer
|
| 195 |
+
|
| 196 |
+
# Get training arguments
|
| 197 |
+
training_args = self.model.get_training_arguments(
|
| 198 |
+
output_dir=self.output_dir,
|
| 199 |
+
save_steps=self.config.save_steps,
|
| 200 |
+
eval_steps=self.config.eval_steps,
|
| 201 |
+
logging_steps=self.config.logging_steps,
|
| 202 |
+
max_steps=self.config.max_iters,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Get preference dataset
|
| 206 |
+
train_dataset = self.dataset.get_train_dataset()
|
| 207 |
+
eval_dataset = self.dataset.get_eval_dataset()
|
| 208 |
+
|
| 209 |
+
# Setup DPO trainer
|
| 210 |
+
trainer = DPOTrainer(
|
| 211 |
+
model=self.model.model,
|
| 212 |
+
ref_model=self.ref_model,
|
| 213 |
+
args=training_args,
|
| 214 |
+
train_dataset=train_dataset,
|
| 215 |
+
eval_dataset=eval_dataset,
|
| 216 |
+
tokenizer=self.model.tokenizer,
|
| 217 |
+
max_prompt_length=self.config.max_seq_length // 2,
|
| 218 |
+
max_length=self.config.max_seq_length,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
return trainer
|
| 222 |
+
|
| 223 |
+
def train(self):
|
| 224 |
+
"""Start DPO training"""
|
| 225 |
+
logger.info("Starting DPO training")
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
train_result = self.trainer.train()
|
| 229 |
+
|
| 230 |
+
# Save the final model
|
| 231 |
+
self.trainer.save_model()
|
| 232 |
+
|
| 233 |
+
# Save training results
|
| 234 |
+
with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f:
|
| 235 |
+
json.dump(train_result.metrics, f, indent=2)
|
| 236 |
+
|
| 237 |
+
logger.info("DPO training completed successfully!")
|
| 238 |
+
logger.info(f"Training metrics: {train_result.metrics}")
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.error(f"DPO training failed: {e}")
|
| 242 |
+
raise
|