from .llm import gemini_llm from .retrieval import load_vectordb from .prompts import chat_prompt_no_memory, chat_prompt_memory, classification_prompt, category_tree_json from langchain.chains import RetrievalQA from .metadata_selfquery import metadata_field_info from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain.retrievers.self_query.qdrant import QdrantTranslator from .memory import ShortTermMemory from .Custom_chain import MyCustomMemoryRetrievalChain memory = ShortTermMemory() vector_store = load_vectordb() def classify_query(query): response = gemini_llm.invoke(classification_prompt.format(query=query, category_tree=category_tree_json)) return response retriever = SelfQueryRetriever.from_llm( llm=gemini_llm, vectorstore=vector_store, document_contents="Thông tin sản phẩm gồm mô tả ngắn và danh mục phân cấp, giá mà khách hàng tìm kiếm", metadata_field_info=metadata_field_info, structured_query_translator= QdrantTranslator(metadata_key="metadata"), search_type="similarity", search_kwargs={"k": 10} ) qa_chain1 = RetrievalQA.from_chain_type( llm=gemini_llm, retriever=retriever, return_source_documents= False, chain_type_kwargs={"prompt": chat_prompt_no_memory}, output_key="result" ) qa_chain2 = MyCustomMemoryRetrievalChain( llm= gemini_llm, retriever= retriever, prompt= chat_prompt_memory, output_key="result" ) # qa_chain2 = RetrievalQA.from_chain_type( # llm=gemini_llm, # retriever=retriever, # return_source_documents=False, # chain_type_kwargs={ # "prompt": chat_prompt_memory, # "document_variable_name": "context" # }, # output_key="result" # ) def get_chat_response(user_input: str) -> str: restriction = classify_query(user_input.strip()) if memory.related_to_cache(user_input): print("Liên quan tới câu trước") response = qa_chain2({"question": user_input + memory.restrict + restriction, "memory": memory.get_memory_text()}) else: memory.reset() classified_query = user_input + restriction memory.restrict = restriction print("Không liên quan tới câu trước") response = qa_chain1({"query": classified_query}) print(restriction) memory.add(user_input, response["result"]) print(memory.restrict) return response["result"]