khalifssa's picture
Update app.py
14eb06f verified
import os
import torch
import torch.backends.cudnn as cudnn
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# Enable CUDA optimizations if available
if torch.cuda.is_available():
cudnn.benchmark = True
# Step 1: Load the PDF and create a vector store
@st.cache_resource
def load_pdf_to_vectorstore(pdf_path):
# Load and split PDF
loader = PyPDFLoader(pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=20,
separators=["\n\n", "\n", ".", " ", ""]
)
chunks = text_splitter.split_documents(documents)
# Create embeddings and vector store
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
vectorstore = FAISS.from_documents(chunks, embeddings)
return vectorstore
# Step 2: Initialize the LaMini model
@st.cache_resource
def setup_model():
model_id = "MBZUAI/LaMini-Flan-T5-248M" # Using smaller model for faster inference
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
# Removed low_cpu_mem_usage parameter
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
if torch.cuda.is_available():
model = model.cuda()
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
do_sample=False,
temperature=0.3,
top_p=0.95,
device=0 if torch.cuda.is_available() else -1,
batch_size=1
)
return pipe
# Step 3: Generate a response using the model and vector store
def generate_response(pipe, vectorstore, user_input):
# Get relevant context
docs = vectorstore.similarity_search(user_input, k=2)
context = "\n".join([
f"Page {doc.metadata.get('page', 'unknown')}: {doc.page_content}"
for doc in docs
])
# Create prompt
prompt = PromptTemplate(
input_variables=["context", "question"],
template="""
Using the following medical text excerpts, answer the question.
If the information isn't clearly provided in the context, or if you're unsure, please say so and recommend consulting a healthcare professional.
Context: {context}
Question: {question}
Answer (citing relevant page numbers when possible):"""
)
# Generate response using the new method
prompt_text = prompt.format(context=context, question=user_input)
response = pipe(prompt_text)[0]['generated_text']
return response
# Cache responses for repeated questions
@st.cache_data
def cached_generate_response(user_input, _pipe, _vectorstore):
return generate_response(_pipe, _vectorstore, user_input)
# Batch processing for multiple questions
def batch_generate_responses(pipe, vectorstore, questions, batch_size=4):
responses = []
for i in range(0, len(questions), batch_size):
batch = questions[i:i + batch_size]
batch_responses = [generate_response(pipe, vectorstore, q) for q in batch]
responses.extend(batch_responses)
return responses
# Streamlit UI
def main():
st.title("Medical Chatbot Assistant πŸ₯")
# Use the PDF file from the root directory
pdf_path = "Medical_book.pdf"
if os.path.exists(pdf_path):
# Initialize progress
progress_text = "Operation in progress. Please wait."
# Load vector store and model with progress indication
with st.spinner("Loading PDF and initializing model..."):
vectorstore = load_pdf_to_vectorstore(pdf_path)
pipe = setup_model()
st.success("Ready to answer questions!")
# Create a chat-like interface
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Ask your medical question:"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Generate and display response
with st.chat_message("assistant"):
with st.spinner("Generating response..."):
response = cached_generate_response(prompt, pipe, vectorstore)
st.markdown(response)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
else:
st.error("The file 'Medical_book.pdf' was not found in the root directory.")
main()