# embedding.py import logging import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from chroma_setup import initialize_client import uuid # Creamos una instancia del modelo local de sentence-transformers # (se descargará y cacheará la primera vez que se ejecute) model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') def embed_text_chunks(pages_and_chunks: list[dict]) -> pd.DataFrame: """ Genera embeddings para cada chunk de texto usando un modelo local de sentence-transformers. """ for item in pages_and_chunks: text_chunk = item["sentence_chunk"] try: # encode() acepta una lista de strings y retorna una lista de embeddings (ndarray). embedding_array = model.encode([text_chunk]) # Devuelve una matriz shape (1, 384) si es all-MiniLM-L6-v2, así que tomamos el [0] embedding = embedding_array[0].tolist() # embedding ahora es una lista de floats item["embedding"] = embedding except Exception as e: logging.error(f"Fallo al generar embedding para: {text_chunk}. Error: {e}") item["embedding"] = None return pd.DataFrame(pages_and_chunks) def save_to_chroma_db(embeddings_df: pd.DataFrame, user_id: str, document_id: str): """ Guarda en ChromaDB los embeddings generados. """ client = initialize_client() # Creas o recuperas la colección. Asegúrate de usar el mismo nombre # que luego usarás en tus queries. collection = client.get_or_create_collection(name=f"text_embeddings_{user_id}") combined_key = f"{user_id}_{document_id}" ids = [f"{combined_key}_{i}" for i in range(len(embeddings_df))] documents = embeddings_df["sentence_chunk"].tolist() embeddings = embeddings_df["embedding"].tolist() # Verificamos que ninguno sea None for idx, emb in enumerate(embeddings): if emb is None: raise ValueError( f"El chunk con ID {ids[idx]} no tiene embedding válido (None)." ) # ¡Ahora todos deben ser listas de floats! # Podemos añadirlos a la colección: collection.add( documents=documents, embeddings=embeddings, ids=ids, metadatas=[{"combined_key": combined_key} for _ in range(len(embeddings_df))] ) def generate_document_id() -> str: return str(uuid.uuid4()) def query_chroma_db(user_id: str, document_id: str, query: str): client = initialize_client() collection = client.get_collection(name=f"text_embeddings_{user_id}") combined_key = f"{user_id}_{document_id}" results = collection.query( query_texts=[query], n_results=5, where={"combined_key": combined_key}, ) documents = results.get("documents", []) if not documents: return "No se encontraron documentos" # Aplanar la lista de documentos relevant_docs = [doc for sublist in documents for doc in sublist] return "\n\n".join(relevant_docs)