Clean code and add readme
Browse files- LISA_mini.ipynb +23 -25
- README.md +31 -0
- app.py +25 -20
- documents.py +51 -130
- embeddings.py +26 -15
- llms.py +16 -34
- preprocess_documents.py +9 -4
- ragchain.py +18 -5
- requirements.txt +1 -1
- rerank.py +3 -2
- retrievers.py +12 -7
- vectorestores.py +8 -3
LISA_mini.ipynb
CHANGED
|
@@ -1,8 +1,16 @@
|
|
| 1 |
{
|
| 2 |
"cells": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "adcfdba2",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
|
@@ -18,14 +26,13 @@
|
|
| 18 |
"from langchain.chains import ConversationalRetrievalChain\n",
|
| 19 |
"from langchain.llms import HuggingFaceTextGenInference\n",
|
| 20 |
"from langchain.chains.conversation.memory import (\n",
|
| 21 |
-
" ConversationBufferMemory,\n",
|
| 22 |
" ConversationBufferWindowMemory,\n",
|
| 23 |
")"
|
| 24 |
]
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
-
"execution_count":
|
| 29 |
"id": "2d85c6d9",
|
| 30 |
"metadata": {},
|
| 31 |
"outputs": [],
|
|
@@ -68,7 +75,7 @@
|
|
| 68 |
},
|
| 69 |
{
|
| 70 |
"cell_type": "code",
|
| 71 |
-
"execution_count":
|
| 72 |
"id": "2d5bacd5",
|
| 73 |
"metadata": {},
|
| 74 |
"outputs": [],
|
|
@@ -107,7 +114,7 @@
|
|
| 107 |
},
|
| 108 |
{
|
| 109 |
"cell_type": "code",
|
| 110 |
-
"execution_count":
|
| 111 |
"id": "8cd31248",
|
| 112 |
"metadata": {},
|
| 113 |
"outputs": [],
|
|
@@ -140,21 +147,12 @@
|
|
| 140 |
},
|
| 141 |
{
|
| 142 |
"cell_type": "code",
|
| 143 |
-
"execution_count":
|
| 144 |
-
"id": "73d560de",
|
| 145 |
-
"metadata": {},
|
| 146 |
-
"outputs": [],
|
| 147 |
-
"source": [
|
| 148 |
-
"# Create retrievers"
|
| 149 |
-
]
|
| 150 |
-
},
|
| 151 |
-
{
|
| 152 |
-
"cell_type": "code",
|
| 153 |
-
"execution_count": 12,
|
| 154 |
"id": "e5796990",
|
| 155 |
"metadata": {},
|
| 156 |
"outputs": [],
|
| 157 |
"source": [
|
|
|
|
| 158 |
"# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n",
|
| 159 |
"\n",
|
| 160 |
"# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n",
|
|
@@ -178,7 +176,7 @@
|
|
| 178 |
},
|
| 179 |
{
|
| 180 |
"cell_type": "code",
|
| 181 |
-
"execution_count":
|
| 182 |
"id": "bc299740",
|
| 183 |
"metadata": {},
|
| 184 |
"outputs": [],
|
|
@@ -191,7 +189,7 @@
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
-
"execution_count":
|
| 195 |
"id": "2eb8bc8f",
|
| 196 |
"metadata": {},
|
| 197 |
"outputs": [],
|
|
@@ -214,7 +212,7 @@
|
|
| 214 |
"\n",
|
| 215 |
"from sentence_transformers import CrossEncoder\n",
|
| 216 |
"\n",
|
| 217 |
-
"model_name = \"BAAI/bge-reranker-large\"
|
| 218 |
"\n",
|
| 219 |
"class BgeRerank(BaseDocumentCompressor):\n",
|
| 220 |
" model_name:str = model_name\n",
|
|
@@ -273,7 +271,7 @@
|
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "code",
|
| 276 |
-
"execution_count":
|
| 277 |
"id": "af780912",
|
| 278 |
"metadata": {},
|
| 279 |
"outputs": [],
|
|
@@ -283,7 +281,7 @@
|
|
| 283 |
"# Ensemble all above\n",
|
| 284 |
"ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n",
|
| 285 |
"\n",
|
| 286 |
-
"#
|
| 287 |
"compressor = BgeRerank()\n",
|
| 288 |
"rerank_retriever = ContextualCompressionRetriever(\n",
|
| 289 |
" base_compressor=compressor, base_retriever=ensemble_retriever\n",
|
|
@@ -292,7 +290,7 @@
|
|
| 292 |
},
|
| 293 |
{
|
| 294 |
"cell_type": "code",
|
| 295 |
-
"execution_count":
|
| 296 |
"id": "beb9ab21",
|
| 297 |
"metadata": {},
|
| 298 |
"outputs": [],
|
|
@@ -307,7 +305,7 @@
|
|
| 307 |
" self.return_messages = return_messages\n",
|
| 308 |
"\n",
|
| 309 |
" def create(self, retriver, llm):\n",
|
| 310 |
-
" memory = ConversationBufferWindowMemory(
|
| 311 |
" memory_key=self.memory_key,\n",
|
| 312 |
" return_messages=self.return_messages,\n",
|
| 313 |
" output_key=self.output_key,\n",
|
|
@@ -634,7 +632,7 @@
|
|
| 634 |
],
|
| 635 |
"metadata": {
|
| 636 |
"kernelspec": {
|
| 637 |
-
"display_name": "
|
| 638 |
"language": "python",
|
| 639 |
"name": "python3"
|
| 640 |
},
|
|
@@ -648,7 +646,7 @@
|
|
| 648 |
"name": "python",
|
| 649 |
"nbconvert_exporter": "python",
|
| 650 |
"pygments_lexer": "ipython3",
|
| 651 |
-
"version": "3.10
|
| 652 |
}
|
| 653 |
},
|
| 654 |
"nbformat": 4,
|
|
|
|
| 1 |
{
|
| 2 |
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "9267529d",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"A mini version of LISA in a Jupyter notebook for easier testing and playing around."
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
{
|
| 12 |
"cell_type": "code",
|
| 13 |
+
"execution_count": 2,
|
| 14 |
"id": "adcfdba2",
|
| 15 |
"metadata": {},
|
| 16 |
"outputs": [],
|
|
|
|
| 26 |
"from langchain.chains import ConversationalRetrievalChain\n",
|
| 27 |
"from langchain.llms import HuggingFaceTextGenInference\n",
|
| 28 |
"from langchain.chains.conversation.memory import (\n",
|
|
|
|
| 29 |
" ConversationBufferWindowMemory,\n",
|
| 30 |
")"
|
| 31 |
]
|
| 32 |
},
|
| 33 |
{
|
| 34 |
"cell_type": "code",
|
| 35 |
+
"execution_count": 3,
|
| 36 |
"id": "2d85c6d9",
|
| 37 |
"metadata": {},
|
| 38 |
"outputs": [],
|
|
|
|
| 75 |
},
|
| 76 |
{
|
| 77 |
"cell_type": "code",
|
| 78 |
+
"execution_count": 5,
|
| 79 |
"id": "2d5bacd5",
|
| 80 |
"metadata": {},
|
| 81 |
"outputs": [],
|
|
|
|
| 114 |
},
|
| 115 |
{
|
| 116 |
"cell_type": "code",
|
| 117 |
+
"execution_count": 6,
|
| 118 |
"id": "8cd31248",
|
| 119 |
"metadata": {},
|
| 120 |
"outputs": [],
|
|
|
|
| 147 |
},
|
| 148 |
{
|
| 149 |
"cell_type": "code",
|
| 150 |
+
"execution_count": 8,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
"id": "e5796990",
|
| 152 |
"metadata": {},
|
| 153 |
"outputs": [],
|
| 154 |
"source": [
|
| 155 |
+
"# Create retrievers\n",
|
| 156 |
"# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n",
|
| 157 |
"\n",
|
| 158 |
"# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n",
|
|
|
|
| 176 |
},
|
| 177 |
{
|
| 178 |
"cell_type": "code",
|
| 179 |
+
"execution_count": 9,
|
| 180 |
"id": "bc299740",
|
| 181 |
"metadata": {},
|
| 182 |
"outputs": [],
|
|
|
|
| 189 |
},
|
| 190 |
{
|
| 191 |
"cell_type": "code",
|
| 192 |
+
"execution_count": 10,
|
| 193 |
"id": "2eb8bc8f",
|
| 194 |
"metadata": {},
|
| 195 |
"outputs": [],
|
|
|
|
| 212 |
"\n",
|
| 213 |
"from sentence_transformers import CrossEncoder\n",
|
| 214 |
"\n",
|
| 215 |
+
"model_name = \"BAAI/bge-reranker-large\"\n",
|
| 216 |
"\n",
|
| 217 |
"class BgeRerank(BaseDocumentCompressor):\n",
|
| 218 |
" model_name:str = model_name\n",
|
|
|
|
| 271 |
},
|
| 272 |
{
|
| 273 |
"cell_type": "code",
|
| 274 |
+
"execution_count": 11,
|
| 275 |
"id": "af780912",
|
| 276 |
"metadata": {},
|
| 277 |
"outputs": [],
|
|
|
|
| 281 |
"# Ensemble all above\n",
|
| 282 |
"ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n",
|
| 283 |
"\n",
|
| 284 |
+
"# Rerank\n",
|
| 285 |
"compressor = BgeRerank()\n",
|
| 286 |
"rerank_retriever = ContextualCompressionRetriever(\n",
|
| 287 |
" base_compressor=compressor, base_retriever=ensemble_retriever\n",
|
|
|
|
| 290 |
},
|
| 291 |
{
|
| 292 |
"cell_type": "code",
|
| 293 |
+
"execution_count": 12,
|
| 294 |
"id": "beb9ab21",
|
| 295 |
"metadata": {},
|
| 296 |
"outputs": [],
|
|
|
|
| 305 |
" self.return_messages = return_messages\n",
|
| 306 |
"\n",
|
| 307 |
" def create(self, retriver, llm):\n",
|
| 308 |
+
" memory = ConversationBufferWindowMemory(\n",
|
| 309 |
" memory_key=self.memory_key,\n",
|
| 310 |
" return_messages=self.return_messages,\n",
|
| 311 |
" output_key=self.output_key,\n",
|
|
|
|
| 632 |
],
|
| 633 |
"metadata": {
|
| 634 |
"kernelspec": {
|
| 635 |
+
"display_name": "lisa",
|
| 636 |
"language": "python",
|
| 637 |
"name": "python3"
|
| 638 |
},
|
|
|
|
| 646 |
"name": "python",
|
| 647 |
"nbconvert_exporter": "python",
|
| 648 |
"pygments_lexer": "ipython3",
|
| 649 |
+
"version": "3.11.10"
|
| 650 |
}
|
| 651 |
},
|
| 652 |
"nbformat": 4,
|
README.md
CHANGED
|
@@ -11,3 +11,34 @@ startup_duration_timeout: 2h
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 14 |
+
|
| 15 |
+
LISA (Lithium Ion Solid-state Assistant) is a question-and-answer (Q&A) research assistant designed for efficient knowledge management with a primary focus on battery science, yet versatile enough to support broader scientific domains. Built on a Retrieval-Augmented Generation (RAG) architecture, LISA uses advanced Large Language Models (LLMs) to provide reliable, detailed answers to research questions.
|
| 16 |
+
|
| 17 |
+
DEMO: https://huggingface.co/spaces/Kadi-IAM/LISA
|
| 18 |
+
|
| 19 |
+
### Installation
|
| 20 |
+
1. Clone the Repository:
|
| 21 |
+
```bash
|
| 22 |
+
git clone "link of this repo"
|
| 23 |
+
cd LISA
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
2. Install Dependencies:
|
| 27 |
+
```bash
|
| 28 |
+
pip install -r requirements.txt
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
3. Set Up the Knowledge Base
|
| 32 |
+
Populate the knowledge base with relevant documents or research papers. Ensure that documents are in a format (pdf or xml) compatible with the RAG pipeline. By default documents should be located at `data/documents`. After running the following comand, some caches files are saved into `data/db`. ATTENTION: pickle is used to save these caches, be careful with potential security risks.
|
| 33 |
+
```bash
|
| 34 |
+
python preprocess_documents.py
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
4. Running LISA
|
| 38 |
+
Once setup is complete, run the following command to launch LISA:
|
| 39 |
+
```bash
|
| 40 |
+
python app.py
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### About
|
| 44 |
+
For more information on our work in intelligent research data management systems, please visit [KadiAI](https://kadi.iam.kit.edu/kadi-ai).
|
app.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
import re
|
| 4 |
-
|
| 5 |
-
from dotenv import load_dotenv
|
| 6 |
import pickle
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
|
| 11 |
from huggingface_hub import login
|
| 12 |
from langchain.vectorstores import FAISS
|
|
@@ -15,24 +18,21 @@ from llms import get_groq_chat
|
|
| 15 |
from documents import load_pdf_as_docs, load_xml_as_docs
|
| 16 |
from vectorestores import get_faiss_vectorestore
|
| 17 |
|
| 18 |
-
|
| 19 |
# For debug
|
| 20 |
# from langchain.globals import set_debug
|
| 21 |
# set_debug(True)
|
| 22 |
|
| 23 |
-
|
| 24 |
# Load and set env variables
|
| 25 |
load_dotenv()
|
| 26 |
|
|
|
|
| 27 |
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
| 28 |
login(HUGGINGFACEHUB_API_TOKEN)
|
| 29 |
TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] # Search engine
|
| 30 |
|
| 31 |
-
# Other settings
|
| 32 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
| 33 |
|
| 34 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
| 35 |
-
|
| 36 |
database_root = "./data/db"
|
| 37 |
document_path = "./data/documents"
|
| 38 |
|
|
@@ -80,12 +80,13 @@ from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
|
| 80 |
|
| 81 |
bm25_retriever = BM25Retriever.from_documents(
|
| 82 |
document_chunks, k=5
|
| 83 |
-
) # 1/2 of dense retriever, experimental value
|
| 84 |
|
| 85 |
-
# Ensemble all above
|
| 86 |
ensemble_retriever = EnsembleRetriever(
|
| 87 |
retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5]
|
| 88 |
)
|
|
|
|
| 89 |
# Reranker
|
| 90 |
from rerank import BgeRerank
|
| 91 |
|
|
@@ -98,7 +99,7 @@ print("rerank loaded")
|
|
| 98 |
llm = get_groq_chat(model_name="llama-3.1-70b-versatile")
|
| 99 |
|
| 100 |
|
| 101 |
-
#
|
| 102 |
from ragchain import RAGChain
|
| 103 |
|
| 104 |
rag_chain = RAGChain()
|
|
@@ -108,13 +109,11 @@ lisa_qa_conversation = rag_chain.create(rerank_retriever, llm, add_citation=True
|
|
| 108 |
from langchain_community.retrievers import TavilySearchAPIRetriever
|
| 109 |
from langchain.chains import RetrievalQAWithSourcesChain
|
| 110 |
|
| 111 |
-
web_search_retriever = TavilySearchAPIRetriever(
|
| 112 |
-
k=4
|
| 113 |
-
) # , include_raw_content=True)#, include_raw_content=True)
|
| 114 |
web_qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
|
| 115 |
llm, retriever=web_search_retriever, return_source_documents=True
|
| 116 |
)
|
| 117 |
-
print("
|
| 118 |
|
| 119 |
|
| 120 |
# Gradio utils
|
|
@@ -136,7 +135,7 @@ def add_text(history, text):
|
|
| 136 |
|
| 137 |
|
| 138 |
def postprocess_remove_cite_misinfo(text, allowed_max_cite_num=6):
|
| 139 |
-
"""
|
| 140 |
|
| 141 |
# Remove trailing references at end of text
|
| 142 |
if "References:\n[" in text:
|
|
@@ -480,7 +479,7 @@ def main():
|
|
| 480 |
# flag_web_search = gr.Checkbox(label="Search web", info="Search information from Internet")
|
| 481 |
gr.Markdown("More in DEV...")
|
| 482 |
|
| 483 |
-
#
|
| 484 |
user_txt.submit(check_input_text, user_txt, None).success(
|
| 485 |
add_text, [chatbot, user_txt], [chatbot, user_txt]
|
| 486 |
).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation])
|
|
@@ -575,6 +574,7 @@ def main():
|
|
| 575 |
with gr.Tab("Setting"):
|
| 576 |
gr.Markdown("More in DEV...")
|
| 577 |
|
|
|
|
| 578 |
load_document.click(
|
| 579 |
document_changes,
|
| 580 |
inputs=[uploaded_doc], # , repo_id],
|
|
@@ -606,8 +606,9 @@ def main():
|
|
| 606 |
)
|
| 607 |
|
| 608 |
##########################
|
| 609 |
-
# Preview
|
| 610 |
with gr.Tab("Preview feature 🔬"):
|
|
|
|
| 611 |
with gr.Tab("Vision LM 🖼"):
|
| 612 |
vision_tmp_link = (
|
| 613 |
"https://kadi-iam-lisa-vlm.hf.space/" # vision model link
|
|
@@ -620,6 +621,7 @@ def main():
|
|
| 620 |
)
|
| 621 |
# gr.Markdown("placeholder")
|
| 622 |
|
|
|
|
| 623 |
with gr.Tab("KadiChat 💬"):
|
| 624 |
kadichat_tmp_link = (
|
| 625 |
"https://kadi-iam-kadichat.hf.space/" # vision model link
|
|
@@ -631,9 +633,12 @@ def main():
|
|
| 631 |
)
|
| 632 |
)
|
| 633 |
|
|
|
|
| 634 |
with gr.Tab("RAG enhanced with Knowledge Graph (dev) 🔎"):
|
| 635 |
kg_tmp_link = "https://kadi-iam-kadikgraph.static.hf.space/index.html"
|
| 636 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 637 |
with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab:
|
| 638 |
gr.HTML(
|
| 639 |
"""<iframe
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main app for LISA RAG chatbot based on langchain.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
import os
|
| 6 |
import time
|
| 7 |
import re
|
| 8 |
+
import gradio as gr
|
|
|
|
| 9 |
import pickle
|
| 10 |
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
|
| 14 |
from huggingface_hub import login
|
| 15 |
from langchain.vectorstores import FAISS
|
|
|
|
| 18 |
from documents import load_pdf_as_docs, load_xml_as_docs
|
| 19 |
from vectorestores import get_faiss_vectorestore
|
| 20 |
|
|
|
|
| 21 |
# For debug
|
| 22 |
# from langchain.globals import set_debug
|
| 23 |
# set_debug(True)
|
| 24 |
|
|
|
|
| 25 |
# Load and set env variables
|
| 26 |
load_dotenv()
|
| 27 |
|
| 28 |
+
# Set API keys
|
| 29 |
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
| 30 |
login(HUGGINGFACEHUB_API_TOKEN)
|
| 31 |
TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] # Search engine
|
| 32 |
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
| 35 |
+
# Set database path
|
| 36 |
database_root = "./data/db"
|
| 37 |
document_path = "./data/documents"
|
| 38 |
|
|
|
|
| 80 |
|
| 81 |
bm25_retriever = BM25Retriever.from_documents(
|
| 82 |
document_chunks, k=5
|
| 83 |
+
) # k = 1/2 of dense retriever, experimental value
|
| 84 |
|
| 85 |
+
# Ensemble all above retrievers
|
| 86 |
ensemble_retriever = EnsembleRetriever(
|
| 87 |
retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5]
|
| 88 |
)
|
| 89 |
+
|
| 90 |
# Reranker
|
| 91 |
from rerank import BgeRerank
|
| 92 |
|
|
|
|
| 99 |
llm = get_groq_chat(model_name="llama-3.1-70b-versatile")
|
| 100 |
|
| 101 |
|
| 102 |
+
# Create conversation qa chain (Note: conversation is not supported yet)
|
| 103 |
from ragchain import RAGChain
|
| 104 |
|
| 105 |
rag_chain = RAGChain()
|
|
|
|
| 109 |
from langchain_community.retrievers import TavilySearchAPIRetriever
|
| 110 |
from langchain.chains import RetrievalQAWithSourcesChain
|
| 111 |
|
| 112 |
+
web_search_retriever = TavilySearchAPIRetriever(k=4) # , include_raw_content=True)
|
|
|
|
|
|
|
| 113 |
web_qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
|
| 114 |
llm, retriever=web_search_retriever, return_source_documents=True
|
| 115 |
)
|
| 116 |
+
print("chains loaded")
|
| 117 |
|
| 118 |
|
| 119 |
# Gradio utils
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
def postprocess_remove_cite_misinfo(text, allowed_max_cite_num=6):
|
| 138 |
+
"""Heuristic removal of misinfo. of citations."""
|
| 139 |
|
| 140 |
# Remove trailing references at end of text
|
| 141 |
if "References:\n[" in text:
|
|
|
|
| 479 |
# flag_web_search = gr.Checkbox(label="Search web", info="Search information from Internet")
|
| 480 |
gr.Markdown("More in DEV...")
|
| 481 |
|
| 482 |
+
# Action functions
|
| 483 |
user_txt.submit(check_input_text, user_txt, None).success(
|
| 484 |
add_text, [chatbot, user_txt], [chatbot, user_txt]
|
| 485 |
).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation])
|
|
|
|
| 574 |
with gr.Tab("Setting"):
|
| 575 |
gr.Markdown("More in DEV...")
|
| 576 |
|
| 577 |
+
# Actions
|
| 578 |
load_document.click(
|
| 579 |
document_changes,
|
| 580 |
inputs=[uploaded_doc], # , repo_id],
|
|
|
|
| 606 |
)
|
| 607 |
|
| 608 |
##########################
|
| 609 |
+
# Preview tabs
|
| 610 |
with gr.Tab("Preview feature 🔬"):
|
| 611 |
+
# VLM model
|
| 612 |
with gr.Tab("Vision LM 🖼"):
|
| 613 |
vision_tmp_link = (
|
| 614 |
"https://kadi-iam-lisa-vlm.hf.space/" # vision model link
|
|
|
|
| 621 |
)
|
| 622 |
# gr.Markdown("placeholder")
|
| 623 |
|
| 624 |
+
# OAuth2 linkage to Kadi-demo
|
| 625 |
with gr.Tab("KadiChat 💬"):
|
| 626 |
kadichat_tmp_link = (
|
| 627 |
"https://kadi-iam-kadichat.hf.space/" # vision model link
|
|
|
|
| 633 |
)
|
| 634 |
)
|
| 635 |
|
| 636 |
+
# Knowledge graph-enhanced RAG
|
| 637 |
with gr.Tab("RAG enhanced with Knowledge Graph (dev) 🔎"):
|
| 638 |
kg_tmp_link = "https://kadi-iam-kadikgraph.static.hf.space/index.html"
|
| 639 |
+
gr.Markdown(
|
| 640 |
+
"[If rendering fails, look at the graph here](https://kadi-iam-kadikgraph.static.hf.space)"
|
| 641 |
+
)
|
| 642 |
with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab:
|
| 643 |
gr.HTML(
|
| 644 |
"""<iframe
|
documents.py
CHANGED
|
@@ -1,25 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
import shutil
|
| 3 |
|
| 4 |
from langchain.document_loaders import (
|
| 5 |
PyMuPDFLoader,
|
| 6 |
)
|
| 7 |
from langchain.docstore.document import Document
|
| 8 |
-
|
| 9 |
-
from langchain.vectorstores import Chroma
|
| 10 |
-
|
| 11 |
from langchain.text_splitter import (
|
| 12 |
-
RecursiveCharacterTextSplitter,
|
| 13 |
SpacyTextSplitter,
|
| 14 |
)
|
| 15 |
|
|
|
|
| 16 |
def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
|
| 17 |
"""Load and parse pdf file(s)."""
|
| 18 |
-
|
| 19 |
-
if pdf_path.endswith(
|
| 20 |
pdf_docs = [pdf_path]
|
| 21 |
else: # a directory
|
| 22 |
-
pdf_docs = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
if load_kwargs is None:
|
| 25 |
load_kwargs = {}
|
|
@@ -31,180 +36,96 @@ def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
|
|
| 31 |
loader = loader_module(pdf, **load_kwargs)
|
| 32 |
doc = loader.load()
|
| 33 |
docs.extend(doc)
|
| 34 |
-
|
| 35 |
return docs
|
| 36 |
|
|
|
|
| 37 |
def load_xml_as_docs(xml_path, loader_module=None, load_kwargs=None):
|
| 38 |
"""Load and parse xml file(s)."""
|
| 39 |
-
|
| 40 |
from bs4 import BeautifulSoup
|
| 41 |
from unstructured.cleaners.core import group_broken_paragraphs
|
| 42 |
-
|
| 43 |
-
if xml_path.endswith(
|
| 44 |
xml_docs = [xml_path]
|
| 45 |
else: # a directory
|
| 46 |
-
xml_docs = [
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
if load_kwargs is None:
|
| 49 |
load_kwargs = {}
|
| 50 |
|
| 51 |
docs = []
|
| 52 |
for xml_file in xml_docs:
|
| 53 |
-
# print("now reading file...")
|
| 54 |
with open(xml_file) as fp:
|
| 55 |
-
soup = BeautifulSoup(
|
|
|
|
|
|
|
| 56 |
pageText = soup.findAll(string=True)
|
| 57 |
-
parsed_text =
|
| 58 |
-
#
|
| 59 |
parsed_text_grouped = group_broken_paragraphs(parsed_text)
|
| 60 |
-
|
| 61 |
# get metadata
|
| 62 |
try:
|
| 63 |
from lxml import etree as ET
|
|
|
|
| 64 |
tree = ET.parse(xml_file)
|
| 65 |
|
| 66 |
# Define namespace
|
| 67 |
ns = {"tei": "http://www.tei-c.org/ns/1.0"}
|
| 68 |
# Read Author personal names as an example
|
| 69 |
-
pers_name_elements = tree.xpath(
|
|
|
|
|
|
|
|
|
|
| 70 |
first_per = pers_name_elements[0].text
|
| 71 |
author_info = first_per + " et al"
|
| 72 |
|
| 73 |
-
title_elements = tree.xpath(
|
|
|
|
|
|
|
| 74 |
title = title_elements[0].text
|
| 75 |
|
| 76 |
# Combine source info
|
| 77 |
source_info = "_".join([author_info, title])
|
| 78 |
except:
|
| 79 |
source_info = "unknown"
|
| 80 |
-
|
| 81 |
-
# maybe even better TODO: discuss with
|
| 82 |
# first_author = soup.find("author")
|
| 83 |
# publication_year = soup.find("date", attrs={'type': 'published'})
|
| 84 |
# title = soup.find("title")
|
| 85 |
# source_info = [first_author, publication_year, title]
|
| 86 |
# source_info_str = "_".join([info.text.strip() if info is not None else "unknown" for info in source_info])
|
| 87 |
-
|
| 88 |
-
doc =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
docs.extend(doc)
|
| 91 |
-
|
| 92 |
return docs
|
| 93 |
|
| 94 |
|
| 95 |
def get_doc_chunks(docs, splitter=None):
|
| 96 |
"""Split docs into chunks."""
|
| 97 |
-
|
| 98 |
if splitter is None:
|
| 99 |
-
# splitter = RecursiveCharacterTextSplitter(
|
| 100 |
# # separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256
|
| 101 |
# separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128
|
| 102 |
# )
|
|
|
|
| 103 |
splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
| 104 |
chunk_size=512,
|
| 105 |
chunk_overlap=128,
|
| 106 |
)
|
| 107 |
chunks = splitter.split_documents(docs)
|
| 108 |
-
|
| 109 |
-
return chunks
|
| 110 |
-
|
| 111 |
|
| 112 |
-
|
| 113 |
-
# embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
| 114 |
-
# vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
|
| 115 |
-
if overwrite:
|
| 116 |
-
shutil.rmtree(persist_directory) # Empty and reset db
|
| 117 |
-
db = Chroma.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory)
|
| 118 |
-
# db.delete_collection()
|
| 119 |
-
db.persist()
|
| 120 |
-
# db = None
|
| 121 |
-
# db = Chroma(persist_directory="db", embedding_function = embeddings, client_settings=CHROMA_SETTINGS)
|
| 122 |
-
# vectorstore = FAISS.from_documents(documents=document_chunks, embedding=embeddings)
|
| 123 |
-
return db
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class VectorstoreManager:
|
| 127 |
-
|
| 128 |
-
def __init__(self):
|
| 129 |
-
self.vectorstore_class = Chroma
|
| 130 |
-
|
| 131 |
-
def create_db(self, embeddings):
|
| 132 |
-
db = self.vectorstore_class(embedding_function=embeddings)
|
| 133 |
-
|
| 134 |
-
self.db = db
|
| 135 |
-
return db
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def load_db(self, persist_directory, embeddings):
|
| 139 |
-
"""Load local vectorestore."""
|
| 140 |
-
|
| 141 |
-
db = self.vectorstore_class(persist_directory=persist_directory, embedding_function=embeddings)
|
| 142 |
-
self.db = db
|
| 143 |
-
|
| 144 |
-
return db
|
| 145 |
-
|
| 146 |
-
def create_db_from_documents(self, document_chunks, embeddings, persist_directory="db", overwrite=False):
|
| 147 |
-
"""Create db from documents."""
|
| 148 |
-
|
| 149 |
-
if overwrite:
|
| 150 |
-
shutil.rmtree(persist_directory) # Empty and reset db
|
| 151 |
-
db = self.vectorstore_class.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory)
|
| 152 |
-
self.db = db
|
| 153 |
-
|
| 154 |
-
return db
|
| 155 |
-
|
| 156 |
-
def persist_db(self, persist_directory="db"):
|
| 157 |
-
"""Persist db."""
|
| 158 |
-
|
| 159 |
-
assert self.db
|
| 160 |
-
self.db.persist() # Chroma
|
| 161 |
-
|
| 162 |
-
class RetrieverManager:
|
| 163 |
-
# some other retrievers Using Advanced Retrievers in LangChain https://www.comet.com/site/blog/using-advanced-retrievers-in-langchain/
|
| 164 |
-
|
| 165 |
-
def __init__(self, vectorstore, k=10):
|
| 166 |
-
|
| 167 |
-
self.vectorstore = vectorstore
|
| 168 |
-
self.retriever = vectorstore.as_retriever(search_kwargs={"k": k}) #search_kwargs={"k": 8}),
|
| 169 |
-
|
| 170 |
-
def get_rerank_retriver(self, base_retriever=None):
|
| 171 |
-
|
| 172 |
-
if base_retriever is None:
|
| 173 |
-
base_retriever = self.retriever
|
| 174 |
-
# with rerank
|
| 175 |
-
from rerank import BgeRerank
|
| 176 |
-
from langchain.retrievers import ContextualCompressionRetriever
|
| 177 |
-
|
| 178 |
-
compressor = BgeRerank()
|
| 179 |
-
compression_retriever = ContextualCompressionRetriever(
|
| 180 |
-
base_compressor=compressor, base_retriever=base_retriever
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
return compression_retriever
|
| 184 |
-
|
| 185 |
-
def get_parent_doc_retriver(self, documents, store_file="./store_location"):
|
| 186 |
-
# TODO need better design
|
| 187 |
-
# Ref: explain how it works: https://clusteredbytes.pages.dev/posts/2023/langchain-parent-document-retriever/
|
| 188 |
-
from langchain.storage.file_system import LocalFileStore
|
| 189 |
-
from langchain.storage import InMemoryStore
|
| 190 |
-
from langchain.storage._lc_store import create_kv_docstore
|
| 191 |
-
from langchain.retrievers import ParentDocumentRetriever
|
| 192 |
-
# Ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain
|
| 193 |
-
# fs = LocalFileStore("./store_location")
|
| 194 |
-
# store = create_kv_docstore(fs)
|
| 195 |
-
docstore = InMemoryStore()
|
| 196 |
-
|
| 197 |
-
# TODO: how to better set this?
|
| 198 |
-
parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
|
| 199 |
-
child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128)
|
| 200 |
-
|
| 201 |
-
retriever = ParentDocumentRetriever(
|
| 202 |
-
vectorstore=self.vectorstore,
|
| 203 |
-
docstore=docstore,
|
| 204 |
-
child_splitter=child_splitter,
|
| 205 |
-
parent_splitter=parent_splitter,
|
| 206 |
-
search_kwargs={"k":10} # Better settings?
|
| 207 |
-
)
|
| 208 |
-
retriever.add_documents(documents)#, ids=None)
|
| 209 |
-
|
| 210 |
-
return retriever
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parse documents, currently pdf and xml are supported.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
import os
|
|
|
|
| 6 |
|
| 7 |
from langchain.document_loaders import (
|
| 8 |
PyMuPDFLoader,
|
| 9 |
)
|
| 10 |
from langchain.docstore.document import Document
|
|
|
|
|
|
|
|
|
|
| 11 |
from langchain.text_splitter import (
|
| 12 |
+
# RecursiveCharacterTextSplitter,
|
| 13 |
SpacyTextSplitter,
|
| 14 |
)
|
| 15 |
|
| 16 |
+
|
| 17 |
def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
|
| 18 |
"""Load and parse pdf file(s)."""
|
| 19 |
+
|
| 20 |
+
if pdf_path.endswith(".pdf"): # single file
|
| 21 |
pdf_docs = [pdf_path]
|
| 22 |
else: # a directory
|
| 23 |
+
pdf_docs = [
|
| 24 |
+
os.path.join(pdf_path, f)
|
| 25 |
+
for f in os.listdir(pdf_path)
|
| 26 |
+
if f.endswith(".pdf")
|
| 27 |
+
]
|
| 28 |
|
| 29 |
if load_kwargs is None:
|
| 30 |
load_kwargs = {}
|
|
|
|
| 36 |
loader = loader_module(pdf, **load_kwargs)
|
| 37 |
doc = loader.load()
|
| 38 |
docs.extend(doc)
|
| 39 |
+
|
| 40 |
return docs
|
| 41 |
|
| 42 |
+
|
| 43 |
def load_xml_as_docs(xml_path, loader_module=None, load_kwargs=None):
|
| 44 |
"""Load and parse xml file(s)."""
|
| 45 |
+
|
| 46 |
from bs4 import BeautifulSoup
|
| 47 |
from unstructured.cleaners.core import group_broken_paragraphs
|
| 48 |
+
|
| 49 |
+
if xml_path.endswith(".xml"): # single file
|
| 50 |
xml_docs = [xml_path]
|
| 51 |
else: # a directory
|
| 52 |
+
xml_docs = [
|
| 53 |
+
os.path.join(xml_path, f)
|
| 54 |
+
for f in os.listdir(xml_path)
|
| 55 |
+
if f.endswith(".xml")
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
if load_kwargs is None:
|
| 59 |
load_kwargs = {}
|
| 60 |
|
| 61 |
docs = []
|
| 62 |
for xml_file in xml_docs:
|
|
|
|
| 63 |
with open(xml_file) as fp:
|
| 64 |
+
soup = BeautifulSoup(
|
| 65 |
+
fp, features="xml"
|
| 66 |
+
) # txt is simply the a string with your XML file
|
| 67 |
pageText = soup.findAll(string=True)
|
| 68 |
+
parsed_text = "\n".join(pageText) # or " ".join, seems similar
|
| 69 |
+
# Clean text
|
| 70 |
parsed_text_grouped = group_broken_paragraphs(parsed_text)
|
| 71 |
+
|
| 72 |
# get metadata
|
| 73 |
try:
|
| 74 |
from lxml import etree as ET
|
| 75 |
+
|
| 76 |
tree = ET.parse(xml_file)
|
| 77 |
|
| 78 |
# Define namespace
|
| 79 |
ns = {"tei": "http://www.tei-c.org/ns/1.0"}
|
| 80 |
# Read Author personal names as an example
|
| 81 |
+
pers_name_elements = tree.xpath(
|
| 82 |
+
"tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:author/tei:persName",
|
| 83 |
+
namespaces=ns,
|
| 84 |
+
)
|
| 85 |
first_per = pers_name_elements[0].text
|
| 86 |
author_info = first_per + " et al"
|
| 87 |
|
| 88 |
+
title_elements = tree.xpath(
|
| 89 |
+
"tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:title", namespaces=ns
|
| 90 |
+
)
|
| 91 |
title = title_elements[0].text
|
| 92 |
|
| 93 |
# Combine source info
|
| 94 |
source_info = "_".join([author_info, title])
|
| 95 |
except:
|
| 96 |
source_info = "unknown"
|
| 97 |
+
|
| 98 |
+
# maybe even better parsing method. TODO: discuss with TUD
|
| 99 |
# first_author = soup.find("author")
|
| 100 |
# publication_year = soup.find("date", attrs={'type': 'published'})
|
| 101 |
# title = soup.find("title")
|
| 102 |
# source_info = [first_author, publication_year, title]
|
| 103 |
# source_info_str = "_".join([info.text.strip() if info is not None else "unknown" for info in source_info])
|
| 104 |
+
|
| 105 |
+
doc = [
|
| 106 |
+
Document(
|
| 107 |
+
page_content=parsed_text_grouped, metadata={"source": source_info}
|
| 108 |
+
)
|
| 109 |
+
]
|
| 110 |
|
| 111 |
docs.extend(doc)
|
| 112 |
+
|
| 113 |
return docs
|
| 114 |
|
| 115 |
|
| 116 |
def get_doc_chunks(docs, splitter=None):
|
| 117 |
"""Split docs into chunks."""
|
| 118 |
+
|
| 119 |
if splitter is None:
|
| 120 |
+
# splitter = RecursiveCharacterTextSplitter( # original default
|
| 121 |
# # separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256
|
| 122 |
# separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128
|
| 123 |
# )
|
| 124 |
+
# Spacy seems better
|
| 125 |
splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
| 126 |
chunk_size=512,
|
| 127 |
chunk_overlap=128,
|
| 128 |
)
|
| 129 |
chunks = splitter.split_documents(docs)
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings.py
CHANGED
|
@@ -1,39 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
from langchain.embeddings import HuggingFaceEmbeddings
|
| 4 |
|
| 5 |
|
| 6 |
def get_hf_embeddings(model_name=None):
|
| 7 |
-
"""Get huggingface embedding."""
|
| 8 |
-
|
| 9 |
if model_name is None:
|
| 10 |
-
# Some candiates
|
| 11 |
# "BAAI/bge-m3" (good, though large and slow)
|
| 12 |
-
# "BAAI/bge-base-en-v1.5" ->
|
| 13 |
-
# "sentence-transformers/all-mpnet-base-v2"
|
| 14 |
-
#
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 18 |
-
|
| 19 |
return embeddings
|
| 20 |
|
| 21 |
|
| 22 |
-
def get_jinaai_embeddings(
|
|
|
|
|
|
|
| 23 |
"""Get jinaai embedding."""
|
| 24 |
-
|
| 25 |
# device: cpu or cuda
|
| 26 |
if device == "auto":
|
| 27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
# For jinaai. Ref: https://github.com/langchain-ai/langchain/issues/6080
|
| 29 |
from transformers import AutoModel
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
model_name = model_name
|
| 33 |
-
model_kwargs = {
|
| 34 |
embeddings = HuggingFaceEmbeddings(
|
| 35 |
model_name=model_name,
|
| 36 |
model_kwargs=model_kwargs,
|
| 37 |
)
|
| 38 |
-
|
| 39 |
-
return embeddings
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Load embedding models from huggingface.
|
| 3 |
+
"""
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from langchain.embeddings import HuggingFaceEmbeddings
|
| 7 |
|
| 8 |
|
| 9 |
def get_hf_embeddings(model_name=None):
|
| 10 |
+
"""Get huggingface embedding by name."""
|
| 11 |
+
|
| 12 |
if model_name is None:
|
| 13 |
+
# Some candiates
|
| 14 |
# "BAAI/bge-m3" (good, though large and slow)
|
| 15 |
+
# "BAAI/bge-base-en-v1.5" -> also good
|
| 16 |
+
# "sentence-transformers/all-mpnet-base-v2"
|
| 17 |
+
# "maidalun1020/bce-embedding-base_v1"
|
| 18 |
+
# "intfloat/multilingual-e5-large"
|
| 19 |
+
# Ref: https://huggingface.co/spaces/mteb/leaderboard
|
| 20 |
+
# https://huggingface.co/maidalun1020/bce-embedding-base_v1
|
| 21 |
+
model_name = "BAAI/bge-large-en-v1.5"
|
| 22 |
+
|
| 23 |
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 24 |
+
|
| 25 |
return embeddings
|
| 26 |
|
| 27 |
|
| 28 |
+
def get_jinaai_embeddings(
|
| 29 |
+
model_name="jinaai/jina-embeddings-v2-base-en", device="auto"
|
| 30 |
+
):
|
| 31 |
"""Get jinaai embedding."""
|
| 32 |
+
|
| 33 |
# device: cpu or cuda
|
| 34 |
if device == "auto":
|
| 35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 36 |
# For jinaai. Ref: https://github.com/langchain-ai/langchain/issues/6080
|
| 37 |
from transformers import AutoModel
|
| 38 |
+
|
| 39 |
+
model = AutoModel.from_pretrained(
|
| 40 |
+
model_name, trust_remote_code=True
|
| 41 |
+
) # -> will yield error, need bug fixing
|
| 42 |
|
| 43 |
model_name = model_name
|
| 44 |
+
model_kwargs = {"device": device, "trust_remote_code": True}
|
| 45 |
embeddings = HuggingFaceEmbeddings(
|
| 46 |
model_name=model_name,
|
| 47 |
model_kwargs=model_kwargs,
|
| 48 |
)
|
| 49 |
+
|
| 50 |
+
return embeddings
|
llms.py
CHANGED
|
@@ -1,22 +1,22 @@
|
|
| 1 |
-
|
| 2 |
-
from
|
|
|
|
|
|
|
| 3 |
from transformers import (
|
| 4 |
-
AutoModelForCausalLM,
|
| 5 |
AutoTokenizer,
|
| 6 |
pipeline,
|
| 7 |
)
|
| 8 |
-
from
|
| 9 |
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 10 |
from langchain_groq import ChatGroq
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
from langchain.chat_models import ChatOpenAI
|
| 14 |
from langchain.llms import HuggingFaceTextGenInference
|
| 15 |
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def get_llm_hf_online(inference_api_url=""):
|
| 18 |
"""Get LLM using huggingface inference."""
|
| 19 |
-
|
| 20 |
if not inference_api_url: # default api url
|
| 21 |
inference_api_url = (
|
| 22 |
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
|
@@ -35,20 +35,16 @@ def get_llm_hf_online(inference_api_url=""):
|
|
| 35 |
|
| 36 |
|
| 37 |
def get_llm_hf_local(model_path):
|
| 38 |
-
"""Get local LLM."""
|
| 39 |
-
|
| 40 |
-
model = LlamaForCausalLM.from_pretrained(
|
| 41 |
-
model_path, device_map="auto"
|
| 42 |
-
)
|
| 43 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 44 |
|
| 45 |
-
# print('making a pipeline...')
|
| 46 |
-
# max_length has typically been deprecated for max_new_tokens
|
| 47 |
pipe = pipeline(
|
| 48 |
"text-generation",
|
| 49 |
model=model,
|
| 50 |
tokenizer=tokenizer,
|
| 51 |
-
max_new_tokens=
|
| 52 |
model_kwargs={"temperature": 0.1}, # better setting?
|
| 53 |
)
|
| 54 |
llm = HuggingFacePipeline(pipeline=pipe)
|
|
@@ -56,22 +52,8 @@ def get_llm_hf_local(model_path):
|
|
| 56 |
return llm
|
| 57 |
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
"""Get openai-like LLM."""
|
| 62 |
-
|
| 63 |
-
llm = ChatOpenAI(
|
| 64 |
-
model=model_name,
|
| 65 |
-
openai_api_key="EMPTY",
|
| 66 |
-
openai_api_base=inference_server_url,
|
| 67 |
-
max_tokens=1024, # better setting?
|
| 68 |
-
temperature=0,
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
return llm
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def get_groq_chat(model_name="llama-3.1-70b-versatile"):
|
| 75 |
|
| 76 |
llm = ChatGroq(temperature=0, model_name=model_name)
|
| 77 |
-
return llm
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Load LLMs from huggingface, Groq, etc.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
from transformers import (
|
| 6 |
+
# AutoModelForCausalLM,
|
| 7 |
AutoTokenizer,
|
| 8 |
pipeline,
|
| 9 |
)
|
| 10 |
+
from langchain.llms import HuggingFacePipeline
|
|
|
|
| 11 |
from langchain_groq import ChatGroq
|
|
|
|
|
|
|
|
|
|
| 12 |
from langchain.llms import HuggingFaceTextGenInference
|
| 13 |
|
| 14 |
+
# from langchain.chat_models import ChatOpenAI # oai model
|
| 15 |
+
|
| 16 |
|
| 17 |
def get_llm_hf_online(inference_api_url=""):
|
| 18 |
"""Get LLM using huggingface inference."""
|
| 19 |
+
|
| 20 |
if not inference_api_url: # default api url
|
| 21 |
inference_api_url = (
|
| 22 |
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def get_llm_hf_local(model_path):
|
| 38 |
+
"""Get local LLM from huggingface."""
|
| 39 |
+
|
| 40 |
+
model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto")
|
|
|
|
|
|
|
| 41 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 42 |
|
|
|
|
|
|
|
| 43 |
pipe = pipeline(
|
| 44 |
"text-generation",
|
| 45 |
model=model,
|
| 46 |
tokenizer=tokenizer,
|
| 47 |
+
max_new_tokens=2048, # better setting?
|
| 48 |
model_kwargs={"temperature": 0.1}, # better setting?
|
| 49 |
)
|
| 50 |
llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
|
| 52 |
return llm
|
| 53 |
|
| 54 |
|
| 55 |
+
def get_groq_chat(model_name="llama-3.1-70b-versatile"):
|
| 56 |
+
"""Get LLM from Groq."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
llm = ChatGroq(temperature=0, model_name=model_name)
|
| 59 |
+
return llm
|
preprocess_documents.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
"""
|
| 2 |
-
Load and parse files (pdf) in the data/documents and save cached pkl files.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
| 6 |
import pickle
|
| 7 |
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
-
|
| 10 |
-
|
| 11 |
from huggingface_hub import login
|
| 12 |
-
|
| 13 |
from documents import load_pdf_as_docs, get_doc_chunks
|
| 14 |
from embeddings import get_jinaai_embeddings
|
| 15 |
|
|
@@ -23,11 +25,14 @@ login(HUGGINGFACEHUB_API_TOKEN)
|
|
| 23 |
|
| 24 |
|
| 25 |
def save_to_pickle(obj, filename):
|
|
|
|
|
|
|
| 26 |
with open(filename, "wb") as file:
|
| 27 |
pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
|
| 28 |
|
| 29 |
|
| 30 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
|
|
|
| 31 |
database_root = "./data/db"
|
| 32 |
document_path = "./data/documents"
|
| 33 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Load and parse files (pdf) in the "data/documents" and save cached pkl files.
|
| 3 |
+
It will load and parse files and save 4 caches:
|
| 4 |
+
1. "docs.pkl" for loaded text documents
|
| 5 |
+
2. "docs_chunks.pkl" for chunked text
|
| 6 |
+
3. "docstore.pkl" for small-to-big retriever
|
| 7 |
+
4. faiss_index for FAISS vectore store
|
| 8 |
"""
|
| 9 |
|
| 10 |
import os
|
| 11 |
import pickle
|
| 12 |
|
| 13 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
| 14 |
from huggingface_hub import login
|
|
|
|
| 15 |
from documents import load_pdf_as_docs, get_doc_chunks
|
| 16 |
from embeddings import get_jinaai_embeddings
|
| 17 |
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def save_to_pickle(obj, filename):
|
| 28 |
+
"""Save obj to disk using pickle."""
|
| 29 |
+
|
| 30 |
with open(filename, "wb") as file:
|
| 31 |
pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
|
| 32 |
|
| 33 |
|
| 34 |
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
| 35 |
+
# Set database path, should be same as defined in "app.py"
|
| 36 |
database_root = "./data/db"
|
| 37 |
document_path = "./data/documents"
|
| 38 |
|
ragchain.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from langchain.chains import LLMChain
|
| 2 |
|
| 3 |
from langchain.prompts import (
|
|
@@ -11,17 +15,17 @@ from langchain.chains import ConversationalRetrievalChain
|
|
| 11 |
from langchain.chains.conversation.memory import (
|
| 12 |
ConversationBufferWindowMemory,
|
| 13 |
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
from langchain.chains import StuffDocumentsChain
|
| 17 |
|
| 18 |
|
| 19 |
def get_cite_combine_docs_chain(llm):
|
|
|
|
| 20 |
|
| 21 |
# Ref: https://github.com/langchain-ai/langchain/issues/7239
|
| 22 |
# Function to format each document with an index, source, and content.
|
| 23 |
def format_document(doc, index, prompt):
|
| 24 |
"""Format a document into a string based on a prompt template."""
|
|
|
|
| 25 |
# Create a dictionary with document content and metadata.
|
| 26 |
base_info = {
|
| 27 |
"page_content": doc.page_content,
|
|
@@ -40,7 +44,11 @@ def get_cite_combine_docs_chain(llm):
|
|
| 40 |
|
| 41 |
# Custom chain class to handle document combination with source indices.
|
| 42 |
class StuffDocumentsWithIndexChain(StuffDocumentsChain):
|
|
|
|
|
|
|
| 43 |
def _get_inputs(self, docs, **kwargs):
|
|
|
|
|
|
|
| 44 |
# Format each document and combine them.
|
| 45 |
doc_strings = [
|
| 46 |
format_document(doc, i, self.document_prompt)
|
|
@@ -58,6 +66,7 @@ def get_cite_combine_docs_chain(llm):
|
|
| 58 |
)
|
| 59 |
return inputs
|
| 60 |
|
|
|
|
| 61 |
# Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
|
| 62 |
# Define a chat prompt with instructions for citing documents.
|
| 63 |
combine_doc_prompt = PromptTemplate(
|
|
@@ -103,6 +112,8 @@ def get_cite_combine_docs_chain(llm):
|
|
| 103 |
|
| 104 |
|
| 105 |
class RAGChain:
|
|
|
|
|
|
|
| 106 |
def __init__(
|
| 107 |
self, memory_key="chat_history", output_key="answer", return_messages=True
|
| 108 |
):
|
|
@@ -111,14 +122,17 @@ class RAGChain:
|
|
| 111 |
self.return_messages = return_messages
|
| 112 |
|
| 113 |
def create(self, retriever, llm, add_citation=False):
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
k=2,
|
| 116 |
memory_key=self.memory_key,
|
| 117 |
return_messages=self.return_messages,
|
| 118 |
output_key=self.output_key,
|
| 119 |
)
|
| 120 |
|
| 121 |
-
# https://github.com/langchain-ai/langchain/issues/4608
|
| 122 |
conversation_chain = ConversationalRetrievalChain.from_llm(
|
| 123 |
llm=llm,
|
| 124 |
retriever=retriever,
|
|
@@ -127,7 +141,6 @@ class RAGChain:
|
|
| 127 |
rephrase_question=False, # disable rephrase, for test purpose
|
| 128 |
get_chat_history=lambda x: x,
|
| 129 |
# return_generated_question=True, # for debug
|
| 130 |
-
# verbose=True,
|
| 131 |
# combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
|
| 132 |
# condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
|
| 133 |
)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main RAG chain based on langchain.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
from langchain.chains import LLMChain
|
| 6 |
|
| 7 |
from langchain.prompts import (
|
|
|
|
| 15 |
from langchain.chains.conversation.memory import (
|
| 16 |
ConversationBufferWindowMemory,
|
| 17 |
)
|
|
|
|
|
|
|
| 18 |
from langchain.chains import StuffDocumentsChain
|
| 19 |
|
| 20 |
|
| 21 |
def get_cite_combine_docs_chain(llm):
|
| 22 |
+
"""Get doc chain which adds metadata to text chunks."""
|
| 23 |
|
| 24 |
# Ref: https://github.com/langchain-ai/langchain/issues/7239
|
| 25 |
# Function to format each document with an index, source, and content.
|
| 26 |
def format_document(doc, index, prompt):
|
| 27 |
"""Format a document into a string based on a prompt template."""
|
| 28 |
+
|
| 29 |
# Create a dictionary with document content and metadata.
|
| 30 |
base_info = {
|
| 31 |
"page_content": doc.page_content,
|
|
|
|
| 44 |
|
| 45 |
# Custom chain class to handle document combination with source indices.
|
| 46 |
class StuffDocumentsWithIndexChain(StuffDocumentsChain):
|
| 47 |
+
"""Custom chain class to handle document combination with source indices."""
|
| 48 |
+
|
| 49 |
def _get_inputs(self, docs, **kwargs):
|
| 50 |
+
"""Overwrite _get_inputs to add metadata for text chunks."""
|
| 51 |
+
|
| 52 |
# Format each document and combine them.
|
| 53 |
doc_strings = [
|
| 54 |
format_document(doc, i, self.document_prompt)
|
|
|
|
| 66 |
)
|
| 67 |
return inputs
|
| 68 |
|
| 69 |
+
# Main prompt for RAG chain with citation
|
| 70 |
# Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
|
| 71 |
# Define a chat prompt with instructions for citing documents.
|
| 72 |
combine_doc_prompt = PromptTemplate(
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
class RAGChain:
|
| 115 |
+
"""Main RAG chain."""
|
| 116 |
+
|
| 117 |
def __init__(
|
| 118 |
self, memory_key="chat_history", output_key="answer", return_messages=True
|
| 119 |
):
|
|
|
|
| 122 |
self.return_messages = return_messages
|
| 123 |
|
| 124 |
def create(self, retriever, llm, add_citation=False):
|
| 125 |
+
"""Create a rag chain instance."""
|
| 126 |
+
|
| 127 |
+
# Memory is kept for later support of conversational chat
|
| 128 |
+
memory = ConversationBufferWindowMemory( # Or ConversationBufferMemory
|
| 129 |
k=2,
|
| 130 |
memory_key=self.memory_key,
|
| 131 |
return_messages=self.return_messages,
|
| 132 |
output_key=self.output_key,
|
| 133 |
)
|
| 134 |
|
| 135 |
+
# Ref: https://github.com/langchain-ai/langchain/issues/4608
|
| 136 |
conversation_chain = ConversationalRetrievalChain.from_llm(
|
| 137 |
llm=llm,
|
| 138 |
retriever=retriever,
|
|
|
|
| 141 |
rephrase_question=False, # disable rephrase, for test purpose
|
| 142 |
get_chat_history=lambda x: x,
|
| 143 |
# return_generated_question=True, # for debug
|
|
|
|
| 144 |
# combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
|
| 145 |
# condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
|
| 146 |
)
|
requirements.txt
CHANGED
|
@@ -5,7 +5,7 @@ langchain-community==0.2.4
|
|
| 5 |
text-generation
|
| 6 |
pypdf
|
| 7 |
pymupdf
|
| 8 |
-
gradio
|
| 9 |
faiss-cpu
|
| 10 |
chromadb
|
| 11 |
rank-bm25
|
|
|
|
| 5 |
text-generation
|
| 6 |
pypdf
|
| 7 |
pymupdf
|
| 8 |
+
gradio==4.44.1
|
| 9 |
faiss-cpu
|
| 10 |
chromadb
|
| 11 |
rank-bm25
|
rerank.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
| 3 |
https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c
|
| 4 |
https://github.com/langchain-ai/langchain/issues/13076
|
| 5 |
"""
|
|
@@ -7,7 +8,7 @@ https://github.com/langchain-ai/langchain/issues/13076
|
|
| 7 |
from __future__ import annotations
|
| 8 |
from typing import Optional, Sequence
|
| 9 |
from langchain.schema import Document
|
| 10 |
-
from langchain.pydantic_v1 import Extra
|
| 11 |
|
| 12 |
from langchain.callbacks.manager import Callbacks
|
| 13 |
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
|
|
| 1 |
"""
|
| 2 |
+
Rerank with cross encoder.
|
| 3 |
+
Ref:
|
| 4 |
https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c
|
| 5 |
https://github.com/langchain-ai/langchain/issues/13076
|
| 6 |
"""
|
|
|
|
| 8 |
from __future__ import annotations
|
| 9 |
from typing import Optional, Sequence
|
| 10 |
from langchain.schema import Document
|
| 11 |
+
from langchain.pydantic_v1 import Extra
|
| 12 |
|
| 13 |
from langchain.callbacks.manager import Callbacks
|
| 14 |
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
retrievers.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
from langchain.text_splitter import (
|
| 4 |
-
CharacterTextSplitter,
|
| 5 |
RecursiveCharacterTextSplitter,
|
| 6 |
SpacyTextSplitter,
|
| 7 |
)
|
|
@@ -9,6 +12,7 @@ from langchain.text_splitter import (
|
|
| 9 |
from rerank import BgeRerank
|
| 10 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 11 |
|
|
|
|
| 12 |
def get_parent_doc_retriever(
|
| 13 |
documents,
|
| 14 |
vectorstore,
|
|
@@ -40,12 +44,14 @@ def get_parent_doc_retriever(
|
|
| 40 |
from langchain_rag.storage import SQLStore
|
| 41 |
|
| 42 |
# Instantiate the SQLStore with the root path
|
| 43 |
-
docstore = SQLStore(
|
|
|
|
|
|
|
| 44 |
else:
|
| 45 |
docstore = docstore # TODO: add check
|
| 46 |
-
# raise # TODO implement
|
| 47 |
|
| 48 |
-
# TODO: how to better set
|
| 49 |
# parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
|
| 50 |
# child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
|
| 51 |
parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
|
@@ -62,11 +68,11 @@ def get_parent_doc_retriever(
|
|
| 62 |
docstore=docstore,
|
| 63 |
child_splitter=child_splitter,
|
| 64 |
parent_splitter=parent_splitter,
|
| 65 |
-
search_kwargs={"k": k},
|
| 66 |
)
|
| 67 |
|
| 68 |
if add_documents:
|
| 69 |
-
retriever.add_documents(documents)
|
| 70 |
|
| 71 |
if save_vectorstore:
|
| 72 |
vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))
|
|
@@ -80,7 +86,6 @@ def get_parent_doc_retriever(
|
|
| 80 |
|
| 81 |
save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))
|
| 82 |
|
| 83 |
-
|
| 84 |
return retriever
|
| 85 |
|
| 86 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Retrievers for text chunks.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
import os
|
| 6 |
|
| 7 |
from langchain.text_splitter import (
|
|
|
|
| 8 |
RecursiveCharacterTextSplitter,
|
| 9 |
SpacyTextSplitter,
|
| 10 |
)
|
|
|
|
| 12 |
from rerank import BgeRerank
|
| 13 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 14 |
|
| 15 |
+
|
| 16 |
def get_parent_doc_retriever(
|
| 17 |
documents,
|
| 18 |
vectorstore,
|
|
|
|
| 44 |
from langchain_rag.storage import SQLStore
|
| 45 |
|
| 46 |
# Instantiate the SQLStore with the root path
|
| 47 |
+
docstore = SQLStore(
|
| 48 |
+
namespace="test", db_url="sqlite:///parent_retrieval_db.db"
|
| 49 |
+
) # TODO: WIP
|
| 50 |
else:
|
| 51 |
docstore = docstore # TODO: add check
|
| 52 |
+
# raise # TODO implement other docstores
|
| 53 |
|
| 54 |
+
# TODO: how to better set these values?
|
| 55 |
# parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
|
| 56 |
# child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
|
| 57 |
parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
|
|
|
|
| 68 |
docstore=docstore,
|
| 69 |
child_splitter=child_splitter,
|
| 70 |
parent_splitter=parent_splitter,
|
| 71 |
+
search_kwargs={"k": k},
|
| 72 |
)
|
| 73 |
|
| 74 |
if add_documents:
|
| 75 |
+
retriever.add_documents(documents)
|
| 76 |
|
| 77 |
if save_vectorstore:
|
| 78 |
vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))
|
|
|
|
| 86 |
|
| 87 |
save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))
|
| 88 |
|
|
|
|
| 89 |
return retriever
|
| 90 |
|
| 91 |
|
vectorestores.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
def get_faiss_vectorestore(embeddings):
|
| 4 |
# Add extra text to init
|
| 5 |
texts = ["LISA - Lithium Ion Solid-state Assistant"]
|
| 6 |
vectorstore = FAISS.from_texts(texts, embeddings)
|
| 7 |
-
|
| 8 |
-
return vectorstore
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vector stores.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from langchain.vectorstores import FAISS
|
| 6 |
+
|
| 7 |
|
| 8 |
def get_faiss_vectorestore(embeddings):
|
| 9 |
# Add extra text to init
|
| 10 |
texts = ["LISA - Lithium Ion Solid-state Assistant"]
|
| 11 |
vectorstore = FAISS.from_texts(texts, embeddings)
|
| 12 |
+
|
| 13 |
+
return vectorstore
|