File size: 15,922 Bytes
60d1d13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
"""
ViettelPay AI Agent using LangGraph
Multi-turn conversation support with short-term memory using InMemorySaver
"""

import os
from typing import Dict, Optional
from functools import partial
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.messages import HumanMessage

from src.agent.nodes import (
    ViettelPayState,
    classify_intent_node,
    query_enhancement_node,
    knowledge_retrieval_node,
    script_response_node,
    generate_response_node,
    route_after_intent_classification,
    route_after_query_enhancement,
    route_after_knowledge_retrieval,
)

# Import configuration utility
from src.utils.config import get_knowledge_base_path, get_llm_provider


class ViettelPayAgent:
    """Main ViettelPay AI Agent using LangGraph workflow with multi-turn conversation support"""

    def __init__(
        self,
        knowledge_base_path: str = None,
        scripts_file: Optional[str] = None,
        llm_provider: str = None,
    ):
        knowledge_base_path = knowledge_base_path or get_knowledge_base_path()
        scripts_file = scripts_file or "./viettelpay_docs/processed/kich_ban.csv"
        llm_provider = llm_provider or get_llm_provider()

        self.knowledge_base_path = knowledge_base_path
        self.scripts_file = scripts_file
        self.llm_provider = llm_provider

        # Initialize LLM client once during agent creation
        print(f"🧠 Initializing LLM client ({self.llm_provider})...")
        from src.llm.llm_client import LLMClientFactory

        self.llm_client = LLMClientFactory.create_client(self.llm_provider)
        print(f"✅ LLM client initialized and ready")

        # Initialize knowledge retriever once during agent creation
        print(f"📚 Initializing knowledge retriever...")
        try:
            from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase

            self.knowledge_base = ViettelKnowledgeBase()
            ensemble_retriever = self.knowledge_base.load_knowledge_base(
                knowledge_base_path
            )
            if not ensemble_retriever:
                raise ValueError(
                    f"Knowledge base not found at {knowledge_base_path}. Run build_database_script.py first."
                )
            print(f"✅ Knowledge retriever initialized and ready")
        except Exception as e:
            print(f"⚠️ Knowledge retriever initialization failed: {e}")
            self.knowledge_base = None

        # Initialize checkpointer for short-term memory
        self.checkpointer = InMemorySaver()

        # Build workflow with pre-initialized components
        self.workflow = self._build_workflow()
        self.app = self.workflow.compile(checkpointer=self.checkpointer)

        print("✅ ViettelPay Agent initialized with multi-turn conversation support")

    def _build_workflow(self) -> StateGraph:
        """Build LangGraph workflow with pre-initialized components"""

        # Create workflow graph
        workflow = StateGraph(ViettelPayState)

        # Create node functions with pre-bound components using functools.partial
        # This eliminates the need to initialize components in each node call
        classify_intent_with_llm = partial(
            classify_intent_node, llm_client=self.llm_client
        )
        query_enhancement_with_llm = partial(
            query_enhancement_node, llm_client=self.llm_client
        )
        knowledge_retrieval_with_retriever = partial(
            knowledge_retrieval_node, knowledge_retriever=self.knowledge_base
        )
        generate_response_with_llm = partial(
            generate_response_node, llm_client=self.llm_client
        )

        # Add nodes (some with pre-bound components, some without)
        workflow.add_node("classify_intent", classify_intent_with_llm)
        workflow.add_node("query_enhancement", query_enhancement_with_llm)
        workflow.add_node("knowledge_retrieval", knowledge_retrieval_with_retriever)
        workflow.add_node(
            "script_response", script_response_node
        )  # No pre-bound components needed
        workflow.add_node("generate_response", generate_response_with_llm)

        # Set entry point
        workflow.set_entry_point("classify_intent")

        # Add conditional routing after intent classification
        workflow.add_conditional_edges(
            "classify_intent",
            route_after_intent_classification,
            {
                "script_response": "script_response",
                "query_enhancement": "query_enhancement",
            },
        )

        # Script responses go directly to end
        workflow.add_edge("script_response", END)

        # Query enhancement goes to knowledge retrieval
        workflow.add_edge("query_enhancement", "knowledge_retrieval")

        # Knowledge retrieval goes to response generation
        workflow.add_edge("knowledge_retrieval", "generate_response")
        workflow.add_edge("generate_response", END)

        print("🔄 LangGraph workflow built successfully with optimized component usage")
        return workflow

    def process_message(self, user_message: str, thread_id: str = "default") -> Dict:
        """Process a user message in a multi-turn conversation"""

        print(f"\n💬 Processing message: '{user_message}' (thread: {thread_id})")
        print("=" * 50)

        # Create configuration with thread_id for conversation memory
        config = {"configurable": {"thread_id": thread_id}}

        try:
            # Create human message
            human_message = HumanMessage(content=user_message)

            # Initialize state with the new message
            # Note: conversation_context is set to None so it gets recomputed with fresh message history
            initial_state = {
                "messages": [human_message],
                "intent": None,
                "confidence": None,
                "enhanced_query": None,
                "retrieved_docs": None,
                "conversation_context": None,  # Reset to ensure fresh context computation
                "response_type": None,
                "error": None,
                "processing_info": None,
            }

            # Run workflow with memory
            result = self.app.invoke(initial_state, config)

            # Extract response from the last AI message
            messages = result.get("messages", [])
            if messages:
                # Get the last AI message
                last_message = messages[-1]
                if hasattr(last_message, "content"):
                    response = last_message.content
                else:
                    response = str(last_message)
            else:
                response = "Xin lỗi, em không thể xử lý yêu cầu này."

            response_type = result.get("response_type", "unknown")
            intent = result.get("intent", "unknown")
            confidence = result.get("confidence", 0.0)
            enhanced_query = result.get("enhanced_query", "")
            error = result.get("error")

            # Build response info
            response_info = {
                "response": response,
                "intent": intent,
                "confidence": confidence,
                "response_type": response_type,
                "enhanced_query": enhanced_query,
                "success": error is None,
                "error": error,
                "thread_id": thread_id,
                "message_count": len(messages),
            }

            print(f"✅ Response generated successfully")
            print(f"   Intent: {intent} (confidence: {confidence})")
            print(f"   Type: {response_type}")
            if enhanced_query and enhanced_query != user_message:
                print(f"   Enhanced query: {enhanced_query}")
            print(f"   Thread: {thread_id}")

            return response_info

        except Exception as e:
            print(f"❌ Workflow error: {e}")

            return {
                "response": "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau.",
                "intent": "error",
                "confidence": 0.0,
                "response_type": "error",
                "enhanced_query": "",
                "success": False,
                "error": str(e),
                "thread_id": thread_id,
                "message_count": 0,
            }

    def chat(self, user_message: str, thread_id: str = "default") -> str:
        """Simple chat interface - returns just the response text"""
        result = self.process_message(user_message, thread_id)
        return result["response"]

    def get_conversation_history(self, thread_id: str = "default") -> list:
        """Get conversation history for a specific thread"""
        try:
            config = {"configurable": {"thread_id": thread_id}}

            # Get the current state to access message history
            current_state = self.app.get_state(config)

            if current_state and current_state.values.get("messages"):
                messages = current_state.values["messages"]
                history = []

                for msg in messages:
                    if hasattr(msg, "type") and hasattr(msg, "content"):
                        role = "user" if msg.type == "human" else "assistant"
                        history.append({"role": role, "content": msg.content})
                    elif hasattr(msg, "role") and hasattr(msg, "content"):
                        history.append({"role": msg.role, "content": msg.content})

                return history
            else:
                return []

        except Exception as e:
            print(f"❌ Error getting conversation history: {e}")
            return []

    def clear_conversation(self, thread_id: str = "default") -> bool:
        """Clear conversation history for a specific thread"""
        try:
            # Note: InMemorySaver doesn't have a direct clear method
            # The conversation will be cleared when the app is restarted
            # For persistent memory, you'd need to implement a clear method
            print(f"📝 Conversation clearing requested for thread: {thread_id}")
            print("   Note: InMemorySaver conversations clear on app restart")
            return True
        except Exception as e:
            print(f"❌ Error clearing conversation: {e}")
            return False

    def get_workflow_info(self) -> Dict:
        """Get information about the workflow structure"""
        return {
            "nodes": [
                "classify_intent",
                "query_enhancement",
                "knowledge_retrieval",
                "script_response",
                "generate_response",
            ],
            "entry_point": "classify_intent",
            "knowledge_base_path": self.knowledge_base_path,
            "scripts_file": self.scripts_file,
            "llm_provider": self.llm_provider,
            "memory_type": "InMemorySaver",
            "multi_turn": True,
            "query_enhancement": True,
            "optimizations": {
                "llm_client": "Single initialization with functools.partial",
                "knowledge_retriever": "Single initialization with functools.partial",
                "conversation_context": "Cached in state to avoid repeated computation",
            },
        }

    def health_check(self) -> Dict:
        """Check if all components are working"""

        health_status = {
            "agent": True,
            "workflow": True,
            "memory": True,
            "llm": False,
            "knowledge_base": False,
            "scripts": False,
            "overall": False,
        }

        try:
            # Test LLM client (already initialized)
            test_response = self.llm_client.generate("Hello", temperature=0.1)
            health_status["llm"] = bool(test_response)
            print("✅ LLM client working")

        except Exception as e:
            print(f"⚠️ LLM health check failed: {e}")
            health_status["llm"] = False

        try:
            # Test memory/checkpointer
            test_config = {"configurable": {"thread_id": "health_check"}}
            test_state = {"messages": [HumanMessage(content="test")]}

            # Try to invoke with memory
            self.app.invoke(test_state, test_config)
            health_status["memory"] = True
            print("✅ Memory/checkpointer working")

        except Exception as e:
            print(f"⚠️ Memory health check failed: {e}")
            health_status["memory"] = False

        try:
            # Test knowledge base (using pre-initialized retriever)
            if self.knowledge_base:
                # Test a simple search to verify it's working
                test_docs = self.knowledge_base.search("test", top_k=1)
                health_status["knowledge_base"] = True
                print("✅ Knowledge retriever working")
            else:
                health_status["knowledge_base"] = False
                print("❌ Knowledge retriever not initialized")

        except Exception as e:
            print(f"⚠️ Knowledge base health check failed: {e}")
            health_status["knowledge_base"] = False

        try:
            # Test scripts
            from src.agent.scripts import ConversationScripts

            scripts = ConversationScripts(self.scripts_file)
            health_status["scripts"] = len(scripts.get_all_script_types()) > 0

        except Exception as e:
            print(f"⚠️ Scripts health check failed: {e}")

        # Overall health
        health_status["overall"] = all(
            [
                health_status["agent"],
                health_status["memory"],
                health_status["llm"],
                health_status["knowledge_base"],
                health_status["scripts"],
            ]
        )

        return health_status


# Usage example and testing
if __name__ == "__main__":
    # Initialize agent
    agent = ViettelPayAgent()

    # Health check
    print("\n🏥 Health Check:")
    health = agent.health_check()
    for component, status in health.items():
        status_icon = "✅" if status else "❌"
        print(f"   {component}: {status_icon}")

    if not health["overall"]:
        print("\n⚠️ Some components are not healthy. Check requirements and data files.")
        exit(1)

    print(f"\n🤖 Agent ready! Workflow info: {agent.get_workflow_info()}")

    # Test multi-turn conversation with query enhancement
    test_thread = "test_conversation"

    print(
        f"\n🧪 Testing multi-turn conversation with query enhancement (thread: {test_thread}):"
    )

    test_messages = [
        "Xin chào!",
        "Mã lỗi 606 là gì?",
        "Làm sao khắc phục?",  # This should be enhanced to "làm sao khắc phục lỗi 606"
        "Còn lỗi nào khác tương tự không?",  # This should be enhanced with error context
        "Cảm ơn bạn!",
    ]

    for i, message in enumerate(test_messages, 1):
        print(f"\n--- Turn {i} ---")
        result = agent.process_message(message, test_thread)
        print(f"User: {message}")
        print(f"Bot: {result['response'][:150]}...")

        if result.get("enhanced_query") and result["enhanced_query"] != message:
            print(f"🚀 Query enhanced: {result['enhanced_query']}")

        # Show conversation history
        if i > 1:
            history = agent.get_conversation_history(test_thread)
            print(f"History length: {len(history)} messages")

    print(f"\n📜 Final conversation history:")
    history = agent.get_conversation_history(test_thread)
    for i, msg in enumerate(history, 1):
        print(f"  {i}. {msg['role']}: {msg['content'][:100]}...")