minhan6559's picture
Upload 73 files
60d1d13 verified
"""
LangGraph Agent State and Processing Nodes
"""
from typing import Dict, List, Optional, TypedDict, Annotated
from langchain.schema import Document
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
import json
import re
from src.agent.prompts import (
INTENT_CLASSIFICATION_PROMPT,
QUERY_ENHANCEMENT_PROMPT,
RESPONSE_GENERATION_PROMPT,
get_system_prompt_by_intent,
)
class ViettelPayState(TypedDict):
"""State for ViettelPay agent workflow with message history support"""
# Message history for multi-turn conversation
messages: Annotated[List[AnyMessage], add_messages]
# Processing
intent: Optional[str]
confidence: Optional[float]
# Query enhancement
enhanced_query: Optional[str]
# Knowledge retrieval
retrieved_docs: Optional[List[Document]]
# Conversation context (cached to avoid repeated computation)
conversation_context: Optional[str]
# Response type metadata
response_type: Optional[str] # "script" or "generated"
# Metadata
error: Optional[str]
processing_info: Optional[Dict]
def get_conversation_context(messages: List[AnyMessage], max_messages: int = 3) -> str:
"""
Extract conversation context from message history
Args:
messages: List of conversation messages
max_messages: Maximum number of recent messages to include
Returns:
Formatted conversation context string
"""
if len(messages) <= 1:
return ""
context = "\n\nLịch sử cuộc hội thoại:\n"
# Get recent messages (excluding the current/last message for intent classification)
recent_messages = messages[
-(max_messages + 1) : -1
] # Exclude the very last message
for msg in recent_messages:
# Handle different message types more robustly
if hasattr(msg, "type"):
if msg.type == "human":
role = "Người dùng"
elif msg.type == "ai":
role = "Trợ lý"
else:
role = f"Unknown-{msg.type}"
elif hasattr(msg, "role"):
if msg.role in ["user", "human"]:
role = "Người dùng"
elif msg.role in ["assistant", "ai"]:
role = "Trợ lý"
else:
role = f"Unknown-{msg.role}"
else:
role = "Unknown"
# Limit message length to avoid token overflow
# content = msg.content[:1000] + "..." if len(msg.content) > 1000 else msg.content
content = msg.content
context += f"{role}: {content}\n"
# print(context)
return context
def classify_intent_node(state: ViettelPayState, llm_client) -> ViettelPayState:
"""Node for intent classification using LLM with conversation context"""
# Get the latest user message
messages = state["messages"]
if not messages:
return {
**state,
"intent": "unclear",
"confidence": 0.0,
"error": "No messages found",
}
# Find the last human/user message
user_message = None
for msg in reversed(messages):
if hasattr(msg, "type") and msg.type == "human":
user_message = msg.content
break
elif hasattr(msg, "role") and msg.role == "user":
user_message = msg.content
break
if not user_message:
return {
**state,
"intent": "unclear",
"confidence": 0.0,
"error": "No user message found",
}
try:
# Get conversation context for better intent classification
conversation_context = get_conversation_context(messages)
# Intent classification prompt with context using the prompts file
classification_prompt = INTENT_CLASSIFICATION_PROMPT.format(
conversation_context=conversation_context, user_message=user_message
)
# Get classification using the pre-initialized LLM client
response = llm_client.generate(classification_prompt, temperature=0.1)
# print(f"🔍 Raw LLM response: {response}")
# Parse JSON response
try:
# Try to extract JSON from response (in case there's extra text)
response_clean = response.strip()
# Look for JSON object in the response
json_match = re.search(r"\{.*\}", response_clean, re.DOTALL)
if json_match:
json_str = json_match.group()
result = json.loads(json_str)
else:
# Try parsing the whole response
result = json.loads(response_clean)
intent = result.get("intent", "unclear")
confidence = result.get("confidence", 0.5)
explanation = result.get("explanation", "")
# print(
# f"✅ JSON parsed successfully: intent={intent}, confidence={confidence}"
# )
except (json.JSONDecodeError, AttributeError) as e:
print(f"❌ JSON parsing failed: {e}")
print(f" Raw response: {response}")
# Fallback: try to extract intent from text
response_lower = response.lower()
if any(
word in response_lower for word in ["lỗi", "error", "606", "mã lỗi"]
):
intent = "error_help"
confidence = 0.7
elif any(word in response_lower for word in ["xin chào", "hello", "chào"]):
intent = "greeting"
confidence = 0.8
elif any(word in response_lower for word in ["hủy", "cancel", "thủ tục"]):
intent = "procedure_guide"
confidence = 0.7
elif any(
word in response_lower for word in ["nạp", "cước", "dịch vụ", "faq"]
):
intent = "faq"
confidence = 0.7
else:
intent = "unclear"
confidence = 0.3
print(f"🔄 Fallback classification: {intent} (confidence: {confidence})")
explanation = "Fallback classification due to JSON parse error"
# print(f"🎯 Intent classified: {intent} (confidence: {confidence})")
return {
**state,
"intent": intent,
"confidence": confidence,
"conversation_context": conversation_context, # Save context for reuse
"processing_info": {
"classification_raw": response,
"explanation": explanation,
"context_used": bool(conversation_context.strip()),
},
}
except Exception as e:
print(f"❌ Intent classification error: {e}")
return {**state, "intent": "unclear", "confidence": 0.0, "error": str(e)}
def query_enhancement_node(state: ViettelPayState, llm_client) -> ViettelPayState:
"""Node for enhancing search query using conversation context"""
# Get the latest user message
messages = state["messages"]
if not messages:
return {**state, "enhanced_query": "", "error": "No messages found"}
# Find the last human/user message
user_message = None
for msg in reversed(messages):
if hasattr(msg, "type") and msg.type == "human":
user_message = msg.content
break
elif hasattr(msg, "role") and msg.role == "user":
user_message = msg.content
break
if not user_message:
return {**state, "enhanced_query": "", "error": "No user message found"}
try:
# Use saved conversation context if available, otherwise get it
conversation_context = state.get("conversation_context")
if conversation_context is None:
conversation_context = get_conversation_context(messages)
# If no context, use original message
if not conversation_context.strip():
print(f"🔍 No context available, using original query: {user_message}")
return {**state, "enhanced_query": user_message}
# Query enhancement prompt using the prompts file
enhancement_prompt = QUERY_ENHANCEMENT_PROMPT.format(
conversation_context=conversation_context, user_message=user_message
)
# Get enhanced query
enhanced_query = llm_client.generate(enhancement_prompt, temperature=0.1)
enhanced_query = enhanced_query.strip()
print(f"🔍 Original query: {user_message}")
print(f"🚀 Enhanced query: {enhanced_query}")
return {**state, "enhanced_query": enhanced_query}
except Exception as e:
print(f"❌ Query enhancement error: {e}")
# Fallback to original message
return {**state, "enhanced_query": user_message, "error": str(e)}
def knowledge_retrieval_node(
state: ViettelPayState, knowledge_retriever
) -> ViettelPayState:
"""Node for knowledge retrieval using pre-initialized ViettelKnowledgeBase"""
# Use enhanced query if available, otherwise fall back to extracting from messages
enhanced_query = state.get("enhanced_query", "")
if not enhanced_query:
# Fallback: extract from messages
messages = state["messages"]
if not messages:
return {**state, "retrieved_docs": [], "error": "No messages found"}
# Find the last human/user message
for msg in reversed(messages):
if hasattr(msg, "type") and msg.type == "human":
enhanced_query = msg.content
break
elif hasattr(msg, "role") and msg.role == "user":
enhanced_query = msg.content
break
if not enhanced_query:
return {**state, "retrieved_docs": [], "error": "No query available"}
try:
if not knowledge_retriever:
raise ValueError("Knowledge retriever not available")
# Retrieve relevant documents using enhanced query and pre-initialized ViettelKnowledgeBase
retrieved_docs = knowledge_retriever.search(enhanced_query, top_k=10)
print(
f"📚 Retrieved {len(retrieved_docs)} documents for enhanced query: {enhanced_query}"
)
return {**state, "retrieved_docs": retrieved_docs}
except Exception as e:
print(f"❌ Knowledge retrieval error: {e}")
return {**state, "retrieved_docs": [], "error": str(e)}
def script_response_node(state: ViettelPayState) -> ViettelPayState:
"""Node for script-based responses"""
from src.agent.scripts import ConversationScripts
from langchain_core.messages import AIMessage
intent = state.get("intent", "")
try:
# Load scripts
scripts = ConversationScripts("./viettelpay_docs/processed/kich_ban.csv")
# Map intents to script types
intent_to_script = {
"greeting": "greeting",
"out_of_scope": "out_of_scope",
"human_request": "human_request_attempt_1", # Could be enhanced later
"unclear": "ask_for_clarity",
}
script_type = intent_to_script.get(intent)
if script_type and scripts.has_script(script_type):
response_text = scripts.get_script(script_type)
print(f"📋 Using script response: {script_type}")
# Add AI message to the conversation
ai_message = AIMessage(content=response_text)
return {**state, "messages": [ai_message], "response_type": "script"}
else:
# Fallback script
fallback_response = (
"Xin lỗi, em chưa hiểu rõ yêu cầu của anh/chị. Vui lòng thử lại."
)
ai_message = AIMessage(content=fallback_response)
print(f"📋 Using fallback script for intent: {intent}")
return {**state, "messages": [ai_message], "response_type": "script"}
except Exception as e:
print(f"❌ Script response error: {e}")
fallback_response = "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau."
ai_message = AIMessage(content=fallback_response)
return {
**state,
"messages": [ai_message],
"response_type": "error",
"error": str(e),
}
def generate_response_node(state: ViettelPayState, llm_client) -> ViettelPayState:
"""Node for LLM-based response generation with conversation context"""
from langchain_core.messages import AIMessage
# Get the latest user message and conversation history
messages = state["messages"]
if not messages:
ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.")
return {**state, "messages": [ai_message], "response_type": "error"}
# Find the last human/user message
user_message = None
for msg in reversed(messages):
if hasattr(msg, "type") and msg.type == "human":
user_message = msg.content
break
elif hasattr(msg, "role") and msg.role == "user":
user_message = msg.content
break
if not user_message:
ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.")
return {**state, "messages": [ai_message], "response_type": "error"}
intent = state.get("intent", "")
retrieved_docs = state.get("retrieved_docs", [])
enhanced_query = state.get("enhanced_query", "")
try:
# Build context from retrieved documents using original content
context = ""
if retrieved_docs:
context = "\n\n".join(
[
f"[{doc.metadata.get('doc_type', 'unknown')}] {doc.metadata.get('original_content', doc.page_content)}"
for doc in retrieved_docs
]
)
# Use saved conversation context if available, otherwise get it
conversation_context = state.get("conversation_context")
if conversation_context is None:
conversation_context = get_conversation_context(messages, max_messages=6)
# Get system prompt based on intent using the prompts file
system_prompt = get_system_prompt_by_intent(intent)
# Build full prompt with both knowledge context and conversation context using the prompts file
generation_prompt = RESPONSE_GENERATION_PROMPT.format(
system_prompt=system_prompt,
context=context,
conversation_context=conversation_context,
user_message=user_message,
enhanced_query=enhanced_query,
)
# Generate response using the pre-initialized LLM client
response_text = llm_client.generate(generation_prompt, temperature=0.1)
print(f"🤖 Generated response for intent: {intent}")
# Add AI message to the conversation
ai_message = AIMessage(content=response_text)
return {**state, "messages": [ai_message], "response_type": "generated"}
except Exception as e:
print(f"❌ Response generation error: {e}")
error_response = "Xin lỗi, em gặp lỗi khi xử lý yêu cầu. Vui lòng thử lại sau."
ai_message = AIMessage(content=error_response)
return {
**state,
"messages": [ai_message],
"response_type": "error",
"error": str(e),
}
# Routing function for conditional edges
def route_after_intent_classification(state: ViettelPayState) -> str:
"""Route to appropriate node after intent classification"""
intent = state.get("intent", "unclear")
# Script-based intents (no knowledge retrieval needed)
script_intents = {"greeting", "out_of_scope", "human_request", "unclear"}
if intent in script_intents:
return "script_response"
else:
# Knowledge-based intents need query enhancement first
return "query_enhancement"
def route_after_query_enhancement(state: ViettelPayState) -> str:
"""Route after query enhancement (always to knowledge retrieval)"""
return "knowledge_retrieval"
def route_after_knowledge_retrieval(state: ViettelPayState) -> str:
"""Route after knowledge retrieval (always to generation)"""
return "generate_response"