File size: 2,444 Bytes
602e9df
 
f3a5d80
 
602e9df
 
 
f3a5d80
 
 
602e9df
 
 
 
 
 
 
 
 
 
 
 
 
f3a5d80
602e9df
 
f3a5d80
602e9df
 
 
f3a5d80
602e9df
 
f3a5d80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602e9df
 
f3a5d80
602e9df
f3a5d80
 
 
 
 
 
 
 
 
 
 
 
 
602e9df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from .llm import gemini_llm  
from .retrieval import load_vectordb
from .prompts import chat_prompt_no_memory, chat_prompt_memory, classification_prompt, category_tree_json
from langchain.chains import RetrievalQA
from .metadata_selfquery import metadata_field_info
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from .memory import ShortTermMemory
from .Custom_chain import MyCustomMemoryRetrievalChain
memory = ShortTermMemory()
vector_store = load_vectordb()

def classify_query(query):
    response = gemini_llm.invoke(classification_prompt.format(query=query, category_tree=category_tree_json))
    return response

retriever = SelfQueryRetriever.from_llm(
    llm=gemini_llm,
    vectorstore=vector_store,
    document_contents="Thông tin sản phẩm gồm mô tả ngắn và danh mục phân cấp, giá mà khách hàng tìm kiếm",
    metadata_field_info=metadata_field_info,
    structured_query_translator= QdrantTranslator(metadata_key="metadata"),
    search_type="similarity",
    search_kwargs={"k": 10}
)

qa_chain1 = RetrievalQA.from_chain_type(
    llm=gemini_llm, 
    retriever=retriever, 
    return_source_documents= False,  
    chain_type_kwargs={"prompt": chat_prompt_no_memory},
    output_key="result" 
)
qa_chain2 = MyCustomMemoryRetrievalChain(
    llm= gemini_llm,
    retriever= retriever,
    prompt= chat_prompt_memory,
    output_key="result"
)

# qa_chain2 = RetrievalQA.from_chain_type(
#     llm=gemini_llm,
#     retriever=retriever,
#     return_source_documents=False,
#     chain_type_kwargs={
#         "prompt": chat_prompt_memory,
#         "document_variable_name": "context"
#     },
#     output_key="result"
# )


def get_chat_response(user_input: str) -> str:
    restriction = classify_query(user_input.strip())

    if memory.related_to_cache(user_input):
        print("Liên quan tới câu trước")
        response = qa_chain2({"question": user_input + memory.restrict + restriction, "memory": memory.get_memory_text()})
        
    else:
        memory.reset()
        classified_query = user_input + restriction
        memory.restrict = restriction
        print("Không liên quan tới câu trước")
        response = qa_chain1({"query": classified_query})
    print(restriction)
    memory.add(user_input, response["result"])
    print(memory.restrict)
    return response["result"]