vella-backend / _utils /gerar_relatorio_modelo_usuario /EnhancedDocumentSummarizer.py
luanpoppe
refactor: pequenas refatorações
7fa7a9c
raw
history blame
9.72 kB
import os
from typing import List, Dict, Tuple, Optional
from _utils.vector_stores.Vector_store_class import VectorStore
from setup.easy_imports import (
Chroma,
ChatOpenAI,
PromptTemplate,
BM25Okapi,
Response,
)
import logging
import requests
from _utils.gerar_relatorio_modelo_usuario.DocumentSummarizer_simples import (
DocumentSummarizer,
)
from _utils.models.gerar_relatorio import (
RetrievalConfig,
)
from modelos_usuarios.serializer import ModeloUsuarioSerializer
from setup.environment import api_url
from _utils.gerar_relatorio_modelo_usuario.contextual_retriever import (
ContextualRetriever,
)
from asgiref.sync import sync_to_async
class EnhancedDocumentSummarizer(DocumentSummarizer):
def __init__(
self,
openai_api_key: str,
claude_api_key: str,
config: RetrievalConfig,
embedding_model,
chunk_size,
chunk_overlap,
num_k_rerank,
model_cohere_rerank,
claude_context_model,
prompt_auxiliar,
gpt_model,
gpt_temperature,
# id_modelo_do_usuario,
prompt_gerar_documento,
reciprocal_rank_fusion,
):
super().__init__(
openai_api_key,
os.environ.get("COHERE_API_KEY"),
embedding_model,
chunk_size,
chunk_overlap,
num_k_rerank,
model_cohere_rerank,
)
self.config = config
self.contextual_retriever = ContextualRetriever(
config, claude_api_key, claude_context_model
)
self.logger = logging.getLogger(__name__)
self.prompt_auxiliar = prompt_auxiliar
self.gpt_model = gpt_model
self.gpt_temperature = gpt_temperature
# self.id_modelo_do_usuario = id_modelo_do_usuario
self.prompt_gerar_documento = prompt_gerar_documento
self.reciprocal_rank_fusion = reciprocal_rank_fusion
self.resumo_gerado = ""
self.vector_store = VectorStore(embedding_model)
def retrieve_with_rank_fusion(
self, vector_store: Chroma, bm25: BM25Okapi, chunk_ids: List[str], query: str
) -> List[Dict]:
"""Combine embedding and BM25 retrieval results"""
try:
# Get embedding results
embedding_results = vector_store.similarity_search_with_score(
query, k=self.config.num_chunks
)
# Convert embedding results to list of (chunk_id, score)
embedding_list = [
(doc.metadata["chunk_id"], 1 / (1 + score))
for doc, score in embedding_results
]
# Get BM25 results
tokenized_query = query.split()
bm25_scores = bm25.get_scores(tokenized_query)
# Convert BM25 scores to list of (chunk_id, score)
bm25_list = [
(chunk_ids[i], float(score)) for i, score in enumerate(bm25_scores)
]
# Sort bm25_list by score in descending order and limit to top N results
bm25_list = sorted(bm25_list, key=lambda x: x[1], reverse=True)[
: self.config.num_chunks
]
# Normalize BM25 scores
calculo_max = max(
[score for _, score in bm25_list]
) # Criei este max() pois em alguns momentos estava vindo valores 0, e reclamava que não podia dividir por 0
max_bm25 = calculo_max if bm25_list and calculo_max else 1
bm25_list = [(doc_id, score / max_bm25) for doc_id, score in bm25_list]
# Pass the lists to rank fusion
result_lists = [embedding_list, bm25_list]
weights = [self.config.embedding_weight, self.config.bm25_weight]
combined_results = self.reciprocal_rank_fusion(
result_lists, weights=weights
)
return combined_results
except Exception as e:
self.logger.error(f"Error in rank fusion retrieval: {str(e)}")
raise
async def generate_enhanced_summary(
self,
vector_store: Chroma,
bm25: BM25Okapi,
chunk_ids: List[str],
query: str = "Summarize the main points of this document",
) -> List[Dict]:
"""Generate enhanced summary using both vector and BM25 retrieval"""
try:
# Get combined results using rank fusion
ranked_results = self.retrieve_with_rank_fusion(
vector_store, bm25, chunk_ids, query
)
# Prepare context and track sources
contexts = []
sources = []
# Get full documents for top results
for chunk_id, score in ranked_results[: self.config.num_chunks]:
results = vector_store.get(
where={"chunk_id": chunk_id}, include=["documents", "metadatas"]
)
if results["documents"]:
context = results["documents"][0]
metadata = results["metadatas"][0]
contexts.append(context)
sources.append(
{
"content": context,
"page": metadata["page"],
"chunk_id": chunk_id,
"relevance_score": score,
"context": metadata.get("context", ""),
}
)
# url_request = f"{api_url}/modelo/{self.id_modelo_do_usuario}"
# try:
# print("url_request: ", url_request)
# resposta = requests.get(url_request)
# print("resposta: ", resposta)
# if resposta.status_code != 200:
# print("Entrou no if de erro")
# return Response(
# {
# "error": "Ocorreu um problema. Pode ser que o modelo não tenha sido encontrado. Tente novamente e/ou entre em contato com a equipe técnica"
# }
# )
# except:
# return Response(
# {
# "error": "Ocorreu um problema. Pode ser que o modelo não tenha sido encontrado. Tente novamente e/ou entre em contato com a equipe técnica"
# }
# )
# modelo_buscado = resposta.json()["modelo"]
# from modelos_usuarios.models import ModeloUsuarioModel
# try:
# # modelo_buscado = ModeloUsuarioModel.objects.get(
# # pk=self.id_modelo_do_usuario
# # )
# # serializer = ModeloUsuarioSerializer(modelo_buscado)
# # print("serializer.data: ", serializer.data)
# modelo_buscado = await sync_to_async(ModeloUsuarioModel.objects.get)(
# pk=self.id_modelo_do_usuario
# )
# serializer = await sync_to_async(ModeloUsuarioSerializer)(
# modelo_buscado
# )
# print("serializer.data: ", serializer.data)
# except Exception as e:
# print("e: ", e)
# return Response(
# {
# "error": "Ocorreu um problema. Pode ser que o modelo não tenha sido encontrado. Tente novamente e/ou entre em contato com a equipe técnica",
# "full_error": e,
# },
# 400,
# )
# print("modelo_buscado: ", serializer.data["modelo"])
llm = ChatOpenAI(
temperature=self.gpt_temperature,
model_name=self.gpt_model,
api_key=self.openai_api_key,
)
prompt_auxiliar = PromptTemplate(
template=self.prompt_auxiliar, input_variables=["context"]
)
resumo_auxiliar_do_documento = llm.invoke(
prompt_auxiliar.format(context="\n\n".join(contexts))
)
self.resumo_gerado = resumo_auxiliar_do_documento.content
prompt_gerar_documento = PromptTemplate(
template=self.prompt_gerar_documento,
input_variables=["context"],
)
documento_gerado = llm.invoke(
prompt_gerar_documento.format(
context=self.resumo_gerado,
# modelo_usuario=serializer.data["modelo"],
)
).content
# Split the response into paragraphs
summaries = [p.strip() for p in documento_gerado.split("\n\n") if p.strip()]
# Create structured output
structured_output = []
for idx, summary in enumerate(summaries):
source_idx = min(idx, len(sources) - 1)
structured_output.append(
{
"content": summary,
"source": {
"page": sources[source_idx]["page"],
"text": sources[source_idx]["content"][:200] + "...",
"context": sources[source_idx]["context"],
"relevance_score": sources[source_idx]["relevance_score"],
"chunk_id": sources[source_idx]["chunk_id"],
},
}
)
return structured_output
except Exception as e:
self.logger.error(f"Error generating enhanced summary: {str(e)}")
raise