sahay-PM / app.py
frozen8569's picture
Create app.py
0b29405 verified
raw
history blame
5.85 kB
import streamlit as st
import torch
import fitz # PyMuPDF
from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
# For Fairness Audit
import pandas as pd
from aif360.datasets import StandardDataset
from aif360.metrics import BinaryLabelDatasetMetric
# --- Page Configuration ---
st.set_page_config(
page_title="Sahay AI ๐Ÿ‡ฎ๐Ÿ‡ณ",
page_icon="๐Ÿค–",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Caching for Performance ---
@st.cache_resource
def load_llm():
"""Loads the smaller, CPU-friendly model (FLAN-T5-Base)."""
llm_model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
return HuggingFacePipeline(pipeline=pipe)
@st.cache_resource
def load_and_process_pdf(pdf_path):
"""Loads and embeds the PDF using IBM's multilingual model."""
try:
doc = fitz.open(pdf_path)
text = "".join(page.get_text() for page in doc)
if not text:
st.error("Could not extract text from PDF.")
return None
except Exception as e:
st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.")
return None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.create_documents([text])
embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
vector_db = FAISS.from_documents(docs, embedding_model)
return vector_db
# --- Conversational Chain ---
def create_conversational_chain(_llm, _vector_db):
prompt_template = """You are a polite AI assistant for the PM-KISAN scheme. Use the context to answer the question precisely. If the question is not related to the context, state that you can only answer questions about the PM-KISAN scheme. Do not make up information.
Context: {context}
Question: {question}
Helpful Answer:"""
QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
chain = ConversationalRetrievalChain.from_llm(
llm=_llm, retriever=_vector_db.as_retriever(search_kwargs={'k': 3}),
memory=memory, return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
)
return chain
# --- IBM AIF360 Fairness Audit ---
def run_fairness_audit():
st.subheader("๐Ÿค– IBM AIF360 - Fairness Audit")
st.info("A simulation to check for bias in our information retriever.")
df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']})
df_for_aif = pd.DataFrame()
df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
df_for_aif['favorable_outcome'] = [1, 1, 1, 1]
aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
protected_attribute_names=['gender'], privileged_classes=[[1]])
metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
spd = metric.statistical_parity_difference()
st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")
# --- Main Application UI ---
if __name__ == "__main__":
with st.sidebar:
st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
st.title("๐Ÿ‡ฎ๐Ÿ‡ณ Sahay AI")
st.markdown("An AI assistant for the **PM-KISAN** scheme, built with IBM's multilingual embedding model.")
if st.button("Run Fairness Audit", use_container_width=True):
st.session_state.run_audit = True
st.header("Chat with Sahay AI ๐Ÿ’ฌ")
st.markdown("Your trusted guide to the PM-KISAN scheme.")
if st.session_state.get('run_audit', False):
run_fairness_audit()
st.session_state.run_audit = False
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]
if "qa_chain" not in st.session_state:
with st.spinner("๐Ÿš€ Initializing Sahay AI..."):
llm = load_llm()
vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
if vector_db:
st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
else:
st.error("Application could not start. Is the PDF uploaded?")
st.stop()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask a question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("๐Ÿง  Thinking..."):
result = st.session_state.qa_chain.invoke({"question": prompt})
response = result["answer"]
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})