Spaces:
Sleeping
Sleeping
# app.py | |
import streamlit as st | |
import os | |
from io import BytesIO | |
from PyPDF2 import PdfReader | |
from PyPDF2.errors import PdfReadError | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
import faiss | |
import uuid | |
from dotenv import load_dotenv | |
import requests | |
import pandas as pd | |
from pandas.errors import ParserError | |
from docx import Document | |
# Load environment variables | |
load_dotenv() | |
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "").strip() | |
RAG_ACCESS_KEY = os.getenv("RAG_ACCESS_KEY") | |
if not HUGGINGFACEHUB_API_TOKEN: | |
st.warning("Hugging Face API token not found! Please set HUGGINGFACEHUB_API_TOKEN in your .env file.") | |
# Initialize session state | |
if "vectorstore" not in st.session_state: | |
st.session_state.vectorstore = None | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
if "uploaded_files" not in st.session_state: | |
st.session_state.uploaded_files = [] | |
# File processing logic | |
def process_input(input_data): | |
# Initialize progress bar and status | |
progress_bar = st.progress(0) | |
status = st.empty() | |
# Step 1: Read file in memory | |
status.text("Reading file...") | |
progress_bar.progress(0.20) | |
file_name = input_data.name | |
file_extension = file_name.lower().split('.')[-1] | |
documents = "" | |
# Step 2: Extract text based on file type | |
status.text("Extracting text...") | |
progress_bar.progress(0.40) | |
try: | |
if file_extension == 'pdf': | |
try: | |
pdf_reader = PdfReader(BytesIO(input_data.read())) | |
documents = "".join([page.extract_text() or "" for page in pdf_reader.pages]) | |
except PdfReadError as e: | |
raise RuntimeError(f"Failed to read PDF: {str(e)}") | |
elif file_extension in ['xls', 'xlsx']: | |
try: | |
df = pd.read_excel(BytesIO(input_data.read()), engine='openpyxl') | |
documents = df.to_string(index=False) | |
except ParserError as e: | |
raise RuntimeError(f"Failed to parse Excel file: {str(e)}") | |
elif file_extension in ['doc', 'docx']: | |
try: | |
doc = Document(BytesIO(input_data.read())) | |
documents = "\n".join([para.text for para in doc.paragraphs if para.text]) | |
except Exception as e: | |
raise RuntimeError(f"Failed to read DOC/DOCX: {str(e)}") | |
elif file_extension == 'txt': | |
try: | |
documents = input_data.read().decode('utf-8') | |
except UnicodeDecodeError: | |
documents = input_data.read().decode('latin-1') | |
else: | |
raise ValueError(f"Unsupported file type: {file_extension}") | |
if not documents.strip(): | |
raise RuntimeError("No text extracted from the file.") | |
except Exception as e: | |
raise RuntimeError(f"Failed to process file: {str(e)}") | |
# Step 3: Split text | |
status.text("Splitting text into chunks...") | |
progress_bar.progress(0.60) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
texts = text_splitter.split_text(documents) | |
chunk_count = len(texts) | |
if chunk_count == 0: | |
raise RuntimeError("No text chunks created for embedding.") | |
# Step 4: Create embeddings | |
status.text(f"Embedding {chunk_count} chunks...") | |
progress_bar.progress(0.80) | |
try: | |
hf_embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
model_kwargs={'device': 'cpu'} | |
) | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize embeddings: {str(e)}") | |
# Step 5: Initialize or append to FAISS vector store | |
status.text("Building or updating vector store...") | |
progress_bar.progress(1.0) | |
try: | |
if st.session_state.vectorstore is None: | |
dimension = len(hf_embeddings.embed_query("test")) | |
index = faiss.IndexFlatL2(dimension) | |
vector_store = FAISS( | |
embedding_function=hf_embeddings, | |
index=index, | |
docstore=InMemoryDocstore({}), | |
index_to_docstore_id={} | |
) | |
else: | |
vector_store = st.session_state.vectorstore | |
# Add texts to vector store | |
uuids = [str(uuid.uuid4()) for _ in texts] | |
vector_store.add_texts(texts, ids=uuids) | |
except Exception as e: | |
raise RuntimeError(f"Failed to update vector store: {str(e)}") | |
# Complete processing | |
status.text("Processing complete!") | |
st.session_state.uploaded_files.append(file_name) | |
st.success(f"Embedded {chunk_count} chunks from {file_name}") | |
return vector_store | |
# Question-answering logic | |
def answer_question(vectorstore, query): | |
if not HUGGINGFACEHUB_API_TOKEN: | |
raise RuntimeError("Missing Hugging Face API token. Please set it in your .env file.") | |
try: | |
llm = HuggingFaceHub( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
model_kwargs={"temperature": 0.7, "max_length": 512}, | |
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN | |
) | |
except requests.exceptions.HTTPError as e: | |
raise RuntimeError(f"Failed to initialize LLM: {str(e)}. Check model availability or API token.") | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
prompt_template = PromptTemplate( | |
template="Use the context to answer the question concisely:\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:", | |
input_variables=["context", "question"] | |
) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=False, | |
chain_type_kwargs={"prompt": prompt_template} | |
) | |
try: | |
result = qa_chain({"query": query}) | |
return result["result"].split("Answer:")[-1].strip() | |
except requests.exceptions.HTTPError as e: | |
raise RuntimeError(f"Error querying LLM: {str(e)}. Please try again or check model endpoint.") | |
# Sidebar with BSNL logo, authentication, and controls | |
with st.sidebar: | |
try: | |
st.image("bsnl_logo.png", width=200) | |
except Exception: | |
st.warning("BSNL logo not found.") | |
st.header("RAG Control Panel") | |
api_key_input = st.text_input("Enter RAG Access Key", type="password") | |
# Blue button styles | |
st.markdown(""" | |
<style> | |
.auth-button button, .delete-button button { | |
background-color: #007BFF !important; | |
color: white !important; | |
font-weight: bold; | |
border-radius: 8px; | |
padding: 10px 20px; | |
border: none; | |
transition: all 0.3s ease; | |
width: 100%; | |
} | |
.auth-button button:hover, .delete-button button:hover { | |
background-color: #0056b3 !important; | |
transform: scale(1.05); | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Authenticate button | |
with st.container(): | |
st.markdown('<div class="auth-button">', unsafe_allow_html=True) | |
if st.button("Authenticate"): | |
if api_key_input == RAG_ACCESS_KEY and RAG_ACCESS_KEY is not None: | |
st.session_state.authenticated = True | |
st.success("Authentication successful!") | |
else: | |
st.error("Invalid API key.") | |
st.markdown('</div>', unsafe_allow_html=True) | |
if st.session_state.authenticated: | |
# Display uploaded files | |
if st.session_state.uploaded_files: | |
st.subheader("Uploaded Files") | |
for file_name in st.session_state.uploaded_files: | |
st.write(f"- {file_name}") | |
# File uploader | |
input_data = st.file_uploader("Upload a file (PDF, XLS/XLSX, DOC/DOCX, TXT)", type=["pdf", "xls", "xlsx", "doc", "docx", "txt"]) | |
if st.button("Process File") and input_data is not None: | |
if input_data.name in st.session_state.uploaded_files: | |
st.warning(f"File '{input_data.name}' has already been processed. Please upload a different file or delete the vector store.") | |
else: | |
try: | |
vector_store = process_input(input_data) | |
st.session_state.vectorstore = vector_store | |
except PermissionError as e: | |
st.error(f"File upload failed: Permission error - {str(e)}. Check file system access.") | |
except OSError as e: | |
st.error(f"File upload failed: OS error - {str(e)}. Check server configuration.") | |
except ValueError as e: | |
st.error(f"File upload failed: {str(e)} (Invalid file format).") | |
except RuntimeError as e: | |
st.error(f"File upload failed: {str(e)} (Exception type: {type(e).__name__}).") | |
except Exception as e: | |
st.error(f"File upload failed: {str(e)} (Exception type: {type(e).__name__}). Please try again or check server logs.") | |
# Delete vector store button | |
if st.session_state.vectorstore is not None: | |
st.markdown('<div class="delete-button">', unsafe_allow_html=True) | |
if st.button("Delete Vector Store"): | |
st.session_state.vectorstore = None | |
st.session_state.uploaded_files = [] | |
st.success("Vector store deleted successfully.") | |
st.markdown('</div>', unsafe_allow_html=True) | |
st.subheader("Chat History") | |
for i, (q, a) in enumerate(st.session_state.history): | |
st.write(f"**Q{i+1}:** {q}") | |
st.write(f"**A{i+1}:** {a}") | |
st.markdown("---") | |
# Main app UI | |
def main(): | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap'); | |
.stApp { | |
background-color: #FFFFFF; | |
font-family: 'Roboto', sans-serif; | |
color: #333333; | |
} | |
.stTextInput > div > div > input { | |
background-color: #FFFFFF; | |
color: #333333; | |
border-radius: 8px; | |
border: 1px solid #007BFF; | |
padding: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
} | |
.stButton > button { | |
background-color: #007BFF; | |
color: white; | |
border-radius: 8px; | |
padding: 10px 20px; | |
border: none; | |
transition: all 0.3s ease; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.2); | |
} | |
.stButton > button:hover { | |
background-color: #0056b3; | |
transform: scale(1.05); | |
} | |
.stSidebar { | |
background-color: #F5F5F5; | |
padding: 20px; | |
border-right: 2px solid #007BFF; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("RAG Q&A App with Mistral AI") | |
st.markdown("Welcome to the BSNL RAG App! Upload a PDF, XLS/XLSX, DOC/DOCX, or TXT file and ask questions. Files are stored in the vector store until explicitly deleted.", unsafe_allow_html=True) | |
if not st.session_state.authenticated: | |
st.warning("Please authenticate using the sidebar.") | |
return | |
if st.session_state.vectorstore is None: | |
st.info("Please upload and process a file.") | |
return | |
query = st.text_input("Enter your question:") | |
if st.button("Submit") and query: | |
with st.spinner("Generating answer..."): | |
try: | |
answer = answer_question(st.session_state.vectorstore, query) | |
st.session_state.history.append((query, answer)) | |
st.write("**Answer:**", answer) | |
except Exception as e: | |
st.error(f"Error generating answer: {str(e)}") | |
if __name__ == "__main__": | |
main() | |