Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import tempfile | |
import uuid | |
from langchain_groq import ChatGroq | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import HumanMessage, AIMessage | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
import re | |
# Page Configuration | |
st.set_page_config(page_title="Pakistan Law AI Agent", page_icon="⚖️") | |
# Constants | |
DEFAULT_GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
MODEL_NAME = "llama-3.3-70b-versatile" | |
DEFAULT_DOCUMENT_PATH = "lawbook.pdf" # Path to your hardcoded Pakistan laws PDF | |
DEFAULT_COLLECTION_NAME = "pakistan_laws_default" | |
CHROMA_PERSIST_DIR = "./chroma_db" | |
# Session state initialization | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "user_id" not in st.session_state: | |
st.session_state.user_id = str(uuid.uuid4()) | |
if "vectordb" not in st.session_state: | |
st.session_state.vectordb = None | |
if "llm" not in st.session_state: | |
st.session_state.llm = None | |
if "qa_chain" not in st.session_state: | |
st.session_state.qa_chain = None | |
if "similar_questions" not in st.session_state: | |
st.session_state.similar_questions = [] | |
if "using_custom_docs" not in st.session_state: | |
st.session_state.using_custom_docs = False | |
if "custom_collection_name" not in st.session_state: | |
st.session_state.custom_collection_name = f"custom_laws_{st.session_state.user_id}" | |
def setup_embeddings(): | |
"""Sets up embeddings model""" | |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def setup_llm(): | |
"""Setup the language model""" | |
if st.session_state.llm is None: | |
st.session_state.llm = ChatGroq( | |
model_name=MODEL_NAME, | |
groq_api_key=DEFAULT_GROQ_API_KEY, | |
temperature=0.2 | |
) | |
return st.session_state.llm | |
def check_default_db_exists(): | |
"""Check if the default document database already exists""" | |
if os.path.exists(os.path.join(CHROMA_PERSIST_DIR, DEFAULT_COLLECTION_NAME)): | |
return True | |
return False | |
def load_existing_vectordb(collection_name): | |
"""Load an existing vector database from disk""" | |
embeddings = setup_embeddings() | |
try: | |
db = Chroma( | |
persist_directory=CHROMA_PERSIST_DIR, | |
embedding_function=embeddings, | |
collection_name=collection_name | |
) | |
return db | |
except Exception as e: | |
st.error(f"Error loading existing database: {str(e)}") | |
return None | |
def process_default_document(force_rebuild=False): | |
"""Process the default Pakistan laws document or load from disk if available""" | |
# Check if database already exists | |
if check_default_db_exists() and not force_rebuild: | |
st.info("Loading existing Pakistan law database...") | |
db = load_existing_vectordb(DEFAULT_COLLECTION_NAME) | |
if db is not None: | |
st.session_state.vectordb = db | |
setup_qa_chain() | |
st.session_state.using_custom_docs = False | |
return True | |
# If database doesn't exist or force rebuild, create it | |
if not os.path.exists(DEFAULT_DOCUMENT_PATH): | |
st.error(f"Default document {DEFAULT_DOCUMENT_PATH} not found. Please make sure it exists.") | |
return False | |
embeddings = setup_embeddings() | |
try: | |
with st.spinner("Building Pakistan law database (this may take a few minutes)..."): | |
loader = PyPDFLoader(DEFAULT_DOCUMENT_PATH) | |
documents = loader.load() | |
# Add source filename to metadata | |
for doc in documents: | |
doc.metadata["source"] = "Pakistan Laws (Official)" | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200 | |
) | |
chunks = text_splitter.split_documents(documents) | |
# Create vector store | |
db = Chroma.from_documents( | |
documents=chunks, | |
embedding=embeddings, | |
collection_name=DEFAULT_COLLECTION_NAME, | |
persist_directory=CHROMA_PERSIST_DIR | |
) | |
# Explicitly persist to disk | |
db.persist() | |
st.session_state.vectordb = db | |
setup_qa_chain() | |
st.session_state.using_custom_docs = False | |
return True | |
except Exception as e: | |
st.error(f"Error processing default document: {str(e)}") | |
return False | |
def check_custom_db_exists(collection_name): | |
"""Check if a custom document database already exists""" | |
if os.path.exists(os.path.join(CHROMA_PERSIST_DIR, collection_name)): | |
return True | |
return False | |
def process_custom_documents(uploaded_files): | |
"""Process user-uploaded PDF documents""" | |
embeddings = setup_embeddings() | |
collection_name = st.session_state.custom_collection_name | |
documents = [] | |
for uploaded_file in uploaded_files: | |
# Save file temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
tmp_path = tmp_file.name | |
# Load and split the document | |
try: | |
loader = PyPDFLoader(tmp_path) | |
file_docs = loader.load() | |
# Add source filename to metadata | |
for doc in file_docs: | |
doc.metadata["source"] = uploaded_file.name | |
documents.extend(file_docs) | |
# Clean up temp file | |
os.unlink(tmp_path) | |
except Exception as e: | |
st.error(f"Error processing {uploaded_file.name}: {str(e)}") | |
continue | |
if documents: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200 | |
) | |
chunks = text_splitter.split_documents(documents) | |
# Create vector store | |
with st.spinner("Building custom document database..."): | |
# If a previous custom DB exists for this user, delete it first | |
if check_custom_db_exists(collection_name): | |
# We need to recreate the vectorstore to delete the old collection | |
temp_db = Chroma( | |
persist_directory=CHROMA_PERSIST_DIR, | |
embedding_function=embeddings, | |
collection_name=collection_name | |
) | |
temp_db.delete_collection() | |
# Create new vector store | |
db = Chroma.from_documents( | |
documents=chunks, | |
embedding=embeddings, | |
collection_name=collection_name, | |
persist_directory=CHROMA_PERSIST_DIR | |
) | |
# Explicitly persist to disk | |
db.persist() | |
st.session_state.vectordb = db | |
setup_qa_chain() | |
st.session_state.using_custom_docs = True | |
return True | |
return False | |
def setup_qa_chain(): | |
"""Set up the QA chain with the RAG system""" | |
if st.session_state.vectordb: | |
llm = setup_llm() | |
# Create prompt template | |
template = """You are a helpful legal assistant specializing in Pakistani law. | |
Use the following context to answer the question. If you don't know the answer based on the context, | |
say that you don't have enough information, but provide general legal information if possible. | |
Context: {context} | |
Question: {question} | |
Answer:""" | |
prompt = ChatPromptTemplate.from_template(template) | |
# Create the QA chain | |
st.session_state.qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=st.session_state.vectordb.as_retriever(search_kwargs={"k": 3}), | |
chain_type_kwargs={"prompt": prompt}, | |
return_source_documents=True | |
) | |
def generate_similar_questions(question, docs): | |
"""Generate similar questions based on retrieved documents""" | |
llm = setup_llm() | |
# Extract key content from docs | |
context = "\n".join([doc.page_content for doc in docs[:2]]) | |
# Prompt to generate similar questions | |
prompt = f"""Based on the following user question and legal context, generate 3 similar questions that the user might also be interested in. | |
Make the questions specific, related to Pakistani law, and directly relevant to the original question. | |
Original Question: {question} | |
Legal Context: {context} | |
Generate exactly 3 similar questions:""" | |
try: | |
response = llm.invoke(prompt) | |
# Extract questions from response using regex | |
questions = re.findall(r"\d+\.\s+(.*?)(?=\d+\.|$)", response.content, re.DOTALL) | |
if not questions: | |
questions = response.content.split("\n") | |
questions = [q.strip() for q in questions if q.strip() and not q.startswith("Similar") and "?" in q] | |
# Clean and limit to 3 questions | |
questions = [q.strip().replace("\n", " ") for q in questions if "?" in q] | |
return questions[:3] | |
except Exception as e: | |
print(f"Error generating similar questions: {e}") | |
return [] | |
def get_answer(question): | |
"""Get answer from QA chain""" | |
# If default documents haven't been processed yet, try to load them | |
if not st.session_state.vectordb: | |
with st.spinner("Loading Pakistan law database..."): | |
process_default_document() | |
if st.session_state.qa_chain: | |
result = st.session_state.qa_chain({"query": question}) | |
answer = result["result"] | |
# Generate similar questions | |
source_docs = result.get("source_documents", []) | |
st.session_state.similar_questions = generate_similar_questions(question, source_docs) | |
# Add source information | |
sources = set() | |
for doc in source_docs: | |
if "source" in doc.metadata: | |
sources.add(doc.metadata["source"]) | |
if sources: | |
answer += f"\n\nSources: {', '.join(sources)}" | |
return answer | |
else: | |
return "Initializing the knowledge base. Please try again in a moment." | |
def main(): | |
st.title("Pakistan Law AI Agent") | |
# Determine current mode | |
if st.session_state.using_custom_docs: | |
st.subheader("Training on your personal resources") | |
else: | |
st.subheader("Powered by Pakistan law database") | |
# Sidebar for uploading documents and switching modes | |
with st.sidebar: | |
st.header("Resource Management") | |
# Option to return to default documents | |
if st.session_state.using_custom_docs: | |
if st.button("Return to Official Database"): | |
with st.spinner("Loading official Pakistan law database..."): | |
process_default_document() | |
st.success("Switched to official Pakistan law database!") | |
st.session_state.messages.append(AIMessage(content="Switched to official Pakistan law database. You can now ask legal questions.")) | |
st.rerun() | |
# Option to rebuild the default database | |
if not st.session_state.using_custom_docs: | |
if st.button("Rebuild Official Database"): | |
with st.spinner("Rebuilding official Pakistan law database..."): | |
process_default_document(force_rebuild=True) | |
st.success("Official database rebuilt successfully!") | |
st.rerun() | |
# Option to upload custom documents | |
st.header("Upload Custom Legal Documents") | |
uploaded_files = st.file_uploader( | |
"Upload PDF files containing legal documents", | |
type=["pdf"], | |
accept_multiple_files=True | |
) | |
if st.button("Train on Uploaded Documents") and uploaded_files: | |
with st.spinner("Processing your documents..."): | |
success = process_custom_documents(uploaded_files) | |
if success: | |
st.success("Your documents processed successfully!") | |
st.session_state.messages.append(AIMessage(content="Custom legal documents loaded successfully. You are now training on your personal resources.")) | |
st.rerun() | |
# Display chat messages | |
for message in st.session_state.messages: | |
if isinstance(message, HumanMessage): | |
with st.chat_message("user"): | |
st.write(message.content) | |
else: | |
with st.chat_message("assistant", avatar="⚖️"): | |
st.write(message.content) | |
# Display similar questions if available | |
if st.session_state.similar_questions: | |
st.markdown("#### Related Questions:") | |
cols = st.columns(len(st.session_state.similar_questions)) | |
for i, question in enumerate(st.session_state.similar_questions): | |
if cols[i].button(question, key=f"similar_q_{i}"): | |
# Add selected question as user input | |
st.session_state.messages.append(HumanMessage(content=question)) | |
# Generate and display assistant response | |
with st.chat_message("assistant", avatar="⚖️"): | |
with st.spinner("Thinking..."): | |
response = get_answer(question) | |
st.write(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append(AIMessage(content=response)) | |
st.rerun() | |
# Input for new question | |
if user_input := st.chat_input("Ask a legal question..."): | |
# Add user message to chat history | |
st.session_state.messages.append(HumanMessage(content=user_input)) | |
# Display user message | |
with st.chat_message("user"): | |
st.write(user_input) | |
# Generate and display assistant response | |
with st.chat_message("assistant", avatar="⚖️"): | |
with st.spinner("Thinking..."): | |
response = get_answer(user_input) | |
st.write(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append(AIMessage(content=response)) | |
st.rerun() | |
if __name__ == "__main__": | |
main() |