MojicaPoC / supervised_classifier.py
Carlos Isael Ramírez González
Modelo nuevo completado
56ff037
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
)
from config import Config
import json
from datasets import Dataset
import torch
class QuestionClassifier:
def __init__(
self, model_name="distilbert-base-multilingual-cased", initialized_train=True
):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model_name = model_name
self.category2id = None
self.category2id = None
if initialized_train:
self.train()
def train(self, json_path=Config.EXMAPLES_JSON, num_epochs=3):
# * Cargar ejemplos
with open(json_path, "r", encoding="utf-8") as f:
examples = json.load(f)
texts, labels, category2id = self._prepare_supervised_data(examples)
self.category2id = category2id
self.id2category = {value: key for key, value in category2id.items()}
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name, num_labels=len(category2id)
)
encodings = self.tokenizer(texts, truncation=True, padding=True)
dataset = Dataset.from_dict(
{
"input_ids": encodings["input_ids"],
"attention_mask": encodings["attention_mask"],
"labels": labels,
}
)
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
num_train_epochs=num_epochs,
logging_steps=1,
# logging_strategy="steps",
report_to="none",
save_strategy="no",
remove_unused_columns=False,
eval_strategy="no",
)
# 4. Trainer
trainer = Trainer(model=self.model, args=training_args, train_dataset=dataset)
trainer.train()
def _prepare_supervised_data(self, examples):
category2id = {cat: i for i, cat in enumerate(examples.keys())}
texts = []
labels = []
for category, items in examples.items():
for item in items:
texts.append(item["pregunta"])
labels.append(category2id[category])
return texts, labels, category2id
def predict(self, question: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(device)
inputs = self.tokenizer(
question, return_tensors="pt", truncation=True, padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
predicted_class_id = outputs.logits.argmax().item()
return self.id2category[predicted_class_id]
# * FORMA DE USARSE
# qc = QuestionClassifier()
# qc.train()
# categoria = qc.predict("Dame los productos más vendidos")
# print(categoria) # → 'PRODUCTOS'