from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, ) from config import Config import json from datasets import Dataset import torch class QuestionClassifier: def __init__( self, model_name="distilbert-base-multilingual-cased", initialized_train=True ): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model_name = model_name self.category2id = None self.category2id = None if initialized_train: self.train() def train(self, json_path=Config.EXMAPLES_JSON, num_epochs=3): # * Cargar ejemplos with open(json_path, "r", encoding="utf-8") as f: examples = json.load(f) texts, labels, category2id = self._prepare_supervised_data(examples) self.category2id = category2id self.id2category = {value: key for key, value in category2id.items()} self.model = AutoModelForSequenceClassification.from_pretrained( self.model_name, num_labels=len(category2id) ) encodings = self.tokenizer(texts, truncation=True, padding=True) dataset = Dataset.from_dict( { "input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": labels, } ) training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=8, num_train_epochs=num_epochs, logging_steps=1, # logging_strategy="steps", report_to="none", save_strategy="no", remove_unused_columns=False, eval_strategy="no", ) # 4. Trainer trainer = Trainer(model=self.model, args=training_args, train_dataset=dataset) trainer.train() def _prepare_supervised_data(self, examples): category2id = {cat: i for i, cat in enumerate(examples.keys())} texts = [] labels = [] for category, items in examples.items(): for item in items: texts.append(item["pregunta"]) labels.append(category2id[category]) return texts, labels, category2id def predict(self, question: str): device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(device) inputs = self.tokenizer( question, return_tensors="pt", truncation=True, padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) predicted_class_id = outputs.logits.argmax().item() return self.id2category[predicted_class_id] # * FORMA DE USARSE # qc = QuestionClassifier() # qc.train() # categoria = qc.predict("Dame los productos más vendidos") # print(categoria) # → 'PRODUCTOS'