Spaces:
Sleeping
Sleeping
from .llm import gemini_llm | |
from .retrieval import load_vectordb | |
from .prompts import chat_prompt_no_memory, chat_prompt_memory, classification_prompt, category_tree_json | |
from langchain.chains import RetrievalQA | |
from .metadata_selfquery import metadata_field_info | |
from langchain.retrievers.self_query.base import SelfQueryRetriever | |
from langchain.retrievers.self_query.qdrant import QdrantTranslator | |
from .memory import ShortTermMemory | |
from .Custom_chain import MyCustomMemoryRetrievalChain | |
memory = ShortTermMemory() | |
vector_store = load_vectordb() | |
def classify_query(query): | |
response = gemini_llm.invoke(classification_prompt.format(query=query, category_tree=category_tree_json)) | |
return response | |
retriever = SelfQueryRetriever.from_llm( | |
llm=gemini_llm, | |
vectorstore=vector_store, | |
document_contents="Thông tin sản phẩm gồm mô tả ngắn và danh mục phân cấp, giá mà khách hàng tìm kiếm", | |
metadata_field_info=metadata_field_info, | |
structured_query_translator= QdrantTranslator(metadata_key="metadata"), | |
search_type="similarity", | |
search_kwargs={"k": 10} | |
) | |
qa_chain1 = RetrievalQA.from_chain_type( | |
llm=gemini_llm, | |
retriever=retriever, | |
return_source_documents= False, | |
chain_type_kwargs={"prompt": chat_prompt_no_memory}, | |
output_key="result" | |
) | |
qa_chain2 = MyCustomMemoryRetrievalChain( | |
llm= gemini_llm, | |
retriever= retriever, | |
prompt= chat_prompt_memory, | |
output_key="result" | |
) | |
# qa_chain2 = RetrievalQA.from_chain_type( | |
# llm=gemini_llm, | |
# retriever=retriever, | |
# return_source_documents=False, | |
# chain_type_kwargs={ | |
# "prompt": chat_prompt_memory, | |
# "document_variable_name": "context" | |
# }, | |
# output_key="result" | |
# ) | |
def get_chat_response(user_input: str) -> str: | |
restriction = classify_query(user_input.strip()) | |
if memory.related_to_cache(user_input): | |
print("Liên quan tới câu trước") | |
response = qa_chain2({"question": user_input + memory.restrict + restriction, "memory": memory.get_memory_text()}) | |
else: | |
memory.reset() | |
classified_query = user_input + restriction | |
memory.restrict = restriction | |
print("Không liên quan tới câu trước") | |
response = qa_chain1({"query": classified_query}) | |
print(restriction) | |
memory.add(user_input, response["result"]) | |
print(memory.restrict) | |
return response["result"] | |