MojicaPoC / semantic_classifier.py
Carlos Isael Ramírez González
Modelo nuevo completado
56ff037
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from config import Config
from load_json import load_examples
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
class SemanticClassifier:
def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", initialized_train=True):
self.model = SentenceTransformer(model_name)
self.clusters = {}
self.examples_embeddings = None
self.kmeans = None
if initialized_train:
self.train()
def train(self, train_data=Config.EXMAPLES_JSON, n_clusters=15):
examples = load_examples(train_data)
# * Aplanar ejemplos
flat_examples = []
for category, items in examples.items():
for item in items:
flat_examples.append({
"category": category,
"pregunta": item["pregunta"],
"query": item["query"]
})
questions = [ex["pregunta"] for ex in flat_examples]
# * Obtener embeddings
embeddings = self.model.encode(questions)
# * Clustering
self.kmeans = KMeans(n_clusters=n_clusters, random_state=12)
cluster_ids = self.kmeans.fit_predict(embeddings)
# * Guardar ejemplos por cluster
for i, cluster_id in enumerate(cluster_ids):
# * Crear lista si no existe
if cluster_id not in self.clusters:
self.clusters[cluster_id] = []
# * Agregamos el ejemplo
self.clusters[cluster_id].append(flat_examples[i])
self.examples_embeddings = embeddings
def classify(self, question: str):
# * En formato de embedding
question_embedding = self.model.encode([question])
# * Encontrar el cluster más cercano
cluster_id = self.kmeans.predict(question_embedding)[0]
# * Retornamos los ejemplos de ese cluster
return self.clusters.get(cluster_id, [])
# * FORMA DE USARSE
# classifier = SemanticClassifier()
# classifier.train(Config.EXMAPLES_JSON, n_clusters=5)
# resultado = classifier.classify("¿Cuantas ciudades tenemos registradas?")
# print(resultado) # te devuelve ejemplos de ese cluster