lib search API ready
Browse files- app_modules/init.py +16 -2
- app_modules/llm_chat_chain.py +1 -1
- app_modules/llm_inference.py +12 -5
- app_modules/llm_qa_chain.py +12 -3
- app_modules/llm_summarize_chain.py +1 -1
- server.py +13 -6
- test.py +7 -4
- web +1 -1
app_modules/init.py
CHANGED
|
@@ -79,14 +79,28 @@ def app_init(initQAChain: bool = True):
|
|
| 79 |
|
| 80 |
print(f"Completed in {end - start:.3f}s")
|
| 81 |
|
| 82 |
-
vectorstore = load_vectorstor(index_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
start = timer()
|
| 85 |
llm_loader = LLMLoader(llm_model_type)
|
| 86 |
llm_loader.init(
|
| 87 |
n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
|
| 88 |
)
|
| 89 |
-
qa_chain =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
end = timer()
|
| 91 |
print(f"Completed in {end - start:.3f}s")
|
| 92 |
|
|
|
|
| 79 |
|
| 80 |
print(f"Completed in {end - start:.3f}s")
|
| 81 |
|
| 82 |
+
vectorstore = load_vectorstor(using_faiss, index_path, embeddings)
|
| 83 |
+
|
| 84 |
+
doc_id_to_vectorstore_mapping = {}
|
| 85 |
+
rootdir = index_path
|
| 86 |
+
for file in os.listdir(rootdir):
|
| 87 |
+
d = os.path.join(rootdir, file)
|
| 88 |
+
if os.path.isdir(d):
|
| 89 |
+
v = load_vectorstor(using_faiss, d, embeddings)
|
| 90 |
+
doc_id_to_vectorstore_mapping[file] = v
|
| 91 |
+
|
| 92 |
+
# print(doc_id_to_vectorstore_mapping)
|
| 93 |
|
| 94 |
start = timer()
|
| 95 |
llm_loader = LLMLoader(llm_model_type)
|
| 96 |
llm_loader.init(
|
| 97 |
n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
|
| 98 |
)
|
| 99 |
+
qa_chain = (
|
| 100 |
+
QAChain(vectorstore, llm_loader, doc_id_to_vectorstore_mapping)
|
| 101 |
+
if initQAChain
|
| 102 |
+
else None
|
| 103 |
+
)
|
| 104 |
end = timer()
|
| 105 |
print(f"Completed in {end - start:.3f}s")
|
| 106 |
|
app_modules/llm_chat_chain.py
CHANGED
|
@@ -27,7 +27,7 @@ class ChatChain(LLMInference):
|
|
| 27 |
def __init__(self, llm_loader):
|
| 28 |
super().__init__(llm_loader)
|
| 29 |
|
| 30 |
-
def create_chain(self) -> Chain:
|
| 31 |
template = (
|
| 32 |
get_llama_2_prompt_template()
|
| 33 |
if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
|
|
|
| 27 |
def __init__(self, llm_loader):
|
| 28 |
super().__init__(llm_loader)
|
| 29 |
|
| 30 |
+
def create_chain(self, inputs) -> Chain:
|
| 31 |
template = (
|
| 32 |
get_llama_2_prompt_template()
|
| 33 |
if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
app_modules/llm_inference.py
CHANGED
|
@@ -22,12 +22,12 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
| 22 |
self.chain = None
|
| 23 |
|
| 24 |
@abc.abstractmethod
|
| 25 |
-
def create_chain(self) -> Chain:
|
| 26 |
pass
|
| 27 |
|
| 28 |
-
def get_chain(self) -> Chain:
|
| 29 |
if self.chain is None:
|
| 30 |
-
self.chain = self.create_chain()
|
| 31 |
|
| 32 |
return self.chain
|
| 33 |
|
|
@@ -48,7 +48,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
| 48 |
try:
|
| 49 |
self.llm_loader.streamer.reset(q)
|
| 50 |
|
| 51 |
-
chain = self.get_chain()
|
| 52 |
result = (
|
| 53 |
self._run_chain_with_streaming_handler(
|
| 54 |
chain, inputs, streaming_handler, testing
|
|
@@ -61,7 +61,14 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
| 61 |
result["answer"] = remove_extra_spaces(result["answer"])
|
| 62 |
|
| 63 |
source_path = os.environ.get("SOURCE_PATH")
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
documents = result["source_documents"]
|
| 66 |
for doc in documents:
|
| 67 |
source = doc.metadata["source"]
|
|
|
|
| 22 |
self.chain = None
|
| 23 |
|
| 24 |
@abc.abstractmethod
|
| 25 |
+
def create_chain(self, inputs) -> Chain:
|
| 26 |
pass
|
| 27 |
|
| 28 |
+
def get_chain(self, inputs) -> Chain:
|
| 29 |
if self.chain is None:
|
| 30 |
+
self.chain = self.create_chain(inputs)
|
| 31 |
|
| 32 |
return self.chain
|
| 33 |
|
|
|
|
| 48 |
try:
|
| 49 |
self.llm_loader.streamer.reset(q)
|
| 50 |
|
| 51 |
+
chain = self.get_chain(inputs)
|
| 52 |
result = (
|
| 53 |
self._run_chain_with_streaming_handler(
|
| 54 |
chain, inputs, streaming_handler, testing
|
|
|
|
| 61 |
result["answer"] = remove_extra_spaces(result["answer"])
|
| 62 |
|
| 63 |
source_path = os.environ.get("SOURCE_PATH")
|
| 64 |
+
base_url = os.environ.get("PDF_FILE_BASE_URL")
|
| 65 |
+
if base_url is not None and len(base_url) > 0:
|
| 66 |
+
documents = result["source_documents"]
|
| 67 |
+
for doc in documents:
|
| 68 |
+
source = doc.metadata["source"]
|
| 69 |
+
title = source.split("/")[-1]
|
| 70 |
+
doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
|
| 71 |
+
elif source_path is not None and len(source_path) > 0:
|
| 72 |
documents = result["source_documents"]
|
| 73 |
for doc in documents:
|
| 74 |
source = doc.metadata["source"]
|
app_modules/llm_qa_chain.py
CHANGED
|
@@ -8,14 +8,23 @@ from app_modules.llm_inference import LLMInference
|
|
| 8 |
class QAChain(LLMInference):
|
| 9 |
vectorstore: VectorStore
|
| 10 |
|
| 11 |
-
def __init__(self, vectorstore, llm_loader):
|
| 12 |
super().__init__(llm_loader)
|
| 13 |
self.vectorstore = vectorstore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
def create_chain(self) -> Chain:
|
| 16 |
qa = ConversationalRetrievalChain.from_llm(
|
| 17 |
self.llm_loader.llm,
|
| 18 |
-
|
| 19 |
max_tokens_limit=self.llm_loader.max_tokens_limit,
|
| 20 |
return_source_documents=True,
|
| 21 |
)
|
|
|
|
| 8 |
class QAChain(LLMInference):
|
| 9 |
vectorstore: VectorStore
|
| 10 |
|
| 11 |
+
def __init__(self, vectorstore, llm_loader, doc_id_to_vectorstore_mapping=None):
|
| 12 |
super().__init__(llm_loader)
|
| 13 |
self.vectorstore = vectorstore
|
| 14 |
+
self.doc_id_to_vectorstore_mapping = doc_id_to_vectorstore_mapping
|
| 15 |
+
|
| 16 |
+
def get_chain(self, inputs) -> Chain:
|
| 17 |
+
return self.create_chain(inputs)
|
| 18 |
+
|
| 19 |
+
def create_chain(self, inputs) -> Chain:
|
| 20 |
+
vectorstore = self.vectorstore
|
| 21 |
+
if "chat_id" in inputs:
|
| 22 |
+
if inputs["chat_id"] in self.doc_id_to_vectorstore_mapping:
|
| 23 |
+
vectorstore = self.doc_id_to_vectorstore_mapping[inputs["chat_id"]]
|
| 24 |
|
|
|
|
| 25 |
qa = ConversationalRetrievalChain.from_llm(
|
| 26 |
self.llm_loader.llm,
|
| 27 |
+
vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
|
| 28 |
max_tokens_limit=self.llm_loader.max_tokens_limit,
|
| 29 |
return_source_documents=True,
|
| 30 |
)
|
app_modules/llm_summarize_chain.py
CHANGED
|
@@ -23,7 +23,7 @@ class SummarizeChain(LLMInference):
|
|
| 23 |
def __init__(self, llm_loader):
|
| 24 |
super().__init__(llm_loader)
|
| 25 |
|
| 26 |
-
def create_chain(self) -> Chain:
|
| 27 |
use_llama_2_prompt_template = (
|
| 28 |
os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
| 29 |
)
|
|
|
|
| 23 |
def __init__(self, llm_loader):
|
| 24 |
super().__init__(llm_loader)
|
| 25 |
|
| 26 |
+
def create_chain(self, inputs) -> Chain:
|
| 27 |
use_llama_2_prompt_template = (
|
| 28 |
os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
| 29 |
)
|
server.py
CHANGED
|
@@ -28,11 +28,11 @@ class ChatResponse(BaseModel):
|
|
| 28 |
|
| 29 |
def do_chat(
|
| 30 |
question: str,
|
| 31 |
-
history: Optional[List] =
|
| 32 |
chat_id: Optional[str] = None,
|
| 33 |
streaming_handler: any = None,
|
| 34 |
):
|
| 35 |
-
if
|
| 36 |
chat_history = []
|
| 37 |
if chat_history_enabled:
|
| 38 |
for element in history:
|
|
@@ -41,7 +41,8 @@ def do_chat(
|
|
| 41 |
|
| 42 |
start = timer()
|
| 43 |
result = qa_chain.call_chain(
|
| 44 |
-
{"question": question, "chat_history": chat_history},
|
|
|
|
| 45 |
)
|
| 46 |
end = timer()
|
| 47 |
print(f"Completed in {end - start:.3f}s")
|
|
@@ -61,20 +62,26 @@ def do_chat(
|
|
| 61 |
|
| 62 |
@serving(websocket=True)
|
| 63 |
def chat(
|
| 64 |
-
question: str,
|
|
|
|
|
|
|
|
|
|
| 65 |
) -> str:
|
| 66 |
print("question@chat:", question)
|
| 67 |
streaming_handler = kwargs.get("streaming_handler")
|
| 68 |
result = do_chat(question, history, chat_id, streaming_handler)
|
| 69 |
resp = ChatResponse(
|
| 70 |
-
sourceDocs=result["source_documents"] if
|
| 71 |
)
|
| 72 |
return json.dumps(resp.dict())
|
| 73 |
|
| 74 |
|
| 75 |
@serving
|
| 76 |
def chat_sync(
|
| 77 |
-
question: str,
|
|
|
|
|
|
|
|
|
|
| 78 |
) -> str:
|
| 79 |
print("question@chat_sync:", question)
|
| 80 |
result = do_chat(question, history, chat_id, None)
|
|
|
|
| 28 |
|
| 29 |
def do_chat(
|
| 30 |
question: str,
|
| 31 |
+
history: Optional[List] = None,
|
| 32 |
chat_id: Optional[str] = None,
|
| 33 |
streaming_handler: any = None,
|
| 34 |
):
|
| 35 |
+
if history is not None:
|
| 36 |
chat_history = []
|
| 37 |
if chat_history_enabled:
|
| 38 |
for element in history:
|
|
|
|
| 41 |
|
| 42 |
start = timer()
|
| 43 |
result = qa_chain.call_chain(
|
| 44 |
+
{"question": question, "chat_history": chat_history, "chat_id": chat_id},
|
| 45 |
+
streaming_handler,
|
| 46 |
)
|
| 47 |
end = timer()
|
| 48 |
print(f"Completed in {end - start:.3f}s")
|
|
|
|
| 62 |
|
| 63 |
@serving(websocket=True)
|
| 64 |
def chat(
|
| 65 |
+
question: str,
|
| 66 |
+
history: Optional[List] = None,
|
| 67 |
+
chat_id: Optional[str] = None,
|
| 68 |
+
**kwargs,
|
| 69 |
) -> str:
|
| 70 |
print("question@chat:", question)
|
| 71 |
streaming_handler = kwargs.get("streaming_handler")
|
| 72 |
result = do_chat(question, history, chat_id, streaming_handler)
|
| 73 |
resp = ChatResponse(
|
| 74 |
+
sourceDocs=result["source_documents"] if history is not None else []
|
| 75 |
)
|
| 76 |
return json.dumps(resp.dict())
|
| 77 |
|
| 78 |
|
| 79 |
@serving
|
| 80 |
def chat_sync(
|
| 81 |
+
question: str,
|
| 82 |
+
history: Optional[List] = None,
|
| 83 |
+
chat_id: Optional[str] = None,
|
| 84 |
+
**kwargs,
|
| 85 |
) -> str:
|
| 86 |
print("question@chat_sync:", question)
|
| 87 |
result = do_chat(question, history, chat_id, None)
|
test.py
CHANGED
|
@@ -30,6 +30,7 @@ class MyCustomHandler(BaseCallbackHandler):
|
|
| 30 |
|
| 31 |
|
| 32 |
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
|
|
|
|
| 33 |
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
| 34 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
|
| 35 |
|
|
@@ -68,8 +69,9 @@ while True:
|
|
| 68 |
custom_handler.reset()
|
| 69 |
|
| 70 |
start = timer()
|
|
|
|
| 71 |
result = qa_chain.call_chain(
|
| 72 |
-
|
| 73 |
custom_handler,
|
| 74 |
None,
|
| 75 |
True,
|
|
@@ -87,13 +89,14 @@ while True:
|
|
| 87 |
if standalone_question is not None:
|
| 88 |
print(f"Load relevant documents for standalone question: {standalone_question}")
|
| 89 |
start = timer()
|
| 90 |
-
qa = qa_chain.get_chain()
|
| 91 |
docs = qa.retriever.get_relevant_documents(standalone_question)
|
| 92 |
end = timer()
|
| 93 |
-
|
| 94 |
-
# print(docs)
|
| 95 |
print(f"Completed in {end - start:.3f}s")
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
if chat_history_enabled == "true":
|
| 98 |
chat_history.append((query, result["answer"]))
|
| 99 |
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
|
| 33 |
+
chat_id = sys.argv[2] if len(sys.argv) > 2 else None
|
| 34 |
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
| 35 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
|
| 36 |
|
|
|
|
| 69 |
custom_handler.reset()
|
| 70 |
|
| 71 |
start = timer()
|
| 72 |
+
inputs = {"question": query, "chat_history": chat_history, "chat_id": chat_id}
|
| 73 |
result = qa_chain.call_chain(
|
| 74 |
+
inputs,
|
| 75 |
custom_handler,
|
| 76 |
None,
|
| 77 |
True,
|
|
|
|
| 89 |
if standalone_question is not None:
|
| 90 |
print(f"Load relevant documents for standalone question: {standalone_question}")
|
| 91 |
start = timer()
|
| 92 |
+
qa = qa_chain.get_chain(inputs)
|
| 93 |
docs = qa.retriever.get_relevant_documents(standalone_question)
|
| 94 |
end = timer()
|
|
|
|
|
|
|
| 95 |
print(f"Completed in {end - start:.3f}s")
|
| 96 |
|
| 97 |
+
if chatting:
|
| 98 |
+
print(docs)
|
| 99 |
+
|
| 100 |
if chat_history_enabled == "true":
|
| 101 |
chat_history.append((query, result["answer"]))
|
| 102 |
|
web
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 15f2b72afe6170badfb982c7adba585af30d578a
|