PapersChat / app.py
as-cle-bert's picture
Upload 4 files
6ca31d3 verified
from utils import ingest_documents, qdrant_client, List, QdrantVectorStore, VectorStoreIndex, embedder
import gradio as gr
from toolsFunctions import pubmed_tool, arxiv_tool
from llama_index.core.tools import QueryEngineTool, FunctionTool
from llama_index.core import Settings
from llama_index.llms.mistralai import MistralAI
from llama_index.core.llms import ChatMessage
from llama_index.core.agent import ReActAgent
from phoenix.otel import register
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
import time
import os
## Observing and tracing
PHOENIX_API_KEY = os.getenv("phoenix_api_key")
os.environ["PHOENIX_CLIENT_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com"
tracer_provider = register(
project_name="llamaindex_hf",
)
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
## Global
Settings.embed_model = embedder
arxivtool = FunctionTool.from_defaults(arxiv_tool, name="arxiv_tool", description="A tool to search ArXiv (pre-print papers database) for specific papers")
pubmedtool = FunctionTool.from_defaults(pubmed_tool, name="pubmed_tool", description="A tool to search PubMed (printed medical papers database) for specific papers")
query_engine = None
message_history = [
ChatMessage(role="system", content="You are a useful assistant that has to help the user with questions that they ask about several papers they uploaded. You should base your answers on the context you can retrieve from the PDFs and, if you cannot retrieve any, search ArXiv for a potential answer. If you cannot find any viable answer, please reply that you do not know the answer to the user's question")
]
## Functions
def reply(message, history, files: List[str] | None, collection: str, llamaparse: bool, llamacloud_api_key: str, mistral_api_key: str):
global message_history
if mistral_api_key == "":
response = "You should provide a Mistral AI API key"
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
else:
try:
chat_mis = MistralAI(model="mistral-small-latest", temperature=0, api_key=mistral_api_key)
chat_mis.complete("Hello, who are you?")
except Exception as e:
response = "You Mistral AI key is not valid"
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
else:
Settings.llm = MistralAI(model="mistral-small-latest", temperature=0, api_key=mistral_api_key)
if llamaparse and llamacloud_api_key == "":
response = "If you activate LlamaParse, you should provide a LlamaCloud API key"
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
elif message == "" or message is None:
response = "You should provide a message"
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
elif files is None and collection == "":
res = "### WARNING! You did not specify any collection, so I only interrogated ArXiv and/or PubMed to answer your question\n\n"
agent = ReActAgent.from_tools(tools=[pubmedtool, arxivtool], verbose=True)
response = agent.chat(message = message, chat_history = message_history)
response = str(response)
message_history.append(ChatMessage(role="user", content=message))
message_history.append(ChatMessage(role="assistant", content=response))
response = res + response
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
elif files is None and collection != "" and collection not in [c.name for c in qdrant_client.get_collections().collections]:
response = "Make sure that the name of the existing collection to use as a knowledge base is correct, because the one you provided does not exist! You can check your existing collections and their features in the dedicated tab of the app :)"
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
elif files is not None:
if len(files) > 5:
response = "You cannot upload more than 5 files"
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
elif collection == "":
response = "You should provide a collection name (new or existing) if you want to ingest files!"
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
else:
collection_name = collection
index = ingest_documents(files, collection_name, llamaparse, llamacloud_api_key)
query_engine = index.as_query_engine()
rag_tool = QueryEngineTool.from_defaults(query_engine, name="papers_rag", description="A RAG engine with information from selected scientific papers")
agent = ReActAgent.from_tools(tools=[rag_tool, pubmedtool, arxivtool], verbose=True)
response = agent.chat(message = message, chat_history = message_history)
response = str(response)
message_history.append(ChatMessage(role="user", content=message))
message_history.append(ChatMessage(role="assistant", content=response))
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
else:
vector_store = QdrantVectorStore(client = qdrant_client, collection_name=collection, enable_hybrid=True)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
query_engine = index.as_query_engine()
rag_tool = QueryEngineTool.from_defaults(query_engine, name="papers_rag", description="A RAG engine with information from selected scientific papers")
agent = ReActAgent.from_tools(tools=[rag_tool, pubmedtool, arxivtool], verbose=True)
response = agent.chat(message = message, chat_history = message_history)
response = str(response)
message_history.append(ChatMessage(role="user", content=message))
message_history.append(ChatMessage(role="assistant", content=response))
r = ""
for char in response:
r+=char
time.sleep(0.001)
yield r
def to_markdown_color(grade: str):
colors = {"red": "ff0000", "yellow": "ffcc00", "green": "33cc33"}
mdcode = f"![#{colors[grade]}](https://placehold.co/15x15/{colors[grade]}/{colors[grade]}.png)"
return mdcode
def get_qdrant_collections_dets():
collections = [c.name for c in qdrant_client.get_collections().collections]
details = []
counter = 0
collections.remove("semantic_cache_med")
collections.remove("stem_cot_qa")
for collection in collections:
counter += 1
dets = qdrant_client.get_collection(collection)
p = f"### {counter}. {collection}\n\n**Number of Points**: {dets.points_count}\n\n**Status**: {to_markdown_color(dets.status)} {dets.status}\n\n"
details.append(p)
final_text = "<h2 align='center'>Available Collections</h2>\n\n"
final_text += "\n\n".join(details)
return final_text
## Frontend
accordion = gr.Accordion(label="⚠️Set up these parameters before you start chatting!⚠️")
iface1 = gr.ChatInterface(fn=reply, additional_inputs=[gr.File(label="Upload Papers (only PDF allowed!)", file_count="multiple", file_types=[".pdf","pdf",".PDF","PDF"], value=None), gr.Textbox(label="Collection", info="Upload your papers to a collection (new or existing)", value=""), gr.Checkbox(label="Use LlamaParse", info="Needs the LlamaCloud API key", value=False), gr.Textbox(label="LlamaCloud API key", type="password", info="Set this field if you enable LlamaParse", value=""), gr.Textbox(label="Mistral AI API key", type="password", value="")], additional_inputs_accordion=accordion)
u = open("usage.md")
content = u.read()
u.close()
iface2 = gr.Blocks()
with iface2:
with gr.Row():
gr.Markdown(content)
iface3 = gr.Interface(fn=get_qdrant_collections_dets, inputs=None, outputs=gr.Markdown(label="Collections"), submit_btn="See your collections")
iface = gr.TabbedInterface([iface1, iface2, iface3], ["Chat💬", "Usage Guide⚙️", "Available Collections🔎"], title="PapersChat📝")
iface.launch(server_name="0.0.0.0", server_port=7860)