Spaces:
Sleeping
Sleeping
File size: 2,444 Bytes
602e9df f3a5d80 602e9df f3a5d80 602e9df f3a5d80 602e9df f3a5d80 602e9df f3a5d80 602e9df f3a5d80 602e9df f3a5d80 602e9df f3a5d80 602e9df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from .llm import gemini_llm
from .retrieval import load_vectordb
from .prompts import chat_prompt_no_memory, chat_prompt_memory, classification_prompt, category_tree_json
from langchain.chains import RetrievalQA
from .metadata_selfquery import metadata_field_info
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from .memory import ShortTermMemory
from .Custom_chain import MyCustomMemoryRetrievalChain
memory = ShortTermMemory()
vector_store = load_vectordb()
def classify_query(query):
response = gemini_llm.invoke(classification_prompt.format(query=query, category_tree=category_tree_json))
return response
retriever = SelfQueryRetriever.from_llm(
llm=gemini_llm,
vectorstore=vector_store,
document_contents="Thông tin sản phẩm gồm mô tả ngắn và danh mục phân cấp, giá mà khách hàng tìm kiếm",
metadata_field_info=metadata_field_info,
structured_query_translator= QdrantTranslator(metadata_key="metadata"),
search_type="similarity",
search_kwargs={"k": 10}
)
qa_chain1 = RetrievalQA.from_chain_type(
llm=gemini_llm,
retriever=retriever,
return_source_documents= False,
chain_type_kwargs={"prompt": chat_prompt_no_memory},
output_key="result"
)
qa_chain2 = MyCustomMemoryRetrievalChain(
llm= gemini_llm,
retriever= retriever,
prompt= chat_prompt_memory,
output_key="result"
)
# qa_chain2 = RetrievalQA.from_chain_type(
# llm=gemini_llm,
# retriever=retriever,
# return_source_documents=False,
# chain_type_kwargs={
# "prompt": chat_prompt_memory,
# "document_variable_name": "context"
# },
# output_key="result"
# )
def get_chat_response(user_input: str) -> str:
restriction = classify_query(user_input.strip())
if memory.related_to_cache(user_input):
print("Liên quan tới câu trước")
response = qa_chain2({"question": user_input + memory.restrict + restriction, "memory": memory.get_memory_text()})
else:
memory.reset()
classified_query = user_input + restriction
memory.restrict = restriction
print("Không liên quan tới câu trước")
response = qa_chain1({"query": classified_query})
print(restriction)
memory.add(user_input, response["result"])
print(memory.restrict)
return response["result"]
|