Spaces:
Sleeping
Sleeping
Commit
·
f3a5d80
1
Parent(s):
81926e1
Add multi-turn conversation feature
Browse files- chatbot/Custom_chain.py +42 -0
- chatbot/__pycache__/core.cpython-310.pyc +0 -0
- chatbot/__pycache__/memory.cpython-310.pyc +0 -0
- chatbot/__pycache__/prompts.cpython-310.pyc +0 -0
- chatbot/core.py +40 -13
- chatbot/memory.py +44 -2
- chatbot/prompts.py +27 -3
- preprocessing_data.ipynb +0 -0
chatbot/Custom_chain.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains.base import Chain
|
2 |
+
from langchain.schema import BaseRetriever
|
3 |
+
from langchain.llms import BaseLLM
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from pydantic import Field
|
6 |
+
from typing import Dict, Any
|
7 |
+
|
8 |
+
class MyCustomMemoryRetrievalChain(Chain):
|
9 |
+
"""
|
10 |
+
Custom chain cho phép truyền question, memory.
|
11 |
+
Lấy docs từ retriever, trộn với prompt, gọi LLM.
|
12 |
+
"""
|
13 |
+
|
14 |
+
llm: BaseLLM = Field(...)
|
15 |
+
retriever: BaseRetriever = Field(...)
|
16 |
+
prompt: PromptTemplate = Field(...)
|
17 |
+
output_key: str = "result"
|
18 |
+
|
19 |
+
@property
|
20 |
+
def input_keys(self) -> list:
|
21 |
+
return ["question", "memory"]
|
22 |
+
|
23 |
+
@property
|
24 |
+
def output_keys(self) -> list:
|
25 |
+
return [self.output_key]
|
26 |
+
|
27 |
+
def _call(self, inputs: Dict[str, Any], run_manager=None) -> Dict[str, Any]:
|
28 |
+
question = inputs["question"]
|
29 |
+
memory = inputs["memory"]
|
30 |
+
|
31 |
+
docs = self.retriever.get_relevant_documents(question)
|
32 |
+
context = "\n".join(doc.page_content for doc in docs)
|
33 |
+
|
34 |
+
final_prompt = self.prompt.format(
|
35 |
+
question=question,
|
36 |
+
memory=memory,
|
37 |
+
context=context
|
38 |
+
)
|
39 |
+
|
40 |
+
answer = self.llm(final_prompt)
|
41 |
+
|
42 |
+
return {self.output_key: answer}
|
chatbot/__pycache__/core.cpython-310.pyc
CHANGED
Binary files a/chatbot/__pycache__/core.cpython-310.pyc and b/chatbot/__pycache__/core.cpython-310.pyc differ
|
|
chatbot/__pycache__/memory.cpython-310.pyc
CHANGED
Binary files a/chatbot/__pycache__/memory.cpython-310.pyc and b/chatbot/__pycache__/memory.cpython-310.pyc differ
|
|
chatbot/__pycache__/prompts.cpython-310.pyc
CHANGED
Binary files a/chatbot/__pycache__/prompts.cpython-310.pyc and b/chatbot/__pycache__/prompts.cpython-310.pyc differ
|
|
chatbot/core.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
from .llm import gemini_llm
|
2 |
from .retrieval import load_vectordb
|
3 |
-
from .
|
4 |
-
from .
|
5 |
-
from langchain.chains import ConversationalRetrievalChain
|
6 |
from .metadata_selfquery import metadata_field_info
|
7 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
8 |
from langchain.retrievers.self_query.qdrant import QdrantTranslator
|
9 |
-
|
10 |
-
|
|
|
11 |
vector_store = load_vectordb()
|
12 |
|
13 |
def classify_query(query):
|
@@ -21,22 +21,49 @@ retriever = SelfQueryRetriever.from_llm(
|
|
21 |
metadata_field_info=metadata_field_info,
|
22 |
structured_query_translator= QdrantTranslator(metadata_key="metadata"),
|
23 |
search_type="similarity",
|
24 |
-
search_kwargs={"k": 10
|
25 |
)
|
26 |
|
27 |
-
|
28 |
llm=gemini_llm,
|
29 |
retriever=retriever,
|
30 |
-
memory=memory,
|
31 |
return_source_documents= False,
|
32 |
-
|
33 |
output_key="result"
|
34 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def get_chat_response(user_input: str) -> str:
|
37 |
-
|
38 |
-
response = qa_chain({"question": classified_query})
|
39 |
|
40 |
-
memory.
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
return response["result"]
|
|
|
1 |
from .llm import gemini_llm
|
2 |
from .retrieval import load_vectordb
|
3 |
+
from .prompts import chat_prompt_no_memory, chat_prompt_memory, classification_prompt, category_tree_json
|
4 |
+
from langchain.chains import RetrievalQA
|
|
|
5 |
from .metadata_selfquery import metadata_field_info
|
6 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
7 |
from langchain.retrievers.self_query.qdrant import QdrantTranslator
|
8 |
+
from .memory import ShortTermMemory
|
9 |
+
from .Custom_chain import MyCustomMemoryRetrievalChain
|
10 |
+
memory = ShortTermMemory()
|
11 |
vector_store = load_vectordb()
|
12 |
|
13 |
def classify_query(query):
|
|
|
21 |
metadata_field_info=metadata_field_info,
|
22 |
structured_query_translator= QdrantTranslator(metadata_key="metadata"),
|
23 |
search_type="similarity",
|
24 |
+
search_kwargs={"k": 10}
|
25 |
)
|
26 |
|
27 |
+
qa_chain1 = RetrievalQA.from_chain_type(
|
28 |
llm=gemini_llm,
|
29 |
retriever=retriever,
|
|
|
30 |
return_source_documents= False,
|
31 |
+
chain_type_kwargs={"prompt": chat_prompt_no_memory},
|
32 |
output_key="result"
|
33 |
)
|
34 |
+
qa_chain2 = MyCustomMemoryRetrievalChain(
|
35 |
+
llm= gemini_llm,
|
36 |
+
retriever= retriever,
|
37 |
+
prompt= chat_prompt_memory,
|
38 |
+
output_key="result"
|
39 |
+
)
|
40 |
+
|
41 |
+
# qa_chain2 = RetrievalQA.from_chain_type(
|
42 |
+
# llm=gemini_llm,
|
43 |
+
# retriever=retriever,
|
44 |
+
# return_source_documents=False,
|
45 |
+
# chain_type_kwargs={
|
46 |
+
# "prompt": chat_prompt_memory,
|
47 |
+
# "document_variable_name": "context"
|
48 |
+
# },
|
49 |
+
# output_key="result"
|
50 |
+
# )
|
51 |
+
|
52 |
|
53 |
def get_chat_response(user_input: str) -> str:
|
54 |
+
restriction = classify_query(user_input.strip())
|
|
|
55 |
|
56 |
+
if memory.related_to_cache(user_input):
|
57 |
+
print("Liên quan tới câu trước")
|
58 |
+
response = qa_chain2({"question": user_input + memory.restrict + restriction, "memory": memory.get_memory_text()})
|
59 |
+
|
60 |
+
else:
|
61 |
+
memory.reset()
|
62 |
+
classified_query = user_input + restriction
|
63 |
+
memory.restrict = restriction
|
64 |
+
print("Không liên quan tới câu trước")
|
65 |
+
response = qa_chain1({"query": classified_query})
|
66 |
+
print(restriction)
|
67 |
+
memory.add(user_input, response["result"])
|
68 |
+
print(memory.restrict)
|
69 |
return response["result"]
|
chatbot/memory.py
CHANGED
@@ -1,3 +1,45 @@
|
|
1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
|
|
1 |
+
from collections import deque
|
2 |
+
from difflib import SequenceMatcher
|
3 |
+
from langchain_core.prompts import PromptTemplate
|
4 |
+
from langchain.chains import LLMChain
|
5 |
+
from .llm import gemini_llm
|
6 |
+
|
7 |
+
related_prompt = PromptTemplate(
|
8 |
+
input_variables=["q1", "q2"],
|
9 |
+
template="""
|
10 |
+
Bạn là một trợ lý thông minh. Nhiệm vụ của bạn là xác định xem hai câu hỏi hiện tại, có phải là nối tiếp hay cùng mục đích, liên quan tới ngữ cảnh trước đó không:
|
11 |
+
|
12 |
+
Câu hỏi hiện tại: {q1}
|
13 |
+
Hội thoại trước đó: {q2} (nếu không có gì, trả lời là không)
|
14 |
+
|
15 |
+
Hai câu hỏi này có liên quan không? Trả lời ngắn gọn: Có hoặc Không.
|
16 |
+
"""
|
17 |
+
)
|
18 |
+
|
19 |
+
check_related_chain = LLMChain(llm=gemini_llm, prompt=related_prompt)
|
20 |
+
|
21 |
+
class ShortTermMemory:
|
22 |
+
def __init__(self, maxlen=3):
|
23 |
+
self.cache = deque(maxlen=maxlen)
|
24 |
+
self.restrict = ""
|
25 |
+
|
26 |
+
def is_similar(self, q1, q2):
|
27 |
+
response = check_related_chain.invoke({"q1": q1, "q2": q2})
|
28 |
+
print(response['text'])
|
29 |
+
return "có" in response['text'].lower()
|
30 |
+
|
31 |
+
|
32 |
+
def related_to_cache(self, query):
|
33 |
+
return self.is_similar(query, self.get_memory_text())
|
34 |
+
|
35 |
+
def add(self, query, answer):
|
36 |
+
text_entry = f"Người dùng hỏi: {query}\n Hệ thống trả lời: {answer}"
|
37 |
+
self.cache.append(text_entry)
|
38 |
+
|
39 |
+
def reset(self):
|
40 |
+
self.cache.clear()
|
41 |
+
self.restrict = ""
|
42 |
+
def get_memory_text(self):
|
43 |
+
return "\n".join(self.cache)
|
44 |
+
|
45 |
|
|
chatbot/prompts.py
CHANGED
@@ -3,7 +3,7 @@ from .llm import gemini_llm
|
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
import json
|
5 |
|
6 |
-
|
7 |
input_variables=["context", "question"],
|
8 |
template="""
|
9 |
Bạn là trợ lý AI bán hàng của **Rạng Đông Store**, chuyên hỗ trợ khách hàng tìm kiếm và lựa chọn các sản phẩm chiếu sáng và gia dụng chất lượng cao.
|
@@ -43,7 +43,7 @@ Bạn là trợ lý AI bán hàng của **Rạng Đông Store**, chuyên hỗ tr
|
|
43 |
`Giá`, `Công suất`, `Góc chiếu`, `Độ rọi`.
|
44 |
- Nếu là **bình giữ nhiệt hoặc phích nước**, hãy nêu rõ:
|
45 |
`Giá`, `Dung tích`, `Thời gian giữ nhiệt`.
|
46 |
-
|
47 |
---
|
48 |
|
49 |
### Câu hỏi từ khách hàng:
|
@@ -229,5 +229,29 @@ classification_prompt = ChatPromptTemplate.from_messages([
|
|
229 |
"Hãy đọc câu hỏi của khách hàng và xác định danh mục thích hợp L1, L2, L3"
|
230 |
"Và giá thấp nhất khách hàng mua, giá cao nhất khách hàng mua"),
|
231 |
("human", "Câu hỏi: {query}. Hãy trả về danh mục thích hợp."
|
232 |
-
"Trả lời theo định dạng sau:
|
|
|
|
|
|
|
233 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
import json
|
5 |
|
6 |
+
chat_prompt_no_memory = PromptTemplate(
|
7 |
input_variables=["context", "question"],
|
8 |
template="""
|
9 |
Bạn là trợ lý AI bán hàng của **Rạng Đông Store**, chuyên hỗ trợ khách hàng tìm kiếm và lựa chọn các sản phẩm chiếu sáng và gia dụng chất lượng cao.
|
|
|
43 |
`Giá`, `Công suất`, `Góc chiếu`, `Độ rọi`.
|
44 |
- Nếu là **bình giữ nhiệt hoặc phích nước**, hãy nêu rõ:
|
45 |
`Giá`, `Dung tích`, `Thời gian giữ nhiệt`.
|
46 |
+
- Ghi đầy đủ các thông tin khác về thông số kỹ thuật, mô tả sản phẩm
|
47 |
---
|
48 |
|
49 |
### Câu hỏi từ khách hàng:
|
|
|
229 |
"Hãy đọc câu hỏi của khách hàng và xác định danh mục thích hợp L1, L2, L3"
|
230 |
"Và giá thấp nhất khách hàng mua, giá cao nhất khách hàng mua"),
|
231 |
("human", "Câu hỏi: {query}. Hãy trả về danh mục thích hợp."
|
232 |
+
"Trả lời theo định dạng sau: L1(Nếu xác định được):... - L2(Nếu xác định được):.... -L3(Nếu xác định được):... - Giá thấp nhất(Nếu xác định được):... - Giá cao nhất(nếu xác định được):..."
|
233 |
+
"Nếu chỉ có thông tin về giá vẫn trả lời Giá thấp nhất:.. - Giá cao nhât:...."
|
234 |
+
"Nếu không tìm được yếu tố nào thì bỏ trống (không cần ghi yếu tố không tìm thấy ví dụ (L1: ... - Giá thấp nhất 100 000))"
|
235 |
+
"Nếu không tìm thấy danh mục nào phù hợp trả về nội dung sau: '...'")
|
236 |
])
|
237 |
+
|
238 |
+
|
239 |
+
chat_prompt_memory = PromptTemplate(
|
240 |
+
input_variables=["context", "question", "memory"],
|
241 |
+
template="""
|
242 |
+
Bạn là một trợ lý AI bán hàng của Rạng Đông Store, chuyên hỗ trợ khách hàng tìm kiếm và lựa chọn các sản phẩm chiếu sáng và gia dụng chất lượng cao.
|
243 |
+
|
244 |
+
Dưới đây là các thông tin bạn cần để trả lời khách hàng:
|
245 |
+
|
246 |
+
Lịch sử trò chuyện trước đó với khách hàng:
|
247 |
+
{memory}
|
248 |
+
|
249 |
+
Thông tin sản phẩm liên quan đến câu hỏi:
|
250 |
+
{context}
|
251 |
+
|
252 |
+
Câu hỏi hiện tại của khách hàng (tập trung trả lời câu hỏi này, quan trọng nhất):
|
253 |
+
{question}
|
254 |
+
|
255 |
+
Hãy dựa vào các thông tin trên để đưa ra câu trả lời chính xác thân thiện nhằm hỗ trợ khách hàng một cách hiệu quả nhất.
|
256 |
+
"""
|
257 |
+
)
|
preprocessing_data.ipynb
ADDED
File without changes
|