Spaces:
Running
Running
Upload 73 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env +9 -0
- .gitattributes +4 -35
- .gitignore +2 -0
- .streamlit/secrets.toml +23 -0
- evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json +0 -0
- evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json +0 -0
- evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json +0 -0
- evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval.json +0 -0
- evaluation_data/results/intent_classification/viettelpay_intent_results.json +0 -0
- evaluation_data/results/multi_turn_retrieval/viettelpay_multiturn_results.json +0 -0
- evaluation_data/results/single_turn_retrieval/viettelpay_eval_results.json +0 -0
- knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/data_level0.bin +3 -0
- knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/header.bin +0 -0
- knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/length.bin +0 -0
- knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/link_lists.bin +0 -0
- knowledge_base/chroma/chroma.sqlite3 +3 -0
- requirements.txt +31 -2
- src/__pycache__/knowledge_base_builder.cpython-310.pyc +0 -0
- src/__pycache__/knowledge_base_builder.cpython-312.pyc +0 -0
- src/__pycache__/simplified_knowledge_base.cpython-310.pyc +0 -0
- src/agent/__pycache__/memory.cpython-311.pyc +0 -0
- src/agent/__pycache__/nodes.cpython-310.pyc +0 -0
- src/agent/__pycache__/nodes.cpython-311.pyc +0 -0
- src/agent/__pycache__/prompts.cpython-311.pyc +0 -0
- src/agent/__pycache__/scripts.cpython-310.pyc +0 -0
- src/agent/__pycache__/scripts.cpython-311.pyc +0 -0
- src/agent/__pycache__/viettelpay_agent.cpython-310.pyc +0 -0
- src/agent/__pycache__/viettelpay_agent.cpython-311.pyc +0 -0
- src/agent/nodes.py +463 -0
- src/agent/prompts.py +125 -0
- src/agent/scripts.py +157 -0
- src/agent/viettelpay_agent.py +416 -0
- src/evaluation/__pycache__/prompts.cpython-311.pyc +0 -0
- src/evaluation/__pycache__/single_turn_retrieval.cpython-311.pyc +0 -0
- src/evaluation/intent_classification.py +901 -0
- src/evaluation/multi_turn_retrieval.py +815 -0
- src/evaluation/prompts.py +318 -0
- src/evaluation/single_turn_retrieval.py +844 -0
- src/knowledge_base/__pycache__/builder.cpython-310.pyc +0 -0
- src/knowledge_base/__pycache__/builder.cpython-311.pyc +0 -0
- src/knowledge_base/__pycache__/viettel_knowledge_base.cpython-311.pyc +0 -0
- src/knowledge_base/viettel_knowledge_base.py +521 -0
- src/llm/__pycache__/langchain_models.cpython-311.pyc +0 -0
- src/llm/__pycache__/llm_client.cpython-310.pyc +0 -0
- src/llm/__pycache__/llm_client.cpython-311.pyc +0 -0
- src/llm/llm_client.py +181 -0
- src/processor/__pycache__/contextual_word_processor.cpython-311.pyc +0 -0
- src/processor/__pycache__/csv_processor.cpython-310.pyc +0 -0
- src/processor/__pycache__/csv_processor.cpython-311.pyc +0 -0
- src/processor/__pycache__/csv_processor.cpython-312.pyc +0 -0
.env
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .env file
|
2 |
+
GEMINI_API_KEY="AIzaSyAn1HDv1_zU4TQbwwUtBIOo7tu5iKVBWho"
|
3 |
+
GOOGLE_API_KEY="AIzaSyAn1HDv1_zU4TQbwwUtBIOo7tu5iKVBWho"
|
4 |
+
OPENAI_API_KEY="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"
|
5 |
+
|
6 |
+
COHERE_API_KEY="D6PBHYizSmWFqzHoMWafV65yJelDh6X3Xg0ghIue"
|
7 |
+
|
8 |
+
# Production
|
9 |
+
# COHERE_API_KEY="VFaacPxkjW0L4HaijiBXuKYWqgYj8XkAo3o5uMWu"
|
.gitattributes
CHANGED
@@ -1,35 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto
|
3 |
+
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/data_level0.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
knowledge_base/chroma/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
secrets.toml
|
.streamlit/secrets.toml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Streamlit Secrets Configuration
|
2 |
+
# Copy your API keys here for production
|
3 |
+
# For local development, you can still use .env file
|
4 |
+
|
5 |
+
[api_keys]
|
6 |
+
GEMINI_API_KEY = "AIzaSyAn1HDv1_zU4TQbwwUtBIOo7tu5iKVBWho"
|
7 |
+
OPENAI_API_KEY = "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"
|
8 |
+
COHERE_API_KEY = "D6PBHYizSmWFqzHoMWafV65yJelDh6X3Xg0ghIue"
|
9 |
+
|
10 |
+
# Production
|
11 |
+
# COHERE_API_KEY="VFaacPxkjW0L4HaijiBXuKYWqgYj8XkAo3o5uMWu"
|
12 |
+
|
13 |
+
# Database and storage paths
|
14 |
+
[paths]
|
15 |
+
KNOWLEDGE_BASE_PATH = "./knowledge_base"
|
16 |
+
DOCUMENTS_FOLDER = "./viettelpay_docs"
|
17 |
+
|
18 |
+
# Model configurations
|
19 |
+
[models]
|
20 |
+
EMBEDDING_MODEL = "dangvantuan/vietnamese-document-embedding"
|
21 |
+
LLM_PROVIDER = "gemini"
|
22 |
+
GEMINI_MODEL = "gemini-2.0-flash"
|
23 |
+
OPENAI_MODEL = "gpt-4o-mini"
|
evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation_data/results/intent_classification/viettelpay_intent_results.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation_data/results/multi_turn_retrieval/viettelpay_multiturn_results.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation_data/results/single_turn_retrieval/viettelpay_eval_results.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23add52afbe7588391f32d3deffb581b2663d2e2ad8851aba7de25e6b3f66761
|
3 |
+
size 32120000
|
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/header.bin
ADDED
Binary file (100 Bytes). View file
|
|
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/length.bin
ADDED
Binary file (40 kB). View file
|
|
knowledge_base/chroma/c8c2137c-264c-4fe5-a301-20b02985da11/link_lists.bin
ADDED
File without changes
|
knowledge_base/chroma/chroma.sqlite3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b962ed5c081a92e80ccb83ef40c1dfb530542ca161a4cefe9fc6ccebebf23e75
|
3 |
+
size 1937408
|
requirements.txt
CHANGED
@@ -1,3 +1,32 @@
|
|
1 |
-
|
2 |
pandas
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# requirements.txt
|
2 |
pandas
|
3 |
+
python-docx
|
4 |
+
langchain-chroma==0.2.4
|
5 |
+
chromadb==1.0.12
|
6 |
+
langchain_cohere==0.4.4
|
7 |
+
langchain==0.3.25
|
8 |
+
langgraph
|
9 |
+
langchain_community==0.3.24
|
10 |
+
google-generativeai
|
11 |
+
openai
|
12 |
+
sentence-transformers==4.1.0
|
13 |
+
pydantic
|
14 |
+
python-dotenv
|
15 |
+
PyYAML
|
16 |
+
tqdm
|
17 |
+
loguru
|
18 |
+
scikit-learn==1.6.1
|
19 |
+
protobuf<3.21
|
20 |
+
grpcio-status==1.48.2
|
21 |
+
torch
|
22 |
+
transformers==4.52.4
|
23 |
+
rank_bm25==0.2.2
|
24 |
+
markitdown[docx]
|
25 |
+
underthesea
|
26 |
+
pyvi
|
27 |
+
langchain-huggingface==0.2.0
|
28 |
+
streamlit
|
29 |
+
langmem
|
30 |
+
dotenv
|
31 |
+
numpy==1.26.4
|
32 |
+
python-docx
|
src/__pycache__/knowledge_base_builder.cpython-310.pyc
ADDED
Binary file (11.1 kB). View file
|
|
src/__pycache__/knowledge_base_builder.cpython-312.pyc
ADDED
Binary file (7.03 kB). View file
|
|
src/__pycache__/simplified_knowledge_base.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
src/agent/__pycache__/memory.cpython-311.pyc
ADDED
Binary file (4.13 kB). View file
|
|
src/agent/__pycache__/nodes.cpython-310.pyc
ADDED
Binary file (8.57 kB). View file
|
|
src/agent/__pycache__/nodes.cpython-311.pyc
ADDED
Binary file (17.2 kB). View file
|
|
src/agent/__pycache__/prompts.cpython-311.pyc
ADDED
Binary file (8.36 kB). View file
|
|
src/agent/__pycache__/scripts.cpython-310.pyc
ADDED
Binary file (7.02 kB). View file
|
|
src/agent/__pycache__/scripts.cpython-311.pyc
ADDED
Binary file (9.54 kB). View file
|
|
src/agent/__pycache__/viettelpay_agent.cpython-310.pyc
ADDED
Binary file (5.74 kB). View file
|
|
src/agent/__pycache__/viettelpay_agent.cpython-311.pyc
ADDED
Binary file (17 kB). View file
|
|
src/agent/nodes.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LangGraph Agent State and Processing Nodes
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, List, Optional, TypedDict, Annotated
|
6 |
+
from langchain.schema import Document
|
7 |
+
from langchain_core.messages import AnyMessage
|
8 |
+
from langgraph.graph.message import add_messages
|
9 |
+
import json
|
10 |
+
import re
|
11 |
+
|
12 |
+
from src.agent.prompts import (
|
13 |
+
INTENT_CLASSIFICATION_PROMPT,
|
14 |
+
QUERY_ENHANCEMENT_PROMPT,
|
15 |
+
RESPONSE_GENERATION_PROMPT,
|
16 |
+
get_system_prompt_by_intent,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class ViettelPayState(TypedDict):
|
21 |
+
"""State for ViettelPay agent workflow with message history support"""
|
22 |
+
|
23 |
+
# Message history for multi-turn conversation
|
24 |
+
messages: Annotated[List[AnyMessage], add_messages]
|
25 |
+
|
26 |
+
# Processing
|
27 |
+
intent: Optional[str]
|
28 |
+
confidence: Optional[float]
|
29 |
+
|
30 |
+
# Query enhancement
|
31 |
+
enhanced_query: Optional[str]
|
32 |
+
|
33 |
+
# Knowledge retrieval
|
34 |
+
retrieved_docs: Optional[List[Document]]
|
35 |
+
|
36 |
+
# Conversation context (cached to avoid repeated computation)
|
37 |
+
conversation_context: Optional[str]
|
38 |
+
|
39 |
+
# Response type metadata
|
40 |
+
response_type: Optional[str] # "script" or "generated"
|
41 |
+
|
42 |
+
# Metadata
|
43 |
+
error: Optional[str]
|
44 |
+
processing_info: Optional[Dict]
|
45 |
+
|
46 |
+
|
47 |
+
def get_conversation_context(messages: List[AnyMessage], max_messages: int = 3) -> str:
|
48 |
+
"""
|
49 |
+
Extract conversation context from message history
|
50 |
+
|
51 |
+
Args:
|
52 |
+
messages: List of conversation messages
|
53 |
+
max_messages: Maximum number of recent messages to include
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Formatted conversation context string
|
57 |
+
"""
|
58 |
+
if len(messages) <= 1:
|
59 |
+
return ""
|
60 |
+
|
61 |
+
context = "\n\nLịch sử cuộc hội thoại:\n"
|
62 |
+
# Get recent messages (excluding the current/last message for intent classification)
|
63 |
+
recent_messages = messages[
|
64 |
+
-(max_messages + 1) : -1
|
65 |
+
] # Exclude the very last message
|
66 |
+
|
67 |
+
for msg in recent_messages:
|
68 |
+
# Handle different message types more robustly
|
69 |
+
if hasattr(msg, "type"):
|
70 |
+
if msg.type == "human":
|
71 |
+
role = "Người dùng"
|
72 |
+
elif msg.type == "ai":
|
73 |
+
role = "Trợ lý"
|
74 |
+
else:
|
75 |
+
role = f"Unknown-{msg.type}"
|
76 |
+
elif hasattr(msg, "role"):
|
77 |
+
if msg.role in ["user", "human"]:
|
78 |
+
role = "Người dùng"
|
79 |
+
elif msg.role in ["assistant", "ai"]:
|
80 |
+
role = "Trợ lý"
|
81 |
+
else:
|
82 |
+
role = f"Unknown-{msg.role}"
|
83 |
+
else:
|
84 |
+
role = "Unknown"
|
85 |
+
|
86 |
+
# Limit message length to avoid token overflow
|
87 |
+
# content = msg.content[:1000] + "..." if len(msg.content) > 1000 else msg.content
|
88 |
+
content = msg.content
|
89 |
+
context += f"{role}: {content}\n"
|
90 |
+
# print(context)
|
91 |
+
return context
|
92 |
+
|
93 |
+
|
94 |
+
def classify_intent_node(state: ViettelPayState, llm_client) -> ViettelPayState:
|
95 |
+
"""Node for intent classification using LLM with conversation context"""
|
96 |
+
|
97 |
+
# Get the latest user message
|
98 |
+
messages = state["messages"]
|
99 |
+
if not messages:
|
100 |
+
return {
|
101 |
+
**state,
|
102 |
+
"intent": "unclear",
|
103 |
+
"confidence": 0.0,
|
104 |
+
"error": "No messages found",
|
105 |
+
}
|
106 |
+
|
107 |
+
# Find the last human/user message
|
108 |
+
user_message = None
|
109 |
+
for msg in reversed(messages):
|
110 |
+
if hasattr(msg, "type") and msg.type == "human":
|
111 |
+
user_message = msg.content
|
112 |
+
break
|
113 |
+
elif hasattr(msg, "role") and msg.role == "user":
|
114 |
+
user_message = msg.content
|
115 |
+
break
|
116 |
+
|
117 |
+
if not user_message:
|
118 |
+
return {
|
119 |
+
**state,
|
120 |
+
"intent": "unclear",
|
121 |
+
"confidence": 0.0,
|
122 |
+
"error": "No user message found",
|
123 |
+
}
|
124 |
+
|
125 |
+
try:
|
126 |
+
# Get conversation context for better intent classification
|
127 |
+
conversation_context = get_conversation_context(messages)
|
128 |
+
|
129 |
+
# Intent classification prompt with context using the prompts file
|
130 |
+
classification_prompt = INTENT_CLASSIFICATION_PROMPT.format(
|
131 |
+
conversation_context=conversation_context, user_message=user_message
|
132 |
+
)
|
133 |
+
|
134 |
+
# Get classification using the pre-initialized LLM client
|
135 |
+
response = llm_client.generate(classification_prompt, temperature=0.1)
|
136 |
+
|
137 |
+
# print(f"🔍 Raw LLM response: {response}")
|
138 |
+
|
139 |
+
# Parse JSON response
|
140 |
+
try:
|
141 |
+
# Try to extract JSON from response (in case there's extra text)
|
142 |
+
response_clean = response.strip()
|
143 |
+
|
144 |
+
# Look for JSON object in the response
|
145 |
+
json_match = re.search(r"\{.*\}", response_clean, re.DOTALL)
|
146 |
+
if json_match:
|
147 |
+
json_str = json_match.group()
|
148 |
+
result = json.loads(json_str)
|
149 |
+
else:
|
150 |
+
# Try parsing the whole response
|
151 |
+
result = json.loads(response_clean)
|
152 |
+
|
153 |
+
intent = result.get("intent", "unclear")
|
154 |
+
confidence = result.get("confidence", 0.5)
|
155 |
+
explanation = result.get("explanation", "")
|
156 |
+
|
157 |
+
# print(
|
158 |
+
# f"✅ JSON parsed successfully: intent={intent}, confidence={confidence}"
|
159 |
+
# )
|
160 |
+
|
161 |
+
except (json.JSONDecodeError, AttributeError) as e:
|
162 |
+
print(f"❌ JSON parsing failed: {e}")
|
163 |
+
print(f" Raw response: {response}")
|
164 |
+
|
165 |
+
# Fallback: try to extract intent from text
|
166 |
+
response_lower = response.lower()
|
167 |
+
if any(
|
168 |
+
word in response_lower for word in ["lỗi", "error", "606", "mã lỗi"]
|
169 |
+
):
|
170 |
+
intent = "error_help"
|
171 |
+
confidence = 0.7
|
172 |
+
elif any(word in response_lower for word in ["xin chào", "hello", "chào"]):
|
173 |
+
intent = "greeting"
|
174 |
+
confidence = 0.8
|
175 |
+
elif any(word in response_lower for word in ["hủy", "cancel", "thủ tục"]):
|
176 |
+
intent = "procedure_guide"
|
177 |
+
confidence = 0.7
|
178 |
+
elif any(
|
179 |
+
word in response_lower for word in ["nạp", "cước", "dịch vụ", "faq"]
|
180 |
+
):
|
181 |
+
intent = "faq"
|
182 |
+
confidence = 0.7
|
183 |
+
else:
|
184 |
+
intent = "unclear"
|
185 |
+
confidence = 0.3
|
186 |
+
|
187 |
+
print(f"🔄 Fallback classification: {intent} (confidence: {confidence})")
|
188 |
+
explanation = "Fallback classification due to JSON parse error"
|
189 |
+
|
190 |
+
# print(f"🎯 Intent classified: {intent} (confidence: {confidence})")
|
191 |
+
|
192 |
+
return {
|
193 |
+
**state,
|
194 |
+
"intent": intent,
|
195 |
+
"confidence": confidence,
|
196 |
+
"conversation_context": conversation_context, # Save context for reuse
|
197 |
+
"processing_info": {
|
198 |
+
"classification_raw": response,
|
199 |
+
"explanation": explanation,
|
200 |
+
"context_used": bool(conversation_context.strip()),
|
201 |
+
},
|
202 |
+
}
|
203 |
+
|
204 |
+
except Exception as e:
|
205 |
+
print(f"❌ Intent classification error: {e}")
|
206 |
+
return {**state, "intent": "unclear", "confidence": 0.0, "error": str(e)}
|
207 |
+
|
208 |
+
|
209 |
+
def query_enhancement_node(state: ViettelPayState, llm_client) -> ViettelPayState:
|
210 |
+
"""Node for enhancing search query using conversation context"""
|
211 |
+
|
212 |
+
# Get the latest user message
|
213 |
+
messages = state["messages"]
|
214 |
+
if not messages:
|
215 |
+
return {**state, "enhanced_query": "", "error": "No messages found"}
|
216 |
+
|
217 |
+
# Find the last human/user message
|
218 |
+
user_message = None
|
219 |
+
for msg in reversed(messages):
|
220 |
+
if hasattr(msg, "type") and msg.type == "human":
|
221 |
+
user_message = msg.content
|
222 |
+
break
|
223 |
+
elif hasattr(msg, "role") and msg.role == "user":
|
224 |
+
user_message = msg.content
|
225 |
+
break
|
226 |
+
|
227 |
+
if not user_message:
|
228 |
+
return {**state, "enhanced_query": "", "error": "No user message found"}
|
229 |
+
|
230 |
+
try:
|
231 |
+
# Use saved conversation context if available, otherwise get it
|
232 |
+
conversation_context = state.get("conversation_context")
|
233 |
+
if conversation_context is None:
|
234 |
+
conversation_context = get_conversation_context(messages)
|
235 |
+
|
236 |
+
# If no context, use original message
|
237 |
+
if not conversation_context.strip():
|
238 |
+
print(f"🔍 No context available, using original query: {user_message}")
|
239 |
+
return {**state, "enhanced_query": user_message}
|
240 |
+
|
241 |
+
# Query enhancement prompt using the prompts file
|
242 |
+
enhancement_prompt = QUERY_ENHANCEMENT_PROMPT.format(
|
243 |
+
conversation_context=conversation_context, user_message=user_message
|
244 |
+
)
|
245 |
+
|
246 |
+
# Get enhanced query
|
247 |
+
enhanced_query = llm_client.generate(enhancement_prompt, temperature=0.1)
|
248 |
+
enhanced_query = enhanced_query.strip()
|
249 |
+
|
250 |
+
print(f"🔍 Original query: {user_message}")
|
251 |
+
print(f"🚀 Enhanced query: {enhanced_query}")
|
252 |
+
|
253 |
+
return {**state, "enhanced_query": enhanced_query}
|
254 |
+
|
255 |
+
except Exception as e:
|
256 |
+
print(f"❌ Query enhancement error: {e}")
|
257 |
+
# Fallback to original message
|
258 |
+
return {**state, "enhanced_query": user_message, "error": str(e)}
|
259 |
+
|
260 |
+
|
261 |
+
def knowledge_retrieval_node(
|
262 |
+
state: ViettelPayState, knowledge_retriever
|
263 |
+
) -> ViettelPayState:
|
264 |
+
"""Node for knowledge retrieval using pre-initialized ViettelKnowledgeBase"""
|
265 |
+
|
266 |
+
# Use enhanced query if available, otherwise fall back to extracting from messages
|
267 |
+
enhanced_query = state.get("enhanced_query", "")
|
268 |
+
|
269 |
+
if not enhanced_query:
|
270 |
+
# Fallback: extract from messages
|
271 |
+
messages = state["messages"]
|
272 |
+
if not messages:
|
273 |
+
return {**state, "retrieved_docs": [], "error": "No messages found"}
|
274 |
+
|
275 |
+
# Find the last human/user message
|
276 |
+
for msg in reversed(messages):
|
277 |
+
if hasattr(msg, "type") and msg.type == "human":
|
278 |
+
enhanced_query = msg.content
|
279 |
+
break
|
280 |
+
elif hasattr(msg, "role") and msg.role == "user":
|
281 |
+
enhanced_query = msg.content
|
282 |
+
break
|
283 |
+
|
284 |
+
if not enhanced_query:
|
285 |
+
return {**state, "retrieved_docs": [], "error": "No query available"}
|
286 |
+
|
287 |
+
try:
|
288 |
+
if not knowledge_retriever:
|
289 |
+
raise ValueError("Knowledge retriever not available")
|
290 |
+
|
291 |
+
# Retrieve relevant documents using enhanced query and pre-initialized ViettelKnowledgeBase
|
292 |
+
retrieved_docs = knowledge_retriever.search(enhanced_query, top_k=10)
|
293 |
+
|
294 |
+
print(
|
295 |
+
f"📚 Retrieved {len(retrieved_docs)} documents for enhanced query: {enhanced_query}"
|
296 |
+
)
|
297 |
+
|
298 |
+
return {**state, "retrieved_docs": retrieved_docs}
|
299 |
+
|
300 |
+
except Exception as e:
|
301 |
+
print(f"❌ Knowledge retrieval error: {e}")
|
302 |
+
return {**state, "retrieved_docs": [], "error": str(e)}
|
303 |
+
|
304 |
+
|
305 |
+
def script_response_node(state: ViettelPayState) -> ViettelPayState:
|
306 |
+
"""Node for script-based responses"""
|
307 |
+
|
308 |
+
from src.agent.scripts import ConversationScripts
|
309 |
+
from langchain_core.messages import AIMessage
|
310 |
+
|
311 |
+
intent = state.get("intent", "")
|
312 |
+
|
313 |
+
try:
|
314 |
+
# Load scripts
|
315 |
+
scripts = ConversationScripts("./viettelpay_docs/processed/kich_ban.csv")
|
316 |
+
|
317 |
+
# Map intents to script types
|
318 |
+
intent_to_script = {
|
319 |
+
"greeting": "greeting",
|
320 |
+
"out_of_scope": "out_of_scope",
|
321 |
+
"human_request": "human_request_attempt_1", # Could be enhanced later
|
322 |
+
"unclear": "ask_for_clarity",
|
323 |
+
}
|
324 |
+
|
325 |
+
script_type = intent_to_script.get(intent)
|
326 |
+
|
327 |
+
if script_type and scripts.has_script(script_type):
|
328 |
+
response_text = scripts.get_script(script_type)
|
329 |
+
print(f"📋 Using script response: {script_type}")
|
330 |
+
|
331 |
+
# Add AI message to the conversation
|
332 |
+
ai_message = AIMessage(content=response_text)
|
333 |
+
|
334 |
+
return {**state, "messages": [ai_message], "response_type": "script"}
|
335 |
+
|
336 |
+
else:
|
337 |
+
# Fallback script
|
338 |
+
fallback_response = (
|
339 |
+
"Xin lỗi, em chưa hiểu rõ yêu cầu của anh/chị. Vui lòng thử lại."
|
340 |
+
)
|
341 |
+
ai_message = AIMessage(content=fallback_response)
|
342 |
+
|
343 |
+
print(f"📋 Using fallback script for intent: {intent}")
|
344 |
+
|
345 |
+
return {**state, "messages": [ai_message], "response_type": "script"}
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
print(f"❌ Script response error: {e}")
|
349 |
+
fallback_response = "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau."
|
350 |
+
ai_message = AIMessage(content=fallback_response)
|
351 |
+
|
352 |
+
return {
|
353 |
+
**state,
|
354 |
+
"messages": [ai_message],
|
355 |
+
"response_type": "error",
|
356 |
+
"error": str(e),
|
357 |
+
}
|
358 |
+
|
359 |
+
|
360 |
+
def generate_response_node(state: ViettelPayState, llm_client) -> ViettelPayState:
|
361 |
+
"""Node for LLM-based response generation with conversation context"""
|
362 |
+
|
363 |
+
from langchain_core.messages import AIMessage
|
364 |
+
|
365 |
+
# Get the latest user message and conversation history
|
366 |
+
messages = state["messages"]
|
367 |
+
if not messages:
|
368 |
+
ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.")
|
369 |
+
return {**state, "messages": [ai_message], "response_type": "error"}
|
370 |
+
|
371 |
+
# Find the last human/user message
|
372 |
+
user_message = None
|
373 |
+
for msg in reversed(messages):
|
374 |
+
if hasattr(msg, "type") and msg.type == "human":
|
375 |
+
user_message = msg.content
|
376 |
+
break
|
377 |
+
elif hasattr(msg, "role") and msg.role == "user":
|
378 |
+
user_message = msg.content
|
379 |
+
break
|
380 |
+
|
381 |
+
if not user_message:
|
382 |
+
ai_message = AIMessage(content="Xin lỗi, em không thể xử lý yêu cầu này.")
|
383 |
+
return {**state, "messages": [ai_message], "response_type": "error"}
|
384 |
+
|
385 |
+
intent = state.get("intent", "")
|
386 |
+
retrieved_docs = state.get("retrieved_docs", [])
|
387 |
+
enhanced_query = state.get("enhanced_query", "")
|
388 |
+
|
389 |
+
try:
|
390 |
+
# Build context from retrieved documents using original content
|
391 |
+
context = ""
|
392 |
+
if retrieved_docs:
|
393 |
+
context = "\n\n".join(
|
394 |
+
[
|
395 |
+
f"[{doc.metadata.get('doc_type', 'unknown')}] {doc.metadata.get('original_content', doc.page_content)}"
|
396 |
+
for doc in retrieved_docs
|
397 |
+
]
|
398 |
+
)
|
399 |
+
|
400 |
+
# Use saved conversation context if available, otherwise get it
|
401 |
+
conversation_context = state.get("conversation_context")
|
402 |
+
if conversation_context is None:
|
403 |
+
conversation_context = get_conversation_context(messages, max_messages=6)
|
404 |
+
|
405 |
+
# Get system prompt based on intent using the prompts file
|
406 |
+
system_prompt = get_system_prompt_by_intent(intent)
|
407 |
+
|
408 |
+
# Build full prompt with both knowledge context and conversation context using the prompts file
|
409 |
+
generation_prompt = RESPONSE_GENERATION_PROMPT.format(
|
410 |
+
system_prompt=system_prompt,
|
411 |
+
context=context,
|
412 |
+
conversation_context=conversation_context,
|
413 |
+
user_message=user_message,
|
414 |
+
enhanced_query=enhanced_query,
|
415 |
+
)
|
416 |
+
|
417 |
+
# Generate response using the pre-initialized LLM client
|
418 |
+
response_text = llm_client.generate(generation_prompt, temperature=0.1)
|
419 |
+
|
420 |
+
print(f"🤖 Generated response for intent: {intent}")
|
421 |
+
|
422 |
+
# Add AI message to the conversation
|
423 |
+
ai_message = AIMessage(content=response_text)
|
424 |
+
|
425 |
+
return {**state, "messages": [ai_message], "response_type": "generated"}
|
426 |
+
|
427 |
+
except Exception as e:
|
428 |
+
print(f"❌ Response generation error: {e}")
|
429 |
+
error_response = "Xin lỗi, em gặp lỗi khi xử lý yêu cầu. Vui lòng thử lại sau."
|
430 |
+
ai_message = AIMessage(content=error_response)
|
431 |
+
|
432 |
+
return {
|
433 |
+
**state,
|
434 |
+
"messages": [ai_message],
|
435 |
+
"response_type": "error",
|
436 |
+
"error": str(e),
|
437 |
+
}
|
438 |
+
|
439 |
+
|
440 |
+
# Routing function for conditional edges
|
441 |
+
def route_after_intent_classification(state: ViettelPayState) -> str:
|
442 |
+
"""Route to appropriate node after intent classification"""
|
443 |
+
|
444 |
+
intent = state.get("intent", "unclear")
|
445 |
+
|
446 |
+
# Script-based intents (no knowledge retrieval needed)
|
447 |
+
script_intents = {"greeting", "out_of_scope", "human_request", "unclear"}
|
448 |
+
|
449 |
+
if intent in script_intents:
|
450 |
+
return "script_response"
|
451 |
+
else:
|
452 |
+
# Knowledge-based intents need query enhancement first
|
453 |
+
return "query_enhancement"
|
454 |
+
|
455 |
+
|
456 |
+
def route_after_query_enhancement(state: ViettelPayState) -> str:
|
457 |
+
"""Route after query enhancement (always to knowledge retrieval)"""
|
458 |
+
return "knowledge_retrieval"
|
459 |
+
|
460 |
+
|
461 |
+
def route_after_knowledge_retrieval(state: ViettelPayState) -> str:
|
462 |
+
"""Route after knowledge retrieval (always to generation)"""
|
463 |
+
return "generate_response"
|
src/agent/prompts.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Prompt templates for ViettelPay AI Agent
|
3 |
+
All prompts using Vietnamese language for ViettelPay Pro customer support
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Intent Classification Prompt (JSON format for better parsing)
|
7 |
+
INTENT_CLASSIFICATION_PROMPT = """
|
8 |
+
Bạn là hệ thống phân loại ý định cho ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
9 |
+
Phân tích tin nhắn của người dùng và trả về ý định chính.
|
10 |
+
|
11 |
+
Các loại ý định:
|
12 |
+
* **`greeting`**: Chỉ là lời chào hỏi thuần túy, không có câu hỏi hoặc yêu cầu cụ thể nào khác. Nếu tin nhắn có cả lời chào VÀ câu hỏi thì phân loại theo các ý định khác, không phải greeting.
|
13 |
+
* *Ví dụ:* "chào em", "hello shop", "xin chào ạ"
|
14 |
+
* *Không phải greeting:* "xin chào, cho hỏi về lỗi 606" → đây là error_help
|
15 |
+
* **`faq`**: Các câu hỏi đáp chung, tìm hiểu về dịch vụ, tính năng, v.v.
|
16 |
+
* *Ví dụ:* "App có bán thẻ game không?", "ViettelPay Pro nạp tiền được cho mạng nao?"
|
17 |
+
* **`error_help`**: Báo cáo sự cố, hỏi về mã lỗi cụ thể.
|
18 |
+
* *Ví dụ:* "Giao dịch báo lỗi 606", "tại sao tôi không thanh toán được?", "lỗi này là gì?"
|
19 |
+
* **`procedure_guide`**: Hỏi về các bước cụ thể để thực hiện một tác vụ.
|
20 |
+
* *Ví dụ:* "làm thế nào để hủy giao dịch?", "chỉ tôi cách lấy lại mã thẻ cào", "hướng dẫn nạp cước"
|
21 |
+
* **`human_request`**: Yêu cầu được nói chuyện trực tiếp với nhân viên hỗ trợ.
|
22 |
+
* *Ví dụ:* "cho tôi gặp người thật", "nối máy cho tổng đài", "em k hiểu, cho gặp ai đó"
|
23 |
+
* **`out_of_scope`**: Câu hỏi ngoài phạm vi ViettelPay (thời tiết, chính trị, v.v.), không liên quan gì đến các dịch vụ tài chính, viễn thông của Viettel.
|
24 |
+
* *Ví dụ:* "dự báo thời tiết hôm nay?", "giá xăng bao nhiêu?", "cách nấu phở"
|
25 |
+
* **`unclear`**: Câu hỏi không rõ ràng, thiếu thông tin cụ thể, cần người dùng bổ sung thêm chi tiết để có thể hỗ trợ hiệu quả.
|
26 |
+
* *Ví dụ:* "lỗi", "giúp với", "gd", "???", "ko hiểu", "bị lỗi giờ sao đây", "không thực hiện được", "sao vậy", "tại sao thế"
|
27 |
+
|
28 |
+
**Bối cảnh cuộc trò chuyện:**
|
29 |
+
<conversation_context>
|
30 |
+
{conversation_context}
|
31 |
+
</conversation_context>
|
32 |
+
|
33 |
+
**Tin nhắn mới của người dùng:**
|
34 |
+
<user_message>
|
35 |
+
{user_message}
|
36 |
+
</user_message>
|
37 |
+
|
38 |
+
Hãy phân tích dựa trên cả ngữ cảnh cuộc hội thoại và tin nhắn mới nhất của người dùng.
|
39 |
+
|
40 |
+
QUAN TRỌNG: Chỉ trả về JSON thuần túy, không có text khác. Format chính xác:
|
41 |
+
{{"intent": "tên_ý_định", "confidence": 0.9, "explanation": "lý do ngắn gọn"}}
|
42 |
+
"""
|
43 |
+
|
44 |
+
# Query Enhancement Prompt for contextual search improvement
|
45 |
+
QUERY_ENHANCEMENT_PROMPT = """
|
46 |
+
**Nhiệm vụ:** Bạn là một trợ lý chuyên gia của ViettelPay Pro.
|
47 |
+
ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
48 |
+
Nhiệm vụ của bạn là đọc cuộc trò chuyện và tin nhắn mới nhất của người dùng để tạo ra một truy vấn tìm kiếm (search query) duy nhất, tối ưu cho cơ sở dữ liệu nội bộ.
|
49 |
+
|
50 |
+
**Bối cảnh cuộc trò chuyện:**
|
51 |
+
<conversation_context>
|
52 |
+
{conversation_context}
|
53 |
+
</conversation_context>
|
54 |
+
|
55 |
+
**Tin nhắn mới của người dùng:**
|
56 |
+
<user_message>
|
57 |
+
{user_message}
|
58 |
+
</user_message
|
59 |
+
|
60 |
+
**Quy tắc tạo truy vấn:**
|
61 |
+
1. **Kết hợp Ngữ cảnh:** Phân tích toàn bộ cuộc trò chuyện để tạo ra một truy vấn đầy đủ ý nghĩa, nắm bắt được mục tiêu thực sự của người dùng.
|
62 |
+
2. **Làm rõ & Cụ thể:** Thay thế các đại từ (ví dụ: "nó", "cái đó") bằng các chủ thể hoặc thuật ngữ cụ thể đã được đề cập (ví dụ: "liên kết ngân hàng", "rút tiền tại ATM", "mã lỗi 101").
|
63 |
+
3. **Tích hợp Thuật ngữ:** Tích hợp một cách **tự nhiên** các từ khóa và thuật ngữ chuyên ngành của ViettelPay Pro (ví dụ: "giao dịch", "nạp cước", "chiết khấu", "OTP", "hoa hồng").
|
64 |
+
4. **Duy trì Tính tự nhiên (QUAN TRỌNG):** Truy vấn phải là một câu hỏi hoặc một cụm từ hoàn chỉnh, tự nhiên bằng tiếng Việt. **Tránh tạo ra danh sách từ khóa rời rạc.**
|
65 |
+
* **Tốt:** "cách tính hoa hồng khi nạp thẻ điện thoại cho khách"
|
66 |
+
* **Không tốt:** "hoa hồng nạp thẻ điện thoại"
|
67 |
+
5. **Giữ lại Ý định Gốc:** Truy vấn phải phản ánh chính xác câu hỏi của người dùng, không thêm thông tin hoặc suy diễn.
|
68 |
+
6. Thêm vài câu sử dụng từ đồng nghĩa và cách diễn đạt khác nhau trong tiếng Việt để tăng khả năng tìm kiếm
|
69 |
+
|
70 |
+
**ĐẦU RA:** CHỈ trả về một chuỗi truy vấn tìm kiếm đã được cải thiện. Không thêm lời giải thích.
|
71 |
+
"""
|
72 |
+
|
73 |
+
# System Prompt for Error Help responses
|
74 |
+
ERROR_HELP_SYSTEM_PROMPT = """Bạn là chuyên gia hỗ trợ kỹ thuật ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
75 |
+
Thể hiện sự cảm thông với khó khăn của người dùng.
|
76 |
+
Cung cấp giải pháp cụ thể, từng bước.
|
77 |
+
Nếu cần hỗ trợ thêm, hướng dẫn liên hệ tổng đài.
|
78 |
+
Nếu có lịch sử cuộc hội thoại, hãy tham khảo để đưa ra câu trả lời phù hợp và có tính liên kết."""
|
79 |
+
|
80 |
+
# System Prompt for Procedure Guide responses
|
81 |
+
PROCEDURE_GUIDE_SYSTEM_PROMPT = """Bạn là hướng dẫn viên ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
82 |
+
Cung cấp hướng dẫn từng bước rõ ràng.
|
83 |
+
Bao gồm link video nếu có trong thông tin.
|
84 |
+
Sử dụng format có số thứ tự cho các bước.
|
85 |
+
Nếu có lịch sử cuộc hội thoại, hãy tham khảo để đưa ra câu trả lời phù hợp và có tính liên kết."""
|
86 |
+
|
87 |
+
# Default System Prompt for general responses
|
88 |
+
DEFAULT_SYSTEM_PROMPT = """Bạn là trợ lý ảo ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
89 |
+
Trả lời câu hỏi dựa trên thông tin được cung cấp.
|
90 |
+
Giọng điệu thân thiện, chuyên nghiệp.
|
91 |
+
Sử dụng "Anh/chị" khi xưng hô.
|
92 |
+
Nếu có lịch sử cuộc hội thoại, hãy tham khảo để đưa ra câu trả lời phù hợp và có tính liên kết."""
|
93 |
+
|
94 |
+
# Response Generation Template with context and knowledge base integration
|
95 |
+
RESPONSE_GENERATION_PROMPT = """<system_prompt>
|
96 |
+
{system_prompt}
|
97 |
+
</system_prompt>
|
98 |
+
|
99 |
+
**Thông tin tham khảo từ cơ sở tri thức:**
|
100 |
+
<knowledge_base_context>
|
101 |
+
{context}
|
102 |
+
</knowledge_base_context>
|
103 |
+
|
104 |
+
**Bối cảnh cuộc trò chuyện:**
|
105 |
+
<conversation_context>
|
106 |
+
{conversation_context}
|
107 |
+
</conversation_context>
|
108 |
+
|
109 |
+
**Tin nhắn mới của người dùng:**
|
110 |
+
<user_message>
|
111 |
+
{user_message}
|
112 |
+
</user_message>
|
113 |
+
|
114 |
+
Hãy trả lời câu hỏi dựa trên thông tin tham khảo và lịch sử cuộc hội thoại (nếu có). Nếu không có thông tin phù hợp, hãy nói rằng bạn cần thêm thông tin.
|
115 |
+
"""
|
116 |
+
|
117 |
+
|
118 |
+
def get_system_prompt_by_intent(intent: str) -> str:
|
119 |
+
"""Get appropriate system prompt based on intent classification"""
|
120 |
+
if intent == "error_help":
|
121 |
+
return ERROR_HELP_SYSTEM_PROMPT
|
122 |
+
elif intent == "procedure_guide":
|
123 |
+
return PROCEDURE_GUIDE_SYSTEM_PROMPT
|
124 |
+
else: # faq, etc.
|
125 |
+
return DEFAULT_SYSTEM_PROMPT
|
src/agent/scripts.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Conversation Scripts Handler
|
3 |
+
Manages predefined scripts for standard conversation scenarios
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
from typing import Dict, Optional
|
9 |
+
|
10 |
+
|
11 |
+
class ConversationScripts:
|
12 |
+
"""Handler for conversation scripts and standard responses"""
|
13 |
+
|
14 |
+
def __init__(self, scripts_file: Optional[str] = None):
|
15 |
+
self.scripts = {}
|
16 |
+
self.scripts_file = scripts_file
|
17 |
+
|
18 |
+
# Default built-in scripts (fallback)
|
19 |
+
self._load_default_scripts()
|
20 |
+
|
21 |
+
# Load from file if provided
|
22 |
+
if scripts_file and os.path.exists(scripts_file):
|
23 |
+
self._load_scripts_from_file(scripts_file)
|
24 |
+
print(f"✅ Loaded conversation scripts from {scripts_file}")
|
25 |
+
else:
|
26 |
+
print("⚠️ Using default built-in scripts")
|
27 |
+
|
28 |
+
def _load_default_scripts(self):
|
29 |
+
"""Load default conversation scripts"""
|
30 |
+
self.scripts = {
|
31 |
+
"greeting": """Xin chào! Em là Trợ lý ảo Viettelpay Pro sẵn sàng hỗ trợ Anh/chị!
|
32 |
+
Hiện tại, Trợ lý ảo đang trong giai đoạn thử nghiệm hỗ trợ nghiệp vụ cước viễn thông, thẻ cào và thẻ game với các nội dung sau:
|
33 |
+
- Hướng dẫn sử dụng
|
34 |
+
- Chính sách phí bán hàng
|
35 |
+
- Tìm hiểu quy định hủy giao dịch
|
36 |
+
- Hướng dẫn xử lý một số lỗi thường gặp.
|
37 |
+
|
38 |
+
Anh/Chị vui lòng bấm vào từng chủ đề để xem chi tiết.
|
39 |
+
Nếu thông tin chưa đáp ứng nhu cầu, Anh/Chị hãy đặt lại câu hỏi để em tiếp tục hỗ trợ ạ!""",
|
40 |
+
"out_of_scope": """Cảm ơn Anh/chị đã đặt câu hỏi!
|
41 |
+
Trợ lý ảo Viettelpay Pro đang thử nghiệm và cập nhật kiến thức nghiệp vụ để hỗ trợ Anh/chị tốt hơn. Vì vậy, rất tiếc nhu cầu hiện tại của Anh/chị nằm ngoài khả năng hỗ trợ của Trợ lý ảo.
|
42 |
+
|
43 |
+
Để được hỗ trợ chính xác và đầy đủ hơn, Anh/chị vui lòng gửi yêu cầu hỗ trợ tại đây""",
|
44 |
+
"human_request_attempt_1": """Anh/Chị vui lòng chia sẻ thêm nội dung cần hỗ trợ, Em rất mong được giải đáp trực tiếp để tiết kiệm thời gian của Anh/Chị ạ!""",
|
45 |
+
"human_request_attempt_2": """Rất tiếc! Hiện tại hệ thống chưa có Tư vấn viên hỗ trợ trực tuyến.
|
46 |
+
Tuy nhiên, Anh/chị vẫn có thể yêu cầu hỗ trợ được trợ giúp qua các hình thức sau:
|
47 |
+
📌 1. Đặt câu hỏi ngay tại đây, Trợ lý ảo ViettelPay Pro luôn sẵn sàng hỗ trợ Anh/chị trong phạm vi nghiệp vụ thử nghiệm (nghiệp vụ cước viễn thông, thẻ cào, thẻ game):
|
48 |
+
✅ Hướng dẫn sử dụng
|
49 |
+
✅ Chính sách phí bán hàng
|
50 |
+
✅ Tìm hiểu về quy định hủy giao dịch
|
51 |
+
✅ Hướng dẫn xử lý một số lỗi thường gặp.
|
52 |
+
📌 2. Tìm hiểu thông tin nghiệp vụ tại mục:
|
53 |
+
📚 Hướng dẫn, hỗ trợ: Các video hướng dẫn nghiệp vụ
|
54 |
+
💡Thông báo: Các tin tức nghiệp vụ và tin nâng cấp hệ thống/tin sự cố.
|
55 |
+
📌 3. Gửi yêu cầu hỗ trợ tại đây
|
56 |
+
Hoặc gọi Tổng đài 1789 nhánh 5 trong trường hợp khẩn cấp""",
|
57 |
+
"confirmation_check": """Anh/Chị có thắc mắc thêm vấn đề nào liên quan đến nội dung em vừa cung cấp không ạ?""",
|
58 |
+
"closing": """Hy vọng những thông tin vừa rồi đã đáp ứng nhu cầu của Anh/chị.
|
59 |
+
Nếu cần hỗ trợ thêm, Anh/Chị hãy đặt câu hỏi để em tiếp tục hỗ trợ ạ!
|
60 |
+
🌟 Chúc Anh/chị một ngày thật vui và thành công!""",
|
61 |
+
"ask_for_clarity": """Em chưa hiểu rõ yêu cầu của Anh/chị. Anh/chị có thể chia sẻ cụ thể hơn được không ạ?""",
|
62 |
+
"empathy_error": """Em hiểu Anh/chị đang gặp khó khăn với lỗi này. Để hỗ trợ Anh/chị tốt nhất, em sẽ tìm hiểu và đưa ra hướng giải quyết cụ thể.""",
|
63 |
+
}
|
64 |
+
|
65 |
+
def _load_scripts_from_file(self, file_path: str):
|
66 |
+
"""Load scripts from CSV file (kich_ban.csv format)"""
|
67 |
+
try:
|
68 |
+
df = pd.read_csv(file_path)
|
69 |
+
|
70 |
+
# Map CSV scenarios to script keys
|
71 |
+
scenario_mapping = {
|
72 |
+
"Chào hỏi": "greeting",
|
73 |
+
"Trao đổi thông tin chính": "out_of_scope", # First occurrence
|
74 |
+
"Trước khi kết thúc phiên": "confirmation_check",
|
75 |
+
"Kết thúc": "closing",
|
76 |
+
}
|
77 |
+
|
78 |
+
for _, row in df.iterrows():
|
79 |
+
scenario_type = row.get("Loại kịch bản", "")
|
80 |
+
situation = row.get("Tình huống", "")
|
81 |
+
script = row.get("Kịch bản chốt", "")
|
82 |
+
|
83 |
+
# Handle specific mappings
|
84 |
+
if scenario_type == "Chào hỏi":
|
85 |
+
self.scripts["greeting"] = script
|
86 |
+
elif scenario_type == "Trao đổi thông tin chính":
|
87 |
+
if "ngoài nghiệp vụ" in situation:
|
88 |
+
self.scripts["out_of_scope"] = script
|
89 |
+
elif "gặp tư vấn viên" in situation:
|
90 |
+
if "Lần 1" in script:
|
91 |
+
self.scripts["human_request_attempt_1"] = (
|
92 |
+
script.split("Lần 1:")[1].split("Lần 2:")[0].strip()
|
93 |
+
)
|
94 |
+
if "Lần 2" in script:
|
95 |
+
self.scripts["human_request_attempt_2"] = script.split(
|
96 |
+
"Lần 2:"
|
97 |
+
)[1].strip()
|
98 |
+
elif "không đủ ý" in situation:
|
99 |
+
self.scripts["ask_for_clarity"] = (
|
100 |
+
"Em chưa hiểu rõ yêu cầu của Anh/chị. Anh/chị có thể chia sẻ cụ thể hơn được không ạ?"
|
101 |
+
)
|
102 |
+
elif "lỗi" in situation:
|
103 |
+
self.scripts["empathy_error"] = (
|
104 |
+
"Em hiểu Anh/chị đang gặp khó khăn với lỗi này."
|
105 |
+
)
|
106 |
+
elif scenario_type == "Trước khi kết thúc phiên":
|
107 |
+
self.scripts["confirmation_check"] = script
|
108 |
+
elif scenario_type == "Kết thúc":
|
109 |
+
self.scripts["closing"] = script
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
print(f"⚠️ Error loading scripts from file: {e}")
|
113 |
+
print("Using default scripts instead")
|
114 |
+
|
115 |
+
def get_script(self, script_type: str) -> Optional[str]:
|
116 |
+
"""Get script by type"""
|
117 |
+
return self.scripts.get(script_type)
|
118 |
+
|
119 |
+
def has_script(self, script_type: str) -> bool:
|
120 |
+
"""Check if script exists"""
|
121 |
+
return script_type in self.scripts
|
122 |
+
|
123 |
+
def get_all_script_types(self) -> list:
|
124 |
+
"""Get all available script types"""
|
125 |
+
return list(self.scripts.keys())
|
126 |
+
|
127 |
+
def add_script(self, script_type: str, script_content: str):
|
128 |
+
"""Add or update a script"""
|
129 |
+
self.scripts[script_type] = script_content
|
130 |
+
|
131 |
+
def get_stats(self) -> dict:
|
132 |
+
"""Get statistics about loaded scripts"""
|
133 |
+
return {
|
134 |
+
"total_scripts": len(self.scripts),
|
135 |
+
"script_types": list(self.scripts.keys()),
|
136 |
+
"source": "file" if self.scripts_file else "default",
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
# Usage example
|
141 |
+
if __name__ == "__main__":
|
142 |
+
# Test with default scripts
|
143 |
+
scripts = ConversationScripts()
|
144 |
+
print("📊 Scripts Stats:", scripts.get_stats())
|
145 |
+
|
146 |
+
# Test specific scripts
|
147 |
+
greeting = scripts.get_script("greeting")
|
148 |
+
print(f"\n👋 Greeting Script:\n{greeting}")
|
149 |
+
|
150 |
+
# Test loading from file (if available)
|
151 |
+
try:
|
152 |
+
scripts_with_file = ConversationScripts(
|
153 |
+
"./viettelpay_docs/processed/kich_ban.csv"
|
154 |
+
)
|
155 |
+
print(f"\n📊 File-based Scripts Stats: {scripts_with_file.get_stats()}")
|
156 |
+
except Exception as e:
|
157 |
+
print(f"File loading test failed: {e}")
|
src/agent/viettelpay_agent.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ViettelPay AI Agent using LangGraph
|
3 |
+
Multi-turn conversation support with short-term memory using InMemorySaver
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Dict, Optional
|
8 |
+
from functools import partial
|
9 |
+
from langgraph.graph import StateGraph, END
|
10 |
+
from langgraph.checkpoint.memory import InMemorySaver
|
11 |
+
from langchain_core.messages import HumanMessage
|
12 |
+
|
13 |
+
from src.agent.nodes import (
|
14 |
+
ViettelPayState,
|
15 |
+
classify_intent_node,
|
16 |
+
query_enhancement_node,
|
17 |
+
knowledge_retrieval_node,
|
18 |
+
script_response_node,
|
19 |
+
generate_response_node,
|
20 |
+
route_after_intent_classification,
|
21 |
+
route_after_query_enhancement,
|
22 |
+
route_after_knowledge_retrieval,
|
23 |
+
)
|
24 |
+
|
25 |
+
# Import configuration utility
|
26 |
+
from src.utils.config import get_knowledge_base_path, get_llm_provider
|
27 |
+
|
28 |
+
|
29 |
+
class ViettelPayAgent:
|
30 |
+
"""Main ViettelPay AI Agent using LangGraph workflow with multi-turn conversation support"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
knowledge_base_path: str = None,
|
35 |
+
scripts_file: Optional[str] = None,
|
36 |
+
llm_provider: str = None,
|
37 |
+
):
|
38 |
+
knowledge_base_path = knowledge_base_path or get_knowledge_base_path()
|
39 |
+
scripts_file = scripts_file or "./viettelpay_docs/processed/kich_ban.csv"
|
40 |
+
llm_provider = llm_provider or get_llm_provider()
|
41 |
+
|
42 |
+
self.knowledge_base_path = knowledge_base_path
|
43 |
+
self.scripts_file = scripts_file
|
44 |
+
self.llm_provider = llm_provider
|
45 |
+
|
46 |
+
# Initialize LLM client once during agent creation
|
47 |
+
print(f"🧠 Initializing LLM client ({self.llm_provider})...")
|
48 |
+
from src.llm.llm_client import LLMClientFactory
|
49 |
+
|
50 |
+
self.llm_client = LLMClientFactory.create_client(self.llm_provider)
|
51 |
+
print(f"✅ LLM client initialized and ready")
|
52 |
+
|
53 |
+
# Initialize knowledge retriever once during agent creation
|
54 |
+
print(f"📚 Initializing knowledge retriever...")
|
55 |
+
try:
|
56 |
+
from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
|
57 |
+
|
58 |
+
self.knowledge_base = ViettelKnowledgeBase()
|
59 |
+
ensemble_retriever = self.knowledge_base.load_knowledge_base(
|
60 |
+
knowledge_base_path
|
61 |
+
)
|
62 |
+
if not ensemble_retriever:
|
63 |
+
raise ValueError(
|
64 |
+
f"Knowledge base not found at {knowledge_base_path}. Run build_database_script.py first."
|
65 |
+
)
|
66 |
+
print(f"✅ Knowledge retriever initialized and ready")
|
67 |
+
except Exception as e:
|
68 |
+
print(f"⚠️ Knowledge retriever initialization failed: {e}")
|
69 |
+
self.knowledge_base = None
|
70 |
+
|
71 |
+
# Initialize checkpointer for short-term memory
|
72 |
+
self.checkpointer = InMemorySaver()
|
73 |
+
|
74 |
+
# Build workflow with pre-initialized components
|
75 |
+
self.workflow = self._build_workflow()
|
76 |
+
self.app = self.workflow.compile(checkpointer=self.checkpointer)
|
77 |
+
|
78 |
+
print("✅ ViettelPay Agent initialized with multi-turn conversation support")
|
79 |
+
|
80 |
+
def _build_workflow(self) -> StateGraph:
|
81 |
+
"""Build LangGraph workflow with pre-initialized components"""
|
82 |
+
|
83 |
+
# Create workflow graph
|
84 |
+
workflow = StateGraph(ViettelPayState)
|
85 |
+
|
86 |
+
# Create node functions with pre-bound components using functools.partial
|
87 |
+
# This eliminates the need to initialize components in each node call
|
88 |
+
classify_intent_with_llm = partial(
|
89 |
+
classify_intent_node, llm_client=self.llm_client
|
90 |
+
)
|
91 |
+
query_enhancement_with_llm = partial(
|
92 |
+
query_enhancement_node, llm_client=self.llm_client
|
93 |
+
)
|
94 |
+
knowledge_retrieval_with_retriever = partial(
|
95 |
+
knowledge_retrieval_node, knowledge_retriever=self.knowledge_base
|
96 |
+
)
|
97 |
+
generate_response_with_llm = partial(
|
98 |
+
generate_response_node, llm_client=self.llm_client
|
99 |
+
)
|
100 |
+
|
101 |
+
# Add nodes (some with pre-bound components, some without)
|
102 |
+
workflow.add_node("classify_intent", classify_intent_with_llm)
|
103 |
+
workflow.add_node("query_enhancement", query_enhancement_with_llm)
|
104 |
+
workflow.add_node("knowledge_retrieval", knowledge_retrieval_with_retriever)
|
105 |
+
workflow.add_node(
|
106 |
+
"script_response", script_response_node
|
107 |
+
) # No pre-bound components needed
|
108 |
+
workflow.add_node("generate_response", generate_response_with_llm)
|
109 |
+
|
110 |
+
# Set entry point
|
111 |
+
workflow.set_entry_point("classify_intent")
|
112 |
+
|
113 |
+
# Add conditional routing after intent classification
|
114 |
+
workflow.add_conditional_edges(
|
115 |
+
"classify_intent",
|
116 |
+
route_after_intent_classification,
|
117 |
+
{
|
118 |
+
"script_response": "script_response",
|
119 |
+
"query_enhancement": "query_enhancement",
|
120 |
+
},
|
121 |
+
)
|
122 |
+
|
123 |
+
# Script responses go directly to end
|
124 |
+
workflow.add_edge("script_response", END)
|
125 |
+
|
126 |
+
# Query enhancement goes to knowledge retrieval
|
127 |
+
workflow.add_edge("query_enhancement", "knowledge_retrieval")
|
128 |
+
|
129 |
+
# Knowledge retrieval goes to response generation
|
130 |
+
workflow.add_edge("knowledge_retrieval", "generate_response")
|
131 |
+
workflow.add_edge("generate_response", END)
|
132 |
+
|
133 |
+
print("🔄 LangGraph workflow built successfully with optimized component usage")
|
134 |
+
return workflow
|
135 |
+
|
136 |
+
def process_message(self, user_message: str, thread_id: str = "default") -> Dict:
|
137 |
+
"""Process a user message in a multi-turn conversation"""
|
138 |
+
|
139 |
+
print(f"\n💬 Processing message: '{user_message}' (thread: {thread_id})")
|
140 |
+
print("=" * 50)
|
141 |
+
|
142 |
+
# Create configuration with thread_id for conversation memory
|
143 |
+
config = {"configurable": {"thread_id": thread_id}}
|
144 |
+
|
145 |
+
try:
|
146 |
+
# Create human message
|
147 |
+
human_message = HumanMessage(content=user_message)
|
148 |
+
|
149 |
+
# Initialize state with the new message
|
150 |
+
# Note: conversation_context is set to None so it gets recomputed with fresh message history
|
151 |
+
initial_state = {
|
152 |
+
"messages": [human_message],
|
153 |
+
"intent": None,
|
154 |
+
"confidence": None,
|
155 |
+
"enhanced_query": None,
|
156 |
+
"retrieved_docs": None,
|
157 |
+
"conversation_context": None, # Reset to ensure fresh context computation
|
158 |
+
"response_type": None,
|
159 |
+
"error": None,
|
160 |
+
"processing_info": None,
|
161 |
+
}
|
162 |
+
|
163 |
+
# Run workflow with memory
|
164 |
+
result = self.app.invoke(initial_state, config)
|
165 |
+
|
166 |
+
# Extract response from the last AI message
|
167 |
+
messages = result.get("messages", [])
|
168 |
+
if messages:
|
169 |
+
# Get the last AI message
|
170 |
+
last_message = messages[-1]
|
171 |
+
if hasattr(last_message, "content"):
|
172 |
+
response = last_message.content
|
173 |
+
else:
|
174 |
+
response = str(last_message)
|
175 |
+
else:
|
176 |
+
response = "Xin lỗi, em không thể xử lý yêu cầu này."
|
177 |
+
|
178 |
+
response_type = result.get("response_type", "unknown")
|
179 |
+
intent = result.get("intent", "unknown")
|
180 |
+
confidence = result.get("confidence", 0.0)
|
181 |
+
enhanced_query = result.get("enhanced_query", "")
|
182 |
+
error = result.get("error")
|
183 |
+
|
184 |
+
# Build response info
|
185 |
+
response_info = {
|
186 |
+
"response": response,
|
187 |
+
"intent": intent,
|
188 |
+
"confidence": confidence,
|
189 |
+
"response_type": response_type,
|
190 |
+
"enhanced_query": enhanced_query,
|
191 |
+
"success": error is None,
|
192 |
+
"error": error,
|
193 |
+
"thread_id": thread_id,
|
194 |
+
"message_count": len(messages),
|
195 |
+
}
|
196 |
+
|
197 |
+
print(f"✅ Response generated successfully")
|
198 |
+
print(f" Intent: {intent} (confidence: {confidence})")
|
199 |
+
print(f" Type: {response_type}")
|
200 |
+
if enhanced_query and enhanced_query != user_message:
|
201 |
+
print(f" Enhanced query: {enhanced_query}")
|
202 |
+
print(f" Thread: {thread_id}")
|
203 |
+
|
204 |
+
return response_info
|
205 |
+
|
206 |
+
except Exception as e:
|
207 |
+
print(f"❌ Workflow error: {e}")
|
208 |
+
|
209 |
+
return {
|
210 |
+
"response": "Xin lỗi, em gặp lỗi kỹ thuật. Vui lòng thử lại sau.",
|
211 |
+
"intent": "error",
|
212 |
+
"confidence": 0.0,
|
213 |
+
"response_type": "error",
|
214 |
+
"enhanced_query": "",
|
215 |
+
"success": False,
|
216 |
+
"error": str(e),
|
217 |
+
"thread_id": thread_id,
|
218 |
+
"message_count": 0,
|
219 |
+
}
|
220 |
+
|
221 |
+
def chat(self, user_message: str, thread_id: str = "default") -> str:
|
222 |
+
"""Simple chat interface - returns just the response text"""
|
223 |
+
result = self.process_message(user_message, thread_id)
|
224 |
+
return result["response"]
|
225 |
+
|
226 |
+
def get_conversation_history(self, thread_id: str = "default") -> list:
|
227 |
+
"""Get conversation history for a specific thread"""
|
228 |
+
try:
|
229 |
+
config = {"configurable": {"thread_id": thread_id}}
|
230 |
+
|
231 |
+
# Get the current state to access message history
|
232 |
+
current_state = self.app.get_state(config)
|
233 |
+
|
234 |
+
if current_state and current_state.values.get("messages"):
|
235 |
+
messages = current_state.values["messages"]
|
236 |
+
history = []
|
237 |
+
|
238 |
+
for msg in messages:
|
239 |
+
if hasattr(msg, "type") and hasattr(msg, "content"):
|
240 |
+
role = "user" if msg.type == "human" else "assistant"
|
241 |
+
history.append({"role": role, "content": msg.content})
|
242 |
+
elif hasattr(msg, "role") and hasattr(msg, "content"):
|
243 |
+
history.append({"role": msg.role, "content": msg.content})
|
244 |
+
|
245 |
+
return history
|
246 |
+
else:
|
247 |
+
return []
|
248 |
+
|
249 |
+
except Exception as e:
|
250 |
+
print(f"❌ Error getting conversation history: {e}")
|
251 |
+
return []
|
252 |
+
|
253 |
+
def clear_conversation(self, thread_id: str = "default") -> bool:
|
254 |
+
"""Clear conversation history for a specific thread"""
|
255 |
+
try:
|
256 |
+
# Note: InMemorySaver doesn't have a direct clear method
|
257 |
+
# The conversation will be cleared when the app is restarted
|
258 |
+
# For persistent memory, you'd need to implement a clear method
|
259 |
+
print(f"📝 Conversation clearing requested for thread: {thread_id}")
|
260 |
+
print(" Note: InMemorySaver conversations clear on app restart")
|
261 |
+
return True
|
262 |
+
except Exception as e:
|
263 |
+
print(f"❌ Error clearing conversation: {e}")
|
264 |
+
return False
|
265 |
+
|
266 |
+
def get_workflow_info(self) -> Dict:
|
267 |
+
"""Get information about the workflow structure"""
|
268 |
+
return {
|
269 |
+
"nodes": [
|
270 |
+
"classify_intent",
|
271 |
+
"query_enhancement",
|
272 |
+
"knowledge_retrieval",
|
273 |
+
"script_response",
|
274 |
+
"generate_response",
|
275 |
+
],
|
276 |
+
"entry_point": "classify_intent",
|
277 |
+
"knowledge_base_path": self.knowledge_base_path,
|
278 |
+
"scripts_file": self.scripts_file,
|
279 |
+
"llm_provider": self.llm_provider,
|
280 |
+
"memory_type": "InMemorySaver",
|
281 |
+
"multi_turn": True,
|
282 |
+
"query_enhancement": True,
|
283 |
+
"optimizations": {
|
284 |
+
"llm_client": "Single initialization with functools.partial",
|
285 |
+
"knowledge_retriever": "Single initialization with functools.partial",
|
286 |
+
"conversation_context": "Cached in state to avoid repeated computation",
|
287 |
+
},
|
288 |
+
}
|
289 |
+
|
290 |
+
def health_check(self) -> Dict:
|
291 |
+
"""Check if all components are working"""
|
292 |
+
|
293 |
+
health_status = {
|
294 |
+
"agent": True,
|
295 |
+
"workflow": True,
|
296 |
+
"memory": True,
|
297 |
+
"llm": False,
|
298 |
+
"knowledge_base": False,
|
299 |
+
"scripts": False,
|
300 |
+
"overall": False,
|
301 |
+
}
|
302 |
+
|
303 |
+
try:
|
304 |
+
# Test LLM client (already initialized)
|
305 |
+
test_response = self.llm_client.generate("Hello", temperature=0.1)
|
306 |
+
health_status["llm"] = bool(test_response)
|
307 |
+
print("✅ LLM client working")
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
print(f"⚠️ LLM health check failed: {e}")
|
311 |
+
health_status["llm"] = False
|
312 |
+
|
313 |
+
try:
|
314 |
+
# Test memory/checkpointer
|
315 |
+
test_config = {"configurable": {"thread_id": "health_check"}}
|
316 |
+
test_state = {"messages": [HumanMessage(content="test")]}
|
317 |
+
|
318 |
+
# Try to invoke with memory
|
319 |
+
self.app.invoke(test_state, test_config)
|
320 |
+
health_status["memory"] = True
|
321 |
+
print("✅ Memory/checkpointer working")
|
322 |
+
|
323 |
+
except Exception as e:
|
324 |
+
print(f"⚠️ Memory health check failed: {e}")
|
325 |
+
health_status["memory"] = False
|
326 |
+
|
327 |
+
try:
|
328 |
+
# Test knowledge base (using pre-initialized retriever)
|
329 |
+
if self.knowledge_base:
|
330 |
+
# Test a simple search to verify it's working
|
331 |
+
test_docs = self.knowledge_base.search("test", top_k=1)
|
332 |
+
health_status["knowledge_base"] = True
|
333 |
+
print("✅ Knowledge retriever working")
|
334 |
+
else:
|
335 |
+
health_status["knowledge_base"] = False
|
336 |
+
print("❌ Knowledge retriever not initialized")
|
337 |
+
|
338 |
+
except Exception as e:
|
339 |
+
print(f"⚠️ Knowledge base health check failed: {e}")
|
340 |
+
health_status["knowledge_base"] = False
|
341 |
+
|
342 |
+
try:
|
343 |
+
# Test scripts
|
344 |
+
from src.agent.scripts import ConversationScripts
|
345 |
+
|
346 |
+
scripts = ConversationScripts(self.scripts_file)
|
347 |
+
health_status["scripts"] = len(scripts.get_all_script_types()) > 0
|
348 |
+
|
349 |
+
except Exception as e:
|
350 |
+
print(f"⚠️ Scripts health check failed: {e}")
|
351 |
+
|
352 |
+
# Overall health
|
353 |
+
health_status["overall"] = all(
|
354 |
+
[
|
355 |
+
health_status["agent"],
|
356 |
+
health_status["memory"],
|
357 |
+
health_status["llm"],
|
358 |
+
health_status["knowledge_base"],
|
359 |
+
health_status["scripts"],
|
360 |
+
]
|
361 |
+
)
|
362 |
+
|
363 |
+
return health_status
|
364 |
+
|
365 |
+
|
366 |
+
# Usage example and testing
|
367 |
+
if __name__ == "__main__":
|
368 |
+
# Initialize agent
|
369 |
+
agent = ViettelPayAgent()
|
370 |
+
|
371 |
+
# Health check
|
372 |
+
print("\n🏥 Health Check:")
|
373 |
+
health = agent.health_check()
|
374 |
+
for component, status in health.items():
|
375 |
+
status_icon = "✅" if status else "❌"
|
376 |
+
print(f" {component}: {status_icon}")
|
377 |
+
|
378 |
+
if not health["overall"]:
|
379 |
+
print("\n⚠️ Some components are not healthy. Check requirements and data files.")
|
380 |
+
exit(1)
|
381 |
+
|
382 |
+
print(f"\n🤖 Agent ready! Workflow info: {agent.get_workflow_info()}")
|
383 |
+
|
384 |
+
# Test multi-turn conversation with query enhancement
|
385 |
+
test_thread = "test_conversation"
|
386 |
+
|
387 |
+
print(
|
388 |
+
f"\n🧪 Testing multi-turn conversation with query enhancement (thread: {test_thread}):"
|
389 |
+
)
|
390 |
+
|
391 |
+
test_messages = [
|
392 |
+
"Xin chào!",
|
393 |
+
"Mã lỗi 606 là gì?",
|
394 |
+
"Làm sao khắc phục?", # This should be enhanced to "làm sao khắc phục lỗi 606"
|
395 |
+
"Còn lỗi nào khác tương tự không?", # This should be enhanced with error context
|
396 |
+
"Cảm ơn bạn!",
|
397 |
+
]
|
398 |
+
|
399 |
+
for i, message in enumerate(test_messages, 1):
|
400 |
+
print(f"\n--- Turn {i} ---")
|
401 |
+
result = agent.process_message(message, test_thread)
|
402 |
+
print(f"User: {message}")
|
403 |
+
print(f"Bot: {result['response'][:150]}...")
|
404 |
+
|
405 |
+
if result.get("enhanced_query") and result["enhanced_query"] != message:
|
406 |
+
print(f"🚀 Query enhanced: {result['enhanced_query']}")
|
407 |
+
|
408 |
+
# Show conversation history
|
409 |
+
if i > 1:
|
410 |
+
history = agent.get_conversation_history(test_thread)
|
411 |
+
print(f"History length: {len(history)} messages")
|
412 |
+
|
413 |
+
print(f"\n📜 Final conversation history:")
|
414 |
+
history = agent.get_conversation_history(test_thread)
|
415 |
+
for i, msg in enumerate(history, 1):
|
416 |
+
print(f" {i}. {msg['role']}: {msg['content'][:100]}...")
|
src/evaluation/__pycache__/prompts.cpython-311.pyc
ADDED
Binary file (16.4 kB). View file
|
|
src/evaluation/__pycache__/single_turn_retrieval.cpython-311.pyc
ADDED
Binary file (39.1 kB). View file
|
|
src/evaluation/intent_classification.py
ADDED
@@ -0,0 +1,901 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simplified Intent Classification Evaluation for ViettelPay AI Agent
|
3 |
+
Removed pattern-based generation, improved chunk mixing, and configurable conversations per chunk
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
import time
|
11 |
+
import random
|
12 |
+
from typing import Dict, List, Optional
|
13 |
+
from pathlib import Path
|
14 |
+
from collections import defaultdict, Counter
|
15 |
+
import pandas as pd
|
16 |
+
from tqdm import tqdm
|
17 |
+
import re
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
# Load environment variables from .env file
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
|
23 |
+
load_dotenv()
|
24 |
+
|
25 |
+
# Add the project root to Python path so we can import from src
|
26 |
+
project_root = Path(__file__).parent.parent.parent
|
27 |
+
sys.path.insert(0, str(project_root))
|
28 |
+
|
29 |
+
# Import existing components
|
30 |
+
from src.evaluation.prompts import INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT
|
31 |
+
from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
|
32 |
+
from src.llm.llm_client import LLMClientFactory
|
33 |
+
from src.agent.nodes import classify_intent_node, ViettelPayState
|
34 |
+
from langchain_core.messages import HumanMessage
|
35 |
+
|
36 |
+
|
37 |
+
class IntentDatasetCreator:
|
38 |
+
"""Simplified intent classification dataset creator with two strategies"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self, gemini_api_key: str, knowledge_base: ViettelKnowledgeBase = None
|
42 |
+
):
|
43 |
+
"""Initialize with Gemini API key and optional knowledge base"""
|
44 |
+
self.llm_client = LLMClientFactory.create_client(
|
45 |
+
"gemini", api_key=gemini_api_key, model="gemini-2.0-flash"
|
46 |
+
)
|
47 |
+
self.knowledge_base = knowledge_base
|
48 |
+
self.dataset = {
|
49 |
+
"conversations": {},
|
50 |
+
"generation_methods": {},
|
51 |
+
"intent_distribution": {},
|
52 |
+
"metadata": {
|
53 |
+
"total_conversations": 0,
|
54 |
+
"total_user_messages": 0,
|
55 |
+
"creation_timestamp": time.time(),
|
56 |
+
},
|
57 |
+
}
|
58 |
+
|
59 |
+
print("✅ IntentDatasetCreator initialized (simplified version)")
|
60 |
+
|
61 |
+
def generate_json_response(
|
62 |
+
self, prompt: str, max_retries: int = 3
|
63 |
+
) -> Optional[Dict]:
|
64 |
+
"""Generate response and parse as JSON with retries"""
|
65 |
+
for attempt in range(max_retries):
|
66 |
+
try:
|
67 |
+
response = self.llm_client.generate(prompt, temperature=0.1)
|
68 |
+
|
69 |
+
if response:
|
70 |
+
response_text = response.strip()
|
71 |
+
json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
|
72 |
+
if json_match:
|
73 |
+
json_text = json_match.group()
|
74 |
+
return json.loads(json_text)
|
75 |
+
else:
|
76 |
+
return json.loads(response_text)
|
77 |
+
|
78 |
+
except json.JSONDecodeError as e:
|
79 |
+
print(f"⚠️ JSON parsing error (attempt {attempt + 1}): {e}")
|
80 |
+
if attempt == max_retries - 1:
|
81 |
+
print(f"❌ Failed to parse JSON after {max_retries} attempts")
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
print(f"⚠️ API error (attempt {attempt + 1}): {e}")
|
85 |
+
if attempt < max_retries - 1:
|
86 |
+
time.sleep(2**attempt)
|
87 |
+
|
88 |
+
return None
|
89 |
+
|
90 |
+
def get_all_chunks(self) -> List[Dict]:
|
91 |
+
"""Get ALL chunks from ChromaDB vectorstore"""
|
92 |
+
print(f"📚 Retrieving ALL chunks from ChromaDB vectorstore...")
|
93 |
+
|
94 |
+
if not self.knowledge_base:
|
95 |
+
raise ValueError("Knowledge base not provided.")
|
96 |
+
|
97 |
+
try:
|
98 |
+
if (
|
99 |
+
not hasattr(self.knowledge_base, "chroma_retriever")
|
100 |
+
or not self.knowledge_base.chroma_retriever
|
101 |
+
):
|
102 |
+
raise ValueError("ChromaDB retriever not found in knowledge base")
|
103 |
+
|
104 |
+
vectorstore = self.knowledge_base.chroma_retriever.vectorstore
|
105 |
+
all_docs = vectorstore.get(include=["documents", "metadatas"])
|
106 |
+
|
107 |
+
documents = all_docs["documents"]
|
108 |
+
metadatas = all_docs["metadatas"]
|
109 |
+
|
110 |
+
all_chunks = []
|
111 |
+
seen_content_hashes = set()
|
112 |
+
|
113 |
+
for i, (content, metadata) in enumerate(zip(documents, metadatas)):
|
114 |
+
content_hash = hash(content[:300])
|
115 |
+
|
116 |
+
if (
|
117 |
+
content_hash not in seen_content_hashes
|
118 |
+
and len(content.strip()) > 100
|
119 |
+
):
|
120 |
+
chunk_info = {
|
121 |
+
"id": f"chunk_{len(all_chunks)}",
|
122 |
+
"content": content,
|
123 |
+
"metadata": metadata or {},
|
124 |
+
}
|
125 |
+
all_chunks.append(chunk_info)
|
126 |
+
seen_content_hashes.add(content_hash)
|
127 |
+
|
128 |
+
print(f"✅ Retrieved {len(all_chunks)} unique chunks from ChromaDB")
|
129 |
+
return all_chunks
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
print(f"❌ Error accessing ChromaDB: {e}")
|
133 |
+
return []
|
134 |
+
|
135 |
+
def generate_single_chunk_conversations(
|
136 |
+
self, chunk: Dict, num_conversations: int = 3
|
137 |
+
) -> List[Dict]:
|
138 |
+
"""Generate conversations from single chunk"""
|
139 |
+
content = chunk["content"]
|
140 |
+
|
141 |
+
generation_instruction = "Tạo cuộc hội thoại tập trung vào chủ đề chính của tài liệu. Bao gồm cả các intent phổ biến như greeting, unclear, human_request để tăng tính đa dạng"
|
142 |
+
|
143 |
+
prompt = INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT.format(
|
144 |
+
num_conversations=num_conversations,
|
145 |
+
content=content,
|
146 |
+
generation_instruction=generation_instruction,
|
147 |
+
)
|
148 |
+
|
149 |
+
response_json = self.generate_json_response(prompt)
|
150 |
+
|
151 |
+
if response_json and "conversations" in response_json:
|
152 |
+
conversations = response_json["conversations"]
|
153 |
+
valid_conversations = []
|
154 |
+
|
155 |
+
for i, conversation in enumerate(conversations):
|
156 |
+
if "turns" in conversation and len(conversation["turns"]) >= 1:
|
157 |
+
valid_turns = []
|
158 |
+
for turn in conversation["turns"]:
|
159 |
+
if "user" in turn and "intent" in turn:
|
160 |
+
valid_turns.append(turn)
|
161 |
+
|
162 |
+
if valid_turns:
|
163 |
+
conv_obj = {
|
164 |
+
"id": f"single_{chunk['id']}_{i}",
|
165 |
+
"turns": valid_turns,
|
166 |
+
"generation_method": "single_chunk",
|
167 |
+
"source_chunks": [chunk["id"]],
|
168 |
+
"chunk_metadata": [chunk["metadata"]],
|
169 |
+
}
|
170 |
+
valid_conversations.append(conv_obj)
|
171 |
+
return valid_conversations
|
172 |
+
else:
|
173 |
+
print(f"⚠️ No valid conversations generated for chunk {chunk['id']}")
|
174 |
+
return []
|
175 |
+
|
176 |
+
def generate_multi_chunk_conversations(
|
177 |
+
self, chunks: List[Dict], num_conversations: int = 3
|
178 |
+
) -> List[Dict]:
|
179 |
+
"""Generate conversations from multiple chunks (2-3 chunks)"""
|
180 |
+
# Combine content from multiple chunks
|
181 |
+
combined_content = ""
|
182 |
+
for i, chunk in enumerate(chunks):
|
183 |
+
combined_content += f"\n\n--- Chủ đề {i+1} ---\n" + chunk["content"]
|
184 |
+
|
185 |
+
generation_instruction = f"Tạo cuộc hội thoại tự nhiên kết hợp {len(chunks)} chủ đề khác nhau. Người dùng có thể chuyển từ chủ đề này sang chủ đề khác. Đặc biệt bao gồm các intent như greeting, unclear, human_request để cuộc hội thoại thực tế hơn"
|
186 |
+
|
187 |
+
prompt = INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT.format(
|
188 |
+
num_conversations=num_conversations,
|
189 |
+
content=combined_content,
|
190 |
+
generation_instruction=generation_instruction,
|
191 |
+
)
|
192 |
+
|
193 |
+
response_json = self.generate_json_response(prompt)
|
194 |
+
|
195 |
+
if response_json and "conversations" in response_json:
|
196 |
+
conversations = response_json["conversations"]
|
197 |
+
valid_conversations = []
|
198 |
+
|
199 |
+
for i, conversation in enumerate(conversations):
|
200 |
+
if "turns" in conversation and len(conversation["turns"]) >= 1:
|
201 |
+
valid_turns = []
|
202 |
+
for turn in conversation["turns"]:
|
203 |
+
if "user" in turn and "intent" in turn:
|
204 |
+
valid_turns.append(turn)
|
205 |
+
|
206 |
+
if valid_turns:
|
207 |
+
conv_obj = {
|
208 |
+
"id": f"multi_{'-'.join([c['id'] for c in chunks])}_{i}",
|
209 |
+
"turns": valid_turns,
|
210 |
+
"generation_method": "multi_chunk",
|
211 |
+
"source_chunks": [c["id"] for c in chunks],
|
212 |
+
"chunk_metadata": [c["metadata"] for c in chunks],
|
213 |
+
}
|
214 |
+
valid_conversations.append(conv_obj)
|
215 |
+
|
216 |
+
print(
|
217 |
+
f"✅ Generated {len(valid_conversations)} conversations for multi-chunk {[c['id'] for c in chunks]}"
|
218 |
+
)
|
219 |
+
return valid_conversations
|
220 |
+
else:
|
221 |
+
print(
|
222 |
+
f"⚠️ No valid conversations generated for chunks {[c['id'] for c in chunks]}"
|
223 |
+
)
|
224 |
+
return []
|
225 |
+
|
226 |
+
def create_intent_dataset(
|
227 |
+
self,
|
228 |
+
num_conversations_per_chunk: int = 3,
|
229 |
+
save_path: str = "evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json",
|
230 |
+
) -> Dict:
|
231 |
+
"""Create intent classification dataset using two strategies only"""
|
232 |
+
print(f"\n🚀 Creating intent classification dataset...")
|
233 |
+
print(f" Conversations per chunk: {num_conversations_per_chunk}")
|
234 |
+
|
235 |
+
# Step 1: Get all chunks
|
236 |
+
all_chunks = self.get_all_chunks()
|
237 |
+
if not all_chunks:
|
238 |
+
raise ValueError("No chunks found in knowledge base!")
|
239 |
+
|
240 |
+
total_chunks = len(all_chunks)
|
241 |
+
print(f"✅ Using all {total_chunks} chunks and shuffle them")
|
242 |
+
random.shuffle(all_chunks)
|
243 |
+
|
244 |
+
# Step 2: Split chunks for two strategies (60% single, 40% multi)
|
245 |
+
split_point = int(total_chunks * 0.6)
|
246 |
+
single_chunks = all_chunks[:split_point]
|
247 |
+
multi_chunks = all_chunks[split_point:]
|
248 |
+
|
249 |
+
print(f"📊 Distribution plan:")
|
250 |
+
print(
|
251 |
+
f" • Single chunk: {len(single_chunks)} chunks → ~{len(single_chunks) * num_conversations_per_chunk} conversations"
|
252 |
+
)
|
253 |
+
print(
|
254 |
+
f" • Multi chunk: {len(multi_chunks)} chunks → ~{len(multi_chunks) // 2.5 * num_conversations_per_chunk} conversations"
|
255 |
+
)
|
256 |
+
|
257 |
+
all_conversations = []
|
258 |
+
|
259 |
+
# Step 3: Generate single-chunk conversations
|
260 |
+
print(f"\n💬 Generating single-chunk conversations...")
|
261 |
+
for chunk in tqdm(single_chunks, desc="Single-chunk conversations"):
|
262 |
+
conversations = self.generate_single_chunk_conversations(
|
263 |
+
chunk, num_conversations_per_chunk
|
264 |
+
)
|
265 |
+
all_conversations.extend(conversations)
|
266 |
+
time.sleep(0.1)
|
267 |
+
|
268 |
+
# Step 4: Generate multi-chunk conversations (2-3 chunks randomly)
|
269 |
+
print(f"\n🔀 Generating multi-chunk conversations...")
|
270 |
+
random.shuffle(multi_chunks) # Randomize order
|
271 |
+
|
272 |
+
i = 0
|
273 |
+
while i < len(multi_chunks):
|
274 |
+
# Randomly choose to use 2 or 3 chunks
|
275 |
+
chunk_count = random.choice([2, 3])
|
276 |
+
chunk_group = multi_chunks[i : i + chunk_count]
|
277 |
+
|
278 |
+
# Only proceed if we have at least 2 chunks
|
279 |
+
if len(chunk_group) >= 2:
|
280 |
+
conversations = self.generate_multi_chunk_conversations(
|
281 |
+
chunk_group, num_conversations_per_chunk
|
282 |
+
)
|
283 |
+
all_conversations.extend(conversations)
|
284 |
+
time.sleep(0.1)
|
285 |
+
|
286 |
+
i += chunk_count
|
287 |
+
|
288 |
+
# Step 5: Track generation methods and intent distribution
|
289 |
+
method_stats = defaultdict(int)
|
290 |
+
intent_counts = Counter()
|
291 |
+
|
292 |
+
for conv in all_conversations:
|
293 |
+
method_stats[conv["generation_method"]] += 1
|
294 |
+
for turn in conv["turns"]:
|
295 |
+
intent_counts[turn["intent"]] += 1
|
296 |
+
|
297 |
+
# Step 6: Populate dataset structure
|
298 |
+
self.dataset["conversations"] = {conv["id"]: conv for conv in all_conversations}
|
299 |
+
|
300 |
+
self.dataset["generation_methods"] = dict(method_stats)
|
301 |
+
self.dataset["intent_distribution"] = dict(intent_counts)
|
302 |
+
|
303 |
+
# Step 7: Update metadata
|
304 |
+
total_user_messages = sum(len(conv["turns"]) for conv in all_conversations)
|
305 |
+
|
306 |
+
self.dataset["metadata"].update(
|
307 |
+
{
|
308 |
+
"total_conversations": len(all_conversations),
|
309 |
+
"total_user_messages": total_user_messages,
|
310 |
+
"chunks_used": total_chunks,
|
311 |
+
"conversations_per_chunk": num_conversations_per_chunk,
|
312 |
+
"generation_distribution": dict(method_stats),
|
313 |
+
"completion_timestamp": time.time(),
|
314 |
+
}
|
315 |
+
)
|
316 |
+
|
317 |
+
# Step 8: Save dataset
|
318 |
+
os.makedirs(
|
319 |
+
os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
|
320 |
+
exist_ok=True,
|
321 |
+
)
|
322 |
+
|
323 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
324 |
+
json.dump(self.dataset, f, ensure_ascii=False, indent=2)
|
325 |
+
|
326 |
+
print(f"\n✅ Intent classification dataset created successfully!")
|
327 |
+
print(f" 📁 Saved to: {save_path}")
|
328 |
+
print(f" 📊 Statistics:")
|
329 |
+
print(f" • Total conversations: {len(all_conversations)}")
|
330 |
+
print(f" • Total user messages: {total_user_messages}")
|
331 |
+
print(f" • Conversations per chunk: {num_conversations_per_chunk}")
|
332 |
+
print(f" • Generation methods: {dict(method_stats)}")
|
333 |
+
print(f" • Intent distribution: {dict(intent_counts)}")
|
334 |
+
|
335 |
+
return self.dataset
|
336 |
+
|
337 |
+
|
338 |
+
class IntentClassificationEvaluator:
|
339 |
+
"""Evaluator for intent classification performance with method-specific analysis"""
|
340 |
+
|
341 |
+
def __init__(self, dataset: Dict, llm_client):
|
342 |
+
"""Initialize evaluator with dataset and LLM client"""
|
343 |
+
self.dataset = dataset
|
344 |
+
self.llm_client = llm_client
|
345 |
+
|
346 |
+
# Define expected intents
|
347 |
+
self.expected_intents = [
|
348 |
+
"greeting",
|
349 |
+
"faq",
|
350 |
+
"error_help",
|
351 |
+
"procedure_guide",
|
352 |
+
"human_request",
|
353 |
+
"out_of_scope",
|
354 |
+
"unclear",
|
355 |
+
]
|
356 |
+
|
357 |
+
# Critical intents for business
|
358 |
+
self.critical_intents = ["error_help", "human_request"]
|
359 |
+
|
360 |
+
# Define flow mappings based on agent routing logic
|
361 |
+
self.script_based_intents = {
|
362 |
+
"greeting",
|
363 |
+
"out_of_scope",
|
364 |
+
"human_request",
|
365 |
+
"unclear",
|
366 |
+
}
|
367 |
+
self.knowledge_based_intents = {
|
368 |
+
"faq",
|
369 |
+
"error_help",
|
370 |
+
"procedure_guide",
|
371 |
+
}
|
372 |
+
|
373 |
+
def _get_intent_flow(self, intent: str) -> str:
|
374 |
+
"""Classify intent into flow type based on agent routing logic"""
|
375 |
+
if intent in self.script_based_intents:
|
376 |
+
return "script_based"
|
377 |
+
elif intent in self.knowledge_based_intents:
|
378 |
+
return "knowledge_based"
|
379 |
+
else:
|
380 |
+
return "unknown"
|
381 |
+
|
382 |
+
def _make_json_serializable(self, obj):
|
383 |
+
"""Convert numpy types to native Python types for JSON serialization"""
|
384 |
+
try:
|
385 |
+
import numpy as np
|
386 |
+
|
387 |
+
if isinstance(obj, dict):
|
388 |
+
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
389 |
+
elif isinstance(obj, list):
|
390 |
+
return [self._make_json_serializable(item) for item in obj]
|
391 |
+
elif isinstance(obj, np.integer):
|
392 |
+
return int(obj)
|
393 |
+
elif isinstance(obj, np.floating):
|
394 |
+
return float(obj)
|
395 |
+
elif isinstance(obj, np.ndarray):
|
396 |
+
return obj.tolist()
|
397 |
+
else:
|
398 |
+
return obj
|
399 |
+
except ImportError:
|
400 |
+
# If numpy is not available, just return the object as-is
|
401 |
+
if isinstance(obj, dict):
|
402 |
+
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
403 |
+
elif isinstance(obj, list):
|
404 |
+
return [self._make_json_serializable(item) for item in obj]
|
405 |
+
else:
|
406 |
+
return obj
|
407 |
+
|
408 |
+
def calculate_essential_metrics(
|
409 |
+
self, ground_truth: List[str], predictions: List[str]
|
410 |
+
) -> Dict:
|
411 |
+
"""Calculate only essential metrics: accuracy, macro, per-class"""
|
412 |
+
try:
|
413 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
414 |
+
|
415 |
+
overall_accuracy = accuracy_score(ground_truth, predictions)
|
416 |
+
|
417 |
+
# Calculate macro metrics (equal weight per intent)
|
418 |
+
precision, recall, f1, support = precision_recall_fscore_support(
|
419 |
+
ground_truth, predictions, average="macro", zero_division=0
|
420 |
+
)
|
421 |
+
|
422 |
+
macro_metrics = {
|
423 |
+
"macro_precision": precision,
|
424 |
+
"macro_recall": recall,
|
425 |
+
"macro_f1": f1,
|
426 |
+
}
|
427 |
+
|
428 |
+
# Calculate per-class metrics
|
429 |
+
precision_per_class, recall_per_class, f1_per_class, support_per_class = (
|
430 |
+
precision_recall_fscore_support(
|
431 |
+
ground_truth, predictions, average=None, zero_division=0
|
432 |
+
)
|
433 |
+
)
|
434 |
+
|
435 |
+
# Get unique labels
|
436 |
+
unique_labels = sorted(list(set(ground_truth + predictions)))
|
437 |
+
|
438 |
+
per_class_metrics = {}
|
439 |
+
for i, label in enumerate(unique_labels):
|
440 |
+
if i < len(precision_per_class):
|
441 |
+
per_class_metrics[label] = {
|
442 |
+
"precision": float(precision_per_class[i]),
|
443 |
+
"recall": float(recall_per_class[i]),
|
444 |
+
"f1": float(f1_per_class[i]),
|
445 |
+
"support": int(
|
446 |
+
support_per_class[i] if i < len(support_per_class) else 0
|
447 |
+
),
|
448 |
+
}
|
449 |
+
|
450 |
+
# Calculate critical intent recall
|
451 |
+
critical_recall = {}
|
452 |
+
for intent in self.critical_intents:
|
453 |
+
if intent in per_class_metrics:
|
454 |
+
critical_recall[intent] = per_class_metrics[intent]["recall"]
|
455 |
+
|
456 |
+
return {
|
457 |
+
"overall_accuracy": float(overall_accuracy),
|
458 |
+
"macro_precision": float(macro_metrics["macro_precision"]),
|
459 |
+
"macro_recall": float(macro_metrics["macro_recall"]),
|
460 |
+
"macro_f1": float(macro_metrics["macro_f1"]),
|
461 |
+
"per_class_metrics": per_class_metrics,
|
462 |
+
"critical_intent_recall": {
|
463 |
+
k: float(v) for k, v in critical_recall.items()
|
464 |
+
},
|
465 |
+
}
|
466 |
+
|
467 |
+
except ImportError:
|
468 |
+
print("⚠️ scikit-learn not installed. Using basic accuracy only.")
|
469 |
+
overall_accuracy = sum(
|
470 |
+
1 for gt, pred in zip(ground_truth, predictions) if gt == pred
|
471 |
+
) / len(predictions)
|
472 |
+
|
473 |
+
return {"overall_accuracy": float(overall_accuracy)}
|
474 |
+
|
475 |
+
def evaluate_intent_classification(self) -> Dict:
|
476 |
+
"""Evaluate intent classification performance with method and flow breakdown"""
|
477 |
+
print(f"\n🎯 Running intent classification evaluation...")
|
478 |
+
|
479 |
+
conversations = self.dataset["conversations"]
|
480 |
+
|
481 |
+
# Initialize tracking
|
482 |
+
all_predictions = []
|
483 |
+
all_ground_truth = []
|
484 |
+
method_results = defaultdict(lambda: {"predictions": [], "ground_truth": []})
|
485 |
+
flow_results = defaultdict(lambda: {"predictions": [], "ground_truth": []})
|
486 |
+
conversation_results = {}
|
487 |
+
|
488 |
+
# Process each conversation
|
489 |
+
for conv_id, conv_data in tqdm(
|
490 |
+
conversations.items(), desc="Evaluating conversations"
|
491 |
+
):
|
492 |
+
generation_method = conv_data.get("generation_method", "unknown")
|
493 |
+
|
494 |
+
conversation_results[conv_id] = {
|
495 |
+
"turns": [],
|
496 |
+
"accuracy": 0,
|
497 |
+
"generation_method": generation_method,
|
498 |
+
}
|
499 |
+
|
500 |
+
correct_predictions = 0
|
501 |
+
total_turns = len(conv_data["turns"])
|
502 |
+
|
503 |
+
# Process each turn in the conversation
|
504 |
+
for turn_idx, turn in enumerate(conv_data["turns"]):
|
505 |
+
user_message = turn["user"]
|
506 |
+
ground_truth_intent = turn["intent"]
|
507 |
+
|
508 |
+
try:
|
509 |
+
# Create messages in the format expected by classify_intent_node
|
510 |
+
messages = [HumanMessage(content=user_message)]
|
511 |
+
|
512 |
+
# Create a mock state for the intent classification node
|
513 |
+
state = ViettelPayState(messages=messages)
|
514 |
+
|
515 |
+
# Use the classify_intent_node directly
|
516 |
+
result_state = classify_intent_node(state, self.llm_client)
|
517 |
+
predicted_intent = result_state.get("intent", "unclear")
|
518 |
+
|
519 |
+
# Track results
|
520 |
+
is_correct = predicted_intent == ground_truth_intent
|
521 |
+
if is_correct:
|
522 |
+
correct_predictions += 1
|
523 |
+
|
524 |
+
# Add to overall tracking
|
525 |
+
all_predictions.append(predicted_intent)
|
526 |
+
all_ground_truth.append(ground_truth_intent)
|
527 |
+
|
528 |
+
# Add to method-specific tracking
|
529 |
+
method_results[generation_method]["predictions"].append(
|
530 |
+
predicted_intent
|
531 |
+
)
|
532 |
+
method_results[generation_method]["ground_truth"].append(
|
533 |
+
ground_truth_intent
|
534 |
+
)
|
535 |
+
|
536 |
+
# Add to flow-specific tracking
|
537 |
+
ground_truth_flow = self._get_intent_flow(ground_truth_intent)
|
538 |
+
predicted_flow = self._get_intent_flow(predicted_intent)
|
539 |
+
|
540 |
+
flow_results[ground_truth_flow]["predictions"].append(
|
541 |
+
predicted_intent
|
542 |
+
)
|
543 |
+
flow_results[ground_truth_flow]["ground_truth"].append(
|
544 |
+
ground_truth_intent
|
545 |
+
)
|
546 |
+
|
547 |
+
conversation_results[conv_id]["turns"].append(
|
548 |
+
{
|
549 |
+
"turn": turn_idx + 1,
|
550 |
+
"user_message": user_message,
|
551 |
+
"ground_truth": ground_truth_intent,
|
552 |
+
"predicted": predicted_intent,
|
553 |
+
"correct": is_correct,
|
554 |
+
}
|
555 |
+
)
|
556 |
+
|
557 |
+
except Exception as e:
|
558 |
+
print(f"⚠️ Error processing turn {turn_idx} in {conv_id}: {e}")
|
559 |
+
# Use "unclear" as fallback prediction
|
560 |
+
all_predictions.append("unclear")
|
561 |
+
all_ground_truth.append(ground_truth_intent)
|
562 |
+
method_results[generation_method]["predictions"].append("unclear")
|
563 |
+
method_results[generation_method]["ground_truth"].append(
|
564 |
+
ground_truth_intent
|
565 |
+
)
|
566 |
+
|
567 |
+
# Add to flow-specific tracking (for errors)
|
568 |
+
ground_truth_flow = self._get_intent_flow(ground_truth_intent)
|
569 |
+
flow_results[ground_truth_flow]["predictions"].append("unclear")
|
570 |
+
flow_results[ground_truth_flow]["ground_truth"].append(
|
571 |
+
ground_truth_intent
|
572 |
+
)
|
573 |
+
|
574 |
+
# Calculate conversation accuracy
|
575 |
+
conversation_results[conv_id]["accuracy"] = float(
|
576 |
+
correct_predictions / total_turns if total_turns > 0 else 0
|
577 |
+
)
|
578 |
+
|
579 |
+
# Calculate overall metrics
|
580 |
+
overall_metrics = self.calculate_essential_metrics(
|
581 |
+
all_ground_truth, all_predictions
|
582 |
+
)
|
583 |
+
|
584 |
+
# Calculate method-specific metrics
|
585 |
+
method_metrics = {}
|
586 |
+
for method, method_data in method_results.items():
|
587 |
+
if method_data["predictions"]: # Ensure we have data
|
588 |
+
method_metrics[method] = self.calculate_essential_metrics(
|
589 |
+
method_data["ground_truth"], method_data["predictions"]
|
590 |
+
)
|
591 |
+
method_metrics[method]["total_messages"] = len(
|
592 |
+
method_data["predictions"]
|
593 |
+
)
|
594 |
+
|
595 |
+
# Calculate flow-specific metrics
|
596 |
+
flow_metrics = {}
|
597 |
+
for flow, flow_data in flow_results.items():
|
598 |
+
if flow_data["predictions"]: # Ensure we have data
|
599 |
+
flow_metrics[flow] = self.calculate_essential_metrics(
|
600 |
+
flow_data["ground_truth"], flow_data["predictions"]
|
601 |
+
)
|
602 |
+
flow_metrics[flow]["total_messages"] = len(flow_data["predictions"])
|
603 |
+
|
604 |
+
results = {
|
605 |
+
"overall_metrics": overall_metrics,
|
606 |
+
"method_specific_metrics": method_metrics,
|
607 |
+
"flow_specific_metrics": flow_metrics,
|
608 |
+
"conversation_results": conversation_results,
|
609 |
+
"intent_distribution": {
|
610 |
+
"ground_truth": dict(Counter(all_ground_truth)),
|
611 |
+
"predicted": dict(Counter(all_predictions)),
|
612 |
+
},
|
613 |
+
"generation_methods": self.dataset.get("generation_methods", {}),
|
614 |
+
}
|
615 |
+
|
616 |
+
# Make sure all values are JSON serializable
|
617 |
+
results = self._make_json_serializable(results)
|
618 |
+
|
619 |
+
return results
|
620 |
+
|
621 |
+
def print_evaluation_results(self, results: Dict):
|
622 |
+
"""Print comprehensive evaluation results"""
|
623 |
+
print(f"\n🎯 INTENT CLASSIFICATION EVALUATION RESULTS")
|
624 |
+
print("=" * 60)
|
625 |
+
|
626 |
+
# Overall performance
|
627 |
+
overall = results["overall_metrics"]
|
628 |
+
print(f"\n📊 Overall Performance:")
|
629 |
+
print(f" Accuracy: {overall['overall_accuracy']:.3f}")
|
630 |
+
if "macro_precision" in overall:
|
631 |
+
print(f" Macro Precision: {overall['macro_precision']:.3f}")
|
632 |
+
print(f" Macro Recall: {overall['macro_recall']:.3f}")
|
633 |
+
print(f" Macro F1: {overall['macro_f1']:.3f}")
|
634 |
+
|
635 |
+
# Per-class performance
|
636 |
+
if "per_class_metrics" in overall:
|
637 |
+
print(f"\n📋 Per-Class Performance:")
|
638 |
+
print(
|
639 |
+
f"{'Intent':<15} {'Precision':<10} {'Recall':<10} {'F1':<10} {'Support':<10}"
|
640 |
+
)
|
641 |
+
print("-" * 65)
|
642 |
+
|
643 |
+
per_class = overall["per_class_metrics"]
|
644 |
+
for intent in self.expected_intents:
|
645 |
+
if intent in per_class:
|
646 |
+
metrics = per_class[intent]
|
647 |
+
print(
|
648 |
+
f"{intent:<15} {metrics['precision']:<10.3f} {metrics['recall']:<10.3f} {metrics['f1']:<10.3f} {metrics['support']:<10}"
|
649 |
+
)
|
650 |
+
|
651 |
+
# Critical intents performance
|
652 |
+
if "critical_intent_recall" in overall:
|
653 |
+
print(f"\n🚨 Critical Intent Performance:")
|
654 |
+
for intent, recall in overall["critical_intent_recall"].items():
|
655 |
+
status = "✅" if recall >= 0.85 else "⚠️" if recall >= 0.75 else "❌"
|
656 |
+
print(f" {status} {intent}: Recall = {recall:.3f}")
|
657 |
+
|
658 |
+
# Method-specific performance
|
659 |
+
print(f"\n🔄 Performance by Generation Method:")
|
660 |
+
method_metrics = results["method_specific_metrics"]
|
661 |
+
if method_metrics:
|
662 |
+
print(f"{'Method':<20} {'Accuracy':<10} {'Macro F1':<10} {'Messages':<10}")
|
663 |
+
print("-" * 55)
|
664 |
+
|
665 |
+
for method, metrics in method_metrics.items():
|
666 |
+
accuracy = metrics["overall_accuracy"]
|
667 |
+
macro_f1 = metrics.get("macro_f1", 0)
|
668 |
+
total_msgs = metrics["total_messages"]
|
669 |
+
print(
|
670 |
+
f"{method:<20} {accuracy:<10.3f} {macro_f1:<10.3f} {total_msgs:<10}"
|
671 |
+
)
|
672 |
+
|
673 |
+
# Flow-specific performance
|
674 |
+
print(f"\n🔀 Performance by Agent Flow:")
|
675 |
+
flow_metrics = results["flow_specific_metrics"]
|
676 |
+
if flow_metrics:
|
677 |
+
print(
|
678 |
+
f"{'Flow Type':<20} {'Accuracy':<10} {'Macro F1':<10} {'Messages':<10}"
|
679 |
+
)
|
680 |
+
print("-" * 55)
|
681 |
+
|
682 |
+
for flow, metrics in flow_metrics.items():
|
683 |
+
accuracy = metrics["overall_accuracy"]
|
684 |
+
macro_f1 = metrics.get("macro_f1", 0)
|
685 |
+
total_msgs = metrics["total_messages"]
|
686 |
+
flow_display = f"{flow}_flow"
|
687 |
+
print(
|
688 |
+
f"{flow_display:<20} {accuracy:<10.3f} {macro_f1:<10.3f} {total_msgs:<10}"
|
689 |
+
)
|
690 |
+
|
691 |
+
# Intent distribution comparison
|
692 |
+
print(f"\n📈 Intent Distribution:")
|
693 |
+
gt_dist = results["intent_distribution"]["ground_truth"]
|
694 |
+
pred_dist = results["intent_distribution"]["predicted"]
|
695 |
+
|
696 |
+
print(f"{'Intent':<15} {'Ground Truth':<15} {'Predicted':<15}")
|
697 |
+
print("-" * 50)
|
698 |
+
|
699 |
+
all_intents = set(list(gt_dist.keys()) + list(pred_dist.keys()))
|
700 |
+
for intent in sorted(all_intents):
|
701 |
+
gt_count = gt_dist.get(intent, 0)
|
702 |
+
pred_count = pred_dist.get(intent, 0)
|
703 |
+
print(f"{intent:<15} {gt_count:<15} {pred_count:<15}")
|
704 |
+
|
705 |
+
# Method insights
|
706 |
+
print(f"\n💡 Method-Specific Insights:")
|
707 |
+
if method_metrics:
|
708 |
+
method_accuracies = {
|
709 |
+
method: metrics["overall_accuracy"]
|
710 |
+
for method, metrics in method_metrics.items()
|
711 |
+
}
|
712 |
+
best_method = max(
|
713 |
+
method_accuracies.keys(), key=lambda k: method_accuracies[k]
|
714 |
+
)
|
715 |
+
worst_method = min(
|
716 |
+
method_accuracies.keys(), key=lambda k: method_accuracies[k]
|
717 |
+
)
|
718 |
+
|
719 |
+
print(
|
720 |
+
f" • Best performing method: {best_method} ({method_accuracies[best_method]:.3f})"
|
721 |
+
)
|
722 |
+
print(
|
723 |
+
f" • Most challenging method: {worst_method} ({method_accuracies[worst_method]:.3f})"
|
724 |
+
)
|
725 |
+
print(
|
726 |
+
f" • Performance gap: {method_accuracies[best_method] - method_accuracies[worst_method]:.3f}"
|
727 |
+
)
|
728 |
+
|
729 |
+
# Flow insights
|
730 |
+
print(f"\n🔀 Flow-Specific Insights:")
|
731 |
+
if flow_metrics:
|
732 |
+
flow_accuracies = {
|
733 |
+
flow: metrics["overall_accuracy"]
|
734 |
+
for flow, metrics in flow_metrics.items()
|
735 |
+
}
|
736 |
+
|
737 |
+
if len(flow_accuracies) >= 2:
|
738 |
+
best_flow = max(
|
739 |
+
flow_accuracies.keys(), key=lambda k: flow_accuracies[k]
|
740 |
+
)
|
741 |
+
worst_flow = min(
|
742 |
+
flow_accuracies.keys(), key=lambda k: flow_accuracies[k]
|
743 |
+
)
|
744 |
+
|
745 |
+
print(
|
746 |
+
f" • Best performing flow: {best_flow} ({flow_accuracies[best_flow]:.3f})"
|
747 |
+
)
|
748 |
+
print(
|
749 |
+
f" • Most challenging flow: {worst_flow} ({flow_accuracies[worst_flow]:.3f})"
|
750 |
+
)
|
751 |
+
print(
|
752 |
+
f" • Flow performance gap: {flow_accuracies[best_flow] - flow_accuracies[worst_flow]:.3f}"
|
753 |
+
)
|
754 |
+
|
755 |
+
# Provide interpretation
|
756 |
+
if (
|
757 |
+
"script_based" in flow_accuracies
|
758 |
+
and "knowledge_based" in flow_accuracies
|
759 |
+
):
|
760 |
+
script_acc = flow_accuracies["script_based"]
|
761 |
+
kb_acc = flow_accuracies["knowledge_based"]
|
762 |
+
|
763 |
+
if script_acc > kb_acc:
|
764 |
+
print(
|
765 |
+
f" • Script-based intents are easier to classify ({script_acc:.3f} vs {kb_acc:.3f})"
|
766 |
+
)
|
767 |
+
elif kb_acc > script_acc:
|
768 |
+
print(
|
769 |
+
f" • Knowledge-based intents are easier to classify ({kb_acc:.3f} vs {script_acc:.3f})"
|
770 |
+
)
|
771 |
+
else:
|
772 |
+
print(
|
773 |
+
f" • Both flows perform similarly ({script_acc:.3f} vs {kb_acc:.3f})"
|
774 |
+
)
|
775 |
+
else:
|
776 |
+
for flow, accuracy in flow_accuracies.items():
|
777 |
+
print(f" • {flow} flow accuracy: {accuracy:.3f}")
|
778 |
+
|
779 |
+
# Success criteria check
|
780 |
+
print(f"\n✅ Success Criteria Check:")
|
781 |
+
accuracy = overall["overall_accuracy"]
|
782 |
+
if accuracy >= 0.80:
|
783 |
+
print(f" 🎉 GOOD: Overall accuracy {accuracy:.3f} >= 0.80")
|
784 |
+
elif accuracy >= 0.75:
|
785 |
+
print(f" ⚠️ OKAY: Overall accuracy {accuracy:.3f} >= 0.75")
|
786 |
+
else:
|
787 |
+
print(f" ❌ NEEDS WORK: Overall accuracy {accuracy:.3f} < 0.75")
|
788 |
+
|
789 |
+
|
790 |
+
def main():
|
791 |
+
"""Main function for simplified intent classification evaluation"""
|
792 |
+
parser = argparse.ArgumentParser(
|
793 |
+
description="Simplified ViettelPay Intent Classification Evaluation"
|
794 |
+
)
|
795 |
+
parser.add_argument(
|
796 |
+
"--mode",
|
797 |
+
choices=["create", "evaluate", "full"],
|
798 |
+
default="full",
|
799 |
+
help="Mode: create dataset, evaluate, or full pipeline",
|
800 |
+
)
|
801 |
+
parser.add_argument(
|
802 |
+
"--dataset-path",
|
803 |
+
default="evaluation_data/datasets/intent_classification/viettelpay_intent_dataset.json",
|
804 |
+
help="Path to intent dataset",
|
805 |
+
)
|
806 |
+
parser.add_argument(
|
807 |
+
"--results-path",
|
808 |
+
default="evaluation_data/results/intent_classification/viettelpay_intent_results.json",
|
809 |
+
help="Path to save evaluation results",
|
810 |
+
)
|
811 |
+
parser.add_argument(
|
812 |
+
"--conversations-per-chunk",
|
813 |
+
type=int,
|
814 |
+
default=3,
|
815 |
+
help="Number of conversations per chunk (default: 3)",
|
816 |
+
)
|
817 |
+
parser.add_argument(
|
818 |
+
"--knowledge-base-path",
|
819 |
+
default="./knowledge_base",
|
820 |
+
help="Path to knowledge base",
|
821 |
+
)
|
822 |
+
|
823 |
+
args = parser.parse_args()
|
824 |
+
|
825 |
+
# Configuration
|
826 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
827 |
+
|
828 |
+
if not GEMINI_API_KEY:
|
829 |
+
print("❌ Please set GEMINI_API_KEY environment variable")
|
830 |
+
return
|
831 |
+
|
832 |
+
try:
|
833 |
+
# Initialize components based on mode
|
834 |
+
kb = None
|
835 |
+
if args.mode in ["create", "full"]:
|
836 |
+
# Initialize knowledge base only if creating dataset
|
837 |
+
print("🔧 Initializing ViettelPay knowledge base...")
|
838 |
+
kb = ViettelKnowledgeBase()
|
839 |
+
if not kb.load_knowledge_base(args.knowledge_base_path):
|
840 |
+
print(
|
841 |
+
"❌ Failed to load knowledge base. Please run build_database_script.py first."
|
842 |
+
)
|
843 |
+
return
|
844 |
+
|
845 |
+
# Step 1: Create dataset if requested
|
846 |
+
if args.mode in ["create", "full"]:
|
847 |
+
print(f"\n🎯 Creating simplified intent classification dataset...")
|
848 |
+
creator = IntentDatasetCreator(GEMINI_API_KEY, kb)
|
849 |
+
|
850 |
+
dataset = creator.create_intent_dataset(
|
851 |
+
num_conversations_per_chunk=args.conversations_per_chunk,
|
852 |
+
save_path=args.dataset_path,
|
853 |
+
)
|
854 |
+
|
855 |
+
# Step 2: Evaluate if requested
|
856 |
+
if args.mode in ["evaluate", "full"]:
|
857 |
+
print(f"\n📊 Evaluating intent classification...")
|
858 |
+
|
859 |
+
# Load dataset if not created in this run
|
860 |
+
if args.mode == "evaluate":
|
861 |
+
if not os.path.exists(args.dataset_path):
|
862 |
+
print(f"❌ Dataset not found: {args.dataset_path}")
|
863 |
+
return
|
864 |
+
|
865 |
+
with open(args.dataset_path, "r", encoding="utf-8") as f:
|
866 |
+
dataset = json.load(f)
|
867 |
+
|
868 |
+
# Initialize LLM client for intent classification
|
869 |
+
print("🤖 Initializing LLM client for intent classification...")
|
870 |
+
llm_client = LLMClientFactory.create_client(
|
871 |
+
"gemini", api_key=GEMINI_API_KEY, model="gemini-2.0-flash"
|
872 |
+
)
|
873 |
+
|
874 |
+
# Run evaluation
|
875 |
+
evaluator = IntentClassificationEvaluator(dataset, llm_client)
|
876 |
+
results = evaluator.evaluate_intent_classification()
|
877 |
+
evaluator.print_evaluation_results(results)
|
878 |
+
|
879 |
+
# Save results
|
880 |
+
if args.results_path:
|
881 |
+
with open(args.results_path, "w", encoding="utf-8") as f:
|
882 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
883 |
+
print(f"\n💾 Results saved to: {args.results_path}")
|
884 |
+
|
885 |
+
print(f"\n✅ Intent classification evaluation completed successfully!")
|
886 |
+
print(f"\n💡 Summary improvements made:")
|
887 |
+
print(f" • Removed pattern-based generation for simplicity")
|
888 |
+
print(f" • Added configurable conversations-per-chunk (default: 3)")
|
889 |
+
print(f" • Improved chunk mixing (random 2-3 chunks)")
|
890 |
+
print(f" • Enhanced prompts to include non-topic intents")
|
891 |
+
print(f" • Added flow-specific analysis (script-based vs knowledge-based)")
|
892 |
+
|
893 |
+
except Exception as e:
|
894 |
+
print(f"❌ Error in main execution: {e}")
|
895 |
+
import traceback
|
896 |
+
|
897 |
+
traceback.print_exc()
|
898 |
+
|
899 |
+
|
900 |
+
if __name__ == "__main__":
|
901 |
+
main()
|
src/evaluation/multi_turn_retrieval.py
ADDED
@@ -0,0 +1,815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Multi-Turn Conversation Retrieval Evaluation for ViettelPay RAG System
|
3 |
+
Generates multi-turn conversations and evaluates retrieval performance
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
import time
|
11 |
+
from typing import Dict, List, Tuple, Optional, Union
|
12 |
+
from pathlib import Path
|
13 |
+
from collections import defaultdict
|
14 |
+
import pandas as pd
|
15 |
+
from tqdm import tqdm
|
16 |
+
import re
|
17 |
+
|
18 |
+
# Load environment variables from .env file
|
19 |
+
from dotenv import load_dotenv
|
20 |
+
|
21 |
+
load_dotenv()
|
22 |
+
|
23 |
+
# Add the project root to Python path so we can import from src
|
24 |
+
project_root = Path(__file__).parent.parent.parent
|
25 |
+
sys.path.insert(0, str(project_root))
|
26 |
+
|
27 |
+
# Import existing components
|
28 |
+
from src.evaluation.prompts import MULTI_TURN_CONVERSATION_GENERATION_PROMPT
|
29 |
+
from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
|
30 |
+
from src.evaluation.single_turn_retrieval import SingleTurnRetrievalEvaluator
|
31 |
+
from src.llm.llm_client import LLMClientFactory, BaseLLMClient
|
32 |
+
from src.agent.nodes import query_enhancement_node, ViettelPayState
|
33 |
+
from langchain_core.messages import HumanMessage
|
34 |
+
|
35 |
+
|
36 |
+
class MultiTurnDatasetCreator:
|
37 |
+
"""Multi-turn conversation dataset creator for ViettelPay evaluation"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self, gemini_api_key: str, knowledge_base: ViettelKnowledgeBase = None
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
Initialize with Gemini API key and optional knowledge base
|
44 |
+
|
45 |
+
Args:
|
46 |
+
gemini_api_key: Google AI API key for Gemini
|
47 |
+
knowledge_base: Pre-initialized ViettelKnowledgeBase instance
|
48 |
+
"""
|
49 |
+
self.llm_client = LLMClientFactory.create_client(
|
50 |
+
"gemini", api_key=gemini_api_key, model="gemini-2.0-flash"
|
51 |
+
)
|
52 |
+
self.knowledge_base = knowledge_base
|
53 |
+
self.dataset = {
|
54 |
+
"conversations": {},
|
55 |
+
"documents": {},
|
56 |
+
"metadata": {
|
57 |
+
"total_chunks_processed": 0,
|
58 |
+
"conversations_generated": 0,
|
59 |
+
"creation_timestamp": time.time(),
|
60 |
+
},
|
61 |
+
}
|
62 |
+
|
63 |
+
print("✅ MultiTurnDatasetCreator initialized with Gemini 2.0 Flash")
|
64 |
+
|
65 |
+
def generate_json_response(
|
66 |
+
self, prompt: str, max_retries: int = 3
|
67 |
+
) -> Optional[Dict]:
|
68 |
+
"""
|
69 |
+
Generate response and parse as JSON with retries
|
70 |
+
|
71 |
+
Args:
|
72 |
+
prompt: Input prompt
|
73 |
+
max_retries: Maximum number of retry attempts
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Parsed JSON response or None if failed
|
77 |
+
"""
|
78 |
+
for attempt in range(max_retries):
|
79 |
+
try:
|
80 |
+
response = self.llm_client.generate(prompt, temperature=0.1)
|
81 |
+
|
82 |
+
if response:
|
83 |
+
# Clean response text
|
84 |
+
response_text = response.strip()
|
85 |
+
|
86 |
+
# Extract JSON from response (handle cases with extra text)
|
87 |
+
json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
|
88 |
+
if json_match:
|
89 |
+
json_text = json_match.group()
|
90 |
+
return json.loads(json_text)
|
91 |
+
else:
|
92 |
+
# Try parsing the whole response
|
93 |
+
return json.loads(response_text)
|
94 |
+
|
95 |
+
except json.JSONDecodeError as e:
|
96 |
+
print(f"⚠️ JSON parsing error (attempt {attempt + 1}): {e}")
|
97 |
+
if attempt == max_retries - 1:
|
98 |
+
print(f"❌ Failed to parse JSON after {max_retries} attempts")
|
99 |
+
print(
|
100 |
+
f"Raw response: {response if 'response' in locals() else 'No response'}"
|
101 |
+
)
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
print(f"⚠️ API error (attempt {attempt + 1}): {e}")
|
105 |
+
if attempt < max_retries - 1:
|
106 |
+
time.sleep(2**attempt) # Exponential backoff
|
107 |
+
|
108 |
+
return None
|
109 |
+
|
110 |
+
def get_all_chunks(self) -> List[Dict]:
|
111 |
+
"""
|
112 |
+
Get ALL chunks directly from ChromaDB vectorstore
|
113 |
+
Reuse the same method from single-turn evaluation
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
List of all document chunks with content and metadata
|
117 |
+
"""
|
118 |
+
print(f"📚 Retrieving ALL chunks directly from ChromaDB vectorstore...")
|
119 |
+
|
120 |
+
if not self.knowledge_base:
|
121 |
+
raise ValueError(
|
122 |
+
"Knowledge base not provided. Please initialize with a ViettelKnowledgeBase instance."
|
123 |
+
)
|
124 |
+
|
125 |
+
try:
|
126 |
+
# Access the ChromaDB vectorstore directly
|
127 |
+
if (
|
128 |
+
not hasattr(self.knowledge_base, "chroma_retriever")
|
129 |
+
or not self.knowledge_base.chroma_retriever
|
130 |
+
):
|
131 |
+
raise ValueError("ChromaDB retriever not found in knowledge base")
|
132 |
+
|
133 |
+
# Get the vectorstore from the retriever
|
134 |
+
vectorstore = self.knowledge_base.chroma_retriever.vectorstore
|
135 |
+
|
136 |
+
# Get all documents directly from ChromaDB
|
137 |
+
print(" Accessing ChromaDB collection...")
|
138 |
+
all_docs = vectorstore.get(include=["documents", "metadatas"])
|
139 |
+
|
140 |
+
documents = all_docs["documents"]
|
141 |
+
metadatas = all_docs["metadatas"]
|
142 |
+
|
143 |
+
print(f" Found {len(documents)} documents in ChromaDB")
|
144 |
+
|
145 |
+
# Convert to our expected format
|
146 |
+
all_chunks = []
|
147 |
+
seen_content_hashes = set()
|
148 |
+
|
149 |
+
for i, (content, metadata) in enumerate(zip(documents, metadatas)):
|
150 |
+
# Create content hash for deduplication
|
151 |
+
content_hash = hash(content[:300])
|
152 |
+
|
153 |
+
if (
|
154 |
+
content_hash not in seen_content_hashes
|
155 |
+
and len(content.strip()) > 100
|
156 |
+
):
|
157 |
+
chunk_info = {
|
158 |
+
"id": f"chunk_{len(all_chunks)}",
|
159 |
+
"content": content,
|
160 |
+
"metadata": metadata or {},
|
161 |
+
"source": "chromadb_direct",
|
162 |
+
"content_length": len(content),
|
163 |
+
"original_index": i,
|
164 |
+
}
|
165 |
+
all_chunks.append(chunk_info)
|
166 |
+
seen_content_hashes.add(content_hash)
|
167 |
+
|
168 |
+
print(f"✅ Retrieved {len(all_chunks)} unique chunks from ChromaDB")
|
169 |
+
|
170 |
+
# Sort by content length (longer chunks first)
|
171 |
+
all_chunks.sort(key=lambda x: x["content_length"], reverse=True)
|
172 |
+
|
173 |
+
return all_chunks
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
print(f"❌ Error accessing ChromaDB directly: {e}")
|
177 |
+
return []
|
178 |
+
|
179 |
+
def generate_conversations_for_chunk(
|
180 |
+
self, chunk: Dict, num_conversations: int = 2
|
181 |
+
) -> List[Dict]:
|
182 |
+
"""
|
183 |
+
Generate multi-turn conversations for a single chunk using Gemini
|
184 |
+
|
185 |
+
Args:
|
186 |
+
chunk: Chunk dictionary with content and metadata
|
187 |
+
num_conversations: Number of conversations to generate per chunk
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
List of conversation dictionaries
|
191 |
+
"""
|
192 |
+
content = chunk["content"]
|
193 |
+
|
194 |
+
prompt = MULTI_TURN_CONVERSATION_GENERATION_PROMPT.format(
|
195 |
+
num_conversations=num_conversations, content=content
|
196 |
+
)
|
197 |
+
|
198 |
+
response_json = self.generate_json_response(prompt)
|
199 |
+
|
200 |
+
if response_json and "conversations" in response_json:
|
201 |
+
conversations = response_json["conversations"]
|
202 |
+
|
203 |
+
# Create conversation objects with metadata
|
204 |
+
conversation_objects = []
|
205 |
+
for i, conversation in enumerate(conversations):
|
206 |
+
if len(conversation.get("turns", [])) >= 2: # At least 2 turns
|
207 |
+
conversation_obj = {
|
208 |
+
"id": f"conv_{chunk['id']}_{i}",
|
209 |
+
"turns": conversation["turns"],
|
210 |
+
"conversation_type": conversation.get("type", "general"),
|
211 |
+
"source_chunk": chunk["id"],
|
212 |
+
"chunk_metadata": chunk["metadata"],
|
213 |
+
"generation_method": "gemini_json",
|
214 |
+
}
|
215 |
+
conversation_objects.append(conversation_obj)
|
216 |
+
|
217 |
+
return conversation_objects
|
218 |
+
else:
|
219 |
+
print(f"⚠️ No valid conversations generated for chunk {chunk['id']}")
|
220 |
+
return []
|
221 |
+
|
222 |
+
def create_multi_turn_dataset(
|
223 |
+
self,
|
224 |
+
conversations_per_chunk: int = 2,
|
225 |
+
save_path: str = "evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json",
|
226 |
+
) -> Dict:
|
227 |
+
"""
|
228 |
+
Create multi-turn conversation dataset using ALL chunks
|
229 |
+
|
230 |
+
Args:
|
231 |
+
conversations_per_chunk: Number of conversations to generate per chunk
|
232 |
+
save_path: Path to save the dataset JSON file
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
Complete dataset dictionary with conversations
|
236 |
+
"""
|
237 |
+
print(f"\n🚀 Creating multi-turn conversation dataset...")
|
238 |
+
print(f" Target: Process ALL chunks from knowledge base")
|
239 |
+
print(f" Conversations per chunk: {conversations_per_chunk}")
|
240 |
+
|
241 |
+
# Step 1: Get all chunks
|
242 |
+
all_chunks = self.get_all_chunks()
|
243 |
+
total_chunks = len(all_chunks)
|
244 |
+
|
245 |
+
if total_chunks == 0:
|
246 |
+
raise ValueError("No chunks found in knowledge base!")
|
247 |
+
|
248 |
+
print(f"✅ Found {total_chunks} chunks to process")
|
249 |
+
|
250 |
+
# Step 2: Generate conversations for all chunks
|
251 |
+
print(f"\n💬 Generating conversations for {total_chunks} chunks...")
|
252 |
+
all_conversations = []
|
253 |
+
|
254 |
+
for chunk in tqdm(all_chunks, desc="Generating conversations"):
|
255 |
+
conversations = self.generate_conversations_for_chunk(
|
256 |
+
chunk, conversations_per_chunk
|
257 |
+
)
|
258 |
+
all_conversations.extend(conversations)
|
259 |
+
time.sleep(0.2) # Rate limiting for Gemini API
|
260 |
+
|
261 |
+
# Step 3: Populate dataset structure
|
262 |
+
self.dataset["documents"] = {
|
263 |
+
chunk["id"]: chunk["content"] for chunk in all_chunks
|
264 |
+
}
|
265 |
+
self.dataset["conversations"] = {
|
266 |
+
conv["id"]: {
|
267 |
+
"turns": conv["turns"],
|
268 |
+
"conversation_type": conv["conversation_type"],
|
269 |
+
"source_chunk": conv["source_chunk"],
|
270 |
+
"chunk_metadata": conv["chunk_metadata"],
|
271 |
+
"generation_method": conv["generation_method"],
|
272 |
+
}
|
273 |
+
for conv in all_conversations
|
274 |
+
}
|
275 |
+
|
276 |
+
# Step 4: Update metadata
|
277 |
+
self.dataset["metadata"].update(
|
278 |
+
{
|
279 |
+
"total_chunks_processed": total_chunks,
|
280 |
+
"conversations_generated": len(all_conversations),
|
281 |
+
"conversations_per_chunk": conversations_per_chunk,
|
282 |
+
"completion_timestamp": time.time(),
|
283 |
+
}
|
284 |
+
)
|
285 |
+
|
286 |
+
# Step 5: Save dataset
|
287 |
+
os.makedirs(
|
288 |
+
os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
|
289 |
+
exist_ok=True,
|
290 |
+
)
|
291 |
+
|
292 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
293 |
+
json.dump(self.dataset, f, ensure_ascii=False, indent=2)
|
294 |
+
|
295 |
+
print(f"\n✅ Multi-turn conversation dataset created successfully!")
|
296 |
+
print(f" 📁 Saved to: {save_path}")
|
297 |
+
print(f" 📊 Statistics:")
|
298 |
+
print(f" • Chunks processed: {total_chunks}")
|
299 |
+
print(f" • Conversations generated: {len(all_conversations)}")
|
300 |
+
print(
|
301 |
+
f" • Avg conversations per chunk: {len(all_conversations)/total_chunks:.1f}"
|
302 |
+
)
|
303 |
+
|
304 |
+
return self.dataset
|
305 |
+
|
306 |
+
|
307 |
+
class ConversationEnhancer:
|
308 |
+
"""Convert multi-turn conversations to enhanced queries using existing query enhancement"""
|
309 |
+
|
310 |
+
def __init__(self, gemini_api_key: str):
|
311 |
+
"""Initialize with Gemini API key for query enhancement"""
|
312 |
+
self.llm_client = LLMClientFactory.create_client(
|
313 |
+
"gemini", api_key=gemini_api_key, model="gemini-2.0-flash-lite"
|
314 |
+
)
|
315 |
+
print("✅ ConversationEnhancer initialized")
|
316 |
+
|
317 |
+
def enhance_conversation(self, conversation_turns: List[Dict]) -> str:
|
318 |
+
"""
|
319 |
+
Convert a multi-turn conversation to an enhanced query
|
320 |
+
|
321 |
+
Args:
|
322 |
+
conversation_turns: List of turn dictionaries with role and content
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
Enhanced query string
|
326 |
+
"""
|
327 |
+
try:
|
328 |
+
# Create messages in the format expected by query_enhancement_node
|
329 |
+
messages = []
|
330 |
+
for turn in conversation_turns:
|
331 |
+
if turn["role"] == "user":
|
332 |
+
messages.append(HumanMessage(content=turn["content"]))
|
333 |
+
|
334 |
+
# Create a mock state for the query enhancement node
|
335 |
+
state = ViettelPayState(messages=messages)
|
336 |
+
|
337 |
+
# Use the existing query enhancement node
|
338 |
+
enhanced_state = query_enhancement_node(state, self.llm_client)
|
339 |
+
|
340 |
+
enhanced_query = enhanced_state.get("enhanced_query", "")
|
341 |
+
|
342 |
+
if not enhanced_query:
|
343 |
+
# Fallback: concatenate all user messages
|
344 |
+
user_messages = [
|
345 |
+
turn["content"]
|
346 |
+
for turn in conversation_turns
|
347 |
+
if turn["role"] == "user"
|
348 |
+
]
|
349 |
+
enhanced_query = " ".join(user_messages)
|
350 |
+
|
351 |
+
return enhanced_query
|
352 |
+
|
353 |
+
except Exception as e:
|
354 |
+
print(f"❌ Error enhancing conversation: {e}")
|
355 |
+
# Fallback: concatenate all user messages
|
356 |
+
user_messages = [
|
357 |
+
turn["content"] for turn in conversation_turns if turn["role"] == "user"
|
358 |
+
]
|
359 |
+
return " ".join(user_messages)
|
360 |
+
|
361 |
+
def convert_dataset_to_single_turn_format(
|
362 |
+
self,
|
363 |
+
multi_turn_dataset: Dict,
|
364 |
+
save_path: str = "evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json",
|
365 |
+
) -> Dict:
|
366 |
+
"""
|
367 |
+
Convert multi-turn conversation dataset to single-turn format with enhanced queries
|
368 |
+
|
369 |
+
Args:
|
370 |
+
multi_turn_dataset: Multi-turn conversation dataset
|
371 |
+
save_path: Path to save the converted dataset
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
Single-turn format dataset
|
375 |
+
"""
|
376 |
+
print(f"\n🔄 Converting multi-turn conversations to enhanced queries...")
|
377 |
+
|
378 |
+
conversations = multi_turn_dataset["conversations"]
|
379 |
+
documents = multi_turn_dataset["documents"]
|
380 |
+
|
381 |
+
# Initialize single-turn format dataset
|
382 |
+
single_turn_dataset = {
|
383 |
+
"queries": {},
|
384 |
+
"documents": documents,
|
385 |
+
"conversation_metadata": {},
|
386 |
+
"metadata": {
|
387 |
+
"total_conversations_processed": len(conversations),
|
388 |
+
"enhanced_queries_generated": 0,
|
389 |
+
"conversion_timestamp": time.time(),
|
390 |
+
"original_dataset_metadata": multi_turn_dataset.get("metadata", {}),
|
391 |
+
},
|
392 |
+
}
|
393 |
+
|
394 |
+
enhanced_count = 0
|
395 |
+
|
396 |
+
# Process each conversation
|
397 |
+
for conv_id, conv_data in tqdm(
|
398 |
+
conversations.items(), desc="Enhancing conversations"
|
399 |
+
):
|
400 |
+
try:
|
401 |
+
# Extract turns
|
402 |
+
turns = conv_data["turns"]
|
403 |
+
|
404 |
+
# Enhance conversation to single query
|
405 |
+
enhanced_query = self.enhance_conversation(turns)
|
406 |
+
|
407 |
+
if enhanced_query and len(enhanced_query.strip()) > 5:
|
408 |
+
single_turn_dataset["queries"][conv_id] = enhanced_query
|
409 |
+
single_turn_dataset["conversation_metadata"][conv_id] = {
|
410 |
+
"original_conversation": turns,
|
411 |
+
"conversation_type": conv_data.get(
|
412 |
+
"conversation_type", "general"
|
413 |
+
),
|
414 |
+
"source_chunk": conv_data["source_chunk"],
|
415 |
+
"chunk_metadata": conv_data.get("chunk_metadata", {}),
|
416 |
+
"generation_method": conv_data.get(
|
417 |
+
"generation_method", "unknown"
|
418 |
+
),
|
419 |
+
}
|
420 |
+
enhanced_count += 1
|
421 |
+
|
422 |
+
time.sleep(0.1) # Small delay for rate limiting
|
423 |
+
|
424 |
+
except Exception as e:
|
425 |
+
print(f"⚠️ Error processing conversation {conv_id}: {e}")
|
426 |
+
continue
|
427 |
+
|
428 |
+
# Update metadata
|
429 |
+
single_turn_dataset["metadata"]["enhanced_queries_generated"] = enhanced_count
|
430 |
+
|
431 |
+
# Save converted dataset
|
432 |
+
os.makedirs(
|
433 |
+
os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
|
434 |
+
exist_ok=True,
|
435 |
+
)
|
436 |
+
|
437 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
438 |
+
json.dump(single_turn_dataset, f, ensure_ascii=False, indent=2)
|
439 |
+
|
440 |
+
print(f"✅ Conversion completed successfully!")
|
441 |
+
print(f" 📁 Saved to: {save_path}")
|
442 |
+
print(f" 📊 Statistics:")
|
443 |
+
print(f" • Conversations processed: {len(conversations)}")
|
444 |
+
print(f" • Enhanced queries generated: {enhanced_count}")
|
445 |
+
print(f" • Success rate: {enhanced_count/len(conversations)*100:.1f}%")
|
446 |
+
|
447 |
+
return single_turn_dataset
|
448 |
+
|
449 |
+
|
450 |
+
class MultiTurnEvaluator:
|
451 |
+
"""Extended evaluator for multi-turn conversation retrieval with additional analysis"""
|
452 |
+
|
453 |
+
def __init__(self, dataset: Dict, knowledge_base: ViettelKnowledgeBase):
|
454 |
+
"""
|
455 |
+
Initialize evaluator with dataset and knowledge base
|
456 |
+
|
457 |
+
Args:
|
458 |
+
dataset: Evaluation dataset in single-turn format (from converted multi-turn)
|
459 |
+
knowledge_base: ViettelKnowledgeBase instance to evaluate
|
460 |
+
"""
|
461 |
+
self.dataset = dataset
|
462 |
+
self.knowledge_base = knowledge_base
|
463 |
+
self.single_turn_evaluator = SingleTurnRetrievalEvaluator(
|
464 |
+
dataset, knowledge_base
|
465 |
+
)
|
466 |
+
|
467 |
+
def _get_conversation_metadata(self, query_id: str) -> Dict:
|
468 |
+
"""
|
469 |
+
Get conversation metadata for a query, handling both formats
|
470 |
+
|
471 |
+
Args:
|
472 |
+
query_id: Query identifier
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
Metadata dictionary
|
476 |
+
"""
|
477 |
+
# First try conversation_metadata (multi-turn format)
|
478 |
+
conversation_metadata = self.dataset.get("conversation_metadata", {})
|
479 |
+
if query_id in conversation_metadata:
|
480 |
+
return conversation_metadata[query_id]
|
481 |
+
|
482 |
+
# Fallback to question_metadata (single-turn format)
|
483 |
+
question_metadata = self.dataset.get("question_metadata", {})
|
484 |
+
if query_id in question_metadata:
|
485 |
+
# Convert single-turn format to multi-turn format for consistency
|
486 |
+
meta = question_metadata[query_id]
|
487 |
+
return {
|
488 |
+
"conversation_type": "single_turn",
|
489 |
+
"source_chunk": meta.get("source_chunk"),
|
490 |
+
"original_conversation": [
|
491 |
+
{"role": "user", "content": self.dataset["queries"][query_id]}
|
492 |
+
],
|
493 |
+
"chunk_metadata": meta.get("chunk_metadata", {}),
|
494 |
+
"generation_method": meta.get("generation_method", "unknown"),
|
495 |
+
}
|
496 |
+
|
497 |
+
return {}
|
498 |
+
|
499 |
+
def evaluate_multi_turn_performance(
|
500 |
+
self, k_values: List[int] = [1, 3, 5, 10]
|
501 |
+
) -> Dict:
|
502 |
+
"""
|
503 |
+
Evaluate multi-turn conversation retrieval performance
|
504 |
+
|
505 |
+
Args:
|
506 |
+
k_values: List of k values to evaluate
|
507 |
+
|
508 |
+
Returns:
|
509 |
+
Dictionary with evaluation results and multi-turn specific analysis
|
510 |
+
"""
|
511 |
+
print(f"\n🔍 Running multi-turn conversation evaluation...")
|
512 |
+
|
513 |
+
# Step 1: Run standard single-turn evaluation
|
514 |
+
base_results = self.single_turn_evaluator.evaluate(k_values)
|
515 |
+
|
516 |
+
# Step 2: Add multi-turn specific analysis
|
517 |
+
# Analyze by conversation type
|
518 |
+
results_by_type = defaultdict(
|
519 |
+
lambda: {"hit_rates": {k: [] for k in k_values}, "rr_scores": []}
|
520 |
+
)
|
521 |
+
|
522 |
+
for query_id, query_result in base_results["per_query_results"].items():
|
523 |
+
conv_meta = self._get_conversation_metadata(query_id)
|
524 |
+
conv_type = conv_meta.get("conversation_type", "unknown")
|
525 |
+
|
526 |
+
# Add to type-specific results
|
527 |
+
results_by_type[conv_type]["rr_scores"].append(query_result.get("rr", 0))
|
528 |
+
for k in k_values:
|
529 |
+
hit_rate = query_result.get("hit_rates", {}).get(k, 0)
|
530 |
+
results_by_type[conv_type]["hit_rates"][k].append(hit_rate)
|
531 |
+
|
532 |
+
# Calculate averages by conversation type
|
533 |
+
type_analysis = {}
|
534 |
+
for conv_type, type_results in results_by_type.items():
|
535 |
+
type_analysis[conv_type] = {
|
536 |
+
"hit_rates": {
|
537 |
+
k: sum(hits) / len(hits) if hits else 0
|
538 |
+
for k, hits in type_results["hit_rates"].items()
|
539 |
+
},
|
540 |
+
"mrr": (
|
541 |
+
sum(type_results["rr_scores"]) / len(type_results["rr_scores"])
|
542 |
+
if type_results["rr_scores"]
|
543 |
+
else 0
|
544 |
+
),
|
545 |
+
"total_conversations": len(type_results["rr_scores"]),
|
546 |
+
}
|
547 |
+
|
548 |
+
# Analyze conversation length impact
|
549 |
+
turn_length_analysis = self._analyze_by_conversation_length(
|
550 |
+
base_results, k_values
|
551 |
+
)
|
552 |
+
|
553 |
+
# Combine results
|
554 |
+
multi_turn_results = {
|
555 |
+
**base_results, # Include all base results
|
556 |
+
"conversation_type_analysis": type_analysis,
|
557 |
+
"turn_length_analysis": turn_length_analysis,
|
558 |
+
"multi_turn_metadata": {
|
559 |
+
"evaluation_type": "multi_turn_conversation",
|
560 |
+
"conversation_types": list(type_analysis.keys()),
|
561 |
+
"total_conversation_types": len(type_analysis),
|
562 |
+
},
|
563 |
+
}
|
564 |
+
|
565 |
+
return multi_turn_results
|
566 |
+
|
567 |
+
def _analyze_by_conversation_length(
|
568 |
+
self, base_results: Dict, k_values: List[int]
|
569 |
+
) -> Dict:
|
570 |
+
"""Analyze performance by conversation turn length"""
|
571 |
+
|
572 |
+
length_analysis = defaultdict(
|
573 |
+
lambda: {"hit_rates": {k: [] for k in k_values}, "rr_scores": []}
|
574 |
+
)
|
575 |
+
|
576 |
+
for query_id, query_result in base_results["per_query_results"].items():
|
577 |
+
conv_meta = self._get_conversation_metadata(query_id)
|
578 |
+
original_conv = conv_meta.get("original_conversation", [])
|
579 |
+
turn_count = len(
|
580 |
+
[turn for turn in original_conv if turn.get("role") == "user"]
|
581 |
+
)
|
582 |
+
|
583 |
+
# Categorize by turn length
|
584 |
+
if turn_count == 1:
|
585 |
+
length_category = "1_turn" # Single-turn questions
|
586 |
+
elif turn_count == 2:
|
587 |
+
length_category = "2_turns"
|
588 |
+
elif turn_count == 3:
|
589 |
+
length_category = "3_turns"
|
590 |
+
elif turn_count >= 4:
|
591 |
+
length_category = "4+_turns"
|
592 |
+
else:
|
593 |
+
length_category = "unknown_turns"
|
594 |
+
|
595 |
+
# Add to length-specific results
|
596 |
+
length_analysis[length_category]["rr_scores"].append(
|
597 |
+
query_result.get("rr", 0)
|
598 |
+
)
|
599 |
+
for k in k_values:
|
600 |
+
hit_rate = query_result.get("hit_rates", {}).get(k, 0)
|
601 |
+
length_analysis[length_category]["hit_rates"][k].append(hit_rate)
|
602 |
+
|
603 |
+
# Calculate averages by turn length
|
604 |
+
final_length_analysis = {}
|
605 |
+
for length_cat, length_results in length_analysis.items():
|
606 |
+
final_length_analysis[length_cat] = {
|
607 |
+
"hit_rates": {
|
608 |
+
k: sum(hits) / len(hits) if hits else 0
|
609 |
+
for k, hits in length_results["hit_rates"].items()
|
610 |
+
},
|
611 |
+
"mrr": (
|
612 |
+
sum(length_results["rr_scores"]) / len(length_results["rr_scores"])
|
613 |
+
if length_results["rr_scores"]
|
614 |
+
else 0
|
615 |
+
),
|
616 |
+
"total_conversations": len(length_results["rr_scores"]),
|
617 |
+
}
|
618 |
+
|
619 |
+
return final_length_analysis
|
620 |
+
|
621 |
+
def print_multi_turn_results(self, results: Dict):
|
622 |
+
"""Print multi-turn evaluation results with additional analysis"""
|
623 |
+
|
624 |
+
# Print base results first
|
625 |
+
self.single_turn_evaluator.print_evaluation_results(results)
|
626 |
+
|
627 |
+
# Print multi-turn specific analysis
|
628 |
+
print(f"\n🔍 MULTI-TURN SPECIFIC ANALYSIS")
|
629 |
+
print("=" * 60)
|
630 |
+
|
631 |
+
# Conversation type analysis
|
632 |
+
type_analysis = results.get("conversation_type_analysis", {})
|
633 |
+
if type_analysis:
|
634 |
+
print(f"\n📊 Performance by Conversation Type:")
|
635 |
+
print(f"{'Type':<20} {'MRR':<8} {'Hit@5':<8} {'Count':<8}")
|
636 |
+
print("-" * 50)
|
637 |
+
|
638 |
+
for conv_type, analysis in type_analysis.items():
|
639 |
+
mrr = analysis["mrr"]
|
640 |
+
hit_at_5 = analysis["hit_rates"].get(5, 0) * 100
|
641 |
+
count = analysis["total_conversations"]
|
642 |
+
print(f"{conv_type:<20} {mrr:<8.3f} {hit_at_5:<8.1f}% {count:<8}")
|
643 |
+
|
644 |
+
# Turn length analysis
|
645 |
+
length_analysis = results.get("turn_length_analysis", {})
|
646 |
+
if length_analysis:
|
647 |
+
print(f"\n📊 Performance by Conversation Length:")
|
648 |
+
print(f"{'Length':<12} {'MRR':<8} {'Hit@5':<8} {'Count':<8}")
|
649 |
+
print("-" * 40)
|
650 |
+
|
651 |
+
for length_cat, analysis in length_analysis.items():
|
652 |
+
mrr = analysis["mrr"]
|
653 |
+
hit_at_5 = analysis["hit_rates"].get(5, 0) * 100
|
654 |
+
count = analysis["total_conversations"]
|
655 |
+
print(f"{length_cat:<12} {mrr:<8.3f} {hit_at_5:<8.1f}% {count:<8}")
|
656 |
+
|
657 |
+
print(f"\n💡 Multi-Turn Insights:")
|
658 |
+
|
659 |
+
# Best performing conversation type
|
660 |
+
if type_analysis:
|
661 |
+
best_type = max(type_analysis.keys(), key=lambda k: type_analysis[k]["mrr"])
|
662 |
+
worst_type = min(
|
663 |
+
type_analysis.keys(), key=lambda k: type_analysis[k]["mrr"]
|
664 |
+
)
|
665 |
+
print(
|
666 |
+
f" • Best conversation type: {best_type} (MRR: {type_analysis[best_type]['mrr']:.3f})"
|
667 |
+
)
|
668 |
+
print(
|
669 |
+
f" • Worst conversation type: {worst_type} (MRR: {type_analysis[worst_type]['mrr']:.3f})"
|
670 |
+
)
|
671 |
+
|
672 |
+
# Turn length insights
|
673 |
+
if length_analysis:
|
674 |
+
best_length = max(
|
675 |
+
length_analysis.keys(), key=lambda k: length_analysis[k]["mrr"]
|
676 |
+
)
|
677 |
+
print(
|
678 |
+
f" • Best performing length: {best_length} (MRR: {length_analysis[best_length]['mrr']:.3f})"
|
679 |
+
)
|
680 |
+
|
681 |
+
|
682 |
+
def main():
|
683 |
+
"""Main function for multi-turn conversation evaluation"""
|
684 |
+
parser = argparse.ArgumentParser(
|
685 |
+
description="ViettelPay Multi-Turn Conversation Retrieval Evaluation"
|
686 |
+
)
|
687 |
+
parser.add_argument(
|
688 |
+
"--mode",
|
689 |
+
choices=["create", "enhance", "evaluate", "full"],
|
690 |
+
default="full",
|
691 |
+
help="Mode: create conversations, enhance to queries, evaluate, or full pipeline",
|
692 |
+
)
|
693 |
+
parser.add_argument(
|
694 |
+
"--conversations-dataset",
|
695 |
+
default="evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_conversations.json",
|
696 |
+
help="Path to multi-turn conversations dataset",
|
697 |
+
)
|
698 |
+
parser.add_argument(
|
699 |
+
"--enhanced-dataset",
|
700 |
+
default="evaluation_data/datasets/multi_turn_retrieval/viettelpay_multiturn_enhanced.json",
|
701 |
+
help="Path to enhanced queries dataset",
|
702 |
+
)
|
703 |
+
parser.add_argument(
|
704 |
+
"--results-path",
|
705 |
+
default="evaluation_data/results/multi_turn_retrieval/viettelpay_multiturn_results.json",
|
706 |
+
help="Path to save evaluation results",
|
707 |
+
)
|
708 |
+
parser.add_argument(
|
709 |
+
"--conversations-per-chunk",
|
710 |
+
type=int,
|
711 |
+
default=3,
|
712 |
+
help="Number of conversations per chunk",
|
713 |
+
)
|
714 |
+
parser.add_argument(
|
715 |
+
"--k-values",
|
716 |
+
nargs="+",
|
717 |
+
type=int,
|
718 |
+
default=[1, 3, 5, 10],
|
719 |
+
help="K values for evaluation",
|
720 |
+
)
|
721 |
+
parser.add_argument(
|
722 |
+
"--knowledge-base-path",
|
723 |
+
default="./knowledge_base",
|
724 |
+
help="Path to knowledge base",
|
725 |
+
)
|
726 |
+
|
727 |
+
args = parser.parse_args()
|
728 |
+
|
729 |
+
# Configuration
|
730 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
731 |
+
|
732 |
+
if not GEMINI_API_KEY:
|
733 |
+
print("❌ Please set GEMINI_API_KEY environment variable")
|
734 |
+
return
|
735 |
+
|
736 |
+
try:
|
737 |
+
# Initialize knowledge base
|
738 |
+
print("🔧 Initializing ViettelPay knowledge base...")
|
739 |
+
kb = ViettelKnowledgeBase()
|
740 |
+
if not kb.load_knowledge_base(args.knowledge_base_path):
|
741 |
+
print(
|
742 |
+
"❌ Failed to load knowledge base. Please run build_database_script.py first."
|
743 |
+
)
|
744 |
+
return
|
745 |
+
|
746 |
+
# Step 1: Create multi-turn conversations if requested
|
747 |
+
if args.mode in ["create", "full"]:
|
748 |
+
print(f"\n🎯 Creating multi-turn conversation dataset...")
|
749 |
+
creator = MultiTurnDatasetCreator(GEMINI_API_KEY, kb)
|
750 |
+
|
751 |
+
conversations_dataset = creator.create_multi_turn_dataset(
|
752 |
+
conversations_per_chunk=args.conversations_per_chunk,
|
753 |
+
save_path=args.conversations_dataset,
|
754 |
+
)
|
755 |
+
|
756 |
+
# Step 2: Enhance conversations to queries if requested
|
757 |
+
if args.mode in ["enhance", "full"]:
|
758 |
+
print(f"\n⚡ Converting conversations to enhanced queries...")
|
759 |
+
|
760 |
+
# Load conversations if not created in this run
|
761 |
+
if args.mode == "enhance":
|
762 |
+
if not os.path.exists(args.conversations_dataset):
|
763 |
+
print(
|
764 |
+
f"❌ Conversations dataset not found: {args.conversations_dataset}"
|
765 |
+
)
|
766 |
+
return
|
767 |
+
|
768 |
+
with open(args.conversations_dataset, "r", encoding="utf-8") as f:
|
769 |
+
conversations_dataset = json.load(f)
|
770 |
+
|
771 |
+
# Enhance conversations
|
772 |
+
enhancer = ConversationEnhancer(GEMINI_API_KEY)
|
773 |
+
enhanced_dataset = enhancer.convert_dataset_to_single_turn_format(
|
774 |
+
conversations_dataset, args.enhanced_dataset
|
775 |
+
)
|
776 |
+
|
777 |
+
# Step 3: Evaluate if requested
|
778 |
+
if args.mode in ["evaluate", "full"]:
|
779 |
+
print(f"\n📊 Evaluating multi-turn conversation retrieval...")
|
780 |
+
|
781 |
+
# Load enhanced dataset if not created in this run
|
782 |
+
if args.mode == "evaluate":
|
783 |
+
if not os.path.exists(args.enhanced_dataset):
|
784 |
+
print(f"❌ Enhanced dataset not found: {args.enhanced_dataset}")
|
785 |
+
return
|
786 |
+
|
787 |
+
with open(args.enhanced_dataset, "r", encoding="utf-8") as f:
|
788 |
+
enhanced_dataset = json.load(f)
|
789 |
+
|
790 |
+
# Run evaluation
|
791 |
+
evaluator = MultiTurnEvaluator(enhanced_dataset, kb)
|
792 |
+
results = evaluator.evaluate_multi_turn_performance(k_values=args.k_values)
|
793 |
+
evaluator.print_multi_turn_results(results)
|
794 |
+
|
795 |
+
# Save results
|
796 |
+
if args.results_path:
|
797 |
+
with open(args.results_path, "w", encoding="utf-8") as f:
|
798 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
799 |
+
print(f"\n💾 Results saved to: {args.results_path}")
|
800 |
+
|
801 |
+
print(f"\n✅ Multi-turn evaluation completed successfully!")
|
802 |
+
print(f"\n💡 Next steps:")
|
803 |
+
print(f" 1. Compare multi-turn vs single-turn performance")
|
804 |
+
print(f" 2. Analyze conversation types that work best")
|
805 |
+
print(f" 3. Optimize query enhancement for multi-turn scenarios")
|
806 |
+
|
807 |
+
except Exception as e:
|
808 |
+
print(f"❌ Error in main execution: {e}")
|
809 |
+
import traceback
|
810 |
+
|
811 |
+
traceback.print_exc()
|
812 |
+
|
813 |
+
|
814 |
+
if __name__ == "__main__":
|
815 |
+
main()
|
src/evaluation/prompts.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Prompts for ViettelPay Synthetic Evaluation Dataset Creation
|
3 |
+
Simplified version for MRR and Hit Rate evaluation only
|
4 |
+
"""
|
5 |
+
|
6 |
+
# Question Generation Prompt (JSON format for better parsing)
|
7 |
+
QUESTION_GENERATION_PROMPT = """Bạn là chuyên gia tạo câu hỏi đánh giá cho hệ thống ViettelPay Pro.
|
8 |
+
ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
9 |
+
Dựa trên đoạn văn bản sau từ tài liệu hướng dẫn, hãy tạo ra {num_questions} câu hỏi đa dạng:
|
10 |
+
|
11 |
+
<context>
|
12 |
+
{content}
|
13 |
+
</context>
|
14 |
+
|
15 |
+
Tạo các loại câu hỏi:
|
16 |
+
1. Câu hỏi trực tiếp về thông tin trong đoạn văn
|
17 |
+
2. Câu hỏi về cách thực hiện hoặc quy trình
|
18 |
+
3. Câu hỏi về lỗi, vấn đề hoặc troubleshooting
|
19 |
+
4. Câu hỏi về quy định, chính sách, phí
|
20 |
+
|
21 |
+
Yêu cầu cho mỗi câu hỏi:
|
22 |
+
- Tự nhiên như khách hàng ViettelPay thật sẽ hỏi
|
23 |
+
- Có thể trả lời được từ đoạn văn bản đã cho
|
24 |
+
- Ngắn gọn (5-20 từ)
|
25 |
+
- Sử dụng tiếng Việt thông dụng
|
26 |
+
- Đa dạng về loại câu hỏi và độ phức tạp
|
27 |
+
|
28 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
|
29 |
+
{{
|
30 |
+
"questions": [
|
31 |
+
"Câu hỏi đầu tiên?",
|
32 |
+
"Câu hỏi thứ hai?",
|
33 |
+
"Câu hỏi thứ ba?"
|
34 |
+
]
|
35 |
+
}}
|
36 |
+
|
37 |
+
CHỈ trả về JSON, không có text khác."""
|
38 |
+
|
39 |
+
# Multi-Turn Conversation Generation Prompt (JSON format)
|
40 |
+
MULTI_TURN_CONVERSATION_GENERATION_PROMPT = """Bạn là một chuyên gia trong việc tạo dữ liệu huấn luyện cho chatbot hỗ trợ khách hàng của ứng dụng ViettelPay Pro.
|
41 |
+
ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
42 |
+
Nhiệm vụ của bạn là tạo ra **{num_conversations}** cuộc hội thoại đa lượt, chân thực và tự nhiên giữa người dùng (Đại lý/Điểm bán) và Trợ lý AI. Toàn bộ nội dung cuộc hội thoại phải dựa **hoàn toàn** vào thông tin được cung cấp trong tài liệu dưới đây.
|
43 |
+
|
44 |
+
<context>
|
45 |
+
{content}
|
46 |
+
</context>
|
47 |
+
|
48 |
+
Hãy tạo ra các cuộc hội thoại xoay quanh những kịch bản sau:
|
49 |
+
|
50 |
+
1. **Giải quyết vấn đề (`error_resolution`):** Người dùng gặp lỗi, giao dịch thất bại, hoặc một tính năng không hoạt động như mong đợi. Họ muốn tìm hiểu nguyên nhân và cách khắc phục.
|
51 |
+
* *Ví dụ luồng hội thoại:* Báo lỗi -> Hỏi về nguyên nhân sâu xa -> Hỏi cách để tránh lỗi này trong tương lai.
|
52 |
+
|
53 |
+
2. **Hướng dẫn thực hiện (`procedure_guide`):** Người dùng muốn biết cách thực hiện một tác vụ cụ thể. Cuộc hội thoại nên đi sâu vào các bước, điều kiện, hoặc các chi tiết liên quan.
|
54 |
+
* *Ví dụ luồng hội thoại:* Hỏi cách thực hiện một dịch vụ -> Hỏi về một bước cụ thể -> Hỏi về một trường hợp đặc biệt (ví dụ: "nếu làm cho nhà mạng khác thì sao?").
|
55 |
+
|
56 |
+
3. **Tra cứu thông tin (`policy_info`):** Người dùng có câu hỏi về chính sách, phí, hạn mức, hoặc các quy định của dịch vụ.
|
57 |
+
* *Ví dụ luồng hội thoại:* Hỏi về một quy định chung -> Hỏi về một trường hợp áp dụng cụ thể -> Hỏi về các ngoại lệ.
|
58 |
+
|
59 |
+
**YÊU CẦU QUAN TRỌNG:**
|
60 |
+
|
61 |
+
* **Dòng chảy tự nhiên:** Mỗi lượt hỏi của người dùng phải là một phản ứng logic, tự nhiên sau khi nhận được câu trả lời (tưởng tượng) từ AI. Hãy hình dung AI đã đưa ra câu trả lời hữu ích nhưng chưa đầy đủ, khiến người dùng phải hỏi thêm để làm rõ.
|
62 |
+
* **Chân thực như người dùng thật:**
|
63 |
+
* Sử dụng ngôn ngữ đời thường, ngắn gọn, đi thẳng vào vấn đề.
|
64 |
+
* Có thể dùng các từ viết tắt phổ biến (vd: "sđt", "tk", "gd", "đk").
|
65 |
+
* Giọng điệu có thể thể hiện sự bối rối, cần hỗ trợ gấp hoặc tò mò.
|
66 |
+
* **Bám sát tài liệu:** **Không** được tự ý sáng tạo thông tin, chính sách, hoặc tính năng không có trong phần `<context>`.
|
67 |
+
* Tất cả các lượt câu hỏi, đặc biệt là câu hỏi cuối cùng, phải có thể trả lời được từ thông tin trong tài liệu `<context>`.
|
68 |
+
* **Cấu trúc:** Mỗi cuộc hội thoại phải có từ 2 đến 3 lượt hỏi từ phía người dùng.
|
69 |
+
* **Ngôn ngữ:** Tiếng Việt.
|
70 |
+
|
71 |
+
Ví dụ cuộc hội thoại:
|
72 |
+
```
|
73 |
+
Lượt 1: "mã lỗi 606 là gì vậy?"
|
74 |
+
Lượt 2: "làm sao để khắc phục lỗi này?"
|
75 |
+
```
|
76 |
+
|
77 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
|
78 |
+
{{
|
79 |
+
"conversations": [
|
80 |
+
{{
|
81 |
+
"type": "error_resolution",
|
82 |
+
"turns": [
|
83 |
+
{{"role": "user", "content": "mã lỗi 606 là gì vậy?"}},
|
84 |
+
{{"role": "user", "content": "làm sao để khắc phục lỗi này?"}}
|
85 |
+
]
|
86 |
+
}},
|
87 |
+
{{
|
88 |
+
"type": "procedure_inquiry",
|
89 |
+
"turns": [
|
90 |
+
{{"role": "user", "content": "nạp cước điện thoại như thế nào?"}},
|
91 |
+
{{"role": "user", "content": "có thể nạp cho số Viettel không?"}},
|
92 |
+
{{"role": "user", "content": "có các mệnh giá nào?"}}
|
93 |
+
]
|
94 |
+
}}
|
95 |
+
]
|
96 |
+
}}
|
97 |
+
|
98 |
+
CHỈ trả về JSON, không có text khác."""
|
99 |
+
|
100 |
+
|
101 |
+
# Quality Check Prompt for Generated Questions
|
102 |
+
QUESTION_QUALITY_CHECK_PROMPT = """Đánh giá chất lượng của câu hỏi được tạo ra cho hệ thống ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
103 |
+
|
104 |
+
Câu hỏi: {question}
|
105 |
+
Đoạn văn gốc: {context}
|
106 |
+
|
107 |
+
Tiêu chí đánh giá:
|
108 |
+
1. Clarity (Rõ ràng): Câu hỏi có dễ hiểu không?
|
109 |
+
2. Answerability (Có thể trả lời): Có thể trả lời từ đoạn văn không?
|
110 |
+
3. Naturalness (Tự nhiên): Có giống cách khách hàng thật hỏi không?
|
111 |
+
4. Relevance (Liên quan): Có phù hợp với nội dung ViettelPay không?
|
112 |
+
|
113 |
+
Mỗi tiêu chí từ 1-5 điểm (5 là tốt nhất).
|
114 |
+
|
115 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
|
116 |
+
{{
|
117 |
+
"clarity": 5,
|
118 |
+
"answerability": 4,
|
119 |
+
"naturalness": 5,
|
120 |
+
"relevance": 5,
|
121 |
+
"overall_score": 4.75,
|
122 |
+
"keep_question": true,
|
123 |
+
"feedback": "Câu hỏi tốt, rõ ràng và tự nhiên"
|
124 |
+
}}
|
125 |
+
|
126 |
+
CHỈ trả về JSON, không có text khác."""
|
127 |
+
|
128 |
+
# Context Quality Check Prompt
|
129 |
+
CONTEXT_QUALITY_CHECK_PROMPT = """Đánh giá chất lượng của đoạn văn bản ViettelPay Pro để tạo câu hỏi. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
130 |
+
|
131 |
+
Đoạn văn bản:
|
132 |
+
{content}
|
133 |
+
|
134 |
+
Tiêu chí đánh giá:
|
135 |
+
1. Clarity (Rõ ràng): Thông tin có dễ hiểu không?
|
136 |
+
2. Completeness (Đầy đủ): Có đủ thông tin để tạo câu hỏi không?
|
137 |
+
3. Structure (Cấu trúc): Có tổ chức tốt không?
|
138 |
+
4. Relevance (Liên quan): Có phù hợp với ViettelPay không?
|
139 |
+
5. Information_density (Mật độ thông tin): Có đủ thông tin hữu ích không?
|
140 |
+
|
141 |
+
Mỗi tiêu chí từ 1-5 điểm (5 là tốt nhất).
|
142 |
+
|
143 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
|
144 |
+
{{
|
145 |
+
"clarity": 5,
|
146 |
+
"completeness": 4,
|
147 |
+
"structure": 4,
|
148 |
+
"relevance": 5,
|
149 |
+
"information_density": 4,
|
150 |
+
"overall_score": 4.4,
|
151 |
+
"use_context": true,
|
152 |
+
"feedback": "Đoạn văn tốt, có thể tạo câu hỏi chất lượng"
|
153 |
+
}}
|
154 |
+
|
155 |
+
CHỈ trả về JSON, không có text khác."""
|
156 |
+
|
157 |
+
# Question Evolution/Variation Prompt
|
158 |
+
QUESTION_EVOLUTION_PROMPT = """Tạo các biến thể của câu hỏi ViettelPay Pro để tăng tính đa dạng. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
159 |
+
|
160 |
+
Câu hỏi gốc: {original_question}
|
161 |
+
Ngữ cảnh: {context}
|
162 |
+
|
163 |
+
Tạo 3 biến thể khác nhau:
|
164 |
+
1. Phiên bản casual/thông tục (cách nói hàng ngày)
|
165 |
+
2. Phiên bản formal/lịch sự (cách nói trang trọng)
|
166 |
+
3. Phiên bản cụ thể (thêm chi tiết, tình huống cụ thể)
|
167 |
+
|
168 |
+
Yêu cầu:
|
169 |
+
- Giữ nguyên ý nghĩa cốt lõi
|
170 |
+
- Vẫn có thể trả lời từ cùng ngữ cảnh
|
171 |
+
- Tự nhiên với người dùng Việt Nam
|
172 |
+
- Đa dạng về cách diễn đạt
|
173 |
+
|
174 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
|
175 |
+
{{
|
176 |
+
"original_question": "{original_question}",
|
177 |
+
"variations": [
|
178 |
+
"Phiên bản casual",
|
179 |
+
"Phiên bản formal",
|
180 |
+
"Phiên bản cụ thể"
|
181 |
+
]
|
182 |
+
}}
|
183 |
+
|
184 |
+
CHỈ trả về JSON, không có text khác."""
|
185 |
+
|
186 |
+
# Dataset Statistics Prompt
|
187 |
+
DATASET_STATS_PROMPT = """Phân tích thống kê dataset đánh giá ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
188 |
+
|
189 |
+
Dữ liệu:
|
190 |
+
- Tổng số câu hỏi: {total_questions}
|
191 |
+
- Tổng số documents: {total_documents}
|
192 |
+
- Câu hỏi theo loại: {question_types}
|
193 |
+
|
194 |
+
Tạo báo cáo thống kê và đề xuất cải thiện.
|
195 |
+
|
196 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
|
197 |
+
{{
|
198 |
+
"coverage_analysis": {{
|
199 |
+
"error_handling": "20%",
|
200 |
+
"procedures": "30%",
|
201 |
+
"policies": "25%",
|
202 |
+
"faq": "25%"
|
203 |
+
}},
|
204 |
+
"quality_metrics": {{
|
205 |
+
"avg_questions_per_doc": 2.1,
|
206 |
+
"question_diversity": "high"
|
207 |
+
}},
|
208 |
+
"recommendations": [
|
209 |
+
"Tăng câu hỏi về error handling",
|
210 |
+
"Cân bằng độ khó của câu hỏi"
|
211 |
+
]
|
212 |
+
}}
|
213 |
+
|
214 |
+
CHỈ trả về JSON, không có text khác."""
|
215 |
+
|
216 |
+
# Error Analysis Prompt
|
217 |
+
ERROR_ANALYSIS_PROMPT = """Phân tích lỗi trong quá trình đánh giá retrieval ViettelPay Pro. ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
218 |
+
|
219 |
+
Kết quả đánh giá:
|
220 |
+
{evaluation_results}
|
221 |
+
|
222 |
+
Xác định:
|
223 |
+
1. Câu hỏi có hiệu suất thấp (Hit Rate < 0.3)
|
224 |
+
2. Loại lỗi thường gặp
|
225 |
+
3. Nguyên nhân gốc rễ
|
226 |
+
4. Đề xuất cải thiện
|
227 |
+
|
228 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như sau:
|
229 |
+
{{
|
230 |
+
"low_performance_queries": [
|
231 |
+
{{"query": "câu hỏi", "hit_rate": 0.2, "issue": "từ khóa không rõ ràng"}}
|
232 |
+
],
|
233 |
+
"common_error_types": [
|
234 |
+
"Thiếu từ khóa chính",
|
235 |
+
"Ngữ cảnh không đủ",
|
236 |
+
"Chunking không tối ưu"
|
237 |
+
],
|
238 |
+
"improvement_suggestions": [
|
239 |
+
"Cải thiện chunking strategy",
|
240 |
+
"Thêm synonyms cho từ khóa"
|
241 |
+
]
|
242 |
+
}}
|
243 |
+
|
244 |
+
CHỈ trả về JSON, không có text khác."""
|
245 |
+
|
246 |
+
# Intent Classification Prompt
|
247 |
+
# Updated Intent Classification Conversation Generation Prompt with chunk mixing support
|
248 |
+
# Improved Intent Classification Conversation Generation Prompt
|
249 |
+
INTENT_CLASSIFICATION_CONVERSATION_GENERATION_PROMPT = """Bạn là chuyên gia tạo dữ liệu đánh giá cho hệ thống phân loại ý định (intent classification) của Trợ lý AI trên ứng dụng ViettelPay Pro.
|
250 |
+
ViettelPay Pro là ứng dụng chuyên biệt dành cho các đại lý và điểm bán của Viettel, giúp họ thực hiện các giao dịch tài chính và viễn thông cho khách hàng một cách nhanh chóng, an toàn và đơn giản.
|
251 |
+
|
252 |
+
Nhiệm vụ của bạn là tạo ra **{num_conversations}** cuộc hội thoại đa lượt thực tế. Mỗi tin nhắn của người dùng phải được gán một nhãn `intent` chính xác.
|
253 |
+
`intent` là ý định của người dùng trong câu hỏi hiện tại và liên quan đến các lượt hỏi trước đó trong cuộc hội thoại.
|
254 |
+
|
255 |
+
**1. Định nghĩa các loại ý định (Bắt buộc phải tuân theo):**
|
256 |
+
|
257 |
+
* **`greeting`**: Chỉ là lời chào hỏi thuần túy, không có câu hỏi hoặc yêu cầu cụ thể nào khác. Nếu tin nhắn có cả lời chào VÀ câu hỏi thì phân loại theo các ý định khác, không phải greeting.
|
258 |
+
* *Ví dụ:* "chào em", "hello shop", "xin chào ạ"
|
259 |
+
* *Không phải greeting:* "xin chào, cho hỏi về lỗi 606" → đây là error_help
|
260 |
+
* **`faq`**: Các câu hỏi đáp chung, tìm hiểu về dịch vụ, tính năng, v.v.
|
261 |
+
* *Ví dụ:* "App có bán thẻ game không?", "ViettelPay Pro nạp tiền được cho mạng nao?"
|
262 |
+
* **`error_help`**: Báo cáo sự cố, hỏi về mã lỗi cụ thể.
|
263 |
+
* *Ví dụ:* "Giao dịch báo lỗi 606", "tại sao tôi không thanh toán được?", "lỗi này là gì?"
|
264 |
+
* **`procedure_guide`**: Hỏi về các bước cụ thể để thực hiện một tác vụ.
|
265 |
+
* *Ví dụ:* "làm thế nào để hủy giao dịch?", "chỉ tôi cách lấy lại mã thẻ cào", "hướng dẫn nạp cước"
|
266 |
+
* **`human_request`**: Yêu cầu được nói chuyện trực tiếp với nhân viên hỗ trợ.
|
267 |
+
* *Ví dụ:* "cho tôi gặp người thật", "nối máy cho tổng đài", "em k hiểu, cho gặp ai đó"
|
268 |
+
* **`out_of_scope`**: Câu hỏi ngoài phạm vi ViettelPay (thời tiết, chính trị, v.v.), không liên quan gì đến các dịch vụ tài chính, viễn thông của Viettel.
|
269 |
+
* *Ví dụ:* "dự báo thời tiết hôm nay?", "giá xăng bao nhiêu?", "cách nấu phở"
|
270 |
+
* **`unclear`**: Câu hỏi không rõ ràng, thiếu thông tin cụ thể, cần người dùng bổ sung thêm chi tiết để có thể hỗ trợ hiệu quả.
|
271 |
+
* *Ví dụ:* "lỗi", "giúp với", "gd", "???", "ko hiểu", "bị lỗi giờ sao đây", "không thực hiện được", "sao vậy", "tại sao thế"
|
272 |
+
|
273 |
+
**2. Nguồn kiến thức (Context):**
|
274 |
+
|
275 |
+
Sử dụng tài liệu dưới đây để lấy các mã lỗi (vd: 606, W02), tên dịch vụ (vd: Gạch nợ cước), và các tình huống thực tế để xây dựng cuộc hội thoại.
|
276 |
+
|
277 |
+
<context>
|
278 |
+
{content}
|
279 |
+
</context>
|
280 |
+
|
281 |
+
**3. Yêu cầu về kịch bản hội thoại:**
|
282 |
+
|
283 |
+
* Mỗi cuộc hội thoại có t��� 2 đến 4 lượt hỏi từ người dùng. Tùy chỉnh sao cho phù hợp với tài liệu.
|
284 |
+
* {generation_instruction}
|
285 |
+
* Tạo các cuộc hội thoại đa dạng, không lặp lại.
|
286 |
+
* **QUAN TRỌNG - Intent Mixing**: Khoảng 70% tin nhắn nên liên quan đến context, nhưng 30% tin nhắn nên là các intent tự nhiên khác như:
|
287 |
+
- `greeting` ở đầu cuộc hội thoại
|
288 |
+
- `unclear` khi người dùng hỏi không rõ ràng, trợ lý cần thêm thông tin từ người dùng
|
289 |
+
- `human_request` khi họ muốn hỗ trợ trực tiếp
|
290 |
+
- `out_of_scope` khi họ hỏi không liên quan đến ViettelPay Pro
|
291 |
+
* Ngôn ngữ phải tự nhiên như người dùng thật, có thể dùng từ viết tắt, và đi thẳng vào vấn đề.
|
292 |
+
* Tạo các tình huống thực tế như người dùng thật sự sẽ hỏi. Không cần phải kết thúc bằng cảm ơn hay tạm biệt.
|
293 |
+
* Các câu hỏi nên dễ để phân loại ý định, không cần phải suy nghĩ quá lâu.
|
294 |
+
|
295 |
+
**4. Ví dụ cuộc hội thoại thực tế:**
|
296 |
+
{{
|
297 |
+
"conversations": [
|
298 |
+
{{
|
299 |
+
"turns": [
|
300 |
+
{{"user": "chào em", "intent": "greeting"}},
|
301 |
+
{{"user": "mã lỗi 606 là gì z?", "intent": "error_help"}},
|
302 |
+
{{"user": "em ko hiểu, gặp ai đó dc ko?", "intent": "human_request"}}
|
303 |
+
]
|
304 |
+
}},
|
305 |
+
{{
|
306 |
+
"turns": [
|
307 |
+
{{"user": "nạp cước như nào?", "intent": "procedure_guide"}},
|
308 |
+
{{"user": "hôm nay trời đẹp nhỉ", "intent": "out_of_scope"}},
|
309 |
+
{{"user": "ờ quay lại, có phí ko?", "intent": "faq"}}
|
310 |
+
]
|
311 |
+
}}
|
312 |
+
]
|
313 |
+
}}
|
314 |
+
|
315 |
+
**5. Định dạng đầu ra (Output):**
|
316 |
+
|
317 |
+
QUAN TRỌNG: Trả về kết quả dưới dạng JSON với format chính xác như ví dụ trên.
|
318 |
+
CHỈ trả về JSON, không có text khác."""
|
src/evaluation/single_turn_retrieval.py
ADDED
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Single Turn Synthetic Retrieval Evaluation Dataset Creator for ViettelPay RAG System
|
3 |
+
Uses Google Gemini 2.0 Flash with JSON responses for better parsing
|
4 |
+
Simplified version with only MRR and hit rate evaluation (no qrels generation)
|
5 |
+
"""
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import argparse
|
11 |
+
import time
|
12 |
+
from typing import Dict, List, Tuple, Optional, Union
|
13 |
+
from pathlib import Path
|
14 |
+
from collections import defaultdict
|
15 |
+
import pandas as pd
|
16 |
+
from tqdm import tqdm
|
17 |
+
import re
|
18 |
+
|
19 |
+
# Load environment variables from .env file
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
|
22 |
+
load_dotenv()
|
23 |
+
|
24 |
+
# Add the project root to Python path so we can import from src
|
25 |
+
project_root = Path(__file__).parent.parent.parent
|
26 |
+
sys.path.insert(0, str(project_root))
|
27 |
+
|
28 |
+
# Import prompts (only the ones we need)
|
29 |
+
from src.evaluation.prompts import (
|
30 |
+
QUESTION_GENERATION_PROMPT,
|
31 |
+
QUESTION_QUALITY_CHECK_PROMPT,
|
32 |
+
CONTEXT_QUALITY_CHECK_PROMPT,
|
33 |
+
QUESTION_EVOLUTION_PROMPT,
|
34 |
+
)
|
35 |
+
|
36 |
+
# Import your existing knowledge base and LLM client
|
37 |
+
from src.knowledge_base.viettel_knowledge_base import ViettelKnowledgeBase
|
38 |
+
from src.llm.llm_client import LLMClientFactory, BaseLLMClient
|
39 |
+
|
40 |
+
|
41 |
+
class SingleTurnDatasetCreator:
|
42 |
+
"""Single turn synthetic evaluation dataset creator with JSON responses and all chunks processing"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self, gemini_api_key: str, knowledge_base: ViettelKnowledgeBase = None
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Initialize with Gemini API key and optional knowledge base
|
49 |
+
|
50 |
+
Args:
|
51 |
+
gemini_api_key: Google AI API key for Gemini
|
52 |
+
knowledge_base: Pre-initialized ViettelKnowledgeBase instance
|
53 |
+
"""
|
54 |
+
self.llm_client = LLMClientFactory.create_client(
|
55 |
+
"gemini", api_key=gemini_api_key, model="gemini-2.0-flash"
|
56 |
+
)
|
57 |
+
self.knowledge_base = knowledge_base
|
58 |
+
self.dataset = {
|
59 |
+
"queries": {},
|
60 |
+
"documents": {},
|
61 |
+
"metadata": {
|
62 |
+
"total_chunks_processed": 0,
|
63 |
+
"questions_generated": 0,
|
64 |
+
"creation_timestamp": time.time(),
|
65 |
+
},
|
66 |
+
}
|
67 |
+
|
68 |
+
print("✅ SingleTurnDatasetCreator initialized with Gemini 2.0 Flash")
|
69 |
+
|
70 |
+
def generate_json_response(
|
71 |
+
self, prompt: str, max_retries: int = 3
|
72 |
+
) -> Optional[Dict]:
|
73 |
+
"""
|
74 |
+
Generate response and parse as JSON with retries
|
75 |
+
|
76 |
+
Args:
|
77 |
+
prompt: Input prompt
|
78 |
+
max_retries: Maximum number of retry attempts
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Parsed JSON response or None if failed
|
82 |
+
"""
|
83 |
+
for attempt in range(max_retries):
|
84 |
+
try:
|
85 |
+
response = self.llm_client.generate(prompt, temperature=0.1)
|
86 |
+
|
87 |
+
if response:
|
88 |
+
# Clean response text
|
89 |
+
response_text = response.strip()
|
90 |
+
|
91 |
+
# Extract JSON from response (handle cases with extra text)
|
92 |
+
json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
|
93 |
+
if json_match:
|
94 |
+
json_text = json_match.group()
|
95 |
+
return json.loads(json_text)
|
96 |
+
else:
|
97 |
+
# Try parsing the whole response
|
98 |
+
return json.loads(response_text)
|
99 |
+
|
100 |
+
except json.JSONDecodeError as e:
|
101 |
+
print(f"⚠️ JSON parsing error (attempt {attempt + 1}): {e}")
|
102 |
+
if attempt == max_retries - 1:
|
103 |
+
print(f"❌ Failed to parse JSON after {max_retries} attempts")
|
104 |
+
print(
|
105 |
+
f"Raw response: {response if 'response' in locals() else 'No response'}"
|
106 |
+
)
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
print(f"⚠️ API error (attempt {attempt + 1}): {e}")
|
110 |
+
if attempt < max_retries - 1:
|
111 |
+
time.sleep(2**attempt) # Exponential backoff
|
112 |
+
|
113 |
+
return None
|
114 |
+
|
115 |
+
def get_all_chunks(self) -> List[Dict]:
|
116 |
+
"""
|
117 |
+
Get ALL chunks directly from ChromaDB vectorstore (no sampling)
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
List of all document chunks with content and metadata
|
121 |
+
"""
|
122 |
+
print(f"📚 Retrieving ALL chunks directly from ChromaDB vectorstore...")
|
123 |
+
|
124 |
+
if not self.knowledge_base:
|
125 |
+
raise ValueError(
|
126 |
+
"Knowledge base not provided. Please initialize with a ViettelKnowledgeBase instance."
|
127 |
+
)
|
128 |
+
|
129 |
+
try:
|
130 |
+
# Access the ChromaDB vectorstore directly
|
131 |
+
if (
|
132 |
+
not hasattr(self.knowledge_base, "chroma_retriever")
|
133 |
+
or not self.knowledge_base.chroma_retriever
|
134 |
+
):
|
135 |
+
raise ValueError("ChromaDB retriever not found in knowledge base")
|
136 |
+
|
137 |
+
# Get the vectorstore from the retriever
|
138 |
+
vectorstore = self.knowledge_base.chroma_retriever.vectorstore
|
139 |
+
|
140 |
+
# Get all documents directly from ChromaDB
|
141 |
+
print(" Accessing ChromaDB collection...")
|
142 |
+
all_docs = vectorstore.get(include=["documents", "metadatas"])
|
143 |
+
|
144 |
+
documents = all_docs["documents"]
|
145 |
+
metadatas = all_docs["metadatas"]
|
146 |
+
|
147 |
+
print(f" Found {len(documents)} documents in ChromaDB")
|
148 |
+
print(f" Sample document preview:")
|
149 |
+
for i, doc in enumerate(documents[:3]):
|
150 |
+
print(f" Doc {i+1}: {doc[:100]}...")
|
151 |
+
|
152 |
+
# Convert to our expected format
|
153 |
+
all_chunks = []
|
154 |
+
seen_content_hashes = set()
|
155 |
+
|
156 |
+
for i, (content, metadata) in enumerate(zip(documents, metadatas)):
|
157 |
+
# Create content hash for deduplication (just in case)
|
158 |
+
content_hash = hash(content[:300])
|
159 |
+
|
160 |
+
if (
|
161 |
+
content_hash not in seen_content_hashes
|
162 |
+
and len(content.strip()) > 50
|
163 |
+
):
|
164 |
+
chunk_info = {
|
165 |
+
"id": f"chunk_{len(all_chunks)}",
|
166 |
+
"content": content,
|
167 |
+
"metadata": metadata or {},
|
168 |
+
"source": "chromadb_direct",
|
169 |
+
"content_length": len(content),
|
170 |
+
"original_index": i,
|
171 |
+
}
|
172 |
+
all_chunks.append(chunk_info)
|
173 |
+
seen_content_hashes.add(content_hash)
|
174 |
+
else:
|
175 |
+
if content_hash in seen_content_hashes:
|
176 |
+
print(f" ⚠️ Skipping duplicate content at index {i}")
|
177 |
+
else:
|
178 |
+
print(
|
179 |
+
f" ⚠️ Skipping short content at index {i} (length: {len(content.strip())})"
|
180 |
+
)
|
181 |
+
|
182 |
+
print(f"✅ Retrieved {len(all_chunks)} unique chunks from ChromaDB")
|
183 |
+
print(
|
184 |
+
f" Filtered out {len(documents) - len(all_chunks)} duplicates/short chunks"
|
185 |
+
)
|
186 |
+
|
187 |
+
# Sort by content length (longer chunks first, usually more informative)
|
188 |
+
all_chunks.sort(key=lambda x: x["content_length"], reverse=True)
|
189 |
+
|
190 |
+
# Display statistics
|
191 |
+
avg_length = sum(chunk["content_length"] for chunk in all_chunks) / len(
|
192 |
+
all_chunks
|
193 |
+
)
|
194 |
+
min_length = min(chunk["content_length"] for chunk in all_chunks)
|
195 |
+
max_length = max(chunk["content_length"] for chunk in all_chunks)
|
196 |
+
|
197 |
+
print(f" 📊 Chunk Statistics:")
|
198 |
+
print(f" • Average length: {avg_length:.0f} characters")
|
199 |
+
print(f" • Min length: {min_length} characters")
|
200 |
+
print(f" • Max length: {max_length} characters")
|
201 |
+
|
202 |
+
return all_chunks
|
203 |
+
|
204 |
+
except Exception as e:
|
205 |
+
print(f"❌ Error accessing ChromaDB directly: {e}")
|
206 |
+
print(f" Falling back to search-based method...")
|
207 |
+
return self._get_all_chunks_fallback()
|
208 |
+
|
209 |
+
def _get_all_chunks_fallback(self) -> List[Dict]:
|
210 |
+
"""
|
211 |
+
Fallback method using search queries if direct ChromaDB access fails
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
List of document chunks retrieved via search
|
215 |
+
"""
|
216 |
+
print(f"🔄 Using fallback search-based chunk retrieval...")
|
217 |
+
|
218 |
+
# Use comprehensive search terms to capture most content
|
219 |
+
comprehensive_queries = [
|
220 |
+
"ViettelPay",
|
221 |
+
"nạp",
|
222 |
+
"cước",
|
223 |
+
"giao dịch",
|
224 |
+
"thanh toán",
|
225 |
+
"lỗi",
|
226 |
+
"hủy",
|
227 |
+
"thẻ",
|
228 |
+
"chuyển",
|
229 |
+
"tiền",
|
230 |
+
"quy định",
|
231 |
+
"phí",
|
232 |
+
"dịch vụ",
|
233 |
+
"tài khoản",
|
234 |
+
"ngân hàng",
|
235 |
+
"OTP",
|
236 |
+
"PIN",
|
237 |
+
"mã",
|
238 |
+
"số",
|
239 |
+
"điện thoại",
|
240 |
+
"internet",
|
241 |
+
"truyền hình",
|
242 |
+
"homephone",
|
243 |
+
"cố định",
|
244 |
+
"game",
|
245 |
+
"Viettel",
|
246 |
+
"Mobifone",
|
247 |
+
# Add some Vietnamese words that might not be captured above
|
248 |
+
"ứng dụng",
|
249 |
+
"khách hàng",
|
250 |
+
"hỗ trợ",
|
251 |
+
"kiểm tra",
|
252 |
+
"xác nhận",
|
253 |
+
"bảo mật",
|
254 |
+
]
|
255 |
+
|
256 |
+
all_chunks = []
|
257 |
+
seen_content_hashes = set()
|
258 |
+
|
259 |
+
for query in comprehensive_queries:
|
260 |
+
try:
|
261 |
+
# Search with large k to get as many chunks as possible
|
262 |
+
docs = self.knowledge_base.search(query, top_k=50)
|
263 |
+
|
264 |
+
for doc in docs:
|
265 |
+
# Create content hash for deduplication
|
266 |
+
content_hash = hash(doc.page_content[:300])
|
267 |
+
|
268 |
+
if (
|
269 |
+
content_hash not in seen_content_hashes
|
270 |
+
and len(doc.page_content.strip()) > 50
|
271 |
+
):
|
272 |
+
chunk_info = {
|
273 |
+
"id": f"chunk_{len(all_chunks)}",
|
274 |
+
"content": doc.page_content,
|
275 |
+
"metadata": doc.metadata,
|
276 |
+
"source": f"search_{query}",
|
277 |
+
"content_length": len(doc.page_content),
|
278 |
+
}
|
279 |
+
all_chunks.append(chunk_info)
|
280 |
+
seen_content_hashes.add(content_hash)
|
281 |
+
|
282 |
+
except Exception as e:
|
283 |
+
print(f"⚠️ Error searching for '{query}': {e}")
|
284 |
+
continue
|
285 |
+
|
286 |
+
print(f"✅ Fallback method retrieved {len(all_chunks)} unique chunks")
|
287 |
+
|
288 |
+
# Sort by content length
|
289 |
+
all_chunks.sort(key=lambda x: x["content_length"], reverse=True)
|
290 |
+
|
291 |
+
return all_chunks
|
292 |
+
|
293 |
+
def generate_questions_for_chunk(
|
294 |
+
self, chunk: Dict, num_questions: int = 2
|
295 |
+
) -> List[Dict]:
|
296 |
+
"""
|
297 |
+
Generate questions for a single chunk using Gemini with JSON response
|
298 |
+
|
299 |
+
Args:
|
300 |
+
chunk: Chunk dictionary with content and metadata
|
301 |
+
num_questions: Number of questions to generate per chunk
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
List of question dictionaries with metadata
|
305 |
+
"""
|
306 |
+
content = chunk["content"]
|
307 |
+
|
308 |
+
prompt = QUESTION_GENERATION_PROMPT.format(
|
309 |
+
num_questions=num_questions, content=content
|
310 |
+
)
|
311 |
+
|
312 |
+
response_json = self.generate_json_response(prompt)
|
313 |
+
|
314 |
+
if response_json and "questions" in response_json:
|
315 |
+
questions = response_json["questions"]
|
316 |
+
|
317 |
+
# Create question objects with metadata
|
318 |
+
question_objects = []
|
319 |
+
for i, question_text in enumerate(questions):
|
320 |
+
if len(question_text.strip()) > 5: # Filter very short questions
|
321 |
+
question_obj = {
|
322 |
+
"id": f"q_{chunk['id']}_{i}",
|
323 |
+
"text": question_text.strip(),
|
324 |
+
"source_chunk": chunk["id"],
|
325 |
+
"chunk_metadata": chunk["metadata"],
|
326 |
+
"generation_method": "gemini_json",
|
327 |
+
}
|
328 |
+
question_objects.append(question_obj)
|
329 |
+
|
330 |
+
return question_objects
|
331 |
+
else:
|
332 |
+
print(f"⚠️ No valid questions generated for chunk {chunk['id']}")
|
333 |
+
return []
|
334 |
+
|
335 |
+
def check_context_quality(self, chunk: Dict) -> bool:
|
336 |
+
"""
|
337 |
+
Check if a chunk is suitable for question generation
|
338 |
+
|
339 |
+
Args:
|
340 |
+
chunk: Chunk dictionary
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
True if chunk should be used, False otherwise
|
344 |
+
"""
|
345 |
+
content = chunk["content"]
|
346 |
+
|
347 |
+
# Basic checks first
|
348 |
+
if len(content.strip()) < 100:
|
349 |
+
return False
|
350 |
+
|
351 |
+
# Use Gemini for quality assessment
|
352 |
+
prompt = CONTEXT_QUALITY_CHECK_PROMPT.format(content=content[:1000])
|
353 |
+
|
354 |
+
response_json = self.generate_json_response(prompt)
|
355 |
+
|
356 |
+
if response_json:
|
357 |
+
return response_json.get("use_context", True)
|
358 |
+
else:
|
359 |
+
# Fallback to basic heuristics
|
360 |
+
return len(content.strip()) > 100 and len(content.split()) > 20
|
361 |
+
|
362 |
+
def create_complete_dataset(
|
363 |
+
self,
|
364 |
+
questions_per_chunk: int = 2,
|
365 |
+
save_path: str = "evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval_dataset.json",
|
366 |
+
quality_check: bool = True,
|
367 |
+
) -> Dict:
|
368 |
+
"""
|
369 |
+
Create complete synthetic evaluation dataset using ALL chunks
|
370 |
+
|
371 |
+
Args:
|
372 |
+
questions_per_chunk: Number of questions to generate per chunk
|
373 |
+
save_path: Path to save the dataset JSON file
|
374 |
+
quality_check: Whether to perform quality checks on chunks
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
Complete dataset dictionary
|
378 |
+
"""
|
379 |
+
print(f"\n🚀 Creating simplified synthetic evaluation dataset...")
|
380 |
+
print(f" Target: Process ALL chunks from knowledge base")
|
381 |
+
print(f" Questions per chunk: {questions_per_chunk}")
|
382 |
+
print(f" Quality check: {quality_check}")
|
383 |
+
print(f" Evaluation method: MRR and Hit Rates only (no qrels)")
|
384 |
+
|
385 |
+
# Step 1: Get all chunks
|
386 |
+
all_chunks = self.get_all_chunks()
|
387 |
+
total_chunks = len(all_chunks)
|
388 |
+
|
389 |
+
if total_chunks == 0:
|
390 |
+
raise ValueError("No chunks found in knowledge base!")
|
391 |
+
|
392 |
+
print(f"✅ Found {total_chunks} chunks to process")
|
393 |
+
|
394 |
+
# Step 2: Quality filtering (optional)
|
395 |
+
if quality_check:
|
396 |
+
print(f"\n🔍 Performing quality checks on chunks...")
|
397 |
+
quality_chunks = []
|
398 |
+
|
399 |
+
for chunk in tqdm(all_chunks, desc="Quality checking"):
|
400 |
+
if self.check_context_quality(chunk):
|
401 |
+
quality_chunks.append(chunk)
|
402 |
+
time.sleep(0.1) # Rate limiting
|
403 |
+
|
404 |
+
print(
|
405 |
+
f"✅ {len(quality_chunks)}/{total_chunks} chunks passed quality check"
|
406 |
+
)
|
407 |
+
chunks_to_process = quality_chunks
|
408 |
+
else:
|
409 |
+
chunks_to_process = all_chunks
|
410 |
+
|
411 |
+
# Step 3: Generate questions for all chunks
|
412 |
+
print(f"\n📝 Generating questions for {len(chunks_to_process)} chunks...")
|
413 |
+
all_questions = []
|
414 |
+
|
415 |
+
for chunk in tqdm(chunks_to_process, desc="Generating questions"):
|
416 |
+
questions = self.generate_questions_for_chunk(chunk, questions_per_chunk)
|
417 |
+
all_questions.extend(questions)
|
418 |
+
time.sleep(0.2) # Rate limiting for Gemini API
|
419 |
+
|
420 |
+
print(
|
421 |
+
f"✅ Generated {len(all_questions)} questions from {len(chunks_to_process)} chunks"
|
422 |
+
)
|
423 |
+
|
424 |
+
# Step 4: Populate dataset structure
|
425 |
+
self.dataset["documents"] = {
|
426 |
+
chunk["id"]: chunk["content"] for chunk in chunks_to_process
|
427 |
+
}
|
428 |
+
self.dataset["queries"] = {q["id"]: q["text"] for q in all_questions}
|
429 |
+
|
430 |
+
# Add question metadata
|
431 |
+
question_metadata = {
|
432 |
+
q["id"]: {
|
433 |
+
"source_chunk": q["source_chunk"],
|
434 |
+
"chunk_metadata": q["chunk_metadata"],
|
435 |
+
"generation_method": q["generation_method"],
|
436 |
+
}
|
437 |
+
for q in all_questions
|
438 |
+
}
|
439 |
+
|
440 |
+
self.dataset["question_metadata"] = question_metadata
|
441 |
+
|
442 |
+
# Step 5: Update metadata
|
443 |
+
self.dataset["metadata"].update(
|
444 |
+
{
|
445 |
+
"total_chunks_processed": len(chunks_to_process),
|
446 |
+
"total_chunks_available": total_chunks,
|
447 |
+
"questions_generated": len(all_questions),
|
448 |
+
"questions_per_chunk": questions_per_chunk,
|
449 |
+
"quality_check_enabled": quality_check,
|
450 |
+
"evaluation_method": "mrr_hit_rates_only",
|
451 |
+
"completion_timestamp": time.time(),
|
452 |
+
}
|
453 |
+
)
|
454 |
+
|
455 |
+
# Step 6: Save dataset
|
456 |
+
os.makedirs(
|
457 |
+
os.path.dirname(save_path) if os.path.dirname(save_path) else ".",
|
458 |
+
exist_ok=True,
|
459 |
+
)
|
460 |
+
|
461 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
462 |
+
json.dump(self.dataset, f, ensure_ascii=False, indent=2)
|
463 |
+
|
464 |
+
print(f"\n✅ COMPLETE dataset created successfully!")
|
465 |
+
print(f" 📁 Saved to: {save_path}")
|
466 |
+
print(f" 📊 Statistics:")
|
467 |
+
print(f" • Chunks processed: {len(chunks_to_process)}/{total_chunks}")
|
468 |
+
print(f" • Questions generated: {len(all_questions)}")
|
469 |
+
print(f" • Evaluation method: MRR and Hit Rates only")
|
470 |
+
print(
|
471 |
+
f" • Coverage: {len(chunks_to_process)/total_chunks*100:.1f}% of knowledge base"
|
472 |
+
)
|
473 |
+
|
474 |
+
return self.dataset
|
475 |
+
|
476 |
+
def load_dataset(self, dataset_path: str) -> Dict:
|
477 |
+
"""Load dataset from JSON file with metadata"""
|
478 |
+
with open(dataset_path, "r", encoding="utf-8") as f:
|
479 |
+
self.dataset = json.load(f)
|
480 |
+
|
481 |
+
metadata = self.dataset.get("metadata", {})
|
482 |
+
|
483 |
+
print(f"📖 Loaded dataset from {dataset_path}")
|
484 |
+
print(f" 📊 Dataset Statistics:")
|
485 |
+
print(f" • Queries: {len(self.dataset['queries'])}")
|
486 |
+
print(f" • Documents: {len(self.dataset['documents'])}")
|
487 |
+
print(f" • Created: {time.ctime(metadata.get('creation_timestamp', 0))}")
|
488 |
+
|
489 |
+
return self.dataset
|
490 |
+
|
491 |
+
|
492 |
+
class SingleTurnRetrievalEvaluator:
|
493 |
+
"""Simplified retrieval evaluator with only MRR and hit rates"""
|
494 |
+
|
495 |
+
def __init__(self, dataset: Dict, knowledge_base: ViettelKnowledgeBase):
|
496 |
+
"""
|
497 |
+
Initialize evaluator with dataset and knowledge base
|
498 |
+
|
499 |
+
Args:
|
500 |
+
dataset: Evaluation dataset with queries and documents
|
501 |
+
knowledge_base: ViettelKnowledgeBase instance to evaluate
|
502 |
+
"""
|
503 |
+
self.dataset = dataset
|
504 |
+
self.knowledge_base = knowledge_base
|
505 |
+
self.results = {}
|
506 |
+
|
507 |
+
def _match_retrieved_documents(self, retrieved_docs) -> List[str]:
|
508 |
+
"""
|
509 |
+
Enhanced document matching with multiple strategies
|
510 |
+
|
511 |
+
Args:
|
512 |
+
retrieved_docs: Retrieved Document objects from knowledge base
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
List of matched document IDs
|
516 |
+
"""
|
517 |
+
matched_ids = []
|
518 |
+
|
519 |
+
for doc in retrieved_docs:
|
520 |
+
# Strategy 1: Try to find exact content match
|
521 |
+
doc_id = self._find_exact_content_match(doc.page_content)
|
522 |
+
|
523 |
+
if not doc_id:
|
524 |
+
# Strategy 2: Try fuzzy content matching
|
525 |
+
doc_id = self._find_fuzzy_content_match(doc.page_content)
|
526 |
+
|
527 |
+
if doc_id:
|
528 |
+
matched_ids.append(doc_id)
|
529 |
+
|
530 |
+
return matched_ids
|
531 |
+
|
532 |
+
def _find_exact_content_match(self, retrieved_content: str) -> Optional[str]:
|
533 |
+
"""Find exact content match"""
|
534 |
+
for doc_id, doc_content in self.dataset["documents"].items():
|
535 |
+
if retrieved_content.strip() == doc_content.strip():
|
536 |
+
return doc_id
|
537 |
+
return None
|
538 |
+
|
539 |
+
def _find_fuzzy_content_match(
|
540 |
+
self, retrieved_content: str, min_overlap: int = 50
|
541 |
+
) -> Optional[str]:
|
542 |
+
"""Find fuzzy content match with word overlap"""
|
543 |
+
best_match_id = None
|
544 |
+
best_overlap = 0
|
545 |
+
|
546 |
+
retrieved_words = set(retrieved_content.lower().split())
|
547 |
+
|
548 |
+
for doc_id, doc_content in self.dataset["documents"].items():
|
549 |
+
doc_words = set(doc_content.lower().split())
|
550 |
+
overlap = len(retrieved_words & doc_words)
|
551 |
+
|
552 |
+
if overlap > best_overlap and overlap >= min_overlap:
|
553 |
+
best_overlap = overlap
|
554 |
+
best_match_id = doc_id
|
555 |
+
|
556 |
+
return best_match_id
|
557 |
+
|
558 |
+
def _safe_average(self, values: List[float]) -> float:
|
559 |
+
"""Calculate average safely handling empty lists"""
|
560 |
+
return sum(values) / len(values) if values else 0.0
|
561 |
+
|
562 |
+
def evaluate(self, k_values: List[int] = [1, 3, 5, 10]) -> Dict:
|
563 |
+
"""
|
564 |
+
Simplified evaluation with only MRR and hit rates
|
565 |
+
|
566 |
+
This method checks if the source document (where the question was generated from)
|
567 |
+
is retrieved among the top-k results.
|
568 |
+
|
569 |
+
Args:
|
570 |
+
k_values: List of k values to evaluate
|
571 |
+
|
572 |
+
Returns:
|
573 |
+
Dictionary with MRR and hit rate results
|
574 |
+
"""
|
575 |
+
print(f"\n🔍 Running simplified evaluation (MRR and Hit Rates only)...")
|
576 |
+
print(f" 📊 K values: {k_values}")
|
577 |
+
print(f" 📚 Total queries: {len(self.dataset['queries'])}")
|
578 |
+
|
579 |
+
# Initialize results
|
580 |
+
hit_rates = {k: [] for k in k_values}
|
581 |
+
rr_scores = [] # Reciprocal Rank scores for MRR calculation
|
582 |
+
query_results = {}
|
583 |
+
failed_queries = []
|
584 |
+
|
585 |
+
# Process each query
|
586 |
+
for query_id, query_text in tqdm(
|
587 |
+
self.dataset["queries"].items(), desc="Evaluating queries"
|
588 |
+
):
|
589 |
+
try:
|
590 |
+
# Get source document from metadata - handle both single-turn and multi-turn formats
|
591 |
+
source_chunk_id = None
|
592 |
+
|
593 |
+
# Try question_metadata first (single-turn format)
|
594 |
+
question_meta = self.dataset.get("question_metadata", {}).get(
|
595 |
+
query_id, {}
|
596 |
+
)
|
597 |
+
if question_meta:
|
598 |
+
source_chunk_id = question_meta.get("source_chunk")
|
599 |
+
|
600 |
+
# If not found, try conversation_metadata (multi-turn format)
|
601 |
+
if not source_chunk_id:
|
602 |
+
conversation_meta = self.dataset.get(
|
603 |
+
"conversation_metadata", {}
|
604 |
+
).get(query_id, {})
|
605 |
+
if conversation_meta:
|
606 |
+
source_chunk_id = conversation_meta.get("source_chunk")
|
607 |
+
|
608 |
+
if not source_chunk_id:
|
609 |
+
print(f"⚠️ No source chunk info for query {query_id}")
|
610 |
+
continue
|
611 |
+
|
612 |
+
# Get retrieval results
|
613 |
+
retrieved_docs = self.knowledge_base.search(
|
614 |
+
query_text, top_k=max(k_values)
|
615 |
+
)
|
616 |
+
retrieved_doc_ids = self._match_retrieved_documents(retrieved_docs)
|
617 |
+
|
618 |
+
# Check if source document is in top-k for each k
|
619 |
+
query_results[query_id] = {
|
620 |
+
"query": query_text,
|
621 |
+
"source_chunk": source_chunk_id,
|
622 |
+
"retrieved": retrieved_doc_ids,
|
623 |
+
"hit_rates": {},
|
624 |
+
}
|
625 |
+
|
626 |
+
# Calculate Reciprocal Rank (MRR) - once per query
|
627 |
+
if source_chunk_id in retrieved_doc_ids:
|
628 |
+
source_rank = (
|
629 |
+
retrieved_doc_ids.index(source_chunk_id) + 1
|
630 |
+
) # 1-indexed rank
|
631 |
+
rr_score = 1.0 / source_rank
|
632 |
+
else:
|
633 |
+
rr_score = 0.0
|
634 |
+
|
635 |
+
query_results[query_id]["rr"] = rr_score
|
636 |
+
query_results[query_id]["source_rank"] = (
|
637 |
+
source_rank if rr_score > 0 else None
|
638 |
+
)
|
639 |
+
rr_scores.append(rr_score)
|
640 |
+
|
641 |
+
for k in k_values:
|
642 |
+
top_k_docs = retrieved_doc_ids[:k]
|
643 |
+
hit = 1 if source_chunk_id in top_k_docs else 0
|
644 |
+
hit_rates[k].append(hit)
|
645 |
+
query_results[query_id]["hit_rates"][k] = hit
|
646 |
+
|
647 |
+
except Exception as e:
|
648 |
+
print(f"❌ Error evaluating query {query_id}: {e}")
|
649 |
+
failed_queries.append((query_id, str(e)))
|
650 |
+
continue
|
651 |
+
|
652 |
+
# Calculate average metrics
|
653 |
+
avg_hit_rates = {}
|
654 |
+
avg_rr = sum(rr_scores) / len(rr_scores) if rr_scores else 0.0
|
655 |
+
|
656 |
+
for k in k_values:
|
657 |
+
avg_hit_rates[k] = self._safe_average(hit_rates[k])
|
658 |
+
|
659 |
+
results = {
|
660 |
+
"hit_rates": avg_hit_rates,
|
661 |
+
"mrr": avg_rr,
|
662 |
+
"per_query_results": query_results,
|
663 |
+
"failed_queries": failed_queries,
|
664 |
+
"summary": {
|
665 |
+
"total_queries": len(self.dataset["queries"]),
|
666 |
+
"evaluated_queries": len(query_results),
|
667 |
+
"failed_queries": len(failed_queries),
|
668 |
+
"success_rate": len(query_results) / len(self.dataset["queries"]) * 100,
|
669 |
+
"k_values": k_values,
|
670 |
+
"evaluation_type": "mrr_hit_rates_only",
|
671 |
+
"evaluation_timestamp": time.time(),
|
672 |
+
},
|
673 |
+
}
|
674 |
+
|
675 |
+
return results
|
676 |
+
|
677 |
+
def print_evaluation_results(self, results: Dict):
|
678 |
+
"""Print simplified evaluation results"""
|
679 |
+
print(f"\n�� SIMPLIFIED EVALUATION RESULTS (MRR + Hit Rates)")
|
680 |
+
print("=" * 60)
|
681 |
+
|
682 |
+
print(f"\n📈 Hit Rates (Source Document Found in Top-K):")
|
683 |
+
print(f"{'K':<5} {'Hit Rate':<12} {'Percentage':<12}")
|
684 |
+
print("-" * 30)
|
685 |
+
|
686 |
+
for k in sorted(results["hit_rates"].keys()):
|
687 |
+
hit_rate = results["hit_rates"][k]
|
688 |
+
percentage = hit_rate * 100
|
689 |
+
print(f"{k:<5} {hit_rate:<12.4f} {percentage:<12.1f}%")
|
690 |
+
|
691 |
+
# Display MRR separately since it's not k-dependent
|
692 |
+
mrr = results["mrr"]
|
693 |
+
print(f"\n📊 Mean Reciprocal Rank (MRR): {mrr:.4f}")
|
694 |
+
print(f" • MRR measures the average reciprocal rank of the source document")
|
695 |
+
print(f" • Higher is better (max = 1.0 if all sources are rank 1)")
|
696 |
+
|
697 |
+
print(f"\n📊 Hit Rate Summary:")
|
698 |
+
for k in sorted(results["hit_rates"].keys()):
|
699 |
+
hit_rate = results["hit_rates"][k]
|
700 |
+
percentage = hit_rate * 100
|
701 |
+
print(
|
702 |
+
f" • Top-{k}: {percentage:.1f}% of questions find their source document"
|
703 |
+
)
|
704 |
+
|
705 |
+
# Summary stats
|
706 |
+
summary = results["summary"]
|
707 |
+
print(f"\n📋 Evaluation Summary:")
|
708 |
+
print(f" • Total queries: {summary['total_queries']}")
|
709 |
+
print(f" • Successfully evaluated: {summary['evaluated_queries']}")
|
710 |
+
print(f" • Failed queries: {summary['failed_queries']}")
|
711 |
+
print(f" • Success rate: {summary['success_rate']:.1f}%")
|
712 |
+
print(f" • Evaluation type: {summary['evaluation_type']}")
|
713 |
+
|
714 |
+
# Simple interpretation
|
715 |
+
avg_hit_rate_5 = results["hit_rates"].get(5, 0)
|
716 |
+
mrr = results["mrr"]
|
717 |
+
print(f"\n🎯 Quick Interpretation:")
|
718 |
+
if avg_hit_rate_5 > 0.8:
|
719 |
+
print(
|
720 |
+
f" ✅ Excellent: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}"
|
721 |
+
)
|
722 |
+
elif avg_hit_rate_5 > 0.6:
|
723 |
+
print(f" 👍 Good: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}")
|
724 |
+
elif avg_hit_rate_5 > 0.4:
|
725 |
+
print(f" ⚠️ Fair: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}")
|
726 |
+
else:
|
727 |
+
print(f" ❌ Poor: {avg_hit_rate_5*100:.1f}% hit rate@5, MRR = {mrr:.3f}")
|
728 |
+
|
729 |
+
|
730 |
+
def main():
|
731 |
+
"""Main function with argument parsing for separate operations"""
|
732 |
+
parser = argparse.ArgumentParser(
|
733 |
+
description="ViettelPay Retrieval Evaluation Dataset Creator (Simplified)"
|
734 |
+
)
|
735 |
+
parser.add_argument(
|
736 |
+
"--mode",
|
737 |
+
choices=["create", "evaluate", "both"],
|
738 |
+
default="both",
|
739 |
+
help="Mode: create dataset, evaluate only, or both",
|
740 |
+
)
|
741 |
+
parser.add_argument(
|
742 |
+
"--dataset-path",
|
743 |
+
default="evaluation_data/datasets/single_turn_retrieval/viettelpay_complete_eval.json",
|
744 |
+
help="Path to dataset file",
|
745 |
+
)
|
746 |
+
parser.add_argument(
|
747 |
+
"--results-path",
|
748 |
+
default="evaluation_data/results/single_turn_retrieval/viettelpay_eval_results.json",
|
749 |
+
help="Path to save evaluation results",
|
750 |
+
)
|
751 |
+
parser.add_argument(
|
752 |
+
"--questions-per-chunk",
|
753 |
+
type=int,
|
754 |
+
default=3,
|
755 |
+
help="Number of questions per chunk",
|
756 |
+
)
|
757 |
+
parser.add_argument(
|
758 |
+
"--k-values",
|
759 |
+
nargs="+",
|
760 |
+
type=int,
|
761 |
+
default=[1, 3, 5, 10],
|
762 |
+
help="K values for evaluation",
|
763 |
+
)
|
764 |
+
parser.add_argument(
|
765 |
+
"--quality-check",
|
766 |
+
action="store_true",
|
767 |
+
help="Enable quality checking for chunks",
|
768 |
+
)
|
769 |
+
parser.add_argument(
|
770 |
+
"--knowledge-base-path",
|
771 |
+
default="./knowledge_base",
|
772 |
+
help="Path to knowledge base",
|
773 |
+
)
|
774 |
+
|
775 |
+
args = parser.parse_args()
|
776 |
+
|
777 |
+
# Configuration
|
778 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
779 |
+
|
780 |
+
if not GEMINI_API_KEY:
|
781 |
+
print("❌ Please set GEMINI_API_KEY environment variable")
|
782 |
+
return
|
783 |
+
|
784 |
+
try:
|
785 |
+
# Initialize knowledge base
|
786 |
+
print("🔧 Initializing ViettelPay knowledge base...")
|
787 |
+
kb = ViettelKnowledgeBase()
|
788 |
+
if not kb.load_knowledge_base(args.knowledge_base_path):
|
789 |
+
print(
|
790 |
+
"❌ Failed to load knowledge base. Please run build_database_script.py first."
|
791 |
+
)
|
792 |
+
return
|
793 |
+
|
794 |
+
# Create dataset if requested
|
795 |
+
if args.mode in ["create", "both"]:
|
796 |
+
print(f"\n🎯 Creating synthetic evaluation dataset...")
|
797 |
+
creator = SingleTurnDatasetCreator(GEMINI_API_KEY, kb)
|
798 |
+
|
799 |
+
dataset = creator.create_complete_dataset(
|
800 |
+
questions_per_chunk=args.questions_per_chunk,
|
801 |
+
save_path=args.dataset_path,
|
802 |
+
quality_check=args.quality_check,
|
803 |
+
)
|
804 |
+
|
805 |
+
# Evaluate if requested
|
806 |
+
if args.mode in ["evaluate", "both"]:
|
807 |
+
print(f"\n⚡ Evaluating retrieval performance...")
|
808 |
+
|
809 |
+
# Load dataset if not created in this run
|
810 |
+
if args.mode == "evaluate":
|
811 |
+
if not os.path.exists(args.dataset_path):
|
812 |
+
print(f"❌ Dataset file not found: {args.dataset_path}")
|
813 |
+
return
|
814 |
+
|
815 |
+
creator = SingleTurnDatasetCreator(GEMINI_API_KEY, kb)
|
816 |
+
dataset = creator.load_dataset(args.dataset_path)
|
817 |
+
|
818 |
+
# Run evaluation
|
819 |
+
evaluator = SingleTurnRetrievalEvaluator(dataset, kb)
|
820 |
+
results = evaluator.evaluate(k_values=args.k_values)
|
821 |
+
evaluator.print_evaluation_results(results)
|
822 |
+
|
823 |
+
# Save results
|
824 |
+
if args.results_path:
|
825 |
+
with open(args.results_path, "w", encoding="utf-8") as f:
|
826 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
827 |
+
print(f"\n💾 Results saved to: {args.results_path}")
|
828 |
+
|
829 |
+
print(f"\n✅ Operation completed successfully!")
|
830 |
+
print(f"\n💡 Next steps:")
|
831 |
+
print(f" 1. Review the MRR and hit rate results")
|
832 |
+
print(f" 2. Identify queries with low performance")
|
833 |
+
print(f" 3. Optimize your retrieval system")
|
834 |
+
print(f" 4. Re-run evaluation to measure progress")
|
835 |
+
|
836 |
+
except Exception as e:
|
837 |
+
print(f"❌ Error in main execution: {e}")
|
838 |
+
import traceback
|
839 |
+
|
840 |
+
traceback.print_exc()
|
841 |
+
|
842 |
+
|
843 |
+
if __name__ == "__main__":
|
844 |
+
main()
|
src/knowledge_base/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
src/knowledge_base/__pycache__/builder.cpython-311.pyc
ADDED
Binary file (18.7 kB). View file
|
|
src/knowledge_base/__pycache__/viettel_knowledge_base.cpython-311.pyc
ADDED
Binary file (24 kB). View file
|
|
src/knowledge_base/viettel_knowledge_base.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ViettelPay Knowledge Base with Contextual Retrieval
|
3 |
+
|
4 |
+
This updated version:
|
5 |
+
- Uses ContextualWordProcessor for all document processing
|
6 |
+
- Integrates OpenAI for contextual enhancement
|
7 |
+
- Processes all doc/docx files from a parent folder
|
8 |
+
- Removes CSV processor dependency
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import pickle
|
13 |
+
|
14 |
+
# import torch
|
15 |
+
from typing import List, Optional
|
16 |
+
from pathlib import Path
|
17 |
+
from openai import OpenAI
|
18 |
+
|
19 |
+
from langchain.schema import Document
|
20 |
+
from langchain.retrievers import EnsembleRetriever
|
21 |
+
from langchain_community.retrievers import BM25Retriever
|
22 |
+
from langchain_core.runnables import ConfigurableField
|
23 |
+
from langchain_cohere.rerank import CohereRerank
|
24 |
+
|
25 |
+
# Use newest import paths for langchain
|
26 |
+
try:
|
27 |
+
from langchain_chroma import Chroma
|
28 |
+
except ImportError:
|
29 |
+
from langchain_community.vectorstores import Chroma
|
30 |
+
|
31 |
+
# Use the new HuggingFaceEmbeddings from langchain-huggingface
|
32 |
+
try:
|
33 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
34 |
+
except ImportError:
|
35 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
36 |
+
|
37 |
+
from src.processor.contextual_word_processor import ContextualWordProcessor
|
38 |
+
from src.processor.text_utils import VietnameseTextProcessor
|
39 |
+
|
40 |
+
# Import configuration utility
|
41 |
+
from src.utils.config import get_cohere_api_key, get_openai_api_key, get_embedding_model
|
42 |
+
|
43 |
+
|
44 |
+
class ViettelKnowledgeBase:
|
45 |
+
"""ViettelPay knowledge base with contextual retrieval enhancement"""
|
46 |
+
|
47 |
+
def __init__(self, embedding_model: str = None):
|
48 |
+
"""
|
49 |
+
Initialize the knowledge base
|
50 |
+
|
51 |
+
Args:
|
52 |
+
embedding_model: Vietnamese embedding model to use
|
53 |
+
"""
|
54 |
+
embedding_model = embedding_model or get_embedding_model()
|
55 |
+
|
56 |
+
# Initialize Vietnamese text processor
|
57 |
+
self.text_processor = VietnameseTextProcessor()
|
58 |
+
|
59 |
+
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
60 |
+
self.device = "cpu"
|
61 |
+
print(f"[INFO] Using device: {self.device}")
|
62 |
+
|
63 |
+
# Initialize embeddings with GPU support and trust_remote_code
|
64 |
+
model_kwargs = {"device": self.device, "trust_remote_code": True}
|
65 |
+
|
66 |
+
self.embeddings = HuggingFaceEmbeddings(
|
67 |
+
model_name=embedding_model, model_kwargs=model_kwargs
|
68 |
+
)
|
69 |
+
|
70 |
+
# Initialize retrievers as None
|
71 |
+
self.chroma_retriever = None
|
72 |
+
self.bm25_retriever = None
|
73 |
+
self.ensemble_retriever = None
|
74 |
+
|
75 |
+
self.reranker = CohereRerank(
|
76 |
+
model="rerank-v3.5",
|
77 |
+
cohere_api_key=get_cohere_api_key(),
|
78 |
+
)
|
79 |
+
|
80 |
+
def build_knowledge_base(
|
81 |
+
self,
|
82 |
+
documents_folder: str,
|
83 |
+
persist_dir: str = "./knowledge_base",
|
84 |
+
reset: bool = True,
|
85 |
+
openai_api_key: Optional[str] = None,
|
86 |
+
) -> None:
|
87 |
+
"""
|
88 |
+
Build knowledge base from all Word documents in a folder
|
89 |
+
|
90 |
+
Args:
|
91 |
+
documents_folder: Path to folder containing doc/docx files
|
92 |
+
persist_dir: Directory to persist the knowledge base
|
93 |
+
reset: Whether to reset existing knowledge base
|
94 |
+
openai_api_key: OpenAI API key for contextual enhancement (optional)
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
None. Use the search() method to perform searches.
|
98 |
+
"""
|
99 |
+
|
100 |
+
print(
|
101 |
+
"[INFO] Building ViettelPay knowledge base with contextual enhancement..."
|
102 |
+
)
|
103 |
+
|
104 |
+
# Initialize OpenAI client for contextual enhancement if API key provided
|
105 |
+
openai_client = None
|
106 |
+
if openai_api_key:
|
107 |
+
openai_client = OpenAI(api_key=openai_api_key)
|
108 |
+
print(f"[INFO] OpenAI client initialized for contextual enhancement")
|
109 |
+
elif get_openai_api_key():
|
110 |
+
api_key = get_openai_api_key()
|
111 |
+
openai_client = OpenAI(api_key=api_key)
|
112 |
+
print(f"[INFO] OpenAI client initialized from configuration")
|
113 |
+
else:
|
114 |
+
print(
|
115 |
+
f"[WARNING] No OpenAI API key provided. Contextual enhancement disabled."
|
116 |
+
)
|
117 |
+
|
118 |
+
# Initialize the contextual word processor with OpenAI client
|
119 |
+
word_processor = ContextualWordProcessor(llm_client=openai_client)
|
120 |
+
|
121 |
+
# Find all Word documents in the folder
|
122 |
+
word_files = self._find_word_documents(documents_folder)
|
123 |
+
|
124 |
+
if not word_files:
|
125 |
+
raise ValueError(f"No Word documents found in {documents_folder}")
|
126 |
+
|
127 |
+
print(f"[INFO] Found {len(word_files)} Word documents to process")
|
128 |
+
|
129 |
+
# Process all documents
|
130 |
+
all_documents = self._process_all_word_files(word_files, word_processor)
|
131 |
+
print(f"[INFO] Total documents processed: {len(all_documents)}")
|
132 |
+
|
133 |
+
# Create directories
|
134 |
+
os.makedirs(persist_dir, exist_ok=True)
|
135 |
+
chroma_dir = os.path.join(persist_dir, "chroma")
|
136 |
+
bm25_path = os.path.join(persist_dir, "bm25_retriever.pkl")
|
137 |
+
|
138 |
+
# Build ChromaDB retriever (uses contextualized content)
|
139 |
+
print("[INFO] Building ChromaDB retriever with contextualized content...")
|
140 |
+
self.chroma_retriever = self._build_chroma_retriever(
|
141 |
+
all_documents, chroma_dir, reset
|
142 |
+
)
|
143 |
+
|
144 |
+
# Build BM25 retriever (uses contextualized content with Vietnamese tokenization)
|
145 |
+
print("[INFO] Building BM25 retriever with Vietnamese tokenization...")
|
146 |
+
self.bm25_retriever = self._build_bm25_retriever(
|
147 |
+
all_documents, bm25_path, reset
|
148 |
+
)
|
149 |
+
|
150 |
+
# Create ensemble retriever with configurable top-k
|
151 |
+
print("[INFO] Creating ensemble retriever...")
|
152 |
+
self.ensemble_retriever = self._build_retriever(
|
153 |
+
self.bm25_retriever, self.chroma_retriever
|
154 |
+
)
|
155 |
+
|
156 |
+
print("[SUCCESS] Contextual knowledge base built successfully!")
|
157 |
+
print("[INFO] Use kb.search(query, top_k) to perform searches.")
|
158 |
+
|
159 |
+
def load_knowledge_base(self, persist_dir: str = "./knowledge_base") -> bool:
|
160 |
+
"""
|
161 |
+
Load existing knowledge base from disk and rebuild BM25 from ChromaDB documents
|
162 |
+
|
163 |
+
Args:
|
164 |
+
persist_dir: Directory where the knowledge base is stored
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
bool: True if loaded successfully, False otherwise
|
168 |
+
"""
|
169 |
+
|
170 |
+
print("[INFO] Loading knowledge base from disk...")
|
171 |
+
|
172 |
+
chroma_dir = os.path.join(persist_dir, "chroma")
|
173 |
+
|
174 |
+
try:
|
175 |
+
# Load ChromaDB
|
176 |
+
if os.path.exists(chroma_dir):
|
177 |
+
vectorstore = Chroma(
|
178 |
+
persist_directory=chroma_dir, embedding_function=self.embeddings
|
179 |
+
)
|
180 |
+
|
181 |
+
self.chroma_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
182 |
+
print("[SUCCESS] ChromaDB loaded")
|
183 |
+
else:
|
184 |
+
print("[ERROR] ChromaDB not found")
|
185 |
+
return False
|
186 |
+
|
187 |
+
# Extract all documents from ChromaDB to rebuild BM25
|
188 |
+
print("[INFO] Extracting documents from ChromaDB to rebuild BM25...")
|
189 |
+
try:
|
190 |
+
# Get all documents and metadata from ChromaDB
|
191 |
+
all_docs = vectorstore.get(include=["documents", "metadatas"])
|
192 |
+
|
193 |
+
documents = all_docs["documents"]
|
194 |
+
metadatas = all_docs["metadatas"]
|
195 |
+
|
196 |
+
# Reconstruct Document objects
|
197 |
+
doc_objects = []
|
198 |
+
for i, (doc_content, metadata) in enumerate(zip(documents, metadatas)):
|
199 |
+
# Handle case where metadata might be None
|
200 |
+
if metadata is None:
|
201 |
+
metadata = {}
|
202 |
+
|
203 |
+
doc_obj = Document(page_content=doc_content, metadata=metadata)
|
204 |
+
doc_objects.append(doc_obj)
|
205 |
+
|
206 |
+
print(f"[INFO] Extracted {len(doc_objects)} documents from ChromaDB")
|
207 |
+
|
208 |
+
# Rebuild BM25 retriever using existing method
|
209 |
+
self.bm25_retriever = self._build_bm25_retriever(
|
210 |
+
documents=doc_objects,
|
211 |
+
bm25_path=None, # Not used anymore
|
212 |
+
reset=False, # Not relevant for rebuilding
|
213 |
+
)
|
214 |
+
|
215 |
+
except Exception as e:
|
216 |
+
print(f"[ERROR] Error rebuilding BM25 from ChromaDB: {e}")
|
217 |
+
return False
|
218 |
+
|
219 |
+
# Create ensemble retriever with configurable top-k
|
220 |
+
self.ensemble_retriever = self._build_retriever(
|
221 |
+
self.bm25_retriever, self.chroma_retriever
|
222 |
+
)
|
223 |
+
|
224 |
+
print("[SUCCESS] Knowledge base loaded successfully!")
|
225 |
+
print("[INFO] Use kb.search(query, top_k) to perform searches.")
|
226 |
+
return True
|
227 |
+
|
228 |
+
except Exception as e:
|
229 |
+
print(f"[ERROR] Error loading knowledge base: {e}")
|
230 |
+
return False
|
231 |
+
|
232 |
+
def search(self, query: str, top_k: int = 10) -> List[Document]:
|
233 |
+
"""
|
234 |
+
Main search method using ensemble retriever with configurable top-k
|
235 |
+
|
236 |
+
Args:
|
237 |
+
query: Search query
|
238 |
+
top_k: Number of documents to return from each retriever (default: 5)
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
List of retrieved documents
|
242 |
+
"""
|
243 |
+
if not self.ensemble_retriever:
|
244 |
+
raise ValueError(
|
245 |
+
"Knowledge base not loaded. Call build_knowledge_base() or load_knowledge_base() first."
|
246 |
+
)
|
247 |
+
|
248 |
+
# Build config based on top_k parameter
|
249 |
+
config = {
|
250 |
+
"configurable": {
|
251 |
+
"bm25_k": top_k * 5,
|
252 |
+
"chroma_search_kwargs": {"k": top_k * 5},
|
253 |
+
}
|
254 |
+
}
|
255 |
+
|
256 |
+
results = self.ensemble_retriever.invoke(query, config=config)
|
257 |
+
reranked_results = self.reranker.rerank(results, query, top_n=top_k)
|
258 |
+
|
259 |
+
final_results = []
|
260 |
+
for rerank_item in reranked_results:
|
261 |
+
# Get the original document using the index
|
262 |
+
original_doc = results[rerank_item["index"]]
|
263 |
+
|
264 |
+
# Create a new document with the relevance score added to metadata
|
265 |
+
reranked_doc = Document(
|
266 |
+
page_content=original_doc.page_content,
|
267 |
+
metadata={
|
268 |
+
**original_doc.metadata,
|
269 |
+
"relevance_score": rerank_item["relevance_score"],
|
270 |
+
},
|
271 |
+
)
|
272 |
+
final_results.append(reranked_doc)
|
273 |
+
|
274 |
+
return final_results
|
275 |
+
|
276 |
+
def get_stats(self) -> dict:
|
277 |
+
"""Get statistics about the knowledge base"""
|
278 |
+
stats = {}
|
279 |
+
|
280 |
+
if self.chroma_retriever:
|
281 |
+
try:
|
282 |
+
vectorstore = self.chroma_retriever.vectorstore
|
283 |
+
collection = vectorstore._collection
|
284 |
+
stats["chroma_documents"] = collection.count()
|
285 |
+
except:
|
286 |
+
stats["chroma_documents"] = "Unknown"
|
287 |
+
|
288 |
+
if self.bm25_retriever:
|
289 |
+
try:
|
290 |
+
stats["bm25_documents"] = len(self.bm25_retriever.docs)
|
291 |
+
except:
|
292 |
+
stats["bm25_documents"] = "Unknown"
|
293 |
+
|
294 |
+
stats["ensemble_available"] = self.ensemble_retriever is not None
|
295 |
+
stats["device"] = self.device
|
296 |
+
stats["vietnamese_tokenizer"] = "Vietnamese BM25 tokenizer (underthesea)"
|
297 |
+
|
298 |
+
return stats
|
299 |
+
|
300 |
+
def _find_word_documents(self, folder_path: str) -> List[str]:
|
301 |
+
"""
|
302 |
+
Find all Word documents (.doc, .docx) in the given folder
|
303 |
+
|
304 |
+
Args:
|
305 |
+
folder_path: Path to the folder to search
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
List of full paths to Word documents
|
309 |
+
"""
|
310 |
+
word_files = []
|
311 |
+
folder = Path(folder_path)
|
312 |
+
|
313 |
+
if not folder.exists():
|
314 |
+
raise FileNotFoundError(f"Folder not found: {folder_path}")
|
315 |
+
|
316 |
+
# Search for Word documents
|
317 |
+
for pattern in ["*.doc", "*.docx"]:
|
318 |
+
word_files.extend(folder.glob(pattern))
|
319 |
+
|
320 |
+
# Convert to string paths and sort for consistent processing order
|
321 |
+
word_files = [str(f) for f in word_files]
|
322 |
+
word_files.sort()
|
323 |
+
|
324 |
+
print(f"[INFO] Found Word documents: {[Path(f).name for f in word_files]}")
|
325 |
+
return word_files
|
326 |
+
|
327 |
+
def _process_all_word_files(
|
328 |
+
self, word_files: List[str], word_processor: ContextualWordProcessor
|
329 |
+
) -> List[Document]:
|
330 |
+
"""Process all Word files into unified chunks with contextual enhancement"""
|
331 |
+
all_documents = []
|
332 |
+
|
333 |
+
for file_path in word_files:
|
334 |
+
try:
|
335 |
+
print(f"[INFO] Processing: {Path(file_path).name}")
|
336 |
+
chunks = word_processor.process_word_document(file_path)
|
337 |
+
all_documents.extend(chunks)
|
338 |
+
|
339 |
+
# Print processing stats for this file
|
340 |
+
stats = word_processor.get_document_stats(chunks)
|
341 |
+
print(
|
342 |
+
f"[SUCCESS] Processed {Path(file_path).name}: {len(chunks)} chunks"
|
343 |
+
)
|
344 |
+
print(f" - Contextualized: {stats.get('contextualized_docs', 0)}")
|
345 |
+
print(
|
346 |
+
f" - Non-contextualized: {stats.get('non_contextualized_docs', 0)}"
|
347 |
+
)
|
348 |
+
|
349 |
+
except Exception as e:
|
350 |
+
print(f"[ERROR] Error processing {Path(file_path).name}: {e}")
|
351 |
+
|
352 |
+
return all_documents
|
353 |
+
|
354 |
+
def _build_retriever(self, bm25_retriever, chroma_retriever):
|
355 |
+
"""
|
356 |
+
Build ensemble retriever with configurable top-k parameters
|
357 |
+
|
358 |
+
Args:
|
359 |
+
bm25_retriever: BM25 retriever with configurable fields
|
360 |
+
chroma_retriever: Chroma retriever with configurable fields
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
EnsembleRetriever with configurable retrievers
|
364 |
+
"""
|
365 |
+
return EnsembleRetriever(
|
366 |
+
retrievers=[bm25_retriever, chroma_retriever],
|
367 |
+
weights=[0.2, 0.8], # Slightly favor semantic search
|
368 |
+
)
|
369 |
+
|
370 |
+
def _build_chroma_retriever(
|
371 |
+
self, documents: List[Document], chroma_dir: str, reset: bool
|
372 |
+
):
|
373 |
+
"""Build ChromaDB retriever with configurable search parameters"""
|
374 |
+
|
375 |
+
if reset and os.path.exists(chroma_dir):
|
376 |
+
import shutil
|
377 |
+
|
378 |
+
shutil.rmtree(chroma_dir)
|
379 |
+
print("[INFO] Removed existing ChromaDB for rebuild")
|
380 |
+
|
381 |
+
# Create Chroma vectorstore (uses contextualized content)
|
382 |
+
vectorstore = Chroma.from_documents(
|
383 |
+
documents=documents, embedding=self.embeddings, persist_directory=chroma_dir
|
384 |
+
)
|
385 |
+
|
386 |
+
# Create retriever with configurable search_kwargs
|
387 |
+
retriever = vectorstore.as_retriever(
|
388 |
+
search_kwargs={"k": 5} # default value
|
389 |
+
).configurable_fields(
|
390 |
+
search_kwargs=ConfigurableField(
|
391 |
+
id="chroma_search_kwargs",
|
392 |
+
name="Chroma Search Kwargs",
|
393 |
+
description="Search kwargs for Chroma DB retriever",
|
394 |
+
)
|
395 |
+
)
|
396 |
+
|
397 |
+
print(
|
398 |
+
f"[SUCCESS] ChromaDB created with {len(documents)} contextualized documents"
|
399 |
+
)
|
400 |
+
return retriever
|
401 |
+
|
402 |
+
def _build_bm25_retriever(
|
403 |
+
self, documents: List[Document], bm25_path: str, reset: bool
|
404 |
+
):
|
405 |
+
"""Build BM25 retriever with Vietnamese tokenization and configurable k parameter"""
|
406 |
+
|
407 |
+
# Note: We no longer save BM25 to pickle file to avoid Streamlit Cloud compatibility issues
|
408 |
+
# BM25 will be rebuilt from ChromaDB documents when loading the knowledge base
|
409 |
+
|
410 |
+
# Create BM25 retriever with Vietnamese tokenizer as preprocess_func
|
411 |
+
print("[INFO] Using Vietnamese tokenizer for BM25 on contextualized content...")
|
412 |
+
retriever = BM25Retriever.from_documents(
|
413 |
+
documents=documents,
|
414 |
+
preprocess_func=self.text_processor.bm25_tokenizer,
|
415 |
+
k=5, # default value
|
416 |
+
).configurable_fields(
|
417 |
+
k=ConfigurableField(
|
418 |
+
id="bm25_k",
|
419 |
+
name="BM25 Top K",
|
420 |
+
description="Number of documents to return from BM25",
|
421 |
+
)
|
422 |
+
)
|
423 |
+
|
424 |
+
print(
|
425 |
+
f"[SUCCESS] BM25 retriever created with {len(documents)} contextualized documents"
|
426 |
+
)
|
427 |
+
return retriever
|
428 |
+
|
429 |
+
|
430 |
+
def test_contextual_kb(kb: ViettelKnowledgeBase, test_queries: List[str]):
|
431 |
+
"""Test function for the contextual knowledge base"""
|
432 |
+
|
433 |
+
print("\n[INFO] Testing Contextual Knowledge Base")
|
434 |
+
print("=" * 60)
|
435 |
+
|
436 |
+
for i, query in enumerate(test_queries, 1):
|
437 |
+
print(f"\n#{i} Query: '{query}'")
|
438 |
+
print("-" * 40)
|
439 |
+
|
440 |
+
try:
|
441 |
+
# Test ensemble search with configurable top-k
|
442 |
+
results = kb.search(query, top_k=3)
|
443 |
+
|
444 |
+
if results:
|
445 |
+
for j, doc in enumerate(results, 1):
|
446 |
+
content_preview = doc.page_content[:150].replace("\n", " ")
|
447 |
+
doc_type = doc.metadata.get("doc_type", "unknown")
|
448 |
+
has_context = doc.metadata.get("has_context", False)
|
449 |
+
context_indicator = (
|
450 |
+
" [CONTEXTUAL]" if has_context else " [ORIGINAL]"
|
451 |
+
)
|
452 |
+
print(
|
453 |
+
f" {j}. [{doc_type}]{context_indicator} {content_preview}..."
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
print(" No results found")
|
457 |
+
|
458 |
+
except Exception as e:
|
459 |
+
print(f" [ERROR] Error: {e}")
|
460 |
+
|
461 |
+
|
462 |
+
# Example usage
|
463 |
+
if __name__ == "__main__":
|
464 |
+
# Initialize knowledge base
|
465 |
+
kb = ViettelKnowledgeBase(
|
466 |
+
embedding_model="dangvantuan/vietnamese-document-embedding"
|
467 |
+
)
|
468 |
+
|
469 |
+
# Build knowledge base from a folder of Word documents
|
470 |
+
documents_folder = "./viettelpay_docs" # Folder containing .doc/.docx files
|
471 |
+
|
472 |
+
try:
|
473 |
+
# Build knowledge base (pass OpenAI API key here for contextual enhancement)
|
474 |
+
kb.build_knowledge_base(
|
475 |
+
documents_folder,
|
476 |
+
"./contextual_kb",
|
477 |
+
reset=True,
|
478 |
+
openai_api_key="your-openai-api-key-here", # or None to use env variable
|
479 |
+
)
|
480 |
+
|
481 |
+
# Alternative: Load existing knowledge base
|
482 |
+
# success = kb.load_knowledge_base("./contextual_kb")
|
483 |
+
# if not success:
|
484 |
+
# print("[ERROR] Failed to load knowledge base")
|
485 |
+
|
486 |
+
# Test queries
|
487 |
+
test_queries = [
|
488 |
+
"lỗi 606",
|
489 |
+
"không nạp được tiền",
|
490 |
+
"hướng dẫn nạp cước",
|
491 |
+
"quy định hủy giao dịch",
|
492 |
+
"mệnh giá thẻ cào",
|
493 |
+
]
|
494 |
+
|
495 |
+
# Test the knowledge base
|
496 |
+
test_contextual_kb(kb, test_queries)
|
497 |
+
|
498 |
+
# Example of runtime configuration for different top-k values
|
499 |
+
print(f"\n[INFO] Example of runtime configuration:")
|
500 |
+
print("=" * 50)
|
501 |
+
|
502 |
+
# Search with different top-k values
|
503 |
+
sample_query = "lỗi 606"
|
504 |
+
|
505 |
+
# Search with top_k=3
|
506 |
+
results1 = kb.search(sample_query, top_k=3)
|
507 |
+
print(f"Search with top_k=3: {len(results1)} total results")
|
508 |
+
|
509 |
+
# Search with top_k=8
|
510 |
+
results2 = kb.search(sample_query, top_k=8)
|
511 |
+
print(f"Search with top_k=8: {len(results2)} total results")
|
512 |
+
|
513 |
+
# Show stats
|
514 |
+
print(f"\n[INFO] Knowledge Base Stats: {kb.get_stats()}")
|
515 |
+
|
516 |
+
except Exception as e:
|
517 |
+
print(f"[ERROR] Error building knowledge base: {e}")
|
518 |
+
print("[INFO] Make sure you have:")
|
519 |
+
print(" 1. Valid OpenAI API key")
|
520 |
+
print(" 2. Word documents in the specified folder")
|
521 |
+
print(" 3. Required dependencies installed (openai, markitdown, etc.)")
|
src/llm/__pycache__/langchain_models.cpython-311.pyc
ADDED
Binary file (4.49 kB). View file
|
|
src/llm/__pycache__/llm_client.cpython-310.pyc
ADDED
Binary file (5.65 kB). View file
|
|
src/llm/__pycache__/llm_client.cpython-311.pyc
ADDED
Binary file (9.31 kB). View file
|
|
src/llm/llm_client.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM Client Abstraction Layer
|
3 |
+
Supports multiple LLM providers without hardcoding
|
4 |
+
"""
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from typing import Dict, Any, Optional
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Import configuration utility
|
11 |
+
from src.utils.config import get_gemini_api_key, get_openai_api_key
|
12 |
+
|
13 |
+
|
14 |
+
class BaseLLMClient(ABC):
|
15 |
+
"""Abstract base class for LLM clients"""
|
16 |
+
|
17 |
+
def __init__(self, **kwargs):
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
22 |
+
"""Generate response from prompt"""
|
23 |
+
pass
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def is_available(self) -> bool:
|
27 |
+
"""Check if LLM service is available"""
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
class GeminiClient(BaseLLMClient):
|
32 |
+
"""Google Gemini client implementation"""
|
33 |
+
|
34 |
+
def __init__(self, api_key: Optional[str] = None, model: str = "gemini-2.0-flash"):
|
35 |
+
self.api_key = api_key or get_gemini_api_key()
|
36 |
+
self.model = model
|
37 |
+
|
38 |
+
if not self.api_key:
|
39 |
+
raise ValueError("Gemini API key not provided")
|
40 |
+
|
41 |
+
try:
|
42 |
+
import google.generativeai as genai
|
43 |
+
|
44 |
+
genai.configure(api_key=self.api_key)
|
45 |
+
self.client = genai.GenerativeModel(self.model)
|
46 |
+
print(f"✅ Gemini client initialized with model: {self.model}")
|
47 |
+
except ImportError:
|
48 |
+
raise ImportError("google-generativeai package not installed")
|
49 |
+
|
50 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
51 |
+
"""Generate response using Gemini"""
|
52 |
+
try:
|
53 |
+
# Set default temperature to 0.1 for consistency
|
54 |
+
generation_config = {
|
55 |
+
"temperature": kwargs.get("temperature", 0.1),
|
56 |
+
"top_p": kwargs.get("top_p", 0.8),
|
57 |
+
"top_k": kwargs.get("top_k", 40),
|
58 |
+
"max_output_tokens": kwargs.get("max_output_tokens", 2048),
|
59 |
+
}
|
60 |
+
|
61 |
+
response = self.client.generate_content(
|
62 |
+
prompt, generation_config=generation_config
|
63 |
+
)
|
64 |
+
return response.text
|
65 |
+
except Exception as e:
|
66 |
+
print(f"❌ Gemini generation error: {e}")
|
67 |
+
raise
|
68 |
+
|
69 |
+
def is_available(self) -> bool:
|
70 |
+
"""Check Gemini availability"""
|
71 |
+
try:
|
72 |
+
test_response = self.client.generate_content("Hello")
|
73 |
+
return bool(test_response.text)
|
74 |
+
except:
|
75 |
+
return False
|
76 |
+
|
77 |
+
|
78 |
+
class OpenAIClient(BaseLLMClient):
|
79 |
+
"""OpenAI client implementation"""
|
80 |
+
|
81 |
+
def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4"):
|
82 |
+
self.api_key = api_key or get_openai_api_key()
|
83 |
+
self.model = model
|
84 |
+
|
85 |
+
if not self.api_key:
|
86 |
+
raise ValueError("OpenAI API key not provided")
|
87 |
+
|
88 |
+
try:
|
89 |
+
import openai
|
90 |
+
|
91 |
+
self.client = openai.OpenAI(api_key=self.api_key)
|
92 |
+
print(f"✅ OpenAI client initialized with model: {self.model}")
|
93 |
+
except ImportError:
|
94 |
+
raise ImportError("openai package not installed")
|
95 |
+
|
96 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
97 |
+
"""Generate response using OpenAI"""
|
98 |
+
try:
|
99 |
+
# Set default temperature to 0.1 for consistency
|
100 |
+
openai_kwargs = {
|
101 |
+
"temperature": kwargs.get("temperature", 0.1),
|
102 |
+
"top_p": kwargs.get("top_p", 1.0),
|
103 |
+
"max_tokens": kwargs.get("max_tokens", 2048),
|
104 |
+
}
|
105 |
+
# Remove any Gemini-specific parameters
|
106 |
+
openai_kwargs.update(
|
107 |
+
{
|
108 |
+
k: v
|
109 |
+
for k, v in kwargs.items()
|
110 |
+
if k
|
111 |
+
in [
|
112 |
+
"temperature",
|
113 |
+
"top_p",
|
114 |
+
"max_tokens",
|
115 |
+
"frequency_penalty",
|
116 |
+
"presence_penalty",
|
117 |
+
]
|
118 |
+
}
|
119 |
+
)
|
120 |
+
|
121 |
+
response = self.client.chat.completions.create(
|
122 |
+
model=self.model,
|
123 |
+
messages=[{"role": "user", "content": prompt}],
|
124 |
+
**openai_kwargs,
|
125 |
+
)
|
126 |
+
return response.choices[0].message.content
|
127 |
+
except Exception as e:
|
128 |
+
print(f"❌ OpenAI generation error: {e}")
|
129 |
+
raise
|
130 |
+
|
131 |
+
def is_available(self) -> bool:
|
132 |
+
"""Check OpenAI availability"""
|
133 |
+
try:
|
134 |
+
response = self.client.chat.completions.create(
|
135 |
+
model=self.model,
|
136 |
+
messages=[{"role": "user", "content": "Hello"}],
|
137 |
+
max_tokens=5,
|
138 |
+
)
|
139 |
+
return bool(response.choices[0].message.content)
|
140 |
+
except:
|
141 |
+
return False
|
142 |
+
|
143 |
+
|
144 |
+
class LLMClientFactory:
|
145 |
+
"""Factory for creating LLM clients"""
|
146 |
+
|
147 |
+
SUPPORTED_PROVIDERS = {
|
148 |
+
"gemini": GeminiClient,
|
149 |
+
"openai": OpenAIClient,
|
150 |
+
}
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def create_client(self, provider: str = "gemini", **kwargs) -> BaseLLMClient:
|
154 |
+
"""Create LLM client by provider name"""
|
155 |
+
|
156 |
+
if provider not in self.SUPPORTED_PROVIDERS:
|
157 |
+
raise ValueError(
|
158 |
+
f"Unsupported provider: {provider}. Supported: {list(self.SUPPORTED_PROVIDERS.keys())}"
|
159 |
+
)
|
160 |
+
|
161 |
+
client_class = self.SUPPORTED_PROVIDERS[provider]
|
162 |
+
return client_class(**kwargs)
|
163 |
+
|
164 |
+
@classmethod
|
165 |
+
def get_available_providers(cls) -> list:
|
166 |
+
"""Get list of available providers"""
|
167 |
+
return list(cls.SUPPORTED_PROVIDERS.keys())
|
168 |
+
|
169 |
+
|
170 |
+
# Usage example
|
171 |
+
if __name__ == "__main__":
|
172 |
+
# Test Gemini client
|
173 |
+
try:
|
174 |
+
client = LLMClientFactory.create_client("gemini")
|
175 |
+
if client.is_available():
|
176 |
+
response = client.generate("Xin chào, bạn có khỏe không?")
|
177 |
+
print(f"Response: {response}")
|
178 |
+
else:
|
179 |
+
print("Gemini not available")
|
180 |
+
except Exception as e:
|
181 |
+
print(f"Error: {e}")
|
src/processor/__pycache__/contextual_word_processor.cpython-311.pyc
ADDED
Binary file (17.2 kB). View file
|
|
src/processor/__pycache__/csv_processor.cpython-310.pyc
ADDED
Binary file (4.85 kB). View file
|
|
src/processor/__pycache__/csv_processor.cpython-311.pyc
ADDED
Binary file (9.4 kB). View file
|
|
src/processor/__pycache__/csv_processor.cpython-312.pyc
ADDED
Binary file (10.4 kB). View file
|
|