Spaces:
Running
Running
""" | |
SmolLM3 Dataset Handler | |
Handles data loading, preprocessing, and tokenization for SmolLM3 fine-tuning | |
""" | |
import os | |
import json | |
import torch | |
from typing import Dict, List, Optional, Union | |
from datasets import Dataset, load_dataset | |
from transformers import PreTrainedTokenizer | |
import logging | |
logger = logging.getLogger(__name__) | |
class SmolLM3Dataset: | |
"""Dataset handler for SmolLM3 fine-tuning""" | |
def __init__( | |
self, | |
data_path: str, | |
tokenizer: PreTrainedTokenizer, | |
max_seq_length: int = 4096, | |
use_chat_template: bool = True, | |
chat_template_kwargs: Optional[Dict] = None, | |
filter_bad_entries: bool = False, | |
bad_entry_field: str = "bad_entry" | |
): | |
self.data_path = data_path | |
self.tokenizer = tokenizer | |
self.max_seq_length = max_seq_length | |
self.use_chat_template = use_chat_template | |
self.chat_template_kwargs = chat_template_kwargs or {} | |
self.filter_bad_entries = filter_bad_entries | |
self.bad_entry_field = bad_entry_field | |
# Load and process dataset | |
self.dataset = self._load_dataset() | |
self.processed_dataset = self._process_dataset() | |
def _load_dataset(self) -> Dataset: | |
"""Load dataset from various formats""" | |
logger.info(f"Loading dataset from {self.data_path}") | |
# Check if it's a Hugging Face dataset | |
if os.path.isdir(self.data_path): | |
# Local directory | |
try: | |
dataset = load_dataset("json", data_files={ | |
"train": os.path.join(self.data_path, "train.json"), | |
"validation": os.path.join(self.data_path, "validation.json") if os.path.exists(os.path.join(self.data_path, "validation.json")) else None, | |
"test": os.path.join(self.data_path, "test.json") if os.path.exists(os.path.join(self.data_path, "test.json")) else None | |
}) | |
logger.info("Loaded dataset from local JSON files") | |
return dataset | |
except Exception as e: | |
logger.warning(f"Failed to load as JSON dataset: {e}") | |
# Try to load as a single JSON file | |
if os.path.isfile(self.data_path) and self.data_path.endswith('.json'): | |
try: | |
with open(self.data_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
# Convert to dataset format | |
if isinstance(data, list): | |
dataset = Dataset.from_list(data) | |
else: | |
dataset = Dataset.from_dict(data) | |
logger.info("Loaded dataset from single JSON file") | |
return dataset | |
except Exception as e: | |
logger.error(f"Failed to load JSON file: {e}") | |
raise | |
# Try to load as a Hugging Face dataset name | |
try: | |
dataset = load_dataset(self.data_path) | |
logger.info(f"Loaded Hugging Face dataset: {self.data_path}") | |
# Filter bad entries if requested | |
if self.filter_bad_entries and self.bad_entry_field in dataset["train"].column_names: | |
logger.info(f"Filtering out bad entries using field: {self.bad_entry_field}") | |
for split in dataset: | |
if self.bad_entry_field in dataset[split].column_names: | |
original_size = len(dataset[split]) | |
dataset[split] = dataset[split].filter(lambda x: not x[self.bad_entry_field]) | |
filtered_size = len(dataset[split]) | |
logger.info(f"Filtered {split}: {original_size} -> {filtered_size} samples") | |
# If only 'train' split exists, create validation and test splits | |
if ("train" in dataset) and ("validation" not in dataset or "test" not in dataset): | |
logger.info("Automatically splitting train into train/validation/test (98/1/1)") | |
split_dataset = dataset["train"].train_test_split(test_size=0.02, seed=42) | |
# Now split test into validation and test (1% each) | |
val_test_split = split_dataset["test"].train_test_split(test_size=0.5, seed=42) | |
dataset = { | |
"train": split_dataset["train"], | |
"validation": val_test_split["train"], | |
"test": val_test_split["test"] | |
} | |
return dataset | |
except Exception as e: | |
logger.error(f"Failed to load dataset: {e}") | |
raise | |
def _process_dataset(self) -> Dataset: | |
"""Process the dataset for training""" | |
logger.info("Processing dataset for training") | |
def format_chat_template(example): | |
"""Format example using chat template""" | |
if self.use_chat_template: | |
try: | |
# Handle different input formats | |
if "messages" in example: | |
messages = example["messages"] | |
elif "conversations" in example: | |
messages = example["conversations"] | |
elif "user" in example and "assistant" in example: | |
messages = [ | |
{"role": "user", "content": example["user"]}, | |
{"role": "assistant", "content": example["assistant"]} | |
] | |
elif "instruction" in example and "output" in example: | |
messages = [ | |
{"role": "user", "content": example["instruction"]}, | |
{"role": "assistant", "content": example["output"]} | |
] | |
elif "prompt" in example and "completion" in example: | |
messages = [ | |
{"role": "user", "content": example["prompt"]}, | |
{"role": "assistant", "content": example["completion"]} | |
] | |
elif "prompt" in example and "accepted_completion" in example: | |
messages = [ | |
{"role": "user", "content": example["prompt"]}, | |
{"role": "assistant", "content": example["accepted_completion"]} | |
] | |
elif "prompt" in example and "completion" in example: | |
messages = [ | |
{"role": "user", "content": example["prompt"]}, | |
{"role": "assistant", "content": example["completion"]} | |
] | |
else: | |
# Fallback: treat as plain text | |
return {"text": str(example)} | |
# Add system message with /no_think tag if not present | |
if messages and messages[0]["role"] != "system": | |
# Check if we should add /no_think tag based on configuration | |
system_content = "You are a helpful assistant." | |
if hasattr(self, 'chat_template_kwargs') and self.chat_template_kwargs: | |
# If no_think_system_message is True, add /no_think tag | |
if self.chat_template_kwargs.get("no_think_system_message") == True: | |
system_content = "You are a helpful assistant. /no_think" | |
messages.insert(0, {"role": "system", "content": system_content}) | |
# Apply chat template | |
text = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=self.chat_template_kwargs.get("add_generation_prompt", True) | |
) | |
return {"text": text} | |
except Exception as e: | |
logger.warning(f"Failed to apply chat template: {e}") | |
# Fallback to plain text | |
return {"text": str(example)} | |
else: | |
# Use plain text | |
if "text" in example: | |
return {"text": example["text"]} | |
else: | |
return {"text": str(example)} | |
def tokenize_function(examples): | |
"""Tokenize the examples""" | |
# Tokenize the texts with fixed length | |
tokenized = self.tokenizer( | |
examples["text"], | |
truncation=True, | |
padding=True, # Enable padding during tokenization | |
max_length=self.max_seq_length, | |
return_overflowing_tokens=False, # Don't return overflowing tokens | |
return_length=True, | |
) | |
# Calculate input length | |
input_length = [len(x) for x in tokenized["input_ids"]] | |
# Create labels (same as input_ids for causal LM) | |
tokenized["labels"] = tokenized["input_ids"].copy() | |
return { | |
"input_ids": tokenized["input_ids"], | |
"attention_mask": tokenized["attention_mask"], | |
"labels": tokenized["labels"], | |
"length": input_length, | |
} | |
# Process the dataset - handle both single dataset and dictionary of splits | |
if isinstance(self.dataset, dict): | |
# Process each split individually | |
processed_dataset = {} | |
for split_name, split_dataset in self.dataset.items(): | |
logger.info(f"Processing {split_name} split...") | |
# Format the split | |
processed_split = split_dataset.map( | |
format_chat_template, | |
remove_columns=split_dataset.column_names, | |
desc=f"Formatting {split_name} dataset" | |
) | |
# Tokenize the split | |
tokenized_split = processed_split.map( | |
tokenize_function, | |
remove_columns=processed_split.column_names, | |
desc=f"Tokenizing {split_name} dataset", | |
batched=True, | |
) | |
processed_dataset[split_name] = tokenized_split | |
else: | |
# Single dataset | |
processed_dataset = self.dataset.map( | |
format_chat_template, | |
remove_columns=self.dataset.column_names, | |
desc="Formatting dataset" | |
) | |
# Tokenize the dataset | |
processed_dataset = processed_dataset.map( | |
tokenize_function, | |
remove_columns=processed_dataset.column_names, | |
desc="Tokenizing dataset", | |
batched=True, | |
) | |
# Log processing results | |
if isinstance(processed_dataset, dict): | |
logger.info(f"Dataset processed. Train samples: {len(processed_dataset['train'])}") | |
if "validation" in processed_dataset: | |
logger.info(f"Validation samples: {len(processed_dataset['validation'])}") | |
if "test" in processed_dataset: | |
logger.info(f"Test samples: {len(processed_dataset['test'])}") | |
else: | |
logger.info(f"Dataset processed. Samples: {len(processed_dataset)}") | |
return processed_dataset | |
def get_train_dataset(self) -> Dataset: | |
"""Get training dataset""" | |
return self.processed_dataset["train"] | |
def get_eval_dataset(self) -> Optional[Dataset]: | |
"""Get evaluation dataset if available""" | |
if "validation" in self.processed_dataset: | |
return self.processed_dataset["validation"] | |
elif "test" in self.processed_dataset: | |
return self.processed_dataset["test"] | |
else: | |
return None | |
def get_data_collator(self): | |
"""Get data collator for training""" | |
from transformers import DataCollatorForLanguageModeling | |
return DataCollatorForLanguageModeling( | |
tokenizer=self.tokenizer, | |
mlm=False, # We're doing causal LM, not masked LM | |
pad_to_multiple_of=8, # Pad to multiple of 8 for efficiency | |
return_tensors="pt", # Ensure we return PyTorch tensors | |
) | |
def create_sample_dataset(output_path: str = "my_dataset"): | |
"""Create a sample dataset for testing""" | |
os.makedirs(output_path, exist_ok=True) | |
# Sample conversations | |
conversations = [ | |
{ | |
"messages": [ | |
{"role": "user", "content": "What is machine learning?"}, | |
{"role": "assistant", "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed."} | |
] | |
}, | |
{ | |
"messages": [ | |
{"role": "user", "content": "Explain gravity in simple terms."}, | |
{"role": "assistant", "content": "Gravity is the force that pulls objects toward each other, like how the Earth pulls things down to the ground."} | |
] | |
}, | |
{ | |
"messages": [ | |
{"role": "user", "content": "How do I make a cup of coffee?"}, | |
{"role": "assistant", "content": "To make a cup of coffee: 1) Boil water, 2) Add coffee grounds to a filter, 3) Pour hot water over the grounds, 4) Let it brew for a few minutes, 5) Enjoy!"} | |
] | |
} | |
] | |
# Split into train/validation | |
train_data = conversations[:2] | |
validation_data = conversations[2:] | |
# Save to files | |
with open(os.path.join(output_path, "train.json"), 'w', encoding='utf-8') as f: | |
json.dump(train_data, f, indent=2, ensure_ascii=False) | |
with open(os.path.join(output_path, "validation.json"), 'w', encoding='utf-8') as f: | |
json.dump(validation_data, f, indent=2, ensure_ascii=False) | |
logger.info(f"Sample dataset created in {output_path}") | |
return output_path |