Spaces:
Sleeping
Sleeping
File size: 5,846 Bytes
0b29405 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
import torch
import fitz # PyMuPDF
from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
# For Fairness Audit
import pandas as pd
from aif360.datasets import StandardDataset
from aif360.metrics import BinaryLabelDatasetMetric
# --- Page Configuration ---
st.set_page_config(
page_title="Sahay AI ๐ฎ๐ณ",
page_icon="๐ค",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Caching for Performance ---
@st.cache_resource
def load_llm():
"""Loads the smaller, CPU-friendly model (FLAN-T5-Base)."""
llm_model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
return HuggingFacePipeline(pipeline=pipe)
@st.cache_resource
def load_and_process_pdf(pdf_path):
"""Loads and embeds the PDF using IBM's multilingual model."""
try:
doc = fitz.open(pdf_path)
text = "".join(page.get_text() for page in doc)
if not text:
st.error("Could not extract text from PDF.")
return None
except Exception as e:
st.error(f"Error reading PDF: {e}. Ensure 'PMKisanSamanNidhi.PDF' is in the main project directory.")
return None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.create_documents([text])
embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
vector_db = FAISS.from_documents(docs, embedding_model)
return vector_db
# --- Conversational Chain ---
def create_conversational_chain(_llm, _vector_db):
prompt_template = """You are a polite AI assistant for the PM-KISAN scheme. Use the context to answer the question precisely. If the question is not related to the context, state that you can only answer questions about the PM-KISAN scheme. Do not make up information.
Context: {context}
Question: {question}
Helpful Answer:"""
QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
chain = ConversationalRetrievalChain.from_llm(
llm=_llm, retriever=_vector_db.as_retriever(search_kwargs={'k': 3}),
memory=memory, return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
)
return chain
# --- IBM AIF360 Fairness Audit ---
def run_fairness_audit():
st.subheader("๐ค IBM AIF360 - Fairness Audit")
st.info("A simulation to check for bias in our information retriever.")
df_display = pd.DataFrame({'gender_text': ['male', 'male', 'female', 'female']})
df_for_aif = pd.DataFrame()
df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
df_for_aif['favorable_outcome'] = [1, 1, 1, 1]
aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
protected_attribute_names=['gender'], privileged_classes=[[1]])
metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
spd = metric.statistical_parity_difference()
st.metric(label="**Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")
# --- Main Application UI ---
if __name__ == "__main__":
with st.sidebar:
st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
st.title("๐ฎ๐ณ Sahay AI")
st.markdown("An AI assistant for the **PM-KISAN** scheme, built with IBM's multilingual embedding model.")
if st.button("Run Fairness Audit", use_container_width=True):
st.session_state.run_audit = True
st.header("Chat with Sahay AI ๐ฌ")
st.markdown("Your trusted guide to the PM-KISAN scheme.")
if st.session_state.get('run_audit', False):
run_fairness_audit()
st.session_state.run_audit = False
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]
if "qa_chain" not in st.session_state:
with st.spinner("๐ Initializing Sahay AI..."):
llm = load_llm()
vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
if vector_db:
st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
else:
st.error("Application could not start. Is the PDF uploaded?")
st.stop()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask a question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("๐ง Thinking..."):
result = st.session_state.qa_chain.invoke({"question": prompt})
response = result["answer"]
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response}) |