Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update build_rag.py
Browse files- build_rag.py +18 -11
    	
        build_rag.py
    CHANGED
    
    | @@ -1,5 +1,3 @@ | |
| 1 | 
            -
            # build_rag.py (Updated with Normalization and Cosine Distance)
         | 
| 2 | 
            -
             | 
| 3 | 
             
            import json
         | 
| 4 | 
             
            import os
         | 
| 5 | 
             
            import pandas as pd
         | 
| @@ -15,12 +13,14 @@ import traceback | |
| 15 | 
             
            # --- Configuration ---
         | 
| 16 | 
             
            CHROMA_PATH = "chroma_db"
         | 
| 17 | 
             
            COLLECTION_NAME = "bible_verses"
         | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
|  | |
|  | |
| 20 | 
             
            STATUS_FILE = "build_status.log"
         | 
| 21 | 
             
            JSON_DIRECTORY = 'bible_json'
         | 
| 22 | 
             
            CHUNK_SIZE = 3
         | 
| 23 | 
            -
            EMBEDDING_BATCH_SIZE = 16
         | 
| 24 | 
             
            # (BOOK_ID_TO_NAME dictionary remains the same)
         | 
| 25 | 
             
            BOOK_ID_TO_NAME = {
         | 
| 26 | 
             
                1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
         | 
| @@ -44,6 +44,12 @@ def update_status(message): | |
| 44 | 
             
                with open(STATUS_FILE, "w") as f:
         | 
| 45 | 
             
                    f.write(message)
         | 
| 46 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
            def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFrame:
         | 
| 48 | 
             
                # (This function is unchanged)
         | 
| 49 | 
             
                all_verses = []
         | 
| @@ -84,7 +90,6 @@ def main(): | |
| 84 | 
             
                    shutil.rmtree(CHROMA_PATH)
         | 
| 85 | 
             
                client = chromadb.PersistentClient(path=CHROMA_PATH)
         | 
| 86 |  | 
| 87 | 
            -
                # *** FIX 1: SET THE DISTANCE FUNCTION FOR THE COLLECTION ***
         | 
| 88 | 
             
                collection = client.create_collection(
         | 
| 89 | 
             
                    name=COLLECTION_NAME,
         | 
| 90 | 
             
                    metadata={"hnsw:space": "cosine"} # Use cosine distance
         | 
| @@ -99,16 +104,18 @@ def main(): | |
| 99 | 
             
                    batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
         | 
| 100 | 
             
                    texts = batch_df['text'].tolist()
         | 
| 101 |  | 
| 102 | 
            -
                     | 
|  | |
| 103 | 
             
                    with torch.no_grad():
         | 
| 104 | 
            -
                         | 
| 105 |  | 
| 106 | 
            -
                    #  | 
| 107 | 
            -
                     | 
|  | |
| 108 |  | 
| 109 | 
             
                    collection.add(
         | 
| 110 | 
             
                        ids=[str(j) for j in range(i, i + len(batch_df))],
         | 
| 111 | 
            -
                        embeddings= | 
| 112 | 
             
                        documents=texts,
         | 
| 113 | 
             
                        metadatas=batch_df[['reference', 'version']].to_dict('records')
         | 
| 114 | 
             
                    )
         | 
|  | |
|  | |
|  | |
| 1 | 
             
            import json
         | 
| 2 | 
             
            import os
         | 
| 3 | 
             
            import pandas as pd
         | 
|  | |
| 13 | 
             
            # --- Configuration ---
         | 
| 14 | 
             
            CHROMA_PATH = "chroma_db"
         | 
| 15 | 
             
            COLLECTION_NAME = "bible_verses"
         | 
| 16 | 
            +
            # *** CHANGE 1: UPDATE THE MODEL NAME ***
         | 
| 17 | 
            +
            MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
         | 
| 18 | 
            +
            # *** CHANGE 2: UPDATE THE DATASET REPO NAME TO AVOID CONFUSION ***
         | 
| 19 | 
            +
            DATASET_REPO = "broadfield-dev/bible-chromadb-mpnet"
         | 
| 20 | 
             
            STATUS_FILE = "build_status.log"
         | 
| 21 | 
             
            JSON_DIRECTORY = 'bible_json'
         | 
| 22 | 
             
            CHUNK_SIZE = 3
         | 
| 23 | 
            +
            EMBEDDING_BATCH_SIZE = 16 # Adjust based on available VRAM
         | 
| 24 | 
             
            # (BOOK_ID_TO_NAME dictionary remains the same)
         | 
| 25 | 
             
            BOOK_ID_TO_NAME = {
         | 
| 26 | 
             
                1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
         | 
|  | |
| 44 | 
             
                with open(STATUS_FILE, "w") as f:
         | 
| 45 | 
             
                    f.write(message)
         | 
| 46 |  | 
| 47 | 
            +
            # Mean Pooling Function - Take attention mask into account for correct averaging
         | 
| 48 | 
            +
            def mean_pooling(model_output, attention_mask):
         | 
| 49 | 
            +
                token_embeddings = model_output[0] #First element of model_output contains all token embeddings
         | 
| 50 | 
            +
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
         | 
| 51 | 
            +
                return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
         | 
| 52 | 
            +
             | 
| 53 | 
             
            def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFrame:
         | 
| 54 | 
             
                # (This function is unchanged)
         | 
| 55 | 
             
                all_verses = []
         | 
|  | |
| 90 | 
             
                    shutil.rmtree(CHROMA_PATH)
         | 
| 91 | 
             
                client = chromadb.PersistentClient(path=CHROMA_PATH)
         | 
| 92 |  | 
|  | |
| 93 | 
             
                collection = client.create_collection(
         | 
| 94 | 
             
                    name=COLLECTION_NAME,
         | 
| 95 | 
             
                    metadata={"hnsw:space": "cosine"} # Use cosine distance
         | 
|  | |
| 104 | 
             
                    batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
         | 
| 105 | 
             
                    texts = batch_df['text'].tolist()
         | 
| 106 |  | 
| 107 | 
            +
                    # *** CHANGE 3: USE THE CORRECT POOLING STRATEGY FOR SBERT MODELS ***
         | 
| 108 | 
            +
                    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(model.device)
         | 
| 109 | 
             
                    with torch.no_grad():
         | 
| 110 | 
            +
                        model_output = model(**encoded_input)
         | 
| 111 |  | 
| 112 | 
            +
                    # Perform pooling and normalization
         | 
| 113 | 
            +
                    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
         | 
| 114 | 
            +
                    normalized_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
         | 
| 115 |  | 
| 116 | 
             
                    collection.add(
         | 
| 117 | 
             
                        ids=[str(j) for j in range(i, i + len(batch_df))],
         | 
| 118 | 
            +
                        embeddings=normalized_embeddings.cpu().tolist(), # Convert to list
         | 
| 119 | 
             
                        documents=texts,
         | 
| 120 | 
             
                        metadatas=batch_df[['reference', 'version']].to_dict('records')
         | 
| 121 | 
             
                    )
         |