Spaces:
Running
Running
import os | |
import torch | |
import torch.backends.cudnn as cudnn | |
import streamlit as st | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain.prompts import PromptTemplate | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
# Enable CUDA optimizations if available | |
if torch.cuda.is_available(): | |
cudnn.benchmark = True | |
# Step 1: Load the PDF and create a vector store | |
def load_pdf_to_vectorstore(pdf_path): | |
# Load and split PDF | |
loader = PyPDFLoader(pdf_path) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=20, | |
separators=["\n\n", "\n", ".", " ", ""] | |
) | |
chunks = text_splitter.split_documents(documents) | |
# Create embeddings and vector store | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
) | |
vectorstore = FAISS.from_documents(chunks, embeddings) | |
return vectorstore | |
# Step 2: Initialize the LaMini model | |
def setup_model(): | |
model_id = "MBZUAI/LaMini-Flan-T5-248M" # Using smaller model for faster inference | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_id, | |
# Removed low_cpu_mem_usage parameter | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
pipe = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=256, | |
do_sample=False, | |
temperature=0.3, | |
top_p=0.95, | |
device=0 if torch.cuda.is_available() else -1, | |
batch_size=1 | |
) | |
return pipe | |
# Step 3: Generate a response using the model and vector store | |
def generate_response(pipe, vectorstore, user_input): | |
# Get relevant context | |
docs = vectorstore.similarity_search(user_input, k=2) | |
context = "\n".join([ | |
f"Page {doc.metadata.get('page', 'unknown')}: {doc.page_content}" | |
for doc in docs | |
]) | |
# Create prompt | |
prompt = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
Using the following medical text excerpts, answer the question. | |
If the information isn't clearly provided in the context, or if you're unsure, please say so and recommend consulting a healthcare professional. | |
Context: {context} | |
Question: {question} | |
Answer (citing relevant page numbers when possible):""" | |
) | |
# Generate response using the new method | |
prompt_text = prompt.format(context=context, question=user_input) | |
response = pipe(prompt_text)[0]['generated_text'] | |
return response | |
# Cache responses for repeated questions | |
def cached_generate_response(user_input, _pipe, _vectorstore): | |
return generate_response(_pipe, _vectorstore, user_input) | |
# Batch processing for multiple questions | |
def batch_generate_responses(pipe, vectorstore, questions, batch_size=4): | |
responses = [] | |
for i in range(0, len(questions), batch_size): | |
batch = questions[i:i + batch_size] | |
batch_responses = [generate_response(pipe, vectorstore, q) for q in batch] | |
responses.extend(batch_responses) | |
return responses | |
# Streamlit UI | |
def main(): | |
st.title("Medical Chatbot Assistant π₯") | |
# Use the PDF file from the root directory | |
pdf_path = "Medical_book.pdf" | |
if os.path.exists(pdf_path): | |
# Initialize progress | |
progress_text = "Operation in progress. Please wait." | |
# Load vector store and model with progress indication | |
with st.spinner("Loading PDF and initializing model..."): | |
vectorstore = load_pdf_to_vectorstore(pdf_path) | |
pipe = setup_model() | |
st.success("Ready to answer questions!") | |
# Create a chat-like interface | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input | |
if prompt := st.chat_input("Ask your medical question:"): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate and display response | |
with st.chat_message("assistant"): | |
with st.spinner("Generating response..."): | |
response = cached_generate_response(prompt, pipe, vectorstore) | |
st.markdown(response) | |
# Add assistant message to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
else: | |
st.error("The file 'Medical_book.pdf' was not found in the root directory.") | |
main() |