Spaces:
Running
Running
""" | |
Main ingestion script for loading Nuinamath dataset into Qdrant. | |
""" | |
import logging | |
import os | |
from datasets import load_dataset | |
from tqdm import tqdm | |
import time | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Configuration settings | |
QDRANT_URL = os.getenv("QDRANT_URL") | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath") | |
DATASET_NAME = "AI-MO/NuminaMath-CoT" | |
DATASET_SPLIT = "train" | |
EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
VECTOR_SIZE = 384 | |
DISTANCE_METRIC = "Cosine" | |
BATCH_SIZE = 100 | |
MAX_SAMPLES = None | |
# Validation | |
if not QDRANT_URL or not QDRANT_API_KEY: | |
raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file") | |
from utils import EmbeddingGenerator, batch_process_dataset | |
from qdrant_manager import QdrantManager | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def main(): | |
"""Main ingestion pipeline.""" | |
try: | |
# Initialize components | |
logger.info("Initializing components...") | |
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL) | |
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY) | |
# Load dataset | |
logger.info(f"Loading dataset: {DATASET_NAME}") | |
if MAX_SAMPLES: | |
dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]") | |
logger.info(f"Loaded {len(dataset)} samples (limited)") | |
else: | |
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT) | |
logger.info(f"Loaded full dataset: {len(dataset)} samples") | |
# Create Qdrant collection | |
logger.info(f"Creating collection: {QDRANT_COLLECTION}") | |
success = qdrant_manager.create_collection( | |
collection_name=QDRANT_COLLECTION, | |
vector_size=VECTOR_SIZE, | |
distance=DISTANCE_METRIC | |
) | |
if not success: | |
logger.error("Failed to create collection") | |
return | |
# Process dataset in batches | |
logger.info("Processing dataset in batches...") | |
batches = batch_process_dataset(dataset, BATCH_SIZE) | |
total_processed = 0 | |
total_batches = len(batches) | |
for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")): | |
try: | |
# Extract texts for embedding | |
texts = [item['text'] for item in batch_data] | |
# Generate embeddings | |
logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}") | |
embeddings = embedding_generator.embed_text(texts) | |
# Upsert to Qdrant | |
logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...") | |
qdrant_manager.upsert_points( | |
collection_name=QDRANT_COLLECTION, | |
points_data=batch_data, | |
embeddings=embeddings | |
) | |
total_processed += len(batch_data) | |
logger.info(f"Progress: {total_processed}/{len(dataset)} items processed") | |
# Small delay to avoid overwhelming the API | |
time.sleep(0.5) | |
except Exception as e: | |
logger.error(f"Error processing batch {batch_idx + 1}: {e}") | |
continue | |
# Final summary | |
logger.info("Ingestion completed!") | |
logger.info(f"Total items processed: {total_processed}") | |
# Get collection info | |
collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION) | |
if collection_info: | |
logger.info(f"Collection status: {collection_info.status}") | |
logger.info(f"Vectors count: {collection_info.vectors_count}") | |
except Exception as e: | |
logger.error(f"Fatal error in ingestion pipeline: {e}") | |
raise | |
if __name__ == "__main__": | |
main() | |