SmolFactory / model.py
Tonic's picture
removes flash attention 2
07eab17 verified
raw
history blame
7.74 kB
"""
SmolLM3 Model Wrapper
Handles model loading, tokenizer, and training setup
"""
import os
import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
TrainingArguments,
Trainer
)
from typing import Optional, Dict, Any
import logging
logger = logging.getLogger(__name__)
class SmolLM3Model:
"""Wrapper for SmolLM3 model and tokenizer"""
def __init__(
self,
model_name: str = "HuggingFaceTB/SmolLM3-3B",
max_seq_length: int = 4096,
config: Optional[Any] = None,
device_map: Optional[str] = None,
torch_dtype: Optional[torch.dtype] = None
):
self.model_name = model_name
self.max_seq_length = max_seq_length
self.config = config
# Set device and dtype
if torch_dtype is None:
if torch.cuda.is_available():
self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
else:
self.torch_dtype = torch.float32
else:
self.torch_dtype = torch_dtype
if device_map is None:
self.device_map = "auto" if torch.cuda.is_available() else "cpu"
else:
self.device_map = device_map
# Load tokenizer and model
self._load_tokenizer()
self._load_model()
def _load_tokenizer(self):
"""Load the tokenizer"""
logger.info(f"Loading tokenizer from {self.model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
use_fast=True
)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(f"Tokenizer loaded successfully. Vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
raise
def _load_model(self):
"""Load the model"""
logger.info(f"Loading model from {self.model_name}")
try:
# Load model configuration
model_config = AutoConfig.from_pretrained(
self.model_name,
trust_remote_code=True
)
# Update configuration if needed
if hasattr(model_config, 'max_position_embeddings'):
model_config.max_position_embeddings = self.max_seq_length
# Load model
model_kwargs = {
"torch_dtype": self.torch_dtype,
"device_map": self.device_map,
"trust_remote_code": True,
"use_cache": False # Disable KV cache for training
}
# Only add flash attention if the model supports it
if hasattr(self.config, 'use_flash_attention') and self.config.use_flash_attention:
try:
# Test if the model supports flash attention
test_config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=True)
if hasattr(test_config, 'use_flash_attention_2'):
model_kwargs["use_flash_attention_2"] = True
except:
# If flash attention is not supported, skip it
pass
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
config=model_config,
**model_kwargs
)
# Enable gradient checkpointing if specified
if self.config and self.config.use_gradient_checkpointing:
self.model.gradient_checkpointing_enable()
logger.info(f"Model loaded successfully. Parameters: {self.model.num_parameters():,}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments:
"""Get training arguments for the Trainer"""
if self.config is None:
raise ValueError("Config is required to get training arguments")
# Merge config with kwargs
training_args = {
"output_dir": output_dir,
"per_device_train_batch_size": self.config.batch_size,
"per_device_eval_batch_size": self.config.batch_size,
"gradient_accumulation_steps": self.config.gradient_accumulation_steps,
"learning_rate": self.config.learning_rate,
"weight_decay": self.config.weight_decay,
"warmup_steps": self.config.warmup_steps,
"max_steps": self.config.max_iters,
"save_steps": self.config.save_steps,
"eval_steps": self.config.eval_steps,
"logging_steps": self.config.logging_steps,
"save_total_limit": self.config.save_total_limit,
"evaluation_strategy": self.config.eval_strategy,
"metric_for_best_model": self.config.metric_for_best_model,
"greater_is_better": self.config.greater_is_better,
"load_best_model_at_end": self.config.load_best_model_at_end,
"fp16": self.config.fp16,
"bf16": self.config.bf16,
"ddp_backend": self.config.ddp_backend,
"ddp_find_unused_parameters": self.config.ddp_find_unused_parameters,
"report_to": "none", # Disable external logging
"remove_unused_columns": False,
"dataloader_pin_memory": False,
"group_by_length": True,
"length_column_name": "length",
"ignore_data_skip": False,
"seed": 42,
"data_seed": 42,
"dataloader_num_workers": 4,
"max_grad_norm": 1.0,
"optim": self.config.optimizer,
"lr_scheduler_type": self.config.scheduler,
"warmup_ratio": 0.1,
"save_strategy": "steps",
"logging_strategy": "steps",
"prediction_loss_only": True,
}
# Override with kwargs
training_args.update(kwargs)
return TrainingArguments(**training_args)
def save_pretrained(self, path: str):
"""Save model and tokenizer"""
logger.info(f"Saving model and tokenizer to {path}")
os.makedirs(path, exist_ok=True)
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)
# Save configuration
if self.config:
import json
config_dict = {k: v for k, v in self.config.__dict__.items()
if not k.startswith('_')}
with open(os.path.join(path, 'training_config.json'), 'w') as f:
json.dump(config_dict, f, indent=2, default=str)
def load_checkpoint(self, checkpoint_path: str):
"""Load model from checkpoint"""
logger.info(f"Loading checkpoint from {checkpoint_path}")
try:
self.model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
torch_dtype=self.torch_dtype,
device_map=self.device_map,
trust_remote_code=True
)
logger.info("Checkpoint loaded successfully")
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise