viettelpay-chatbot / src /agent /viettelpay_agent.py
minhan6559's picture
Upload 73 files
60d1d13 verified
"""
ViettelPay AI Agent using LangGraph
Multi-turn conversation support with short-term memory using InMemorySaver
"""
import os
from typing import Dict, Optional
from functools import partial
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.messages import HumanMessage
from src.agent.nodes import (
ViettelPayState,
classify_intent_node,
query_enhancement_node,
knowledge_retrieval_node,
script_response_node,
generate_response_node,
route_after_intent_classification,
route_after_query_enhancement,
route_after_knowledge_retrieval,
)
# Import configuration utility
from src.utils.config import get_knowledge_base_path, get_llm_provider
class ViettelPayAgent:
"""Main ViettelPay AI Agent using LangGraph workflow with multi-turn conversation support"""
def __init__(
self,
knowledge_base_path: str = None,
scripts_file: Optional[str] = None,
llm_provider: str = None,
):
knowledge_base_path = knowledge_base_path or get_knowledge_base_path()
scripts_file = scripts_file or "./viettelpay_docs/processed/kich_ban.csv"
llm_provider = llm_provider or get_llm_provider()
self.knowledge_base_path = knowledge_base_path
self.scripts_file = scripts_file
self.llm_provider = llm_provider
# Initialize LLM client once during agent creation
print(f"🧠 Initializing LLM client ({self.llm_provider})...")
from src.llm.llm_client import LLMClientFactory
self.llm_client = LLMClientFactory.create_client(self.llm_provider)
print(f"✅ LLM client initialized and ready")
# Initialize knowledge retriever once during agent creation
print(f"📚 Initializing knowledge retriever...")
try:
from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
self.knowledge_base = ViettelKnowledgeBase()
ensemble_retriever = self.knowledge_base.load_knowledge_base(
knowledge_base_path
)
if not ensemble_retriever:
raise ValueError(
f"Knowledge base not found at {knowledge_base_path}. Run build_database_script.py first."
)
print(f"✅ Knowledge retriever initialized and ready")
except Exception as e:
print(f"⚠️ Knowledge retriever initialization failed: {e}")
self.knowledge_base = None
# Initialize checkpointer for short-term memory
self.checkpointer = InMemorySaver()
# Build workflow with pre-initialized components
self.workflow = self._build_workflow()
self.app = self.workflow.compile(checkpointer=self.checkpointer)
print("✅ ViettelPay Agent initialized with multi-turn conversation support")
def _build_workflow(self) -> StateGraph:
"""Build LangGraph workflow with pre-initialized components"""
# Create workflow graph
workflow = StateGraph(ViettelPayState)
# Create node functions with pre-bound components using functools.partial
# This eliminates the need to initialize components in each node call
classify_intent_with_llm = partial(
classify_intent_node, llm_client=self.llm_client
)
query_enhancement_with_llm = partial(
query_enhancement_node, llm_client=self.llm_client
)
knowledge_retrieval_with_retriever = partial(
knowledge_retrieval_node, knowledge_retriever=self.knowledge_base
)
generate_response_with_llm = partial(
generate_response_node, llm_client=self.llm_client
)
# Add nodes (some with pre-bound components, some without)
workflow.add_node("classify_intent", classify_intent_with_llm)
workflow.add_node("query_enhancement", query_enhancement_with_llm)
workflow.add_node("knowledge_retrieval", knowledge_retrieval_with_retriever)
workflow.add_node(
"script_response", script_response_node
) # No pre-bound components needed
workflow.add_node("generate_response", generate_response_with_llm)
# Set entry point
workflow.set_entry_point("classify_intent")
# Add conditional routing after intent classification
workflow.add_conditional_edges(
"classify_intent",
route_after_intent_classification,
{
"script_response": "script_response",
"query_enhancement": "query_enhancement",
},
)
# Script responses go directly to end
workflow.add_edge("script_response", END)
# Query enhancement goes to knowledge retrieval
workflow.add_edge("query_enhancement", "knowledge_retrieval")
# Knowledge retrieval goes to response generation
workflow.add_edge("knowledge_retrieval", "generate_response")
workflow.add_edge("generate_response", END)
print("🔄 LangGraph workflow built successfully with optimized component usage")
return workflow
def process_message(self, user_message: str, thread_id: str = "default") -> Dict:
"""Process a user message in a multi-turn conversation"""
print(f"\n💬 Processing message: '{user_message}' (thread: {thread_id})")
print("=" * 50)
# Create configuration with thread_id for conversation memory
config = {"configurable": {"thread_id": thread_id}}
try:
# Create human message
human_message = HumanMessage(content=user_message)
# Initialize state with the new message
# Note: conversation_context is set to None so it gets recomputed with fresh message history
initial_state = {
"messages": [human_message],
"intent": None,
"confidence": None,
"enhanced_query": None,
"retrieved_docs": None,
"conversation_context": None, # Reset to ensure fresh context computation
"response_type": None,
"error": None,
"processing_info": None,
}
# Run workflow with memory
result = self.app.invoke(initial_state, config)
# Extract response from the last AI message
messages = result.get("messages", [])
if messages:
# Get the last AI message
last_message = messages[-1]
if hasattr(last_message, "content"):
response = last_message.content
else:
response = str(last_message)
else:
response = "Xin lỗi, em không thể xử lý yêu cầu này."
response_type = result.get("response_type", "unknown")
intent = result.get("intent", "unknown")
confidence = result.get("confidence", 0.0)
enhanced_query = result.get("enhanced_query", "")
error = result.get("error")
# Build response info
response_info = {
"response": response,
"intent": intent,
"confidence": confidence,
"response_type": response_type,
"enhanced_query": enhanced_query,
"success": error is None,
"error": error,
"thread_id": thread_id,
"message_count": len(messages),
}
print(f"✅ Response generated successfully")
print(f" Intent: {intent} (confidence: {confidence})")
print(f" Type: {response_type}")
if enhanced_query and enhanced_query != user_message:
print(f" Enhanced query: {enhanced_query}")
print(f" Thread: {thread_id}")
return response_info
except Exception as e:
print(f"❌ Workflow error: {e}")
return {
"response": "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau.",
"intent": "error",
"confidence": 0.0,
"response_type": "error",
"enhanced_query": "",
"success": False,
"error": str(e),
"thread_id": thread_id,
"message_count": 0,
}
def chat(self, user_message: str, thread_id: str = "default") -> str:
"""Simple chat interface - returns just the response text"""
result = self.process_message(user_message, thread_id)
return result["response"]
def get_conversation_history(self, thread_id: str = "default") -> list:
"""Get conversation history for a specific thread"""
try:
config = {"configurable": {"thread_id": thread_id}}
# Get the current state to access message history
current_state = self.app.get_state(config)
if current_state and current_state.values.get("messages"):
messages = current_state.values["messages"]
history = []
for msg in messages:
if hasattr(msg, "type") and hasattr(msg, "content"):
role = "user" if msg.type == "human" else "assistant"
history.append({"role": role, "content": msg.content})
elif hasattr(msg, "role") and hasattr(msg, "content"):
history.append({"role": msg.role, "content": msg.content})
return history
else:
return []
except Exception as e:
print(f"❌ Error getting conversation history: {e}")
return []
def clear_conversation(self, thread_id: str = "default") -> bool:
"""Clear conversation history for a specific thread"""
try:
# Note: InMemorySaver doesn't have a direct clear method
# The conversation will be cleared when the app is restarted
# For persistent memory, you'd need to implement a clear method
print(f"📝 Conversation clearing requested for thread: {thread_id}")
print(" Note: InMemorySaver conversations clear on app restart")
return True
except Exception as e:
print(f"❌ Error clearing conversation: {e}")
return False
def get_workflow_info(self) -> Dict:
"""Get information about the workflow structure"""
return {
"nodes": [
"classify_intent",
"query_enhancement",
"knowledge_retrieval",
"script_response",
"generate_response",
],
"entry_point": "classify_intent",
"knowledge_base_path": self.knowledge_base_path,
"scripts_file": self.scripts_file,
"llm_provider": self.llm_provider,
"memory_type": "InMemorySaver",
"multi_turn": True,
"query_enhancement": True,
"optimizations": {
"llm_client": "Single initialization with functools.partial",
"knowledge_retriever": "Single initialization with functools.partial",
"conversation_context": "Cached in state to avoid repeated computation",
},
}
def health_check(self) -> Dict:
"""Check if all components are working"""
health_status = {
"agent": True,
"workflow": True,
"memory": True,
"llm": False,
"knowledge_base": False,
"scripts": False,
"overall": False,
}
try:
# Test LLM client (already initialized)
test_response = self.llm_client.generate("Hello", temperature=0.1)
health_status["llm"] = bool(test_response)
print("✅ LLM client working")
except Exception as e:
print(f"⚠️ LLM health check failed: {e}")
health_status["llm"] = False
try:
# Test memory/checkpointer
test_config = {"configurable": {"thread_id": "health_check"}}
test_state = {"messages": [HumanMessage(content="test")]}
# Try to invoke with memory
self.app.invoke(test_state, test_config)
health_status["memory"] = True
print("✅ Memory/checkpointer working")
except Exception as e:
print(f"⚠️ Memory health check failed: {e}")
health_status["memory"] = False
try:
# Test knowledge base (using pre-initialized retriever)
if self.knowledge_base:
# Test a simple search to verify it's working
test_docs = self.knowledge_base.search("test", top_k=1)
health_status["knowledge_base"] = True
print("✅ Knowledge retriever working")
else:
health_status["knowledge_base"] = False
print("❌ Knowledge retriever not initialized")
except Exception as e:
print(f"⚠️ Knowledge base health check failed: {e}")
health_status["knowledge_base"] = False
try:
# Test scripts
from src.agent.scripts import ConversationScripts
scripts = ConversationScripts(self.scripts_file)
health_status["scripts"] = len(scripts.get_all_script_types()) > 0
except Exception as e:
print(f"⚠️ Scripts health check failed: {e}")
# Overall health
health_status["overall"] = all(
[
health_status["agent"],
health_status["memory"],
health_status["llm"],
health_status["knowledge_base"],
health_status["scripts"],
]
)
return health_status
# Usage example and testing
if __name__ == "__main__":
# Initialize agent
agent = ViettelPayAgent()
# Health check
print("\n🏥 Health Check:")
health = agent.health_check()
for component, status in health.items():
status_icon = "✅" if status else "❌"
print(f" {component}: {status_icon}")
if not health["overall"]:
print("\n⚠️ Some components are not healthy. Check requirements and data files.")
exit(1)
print(f"\n🤖 Agent ready! Workflow info: {agent.get_workflow_info()}")
# Test multi-turn conversation with query enhancement
test_thread = "test_conversation"
print(
f"\n🧪 Testing multi-turn conversation with query enhancement (thread: {test_thread}):"
)
test_messages = [
"Xin chào!",
"Mã lỗi 606 là gì?",
"Làm sao khắc phục?", # This should be enhanced to "làm sao khắc phục lỗi 606"
"Còn lỗi nào khác tương tự không?", # This should be enhanced with error context
"Cảm ơn bạn!",
]
for i, message in enumerate(test_messages, 1):
print(f"\n--- Turn {i} ---")
result = agent.process_message(message, test_thread)
print(f"User: {message}")
print(f"Bot: {result['response'][:150]}...")
if result.get("enhanced_query") and result["enhanced_query"] != message:
print(f"🚀 Query enhanced: {result['enhanced_query']}")
# Show conversation history
if i > 1:
history = agent.get_conversation_history(test_thread)
print(f"History length: {len(history)} messages")
print(f"\n📜 Final conversation history:")
history = agent.get_conversation_history(test_thread)
for i, msg in enumerate(history, 1):
print(f" {i}. {msg['role']}: {msg['content'][:100]}...")