File size: 9,427 Bytes
6ca31d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from utils import ingest_documents, qdrant_client, List, QdrantVectorStore, VectorStoreIndex, embedder
import gradio as gr
from toolsFunctions import pubmed_tool, arxiv_tool
from llama_index.core.tools import QueryEngineTool, FunctionTool
from llama_index.core import Settings
from llama_index.llms.mistralai import MistralAI
from llama_index.core.llms import ChatMessage
from llama_index.core.agent import ReActAgent
from phoenix.otel import register
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
import time
import os

## Observing and tracing
PHOENIX_API_KEY = os.getenv("phoenix_api_key")
os.environ["PHOENIX_CLIENT_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com"
tracer_provider = register(
    project_name="llamaindex_hf", 
) 
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)

## Global
Settings.embed_model = embedder
arxivtool = FunctionTool.from_defaults(arxiv_tool, name="arxiv_tool", description="A tool to search ArXiv (pre-print papers database) for specific papers")
pubmedtool = FunctionTool.from_defaults(pubmed_tool, name="pubmed_tool", description="A tool to search PubMed (printed medical papers database) for specific papers")
query_engine = None
message_history = [
    ChatMessage(role="system", content="You are a useful assistant that has to help the user with questions that they ask about several papers they uploaded. You should base your answers on the context you can retrieve from the PDFs and, if you cannot retrieve any, search ArXiv for a potential answer. If you cannot find any viable answer, please reply that you do not know the answer to the user's question")
]

## Functions
def reply(message, history, files: List[str] | None, collection: str, llamaparse: bool, llamacloud_api_key: str, mistral_api_key: str):
    global message_history
    if mistral_api_key  == "":
        response = "You should provide a Mistral AI API key"
        r = ""
        for char in response:
            r+=char
            time.sleep(0.001)
            yield r
    else:
        try:
            chat_mis = MistralAI(model="mistral-small-latest", temperature=0, api_key=mistral_api_key) 
            chat_mis.complete("Hello, who are you?")
        except Exception as e:
            response = "You Mistral AI key is not valid"
            r = ""
            for char in response:
                r+=char
                time.sleep(0.001)
                yield r
        else:
            Settings.llm = MistralAI(model="mistral-small-latest", temperature=0, api_key=mistral_api_key)
            if llamaparse and llamacloud_api_key == "":
                response = "If you activate LlamaParse, you should provide a LlamaCloud API key"
                r = ""
                for char in response:
                    r+=char
                    time.sleep(0.001)
                    yield r
            elif message == "" or message is None:
                response = "You should provide a message"
                r = ""
                for char in response:
                    r+=char
                    time.sleep(0.001)
                    yield r
            elif files is None and collection == "":
                res = "### WARNING! You did not specify any collection, so I only interrogated ArXiv and/or PubMed to answer your question\n\n"
                agent = ReActAgent.from_tools(tools=[pubmedtool, arxivtool], verbose=True)
                response = agent.chat(message = message, chat_history = message_history)
                response = str(response)
                message_history.append(ChatMessage(role="user", content=message))
                message_history.append(ChatMessage(role="assistant", content=response))
                response = res + response
                r = ""
                for char in response:
                    r+=char
                    time.sleep(0.001)
                    yield r
            elif files is None and collection != "" and collection not in [c.name for c in qdrant_client.get_collections().collections]:
                    response = "Make sure that the name of the existing collection to use as a knowledge base is correct, because the one you provided does not exist! You can check your existing collections and their features in the dedicated tab of the app :)"
                    r = ""
                    for char in response:
                        r+=char
                        time.sleep(0.001)
                        yield r
            elif files is not None:
                if len(files) > 5:
                    response = "You cannot upload more than 5 files"
                    r = ""
                    for char in response:
                        r+=char
                        time.sleep(0.001)
                        yield r 
                elif collection == "":
                    response = "You should provide a collection name (new or existing) if you want to ingest files!"
                    r = ""
                    for char in response:
                        r+=char
                        time.sleep(0.001)
                        yield r
                else:
                    collection_name = collection
                    index = ingest_documents(files, collection_name, llamaparse, llamacloud_api_key)
                    query_engine = index.as_query_engine()
                    rag_tool = QueryEngineTool.from_defaults(query_engine, name="papers_rag", description="A RAG engine with information from selected scientific papers")
                    agent = ReActAgent.from_tools(tools=[rag_tool, pubmedtool, arxivtool], verbose=True)
                    response = agent.chat(message = message, chat_history = message_history)
                    response = str(response)
                    message_history.append(ChatMessage(role="user", content=message))
                    message_history.append(ChatMessage(role="assistant", content=response))
                    r = ""
                    for char in response:
                        r+=char
                        time.sleep(0.001)
                        yield r
            else:
                vector_store = QdrantVectorStore(client = qdrant_client, collection_name=collection, enable_hybrid=True)
                index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
                query_engine = index.as_query_engine()
                rag_tool = QueryEngineTool.from_defaults(query_engine, name="papers_rag", description="A RAG engine with information from selected scientific papers")
                agent = ReActAgent.from_tools(tools=[rag_tool, pubmedtool, arxivtool], verbose=True)
                response = agent.chat(message = message, chat_history = message_history)
                response = str(response)
                message_history.append(ChatMessage(role="user", content=message))
                message_history.append(ChatMessage(role="assistant", content=response))
                r = ""
                for char in response:
                    r+=char
                    time.sleep(0.001)
                    yield r

def to_markdown_color(grade: str):
    colors = {"red": "ff0000", "yellow": "ffcc00", "green": "33cc33"}
    mdcode = f"![#{colors[grade]}](https://placehold.co/15x15/{colors[grade]}/{colors[grade]}.png)"
    return mdcode

def get_qdrant_collections_dets():
    collections = [c.name for c in qdrant_client.get_collections().collections]
    details = []
    counter = 0
    collections.remove("semantic_cache_med")
    collections.remove("stem_cot_qa")
    for collection in collections:
        counter += 1
        dets = qdrant_client.get_collection(collection)
        p = f"### {counter}. {collection}\n\n**Number of Points**: {dets.points_count}\n\n**Status**: {to_markdown_color(dets.status)} {dets.status}\n\n"
        details.append(p)
    final_text = "<h2 align='center'>Available Collections</h2>\n\n"
    final_text += "\n\n".join(details)
    return final_text

## Frontend
accordion = gr.Accordion(label="⚠️Set up these parameters before you start chatting!⚠️")

iface1 = gr.ChatInterface(fn=reply, additional_inputs=[gr.File(label="Upload Papers (only PDF allowed!)", file_count="multiple", file_types=[".pdf","pdf",".PDF","PDF"], value=None), gr.Textbox(label="Collection", info="Upload your papers to a collection (new or existing)", value=""), gr.Checkbox(label="Use LlamaParse", info="Needs the LlamaCloud API key", value=False), gr.Textbox(label="LlamaCloud API key", type="password", info="Set this field if you enable LlamaParse", value=""), gr.Textbox(label="Mistral AI API key", type="password", value="")], additional_inputs_accordion=accordion)
u = open("usage.md")
content = u.read()
u.close()
iface2 = gr.Blocks()
with iface2:
    with gr.Row():
        gr.Markdown(content)
iface3 = gr.Interface(fn=get_qdrant_collections_dets, inputs=None, outputs=gr.Markdown(label="Collections"), submit_btn="See your collections")
iface = gr.TabbedInterface([iface1, iface2, iface3], ["Chat💬", "Usage Guide⚙️", "Available Collections🔎"], title="PapersChat📝")
iface.launch(server_name="0.0.0.0", server_port=7860)