Insight_DKG / embedding.py
jeysshon's picture
Update embedding.py
63dc01f verified
# embedding.py
import logging
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from chroma_setup import initialize_client
import uuid
# Creamos una instancia del modelo local de sentence-transformers
# (se descargará y cacheará la primera vez que se ejecute)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
def embed_text_chunks(pages_and_chunks: list[dict]) -> pd.DataFrame:
"""
Genera embeddings para cada chunk de texto usando un modelo local
de sentence-transformers.
"""
for item in pages_and_chunks:
text_chunk = item["sentence_chunk"]
try:
# encode() acepta una lista de strings y retorna una lista de embeddings (ndarray).
embedding_array = model.encode([text_chunk])
# Devuelve una matriz shape (1, 384) si es all-MiniLM-L6-v2, así que tomamos el [0]
embedding = embedding_array[0].tolist()
# embedding ahora es una lista de floats
item["embedding"] = embedding
except Exception as e:
logging.error(f"Fallo al generar embedding para: {text_chunk}. Error: {e}")
item["embedding"] = None
return pd.DataFrame(pages_and_chunks)
def save_to_chroma_db(embeddings_df: pd.DataFrame, user_id: str, document_id: str):
"""
Guarda en ChromaDB los embeddings generados.
"""
client = initialize_client()
# Creas o recuperas la colección. Asegúrate de usar el mismo nombre
# que luego usarás en tus queries.
collection = client.get_or_create_collection(name=f"text_embeddings_{user_id}")
combined_key = f"{user_id}_{document_id}"
ids = [f"{combined_key}_{i}" for i in range(len(embeddings_df))]
documents = embeddings_df["sentence_chunk"].tolist()
embeddings = embeddings_df["embedding"].tolist()
# Verificamos que ninguno sea None
for idx, emb in enumerate(embeddings):
if emb is None:
raise ValueError(
f"El chunk con ID {ids[idx]} no tiene embedding válido (None)."
)
# ¡Ahora todos deben ser listas de floats!
# Podemos añadirlos a la colección:
collection.add(
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=[{"combined_key": combined_key} for _ in range(len(embeddings_df))]
)
def generate_document_id() -> str:
return str(uuid.uuid4())
def query_chroma_db(user_id: str, document_id: str, query: str):
client = initialize_client()
collection = client.get_collection(name=f"text_embeddings_{user_id}")
combined_key = f"{user_id}_{document_id}"
results = collection.query(
query_texts=[query],
n_results=5,
where={"combined_key": combined_key},
)
documents = results.get("documents", [])
if not documents:
return "No se encontraron documentos"
# Aplanar la lista de documentos
relevant_docs = [doc for sublist in documents for doc in sublist]
return "\n\n".join(relevant_docs)