viettelpay-chatbot / src /knowledge_base /viettel_knowledge_base.py
minhan6559's picture
Upload 73 files
60d1d13 verified
"""
ViettelPay Knowledge Base with Contextual Retrieval
This updated version:
- Uses ContextualWordProcessor for all document processing
- Integrates OpenAI for contextual enhancement
- Processes all doc/docx files from a parent folder
- Removes CSV processor dependency
"""
import os
import pickle
# import torch
from typing import List, Optional
from pathlib import Path
from openai import OpenAI
from langchain.schema import Document
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.runnables import ConfigurableField
from langchain_cohere.rerank import CohereRerank
# Use newest import paths for langchain
try:
from langchain_chroma import Chroma
except ImportError:
from langchain_community.vectorstores import Chroma
# Use the new HuggingFaceEmbeddings from langchain-huggingface
try:
from langchain_huggingface import HuggingFaceEmbeddings
except ImportError:
from langchain_community.embeddings import HuggingFaceEmbeddings
from src.processor.contextual_word_processor import ContextualWordProcessor
from src.processor.text_utils import VietnameseTextProcessor
# Import configuration utility
from src.utils.config import get_cohere_api_key, get_openai_api_key, get_embedding_model
class ViettelKnowledgeBase:
"""ViettelPay knowledge base with contextual retrieval enhancement"""
def __init__(self, embedding_model: str = None):
"""
Initialize the knowledge base
Args:
embedding_model: Vietnamese embedding model to use
"""
embedding_model = embedding_model or get_embedding_model()
# Initialize Vietnamese text processor
self.text_processor = VietnameseTextProcessor()
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = "cpu"
print(f"[INFO] Using device: {self.device}")
# Initialize embeddings with GPU support and trust_remote_code
model_kwargs = {"device": self.device, "trust_remote_code": True}
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model, model_kwargs=model_kwargs
)
# Initialize retrievers as None
self.chroma_retriever = None
self.bm25_retriever = None
self.ensemble_retriever = None
self.reranker = CohereRerank(
model="rerank-v3.5",
cohere_api_key=get_cohere_api_key(),
)
def build_knowledge_base(
self,
documents_folder: str,
persist_dir: str = "./knowledge_base",
reset: bool = True,
openai_api_key: Optional[str] = None,
) -> None:
"""
Build knowledge base from all Word documents in a folder
Args:
documents_folder: Path to folder containing doc/docx files
persist_dir: Directory to persist the knowledge base
reset: Whether to reset existing knowledge base
openai_api_key: OpenAI API key for contextual enhancement (optional)
Returns:
None. Use the search() method to perform searches.
"""
print(
"[INFO] Building ViettelPay knowledge base with contextual enhancement..."
)
# Initialize OpenAI client for contextual enhancement if API key provided
openai_client = None
if openai_api_key:
openai_client = OpenAI(api_key=openai_api_key)
print(f"[INFO] OpenAI client initialized for contextual enhancement")
elif get_openai_api_key():
api_key = get_openai_api_key()
openai_client = OpenAI(api_key=api_key)
print(f"[INFO] OpenAI client initialized from configuration")
else:
print(
f"[WARNING] No OpenAI API key provided. Contextual enhancement disabled."
)
# Initialize the contextual word processor with OpenAI client
word_processor = ContextualWordProcessor(llm_client=openai_client)
# Find all Word documents in the folder
word_files = self._find_word_documents(documents_folder)
if not word_files:
raise ValueError(f"No Word documents found in {documents_folder}")
print(f"[INFO] Found {len(word_files)} Word documents to process")
# Process all documents
all_documents = self._process_all_word_files(word_files, word_processor)
print(f"[INFO] Total documents processed: {len(all_documents)}")
# Create directories
os.makedirs(persist_dir, exist_ok=True)
chroma_dir = os.path.join(persist_dir, "chroma")
bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl")
# Build ChromaDB retriever (uses contextualized content)
print("[INFO] Building ChromaDB retriever with contextualized content...")
self.chroma_retriever = self._build_chroma_retriever(
all_documents, chroma_dir, reset
)
# Build BM25 retriever (uses contextualized content with Vietnamese tokenization)
print("[INFO] Building BM25 retriever with Vietnamese tokenization...")
self.bm25_retriever = self._build_bm25_retriever(
all_documents, bm25_path, reset
)
# Create ensemble retriever with configurable top-k
print("[INFO] Creating ensemble retriever...")
self.ensemble_retriever = self._build_retriever(
self.bm25_retriever, self.chroma_retriever
)
print("[SUCCESS] Contextual knowledge base built successfully!")
print("[INFO] Use kb.search(query, top_k) to perform searches.")
def load_knowledge_base(self, persist_dir: str = "./knowledge_base") -> bool:
"""
Load existing knowledge base from disk and rebuild BM25 from ChromaDB documents
Args:
persist_dir: Directory where the knowledge base is stored
Returns:
bool: True if loaded successfully, False otherwise
"""
print("[INFO] Loading knowledge base from disk...")
chroma_dir = os.path.join(persist_dir, "chroma")
try:
# Load ChromaDB
if os.path.exists(chroma_dir):
vectorstore = Chroma(
persist_directory=chroma_dir, embedding_function=self.embeddings
)
self.chroma_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
print("[SUCCESS] ChromaDB loaded")
else:
print("[ERROR] ChromaDB not found")
return False
# Extract all documents from ChromaDB to rebuild BM25
print("[INFO] Extracting documents from ChromaDB to rebuild BM25...")
try:
# Get all documents and metadata from ChromaDB
all_docs = vectorstore.get(include=["documents", "metadatas"])
documents = all_docs["documents"]
metadatas = all_docs["metadatas"]
# Reconstruct Document objects
doc_objects = []
for i, (doc_content, metadata) in enumerate(zip(documents, metadatas)):
# Handle case where metadata might be None
if metadata is None:
metadata = {}
doc_obj = Document(page_content=doc_content, metadata=metadata)
doc_objects.append(doc_obj)
print(f"[INFO] Extracted {len(doc_objects)} documents from ChromaDB")
# Rebuild BM25 retriever using existing method
self.bm25_retriever = self._build_bm25_retriever(
documents=doc_objects,
bm25_path=None, # Not used anymore
reset=False, # Not relevant for rebuilding
)
except Exception as e:
print(f"[ERROR] Error rebuilding BM25 from ChromaDB: {e}")
return False
# Create ensemble retriever with configurable top-k
self.ensemble_retriever = self._build_retriever(
self.bm25_retriever, self.chroma_retriever
)
print("[SUCCESS] Knowledge base loaded successfully!")
print("[INFO] Use kb.search(query, top_k) to perform searches.")
return True
except Exception as e:
print(f"[ERROR] Error loading knowledge base: {e}")
return False
def search(self, query: str, top_k: int = 10) -> List[Document]:
"""
Main search method using ensemble retriever with configurable top-k
Args:
query: Search query
top_k: Number of documents to return from each retriever (default: 5)
Returns:
List of retrieved documents
"""
if not self.ensemble_retriever:
raise ValueError(
"Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first."
)
# Build config based on top_k parameter
config = {
"configurable": {
"bm25_k": top_k * 5,
"chroma_search_kwargs": {"k": top_k * 5},
}
}
results = self.ensemble_retriever.invoke(query, config=config)
reranked_results = self.reranker.rerank(results, query, top_n=top_k)
final_results = []
for rerank_item in reranked_results:
# Get the original document using the index
original_doc = results[rerank_item["index"]]
# Create a new document with the relevance score added to metadata
reranked_doc = Document(
page_content=original_doc.page_content,
metadata={
**original_doc.metadata,
"relevance_score": rerank_item["relevance_score"],
},
)
final_results.append(reranked_doc)
return final_results
def get_stats(self) -> dict:
"""Get statistics about the knowledge base"""
stats = {}
if self.chroma_retriever:
try:
vectorstore = self.chroma_retriever.vectorstore
collection = vectorstore._collection
stats["chroma_documents"] = collection.count()
except:
stats["chroma_documents"] = "Unknown"
if self.bm25_retriever:
try:
stats["bm25_documents"] = len(self.bm25_retriever.docs)
except:
stats["bm25_documents"] = "Unknown"
stats["ensemble_available"] = self.ensemble_retriever is not None
stats["device"] = self.device
stats["vietnamese_tokenizer"] = "Vietnamese BM25 tokenizer (underthesea)"
return stats
def _find_word_documents(self, folder_path: str) -> List[str]:
"""
Find all Word documents (.doc, .docx) in the given folder
Args:
folder_path: Path to the folder to search
Returns:
List of full paths to Word documents
"""
word_files = []
folder = Path(folder_path)
if not folder.exists():
raise FileNotFoundError(f"Folder not found: {folder_path}")
# Search for Word documents
for pattern in ["*.doc", "*.docx"]:
word_files.extend(folder.glob(pattern))
# Convert to string paths and sort for consistent processing order
word_files = [str(f) for f in word_files]
word_files.sort()
print(f"[INFO] Found Word documents: {[Path(f).name for f in word_files]}")
return word_files
def _process_all_word_files(
self, word_files: List[str], word_processor: ContextualWordProcessor
) -> List[Document]:
"""Process all Word files into unified chunks with contextual enhancement"""
all_documents = []
for file_path in word_files:
try:
print(f"[INFO] Processing: {Path(file_path).name}")
chunks = word_processor.process_word_document(file_path)
all_documents.extend(chunks)
# Print processing stats for this file
stats = word_processor.get_document_stats(chunks)
print(
f"[SUCCESS] Processed {Path(file_path).name}: {len(chunks)} chunks"
)
print(f" - Contextualized: {stats.get('contextualized_docs', 0)}")
print(
f" - Non-contextualized: {stats.get('non_contextualized_docs', 0)}"
)
except Exception as e:
print(f"[ERROR] Error processing {Path(file_path).name}: {e}")
return all_documents
def _build_retriever(self, bm25_retriever, chroma_retriever):
"""
Build ensemble retriever with configurable top-k parameters
Args:
bm25_retriever: BM25 retriever with configurable fields
chroma_retriever: Chroma retriever with configurable fields
Returns:
EnsembleRetriever with configurable retrievers
"""
return EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever],
weights=[0.2, 0.8], # Slightly favor semantic search
)
def _build_chroma_retriever(
self, documents: List[Document], chroma_dir: str, reset: bool
):
"""Build ChromaDB retriever with configurable search parameters"""
if reset and os.path.exists(chroma_dir):
import shutil
shutil.rmtree(chroma_dir)
print("[INFO] Removed existing ChromaDB for rebuild")
# Create Chroma vectorstore (uses contextualized content)
vectorstore = Chroma.from_documents(
documents=documents, embedding=self.embeddings, persist_directory=chroma_dir
)
# Create retriever with configurable search_kwargs
retriever = vectorstore.as_retriever(
search_kwargs={"k": 5} # default value
).configurable_fields(
search_kwargs=ConfigurableField(
id="chroma_search_kwargs",
name="Chroma Search Kwargs",
description="Search kwargs for Chroma DB retriever",
)
)
print(
f"[SUCCESS] ChromaDB created with {len(documents)} contextualized documents"
)
return retriever
def _build_bm25_retriever(
self, documents: List[Document], bm25_path: str, reset: bool
):
"""Build BM25 retriever with Vietnamese tokenization and configurable k parameter"""
# Note: We no longer save BM25 to pickle file to avoid Streamlit Cloud compatibility issues
# BM25 will be rebuilt from ChromaDB documents when loading the knowledge base
# Create BM25 retriever with Vietnamese tokenizer as preprocess_func
print("[INFO] Using Vietnamese tokenizer for BM25 on contextualized content...")
retriever = BM25Retriever.from_documents(
documents=documents,
preprocess_func=self.text_processor.bm25_tokenizer,
k=5, # default value
).configurable_fields(
k=ConfigurableField(
id="bm25_k",
name="BM25 Top K",
description="Number of documents to return from BM25",
)
)
print(
f"[SUCCESS] BM25 retriever created with {len(documents)} contextualized documents"
)
return retriever
def test_contextual_kb(kb: ViettelKnowledgeBase, test_queries: List[str]):
"""Test function for the contextual knowledge base"""
print("\n[INFO] Testing Contextual Knowledge Base")
print("=" * 60)
for i, query in enumerate(test_queries, 1):
print(f"\n#{i} Query: '{query}'")
print("-" * 40)
try:
# Test ensemble search with configurable top-k
results = kb.search(query, top_k=3)
if results:
for j, doc in enumerate(results, 1):
content_preview = doc.page_content[:150].replace("\n", " ")
doc_type = doc.metadata.get("doc_type", "unknown")
has_context = doc.metadata.get("has_context", False)
context_indicator = (
" [CONTEXTUAL]" if has_context else " [ORIGINAL]"
)
print(
f" {j}. [{doc_type}]{context_indicator} {content_preview}..."
)
else:
print(" No results found")
except Exception as e:
print(f" [ERROR] Error: {e}")
# Example usage
if __name__ == "__main__":
# Initialize knowledge base
kb = ViettelKnowledgeBase(
embedding_model="dangvantuan/vietnamese-document-embedding"
)
# Build knowledge base from a folder of Word documents
documents_folder = "./viettelpay_docs" # Folder containing .doc/.docx files
try:
# Build knowledge base (pass OpenAI API key here for contextual enhancement)
kb.build_knowledge_base(
documents_folder,
"./contextual_kb",
reset=True,
openai_api_key="your-openai-api-key-here", # or None to use env variable
)
# Alternative: Load existing knowledge base
# success = kb.load_knowledge_base("./contextual_kb")
# if not success:
# print("[ERROR] Failed to load knowledge base")
# Test queries
test_queries = [
"lỗi 606",
"không nạp được tiền",
"hướng dẫn nạp cước",
"quy định hủy giao dịch",
"mệnh giá thẻ cào",
]
# Test the knowledge base
test_contextual_kb(kb, test_queries)
# Example of runtime configuration for different top-k values
print(f"\n[INFO] Example of runtime configuration:")
print("=" * 50)
# Search with different top-k values
sample_query = "lỗi 606"
# Search with top_k=3
results1 = kb.search(sample_query, top_k=3)
print(f"Search with top_k=3: {len(results1)} total results")
# Search with top_k=8
results2 = kb.search(sample_query, top_k=8)
print(f"Search with top_k=8: {len(results2)} total results")
# Show stats
print(f"\n[INFO] Knowledge Base Stats: {kb.get_stats()}")
except Exception as e:
print(f"[ERROR] Error building knowledge base: {e}")
print("[INFO] Make sure you have:")
print(" 1. Valid OpenAI API key")
print(" 2. Word documents in the specified folder")
print(" 3. Required dependencies installed (openai, markitdown, etc.)")