mtwesley commited on
Commit
252940e
·
1 Parent(s): 6d429bd

removing too many comments

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +40 -116
.gitignore CHANGED
@@ -20,6 +20,8 @@ wheels/
20
  .installed.cfg
21
  *.egg
22
 
 
 
23
  # Virtual Environment
24
  venv/
25
  env/
 
20
  .installed.cfg
21
  *.egg
22
 
23
+ old/*
24
+
25
  # Virtual Environment
26
  venv/
27
  env/
app.py CHANGED
@@ -2,12 +2,12 @@ import streamlit as st
2
  import os
3
  import re
4
  import json
5
- import demjson3 # Using demjson3 for potentially less strict JSON parsing
6
  import requests
7
  import faiss
8
  import numpy as np
9
  import multiprocessing
10
- import time # For adding slight delay if needed
11
 
12
  from huggingface_hub import hf_hub_download, login
13
  from sentence_transformers import SentenceTransformer
@@ -31,27 +31,25 @@ except KeyError:
31
 
32
  # Model and RAG configuration
33
  MODEL_REPO_ID = "bartowski/gemma-2-2b-it-GGUF"
34
- MODEL_FILENAME = "gemma-2-2b-it-Q8_0.gguf" # Using Q8 quantization from notebook
35
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
36
  DOCS_PATH = "docs"
37
- FAISS_INDEX_PATH = "bps_faiss.index" # Path to save/load the index
38
 
39
  # LLM parameters
40
- N_CTX = 2048 # Context window size
41
- MAX_TOKENS_RESPONSE = 350 # Max tokens for the LLM response generation
42
- TEMPERATURE = 0.5 # LLM temperature
43
- N_THREADS = multiprocessing.cpu_count() - 1 # Use half the CPU cores
44
 
45
  # RAG parameters
46
- TOP_K_DOCS = 3 # Number of relevant document chunks to retrieve
47
 
48
  # Import prompts from prompts.py
49
- # NOTE: Error message here is for app stability, not part of LLM prompts.
50
  try:
51
  from prompts import system_prompt, json_prompt, initial_school_search_prompt
52
  except ImportError:
53
  st.error("Could not import prompts from prompts.py. Make sure the file exists.")
54
- # Define fallbacks directly if import fails - Using exact text from notebook prompts
55
  system_prompt = """
56
  You are a professional assistant that answers questions about enrollment in Boston Public Schools.
57
  Be friendly and helpful. Families will ask questions and provide information, such as the child's residence, grade, and school preference.
@@ -66,19 +64,15 @@ Keep the conversation going and ask questions one at a time until you have all i
66
  st.stop()
67
 
68
 
69
- # --- Helper Functions (ported from Notebook) ---
70
 
71
  def clean_reply_text(reply: str) -> str:
72
  """Removes potential JSON blocks and cleans up common LLM artifacts."""
73
- # Remove ```json ... ``` blocks or similar markdown code blocks
74
  reply = re.sub(r"```[jJ][sS][oO][nN]?\s*(\{.*?\})\s*```", "", reply, flags=re.DOTALL)
75
- # Remove trailing JSON object if it's at the very end after potential whitespace
76
  reply = re.sub(r"\s*\{.*\}\s*$", "", reply, flags=re.DOTALL)
77
- # Remove stray backticks, `json` keywords, and unmatched brackets
78
  reply = re.sub(r"`", "", reply)
79
  reply = re.sub(r"(?i)\bjson\b", "", reply)
80
- reply = re.sub(r"[\[\]]", "", reply) # Remove only square brackets if dangling
81
- # Collapse multiple blank lines
82
  reply = re.sub(r"\n{2,}", "\n", reply)
83
  return reply.strip()
84
 
@@ -88,12 +82,10 @@ def extract_reply_and_json(text: str) -> tuple[str, dict]:
88
  Uses demjson3 for potentially more lenient parsing.
89
  """
90
  json_part = {}
91
- reply_part = text # Default reply is the whole text initially
92
 
93
- # Find the last potential JSON object (heuristic: starts with { ends with })
94
  last_brace_open = text.rfind('{')
95
  if last_brace_open != -1:
96
- # Try to find the matching closing brace
97
  brace_level = 0
98
  last_brace_close = -1
99
  potential_json_str = text[last_brace_open:]
@@ -109,47 +101,34 @@ def extract_reply_and_json(text: str) -> tuple[str, dict]:
109
  if last_brace_close != -1:
110
  json_str = text[last_brace_open : last_brace_close + 1]
111
  try:
112
- # Use demjson3 to decode
113
  parsed = demjson3.decode(json_str)
114
  if isinstance(parsed, dict):
115
  json_part = parsed
116
- # If JSON is successfully parsed, assume text before it is the reply
117
  reply_part = text[:last_brace_open].strip()
118
- else:
119
- # Parsed but not a dict, might be noise, keep original reply
120
- pass
121
- except demjson3.JSONDecodeError as e:
122
- # Decoding failed, assume it wasn't valid JSON
123
- # print(f"JSON decode failed: {e}")
124
- # print(f"Offending string segment:\n{json_str}")
125
- pass # Keep original reply_part and empty json_part
126
-
127
- # Clean the reply part further
128
  cleaned_reply = clean_reply_text(reply_part)
129
 
130
- # --- MODIFICATION: Removed the fallback message I added ---
131
- # If cleaning removed everything, return empty string for reply.
132
- # The notebook didn't specify fallback behavior here.
133
  if not cleaned_reply and json_part:
134
  cleaned_reply = clean_reply_text(text[:last_brace_open])
135
  elif not cleaned_reply and not json_part:
136
- cleaned_reply = "" # Return empty if nothing is left
137
 
138
  return cleaned_reply, json_part
139
 
140
 
141
  def geocode_address(address: str) -> tuple[float | None, float | None]:
142
  """Turn a free‑form address into (lat, lon) using Geoapify."""
143
- # NOTE: Error/warning messages here are for app stability/user feedback, not part of LLM prompts.
144
  if not GEOAPIFY_KEY:
145
  return None, None
146
  try:
147
  resp = requests.get(
148
  "https://api.geoapify.com/v1/geocode/search",
149
  params={"text": address, "limit": 1, "apiKey": GEOAPIFY_KEY},
150
- timeout=10 # Add a timeout
151
  )
152
- resp.raise_for_status() # Raise an exception for bad status codes
153
  features = resp.json().get("features", [])
154
  if not features:
155
  return None, None
@@ -164,7 +143,6 @@ def geocode_address(address: str) -> tuple[float | None, float | None]:
164
 
165
  def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> list[dict]:
166
  """Get nearby schools using Geoapify."""
167
- # NOTE: Error/warning messages here are for app stability/user feedback, not part of LLM prompts.
168
  if not GEOAPIFY_KEY:
169
  return []
170
 
@@ -177,12 +155,12 @@ def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> lis
177
  resp = requests.get(
178
  "https://api.geoapify.com/v2/places",
179
  params={
180
- "categories": "education.school", # Use the specific category
181
  "filter": f"circle:{lon},{lat},{radius}",
182
  "limit": limit,
183
  "apiKey": GEOAPIFY_KEY,
184
  },
185
- timeout=10 # Add a timeout
186
  )
187
  resp.raise_for_status()
188
 
@@ -190,7 +168,7 @@ def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> lis
190
  for feat in resp.json().get("features", []):
191
  prop = feat.get("properties", {})
192
  name = prop.get("name")
193
- addr = prop.get("formatted") # Full address string
194
  if name and addr:
195
  schools.append({"name": name, "address": addr})
196
  return schools
@@ -204,17 +182,14 @@ def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> lis
204
  def build_school_search_prompt(address: str) -> str:
205
  """Builds the prompt section listing nearby schools."""
206
  if not address:
207
- # Use the exact initial prompt text from prompts.py (originating from notebook)
208
  return initial_school_search_prompt
209
 
210
  nearby_schools = get_nearby_schools(address, radius=2000, limit=10)
211
 
212
  if not nearby_schools:
213
- # This text is generated dynamically based on API results, not a fixed prompt string.
214
  return f"No schools found near '{address}'. Please ensure the address is correct or try a broader area if applicable."
215
 
216
  school_list_str = "\n".join(f"- {s['name']}: {s['address']}" for s in nearby_schools)
217
- # This text is also generated dynamically.
218
  return (
219
  f"Based on the residence '{address}', here are some nearby schools:\n{school_list_str}\n\n"
220
  "Use this information and the provided documents to answer eligibility questions for the user's grade level."
@@ -224,35 +199,27 @@ def update_context(context_json: dict, new_data: dict) -> tuple[dict, bool]:
224
  """
225
  Updates context_json in-place based on new_data extracted from LLM response.
226
  Returns the updated context and a boolean indicating if residence changed.
227
- (Logic directly based on notebook implementation)
228
  """
229
  residence_changed = False
230
  current_res = context_json.get("residence", "").strip()
231
  new_res = new_data.get("residence", "").strip()
232
 
233
- # Update residence only if it's new and different
234
  if new_res and new_res != current_res:
235
  context_json["residence"] = new_res
236
  residence_changed = True
237
- # Handle case where residence might be explicitly cleared in new_data
238
  elif "residence" in new_data and not new_res and current_res:
239
  context_json["residence"] = ""
240
- residence_changed = True # Clearing is also a change
241
 
242
- # Update other fields (skip 'residence') if they exist in new_data and are different
243
  for key, value in new_data.items():
244
  if key != "residence":
245
- # Ensure comparison handles various types by converting to string for check
246
  new_val_str = str(value).strip() if value is not None else ""
247
  old_val_str = str(context_json.get(key, "")).strip()
248
 
249
- # Update if new value is provided and different from old value
250
  if new_val_str and new_val_str != old_val_str:
251
- context_json[key] = value # Store original type from new_data
252
- # Update if key is in new_data, new value is empty, but old value was not
253
  elif key in new_data and not new_val_str and old_val_str:
254
- # Handle explicit clearing of other fields
255
- context_json[key] = "" # Store empty string or appropriate null value
256
 
257
  return context_json, residence_changed
258
 
@@ -261,7 +228,6 @@ def update_context(context_json: dict, new_data: dict) -> tuple[dict, bool]:
261
  @st.cache_resource
262
  def load_embedding_model():
263
  """Loads the Sentence Transformer model."""
264
- # NOTE: Error message here is for app stability, not part of LLM prompts.
265
  try:
266
  return SentenceTransformer(EMBEDDING_MODEL_NAME)
267
  except Exception as e:
@@ -271,7 +237,6 @@ def load_embedding_model():
271
  @st.cache_data
272
  def load_documents(docs_path: str) -> tuple[list[str], list[str]]:
273
  """Loads text documents from the specified directory."""
274
- # NOTE: Error/warning messages here are for app stability, not part of LLM prompts.
275
  doc_texts = []
276
  filenames = []
277
  if not os.path.isdir(docs_path):
@@ -297,7 +262,6 @@ def load_documents(docs_path: str) -> tuple[list[str], list[str]]:
297
  @st.cache_resource(show_spinner="Creating document embeddings and FAISS index...")
298
  def create_faiss_index(_embedder, doc_texts):
299
  """Creates FAISS index from document texts."""
300
- # NOTE: Error messages here are for app stability, not part of LLM prompts.
301
  if not doc_texts:
302
  return None
303
  try:
@@ -306,12 +270,10 @@ def create_faiss_index(_embedder, doc_texts):
306
  st.error("Embedding failed, no document embeddings generated.")
307
  return None
308
 
309
- # Normalize embeddings for Inner Product (IP) search
310
  faiss.normalize_L2(doc_embeddings)
311
  dimension = doc_embeddings.shape[1]
312
- index = faiss.IndexFlatIP(dimension) # Using Inner Product
313
  index.add(doc_embeddings)
314
- # Option to save/load index can be added here if needed
315
  return index
316
  except Exception as e:
317
  st.error(f"Error creating FAISS index: {e}")
@@ -319,7 +281,6 @@ def create_faiss_index(_embedder, doc_texts):
319
 
320
  def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> list[str]:
321
  """Queries the FAISS index to retrieve relevant document chunks."""
322
- # NOTE: Error/warning messages here are for app stability, not part of LLM prompts.
323
  if _index is None or not doc_texts:
324
  return []
325
  try:
@@ -327,10 +288,9 @@ def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> li
327
  if query_embedding is None or query_embedding.shape[0] == 0:
328
  st.warning("Failed to generate query embedding.")
329
  return []
330
- faiss.normalize_L2(query_embedding) # Normalize query embedding
331
  distances, indices = _index.search(query_embedding, top_k)
332
 
333
- # Return the text of the k nearest neighbors
334
  return [doc_texts[i] for i in indices[0] if i != -1]
335
  except Exception as e:
336
  st.error(f"Error querying FAISS index: {e}")
@@ -345,8 +305,8 @@ def load_llm():
345
  model_path = hf_hub_download(
346
  repo_id=MODEL_REPO_ID,
347
  filename=MODEL_FILENAME,
348
- local_dir="models", # Download to a local 'models' directory
349
- local_dir_use_symlinks=False # Avoid symlinks issues in some environments
350
  )
351
  st.success(f"Model found at: {model_path}")
352
  except Exception as e:
@@ -359,7 +319,7 @@ def load_llm():
359
  model_path=model_path,
360
  n_ctx=N_CTX,
361
  n_threads=N_THREADS,
362
- verbose=False # Set to True for more llama.cpp logging
363
  )
364
  return llm
365
  except Exception as e:
@@ -370,43 +330,34 @@ def load_llm():
370
 
371
  def build_full_prompt(
372
  context_json: dict,
373
- school_search_prompt: str, # This now correctly uses initial_school_search_prompt from prompts.py when address is missing
374
  history: list[dict],
375
- max_history=5 # Keep last 5 turns (user + assistant)
376
  ) -> str:
377
- """Builds the final prompt string for the LLM, using exact prompt texts where specified."""
378
 
379
- # 1. Get the latest user input from history
380
  last_user_input = ""
381
  if history and history[-1]["role"] == "user":
382
  last_user_input = history[-1]["content"]
383
 
384
- # 2. Create a query string for RAG (last user input + context summary)
385
  summary_info = context_json.get("summary", "")
386
  rag_query = f"{last_user_input}\n\nContext Summary: {summary_info}".strip()
387
 
388
- # 3. Retrieve relevant documents
389
  retrieved_docs = query_docs(rag_query, faiss_index, embedder, doc_texts_global, top_k=TOP_K_DOCS)
390
  docs_context_str = "\n\n---\n\n".join(retrieved_docs)
391
  if docs_context_str:
392
- # This text is dynamically generated based on RAG results
393
  docs_context_str = f"DOCUMENT CONTEXT:\n{docs_context_str}\n---"
394
  else:
395
- # This text is dynamically generated
396
  docs_context_str = "DOCUMENT CONTEXT: None available."
397
 
398
-
399
- # 4. Format conversation history
400
- recent_history = history[-(max_history * 2):] # Get last N turns
401
  conversation = []
402
  for msg in recent_history:
403
  role = "User" if msg["role"] == "user" else "Assistant"
404
- # History content comes directly from user input or previous LLM output
405
  conversation.append(f"{role}: {msg['content']}")
406
 
407
  conversation_str = "\n".join(conversation)
408
 
409
- # 5. Assemble the final prompt using exact texts from prompts.py where applicable
410
  prompt = f"""{system_prompt}
411
 
412
  {docs_context_str}
@@ -421,112 +372,85 @@ SCHOOL SEARCH INFO:
421
 
422
  CONVERSATION HISTORY:
423
  {conversation_str}
424
- Assistant:""" # Note: Ends with "Assistant:", prompting the model to respond
425
 
426
  return prompt
427
 
428
 
429
  # --- Streamlit App UI and Logic ---
430
 
431
- # NOTE: UI text here is for the Streamlit interface, not part of LLM prompts.
432
  st.set_page_config(page_title="Boston School Choice Chatbot", page_icon="🏫", layout="wide")
433
  st.title("Boston Public Schools Enrollment Assistant 🏫")
434
  st.markdown("Ask questions about enrolling in Boston Public Schools. I can help find nearby schools if you provide a residence address.")
435
 
436
- # Load models and data
437
  llm = load_llm()
438
  embedder = load_embedding_model()
439
- doc_texts_global, filenames_global = load_documents(DOCS_PATH) # Load once
440
  faiss_index = create_faiss_index(embedder, doc_texts_global)
441
 
442
- # Initialize session state
443
  if "messages" not in st.session_state:
444
- st.session_state.messages = [] # Stores chat history
445
  if "context_json" not in st.session_state:
446
- # Initial context structure from notebook
447
  st.session_state.context_json = {
448
  "residence": "",
449
  "grade": "",
450
  "school_choice": "",
451
- "summary": "" # LLM can update this summary
452
  }
453
  if "school_search" not in st.session_state:
454
- # Use the exact initial prompt text from prompts.py
455
  st.session_state.school_search = initial_school_search_prompt
456
 
457
- # Display chat messages from history
458
  for message in st.session_state.messages:
459
  with st.chat_message(message["role"]):
460
  st.markdown(message["content"])
461
 
462
- # --- Main Chat Loop ---
463
- # NOTE: UI text here (chat_input prompt) is for the Streamlit interface.
464
  if prompt := st.chat_input("What is your question? (e.g., 'I live at 123 Main St, my child is going into grade 2')"):
465
- # Add user message to history and display it
466
  st.session_state.messages.append({"role": "user", "content": prompt})
467
  with st.chat_message("user"):
468
  st.markdown(prompt)
469
 
470
- # Prepare and generate response
471
  with st.chat_message("assistant"):
472
  message_placeholder = st.empty()
473
- # NOTE: UI text here is for the Streamlit interface.
474
  message_placeholder.markdown("Thinking...")
475
 
476
- # Build the prompt using current state and exact prompt texts
477
  full_prompt = build_full_prompt(
478
  st.session_state.context_json,
479
  st.session_state.school_search,
480
  st.session_state.messages
481
  )
482
 
483
- # Display prompt for debugging if needed (optional)
484
- # with st.expander("DEBUG: View Full Prompt"):
485
- # st.text(full_prompt)
486
-
487
  try:
488
- # Call the LLM
489
  response = llm(
490
  full_prompt,
491
  max_tokens=MAX_TOKENS_RESPONSE,
492
  temperature=TEMPERATURE,
493
- stop=["\nUser:", "\nAssistant:", "<|end_header_id|>", "<|eot_id|>"], # Stop tokens
494
- echo=False # Don't echo the prompt in the output
495
  )
496
  raw_output = response["choices"][0]["text"].strip()
497
 
498
- # Parse the response using the corrected function
499
  reply_text, new_data = extract_reply_and_json(raw_output)
500
 
501
- # Update context JSON state using the notebook's logic
502
  updated_context, residence_changed = update_context(st.session_state.context_json, new_data)
503
  st.session_state.context_json = updated_context
504
 
505
- # If residence changed, update the school search prompt text for the *next* turn
506
  if residence_changed:
507
- # This calls build_school_search_prompt which uses initial_school_search_prompt if address is now empty
508
  st.session_state.school_search = build_school_search_prompt(st.session_state.context_json.get("residence", ""))
509
 
510
- # Display the final reply from LLM (or empty string if parsing failed)
511
- message_placeholder.markdown(reply_text if reply_text else "_Assistant had trouble generating a response._") # Provide minimal feedback if reply is empty
512
 
513
- # Add assistant response (or empty string) to history
514
  st.session_state.messages.append({"role": "assistant", "content": reply_text})
515
 
516
  except Exception as e:
517
- # NOTE: Error message here is for app stability, not part of LLM prompts.
518
  st.error(f"An error occurred during response generation: {e}")
519
  error_message = "Sorry, I encountered an error processing your request."
520
  message_placeholder.markdown(error_message)
521
  st.session_state.messages.append({"role": "assistant", "content": error_message})
522
 
523
 
524
- # Optional: Display current context JSON for debugging
525
- # NOTE: UI text here is for the Streamlit interface.
526
  with st.sidebar:
527
  st.subheader("ℹ️ Current Context")
528
  st.json(st.session_state.context_json)
529
  st.subheader("🏫 School Search Status")
530
- # Display the current text being used for the school search part of the prompt
531
  st.text(st.session_state.school_search)
532
 
 
2
  import os
3
  import re
4
  import json
5
+ import demjson3
6
  import requests
7
  import faiss
8
  import numpy as np
9
  import multiprocessing
10
+ import time
11
 
12
  from huggingface_hub import hf_hub_download, login
13
  from sentence_transformers import SentenceTransformer
 
31
 
32
  # Model and RAG configuration
33
  MODEL_REPO_ID = "bartowski/gemma-2-2b-it-GGUF"
34
+ MODEL_FILENAME = "gemma-2-2b-it-Q8_0.gguf"
35
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
36
  DOCS_PATH = "docs"
37
+ FAISS_INDEX_PATH = "bps_faiss.index"
38
 
39
  # LLM parameters
40
+ N_CTX = 2048
41
+ MAX_TOKENS_RESPONSE = 350
42
+ TEMPERATURE = 0.5
43
+ N_THREADS = multiprocessing.cpu_count() - 1
44
 
45
  # RAG parameters
46
+ TOP_K_DOCS = 3
47
 
48
  # Import prompts from prompts.py
 
49
  try:
50
  from prompts import system_prompt, json_prompt, initial_school_search_prompt
51
  except ImportError:
52
  st.error("Could not import prompts from prompts.py. Make sure the file exists.")
 
53
  system_prompt = """
54
  You are a professional assistant that answers questions about enrollment in Boston Public Schools.
55
  Be friendly and helpful. Families will ask questions and provide information, such as the child's residence, grade, and school preference.
 
64
  st.stop()
65
 
66
 
67
+ # --- Helper Functions ---
68
 
69
  def clean_reply_text(reply: str) -> str:
70
  """Removes potential JSON blocks and cleans up common LLM artifacts."""
 
71
  reply = re.sub(r"```[jJ][sS][oO][nN]?\s*(\{.*?\})\s*```", "", reply, flags=re.DOTALL)
 
72
  reply = re.sub(r"\s*\{.*\}\s*$", "", reply, flags=re.DOTALL)
 
73
  reply = re.sub(r"`", "", reply)
74
  reply = re.sub(r"(?i)\bjson\b", "", reply)
75
+ reply = re.sub(r"[\[\]]", "", reply)
 
76
  reply = re.sub(r"\n{2,}", "\n", reply)
77
  return reply.strip()
78
 
 
82
  Uses demjson3 for potentially more lenient parsing.
83
  """
84
  json_part = {}
85
+ reply_part = text
86
 
 
87
  last_brace_open = text.rfind('{')
88
  if last_brace_open != -1:
 
89
  brace_level = 0
90
  last_brace_close = -1
91
  potential_json_str = text[last_brace_open:]
 
101
  if last_brace_close != -1:
102
  json_str = text[last_brace_open : last_brace_close + 1]
103
  try:
 
104
  parsed = demjson3.decode(json_str)
105
  if isinstance(parsed, dict):
106
  json_part = parsed
 
107
  reply_part = text[:last_brace_open].strip()
108
+ except demjson3.JSONDecodeError:
109
+ pass
110
+
 
 
 
 
 
 
 
111
  cleaned_reply = clean_reply_text(reply_part)
112
 
 
 
 
113
  if not cleaned_reply and json_part:
114
  cleaned_reply = clean_reply_text(text[:last_brace_open])
115
  elif not cleaned_reply and not json_part:
116
+ cleaned_reply = ""
117
 
118
  return cleaned_reply, json_part
119
 
120
 
121
  def geocode_address(address: str) -> tuple[float | None, float | None]:
122
  """Turn a free‑form address into (lat, lon) using Geoapify."""
 
123
  if not GEOAPIFY_KEY:
124
  return None, None
125
  try:
126
  resp = requests.get(
127
  "https://api.geoapify.com/v1/geocode/search",
128
  params={"text": address, "limit": 1, "apiKey": GEOAPIFY_KEY},
129
+ timeout=10
130
  )
131
+ resp.raise_for_status()
132
  features = resp.json().get("features", [])
133
  if not features:
134
  return None, None
 
143
 
144
  def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> list[dict]:
145
  """Get nearby schools using Geoapify."""
 
146
  if not GEOAPIFY_KEY:
147
  return []
148
 
 
155
  resp = requests.get(
156
  "https://api.geoapify.com/v2/places",
157
  params={
158
+ "categories": "education.school",
159
  "filter": f"circle:{lon},{lat},{radius}",
160
  "limit": limit,
161
  "apiKey": GEOAPIFY_KEY,
162
  },
163
+ timeout=10
164
  )
165
  resp.raise_for_status()
166
 
 
168
  for feat in resp.json().get("features", []):
169
  prop = feat.get("properties", {})
170
  name = prop.get("name")
171
+ addr = prop.get("formatted")
172
  if name and addr:
173
  schools.append({"name": name, "address": addr})
174
  return schools
 
182
  def build_school_search_prompt(address: str) -> str:
183
  """Builds the prompt section listing nearby schools."""
184
  if not address:
 
185
  return initial_school_search_prompt
186
 
187
  nearby_schools = get_nearby_schools(address, radius=2000, limit=10)
188
 
189
  if not nearby_schools:
 
190
  return f"No schools found near '{address}'. Please ensure the address is correct or try a broader area if applicable."
191
 
192
  school_list_str = "\n".join(f"- {s['name']}: {s['address']}" for s in nearby_schools)
 
193
  return (
194
  f"Based on the residence '{address}', here are some nearby schools:\n{school_list_str}\n\n"
195
  "Use this information and the provided documents to answer eligibility questions for the user's grade level."
 
199
  """
200
  Updates context_json in-place based on new_data extracted from LLM response.
201
  Returns the updated context and a boolean indicating if residence changed.
 
202
  """
203
  residence_changed = False
204
  current_res = context_json.get("residence", "").strip()
205
  new_res = new_data.get("residence", "").strip()
206
 
 
207
  if new_res and new_res != current_res:
208
  context_json["residence"] = new_res
209
  residence_changed = True
 
210
  elif "residence" in new_data and not new_res and current_res:
211
  context_json["residence"] = ""
212
+ residence_changed = True
213
 
 
214
  for key, value in new_data.items():
215
  if key != "residence":
 
216
  new_val_str = str(value).strip() if value is not None else ""
217
  old_val_str = str(context_json.get(key, "")).strip()
218
 
 
219
  if new_val_str and new_val_str != old_val_str:
220
+ context_json[key] = value
 
221
  elif key in new_data and not new_val_str and old_val_str:
222
+ context_json[key] = ""
 
223
 
224
  return context_json, residence_changed
225
 
 
228
  @st.cache_resource
229
  def load_embedding_model():
230
  """Loads the Sentence Transformer model."""
 
231
  try:
232
  return SentenceTransformer(EMBEDDING_MODEL_NAME)
233
  except Exception as e:
 
237
  @st.cache_data
238
  def load_documents(docs_path: str) -> tuple[list[str], list[str]]:
239
  """Loads text documents from the specified directory."""
 
240
  doc_texts = []
241
  filenames = []
242
  if not os.path.isdir(docs_path):
 
262
  @st.cache_resource(show_spinner="Creating document embeddings and FAISS index...")
263
  def create_faiss_index(_embedder, doc_texts):
264
  """Creates FAISS index from document texts."""
 
265
  if not doc_texts:
266
  return None
267
  try:
 
270
  st.error("Embedding failed, no document embeddings generated.")
271
  return None
272
 
 
273
  faiss.normalize_L2(doc_embeddings)
274
  dimension = doc_embeddings.shape[1]
275
+ index = faiss.IndexFlatIP(dimension)
276
  index.add(doc_embeddings)
 
277
  return index
278
  except Exception as e:
279
  st.error(f"Error creating FAISS index: {e}")
 
281
 
282
  def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> list[str]:
283
  """Queries the FAISS index to retrieve relevant document chunks."""
 
284
  if _index is None or not doc_texts:
285
  return []
286
  try:
 
288
  if query_embedding is None or query_embedding.shape[0] == 0:
289
  st.warning("Failed to generate query embedding.")
290
  return []
291
+ faiss.normalize_L2(query_embedding)
292
  distances, indices = _index.search(query_embedding, top_k)
293
 
 
294
  return [doc_texts[i] for i in indices[0] if i != -1]
295
  except Exception as e:
296
  st.error(f"Error querying FAISS index: {e}")
 
305
  model_path = hf_hub_download(
306
  repo_id=MODEL_REPO_ID,
307
  filename=MODEL_FILENAME,
308
+ local_dir="models",
309
+ local_dir_use_symlinks=False
310
  )
311
  st.success(f"Model found at: {model_path}")
312
  except Exception as e:
 
319
  model_path=model_path,
320
  n_ctx=N_CTX,
321
  n_threads=N_THREADS,
322
+ verbose=False
323
  )
324
  return llm
325
  except Exception as e:
 
330
 
331
  def build_full_prompt(
332
  context_json: dict,
333
+ school_search_prompt: str,
334
  history: list[dict],
335
+ max_history=5
336
  ) -> str:
337
+ """Builds the final prompt string for the LLM."""
338
 
 
339
  last_user_input = ""
340
  if history and history[-1]["role"] == "user":
341
  last_user_input = history[-1]["content"]
342
 
 
343
  summary_info = context_json.get("summary", "")
344
  rag_query = f"{last_user_input}\n\nContext Summary: {summary_info}".strip()
345
 
 
346
  retrieved_docs = query_docs(rag_query, faiss_index, embedder, doc_texts_global, top_k=TOP_K_DOCS)
347
  docs_context_str = "\n\n---\n\n".join(retrieved_docs)
348
  if docs_context_str:
 
349
  docs_context_str = f"DOCUMENT CONTEXT:\n{docs_context_str}\n---"
350
  else:
 
351
  docs_context_str = "DOCUMENT CONTEXT: None available."
352
 
353
+ recent_history = history[-(max_history * 2):]
 
 
354
  conversation = []
355
  for msg in recent_history:
356
  role = "User" if msg["role"] == "user" else "Assistant"
 
357
  conversation.append(f"{role}: {msg['content']}")
358
 
359
  conversation_str = "\n".join(conversation)
360
 
 
361
  prompt = f"""{system_prompt}
362
 
363
  {docs_context_str}
 
372
 
373
  CONVERSATION HISTORY:
374
  {conversation_str}
375
+ Assistant:"""
376
 
377
  return prompt
378
 
379
 
380
  # --- Streamlit App UI and Logic ---
381
 
 
382
  st.set_page_config(page_title="Boston School Choice Chatbot", page_icon="🏫", layout="wide")
383
  st.title("Boston Public Schools Enrollment Assistant 🏫")
384
  st.markdown("Ask questions about enrolling in Boston Public Schools. I can help find nearby schools if you provide a residence address.")
385
 
 
386
  llm = load_llm()
387
  embedder = load_embedding_model()
388
+ doc_texts_global, filenames_global = load_documents(DOCS_PATH)
389
  faiss_index = create_faiss_index(embedder, doc_texts_global)
390
 
 
391
  if "messages" not in st.session_state:
392
+ st.session_state.messages = []
393
  if "context_json" not in st.session_state:
 
394
  st.session_state.context_json = {
395
  "residence": "",
396
  "grade": "",
397
  "school_choice": "",
398
+ "summary": ""
399
  }
400
  if "school_search" not in st.session_state:
 
401
  st.session_state.school_search = initial_school_search_prompt
402
 
 
403
  for message in st.session_state.messages:
404
  with st.chat_message(message["role"]):
405
  st.markdown(message["content"])
406
 
 
 
407
  if prompt := st.chat_input("What is your question? (e.g., 'I live at 123 Main St, my child is going into grade 2')"):
 
408
  st.session_state.messages.append({"role": "user", "content": prompt})
409
  with st.chat_message("user"):
410
  st.markdown(prompt)
411
 
 
412
  with st.chat_message("assistant"):
413
  message_placeholder = st.empty()
 
414
  message_placeholder.markdown("Thinking...")
415
 
 
416
  full_prompt = build_full_prompt(
417
  st.session_state.context_json,
418
  st.session_state.school_search,
419
  st.session_state.messages
420
  )
421
 
 
 
 
 
422
  try:
 
423
  response = llm(
424
  full_prompt,
425
  max_tokens=MAX_TOKENS_RESPONSE,
426
  temperature=TEMPERATURE,
427
+ stop=["\nUser:", "\nAssistant:", "<|end_header_id|>", "<|eot_id|>"],
428
+ echo=False
429
  )
430
  raw_output = response["choices"][0]["text"].strip()
431
 
 
432
  reply_text, new_data = extract_reply_and_json(raw_output)
433
 
 
434
  updated_context, residence_changed = update_context(st.session_state.context_json, new_data)
435
  st.session_state.context_json = updated_context
436
 
 
437
  if residence_changed:
 
438
  st.session_state.school_search = build_school_search_prompt(st.session_state.context_json.get("residence", ""))
439
 
440
+ message_placeholder.markdown(reply_text if reply_text else "_Assistant had trouble generating a response._")
 
441
 
 
442
  st.session_state.messages.append({"role": "assistant", "content": reply_text})
443
 
444
  except Exception as e:
 
445
  st.error(f"An error occurred during response generation: {e}")
446
  error_message = "Sorry, I encountered an error processing your request."
447
  message_placeholder.markdown(error_message)
448
  st.session_state.messages.append({"role": "assistant", "content": error_message})
449
 
450
 
 
 
451
  with st.sidebar:
452
  st.subheader("ℹ️ Current Context")
453
  st.json(st.session_state.context_json)
454
  st.subheader("🏫 School Search Status")
 
455
  st.text(st.session_state.school_search)
456