Spaces:
Sleeping
Sleeping
removing too many comments
Browse files- .gitignore +2 -0
- app.py +40 -116
.gitignore
CHANGED
@@ -20,6 +20,8 @@ wheels/
|
|
20 |
.installed.cfg
|
21 |
*.egg
|
22 |
|
|
|
|
|
23 |
# Virtual Environment
|
24 |
venv/
|
25 |
env/
|
|
|
20 |
.installed.cfg
|
21 |
*.egg
|
22 |
|
23 |
+
old/*
|
24 |
+
|
25 |
# Virtual Environment
|
26 |
venv/
|
27 |
env/
|
app.py
CHANGED
@@ -2,12 +2,12 @@ import streamlit as st
|
|
2 |
import os
|
3 |
import re
|
4 |
import json
|
5 |
-
import demjson3
|
6 |
import requests
|
7 |
import faiss
|
8 |
import numpy as np
|
9 |
import multiprocessing
|
10 |
-
import time
|
11 |
|
12 |
from huggingface_hub import hf_hub_download, login
|
13 |
from sentence_transformers import SentenceTransformer
|
@@ -31,27 +31,25 @@ except KeyError:
|
|
31 |
|
32 |
# Model and RAG configuration
|
33 |
MODEL_REPO_ID = "bartowski/gemma-2-2b-it-GGUF"
|
34 |
-
MODEL_FILENAME = "gemma-2-2b-it-Q8_0.gguf"
|
35 |
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
36 |
DOCS_PATH = "docs"
|
37 |
-
FAISS_INDEX_PATH = "bps_faiss.index"
|
38 |
|
39 |
# LLM parameters
|
40 |
-
N_CTX = 2048
|
41 |
-
MAX_TOKENS_RESPONSE = 350
|
42 |
-
TEMPERATURE = 0.5
|
43 |
-
N_THREADS = multiprocessing.cpu_count() - 1
|
44 |
|
45 |
# RAG parameters
|
46 |
-
TOP_K_DOCS = 3
|
47 |
|
48 |
# Import prompts from prompts.py
|
49 |
-
# NOTE: Error message here is for app stability, not part of LLM prompts.
|
50 |
try:
|
51 |
from prompts import system_prompt, json_prompt, initial_school_search_prompt
|
52 |
except ImportError:
|
53 |
st.error("Could not import prompts from prompts.py. Make sure the file exists.")
|
54 |
-
# Define fallbacks directly if import fails - Using exact text from notebook prompts
|
55 |
system_prompt = """
|
56 |
You are a professional assistant that answers questions about enrollment in Boston Public Schools.
|
57 |
Be friendly and helpful. Families will ask questions and provide information, such as the child's residence, grade, and school preference.
|
@@ -66,19 +64,15 @@ Keep the conversation going and ask questions one at a time until you have all i
|
|
66 |
st.stop()
|
67 |
|
68 |
|
69 |
-
# --- Helper Functions
|
70 |
|
71 |
def clean_reply_text(reply: str) -> str:
|
72 |
"""Removes potential JSON blocks and cleans up common LLM artifacts."""
|
73 |
-
# Remove ```json ... ``` blocks or similar markdown code blocks
|
74 |
reply = re.sub(r"```[jJ][sS][oO][nN]?\s*(\{.*?\})\s*```", "", reply, flags=re.DOTALL)
|
75 |
-
# Remove trailing JSON object if it's at the very end after potential whitespace
|
76 |
reply = re.sub(r"\s*\{.*\}\s*$", "", reply, flags=re.DOTALL)
|
77 |
-
# Remove stray backticks, `json` keywords, and unmatched brackets
|
78 |
reply = re.sub(r"`", "", reply)
|
79 |
reply = re.sub(r"(?i)\bjson\b", "", reply)
|
80 |
-
reply = re.sub(r"[\[\]]", "", reply)
|
81 |
-
# Collapse multiple blank lines
|
82 |
reply = re.sub(r"\n{2,}", "\n", reply)
|
83 |
return reply.strip()
|
84 |
|
@@ -88,12 +82,10 @@ def extract_reply_and_json(text: str) -> tuple[str, dict]:
|
|
88 |
Uses demjson3 for potentially more lenient parsing.
|
89 |
"""
|
90 |
json_part = {}
|
91 |
-
reply_part = text
|
92 |
|
93 |
-
# Find the last potential JSON object (heuristic: starts with { ends with })
|
94 |
last_brace_open = text.rfind('{')
|
95 |
if last_brace_open != -1:
|
96 |
-
# Try to find the matching closing brace
|
97 |
brace_level = 0
|
98 |
last_brace_close = -1
|
99 |
potential_json_str = text[last_brace_open:]
|
@@ -109,47 +101,34 @@ def extract_reply_and_json(text: str) -> tuple[str, dict]:
|
|
109 |
if last_brace_close != -1:
|
110 |
json_str = text[last_brace_open : last_brace_close + 1]
|
111 |
try:
|
112 |
-
# Use demjson3 to decode
|
113 |
parsed = demjson3.decode(json_str)
|
114 |
if isinstance(parsed, dict):
|
115 |
json_part = parsed
|
116 |
-
# If JSON is successfully parsed, assume text before it is the reply
|
117 |
reply_part = text[:last_brace_open].strip()
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
except demjson3.JSONDecodeError as e:
|
122 |
-
# Decoding failed, assume it wasn't valid JSON
|
123 |
-
# print(f"JSON decode failed: {e}")
|
124 |
-
# print(f"Offending string segment:\n{json_str}")
|
125 |
-
pass # Keep original reply_part and empty json_part
|
126 |
-
|
127 |
-
# Clean the reply part further
|
128 |
cleaned_reply = clean_reply_text(reply_part)
|
129 |
|
130 |
-
# --- MODIFICATION: Removed the fallback message I added ---
|
131 |
-
# If cleaning removed everything, return empty string for reply.
|
132 |
-
# The notebook didn't specify fallback behavior here.
|
133 |
if not cleaned_reply and json_part:
|
134 |
cleaned_reply = clean_reply_text(text[:last_brace_open])
|
135 |
elif not cleaned_reply and not json_part:
|
136 |
-
cleaned_reply = ""
|
137 |
|
138 |
return cleaned_reply, json_part
|
139 |
|
140 |
|
141 |
def geocode_address(address: str) -> tuple[float | None, float | None]:
|
142 |
"""Turn a free‑form address into (lat, lon) using Geoapify."""
|
143 |
-
# NOTE: Error/warning messages here are for app stability/user feedback, not part of LLM prompts.
|
144 |
if not GEOAPIFY_KEY:
|
145 |
return None, None
|
146 |
try:
|
147 |
resp = requests.get(
|
148 |
"https://api.geoapify.com/v1/geocode/search",
|
149 |
params={"text": address, "limit": 1, "apiKey": GEOAPIFY_KEY},
|
150 |
-
timeout=10
|
151 |
)
|
152 |
-
resp.raise_for_status()
|
153 |
features = resp.json().get("features", [])
|
154 |
if not features:
|
155 |
return None, None
|
@@ -164,7 +143,6 @@ def geocode_address(address: str) -> tuple[float | None, float | None]:
|
|
164 |
|
165 |
def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> list[dict]:
|
166 |
"""Get nearby schools using Geoapify."""
|
167 |
-
# NOTE: Error/warning messages here are for app stability/user feedback, not part of LLM prompts.
|
168 |
if not GEOAPIFY_KEY:
|
169 |
return []
|
170 |
|
@@ -177,12 +155,12 @@ def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> lis
|
|
177 |
resp = requests.get(
|
178 |
"https://api.geoapify.com/v2/places",
|
179 |
params={
|
180 |
-
"categories": "education.school",
|
181 |
"filter": f"circle:{lon},{lat},{radius}",
|
182 |
"limit": limit,
|
183 |
"apiKey": GEOAPIFY_KEY,
|
184 |
},
|
185 |
-
timeout=10
|
186 |
)
|
187 |
resp.raise_for_status()
|
188 |
|
@@ -190,7 +168,7 @@ def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> lis
|
|
190 |
for feat in resp.json().get("features", []):
|
191 |
prop = feat.get("properties", {})
|
192 |
name = prop.get("name")
|
193 |
-
addr = prop.get("formatted")
|
194 |
if name and addr:
|
195 |
schools.append({"name": name, "address": addr})
|
196 |
return schools
|
@@ -204,17 +182,14 @@ def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> lis
|
|
204 |
def build_school_search_prompt(address: str) -> str:
|
205 |
"""Builds the prompt section listing nearby schools."""
|
206 |
if not address:
|
207 |
-
# Use the exact initial prompt text from prompts.py (originating from notebook)
|
208 |
return initial_school_search_prompt
|
209 |
|
210 |
nearby_schools = get_nearby_schools(address, radius=2000, limit=10)
|
211 |
|
212 |
if not nearby_schools:
|
213 |
-
# This text is generated dynamically based on API results, not a fixed prompt string.
|
214 |
return f"No schools found near '{address}'. Please ensure the address is correct or try a broader area if applicable."
|
215 |
|
216 |
school_list_str = "\n".join(f"- {s['name']}: {s['address']}" for s in nearby_schools)
|
217 |
-
# This text is also generated dynamically.
|
218 |
return (
|
219 |
f"Based on the residence '{address}', here are some nearby schools:\n{school_list_str}\n\n"
|
220 |
"Use this information and the provided documents to answer eligibility questions for the user's grade level."
|
@@ -224,35 +199,27 @@ def update_context(context_json: dict, new_data: dict) -> tuple[dict, bool]:
|
|
224 |
"""
|
225 |
Updates context_json in-place based on new_data extracted from LLM response.
|
226 |
Returns the updated context and a boolean indicating if residence changed.
|
227 |
-
(Logic directly based on notebook implementation)
|
228 |
"""
|
229 |
residence_changed = False
|
230 |
current_res = context_json.get("residence", "").strip()
|
231 |
new_res = new_data.get("residence", "").strip()
|
232 |
|
233 |
-
# Update residence only if it's new and different
|
234 |
if new_res and new_res != current_res:
|
235 |
context_json["residence"] = new_res
|
236 |
residence_changed = True
|
237 |
-
# Handle case where residence might be explicitly cleared in new_data
|
238 |
elif "residence" in new_data and not new_res and current_res:
|
239 |
context_json["residence"] = ""
|
240 |
-
residence_changed = True
|
241 |
|
242 |
-
# Update other fields (skip 'residence') if they exist in new_data and are different
|
243 |
for key, value in new_data.items():
|
244 |
if key != "residence":
|
245 |
-
# Ensure comparison handles various types by converting to string for check
|
246 |
new_val_str = str(value).strip() if value is not None else ""
|
247 |
old_val_str = str(context_json.get(key, "")).strip()
|
248 |
|
249 |
-
# Update if new value is provided and different from old value
|
250 |
if new_val_str and new_val_str != old_val_str:
|
251 |
-
context_json[key] = value
|
252 |
-
# Update if key is in new_data, new value is empty, but old value was not
|
253 |
elif key in new_data and not new_val_str and old_val_str:
|
254 |
-
|
255 |
-
context_json[key] = "" # Store empty string or appropriate null value
|
256 |
|
257 |
return context_json, residence_changed
|
258 |
|
@@ -261,7 +228,6 @@ def update_context(context_json: dict, new_data: dict) -> tuple[dict, bool]:
|
|
261 |
@st.cache_resource
|
262 |
def load_embedding_model():
|
263 |
"""Loads the Sentence Transformer model."""
|
264 |
-
# NOTE: Error message here is for app stability, not part of LLM prompts.
|
265 |
try:
|
266 |
return SentenceTransformer(EMBEDDING_MODEL_NAME)
|
267 |
except Exception as e:
|
@@ -271,7 +237,6 @@ def load_embedding_model():
|
|
271 |
@st.cache_data
|
272 |
def load_documents(docs_path: str) -> tuple[list[str], list[str]]:
|
273 |
"""Loads text documents from the specified directory."""
|
274 |
-
# NOTE: Error/warning messages here are for app stability, not part of LLM prompts.
|
275 |
doc_texts = []
|
276 |
filenames = []
|
277 |
if not os.path.isdir(docs_path):
|
@@ -297,7 +262,6 @@ def load_documents(docs_path: str) -> tuple[list[str], list[str]]:
|
|
297 |
@st.cache_resource(show_spinner="Creating document embeddings and FAISS index...")
|
298 |
def create_faiss_index(_embedder, doc_texts):
|
299 |
"""Creates FAISS index from document texts."""
|
300 |
-
# NOTE: Error messages here are for app stability, not part of LLM prompts.
|
301 |
if not doc_texts:
|
302 |
return None
|
303 |
try:
|
@@ -306,12 +270,10 @@ def create_faiss_index(_embedder, doc_texts):
|
|
306 |
st.error("Embedding failed, no document embeddings generated.")
|
307 |
return None
|
308 |
|
309 |
-
# Normalize embeddings for Inner Product (IP) search
|
310 |
faiss.normalize_L2(doc_embeddings)
|
311 |
dimension = doc_embeddings.shape[1]
|
312 |
-
index = faiss.IndexFlatIP(dimension)
|
313 |
index.add(doc_embeddings)
|
314 |
-
# Option to save/load index can be added here if needed
|
315 |
return index
|
316 |
except Exception as e:
|
317 |
st.error(f"Error creating FAISS index: {e}")
|
@@ -319,7 +281,6 @@ def create_faiss_index(_embedder, doc_texts):
|
|
319 |
|
320 |
def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> list[str]:
|
321 |
"""Queries the FAISS index to retrieve relevant document chunks."""
|
322 |
-
# NOTE: Error/warning messages here are for app stability, not part of LLM prompts.
|
323 |
if _index is None or not doc_texts:
|
324 |
return []
|
325 |
try:
|
@@ -327,10 +288,9 @@ def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> li
|
|
327 |
if query_embedding is None or query_embedding.shape[0] == 0:
|
328 |
st.warning("Failed to generate query embedding.")
|
329 |
return []
|
330 |
-
faiss.normalize_L2(query_embedding)
|
331 |
distances, indices = _index.search(query_embedding, top_k)
|
332 |
|
333 |
-
# Return the text of the k nearest neighbors
|
334 |
return [doc_texts[i] for i in indices[0] if i != -1]
|
335 |
except Exception as e:
|
336 |
st.error(f"Error querying FAISS index: {e}")
|
@@ -345,8 +305,8 @@ def load_llm():
|
|
345 |
model_path = hf_hub_download(
|
346 |
repo_id=MODEL_REPO_ID,
|
347 |
filename=MODEL_FILENAME,
|
348 |
-
local_dir="models",
|
349 |
-
local_dir_use_symlinks=False
|
350 |
)
|
351 |
st.success(f"Model found at: {model_path}")
|
352 |
except Exception as e:
|
@@ -359,7 +319,7 @@ def load_llm():
|
|
359 |
model_path=model_path,
|
360 |
n_ctx=N_CTX,
|
361 |
n_threads=N_THREADS,
|
362 |
-
verbose=False
|
363 |
)
|
364 |
return llm
|
365 |
except Exception as e:
|
@@ -370,43 +330,34 @@ def load_llm():
|
|
370 |
|
371 |
def build_full_prompt(
|
372 |
context_json: dict,
|
373 |
-
school_search_prompt: str,
|
374 |
history: list[dict],
|
375 |
-
max_history=5
|
376 |
) -> str:
|
377 |
-
"""Builds the final prompt string for the LLM
|
378 |
|
379 |
-
# 1. Get the latest user input from history
|
380 |
last_user_input = ""
|
381 |
if history and history[-1]["role"] == "user":
|
382 |
last_user_input = history[-1]["content"]
|
383 |
|
384 |
-
# 2. Create a query string for RAG (last user input + context summary)
|
385 |
summary_info = context_json.get("summary", "")
|
386 |
rag_query = f"{last_user_input}\n\nContext Summary: {summary_info}".strip()
|
387 |
|
388 |
-
# 3. Retrieve relevant documents
|
389 |
retrieved_docs = query_docs(rag_query, faiss_index, embedder, doc_texts_global, top_k=TOP_K_DOCS)
|
390 |
docs_context_str = "\n\n---\n\n".join(retrieved_docs)
|
391 |
if docs_context_str:
|
392 |
-
# This text is dynamically generated based on RAG results
|
393 |
docs_context_str = f"DOCUMENT CONTEXT:\n{docs_context_str}\n---"
|
394 |
else:
|
395 |
-
# This text is dynamically generated
|
396 |
docs_context_str = "DOCUMENT CONTEXT: None available."
|
397 |
|
398 |
-
|
399 |
-
# 4. Format conversation history
|
400 |
-
recent_history = history[-(max_history * 2):] # Get last N turns
|
401 |
conversation = []
|
402 |
for msg in recent_history:
|
403 |
role = "User" if msg["role"] == "user" else "Assistant"
|
404 |
-
# History content comes directly from user input or previous LLM output
|
405 |
conversation.append(f"{role}: {msg['content']}")
|
406 |
|
407 |
conversation_str = "\n".join(conversation)
|
408 |
|
409 |
-
# 5. Assemble the final prompt using exact texts from prompts.py where applicable
|
410 |
prompt = f"""{system_prompt}
|
411 |
|
412 |
{docs_context_str}
|
@@ -421,112 +372,85 @@ SCHOOL SEARCH INFO:
|
|
421 |
|
422 |
CONVERSATION HISTORY:
|
423 |
{conversation_str}
|
424 |
-
Assistant:"""
|
425 |
|
426 |
return prompt
|
427 |
|
428 |
|
429 |
# --- Streamlit App UI and Logic ---
|
430 |
|
431 |
-
# NOTE: UI text here is for the Streamlit interface, not part of LLM prompts.
|
432 |
st.set_page_config(page_title="Boston School Choice Chatbot", page_icon="🏫", layout="wide")
|
433 |
st.title("Boston Public Schools Enrollment Assistant 🏫")
|
434 |
st.markdown("Ask questions about enrolling in Boston Public Schools. I can help find nearby schools if you provide a residence address.")
|
435 |
|
436 |
-
# Load models and data
|
437 |
llm = load_llm()
|
438 |
embedder = load_embedding_model()
|
439 |
-
doc_texts_global, filenames_global = load_documents(DOCS_PATH)
|
440 |
faiss_index = create_faiss_index(embedder, doc_texts_global)
|
441 |
|
442 |
-
# Initialize session state
|
443 |
if "messages" not in st.session_state:
|
444 |
-
st.session_state.messages = []
|
445 |
if "context_json" not in st.session_state:
|
446 |
-
# Initial context structure from notebook
|
447 |
st.session_state.context_json = {
|
448 |
"residence": "",
|
449 |
"grade": "",
|
450 |
"school_choice": "",
|
451 |
-
"summary": ""
|
452 |
}
|
453 |
if "school_search" not in st.session_state:
|
454 |
-
# Use the exact initial prompt text from prompts.py
|
455 |
st.session_state.school_search = initial_school_search_prompt
|
456 |
|
457 |
-
# Display chat messages from history
|
458 |
for message in st.session_state.messages:
|
459 |
with st.chat_message(message["role"]):
|
460 |
st.markdown(message["content"])
|
461 |
|
462 |
-
# --- Main Chat Loop ---
|
463 |
-
# NOTE: UI text here (chat_input prompt) is for the Streamlit interface.
|
464 |
if prompt := st.chat_input("What is your question? (e.g., 'I live at 123 Main St, my child is going into grade 2')"):
|
465 |
-
# Add user message to history and display it
|
466 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
467 |
with st.chat_message("user"):
|
468 |
st.markdown(prompt)
|
469 |
|
470 |
-
# Prepare and generate response
|
471 |
with st.chat_message("assistant"):
|
472 |
message_placeholder = st.empty()
|
473 |
-
# NOTE: UI text here is for the Streamlit interface.
|
474 |
message_placeholder.markdown("Thinking...")
|
475 |
|
476 |
-
# Build the prompt using current state and exact prompt texts
|
477 |
full_prompt = build_full_prompt(
|
478 |
st.session_state.context_json,
|
479 |
st.session_state.school_search,
|
480 |
st.session_state.messages
|
481 |
)
|
482 |
|
483 |
-
# Display prompt for debugging if needed (optional)
|
484 |
-
# with st.expander("DEBUG: View Full Prompt"):
|
485 |
-
# st.text(full_prompt)
|
486 |
-
|
487 |
try:
|
488 |
-
# Call the LLM
|
489 |
response = llm(
|
490 |
full_prompt,
|
491 |
max_tokens=MAX_TOKENS_RESPONSE,
|
492 |
temperature=TEMPERATURE,
|
493 |
-
stop=["\nUser:", "\nAssistant:", "<|end_header_id|>", "<|eot_id|>"],
|
494 |
-
echo=False
|
495 |
)
|
496 |
raw_output = response["choices"][0]["text"].strip()
|
497 |
|
498 |
-
# Parse the response using the corrected function
|
499 |
reply_text, new_data = extract_reply_and_json(raw_output)
|
500 |
|
501 |
-
# Update context JSON state using the notebook's logic
|
502 |
updated_context, residence_changed = update_context(st.session_state.context_json, new_data)
|
503 |
st.session_state.context_json = updated_context
|
504 |
|
505 |
-
# If residence changed, update the school search prompt text for the *next* turn
|
506 |
if residence_changed:
|
507 |
-
# This calls build_school_search_prompt which uses initial_school_search_prompt if address is now empty
|
508 |
st.session_state.school_search = build_school_search_prompt(st.session_state.context_json.get("residence", ""))
|
509 |
|
510 |
-
|
511 |
-
message_placeholder.markdown(reply_text if reply_text else "_Assistant had trouble generating a response._") # Provide minimal feedback if reply is empty
|
512 |
|
513 |
-
# Add assistant response (or empty string) to history
|
514 |
st.session_state.messages.append({"role": "assistant", "content": reply_text})
|
515 |
|
516 |
except Exception as e:
|
517 |
-
# NOTE: Error message here is for app stability, not part of LLM prompts.
|
518 |
st.error(f"An error occurred during response generation: {e}")
|
519 |
error_message = "Sorry, I encountered an error processing your request."
|
520 |
message_placeholder.markdown(error_message)
|
521 |
st.session_state.messages.append({"role": "assistant", "content": error_message})
|
522 |
|
523 |
|
524 |
-
# Optional: Display current context JSON for debugging
|
525 |
-
# NOTE: UI text here is for the Streamlit interface.
|
526 |
with st.sidebar:
|
527 |
st.subheader("ℹ️ Current Context")
|
528 |
st.json(st.session_state.context_json)
|
529 |
st.subheader("🏫 School Search Status")
|
530 |
-
# Display the current text being used for the school search part of the prompt
|
531 |
st.text(st.session_state.school_search)
|
532 |
|
|
|
2 |
import os
|
3 |
import re
|
4 |
import json
|
5 |
+
import demjson3
|
6 |
import requests
|
7 |
import faiss
|
8 |
import numpy as np
|
9 |
import multiprocessing
|
10 |
+
import time
|
11 |
|
12 |
from huggingface_hub import hf_hub_download, login
|
13 |
from sentence_transformers import SentenceTransformer
|
|
|
31 |
|
32 |
# Model and RAG configuration
|
33 |
MODEL_REPO_ID = "bartowski/gemma-2-2b-it-GGUF"
|
34 |
+
MODEL_FILENAME = "gemma-2-2b-it-Q8_0.gguf"
|
35 |
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
36 |
DOCS_PATH = "docs"
|
37 |
+
FAISS_INDEX_PATH = "bps_faiss.index"
|
38 |
|
39 |
# LLM parameters
|
40 |
+
N_CTX = 2048
|
41 |
+
MAX_TOKENS_RESPONSE = 350
|
42 |
+
TEMPERATURE = 0.5
|
43 |
+
N_THREADS = multiprocessing.cpu_count() - 1
|
44 |
|
45 |
# RAG parameters
|
46 |
+
TOP_K_DOCS = 3
|
47 |
|
48 |
# Import prompts from prompts.py
|
|
|
49 |
try:
|
50 |
from prompts import system_prompt, json_prompt, initial_school_search_prompt
|
51 |
except ImportError:
|
52 |
st.error("Could not import prompts from prompts.py. Make sure the file exists.")
|
|
|
53 |
system_prompt = """
|
54 |
You are a professional assistant that answers questions about enrollment in Boston Public Schools.
|
55 |
Be friendly and helpful. Families will ask questions and provide information, such as the child's residence, grade, and school preference.
|
|
|
64 |
st.stop()
|
65 |
|
66 |
|
67 |
+
# --- Helper Functions ---
|
68 |
|
69 |
def clean_reply_text(reply: str) -> str:
|
70 |
"""Removes potential JSON blocks and cleans up common LLM artifacts."""
|
|
|
71 |
reply = re.sub(r"```[jJ][sS][oO][nN]?\s*(\{.*?\})\s*```", "", reply, flags=re.DOTALL)
|
|
|
72 |
reply = re.sub(r"\s*\{.*\}\s*$", "", reply, flags=re.DOTALL)
|
|
|
73 |
reply = re.sub(r"`", "", reply)
|
74 |
reply = re.sub(r"(?i)\bjson\b", "", reply)
|
75 |
+
reply = re.sub(r"[\[\]]", "", reply)
|
|
|
76 |
reply = re.sub(r"\n{2,}", "\n", reply)
|
77 |
return reply.strip()
|
78 |
|
|
|
82 |
Uses demjson3 for potentially more lenient parsing.
|
83 |
"""
|
84 |
json_part = {}
|
85 |
+
reply_part = text
|
86 |
|
|
|
87 |
last_brace_open = text.rfind('{')
|
88 |
if last_brace_open != -1:
|
|
|
89 |
brace_level = 0
|
90 |
last_brace_close = -1
|
91 |
potential_json_str = text[last_brace_open:]
|
|
|
101 |
if last_brace_close != -1:
|
102 |
json_str = text[last_brace_open : last_brace_close + 1]
|
103 |
try:
|
|
|
104 |
parsed = demjson3.decode(json_str)
|
105 |
if isinstance(parsed, dict):
|
106 |
json_part = parsed
|
|
|
107 |
reply_part = text[:last_brace_open].strip()
|
108 |
+
except demjson3.JSONDecodeError:
|
109 |
+
pass
|
110 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
cleaned_reply = clean_reply_text(reply_part)
|
112 |
|
|
|
|
|
|
|
113 |
if not cleaned_reply and json_part:
|
114 |
cleaned_reply = clean_reply_text(text[:last_brace_open])
|
115 |
elif not cleaned_reply and not json_part:
|
116 |
+
cleaned_reply = ""
|
117 |
|
118 |
return cleaned_reply, json_part
|
119 |
|
120 |
|
121 |
def geocode_address(address: str) -> tuple[float | None, float | None]:
|
122 |
"""Turn a free‑form address into (lat, lon) using Geoapify."""
|
|
|
123 |
if not GEOAPIFY_KEY:
|
124 |
return None, None
|
125 |
try:
|
126 |
resp = requests.get(
|
127 |
"https://api.geoapify.com/v1/geocode/search",
|
128 |
params={"text": address, "limit": 1, "apiKey": GEOAPIFY_KEY},
|
129 |
+
timeout=10
|
130 |
)
|
131 |
+
resp.raise_for_status()
|
132 |
features = resp.json().get("features", [])
|
133 |
if not features:
|
134 |
return None, None
|
|
|
143 |
|
144 |
def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> list[dict]:
|
145 |
"""Get nearby schools using Geoapify."""
|
|
|
146 |
if not GEOAPIFY_KEY:
|
147 |
return []
|
148 |
|
|
|
155 |
resp = requests.get(
|
156 |
"https://api.geoapify.com/v2/places",
|
157 |
params={
|
158 |
+
"categories": "education.school",
|
159 |
"filter": f"circle:{lon},{lat},{radius}",
|
160 |
"limit": limit,
|
161 |
"apiKey": GEOAPIFY_KEY,
|
162 |
},
|
163 |
+
timeout=10
|
164 |
)
|
165 |
resp.raise_for_status()
|
166 |
|
|
|
168 |
for feat in resp.json().get("features", []):
|
169 |
prop = feat.get("properties", {})
|
170 |
name = prop.get("name")
|
171 |
+
addr = prop.get("formatted")
|
172 |
if name and addr:
|
173 |
schools.append({"name": name, "address": addr})
|
174 |
return schools
|
|
|
182 |
def build_school_search_prompt(address: str) -> str:
|
183 |
"""Builds the prompt section listing nearby schools."""
|
184 |
if not address:
|
|
|
185 |
return initial_school_search_prompt
|
186 |
|
187 |
nearby_schools = get_nearby_schools(address, radius=2000, limit=10)
|
188 |
|
189 |
if not nearby_schools:
|
|
|
190 |
return f"No schools found near '{address}'. Please ensure the address is correct or try a broader area if applicable."
|
191 |
|
192 |
school_list_str = "\n".join(f"- {s['name']}: {s['address']}" for s in nearby_schools)
|
|
|
193 |
return (
|
194 |
f"Based on the residence '{address}', here are some nearby schools:\n{school_list_str}\n\n"
|
195 |
"Use this information and the provided documents to answer eligibility questions for the user's grade level."
|
|
|
199 |
"""
|
200 |
Updates context_json in-place based on new_data extracted from LLM response.
|
201 |
Returns the updated context and a boolean indicating if residence changed.
|
|
|
202 |
"""
|
203 |
residence_changed = False
|
204 |
current_res = context_json.get("residence", "").strip()
|
205 |
new_res = new_data.get("residence", "").strip()
|
206 |
|
|
|
207 |
if new_res and new_res != current_res:
|
208 |
context_json["residence"] = new_res
|
209 |
residence_changed = True
|
|
|
210 |
elif "residence" in new_data and not new_res and current_res:
|
211 |
context_json["residence"] = ""
|
212 |
+
residence_changed = True
|
213 |
|
|
|
214 |
for key, value in new_data.items():
|
215 |
if key != "residence":
|
|
|
216 |
new_val_str = str(value).strip() if value is not None else ""
|
217 |
old_val_str = str(context_json.get(key, "")).strip()
|
218 |
|
|
|
219 |
if new_val_str and new_val_str != old_val_str:
|
220 |
+
context_json[key] = value
|
|
|
221 |
elif key in new_data and not new_val_str and old_val_str:
|
222 |
+
context_json[key] = ""
|
|
|
223 |
|
224 |
return context_json, residence_changed
|
225 |
|
|
|
228 |
@st.cache_resource
|
229 |
def load_embedding_model():
|
230 |
"""Loads the Sentence Transformer model."""
|
|
|
231 |
try:
|
232 |
return SentenceTransformer(EMBEDDING_MODEL_NAME)
|
233 |
except Exception as e:
|
|
|
237 |
@st.cache_data
|
238 |
def load_documents(docs_path: str) -> tuple[list[str], list[str]]:
|
239 |
"""Loads text documents from the specified directory."""
|
|
|
240 |
doc_texts = []
|
241 |
filenames = []
|
242 |
if not os.path.isdir(docs_path):
|
|
|
262 |
@st.cache_resource(show_spinner="Creating document embeddings and FAISS index...")
|
263 |
def create_faiss_index(_embedder, doc_texts):
|
264 |
"""Creates FAISS index from document texts."""
|
|
|
265 |
if not doc_texts:
|
266 |
return None
|
267 |
try:
|
|
|
270 |
st.error("Embedding failed, no document embeddings generated.")
|
271 |
return None
|
272 |
|
|
|
273 |
faiss.normalize_L2(doc_embeddings)
|
274 |
dimension = doc_embeddings.shape[1]
|
275 |
+
index = faiss.IndexFlatIP(dimension)
|
276 |
index.add(doc_embeddings)
|
|
|
277 |
return index
|
278 |
except Exception as e:
|
279 |
st.error(f"Error creating FAISS index: {e}")
|
|
|
281 |
|
282 |
def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> list[str]:
|
283 |
"""Queries the FAISS index to retrieve relevant document chunks."""
|
|
|
284 |
if _index is None or not doc_texts:
|
285 |
return []
|
286 |
try:
|
|
|
288 |
if query_embedding is None or query_embedding.shape[0] == 0:
|
289 |
st.warning("Failed to generate query embedding.")
|
290 |
return []
|
291 |
+
faiss.normalize_L2(query_embedding)
|
292 |
distances, indices = _index.search(query_embedding, top_k)
|
293 |
|
|
|
294 |
return [doc_texts[i] for i in indices[0] if i != -1]
|
295 |
except Exception as e:
|
296 |
st.error(f"Error querying FAISS index: {e}")
|
|
|
305 |
model_path = hf_hub_download(
|
306 |
repo_id=MODEL_REPO_ID,
|
307 |
filename=MODEL_FILENAME,
|
308 |
+
local_dir="models",
|
309 |
+
local_dir_use_symlinks=False
|
310 |
)
|
311 |
st.success(f"Model found at: {model_path}")
|
312 |
except Exception as e:
|
|
|
319 |
model_path=model_path,
|
320 |
n_ctx=N_CTX,
|
321 |
n_threads=N_THREADS,
|
322 |
+
verbose=False
|
323 |
)
|
324 |
return llm
|
325 |
except Exception as e:
|
|
|
330 |
|
331 |
def build_full_prompt(
|
332 |
context_json: dict,
|
333 |
+
school_search_prompt: str,
|
334 |
history: list[dict],
|
335 |
+
max_history=5
|
336 |
) -> str:
|
337 |
+
"""Builds the final prompt string for the LLM."""
|
338 |
|
|
|
339 |
last_user_input = ""
|
340 |
if history and history[-1]["role"] == "user":
|
341 |
last_user_input = history[-1]["content"]
|
342 |
|
|
|
343 |
summary_info = context_json.get("summary", "")
|
344 |
rag_query = f"{last_user_input}\n\nContext Summary: {summary_info}".strip()
|
345 |
|
|
|
346 |
retrieved_docs = query_docs(rag_query, faiss_index, embedder, doc_texts_global, top_k=TOP_K_DOCS)
|
347 |
docs_context_str = "\n\n---\n\n".join(retrieved_docs)
|
348 |
if docs_context_str:
|
|
|
349 |
docs_context_str = f"DOCUMENT CONTEXT:\n{docs_context_str}\n---"
|
350 |
else:
|
|
|
351 |
docs_context_str = "DOCUMENT CONTEXT: None available."
|
352 |
|
353 |
+
recent_history = history[-(max_history * 2):]
|
|
|
|
|
354 |
conversation = []
|
355 |
for msg in recent_history:
|
356 |
role = "User" if msg["role"] == "user" else "Assistant"
|
|
|
357 |
conversation.append(f"{role}: {msg['content']}")
|
358 |
|
359 |
conversation_str = "\n".join(conversation)
|
360 |
|
|
|
361 |
prompt = f"""{system_prompt}
|
362 |
|
363 |
{docs_context_str}
|
|
|
372 |
|
373 |
CONVERSATION HISTORY:
|
374 |
{conversation_str}
|
375 |
+
Assistant:"""
|
376 |
|
377 |
return prompt
|
378 |
|
379 |
|
380 |
# --- Streamlit App UI and Logic ---
|
381 |
|
|
|
382 |
st.set_page_config(page_title="Boston School Choice Chatbot", page_icon="🏫", layout="wide")
|
383 |
st.title("Boston Public Schools Enrollment Assistant 🏫")
|
384 |
st.markdown("Ask questions about enrolling in Boston Public Schools. I can help find nearby schools if you provide a residence address.")
|
385 |
|
|
|
386 |
llm = load_llm()
|
387 |
embedder = load_embedding_model()
|
388 |
+
doc_texts_global, filenames_global = load_documents(DOCS_PATH)
|
389 |
faiss_index = create_faiss_index(embedder, doc_texts_global)
|
390 |
|
|
|
391 |
if "messages" not in st.session_state:
|
392 |
+
st.session_state.messages = []
|
393 |
if "context_json" not in st.session_state:
|
|
|
394 |
st.session_state.context_json = {
|
395 |
"residence": "",
|
396 |
"grade": "",
|
397 |
"school_choice": "",
|
398 |
+
"summary": ""
|
399 |
}
|
400 |
if "school_search" not in st.session_state:
|
|
|
401 |
st.session_state.school_search = initial_school_search_prompt
|
402 |
|
|
|
403 |
for message in st.session_state.messages:
|
404 |
with st.chat_message(message["role"]):
|
405 |
st.markdown(message["content"])
|
406 |
|
|
|
|
|
407 |
if prompt := st.chat_input("What is your question? (e.g., 'I live at 123 Main St, my child is going into grade 2')"):
|
|
|
408 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
409 |
with st.chat_message("user"):
|
410 |
st.markdown(prompt)
|
411 |
|
|
|
412 |
with st.chat_message("assistant"):
|
413 |
message_placeholder = st.empty()
|
|
|
414 |
message_placeholder.markdown("Thinking...")
|
415 |
|
|
|
416 |
full_prompt = build_full_prompt(
|
417 |
st.session_state.context_json,
|
418 |
st.session_state.school_search,
|
419 |
st.session_state.messages
|
420 |
)
|
421 |
|
|
|
|
|
|
|
|
|
422 |
try:
|
|
|
423 |
response = llm(
|
424 |
full_prompt,
|
425 |
max_tokens=MAX_TOKENS_RESPONSE,
|
426 |
temperature=TEMPERATURE,
|
427 |
+
stop=["\nUser:", "\nAssistant:", "<|end_header_id|>", "<|eot_id|>"],
|
428 |
+
echo=False
|
429 |
)
|
430 |
raw_output = response["choices"][0]["text"].strip()
|
431 |
|
|
|
432 |
reply_text, new_data = extract_reply_and_json(raw_output)
|
433 |
|
|
|
434 |
updated_context, residence_changed = update_context(st.session_state.context_json, new_data)
|
435 |
st.session_state.context_json = updated_context
|
436 |
|
|
|
437 |
if residence_changed:
|
|
|
438 |
st.session_state.school_search = build_school_search_prompt(st.session_state.context_json.get("residence", ""))
|
439 |
|
440 |
+
message_placeholder.markdown(reply_text if reply_text else "_Assistant had trouble generating a response._")
|
|
|
441 |
|
|
|
442 |
st.session_state.messages.append({"role": "assistant", "content": reply_text})
|
443 |
|
444 |
except Exception as e:
|
|
|
445 |
st.error(f"An error occurred during response generation: {e}")
|
446 |
error_message = "Sorry, I encountered an error processing your request."
|
447 |
message_placeholder.markdown(error_message)
|
448 |
st.session_state.messages.append({"role": "assistant", "content": error_message})
|
449 |
|
450 |
|
|
|
|
|
451 |
with st.sidebar:
|
452 |
st.subheader("ℹ️ Current Context")
|
453 |
st.json(st.session_state.context_json)
|
454 |
st.subheader("🏫 School Search Status")
|
|
|
455 |
st.text(st.session_state.school_search)
|
456 |
|