broadfield-dev commited on
Commit
390fc75
·
verified ·
1 Parent(s): 8304431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py (Updated with Normalization for the query)
2
-
3
  import sys
4
  import subprocess
5
  from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
@@ -16,14 +14,22 @@ app.secret_key = os.urandom(24)
16
 
17
  CHROMA_PATH = "chroma_db"
18
  COLLECTION_NAME = "bible_verses"
19
- MODEL_NAME = "google/embeddinggemma-300m"
20
- DATASET_REPO = "broadfield-dev/bible-chromadb-gemma"
 
 
21
  STATUS_FILE = "build_status.log"
22
 
23
  chroma_collection = None
24
  tokenizer = None
25
  embedding_model = None
26
 
 
 
 
 
 
 
27
  def load_resources():
28
  # (This function is unchanged)
29
  global chroma_collection, tokenizer, embedding_model
@@ -89,12 +95,13 @@ def search():
89
  if not user_query:
90
  return render_template('index.html', results=[])
91
 
92
- inputs = tokenizer(user_query, return_tensors="pt")
 
93
  with torch.no_grad():
94
- outputs = embedding_model(**inputs)
95
 
96
- # *** FIX: NORMALIZE THE QUERY EMBEDDING ***
97
- query_embedding = F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=1)
98
 
99
  search_results = chroma_collection.query(
100
  query_embeddings=query_embedding.cpu().tolist(),
 
 
 
1
  import sys
2
  import subprocess
3
  from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
 
14
 
15
  CHROMA_PATH = "chroma_db"
16
  COLLECTION_NAME = "bible_verses"
17
+ # *** CHANGE 1: UPDATE THE MODEL NAME ***
18
+ MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
19
+ # *** CHANGE 2: UPDATE THE DATASET REPO NAME ***
20
+ DATASET_REPO = "broadfield-dev/bible-chromadb-mpnet"
21
  STATUS_FILE = "build_status.log"
22
 
23
  chroma_collection = None
24
  tokenizer = None
25
  embedding_model = None
26
 
27
+ # Mean Pooling Function - Take attention mask into account for correct averaging
28
+ def mean_pooling(model_output, attention_mask):
29
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
30
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
31
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
32
+
33
  def load_resources():
34
  # (This function is unchanged)
35
  global chroma_collection, tokenizer, embedding_model
 
95
  if not user_query:
96
  return render_template('index.html', results=[])
97
 
98
+ # *** CHANGE 3: USE THE CORRECT POOLING STRATEGY FOR SBERT MODELS ***
99
+ encoded_input = tokenizer([user_query], padding=True, truncation=True, return_tensors='pt')
100
  with torch.no_grad():
101
+ model_output = embedding_model(**encoded_input)
102
 
103
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
104
+ query_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
105
 
106
  search_results = chroma_collection.query(
107
  query_embeddings=query_embedding.cpu().tolist(),