Spaces:
Build error
Build error
| import os | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| from langchain_chroma import Chroma | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.chains import create_retrieval_chain, create_history_aware_retriever | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import MessagesPlaceholder | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_core.documents import Document | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| from langchain_openai import ChatOpenAI | |
| from langchain.callbacks.tracers import ConsoleCallbackHandler | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from datasets import load_dataset | |
| import chromadb | |
| from typing import List | |
| from mixedbread_ai.client import MixedbreadAI | |
| from tqdm import tqdm | |
| # Global params | |
| CHROMA_PATH = "chromadb_mem10_mxbai_800_complete" | |
| MODEL_EMB = "mxbai-embed-large" | |
| MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" | |
| LLM_NAME = "gpt-4o-mini" | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| HF_API_KEY = os.environ.get("HF_API_KEY") | |
| # MixedbreadAI Client | |
| # device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) | |
| model_emb = "mixedbread-ai/mxbai-embed-large-v1" | |
| # Set up ChromaDB | |
| memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True) | |
| batched_ds = memoires_ds.batch(batch_size=50000) | |
| client = chromadb.Client() | |
| collection = client.get_or_create_collection(name="embeddings_mxbai") | |
| for batch in tqdm(batched_ds, desc="Processing dataset batches"): | |
| collection.add( | |
| ids=batch["id"], | |
| metadatas=batch["metadata"], | |
| documents=batch["document"], | |
| embeddings=batch["embedding"], | |
| ) | |
| print(f"Collection complete: {collection.count()}") | |
| db = Chroma( | |
| client=client, | |
| collection_name=f"embeddings_mxbai", | |
| embedding_function = HuggingFaceEmbeddings(model_name=model_emb) | |
| ) | |
| # Reranker class | |
| class Reranker(BaseRetriever): | |
| retriever: VectorStoreRetriever | |
| # model: CrossEncoder | |
| k: int | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| docs = self.retriever.invoke(query) | |
| results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) | |
| return [Document(page_content=res.input) for res in results.data] | |
| # Set up reranker + LLM | |
| retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) | |
| reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) | |
| llm = ChatOpenAI(model=LLM_NAME, verbose=True) #, api_key=OPENAI_API_KEY, ) | |
| # Set up the contextualize question prompt | |
| contextualize_q_system_prompt = ( | |
| "Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur " | |
| "qui peut faire référence à un contexte dans l'historique du chat, " | |
| "formuler une question autonome qui peut être comprise " | |
| "sans l'historique du chat. Ne répondez PAS à la question, " | |
| "juste la reformuler si nécessaire et sinon la renvoyer telle quelle." | |
| ) | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| # Create the history-aware retriever | |
| history_aware_retriever = create_history_aware_retriever( | |
| llm, reranker, contextualize_q_prompt | |
| ) | |
| # Set up the QA prompt | |
| system_prompt = ( | |
| "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" | |
| ) | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| # Create the question-answer chain | |
| question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
| rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
| # Set up the conversation history | |
| store = {} | |
| def get_session_history(session_id: str) -> ChatMessageHistory: | |
| if session_id not in store: | |
| store[session_id] = ChatMessageHistory() | |
| return store[session_id] | |
| conversational_rag_chain = RunnableWithMessageHistory( | |
| rag_chain, | |
| get_session_history, | |
| input_messages_key="input", | |
| history_messages_key="chat_history", | |
| output_messages_key="answer", | |
| ) | |
| # Gradio interface | |
| def chatbot(message, history): | |
| session_id = "gradio_session" | |
| response = conversational_rag_chain.invoke( | |
| {"input": message}, | |
| config={ | |
| "configurable": {"session_id": session_id}, | |
| "callbacks": [ConsoleCallbackHandler()] | |
| }, | |
| )["answer"] | |
| return response | |
| iface = gr.ChatInterface( | |
| chatbot, | |
| title="Assurance Chatbot", | |
| description="Posez vos questions sur l'assurance", | |
| theme="soft", | |
| examples=[ | |
| "Qu'est-ce que l'assurance multirisque habitation ?", | |
| "Qu'est-ce que la garantie DTA ?", | |
| ], | |
| retry_btn=None, | |
| undo_btn=None, | |
| clear_btn="Effacer la conversation", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() # share=True |