usmanyousaf commited on
Commit
f12016a
·
verified ·
1 Parent(s): c879e4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -213
app.py CHANGED
@@ -12,138 +12,13 @@ from langchain_community.vectorstores import Chroma
12
  from langchain.chains import RetrievalQA
13
  import re
14
 
15
- from app import check_custom_db_exists
16
-
17
- # Custom CSS Injection
18
- def inject_custom_css():
19
- st.markdown("""
20
- <style>
21
- /* Main container */
22
- .stApp {
23
- background: linear-gradient(135deg, #1a1a1a, #2d2d2d);
24
- color: #e0e0e0;
25
- }
26
-
27
- /* Chat containers */
28
- .stChatMessage {
29
- padding: 1.5rem;
30
- border-radius: 15px;
31
- margin: 1rem 0;
32
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
33
- }
34
-
35
- /* User message styling */
36
- [data-testid="stChatMessage"][aria-label="user"] {
37
- background-color: #2d2d2d;
38
- border: 1px solid #3d3d3d;
39
- margin-left: 10%;
40
- }
41
-
42
- /* Assistant message styling */
43
- [data-testid="stChatMessage"][aria-label="assistant"] {
44
- background-color: #004d40;
45
- border: 1px solid #00695c;
46
- margin-right: 10%;
47
- }
48
-
49
- /* Sidebar styling */
50
- [data-testid="stSidebar"] {
51
- background: #121212 !important;
52
- border-right: 2px solid #2d2d2d;
53
- padding: 1rem;
54
- }
55
-
56
- /* Button styling */
57
- .stButton>button {
58
- background: linear-gradient(45deg, #00695c, #004d40);
59
- color: white !important;
60
- border: none;
61
- border-radius: 8px;
62
- padding: 0.8rem 1.5rem;
63
- transition: all 0.3s;
64
- font-weight: 500;
65
- }
66
-
67
- .stButton>button:hover {
68
- transform: translateY(-2px);
69
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);
70
- }
71
-
72
- /* File uploader */
73
- [data-testid="stFileUploader"] {
74
- border: 2px dashed #3d3d3d;
75
- border-radius: 10px;
76
- padding: 1rem;
77
- background: #2d2d2d;
78
- }
79
-
80
- /* Input field */
81
- .stTextInput>div>div>input {
82
- background-color: #2d2d2d;
83
- color: white;
84
- border: 1px solid #3d3d3d;
85
- border-radius: 8px;
86
- padding: 0.8rem;
87
- }
88
-
89
- /* Spinner color */
90
- .stSpinner>div>div {
91
- border-color: #00bcd4 transparent transparent transparent;
92
- }
93
-
94
- /* Custom title styling */
95
- .title-text {
96
- background: linear-gradient(45deg, #00bcd4, #00695c);
97
- -webkit-background-clip: text;
98
- -webkit-text-fill-color: transparent;
99
- font-family: 'Roboto', sans-serif;
100
- font-size: 2.8rem;
101
- text-align: center;
102
- margin-bottom: 2rem;
103
- letter-spacing: -0.5px;
104
- text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
105
- }
106
-
107
- /* Similar questions buttons */
108
- .stButton>button.similar-q {
109
- background: #2d2d2d;
110
- border: 1px solid #00bcd4;
111
- color: #00bcd4 !important;
112
- white-space: normal;
113
- height: auto;
114
- min-height: 3rem;
115
- transition: all 0.3s;
116
- }
117
-
118
- /* Hover effects */
119
- .stButton>button.similar-q:hover {
120
- background: #004d40 !important;
121
- transform: scale(1.02);
122
- }
123
-
124
- /* Source text styling */
125
- .source-text {
126
- color: #00bcd4;
127
- font-size: 0.9rem;
128
- margin-top: 1rem;
129
- padding-top: 0.5rem;
130
- border-top: 1px solid #3d3d3d;
131
- }
132
- </style>
133
- """, unsafe_allow_html=True)
134
-
135
  # Page Configuration
136
- st.set_page_config(
137
- page_title="AI Law Agent",
138
- page_icon="⚖️",
139
- layout="centered",
140
- initial_sidebar_state="expanded"
141
- )
142
 
143
  # Constants
144
  DEFAULT_GROQ_API_KEY = os.getenv("GROQ_API_KEY")
145
  MODEL_NAME = "llama-3.3-70b-versatile"
146
- DEFAULT_DOCUMENT_PATH = "/Users/appleenterprises/Desktop/ai law bot/lawbook.pdf"
147
  DEFAULT_COLLECTION_NAME = "pakistan_laws_default"
148
  CHROMA_PERSIST_DIR = "./chroma_db"
149
 
@@ -166,9 +41,11 @@ if "custom_collection_name" not in st.session_state:
166
  st.session_state.custom_collection_name = f"custom_laws_{st.session_state.user_id}"
167
 
168
  def setup_embeddings():
 
169
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
170
 
171
  def setup_llm():
 
172
  if st.session_state.llm is None:
173
  st.session_state.llm = ChatGroq(
174
  model_name=MODEL_NAME,
@@ -178,37 +55,50 @@ def setup_llm():
178
  return st.session_state.llm
179
 
180
  def check_default_db_exists():
181
- return os.path.exists(os.path.join(CHROMA_PERSIST_DIR, DEFAULT_COLLECTION_NAME))
 
 
 
182
 
183
  def load_existing_vectordb(collection_name):
 
 
184
  try:
185
- return Chroma(
186
  persist_directory=CHROMA_PERSIST_DIR,
187
- embedding_function=setup_embeddings(),
188
  collection_name=collection_name
189
  )
 
190
  except Exception as e:
191
- st.error(f"Error loading database: {str(e)}")
192
  return None
193
 
194
  def process_default_document(force_rebuild=False):
 
 
195
  if check_default_db_exists() and not force_rebuild:
 
196
  db = load_existing_vectordb(DEFAULT_COLLECTION_NAME)
197
- if db:
198
  st.session_state.vectordb = db
199
  setup_qa_chain()
200
  st.session_state.using_custom_docs = False
201
  return True
202
 
 
203
  if not os.path.exists(DEFAULT_DOCUMENT_PATH):
204
- st.error("Default document not found.")
205
  return False
206
 
 
 
207
  try:
208
- with st.spinner("Building knowledge base..."):
209
  loader = PyPDFLoader(DEFAULT_DOCUMENT_PATH)
210
  documents = loader.load()
211
 
 
212
  for doc in documents:
213
  doc.metadata["source"] = "Pakistan Laws (Official)"
214
 
@@ -218,40 +108,61 @@ def process_default_document(force_rebuild=False):
218
  )
219
  chunks = text_splitter.split_documents(documents)
220
 
 
221
  db = Chroma.from_documents(
222
  documents=chunks,
223
- embedding=setup_embeddings(),
224
  collection_name=DEFAULT_COLLECTION_NAME,
225
  persist_directory=CHROMA_PERSIST_DIR
226
  )
227
 
 
228
  db.persist()
 
229
  st.session_state.vectordb = db
230
  setup_qa_chain()
231
  st.session_state.using_custom_docs = False
 
232
  return True
233
  except Exception as e:
234
- st.error(f"Error processing document: {str(e)}")
235
  return False
236
 
 
 
 
 
 
 
237
  def process_custom_documents(uploaded_files):
 
 
238
  collection_name = st.session_state.custom_collection_name
 
239
  documents = []
240
 
241
  for uploaded_file in uploaded_files:
 
242
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
243
  tmp_file.write(uploaded_file.getvalue())
244
  tmp_path = tmp_file.name
245
 
 
246
  try:
247
  loader = PyPDFLoader(tmp_path)
248
  file_docs = loader.load()
 
 
249
  for doc in file_docs:
250
  doc.metadata["source"] = uploaded_file.name
 
251
  documents.extend(file_docs)
 
 
252
  os.unlink(tmp_path)
253
  except Exception as e:
254
  st.error(f"Error processing {uploaded_file.name}: {str(e)}")
 
255
 
256
  if documents:
257
  text_splitter = RecursiveCharacterTextSplitter(
@@ -260,33 +171,45 @@ def process_custom_documents(uploaded_files):
260
  )
261
  chunks = text_splitter.split_documents(documents)
262
 
263
- with st.spinner("Analyzing documents..."):
 
 
264
  if check_custom_db_exists(collection_name):
 
265
  temp_db = Chroma(
266
  persist_directory=CHROMA_PERSIST_DIR,
267
- embedding_function=setup_embeddings(),
268
  collection_name=collection_name
269
  )
270
  temp_db.delete_collection()
271
 
 
272
  db = Chroma.from_documents(
273
  documents=chunks,
274
- embedding=setup_embeddings(),
275
  collection_name=collection_name,
276
  persist_directory=CHROMA_PERSIST_DIR
277
  )
278
 
 
279
  db.persist()
 
280
  st.session_state.vectordb = db
281
  setup_qa_chain()
282
  st.session_state.using_custom_docs = True
 
283
  return True
284
  return False
285
 
286
  def setup_qa_chain():
 
287
  if st.session_state.vectordb:
288
- template = """You are a legal expert specializing in Pakistani law.
289
- Use context to answer. If unsure, state uncertainty but provide general legal info.
 
 
 
 
290
 
291
  Context: {context}
292
 
@@ -296,8 +219,9 @@ def setup_qa_chain():
296
 
297
  prompt = ChatPromptTemplate.from_template(template)
298
 
 
299
  st.session_state.qa_chain = RetrievalQA.from_chain_type(
300
- llm=setup_llm(),
301
  chain_type="stuff",
302
  retriever=st.session_state.vectordb.as_retriever(search_kwargs={"k": 3}),
303
  chain_type_kwargs={"prompt": prompt},
@@ -305,135 +229,155 @@ def setup_qa_chain():
305
  )
306
 
307
  def generate_similar_questions(question, docs):
 
308
  llm = setup_llm()
 
 
309
  context = "\n".join([doc.page_content for doc in docs[:2]])
310
 
311
- prompt = f"""Generate 3 specific Pakistani law questions related to:
 
 
312
 
313
- Original: {question}
314
 
315
- Context: {context}
316
 
317
- Generate exactly 3 questions:"""
318
 
319
  try:
320
  response = llm.invoke(prompt)
 
321
  questions = re.findall(r"\d+\.\s+(.*?)(?=\d+\.|$)", response.content, re.DOTALL)
322
  if not questions:
323
  questions = response.content.split("\n")
324
- questions = [q.strip() for q in questions if q.strip() and "?" in q]
325
- return [q.strip().replace("\n", " ") for q in questions if "?" in q][:3]
326
- except:
 
 
 
 
327
  return []
328
 
329
  def get_answer(question):
 
 
330
  if not st.session_state.vectordb:
331
- with st.spinner("Initializing system..."):
332
  process_default_document()
333
 
334
  if st.session_state.qa_chain:
335
  result = st.session_state.qa_chain({"query": question})
336
  answer = result["result"]
337
 
338
- st.session_state.similar_questions = generate_similar_questions(question, result.get("source_documents", []))
 
 
339
 
 
340
  sources = set()
341
- for doc in result.get("source_documents", []):
342
  if "source" in doc.metadata:
343
  sources.add(doc.metadata["source"])
344
 
345
  if sources:
346
- answer += f"\n\n<div class='source-text'>Sources: {', '.join(sources)}</div>"
347
 
348
  return answer
349
- return "System initializing... Please try again."
 
350
 
351
  def main():
352
- inject_custom_css()
353
 
354
- st.markdown("""
355
- <h1 class="title-text">
356
- <div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem;">
357
- <span>⚖️</span>
358
- <span>Your AI Law Agent</span>
359
- </div>
360
- </h1>
361
- """, unsafe_allow_html=True)
362
-
363
- # Sidebar Management
364
  with st.sidebar:
365
- st.header("📚 Document Management")
366
 
 
367
  if st.session_state.using_custom_docs:
368
- if st.button("🔙 Return to Official Database", use_container_width=True):
369
- with st.spinner("Switching..."):
370
  process_default_document()
371
- st.session_state.messages.append(AIMessage(content="Switched to official database"))
 
372
  st.rerun()
373
 
 
374
  if not st.session_state.using_custom_docs:
375
- if st.button("🔄 Rebuild Database", use_container_width=True):
376
- with st.spinner("Rebuilding..."):
377
  process_default_document(force_rebuild=True)
 
378
  st.rerun()
379
 
380
- st.header("📁 Upload Documents")
 
381
  uploaded_files = st.file_uploader(
382
- "Upload legal PDFs",
383
  type=["pdf"],
384
- accept_multiple_files=True,
385
- label_visibility="collapsed"
386
  )
387
 
388
- if st.button("🚀 Train on Uploads", use_container_width=True) and uploaded_files:
389
- with st.spinner("Processing..."):
390
- if process_custom_documents(uploaded_files):
391
- st.session_state.messages.append(AIMessage(content="Custom documents loaded"))
 
 
392
  st.rerun()
393
-
394
- # Chat Display
395
  for message in st.session_state.messages:
396
- avatar = "👤" if isinstance(message, HumanMessage) else "⚖️"
397
- with st.chat_message("user" if isinstance(message, HumanMessage) else "assistant", avatar=avatar):
398
- st.write(message.content)
399
-
400
- # Similar Questions
 
 
 
401
  if st.session_state.similar_questions:
402
- st.markdown("""
403
- <div style="padding: 1rem; background: #2d2d2d; border-radius: 10px; margin: 1rem 0;">
404
- <h4 style="color: #00bcd4; margin-bottom: 0.5rem;">🔍 Related Queries</h4>
405
- """, unsafe_allow_html=True)
406
-
407
- cols = st.columns([1,1,1])
408
  for i, question in enumerate(st.session_state.similar_questions):
409
- with cols[i]:
410
- if st.button(
411
- f"❓ {question}",
412
- key=f"similar_q_{i}",
413
- use_container_width=True,
414
- help="Click to ask this related question"
415
- ):
416
- st.session_state.messages.append(HumanMessage(content=question))
417
- with st.chat_message("assistant", avatar="⚖️"):
418
- with st.spinner("Analyzing..."):
419
- response = get_answer(question)
420
- st.write(response, unsafe_allow_html=True)
421
- st.session_state.messages.append(AIMessage(content=response))
422
- st.rerun()
423
-
424
- st.markdown("</div>", unsafe_allow_html=True)
425
-
426
- # Input Handling
427
- if user_input := st.chat_input("Ask your legal question..."):
428
  st.session_state.messages.append(HumanMessage(content=user_input))
 
 
429
  with st.chat_message("user"):
430
  st.write(user_input)
431
 
 
432
  with st.chat_message("assistant", avatar="⚖️"):
433
- with st.spinner("Researching..."):
434
  response = get_answer(user_input)
435
- st.write(response, unsafe_allow_html=True)
436
 
 
437
  st.session_state.messages.append(AIMessage(content=response))
438
  st.rerun()
439
 
 
12
  from langchain.chains import RetrievalQA
13
  import re
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Page Configuration
16
+ st.set_page_config(page_title="Pakistan Law AI Agent", page_icon="⚖️")
 
 
 
 
 
17
 
18
  # Constants
19
  DEFAULT_GROQ_API_KEY = os.getenv("GROQ_API_KEY")
20
  MODEL_NAME = "llama-3.3-70b-versatile"
21
+ DEFAULT_DOCUMENT_PATH = "/Users/appleenterprises/Desktop/ai law bot/lawbook.pdf" # Path to your hardcoded Pakistan laws PDF
22
  DEFAULT_COLLECTION_NAME = "pakistan_laws_default"
23
  CHROMA_PERSIST_DIR = "./chroma_db"
24
 
 
41
  st.session_state.custom_collection_name = f"custom_laws_{st.session_state.user_id}"
42
 
43
  def setup_embeddings():
44
+ """Sets up embeddings model"""
45
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
46
 
47
  def setup_llm():
48
+ """Setup the language model"""
49
  if st.session_state.llm is None:
50
  st.session_state.llm = ChatGroq(
51
  model_name=MODEL_NAME,
 
55
  return st.session_state.llm
56
 
57
  def check_default_db_exists():
58
+ """Check if the default document database already exists"""
59
+ if os.path.exists(os.path.join(CHROMA_PERSIST_DIR, DEFAULT_COLLECTION_NAME)):
60
+ return True
61
+ return False
62
 
63
  def load_existing_vectordb(collection_name):
64
+ """Load an existing vector database from disk"""
65
+ embeddings = setup_embeddings()
66
  try:
67
+ db = Chroma(
68
  persist_directory=CHROMA_PERSIST_DIR,
69
+ embedding_function=embeddings,
70
  collection_name=collection_name
71
  )
72
+ return db
73
  except Exception as e:
74
+ st.error(f"Error loading existing database: {str(e)}")
75
  return None
76
 
77
  def process_default_document(force_rebuild=False):
78
+ """Process the default Pakistan laws document or load from disk if available"""
79
+ # Check if database already exists
80
  if check_default_db_exists() and not force_rebuild:
81
+ st.info("Loading existing Pakistan law database...")
82
  db = load_existing_vectordb(DEFAULT_COLLECTION_NAME)
83
+ if db is not None:
84
  st.session_state.vectordb = db
85
  setup_qa_chain()
86
  st.session_state.using_custom_docs = False
87
  return True
88
 
89
+ # If database doesn't exist or force rebuild, create it
90
  if not os.path.exists(DEFAULT_DOCUMENT_PATH):
91
+ st.error(f"Default document {DEFAULT_DOCUMENT_PATH} not found. Please make sure it exists.")
92
  return False
93
 
94
+ embeddings = setup_embeddings()
95
+
96
  try:
97
+ with st.spinner("Building Pakistan law database (this may take a few minutes)..."):
98
  loader = PyPDFLoader(DEFAULT_DOCUMENT_PATH)
99
  documents = loader.load()
100
 
101
+ # Add source filename to metadata
102
  for doc in documents:
103
  doc.metadata["source"] = "Pakistan Laws (Official)"
104
 
 
108
  )
109
  chunks = text_splitter.split_documents(documents)
110
 
111
+ # Create vector store
112
  db = Chroma.from_documents(
113
  documents=chunks,
114
+ embedding=embeddings,
115
  collection_name=DEFAULT_COLLECTION_NAME,
116
  persist_directory=CHROMA_PERSIST_DIR
117
  )
118
 
119
+ # Explicitly persist to disk
120
  db.persist()
121
+
122
  st.session_state.vectordb = db
123
  setup_qa_chain()
124
  st.session_state.using_custom_docs = False
125
+
126
  return True
127
  except Exception as e:
128
+ st.error(f"Error processing default document: {str(e)}")
129
  return False
130
 
131
+ def check_custom_db_exists(collection_name):
132
+ """Check if a custom document database already exists"""
133
+ if os.path.exists(os.path.join(CHROMA_PERSIST_DIR, collection_name)):
134
+ return True
135
+ return False
136
+
137
  def process_custom_documents(uploaded_files):
138
+ """Process user-uploaded PDF documents"""
139
+ embeddings = setup_embeddings()
140
  collection_name = st.session_state.custom_collection_name
141
+
142
  documents = []
143
 
144
  for uploaded_file in uploaded_files:
145
+ # Save file temporarily
146
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
147
  tmp_file.write(uploaded_file.getvalue())
148
  tmp_path = tmp_file.name
149
 
150
+ # Load and split the document
151
  try:
152
  loader = PyPDFLoader(tmp_path)
153
  file_docs = loader.load()
154
+
155
+ # Add source filename to metadata
156
  for doc in file_docs:
157
  doc.metadata["source"] = uploaded_file.name
158
+
159
  documents.extend(file_docs)
160
+
161
+ # Clean up temp file
162
  os.unlink(tmp_path)
163
  except Exception as e:
164
  st.error(f"Error processing {uploaded_file.name}: {str(e)}")
165
+ continue
166
 
167
  if documents:
168
  text_splitter = RecursiveCharacterTextSplitter(
 
171
  )
172
  chunks = text_splitter.split_documents(documents)
173
 
174
+ # Create vector store
175
+ with st.spinner("Building custom document database..."):
176
+ # If a previous custom DB exists for this user, delete it first
177
  if check_custom_db_exists(collection_name):
178
+ # We need to recreate the vectorstore to delete the old collection
179
  temp_db = Chroma(
180
  persist_directory=CHROMA_PERSIST_DIR,
181
+ embedding_function=embeddings,
182
  collection_name=collection_name
183
  )
184
  temp_db.delete_collection()
185
 
186
+ # Create new vector store
187
  db = Chroma.from_documents(
188
  documents=chunks,
189
+ embedding=embeddings,
190
  collection_name=collection_name,
191
  persist_directory=CHROMA_PERSIST_DIR
192
  )
193
 
194
+ # Explicitly persist to disk
195
  db.persist()
196
+
197
  st.session_state.vectordb = db
198
  setup_qa_chain()
199
  st.session_state.using_custom_docs = True
200
+
201
  return True
202
  return False
203
 
204
  def setup_qa_chain():
205
+ """Set up the QA chain with the RAG system"""
206
  if st.session_state.vectordb:
207
+ llm = setup_llm()
208
+
209
+ # Create prompt template
210
+ template = """You are a helpful legal assistant specializing in Pakistani law.
211
+ Use the following context to answer the question. If you don't know the answer based on the context,
212
+ say that you don't have enough information, but provide general legal information if possible.
213
 
214
  Context: {context}
215
 
 
219
 
220
  prompt = ChatPromptTemplate.from_template(template)
221
 
222
+ # Create the QA chain
223
  st.session_state.qa_chain = RetrievalQA.from_chain_type(
224
+ llm=llm,
225
  chain_type="stuff",
226
  retriever=st.session_state.vectordb.as_retriever(search_kwargs={"k": 3}),
227
  chain_type_kwargs={"prompt": prompt},
 
229
  )
230
 
231
  def generate_similar_questions(question, docs):
232
+ """Generate similar questions based on retrieved documents"""
233
  llm = setup_llm()
234
+
235
+ # Extract key content from docs
236
  context = "\n".join([doc.page_content for doc in docs[:2]])
237
 
238
+ # Prompt to generate similar questions
239
+ prompt = f"""Based on the following user question and legal context, generate 3 similar questions that the user might also be interested in.
240
+ Make the questions specific, related to Pakistani law, and directly relevant to the original question.
241
 
242
+ Original Question: {question}
243
 
244
+ Legal Context: {context}
245
 
246
+ Generate exactly 3 similar questions:"""
247
 
248
  try:
249
  response = llm.invoke(prompt)
250
+ # Extract questions from response using regex
251
  questions = re.findall(r"\d+\.\s+(.*?)(?=\d+\.|$)", response.content, re.DOTALL)
252
  if not questions:
253
  questions = response.content.split("\n")
254
+ questions = [q.strip() for q in questions if q.strip() and not q.startswith("Similar") and "?" in q]
255
+
256
+ # Clean and limit to 3 questions
257
+ questions = [q.strip().replace("\n", " ") for q in questions if "?" in q]
258
+ return questions[:3]
259
+ except Exception as e:
260
+ print(f"Error generating similar questions: {e}")
261
  return []
262
 
263
  def get_answer(question):
264
+ """Get answer from QA chain"""
265
+ # If default documents haven't been processed yet, try to load them
266
  if not st.session_state.vectordb:
267
+ with st.spinner("Loading Pakistan law database..."):
268
  process_default_document()
269
 
270
  if st.session_state.qa_chain:
271
  result = st.session_state.qa_chain({"query": question})
272
  answer = result["result"]
273
 
274
+ # Generate similar questions
275
+ source_docs = result.get("source_documents", [])
276
+ st.session_state.similar_questions = generate_similar_questions(question, source_docs)
277
 
278
+ # Add source information
279
  sources = set()
280
+ for doc in source_docs:
281
  if "source" in doc.metadata:
282
  sources.add(doc.metadata["source"])
283
 
284
  if sources:
285
+ answer += f"\n\nSources: {', '.join(sources)}"
286
 
287
  return answer
288
+ else:
289
+ return "Initializing the knowledge base. Please try again in a moment."
290
 
291
  def main():
292
+ st.title("Pakistan Law AI Agent")
293
 
294
+ # Determine current mode
295
+ if st.session_state.using_custom_docs:
296
+ st.subheader("Training on your personal resources")
297
+ else:
298
+ st.subheader("Powered by Pakistan law database")
299
+
300
+ # Sidebar for uploading documents and switching modes
 
 
 
301
  with st.sidebar:
302
+ st.header("Resource Management")
303
 
304
+ # Option to return to default documents
305
  if st.session_state.using_custom_docs:
306
+ if st.button("Return to Official Database"):
307
+ with st.spinner("Loading official Pakistan law database..."):
308
  process_default_document()
309
+ st.success("Switched to official Pakistan law database!")
310
+ st.session_state.messages.append(AIMessage(content="Switched to official Pakistan law database. You can now ask legal questions."))
311
  st.rerun()
312
 
313
+ # Option to rebuild the default database
314
  if not st.session_state.using_custom_docs:
315
+ if st.button("Rebuild Official Database"):
316
+ with st.spinner("Rebuilding official Pakistan law database..."):
317
  process_default_document(force_rebuild=True)
318
+ st.success("Official database rebuilt successfully!")
319
  st.rerun()
320
 
321
+ # Option to upload custom documents
322
+ st.header("Upload Custom Legal Documents")
323
  uploaded_files = st.file_uploader(
324
+ "Upload PDF files containing legal documents",
325
  type=["pdf"],
326
+ accept_multiple_files=True
 
327
  )
328
 
329
+ if st.button("Train on Uploaded Documents") and uploaded_files:
330
+ with st.spinner("Processing your documents..."):
331
+ success = process_custom_documents(uploaded_files)
332
+ if success:
333
+ st.success("Your documents processed successfully!")
334
+ st.session_state.messages.append(AIMessage(content="Custom legal documents loaded successfully. You are now training on your personal resources."))
335
  st.rerun()
336
+
337
+ # Display chat messages
338
  for message in st.session_state.messages:
339
+ if isinstance(message, HumanMessage):
340
+ with st.chat_message("user"):
341
+ st.write(message.content)
342
+ else:
343
+ with st.chat_message("assistant", avatar="⚖️"):
344
+ st.write(message.content)
345
+
346
+ # Display similar questions if available
347
  if st.session_state.similar_questions:
348
+ st.markdown("#### Related Questions:")
349
+ cols = st.columns(len(st.session_state.similar_questions))
 
 
 
 
350
  for i, question in enumerate(st.session_state.similar_questions):
351
+ if cols[i].button(question, key=f"similar_q_{i}"):
352
+ # Add selected question as user input
353
+ st.session_state.messages.append(HumanMessage(content=question))
354
+
355
+ # Generate and display assistant response
356
+ with st.chat_message("assistant", avatar="⚖️"):
357
+ with st.spinner("Thinking..."):
358
+ response = get_answer(question)
359
+ st.write(response)
360
+
361
+ # Add assistant response to chat history
362
+ st.session_state.messages.append(AIMessage(content=response))
363
+ st.rerun()
364
+
365
+ # Input for new question
366
+ if user_input := st.chat_input("Ask a legal question..."):
367
+ # Add user message to chat history
 
 
368
  st.session_state.messages.append(HumanMessage(content=user_input))
369
+
370
+ # Display user message
371
  with st.chat_message("user"):
372
  st.write(user_input)
373
 
374
+ # Generate and display assistant response
375
  with st.chat_message("assistant", avatar="⚖️"):
376
+ with st.spinner("Thinking..."):
377
  response = get_answer(user_input)
378
+ st.write(response)
379
 
380
+ # Add assistant response to chat history
381
  st.session_state.messages.append(AIMessage(content=response))
382
  st.rerun()
383