import streamlit as st import torch import fitz # PyMuPDF from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate # For Fairness Audit import pandas as pd from aif360.datasets import StandardDataset from aif360.metrics import BinaryLabelDatasetMetric # --- Page Configuration --- st.set_page_config( page_title="Sahay AI 🇮🇳", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) # --- Caching for Performance --- @st.cache_resource def load_llm(): """Loads the smaller, CPU-friendly model (FLAN-T5-Base).""" llm_model_name = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(llm_model_name) model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name) pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512) return HuggingFacePipeline(pipeline=pipe) @st.cache_resource def load_and_process_pdf(pdf_path): """Loads and embeds the PDF using IBM's multilingual model.""" try: doc = fitz.open(pdf_path) text = "".join(page.get_text() for page in doc) if not text: st.error("Could not extract text from PDF.") return None except Exception as e: st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.") return None text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) docs = text_splitter.create_documents([text]) embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual") vector_db = FAISS.from_documents(docs, embedding_model) return vector_db # --- Conversational Chain --- def create_conversational_chain(_llm, _vector_db): prompt_template = """You are a polite AI assistant for the PM-KISAN scheme. Use the context to answer the question precisely. If the question is not related to the context, state that you can only answer questions about the PM-KISAN scheme. Do not make up information. Context: {context} Question: {question} Helpful Answer:""" QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer') chain = ConversationalRetrievalChain.from_llm( llm=_llm, retriever=_vector_db.as_retriever(search_kwargs={'k': 3}), memory=memory, return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT} ) return chain # --- IBM AIF360 Fairness Audit --- def run_fairness_audit(): st.subheader("🤖 IBM AIF360 - Fairness Audit") st.info("A simulation to check for bias in our information retriever.") df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']}) df_for_aif = pd.DataFrame() df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0}) df_for_aif['favorable_outcome'] = [1, 1, 1, 1] aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1], protected_attribute_names=['gender'], privileged_classes=[[1]]) metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}]) spd = metric.statistical_parity_difference() st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}") # --- Main Application UI --- if __name__ == "__main__": with st.sidebar: st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100) st.title("🇮🇳 Sahay AI") st.markdown("An AI assistant for the **PM-KISAN** scheme, built with IBM's multilingual embedding model.") if st.button("Run Fairness Audit", use_container_width=True): st.session_state.run_audit = True st.header("Chat with Sahay AI 💬") st.markdown("Your trusted guide to the PM-KISAN scheme.") if st.session_state.get('run_audit', False): run_fairness_audit() st.session_state.run_audit = False if "messages" not in st.session_state: st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}] if "qa_chain" not in st.session_state: with st.spinner("🚀 Initializing Sahay AI..."): llm = load_llm() vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF") if vector_db: st.session_state.qa_chain = create_conversational_chain(llm, vector_db) else: st.error("Application could not start. Is the PDF uploaded?") st.stop() for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Ask a question..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("🧠 Thinking..."): result = st.session_state.qa_chain.invoke({"question": prompt}) response = result["answer"] st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response})