Spaces:
Sleeping
Sleeping
GitHub Actions Bot
commited on
Commit
·
c379a6e
0
Parent(s):
Changes from ggruber193/polars-docu-chat-rag
Browse files- app.py +84 -0
- requirements.txt +10 -0
- src/config.py +19 -0
- src/data_processing/process_markdown.py +53 -0
- src/database/qdrant_store.py +83 -0
- src/embeddings.py +37 -0
- src/rag_lanchain.py +159 -0
- src/retrieval/retrieval.py +6 -0
- src/testing.py +19 -0
- src/utils.py +12 -0
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from uuid import uuid4
|
3 |
+
|
4 |
+
from langgraph.checkpoint.memory import MemorySaver
|
5 |
+
from langgraph.store.memory import InMemoryStore
|
6 |
+
|
7 |
+
from src.rag_lanchain import graph_builder
|
8 |
+
|
9 |
+
|
10 |
+
memory = MemorySaver()
|
11 |
+
in_memory_store = InMemoryStore()
|
12 |
+
graph = graph_builder.compile(checkpointer=memory, store=in_memory_store)
|
13 |
+
|
14 |
+
|
15 |
+
def respond(msg, config):
|
16 |
+
role_dict = {"ai": "assistant", "human": "user"}
|
17 |
+
if len(msg) == 0:
|
18 |
+
gr.Warning("Chat messages cannot be empty")
|
19 |
+
history = []
|
20 |
+
for hist in graph.get_state_history(config):
|
21 |
+
history = [{"role": role_dict.get(i.type, i.type), "content": i.content} for i in
|
22 |
+
hist.values["messages"]]
|
23 |
+
break
|
24 |
+
return "", history
|
25 |
+
events = graph.stream(
|
26 |
+
{"messages": [{"role": "user", "content": msg}]},
|
27 |
+
config,
|
28 |
+
stream_mode="values",
|
29 |
+
)
|
30 |
+
events = list(events)
|
31 |
+
conversation = events[-1]["messages"]
|
32 |
+
conversation = [{"role": role_dict.get(i.type, i.type), "content": i.content} for i in conversation]
|
33 |
+
return "", conversation
|
34 |
+
|
35 |
+
|
36 |
+
def init_chat_state():
|
37 |
+
return {"configurable": {"thread_id": str(uuid4()).replace('-', '_')}}
|
38 |
+
|
39 |
+
|
40 |
+
css = """
|
41 |
+
.centered-container {
|
42 |
+
max-width: 1000px;
|
43 |
+
margin: 0 auto;
|
44 |
+
}
|
45 |
+
"""
|
46 |
+
|
47 |
+
THEME = gr.themes.Ocean()
|
48 |
+
|
49 |
+
demo = gr.Blocks(theme=THEME, fill_width=False, fill_height=True, css=css)
|
50 |
+
|
51 |
+
with demo:
|
52 |
+
config_state = gr.State(init_chat_state)
|
53 |
+
with gr.Column(elem_classes="centered-container"):
|
54 |
+
gr.Markdown("""
|
55 |
+
# 💬 Polars Python Chatbot
|
56 |
+
### Ask anything about the [Polars](https://pola-rs.github.io/polars/) Python package!
|
57 |
+
### This chatbot uses a database of embeddings generated from the official documentation to help you find accurate and relevant answers about using Polars for data manipulation in Python.
|
58 |
+
""")
|
59 |
+
|
60 |
+
chatbot = gr.Chatbot(
|
61 |
+
label=None,
|
62 |
+
type="messages",
|
63 |
+
show_label=False,
|
64 |
+
height=400,
|
65 |
+
)
|
66 |
+
|
67 |
+
with gr.Row(equal_height=True):
|
68 |
+
msg = gr.Textbox(
|
69 |
+
placeholder="Type your message here...",
|
70 |
+
show_label=False,
|
71 |
+
lines=3,
|
72 |
+
max_lines=3,
|
73 |
+
scale=5,
|
74 |
+
)
|
75 |
+
send_btn = gr.Button("Send", variant="primary", scale=1)
|
76 |
+
|
77 |
+
with gr.Row():
|
78 |
+
clear = gr.ClearButton([msg, chatbot], value="Clear Chat", variant="secondary")
|
79 |
+
|
80 |
+
send_btn.click(respond, [msg, config_state], [msg, chatbot])
|
81 |
+
msg.submit(respond, [msg, config_state], [msg, chatbot])
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
beautifulsoup4~=4.13.4
|
2 |
+
markdown~=3.8
|
3 |
+
langchain~=0.3.23
|
4 |
+
transformers~=4.51.3
|
5 |
+
torch~=2.6.0
|
6 |
+
tqdm~=4.67.1
|
7 |
+
qdrant_client
|
8 |
+
langgraph~=0.3.31
|
9 |
+
gradio~=5.25.2
|
10 |
+
langchain_google_genai~=2.1.3
|
src/config.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
EMBEDDING_MODEL = "thenlper/gte-small"
|
5 |
+
|
6 |
+
QDRANT_COLLECTION_NAME = "polars-documentation"
|
7 |
+
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
8 |
+
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
9 |
+
CHAT_API_KEY = os.environ.get("CHAT_API_KEY", "")
|
10 |
+
|
11 |
+
|
12 |
+
def get_qdrant_config():
|
13 |
+
from qdrant_client import models
|
14 |
+
QDRANT_COLLECTION_CONFIG = {
|
15 |
+
"collection_name": QDRANT_COLLECTION_NAME,
|
16 |
+
"vectors_config": models.VectorParams(size=384, distance=models.Distance.COSINE), # on_disk=True),
|
17 |
+
# "hnsw_config": models.HnswConfigDiff(on_disk=True)
|
18 |
+
}
|
19 |
+
return QDRANT_COLLECTION_CONFIG
|
src/data_processing/process_markdown.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
from langchain_core.documents import Document
|
5 |
+
from markdown import markdown
|
6 |
+
from pathlib import Path
|
7 |
+
from langchain.text_splitter import MarkdownTextSplitter, MarkdownHeaderTextSplitter, TextSplitter
|
8 |
+
|
9 |
+
from src.utils import batched
|
10 |
+
|
11 |
+
|
12 |
+
def read_markdown_file(path: str | Path) -> [str, str]:
|
13 |
+
path = Path(path)
|
14 |
+
with open(path, 'r', encoding="utf8") as f_r:
|
15 |
+
text = f_r.read()
|
16 |
+
|
17 |
+
# text = markdown(text)
|
18 |
+
# text = ''.join(BeautifulSoup(text).findAll(text=True))
|
19 |
+
return text, str(path)
|
20 |
+
|
21 |
+
|
22 |
+
def split_markdown(md: str | list[str],
|
23 |
+
metadata=dict[str, Any] | list[dict[str, Any]],
|
24 |
+
chunk_size=512,
|
25 |
+
overlap=64,
|
26 |
+
splitter: TextSplitter = None) -> list[Document]:
|
27 |
+
if isinstance(md, str):
|
28 |
+
md = [md]
|
29 |
+
if isinstance(metadata, list):
|
30 |
+
raise ValueError("metadata should be a single dict")
|
31 |
+
metadata = [metadata]
|
32 |
+
if splitter is None:
|
33 |
+
headers_to_split_on = [
|
34 |
+
("#", "Header 1"),
|
35 |
+
("##", "Header 2"),
|
36 |
+
("###", "Header 3"),
|
37 |
+
]
|
38 |
+
md = [MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False).split_text(i) for i in md]
|
39 |
+
metadata = [{**metadata[i], **text.metadata} for i, text_split in enumerate(md) for text in text_split]
|
40 |
+
md = [j.page_content for i in md for j in i]
|
41 |
+
splitter = MarkdownTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
|
42 |
+
|
43 |
+
docs = splitter.create_documents(md, metadata)
|
44 |
+
return docs
|
45 |
+
|
46 |
+
|
47 |
+
def process_markdown_files(paths: list[str | Path], batch_size=1, chunk_size=512, overlap=64):
|
48 |
+
for files in batched(paths, batch_size):
|
49 |
+
mds_w_paths = [read_markdown_file(i) for i in files]
|
50 |
+
metadata = [{"path": md_path} for _, md_path in mds_w_paths]
|
51 |
+
md = [md for md, _ in mds_w_paths]
|
52 |
+
docs = split_markdown(md, metadata, chunk_size=chunk_size, overlap=overlap)
|
53 |
+
yield [i.page_content for i in docs], [i.metadata for i in docs]
|
src/database/qdrant_store.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from qdrant_client import QdrantClient, models
|
4 |
+
from uuid import uuid4
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
|
7 |
+
from src.config import QDRANT_COLLECTION_NAME, QDRANT_URL, QDRANT_API_KEY, EMBEDDING_MODEL
|
8 |
+
from src.embeddings import TextEmbedder
|
9 |
+
|
10 |
+
|
11 |
+
class QdrantStore:
|
12 |
+
def __init__(self, client: QdrantClient, collection_config=None):
|
13 |
+
self.client = client
|
14 |
+
self.collection_names = set([i.name for i in client.get_collections().collections])
|
15 |
+
|
16 |
+
if collection_config is not None:
|
17 |
+
self.create_collection(collection_config)
|
18 |
+
|
19 |
+
def create_collection(self, collection_config: dict):
|
20 |
+
collection_name = collection_config["collection_name"]
|
21 |
+
if not self.client.collection_exists(collection_name):
|
22 |
+
self.client.create_collection(**collection_config)
|
23 |
+
self.collection_names.add(collection_name)
|
24 |
+
|
25 |
+
def _check_collection_name(self, collection_name):
|
26 |
+
if collection_name not in self.collection_names:
|
27 |
+
raise ValueError(f"Collection: {collection_name} does not exist.")
|
28 |
+
|
29 |
+
def upsert_points(self,
|
30 |
+
vectors: Any | list[Any],
|
31 |
+
payloads: dict | list[dict],
|
32 |
+
collection_name: str):
|
33 |
+
self._check_collection_name(collection_name)
|
34 |
+
|
35 |
+
ids = [str(uuid4()) for _ in payloads]
|
36 |
+
|
37 |
+
self.client.upsert(
|
38 |
+
collection_name=collection_name,
|
39 |
+
points=models.Batch(
|
40 |
+
ids=ids,
|
41 |
+
payloads=payloads,
|
42 |
+
vectors=vectors
|
43 |
+
)
|
44 |
+
)
|
45 |
+
|
46 |
+
def delete_points(self,
|
47 |
+
filters: dict[str, list[models.FieldCondition]],
|
48 |
+
collection_name: str):
|
49 |
+
self._check_collection_name(collection_name)
|
50 |
+
|
51 |
+
self.client.delete(
|
52 |
+
collection_name=collection_name,
|
53 |
+
points_selector=models.Filter(**filters)
|
54 |
+
)
|
55 |
+
|
56 |
+
def delete_points_by_match(self,
|
57 |
+
key_value: tuple[str, list[str] | str],
|
58 |
+
collection_name: str):
|
59 |
+
key, values = key_value
|
60 |
+
if isinstance(values, str):
|
61 |
+
values = [values]
|
62 |
+
filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]}
|
63 |
+
self.delete_points(filter, collection_name)
|
64 |
+
|
65 |
+
def get_topk_points_single(self,
|
66 |
+
query: Any | str,
|
67 |
+
collection_name: str,
|
68 |
+
k=5):
|
69 |
+
responses = self.client.query_points(collection_name=collection_name,
|
70 |
+
query=query,
|
71 |
+
limit=k)
|
72 |
+
|
73 |
+
return [i.payload["text"] for i in responses.points]
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY)
|
78 |
+
qdrant_store = QdrantStore(client)
|
79 |
+
embedding_model = TextEmbedder(modelname=EMBEDDING_MODEL)
|
80 |
+
query = "How to filter a dataframe"
|
81 |
+
query_emb = embedding_model.embed_text(query)
|
82 |
+
responses = qdrant_store.get_topk_points_single(query_emb[0], collection_name=QDRANT_COLLECTION_NAME)
|
83 |
+
print(responses)
|
src/embeddings.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
from torch import Tensor
|
3 |
+
from torch import functional as F
|
4 |
+
|
5 |
+
from src.config import EMBEDDING_MODEL
|
6 |
+
from src.utils import batched
|
7 |
+
|
8 |
+
|
9 |
+
class TextEmbedder:
|
10 |
+
def __init__(self, modelname=EMBEDDING_MODEL, max_length=512):
|
11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelname)
|
12 |
+
self.model = AutoModel.from_pretrained(modelname)
|
13 |
+
self.max_length = max_length
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def average_pool(last_hidden_states: Tensor,
|
17 |
+
attention_mask: Tensor) -> Tensor:
|
18 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
19 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
20 |
+
|
21 |
+
def embed_text(self, text: str | list[str], batch_size=128):
|
22 |
+
if isinstance(text, str):
|
23 |
+
text = [text]
|
24 |
+
|
25 |
+
outputs = []
|
26 |
+
|
27 |
+
for batch in batched(text, n=batch_size):
|
28 |
+
batch_dict = self.tokenizer(batch, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt')
|
29 |
+
output = self.model(**batch_dict)
|
30 |
+
embeddings = self.average_pool(output.last_hidden_state, batch_dict['attention_mask'])
|
31 |
+
|
32 |
+
# embeddings = F.norm(embeddings, p=2, dim=1)
|
33 |
+
# scores = (embeddings[:1] @ embeddings[1:].T) * 100
|
34 |
+
|
35 |
+
embeddings = embeddings.tolist()
|
36 |
+
outputs += embeddings
|
37 |
+
return outputs
|
src/rag_lanchain.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Any, List
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from langgraph.constants import END
|
6 |
+
from qdrant_client import QdrantClient
|
7 |
+
|
8 |
+
from langchain_core.prompts import PromptTemplate
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
from langchain_core.rate_limiters import InMemoryRateLimiter, BaseRateLimiter
|
11 |
+
from langchain_core.tools import tool
|
12 |
+
from langchain_core.messages import SystemMessage
|
13 |
+
|
14 |
+
from langgraph.graph import START, StateGraph, Graph, MessagesState
|
15 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
16 |
+
from langgraph.checkpoint.memory import MemorySaver
|
17 |
+
|
18 |
+
from src.database.qdrant_store import QdrantStore
|
19 |
+
from src.embeddings import TextEmbedder
|
20 |
+
from src.config import EMBEDDING_MODEL, QDRANT_COLLECTION_NAME, CHAT_API_KEY, QDRANT_URL, QDRANT_API_KEY
|
21 |
+
|
22 |
+
RAG_PROMPT_STR = """
|
23 |
+
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
|
24 |
+
\n
|
25 |
+
{context}
|
26 |
+
"""
|
27 |
+
RAG_PROMPT = PromptTemplate.from_template(RAG_PROMPT_STR)
|
28 |
+
|
29 |
+
embedding_model = TextEmbedder(modelname=EMBEDDING_MODEL)
|
30 |
+
|
31 |
+
client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY)
|
32 |
+
qdrant_store = QdrantStore(client)
|
33 |
+
|
34 |
+
rate_limiter = InMemoryRateLimiter(
|
35 |
+
requests_per_second=0.25, # <-- Super slow! We can only make a request once every 10 seconds!!
|
36 |
+
check_every_n_seconds=0.1, # Wake up every 100 ms to check whether allowed to make a request,
|
37 |
+
max_bucket_size=15, # Controls the maximum burst size.
|
38 |
+
)
|
39 |
+
|
40 |
+
llm = ChatGoogleGenerativeAI(
|
41 |
+
model="gemini-2.0-flash-001",
|
42 |
+
google_api_key=CHAT_API_KEY,
|
43 |
+
temperature=0,
|
44 |
+
max_tokens=None,
|
45 |
+
timeout=None,
|
46 |
+
max_retries=2,
|
47 |
+
)
|
48 |
+
|
49 |
+
# init_chat_model("google_vertexai:gemini-2.0-flash", rate_limiter=rate_limiter, )
|
50 |
+
|
51 |
+
|
52 |
+
class State(TypedDict):
|
53 |
+
question: str
|
54 |
+
context: List[str]
|
55 |
+
answer: str
|
56 |
+
|
57 |
+
|
58 |
+
def query_or_respond(state: MessagesState):
|
59 |
+
llm_with_tools = llm.bind_tools([retrieve])
|
60 |
+
response = llm_with_tools.invoke(state["messages"])
|
61 |
+
return {"messages": [response]}
|
62 |
+
|
63 |
+
|
64 |
+
@tool
|
65 |
+
def retrieve(query: str):
|
66 |
+
"""Retrieve information related to a query, specific to the python polars package"""
|
67 |
+
retrieved_docs = []
|
68 |
+
if qdrant_store is not None:
|
69 |
+
query = embedding_model.embed_text(query)
|
70 |
+
retrieved_docs = qdrant_store.get_topk_points_single(query[0], QDRANT_COLLECTION_NAME, k=5)
|
71 |
+
else:
|
72 |
+
retrieved_docs = []
|
73 |
+
return '\n\n'.join(retrieved_docs)
|
74 |
+
|
75 |
+
|
76 |
+
def generate(state: MessagesState):
|
77 |
+
recent_tool_messages = []
|
78 |
+
for message in reversed(state["messages"]):
|
79 |
+
if message.type == "tool":
|
80 |
+
recent_tool_messages.append(message)
|
81 |
+
else:
|
82 |
+
break
|
83 |
+
tool_messages = recent_tool_messages[::-1]
|
84 |
+
system_message_content = RAG_PROMPT_STR.format(context=tool_messages)
|
85 |
+
conversation_messages = [
|
86 |
+
message
|
87 |
+
for message in state["messages"]
|
88 |
+
if message.type in ("human", "system")
|
89 |
+
or (message.type == "ai" and not message.tool_calls)
|
90 |
+
]
|
91 |
+
prompt = [SystemMessage(system_message_content)] + conversation_messages
|
92 |
+
|
93 |
+
response = llm.invoke(prompt)
|
94 |
+
return {"messages": [response]}
|
95 |
+
|
96 |
+
|
97 |
+
tools = ToolNode([retrieve])
|
98 |
+
|
99 |
+
graph_builder = StateGraph(MessagesState)
|
100 |
+
|
101 |
+
graph_builder.add_node(query_or_respond)
|
102 |
+
graph_builder.add_node(tools)
|
103 |
+
graph_builder.add_node(generate)
|
104 |
+
|
105 |
+
graph_builder.set_entry_point("query_or_respond")
|
106 |
+
graph_builder.add_conditional_edges(
|
107 |
+
"query_or_respond",
|
108 |
+
tools_condition,
|
109 |
+
{END: END, "tools": "tools"},
|
110 |
+
)
|
111 |
+
graph_builder.add_edge("tools", "generate")
|
112 |
+
graph_builder.add_edge("generate", END)
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
memory = MemorySaver()
|
116 |
+
graph = graph_builder.compile(checkpointer=memory)
|
117 |
+
config = {"configurable": {"thread_id": "def234"}}
|
118 |
+
|
119 |
+
user_input = "Hi there! My name is Will."
|
120 |
+
|
121 |
+
# The config is the **second positional argument** to stream() or invoke()!
|
122 |
+
events = graph.stream(
|
123 |
+
{"messages": [{"role": "user", "content": user_input}]},
|
124 |
+
config,
|
125 |
+
stream_mode="values",
|
126 |
+
)
|
127 |
+
for event in events:
|
128 |
+
event["messages"][-1].pretty_print()
|
129 |
+
|
130 |
+
print(graph.get_state(config))
|
131 |
+
print(memory.get(config))
|
132 |
+
|
133 |
+
user_input = "Remember my name?"
|
134 |
+
config = {"configurable": {"thread_id": "def234"}}
|
135 |
+
|
136 |
+
# The config is the **second positional argument** to stream() or invoke()!
|
137 |
+
events = graph.stream(
|
138 |
+
{"messages": [{"role": "user", "content": user_input}]},
|
139 |
+
config,
|
140 |
+
stream_mode="values",
|
141 |
+
)
|
142 |
+
for event in events:
|
143 |
+
event["messages"][-1].pretty_print()
|
144 |
+
|
145 |
+
print(graph.get_state(config))
|
146 |
+
print(memory.get(config))
|
147 |
+
|
148 |
+
user_input = "Remember my name?"
|
149 |
+
config = {"configurable": {"thread_id": "ddef234"}}
|
150 |
+
|
151 |
+
# The config is the **second positional argument** to stream() or invoke()!
|
152 |
+
events = graph.stream(
|
153 |
+
{"messages": [{"role": "user", "content": user_input}]},
|
154 |
+
config,
|
155 |
+
stream_mode="values",
|
156 |
+
)
|
157 |
+
for event in events:
|
158 |
+
event["messages"][-1].pretty_print()
|
159 |
+
|
src/retrieval/retrieval.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from src.database.qdrant_store import QdrantStore
|
3 |
+
from src.embeddings import TextEmbedder
|
4 |
+
|
5 |
+
def embed_query(query: str | list[str], ):
|
6 |
+
pass
|
src/testing.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bs4 import BeautifulSoup
|
2 |
+
from markdown import markdown
|
3 |
+
from langchain.text_splitter import MarkdownTextSplitter
|
4 |
+
|
5 |
+
|
6 |
+
path = "D:\PycharmProjects\polargs-docu-chat-rag\data\polars-docu\concepts\data-types-and-structures.md"
|
7 |
+
|
8 |
+
with open(path, 'r', encoding="utf8") as f_r:
|
9 |
+
test_md = f_r.read()
|
10 |
+
|
11 |
+
html = markdown(test_md)
|
12 |
+
text = ''.join(BeautifulSoup(html).findAll(text=True))
|
13 |
+
|
14 |
+
print(text[:10])
|
15 |
+
|
16 |
+
splitter = MarkdownTextSplitter(chunk_size=512, chunk_overlap=64)
|
17 |
+
|
18 |
+
docs = splitter.create_documents([text])
|
19 |
+
print(docs)
|
src/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import islice
|
2 |
+
|
3 |
+
|
4 |
+
def batched(iterable, n, *, strict=False):
|
5 |
+
# batched('ABCDEFG', 3) → ABC DEF G
|
6 |
+
if n < 1:
|
7 |
+
raise ValueError('n must be at least one')
|
8 |
+
iterator = iter(iterable)
|
9 |
+
while batch := tuple(islice(iterator, n)):
|
10 |
+
if strict and len(batch) != n:
|
11 |
+
raise ValueError('batched(): incomplete batch')
|
12 |
+
yield batch
|