SankethHonavar commited on
Commit
76b04ec
·
1 Parent(s): 18c83ea

Deploy LLM Medical Chatbot with FAISS

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/medmcqa_index/index.faiss filter=lfs diff=lfs merge=lfs -text
37
+ data/medmcqa_index/index.pkl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pkl
4
+ *.db
5
+ *.log
6
+ .env
7
+ .venv/
8
+ .ipynb_checkpoints/
9
+ *.sqlite3
10
+ *.DS_Store
11
+ try.py
12
+
13
+ # Ignore everything in data folder
14
+ data/*
15
+
16
+ # But allow medmcqa_index folder and its contents
17
+ !data/medmcqa_index/
18
+ !data/medmcqa_index/**
Dockerfile CHANGED
@@ -2,14 +2,11 @@ FROM python:3.10
2
 
3
  WORKDIR /app
4
 
5
- # Copy and install dependencies first
6
  COPY requirements.txt .
7
  RUN pip install --no-cache-dir -r requirements.txt
8
 
9
- # Copy all project files and folders
10
  COPY . .
11
 
12
  EXPOSE 7860
13
 
14
- # Run your Streamlit app (entry point)
15
- CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
2
 
3
  WORKDIR /app
4
 
 
5
  COPY requirements.txt .
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
 
8
  COPY . .
9
 
10
  EXPOSE 7860
11
 
12
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from retriever import load_vector_store
3
+ from langgraph_graph import generate_answer
4
+ from time import sleep
5
+
6
+ # Load vector DB
7
+ db = load_vector_store()
8
+
9
+ st.set_page_config("MedMCQA Chatbot", page_icon="🩺")
10
+
11
+ # 🌗 Theme toggle sidebar
12
+ with st.sidebar:
13
+ st.title("🩺 MedMCQA Chatbot")
14
+ theme_mode = st.radio("🌓 Theme", ["Light", "Dark"], horizontal=True)
15
+
16
+ # 🌓 Apply selected theme
17
+ if theme_mode == "Dark":
18
+ st.markdown("""
19
+ <style>
20
+ :root { --text-color: #eee; }
21
+ body, .stApp {
22
+ background-color: #1e1e1e !important;
23
+ color: var(--text-color) !important;
24
+ }
25
+ .stTextInput input {
26
+ background-color: #333 !important;
27
+ color: var(--text-color) !important;
28
+ }
29
+ .stTextInput label {
30
+ color: var(--text-color) !important;
31
+ }
32
+ input::placeholder {
33
+ color: #bbb !important;
34
+ }
35
+ .stButton>button {
36
+ background-color: #444 !important;
37
+ color: var(--text-color) !important;
38
+ }
39
+ </style>
40
+ """, unsafe_allow_html=True)
41
+ else:
42
+ st.markdown("""
43
+ <style>
44
+ :root { --text-color: #111; }
45
+ body, .stApp {
46
+ background-color: #ffffff !important;
47
+ color: var(--text-color) !important;
48
+ }
49
+ .stTextInput input {
50
+ background-color: #f0f0f0 !important;
51
+ color: var(--text-color) !important;
52
+ }
53
+ .stTextInput label {
54
+ color: var(--text-color) !important;
55
+ }
56
+ input::placeholder {
57
+ color: #444 !important;
58
+ }
59
+ .stButton>button {
60
+ background-color: #e0e0e0 !important;
61
+ color: var(--text-color) !important;
62
+ }
63
+ </style>
64
+ """, unsafe_allow_html=True)
65
+
66
+ # 🧠 App title
67
+ st.header("🩺 MedMCQA Chatbot")
68
+ st.caption("Ask a medical question and get answers from the MedMCQA dataset only. If not found, it will respond gracefully.")
69
+
70
+ # ✏️ Query box
71
+ query = st.text_input(
72
+ "🔍 Enter your medical question:",
73
+ placeholder="e.g., What is the mechanism of Aspirin?",
74
+ label_visibility="visible"
75
+ )
76
+
77
+ # 🚀 Answer generation
78
+ if query:
79
+ results = db.similarity_search(query, k=3)
80
+ context = "\n\n".join([doc.page_content for doc in results])
81
+
82
+ with st.spinner("🧠 Generating answer..."):
83
+ response = generate_answer(query, context)
84
+
85
+ st.markdown("""
86
+ <style>
87
+ .fade-in {
88
+ animation: fadeIn 0.7s ease-in;
89
+ }
90
+ @keyframes fadeIn {
91
+ 0% { opacity: 0; transform: translateY(20px); }
92
+ 100% { opacity: 1; transform: translateY(0); }
93
+ }
94
+ </style>
95
+ """, unsafe_allow_html=True)
96
+
97
+ st.markdown("<div class='fade-in'><h4>🧠 Answer:</h4></div>", unsafe_allow_html=True)
98
+ answer_placeholder = st.empty()
99
+ final_text = ""
100
+ for char in response:
101
+ final_text += char
102
+ answer_placeholder.markdown(f"<div class='fade-in'>{final_text}</div>", unsafe_allow_html=True)
103
+ sleep(0.01)
104
+
105
+ with st.expander("🔎 Top Matches"):
106
+ for i, doc in enumerate(results, 1):
107
+ content = doc.page_content
108
+ if query.lower() in content.lower():
109
+ content = content.replace(query, f"**{query}**")
110
+ st.markdown(f"**Result {i}:**\n\n{content}")
111
+
112
+ # 📬 Sidebar Contact
113
+ with st.sidebar:
114
+ st.markdown("---")
115
+ st.markdown("### 📬 Contact")
116
+ st.markdown("[📧 Email](mailto:[email protected])")
117
+ st.markdown("[🔗 LinkedIn](https://linkedin.com/in/sankethhonavar)")
118
+ st.markdown("[💻 GitHub](https://github.com/sankethhonavar)")
119
+
120
+ # ✨ Floating Icons (Right side - Top aligned)
121
+ st.markdown("""
122
+ <style>
123
+ .floating-button {
124
+ position: fixed;
125
+ top: 80px;
126
+ right: 20px;
127
+ display: flex;
128
+ flex-direction: column;
129
+ gap: 12px;
130
+ z-index: 9999;
131
+ }
132
+ .floating-button a {
133
+ background-color: #0077b5;
134
+ color: white;
135
+ padding: 10px 14px;
136
+ border-radius: 50%;
137
+ text-align: center;
138
+ font-size: 20px;
139
+ text-decoration: none;
140
+ box-shadow: 2px 2px 8px rgba(0, 0, 0, 0.3);
141
+ transition: background-color 0.3s;
142
+ }
143
+ .floating-button a:hover {
144
+ background-color: #005983;
145
+ }
146
+ .floating-button a.email {
147
+ background-color: #444444;
148
+ }
149
+ .floating-button a.email:hover {
150
+ background-color: #222222;
151
+ }
152
+ .floating-button a.github {
153
+ background-color: #171515;
154
+ }
155
+ .floating-button a.github:hover {
156
+ background-color: #000000;
157
+ }
158
+ </style>
159
+
160
+ <div class="floating-button">
161
+ <a href="mailto:[email protected]" class="email" title="Email Me">
162
+ <img src="https://img.icons8.com/ios-filled/25/ffffff/new-post.png" alt="Email"/>
163
+ </a>
164
+ <a href="https://linkedin.com/in/sankethhonavar" target="_blank" title="LinkedIn">
165
+ <img src="https://img.icons8.com/ios-filled/25/ffffff/linkedin.png" alt="LinkedIn"/>
166
+ </a>
167
+ <a href="https://github.com/SankethHonavar" target="_blank" class="github" title="GitHub">
168
+ <img src="https://img.icons8.com/ios-filled/25/ffffff/github.png" alt="GitHub"/>
169
+ </a>
170
+ </div>
171
+ """, unsafe_allow_html=True)
172
+
173
+ # 📄 Footer
174
+ st.markdown("""
175
+ ---
176
+ <p style='text-align: center; font-size: 0.9rem; color: grey'>
177
+ Made with ❤️ by <a href='https://linkedin.com/in/sankethhonavar' target='_blank'>Sanketh Honavar</a>
178
+ </p>
179
+ """, unsafe_allow_html=True)
data/medmcqa_index/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34d54d6522c7d7b29d217d765eb4553f125c9f0d0d4a817cd466e885bef2d145
3
+ size 7680045
data/medmcqa_index/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56f440b2cdb8220a6eee18440355834f439554987e4001c0a97513a5ac5a10d8
3
+ size 4297348
dataset_loader.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_loader.py
2
+ from datasets import load_dataset
3
+
4
+ def load_medmcqa_subset(limit=5000):
5
+ dataset = load_dataset("medmcqa", split="train")
6
+
7
+ def format_entry(entry):
8
+ return {
9
+ "question": entry["question"],
10
+ "formatted": (
11
+ f"Q: {entry['question']}\n"
12
+ f"A. {entry['opa']} B. {entry['opb']} C. {entry['opc']} D. {entry['opd']}\n"
13
+ f"Correct Answer: {entry['cop']}\n"
14
+ f"Explanation: {entry['exp']}"
15
+ )
16
+ }
17
+
18
+ return [format_entry(entry) for entry in dataset]
langgraph_graph.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from retriever import retrieve_relevant_docs
2
+ from langchain_core.prompts import PromptTemplate
3
+ from langchain.chains.combine_documents import create_stuff_documents_chain
4
+ from langchain.chains import create_retrieval_chain
5
+ from langchain_google_genai import ChatGoogleGenerativeAI
6
+
7
+ # LLM used for both doc chain and fallback answer
8
+ llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0.3)
9
+
10
+ # Define the structured prompt
11
+ prompt = PromptTemplate.from_template("""
12
+ You are a helpful medical assistant. Use only the dataset context below to answer.
13
+
14
+ Context:
15
+ {context}
16
+
17
+ Question: {input}
18
+
19
+ If you are unsure, say "Sorry, I couldn't find an answer based on the dataset." Do not guess.
20
+ """)
21
+
22
+ # Build document chain and retrieval chain
23
+ document_chain = create_stuff_documents_chain(llm, prompt)
24
+ retriever_chain = create_retrieval_chain(retrieve_relevant_docs(), document_chain)
25
+
26
+ # Expose chain for Streamlit app
27
+ graph = retriever_chain
28
+
29
+ # Manual fallback function if needed
30
+ def generate_answer(query: str, context: str) -> str:
31
+ if not context.strip():
32
+ return "Sorry, I couldn't find an answer based on the dataset."
33
+
34
+ fallback_llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0.3)
35
+ fallback_prompt = f"""
36
+ You are a helpful medical assistant. Use only the dataset context below to answer.
37
+
38
+ Context:
39
+ {context}
40
+
41
+ Question: {query}
42
+
43
+ If you are unsure, say "Sorry, I couldn't find an answer based on the dataset." Do not guess.
44
+ """
45
+ response = fallback_llm.invoke(fallback_prompt)
46
+ return response.content.strip()
main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from retriever import load_vector_store
3
+ from langgraph_graph import generate_answer
4
+
5
+
6
+ def medchat(query):
7
+ """
8
+ Full MedMCQA pipeline.
9
+ 1. Retrieve top matches
10
+ 2. Prompt LLM with strict instruction to avoid hallucination
11
+ """
12
+ retriever = load_vector_store()
13
+ matches = retriever.similarity_search(query, k=3)
14
+ context = "\n\n".join([match.page_content for match in matches])
15
+
16
+ prompt = f"""
17
+ You are a helpful medical assistant. Use only the dataset context below to answer.
18
+
19
+ Context:
20
+ {context}
21
+
22
+ Question: {query}
23
+ If you are unsure, say 'Sorry, I couldn't find an answer based on the dataset.'
24
+ """
25
+
26
+ return generate_answer(prompt)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ print("\n🩺 MedMCQA Chatbot")
31
+ print("Ask a medical question and get answers from MedMCQA dataset.\n")
32
+
33
+ while True:
34
+ user_q = input("Ask a medical question (or type 'exit'): ")
35
+ if user_q.lower() == "exit":
36
+ break
37
+ response = medchat(user_q)
38
+ print("\n🧠 Answer:", response, "\n")
retriever.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+ from langchain_community.vectorstores import FAISS
3
+ from langchain_core.documents import Document
4
+ from dataset_loader import load_medmcqa_subset
5
+ from tqdm import tqdm # Progress bar for better visibility during indexing
6
+ import os
7
+
8
+ def retrieve_relevant_docs():
9
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
10
+ db = FAISS.load_local("data/medmcqa_index", embeddings, allow_dangerous_deserialization=True)
11
+ return db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
12
+
13
+ def create_vector_store():
14
+ examples = load_medmcqa_subset()
15
+
16
+ # Format each entry into a LangChain Document with progress bar
17
+ docs = [
18
+ Document(
19
+ page_content=e["formatted"],
20
+ metadata={"question": e["question"]}
21
+ )
22
+ for e in tqdm(examples, desc="📚 Embedding MedMCQA examples")
23
+ ]
24
+
25
+ # Create embedding model
26
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
27
+
28
+ # Build and save FAISS index
29
+ db = FAISS.from_documents(docs, embeddings)
30
+ os.makedirs("data", exist_ok=True)
31
+ db.save_local("data/medmcqa_index")
32
+ print("✅ Vector DB saved at data/medmcqa_index")
33
+
34
+
35
+ def load_vector_store():
36
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
37
+ return FAISS.load_local("data/medmcqa_index", embeddings, allow_dangerous_deserialization=True)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ from langchain.prompts import PromptTemplate
42
+ from langchain_core.output_parsers import StrOutputParser
43
+ from langchain_google_genai import ChatGoogleGenerativeAI
44
+
45
+ # Load DB
46
+ db = load_vector_store()
47
+ query = "What is the treatment for asthma?"
48
+ docs = db.similarity_search(query, k=4)
49
+
50
+ # Prompt Template
51
+ prompt_template = PromptTemplate.from_template(
52
+ """
53
+ You are a helpful medical assistant. Use only the dataset context below to answer.
54
+
55
+ Context:
56
+ {context}
57
+
58
+ Question:
59
+ {question}
60
+
61
+ If you are unsure, say 'Sorry, I couldn't find an answer based on the dataset.'
62
+ """
63
+ )
64
+
65
+ # LLM
66
+ llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0.3)
67
+
68
+ chain = prompt_template | llm | StrOutputParser()
69
+ print("\n\n🧠 Answer:\n", chain.invoke({"context": "\n\n".join(d.page_content for d in docs), "question": query}))
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))