Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import tempfile | |
| import uuid | |
| from langchain_groq import ChatGroq | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import HumanMessage, AIMessage | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.chains import RetrievalQA | |
| import re | |
| # Page Configuration | |
| st.set_page_config(page_title="Pakistan Law AI Agent", page_icon="⚖️") | |
| # Constants | |
| DEFAULT_GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| MODEL_NAME = "llama-3.3-70b-versatile" | |
| DEFAULT_DOCUMENT_PATH = "lawbook.pdf" # Path to your hardcoded Pakistan laws PDF | |
| DEFAULT_COLLECTION_NAME = "pakistan_laws_default" | |
| CHROMA_PERSIST_DIR = "./chroma_db" | |
| # Session state initialization | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "user_id" not in st.session_state: | |
| st.session_state.user_id = str(uuid.uuid4()) | |
| if "vectordb" not in st.session_state: | |
| st.session_state.vectordb = None | |
| if "llm" not in st.session_state: | |
| st.session_state.llm = None | |
| if "qa_chain" not in st.session_state: | |
| st.session_state.qa_chain = None | |
| if "similar_questions" not in st.session_state: | |
| st.session_state.similar_questions = [] | |
| if "using_custom_docs" not in st.session_state: | |
| st.session_state.using_custom_docs = False | |
| if "custom_collection_name" not in st.session_state: | |
| st.session_state.custom_collection_name = f"custom_laws_{st.session_state.user_id}" | |
| def setup_embeddings(): | |
| """Sets up embeddings model""" | |
| return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| def setup_llm(): | |
| """Setup the language model""" | |
| if st.session_state.llm is None: | |
| st.session_state.llm = ChatGroq( | |
| model_name=MODEL_NAME, | |
| groq_api_key=DEFAULT_GROQ_API_KEY, | |
| temperature=0.2 | |
| ) | |
| return st.session_state.llm | |
| def check_default_db_exists(): | |
| """Check if the default document database already exists""" | |
| if os.path.exists(os.path.join(CHROMA_PERSIST_DIR, DEFAULT_COLLECTION_NAME)): | |
| return True | |
| return False | |
| def load_existing_vectordb(collection_name): | |
| """Load an existing vector database from disk""" | |
| embeddings = setup_embeddings() | |
| try: | |
| db = Chroma( | |
| persist_directory=CHROMA_PERSIST_DIR, | |
| embedding_function=embeddings, | |
| collection_name=collection_name | |
| ) | |
| return db | |
| except Exception as e: | |
| st.error(f"Error loading existing database: {str(e)}") | |
| return None | |
| def process_default_document(force_rebuild=False): | |
| """Process the default Pakistan laws document or load from disk if available""" | |
| # Check if database already exists | |
| if check_default_db_exists() and not force_rebuild: | |
| st.info("Loading existing Pakistan law database...") | |
| db = load_existing_vectordb(DEFAULT_COLLECTION_NAME) | |
| if db is not None: | |
| st.session_state.vectordb = db | |
| setup_qa_chain() | |
| st.session_state.using_custom_docs = False | |
| return True | |
| # If database doesn't exist or force rebuild, create it | |
| if not os.path.exists(DEFAULT_DOCUMENT_PATH): | |
| st.error(f"Default document {DEFAULT_DOCUMENT_PATH} not found. Please make sure it exists.") | |
| return False | |
| embeddings = setup_embeddings() | |
| try: | |
| with st.spinner("Building Pakistan law database (this may take a few minutes)..."): | |
| loader = PyPDFLoader(DEFAULT_DOCUMENT_PATH) | |
| documents = loader.load() | |
| # Add source filename to metadata | |
| for doc in documents: | |
| doc.metadata["source"] = "Pakistan Laws (Official)" | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200 | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| # Create vector store | |
| db = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| collection_name=DEFAULT_COLLECTION_NAME, | |
| persist_directory=CHROMA_PERSIST_DIR | |
| ) | |
| # Explicitly persist to disk | |
| db.persist() | |
| st.session_state.vectordb = db | |
| setup_qa_chain() | |
| st.session_state.using_custom_docs = False | |
| return True | |
| except Exception as e: | |
| st.error(f"Error processing default document: {str(e)}") | |
| return False | |
| def check_custom_db_exists(collection_name): | |
| """Check if a custom document database already exists""" | |
| if os.path.exists(os.path.join(CHROMA_PERSIST_DIR, collection_name)): | |
| return True | |
| return False | |
| def process_custom_documents(uploaded_files): | |
| """Process user-uploaded PDF documents""" | |
| embeddings = setup_embeddings() | |
| collection_name = st.session_state.custom_collection_name | |
| documents = [] | |
| for uploaded_file in uploaded_files: | |
| # Save file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_path = tmp_file.name | |
| # Load and split the document | |
| try: | |
| loader = PyPDFLoader(tmp_path) | |
| file_docs = loader.load() | |
| # Add source filename to metadata | |
| for doc in file_docs: | |
| doc.metadata["source"] = uploaded_file.name | |
| documents.extend(file_docs) | |
| # Clean up temp file | |
| os.unlink(tmp_path) | |
| except Exception as e: | |
| st.error(f"Error processing {uploaded_file.name}: {str(e)}") | |
| continue | |
| if documents: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200 | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| # Create vector store | |
| with st.spinner("Building custom document database..."): | |
| # If a previous custom DB exists for this user, delete it first | |
| if check_custom_db_exists(collection_name): | |
| # We need to recreate the vectorstore to delete the old collection | |
| temp_db = Chroma( | |
| persist_directory=CHROMA_PERSIST_DIR, | |
| embedding_function=embeddings, | |
| collection_name=collection_name | |
| ) | |
| temp_db.delete_collection() | |
| # Create new vector store | |
| db = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| collection_name=collection_name, | |
| persist_directory=CHROMA_PERSIST_DIR | |
| ) | |
| # Explicitly persist to disk | |
| db.persist() | |
| st.session_state.vectordb = db | |
| setup_qa_chain() | |
| st.session_state.using_custom_docs = True | |
| return True | |
| return False | |
| def setup_qa_chain(): | |
| """Set up the QA chain with the RAG system""" | |
| if st.session_state.vectordb: | |
| llm = setup_llm() | |
| # Create prompt template | |
| template = """You are a helpful legal assistant specializing in Pakistani law. | |
| Use the following context to answer the question. If you don't know the answer based on the context, | |
| say that you don't have enough information, but provide general legal information if possible. | |
| Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| prompt = ChatPromptTemplate.from_template(template) | |
| # Create the QA chain | |
| st.session_state.qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=st.session_state.vectordb.as_retriever(search_kwargs={"k": 3}), | |
| chain_type_kwargs={"prompt": prompt}, | |
| return_source_documents=True | |
| ) | |
| def generate_similar_questions(question, docs): | |
| """Generate similar questions based on retrieved documents""" | |
| llm = setup_llm() | |
| # Extract key content from docs | |
| context = "\n".join([doc.page_content for doc in docs[:2]]) | |
| # Prompt to generate similar questions | |
| prompt = f"""Based on the following user question and legal context, generate 3 similar questions that the user might also be interested in. | |
| Make the questions specific, related to Pakistani law, and directly relevant to the original question. | |
| Original Question: {question} | |
| Legal Context: {context} | |
| Generate exactly 3 similar questions:""" | |
| try: | |
| response = llm.invoke(prompt) | |
| # Extract questions from response using regex | |
| questions = re.findall(r"\d+\.\s+(.*?)(?=\d+\.|$)", response.content, re.DOTALL) | |
| if not questions: | |
| questions = response.content.split("\n") | |
| questions = [q.strip() for q in questions if q.strip() and not q.startswith("Similar") and "?" in q] | |
| # Clean and limit to 3 questions | |
| questions = [q.strip().replace("\n", " ") for q in questions if "?" in q] | |
| return questions[:3] | |
| except Exception as e: | |
| print(f"Error generating similar questions: {e}") | |
| return [] | |
| def get_answer(question): | |
| """Get answer from QA chain""" | |
| # If default documents haven't been processed yet, try to load them | |
| if not st.session_state.vectordb: | |
| with st.spinner("Loading Pakistan law database..."): | |
| process_default_document() | |
| if st.session_state.qa_chain: | |
| result = st.session_state.qa_chain({"query": question}) | |
| answer = result["result"] | |
| # Generate similar questions | |
| source_docs = result.get("source_documents", []) | |
| st.session_state.similar_questions = generate_similar_questions(question, source_docs) | |
| # Add source information | |
| sources = set() | |
| for doc in source_docs: | |
| if "source" in doc.metadata: | |
| sources.add(doc.metadata["source"]) | |
| if sources: | |
| answer += f"\n\nSources: {', '.join(sources)}" | |
| return answer | |
| else: | |
| return "Initializing the knowledge base. Please try again in a moment." | |
| def main(): | |
| st.title("Pakistan Law AI Agent") | |
| # Determine current mode | |
| if st.session_state.using_custom_docs: | |
| st.subheader("Training on your personal resources") | |
| else: | |
| st.subheader("Powered by Pakistan law database") | |
| # Sidebar for uploading documents and switching modes | |
| with st.sidebar: | |
| st.header("Resource Management") | |
| # Option to return to default documents | |
| if st.session_state.using_custom_docs: | |
| if st.button("Return to Official Database"): | |
| with st.spinner("Loading official Pakistan law database..."): | |
| process_default_document() | |
| st.success("Switched to official Pakistan law database!") | |
| st.session_state.messages.append(AIMessage(content="Switched to official Pakistan law database. You can now ask legal questions.")) | |
| st.rerun() | |
| # Option to rebuild the default database | |
| if not st.session_state.using_custom_docs: | |
| if st.button("Rebuild Official Database"): | |
| with st.spinner("Rebuilding official Pakistan law database..."): | |
| process_default_document(force_rebuild=True) | |
| st.success("Official database rebuilt successfully!") | |
| st.rerun() | |
| # Option to upload custom documents | |
| st.header("Upload Custom Legal Documents") | |
| uploaded_files = st.file_uploader( | |
| "Upload PDF files containing legal documents", | |
| type=["pdf"], | |
| accept_multiple_files=True | |
| ) | |
| if st.button("Train on Uploaded Documents") and uploaded_files: | |
| with st.spinner("Processing your documents..."): | |
| success = process_custom_documents(uploaded_files) | |
| if success: | |
| st.success("Your documents processed successfully!") | |
| st.session_state.messages.append(AIMessage(content="Custom legal documents loaded successfully. You are now training on your personal resources.")) | |
| st.rerun() | |
| # Display chat messages | |
| for message in st.session_state.messages: | |
| if isinstance(message, HumanMessage): | |
| with st.chat_message("user"): | |
| st.write(message.content) | |
| else: | |
| with st.chat_message("assistant", avatar="⚖️"): | |
| st.write(message.content) | |
| # Display similar questions if available | |
| if st.session_state.similar_questions: | |
| st.markdown("#### Related Questions:") | |
| cols = st.columns(len(st.session_state.similar_questions)) | |
| for i, question in enumerate(st.session_state.similar_questions): | |
| if cols[i].button(question, key=f"similar_q_{i}"): | |
| # Add selected question as user input | |
| st.session_state.messages.append(HumanMessage(content=question)) | |
| # Generate and display assistant response | |
| with st.chat_message("assistant", avatar="⚖️"): | |
| with st.spinner("Thinking..."): | |
| response = get_answer(question) | |
| st.write(response) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append(AIMessage(content=response)) | |
| st.rerun() | |
| # Input for new question | |
| if user_input := st.chat_input("Ask a legal question..."): | |
| # Add user message to chat history | |
| st.session_state.messages.append(HumanMessage(content=user_input)) | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| # Generate and display assistant response | |
| with st.chat_message("assistant", avatar="⚖️"): | |
| with st.spinner("Thinking..."): | |
| response = get_answer(user_input) | |
| st.write(response) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append(AIMessage(content=response)) | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() |