|
""" |
|
Main app for LISA RAG chatbot based on langchain. |
|
""" |
|
|
|
import os |
|
import time |
|
import re |
|
import gradio as gr |
|
import pickle |
|
|
|
from pathlib import Path |
|
from dotenv import load_dotenv |
|
|
|
from huggingface_hub import login |
|
from langchain.vectorstores import FAISS |
|
|
|
from llms import get_groq_chat |
|
from documents import load_pdf_as_docs, load_xml_as_docs |
|
from vectorestores import get_faiss_vectorestore |
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] |
|
login(HUGGINGFACEHUB_API_TOKEN) |
|
TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] |
|
|
|
|
|
|
|
|
|
database_root = "./data/db" |
|
document_path = "./data/documents" |
|
|
|
|
|
|
|
def load_from_pickle(filename): |
|
with open(filename, "rb") as file: |
|
return pickle.load(file) |
|
|
|
|
|
|
|
docs = load_from_pickle(os.path.join(database_root, "docs.pkl")) |
|
|
|
|
|
document_chunks = load_from_pickle(os.path.join(database_root, "docs_chunks.pkl")) |
|
|
|
|
|
from embeddings import get_jinaai_embeddings |
|
|
|
embeddings = get_jinaai_embeddings(device="auto") |
|
print("embedding loaded") |
|
|
|
|
|
vectorstore = FAISS.load_local( |
|
os.path.join(database_root, "faiss_index"), |
|
embeddings, |
|
allow_dangerous_deserialization=True, |
|
) |
|
print("vectorestore loaded") |
|
|
|
|
|
from retrievers import get_parent_doc_retriever, get_rerank_retriever |
|
|
|
docstore = load_from_pickle(os.path.join(database_root, "docstore.pkl")) |
|
parent_doc_retriver = get_parent_doc_retriever( |
|
docs, |
|
vectorstore, |
|
save_path_root=database_root, |
|
docstore=docstore, |
|
add_documents=False, |
|
) |
|
|
|
|
|
from langchain.retrievers import BM25Retriever, EnsembleRetriever |
|
|
|
bm25_retriever = BM25Retriever.from_documents( |
|
document_chunks, k=5 |
|
) |
|
|
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5] |
|
) |
|
|
|
|
|
from rerank import BgeRerank |
|
|
|
reranker = BgeRerank() |
|
rerank_retriever = get_rerank_retriever(ensemble_retriever, reranker) |
|
print("rerank loaded") |
|
|
|
|
|
|
|
llm = get_groq_chat(model_name="llama-3.3-70b-versatile") |
|
|
|
|
|
|
|
from ragchain import RAGChain |
|
|
|
rag_chain = RAGChain() |
|
lisa_qa_conversation = rag_chain.create(rerank_retriever, llm, add_citation=True) |
|
|
|
|
|
from langchain_community.retrievers import TavilySearchAPIRetriever |
|
from langchain.chains import RetrievalQAWithSourcesChain |
|
|
|
web_search_retriever = TavilySearchAPIRetriever(k=4) |
|
web_qa_chain = RetrievalQAWithSourcesChain.from_chain_type( |
|
llm, retriever=web_search_retriever, return_source_documents=True |
|
) |
|
print("chains loaded") |
|
|
|
|
|
|
|
def check_input_text(text): |
|
"""Check input text (question).""" |
|
|
|
if not text: |
|
gr.Warning("Please input a question.") |
|
raise TypeError |
|
|
|
return True |
|
|
|
|
|
def add_text(history, text): |
|
"""Add conversation to history message.""" |
|
|
|
history = history + [(text, None)] |
|
yield history, "" |
|
|
|
|
|
def postprocess_remove_cite_misinfo(text, allowed_max_cite_num=6): |
|
"""Heuristic removal of misinfo. of citations.""" |
|
|
|
|
|
if "References:\n[" in text: |
|
text = text.split("References:\n")[0] |
|
|
|
source_ids = re.findall(r"(\[.*?\]+)", text) |
|
pattern = r"(,*? *?\[.*?\]+)" |
|
print(f"source ids by re: {source_ids}") |
|
|
|
|
|
def replace_and_increment(match): |
|
|
|
match_str = match.group(1) |
|
|
|
|
|
|
|
if "–" in match_str or "-" in match_str: |
|
return "" |
|
|
|
|
|
if "i" in match_str: |
|
return "" |
|
|
|
|
|
|
|
pattern = r"(\d+)" |
|
nums = re.findall(pattern, match_str) |
|
if nums: |
|
nums_list = [] |
|
for n in nums: |
|
if int(n) <= allowed_max_cite_num: |
|
nums_list.append("[[" + n + "]]") |
|
|
|
else: |
|
return "" |
|
|
|
if re.search("^,", match_str): |
|
return ( |
|
'<sup><span style="color:#F27F0C">' |
|
+ ", " |
|
+ ", ".join(nums_list) |
|
+ "</span></sup>" |
|
) |
|
|
|
return ( |
|
'<sup><span style="color:#F27F0C">' |
|
+ " " |
|
+ ", ".join(nums_list) |
|
+ "</span></sup>" |
|
) |
|
|
|
|
|
new_text = re.sub(pattern, replace_and_increment, text) |
|
|
|
|
|
if "\n\n [" in new_text: |
|
new_text = new_text.split("\n\n [")[0] |
|
if "\n\n[" in new_text: |
|
new_text = new_text.split("\n\n[")[0] |
|
|
|
|
|
new_text = new_text.strip() |
|
|
|
return new_text |
|
|
|
|
|
def postprocess_citation(text, source_docs): |
|
"""Postprocess text for extracting citations.""" |
|
|
|
|
|
|
|
source_ids = re.findall(r"\[(\d*)\]", text) |
|
|
|
|
|
aligned_source_ids = list(map(lambda x: int(x) - 1, source_ids)) |
|
|
|
|
|
candidate_source_ids = list(range(len(source_docs))) |
|
filtered_source_ids = set( |
|
[i for i in aligned_source_ids if i in candidate_source_ids] |
|
) |
|
filtered_docs = [source_docs[i] for i in filtered_source_ids] |
|
output_markdown = "" |
|
for i, d in zip(filtered_source_ids, filtered_docs): |
|
|
|
|
|
index = i + 1 |
|
source = d.metadata["source"] |
|
content = d.page_content.strip().replace("\n", " ") |
|
source_info = f"<b>[[{index}]] {source}</b>" |
|
item = f""" |
|
<details> |
|
<summary>{source_info}</summary> |
|
|
|
<blockquote cite=""> |
|
<p>{content}</p> |
|
</blockquote> |
|
</details> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
output_markdown += item |
|
|
|
|
|
return output_markdown |
|
|
|
|
|
def postprocess_web_citation(text, qa_result): |
|
"""Postprocess text for extracting web citations.""" |
|
|
|
|
|
|
|
if qa_result["sources"]: |
|
|
|
web_sources = qa_result["sources"].split(",") |
|
web_sources = [ |
|
s.strip().replace(">", "").replace("<", "").replace(",", "") |
|
for s in web_sources |
|
] |
|
else: |
|
web_sources = [doc.metadata["source"] for doc in qa_result["source_documents"]] |
|
output_markdown = "" |
|
for i, d in enumerate(web_sources): |
|
index = i + 1 |
|
source = d |
|
item = f""" |
|
<p><a href="{source}/" target="_blank" rel="noopener noreferrer">[{index}]. {source}</a></p> |
|
|
|
""" |
|
output_markdown += item |
|
return output_markdown |
|
|
|
|
|
def bot_lisa(history, flag_web_search): |
|
"""Get answer from LLM.""" |
|
|
|
if not flag_web_search: |
|
result = lisa_qa_conversation( |
|
{ |
|
"question": history[-1][0], |
|
"chat_history": history[:-1], |
|
} |
|
) |
|
if result is None: |
|
raise gr.Error("Sorry, failed to get answer from LLM, please try again.") |
|
|
|
|
|
print(f"Answer: {result['answer']}") |
|
print(f"Source document: {result['source_documents']}") |
|
|
|
answer_text = result["answer"].strip() |
|
|
|
answer_text = postprocess_remove_cite_misinfo(answer_text) |
|
|
|
|
|
citation_text = postprocess_citation(answer_text, result["source_documents"]) |
|
|
|
else: |
|
result = web_qa_chain( |
|
{ |
|
"question": history[-1][0], |
|
|
|
} |
|
) |
|
if result is None: |
|
raise gr.Error("Sorry, failed to get answer from LLM, please try again.") |
|
|
|
|
|
answer_text = result["answer"].strip() |
|
citation_text = postprocess_web_citation(answer_text, result) |
|
|
|
|
|
|
|
|
|
|
|
history[-1][1] = "" |
|
for character in answer_text: |
|
time.sleep(0.002) |
|
history[-1][1] += character |
|
yield history, citation_text |
|
|
|
|
|
def bot(history, qa_conversation): |
|
"""Get answer from LLM, so custom document.""" |
|
|
|
|
|
if qa_conversation is None: |
|
gr.Warning("Please upload a document first.") |
|
|
|
result = qa_conversation( |
|
{ |
|
"question": history[-1][0], |
|
"chat_history": history[:-1], |
|
} |
|
) |
|
|
|
if result is None: |
|
return "", "" |
|
|
|
print(f"Source document: {result['source_documents']}") |
|
answer_text = result["answer"].strip() |
|
|
|
answer_text = postprocess_remove_cite_misinfo(answer_text) |
|
|
|
citation_text = postprocess_citation(answer_text, result["source_documents"]) |
|
|
|
history[-1][1] = "" |
|
for character in answer_text: |
|
time.sleep(0.002) |
|
history[-1][1] += character |
|
yield history, citation_text |
|
|
|
|
|
def document_changes(doc_path): |
|
"""Parse user document.""" |
|
|
|
max_file_num = 3 |
|
|
|
if doc_path is None: |
|
gr.Warning("Please choose a document first and wait until uploaded.") |
|
return ( |
|
"Please choose a document and wait until uploaded.", |
|
None, |
|
) |
|
|
|
print("now reading document") |
|
print(f"file is located at {doc_path[0]}") |
|
|
|
documents = [] |
|
for doc in doc_path[:max_file_num]: |
|
file_extension = Path(doc).suffix |
|
if file_extension == ".pdf": |
|
documents.extend(load_pdf_as_docs(doc)) |
|
elif file_extension == ".xml": |
|
documents.extend(load_xml_as_docs(doc)) |
|
|
|
print("now creating vectordatabase") |
|
|
|
vectorstore = get_faiss_vectorestore(embeddings) |
|
parent_doc_retriever = get_parent_doc_retriever(documents, vectorstore) |
|
rerank_retriever = get_rerank_retriever(parent_doc_retriever, reranker) |
|
|
|
print("now getting llm model") |
|
|
|
llm = get_groq_chat(model_name="llama-3.1-70b-versatile") |
|
|
|
rag_chain = RAGChain() |
|
|
|
|
|
qa_conversation = rag_chain.create(rerank_retriever, llm, add_citation=True) |
|
|
|
|
|
|
|
|
|
file_name = Path(doc_path[0]).name |
|
return f"Ready for {file_name} etc.", qa_conversation |
|
|
|
|
|
|
|
def main(): |
|
"""Gradio interface.""" |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
|
|
gr.Markdown("## LISA - Lithium Ion Solid-state Assistant") |
|
gr.Markdown( |
|
""" |
|
Q&A research assistant for efficient Knowledge Management not only in Battery Science. |
|
Based on RAG-architecture and powered by Large Language Models (LLMs).""" |
|
) |
|
|
|
with gr.Tab("LISA ⚡"): |
|
with gr.Row(): |
|
with gr.Column(scale=7): |
|
|
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbot", |
|
label="Document Assistant", |
|
bubble_full_width=False, |
|
show_copy_button=True, |
|
|
|
) |
|
|
|
user_txt = gr.Textbox( |
|
label="Question", |
|
placeholder="Type in the question and press Enter/click Submit", |
|
) |
|
|
|
with gr.Accordion("Advanced", open=False): |
|
flag_web_search = gr.Checkbox( |
|
label="Search web", info="Search information from Internet" |
|
) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
with gr.Column(scale=1): |
|
clear_btn = gr.Button("Clear", variant="stop") |
|
|
|
|
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"Please name two common solid electrolytes.", |
|
"Please name two common oxide solid electrolytes.", |
|
"Please tell me what is solid-state battery.", |
|
"How to synthesize gc-LPSC?", |
|
"Please tell me the purpose of Kadi4Mat.", |
|
"Who is working on Kadi4Mat?", |
|
"Can you recommend a paper to get a deeper understanding of Kadi4Mat?", |
|
|
|
], |
|
inputs=user_txt, |
|
outputs=chatbot, |
|
fn=add_text, |
|
label="Try asking...", |
|
|
|
cache_examples=False, |
|
examples_per_page=3, |
|
) |
|
|
|
|
|
|
|
with gr.Column(scale=3): |
|
with gr.Tab("References"): |
|
doc_citation = gr.HTML( |
|
"<p>References used in answering the question will be displayed below.</p>" |
|
) |
|
|
|
with gr.Tab("Setting"): |
|
|
|
|
|
gr.Markdown("More in DEV...") |
|
|
|
|
|
user_txt.submit(check_input_text, user_txt, None).success( |
|
add_text, [chatbot, user_txt], [chatbot, user_txt] |
|
).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation]) |
|
|
|
submit_btn.click(check_input_text, user_txt, None).success( |
|
add_text, |
|
[chatbot, user_txt], |
|
[chatbot, user_txt], |
|
|
|
|
|
).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation]) |
|
|
|
clear_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Upload document 📚"): |
|
qa_conversation = gr.State( |
|
"placeholder", time_to_live=3600 |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=7, variant="chat_panel"): |
|
chatbot_docqa = gr.Chatbot( |
|
[], |
|
elem_id="chatbot_docqa", |
|
label="Document Assistant", |
|
show_copy_button=True, |
|
likeable=True, |
|
) |
|
docqa_question = gr.Textbox( |
|
label="Question", |
|
placeholder="Type in the question and press Enter/click Submit", |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=50): |
|
docqa_submit_btn = gr.Button("Submit", variant="primary") |
|
with gr.Column(scale=50): |
|
docqa_clear_btn = gr.Button("Clear", variant="stop") |
|
|
|
gr.Examples( |
|
examples=[ |
|
"Summarize the paper", |
|
"Summarize the paper in 3 bullet points", |
|
|
|
"What are the contributions of this paper", |
|
"Explain the practical implications of this paper", |
|
"Methods used in this paper", |
|
"What data has been used in this paper", |
|
"Results of the paper", |
|
"Conclusions from the paper", |
|
"Limitations of this paper", |
|
"Future works suggested in this paper", |
|
], |
|
inputs=docqa_question, |
|
outputs=chatbot_docqa, |
|
fn=add_text, |
|
label="Example questions for single document.", |
|
|
|
cache_examples=False, |
|
examples_per_page=4, |
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
with gr.Tab("Load"): |
|
|
|
with gr.Row(): |
|
gr.HTML( |
|
"Upload pdf/xml file(s), click the Load file button. After preprocessing, you can start asking questions about the document. (Please do not share sensitive document)" |
|
) |
|
with gr.Row(): |
|
uploaded_doc = gr.File( |
|
label="Upload pdf/xml (max. 3) file(s)", |
|
file_count="multiple", |
|
file_types=[".pdf", ".xml"], |
|
type="filepath", |
|
height=100, |
|
) |
|
with gr.Row(): |
|
langchain_status = gr.Textbox( |
|
label="Status", placeholder="", interactive=False |
|
) |
|
load_document = gr.Button("Load file") |
|
with gr.Tab("References"): |
|
doc_citation_user_doc = gr.HTML( |
|
"References used in answering the question will be displayed below." |
|
) |
|
with gr.Tab("Setting"): |
|
gr.Markdown("More in DEV...") |
|
|
|
|
|
load_document.click( |
|
document_changes, |
|
inputs=[uploaded_doc], |
|
outputs=[ |
|
langchain_status, |
|
qa_conversation, |
|
], |
|
queue=False, |
|
) |
|
|
|
docqa_question.submit(check_input_text, docqa_question).success( |
|
add_text, |
|
[chatbot_docqa, docqa_question], |
|
[chatbot_docqa, docqa_question], |
|
).then( |
|
bot, |
|
[chatbot_docqa, qa_conversation], |
|
[chatbot_docqa, doc_citation_user_doc], |
|
) |
|
|
|
docqa_submit_btn.click(check_input_text, docqa_question).success( |
|
add_text, |
|
[chatbot_docqa, docqa_question], |
|
[chatbot_docqa, docqa_question], |
|
).then( |
|
bot, |
|
[chatbot_docqa, qa_conversation], |
|
[chatbot_docqa, doc_citation_user_doc], |
|
) |
|
|
|
|
|
|
|
with gr.Tab("Preview feature 🔬"): |
|
|
|
with gr.Tab("Vision LM 🖼"): |
|
vision_tmp_link = ( |
|
"https://kadi-iam-lisa-vlm.hf.space/" |
|
) |
|
with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab: |
|
gr.HTML( |
|
"""<iframe src="{}" style="width:100%; height:1024px; overflow:auto"></iframe>""".format( |
|
vision_tmp_link |
|
) |
|
) |
|
|
|
|
|
|
|
with gr.Tab("KadiChat 💬"): |
|
kadichat_tmp_link = ( |
|
"https://kadi-iam-kadichat.hf.space/" |
|
) |
|
with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab: |
|
gr.HTML( |
|
"""<iframe src="{}" style="width:100%; height:1024px; overflow:auto"></iframe>""".format( |
|
kadichat_tmp_link |
|
) |
|
) |
|
|
|
|
|
with gr.Tab("RAG enhanced with Knowledge Graph (dev) 🔎"): |
|
kg_tmp_link = "https://kadi-iam-kadikgraph.static.hf.space/index.html" |
|
gr.Markdown( |
|
"[If rendering fails, look at the graph here](https://kadi-iam-kadikgraph.static.hf.space)" |
|
) |
|
with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab: |
|
gr.HTML( |
|
"""<iframe |
|
src="{}" |
|
frameborder="0" |
|
width="850" |
|
height="450" |
|
></iframe> |
|
""".format( |
|
kg_tmp_link |
|
) |
|
) |
|
|
|
|
|
with gr.Tab("About 📝"): |
|
with gr.Tab("Dev. info"): |
|
gr.Markdown( |
|
""" |
|
This system is being developed by the [Kadi Team at IAM-MMS, KIT](https://kadi.iam.kit.edu/kadi-ai), in collaboration with various groups with different scientific backgrounds. |
|
|
|
Changelog: |
|
|
|
- 23-10-2024: Add Kadi knowledge graph as test for Knowledge Graph-RAG. |
|
- 18-10-2024: Add linkage to Kadi. |
|
- 02-10-2024: Code cleaning, release code soon |
|
- 26-09-2024: Switch Vision-LLM to Mistral via API |
|
- 31-08-2024: Make document parsing as a preprocessing step and cache vector-database |
|
- 31-05-2024: Add Vision-LLM and draft Knowledge Graph-RAG (*preview*) |
|
- 21-05-2024: Add web search in setting (*experimental*) |
|
- 15-03-2024: Add evaluation and improve citation feature |
|
- 20-02-2024: Add citation feature (*experimental*) |
|
- 16-02-2024: Add support for xml file |
|
- 12-02-2024: Set demo on huggingface |
|
- 16-01-2024: Build first demo version |
|
- 23-11-2023: Draft concept |
|
|
|
|
|
Dev: |
|
|
|
- Metadata parsing |
|
- More robust citation feature |
|
- Conversational chat |
|
|
|
|
|
Current limitations: |
|
|
|
- The conversational chat (chat with history context) is not supported yet |
|
- Only 3 files are allowed to upload for testing |
|
|
|
*Notes: The model may produce incorrect statements. Users should treat these outputs as suggestions or starting points, not as definitive or accurate facts. |
|
""" |
|
) |
|
|
|
with gr.Tab("What's included?"): |
|
from paper_list import paper_list_str |
|
|
|
gr.Markdown( |
|
f"Currently, LISA includes the following open/free access pulications/documents/websites:\n\n {paper_list_str}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.queue(max_size=8, default_concurrency_limit=4).launch(share=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|