File size: 5,846 Bytes
0b29405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131


import streamlit as st
import torch
import fitz  # PyMuPDF
from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

# For Fairness Audit
import pandas as pd
from aif360.datasets import StandardDataset
from aif360.metrics import BinaryLabelDatasetMetric

# --- Page Configuration ---
st.set_page_config(
    page_title="Sahay AI ๐Ÿ‡ฎ๐Ÿ‡ณ",
    page_icon="๐Ÿค–",
    layout="wide",
    initial_sidebar_state="expanded"
)

# --- Caching for Performance ---
@st.cache_resource
def load_llm():
    """Loads the smaller, CPU-friendly model (FLAN-T5-Base)."""
    llm_model_name = "google/flan-t5-base"
    tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
    pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
    return HuggingFacePipeline(pipeline=pipe)

@st.cache_resource
def load_and_process_pdf(pdf_path):
    """Loads and embeds the PDF using IBM's multilingual model."""
    try:
        doc = fitz.open(pdf_path)
        text = "".join(page.get_text() for page in doc)
        if not text:
            st.error("Could not extract text from PDF.")
            return None
    except Exception as e:
        st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.")
        return None

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.create_documents([text])
    
    embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
    vector_db = FAISS.from_documents(docs, embedding_model)
    return vector_db

# --- Conversational Chain ---
def create_conversational_chain(_llm, _vector_db):
    prompt_template = """You are a polite AI assistant for the PM-KISAN scheme. Use the context to answer the question precisely. If the question is not related to the context, state that you can only answer questions about the PM-KISAN scheme. Do not make up information.
    Context: {context}
    Question: {question}
    Helpful Answer:"""
    QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
    chain = ConversationalRetrievalChain.from_llm(
        llm=_llm, retriever=_vector_db.as_retriever(search_kwargs={'k': 3}),
        memory=memory, return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
    )
    return chain

# --- IBM AIF360 Fairness Audit ---
def run_fairness_audit():
    st.subheader("๐Ÿค– IBM AIF360 - Fairness Audit")
    st.info("A simulation to check for bias in our information retriever.")
    df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']})
    df_for_aif = pd.DataFrame()
    df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
    df_for_aif['favorable_outcome'] = [1, 1, 1, 1]
    
    aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
                                  protected_attribute_names=['gender'], privileged_classes=[[1]])
    metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
    spd = metric.statistical_parity_difference()
    st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")

# --- Main Application UI ---
if __name__ == "__main__":
    
    with st.sidebar:
        st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
        st.title("๐Ÿ‡ฎ๐Ÿ‡ณ Sahay AI")
        st.markdown("An AI assistant for the **PM-KISAN** scheme, built with IBM's multilingual embedding model.")
        if st.button("Run Fairness Audit", use_container_width=True):
            st.session_state.run_audit = True

    st.header("Chat with Sahay AI ๐Ÿ’ฌ")
    st.markdown("Your trusted guide to the PM-KISAN scheme.")

    if st.session_state.get('run_audit', False):
        run_fairness_audit()
        st.session_state.run_audit = False
    
    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]

    if "qa_chain" not in st.session_state:
        with st.spinner("๐Ÿš€ Initializing Sahay AI..."):
            llm = load_llm()
            vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
            if vector_db:
                st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
            else:
                st.error("Application could not start. Is the PDF uploaded?")
                st.stop()

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("Ask a question..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            with st.spinner("๐Ÿง  Thinking..."):
                result = st.session_state.qa_chain.invoke({"question": prompt})
                response = result["answer"]
                st.markdown(response)
        st.session_state.messages.append({"role": "assistant", "content": response})