broadfield-dev commited on
Commit
4dc4b99
·
verified ·
1 Parent(s): 390fc75

Update build_rag.py

Browse files
Files changed (1) hide show
  1. build_rag.py +18 -11
build_rag.py CHANGED
@@ -1,5 +1,3 @@
1
- # build_rag.py (Updated with Normalization and Cosine Distance)
2
-
3
  import json
4
  import os
5
  import pandas as pd
@@ -15,12 +13,14 @@ import traceback
15
  # --- Configuration ---
16
  CHROMA_PATH = "chroma_db"
17
  COLLECTION_NAME = "bible_verses"
18
- MODEL_NAME = "google/embeddinggemma-300m"
19
- DATASET_REPO = "broadfield-dev/bible-chromadb-gemma"
 
 
20
  STATUS_FILE = "build_status.log"
21
  JSON_DIRECTORY = 'bible_json'
22
  CHUNK_SIZE = 3
23
- EMBEDDING_BATCH_SIZE = 16
24
  # (BOOK_ID_TO_NAME dictionary remains the same)
25
  BOOK_ID_TO_NAME = {
26
  1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
@@ -44,6 +44,12 @@ def update_status(message):
44
  with open(STATUS_FILE, "w") as f:
45
  f.write(message)
46
 
 
 
 
 
 
 
47
  def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFrame:
48
  # (This function is unchanged)
49
  all_verses = []
@@ -84,7 +90,6 @@ def main():
84
  shutil.rmtree(CHROMA_PATH)
85
  client = chromadb.PersistentClient(path=CHROMA_PATH)
86
 
87
- # *** FIX 1: SET THE DISTANCE FUNCTION FOR THE COLLECTION ***
88
  collection = client.create_collection(
89
  name=COLLECTION_NAME,
90
  metadata={"hnsw:space": "cosine"} # Use cosine distance
@@ -99,16 +104,18 @@ def main():
99
  batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
100
  texts = batch_df['text'].tolist()
101
 
102
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
 
103
  with torch.no_grad():
104
- outputs = model(**inputs)
105
 
106
- # *** FIX 2: NORMALIZE THE EMBEDDINGS ***
107
- embeddings = F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=1)
 
108
 
109
  collection.add(
110
  ids=[str(j) for j in range(i, i + len(batch_df))],
111
- embeddings=embeddings.cpu().tolist(), # Convert to list after normalization
112
  documents=texts,
113
  metadatas=batch_df[['reference', 'version']].to_dict('records')
114
  )
 
 
 
1
  import json
2
  import os
3
  import pandas as pd
 
13
  # --- Configuration ---
14
  CHROMA_PATH = "chroma_db"
15
  COLLECTION_NAME = "bible_verses"
16
+ # *** CHANGE 1: UPDATE THE MODEL NAME ***
17
+ MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
18
+ # *** CHANGE 2: UPDATE THE DATASET REPO NAME TO AVOID CONFUSION ***
19
+ DATASET_REPO = "broadfield-dev/bible-chromadb-mpnet"
20
  STATUS_FILE = "build_status.log"
21
  JSON_DIRECTORY = 'bible_json'
22
  CHUNK_SIZE = 3
23
+ EMBEDDING_BATCH_SIZE = 16 # Adjust based on available VRAM
24
  # (BOOK_ID_TO_NAME dictionary remains the same)
25
  BOOK_ID_TO_NAME = {
26
  1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
 
44
  with open(STATUS_FILE, "w") as f:
45
  f.write(message)
46
 
47
+ # Mean Pooling Function - Take attention mask into account for correct averaging
48
+ def mean_pooling(model_output, attention_mask):
49
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
50
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
51
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
52
+
53
  def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFrame:
54
  # (This function is unchanged)
55
  all_verses = []
 
90
  shutil.rmtree(CHROMA_PATH)
91
  client = chromadb.PersistentClient(path=CHROMA_PATH)
92
 
 
93
  collection = client.create_collection(
94
  name=COLLECTION_NAME,
95
  metadata={"hnsw:space": "cosine"} # Use cosine distance
 
104
  batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
105
  texts = batch_df['text'].tolist()
106
 
107
+ # *** CHANGE 3: USE THE CORRECT POOLING STRATEGY FOR SBERT MODELS ***
108
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(model.device)
109
  with torch.no_grad():
110
+ model_output = model(**encoded_input)
111
 
112
+ # Perform pooling and normalization
113
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
114
+ normalized_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
115
 
116
  collection.add(
117
  ids=[str(j) for j in range(i, i + len(batch_df))],
118
+ embeddings=normalized_embeddings.cpu().tolist(), # Convert to list
119
  documents=texts,
120
  metadatas=batch_df[['reference', 'version']].to_dict('records')
121
  )