import chromadb from chromadb.utils import embedding_functions import openai import os import logging from typing import List, Dict, Any, Optional import uuid from datetime import datetime import numpy as np logger = logging.getLogger(__name__) class RAGSystem: """Retrieval-Augmented Generation system for chatbot functionality""" def __init__(self, openai_api_key: str, persist_directory: str = "chroma_db"): self.client = openai.OpenAI(api_key=openai_api_key) # Initialize ChromaDB self.chroma_client = chromadb.PersistentClient(path=persist_directory) # Create embedding function self.embedding_function = embedding_functions.DefaultEmbeddingFunction() # Collections for different document types self.pdf_collection = self._get_or_create_collection("pdf_documents") self.lecture_collection = self._get_or_create_collection("lecture_content") def _get_or_create_collection(self, name: str): """Get existing collection or create new one""" try: return self.chroma_client.get_collection( name=name, embedding_function=self.embedding_function ) except: return self.chroma_client.create_collection( name=name, embedding_function=self.embedding_function, metadata={"description": f"Collection for {name}"} ) def add_pdf_content(self, session_id: str, pdf_content: str, metadata: Dict[str, Any] = None) -> bool: """Add PDF content to the vector database""" try: # Split content into chunks chunks = self._split_text(pdf_content, chunk_size=1000, overlap=200) # Prepare documents for insertion documents = [] metadatas = [] ids = [] base_metadata = { "session_id": session_id, "document_type": "pdf", "added_at": datetime.now().isoformat(), **(metadata or {}) } for i, chunk in enumerate(chunks): doc_id = f"{session_id}_pdf_{i}_{uuid.uuid4().hex[:8]}" documents.append(chunk) metadatas.append({ **base_metadata, "chunk_index": i, "chunk_id": doc_id }) ids.append(doc_id) # Add to collection self.pdf_collection.add( documents=documents, metadatas=metadatas, ids=ids ) logger.info(f"Added {len(chunks)} PDF chunks for session {session_id}") return True except Exception as e: logger.error(f"Failed to add PDF content: {str(e)}") return False def add_lecture_content(self, session_id: str, lecture_content: str, metadata: Dict[str, Any] = None) -> bool: """Add lecture content to the vector database""" try: # Split content into chunks chunks = self._split_text(lecture_content, chunk_size=1000, overlap=200) documents = [] metadatas = [] ids = [] base_metadata = { "session_id": session_id, "document_type": "lecture", "added_at": datetime.now().isoformat(), **(metadata or {}) } for i, chunk in enumerate(chunks): doc_id = f"{session_id}_lecture_{i}_{uuid.uuid4().hex[:8]}" documents.append(chunk) metadatas.append({ **base_metadata, "chunk_index": i, "chunk_id": doc_id }) ids.append(doc_id) # Add to collection self.lecture_collection.add( documents=documents, metadatas=metadatas, ids=ids ) logger.info(f"Added {len(chunks)} lecture chunks for session {session_id}") return True except Exception as e: logger.error(f"Failed to add lecture content: {str(e)}") return False def retrieve_relevant_content(self, session_id: str, query: str, n_results: int = 5) -> Dict[str, Any]: """Retrieve relevant content for a query""" try: # Search in both collections pdf_results = self.pdf_collection.query( query_texts=[query], n_results=n_results, where={"session_id": session_id} ) lecture_results = self.lecture_collection.query( query_texts=[query], n_results=n_results, where={"session_id": session_id} ) # Combine and rank results all_results = [] # Process PDF results if pdf_results['documents'] and pdf_results['documents'][0]: for i, doc in enumerate(pdf_results['documents'][0]): all_results.append({ 'content': doc, 'metadata': pdf_results['metadatas'][0][i], 'distance': pdf_results['distances'][0][i], 'source': 'pdf' }) # Process lecture results if lecture_results['documents'] and lecture_results['documents'][0]: for i, doc in enumerate(lecture_results['documents'][0]): all_results.append({ 'content': doc, 'metadata': lecture_results['metadatas'][0][i], 'distance': lecture_results['distances'][0][i], 'source': 'lecture' }) # Sort by relevance (distance) all_results.sort(key=lambda x: x['distance']) return { 'success': True, 'results': all_results[:n_results], 'total_found': len(all_results) } except Exception as e: logger.error(f"Content retrieval failed: {str(e)}") return { 'success': False, 'results': [], 'total_found': 0, 'error': str(e) } def _split_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]: """Split text into overlapping chunks""" if len(text) <= chunk_size: return [text] chunks = [] start = 0 while start < len(text): end = start + chunk_size # Try to end at a sentence boundary if end < len(text): # Look for sentence endings within the last 100 characters search_start = max(end - 100, start) sentence_ends = [] for punct in ['. ', '! ', '? ', '\n\n']: pos = text.rfind(punct, search_start, end) if pos > start: sentence_ends.append(pos + len(punct)) if sentence_ends: end = max(sentence_ends) chunk = text[start:end].strip() if chunk: chunks.append(chunk) # Move start position with overlap start = end - overlap if start >= len(text): break return chunks def get_session_stats(self, session_id: str) -> Dict[str, Any]: """Get statistics about stored content for a session""" try: # Count PDF chunks pdf_count = len(self.pdf_collection.get( where={"session_id": session_id} )['ids']) # Count lecture chunks lecture_count = len(self.lecture_collection.get( where={"session_id": session_id} )['ids']) return { 'pdf_chunks': pdf_count, 'lecture_chunks': lecture_count, 'total_chunks': pdf_count + lecture_count } except Exception as e: logger.error(f"Failed to get session stats: {str(e)}") return { 'pdf_chunks': 0, 'lecture_chunks': 0, 'total_chunks': 0 } def clear_session_data(self, session_id: str) -> bool: """Clear all data for a specific session""" try: # Get all document IDs for this session pdf_ids = self.pdf_collection.get( where={"session_id": session_id} )['ids'] lecture_ids = self.lecture_collection.get( where={"session_id": session_id} )['ids'] # Delete documents if pdf_ids: self.pdf_collection.delete(ids=pdf_ids) if lecture_ids: self.lecture_collection.delete(ids=lecture_ids) logger.info(f"Cleared data for session {session_id}") return True except Exception as e: logger.error(f"Failed to clear session data: {str(e)}") return False