Spaces:
Sleeping
Sleeping
Commit
·
6874d8b
1
Parent(s):
61f25c3
phase 1 - data storage in qdrant and retrieval
Browse files- .gitignore +3 -0
- database/README.md +27 -0
- database/ingest.py +115 -0
- database/qdrant_manager.py +137 -0
- database/requirements.txt +15 -0
- database/test_retrieval.py +93 -0
- database/utils.py +117 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
venv/
|
3 |
+
__pycache__/
|
database/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Database Module - Math Agentic RAG
|
2 |
+
|
3 |
+
This module handles the knowledge base creation and retrieval for the Math Agentic RAG system.
|
4 |
+
|
5 |
+
## Files Overview
|
6 |
+
|
7 |
+
### Core Files
|
8 |
+
- **`utils.py`** - Utility functions for embedding generation and data processing
|
9 |
+
- **`qdrant_manager.py`** - Qdrant vector database client wrapper
|
10 |
+
- **`ingest.py`** - Main ingestion script for loading dataset into Qdrant (includes config)
|
11 |
+
- **`test_retrieval.py`** - Testing script for validating retrieval functionality (includes config)
|
12 |
+
|
13 |
+
### Dependencies
|
14 |
+
- **`requirements.txt`** - Python package dependencies
|
15 |
+
|
16 |
+
## Usage
|
17 |
+
|
18 |
+
1. **Setup Environment Variables**: Ensure `.env` file has Qdrant credentials
|
19 |
+
2. **Install Dependencies**: `pip install -r requirements.txt`
|
20 |
+
3. **Ingest Data**: `python ingest.py`
|
21 |
+
4. **Test Retrieval**: `python test_retrieval.py`
|
22 |
+
|
23 |
+
## Current Status
|
24 |
+
- ✅ Dataset: Nuinamath (5,000 mathematical problems)
|
25 |
+
- ✅ Vector DB: Qdrant Cloud
|
26 |
+
- ✅ Embedding Model: all-MiniLM-L6-v2 (384 dimensions)
|
27 |
+
- ✅ Status: Ready for Phase 2
|
database/ingest.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main ingestion script for loading Nuinamath dataset into Qdrant.
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from datasets import load_dataset
|
7 |
+
from tqdm import tqdm
|
8 |
+
import time
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
# Load environment variables
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
# Configuration settings
|
15 |
+
QDRANT_URL = os.getenv("QDRANT_URL")
|
16 |
+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
17 |
+
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
|
18 |
+
DATASET_NAME = "AI-MO/NuminaMath-CoT"
|
19 |
+
DATASET_SPLIT = "train"
|
20 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
21 |
+
VECTOR_SIZE = 384
|
22 |
+
DISTANCE_METRIC = "Cosine"
|
23 |
+
BATCH_SIZE = 100
|
24 |
+
MAX_SAMPLES = None
|
25 |
+
|
26 |
+
# Validation
|
27 |
+
if not QDRANT_URL or not QDRANT_API_KEY:
|
28 |
+
raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file")
|
29 |
+
|
30 |
+
from utils import EmbeddingGenerator, batch_process_dataset
|
31 |
+
from qdrant_manager import QdrantManager
|
32 |
+
|
33 |
+
# Set up logging
|
34 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
def main():
|
38 |
+
"""Main ingestion pipeline."""
|
39 |
+
try:
|
40 |
+
# Initialize components
|
41 |
+
logger.info("Initializing components...")
|
42 |
+
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
|
43 |
+
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
|
44 |
+
|
45 |
+
# Load dataset
|
46 |
+
logger.info(f"Loading dataset: {DATASET_NAME}")
|
47 |
+
if MAX_SAMPLES:
|
48 |
+
dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]")
|
49 |
+
logger.info(f"Loaded {len(dataset)} samples (limited)")
|
50 |
+
else:
|
51 |
+
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
52 |
+
logger.info(f"Loaded full dataset: {len(dataset)} samples")
|
53 |
+
|
54 |
+
# Create Qdrant collection
|
55 |
+
logger.info(f"Creating collection: {QDRANT_COLLECTION}")
|
56 |
+
success = qdrant_manager.create_collection(
|
57 |
+
collection_name=QDRANT_COLLECTION,
|
58 |
+
vector_size=VECTOR_SIZE,
|
59 |
+
distance=DISTANCE_METRIC
|
60 |
+
)
|
61 |
+
|
62 |
+
if not success:
|
63 |
+
logger.error("Failed to create collection")
|
64 |
+
return
|
65 |
+
|
66 |
+
# Process dataset in batches
|
67 |
+
logger.info("Processing dataset in batches...")
|
68 |
+
batches = batch_process_dataset(dataset, BATCH_SIZE)
|
69 |
+
|
70 |
+
total_processed = 0
|
71 |
+
total_batches = len(batches)
|
72 |
+
|
73 |
+
for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")):
|
74 |
+
try:
|
75 |
+
# Extract texts for embedding
|
76 |
+
texts = [item['text'] for item in batch_data]
|
77 |
+
|
78 |
+
# Generate embeddings
|
79 |
+
logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}")
|
80 |
+
embeddings = embedding_generator.embed_text(texts)
|
81 |
+
|
82 |
+
# Upsert to Qdrant
|
83 |
+
logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...")
|
84 |
+
qdrant_manager.upsert_points(
|
85 |
+
collection_name=QDRANT_COLLECTION,
|
86 |
+
points_data=batch_data,
|
87 |
+
embeddings=embeddings
|
88 |
+
)
|
89 |
+
|
90 |
+
total_processed += len(batch_data)
|
91 |
+
logger.info(f"Progress: {total_processed}/{len(dataset)} items processed")
|
92 |
+
|
93 |
+
# Small delay to avoid overwhelming the API
|
94 |
+
time.sleep(0.5)
|
95 |
+
|
96 |
+
except Exception as e:
|
97 |
+
logger.error(f"Error processing batch {batch_idx + 1}: {e}")
|
98 |
+
continue
|
99 |
+
|
100 |
+
# Final summary
|
101 |
+
logger.info("Ingestion completed!")
|
102 |
+
logger.info(f"Total items processed: {total_processed}")
|
103 |
+
|
104 |
+
# Get collection info
|
105 |
+
collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
|
106 |
+
if collection_info:
|
107 |
+
logger.info(f"Collection status: {collection_info.status}")
|
108 |
+
logger.info(f"Vectors count: {collection_info.vectors_count}")
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Fatal error in ingestion pipeline: {e}")
|
112 |
+
raise
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
main()
|
database/qdrant_manager.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Qdrant client wrapper for vector database operations.
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
from qdrant_client import QdrantClient
|
7 |
+
from qdrant_client.models import Distance, VectorParams, PointStruct
|
8 |
+
import time
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class QdrantManager:
|
13 |
+
"""Manages Qdrant vector database operations."""
|
14 |
+
|
15 |
+
def __init__(self, url: str, api_key: str):
|
16 |
+
"""Initialize Qdrant client."""
|
17 |
+
self.client = QdrantClient(url=url, api_key=api_key)
|
18 |
+
logger.info(f"Connected to Qdrant at {url}")
|
19 |
+
|
20 |
+
def create_collection(self, collection_name: str, vector_size: int, distance: str = "Cosine"):
|
21 |
+
"""
|
22 |
+
Create a new collection in Qdrant.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
collection_name: Name of the collection
|
26 |
+
vector_size: Dimension of vectors
|
27 |
+
distance: Distance metric (Cosine, Euclidean, Dot)
|
28 |
+
"""
|
29 |
+
try:
|
30 |
+
# Check if collection already exists
|
31 |
+
collections = self.client.get_collections().collections
|
32 |
+
existing_names = [col.name for col in collections]
|
33 |
+
|
34 |
+
if collection_name in existing_names:
|
35 |
+
logger.info(f"Collection '{collection_name}' already exists")
|
36 |
+
return True
|
37 |
+
|
38 |
+
# Create new collection
|
39 |
+
distance_map = {
|
40 |
+
"Cosine": Distance.COSINE,
|
41 |
+
"Euclidean": Distance.EUCLID,
|
42 |
+
"Dot": Distance.DOT
|
43 |
+
}
|
44 |
+
|
45 |
+
self.client.create_collection(
|
46 |
+
collection_name=collection_name,
|
47 |
+
vectors_config=VectorParams(
|
48 |
+
size=vector_size,
|
49 |
+
distance=distance_map.get(distance, Distance.COSINE)
|
50 |
+
)
|
51 |
+
)
|
52 |
+
logger.info(f"Created collection '{collection_name}' with vector size {vector_size}")
|
53 |
+
return True
|
54 |
+
|
55 |
+
except Exception as e:
|
56 |
+
logger.error(f"Error creating collection: {e}")
|
57 |
+
return False
|
58 |
+
|
59 |
+
def upsert_points(self, collection_name: str, points_data: List[Dict[str, Any]],
|
60 |
+
embeddings: List[List[float]], max_retries: int = 3):
|
61 |
+
"""
|
62 |
+
Upsert points into Qdrant collection with retry logic.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
collection_name: Name of the collection
|
66 |
+
points_data: List of point data dictionaries
|
67 |
+
embeddings: List of embedding vectors
|
68 |
+
max_retries: Maximum number of retry attempts
|
69 |
+
"""
|
70 |
+
points = []
|
71 |
+
for i, (data, embedding) in enumerate(zip(points_data, embeddings)):
|
72 |
+
point = PointStruct(
|
73 |
+
id=data['id'],
|
74 |
+
vector=embedding,
|
75 |
+
payload={
|
76 |
+
'problem': data['problem'],
|
77 |
+
'solution': data['solution'],
|
78 |
+
'source': data['source']
|
79 |
+
}
|
80 |
+
)
|
81 |
+
points.append(point)
|
82 |
+
|
83 |
+
# Retry logic for network issues
|
84 |
+
for attempt in range(max_retries):
|
85 |
+
try:
|
86 |
+
self.client.upsert(
|
87 |
+
collection_name=collection_name,
|
88 |
+
points=points
|
89 |
+
)
|
90 |
+
logger.info(f"Successfully upserted {len(points)} points")
|
91 |
+
return True
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.warning(f"Attempt {attempt + 1} failed: {e}")
|
95 |
+
if attempt < max_retries - 1:
|
96 |
+
time.sleep(2 ** attempt) # Exponential backoff
|
97 |
+
else:
|
98 |
+
logger.error(f"Failed to upsert points after {max_retries} attempts")
|
99 |
+
raise e
|
100 |
+
|
101 |
+
def search_similar(self, collection_name: str, query_vector: List[float],
|
102 |
+
limit: int = 3, score_threshold: float = 0.0):
|
103 |
+
"""
|
104 |
+
Search for similar vectors in the collection.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
collection_name: Name of the collection
|
108 |
+
query_vector: Query embedding vector
|
109 |
+
limit: Number of results to return
|
110 |
+
score_threshold: Minimum similarity score
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Search results from Qdrant
|
114 |
+
"""
|
115 |
+
try:
|
116 |
+
results = self.client.search(
|
117 |
+
collection_name=collection_name,
|
118 |
+
query_vector=query_vector,
|
119 |
+
limit=limit,
|
120 |
+
score_threshold=score_threshold
|
121 |
+
)
|
122 |
+
logger.info(f"Found {len(results)} similar results")
|
123 |
+
return results
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
logger.error(f"Error searching collection: {e}")
|
127 |
+
return []
|
128 |
+
|
129 |
+
def get_collection_info(self, collection_name: str):
|
130 |
+
"""Get information about a collection."""
|
131 |
+
try:
|
132 |
+
info = self.client.get_collection(collection_name)
|
133 |
+
logger.info(f"Collection info: {info}")
|
134 |
+
return info
|
135 |
+
except Exception as e:
|
136 |
+
logger.error(f"Error getting collection info: {e}")
|
137 |
+
return None
|
database/requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset loading and processing
|
2 |
+
datasets==2.18.0
|
3 |
+
pandas
|
4 |
+
|
5 |
+
# For embedding generation
|
6 |
+
sentence-transformers==2.2.2
|
7 |
+
|
8 |
+
# For Qdrant client (VectorDB)
|
9 |
+
qdrant-client==1.8.0
|
10 |
+
|
11 |
+
# For environment variables
|
12 |
+
python-dotenv
|
13 |
+
|
14 |
+
# For progress tracking
|
15 |
+
tqdm
|
database/test_retrieval.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Test script for retrieving similar math problems from Qdrant.
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
# Load environment variables
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
# Configuration settings
|
12 |
+
QDRANT_URL = os.getenv("QDRANT_URL")
|
13 |
+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
14 |
+
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
|
15 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
16 |
+
|
17 |
+
from utils import EmbeddingGenerator, format_retrieval_results
|
18 |
+
from qdrant_manager import QdrantManager
|
19 |
+
|
20 |
+
# Set up logging
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
def test_retrieval():
|
25 |
+
"""Test the retrieval system with sample math questions."""
|
26 |
+
|
27 |
+
# Sample test questions
|
28 |
+
test_questions = [
|
29 |
+
"What is the value of x in 3x + 5 = 20?",
|
30 |
+
"How do you find the area of a triangle given 3 sides?",
|
31 |
+
"Solve for y: 2y - 7 = 15",
|
32 |
+
"What is the derivative of x^2 + 3x?",
|
33 |
+
"Find the arithmetic sequence common difference"
|
34 |
+
]
|
35 |
+
|
36 |
+
try:
|
37 |
+
# Initialize components
|
38 |
+
logger.info("Initializing retrieval system...")
|
39 |
+
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
|
40 |
+
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
|
41 |
+
|
42 |
+
# Test each question
|
43 |
+
for i, question in enumerate(test_questions, 1):
|
44 |
+
print(f"\n{'='*60}")
|
45 |
+
print(f"TEST QUERY {i}: {question}")
|
46 |
+
print('='*60)
|
47 |
+
|
48 |
+
# Generate embedding for the question
|
49 |
+
query_embedding = embedding_generator.embed_single_text(question)
|
50 |
+
|
51 |
+
# Search for similar problems
|
52 |
+
results = qdrant_manager.search_similar(
|
53 |
+
collection_name=QDRANT_COLLECTION,
|
54 |
+
query_vector=query_embedding,
|
55 |
+
limit=3,
|
56 |
+
score_threshold=0.1
|
57 |
+
)
|
58 |
+
|
59 |
+
# Format and display results
|
60 |
+
formatted_results = format_retrieval_results(results)
|
61 |
+
print(formatted_results)
|
62 |
+
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Error in retrieval test: {e}")
|
65 |
+
|
66 |
+
def test_collection_status():
|
67 |
+
"""Check the status of the Qdrant collection."""
|
68 |
+
try:
|
69 |
+
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
|
70 |
+
|
71 |
+
print(f"\n{'='*40}")
|
72 |
+
print("COLLECTION STATUS")
|
73 |
+
print('='*40)
|
74 |
+
|
75 |
+
info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
|
76 |
+
if info:
|
77 |
+
print(f"Collection Name: {QDRANT_COLLECTION}")
|
78 |
+
print(f"Status: {info.status}")
|
79 |
+
print(f"Vectors Count: {info.vectors_count}")
|
80 |
+
print(f"Vector Size: {info.config.params.vectors.size}")
|
81 |
+
print(f"Distance Metric: {info.config.params.vectors.distance}")
|
82 |
+
else:
|
83 |
+
print("Collection not found or error occurred")
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Error checking collection status: {e}")
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
print("Testing Qdrant Collection Status...")
|
90 |
+
test_collection_status()
|
91 |
+
|
92 |
+
print("\n\nTesting Retrieval System...")
|
93 |
+
test_retrieval()
|
database/utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions for data processing and embedding generation.
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from datasets import Dataset
|
8 |
+
import uuid
|
9 |
+
|
10 |
+
# Set up logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class EmbeddingGenerator:
|
15 |
+
"""Handles text embedding generation using sentence transformers."""
|
16 |
+
|
17 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
18 |
+
"""Initialize the embedding model."""
|
19 |
+
logger.info(f"Loading embedding model: {model_name}")
|
20 |
+
self.model = SentenceTransformer(model_name)
|
21 |
+
self.model_name = model_name
|
22 |
+
|
23 |
+
def embed_text(self, texts: List[str]) -> List[List[float]]:
|
24 |
+
"""Generate embeddings for a list of texts."""
|
25 |
+
logger.info(f"Generating embeddings for {len(texts)} texts")
|
26 |
+
embeddings = self.model.encode(texts, show_progress_bar=True)
|
27 |
+
return embeddings.tolist()
|
28 |
+
|
29 |
+
def embed_single_text(self, text: str) -> List[float]:
|
30 |
+
"""Generate embedding for a single text."""
|
31 |
+
embedding = self.model.encode([text])
|
32 |
+
return embedding[0].tolist()
|
33 |
+
|
34 |
+
def preprocess_dataset_entry(entry: Dict[str, Any]) -> Dict[str, Any]:
|
35 |
+
"""
|
36 |
+
Preprocess a single dataset entry to create combined text for embedding.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
entry: Dictionary containing 'problem' and 'solution' keys
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Processed entry with 'text' field for embedding
|
43 |
+
"""
|
44 |
+
problem = entry.get('problem', '')
|
45 |
+
solution = entry.get('solution', '')
|
46 |
+
|
47 |
+
# Create combined text for embedding
|
48 |
+
combined_text = f"Question: {problem}\nAnswer: {solution}"
|
49 |
+
|
50 |
+
return {
|
51 |
+
'id': str(uuid.uuid4()),
|
52 |
+
'text': combined_text,
|
53 |
+
'problem': problem,
|
54 |
+
'solution': solution,
|
55 |
+
'source': entry.get('source', 'unknown')
|
56 |
+
}
|
57 |
+
|
58 |
+
def batch_process_dataset(dataset: Dataset, batch_size: int = 100) -> List[List[Dict[str, Any]]]:
|
59 |
+
"""
|
60 |
+
Process dataset in batches for memory efficiency.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
dataset: HuggingFace dataset
|
64 |
+
batch_size: Number of items per batch
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
List of batches, each containing processed entries
|
68 |
+
"""
|
69 |
+
batches = []
|
70 |
+
total_items = len(dataset)
|
71 |
+
|
72 |
+
logger.info(f"Processing {total_items} items in batches of {batch_size}")
|
73 |
+
|
74 |
+
for i in range(0, total_items, batch_size):
|
75 |
+
batch_end = min(i + batch_size, total_items)
|
76 |
+
batch_data = dataset[i:batch_end]
|
77 |
+
|
78 |
+
# Process each item in the batch
|
79 |
+
processed_batch = []
|
80 |
+
for j in range(len(batch_data['problem'])):
|
81 |
+
entry = {
|
82 |
+
'problem': batch_data['problem'][j],
|
83 |
+
'solution': batch_data['solution'][j],
|
84 |
+
'source': batch_data['source'][j]
|
85 |
+
}
|
86 |
+
processed_entry = preprocess_dataset_entry(entry)
|
87 |
+
processed_batch.append(processed_entry)
|
88 |
+
|
89 |
+
batches.append(processed_batch)
|
90 |
+
logger.info(f"Processed batch {len(batches)}/{(total_items + batch_size - 1) // batch_size}")
|
91 |
+
|
92 |
+
return batches
|
93 |
+
|
94 |
+
def format_retrieval_results(results: List[Dict]) -> str:
|
95 |
+
"""
|
96 |
+
Format retrieval results for display.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
results: List of search results from Qdrant
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Formatted string for display
|
103 |
+
"""
|
104 |
+
if not results:
|
105 |
+
return "No results found."
|
106 |
+
|
107 |
+
output = []
|
108 |
+
for i, result in enumerate(results, 1):
|
109 |
+
payload = result.payload
|
110 |
+
score = result.score
|
111 |
+
|
112 |
+
output.append(f"\n--- Result {i} (Score: {score:.4f}) ---")
|
113 |
+
output.append(f"Question: {payload['problem']}")
|
114 |
+
output.append(f"Answer: {payload['solution'][:200]}...") # Truncate long answers
|
115 |
+
output.append("-" * 50)
|
116 |
+
|
117 |
+
return "\n".join(output)
|