Spaces:
Build error
Build error
| from langchain_community.tools import TavilySearchResults | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain_core.documents import Document | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from mixedbread_ai.client import MixedbreadAI | |
| from langchain.chains import create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.prompts import ChatPromptTemplate | |
| from dotenv import load_dotenv | |
| import os | |
| from langchain_chroma import Chroma | |
| import chromadb | |
| from typing import List | |
| from datasets import load_dataset | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from tqdm import tqdm | |
| from datetime import datetime | |
| load_dotenv() | |
| # Global params | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| MODEL_EMB = "mxbai-embed-large" | |
| MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" | |
| LLM_NAME = "gpt-4o-mini" | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| HF_API_KEY = os.environ.get("HF_API_KEY") | |
| # MixedbreadAI Client | |
| mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) | |
| model_emb = "mixedbread-ai/mxbai-embed-large-v1" | |
| # # Set up ChromaDB | |
| memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True) | |
| batched_ds = memoires_ds.batch(batch_size=41000) | |
| client = chromadb.Client() | |
| collection = client.get_or_create_collection(name="embeddings_mxbai") | |
| # for batch in tqdm(batched_ds, desc="Processing dataset batches"): | |
| # collection.add( | |
| # ids=batch["id"], | |
| # metadatas=batch["metadata"], | |
| # documents=batch["document"], | |
| # embeddings=batch["embedding"], | |
| # ) | |
| print(f"Collection complete: {collection.count()}") | |
| del memoires_ds, batched_ds | |
| llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0) | |
| def init_rag_tool(): | |
| """Init tools to allow an LLM to query the documents""" | |
| # client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| db = Chroma( | |
| client=client, | |
| collection_name=f"embeddings_mxbai", | |
| embedding_function = HuggingFaceEmbeddings(model_name=model_emb) | |
| ) | |
| # Reranker class | |
| class Reranker(BaseRetriever): | |
| retriever: VectorStoreRetriever | |
| # model: CrossEncoder | |
| k: int | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| docs = self.retriever.invoke(query) | |
| results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) | |
| return [Document(page_content=res.input) for res in results.data] | |
| # Set up reranker + LLM | |
| retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) | |
| reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) | |
| llm = ChatOpenAI(model=LLM_NAME, verbose=True) | |
| system_prompt = ( | |
| "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" | |
| "Si tu ne connais pas la réponse, dis que tu ne sais pas." | |
| ) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| question_answer_chain = create_stuff_documents_chain(llm, prompt) | |
| rag_chain = create_retrieval_chain(reranker, question_answer_chain) | |
| rag_tool = rag_chain.as_tool( | |
| name="RAG_search", | |
| description="Recherche d'information dans les mémoires d'actuariat", | |
| arg_types={"input": str}, | |
| ) | |
| return rag_tool | |
| def init_websearch_tool(): | |
| web_search_tool = TavilySearchResults( | |
| name="Web_search", | |
| max_results=5, | |
| description="Recherche d'informations sur le web", | |
| search_depth="advanced", | |
| include_answer=True, | |
| include_raw_content=True, | |
| include_images=False, | |
| verbose=False, | |
| ) | |
| return web_search_tool | |
| def create_agent(): | |
| rag_tool = init_rag_tool() | |
| web_search_tool = init_websearch_tool() | |
| memory = MemorySaver() | |
| llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True) | |
| tools = [rag_tool, web_search_tool] | |
| system_message = """ | |
| Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat. | |
| Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur. | |
| """ # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web. | |
| react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False) | |
| return react_agent | |