GitHub Actions Bot commited on
Commit
c379a6e
·
0 Parent(s):

Changes from ggruber193/polars-docu-chat-rag

Browse files
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from uuid import uuid4
3
+
4
+ from langgraph.checkpoint.memory import MemorySaver
5
+ from langgraph.store.memory import InMemoryStore
6
+
7
+ from src.rag_lanchain import graph_builder
8
+
9
+
10
+ memory = MemorySaver()
11
+ in_memory_store = InMemoryStore()
12
+ graph = graph_builder.compile(checkpointer=memory, store=in_memory_store)
13
+
14
+
15
+ def respond(msg, config):
16
+ role_dict = {"ai": "assistant", "human": "user"}
17
+ if len(msg) == 0:
18
+ gr.Warning("Chat messages cannot be empty")
19
+ history = []
20
+ for hist in graph.get_state_history(config):
21
+ history = [{"role": role_dict.get(i.type, i.type), "content": i.content} for i in
22
+ hist.values["messages"]]
23
+ break
24
+ return "", history
25
+ events = graph.stream(
26
+ {"messages": [{"role": "user", "content": msg}]},
27
+ config,
28
+ stream_mode="values",
29
+ )
30
+ events = list(events)
31
+ conversation = events[-1]["messages"]
32
+ conversation = [{"role": role_dict.get(i.type, i.type), "content": i.content} for i in conversation]
33
+ return "", conversation
34
+
35
+
36
+ def init_chat_state():
37
+ return {"configurable": {"thread_id": str(uuid4()).replace('-', '_')}}
38
+
39
+
40
+ css = """
41
+ .centered-container {
42
+ max-width: 1000px;
43
+ margin: 0 auto;
44
+ }
45
+ """
46
+
47
+ THEME = gr.themes.Ocean()
48
+
49
+ demo = gr.Blocks(theme=THEME, fill_width=False, fill_height=True, css=css)
50
+
51
+ with demo:
52
+ config_state = gr.State(init_chat_state)
53
+ with gr.Column(elem_classes="centered-container"):
54
+ gr.Markdown("""
55
+ # 💬 Polars Python Chatbot
56
+ ### Ask anything about the [Polars](https://pola-rs.github.io/polars/) Python package!
57
+ ### This chatbot uses a database of embeddings generated from the official documentation to help you find accurate and relevant answers about using Polars for data manipulation in Python.
58
+ """)
59
+
60
+ chatbot = gr.Chatbot(
61
+ label=None,
62
+ type="messages",
63
+ show_label=False,
64
+ height=400,
65
+ )
66
+
67
+ with gr.Row(equal_height=True):
68
+ msg = gr.Textbox(
69
+ placeholder="Type your message here...",
70
+ show_label=False,
71
+ lines=3,
72
+ max_lines=3,
73
+ scale=5,
74
+ )
75
+ send_btn = gr.Button("Send", variant="primary", scale=1)
76
+
77
+ with gr.Row():
78
+ clear = gr.ClearButton([msg, chatbot], value="Clear Chat", variant="secondary")
79
+
80
+ send_btn.click(respond, [msg, config_state], [msg, chatbot])
81
+ msg.submit(respond, [msg, config_state], [msg, chatbot])
82
+
83
+ if __name__ == '__main__':
84
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ beautifulsoup4~=4.13.4
2
+ markdown~=3.8
3
+ langchain~=0.3.23
4
+ transformers~=4.51.3
5
+ torch~=2.6.0
6
+ tqdm~=4.67.1
7
+ qdrant_client
8
+ langgraph~=0.3.31
9
+ gradio~=5.25.2
10
+ langchain_google_genai~=2.1.3
src/config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ EMBEDDING_MODEL = "thenlper/gte-small"
5
+
6
+ QDRANT_COLLECTION_NAME = "polars-documentation"
7
+ QDRANT_URL = os.environ.get("QDRANT_URL", "")
8
+ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
9
+ CHAT_API_KEY = os.environ.get("CHAT_API_KEY", "")
10
+
11
+
12
+ def get_qdrant_config():
13
+ from qdrant_client import models
14
+ QDRANT_COLLECTION_CONFIG = {
15
+ "collection_name": QDRANT_COLLECTION_NAME,
16
+ "vectors_config": models.VectorParams(size=384, distance=models.Distance.COSINE), # on_disk=True),
17
+ # "hnsw_config": models.HnswConfigDiff(on_disk=True)
18
+ }
19
+ return QDRANT_COLLECTION_CONFIG
src/data_processing/process_markdown.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from bs4 import BeautifulSoup
4
+ from langchain_core.documents import Document
5
+ from markdown import markdown
6
+ from pathlib import Path
7
+ from langchain.text_splitter import MarkdownTextSplitter, MarkdownHeaderTextSplitter, TextSplitter
8
+
9
+ from src.utils import batched
10
+
11
+
12
+ def read_markdown_file(path: str | Path) -> [str, str]:
13
+ path = Path(path)
14
+ with open(path, 'r', encoding="utf8") as f_r:
15
+ text = f_r.read()
16
+
17
+ # text = markdown(text)
18
+ # text = ''.join(BeautifulSoup(text).findAll(text=True))
19
+ return text, str(path)
20
+
21
+
22
+ def split_markdown(md: str | list[str],
23
+ metadata=dict[str, Any] | list[dict[str, Any]],
24
+ chunk_size=512,
25
+ overlap=64,
26
+ splitter: TextSplitter = None) -> list[Document]:
27
+ if isinstance(md, str):
28
+ md = [md]
29
+ if isinstance(metadata, list):
30
+ raise ValueError("metadata should be a single dict")
31
+ metadata = [metadata]
32
+ if splitter is None:
33
+ headers_to_split_on = [
34
+ ("#", "Header 1"),
35
+ ("##", "Header 2"),
36
+ ("###", "Header 3"),
37
+ ]
38
+ md = [MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False).split_text(i) for i in md]
39
+ metadata = [{**metadata[i], **text.metadata} for i, text_split in enumerate(md) for text in text_split]
40
+ md = [j.page_content for i in md for j in i]
41
+ splitter = MarkdownTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
42
+
43
+ docs = splitter.create_documents(md, metadata)
44
+ return docs
45
+
46
+
47
+ def process_markdown_files(paths: list[str | Path], batch_size=1, chunk_size=512, overlap=64):
48
+ for files in batched(paths, batch_size):
49
+ mds_w_paths = [read_markdown_file(i) for i in files]
50
+ metadata = [{"path": md_path} for _, md_path in mds_w_paths]
51
+ md = [md for md, _ in mds_w_paths]
52
+ docs = split_markdown(md, metadata, chunk_size=chunk_size, overlap=overlap)
53
+ yield [i.page_content for i in docs], [i.metadata for i in docs]
src/database/qdrant_store.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from qdrant_client import QdrantClient, models
4
+ from uuid import uuid4
5
+ from transformers import PreTrainedModel
6
+
7
+ from src.config import QDRANT_COLLECTION_NAME, QDRANT_URL, QDRANT_API_KEY, EMBEDDING_MODEL
8
+ from src.embeddings import TextEmbedder
9
+
10
+
11
+ class QdrantStore:
12
+ def __init__(self, client: QdrantClient, collection_config=None):
13
+ self.client = client
14
+ self.collection_names = set([i.name for i in client.get_collections().collections])
15
+
16
+ if collection_config is not None:
17
+ self.create_collection(collection_config)
18
+
19
+ def create_collection(self, collection_config: dict):
20
+ collection_name = collection_config["collection_name"]
21
+ if not self.client.collection_exists(collection_name):
22
+ self.client.create_collection(**collection_config)
23
+ self.collection_names.add(collection_name)
24
+
25
+ def _check_collection_name(self, collection_name):
26
+ if collection_name not in self.collection_names:
27
+ raise ValueError(f"Collection: {collection_name} does not exist.")
28
+
29
+ def upsert_points(self,
30
+ vectors: Any | list[Any],
31
+ payloads: dict | list[dict],
32
+ collection_name: str):
33
+ self._check_collection_name(collection_name)
34
+
35
+ ids = [str(uuid4()) for _ in payloads]
36
+
37
+ self.client.upsert(
38
+ collection_name=collection_name,
39
+ points=models.Batch(
40
+ ids=ids,
41
+ payloads=payloads,
42
+ vectors=vectors
43
+ )
44
+ )
45
+
46
+ def delete_points(self,
47
+ filters: dict[str, list[models.FieldCondition]],
48
+ collection_name: str):
49
+ self._check_collection_name(collection_name)
50
+
51
+ self.client.delete(
52
+ collection_name=collection_name,
53
+ points_selector=models.Filter(**filters)
54
+ )
55
+
56
+ def delete_points_by_match(self,
57
+ key_value: tuple[str, list[str] | str],
58
+ collection_name: str):
59
+ key, values = key_value
60
+ if isinstance(values, str):
61
+ values = [values]
62
+ filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]}
63
+ self.delete_points(filter, collection_name)
64
+
65
+ def get_topk_points_single(self,
66
+ query: Any | str,
67
+ collection_name: str,
68
+ k=5):
69
+ responses = self.client.query_points(collection_name=collection_name,
70
+ query=query,
71
+ limit=k)
72
+
73
+ return [i.payload["text"] for i in responses.points]
74
+
75
+
76
+ if __name__ == '__main__':
77
+ client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY)
78
+ qdrant_store = QdrantStore(client)
79
+ embedding_model = TextEmbedder(modelname=EMBEDDING_MODEL)
80
+ query = "How to filter a dataframe"
81
+ query_emb = embedding_model.embed_text(query)
82
+ responses = qdrant_store.get_topk_points_single(query_emb[0], collection_name=QDRANT_COLLECTION_NAME)
83
+ print(responses)
src/embeddings.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ from torch import Tensor
3
+ from torch import functional as F
4
+
5
+ from src.config import EMBEDDING_MODEL
6
+ from src.utils import batched
7
+
8
+
9
+ class TextEmbedder:
10
+ def __init__(self, modelname=EMBEDDING_MODEL, max_length=512):
11
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
12
+ self.model = AutoModel.from_pretrained(modelname)
13
+ self.max_length = max_length
14
+
15
+ @staticmethod
16
+ def average_pool(last_hidden_states: Tensor,
17
+ attention_mask: Tensor) -> Tensor:
18
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
19
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
20
+
21
+ def embed_text(self, text: str | list[str], batch_size=128):
22
+ if isinstance(text, str):
23
+ text = [text]
24
+
25
+ outputs = []
26
+
27
+ for batch in batched(text, n=batch_size):
28
+ batch_dict = self.tokenizer(batch, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt')
29
+ output = self.model(**batch_dict)
30
+ embeddings = self.average_pool(output.last_hidden_state, batch_dict['attention_mask'])
31
+
32
+ # embeddings = F.norm(embeddings, p=2, dim=1)
33
+ # scores = (embeddings[:1] @ embeddings[1:].T) * 100
34
+
35
+ embeddings = embeddings.tolist()
36
+ outputs += embeddings
37
+ return outputs
src/rag_lanchain.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Any, List
2
+ import os
3
+ from functools import partial
4
+
5
+ from langgraph.constants import END
6
+ from qdrant_client import QdrantClient
7
+
8
+ from langchain_core.prompts import PromptTemplate
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_core.rate_limiters import InMemoryRateLimiter, BaseRateLimiter
11
+ from langchain_core.tools import tool
12
+ from langchain_core.messages import SystemMessage
13
+
14
+ from langgraph.graph import START, StateGraph, Graph, MessagesState
15
+ from langgraph.prebuilt import ToolNode, tools_condition
16
+ from langgraph.checkpoint.memory import MemorySaver
17
+
18
+ from src.database.qdrant_store import QdrantStore
19
+ from src.embeddings import TextEmbedder
20
+ from src.config import EMBEDDING_MODEL, QDRANT_COLLECTION_NAME, CHAT_API_KEY, QDRANT_URL, QDRANT_API_KEY
21
+
22
+ RAG_PROMPT_STR = """
23
+ You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
24
+ \n
25
+ {context}
26
+ """
27
+ RAG_PROMPT = PromptTemplate.from_template(RAG_PROMPT_STR)
28
+
29
+ embedding_model = TextEmbedder(modelname=EMBEDDING_MODEL)
30
+
31
+ client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY)
32
+ qdrant_store = QdrantStore(client)
33
+
34
+ rate_limiter = InMemoryRateLimiter(
35
+ requests_per_second=0.25, # <-- Super slow! We can only make a request once every 10 seconds!!
36
+ check_every_n_seconds=0.1, # Wake up every 100 ms to check whether allowed to make a request,
37
+ max_bucket_size=15, # Controls the maximum burst size.
38
+ )
39
+
40
+ llm = ChatGoogleGenerativeAI(
41
+ model="gemini-2.0-flash-001",
42
+ google_api_key=CHAT_API_KEY,
43
+ temperature=0,
44
+ max_tokens=None,
45
+ timeout=None,
46
+ max_retries=2,
47
+ )
48
+
49
+ # init_chat_model("google_vertexai:gemini-2.0-flash", rate_limiter=rate_limiter, )
50
+
51
+
52
+ class State(TypedDict):
53
+ question: str
54
+ context: List[str]
55
+ answer: str
56
+
57
+
58
+ def query_or_respond(state: MessagesState):
59
+ llm_with_tools = llm.bind_tools([retrieve])
60
+ response = llm_with_tools.invoke(state["messages"])
61
+ return {"messages": [response]}
62
+
63
+
64
+ @tool
65
+ def retrieve(query: str):
66
+ """Retrieve information related to a query, specific to the python polars package"""
67
+ retrieved_docs = []
68
+ if qdrant_store is not None:
69
+ query = embedding_model.embed_text(query)
70
+ retrieved_docs = qdrant_store.get_topk_points_single(query[0], QDRANT_COLLECTION_NAME, k=5)
71
+ else:
72
+ retrieved_docs = []
73
+ return '\n\n'.join(retrieved_docs)
74
+
75
+
76
+ def generate(state: MessagesState):
77
+ recent_tool_messages = []
78
+ for message in reversed(state["messages"]):
79
+ if message.type == "tool":
80
+ recent_tool_messages.append(message)
81
+ else:
82
+ break
83
+ tool_messages = recent_tool_messages[::-1]
84
+ system_message_content = RAG_PROMPT_STR.format(context=tool_messages)
85
+ conversation_messages = [
86
+ message
87
+ for message in state["messages"]
88
+ if message.type in ("human", "system")
89
+ or (message.type == "ai" and not message.tool_calls)
90
+ ]
91
+ prompt = [SystemMessage(system_message_content)] + conversation_messages
92
+
93
+ response = llm.invoke(prompt)
94
+ return {"messages": [response]}
95
+
96
+
97
+ tools = ToolNode([retrieve])
98
+
99
+ graph_builder = StateGraph(MessagesState)
100
+
101
+ graph_builder.add_node(query_or_respond)
102
+ graph_builder.add_node(tools)
103
+ graph_builder.add_node(generate)
104
+
105
+ graph_builder.set_entry_point("query_or_respond")
106
+ graph_builder.add_conditional_edges(
107
+ "query_or_respond",
108
+ tools_condition,
109
+ {END: END, "tools": "tools"},
110
+ )
111
+ graph_builder.add_edge("tools", "generate")
112
+ graph_builder.add_edge("generate", END)
113
+
114
+ if __name__ == '__main__':
115
+ memory = MemorySaver()
116
+ graph = graph_builder.compile(checkpointer=memory)
117
+ config = {"configurable": {"thread_id": "def234"}}
118
+
119
+ user_input = "Hi there! My name is Will."
120
+
121
+ # The config is the **second positional argument** to stream() or invoke()!
122
+ events = graph.stream(
123
+ {"messages": [{"role": "user", "content": user_input}]},
124
+ config,
125
+ stream_mode="values",
126
+ )
127
+ for event in events:
128
+ event["messages"][-1].pretty_print()
129
+
130
+ print(graph.get_state(config))
131
+ print(memory.get(config))
132
+
133
+ user_input = "Remember my name?"
134
+ config = {"configurable": {"thread_id": "def234"}}
135
+
136
+ # The config is the **second positional argument** to stream() or invoke()!
137
+ events = graph.stream(
138
+ {"messages": [{"role": "user", "content": user_input}]},
139
+ config,
140
+ stream_mode="values",
141
+ )
142
+ for event in events:
143
+ event["messages"][-1].pretty_print()
144
+
145
+ print(graph.get_state(config))
146
+ print(memory.get(config))
147
+
148
+ user_input = "Remember my name?"
149
+ config = {"configurable": {"thread_id": "ddef234"}}
150
+
151
+ # The config is the **second positional argument** to stream() or invoke()!
152
+ events = graph.stream(
153
+ {"messages": [{"role": "user", "content": user_input}]},
154
+ config,
155
+ stream_mode="values",
156
+ )
157
+ for event in events:
158
+ event["messages"][-1].pretty_print()
159
+
src/retrieval/retrieval.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from src.database.qdrant_store import QdrantStore
3
+ from src.embeddings import TextEmbedder
4
+
5
+ def embed_query(query: str | list[str], ):
6
+ pass
src/testing.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bs4 import BeautifulSoup
2
+ from markdown import markdown
3
+ from langchain.text_splitter import MarkdownTextSplitter
4
+
5
+
6
+ path = "D:\PycharmProjects\polargs-docu-chat-rag\data\polars-docu\concepts\data-types-and-structures.md"
7
+
8
+ with open(path, 'r', encoding="utf8") as f_r:
9
+ test_md = f_r.read()
10
+
11
+ html = markdown(test_md)
12
+ text = ''.join(BeautifulSoup(html).findAll(text=True))
13
+
14
+ print(text[:10])
15
+
16
+ splitter = MarkdownTextSplitter(chunk_size=512, chunk_overlap=64)
17
+
18
+ docs = splitter.create_documents([text])
19
+ print(docs)
src/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import islice
2
+
3
+
4
+ def batched(iterable, n, *, strict=False):
5
+ # batched('ABCDEFG', 3) → ABC DEF G
6
+ if n < 1:
7
+ raise ValueError('n must be at least one')
8
+ iterator = iter(iterable)
9
+ while batch := tuple(islice(iterator, n)):
10
+ if strict and len(batch) != n:
11
+ raise ValueError('batched(): incomplete batch')
12
+ yield batch