trantuan1701's picture
Add multi-turn conversation feature
f3a5d80
raw
history blame
2.44 kB
from .llm import gemini_llm
from .retrieval import load_vectordb
from .prompts import chat_prompt_no_memory, chat_prompt_memory, classification_prompt, category_tree_json
from langchain.chains import RetrievalQA
from .metadata_selfquery import metadata_field_info
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from .memory import ShortTermMemory
from .Custom_chain import MyCustomMemoryRetrievalChain
memory = ShortTermMemory()
vector_store = load_vectordb()
def classify_query(query):
response = gemini_llm.invoke(classification_prompt.format(query=query, category_tree=category_tree_json))
return response
retriever = SelfQueryRetriever.from_llm(
llm=gemini_llm,
vectorstore=vector_store,
document_contents="Thông tin sản phẩm gồm mô tả ngắn và danh mục phân cấp, giá mà khách hàng tìm kiếm",
metadata_field_info=metadata_field_info,
structured_query_translator= QdrantTranslator(metadata_key="metadata"),
search_type="similarity",
search_kwargs={"k": 10}
)
qa_chain1 = RetrievalQA.from_chain_type(
llm=gemini_llm,
retriever=retriever,
return_source_documents= False,
chain_type_kwargs={"prompt": chat_prompt_no_memory},
output_key="result"
)
qa_chain2 = MyCustomMemoryRetrievalChain(
llm= gemini_llm,
retriever= retriever,
prompt= chat_prompt_memory,
output_key="result"
)
# qa_chain2 = RetrievalQA.from_chain_type(
# llm=gemini_llm,
# retriever=retriever,
# return_source_documents=False,
# chain_type_kwargs={
# "prompt": chat_prompt_memory,
# "document_variable_name": "context"
# },
# output_key="result"
# )
def get_chat_response(user_input: str) -> str:
restriction = classify_query(user_input.strip())
if memory.related_to_cache(user_input):
print("Liên quan tới câu trước")
response = qa_chain2({"question": user_input + memory.restrict + restriction, "memory": memory.get_memory_text()})
else:
memory.reset()
classified_query = user_input + restriction
memory.restrict = restriction
print("Không liên quan tới câu trước")
response = qa_chain1({"query": classified_query})
print(restriction)
memory.add(user_input, response["result"])
print(memory.restrict)
return response["result"]