bhoomika19's picture
phase 1 - data storage in qdrant and retrieval
6874d8b
raw
history blame
4.1 kB
"""
Main ingestion script for loading Nuinamath dataset into Qdrant.
"""
import logging
import os
from datasets import load_dataset
from tqdm import tqdm
import time
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configuration settings
QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
DATASET_NAME = "AI-MO/NuminaMath-CoT"
DATASET_SPLIT = "train"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
VECTOR_SIZE = 384
DISTANCE_METRIC = "Cosine"
BATCH_SIZE = 100
MAX_SAMPLES = None
# Validation
if not QDRANT_URL or not QDRANT_API_KEY:
raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file")
from utils import EmbeddingGenerator, batch_process_dataset
from qdrant_manager import QdrantManager
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
"""Main ingestion pipeline."""
try:
# Initialize components
logger.info("Initializing components...")
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
# Load dataset
logger.info(f"Loading dataset: {DATASET_NAME}")
if MAX_SAMPLES:
dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]")
logger.info(f"Loaded {len(dataset)} samples (limited)")
else:
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
logger.info(f"Loaded full dataset: {len(dataset)} samples")
# Create Qdrant collection
logger.info(f"Creating collection: {QDRANT_COLLECTION}")
success = qdrant_manager.create_collection(
collection_name=QDRANT_COLLECTION,
vector_size=VECTOR_SIZE,
distance=DISTANCE_METRIC
)
if not success:
logger.error("Failed to create collection")
return
# Process dataset in batches
logger.info("Processing dataset in batches...")
batches = batch_process_dataset(dataset, BATCH_SIZE)
total_processed = 0
total_batches = len(batches)
for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")):
try:
# Extract texts for embedding
texts = [item['text'] for item in batch_data]
# Generate embeddings
logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}")
embeddings = embedding_generator.embed_text(texts)
# Upsert to Qdrant
logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...")
qdrant_manager.upsert_points(
collection_name=QDRANT_COLLECTION,
points_data=batch_data,
embeddings=embeddings
)
total_processed += len(batch_data)
logger.info(f"Progress: {total_processed}/{len(dataset)} items processed")
# Small delay to avoid overwhelming the API
time.sleep(0.5)
except Exception as e:
logger.error(f"Error processing batch {batch_idx + 1}: {e}")
continue
# Final summary
logger.info("Ingestion completed!")
logger.info(f"Total items processed: {total_processed}")
# Get collection info
collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
if collection_info:
logger.info(f"Collection status: {collection_info.status}")
logger.info(f"Vectors count: {collection_info.vectors_count}")
except Exception as e:
logger.error(f"Fatal error in ingestion pipeline: {e}")
raise
if __name__ == "__main__":
main()