Insight_DKG / embedding.py
jeysshon's picture
Create embedding.py
5d0bb3c verified
raw
history blame
6.25 kB
import time
import requests
from requests.exceptions import ReadTimeout, HTTPError
import logging
import json
import pandas as pd
import chromadb
from chromadb.utils import embedding_functions
import os
from dotenv import load_dotenv
import datetime
import uuid
from chroma_setup import initialize_client
import numpy as np
# Carga las variables de entorno
load_dotenv()
def get_embedding_model():
"""
Retorna una función de incrustación (embedding) basada en un modelo de HuggingFace.
Lee la clave de la API desde las variables de entorno.
"""
return embedding_functions.HuggingFaceEmbeddingFunction(
api_key=os.getenv("HUGGINGFACE_API_KEY"),
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
def embed_with_retry(embedding_model, text_chunk, max_retries=3, backoff_factor=2):
"""
Reintenta la generación de embeddings en caso de errores de timeout o límites de la API.
Parámetros:
-----------
embedding_model : objeto de función
Función de incrustación proporcionada por HuggingFaceEmbeddingFunction.
text_chunk : str
Texto a convertir en embedding.
max_retries : int
Máximo número de reintentos.
backoff_factor : int
Factor de espera exponencial antes de cada reintento.
Retorna:
--------
list[float]
Lista de valores flotantes que representan el embedding del texto.
"""
retries = 0
while retries < max_retries:
try:
embedding = embedding_model(input=text_chunk)
return embedding
except ReadTimeout as e:
logging.warning(f"Timeout (ReadTimeout): {e}. Reintentando... ({retries+1}/{max_retries})")
retries += 1
time.sleep(backoff_factor ** retries)
except HTTPError as e:
if e.response.status_code == 429: # Límite de peticiones
retry_after = int(e.response.headers.get("Retry-After", 60))
logging.warning(f"Límite de la API alcanzado. Reintentando en {retry_after} segundos...")
time.sleep(retry_after)
retries += 1
else:
raise e
raise Exception(f"No se pudo generar el embedding después de {max_retries} intentos.")
def embed_text_chunks(pages_and_chunks: list[dict]) -> pd.DataFrame:
"""
Genera embeddings para cada chunk de texto usando un modelo de HuggingFace,
con lógica de reintento en caso de errores.
Parámetros:
-----------
pages_and_chunks : list[dict]
Lista de diccionarios que contienen chunks de texto y metadatos.
Retorna:
--------
pd.DataFrame
DataFrame que incluye cada chunk, sus metadatos y su embedding.
"""
embedding_model = get_embedding_model()
for item in pages_and_chunks:
try:
embedding = embed_with_retry(embedding_model, item["sentence_chunk"])
# Verifica que sea una lista anidada y la aplana
if isinstance(embedding, list):
embedding = [float(val) for sublist in embedding for val in sublist]
else:
raise ValueError(f"Formato de embedding inesperado: {type(embedding)}")
item["embedding"] = embedding
except Exception as e:
logging.error(f"No se pudo generar embedding para: {item['sentence_chunk']}. Error: {e}")
item["embedding"] = None
return pd.DataFrame(pages_and_chunks)
def save_to_chroma_db(embeddings_df: pd.DataFrame, user_id: str, document_id: str):
"""
Guarda en la base de datos Chroma los embeddings generados,
asignándoles metadatos con un identificador combinado de usuario y documento.
Parámetros:
-----------
embeddings_df : pd.DataFrame
DataFrame con los chunks y sus embeddings.
user_id : str
Identificador único de usuario.
document_id : str
Identificador único de documento.
"""
client = initialize_client()
collection = client.get_or_create_collection(name=f"text_embeddings_{user_id}")
combined_key = f"{user_id}_{document_id}"
ids = [f"{combined_key}_{i}" for i in range(len(embeddings_df))]
documents = embeddings_df["sentence_chunk"].tolist()
embeddings = []
for embedding in embeddings_df["embedding"]:
if isinstance(embedding, np.ndarray):
embeddings.append(embedding.flatten().tolist())
else:
embeddings.append(embedding)
metadatas = [{"combined_key": combined_key} for _ in range(len(embeddings_df))]
print(f"Guardando documentos con combined_key: {combined_key}")
collection.add(
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas
)
def query_chroma_db(user_id: str, document_id: str, query: str):
"""
Consulta la base de datos Chroma para recuperar los fragmentos de texto más
relevantes basados en la consulta dada.
Parámetros:
-----------
user_id : str
Identificador único de usuario.
document_id : str
Identificador único de documento.
query : str
Consulta que se desea realizar.
Retorna:
--------
str
Texto combinado de los documentos más relevantes, o mensaje indicando
que no se encontraron documentos.
"""
client = initialize_client()
collection = client.get_collection(name=f"text_embeddings_{user_id}")
combined_key = f"{user_id}_{document_id}"
print(f"Consultando con combined_key: {combined_key}")
results = collection.query(
query_texts=[query],
n_results=5,
where={"combined_key": combined_key},
)
print(f"Resultados de la consulta: {results}")
documents = results.get("documents", [])
if documents:
relevant_docs = [doc for sublist in documents for doc in sublist] # Aplanar la lista
context = "\n\n".join(relevant_docs)
else:
context = "No se encontraron documentos"
return context
def generate_document_id() -> str:
"""
Genera un ID único de documento usando UUID.
Retorna:
--------
str
Cadena única que identifica el documento.
"""
return str(uuid.uuid4())