trantuan1701's picture
Add multi-turn conversation feature
f3a5d80
raw
history blame
1.5 kB
from collections import deque
from difflib import SequenceMatcher
from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain
from .llm import gemini_llm
related_prompt = PromptTemplate(
input_variables=["q1", "q2"],
template="""
Bạn là một trợ lý thông minh. Nhiệm vụ của bạn là xác định xem hai câu hỏi hiện tại, có phải là nối tiếp hay cùng mục đích, liên quan tới ngữ cảnh trước đó không:
Câu hỏi hiện tại: {q1}
Hội thoại trước đó: {q2} (nếu không có gì, trả lời là không)
Hai câu hỏi này có liên quan không? Trả lời ngắn gọn: Có hoặc Không.
"""
)
check_related_chain = LLMChain(llm=gemini_llm, prompt=related_prompt)
class ShortTermMemory:
def __init__(self, maxlen=3):
self.cache = deque(maxlen=maxlen)
self.restrict = ""
def is_similar(self, q1, q2):
response = check_related_chain.invoke({"q1": q1, "q2": q2})
print(response['text'])
return "có" in response['text'].lower()
def related_to_cache(self, query):
return self.is_similar(query, self.get_memory_text())
def add(self, query, answer):
text_entry = f"Người dùng hỏi: {query}\n Hệ thống trả lời: {answer}"
self.cache.append(text_entry)
def reset(self):
self.cache.clear()
self.restrict = ""
def get_memory_text(self):
return "\n".join(self.cache)