llmguard / model /train_classifier.py
Tuathe's picture
Clean repo without large checkpoint files
6826247
# model/train_classifier.py
from datasets import load_dataset, DatasetDict
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
import torch
import numpy as np
from sklearn.metrics import accuracy_score
# Load the dataset
dataset = load_dataset("csv", data_files="data/cleaned_injection_prompts.csv")
# Convert string labels to integers (0 for safe, 1 for injection)
def encode_labels(example):
example["label"] = int(example["label"])
return example
dataset = dataset.map(encode_labels)
# Split into train and test
dataset = dataset["train"].train_test_split(test_size=0.1)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
# Tokenize inputs
def tokenize(example):
return tokenizer(example["text"], padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize, batched=True)
# Load model
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
# Metrics
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return {"accuracy": accuracy_score(labels, predictions)}
# Training arguments
args = TrainingArguments(
output_dir="./model/injection_classifier",
evaluation_strategy="epoch",
logging_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=4,
weight_decay=0.01,
save_total_limit=1,
load_best_model_at_end=True,
metric_for_best_model="accuracy"
)
# Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
# Train the model
trainer.train()
# Save the model
trainer.save_model("model/injection_classifier")
tokenizer.save_pretrained("model/injection_classifier")