Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
from sklearn.cluster import KMeans | |
from config import Config | |
from load_json import load_examples | |
from sentence_transformers import SentenceTransformer | |
from sklearn.cluster import KMeans | |
class SemanticClassifier: | |
def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", initialized_train=True): | |
self.model = SentenceTransformer(model_name) | |
self.clusters = {} | |
self.examples_embeddings = None | |
self.kmeans = None | |
if initialized_train: | |
self.train() | |
def train(self, train_data=Config.EXMAPLES_JSON, n_clusters=15): | |
examples = load_examples(train_data) | |
# * Aplanar ejemplos | |
flat_examples = [] | |
for category, items in examples.items(): | |
for item in items: | |
flat_examples.append({ | |
"category": category, | |
"pregunta": item["pregunta"], | |
"query": item["query"] | |
}) | |
questions = [ex["pregunta"] for ex in flat_examples] | |
# * Obtener embeddings | |
embeddings = self.model.encode(questions) | |
# * Clustering | |
self.kmeans = KMeans(n_clusters=n_clusters, random_state=12) | |
cluster_ids = self.kmeans.fit_predict(embeddings) | |
# * Guardar ejemplos por cluster | |
for i, cluster_id in enumerate(cluster_ids): | |
# * Crear lista si no existe | |
if cluster_id not in self.clusters: | |
self.clusters[cluster_id] = [] | |
# * Agregamos el ejemplo | |
self.clusters[cluster_id].append(flat_examples[i]) | |
self.examples_embeddings = embeddings | |
def classify(self, question: str): | |
# * En formato de embedding | |
question_embedding = self.model.encode([question]) | |
# * Encontrar el cluster más cercano | |
cluster_id = self.kmeans.predict(question_embedding)[0] | |
# * Retornamos los ejemplos de ese cluster | |
return self.clusters.get(cluster_id, []) | |
# * FORMA DE USARSE | |
# classifier = SemanticClassifier() | |
# classifier.train(Config.EXMAPLES_JSON, n_clusters=5) | |
# resultado = classifier.classify("¿Cuantas ciudades tenemos registradas?") | |
# print(resultado) # te devuelve ejemplos de ese cluster |