File size: 7,661 Bytes
d8dd7a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07eab17
 
 
cb932c7
07eab17
 
 
 
 
 
 
 
 
 
 
 
 
d8dd7a1
 
 
07eab17
d8dd7a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c97eb6
d8dd7a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
SmolLM3 Model Wrapper
Handles model loading, tokenizer, and training setup
"""

import os
import torch
import torch.nn as nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    TrainingArguments,
    Trainer
)
from typing import Optional, Dict, Any
import logging

logger = logging.getLogger(__name__)

class SmolLM3Model:
    """Wrapper for SmolLM3 model and tokenizer"""
    
    def __init__(
        self,
        model_name: str = "HuggingFaceTB/SmolLM3-3B",
        max_seq_length: int = 4096,
        config: Optional[Any] = None,
        device_map: Optional[str] = None,
        torch_dtype: Optional[torch.dtype] = None
    ):
        self.model_name = model_name
        self.max_seq_length = max_seq_length
        self.config = config
        
        # Set device and dtype
        if torch_dtype is None:
            if torch.cuda.is_available():
                self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
            else:
                self.torch_dtype = torch.float32
        else:
            self.torch_dtype = torch_dtype
            
        if device_map is None:
            self.device_map = "auto" if torch.cuda.is_available() else "cpu"
        else:
            self.device_map = device_map
        
        # Load tokenizer and model
        self._load_tokenizer()
        self._load_model()
        
    def _load_tokenizer(self):
        """Load the tokenizer"""
        logger.info(f"Loading tokenizer from {self.model_name}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True,
                use_fast=True
            )
            
            # Set pad token if not present
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            logger.info(f"Tokenizer loaded successfully. Vocab size: {self.tokenizer.vocab_size}")
            
        except Exception as e:
            logger.error(f"Failed to load tokenizer: {e}")
            raise
    
    def _load_model(self):
        """Load the model"""
        logger.info(f"Loading model from {self.model_name}")
        try:
            # Load model configuration
            model_config = AutoConfig.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
            
            # Update configuration if needed
            if hasattr(model_config, 'max_position_embeddings'):
                model_config.max_position_embeddings = self.max_seq_length
            
            # Load model
            model_kwargs = {
                "torch_dtype": self.torch_dtype,
                "device_map": self.device_map,
                "trust_remote_code": True
            }
            
            # Only add flash attention if the model supports it
            if hasattr(self.config, 'use_flash_attention') and self.config.use_flash_attention:
                try:
                    # Test if the model supports flash attention
                    test_config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=True)
                    if hasattr(test_config, 'use_flash_attention_2'):
                        model_kwargs["use_flash_attention_2"] = True
                except:
                    # If flash attention is not supported, skip it
                    pass
            
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                config=model_config,
                **model_kwargs
            )
            
            # Enable gradient checkpointing if specified
            if self.config and self.config.use_gradient_checkpointing:
                self.model.gradient_checkpointing_enable()
            
            logger.info(f"Model loaded successfully. Parameters: {self.model.num_parameters():,}")
            
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise
    
    def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments:
        """Get training arguments for the Trainer"""
        if self.config is None:
            raise ValueError("Config is required to get training arguments")
        
        # Merge config with kwargs
        training_args = {
            "output_dir": output_dir,
            "per_device_train_batch_size": self.config.batch_size,
            "per_device_eval_batch_size": self.config.batch_size,
            "gradient_accumulation_steps": self.config.gradient_accumulation_steps,
            "learning_rate": self.config.learning_rate,
            "weight_decay": self.config.weight_decay,
            "warmup_steps": self.config.warmup_steps,
            "max_steps": self.config.max_iters,
            "save_steps": self.config.save_steps,
            "eval_steps": self.config.eval_steps,
            "logging_steps": self.config.logging_steps,
            "save_total_limit": self.config.save_total_limit,
            "eval_strategy": self.config.eval_strategy,
            "metric_for_best_model": self.config.metric_for_best_model,
            "greater_is_better": self.config.greater_is_better,
            "load_best_model_at_end": self.config.load_best_model_at_end,
            "fp16": self.config.fp16,
            "bf16": self.config.bf16,
            "ddp_backend": self.config.ddp_backend,
            "ddp_find_unused_parameters": self.config.ddp_find_unused_parameters,
            "report_to": "none",  # Disable external logging
            "remove_unused_columns": False,
            "dataloader_pin_memory": False,
            "group_by_length": True,
            "length_column_name": "length",
            "ignore_data_skip": False,
            "seed": 42,
            "data_seed": 42,
            "dataloader_num_workers": 4,
            "max_grad_norm": 1.0,
            "optim": self.config.optimizer,
            "lr_scheduler_type": self.config.scheduler,
            "warmup_ratio": 0.1,
            "save_strategy": "steps",
            "logging_strategy": "steps",
            "prediction_loss_only": True,
        }
        
        # Override with kwargs
        training_args.update(kwargs)
        
        return TrainingArguments(**training_args)
    
    def save_pretrained(self, path: str):
        """Save model and tokenizer"""
        logger.info(f"Saving model and tokenizer to {path}")
        os.makedirs(path, exist_ok=True)
        
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
        
        # Save configuration
        if self.config:
            import json
            config_dict = {k: v for k, v in self.config.__dict__.items() 
                          if not k.startswith('_')}
            with open(os.path.join(path, 'training_config.json'), 'w') as f:
                json.dump(config_dict, f, indent=2, default=str)
    
    def load_checkpoint(self, checkpoint_path: str):
        """Load model from checkpoint"""
        logger.info(f"Loading checkpoint from {checkpoint_path}")
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                checkpoint_path,
                torch_dtype=self.torch_dtype,
                device_map=self.device_map,
                trust_remote_code=True
            )
            logger.info("Checkpoint loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}")
            raise