frozen8569 commited on
Commit
0b29405
ยท
verified ยท
1 Parent(s): 3581769

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import streamlit as st
4
+ import torch
5
+ import fitz # PyMuPDF
6
+ from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
11
+ from langchain.chains import ConversationalRetrievalChain
12
+ from langchain.memory import ConversationBufferMemory
13
+ from langchain.prompts import PromptTemplate
14
+
15
+ # For Fairness Audit
16
+ import pandas as pd
17
+ from aif360.datasets import StandardDataset
18
+ from aif360.metrics import BinaryLabelDatasetMetric
19
+
20
+ # --- Page Configuration ---
21
+ st.set_page_config(
22
+ page_title="Sahay AI ๐Ÿ‡ฎ๐Ÿ‡ณ",
23
+ page_icon="๐Ÿค–",
24
+ layout="wide",
25
+ initial_sidebar_state="expanded"
26
+ )
27
+
28
+ # --- Caching for Performance ---
29
+ @st.cache_resource
30
+ def load_llm():
31
+ """Loads the smaller, CPU-friendly model (FLAN-T5-Base)."""
32
+ llm_model_name = "google/flan-t5-base"
33
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
35
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
36
+ return HuggingFacePipeline(pipeline=pipe)
37
+
38
+ @st.cache_resource
39
+ def load_and_process_pdf(pdf_path):
40
+ """Loads and embeds the PDF using IBM's multilingual model."""
41
+ try:
42
+ doc = fitz.open(pdf_path)
43
+ text = "".join(page.get_text() for page in doc)
44
+ if not text:
45
+ st.error("Could not extract text from PDF.")
46
+ return None
47
+ except Exception as e:
48
+ st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.")
49
+ return None
50
+
51
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
52
+ docs = text_splitter.create_documents([text])
53
+
54
+ embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
55
+ vector_db = FAISS.from_documents(docs, embedding_model)
56
+ return vector_db
57
+
58
+ # --- Conversational Chain ---
59
+ def create_conversational_chain(_llm, _vector_db):
60
+ prompt_template = """You are a polite AI assistant for the PM-KISAN scheme. Use the context to answer the question precisely. If the question is not related to the context, state that you can only answer questions about the PM-KISAN scheme. Do not make up information.
61
+ Context: {context}
62
+ Question: {question}
63
+ Helpful Answer:"""
64
+ QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
65
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
66
+ chain = ConversationalRetrievalChain.from_llm(
67
+ llm=_llm, retriever=_vector_db.as_retriever(search_kwargs={'k': 3}),
68
+ memory=memory, return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
69
+ )
70
+ return chain
71
+
72
+ # --- IBM AIF360 Fairness Audit ---
73
+ def run_fairness_audit():
74
+ st.subheader("๐Ÿค– IBM AIF360 - Fairness Audit")
75
+ st.info("A simulation to check for bias in our information retriever.")
76
+ df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']})
77
+ df_for_aif = pd.DataFrame()
78
+ df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
79
+ df_for_aif['favorable_outcome'] = [1, 1, 1, 1]
80
+
81
+ aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
82
+ protected_attribute_names=['gender'], privileged_classes=[[1]])
83
+ metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
84
+ spd = metric.statistical_parity_difference()
85
+ st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")
86
+
87
+ # --- Main Application UI ---
88
+ if __name__ == "__main__":
89
+
90
+ with st.sidebar:
91
+ st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
92
+ st.title("๐Ÿ‡ฎ๐Ÿ‡ณ Sahay AI")
93
+ st.markdown("An AI assistant for the **PM-KISAN** scheme, built with IBM's multilingual embedding model.")
94
+ if st.button("Run Fairness Audit", use_container_width=True):
95
+ st.session_state.run_audit = True
96
+
97
+ st.header("Chat with Sahay AI ๐Ÿ’ฌ")
98
+ st.markdown("Your trusted guide to the PM-KISAN scheme.")
99
+
100
+ if st.session_state.get('run_audit', False):
101
+ run_fairness_audit()
102
+ st.session_state.run_audit = False
103
+
104
+ if "messages" not in st.session_state:
105
+ st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]
106
+
107
+ if "qa_chain" not in st.session_state:
108
+ with st.spinner("๐Ÿš€ Initializing Sahay AI..."):
109
+ llm = load_llm()
110
+ vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
111
+ if vector_db:
112
+ st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
113
+ else:
114
+ st.error("Application could not start. Is the PDF uploaded?")
115
+ st.stop()
116
+
117
+ for message in st.session_state.messages:
118
+ with st.chat_message(message["role"]):
119
+ st.markdown(message["content"])
120
+
121
+ if prompt := st.chat_input("Ask a question..."):
122
+ st.session_state.messages.append({"role": "user", "content": prompt})
123
+ with st.chat_message("user"):
124
+ st.markdown(prompt)
125
+
126
+ with st.chat_message("assistant"):
127
+ with st.spinner("๐Ÿง  Thinking..."):
128
+ result = st.session_state.qa_chain.invoke({"question": prompt})
129
+ response = result["answer"]
130
+ st.markdown(response)
131
+ st.session_state.messages.append({"role": "assistant", "content": response})