bhoomika19 commited on
Commit
6874d8b
·
1 Parent(s): 61f25c3

phase 1 - data storage in qdrant and retrieval

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ venv/
3
+ __pycache__/
database/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Database Module - Math Agentic RAG
2
+
3
+ This module handles the knowledge base creation and retrieval for the Math Agentic RAG system.
4
+
5
+ ## Files Overview
6
+
7
+ ### Core Files
8
+ - **`utils.py`** - Utility functions for embedding generation and data processing
9
+ - **`qdrant_manager.py`** - Qdrant vector database client wrapper
10
+ - **`ingest.py`** - Main ingestion script for loading dataset into Qdrant (includes config)
11
+ - **`test_retrieval.py`** - Testing script for validating retrieval functionality (includes config)
12
+
13
+ ### Dependencies
14
+ - **`requirements.txt`** - Python package dependencies
15
+
16
+ ## Usage
17
+
18
+ 1. **Setup Environment Variables**: Ensure `.env` file has Qdrant credentials
19
+ 2. **Install Dependencies**: `pip install -r requirements.txt`
20
+ 3. **Ingest Data**: `python ingest.py`
21
+ 4. **Test Retrieval**: `python test_retrieval.py`
22
+
23
+ ## Current Status
24
+ - ✅ Dataset: Nuinamath (5,000 mathematical problems)
25
+ - ✅ Vector DB: Qdrant Cloud
26
+ - ✅ Embedding Model: all-MiniLM-L6-v2 (384 dimensions)
27
+ - ✅ Status: Ready for Phase 2
database/ingest.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main ingestion script for loading Nuinamath dataset into Qdrant.
3
+ """
4
+ import logging
5
+ import os
6
+ from datasets import load_dataset
7
+ from tqdm import tqdm
8
+ import time
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Configuration settings
15
+ QDRANT_URL = os.getenv("QDRANT_URL")
16
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
17
+ QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
18
+ DATASET_NAME = "AI-MO/NuminaMath-CoT"
19
+ DATASET_SPLIT = "train"
20
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
21
+ VECTOR_SIZE = 384
22
+ DISTANCE_METRIC = "Cosine"
23
+ BATCH_SIZE = 100
24
+ MAX_SAMPLES = None
25
+
26
+ # Validation
27
+ if not QDRANT_URL or not QDRANT_API_KEY:
28
+ raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file")
29
+
30
+ from utils import EmbeddingGenerator, batch_process_dataset
31
+ from qdrant_manager import QdrantManager
32
+
33
+ # Set up logging
34
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
35
+ logger = logging.getLogger(__name__)
36
+
37
+ def main():
38
+ """Main ingestion pipeline."""
39
+ try:
40
+ # Initialize components
41
+ logger.info("Initializing components...")
42
+ embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
43
+ qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
44
+
45
+ # Load dataset
46
+ logger.info(f"Loading dataset: {DATASET_NAME}")
47
+ if MAX_SAMPLES:
48
+ dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]")
49
+ logger.info(f"Loaded {len(dataset)} samples (limited)")
50
+ else:
51
+ dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
52
+ logger.info(f"Loaded full dataset: {len(dataset)} samples")
53
+
54
+ # Create Qdrant collection
55
+ logger.info(f"Creating collection: {QDRANT_COLLECTION}")
56
+ success = qdrant_manager.create_collection(
57
+ collection_name=QDRANT_COLLECTION,
58
+ vector_size=VECTOR_SIZE,
59
+ distance=DISTANCE_METRIC
60
+ )
61
+
62
+ if not success:
63
+ logger.error("Failed to create collection")
64
+ return
65
+
66
+ # Process dataset in batches
67
+ logger.info("Processing dataset in batches...")
68
+ batches = batch_process_dataset(dataset, BATCH_SIZE)
69
+
70
+ total_processed = 0
71
+ total_batches = len(batches)
72
+
73
+ for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")):
74
+ try:
75
+ # Extract texts for embedding
76
+ texts = [item['text'] for item in batch_data]
77
+
78
+ # Generate embeddings
79
+ logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}")
80
+ embeddings = embedding_generator.embed_text(texts)
81
+
82
+ # Upsert to Qdrant
83
+ logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...")
84
+ qdrant_manager.upsert_points(
85
+ collection_name=QDRANT_COLLECTION,
86
+ points_data=batch_data,
87
+ embeddings=embeddings
88
+ )
89
+
90
+ total_processed += len(batch_data)
91
+ logger.info(f"Progress: {total_processed}/{len(dataset)} items processed")
92
+
93
+ # Small delay to avoid overwhelming the API
94
+ time.sleep(0.5)
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error processing batch {batch_idx + 1}: {e}")
98
+ continue
99
+
100
+ # Final summary
101
+ logger.info("Ingestion completed!")
102
+ logger.info(f"Total items processed: {total_processed}")
103
+
104
+ # Get collection info
105
+ collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
106
+ if collection_info:
107
+ logger.info(f"Collection status: {collection_info.status}")
108
+ logger.info(f"Vectors count: {collection_info.vectors_count}")
109
+
110
+ except Exception as e:
111
+ logger.error(f"Fatal error in ingestion pipeline: {e}")
112
+ raise
113
+
114
+ if __name__ == "__main__":
115
+ main()
database/qdrant_manager.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qdrant client wrapper for vector database operations.
3
+ """
4
+ import logging
5
+ from typing import List, Dict, Any
6
+ from qdrant_client import QdrantClient
7
+ from qdrant_client.models import Distance, VectorParams, PointStruct
8
+ import time
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class QdrantManager:
13
+ """Manages Qdrant vector database operations."""
14
+
15
+ def __init__(self, url: str, api_key: str):
16
+ """Initialize Qdrant client."""
17
+ self.client = QdrantClient(url=url, api_key=api_key)
18
+ logger.info(f"Connected to Qdrant at {url}")
19
+
20
+ def create_collection(self, collection_name: str, vector_size: int, distance: str = "Cosine"):
21
+ """
22
+ Create a new collection in Qdrant.
23
+
24
+ Args:
25
+ collection_name: Name of the collection
26
+ vector_size: Dimension of vectors
27
+ distance: Distance metric (Cosine, Euclidean, Dot)
28
+ """
29
+ try:
30
+ # Check if collection already exists
31
+ collections = self.client.get_collections().collections
32
+ existing_names = [col.name for col in collections]
33
+
34
+ if collection_name in existing_names:
35
+ logger.info(f"Collection '{collection_name}' already exists")
36
+ return True
37
+
38
+ # Create new collection
39
+ distance_map = {
40
+ "Cosine": Distance.COSINE,
41
+ "Euclidean": Distance.EUCLID,
42
+ "Dot": Distance.DOT
43
+ }
44
+
45
+ self.client.create_collection(
46
+ collection_name=collection_name,
47
+ vectors_config=VectorParams(
48
+ size=vector_size,
49
+ distance=distance_map.get(distance, Distance.COSINE)
50
+ )
51
+ )
52
+ logger.info(f"Created collection '{collection_name}' with vector size {vector_size}")
53
+ return True
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error creating collection: {e}")
57
+ return False
58
+
59
+ def upsert_points(self, collection_name: str, points_data: List[Dict[str, Any]],
60
+ embeddings: List[List[float]], max_retries: int = 3):
61
+ """
62
+ Upsert points into Qdrant collection with retry logic.
63
+
64
+ Args:
65
+ collection_name: Name of the collection
66
+ points_data: List of point data dictionaries
67
+ embeddings: List of embedding vectors
68
+ max_retries: Maximum number of retry attempts
69
+ """
70
+ points = []
71
+ for i, (data, embedding) in enumerate(zip(points_data, embeddings)):
72
+ point = PointStruct(
73
+ id=data['id'],
74
+ vector=embedding,
75
+ payload={
76
+ 'problem': data['problem'],
77
+ 'solution': data['solution'],
78
+ 'source': data['source']
79
+ }
80
+ )
81
+ points.append(point)
82
+
83
+ # Retry logic for network issues
84
+ for attempt in range(max_retries):
85
+ try:
86
+ self.client.upsert(
87
+ collection_name=collection_name,
88
+ points=points
89
+ )
90
+ logger.info(f"Successfully upserted {len(points)} points")
91
+ return True
92
+
93
+ except Exception as e:
94
+ logger.warning(f"Attempt {attempt + 1} failed: {e}")
95
+ if attempt < max_retries - 1:
96
+ time.sleep(2 ** attempt) # Exponential backoff
97
+ else:
98
+ logger.error(f"Failed to upsert points after {max_retries} attempts")
99
+ raise e
100
+
101
+ def search_similar(self, collection_name: str, query_vector: List[float],
102
+ limit: int = 3, score_threshold: float = 0.0):
103
+ """
104
+ Search for similar vectors in the collection.
105
+
106
+ Args:
107
+ collection_name: Name of the collection
108
+ query_vector: Query embedding vector
109
+ limit: Number of results to return
110
+ score_threshold: Minimum similarity score
111
+
112
+ Returns:
113
+ Search results from Qdrant
114
+ """
115
+ try:
116
+ results = self.client.search(
117
+ collection_name=collection_name,
118
+ query_vector=query_vector,
119
+ limit=limit,
120
+ score_threshold=score_threshold
121
+ )
122
+ logger.info(f"Found {len(results)} similar results")
123
+ return results
124
+
125
+ except Exception as e:
126
+ logger.error(f"Error searching collection: {e}")
127
+ return []
128
+
129
+ def get_collection_info(self, collection_name: str):
130
+ """Get information about a collection."""
131
+ try:
132
+ info = self.client.get_collection(collection_name)
133
+ logger.info(f"Collection info: {info}")
134
+ return info
135
+ except Exception as e:
136
+ logger.error(f"Error getting collection info: {e}")
137
+ return None
database/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset loading and processing
2
+ datasets==2.18.0
3
+ pandas
4
+
5
+ # For embedding generation
6
+ sentence-transformers==2.2.2
7
+
8
+ # For Qdrant client (VectorDB)
9
+ qdrant-client==1.8.0
10
+
11
+ # For environment variables
12
+ python-dotenv
13
+
14
+ # For progress tracking
15
+ tqdm
database/test_retrieval.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for retrieving similar math problems from Qdrant.
3
+ """
4
+ import logging
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ # Load environment variables
9
+ load_dotenv()
10
+
11
+ # Configuration settings
12
+ QDRANT_URL = os.getenv("QDRANT_URL")
13
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
14
+ QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
15
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
16
+
17
+ from utils import EmbeddingGenerator, format_retrieval_results
18
+ from qdrant_manager import QdrantManager
19
+
20
+ # Set up logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ def test_retrieval():
25
+ """Test the retrieval system with sample math questions."""
26
+
27
+ # Sample test questions
28
+ test_questions = [
29
+ "What is the value of x in 3x + 5 = 20?",
30
+ "How do you find the area of a triangle given 3 sides?",
31
+ "Solve for y: 2y - 7 = 15",
32
+ "What is the derivative of x^2 + 3x?",
33
+ "Find the arithmetic sequence common difference"
34
+ ]
35
+
36
+ try:
37
+ # Initialize components
38
+ logger.info("Initializing retrieval system...")
39
+ embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
40
+ qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
41
+
42
+ # Test each question
43
+ for i, question in enumerate(test_questions, 1):
44
+ print(f"\n{'='*60}")
45
+ print(f"TEST QUERY {i}: {question}")
46
+ print('='*60)
47
+
48
+ # Generate embedding for the question
49
+ query_embedding = embedding_generator.embed_single_text(question)
50
+
51
+ # Search for similar problems
52
+ results = qdrant_manager.search_similar(
53
+ collection_name=QDRANT_COLLECTION,
54
+ query_vector=query_embedding,
55
+ limit=3,
56
+ score_threshold=0.1
57
+ )
58
+
59
+ # Format and display results
60
+ formatted_results = format_retrieval_results(results)
61
+ print(formatted_results)
62
+
63
+ except Exception as e:
64
+ logger.error(f"Error in retrieval test: {e}")
65
+
66
+ def test_collection_status():
67
+ """Check the status of the Qdrant collection."""
68
+ try:
69
+ qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
70
+
71
+ print(f"\n{'='*40}")
72
+ print("COLLECTION STATUS")
73
+ print('='*40)
74
+
75
+ info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
76
+ if info:
77
+ print(f"Collection Name: {QDRANT_COLLECTION}")
78
+ print(f"Status: {info.status}")
79
+ print(f"Vectors Count: {info.vectors_count}")
80
+ print(f"Vector Size: {info.config.params.vectors.size}")
81
+ print(f"Distance Metric: {info.config.params.vectors.distance}")
82
+ else:
83
+ print("Collection not found or error occurred")
84
+
85
+ except Exception as e:
86
+ logger.error(f"Error checking collection status: {e}")
87
+
88
+ if __name__ == "__main__":
89
+ print("Testing Qdrant Collection Status...")
90
+ test_collection_status()
91
+
92
+ print("\n\nTesting Retrieval System...")
93
+ test_retrieval()
database/utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for data processing and embedding generation.
3
+ """
4
+ import logging
5
+ from typing import List, Dict, Any
6
+ from sentence_transformers import SentenceTransformer
7
+ from datasets import Dataset
8
+ import uuid
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class EmbeddingGenerator:
15
+ """Handles text embedding generation using sentence transformers."""
16
+
17
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
18
+ """Initialize the embedding model."""
19
+ logger.info(f"Loading embedding model: {model_name}")
20
+ self.model = SentenceTransformer(model_name)
21
+ self.model_name = model_name
22
+
23
+ def embed_text(self, texts: List[str]) -> List[List[float]]:
24
+ """Generate embeddings for a list of texts."""
25
+ logger.info(f"Generating embeddings for {len(texts)} texts")
26
+ embeddings = self.model.encode(texts, show_progress_bar=True)
27
+ return embeddings.tolist()
28
+
29
+ def embed_single_text(self, text: str) -> List[float]:
30
+ """Generate embedding for a single text."""
31
+ embedding = self.model.encode([text])
32
+ return embedding[0].tolist()
33
+
34
+ def preprocess_dataset_entry(entry: Dict[str, Any]) -> Dict[str, Any]:
35
+ """
36
+ Preprocess a single dataset entry to create combined text for embedding.
37
+
38
+ Args:
39
+ entry: Dictionary containing 'problem' and 'solution' keys
40
+
41
+ Returns:
42
+ Processed entry with 'text' field for embedding
43
+ """
44
+ problem = entry.get('problem', '')
45
+ solution = entry.get('solution', '')
46
+
47
+ # Create combined text for embedding
48
+ combined_text = f"Question: {problem}\nAnswer: {solution}"
49
+
50
+ return {
51
+ 'id': str(uuid.uuid4()),
52
+ 'text': combined_text,
53
+ 'problem': problem,
54
+ 'solution': solution,
55
+ 'source': entry.get('source', 'unknown')
56
+ }
57
+
58
+ def batch_process_dataset(dataset: Dataset, batch_size: int = 100) -> List[List[Dict[str, Any]]]:
59
+ """
60
+ Process dataset in batches for memory efficiency.
61
+
62
+ Args:
63
+ dataset: HuggingFace dataset
64
+ batch_size: Number of items per batch
65
+
66
+ Returns:
67
+ List of batches, each containing processed entries
68
+ """
69
+ batches = []
70
+ total_items = len(dataset)
71
+
72
+ logger.info(f"Processing {total_items} items in batches of {batch_size}")
73
+
74
+ for i in range(0, total_items, batch_size):
75
+ batch_end = min(i + batch_size, total_items)
76
+ batch_data = dataset[i:batch_end]
77
+
78
+ # Process each item in the batch
79
+ processed_batch = []
80
+ for j in range(len(batch_data['problem'])):
81
+ entry = {
82
+ 'problem': batch_data['problem'][j],
83
+ 'solution': batch_data['solution'][j],
84
+ 'source': batch_data['source'][j]
85
+ }
86
+ processed_entry = preprocess_dataset_entry(entry)
87
+ processed_batch.append(processed_entry)
88
+
89
+ batches.append(processed_batch)
90
+ logger.info(f"Processed batch {len(batches)}/{(total_items + batch_size - 1) // batch_size}")
91
+
92
+ return batches
93
+
94
+ def format_retrieval_results(results: List[Dict]) -> str:
95
+ """
96
+ Format retrieval results for display.
97
+
98
+ Args:
99
+ results: List of search results from Qdrant
100
+
101
+ Returns:
102
+ Formatted string for display
103
+ """
104
+ if not results:
105
+ return "No results found."
106
+
107
+ output = []
108
+ for i, result in enumerate(results, 1):
109
+ payload = result.payload
110
+ score = result.score
111
+
112
+ output.append(f"\n--- Result {i} (Score: {score:.4f}) ---")
113
+ output.append(f"Question: {payload['problem']}")
114
+ output.append(f"Answer: {payload['solution'][:200]}...") # Truncate long answers
115
+ output.append("-" * 50)
116
+
117
+ return "\n".join(output)