File size: 2,317 Bytes
56ff037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from config import Config
from load_json import load_examples

from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans

class SemanticClassifier:
    def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", initialized_train=True):
        self.model = SentenceTransformer(model_name)
        self.clusters = {}
        self.examples_embeddings = None
        self.kmeans = None
        if initialized_train:
            self.train()

    def train(self, train_data=Config.EXMAPLES_JSON, n_clusters=15):
        examples = load_examples(train_data)
        
        # * Aplanar ejemplos
        flat_examples = []
        for category, items in examples.items():
            for item in items:
                flat_examples.append({
                    "category": category,
                    "pregunta": item["pregunta"],
                    "query": item["query"]
                })
        
        questions = [ex["pregunta"] for ex in flat_examples]
        
        # * Obtener embeddings
        embeddings = self.model.encode(questions)

        # * Clustering
        self.kmeans = KMeans(n_clusters=n_clusters, random_state=12)
        cluster_ids = self.kmeans.fit_predict(embeddings)

        # * Guardar ejemplos por cluster
        for i, cluster_id in enumerate(cluster_ids):
            # * Crear lista si no existe
            if cluster_id not in self.clusters:
                self.clusters[cluster_id] = []
            # * Agregamos el ejemplo
            self.clusters[cluster_id].append(flat_examples[i])
        self.examples_embeddings = embeddings
    
    def classify(self, question: str): 
        # * En formato de embedding
        question_embedding = self.model.encode([question])
        # * Encontrar el cluster más cercano
        cluster_id = self.kmeans.predict(question_embedding)[0]
        # * Retornamos los ejemplos de ese cluster
        return self.clusters.get(cluster_id, [])
    
    
# * FORMA DE USARSE
# classifier = SemanticClassifier()
# classifier.train(Config.EXMAPLES_JSON, n_clusters=5)

# resultado = classifier.classify("¿Cuantas ciudades tenemos registradas?")
# print(resultado)  # te devuelve ejemplos de ese cluster