Shreyas094 commited on
Commit
7483e97
·
verified ·
1 Parent(s): 6bdbafb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +723 -154
app.py CHANGED
@@ -1,203 +1,772 @@
1
  import os
2
- import logging
3
- import asyncio
4
  import gradio as gr
5
- from huggingface_hub import InferenceClient
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain.vectorstores import FAISS
8
- from langchain.schema import Document
9
  from duckduckgo_search import DDGS
10
- from dotenv import load_dotenv
11
- from functools import lru_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Load environment variables
14
- load_dotenv()
15
 
16
- # Configure logging
 
 
 
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
- logger = logging.getLogger(__name__)
19
 
20
  # Environment variables and configurations
21
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
22
  MODELS = [
23
  "mistralai/Mistral-7B-Instruct-v0.3",
24
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
25
- "mistralai/Mistral-Nemo-Instruct-2407",
26
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
27
- "meta-llama/Meta-Llama-3.1-70B-Instruct",
28
- "google/gemma-2-9b-it",
29
- "google/gemma-2-27b-it"
30
- ]
31
 
32
- DEFAULT_SYSTEM_PROMPT = """You are a world-class financial AI assistant, capable of complex reasoning and reflection.
33
- Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags.
34
- Providing comprehensive and accurate information based on web search results is essential.
35
- Your goal is to synthesize the given context into a coherent and detailed response that directly addresses the user's query.
36
- Please ensure that your response is well-structured and factual.
37
- If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags."""
38
 
39
- class WebSearcher:
40
- def __init__(self):
41
- self.ddgs = DDGS()
 
 
 
 
 
 
 
42
 
43
- @lru_cache(maxsize=100)
44
- def search(self, query, max_results=5):
 
 
 
 
45
  try:
46
- results = list(self.ddgs.text(query, max_results=max_results))
47
- logger.info(f"Search completed for query: {query}")
48
- return results
49
  except Exception as e:
50
- logger.error(f"Error during DuckDuckGo search: {str(e)}")
51
- return []
 
 
 
 
 
52
 
53
- @lru_cache(maxsize=1)
54
  def get_embeddings():
55
  return HuggingFaceEmbeddings(model_name="sentence-transformers/stsb-roberta-large")
56
 
57
- def create_web_search_vectors(search_results):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  embed = get_embeddings()
59
- documents = [
60
- Document(
61
- page_content=f"{result['title']}\n{result['body']}\nSource: {result['href']}",
62
- metadata={"source": result['href']}
63
- )
64
- for result in search_results if 'body' in result
65
- ]
66
- logger.info(f"Created vectors for {len(documents)} search results.")
67
- return FAISS.from_documents(documents, embed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- async def get_response_with_search(query, system_prompt, model, use_embeddings, num_calls=3, temperature=0.2):
70
- searcher = WebSearcher()
71
- search_results = searcher.search(query)
72
 
73
- if not search_results:
74
- logger.warning(f"No web search results found for query: {query}")
75
- yield "No web search results available. Please try again.", ""
76
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- sources = [result['href'] for result in search_results if 'href' in result]
79
- source_list_str = "\n".join(sources)
 
 
 
 
 
 
80
 
81
- if use_embeddings:
82
- web_search_database = create_web_search_vectors(search_results)
83
- retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  relevant_docs = retriever.get_relevant_documents(query)
85
- context = "\n".join([doc.page_content for doc in relevant_docs])
 
 
 
 
 
 
86
  else:
87
- context = "\n".join([f"{result['title']}\n{result['body']}" for result in search_results])
88
 
89
- logger.info(f"Context created for query: {query}")
 
 
 
 
 
90
 
91
- user_message = f"""Using the following context from web search results:
 
92
  {context}
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- Write a detailed and complete research document that fulfills the following user request: '{query}'."""
 
 
 
 
 
95
 
96
- async with InferenceClient(model, token=HUGGINGFACE_TOKEN) as client:
97
- full_response = ""
98
  try:
99
- for _ in range(num_calls):
100
- async for response in client.chat_completion_stream(
101
- messages=[
102
- {"role": "system", "content": system_prompt},
103
- {"role": "user", "content": user_message}
104
- ],
105
- max_tokens=6000,
106
- temperature=temperature,
107
- top_p=0.8,
108
- ):
109
- if "content" in response:
110
- chunk = response["content"]
111
- full_response += chunk
112
- yield full_response, ""
 
 
113
  except Exception as e:
114
- logger.error(f"Error in get_response_with_search: {str(e)}")
115
- yield f"An error occurred while processing your request: {str(e)}", ""
116
-
117
  if not full_response:
118
- logger.warning("No response generated from the model")
119
- yield "No response generated from the model.", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- yield f"{full_response}\n\nSources:\n{source_list_str}", ""
122
 
123
- async def respond(message, system_prompt, history, model, temperature, num_calls, use_embeddings):
124
- logger.info(f"User Query: {message}")
125
- logger.info(f"Model Used: {model}")
126
- logger.info(f"Temperature: {temperature}")
127
- logger.info(f"Number of API Calls: {num_calls}")
128
- logger.info(f"Use Embeddings: {use_embeddings}")
129
- logger.info(f"System Prompt: {system_prompt}")
130
 
131
- try:
132
- async for main_content, sources in get_response_with_search(message, system_prompt, model, use_embeddings, num_calls=num_calls, temperature=temperature):
133
- yield main_content
134
- except asyncio.CancelledError:
135
- logger.warning("The operation was cancelled.")
136
- yield "The operation was cancelled. Please try again."
137
- except Exception as e:
138
- logger.error(f"Error in respond function: {str(e)}")
139
- yield f"An error occurred: {str(e)}"
140
 
141
  css = """
142
  /* Fine-tune chatbox size */
143
- .chatbot-container {
144
- height: 600px !important;
145
- width: 100% !important;
146
- }
147
- .chatbot-container > div {
148
- height: 100%;
149
- width: 100%;
150
  }
151
  """
152
 
153
- def create_gradio_interface():
154
- custom_placeholder = "Enter your question here for web search."
155
-
156
- demo = gr.ChatInterface(
157
- fn=respond,
158
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=True, render=False),
159
- additional_inputs=[
160
- gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, lines=6, label="System Prompt", placeholder="Enter your system prompt here"),
161
- gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[3]),
162
- gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
163
- gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
164
- gr.Checkbox(label="Use Embeddings", value=False),
165
- ],
166
- title="AI-powered Web Search Assistant",
167
- description="Use web search to answer questions or generate summaries.",
168
- theme=gr.Theme.from_hub("allenai/gradio-theme"),
169
- css=css,
170
- examples=[
171
- ["What are the latest developments in artificial intelligence?"],
172
- ["Explain the concept of quantum computing."],
173
- ["What are the environmental impacts of renewable energy?"]
174
- ],
175
- cache_examples=False,
176
- analytics_enabled=False,
177
- textbox=gr.Textbox(placeholder=custom_placeholder, container=False, scale=7),
178
- chatbot=gr.Chatbot(
179
- show_copy_button=True,
180
- likeable=True,
181
- layout="bubble",
182
- height=400,
183
- )
184
  )
185
 
186
- with demo:
187
- gr.Markdown("""
188
- ## How to use
189
- 1. Enter your question in the chat interface.
190
- 2. Optionally, modify the System Prompt to guide the AI's behavior.
191
- 3. Select the model you want to use from the dropdown.
192
- 4. Adjust the Temperature to control the randomness of the response.
193
- 5. Set the Number of API Calls to determine how many times the model will be queried.
194
- 6. Check or uncheck the "Use Embeddings" box to toggle between using embeddings or direct text summarization.
195
- 7. Press Enter or click the submit button to get your answer.
196
- 8. Use the provided examples or ask your own questions.
197
- """)
198
-
199
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  if __name__ == "__main__":
202
- demo = create_gradio_interface()
203
  demo.launch(share=True)
 
1
  import os
2
+ import json
3
+ import re
4
  import gradio as gr
5
+ import requests
 
 
 
6
  from duckduckgo_search import DDGS
7
+ from typing import List
8
+ from pydantic import BaseModel, Field
9
+ from tempfile import NamedTemporaryFile
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain_core.vectorstores import VectorStore
12
+ from langchain_core.documents import Document
13
+ from langchain_community.document_loaders import PyPDFLoader
14
+ from langchain_community.embeddings import HuggingFaceEmbeddings
15
+ from llama_parse import LlamaParse
16
+ from langchain_core.documents import Document
17
+ from huggingface_hub import InferenceClient
18
+ import inspect
19
+ import logging
20
+ import shutil
21
+
22
+
23
 
 
 
24
 
25
+
26
+
27
+
28
+ # Set up basic configuration for logging
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
30
+
31
 
32
  # Environment variables and configurations
33
+ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
34
+ llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
35
+ ACCOUNT_ID = os.environ.get("CLOUDFARE_ACCOUNT_ID")
36
+ API_TOKEN = os.environ.get("CLOUDFLARE_AUTH_TOKEN")
37
+ API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/a17f03e0f049ccae0c15cdcf3b9737ce/ai/run/"
38
+
39
+ print(f"ACCOUNT_ID: {ACCOUNT_ID}")
40
+ print(f"CLOUDFLARE_AUTH_TOKEN: {API_TOKEN[:5]}..." if API_TOKEN else "Not set")
41
+
42
  MODELS = [
43
  "mistralai/Mistral-7B-Instruct-v0.3",
44
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
45
+ "@cf/meta/llama-3.1-8b-instruct",
46
+ "mistralai/Mistral-Nemo-Instruct-2407"
47
+
 
 
 
48
 
 
 
 
 
 
 
49
 
50
+ ]
51
+
52
+ # Initialize LlamaParse
53
+ llama_parser = LlamaParse(
54
+ api_key=llama_cloud_api_key,
55
+ result_type="markdown",
56
+ num_workers=4,
57
+ verbose=True,
58
+ language="en",
59
+ )
60
 
61
+ def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[Document]:
62
+ """Loads and splits the document into pages."""
63
+ if parser == "pypdf":
64
+ loader = PyPDFLoader(file.name)
65
+ return loader.load_and_split()
66
+ elif parser == "llamaparse":
67
  try:
68
+ documents = llama_parser.load_data(file.name)
69
+ return [Document(page_content=doc.text, metadata={"source": file.name}) for doc in documents]
70
+
71
  except Exception as e:
72
+ print(f"Error using Llama Parse: {str(e)}")
73
+ print("Falling back to PyPDF parser")
74
+ loader = PyPDFLoader(file.name)
75
+ return loader.load_and_split()
76
+ else:
77
+ raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
78
+
79
 
 
80
  def get_embeddings():
81
  return HuggingFaceEmbeddings(model_name="sentence-transformers/stsb-roberta-large")
82
 
83
+ # Add this at the beginning of your script, after imports
84
+ DOCUMENTS_FILE = "uploaded_documents.json"
85
+
86
+ def load_documents():
87
+ if os.path.exists(DOCUMENTS_FILE):
88
+ with open(DOCUMENTS_FILE, "r") as f:
89
+ return json.load(f)
90
+ return []
91
+
92
+ def save_documents(documents):
93
+ with open(DOCUMENTS_FILE, "w") as f:
94
+ json.dump(documents, f)
95
+
96
+ # Replace the global uploaded_documents with this
97
+ uploaded_documents = load_documents()
98
+
99
+ # Modify the update_vectors function
100
+ def update_vectors(files, parser):
101
+ global uploaded_documents
102
+ logging.info(f"Entering update_vectors with {len(files)} files and parser: {parser}")
103
+
104
+ if not files:
105
+ logging.warning("No files provided for update_vectors")
106
+ return "Please upload at least one PDF file.", display_documents()
107
+
108
  embed = get_embeddings()
109
+ total_chunks = 0
110
+
111
+ all_data = []
112
+ for file in files:
113
+ logging.info(f"Processing file: {file.name}")
114
+ try:
115
+ data = load_document(file, parser)
116
+ if not data:
117
+ logging.warning(f"No chunks loaded from {file.name}")
118
+ continue
119
+ logging.info(f"Loaded {len(data)} chunks from {file.name}")
120
+ all_data.extend(data)
121
+ total_chunks += len(data)
122
+ if not any(doc["name"] == file.name for doc in uploaded_documents):
123
+ uploaded_documents.append({"name": file.name, "selected": True})
124
+ logging.info(f"Added new document to uploaded_documents: {file.name}")
125
+ else:
126
+ logging.info(f"Document already exists in uploaded_documents: {file.name}")
127
+ except Exception as e:
128
+ logging.error(f"Error processing file {file.name}: {str(e)}")
129
+
130
+ logging.info(f"Total chunks processed: {total_chunks}")
131
+
132
+ if not all_data:
133
+ logging.warning("No valid data extracted from uploaded files")
134
+ return "No valid data could be extracted from the uploaded files. Please check the file contents and try again.", display_documents()
135
+
136
+ try:
137
+ if os.path.exists("faiss_database"):
138
+ logging.info("Updating existing FAISS database")
139
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
140
+ database.add_documents(all_data)
141
+ else:
142
+ logging.info("Creating new FAISS database")
143
+ database = FAISS.from_documents(all_data, embed)
144
+
145
+ database.save_local("faiss_database")
146
+ logging.info("FAISS database saved")
147
+ except Exception as e:
148
+ logging.error(f"Error updating FAISS database: {str(e)}")
149
+ return f"Error updating vector store: {str(e)}", display_documents()
150
+
151
+ # Save the updated list of documents
152
+ save_documents(uploaded_documents)
153
+
154
+ return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}.", display_documents()
155
 
156
+ def delete_documents(selected_docs):
157
+ global uploaded_documents
 
158
 
159
+ if not selected_docs:
160
+ return "No documents selected for deletion.", display_documents()
161
+
162
+ embed = get_embeddings()
163
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
164
+
165
+ deleted_docs = []
166
+ docs_to_keep = []
167
+ for doc in database.docstore._dict.values():
168
+ if doc.metadata.get("source") not in selected_docs:
169
+ docs_to_keep.append(doc)
170
+ else:
171
+ deleted_docs.append(doc.metadata.get("source", "Unknown"))
172
+
173
+ # Print debugging information
174
+ logging.info(f"Total documents before deletion: {len(database.docstore._dict)}")
175
+ logging.info(f"Documents to keep: {len(docs_to_keep)}")
176
+ logging.info(f"Documents to delete: {len(deleted_docs)}")
177
+
178
+ if not docs_to_keep:
179
+ # If all documents are deleted, remove the FAISS database directory
180
+ if os.path.exists("faiss_database"):
181
+ shutil.rmtree("faiss_database")
182
+ logging.info("All documents deleted. Removed FAISS database directory.")
183
+ else:
184
+ # Create new FAISS index with remaining documents
185
+ new_database = FAISS.from_documents(docs_to_keep, embed)
186
+ new_database.save_local("faiss_database")
187
+ logging.info(f"Created new FAISS index with {len(docs_to_keep)} documents.")
188
+
189
+ # Update uploaded_documents list
190
+ uploaded_documents = [doc for doc in uploaded_documents if doc["name"] not in deleted_docs]
191
+ save_documents(uploaded_documents)
192
+
193
+ return f"Deleted documents: {', '.join(deleted_docs)}", display_documents()
194
+
195
+ def generate_chunked_response(prompt, model, max_tokens=10000, num_calls=3, temperature=0.2, should_stop=False):
196
+ print(f"Starting generate_chunked_response with {num_calls} calls")
197
+ full_response = ""
198
+ messages = [{"role": "user", "content": prompt}]
199
+
200
+ if model == "@cf/meta/llama-3.1-8b-instruct":
201
+ # Cloudflare API
202
+ for i in range(num_calls):
203
+ print(f"Starting Cloudflare API call {i+1}")
204
+ if should_stop:
205
+ print("Stop clicked, breaking loop")
206
+ break
207
+ try:
208
+ response = requests.post(
209
+ f"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/run/@cf/meta/llama-3.1-8b-instruct",
210
+ headers={"Authorization": f"Bearer {API_TOKEN}"},
211
+ json={
212
+ "stream": true,
213
+ "messages": [
214
+ {"role": "system", "content": "You are a friendly assistant"},
215
+ {"role": "user", "content": prompt}
216
+ ],
217
+ "max_tokens": max_tokens,
218
+ "temperature": temperature
219
+ },
220
+ stream=true
221
+ )
222
+
223
+ for line in response.iter_lines():
224
+ if should_stop:
225
+ print("Stop clicked during streaming, breaking")
226
+ break
227
+ if line:
228
+ try:
229
+ json_data = json.loads(line.decode('utf-8').split('data: ')[1])
230
+ chunk = json_data['response']
231
+ full_response += chunk
232
+ except json.JSONDecodeError:
233
+ continue
234
+ print(f"Cloudflare API call {i+1} completed")
235
+ except Exception as e:
236
+ print(f"Error in generating response from Cloudflare: {str(e)}")
237
+ else:
238
+ # Original Hugging Face API logic
239
+ client = InferenceClient(model, token=huggingface_token)
240
+
241
+ for i in range(num_calls):
242
+ print(f"Starting Hugging Face API call {i+1}")
243
+ if should_stop:
244
+ print("Stop clicked, breaking loop")
245
+ break
246
+ try:
247
+ for message in client.chat_completion(
248
+ messages=messages,
249
+ max_tokens=max_tokens,
250
+ temperature=temperature,
251
+ stream=True,
252
+ ):
253
+ if should_stop:
254
+ print("Stop clicked during streaming, breaking")
255
+ break
256
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
257
+ chunk = message.choices[0].delta.content
258
+ full_response += chunk
259
+ print(f"Hugging Face API call {i+1} completed")
260
+ except Exception as e:
261
+ print(f"Error in generating response from Hugging Face: {str(e)}")
262
+
263
+ # Clean up the response
264
+ clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
265
+ clean_response = clean_response.replace("Using the following context:", "").strip()
266
+ clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
267
+
268
+ # Remove duplicate paragraphs and sentences
269
+ paragraphs = clean_response.split('\n\n')
270
+ unique_paragraphs = []
271
+ for paragraph in paragraphs:
272
+ if paragraph not in unique_paragraphs:
273
+ sentences = paragraph.split('. ')
274
+ unique_sentences = []
275
+ for sentence in sentences:
276
+ if sentence not in unique_sentences:
277
+ unique_sentences.append(sentence)
278
+ unique_paragraphs.append('. '.join(unique_sentences))
279
+
280
+ final_response = '\n\n'.join(unique_paragraphs)
281
+
282
+ print(f"Final clean response: {final_response[:100]}...")
283
+ return final_response
284
+
285
+ def duckduckgo_search(query):
286
+ with DDGS() as ddgs:
287
+ results = ddgs.text(query, max_results=5)
288
+ return results
289
 
290
+ class CitingSources(BaseModel):
291
+ sources: List[str] = Field(
292
+ ...,
293
+ description="List of sources to cite. Should be an URL of the source."
294
+ )
295
+ def chatbot_interface(message, history, use_web_search, model, temperature, num_calls):
296
+ if not message.strip():
297
+ return "", history
298
 
299
+ history = history + [(message, "")]
300
+
301
+ try:
302
+ for response in respond(message, history, model, temperature, num_calls, use_web_search):
303
+ history[-1] = (message, response)
304
+ yield history
305
+ except gr.CancelledError:
306
+ yield history
307
+ except Exception as e:
308
+ logging.error(f"Unexpected error in chatbot_interface: {str(e)}")
309
+ history[-1] = (message, f"An unexpected error occurred: {str(e)}")
310
+ yield history
311
+
312
+ def retry_last_response(history, use_web_search, model, temperature, num_calls):
313
+ if not history:
314
+ return history
315
+
316
+ last_user_msg = history[-1][0]
317
+ history = history[:-1] # Remove the last response
318
+
319
+ return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
320
+
321
+ def respond(message, history, model, temperature, num_calls, use_web_search, selected_docs, instruction_key):
322
+ logging.info(f"User Query: {message}")
323
+ logging.info(f"Model Used: {model}")
324
+ logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
325
+ logging.info(f"Selected Documents: {selected_docs}")
326
+ logging.info(f"Instruction Key: {instruction_key}")
327
+
328
+ try:
329
+ if instruction_key and instruction_key != "None":
330
+ # This is a summary generation request
331
+ instruction = INSTRUCTION_PROMPTS[instruction_key]
332
+ context_str = get_context_for_summary(selected_docs)
333
+ message = f"{instruction}\n\nUsing the following context from the PDF documents:\n{context_str}\nGenerate a detailed summary."
334
+ use_web_search = False # Ensure we use PDF search for summaries
335
+
336
+ if use_web_search:
337
+ for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
338
+ response = f"{main_content}\n\n{sources}"
339
+ first_line = response.split('\n')[0] if response else ''
340
+ # logging.info(f"Generated Response (first line): {first_line}")
341
+ yield response
342
+ else:
343
+ embed = get_embeddings()
344
+ if os.path.exists("faiss_database"):
345
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
346
+ retriever = database.as_retriever()
347
+
348
+ # Filter relevant documents based on user selection
349
+ all_relevant_docs = retriever.get_relevant_documents(message)
350
+ relevant_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
351
+
352
+ if not relevant_docs:
353
+ yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
354
+ return
355
+
356
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
357
+ else:
358
+ context_str = "No documents available."
359
+ yield "No documents available. Please upload PDF documents to answer questions."
360
+ return
361
+
362
+ if model == "@cf/meta/llama-3.1-8b-instruct":
363
+ # Use Cloudflare API
364
+ for partial_response in get_response_from_cloudflare(prompt="", context=context_str, query=message, num_calls=num_calls, temperature=temperature, search_type="pdf"):
365
+ first_line = partial_response.split('\n')[0] if partial_response else ''
366
+ # logging.info(f"Generated Response (first line): {first_line}")
367
+ yield partial_response
368
+ else:
369
+ # Use Hugging Face API
370
+ for partial_response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
371
+ first_line = partial_response.split('\n')[0] if partial_response else ''
372
+ # logging.info(f"Generated Response (first line): {first_line}")
373
+ yield partial_response
374
+
375
+ except Exception as e:
376
+ logging.error(f"Error with {model}: {str(e)}")
377
+ if "microsoft/Phi-3-mini-4k-instruct" in model:
378
+ logging.info("Falling back to Mistral model due to Phi-3 error")
379
+ fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
380
+ yield from respond(message, history, fallback_model, temperature, num_calls, use_web_search, selected_docs, instruction_key)
381
+ else:
382
+ yield f"An error occurred with the {model} model: {str(e)}. Please try again or select a different model."
383
+
384
+ logging.basicConfig(level=logging.DEBUG)
385
+
386
+ def get_context_for_summary(selected_docs):
387
+ embed = get_embeddings()
388
+ if os.path.exists("faiss_database"):
389
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
390
+ retriever = database.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 most relevant chunks
391
+
392
+ # Create a generic query that covers common financial summary topics
393
+ generic_query = "financial performance revenue profit assets liabilities cash flow key metrics highlights"
394
+
395
+ relevant_docs = retriever.get_relevant_documents(generic_query)
396
+ filtered_docs = [doc for doc in relevant_docs if doc.metadata["source"] in selected_docs]
397
+
398
+ if not filtered_docs:
399
+ return "No relevant information found in the selected documents for summary generation."
400
+
401
+ context_str = "\n".join([doc.page_content for doc in filtered_docs])
402
+ return context_str
403
+ else:
404
+ return "No documents available for summary generation."
405
+
406
+ def get_context_for_query(query, selected_docs):
407
+ embed = get_embeddings()
408
+ if os.path.exists("faiss_database"):
409
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
410
+ retriever = database.as_retriever(search_kwargs={"k": 3}) # Retrieve top 3 most relevant chunks
411
+
412
  relevant_docs = retriever.get_relevant_documents(query)
413
+ filtered_docs = [doc for doc in relevant_docs if doc.metadata["source"] in selected_docs]
414
+
415
+ if not filtered_docs:
416
+ return "No relevant information found in the selected documents for the given query."
417
+
418
+ context_str = "\n".join([doc.page_content for doc in filtered_docs])
419
+ return context_str
420
  else:
421
+ return "No documents available to answer the query."
422
 
423
+ def get_response_from_cloudflare(prompt, context, query, num_calls=3, temperature=0.2, search_type="pdf"):
424
+ headers = {
425
+ "Authorization": f"Bearer {API_TOKEN}",
426
+ "Content-Type": "application/json"
427
+ }
428
+ model = "@cf/meta/llama-3.1-8b-instruct"
429
 
430
+ if search_type == "pdf":
431
+ instruction = f"""Using the following context from the PDF documents:
432
  {context}
433
+ Write a detailed and complete response that answers the following user question: '{query}'"""
434
+ else: # web search
435
+ instruction = f"""Using the following context:
436
+ {context}
437
+ Write a detailed and complete research document that fulfills the following user request: '{query}'
438
+ After writing the document, please provide a list of sources used in your response."""
439
+
440
+ inputs = [
441
+ {"role": "system", "content": instruction},
442
+ {"role": "user", "content": query}
443
+ ]
444
 
445
+ payload = {
446
+ "messages": inputs,
447
+ "stream": True,
448
+ "temperature": temperature,
449
+ "max_tokens": 32000
450
+ }
451
 
452
+ full_response = ""
453
+ for i in range(num_calls):
454
  try:
455
+ with requests.post(f"{API_BASE_URL}{model}", headers=headers, json=payload, stream=True) as response:
456
+ if response.status_code == 200:
457
+ for line in response.iter_lines():
458
+ if line:
459
+ try:
460
+ json_response = json.loads(line.decode('utf-8').split('data: ')[1])
461
+ if 'response' in json_response:
462
+ chunk = json_response['response']
463
+ full_response += chunk
464
+ yield full_response
465
+ except (json.JSONDecodeError, IndexError) as e:
466
+ logging.error(f"Error parsing streaming response: {str(e)}")
467
+ continue
468
+ else:
469
+ logging.error(f"HTTP Error: {response.status_code}, Response: {response.text}")
470
+ yield f"I apologize, but I encountered an HTTP error: {response.status_code}. Please try again later."
471
  except Exception as e:
472
+ logging.error(f"Error in generating response from Cloudflare: {str(e)}")
473
+ yield f"I apologize, but an error occurred: {str(e)}. Please try again later."
474
+
475
  if not full_response:
476
+ yield "I apologize, but I couldn't generate a response at this time. Please try again later."
477
+
478
+ def create_web_search_vectors(search_results):
479
+ embed = get_embeddings()
480
+
481
+ documents = []
482
+ for result in search_results:
483
+ if 'body' in result:
484
+ content = f"{result['title']}\n{result['body']}\nSource: {result['href']}"
485
+ documents.append(Document(page_content=content, metadata={"source": result['href']}))
486
+
487
+ return FAISS.from_documents(documents, embed)
488
+
489
+ def get_response_with_search(query, model, num_calls=3, temperature=0.2):
490
+ search_results = duckduckgo_search(query)
491
+ web_search_database = create_web_search_vectors(search_results)
492
+
493
+ if not web_search_database:
494
+ yield "No web search results available. Please try again.", ""
495
+ return
496
+
497
+ retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
498
+ relevant_docs = retriever.get_relevant_documents(query)
499
+
500
+ context = "\n".join([doc.page_content for doc in relevant_docs])
501
+
502
+ prompt = f"""Using the following context from web search results:
503
+ {context}
504
+ Write a detailed and complete research document that fulfills the following user request: '{query}'
505
+ After writing the document, please provide a list of sources used in your response."""
506
+
507
+ if model == "@cf/meta/llama-3.1-8b-instruct":
508
+ # Use Cloudflare API
509
+ for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
510
+ yield response, "" # Yield streaming response without sources
511
+ else:
512
+ # Use Hugging Face API
513
+ client = InferenceClient(model, token=huggingface_token)
514
+
515
+ main_content = ""
516
+ for i in range(num_calls):
517
+ for message in client.chat_completion(
518
+ messages=[{"role": "user", "content": prompt}],
519
+ max_tokens=10000,
520
+ temperature=temperature,
521
+ stream=True,
522
+ ):
523
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
524
+ chunk = message.choices[0].delta.content
525
+ main_content += chunk
526
+ yield main_content, "" # Yield partial main content without sources
527
+
528
+
529
+
530
+
531
+
532
+ INSTRUCTION_PROMPTS = {
533
+ "Asset Managers": "Summarize the key financial metrics, assets under management, and performance highlights for this asset management company.",
534
+ "Consumer Finance Companies": "Provide a summary of the company's loan portfolio, interest income, credit quality, and key operational metrics.",
535
+ "Mortgage REITs": "Summarize the REIT's mortgage-backed securities portfolio, net interest income, book value per share, and dividend yield.",
536
+ # Add more instruction prompts as needed
537
+ }
538
+
539
+ def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
540
+ logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
541
+
542
+ embed = get_embeddings()
543
+ if os.path.exists("faiss_database"):
544
+ logging.info("Loading FAISS database")
545
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
546
+ else:
547
+ logging.warning("No FAISS database found")
548
+ yield "No documents available. Please upload PDF documents to answer questions."
549
+ return
550
+
551
+ # Pre-filter the documents
552
+ filtered_docs = []
553
+ for doc_id, doc in database.docstore._dict.items():
554
+ if isinstance(doc, Document) and doc.metadata.get("source") in selected_docs:
555
+ filtered_docs.append(doc)
556
+
557
+ logging.info(f"Number of documents after pre-filtering: {len(filtered_docs)}")
558
+
559
+ if not filtered_docs:
560
+ logging.warning(f"No documents found for the selected sources: {selected_docs}")
561
+ yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
562
+ return
563
+
564
+ # Create a new FAISS index with only the selected documents
565
+ filtered_db = FAISS.from_documents(filtered_docs, embed)
566
+
567
+ retriever = filtered_db.as_retriever(search_kwargs={"k": 10})
568
+ logging.info(f"Retrieving relevant documents for query: {query}")
569
+ relevant_docs = retriever.get_relevant_documents(query)
570
+ logging.info(f"Number of relevant documents retrieved: {len(relevant_docs)}")
571
+
572
+ for doc in relevant_docs:
573
+ logging.info(f"Document source: {doc.metadata['source']}")
574
+ logging.info(f"Document content preview: {doc.page_content[:100]}...") # Log first 100 characters of each document
575
+
576
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
577
+ logging.info(f"Total context length: {len(context_str)}")
578
+
579
+ if model == "@cf/meta/llama-3.1-8b-instruct":
580
+ logging.info("Using Cloudflare API")
581
+ # Use Cloudflare API with the retrieved context
582
+ for response in get_response_from_cloudflare(prompt="", context=context_str, query=query, num_calls=num_calls, temperature=temperature, search_type="pdf"):
583
+ yield response
584
+ else:
585
+ logging.info("Using Hugging Face API")
586
+ # Use Hugging Face API
587
+ prompt = f"""Using the following context from the PDF documents:
588
+ {context_str}
589
+ Write a detailed and complete response that answers the following user question: '{query}'"""
590
+
591
+ client = InferenceClient(model, token=huggingface_token)
592
+
593
+ response = ""
594
+ for i in range(num_calls):
595
+ logging.info(f"API call {i+1}/{num_calls}")
596
+ for message in client.chat_completion(
597
+ messages=[{"role": "user", "content": prompt}],
598
+ max_tokens=10000,
599
+ temperature=temperature,
600
+ stream=True,
601
+ ):
602
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
603
+ chunk = message.choices[0].delta.content
604
+ response += chunk
605
+ yield response # Yield partial response
606
+
607
+ logging.info("Finished generating response")
608
+
609
+ def vote(data: gr.LikeData):
610
+ if data.liked:
611
+ print(f"You upvoted this response: {data.value}")
612
+ else:
613
+ print(f"You downvoted this response: {data.value}")
614
+
615
 
 
616
 
 
 
 
 
 
 
 
617
 
 
 
 
 
 
 
 
 
 
618
 
619
  css = """
620
  /* Fine-tune chatbox size */
 
 
 
 
 
 
 
621
  }
622
  """
623
 
624
+ uploaded_documents = []
625
+
626
+ def display_documents():
627
+ return gr.CheckboxGroup(
628
+ choices=[doc["name"] for doc in uploaded_documents],
629
+ value=[doc["name"] for doc in uploaded_documents if doc["selected"]],
630
+ label="Select documents to query or delete"
631
+
632
+
633
+
634
+
635
+
636
+
637
+
638
+
639
+
640
+
641
+
642
+
643
+
644
+
645
+
646
+
647
+
648
+
649
+
650
+
651
+
652
+
653
+
654
+
655
  )
656
 
657
+ def initial_conversation():
658
+ return [
659
+ (None, "Welcome! I'm your AI assistant for web search and PDF analysis. Here's how you can use me:\n\n"
660
+ "1. Set the toggle for Web Search and PDF Search from the checkbox in Additional Inputs drop down window\n"
661
+ "2. Use web search to find information\n"
662
+ "3. Upload the documents and ask questions about uploaded PDF documents by selecting your respective document\n"
663
+ "4. For any queries feel free to reach out @[email protected] or discord - shreyas094\n\n"
664
+ "To get started, upload some PDFs or ask me a question!")
665
+ ]
666
+ # Add this new function
667
+ def refresh_documents():
668
+ global uploaded_documents
669
+ uploaded_documents = load_documents()
670
+ return display_documents()
671
+
672
+ # Define the checkbox outside the demo block
673
+ document_selector = gr.CheckboxGroup(label="Select documents to query")
674
+
675
+ use_web_search = gr.Checkbox(label="Use Web Search", value=True)
676
+
677
+ custom_placeholder = "Ask a question (Note: You can toggle between Web Search and PDF Chat in Additional Inputs below)"
678
+
679
+ instruction_choices = ["None"] + list(INSTRUCTION_PROMPTS.keys())
680
+
681
+ demo = gr.ChatInterface(
682
+ respond,
683
+ additional_inputs=[
684
+ gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[3]),
685
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
686
+ gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
687
+ use_web_search,
688
+ document_selector,
689
+ gr.Dropdown(choices=instruction_choices, label="Select Entity Type for Summary", value="None")
690
+ ],
691
+ title="AI-powered Web Search and PDF Chat Assistant",
692
+ description="Chat with your PDFs, use web search to answer questions, or generate summaries. Select an Entity Type for Summary to generate a specific summary.",
693
+ theme=gr.themes.Soft(
694
+ primary_hue="orange",
695
+ secondary_hue="amber",
696
+ neutral_hue="gray",
697
+ font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]
698
+ ).set(
699
+ body_background_fill_dark="#0c0505",
700
+ block_background_fill_dark="#0c0505",
701
+ block_border_width="1px",
702
+ block_title_background_fill_dark="#1b0f0f",
703
+ input_background_fill_dark="#140b0b",
704
+ button_secondary_background_fill_dark="#140b0b",
705
+ border_color_accent_dark="#1b0f0f",
706
+ border_color_primary_dark="#1b0f0f",
707
+ background_fill_secondary_dark="#0c0505",
708
+ color_accent_soft_dark="transparent",
709
+ code_background_fill_dark="#140b0b"
710
+ ),
711
+ css=css,
712
+ examples=[
713
+ ["Tell me about the contents of the uploaded PDFs."],
714
+ ["What are the main topics discussed in the documents?"],
715
+ ["Can you summarize the key points from the PDFs?"]
716
+ ],
717
+ cache_examples=False,
718
+ analytics_enabled=False,
719
+ textbox=gr.Textbox(placeholder=custom_placeholder, container=False, scale=7),
720
+ chatbot = gr.Chatbot(
721
+ show_copy_button=True,
722
+ likeable=True,
723
+ layout="bubble",
724
+ height=400,
725
+ value=initial_conversation()
726
+ )
727
+ )
728
+
729
+ # Add file upload functionality
730
+ with demo:
731
+ gr.Markdown("## Upload and Manage PDF Documents")
732
+
733
+ with gr.Row():
734
+ file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
735
+ parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="llamaparse")
736
+ update_button = gr.Button("Upload Document")
737
+ refresh_button = gr.Button("Refresh Document List")
738
+
739
+ update_output = gr.Textbox(label="Update Status")
740
+ delete_button = gr.Button("Delete Selected Documents")
741
+
742
+ # Update both the output text and the document selector
743
+ update_button.click(update_vectors,
744
+ inputs=[file_input, parser_dropdown],
745
+ outputs=[update_output, document_selector])
746
+
747
+ # Add the refresh button functionality
748
+ refresh_button.click(refresh_documents,
749
+ inputs=[],
750
+ outputs=[document_selector])
751
+
752
+ # Add the delete button functionality
753
+ delete_button.click(delete_documents,
754
+ inputs=[document_selector],
755
+ outputs=[update_output, document_selector])
756
+
757
+ gr.Markdown(
758
+ """
759
+ ## How to use
760
+ 1. Upload PDF documents using the file input at the top.
761
+ 2. Select the PDF parser (pypdf or llamaparse) and click "Upload Document" to update the vector store.
762
+ 3. Select the documents you want to query using the checkboxes.
763
+ 4. Ask questions in the chat interface.
764
+ 5. Toggle "Use Web Search" to switch between PDF chat and web search.
765
+ 6. Adjust Temperature and Number of API Calls to fine-tune the response generation.
766
+ 7. Use the provided examples or ask your own questions.
767
+ """
768
+ )
769
 
770
  if __name__ == "__main__":
771
+
772
  demo.launch(share=True)