minhan6559 commited on
Commit
60d1d13
·
verified ·
1 Parent(s): 8e735a4

Upload 73 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env +9 -0
  2. .gitattributes +4 -35
  3. .gitignore +2 -0
  4. .streamlit/secrets.toml +23 -0
  5. evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json +0 -0
  6. evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json +0 -0
  7. evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json +0 -0
  8. evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval.json +0 -0
  9. evaluation_data/results/intent_classification/viettelpay_intent_results.json +0 -0
  10. evaluation_data/results/multi_turn_retrieval/viettelpay_multiturn_results.json +0 -0
  11. evaluation_data/results/single_turn_retrieval/viettelpay_eval_results.json +0 -0
  12. knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/data_level0.bin +3 -0
  13. knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/header.bin +0 -0
  14. knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/length.bin +0 -0
  15. knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/link_lists.bin +0 -0
  16. knowledge_base/chroma/chroma.sqlite3 +3 -0
  17. requirements.txt +31 -2
  18. src/__pycache__/knowledge_base_builder.cpython-310.pyc +0 -0
  19. src/__pycache__/knowledge_base_builder.cpython-312.pyc +0 -0
  20. src/__pycache__/simplified_knowledge_base.cpython-310.pyc +0 -0
  21. src/agent/__pycache__/memory.cpython-311.pyc +0 -0
  22. src/agent/__pycache__/nodes.cpython-310.pyc +0 -0
  23. src/agent/__pycache__/nodes.cpython-311.pyc +0 -0
  24. src/agent/__pycache__/prompts.cpython-311.pyc +0 -0
  25. src/agent/__pycache__/scripts.cpython-310.pyc +0 -0
  26. src/agent/__pycache__/scripts.cpython-311.pyc +0 -0
  27. src/agent/__pycache__/viettelpay_agent.cpython-310.pyc +0 -0
  28. src/agent/__pycache__/viettelpay_agent.cpython-311.pyc +0 -0
  29. src/agent/nodes.py +463 -0
  30. src/agent/prompts.py +125 -0
  31. src/agent/scripts.py +157 -0
  32. src/agent/viettelpay_agent.py +416 -0
  33. src/evaluation/__pycache__/prompts.cpython-311.pyc +0 -0
  34. src/evaluation/__pycache__/single_turn_retrieval.cpython-311.pyc +0 -0
  35. src/evaluation/intent_classification.py +901 -0
  36. src/evaluation/multi_turn_retrieval.py +815 -0
  37. src/evaluation/prompts.py +318 -0
  38. src/evaluation/single_turn_retrieval.py +844 -0
  39. src/knowledge_base/__pycache__/builder.cpython-310.pyc +0 -0
  40. src/knowledge_base/__pycache__/builder.cpython-311.pyc +0 -0
  41. src/knowledge_base/__pycache__/viettel_knowledge_base.cpython-311.pyc +0 -0
  42. src/knowledge_base/viettel_knowledge_base.py +521 -0
  43. src/llm/__pycache__/langchain_models.cpython-311.pyc +0 -0
  44. src/llm/__pycache__/llm_client.cpython-310.pyc +0 -0
  45. src/llm/__pycache__/llm_client.cpython-311.pyc +0 -0
  46. src/llm/llm_client.py +181 -0
  47. src/processor/__pycache__/contextual_word_processor.cpython-311.pyc +0 -0
  48. src/processor/__pycache__/csv_processor.cpython-310.pyc +0 -0
  49. src/processor/__pycache__/csv_processor.cpython-311.pyc +0 -0
  50. src/processor/__pycache__/csv_processor.cpython-312.pyc +0 -0
.env ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # .env file
2
+ GEMINI_API_KEY="AIzaSyAn1HDv1_zU4TQbwwUtBIOo7tu5iKVBWho"
3
+ GOOGLE_API_KEY="AIzaSyAn1HDv1_zU4TQbwwUtBIOo7tu5iKVBWho"
4
+ OPENAI_API_KEY="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"
5
+
6
+ COHERE_API_KEY="D6PBHYizSmWFqzHoMWafV65yJelDh6X3Xg0ghIue"
7
+
8
+ # Production
9
+ # COHERE_API_KEY="VFaacPxkjW0L4HaijiBXuKYWqgYj8XkAo3o5uMWu"
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/data_level0.bin filter=lfs diff=lfs merge=lfs -text
4
+ knowledge_base/chroma/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ secrets.toml
.streamlit/secrets.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Streamlit Secrets Configuration
2
+ # Copy your API keys here for production
3
+ # For local development, you can still use .env file
4
+
5
+ [api_keys]
6
+ GEMINI_API_KEY = "AIzaSyAn1HDv1_zU4TQbwwUtBIOo7tu5iKVBWho"
7
+ OPENAI_API_KEY = "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"
8
+ COHERE_API_KEY = "D6PBHYizSmWFqzHoMWafV65yJelDh6X3Xg0ghIue"
9
+
10
+ # Production
11
+ # COHERE_API_KEY="VFaacPxkjW0L4HaijiBXuKYWqgYj8XkAo3o5uMWu"
12
+
13
+ # Database and storage paths
14
+ [paths]
15
+ KNOWLEDGE_BASE_PATH = "./knowledge_base"
16
+ DOCUMENTS_FOLDER = "./viettelpay_docs"
17
+
18
+ # Model configurations
19
+ [models]
20
+ EMBEDDING_MODEL = "dangvantuan/vietnamese-document-embedding"
21
+ LLM_PROVIDER = "gemini"
22
+ GEMINI_MODEL = "gemini-2.0-flash"
23
+ OPENAI_MODEL = "gpt-4o-mini"
evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluation_data/results/intent_classification/viettelpay_intent_results.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluation_data/results/multi_turn_retrieval/viettelpay_multiturn_results.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluation_data/results/single_turn_retrieval/viettelpay_eval_results.json ADDED
The diff for this file is too large to render. See raw diff
 
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23add52afbe7588391f32d3deffb581b2663d2e2ad8851aba7de25e6b3f66761
3
+ size 32120000
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/header.bin ADDED
Binary file (100 Bytes). View file
 
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/length.bin ADDED
Binary file (40 kB). View file
 
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/link_lists.bin ADDED
File without changes
knowledge_base/chroma/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b962ed5c081a92e80ccb83ef40c1dfb530542ca161a4cefe9fc6ccebebf23e75
3
+ size 1937408
requirements.txt CHANGED
@@ -1,3 +1,32 @@
1
- altair
2
  pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
  pandas
3
+ python-docx
4
+ langchain-chroma==0.2.4
5
+ chromadb==1.0.12
6
+ langchain_cohere==0.4.4
7
+ langchain==0.3.25
8
+ langgraph
9
+ langchain_community==0.3.24
10
+ google-generativeai
11
+ openai
12
+ sentence-transformers==4.1.0
13
+ pydantic
14
+ python-dotenv
15
+ PyYAML
16
+ tqdm
17
+ loguru
18
+ scikit-learn==1.6.1
19
+ protobuf<3.21
20
+ grpcio-status==1.48.2
21
+ torch
22
+ transformers==4.52.4
23
+ rank_bm25==0.2.2
24
+ markitdown[docx]
25
+ underthesea
26
+ pyvi
27
+ langchain-huggingface==0.2.0
28
+ streamlit
29
+ langmem
30
+ dotenv
31
+ numpy==1.26.4
32
+ python-docx
src/__pycache__/knowledge_base_builder.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
src/__pycache__/knowledge_base_builder.cpython-312.pyc ADDED
Binary file (7.03 kB). View file
 
src/__pycache__/simplified_knowledge_base.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
src/agent/__pycache__/memory.cpython-311.pyc ADDED
Binary file (4.13 kB). View file
 
src/agent/__pycache__/nodes.cpython-310.pyc ADDED
Binary file (8.57 kB). View file
 
src/agent/__pycache__/nodes.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
src/agent/__pycache__/prompts.cpython-311.pyc ADDED
Binary file (8.36 kB). View file
 
src/agent/__pycache__/scripts.cpython-310.pyc ADDED
Binary file (7.02 kB). View file
 
src/agent/__pycache__/scripts.cpython-311.pyc ADDED
Binary file (9.54 kB). View file
 
src/agent/__pycache__/viettelpay_agent.cpython-310.pyc ADDED
Binary file (5.74 kB). View file
 
src/agent/__pycache__/viettelpay_agent.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
src/agent/nodes.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph Agent State and Processing Nodes
3
+ """
4
+
5
+ from typing import Dict, List, Optional, TypedDict, Annotated
6
+ from langchain.schema import Document
7
+ from langchain_core.messages import AnyMessage
8
+ from langgraph.graph.message import add_messages
9
+ import json
10
+ import re
11
+
12
+ from src.agent.prompts import (
13
+ INTENT_CLASSIFICATION_PROMPT,
14
+ QUERY_ENHANCEMENT_PROMPT,
15
+ RESPONSE_GENERATION_PROMPT,
16
+ get_system_prompt_by_intent,
17
+ )
18
+
19
+
20
+ class ViettelPayState(TypedDict):
21
+ """State for ViettelPay agent workflow with message history support"""
22
+
23
+ # Message history for multi-turn conversation
24
+ messages: Annotated[List[AnyMessage], add_messages]
25
+
26
+ # Processing
27
+ intent: Optional[str]
28
+ confidence: Optional[float]
29
+
30
+ # Query enhancement
31
+ enhanced_query: Optional[str]
32
+
33
+ # Knowledge retrieval
34
+ retrieved_docs: Optional[List[Document]]
35
+
36
+ # Conversation context (cached to avoid repeated computation)
37
+ conversation_context: Optional[str]
38
+
39
+ # Response type metadata
40
+ response_type: Optional[str] # "script" or "generated"
41
+
42
+ # Metadata
43
+ error: Optional[str]
44
+ processing_info: Optional[Dict]
45
+
46
+
47
+ def get_conversation_context(messages: List[AnyMessage], max_messages: int = 3) -> str:
48
+ """
49
+ Extract conversation context from message history
50
+
51
+ Args:
52
+ messages: List of conversation messages
53
+ max_messages: Maximum number of recent messages to include
54
+
55
+ Returns:
56
+ Formatted conversation context string
57
+ """
58
+ if len(messages) <= 1:
59
+ return ""
60
+
61
+ context = "\n\nLịch sử cuộc hội thoại:\n"
62
+ # Get recent messages (excluding the current/last message for intent classification)
63
+ recent_messages = messages[
64
+ -(max_messages + 1) : -1
65
+ ] # Exclude the very last message
66
+
67
+ for msg in recent_messages:
68
+ # Handle different message types more robustly
69
+ if hasattr(msg, "type"):
70
+ if msg.type == "human":
71
+ role = "Người dùng"
72
+ elif msg.type == "ai":
73
+ role = "Trợ lý"
74
+ else:
75
+ role = f"Unknown-{msg.type}"
76
+ elif hasattr(msg, "role"):
77
+ if msg.role in ["user", "human"]:
78
+ role = "Người dùng"
79
+ elif msg.role in ["assistant", "ai"]:
80
+ role = "Trợ lý"
81
+ else:
82
+ role = f"Unknown-{msg.role}"
83
+ else:
84
+ role = "Unknown"
85
+
86
+ # Limit message length to avoid token overflow
87
+ # content = msg.content[:1000] + "..." if len(msg.content) > 1000 else msg.content
88
+ content = msg.content
89
+ context += f"{role}: {content}\n"
90
+ # print(context)
91
+ return context
92
+
93
+
94
+ def classify_intent_node(state: ViettelPayState, llm_client) -> ViettelPayState:
95
+ """Node for intent classification using LLM with conversation context"""
96
+
97
+ # Get the latest user message
98
+ messages = state["messages"]
99
+ if not messages:
100
+ return {
101
+ **state,
102
+ "intent": "unclear",
103
+ "confidence": 0.0,
104
+ "error": "No messages found",
105
+ }
106
+
107
+ # Find the last human/user message
108
+ user_message = None
109
+ for msg in reversed(messages):
110
+ if hasattr(msg, "type") and msg.type == "human":
111
+ user_message = msg.content
112
+ break
113
+ elif hasattr(msg, "role") and msg.role == "user":
114
+ user_message = msg.content
115
+ break
116
+
117
+ if not user_message:
118
+ return {
119
+ **state,
120
+ "intent": "unclear",
121
+ "confidence": 0.0,
122
+ "error": "No user message found",
123
+ }
124
+
125
+ try:
126
+ # Get conversation context for better intent classification
127
+ conversation_context = get_conversation_context(messages)
128
+
129
+ # Intent classification prompt with context using the prompts file
130
+ classification_prompt = INTENT_CLASSIFICATION_PROMPT.format(
131
+ conversation_context=conversation_context, user_message=user_message
132
+ )
133
+
134
+ # Get classification using the pre-initialized LLM client
135
+ response = llm_client.generate(classification_prompt, temperature=0.1)
136
+
137
+ # print(f"🔍 Raw LLM response: {response}")
138
+
139
+ # Parse JSON response
140
+ try:
141
+ # Try to extract JSON from response (in case there's extra text)
142
+ response_clean = response.strip()
143
+
144
+ # Look for JSON object in the response
145
+ json_match = re.search(r"\{.*\}", response_clean, re.DOTALL)
146
+ if json_match:
147
+ json_str = json_match.group()
148
+ result = json.loads(json_str)
149
+ else:
150
+ # Try parsing the whole response
151
+ result = json.loads(response_clean)
152
+
153
+ intent = result.get("intent", "unclear")
154
+ confidence = result.get("confidence", 0.5)
155
+ explanation = result.get("explanation", "")
156
+
157
+ # print(
158
+ # f"✅ JSON parsed successfully: intent={intent}, confidence={confidence}"
159
+ # )
160
+
161
+ except (json.JSONDecodeError, AttributeError) as e:
162
+ print(f"❌ JSON parsing failed: {e}")
163
+ print(f" Raw response: {response}")
164
+
165
+ # Fallback: try to extract intent from text
166
+ response_lower = response.lower()
167
+ if any(
168
+ word in response_lower for word in ["lỗi", "error", "606", "mã lỗi"]
169
+ ):
170
+ intent = "error_help"
171
+ confidence = 0.7
172
+ elif any(word in response_lower for word in ["xin chào", "hello", "chào"]):
173
+ intent = "greeting"
174
+ confidence = 0.8
175
+ elif any(word in response_lower for word in ["hủy", "cancel", "thủ tục"]):
176
+ intent = "procedure_guide"
177
+ confidence = 0.7
178
+ elif any(
179
+ word in response_lower for word in ["nạp", "cước", "dịch vụ", "faq"]
180
+ ):
181
+ intent = "faq"
182
+ confidence = 0.7
183
+ else:
184
+ intent = "unclear"
185
+ confidence = 0.3
186
+
187
+ print(f"🔄 Fallback classification: {intent} (confidence: {confidence})")
188
+ explanation = "Fallback classification due to JSON parse error"
189
+
190
+ # print(f"🎯 Intent classified: {intent} (confidence: {confidence})")
191
+
192
+ return {
193
+ **state,
194
+ "intent": intent,
195
+ "confidence": confidence,
196
+ "conversation_context": conversation_context, # Save context for reuse
197
+ "processing_info": {
198
+ "classification_raw": response,
199
+ "explanation": explanation,
200
+ "context_used": bool(conversation_context.strip()),
201
+ },
202
+ }
203
+
204
+ except Exception as e:
205
+ print(f"❌ Intent classification error: {e}")
206
+ return {**state, "intent": "unclear", "confidence": 0.0, "error": str(e)}
207
+
208
+
209
+ def query_enhancement_node(state: ViettelPayState, llm_client) -> ViettelPayState:
210
+ """Node for enhancing search query using conversation context"""
211
+
212
+ # Get the latest user message
213
+ messages = state["messages"]
214
+ if not messages:
215
+ return {**state, "enhanced_query": "", "error": "No messages found"}
216
+
217
+ # Find the last human/user message
218
+ user_message = None
219
+ for msg in reversed(messages):
220
+ if hasattr(msg, "type") and msg.type == "human":
221
+ user_message = msg.content
222
+ break
223
+ elif hasattr(msg, "role") and msg.role == "user":
224
+ user_message = msg.content
225
+ break
226
+
227
+ if not user_message:
228
+ return {**state, "enhanced_query": "", "error": "No user message found"}
229
+
230
+ try:
231
+ # Use saved conversation context if available, otherwise get it
232
+ conversation_context = state.get("conversation_context")
233
+ if conversation_context is None:
234
+ conversation_context = get_conversation_context(messages)
235
+
236
+ # If no context, use original message
237
+ if not conversation_context.strip():
238
+ print(f"🔍 No context available, using original query: {user_message}")
239
+ return {**state, "enhanced_query": user_message}
240
+
241
+ # Query enhancement prompt using the prompts file
242
+ enhancement_prompt = QUERY_ENHANCEMENT_PROMPT.format(
243
+ conversation_context=conversation_context, user_message=user_message
244
+ )
245
+
246
+ # Get enhanced query
247
+ enhanced_query = llm_client.generate(enhancement_prompt, temperature=0.1)
248
+ enhanced_query = enhanced_query.strip()
249
+
250
+ print(f"🔍 Original query: {user_message}")
251
+ print(f"🚀 Enhanced query: {enhanced_query}")
252
+
253
+ return {**state, "enhanced_query": enhanced_query}
254
+
255
+ except Exception as e:
256
+ print(f"❌ Query enhancement error: {e}")
257
+ # Fallback to original message
258
+ return {**state, "enhanced_query": user_message, "error": str(e)}
259
+
260
+
261
+ def knowledge_retrieval_node(
262
+ state: ViettelPayState, knowledge_retriever
263
+ ) -> ViettelPayState:
264
+ """Node for knowledge retrieval using pre-initialized ViettelKnowledgeBase"""
265
+
266
+ # Use enhanced query if available, otherwise fall back to extracting from messages
267
+ enhanced_query = state.get("enhanced_query", "")
268
+
269
+ if not enhanced_query:
270
+ # Fallback: extract from messages
271
+ messages = state["messages"]
272
+ if not messages:
273
+ return {**state, "retrieved_docs": [], "error": "No messages found"}
274
+
275
+ # Find the last human/user message
276
+ for msg in reversed(messages):
277
+ if hasattr(msg, "type") and msg.type == "human":
278
+ enhanced_query = msg.content
279
+ break
280
+ elif hasattr(msg, "role") and msg.role == "user":
281
+ enhanced_query = msg.content
282
+ break
283
+
284
+ if not enhanced_query:
285
+ return {**state, "retrieved_docs": [], "error": "No query available"}
286
+
287
+ try:
288
+ if not knowledge_retriever:
289
+ raise ValueError("Knowledge retriever not available")
290
+
291
+ # Retrieve relevant documents using enhanced query and pre-initialized ViettelKnowledgeBase
292
+ retrieved_docs = knowledge_retriever.search(enhanced_query, top_k=10)
293
+
294
+ print(
295
+ f"📚 Retrieved {len(retrieved_docs)} documents for enhanced query: {enhanced_query}"
296
+ )
297
+
298
+ return {**state, "retrieved_docs": retrieved_docs}
299
+
300
+ except Exception as e:
301
+ print(f"❌ Knowledge retrieval error: {e}")
302
+ return {**state, "retrieved_docs": [], "error": str(e)}
303
+
304
+
305
+ def script_response_node(state: ViettelPayState) -> ViettelPayState:
306
+ """Node for script-based responses"""
307
+
308
+ from src.agent.scripts import ConversationScripts
309
+ from langchain_core.messages import AIMessage
310
+
311
+ intent = state.get("intent", "")
312
+
313
+ try:
314
+ # Load scripts
315
+ scripts = ConversationScripts("./viettelpay_docs/processed/kich_ban.csv")
316
+
317
+ # Map intents to script types
318
+ intent_to_script = {
319
+ "greeting": "greeting",
320
+ "out_of_scope": "out_of_scope",
321
+ "human_request": "human_request_attempt_1", # Could be enhanced later
322
+ "unclear": "ask_for_clarity",
323
+ }
324
+
325
+ script_type = intent_to_script.get(intent)
326
+
327
+ if script_type and scripts.has_script(script_type):
328
+ response_text = scripts.get_script(script_type)
329
+ print(f"📋 Using script response: {script_type}")
330
+
331
+ # Add AI message to the conversation
332
+ ai_message = AIMessage(content=response_text)
333
+
334
+ return {**state, "messages": [ai_message], "response_type": "script"}
335
+
336
+ else:
337
+ # Fallback script
338
+ fallback_response = (
339
+ "Xin lỗi, em chưa hiểu rõ yêu cầu của anh/chị. Vui lòng thử lại."
340
+ )
341
+ ai_message = AIMessage(content=fallback_response)
342
+
343
+ print(f"📋 Using fallback script for intent: {intent}")
344
+
345
+ return {**state, "messages": [ai_message], "response_type": "script"}
346
+
347
+ except Exception as e:
348
+ print(f"❌ Script response error: {e}")
349
+ fallback_response = "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau."
350
+ ai_message = AIMessage(content=fallback_response)
351
+
352
+ return {
353
+ **state,
354
+ "messages": [ai_message],
355
+ "response_type": "error",
356
+ "error": str(e),
357
+ }
358
+
359
+
360
+ def generate_response_node(state: ViettelPayState, llm_client) -> ViettelPayState:
361
+ """Node for LLM-based response generation with conversation context"""
362
+
363
+ from langchain_core.messages import AIMessage
364
+
365
+ # Get the latest user message and conversation history
366
+ messages = state["messages"]
367
+ if not messages:
368
+ ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.")
369
+ return {**state, "messages": [ai_message], "response_type": "error"}
370
+
371
+ # Find the last human/user message
372
+ user_message = None
373
+ for msg in reversed(messages):
374
+ if hasattr(msg, "type") and msg.type == "human":
375
+ user_message = msg.content
376
+ break
377
+ elif hasattr(msg, "role") and msg.role == "user":
378
+ user_message = msg.content
379
+ break
380
+
381
+ if not user_message:
382
+ ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.")
383
+ return {**state, "messages": [ai_message], "response_type": "error"}
384
+
385
+ intent = state.get("intent", "")
386
+ retrieved_docs = state.get("retrieved_docs", [])
387
+ enhanced_query = state.get("enhanced_query", "")
388
+
389
+ try:
390
+ # Build context from retrieved documents using original content
391
+ context = ""
392
+ if retrieved_docs:
393
+ context = "\n\n".join(
394
+ [
395
+ f"[{doc.metadata.get('doc_type', 'unknown')}] {doc.metadata.get('original_content', doc.page_content)}"
396
+ for doc in retrieved_docs
397
+ ]
398
+ )
399
+
400
+ # Use saved conversation context if available, otherwise get it
401
+ conversation_context = state.get("conversation_context")
402
+ if conversation_context is None:
403
+ conversation_context = get_conversation_context(messages, max_messages=6)
404
+
405
+ # Get system prompt based on intent using the prompts file
406
+ system_prompt = get_system_prompt_by_intent(intent)
407
+
408
+ # Build full prompt with both knowledge context and conversation context using the prompts file
409
+ generation_prompt = RESPONSE_GENERATION_PROMPT.format(
410
+ system_prompt=system_prompt,
411
+ context=context,
412
+ conversation_context=conversation_context,
413
+ user_message=user_message,
414
+ enhanced_query=enhanced_query,
415
+ )
416
+
417
+ # Generate response using the pre-initialized LLM client
418
+ response_text = llm_client.generate(generation_prompt, temperature=0.1)
419
+
420
+ print(f"🤖 Generated response for intent: {intent}")
421
+
422
+ # Add AI message to the conversation
423
+ ai_message = AIMessage(content=response_text)
424
+
425
+ return {**state, "messages": [ai_message], "response_type": "generated"}
426
+
427
+ except Exception as e:
428
+ print(f"❌ Response generation error: {e}")
429
+ error_response = "Xin lỗi, em gặp lỗi khi xử lý yêu cầu. Vui lòng thử lại sau."
430
+ ai_message = AIMessage(content=error_response)
431
+
432
+ return {
433
+ **state,
434
+ "messages": [ai_message],
435
+ "response_type": "error",
436
+ "error": str(e),
437
+ }
438
+
439
+
440
+ # Routing function for conditional edges
441
+ def route_after_intent_classification(state: ViettelPayState) -> str:
442
+ """Route to appropriate node after intent classification"""
443
+
444
+ intent = state.get("intent", "unclear")
445
+
446
+ # Script-based intents (no knowledge retrieval needed)
447
+ script_intents = {"greeting", "out_of_scope", "human_request", "unclear"}
448
+
449
+ if intent in script_intents:
450
+ return "script_response"
451
+ else:
452
+ # Knowledge-based intents need query enhancement first
453
+ return "query_enhancement"
454
+
455
+
456
+ def route_after_query_enhancement(state: ViettelPayState) -> str:
457
+ """Route after query enhancement (always to knowledge retrieval)"""
458
+ return "knowledge_retrieval"
459
+
460
+
461
+ def route_after_knowledge_retrieval(state: ViettelPayState) -> str:
462
+ """Route after knowledge retrieval (always to generation)"""
463
+ return "generate_response"
src/agent/prompts.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for ViettelPay AI Agent
3
+ All prompts using Vietnamese language for ViettelPay Pro customer support
4
+ """
5
+
6
+ # Intent Classification Prompt (JSON format for better parsing)
7
+ INTENT_CLASSIFICATION_PROMPT = """
8
+ Bạn là hệ thống phân loại ý định cho ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
9
+ Phân tích tin nhắn của người dùng và trả về ý định chính.
10
+
11
+ Các loại ý định:
12
+ * **`greeting`**: Chỉ là lời chào hỏi thuần túy, không có câu hỏi hoặc yêu cầu cụ thể nào khác. Nếu tin nhắn có cả lời chào VÀ câu hỏi thì phân loại theo các ý định khác, không phải greeting.
13
+ * *Ví dụ:* "chào em", "hello shop", "xin chào ạ"
14
+ * *Không phải greeting:* "xin chào, cho hỏi về lỗi 606" → đây là error_help
15
+ * **`faq`**: Các câu hỏi đáp chung, tìm hiểu về dịch vụ, tính năng, v.v.
16
+ * *Ví dụ:* "App có bán thẻ game không?", "ViettelPay Pro nạp tiền được cho mạng nao?"
17
+ * **`error_help`**: Báo cáo sự cố, hỏi về mã lỗi cụ thể.
18
+ * *Ví dụ:* "Giao dịch báo lỗi 606", "tại sao tôi không thanh toán được?", "lỗi này là gì?"
19
+ * **`procedure_guide`**: Hỏi về các bước cụ thể để thực hiện một tác vụ.
20
+ * *Ví dụ:* "làm thế nào để hủy giao dịch?", "chỉ tôi cách lấy lại mã thẻ cào", "hướng dẫn nạp cước"
21
+ * **`human_request`**: Yêu cầu được nói chuyện trực tiếp với nhân viên hỗ trợ.
22
+ * *Ví dụ:* "cho tôi gặp người thật", "nối máy cho tổng đài", "em k hiểu, cho gặp ai đó"
23
+ * **`out_of_scope`**: Câu hỏi ngoài phạm vi ViettelPay (thời tiết, chính trị, v.v.), không liên quan gì đến các dịch vụ tài chính, viễn thông của Viettel.
24
+ * *Ví dụ:* "dự báo thời tiết hôm nay?", "giá xăng bao nhiêu?", "cách nấu phở"
25
+ * **`unclear`**: Câu hỏi không rõ ràng, thiếu thông tin cụ thể, cần người dùng bổ sung thêm chi tiết để có thể hỗ trợ hiệu quả.
26
+ * *Ví dụ:* "lỗi", "giúp với", "gd", "???", "ko hiểu", "bị lỗi giờ sao đây", "không thực hiện được", "sao vậy", "tại sao thế"
27
+
28
+ **Bối cảnh cuộc trò chuyện:**
29
+ <conversation_context>
30
+ {conversation_context}
31
+ </conversation_context>
32
+
33
+ **Tin nhắn mới của người dùng:**
34
+ <user_message>
35
+ {user_message}
36
+ </user_message>
37
+
38
+ Hãy phân tích dựa trên cả ngữ cảnh cuộc hội thoại và tin nhắn mới nhất của người dùng.
39
+
40
+ QUAN TRỌNG: Chỉ trả về JSON thuần túy, không có text khác. Format chính xác:
41
+ {{"intent": "tên_ý_định", "confidence": 0.9, "explanation": "lý do ngắn gọn"}}
42
+ """
43
+
44
+ # Query Enhancement Prompt for contextual search improvement
45
+ QUERY_ENHANCEMENT_PROMPT = """
46
+ **Nhiệm vụ:** Bạn là một trợ lý chuyên gia của ViettelPay Pro.
47
+ ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
48
+ Nhiệm vụ của bạn là đọc cuộc trò chuyện và tin nhắn mới nhất của người dùng để tạo ra một truy vấn tìm kiếm (search query) duy nhất, tối ưu cho cơ sở dữ liệu nội bộ.
49
+
50
+ **Bối cảnh cuộc trò chuyện:**
51
+ <conversation_context>
52
+ {conversation_context}
53
+ </conversation_context>
54
+
55
+ **Tin nhắn mới của người dùng:**
56
+ <user_message>
57
+ {user_message}
58
+ </user_message
59
+
60
+ **Quy tắc tạo truy vấn:**
61
+ 1. **Kết hợp Ngữ cảnh:** Phân tích toàn bộ cuộc trò chuyện để tạo ra một truy vấn đầy đủ ý nghĩa, nắm bắt được mục tiêu thực sự của người dùng.
62
+ 2. **Làm rõ & Cụ thể:** Thay thế các đại từ (ví dụ: "nó", "cái đó") bằng các chủ thể hoặc thuật ngữ cụ thể đã được đề cập (ví dụ: "liên kết ngân hàng", "rút tiền tại ATM", "mã lỗi 101").
63
+ 3. **Tích hợp Thuật ngữ:** Tích hợp một cách **tự nhiên** các từ khóa và thuật ngữ chuyên ngành của ViettelPay Pro (ví dụ: "giao dịch", "nạp cước", "chiết khấu", "OTP", "hoa hồng").
64
+ 4. **Duy trì Tính tự nhiên (QUAN TRỌNG):** Truy vấn phải là một câu hỏi hoặc một cụm từ hoàn chỉnh, tự nhiên bằng tiếng Việt. **Tránh tạo ra danh sách từ khóa rời rạc.**
65
+ * **Tốt:** "cách tính hoa hồng khi nạp thẻ điện thoại cho khách"
66
+ * **Không tốt:** "hoa hồng nạp thẻ điện thoại"
67
+ 5. **Giữ lại Ý định Gốc:** Truy vấn phải phản ánh chính xác câu hỏi của người dùng, không thêm thông tin hoặc suy diễn.
68
+ 6. Thêm vài câu sử dụng từ đồng nghĩa và cách diễn đạt khác nhau trong tiếng Việt để tăng khả năng tìm kiếm
69
+
70
+ **ĐẦU RA:** CHỈ trả về một chuỗi truy vấn tìm kiếm đã được cải thiện. Không thêm lời giải thích.
71
+ """
72
+
73
+ # System Prompt for Error Help responses
74
+ ERROR_HELP_SYSTEM_PROMPT = """Bạn là chuyên gia hỗ trợ kỹ thuật ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
75
+ Thể hiện sự cảm thông với khó khăn của người dùng.
76
+ Cung cấp giải pháp cụ thể, từng bước.
77
+ Nếu cần hỗ trợ thêm, hướng dẫn liên hệ tổng đài.
78
+ Nếu có lịch sử cuộc hội thoại, hãy tham khảo để đưa ra câu trả lời phù hợp và có tính liên kết."""
79
+
80
+ # System Prompt for Procedure Guide responses
81
+ PROCEDURE_GUIDE_SYSTEM_PROMPT = """Bạn là hướng dẫn viên ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
82
+ Cung cấp hướng dẫn từng bước rõ ràng.
83
+ Bao gồm link video nếu có trong thông tin.
84
+ Sử dụng format có số thứ tự cho các bước.
85
+ Nếu có lịch sử cuộc hội thoại, hãy tham khảo để đưa ra câu trả lời phù hợp và có tính liên kết."""
86
+
87
+ # Default System Prompt for general responses
88
+ DEFAULT_SYSTEM_PROMPT = """Bạn là trợ lý ảo ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
89
+ Trả lời câu hỏi dựa trên thông tin được cung cấp.
90
+ Giọng điệu thân thiện, chuyên nghiệp.
91
+ Sử dụng "Anh/chị" khi xưng hô.
92
+ Nếu có lịch sử cuộc hội thoại, hãy tham khảo để đưa ra câu trả lời phù hợp và có tính liên kết."""
93
+
94
+ # Response Generation Template with context and knowledge base integration
95
+ RESPONSE_GENERATION_PROMPT = """<system_prompt>
96
+ {system_prompt}
97
+ </system_prompt>
98
+
99
+ **Thông tin tham khảo từ cơ sở tri thức:**
100
+ <knowledge_base_context>
101
+ {context}
102
+ </knowledge_base_context>
103
+
104
+ **Bối cảnh cuộc trò chuyện:**
105
+ <conversation_context>
106
+ {conversation_context}
107
+ </conversation_context>
108
+
109
+ **Tin nhắn mới của người dùng:**
110
+ <user_message>
111
+ {user_message}
112
+ </user_message>
113
+
114
+ Hãy trả lời câu hỏi dựa trên thông tin tham khảo và lịch sử cuộc hội thoại (nếu có). Nếu không có thông tin phù hợp, hãy nói rằng bạn cần thêm thông tin.
115
+ """
116
+
117
+
118
+ def get_system_prompt_by_intent(intent: str) -> str:
119
+ """Get appropriate system prompt based on intent classification"""
120
+ if intent == "error_help":
121
+ return ERROR_HELP_SYSTEM_PROMPT
122
+ elif intent == "procedure_guide":
123
+ return PROCEDURE_GUIDE_SYSTEM_PROMPT
124
+ else: # faq, etc.
125
+ return DEFAULT_SYSTEM_PROMPT
src/agent/scripts.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation Scripts Handler
3
+ Manages predefined scripts for standard conversation scenarios
4
+ """
5
+
6
+ import os
7
+ import pandas as pd
8
+ from typing import Dict, Optional
9
+
10
+
11
+ class ConversationScripts:
12
+ """Handler for conversation scripts and standard responses"""
13
+
14
+ def __init__(self, scripts_file: Optional[str] = None):
15
+ self.scripts = {}
16
+ self.scripts_file = scripts_file
17
+
18
+ # Default built-in scripts (fallback)
19
+ self._load_default_scripts()
20
+
21
+ # Load from file if provided
22
+ if scripts_file and os.path.exists(scripts_file):
23
+ self._load_scripts_from_file(scripts_file)
24
+ print(f"✅ Loaded conversation scripts from {scripts_file}")
25
+ else:
26
+ print("⚠️ Using default built-in scripts")
27
+
28
+ def _load_default_scripts(self):
29
+ """Load default conversation scripts"""
30
+ self.scripts = {
31
+ "greeting": """Xin chào! Em là Trợ lý ảo Viettelpay Pro sẵn sàng hỗ trợ Anh/chị!
32
+ Hiện tại, Trợ lý ảo đang trong giai đoạn thử nghiệm hỗ trợ nghiệp vụ cước viễn thông, thẻ cào và thẻ game với các nội dung sau:
33
+ - Hướng dẫn sử dụng
34
+ - Chính sách phí bán hàng
35
+ - Tìm hiểu quy định hủy giao dịch
36
+ - Hướng dẫn xử lý một số lỗi thường gặp.
37
+
38
+ Anh/Chị vui lòng bấm vào từng chủ đề để xem chi tiết.
39
+ Nếu thông tin chưa đáp ứng nhu cầu, Anh/Chị hãy đặt lại câu hỏi để em tiếp tục hỗ trợ ạ!""",
40
+ "out_of_scope": """Cảm ơn Anh/chị đã đặt câu hỏi!
41
+ Trợ lý ảo Viettelpay Pro đang thử nghiệm và cập nhật kiến thức nghiệp vụ để hỗ trợ Anh/chị tốt hơn. Vì vậy, rất tiếc nhu cầu hiện tại của Anh/chị nằm ngoài khả năng hỗ trợ của Trợ lý ảo.
42
+
43
+ Để được hỗ trợ chính xác và đầy đủ hơn, Anh/chị vui lòng gửi yêu cầu hỗ trợ tại đây""",
44
+ "human_request_attempt_1": """Anh/Chị vui lòng chia sẻ thêm nội dung cần hỗ trợ, Em rất mong được giải đáp trực tiếp để tiết kiệm thời gian của Anh/Chị ạ!""",
45
+ "human_request_attempt_2": """Rất tiếc! Hiện tại hệ thống chưa có Tư vấn viên hỗ trợ trực tuyến.
46
+ Tuy nhiên, Anh/chị vẫn có thể yêu cầu hỗ trợ được trợ giúp qua các hình thức sau:
47
+ 📌 1. Đặt câu hỏi ngay tại đây, Trợ lý ảo ViettelPay Pro luôn sẵn sàng hỗ trợ Anh/chị trong phạm vi nghiệp vụ thử nghiệm (nghiệp vụ cước viễn thông, thẻ cào, thẻ game):
48
+ ✅ Hướng dẫn sử dụng
49
+ ✅ Chính sách phí bán hàng
50
+ ✅ Tìm hiểu về quy định hủy giao dịch
51
+ ✅ Hướng dẫn xử lý một số lỗi thường gặp.
52
+ 📌 2. Tìm hiểu thông tin nghiệp vụ tại mục:
53
+ 📚 Hướng dẫn, hỗ trợ: Các video hướng dẫn nghiệp vụ
54
+ 💡Thông báo: Các tin tức nghiệp vụ và tin nâng cấp hệ thống/tin sự cố.
55
+ 📌 3. Gửi yêu cầu hỗ trợ tại đây
56
+ Hoặc gọi Tổng đài 1789 nhánh 5 trong trường hợp khẩn cấp""",
57
+ "confirmation_check": """Anh/Chị có thắc mắc thêm vấn đề nào liên quan đến nội dung em vừa cung cấp không ạ?""",
58
+ "closing": """Hy vọng những thông tin vừa rồi đã đáp ứng nhu cầu của Anh/chị.
59
+ Nếu cần hỗ trợ thêm, Anh/Chị hãy đặt câu hỏi để em tiếp tục hỗ trợ ạ!
60
+ 🌟 Chúc Anh/chị một ngày thật vui và thành công!""",
61
+ "ask_for_clarity": """Em chưa hiểu rõ yêu cầu của Anh/chị. Anh/chị có thể chia sẻ cụ thể hơn được không ạ?""",
62
+ "empathy_error": """Em hiểu Anh/chị đang gặp khó khăn với lỗi này. Để hỗ trợ Anh/chị tốt nhất, em sẽ tìm hiểu và đưa ra hướng giải quyết cụ thể.""",
63
+ }
64
+
65
+ def _load_scripts_from_file(self, file_path: str):
66
+ """Load scripts from CSV file (kich_ban.csv format)"""
67
+ try:
68
+ df = pd.read_csv(file_path)
69
+
70
+ # Map CSV scenarios to script keys
71
+ scenario_mapping = {
72
+ "Chào hỏi": "greeting",
73
+ "Trao đổi thông tin chính": "out_of_scope", # First occurrence
74
+ "Trước khi kết thúc phiên": "confirmation_check",
75
+ "Kết thúc": "closing",
76
+ }
77
+
78
+ for _, row in df.iterrows():
79
+ scenario_type = row.get("Loại kịch bản", "")
80
+ situation = row.get("Tình huống", "")
81
+ script = row.get("Kịch bản chốt", "")
82
+
83
+ # Handle specific mappings
84
+ if scenario_type == "Chào hỏi":
85
+ self.scripts["greeting"] = script
86
+ elif scenario_type == "Trao đổi thông tin chính":
87
+ if "ngoài nghiệp vụ" in situation:
88
+ self.scripts["out_of_scope"] = script
89
+ elif "gặp tư vấn viên" in situation:
90
+ if "Lần 1" in script:
91
+ self.scripts["human_request_attempt_1"] = (
92
+ script.split("Lần 1:")[1].split("Lần 2:")[0].strip()
93
+ )
94
+ if "Lần 2" in script:
95
+ self.scripts["human_request_attempt_2"] = script.split(
96
+ "Lần 2:"
97
+ )[1].strip()
98
+ elif "không đủ ý" in situation:
99
+ self.scripts["ask_for_clarity"] = (
100
+ "Em chưa hiểu rõ yêu cầu của Anh/chị. Anh/chị có thể chia sẻ cụ thể hơn được không ạ?"
101
+ )
102
+ elif "lỗi" in situation:
103
+ self.scripts["empathy_error"] = (
104
+ "Em hiểu Anh/chị đang gặp khó khăn với lỗi này."
105
+ )
106
+ elif scenario_type == "Trước khi kết thúc phiên":
107
+ self.scripts["confirmation_check"] = script
108
+ elif scenario_type == "Kết thúc":
109
+ self.scripts["closing"] = script
110
+
111
+ except Exception as e:
112
+ print(f"⚠️ Error loading scripts from file: {e}")
113
+ print("Using default scripts instead")
114
+
115
+ def get_script(self, script_type: str) -> Optional[str]:
116
+ """Get script by type"""
117
+ return self.scripts.get(script_type)
118
+
119
+ def has_script(self, script_type: str) -> bool:
120
+ """Check if script exists"""
121
+ return script_type in self.scripts
122
+
123
+ def get_all_script_types(self) -> list:
124
+ """Get all available script types"""
125
+ return list(self.scripts.keys())
126
+
127
+ def add_script(self, script_type: str, script_content: str):
128
+ """Add or update a script"""
129
+ self.scripts[script_type] = script_content
130
+
131
+ def get_stats(self) -> dict:
132
+ """Get statistics about loaded scripts"""
133
+ return {
134
+ "total_scripts": len(self.scripts),
135
+ "script_types": list(self.scripts.keys()),
136
+ "source": "file" if self.scripts_file else "default",
137
+ }
138
+
139
+
140
+ # Usage example
141
+ if __name__ == "__main__":
142
+ # Test with default scripts
143
+ scripts = ConversationScripts()
144
+ print("📊 Scripts Stats:", scripts.get_stats())
145
+
146
+ # Test specific scripts
147
+ greeting = scripts.get_script("greeting")
148
+ print(f"\n👋 Greeting Script:\n{greeting}")
149
+
150
+ # Test loading from file (if available)
151
+ try:
152
+ scripts_with_file = ConversationScripts(
153
+ "./viettelpay_docs/processed/kich_ban.csv"
154
+ )
155
+ print(f"\n📊 File-based Scripts Stats: {scripts_with_file.get_stats()}")
156
+ except Exception as e:
157
+ print(f"File loading test failed: {e}")
src/agent/viettelpay_agent.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViettelPay AI Agent using LangGraph
3
+ Multi-turn conversation support with short-term memory using InMemorySaver
4
+ """
5
+
6
+ import os
7
+ from typing import Dict, Optional
8
+ from functools import partial
9
+ from langgraph.graph import StateGraph, END
10
+ from langgraph.checkpoint.memory import InMemorySaver
11
+ from langchain_core.messages import HumanMessage
12
+
13
+ from src.agent.nodes import (
14
+ ViettelPayState,
15
+ classify_intent_node,
16
+ query_enhancement_node,
17
+ knowledge_retrieval_node,
18
+ script_response_node,
19
+ generate_response_node,
20
+ route_after_intent_classification,
21
+ route_after_query_enhancement,
22
+ route_after_knowledge_retrieval,
23
+ )
24
+
25
+ # Import configuration utility
26
+ from src.utils.config import get_knowledge_base_path, get_llm_provider
27
+
28
+
29
+ class ViettelPayAgent:
30
+ """Main ViettelPay AI Agent using LangGraph workflow with multi-turn conversation support"""
31
+
32
+ def __init__(
33
+ self,
34
+ knowledge_base_path: str = None,
35
+ scripts_file: Optional[str] = None,
36
+ llm_provider: str = None,
37
+ ):
38
+ knowledge_base_path = knowledge_base_path or get_knowledge_base_path()
39
+ scripts_file = scripts_file or "./viettelpay_docs/processed/kich_ban.csv"
40
+ llm_provider = llm_provider or get_llm_provider()
41
+
42
+ self.knowledge_base_path = knowledge_base_path
43
+ self.scripts_file = scripts_file
44
+ self.llm_provider = llm_provider
45
+
46
+ # Initialize LLM client once during agent creation
47
+ print(f"🧠 Initializing LLM client ({self.llm_provider})...")
48
+ from src.llm.llm_client import LLMClientFactory
49
+
50
+ self.llm_client = LLMClientFactory.create_client(self.llm_provider)
51
+ print(f"✅ LLM client initialized and ready")
52
+
53
+ # Initialize knowledge retriever once during agent creation
54
+ print(f"📚 Initializing knowledge retriever...")
55
+ try:
56
+ from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
57
+
58
+ self.knowledge_base = ViettelKnowledgeBase()
59
+ ensemble_retriever = self.knowledge_base.load_knowledge_base(
60
+ knowledge_base_path
61
+ )
62
+ if not ensemble_retriever:
63
+ raise ValueError(
64
+ f"Knowledge base not found at {knowledge_base_path}. Run build_database_script.py first."
65
+ )
66
+ print(f"✅ Knowledge retriever initialized and ready")
67
+ except Exception as e:
68
+ print(f"⚠️ Knowledge retriever initialization failed: {e}")
69
+ self.knowledge_base = None
70
+
71
+ # Initialize checkpointer for short-term memory
72
+ self.checkpointer = InMemorySaver()
73
+
74
+ # Build workflow with pre-initialized components
75
+ self.workflow = self._build_workflow()
76
+ self.app = self.workflow.compile(checkpointer=self.checkpointer)
77
+
78
+ print("✅ ViettelPay Agent initialized with multi-turn conversation support")
79
+
80
+ def _build_workflow(self) -> StateGraph:
81
+ """Build LangGraph workflow with pre-initialized components"""
82
+
83
+ # Create workflow graph
84
+ workflow = StateGraph(ViettelPayState)
85
+
86
+ # Create node functions with pre-bound components using functools.partial
87
+ # This eliminates the need to initialize components in each node call
88
+ classify_intent_with_llm = partial(
89
+ classify_intent_node, llm_client=self.llm_client
90
+ )
91
+ query_enhancement_with_llm = partial(
92
+ query_enhancement_node, llm_client=self.llm_client
93
+ )
94
+ knowledge_retrieval_with_retriever = partial(
95
+ knowledge_retrieval_node, knowledge_retriever=self.knowledge_base
96
+ )
97
+ generate_response_with_llm = partial(
98
+ generate_response_node, llm_client=self.llm_client
99
+ )
100
+
101
+ # Add nodes (some with pre-bound components, some without)
102
+ workflow.add_node("classify_intent", classify_intent_with_llm)
103
+ workflow.add_node("query_enhancement", query_enhancement_with_llm)
104
+ workflow.add_node("knowledge_retrieval", knowledge_retrieval_with_retriever)
105
+ workflow.add_node(
106
+ "script_response", script_response_node
107
+ ) # No pre-bound components needed
108
+ workflow.add_node("generate_response", generate_response_with_llm)
109
+
110
+ # Set entry point
111
+ workflow.set_entry_point("classify_intent")
112
+
113
+ # Add conditional routing after intent classification
114
+ workflow.add_conditional_edges(
115
+ "classify_intent",
116
+ route_after_intent_classification,
117
+ {
118
+ "script_response": "script_response",
119
+ "query_enhancement": "query_enhancement",
120
+ },
121
+ )
122
+
123
+ # Script responses go directly to end
124
+ workflow.add_edge("script_response", END)
125
+
126
+ # Query enhancement goes to knowledge retrieval
127
+ workflow.add_edge("query_enhancement", "knowledge_retrieval")
128
+
129
+ # Knowledge retrieval goes to response generation
130
+ workflow.add_edge("knowledge_retrieval", "generate_response")
131
+ workflow.add_edge("generate_response", END)
132
+
133
+ print("🔄 LangGraph workflow built successfully with optimized component usage")
134
+ return workflow
135
+
136
+ def process_message(self, user_message: str, thread_id: str = "default") -> Dict:
137
+ """Process a user message in a multi-turn conversation"""
138
+
139
+ print(f"\n💬 Processing message: '{user_message}' (thread: {thread_id})")
140
+ print("=" * 50)
141
+
142
+ # Create configuration with thread_id for conversation memory
143
+ config = {"configurable": {"thread_id": thread_id}}
144
+
145
+ try:
146
+ # Create human message
147
+ human_message = HumanMessage(content=user_message)
148
+
149
+ # Initialize state with the new message
150
+ # Note: conversation_context is set to None so it gets recomputed with fresh message history
151
+ initial_state = {
152
+ "messages": [human_message],
153
+ "intent": None,
154
+ "confidence": None,
155
+ "enhanced_query": None,
156
+ "retrieved_docs": None,
157
+ "conversation_context": None, # Reset to ensure fresh context computation
158
+ "response_type": None,
159
+ "error": None,
160
+ "processing_info": None,
161
+ }
162
+
163
+ # Run workflow with memory
164
+ result = self.app.invoke(initial_state, config)
165
+
166
+ # Extract response from the last AI message
167
+ messages = result.get("messages", [])
168
+ if messages:
169
+ # Get the last AI message
170
+ last_message = messages[-1]
171
+ if hasattr(last_message, "content"):
172
+ response = last_message.content
173
+ else:
174
+ response = str(last_message)
175
+ else:
176
+ response = "Xin lỗi, em không thể xử lý yêu cầu này."
177
+
178
+ response_type = result.get("response_type", "unknown")
179
+ intent = result.get("intent", "unknown")
180
+ confidence = result.get("confidence", 0.0)
181
+ enhanced_query = result.get("enhanced_query", "")
182
+ error = result.get("error")
183
+
184
+ # Build response info
185
+ response_info = {
186
+ "response": response,
187
+ "intent": intent,
188
+ "confidence": confidence,
189
+ "response_type": response_type,
190
+ "enhanced_query": enhanced_query,
191
+ "success": error is None,
192
+ "error": error,
193
+ "thread_id": thread_id,
194
+ "message_count": len(messages),
195
+ }
196
+
197
+ print(f"✅ Response generated successfully")
198
+ print(f" Intent: {intent} (confidence: {confidence})")
199
+ print(f" Type: {response_type}")
200
+ if enhanced_query and enhanced_query != user_message:
201
+ print(f" Enhanced query: {enhanced_query}")
202
+ print(f" Thread: {thread_id}")
203
+
204
+ return response_info
205
+
206
+ except Exception as e:
207
+ print(f"❌ Workflow error: {e}")
208
+
209
+ return {
210
+ "response": "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau.",
211
+ "intent": "error",
212
+ "confidence": 0.0,
213
+ "response_type": "error",
214
+ "enhanced_query": "",
215
+ "success": False,
216
+ "error": str(e),
217
+ "thread_id": thread_id,
218
+ "message_count": 0,
219
+ }
220
+
221
+ def chat(self, user_message: str, thread_id: str = "default") -> str:
222
+ """Simple chat interface - returns just the response text"""
223
+ result = self.process_message(user_message, thread_id)
224
+ return result["response"]
225
+
226
+ def get_conversation_history(self, thread_id: str = "default") -> list:
227
+ """Get conversation history for a specific thread"""
228
+ try:
229
+ config = {"configurable": {"thread_id": thread_id}}
230
+
231
+ # Get the current state to access message history
232
+ current_state = self.app.get_state(config)
233
+
234
+ if current_state and current_state.values.get("messages"):
235
+ messages = current_state.values["messages"]
236
+ history = []
237
+
238
+ for msg in messages:
239
+ if hasattr(msg, "type") and hasattr(msg, "content"):
240
+ role = "user" if msg.type == "human" else "assistant"
241
+ history.append({"role": role, "content": msg.content})
242
+ elif hasattr(msg, "role") and hasattr(msg, "content"):
243
+ history.append({"role": msg.role, "content": msg.content})
244
+
245
+ return history
246
+ else:
247
+ return []
248
+
249
+ except Exception as e:
250
+ print(f"❌ Error getting conversation history: {e}")
251
+ return []
252
+
253
+ def clear_conversation(self, thread_id: str = "default") -> bool:
254
+ """Clear conversation history for a specific thread"""
255
+ try:
256
+ # Note: InMemorySaver doesn't have a direct clear method
257
+ # The conversation will be cleared when the app is restarted
258
+ # For persistent memory, you'd need to implement a clear method
259
+ print(f"📝 Conversation clearing requested for thread: {thread_id}")
260
+ print(" Note: InMemorySaver conversations clear on app restart")
261
+ return True
262
+ except Exception as e:
263
+ print(f"❌ Error clearing conversation: {e}")
264
+ return False
265
+
266
+ def get_workflow_info(self) -> Dict:
267
+ """Get information about the workflow structure"""
268
+ return {
269
+ "nodes": [
270
+ "classify_intent",
271
+ "query_enhancement",
272
+ "knowledge_retrieval",
273
+ "script_response",
274
+ "generate_response",
275
+ ],
276
+ "entry_point": "classify_intent",
277
+ "knowledge_base_path": self.knowledge_base_path,
278
+ "scripts_file": self.scripts_file,
279
+ "llm_provider": self.llm_provider,
280
+ "memory_type": "InMemorySaver",
281
+ "multi_turn": True,
282
+ "query_enhancement": True,
283
+ "optimizations": {
284
+ "llm_client": "Single initialization with functools.partial",
285
+ "knowledge_retriever": "Single initialization with functools.partial",
286
+ "conversation_context": "Cached in state to avoid repeated computation",
287
+ },
288
+ }
289
+
290
+ def health_check(self) -> Dict:
291
+ """Check if all components are working"""
292
+
293
+ health_status = {
294
+ "agent": True,
295
+ "workflow": True,
296
+ "memory": True,
297
+ "llm": False,
298
+ "knowledge_base": False,
299
+ "scripts": False,
300
+ "overall": False,
301
+ }
302
+
303
+ try:
304
+ # Test LLM client (already initialized)
305
+ test_response = self.llm_client.generate("Hello", temperature=0.1)
306
+ health_status["llm"] = bool(test_response)
307
+ print("✅ LLM client working")
308
+
309
+ except Exception as e:
310
+ print(f"⚠️ LLM health check failed: {e}")
311
+ health_status["llm"] = False
312
+
313
+ try:
314
+ # Test memory/checkpointer
315
+ test_config = {"configurable": {"thread_id": "health_check"}}
316
+ test_state = {"messages": [HumanMessage(content="test")]}
317
+
318
+ # Try to invoke with memory
319
+ self.app.invoke(test_state, test_config)
320
+ health_status["memory"] = True
321
+ print("✅ Memory/checkpointer working")
322
+
323
+ except Exception as e:
324
+ print(f"⚠️ Memory health check failed: {e}")
325
+ health_status["memory"] = False
326
+
327
+ try:
328
+ # Test knowledge base (using pre-initialized retriever)
329
+ if self.knowledge_base:
330
+ # Test a simple search to verify it's working
331
+ test_docs = self.knowledge_base.search("test", top_k=1)
332
+ health_status["knowledge_base"] = True
333
+ print("✅ Knowledge retriever working")
334
+ else:
335
+ health_status["knowledge_base"] = False
336
+ print("❌ Knowledge retriever not initialized")
337
+
338
+ except Exception as e:
339
+ print(f"⚠️ Knowledge base health check failed: {e}")
340
+ health_status["knowledge_base"] = False
341
+
342
+ try:
343
+ # Test scripts
344
+ from src.agent.scripts import ConversationScripts
345
+
346
+ scripts = ConversationScripts(self.scripts_file)
347
+ health_status["scripts"] = len(scripts.get_all_script_types()) > 0
348
+
349
+ except Exception as e:
350
+ print(f"⚠️ Scripts health check failed: {e}")
351
+
352
+ # Overall health
353
+ health_status["overall"] = all(
354
+ [
355
+ health_status["agent"],
356
+ health_status["memory"],
357
+ health_status["llm"],
358
+ health_status["knowledge_base"],
359
+ health_status["scripts"],
360
+ ]
361
+ )
362
+
363
+ return health_status
364
+
365
+
366
+ # Usage example and testing
367
+ if __name__ == "__main__":
368
+ # Initialize agent
369
+ agent = ViettelPayAgent()
370
+
371
+ # Health check
372
+ print("\n🏥 Health Check:")
373
+ health = agent.health_check()
374
+ for component, status in health.items():
375
+ status_icon = "✅" if status else "❌"
376
+ print(f" {component}: {status_icon}")
377
+
378
+ if not health["overall"]:
379
+ print("\n⚠️ Some components are not healthy. Check requirements and data files.")
380
+ exit(1)
381
+
382
+ print(f"\n🤖 Agent ready! Workflow info: {agent.get_workflow_info()}")
383
+
384
+ # Test multi-turn conversation with query enhancement
385
+ test_thread = "test_conversation"
386
+
387
+ print(
388
+ f"\n🧪 Testing multi-turn conversation with query enhancement (thread: {test_thread}):"
389
+ )
390
+
391
+ test_messages = [
392
+ "Xin chào!",
393
+ "Mã lỗi 606 là gì?",
394
+ "Làm sao khắc phục?", # This should be enhanced to "làm sao khắc phục lỗi 606"
395
+ "Còn lỗi nào khác tương tự không?", # This should be enhanced with error context
396
+ "Cảm ơn bạn!",
397
+ ]
398
+
399
+ for i, message in enumerate(test_messages, 1):
400
+ print(f"\n--- Turn {i} ---")
401
+ result = agent.process_message(message, test_thread)
402
+ print(f"User: {message}")
403
+ print(f"Bot: {result['response'][:150]}...")
404
+
405
+ if result.get("enhanced_query") and result["enhanced_query"] != message:
406
+ print(f"🚀 Query enhanced: {result['enhanced_query']}")
407
+
408
+ # Show conversation history
409
+ if i > 1:
410
+ history = agent.get_conversation_history(test_thread)
411
+ print(f"History length: {len(history)} messages")
412
+
413
+ print(f"\n📜 Final conversation history:")
414
+ history = agent.get_conversation_history(test_thread)
415
+ for i, msg in enumerate(history, 1):
416
+ print(f" {i}. {msg['role']}: {msg['content'][:100]}...")
src/evaluation/__pycache__/prompts.cpython-311.pyc ADDED
Binary file (16.4 kB). View file
 
src/evaluation/__pycache__/single_turn_retrieval.cpython-311.pyc ADDED
Binary file (39.1 kB). View file
 
src/evaluation/intent_classification.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified Intent Classification Evaluation for ViettelPay AI Agent
3
+ Removed pattern-based generation, improved chunk mixing, and configurable conversations per chunk
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import time
11
+ import random
12
+ from typing import Dict, List, Optional
13
+ from pathlib import Path
14
+ from collections import defaultdict, Counter
15
+ import pandas as pd
16
+ from tqdm import tqdm
17
+ import re
18
+ import numpy as np
19
+
20
+ # Load environment variables from .env file
21
+ from dotenv import load_dotenv
22
+
23
+ load_dotenv()
24
+
25
+ # Add the project root to Python path so we can import from src
26
+ project_root = Path(__file__).parent.parent.parent
27
+ sys.path.insert(0, str(project_root))
28
+
29
+ # Import existing components
30
+ from src.evaluation.prompts import INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT
31
+ from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
32
+ from src.llm.llm_client import LLMClientFactory
33
+ from src.agent.nodes import classify_intent_node, ViettelPayState
34
+ from langchain_core.messages import HumanMessage
35
+
36
+
37
+ class IntentDatasetCreator:
38
+ """Simplified intent classification dataset creator with two strategies"""
39
+
40
+ def __init__(
41
+ self, gemini_api_key: str, knowledge_base: ViettelKnowledgeBase = None
42
+ ):
43
+ """Initialize with Gemini API key and optional knowledge base"""
44
+ self.llm_client = LLMClientFactory.create_client(
45
+ "gemini", api_key=gemini_api_key, model="gemini-2.0-flash"
46
+ )
47
+ self.knowledge_base = knowledge_base
48
+ self.dataset = {
49
+ "conversations": {},
50
+ "generation_methods": {},
51
+ "intent_distribution": {},
52
+ "metadata": {
53
+ "total_conversations": 0,
54
+ "total_user_messages": 0,
55
+ "creation_timestamp": time.time(),
56
+ },
57
+ }
58
+
59
+ print("✅ IntentDatasetCreator initialized (simplified version)")
60
+
61
+ def generate_json_response(
62
+ self, prompt: str, max_retries: int = 3
63
+ ) -> Optional[Dict]:
64
+ """Generate response and parse as JSON with retries"""
65
+ for attempt in range(max_retries):
66
+ try:
67
+ response = self.llm_client.generate(prompt, temperature=0.1)
68
+
69
+ if response:
70
+ response_text = response.strip()
71
+ json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
72
+ if json_match:
73
+ json_text = json_match.group()
74
+ return json.loads(json_text)
75
+ else:
76
+ return json.loads(response_text)
77
+
78
+ except json.JSONDecodeError as e:
79
+ print(f"⚠️ JSON parsing error (attempt {attempt + 1}): {e}")
80
+ if attempt == max_retries - 1:
81
+ print(f"❌ Failed to parse JSON after {max_retries} attempts")
82
+
83
+ except Exception as e:
84
+ print(f"⚠️ API error (attempt {attempt + 1}): {e}")
85
+ if attempt < max_retries - 1:
86
+ time.sleep(2**attempt)
87
+
88
+ return None
89
+
90
+ def get_all_chunks(self) -> List[Dict]:
91
+ """Get ALL chunks from ChromaDB vectorstore"""
92
+ print(f"📚 Retrieving ALL chunks from ChromaDB vectorstore...")
93
+
94
+ if not self.knowledge_base:
95
+ raise ValueError("Knowledge base not provided.")
96
+
97
+ try:
98
+ if (
99
+ not hasattr(self.knowledge_base, "chroma_retriever")
100
+ or not self.knowledge_base.chroma_retriever
101
+ ):
102
+ raise ValueError("ChromaDB retriever not found in knowledge base")
103
+
104
+ vectorstore = self.knowledge_base.chroma_retriever.vectorstore
105
+ all_docs = vectorstore.get(include=["documents", "metadatas"])
106
+
107
+ documents = all_docs["documents"]
108
+ metadatas = all_docs["metadatas"]
109
+
110
+ all_chunks = []
111
+ seen_content_hashes = set()
112
+
113
+ for i, (content, metadata) in enumerate(zip(documents, metadatas)):
114
+ content_hash = hash(content[:300])
115
+
116
+ if (
117
+ content_hash not in seen_content_hashes
118
+ and len(content.strip()) > 100
119
+ ):
120
+ chunk_info = {
121
+ "id": f"chunk_{len(all_chunks)}",
122
+ "content": content,
123
+ "metadata": metadata or {},
124
+ }
125
+ all_chunks.append(chunk_info)
126
+ seen_content_hashes.add(content_hash)
127
+
128
+ print(f"✅ Retrieved {len(all_chunks)} unique chunks from ChromaDB")
129
+ return all_chunks
130
+
131
+ except Exception as e:
132
+ print(f"❌ Error accessing ChromaDB: {e}")
133
+ return []
134
+
135
+ def generate_single_chunk_conversations(
136
+ self, chunk: Dict, num_conversations: int = 3
137
+ ) -> List[Dict]:
138
+ """Generate conversations from single chunk"""
139
+ content = chunk["content"]
140
+
141
+ generation_instruction = "Tạo cuộc hội thoại tập trung vào chủ đề chính của tài liệu. Bao gồm cả các intent phổ biến như greeting, unclear, human_request để tăng tính đa dạng"
142
+
143
+ prompt = INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT.format(
144
+ num_conversations=num_conversations,
145
+ content=content,
146
+ generation_instruction=generation_instruction,
147
+ )
148
+
149
+ response_json = self.generate_json_response(prompt)
150
+
151
+ if response_json and "conversations" in response_json:
152
+ conversations = response_json["conversations"]
153
+ valid_conversations = []
154
+
155
+ for i, conversation in enumerate(conversations):
156
+ if "turns" in conversation and len(conversation["turns"]) >= 1:
157
+ valid_turns = []
158
+ for turn in conversation["turns"]:
159
+ if "user" in turn and "intent" in turn:
160
+ valid_turns.append(turn)
161
+
162
+ if valid_turns:
163
+ conv_obj = {
164
+ "id": f"single_{chunk['id']}_{i}",
165
+ "turns": valid_turns,
166
+ "generation_method": "single_chunk",
167
+ "source_chunks": [chunk["id"]],
168
+ "chunk_metadata": [chunk["metadata"]],
169
+ }
170
+ valid_conversations.append(conv_obj)
171
+ return valid_conversations
172
+ else:
173
+ print(f"⚠️ No valid conversations generated for chunk {chunk['id']}")
174
+ return []
175
+
176
+ def generate_multi_chunk_conversations(
177
+ self, chunks: List[Dict], num_conversations: int = 3
178
+ ) -> List[Dict]:
179
+ """Generate conversations from multiple chunks (2-3 chunks)"""
180
+ # Combine content from multiple chunks
181
+ combined_content = ""
182
+ for i, chunk in enumerate(chunks):
183
+ combined_content += f"\n\n--- Chủ đề {i+1} ---\n" + chunk["content"]
184
+
185
+ generation_instruction = f"Tạo cuộc hội thoại tự nhiên kết hợp {len(chunks)} chủ đề khác nhau. Người dùng có thể chuyển từ chủ đề này sang chủ đề khác. Đặc biệt bao gồm các intent như greeting, unclear, human_request để cuộc hội thoại thực tế hơn"
186
+
187
+ prompt = INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT.format(
188
+ num_conversations=num_conversations,
189
+ content=combined_content,
190
+ generation_instruction=generation_instruction,
191
+ )
192
+
193
+ response_json = self.generate_json_response(prompt)
194
+
195
+ if response_json and "conversations" in response_json:
196
+ conversations = response_json["conversations"]
197
+ valid_conversations = []
198
+
199
+ for i, conversation in enumerate(conversations):
200
+ if "turns" in conversation and len(conversation["turns"]) >= 1:
201
+ valid_turns = []
202
+ for turn in conversation["turns"]:
203
+ if "user" in turn and "intent" in turn:
204
+ valid_turns.append(turn)
205
+
206
+ if valid_turns:
207
+ conv_obj = {
208
+ "id": f"multi_{'-'.join([c['id'] for c in chunks])}_{i}",
209
+ "turns": valid_turns,
210
+ "generation_method": "multi_chunk",
211
+ "source_chunks": [c["id"] for c in chunks],
212
+ "chunk_metadata": [c["metadata"] for c in chunks],
213
+ }
214
+ valid_conversations.append(conv_obj)
215
+
216
+ print(
217
+ f"✅ Generated {len(valid_conversations)} conversations for multi-chunk {[c['id'] for c in chunks]}"
218
+ )
219
+ return valid_conversations
220
+ else:
221
+ print(
222
+ f"⚠️ No valid conversations generated for chunks {[c['id'] for c in chunks]}"
223
+ )
224
+ return []
225
+
226
+ def create_intent_dataset(
227
+ self,
228
+ num_conversations_per_chunk: int = 3,
229
+ save_path: str = "evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json",
230
+ ) -> Dict:
231
+ """Create intent classification dataset using two strategies only"""
232
+ print(f"\n🚀 Creating intent classification dataset...")
233
+ print(f" Conversations per chunk: {num_conversations_per_chunk}")
234
+
235
+ # Step 1: Get all chunks
236
+ all_chunks = self.get_all_chunks()
237
+ if not all_chunks:
238
+ raise ValueError("No chunks found in knowledge base!")
239
+
240
+ total_chunks = len(all_chunks)
241
+ print(f"✅ Using all {total_chunks} chunks and shuffle them")
242
+ random.shuffle(all_chunks)
243
+
244
+ # Step 2: Split chunks for two strategies (60% single, 40% multi)
245
+ split_point = int(total_chunks * 0.6)
246
+ single_chunks = all_chunks[:split_point]
247
+ multi_chunks = all_chunks[split_point:]
248
+
249
+ print(f"📊 Distribution plan:")
250
+ print(
251
+ f" • Single chunk: {len(single_chunks)} chunks → ~{len(single_chunks) * num_conversations_per_chunk} conversations"
252
+ )
253
+ print(
254
+ f" • Multi chunk: {len(multi_chunks)} chunks → ~{len(multi_chunks) // 2.5 * num_conversations_per_chunk} conversations"
255
+ )
256
+
257
+ all_conversations = []
258
+
259
+ # Step 3: Generate single-chunk conversations
260
+ print(f"\n💬 Generating single-chunk conversations...")
261
+ for chunk in tqdm(single_chunks, desc="Single-chunk conversations"):
262
+ conversations = self.generate_single_chunk_conversations(
263
+ chunk, num_conversations_per_chunk
264
+ )
265
+ all_conversations.extend(conversations)
266
+ time.sleep(0.1)
267
+
268
+ # Step 4: Generate multi-chunk conversations (2-3 chunks randomly)
269
+ print(f"\n🔀 Generating multi-chunk conversations...")
270
+ random.shuffle(multi_chunks) # Randomize order
271
+
272
+ i = 0
273
+ while i < len(multi_chunks):
274
+ # Randomly choose to use 2 or 3 chunks
275
+ chunk_count = random.choice([2, 3])
276
+ chunk_group = multi_chunks[i : i + chunk_count]
277
+
278
+ # Only proceed if we have at least 2 chunks
279
+ if len(chunk_group) >= 2:
280
+ conversations = self.generate_multi_chunk_conversations(
281
+ chunk_group, num_conversations_per_chunk
282
+ )
283
+ all_conversations.extend(conversations)
284
+ time.sleep(0.1)
285
+
286
+ i += chunk_count
287
+
288
+ # Step 5: Track generation methods and intent distribution
289
+ method_stats = defaultdict(int)
290
+ intent_counts = Counter()
291
+
292
+ for conv in all_conversations:
293
+ method_stats[conv["generation_method"]] += 1
294
+ for turn in conv["turns"]:
295
+ intent_counts[turn["intent"]] += 1
296
+
297
+ # Step 6: Populate dataset structure
298
+ self.dataset["conversations"] = {conv["id"]: conv for conv in all_conversations}
299
+
300
+ self.dataset["generation_methods"] = dict(method_stats)
301
+ self.dataset["intent_distribution"] = dict(intent_counts)
302
+
303
+ # Step 7: Update metadata
304
+ total_user_messages = sum(len(conv["turns"]) for conv in all_conversations)
305
+
306
+ self.dataset["metadata"].update(
307
+ {
308
+ "total_conversations": len(all_conversations),
309
+ "total_user_messages": total_user_messages,
310
+ "chunks_used": total_chunks,
311
+ "conversations_per_chunk": num_conversations_per_chunk,
312
+ "generation_distribution": dict(method_stats),
313
+ "completion_timestamp": time.time(),
314
+ }
315
+ )
316
+
317
+ # Step 8: Save dataset
318
+ os.makedirs(
319
+ os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
320
+ exist_ok=True,
321
+ )
322
+
323
+ with open(save_path, "w", encoding="utf-8") as f:
324
+ json.dump(self.dataset, f, ensure_ascii=False, indent=2)
325
+
326
+ print(f"\n✅ Intent classification dataset created successfully!")
327
+ print(f" 📁 Saved to: {save_path}")
328
+ print(f" 📊 Statistics:")
329
+ print(f" • Total conversations: {len(all_conversations)}")
330
+ print(f" • Total user messages: {total_user_messages}")
331
+ print(f" • Conversations per chunk: {num_conversations_per_chunk}")
332
+ print(f" • Generation methods: {dict(method_stats)}")
333
+ print(f" • Intent distribution: {dict(intent_counts)}")
334
+
335
+ return self.dataset
336
+
337
+
338
+ class IntentClassificationEvaluator:
339
+ """Evaluator for intent classification performance with method-specific analysis"""
340
+
341
+ def __init__(self, dataset: Dict, llm_client):
342
+ """Initialize evaluator with dataset and LLM client"""
343
+ self.dataset = dataset
344
+ self.llm_client = llm_client
345
+
346
+ # Define expected intents
347
+ self.expected_intents = [
348
+ "greeting",
349
+ "faq",
350
+ "error_help",
351
+ "procedure_guide",
352
+ "human_request",
353
+ "out_of_scope",
354
+ "unclear",
355
+ ]
356
+
357
+ # Critical intents for business
358
+ self.critical_intents = ["error_help", "human_request"]
359
+
360
+ # Define flow mappings based on agent routing logic
361
+ self.script_based_intents = {
362
+ "greeting",
363
+ "out_of_scope",
364
+ "human_request",
365
+ "unclear",
366
+ }
367
+ self.knowledge_based_intents = {
368
+ "faq",
369
+ "error_help",
370
+ "procedure_guide",
371
+ }
372
+
373
+ def _get_intent_flow(self, intent: str) -> str:
374
+ """Classify intent into flow type based on agent routing logic"""
375
+ if intent in self.script_based_intents:
376
+ return "script_based"
377
+ elif intent in self.knowledge_based_intents:
378
+ return "knowledge_based"
379
+ else:
380
+ return "unknown"
381
+
382
+ def _make_json_serializable(self, obj):
383
+ """Convert numpy types to native Python types for JSON serialization"""
384
+ try:
385
+ import numpy as np
386
+
387
+ if isinstance(obj, dict):
388
+ return {k: self._make_json_serializable(v) for k, v in obj.items()}
389
+ elif isinstance(obj, list):
390
+ return [self._make_json_serializable(item) for item in obj]
391
+ elif isinstance(obj, np.integer):
392
+ return int(obj)
393
+ elif isinstance(obj, np.floating):
394
+ return float(obj)
395
+ elif isinstance(obj, np.ndarray):
396
+ return obj.tolist()
397
+ else:
398
+ return obj
399
+ except ImportError:
400
+ # If numpy is not available, just return the object as-is
401
+ if isinstance(obj, dict):
402
+ return {k: self._make_json_serializable(v) for k, v in obj.items()}
403
+ elif isinstance(obj, list):
404
+ return [self._make_json_serializable(item) for item in obj]
405
+ else:
406
+ return obj
407
+
408
+ def calculate_essential_metrics(
409
+ self, ground_truth: List[str], predictions: List[str]
410
+ ) -> Dict:
411
+ """Calculate only essential metrics: accuracy, macro, per-class"""
412
+ try:
413
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
414
+
415
+ overall_accuracy = accuracy_score(ground_truth, predictions)
416
+
417
+ # Calculate macro metrics (equal weight per intent)
418
+ precision, recall, f1, support = precision_recall_fscore_support(
419
+ ground_truth, predictions, average="macro", zero_division=0
420
+ )
421
+
422
+ macro_metrics = {
423
+ "macro_precision": precision,
424
+ "macro_recall": recall,
425
+ "macro_f1": f1,
426
+ }
427
+
428
+ # Calculate per-class metrics
429
+ precision_per_class, recall_per_class, f1_per_class, support_per_class = (
430
+ precision_recall_fscore_support(
431
+ ground_truth, predictions, average=None, zero_division=0
432
+ )
433
+ )
434
+
435
+ # Get unique labels
436
+ unique_labels = sorted(list(set(ground_truth + predictions)))
437
+
438
+ per_class_metrics = {}
439
+ for i, label in enumerate(unique_labels):
440
+ if i < len(precision_per_class):
441
+ per_class_metrics[label] = {
442
+ "precision": float(precision_per_class[i]),
443
+ "recall": float(recall_per_class[i]),
444
+ "f1": float(f1_per_class[i]),
445
+ "support": int(
446
+ support_per_class[i] if i < len(support_per_class) else 0
447
+ ),
448
+ }
449
+
450
+ # Calculate critical intent recall
451
+ critical_recall = {}
452
+ for intent in self.critical_intents:
453
+ if intent in per_class_metrics:
454
+ critical_recall[intent] = per_class_metrics[intent]["recall"]
455
+
456
+ return {
457
+ "overall_accuracy": float(overall_accuracy),
458
+ "macro_precision": float(macro_metrics["macro_precision"]),
459
+ "macro_recall": float(macro_metrics["macro_recall"]),
460
+ "macro_f1": float(macro_metrics["macro_f1"]),
461
+ "per_class_metrics": per_class_metrics,
462
+ "critical_intent_recall": {
463
+ k: float(v) for k, v in critical_recall.items()
464
+ },
465
+ }
466
+
467
+ except ImportError:
468
+ print("⚠️ scikit-learn not installed. Using basic accuracy only.")
469
+ overall_accuracy = sum(
470
+ 1 for gt, pred in zip(ground_truth, predictions) if gt == pred
471
+ ) / len(predictions)
472
+
473
+ return {"overall_accuracy": float(overall_accuracy)}
474
+
475
+ def evaluate_intent_classification(self) -> Dict:
476
+ """Evaluate intent classification performance with method and flow breakdown"""
477
+ print(f"\n🎯 Running intent classification evaluation...")
478
+
479
+ conversations = self.dataset["conversations"]
480
+
481
+ # Initialize tracking
482
+ all_predictions = []
483
+ all_ground_truth = []
484
+ method_results = defaultdict(lambda: {"predictions": [], "ground_truth": []})
485
+ flow_results = defaultdict(lambda: {"predictions": [], "ground_truth": []})
486
+ conversation_results = {}
487
+
488
+ # Process each conversation
489
+ for conv_id, conv_data in tqdm(
490
+ conversations.items(), desc="Evaluating conversations"
491
+ ):
492
+ generation_method = conv_data.get("generation_method", "unknown")
493
+
494
+ conversation_results[conv_id] = {
495
+ "turns": [],
496
+ "accuracy": 0,
497
+ "generation_method": generation_method,
498
+ }
499
+
500
+ correct_predictions = 0
501
+ total_turns = len(conv_data["turns"])
502
+
503
+ # Process each turn in the conversation
504
+ for turn_idx, turn in enumerate(conv_data["turns"]):
505
+ user_message = turn["user"]
506
+ ground_truth_intent = turn["intent"]
507
+
508
+ try:
509
+ # Create messages in the format expected by classify_intent_node
510
+ messages = [HumanMessage(content=user_message)]
511
+
512
+ # Create a mock state for the intent classification node
513
+ state = ViettelPayState(messages=messages)
514
+
515
+ # Use the classify_intent_node directly
516
+ result_state = classify_intent_node(state, self.llm_client)
517
+ predicted_intent = result_state.get("intent", "unclear")
518
+
519
+ # Track results
520
+ is_correct = predicted_intent == ground_truth_intent
521
+ if is_correct:
522
+ correct_predictions += 1
523
+
524
+ # Add to overall tracking
525
+ all_predictions.append(predicted_intent)
526
+ all_ground_truth.append(ground_truth_intent)
527
+
528
+ # Add to method-specific tracking
529
+ method_results[generation_method]["predictions"].append(
530
+ predicted_intent
531
+ )
532
+ method_results[generation_method]["ground_truth"].append(
533
+ ground_truth_intent
534
+ )
535
+
536
+ # Add to flow-specific tracking
537
+ ground_truth_flow = self._get_intent_flow(ground_truth_intent)
538
+ predicted_flow = self._get_intent_flow(predicted_intent)
539
+
540
+ flow_results[ground_truth_flow]["predictions"].append(
541
+ predicted_intent
542
+ )
543
+ flow_results[ground_truth_flow]["ground_truth"].append(
544
+ ground_truth_intent
545
+ )
546
+
547
+ conversation_results[conv_id]["turns"].append(
548
+ {
549
+ "turn": turn_idx + 1,
550
+ "user_message": user_message,
551
+ "ground_truth": ground_truth_intent,
552
+ "predicted": predicted_intent,
553
+ "correct": is_correct,
554
+ }
555
+ )
556
+
557
+ except Exception as e:
558
+ print(f"⚠️ Error processing turn {turn_idx} in {conv_id}: {e}")
559
+ # Use "unclear" as fallback prediction
560
+ all_predictions.append("unclear")
561
+ all_ground_truth.append(ground_truth_intent)
562
+ method_results[generation_method]["predictions"].append("unclear")
563
+ method_results[generation_method]["ground_truth"].append(
564
+ ground_truth_intent
565
+ )
566
+
567
+ # Add to flow-specific tracking (for errors)
568
+ ground_truth_flow = self._get_intent_flow(ground_truth_intent)
569
+ flow_results[ground_truth_flow]["predictions"].append("unclear")
570
+ flow_results[ground_truth_flow]["ground_truth"].append(
571
+ ground_truth_intent
572
+ )
573
+
574
+ # Calculate conversation accuracy
575
+ conversation_results[conv_id]["accuracy"] = float(
576
+ correct_predictions / total_turns if total_turns > 0 else 0
577
+ )
578
+
579
+ # Calculate overall metrics
580
+ overall_metrics = self.calculate_essential_metrics(
581
+ all_ground_truth, all_predictions
582
+ )
583
+
584
+ # Calculate method-specific metrics
585
+ method_metrics = {}
586
+ for method, method_data in method_results.items():
587
+ if method_data["predictions"]: # Ensure we have data
588
+ method_metrics[method] = self.calculate_essential_metrics(
589
+ method_data["ground_truth"], method_data["predictions"]
590
+ )
591
+ method_metrics[method]["total_messages"] = len(
592
+ method_data["predictions"]
593
+ )
594
+
595
+ # Calculate flow-specific metrics
596
+ flow_metrics = {}
597
+ for flow, flow_data in flow_results.items():
598
+ if flow_data["predictions"]: # Ensure we have data
599
+ flow_metrics[flow] = self.calculate_essential_metrics(
600
+ flow_data["ground_truth"], flow_data["predictions"]
601
+ )
602
+ flow_metrics[flow]["total_messages"] = len(flow_data["predictions"])
603
+
604
+ results = {
605
+ "overall_metrics": overall_metrics,
606
+ "method_specific_metrics": method_metrics,
607
+ "flow_specific_metrics": flow_metrics,
608
+ "conversation_results": conversation_results,
609
+ "intent_distribution": {
610
+ "ground_truth": dict(Counter(all_ground_truth)),
611
+ "predicted": dict(Counter(all_predictions)),
612
+ },
613
+ "generation_methods": self.dataset.get("generation_methods", {}),
614
+ }
615
+
616
+ # Make sure all values are JSON serializable
617
+ results = self._make_json_serializable(results)
618
+
619
+ return results
620
+
621
+ def print_evaluation_results(self, results: Dict):
622
+ """Print comprehensive evaluation results"""
623
+ print(f"\n🎯 INTENT CLASSIFICATION EVALUATION RESULTS")
624
+ print("=" * 60)
625
+
626
+ # Overall performance
627
+ overall = results["overall_metrics"]
628
+ print(f"\n📊 Overall Performance:")
629
+ print(f" Accuracy: {overall['overall_accuracy']:.3f}")
630
+ if "macro_precision" in overall:
631
+ print(f" Macro Precision: {overall['macro_precision']:.3f}")
632
+ print(f" Macro Recall: {overall['macro_recall']:.3f}")
633
+ print(f" Macro F1: {overall['macro_f1']:.3f}")
634
+
635
+ # Per-class performance
636
+ if "per_class_metrics" in overall:
637
+ print(f"\n📋 Per-Class Performance:")
638
+ print(
639
+ f"{'Intent':<15} {'Precision':<10} {'Recall':<10} {'F1':<10} {'Support':<10}"
640
+ )
641
+ print("-" * 65)
642
+
643
+ per_class = overall["per_class_metrics"]
644
+ for intent in self.expected_intents:
645
+ if intent in per_class:
646
+ metrics = per_class[intent]
647
+ print(
648
+ f"{intent:<15} {metrics['precision']:<10.3f} {metrics['recall']:<10.3f} {metrics['f1']:<10.3f} {metrics['support']:<10}"
649
+ )
650
+
651
+ # Critical intents performance
652
+ if "critical_intent_recall" in overall:
653
+ print(f"\n🚨 Critical Intent Performance:")
654
+ for intent, recall in overall["critical_intent_recall"].items():
655
+ status = "✅" if recall >= 0.85 else "⚠️" if recall >= 0.75 else "❌"
656
+ print(f" {status} {intent}: Recall = {recall:.3f}")
657
+
658
+ # Method-specific performance
659
+ print(f"\n🔄 Performance by Generation Method:")
660
+ method_metrics = results["method_specific_metrics"]
661
+ if method_metrics:
662
+ print(f"{'Method':<20} {'Accuracy':<10} {'Macro F1':<10} {'Messages':<10}")
663
+ print("-" * 55)
664
+
665
+ for method, metrics in method_metrics.items():
666
+ accuracy = metrics["overall_accuracy"]
667
+ macro_f1 = metrics.get("macro_f1", 0)
668
+ total_msgs = metrics["total_messages"]
669
+ print(
670
+ f"{method:<20} {accuracy:<10.3f} {macro_f1:<10.3f} {total_msgs:<10}"
671
+ )
672
+
673
+ # Flow-specific performance
674
+ print(f"\n🔀 Performance by Agent Flow:")
675
+ flow_metrics = results["flow_specific_metrics"]
676
+ if flow_metrics:
677
+ print(
678
+ f"{'Flow Type':<20} {'Accuracy':<10} {'Macro F1':<10} {'Messages':<10}"
679
+ )
680
+ print("-" * 55)
681
+
682
+ for flow, metrics in flow_metrics.items():
683
+ accuracy = metrics["overall_accuracy"]
684
+ macro_f1 = metrics.get("macro_f1", 0)
685
+ total_msgs = metrics["total_messages"]
686
+ flow_display = f"{flow}_flow"
687
+ print(
688
+ f"{flow_display:<20} {accuracy:<10.3f} {macro_f1:<10.3f} {total_msgs:<10}"
689
+ )
690
+
691
+ # Intent distribution comparison
692
+ print(f"\n📈 Intent Distribution:")
693
+ gt_dist = results["intent_distribution"]["ground_truth"]
694
+ pred_dist = results["intent_distribution"]["predicted"]
695
+
696
+ print(f"{'Intent':<15} {'Ground Truth':<15} {'Predicted':<15}")
697
+ print("-" * 50)
698
+
699
+ all_intents = set(list(gt_dist.keys()) + list(pred_dist.keys()))
700
+ for intent in sorted(all_intents):
701
+ gt_count = gt_dist.get(intent, 0)
702
+ pred_count = pred_dist.get(intent, 0)
703
+ print(f"{intent:<15} {gt_count:<15} {pred_count:<15}")
704
+
705
+ # Method insights
706
+ print(f"\n💡 Method-Specific Insights:")
707
+ if method_metrics:
708
+ method_accuracies = {
709
+ method: metrics["overall_accuracy"]
710
+ for method, metrics in method_metrics.items()
711
+ }
712
+ best_method = max(
713
+ method_accuracies.keys(), key=lambda k: method_accuracies[k]
714
+ )
715
+ worst_method = min(
716
+ method_accuracies.keys(), key=lambda k: method_accuracies[k]
717
+ )
718
+
719
+ print(
720
+ f" • Best performing method: {best_method} ({method_accuracies[best_method]:.3f})"
721
+ )
722
+ print(
723
+ f" • Most challenging method: {worst_method} ({method_accuracies[worst_method]:.3f})"
724
+ )
725
+ print(
726
+ f" • Performance gap: {method_accuracies[best_method] - method_accuracies[worst_method]:.3f}"
727
+ )
728
+
729
+ # Flow insights
730
+ print(f"\n🔀 Flow-Specific Insights:")
731
+ if flow_metrics:
732
+ flow_accuracies = {
733
+ flow: metrics["overall_accuracy"]
734
+ for flow, metrics in flow_metrics.items()
735
+ }
736
+
737
+ if len(flow_accuracies) >= 2:
738
+ best_flow = max(
739
+ flow_accuracies.keys(), key=lambda k: flow_accuracies[k]
740
+ )
741
+ worst_flow = min(
742
+ flow_accuracies.keys(), key=lambda k: flow_accuracies[k]
743
+ )
744
+
745
+ print(
746
+ f" • Best performing flow: {best_flow} ({flow_accuracies[best_flow]:.3f})"
747
+ )
748
+ print(
749
+ f" • Most challenging flow: {worst_flow} ({flow_accuracies[worst_flow]:.3f})"
750
+ )
751
+ print(
752
+ f" • Flow performance gap: {flow_accuracies[best_flow] - flow_accuracies[worst_flow]:.3f}"
753
+ )
754
+
755
+ # Provide interpretation
756
+ if (
757
+ "script_based" in flow_accuracies
758
+ and "knowledge_based" in flow_accuracies
759
+ ):
760
+ script_acc = flow_accuracies["script_based"]
761
+ kb_acc = flow_accuracies["knowledge_based"]
762
+
763
+ if script_acc > kb_acc:
764
+ print(
765
+ f" • Script-based intents are easier to classify ({script_acc:.3f} vs {kb_acc:.3f})"
766
+ )
767
+ elif kb_acc > script_acc:
768
+ print(
769
+ f" • Knowledge-based intents are easier to classify ({kb_acc:.3f} vs {script_acc:.3f})"
770
+ )
771
+ else:
772
+ print(
773
+ f" • Both flows perform similarly ({script_acc:.3f} vs {kb_acc:.3f})"
774
+ )
775
+ else:
776
+ for flow, accuracy in flow_accuracies.items():
777
+ print(f" • {flow} flow accuracy: {accuracy:.3f}")
778
+
779
+ # Success criteria check
780
+ print(f"\n✅ Success Criteria Check:")
781
+ accuracy = overall["overall_accuracy"]
782
+ if accuracy >= 0.80:
783
+ print(f" 🎉 GOOD: Overall accuracy {accuracy:.3f} >= 0.80")
784
+ elif accuracy >= 0.75:
785
+ print(f" ⚠️ OKAY: Overall accuracy {accuracy:.3f} >= 0.75")
786
+ else:
787
+ print(f" ❌ NEEDS WORK: Overall accuracy {accuracy:.3f} < 0.75")
788
+
789
+
790
+ def main():
791
+ """Main function for simplified intent classification evaluation"""
792
+ parser = argparse.ArgumentParser(
793
+ description="Simplified ViettelPay Intent Classification Evaluation"
794
+ )
795
+ parser.add_argument(
796
+ "--mode",
797
+ choices=["create", "evaluate", "full"],
798
+ default="full",
799
+ help="Mode: create dataset, evaluate, or full pipeline",
800
+ )
801
+ parser.add_argument(
802
+ "--dataset-path",
803
+ default="evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json",
804
+ help="Path to intent dataset",
805
+ )
806
+ parser.add_argument(
807
+ "--results-path",
808
+ default="evaluation_data/results/intent_classification/viettelpay_intent_results.json",
809
+ help="Path to save evaluation results",
810
+ )
811
+ parser.add_argument(
812
+ "--conversations-per-chunk",
813
+ type=int,
814
+ default=3,
815
+ help="Number of conversations per chunk (default: 3)",
816
+ )
817
+ parser.add_argument(
818
+ "--knowledge-base-path",
819
+ default="./knowledge_base",
820
+ help="Path to knowledge base",
821
+ )
822
+
823
+ args = parser.parse_args()
824
+
825
+ # Configuration
826
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
827
+
828
+ if not GEMINI_API_KEY:
829
+ print("❌ Please set GEMINI_API_KEY environment variable")
830
+ return
831
+
832
+ try:
833
+ # Initialize components based on mode
834
+ kb = None
835
+ if args.mode in ["create", "full"]:
836
+ # Initialize knowledge base only if creating dataset
837
+ print("🔧 Initializing ViettelPay knowledge base...")
838
+ kb = ViettelKnowledgeBase()
839
+ if not kb.load_knowledge_base(args.knowledge_base_path):
840
+ print(
841
+ "❌ Failed to load knowledge base. Please run build_database_script.py first."
842
+ )
843
+ return
844
+
845
+ # Step 1: Create dataset if requested
846
+ if args.mode in ["create", "full"]:
847
+ print(f"\n🎯 Creating simplified intent classification dataset...")
848
+ creator = IntentDatasetCreator(GEMINI_API_KEY, kb)
849
+
850
+ dataset = creator.create_intent_dataset(
851
+ num_conversations_per_chunk=args.conversations_per_chunk,
852
+ save_path=args.dataset_path,
853
+ )
854
+
855
+ # Step 2: Evaluate if requested
856
+ if args.mode in ["evaluate", "full"]:
857
+ print(f"\n📊 Evaluating intent classification...")
858
+
859
+ # Load dataset if not created in this run
860
+ if args.mode == "evaluate":
861
+ if not os.path.exists(args.dataset_path):
862
+ print(f"❌ Dataset not found: {args.dataset_path}")
863
+ return
864
+
865
+ with open(args.dataset_path, "r", encoding="utf-8") as f:
866
+ dataset = json.load(f)
867
+
868
+ # Initialize LLM client for intent classification
869
+ print("🤖 Initializing LLM client for intent classification...")
870
+ llm_client = LLMClientFactory.create_client(
871
+ "gemini", api_key=GEMINI_API_KEY, model="gemini-2.0-flash"
872
+ )
873
+
874
+ # Run evaluation
875
+ evaluator = IntentClassificationEvaluator(dataset, llm_client)
876
+ results = evaluator.evaluate_intent_classification()
877
+ evaluator.print_evaluation_results(results)
878
+
879
+ # Save results
880
+ if args.results_path:
881
+ with open(args.results_path, "w", encoding="utf-8") as f:
882
+ json.dump(results, f, ensure_ascii=False, indent=2)
883
+ print(f"\n💾 Results saved to: {args.results_path}")
884
+
885
+ print(f"\n✅ Intent classification evaluation completed successfully!")
886
+ print(f"\n💡 Summary improvements made:")
887
+ print(f" • Removed pattern-based generation for simplicity")
888
+ print(f" • Added configurable conversations-per-chunk (default: 3)")
889
+ print(f" • Improved chunk mixing (random 2-3 chunks)")
890
+ print(f" • Enhanced prompts to include non-topic intents")
891
+ print(f" • Added flow-specific analysis (script-based vs knowledge-based)")
892
+
893
+ except Exception as e:
894
+ print(f"❌ Error in main execution: {e}")
895
+ import traceback
896
+
897
+ traceback.print_exc()
898
+
899
+
900
+ if __name__ == "__main__":
901
+ main()
src/evaluation/multi_turn_retrieval.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Turn Conversation Retrieval Evaluation for ViettelPay RAG System
3
+ Generates multi-turn conversations and evaluates retrieval performance
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import time
11
+ from typing import Dict, List, Tuple, Optional, Union
12
+ from pathlib import Path
13
+ from collections import defaultdict
14
+ import pandas as pd
15
+ from tqdm import tqdm
16
+ import re
17
+
18
+ # Load environment variables from .env file
19
+ from dotenv import load_dotenv
20
+
21
+ load_dotenv()
22
+
23
+ # Add the project root to Python path so we can import from src
24
+ project_root = Path(__file__).parent.parent.parent
25
+ sys.path.insert(0, str(project_root))
26
+
27
+ # Import existing components
28
+ from src.evaluation.prompts import MULTI_TURN_CONVERSATION_GENERATION_PROMPT
29
+ from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
30
+ from src.evaluation.single_turn_retrieval import SingleTurnRetrievalEvaluator
31
+ from src.llm.llm_client import LLMClientFactory, BaseLLMClient
32
+ from src.agent.nodes import query_enhancement_node, ViettelPayState
33
+ from langchain_core.messages import HumanMessage
34
+
35
+
36
+ class MultiTurnDatasetCreator:
37
+ """Multi-turn conversation dataset creator for ViettelPay evaluation"""
38
+
39
+ def __init__(
40
+ self, gemini_api_key: str, knowledge_base: ViettelKnowledgeBase = None
41
+ ):
42
+ """
43
+ Initialize with Gemini API key and optional knowledge base
44
+
45
+ Args:
46
+ gemini_api_key: Google AI API key for Gemini
47
+ knowledge_base: Pre-initialized ViettelKnowledgeBase instance
48
+ """
49
+ self.llm_client = LLMClientFactory.create_client(
50
+ "gemini", api_key=gemini_api_key, model="gemini-2.0-flash"
51
+ )
52
+ self.knowledge_base = knowledge_base
53
+ self.dataset = {
54
+ "conversations": {},
55
+ "documents": {},
56
+ "metadata": {
57
+ "total_chunks_processed": 0,
58
+ "conversations_generated": 0,
59
+ "creation_timestamp": time.time(),
60
+ },
61
+ }
62
+
63
+ print("✅ MultiTurnDatasetCreator initialized with Gemini 2.0 Flash")
64
+
65
+ def generate_json_response(
66
+ self, prompt: str, max_retries: int = 3
67
+ ) -> Optional[Dict]:
68
+ """
69
+ Generate response and parse as JSON with retries
70
+
71
+ Args:
72
+ prompt: Input prompt
73
+ max_retries: Maximum number of retry attempts
74
+
75
+ Returns:
76
+ Parsed JSON response or None if failed
77
+ """
78
+ for attempt in range(max_retries):
79
+ try:
80
+ response = self.llm_client.generate(prompt, temperature=0.1)
81
+
82
+ if response:
83
+ # Clean response text
84
+ response_text = response.strip()
85
+
86
+ # Extract JSON from response (handle cases with extra text)
87
+ json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
88
+ if json_match:
89
+ json_text = json_match.group()
90
+ return json.loads(json_text)
91
+ else:
92
+ # Try parsing the whole response
93
+ return json.loads(response_text)
94
+
95
+ except json.JSONDecodeError as e:
96
+ print(f"⚠️ JSON parsing error (attempt {attempt + 1}): {e}")
97
+ if attempt == max_retries - 1:
98
+ print(f"❌ Failed to parse JSON after {max_retries} attempts")
99
+ print(
100
+ f"Raw response: {response if 'response' in locals() else 'No response'}"
101
+ )
102
+
103
+ except Exception as e:
104
+ print(f"⚠️ API error (attempt {attempt + 1}): {e}")
105
+ if attempt < max_retries - 1:
106
+ time.sleep(2**attempt) # Exponential backoff
107
+
108
+ return None
109
+
110
+ def get_all_chunks(self) -> List[Dict]:
111
+ """
112
+ Get ALL chunks directly from ChromaDB vectorstore
113
+ Reuse the same method from single-turn evaluation
114
+
115
+ Returns:
116
+ List of all document chunks with content and metadata
117
+ """
118
+ print(f"📚 Retrieving ALL chunks directly from ChromaDB vectorstore...")
119
+
120
+ if not self.knowledge_base:
121
+ raise ValueError(
122
+ "Knowledge base not provided. Please initialize with a ViettelKnowledgeBase instance."
123
+ )
124
+
125
+ try:
126
+ # Access the ChromaDB vectorstore directly
127
+ if (
128
+ not hasattr(self.knowledge_base, "chroma_retriever")
129
+ or not self.knowledge_base.chroma_retriever
130
+ ):
131
+ raise ValueError("ChromaDB retriever not found in knowledge base")
132
+
133
+ # Get the vectorstore from the retriever
134
+ vectorstore = self.knowledge_base.chroma_retriever.vectorstore
135
+
136
+ # Get all documents directly from ChromaDB
137
+ print(" Accessing ChromaDB collection...")
138
+ all_docs = vectorstore.get(include=["documents", "metadatas"])
139
+
140
+ documents = all_docs["documents"]
141
+ metadatas = all_docs["metadatas"]
142
+
143
+ print(f" Found {len(documents)} documents in ChromaDB")
144
+
145
+ # Convert to our expected format
146
+ all_chunks = []
147
+ seen_content_hashes = set()
148
+
149
+ for i, (content, metadata) in enumerate(zip(documents, metadatas)):
150
+ # Create content hash for deduplication
151
+ content_hash = hash(content[:300])
152
+
153
+ if (
154
+ content_hash not in seen_content_hashes
155
+ and len(content.strip()) > 100
156
+ ):
157
+ chunk_info = {
158
+ "id": f"chunk_{len(all_chunks)}",
159
+ "content": content,
160
+ "metadata": metadata or {},
161
+ "source": "chromadb_direct",
162
+ "content_length": len(content),
163
+ "original_index": i,
164
+ }
165
+ all_chunks.append(chunk_info)
166
+ seen_content_hashes.add(content_hash)
167
+
168
+ print(f"✅ Retrieved {len(all_chunks)} unique chunks from ChromaDB")
169
+
170
+ # Sort by content length (longer chunks first)
171
+ all_chunks.sort(key=lambda x: x["content_length"], reverse=True)
172
+
173
+ return all_chunks
174
+
175
+ except Exception as e:
176
+ print(f"❌ Error accessing ChromaDB directly: {e}")
177
+ return []
178
+
179
+ def generate_conversations_for_chunk(
180
+ self, chunk: Dict, num_conversations: int = 2
181
+ ) -> List[Dict]:
182
+ """
183
+ Generate multi-turn conversations for a single chunk using Gemini
184
+
185
+ Args:
186
+ chunk: Chunk dictionary with content and metadata
187
+ num_conversations: Number of conversations to generate per chunk
188
+
189
+ Returns:
190
+ List of conversation dictionaries
191
+ """
192
+ content = chunk["content"]
193
+
194
+ prompt = MULTI_TURN_CONVERSATION_GENERATION_PROMPT.format(
195
+ num_conversations=num_conversations, content=content
196
+ )
197
+
198
+ response_json = self.generate_json_response(prompt)
199
+
200
+ if response_json and "conversations" in response_json:
201
+ conversations = response_json["conversations"]
202
+
203
+ # Create conversation objects with metadata
204
+ conversation_objects = []
205
+ for i, conversation in enumerate(conversations):
206
+ if len(conversation.get("turns", [])) >= 2: # At least 2 turns
207
+ conversation_obj = {
208
+ "id": f"conv_{chunk['id']}_{i}",
209
+ "turns": conversation["turns"],
210
+ "conversation_type": conversation.get("type", "general"),
211
+ "source_chunk": chunk["id"],
212
+ "chunk_metadata": chunk["metadata"],
213
+ "generation_method": "gemini_json",
214
+ }
215
+ conversation_objects.append(conversation_obj)
216
+
217
+ return conversation_objects
218
+ else:
219
+ print(f"⚠️ No valid conversations generated for chunk {chunk['id']}")
220
+ return []
221
+
222
+ def create_multi_turn_dataset(
223
+ self,
224
+ conversations_per_chunk: int = 2,
225
+ save_path: str = "evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json",
226
+ ) -> Dict:
227
+ """
228
+ Create multi-turn conversation dataset using ALL chunks
229
+
230
+ Args:
231
+ conversations_per_chunk: Number of conversations to generate per chunk
232
+ save_path: Path to save the dataset JSON file
233
+
234
+ Returns:
235
+ Complete dataset dictionary with conversations
236
+ """
237
+ print(f"\n🚀 Creating multi-turn conversation dataset...")
238
+ print(f" Target: Process ALL chunks from knowledge base")
239
+ print(f" Conversations per chunk: {conversations_per_chunk}")
240
+
241
+ # Step 1: Get all chunks
242
+ all_chunks = self.get_all_chunks()
243
+ total_chunks = len(all_chunks)
244
+
245
+ if total_chunks == 0:
246
+ raise ValueError("No chunks found in knowledge base!")
247
+
248
+ print(f"✅ Found {total_chunks} chunks to process")
249
+
250
+ # Step 2: Generate conversations for all chunks
251
+ print(f"\n💬 Generating conversations for {total_chunks} chunks...")
252
+ all_conversations = []
253
+
254
+ for chunk in tqdm(all_chunks, desc="Generating conversations"):
255
+ conversations = self.generate_conversations_for_chunk(
256
+ chunk, conversations_per_chunk
257
+ )
258
+ all_conversations.extend(conversations)
259
+ time.sleep(0.2) # Rate limiting for Gemini API
260
+
261
+ # Step 3: Populate dataset structure
262
+ self.dataset["documents"] = {
263
+ chunk["id"]: chunk["content"] for chunk in all_chunks
264
+ }
265
+ self.dataset["conversations"] = {
266
+ conv["id"]: {
267
+ "turns": conv["turns"],
268
+ "conversation_type": conv["conversation_type"],
269
+ "source_chunk": conv["source_chunk"],
270
+ "chunk_metadata": conv["chunk_metadata"],
271
+ "generation_method": conv["generation_method"],
272
+ }
273
+ for conv in all_conversations
274
+ }
275
+
276
+ # Step 4: Update metadata
277
+ self.dataset["metadata"].update(
278
+ {
279
+ "total_chunks_processed": total_chunks,
280
+ "conversations_generated": len(all_conversations),
281
+ "conversations_per_chunk": conversations_per_chunk,
282
+ "completion_timestamp": time.time(),
283
+ }
284
+ )
285
+
286
+ # Step 5: Save dataset
287
+ os.makedirs(
288
+ os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
289
+ exist_ok=True,
290
+ )
291
+
292
+ with open(save_path, "w", encoding="utf-8") as f:
293
+ json.dump(self.dataset, f, ensure_ascii=False, indent=2)
294
+
295
+ print(f"\n✅ Multi-turn conversation dataset created successfully!")
296
+ print(f" 📁 Saved to: {save_path}")
297
+ print(f" 📊 Statistics:")
298
+ print(f" • Chunks processed: {total_chunks}")
299
+ print(f" • Conversations generated: {len(all_conversations)}")
300
+ print(
301
+ f" • Avg conversations per chunk: {len(all_conversations)/total_chunks:.1f}"
302
+ )
303
+
304
+ return self.dataset
305
+
306
+
307
+ class ConversationEnhancer:
308
+ """Convert multi-turn conversations to enhanced queries using existing query enhancement"""
309
+
310
+ def __init__(self, gemini_api_key: str):
311
+ """Initialize with Gemini API key for query enhancement"""
312
+ self.llm_client = LLMClientFactory.create_client(
313
+ "gemini", api_key=gemini_api_key, model="gemini-2.0-flash-lite"
314
+ )
315
+ print("✅ ConversationEnhancer initialized")
316
+
317
+ def enhance_conversation(self, conversation_turns: List[Dict]) -> str:
318
+ """
319
+ Convert a multi-turn conversation to an enhanced query
320
+
321
+ Args:
322
+ conversation_turns: List of turn dictionaries with role and content
323
+
324
+ Returns:
325
+ Enhanced query string
326
+ """
327
+ try:
328
+ # Create messages in the format expected by query_enhancement_node
329
+ messages = []
330
+ for turn in conversation_turns:
331
+ if turn["role"] == "user":
332
+ messages.append(HumanMessage(content=turn["content"]))
333
+
334
+ # Create a mock state for the query enhancement node
335
+ state = ViettelPayState(messages=messages)
336
+
337
+ # Use the existing query enhancement node
338
+ enhanced_state = query_enhancement_node(state, self.llm_client)
339
+
340
+ enhanced_query = enhanced_state.get("enhanced_query", "")
341
+
342
+ if not enhanced_query:
343
+ # Fallback: concatenate all user messages
344
+ user_messages = [
345
+ turn["content"]
346
+ for turn in conversation_turns
347
+ if turn["role"] == "user"
348
+ ]
349
+ enhanced_query = " ".join(user_messages)
350
+
351
+ return enhanced_query
352
+
353
+ except Exception as e:
354
+ print(f"❌ Error enhancing conversation: {e}")
355
+ # Fallback: concatenate all user messages
356
+ user_messages = [
357
+ turn["content"] for turn in conversation_turns if turn["role"] == "user"
358
+ ]
359
+ return " ".join(user_messages)
360
+
361
+ def convert_dataset_to_single_turn_format(
362
+ self,
363
+ multi_turn_dataset: Dict,
364
+ save_path: str = "evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json",
365
+ ) -> Dict:
366
+ """
367
+ Convert multi-turn conversation dataset to single-turn format with enhanced queries
368
+
369
+ Args:
370
+ multi_turn_dataset: Multi-turn conversation dataset
371
+ save_path: Path to save the converted dataset
372
+
373
+ Returns:
374
+ Single-turn format dataset
375
+ """
376
+ print(f"\n🔄 Converting multi-turn conversations to enhanced queries...")
377
+
378
+ conversations = multi_turn_dataset["conversations"]
379
+ documents = multi_turn_dataset["documents"]
380
+
381
+ # Initialize single-turn format dataset
382
+ single_turn_dataset = {
383
+ "queries": {},
384
+ "documents": documents,
385
+ "conversation_metadata": {},
386
+ "metadata": {
387
+ "total_conversations_processed": len(conversations),
388
+ "enhanced_queries_generated": 0,
389
+ "conversion_timestamp": time.time(),
390
+ "original_dataset_metadata": multi_turn_dataset.get("metadata", {}),
391
+ },
392
+ }
393
+
394
+ enhanced_count = 0
395
+
396
+ # Process each conversation
397
+ for conv_id, conv_data in tqdm(
398
+ conversations.items(), desc="Enhancing conversations"
399
+ ):
400
+ try:
401
+ # Extract turns
402
+ turns = conv_data["turns"]
403
+
404
+ # Enhance conversation to single query
405
+ enhanced_query = self.enhance_conversation(turns)
406
+
407
+ if enhanced_query and len(enhanced_query.strip()) > 5:
408
+ single_turn_dataset["queries"][conv_id] = enhanced_query
409
+ single_turn_dataset["conversation_metadata"][conv_id] = {
410
+ "original_conversation": turns,
411
+ "conversation_type": conv_data.get(
412
+ "conversation_type", "general"
413
+ ),
414
+ "source_chunk": conv_data["source_chunk"],
415
+ "chunk_metadata": conv_data.get("chunk_metadata", {}),
416
+ "generation_method": conv_data.get(
417
+ "generation_method", "unknown"
418
+ ),
419
+ }
420
+ enhanced_count += 1
421
+
422
+ time.sleep(0.1) # Small delay for rate limiting
423
+
424
+ except Exception as e:
425
+ print(f"⚠️ Error processing conversation {conv_id}: {e}")
426
+ continue
427
+
428
+ # Update metadata
429
+ single_turn_dataset["metadata"]["enhanced_queries_generated"] = enhanced_count
430
+
431
+ # Save converted dataset
432
+ os.makedirs(
433
+ os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
434
+ exist_ok=True,
435
+ )
436
+
437
+ with open(save_path, "w", encoding="utf-8") as f:
438
+ json.dump(single_turn_dataset, f, ensure_ascii=False, indent=2)
439
+
440
+ print(f"✅ Conversion completed successfully!")
441
+ print(f" 📁 Saved to: {save_path}")
442
+ print(f" 📊 Statistics:")
443
+ print(f" • Conversations processed: {len(conversations)}")
444
+ print(f" • Enhanced queries generated: {enhanced_count}")
445
+ print(f" • Success rate: {enhanced_count/len(conversations)*100:.1f}%")
446
+
447
+ return single_turn_dataset
448
+
449
+
450
+ class MultiTurnEvaluator:
451
+ """Extended evaluator for multi-turn conversation retrieval with additional analysis"""
452
+
453
+ def __init__(self, dataset: Dict, knowledge_base: ViettelKnowledgeBase):
454
+ """
455
+ Initialize evaluator with dataset and knowledge base
456
+
457
+ Args:
458
+ dataset: Evaluation dataset in single-turn format (from converted multi-turn)
459
+ knowledge_base: ViettelKnowledgeBase instance to evaluate
460
+ """
461
+ self.dataset = dataset
462
+ self.knowledge_base = knowledge_base
463
+ self.single_turn_evaluator = SingleTurnRetrievalEvaluator(
464
+ dataset, knowledge_base
465
+ )
466
+
467
+ def _get_conversation_metadata(self, query_id: str) -> Dict:
468
+ """
469
+ Get conversation metadata for a query, handling both formats
470
+
471
+ Args:
472
+ query_id: Query identifier
473
+
474
+ Returns:
475
+ Metadata dictionary
476
+ """
477
+ # First try conversation_metadata (multi-turn format)
478
+ conversation_metadata = self.dataset.get("conversation_metadata", {})
479
+ if query_id in conversation_metadata:
480
+ return conversation_metadata[query_id]
481
+
482
+ # Fallback to question_metadata (single-turn format)
483
+ question_metadata = self.dataset.get("question_metadata", {})
484
+ if query_id in question_metadata:
485
+ # Convert single-turn format to multi-turn format for consistency
486
+ meta = question_metadata[query_id]
487
+ return {
488
+ "conversation_type": "single_turn",
489
+ "source_chunk": meta.get("source_chunk"),
490
+ "original_conversation": [
491
+ {"role": "user", "content": self.dataset["queries"][query_id]}
492
+ ],
493
+ "chunk_metadata": meta.get("chunk_metadata", {}),
494
+ "generation_method": meta.get("generation_method", "unknown"),
495
+ }
496
+
497
+ return {}
498
+
499
+ def evaluate_multi_turn_performance(
500
+ self, k_values: List[int] = [1, 3, 5, 10]
501
+ ) -> Dict:
502
+ """
503
+ Evaluate multi-turn conversation retrieval performance
504
+
505
+ Args:
506
+ k_values: List of k values to evaluate
507
+
508
+ Returns:
509
+ Dictionary with evaluation results and multi-turn specific analysis
510
+ """
511
+ print(f"\n🔍 Running multi-turn conversation evaluation...")
512
+
513
+ # Step 1: Run standard single-turn evaluation
514
+ base_results = self.single_turn_evaluator.evaluate(k_values)
515
+
516
+ # Step 2: Add multi-turn specific analysis
517
+ # Analyze by conversation type
518
+ results_by_type = defaultdict(
519
+ lambda: {"hit_rates": {k: [] for k in k_values}, "rr_scores": []}
520
+ )
521
+
522
+ for query_id, query_result in base_results["per_query_results"].items():
523
+ conv_meta = self._get_conversation_metadata(query_id)
524
+ conv_type = conv_meta.get("conversation_type", "unknown")
525
+
526
+ # Add to type-specific results
527
+ results_by_type[conv_type]["rr_scores"].append(query_result.get("rr", 0))
528
+ for k in k_values:
529
+ hit_rate = query_result.get("hit_rates", {}).get(k, 0)
530
+ results_by_type[conv_type]["hit_rates"][k].append(hit_rate)
531
+
532
+ # Calculate averages by conversation type
533
+ type_analysis = {}
534
+ for conv_type, type_results in results_by_type.items():
535
+ type_analysis[conv_type] = {
536
+ "hit_rates": {
537
+ k: sum(hits) / len(hits) if hits else 0
538
+ for k, hits in type_results["hit_rates"].items()
539
+ },
540
+ "mrr": (
541
+ sum(type_results["rr_scores"]) / len(type_results["rr_scores"])
542
+ if type_results["rr_scores"]
543
+ else 0
544
+ ),
545
+ "total_conversations": len(type_results["rr_scores"]),
546
+ }
547
+
548
+ # Analyze conversation length impact
549
+ turn_length_analysis = self._analyze_by_conversation_length(
550
+ base_results, k_values
551
+ )
552
+
553
+ # Combine results
554
+ multi_turn_results = {
555
+ **base_results, # Include all base results
556
+ "conversation_type_analysis": type_analysis,
557
+ "turn_length_analysis": turn_length_analysis,
558
+ "multi_turn_metadata": {
559
+ "evaluation_type": "multi_turn_conversation",
560
+ "conversation_types": list(type_analysis.keys()),
561
+ "total_conversation_types": len(type_analysis),
562
+ },
563
+ }
564
+
565
+ return multi_turn_results
566
+
567
+ def _analyze_by_conversation_length(
568
+ self, base_results: Dict, k_values: List[int]
569
+ ) -> Dict:
570
+ """Analyze performance by conversation turn length"""
571
+
572
+ length_analysis = defaultdict(
573
+ lambda: {"hit_rates": {k: [] for k in k_values}, "rr_scores": []}
574
+ )
575
+
576
+ for query_id, query_result in base_results["per_query_results"].items():
577
+ conv_meta = self._get_conversation_metadata(query_id)
578
+ original_conv = conv_meta.get("original_conversation", [])
579
+ turn_count = len(
580
+ [turn for turn in original_conv if turn.get("role") == "user"]
581
+ )
582
+
583
+ # Categorize by turn length
584
+ if turn_count == 1:
585
+ length_category = "1_turn" # Single-turn questions
586
+ elif turn_count == 2:
587
+ length_category = "2_turns"
588
+ elif turn_count == 3:
589
+ length_category = "3_turns"
590
+ elif turn_count >= 4:
591
+ length_category = "4+_turns"
592
+ else:
593
+ length_category = "unknown_turns"
594
+
595
+ # Add to length-specific results
596
+ length_analysis[length_category]["rr_scores"].append(
597
+ query_result.get("rr", 0)
598
+ )
599
+ for k in k_values:
600
+ hit_rate = query_result.get("hit_rates", {}).get(k, 0)
601
+ length_analysis[length_category]["hit_rates"][k].append(hit_rate)
602
+
603
+ # Calculate averages by turn length
604
+ final_length_analysis = {}
605
+ for length_cat, length_results in length_analysis.items():
606
+ final_length_analysis[length_cat] = {
607
+ "hit_rates": {
608
+ k: sum(hits) / len(hits) if hits else 0
609
+ for k, hits in length_results["hit_rates"].items()
610
+ },
611
+ "mrr": (
612
+ sum(length_results["rr_scores"]) / len(length_results["rr_scores"])
613
+ if length_results["rr_scores"]
614
+ else 0
615
+ ),
616
+ "total_conversations": len(length_results["rr_scores"]),
617
+ }
618
+
619
+ return final_length_analysis
620
+
621
+ def print_multi_turn_results(self, results: Dict):
622
+ """Print multi-turn evaluation results with additional analysis"""
623
+
624
+ # Print base results first
625
+ self.single_turn_evaluator.print_evaluation_results(results)
626
+
627
+ # Print multi-turn specific analysis
628
+ print(f"\n🔍 MULTI-TURN SPECIFIC ANALYSIS")
629
+ print("=" * 60)
630
+
631
+ # Conversation type analysis
632
+ type_analysis = results.get("conversation_type_analysis", {})
633
+ if type_analysis:
634
+ print(f"\n📊 Performance by Conversation Type:")
635
+ print(f"{'Type':<20} {'MRR':<8} {'Hit@5':<8} {'Count':<8}")
636
+ print("-" * 50)
637
+
638
+ for conv_type, analysis in type_analysis.items():
639
+ mrr = analysis["mrr"]
640
+ hit_at_5 = analysis["hit_rates"].get(5, 0) * 100
641
+ count = analysis["total_conversations"]
642
+ print(f"{conv_type:<20} {mrr:<8.3f} {hit_at_5:<8.1f}% {count:<8}")
643
+
644
+ # Turn length analysis
645
+ length_analysis = results.get("turn_length_analysis", {})
646
+ if length_analysis:
647
+ print(f"\n📊 Performance by Conversation Length:")
648
+ print(f"{'Length':<12} {'MRR':<8} {'Hit@5':<8} {'Count':<8}")
649
+ print("-" * 40)
650
+
651
+ for length_cat, analysis in length_analysis.items():
652
+ mrr = analysis["mrr"]
653
+ hit_at_5 = analysis["hit_rates"].get(5, 0) * 100
654
+ count = analysis["total_conversations"]
655
+ print(f"{length_cat:<12} {mrr:<8.3f} {hit_at_5:<8.1f}% {count:<8}")
656
+
657
+ print(f"\n💡 Multi-Turn Insights:")
658
+
659
+ # Best performing conversation type
660
+ if type_analysis:
661
+ best_type = max(type_analysis.keys(), key=lambda k: type_analysis[k]["mrr"])
662
+ worst_type = min(
663
+ type_analysis.keys(), key=lambda k: type_analysis[k]["mrr"]
664
+ )
665
+ print(
666
+ f" • Best conversation type: {best_type} (MRR: {type_analysis[best_type]['mrr']:.3f})"
667
+ )
668
+ print(
669
+ f" • Worst conversation type: {worst_type} (MRR: {type_analysis[worst_type]['mrr']:.3f})"
670
+ )
671
+
672
+ # Turn length insights
673
+ if length_analysis:
674
+ best_length = max(
675
+ length_analysis.keys(), key=lambda k: length_analysis[k]["mrr"]
676
+ )
677
+ print(
678
+ f" • Best performing length: {best_length} (MRR: {length_analysis[best_length]['mrr']:.3f})"
679
+ )
680
+
681
+
682
+ def main():
683
+ """Main function for multi-turn conversation evaluation"""
684
+ parser = argparse.ArgumentParser(
685
+ description="ViettelPay Multi-Turn Conversation Retrieval Evaluation"
686
+ )
687
+ parser.add_argument(
688
+ "--mode",
689
+ choices=["create", "enhance", "evaluate", "full"],
690
+ default="full",
691
+ help="Mode: create conversations, enhance to queries, evaluate, or full pipeline",
692
+ )
693
+ parser.add_argument(
694
+ "--conversations-dataset",
695
+ default="evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json",
696
+ help="Path to multi-turn conversations dataset",
697
+ )
698
+ parser.add_argument(
699
+ "--enhanced-dataset",
700
+ default="evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json",
701
+ help="Path to enhanced queries dataset",
702
+ )
703
+ parser.add_argument(
704
+ "--results-path",
705
+ default="evaluation_data/results/multi_turn_retrieval/viettelpay_multiturn_results.json",
706
+ help="Path to save evaluation results",
707
+ )
708
+ parser.add_argument(
709
+ "--conversations-per-chunk",
710
+ type=int,
711
+ default=3,
712
+ help="Number of conversations per chunk",
713
+ )
714
+ parser.add_argument(
715
+ "--k-values",
716
+ nargs="+",
717
+ type=int,
718
+ default=[1, 3, 5, 10],
719
+ help="K values for evaluation",
720
+ )
721
+ parser.add_argument(
722
+ "--knowledge-base-path",
723
+ default="./knowledge_base",
724
+ help="Path to knowledge base",
725
+ )
726
+
727
+ args = parser.parse_args()
728
+
729
+ # Configuration
730
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
731
+
732
+ if not GEMINI_API_KEY:
733
+ print("❌ Please set GEMINI_API_KEY environment variable")
734
+ return
735
+
736
+ try:
737
+ # Initialize knowledge base
738
+ print("🔧 Initializing ViettelPay knowledge base...")
739
+ kb = ViettelKnowledgeBase()
740
+ if not kb.load_knowledge_base(args.knowledge_base_path):
741
+ print(
742
+ "❌ Failed to load knowledge base. Please run build_database_script.py first."
743
+ )
744
+ return
745
+
746
+ # Step 1: Create multi-turn conversations if requested
747
+ if args.mode in ["create", "full"]:
748
+ print(f"\n🎯 Creating multi-turn conversation dataset...")
749
+ creator = MultiTurnDatasetCreator(GEMINI_API_KEY, kb)
750
+
751
+ conversations_dataset = creator.create_multi_turn_dataset(
752
+ conversations_per_chunk=args.conversations_per_chunk,
753
+ save_path=args.conversations_dataset,
754
+ )
755
+
756
+ # Step 2: Enhance conversations to queries if requested
757
+ if args.mode in ["enhance", "full"]:
758
+ print(f"\n⚡ Converting conversations to enhanced queries...")
759
+
760
+ # Load conversations if not created in this run
761
+ if args.mode == "enhance":
762
+ if not os.path.exists(args.conversations_dataset):
763
+ print(
764
+ f"❌ Conversations dataset not found: {args.conversations_dataset}"
765
+ )
766
+ return
767
+
768
+ with open(args.conversations_dataset, "r", encoding="utf-8") as f:
769
+ conversations_dataset = json.load(f)
770
+
771
+ # Enhance conversations
772
+ enhancer = ConversationEnhancer(GEMINI_API_KEY)
773
+ enhanced_dataset = enhancer.convert_dataset_to_single_turn_format(
774
+ conversations_dataset, args.enhanced_dataset
775
+ )
776
+
777
+ # Step 3: Evaluate if requested
778
+ if args.mode in ["evaluate", "full"]:
779
+ print(f"\n📊 Evaluating multi-turn conversation retrieval...")
780
+
781
+ # Load enhanced dataset if not created in this run
782
+ if args.mode == "evaluate":
783
+ if not os.path.exists(args.enhanced_dataset):
784
+ print(f"❌ Enhanced dataset not found: {args.enhanced_dataset}")
785
+ return
786
+
787
+ with open(args.enhanced_dataset, "r", encoding="utf-8") as f:
788
+ enhanced_dataset = json.load(f)
789
+
790
+ # Run evaluation
791
+ evaluator = MultiTurnEvaluator(enhanced_dataset, kb)
792
+ results = evaluator.evaluate_multi_turn_performance(k_values=args.k_values)
793
+ evaluator.print_multi_turn_results(results)
794
+
795
+ # Save results
796
+ if args.results_path:
797
+ with open(args.results_path, "w", encoding="utf-8") as f:
798
+ json.dump(results, f, ensure_ascii=False, indent=2)
799
+ print(f"\n💾 Results saved to: {args.results_path}")
800
+
801
+ print(f"\n✅ Multi-turn evaluation completed successfully!")
802
+ print(f"\n💡 Next steps:")
803
+ print(f" 1. Compare multi-turn vs single-turn performance")
804
+ print(f" 2. Analyze conversation types that work best")
805
+ print(f" 3. Optimize query enhancement for multi-turn scenarios")
806
+
807
+ except Exception as e:
808
+ print(f"❌ Error in main execution: {e}")
809
+ import traceback
810
+
811
+ traceback.print_exc()
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
src/evaluation/prompts.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompts for ViettelPay Synthetic Evaluation Dataset Creation
3
+ Simplified version for MRR and Hit Rate evaluation only
4
+ """
5
+
6
+ # Question Generation Prompt (JSON format for better parsing)
7
+ QUESTION_GENERATION_PROMPT = """Bạn là chuyên gia tạo câu hỏi đánh giá cho hệ thống ViettelPay Pro.
8
+ ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
9
+ Dựa trên đoạn văn bản sau từ tài liệu hướng dẫn, hãy tạo ra {num_questions} câu hỏi đa dạng:
10
+
11
+ <context>
12
+ {content}
13
+ </context>
14
+
15
+ Tạo các loại câu hỏi:
16
+ 1. Câu hỏi trực tiếp về thông tin trong đoạn văn
17
+ 2. Câu hỏi về cách thực hiện hoặc quy trình
18
+ 3. Câu hỏi về lỗi, vấn đề hoặc troubleshooting
19
+ 4. Câu hỏi về quy định, chính sách, phí
20
+
21
+ Yêu cầu cho mỗi câu hỏi:
22
+ - Tự nhiên như khách hàng ViettelPay thật sẽ hỏi
23
+ - Có thể trả lời được từ đoạn văn bản đã cho
24
+ - Ngắn gọn (5-20 từ)
25
+ - Sử dụng tiếng Việt thông dụng
26
+ - Đa dạng về loại câu hỏi và độ phức tạp
27
+
28
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
29
+ {{
30
+ "questions": [
31
+ "Câu hỏi đầu tiên?",
32
+ "Câu hỏi thứ hai?",
33
+ "Câu hỏi thứ ba?"
34
+ ]
35
+ }}
36
+
37
+ CHỈ trả về JSON, không có text khác."""
38
+
39
+ # Multi-Turn Conversation Generation Prompt (JSON format)
40
+ MULTI_TURN_CONVERSATION_GENERATION_PROMPT = """Bạn là một chuyên gia trong việc tạo dữ liệu huấn luyện cho chatbot hỗ trợ khách hàng của ứng dụng ViettelPay Pro.
41
+ ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
42
+ Nhiệm vụ của bạn là tạo ra **{num_conversations}** cuộc hội thoại đa lượt, chân thực và tự nhiên giữa người dùng (Đại lý/Điểm bán) và Trợ lý AI. Toàn bộ nội dung cuộc hội thoại phải dựa **hoàn toàn** vào thông tin được cung cấp trong tài liệu dưới đây.
43
+
44
+ <context>
45
+ {content}
46
+ </context>
47
+
48
+ Hãy tạo ra các cuộc hội thoại xoay quanh những kịch bản sau:
49
+
50
+ 1. **Giải quyết vấn đề (`error_resolution`):** Người dùng gặp lỗi, giao dịch thất bại, hoặc một tính năng không hoạt động như mong đợi. Họ muốn tìm hiểu nguyên nhân và cách khắc phục.
51
+ * *Ví dụ luồng hội thoại:* Báo lỗi -> Hỏi về nguyên nhân sâu xa -> Hỏi cách để tránh lỗi này trong tương lai.
52
+
53
+ 2. **Hướng dẫn thực hiện (`procedure_guide`):** Người dùng muốn biết cách thực hiện một tác vụ cụ thể. Cuộc hội thoại nên đi sâu vào các bước, điều kiện, hoặc các chi tiết liên quan.
54
+ * *Ví dụ luồng hội thoại:* Hỏi cách thực hiện một dịch vụ -> Hỏi về một bước cụ thể -> Hỏi về một trường hợp đặc biệt (ví dụ: "nếu làm cho nhà mạng khác thì sao?").
55
+
56
+ 3. **Tra cứu thông tin (`policy_info`):** Người dùng có câu hỏi về chính sách, phí, hạn mức, hoặc các quy định của dịch vụ.
57
+ * *Ví dụ luồng hội thoại:* Hỏi về một quy định chung -> Hỏi về một trường hợp áp dụng cụ thể -> Hỏi về các ngoại lệ.
58
+
59
+ **YÊU CẦU QUAN TRỌNG:**
60
+
61
+ * **Dòng chảy tự nhiên:** Mỗi lượt hỏi của người dùng phải là một phản ứng logic, tự nhiên sau khi nhận được câu trả lời (tưởng tượng) từ AI. Hãy hình dung AI đã đưa ra câu trả lời hữu ích nhưng chưa đầy đủ, khiến người dùng phải hỏi thêm để làm rõ.
62
+ * **Chân thực như người dùng thật:**
63
+ * Sử dụng ngôn ngữ đời thường, ngắn gọn, đi thẳng vào vấn đề.
64
+ * Có thể dùng các từ viết tắt phổ biến (vd: "sđt", "tk", "gd", "đk").
65
+ * Giọng điệu có thể thể hiện sự bối rối, cần hỗ trợ gấp hoặc tò mò.
66
+ * **Bám sát tài liệu:** **Không** được tự ý sáng tạo thông tin, chính sách, hoặc tính năng không có trong phần `<context>`.
67
+ * Tất cả các lượt câu hỏi, đặc biệt là câu hỏi cuối cùng, phải có thể trả lời được từ thông tin trong tài liệu `<context>`.
68
+ * **Cấu trúc:** Mỗi cuộc hội thoại phải có từ 2 đến 3 lượt hỏi từ phía người dùng.
69
+ * **Ngôn ngữ:** Tiếng Việt.
70
+
71
+ Ví dụ cuộc hội thoại:
72
+ ```
73
+ Lượt 1: "mã lỗi 606 là gì vậy?"
74
+ Lượt 2: "làm sao để khắc phục lỗi này?"
75
+ ```
76
+
77
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
78
+ {{
79
+ "conversations": [
80
+ {{
81
+ "type": "error_resolution",
82
+ "turns": [
83
+ {{"role": "user", "content": "mã lỗi 606 là gì vậy?"}},
84
+ {{"role": "user", "content": "làm sao để khắc phục lỗi này?"}}
85
+ ]
86
+ }},
87
+ {{
88
+ "type": "procedure_inquiry",
89
+ "turns": [
90
+ {{"role": "user", "content": "nạp cước điện thoại như thế nào?"}},
91
+ {{"role": "user", "content": "có thể nạp cho số Viettel không?"}},
92
+ {{"role": "user", "content": "có các mệnh giá nào?"}}
93
+ ]
94
+ }}
95
+ ]
96
+ }}
97
+
98
+ CHỈ trả về JSON, không có text khác."""
99
+
100
+
101
+ # Quality Check Prompt for Generated Questions
102
+ QUESTION_QUALITY_CHECK_PROMPT = """Đánh giá chất lượng của câu hỏi được tạo ra cho hệ thống ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
103
+
104
+ Câu hỏi: {question}
105
+ Đoạn văn gốc: {context}
106
+
107
+ Tiêu chí đánh giá:
108
+ 1. Clarity (Rõ ràng): Câu hỏi có dễ hiểu không?
109
+ 2. Answerability (Có thể trả lời): Có thể trả lời từ đoạn văn không?
110
+ 3. Naturalness (Tự nhiên): Có giống cách khách hàng thật hỏi không?
111
+ 4. Relevance (Liên quan): Có phù hợp với nội dung ViettelPay không?
112
+
113
+ Mỗi tiêu chí từ 1-5 điểm (5 là tốt nhất).
114
+
115
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
116
+ {{
117
+ "clarity": 5,
118
+ "answerability": 4,
119
+ "naturalness": 5,
120
+ "relevance": 5,
121
+ "overall_score": 4.75,
122
+ "keep_question": true,
123
+ "feedback": "Câu hỏi tốt, rõ ràng và tự nhiên"
124
+ }}
125
+
126
+ CHỈ trả về JSON, không có text khác."""
127
+
128
+ # Context Quality Check Prompt
129
+ CONTEXT_QUALITY_CHECK_PROMPT = """Đánh giá chất lượng của đoạn văn bản ViettelPay Pro để tạo câu hỏi. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
130
+
131
+ Đoạn văn bản:
132
+ {content}
133
+
134
+ Tiêu chí đánh giá:
135
+ 1. Clarity (Rõ ràng): Thông tin có dễ hiểu không?
136
+ 2. Completeness (Đầy đủ): Có đủ thông tin để tạo câu hỏi không?
137
+ 3. Structure (Cấu trúc): Có tổ chức tốt không?
138
+ 4. Relevance (Liên quan): Có phù hợp với ViettelPay không?
139
+ 5. Information_density (Mật độ thông tin): Có đủ thông tin hữu ích không?
140
+
141
+ Mỗi tiêu chí từ 1-5 điểm (5 là tốt nhất).
142
+
143
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
144
+ {{
145
+ "clarity": 5,
146
+ "completeness": 4,
147
+ "structure": 4,
148
+ "relevance": 5,
149
+ "information_density": 4,
150
+ "overall_score": 4.4,
151
+ "use_context": true,
152
+ "feedback": "Đoạn văn tốt, có thể tạo câu hỏi chất lượng"
153
+ }}
154
+
155
+ CHỈ trả về JSON, không có text khác."""
156
+
157
+ # Question Evolution/Variation Prompt
158
+ QUESTION_EVOLUTION_PROMPT = """Tạo các biến thể của câu hỏi ViettelPay Pro để tăng tính đa dạng. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
159
+
160
+ Câu hỏi gốc: {original_question}
161
+ Ngữ cảnh: {context}
162
+
163
+ Tạo 3 biến thể khác nhau:
164
+ 1. Phiên bản casual/thông tục (cách nói hàng ngày)
165
+ 2. Phiên bản formal/lịch sự (cách nói trang trọng)
166
+ 3. Phiên bản cụ thể (thêm chi tiết, tình huống cụ thể)
167
+
168
+ Yêu cầu:
169
+ - Giữ nguyên ý nghĩa cốt lõi
170
+ - Vẫn có thể trả lời từ cùng ngữ cảnh
171
+ - Tự nhiên với người dùng Việt Nam
172
+ - Đa dạng về cách diễn đạt
173
+
174
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
175
+ {{
176
+ "original_question": "{original_question}",
177
+ "variations": [
178
+ "Phiên bản casual",
179
+ "Phiên bản formal",
180
+ "Phiên bản cụ thể"
181
+ ]
182
+ }}
183
+
184
+ CHỈ trả về JSON, không có text khác."""
185
+
186
+ # Dataset Statistics Prompt
187
+ DATASET_STATS_PROMPT = """Phân tích thống kê dataset đánh giá ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
188
+
189
+ Dữ liệu:
190
+ - Tổng số câu hỏi: {total_questions}
191
+ - Tổng số documents: {total_documents}
192
+ - Câu hỏi theo loại: {question_types}
193
+
194
+ Tạo báo cáo thống kê và đề xuất cải thiện.
195
+
196
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
197
+ {{
198
+ "coverage_analysis": {{
199
+ "error_handling": "20%",
200
+ "procedures": "30%",
201
+ "policies": "25%",
202
+ "faq": "25%"
203
+ }},
204
+ "quality_metrics": {{
205
+ "avg_questions_per_doc": 2.1,
206
+ "question_diversity": "high"
207
+ }},
208
+ "recommendations": [
209
+ "Tăng câu hỏi về error handling",
210
+ "Cân bằng độ khó của câu hỏi"
211
+ ]
212
+ }}
213
+
214
+ CHỈ trả về JSON, không có text khác."""
215
+
216
+ # Error Analysis Prompt
217
+ ERROR_ANALYSIS_PROMPT = """Phân tích lỗi trong quá trình đánh giá retrieval ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
218
+
219
+ Kết quả đánh giá:
220
+ {evaluation_results}
221
+
222
+ Xác định:
223
+ 1. Câu hỏi có hiệu suất thấp (Hit Rate < 0.3)
224
+ 2. Loại lỗi thường gặp
225
+ 3. Nguyên nhân gốc rễ
226
+ 4. Đề xuất cải thiện
227
+
228
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
229
+ {{
230
+ "low_performance_queries": [
231
+ {{"query": "câu hỏi", "hit_rate": 0.2, "issue": "từ khóa không rõ ràng"}}
232
+ ],
233
+ "common_error_types": [
234
+ "Thiếu từ khóa chính",
235
+ "Ngữ cảnh không đủ",
236
+ "Chunking không tối ưu"
237
+ ],
238
+ "improvement_suggestions": [
239
+ "Cải thiện chunking strategy",
240
+ "Thêm synonyms cho từ khóa"
241
+ ]
242
+ }}
243
+
244
+ CHỈ trả về JSON, không có text khác."""
245
+
246
+ # Intent Classification Prompt
247
+ # Updated Intent Classification Conversation Generation Prompt with chunk mixing support
248
+ # Improved Intent Classification Conversation Generation Prompt
249
+ INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT = """Bạn là chuyên gia tạo dữ liệu đánh giá cho hệ thống phân loại ý định (intent classification) của Trợ lý AI trên ứng dụng ViettelPay Pro.
250
+ ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
251
+
252
+ Nhiệm vụ của bạn là tạo ra **{num_conversations}** cuộc hội thoại đa lượt thực tế. Mỗi tin nhắn của người dùng phải được gán một nhãn `intent` chính xác.
253
+ `intent` là ý định của người dùng trong câu hỏi hiện tại và liên quan đến các lượt hỏi trước đó trong cuộc hội thoại.
254
+
255
+ **1. Định nghĩa các loại ý định (Bắt buộc phải tuân theo):**
256
+
257
+ * **`greeting`**: Chỉ là lời chào hỏi thuần túy, không có câu hỏi hoặc yêu cầu cụ thể nào khác. Nếu tin nhắn có cả lời chào VÀ câu hỏi thì phân loại theo các ý định khác, không phải greeting.
258
+ * *Ví dụ:* "chào em", "hello shop", "xin chào ạ"
259
+ * *Không phải greeting:* "xin chào, cho hỏi về lỗi 606" → đây là error_help
260
+ * **`faq`**: Các câu hỏi đáp chung, tìm hiểu về dịch vụ, tính năng, v.v.
261
+ * *Ví dụ:* "App có bán thẻ game không?", "ViettelPay Pro nạp tiền được cho mạng nao?"
262
+ * **`error_help`**: Báo cáo sự cố, hỏi về mã lỗi cụ thể.
263
+ * *Ví dụ:* "Giao dịch báo lỗi 606", "tại sao tôi không thanh toán được?", "lỗi này là gì?"
264
+ * **`procedure_guide`**: Hỏi về các bước cụ thể để thực hiện một tác vụ.
265
+ * *Ví dụ:* "làm thế nào để hủy giao dịch?", "chỉ tôi cách lấy lại mã thẻ cào", "hướng dẫn nạp cước"
266
+ * **`human_request`**: Yêu cầu được nói chuyện trực tiếp với nhân viên hỗ trợ.
267
+ * *Ví dụ:* "cho tôi gặp người thật", "nối máy cho tổng đài", "em k hiểu, cho gặp ai đó"
268
+ * **`out_of_scope`**: Câu hỏi ngoài phạm vi ViettelPay (thời tiết, chính trị, v.v.), không liên quan gì đến các dịch vụ tài chính, viễn thông của Viettel.
269
+ * *Ví dụ:* "dự báo thời tiết hôm nay?", "giá xăng bao nhiêu?", "cách nấu phở"
270
+ * **`unclear`**: Câu hỏi không rõ ràng, thiếu thông tin cụ thể, cần người dùng bổ sung thêm chi tiết để có thể hỗ trợ hiệu quả.
271
+ * *Ví dụ:* "lỗi", "giúp với", "gd", "???", "ko hiểu", "bị lỗi giờ sao đây", "không thực hiện được", "sao vậy", "tại sao thế"
272
+
273
+ **2. Nguồn kiến thức (Context):**
274
+
275
+ Sử dụng tài liệu dưới đây để lấy các mã lỗi (vd: 606, W02), tên dịch vụ (vd: Gạch nợ cước), và các tình huống thực tế để xây dựng cuộc hội thoại.
276
+
277
+ <context>
278
+ {content}
279
+ </context>
280
+
281
+ **3. Yêu cầu về kịch bản hội thoại:**
282
+
283
+ * Mỗi cuộc hội thoại có t��� 2 đến 4 lượt hỏi từ người dùng. Tùy chỉnh sao cho phù hợp với tài liệu.
284
+ * {generation_instruction}
285
+ * Tạo các cuộc hội thoại đa dạng, không lặp lại.
286
+ * **QUAN TRỌNG - Intent Mixing**: Khoảng 70% tin nhắn nên liên quan đến context, nhưng 30% tin nhắn nên là các intent tự nhiên khác như:
287
+ - `greeting` ở đầu cuộc hội thoại
288
+ - `unclear` khi người dùng hỏi không rõ ràng, trợ lý cần thêm thông tin từ người dùng
289
+ - `human_request` khi họ muốn hỗ trợ trực tiếp
290
+ - `out_of_scope` khi họ hỏi không liên quan đến ViettelPay Pro
291
+ * Ngôn ngữ phải tự nhiên như người dùng thật, có thể dùng từ viết tắt, và đi thẳng vào vấn đề.
292
+ * Tạo các tình huống thực tế như người dùng thật sự sẽ hỏi. Không cần phải kết thúc bằng cảm ơn hay tạm biệt.
293
+ * Các câu hỏi nên dễ để phân loại ý định, không cần phải suy nghĩ quá lâu.
294
+
295
+ **4. Ví dụ cuộc hội thoại thực tế:**
296
+ {{
297
+ "conversations": [
298
+ {{
299
+ "turns": [
300
+ {{"user": "chào em", "intent": "greeting"}},
301
+ {{"user": "mã lỗi 606 là gì z?", "intent": "error_help"}},
302
+ {{"user": "em ko hiểu, gặp ai đó dc ko?", "intent": "human_request"}}
303
+ ]
304
+ }},
305
+ {{
306
+ "turns": [
307
+ {{"user": "nạp cước như nào?", "intent": "procedure_guide"}},
308
+ {{"user": "hôm nay trời đẹp nhỉ", "intent": "out_of_scope"}},
309
+ {{"user": "ờ quay lại, có phí ko?", "intent": "faq"}}
310
+ ]
311
+ }}
312
+ ]
313
+ }}
314
+
315
+ **5. Định dạng đầu ra (Output):**
316
+
317
+ QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như ví dụ trên.
318
+ CHỈ trả về JSON, không có text khác."""
src/evaluation/single_turn_retrieval.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Single Turn Synthetic Retrieval Evaluation Dataset Creator for ViettelPay RAG System
3
+ Uses Google Gemini 2.0 Flash with JSON responses for better parsing
4
+ Simplified version with only MRR and hit rate evaluation (no qrels generation)
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+ import argparse
11
+ import time
12
+ from typing import Dict, List, Tuple, Optional, Union
13
+ from pathlib import Path
14
+ from collections import defaultdict
15
+ import pandas as pd
16
+ from tqdm import tqdm
17
+ import re
18
+
19
+ # Load environment variables from .env file
20
+ from dotenv import load_dotenv
21
+
22
+ load_dotenv()
23
+
24
+ # Add the project root to Python path so we can import from src
25
+ project_root = Path(__file__).parent.parent.parent
26
+ sys.path.insert(0, str(project_root))
27
+
28
+ # Import prompts (only the ones we need)
29
+ from src.evaluation.prompts import (
30
+ QUESTION_GENERATION_PROMPT,
31
+ QUESTION_QUALITY_CHECK_PROMPT,
32
+ CONTEXT_QUALITY_CHECK_PROMPT,
33
+ QUESTION_EVOLUTION_PROMPT,
34
+ )
35
+
36
+ # Import your existing knowledge base and LLM client
37
+ from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
38
+ from src.llm.llm_client import LLMClientFactory, BaseLLMClient
39
+
40
+
41
+ class SingleTurnDatasetCreator:
42
+ """Single turn synthetic evaluation dataset creator with JSON responses and all chunks processing"""
43
+
44
+ def __init__(
45
+ self, gemini_api_key: str, knowledge_base: ViettelKnowledgeBase = None
46
+ ):
47
+ """
48
+ Initialize with Gemini API key and optional knowledge base
49
+
50
+ Args:
51
+ gemini_api_key: Google AI API key for Gemini
52
+ knowledge_base: Pre-initialized ViettelKnowledgeBase instance
53
+ """
54
+ self.llm_client = LLMClientFactory.create_client(
55
+ "gemini", api_key=gemini_api_key, model="gemini-2.0-flash"
56
+ )
57
+ self.knowledge_base = knowledge_base
58
+ self.dataset = {
59
+ "queries": {},
60
+ "documents": {},
61
+ "metadata": {
62
+ "total_chunks_processed": 0,
63
+ "questions_generated": 0,
64
+ "creation_timestamp": time.time(),
65
+ },
66
+ }
67
+
68
+ print("✅ SingleTurnDatasetCreator initialized with Gemini 2.0 Flash")
69
+
70
+ def generate_json_response(
71
+ self, prompt: str, max_retries: int = 3
72
+ ) -> Optional[Dict]:
73
+ """
74
+ Generate response and parse as JSON with retries
75
+
76
+ Args:
77
+ prompt: Input prompt
78
+ max_retries: Maximum number of retry attempts
79
+
80
+ Returns:
81
+ Parsed JSON response or None if failed
82
+ """
83
+ for attempt in range(max_retries):
84
+ try:
85
+ response = self.llm_client.generate(prompt, temperature=0.1)
86
+
87
+ if response:
88
+ # Clean response text
89
+ response_text = response.strip()
90
+
91
+ # Extract JSON from response (handle cases with extra text)
92
+ json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
93
+ if json_match:
94
+ json_text = json_match.group()
95
+ return json.loads(json_text)
96
+ else:
97
+ # Try parsing the whole response
98
+ return json.loads(response_text)
99
+
100
+ except json.JSONDecodeError as e:
101
+ print(f"⚠️ JSON parsing error (attempt {attempt + 1}): {e}")
102
+ if attempt == max_retries - 1:
103
+ print(f"❌ Failed to parse JSON after {max_retries} attempts")
104
+ print(
105
+ f"Raw response: {response if 'response' in locals() else 'No response'}"
106
+ )
107
+
108
+ except Exception as e:
109
+ print(f"⚠️ API error (attempt {attempt + 1}): {e}")
110
+ if attempt < max_retries - 1:
111
+ time.sleep(2**attempt) # Exponential backoff
112
+
113
+ return None
114
+
115
+ def get_all_chunks(self) -> List[Dict]:
116
+ """
117
+ Get ALL chunks directly from ChromaDB vectorstore (no sampling)
118
+
119
+ Returns:
120
+ List of all document chunks with content and metadata
121
+ """
122
+ print(f"📚 Retrieving ALL chunks directly from ChromaDB vectorstore...")
123
+
124
+ if not self.knowledge_base:
125
+ raise ValueError(
126
+ "Knowledge base not provided. Please initialize with a ViettelKnowledgeBase instance."
127
+ )
128
+
129
+ try:
130
+ # Access the ChromaDB vectorstore directly
131
+ if (
132
+ not hasattr(self.knowledge_base, "chroma_retriever")
133
+ or not self.knowledge_base.chroma_retriever
134
+ ):
135
+ raise ValueError("ChromaDB retriever not found in knowledge base")
136
+
137
+ # Get the vectorstore from the retriever
138
+ vectorstore = self.knowledge_base.chroma_retriever.vectorstore
139
+
140
+ # Get all documents directly from ChromaDB
141
+ print(" Accessing ChromaDB collection...")
142
+ all_docs = vectorstore.get(include=["documents", "metadatas"])
143
+
144
+ documents = all_docs["documents"]
145
+ metadatas = all_docs["metadatas"]
146
+
147
+ print(f" Found {len(documents)} documents in ChromaDB")
148
+ print(f" Sample document preview:")
149
+ for i, doc in enumerate(documents[:3]):
150
+ print(f" Doc {i+1}: {doc[:100]}...")
151
+
152
+ # Convert to our expected format
153
+ all_chunks = []
154
+ seen_content_hashes = set()
155
+
156
+ for i, (content, metadata) in enumerate(zip(documents, metadatas)):
157
+ # Create content hash for deduplication (just in case)
158
+ content_hash = hash(content[:300])
159
+
160
+ if (
161
+ content_hash not in seen_content_hashes
162
+ and len(content.strip()) > 50
163
+ ):
164
+ chunk_info = {
165
+ "id": f"chunk_{len(all_chunks)}",
166
+ "content": content,
167
+ "metadata": metadata or {},
168
+ "source": "chromadb_direct",
169
+ "content_length": len(content),
170
+ "original_index": i,
171
+ }
172
+ all_chunks.append(chunk_info)
173
+ seen_content_hashes.add(content_hash)
174
+ else:
175
+ if content_hash in seen_content_hashes:
176
+ print(f" ⚠️ Skipping duplicate content at index {i}")
177
+ else:
178
+ print(
179
+ f" ⚠️ Skipping short content at index {i} (length: {len(content.strip())})"
180
+ )
181
+
182
+ print(f"✅ Retrieved {len(all_chunks)} unique chunks from ChromaDB")
183
+ print(
184
+ f" Filtered out {len(documents) - len(all_chunks)} duplicates/short chunks"
185
+ )
186
+
187
+ # Sort by content length (longer chunks first, usually more informative)
188
+ all_chunks.sort(key=lambda x: x["content_length"], reverse=True)
189
+
190
+ # Display statistics
191
+ avg_length = sum(chunk["content_length"] for chunk in all_chunks) / len(
192
+ all_chunks
193
+ )
194
+ min_length = min(chunk["content_length"] for chunk in all_chunks)
195
+ max_length = max(chunk["content_length"] for chunk in all_chunks)
196
+
197
+ print(f" 📊 Chunk Statistics:")
198
+ print(f" • Average length: {avg_length:.0f} characters")
199
+ print(f" • Min length: {min_length} characters")
200
+ print(f" • Max length: {max_length} characters")
201
+
202
+ return all_chunks
203
+
204
+ except Exception as e:
205
+ print(f"❌ Error accessing ChromaDB directly: {e}")
206
+ print(f" Falling back to search-based method...")
207
+ return self._get_all_chunks_fallback()
208
+
209
+ def _get_all_chunks_fallback(self) -> List[Dict]:
210
+ """
211
+ Fallback method using search queries if direct ChromaDB access fails
212
+
213
+ Returns:
214
+ List of document chunks retrieved via search
215
+ """
216
+ print(f"🔄 Using fallback search-based chunk retrieval...")
217
+
218
+ # Use comprehensive search terms to capture most content
219
+ comprehensive_queries = [
220
+ "ViettelPay",
221
+ "nạp",
222
+ "cước",
223
+ "giao dịch",
224
+ "thanh toán",
225
+ "lỗi",
226
+ "hủy",
227
+ "thẻ",
228
+ "chuyển",
229
+ "tiền",
230
+ "quy định",
231
+ "phí",
232
+ "dịch vụ",
233
+ "tài khoản",
234
+ "ngân hàng",
235
+ "OTP",
236
+ "PIN",
237
+ "mã",
238
+ "số",
239
+ "điện thoại",
240
+ "internet",
241
+ "truyền hình",
242
+ "homephone",
243
+ "cố định",
244
+ "game",
245
+ "Viettel",
246
+ "Mobifone",
247
+ # Add some Vietnamese words that might not be captured above
248
+ "ứng dụng",
249
+ "khách hàng",
250
+ "hỗ trợ",
251
+ "kiểm tra",
252
+ "xác nhận",
253
+ "bảo mật",
254
+ ]
255
+
256
+ all_chunks = []
257
+ seen_content_hashes = set()
258
+
259
+ for query in comprehensive_queries:
260
+ try:
261
+ # Search with large k to get as many chunks as possible
262
+ docs = self.knowledge_base.search(query, top_k=50)
263
+
264
+ for doc in docs:
265
+ # Create content hash for deduplication
266
+ content_hash = hash(doc.page_content[:300])
267
+
268
+ if (
269
+ content_hash not in seen_content_hashes
270
+ and len(doc.page_content.strip()) > 50
271
+ ):
272
+ chunk_info = {
273
+ "id": f"chunk_{len(all_chunks)}",
274
+ "content": doc.page_content,
275
+ "metadata": doc.metadata,
276
+ "source": f"search_{query}",
277
+ "content_length": len(doc.page_content),
278
+ }
279
+ all_chunks.append(chunk_info)
280
+ seen_content_hashes.add(content_hash)
281
+
282
+ except Exception as e:
283
+ print(f"⚠️ Error searching for '{query}': {e}")
284
+ continue
285
+
286
+ print(f"✅ Fallback method retrieved {len(all_chunks)} unique chunks")
287
+
288
+ # Sort by content length
289
+ all_chunks.sort(key=lambda x: x["content_length"], reverse=True)
290
+
291
+ return all_chunks
292
+
293
+ def generate_questions_for_chunk(
294
+ self, chunk: Dict, num_questions: int = 2
295
+ ) -> List[Dict]:
296
+ """
297
+ Generate questions for a single chunk using Gemini with JSON response
298
+
299
+ Args:
300
+ chunk: Chunk dictionary with content and metadata
301
+ num_questions: Number of questions to generate per chunk
302
+
303
+ Returns:
304
+ List of question dictionaries with metadata
305
+ """
306
+ content = chunk["content"]
307
+
308
+ prompt = QUESTION_GENERATION_PROMPT.format(
309
+ num_questions=num_questions, content=content
310
+ )
311
+
312
+ response_json = self.generate_json_response(prompt)
313
+
314
+ if response_json and "questions" in response_json:
315
+ questions = response_json["questions"]
316
+
317
+ # Create question objects with metadata
318
+ question_objects = []
319
+ for i, question_text in enumerate(questions):
320
+ if len(question_text.strip()) > 5: # Filter very short questions
321
+ question_obj = {
322
+ "id": f"q_{chunk['id']}_{i}",
323
+ "text": question_text.strip(),
324
+ "source_chunk": chunk["id"],
325
+ "chunk_metadata": chunk["metadata"],
326
+ "generation_method": "gemini_json",
327
+ }
328
+ question_objects.append(question_obj)
329
+
330
+ return question_objects
331
+ else:
332
+ print(f"⚠️ No valid questions generated for chunk {chunk['id']}")
333
+ return []
334
+
335
+ def check_context_quality(self, chunk: Dict) -> bool:
336
+ """
337
+ Check if a chunk is suitable for question generation
338
+
339
+ Args:
340
+ chunk: Chunk dictionary
341
+
342
+ Returns:
343
+ True if chunk should be used, False otherwise
344
+ """
345
+ content = chunk["content"]
346
+
347
+ # Basic checks first
348
+ if len(content.strip()) < 100:
349
+ return False
350
+
351
+ # Use Gemini for quality assessment
352
+ prompt = CONTEXT_QUALITY_CHECK_PROMPT.format(content=content[:1000])
353
+
354
+ response_json = self.generate_json_response(prompt)
355
+
356
+ if response_json:
357
+ return response_json.get("use_context", True)
358
+ else:
359
+ # Fallback to basic heuristics
360
+ return len(content.strip()) > 100 and len(content.split()) > 20
361
+
362
+ def create_complete_dataset(
363
+ self,
364
+ questions_per_chunk: int = 2,
365
+ save_path: str = "evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval_dataset.json",
366
+ quality_check: bool = True,
367
+ ) -> Dict:
368
+ """
369
+ Create complete synthetic evaluation dataset using ALL chunks
370
+
371
+ Args:
372
+ questions_per_chunk: Number of questions to generate per chunk
373
+ save_path: Path to save the dataset JSON file
374
+ quality_check: Whether to perform quality checks on chunks
375
+
376
+ Returns:
377
+ Complete dataset dictionary
378
+ """
379
+ print(f"\n🚀 Creating simplified synthetic evaluation dataset...")
380
+ print(f" Target: Process ALL chunks from knowledge base")
381
+ print(f" Questions per chunk: {questions_per_chunk}")
382
+ print(f" Quality check: {quality_check}")
383
+ print(f" Evaluation method: MRR and Hit Rates only (no qrels)")
384
+
385
+ # Step 1: Get all chunks
386
+ all_chunks = self.get_all_chunks()
387
+ total_chunks = len(all_chunks)
388
+
389
+ if total_chunks == 0:
390
+ raise ValueError("No chunks found in knowledge base!")
391
+
392
+ print(f"✅ Found {total_chunks} chunks to process")
393
+
394
+ # Step 2: Quality filtering (optional)
395
+ if quality_check:
396
+ print(f"\n🔍 Performing quality checks on chunks...")
397
+ quality_chunks = []
398
+
399
+ for chunk in tqdm(all_chunks, desc="Quality checking"):
400
+ if self.check_context_quality(chunk):
401
+ quality_chunks.append(chunk)
402
+ time.sleep(0.1) # Rate limiting
403
+
404
+ print(
405
+ f"✅ {len(quality_chunks)}/{total_chunks} chunks passed quality check"
406
+ )
407
+ chunks_to_process = quality_chunks
408
+ else:
409
+ chunks_to_process = all_chunks
410
+
411
+ # Step 3: Generate questions for all chunks
412
+ print(f"\n📝 Generating questions for {len(chunks_to_process)} chunks...")
413
+ all_questions = []
414
+
415
+ for chunk in tqdm(chunks_to_process, desc="Generating questions"):
416
+ questions = self.generate_questions_for_chunk(chunk, questions_per_chunk)
417
+ all_questions.extend(questions)
418
+ time.sleep(0.2) # Rate limiting for Gemini API
419
+
420
+ print(
421
+ f"✅ Generated {len(all_questions)} questions from {len(chunks_to_process)} chunks"
422
+ )
423
+
424
+ # Step 4: Populate dataset structure
425
+ self.dataset["documents"] = {
426
+ chunk["id"]: chunk["content"] for chunk in chunks_to_process
427
+ }
428
+ self.dataset["queries"] = {q["id"]: q["text"] for q in all_questions}
429
+
430
+ # Add question metadata
431
+ question_metadata = {
432
+ q["id"]: {
433
+ "source_chunk": q["source_chunk"],
434
+ "chunk_metadata": q["chunk_metadata"],
435
+ "generation_method": q["generation_method"],
436
+ }
437
+ for q in all_questions
438
+ }
439
+
440
+ self.dataset["question_metadata"] = question_metadata
441
+
442
+ # Step 5: Update metadata
443
+ self.dataset["metadata"].update(
444
+ {
445
+ "total_chunks_processed": len(chunks_to_process),
446
+ "total_chunks_available": total_chunks,
447
+ "questions_generated": len(all_questions),
448
+ "questions_per_chunk": questions_per_chunk,
449
+ "quality_check_enabled": quality_check,
450
+ "evaluation_method": "mrr_hit_rates_only",
451
+ "completion_timestamp": time.time(),
452
+ }
453
+ )
454
+
455
+ # Step 6: Save dataset
456
+ os.makedirs(
457
+ os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
458
+ exist_ok=True,
459
+ )
460
+
461
+ with open(save_path, "w", encoding="utf-8") as f:
462
+ json.dump(self.dataset, f, ensure_ascii=False, indent=2)
463
+
464
+ print(f"\n✅ COMPLETE dataset created successfully!")
465
+ print(f" 📁 Saved to: {save_path}")
466
+ print(f" 📊 Statistics:")
467
+ print(f" • Chunks processed: {len(chunks_to_process)}/{total_chunks}")
468
+ print(f" • Questions generated: {len(all_questions)}")
469
+ print(f" • Evaluation method: MRR and Hit Rates only")
470
+ print(
471
+ f" • Coverage: {len(chunks_to_process)/total_chunks*100:.1f}% of knowledge base"
472
+ )
473
+
474
+ return self.dataset
475
+
476
+ def load_dataset(self, dataset_path: str) -> Dict:
477
+ """Load dataset from JSON file with metadata"""
478
+ with open(dataset_path, "r", encoding="utf-8") as f:
479
+ self.dataset = json.load(f)
480
+
481
+ metadata = self.dataset.get("metadata", {})
482
+
483
+ print(f"📖 Loaded dataset from {dataset_path}")
484
+ print(f" 📊 Dataset Statistics:")
485
+ print(f" • Queries: {len(self.dataset['queries'])}")
486
+ print(f" • Documents: {len(self.dataset['documents'])}")
487
+ print(f" • Created: {time.ctime(metadata.get('creation_timestamp', 0))}")
488
+
489
+ return self.dataset
490
+
491
+
492
+ class SingleTurnRetrievalEvaluator:
493
+ """Simplified retrieval evaluator with only MRR and hit rates"""
494
+
495
+ def __init__(self, dataset: Dict, knowledge_base: ViettelKnowledgeBase):
496
+ """
497
+ Initialize evaluator with dataset and knowledge base
498
+
499
+ Args:
500
+ dataset: Evaluation dataset with queries and documents
501
+ knowledge_base: ViettelKnowledgeBase instance to evaluate
502
+ """
503
+ self.dataset = dataset
504
+ self.knowledge_base = knowledge_base
505
+ self.results = {}
506
+
507
+ def _match_retrieved_documents(self, retrieved_docs) -> List[str]:
508
+ """
509
+ Enhanced document matching with multiple strategies
510
+
511
+ Args:
512
+ retrieved_docs: Retrieved Document objects from knowledge base
513
+
514
+ Returns:
515
+ List of matched document IDs
516
+ """
517
+ matched_ids = []
518
+
519
+ for doc in retrieved_docs:
520
+ # Strategy 1: Try to find exact content match
521
+ doc_id = self._find_exact_content_match(doc.page_content)
522
+
523
+ if not doc_id:
524
+ # Strategy 2: Try fuzzy content matching
525
+ doc_id = self._find_fuzzy_content_match(doc.page_content)
526
+
527
+ if doc_id:
528
+ matched_ids.append(doc_id)
529
+
530
+ return matched_ids
531
+
532
+ def _find_exact_content_match(self, retrieved_content: str) -> Optional[str]:
533
+ """Find exact content match"""
534
+ for doc_id, doc_content in self.dataset["documents"].items():
535
+ if retrieved_content.strip() == doc_content.strip():
536
+ return doc_id
537
+ return None
538
+
539
+ def _find_fuzzy_content_match(
540
+ self, retrieved_content: str, min_overlap: int = 50
541
+ ) -> Optional[str]:
542
+ """Find fuzzy content match with word overlap"""
543
+ best_match_id = None
544
+ best_overlap = 0
545
+
546
+ retrieved_words = set(retrieved_content.lower().split())
547
+
548
+ for doc_id, doc_content in self.dataset["documents"].items():
549
+ doc_words = set(doc_content.lower().split())
550
+ overlap = len(retrieved_words & doc_words)
551
+
552
+ if overlap > best_overlap and overlap >= min_overlap:
553
+ best_overlap = overlap
554
+ best_match_id = doc_id
555
+
556
+ return best_match_id
557
+
558
+ def _safe_average(self, values: List[float]) -> float:
559
+ """Calculate average safely handling empty lists"""
560
+ return sum(values) / len(values) if values else 0.0
561
+
562
+ def evaluate(self, k_values: List[int] = [1, 3, 5, 10]) -> Dict:
563
+ """
564
+ Simplified evaluation with only MRR and hit rates
565
+
566
+ This method checks if the source document (where the question was generated from)
567
+ is retrieved among the top-k results.
568
+
569
+ Args:
570
+ k_values: List of k values to evaluate
571
+
572
+ Returns:
573
+ Dictionary with MRR and hit rate results
574
+ """
575
+ print(f"\n🔍 Running simplified evaluation (MRR and Hit Rates only)...")
576
+ print(f" 📊 K values: {k_values}")
577
+ print(f" 📚 Total queries: {len(self.dataset['queries'])}")
578
+
579
+ # Initialize results
580
+ hit_rates = {k: [] for k in k_values}
581
+ rr_scores = [] # Reciprocal Rank scores for MRR calculation
582
+ query_results = {}
583
+ failed_queries = []
584
+
585
+ # Process each query
586
+ for query_id, query_text in tqdm(
587
+ self.dataset["queries"].items(), desc="Evaluating queries"
588
+ ):
589
+ try:
590
+ # Get source document from metadata - handle both single-turn and multi-turn formats
591
+ source_chunk_id = None
592
+
593
+ # Try question_metadata first (single-turn format)
594
+ question_meta = self.dataset.get("question_metadata", {}).get(
595
+ query_id, {}
596
+ )
597
+ if question_meta:
598
+ source_chunk_id = question_meta.get("source_chunk")
599
+
600
+ # If not found, try conversation_metadata (multi-turn format)
601
+ if not source_chunk_id:
602
+ conversation_meta = self.dataset.get(
603
+ "conversation_metadata", {}
604
+ ).get(query_id, {})
605
+ if conversation_meta:
606
+ source_chunk_id = conversation_meta.get("source_chunk")
607
+
608
+ if not source_chunk_id:
609
+ print(f"⚠️ No source chunk info for query {query_id}")
610
+ continue
611
+
612
+ # Get retrieval results
613
+ retrieved_docs = self.knowledge_base.search(
614
+ query_text, top_k=max(k_values)
615
+ )
616
+ retrieved_doc_ids = self._match_retrieved_documents(retrieved_docs)
617
+
618
+ # Check if source document is in top-k for each k
619
+ query_results[query_id] = {
620
+ "query": query_text,
621
+ "source_chunk": source_chunk_id,
622
+ "retrieved": retrieved_doc_ids,
623
+ "hit_rates": {},
624
+ }
625
+
626
+ # Calculate Reciprocal Rank (MRR) - once per query
627
+ if source_chunk_id in retrieved_doc_ids:
628
+ source_rank = (
629
+ retrieved_doc_ids.index(source_chunk_id) + 1
630
+ ) # 1-indexed rank
631
+ rr_score = 1.0 / source_rank
632
+ else:
633
+ rr_score = 0.0
634
+
635
+ query_results[query_id]["rr"] = rr_score
636
+ query_results[query_id]["source_rank"] = (
637
+ source_rank if rr_score > 0 else None
638
+ )
639
+ rr_scores.append(rr_score)
640
+
641
+ for k in k_values:
642
+ top_k_docs = retrieved_doc_ids[:k]
643
+ hit = 1 if source_chunk_id in top_k_docs else 0
644
+ hit_rates[k].append(hit)
645
+ query_results[query_id]["hit_rates"][k] = hit
646
+
647
+ except Exception as e:
648
+ print(f"❌ Error evaluating query {query_id}: {e}")
649
+ failed_queries.append((query_id, str(e)))
650
+ continue
651
+
652
+ # Calculate average metrics
653
+ avg_hit_rates = {}
654
+ avg_rr = sum(rr_scores) / len(rr_scores) if rr_scores else 0.0
655
+
656
+ for k in k_values:
657
+ avg_hit_rates[k] = self._safe_average(hit_rates[k])
658
+
659
+ results = {
660
+ "hit_rates": avg_hit_rates,
661
+ "mrr": avg_rr,
662
+ "per_query_results": query_results,
663
+ "failed_queries": failed_queries,
664
+ "summary": {
665
+ "total_queries": len(self.dataset["queries"]),
666
+ "evaluated_queries": len(query_results),
667
+ "failed_queries": len(failed_queries),
668
+ "success_rate": len(query_results) / len(self.dataset["queries"]) * 100,
669
+ "k_values": k_values,
670
+ "evaluation_type": "mrr_hit_rates_only",
671
+ "evaluation_timestamp": time.time(),
672
+ },
673
+ }
674
+
675
+ return results
676
+
677
+ def print_evaluation_results(self, results: Dict):
678
+ """Print simplified evaluation results"""
679
+ print(f"\n�� SIMPLIFIED EVALUATION RESULTS (MRR + Hit Rates)")
680
+ print("=" * 60)
681
+
682
+ print(f"\n📈 Hit Rates (Source Document Found in Top-K):")
683
+ print(f"{'K':<5} {'Hit Rate':<12} {'Percentage':<12}")
684
+ print("-" * 30)
685
+
686
+ for k in sorted(results["hit_rates"].keys()):
687
+ hit_rate = results["hit_rates"][k]
688
+ percentage = hit_rate * 100
689
+ print(f"{k:<5} {hit_rate:<12.4f} {percentage:<12.1f}%")
690
+
691
+ # Display MRR separately since it's not k-dependent
692
+ mrr = results["mrr"]
693
+ print(f"\n📊 Mean Reciprocal Rank (MRR): {mrr:.4f}")
694
+ print(f" • MRR measures the average reciprocal rank of the source document")
695
+ print(f" • Higher is better (max = 1.0 if all sources are rank 1)")
696
+
697
+ print(f"\n📊 Hit Rate Summary:")
698
+ for k in sorted(results["hit_rates"].keys()):
699
+ hit_rate = results["hit_rates"][k]
700
+ percentage = hit_rate * 100
701
+ print(
702
+ f" • Top-{k}: {percentage:.1f}% of questions find their source document"
703
+ )
704
+
705
+ # Summary stats
706
+ summary = results["summary"]
707
+ print(f"\n📋 Evaluation Summary:")
708
+ print(f" • Total queries: {summary['total_queries']}")
709
+ print(f" • Successfully evaluated: {summary['evaluated_queries']}")
710
+ print(f" • Failed queries: {summary['failed_queries']}")
711
+ print(f" • Success rate: {summary['success_rate']:.1f}%")
712
+ print(f" • Evaluation type: {summary['evaluation_type']}")
713
+
714
+ # Simple interpretation
715
+ avg_hit_rate_5 = results["hit_rates"].get(5, 0)
716
+ mrr = results["mrr"]
717
+ print(f"\n🎯 Quick Interpretation:")
718
+ if avg_hit_rate_5 > 0.8:
719
+ print(
720
+ f" ✅ Excellent: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}"
721
+ )
722
+ elif avg_hit_rate_5 > 0.6:
723
+ print(f" 👍 Good: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}")
724
+ elif avg_hit_rate_5 > 0.4:
725
+ print(f" ⚠️ Fair: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}")
726
+ else:
727
+ print(f" ❌ Poor: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}")
728
+
729
+
730
+ def main():
731
+ """Main function with argument parsing for separate operations"""
732
+ parser = argparse.ArgumentParser(
733
+ description="ViettelPay Retrieval Evaluation Dataset Creator (Simplified)"
734
+ )
735
+ parser.add_argument(
736
+ "--mode",
737
+ choices=["create", "evaluate", "both"],
738
+ default="both",
739
+ help="Mode: create dataset, evaluate only, or both",
740
+ )
741
+ parser.add_argument(
742
+ "--dataset-path",
743
+ default="evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval.json",
744
+ help="Path to dataset file",
745
+ )
746
+ parser.add_argument(
747
+ "--results-path",
748
+ default="evaluation_data/results/single_turn_retrieval/viettelpay_eval_results.json",
749
+ help="Path to save evaluation results",
750
+ )
751
+ parser.add_argument(
752
+ "--questions-per-chunk",
753
+ type=int,
754
+ default=3,
755
+ help="Number of questions per chunk",
756
+ )
757
+ parser.add_argument(
758
+ "--k-values",
759
+ nargs="+",
760
+ type=int,
761
+ default=[1, 3, 5, 10],
762
+ help="K values for evaluation",
763
+ )
764
+ parser.add_argument(
765
+ "--quality-check",
766
+ action="store_true",
767
+ help="Enable quality checking for chunks",
768
+ )
769
+ parser.add_argument(
770
+ "--knowledge-base-path",
771
+ default="./knowledge_base",
772
+ help="Path to knowledge base",
773
+ )
774
+
775
+ args = parser.parse_args()
776
+
777
+ # Configuration
778
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
779
+
780
+ if not GEMINI_API_KEY:
781
+ print("❌ Please set GEMINI_API_KEY environment variable")
782
+ return
783
+
784
+ try:
785
+ # Initialize knowledge base
786
+ print("🔧 Initializing ViettelPay knowledge base...")
787
+ kb = ViettelKnowledgeBase()
788
+ if not kb.load_knowledge_base(args.knowledge_base_path):
789
+ print(
790
+ "❌ Failed to load knowledge base. Please run build_database_script.py first."
791
+ )
792
+ return
793
+
794
+ # Create dataset if requested
795
+ if args.mode in ["create", "both"]:
796
+ print(f"\n🎯 Creating synthetic evaluation dataset...")
797
+ creator = SingleTurnDatasetCreator(GEMINI_API_KEY, kb)
798
+
799
+ dataset = creator.create_complete_dataset(
800
+ questions_per_chunk=args.questions_per_chunk,
801
+ save_path=args.dataset_path,
802
+ quality_check=args.quality_check,
803
+ )
804
+
805
+ # Evaluate if requested
806
+ if args.mode in ["evaluate", "both"]:
807
+ print(f"\n⚡ Evaluating retrieval performance...")
808
+
809
+ # Load dataset if not created in this run
810
+ if args.mode == "evaluate":
811
+ if not os.path.exists(args.dataset_path):
812
+ print(f"❌ Dataset file not found: {args.dataset_path}")
813
+ return
814
+
815
+ creator = SingleTurnDatasetCreator(GEMINI_API_KEY, kb)
816
+ dataset = creator.load_dataset(args.dataset_path)
817
+
818
+ # Run evaluation
819
+ evaluator = SingleTurnRetrievalEvaluator(dataset, kb)
820
+ results = evaluator.evaluate(k_values=args.k_values)
821
+ evaluator.print_evaluation_results(results)
822
+
823
+ # Save results
824
+ if args.results_path:
825
+ with open(args.results_path, "w", encoding="utf-8") as f:
826
+ json.dump(results, f, ensure_ascii=False, indent=2)
827
+ print(f"\n💾 Results saved to: {args.results_path}")
828
+
829
+ print(f"\n✅ Operation completed successfully!")
830
+ print(f"\n💡 Next steps:")
831
+ print(f" 1. Review the MRR and hit rate results")
832
+ print(f" 2. Identify queries with low performance")
833
+ print(f" 3. Optimize your retrieval system")
834
+ print(f" 4. Re-run evaluation to measure progress")
835
+
836
+ except Exception as e:
837
+ print(f"❌ Error in main execution: {e}")
838
+ import traceback
839
+
840
+ traceback.print_exc()
841
+
842
+
843
+ if __name__ == "__main__":
844
+ main()
src/knowledge_base/__pycache__/builder.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
src/knowledge_base/__pycache__/builder.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
src/knowledge_base/__pycache__/viettel_knowledge_base.cpython-311.pyc ADDED
Binary file (24 kB). View file
 
src/knowledge_base/viettel_knowledge_base.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViettelPay Knowledge Base with Contextual Retrieval
3
+
4
+ This updated version:
5
+ - Uses ContextualWordProcessor for all document processing
6
+ - Integrates OpenAI for contextual enhancement
7
+ - Processes all doc/docx files from a parent folder
8
+ - Removes CSV processor dependency
9
+ """
10
+
11
+ import os
12
+ import pickle
13
+
14
+ # import torch
15
+ from typing import List, Optional
16
+ from pathlib import Path
17
+ from openai import OpenAI
18
+
19
+ from langchain.schema import Document
20
+ from langchain.retrievers import EnsembleRetriever
21
+ from langchain_community.retrievers import BM25Retriever
22
+ from langchain_core.runnables import ConfigurableField
23
+ from langchain_cohere.rerank import CohereRerank
24
+
25
+ # Use newest import paths for langchain
26
+ try:
27
+ from langchain_chroma import Chroma
28
+ except ImportError:
29
+ from langchain_community.vectorstores import Chroma
30
+
31
+ # Use the new HuggingFaceEmbeddings from langchain-huggingface
32
+ try:
33
+ from langchain_huggingface import HuggingFaceEmbeddings
34
+ except ImportError:
35
+ from langchain_community.embeddings import HuggingFaceEmbeddings
36
+
37
+ from src.processor.contextual_word_processor import ContextualWordProcessor
38
+ from src.processor.text_utils import VietnameseTextProcessor
39
+
40
+ # Import configuration utility
41
+ from src.utils.config import get_cohere_api_key, get_openai_api_key, get_embedding_model
42
+
43
+
44
+ class ViettelKnowledgeBase:
45
+ """ViettelPay knowledge base with contextual retrieval enhancement"""
46
+
47
+ def __init__(self, embedding_model: str = None):
48
+ """
49
+ Initialize the knowledge base
50
+
51
+ Args:
52
+ embedding_model: Vietnamese embedding model to use
53
+ """
54
+ embedding_model = embedding_model or get_embedding_model()
55
+
56
+ # Initialize Vietnamese text processor
57
+ self.text_processor = VietnameseTextProcessor()
58
+
59
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ self.device = "cpu"
61
+ print(f"[INFO] Using device: {self.device}")
62
+
63
+ # Initialize embeddings with GPU support and trust_remote_code
64
+ model_kwargs = {"device": self.device, "trust_remote_code": True}
65
+
66
+ self.embeddings = HuggingFaceEmbeddings(
67
+ model_name=embedding_model, model_kwargs=model_kwargs
68
+ )
69
+
70
+ # Initialize retrievers as None
71
+ self.chroma_retriever = None
72
+ self.bm25_retriever = None
73
+ self.ensemble_retriever = None
74
+
75
+ self.reranker = CohereRerank(
76
+ model="rerank-v3.5",
77
+ cohere_api_key=get_cohere_api_key(),
78
+ )
79
+
80
+ def build_knowledge_base(
81
+ self,
82
+ documents_folder: str,
83
+ persist_dir: str = "./knowledge_base",
84
+ reset: bool = True,
85
+ openai_api_key: Optional[str] = None,
86
+ ) -> None:
87
+ """
88
+ Build knowledge base from all Word documents in a folder
89
+
90
+ Args:
91
+ documents_folder: Path to folder containing doc/docx files
92
+ persist_dir: Directory to persist the knowledge base
93
+ reset: Whether to reset existing knowledge base
94
+ openai_api_key: OpenAI API key for contextual enhancement (optional)
95
+
96
+ Returns:
97
+ None. Use the search() method to perform searches.
98
+ """
99
+
100
+ print(
101
+ "[INFO] Building ViettelPay knowledge base with contextual enhancement..."
102
+ )
103
+
104
+ # Initialize OpenAI client for contextual enhancement if API key provided
105
+ openai_client = None
106
+ if openai_api_key:
107
+ openai_client = OpenAI(api_key=openai_api_key)
108
+ print(f"[INFO] OpenAI client initialized for contextual enhancement")
109
+ elif get_openai_api_key():
110
+ api_key = get_openai_api_key()
111
+ openai_client = OpenAI(api_key=api_key)
112
+ print(f"[INFO] OpenAI client initialized from configuration")
113
+ else:
114
+ print(
115
+ f"[WARNING] No OpenAI API key provided. Contextual enhancement disabled."
116
+ )
117
+
118
+ # Initialize the contextual word processor with OpenAI client
119
+ word_processor = ContextualWordProcessor(llm_client=openai_client)
120
+
121
+ # Find all Word documents in the folder
122
+ word_files = self._find_word_documents(documents_folder)
123
+
124
+ if not word_files:
125
+ raise ValueError(f"No Word documents found in {documents_folder}")
126
+
127
+ print(f"[INFO] Found {len(word_files)} Word documents to process")
128
+
129
+ # Process all documents
130
+ all_documents = self._process_all_word_files(word_files, word_processor)
131
+ print(f"[INFO] Total documents processed: {len(all_documents)}")
132
+
133
+ # Create directories
134
+ os.makedirs(persist_dir, exist_ok=True)
135
+ chroma_dir = os.path.join(persist_dir, "chroma")
136
+ bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl")
137
+
138
+ # Build ChromaDB retriever (uses contextualized content)
139
+ print("[INFO] Building ChromaDB retriever with contextualized content...")
140
+ self.chroma_retriever = self._build_chroma_retriever(
141
+ all_documents, chroma_dir, reset
142
+ )
143
+
144
+ # Build BM25 retriever (uses contextualized content with Vietnamese tokenization)
145
+ print("[INFO] Building BM25 retriever with Vietnamese tokenization...")
146
+ self.bm25_retriever = self._build_bm25_retriever(
147
+ all_documents, bm25_path, reset
148
+ )
149
+
150
+ # Create ensemble retriever with configurable top-k
151
+ print("[INFO] Creating ensemble retriever...")
152
+ self.ensemble_retriever = self._build_retriever(
153
+ self.bm25_retriever, self.chroma_retriever
154
+ )
155
+
156
+ print("[SUCCESS] Contextual knowledge base built successfully!")
157
+ print("[INFO] Use kb.search(query, top_k) to perform searches.")
158
+
159
+ def load_knowledge_base(self, persist_dir: str = "./knowledge_base") -> bool:
160
+ """
161
+ Load existing knowledge base from disk and rebuild BM25 from ChromaDB documents
162
+
163
+ Args:
164
+ persist_dir: Directory where the knowledge base is stored
165
+
166
+ Returns:
167
+ bool: True if loaded successfully, False otherwise
168
+ """
169
+
170
+ print("[INFO] Loading knowledge base from disk...")
171
+
172
+ chroma_dir = os.path.join(persist_dir, "chroma")
173
+
174
+ try:
175
+ # Load ChromaDB
176
+ if os.path.exists(chroma_dir):
177
+ vectorstore = Chroma(
178
+ persist_directory=chroma_dir, embedding_function=self.embeddings
179
+ )
180
+
181
+ self.chroma_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
182
+ print("[SUCCESS] ChromaDB loaded")
183
+ else:
184
+ print("[ERROR] ChromaDB not found")
185
+ return False
186
+
187
+ # Extract all documents from ChromaDB to rebuild BM25
188
+ print("[INFO] Extracting documents from ChromaDB to rebuild BM25...")
189
+ try:
190
+ # Get all documents and metadata from ChromaDB
191
+ all_docs = vectorstore.get(include=["documents", "metadatas"])
192
+
193
+ documents = all_docs["documents"]
194
+ metadatas = all_docs["metadatas"]
195
+
196
+ # Reconstruct Document objects
197
+ doc_objects = []
198
+ for i, (doc_content, metadata) in enumerate(zip(documents, metadatas)):
199
+ # Handle case where metadata might be None
200
+ if metadata is None:
201
+ metadata = {}
202
+
203
+ doc_obj = Document(page_content=doc_content, metadata=metadata)
204
+ doc_objects.append(doc_obj)
205
+
206
+ print(f"[INFO] Extracted {len(doc_objects)} documents from ChromaDB")
207
+
208
+ # Rebuild BM25 retriever using existing method
209
+ self.bm25_retriever = self._build_bm25_retriever(
210
+ documents=doc_objects,
211
+ bm25_path=None, # Not used anymore
212
+ reset=False, # Not relevant for rebuilding
213
+ )
214
+
215
+ except Exception as e:
216
+ print(f"[ERROR] Error rebuilding BM25 from ChromaDB: {e}")
217
+ return False
218
+
219
+ # Create ensemble retriever with configurable top-k
220
+ self.ensemble_retriever = self._build_retriever(
221
+ self.bm25_retriever, self.chroma_retriever
222
+ )
223
+
224
+ print("[SUCCESS] Knowledge base loaded successfully!")
225
+ print("[INFO] Use kb.search(query, top_k) to perform searches.")
226
+ return True
227
+
228
+ except Exception as e:
229
+ print(f"[ERROR] Error loading knowledge base: {e}")
230
+ return False
231
+
232
+ def search(self, query: str, top_k: int = 10) -> List[Document]:
233
+ """
234
+ Main search method using ensemble retriever with configurable top-k
235
+
236
+ Args:
237
+ query: Search query
238
+ top_k: Number of documents to return from each retriever (default: 5)
239
+
240
+ Returns:
241
+ List of retrieved documents
242
+ """
243
+ if not self.ensemble_retriever:
244
+ raise ValueError(
245
+ "Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first."
246
+ )
247
+
248
+ # Build config based on top_k parameter
249
+ config = {
250
+ "configurable": {
251
+ "bm25_k": top_k * 5,
252
+ "chroma_search_kwargs": {"k": top_k * 5},
253
+ }
254
+ }
255
+
256
+ results = self.ensemble_retriever.invoke(query, config=config)
257
+ reranked_results = self.reranker.rerank(results, query, top_n=top_k)
258
+
259
+ final_results = []
260
+ for rerank_item in reranked_results:
261
+ # Get the original document using the index
262
+ original_doc = results[rerank_item["index"]]
263
+
264
+ # Create a new document with the relevance score added to metadata
265
+ reranked_doc = Document(
266
+ page_content=original_doc.page_content,
267
+ metadata={
268
+ **original_doc.metadata,
269
+ "relevance_score": rerank_item["relevance_score"],
270
+ },
271
+ )
272
+ final_results.append(reranked_doc)
273
+
274
+ return final_results
275
+
276
+ def get_stats(self) -> dict:
277
+ """Get statistics about the knowledge base"""
278
+ stats = {}
279
+
280
+ if self.chroma_retriever:
281
+ try:
282
+ vectorstore = self.chroma_retriever.vectorstore
283
+ collection = vectorstore._collection
284
+ stats["chroma_documents"] = collection.count()
285
+ except:
286
+ stats["chroma_documents"] = "Unknown"
287
+
288
+ if self.bm25_retriever:
289
+ try:
290
+ stats["bm25_documents"] = len(self.bm25_retriever.docs)
291
+ except:
292
+ stats["bm25_documents"] = "Unknown"
293
+
294
+ stats["ensemble_available"] = self.ensemble_retriever is not None
295
+ stats["device"] = self.device
296
+ stats["vietnamese_tokenizer"] = "Vietnamese BM25 tokenizer (underthesea)"
297
+
298
+ return stats
299
+
300
+ def _find_word_documents(self, folder_path: str) -> List[str]:
301
+ """
302
+ Find all Word documents (.doc, .docx) in the given folder
303
+
304
+ Args:
305
+ folder_path: Path to the folder to search
306
+
307
+ Returns:
308
+ List of full paths to Word documents
309
+ """
310
+ word_files = []
311
+ folder = Path(folder_path)
312
+
313
+ if not folder.exists():
314
+ raise FileNotFoundError(f"Folder not found: {folder_path}")
315
+
316
+ # Search for Word documents
317
+ for pattern in ["*.doc", "*.docx"]:
318
+ word_files.extend(folder.glob(pattern))
319
+
320
+ # Convert to string paths and sort for consistent processing order
321
+ word_files = [str(f) for f in word_files]
322
+ word_files.sort()
323
+
324
+ print(f"[INFO] Found Word documents: {[Path(f).name for f in word_files]}")
325
+ return word_files
326
+
327
+ def _process_all_word_files(
328
+ self, word_files: List[str], word_processor: ContextualWordProcessor
329
+ ) -> List[Document]:
330
+ """Process all Word files into unified chunks with contextual enhancement"""
331
+ all_documents = []
332
+
333
+ for file_path in word_files:
334
+ try:
335
+ print(f"[INFO] Processing: {Path(file_path).name}")
336
+ chunks = word_processor.process_word_document(file_path)
337
+ all_documents.extend(chunks)
338
+
339
+ # Print processing stats for this file
340
+ stats = word_processor.get_document_stats(chunks)
341
+ print(
342
+ f"[SUCCESS] Processed {Path(file_path).name}: {len(chunks)} chunks"
343
+ )
344
+ print(f" - Contextualized: {stats.get('contextualized_docs', 0)}")
345
+ print(
346
+ f" - Non-contextualized: {stats.get('non_contextualized_docs', 0)}"
347
+ )
348
+
349
+ except Exception as e:
350
+ print(f"[ERROR] Error processing {Path(file_path).name}: {e}")
351
+
352
+ return all_documents
353
+
354
+ def _build_retriever(self, bm25_retriever, chroma_retriever):
355
+ """
356
+ Build ensemble retriever with configurable top-k parameters
357
+
358
+ Args:
359
+ bm25_retriever: BM25 retriever with configurable fields
360
+ chroma_retriever: Chroma retriever with configurable fields
361
+
362
+ Returns:
363
+ EnsembleRetriever with configurable retrievers
364
+ """
365
+ return EnsembleRetriever(
366
+ retrievers=[bm25_retriever, chroma_retriever],
367
+ weights=[0.2, 0.8], # Slightly favor semantic search
368
+ )
369
+
370
+ def _build_chroma_retriever(
371
+ self, documents: List[Document], chroma_dir: str, reset: bool
372
+ ):
373
+ """Build ChromaDB retriever with configurable search parameters"""
374
+
375
+ if reset and os.path.exists(chroma_dir):
376
+ import shutil
377
+
378
+ shutil.rmtree(chroma_dir)
379
+ print("[INFO] Removed existing ChromaDB for rebuild")
380
+
381
+ # Create Chroma vectorstore (uses contextualized content)
382
+ vectorstore = Chroma.from_documents(
383
+ documents=documents, embedding=self.embeddings, persist_directory=chroma_dir
384
+ )
385
+
386
+ # Create retriever with configurable search_kwargs
387
+ retriever = vectorstore.as_retriever(
388
+ search_kwargs={"k": 5} # default value
389
+ ).configurable_fields(
390
+ search_kwargs=ConfigurableField(
391
+ id="chroma_search_kwargs",
392
+ name="Chroma Search Kwargs",
393
+ description="Search kwargs for Chroma DB retriever",
394
+ )
395
+ )
396
+
397
+ print(
398
+ f"[SUCCESS] ChromaDB created with {len(documents)} contextualized documents"
399
+ )
400
+ return retriever
401
+
402
+ def _build_bm25_retriever(
403
+ self, documents: List[Document], bm25_path: str, reset: bool
404
+ ):
405
+ """Build BM25 retriever with Vietnamese tokenization and configurable k parameter"""
406
+
407
+ # Note: We no longer save BM25 to pickle file to avoid Streamlit Cloud compatibility issues
408
+ # BM25 will be rebuilt from ChromaDB documents when loading the knowledge base
409
+
410
+ # Create BM25 retriever with Vietnamese tokenizer as preprocess_func
411
+ print("[INFO] Using Vietnamese tokenizer for BM25 on contextualized content...")
412
+ retriever = BM25Retriever.from_documents(
413
+ documents=documents,
414
+ preprocess_func=self.text_processor.bm25_tokenizer,
415
+ k=5, # default value
416
+ ).configurable_fields(
417
+ k=ConfigurableField(
418
+ id="bm25_k",
419
+ name="BM25 Top K",
420
+ description="Number of documents to return from BM25",
421
+ )
422
+ )
423
+
424
+ print(
425
+ f"[SUCCESS] BM25 retriever created with {len(documents)} contextualized documents"
426
+ )
427
+ return retriever
428
+
429
+
430
+ def test_contextual_kb(kb: ViettelKnowledgeBase, test_queries: List[str]):
431
+ """Test function for the contextual knowledge base"""
432
+
433
+ print("\n[INFO] Testing Contextual Knowledge Base")
434
+ print("=" * 60)
435
+
436
+ for i, query in enumerate(test_queries, 1):
437
+ print(f"\n#{i} Query: '{query}'")
438
+ print("-" * 40)
439
+
440
+ try:
441
+ # Test ensemble search with configurable top-k
442
+ results = kb.search(query, top_k=3)
443
+
444
+ if results:
445
+ for j, doc in enumerate(results, 1):
446
+ content_preview = doc.page_content[:150].replace("\n", " ")
447
+ doc_type = doc.metadata.get("doc_type", "unknown")
448
+ has_context = doc.metadata.get("has_context", False)
449
+ context_indicator = (
450
+ " [CONTEXTUAL]" if has_context else " [ORIGINAL]"
451
+ )
452
+ print(
453
+ f" {j}. [{doc_type}]{context_indicator} {content_preview}..."
454
+ )
455
+ else:
456
+ print(" No results found")
457
+
458
+ except Exception as e:
459
+ print(f" [ERROR] Error: {e}")
460
+
461
+
462
+ # Example usage
463
+ if __name__ == "__main__":
464
+ # Initialize knowledge base
465
+ kb = ViettelKnowledgeBase(
466
+ embedding_model="dangvantuan/vietnamese-document-embedding"
467
+ )
468
+
469
+ # Build knowledge base from a folder of Word documents
470
+ documents_folder = "./viettelpay_docs" # Folder containing .doc/.docx files
471
+
472
+ try:
473
+ # Build knowledge base (pass OpenAI API key here for contextual enhancement)
474
+ kb.build_knowledge_base(
475
+ documents_folder,
476
+ "./contextual_kb",
477
+ reset=True,
478
+ openai_api_key="your-openai-api-key-here", # or None to use env variable
479
+ )
480
+
481
+ # Alternative: Load existing knowledge base
482
+ # success = kb.load_knowledge_base("./contextual_kb")
483
+ # if not success:
484
+ # print("[ERROR] Failed to load knowledge base")
485
+
486
+ # Test queries
487
+ test_queries = [
488
+ "lỗi 606",
489
+ "không nạp được tiền",
490
+ "hướng dẫn nạp cước",
491
+ "quy định hủy giao dịch",
492
+ "mệnh giá thẻ cào",
493
+ ]
494
+
495
+ # Test the knowledge base
496
+ test_contextual_kb(kb, test_queries)
497
+
498
+ # Example of runtime configuration for different top-k values
499
+ print(f"\n[INFO] Example of runtime configuration:")
500
+ print("=" * 50)
501
+
502
+ # Search with different top-k values
503
+ sample_query = "lỗi 606"
504
+
505
+ # Search with top_k=3
506
+ results1 = kb.search(sample_query, top_k=3)
507
+ print(f"Search with top_k=3: {len(results1)} total results")
508
+
509
+ # Search with top_k=8
510
+ results2 = kb.search(sample_query, top_k=8)
511
+ print(f"Search with top_k=8: {len(results2)} total results")
512
+
513
+ # Show stats
514
+ print(f"\n[INFO] Knowledge Base Stats: {kb.get_stats()}")
515
+
516
+ except Exception as e:
517
+ print(f"[ERROR] Error building knowledge base: {e}")
518
+ print("[INFO] Make sure you have:")
519
+ print(" 1. Valid OpenAI API key")
520
+ print(" 2. Word documents in the specified folder")
521
+ print(" 3. Required dependencies installed (openai, markitdown, etc.)")
src/llm/__pycache__/langchain_models.cpython-311.pyc ADDED
Binary file (4.49 kB). View file
 
src/llm/__pycache__/llm_client.cpython-310.pyc ADDED
Binary file (5.65 kB). View file
 
src/llm/__pycache__/llm_client.cpython-311.pyc ADDED
Binary file (9.31 kB). View file
 
src/llm/llm_client.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Client Abstraction Layer
3
+ Supports multiple LLM providers without hardcoding
4
+ """
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import Dict, Any, Optional
8
+ import os
9
+
10
+ # Import configuration utility
11
+ from src.utils.config import get_gemini_api_key, get_openai_api_key
12
+
13
+
14
+ class BaseLLMClient(ABC):
15
+ """Abstract base class for LLM clients"""
16
+
17
+ def __init__(self, **kwargs):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def generate(self, prompt: str, **kwargs) -> str:
22
+ """Generate response from prompt"""
23
+ pass
24
+
25
+ @abstractmethod
26
+ def is_available(self) -> bool:
27
+ """Check if LLM service is available"""
28
+ pass
29
+
30
+
31
+ class GeminiClient(BaseLLMClient):
32
+ """Google Gemini client implementation"""
33
+
34
+ def __init__(self, api_key: Optional[str] = None, model: str = "gemini-2.0-flash"):
35
+ self.api_key = api_key or get_gemini_api_key()
36
+ self.model = model
37
+
38
+ if not self.api_key:
39
+ raise ValueError("Gemini API key not provided")
40
+
41
+ try:
42
+ import google.generativeai as genai
43
+
44
+ genai.configure(api_key=self.api_key)
45
+ self.client = genai.GenerativeModel(self.model)
46
+ print(f"✅ Gemini client initialized with model: {self.model}")
47
+ except ImportError:
48
+ raise ImportError("google-generativeai package not installed")
49
+
50
+ def generate(self, prompt: str, **kwargs) -> str:
51
+ """Generate response using Gemini"""
52
+ try:
53
+ # Set default temperature to 0.1 for consistency
54
+ generation_config = {
55
+ "temperature": kwargs.get("temperature", 0.1),
56
+ "top_p": kwargs.get("top_p", 0.8),
57
+ "top_k": kwargs.get("top_k", 40),
58
+ "max_output_tokens": kwargs.get("max_output_tokens", 2048),
59
+ }
60
+
61
+ response = self.client.generate_content(
62
+ prompt, generation_config=generation_config
63
+ )
64
+ return response.text
65
+ except Exception as e:
66
+ print(f"❌ Gemini generation error: {e}")
67
+ raise
68
+
69
+ def is_available(self) -> bool:
70
+ """Check Gemini availability"""
71
+ try:
72
+ test_response = self.client.generate_content("Hello")
73
+ return bool(test_response.text)
74
+ except:
75
+ return False
76
+
77
+
78
+ class OpenAIClient(BaseLLMClient):
79
+ """OpenAI client implementation"""
80
+
81
+ def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4"):
82
+ self.api_key = api_key or get_openai_api_key()
83
+ self.model = model
84
+
85
+ if not self.api_key:
86
+ raise ValueError("OpenAI API key not provided")
87
+
88
+ try:
89
+ import openai
90
+
91
+ self.client = openai.OpenAI(api_key=self.api_key)
92
+ print(f"✅ OpenAI client initialized with model: {self.model}")
93
+ except ImportError:
94
+ raise ImportError("openai package not installed")
95
+
96
+ def generate(self, prompt: str, **kwargs) -> str:
97
+ """Generate response using OpenAI"""
98
+ try:
99
+ # Set default temperature to 0.1 for consistency
100
+ openai_kwargs = {
101
+ "temperature": kwargs.get("temperature", 0.1),
102
+ "top_p": kwargs.get("top_p", 1.0),
103
+ "max_tokens": kwargs.get("max_tokens", 2048),
104
+ }
105
+ # Remove any Gemini-specific parameters
106
+ openai_kwargs.update(
107
+ {
108
+ k: v
109
+ for k, v in kwargs.items()
110
+ if k
111
+ in [
112
+ "temperature",
113
+ "top_p",
114
+ "max_tokens",
115
+ "frequency_penalty",
116
+ "presence_penalty",
117
+ ]
118
+ }
119
+ )
120
+
121
+ response = self.client.chat.completions.create(
122
+ model=self.model,
123
+ messages=[{"role": "user", "content": prompt}],
124
+ **openai_kwargs,
125
+ )
126
+ return response.choices[0].message.content
127
+ except Exception as e:
128
+ print(f"❌ OpenAI generation error: {e}")
129
+ raise
130
+
131
+ def is_available(self) -> bool:
132
+ """Check OpenAI availability"""
133
+ try:
134
+ response = self.client.chat.completions.create(
135
+ model=self.model,
136
+ messages=[{"role": "user", "content": "Hello"}],
137
+ max_tokens=5,
138
+ )
139
+ return bool(response.choices[0].message.content)
140
+ except:
141
+ return False
142
+
143
+
144
+ class LLMClientFactory:
145
+ """Factory for creating LLM clients"""
146
+
147
+ SUPPORTED_PROVIDERS = {
148
+ "gemini": GeminiClient,
149
+ "openai": OpenAIClient,
150
+ }
151
+
152
+ @classmethod
153
+ def create_client(self, provider: str = "gemini", **kwargs) -> BaseLLMClient:
154
+ """Create LLM client by provider name"""
155
+
156
+ if provider not in self.SUPPORTED_PROVIDERS:
157
+ raise ValueError(
158
+ f"Unsupported provider: {provider}. Supported: {list(self.SUPPORTED_PROVIDERS.keys())}"
159
+ )
160
+
161
+ client_class = self.SUPPORTED_PROVIDERS[provider]
162
+ return client_class(**kwargs)
163
+
164
+ @classmethod
165
+ def get_available_providers(cls) -> list:
166
+ """Get list of available providers"""
167
+ return list(cls.SUPPORTED_PROVIDERS.keys())
168
+
169
+
170
+ # Usage example
171
+ if __name__ == "__main__":
172
+ # Test Gemini client
173
+ try:
174
+ client = LLMClientFactory.create_client("gemini")
175
+ if client.is_available():
176
+ response = client.generate("Xin chào, bạn có khỏe không?")
177
+ print(f"Response: {response}")
178
+ else:
179
+ print("Gemini not available")
180
+ except Exception as e:
181
+ print(f"Error: {e}")
src/processor/__pycache__/contextual_word_processor.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
src/processor/__pycache__/csv_processor.cpython-310.pyc ADDED
Binary file (4.85 kB). View file
 
src/processor/__pycache__/csv_processor.cpython-311.pyc ADDED
Binary file (9.4 kB). View file
 
src/processor/__pycache__/csv_processor.cpython-312.pyc ADDED
Binary file (10.4 kB). View file