Spaces:
Sleeping
Sleeping
from langchain.chains.base import Chain | |
from langchain.schema import BaseRetriever | |
from langchain.llms import BaseLLM | |
from langchain.prompts import PromptTemplate | |
from pydantic import Field | |
from typing import Dict, Any | |
class MyCustomMemoryRetrievalChain(Chain): | |
""" | |
Custom chain cho phép truyền question, memory. | |
Lấy docs từ retriever, trộn với prompt, gọi LLM. | |
""" | |
llm: BaseLLM = Field(...) | |
retriever: BaseRetriever = Field(...) | |
prompt: PromptTemplate = Field(...) | |
output_key: str = "result" | |
def input_keys(self) -> list: | |
return ["question", "memory"] | |
def output_keys(self) -> list: | |
return [self.output_key] | |
def _call(self, inputs: Dict[str, Any], run_manager=None) -> Dict[str, Any]: | |
question = inputs["question"] | |
memory = inputs["memory"] | |
docs = self.retriever.get_relevant_documents(question) | |
context = "\n".join(doc.page_content for doc in docs) | |
final_prompt = self.prompt.format( | |
question=question, | |
memory=memory, | |
context=context | |
) | |
answer = self.llm(final_prompt) | |
return {self.output_key: answer} | |