Jialun He commited on
Commit
2c612d2
·
1 Parent(s): c97e291

preload vector store

Browse files
Files changed (4) hide show
  1. .gitignore +49 -0
  2. agent.py +53 -29
  3. supabase_docs.csv +0 -0
  4. util.py +56 -0
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+ env.bak/
30
+ venv.bak/
31
+ .conda/
32
+
33
+ # VS Code
34
+ .vscode/
35
+ *.code-workspace
36
+ .history/
37
+
38
+ # Jupyter Notebook
39
+ .ipynb_checkpoints
40
+
41
+ # Local development files
42
+ *.log
43
+ .DS_Store
44
+ Thumbs.db
45
+
46
+ # Project specific
47
+ .env
48
+ *.db
49
+ *.sqlite3
agent.py CHANGED
@@ -1,25 +1,30 @@
1
  """LangGraph Agent"""
 
2
  import os
 
3
  from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_openai import ChatOpenAI
9
- from langchain.agents import initialize_agent, Tool
10
- from langchain_groq import ChatGroq
11
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
- from langchain_community.document_loaders import WikipediaLoader
14
- from langchain_community.document_loaders import ArxivLoader
15
  from langchain_community.vectorstores import SupabaseVectorStore
16
- from langchain_core.messages import SystemMessage, HumanMessage
17
  from langchain_core.tools import tool
18
- from langchain.tools.retriever import create_retriever_tool
 
 
 
 
 
 
 
 
 
19
  from supabase.client import Client, create_client
20
 
21
  load_dotenv()
22
 
 
23
  @tool
24
  def multiply(a: int, b: int) -> int:
25
  """Multiply two numbers.
@@ -29,6 +34,7 @@ def multiply(a: int, b: int) -> int:
29
  """
30
  return a * b
31
 
 
32
  @tool
33
  def add(a: int, b: int) -> int:
34
  """Add two numbers.
@@ -38,6 +44,7 @@ def add(a: int, b: int) -> int:
38
  """
39
  return a + b
40
 
 
41
  @tool
42
  def subtract(a: int, b: int) -> int:
43
  """Subtract two numbers.
@@ -47,6 +54,7 @@ def subtract(a: int, b: int) -> int:
47
  """
48
  return a - b
49
 
 
50
  @tool
51
  def divide(a: int, b: int) -> int:
52
  """Divide two numbers.
@@ -58,6 +66,7 @@ def divide(a: int, b: int) -> int:
58
  raise ValueError("Cannot divide by zero.")
59
  return a / b
60
 
 
61
  @tool
62
  def modulus(a: int, b: int) -> int:
63
  """Get the modulus of two numbers.
@@ -67,6 +76,7 @@ def modulus(a: int, b: int) -> int:
67
  """
68
  return a % b
69
 
 
70
  @tool
71
  def wiki_search(query: str) -> str:
72
  """Search Wikipedia for a query and return maximum 2 results.
@@ -77,17 +87,22 @@ def wiki_search(query: str) -> str:
77
  [
78
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
79
  for doc in search_docs
80
- ])
 
81
  return {"wiki_results": formatted_search_docs}
82
 
 
83
  @tool
84
  def web_search(query: str) -> str:
85
  """Search Tavily for a query and return maximum 3 results.
86
  Args:
87
  query: The search query."""
88
- search_docs = TavilySearchResults(max_results=3).invoke(query) # Fixed: pass query as positional argument
 
 
89
  return {"web_results": search_docs} # Also fixed the return type issue
90
 
 
91
  @tool
92
  def arvix_search(query: str) -> str:
93
  """Search Arxiv for a query and return maximum 3 result.
@@ -98,24 +113,24 @@ def arvix_search(query: str) -> str:
98
  [
99
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
100
  for doc in search_docs
101
- ])
 
102
  return {"arvix_results": formatted_search_docs}
103
 
104
 
105
  def test_supabase_connection():
106
  load_dotenv()
107
-
108
  try:
109
  supabase = create_client(
110
- os.environ.get("SUPABASE_URL"),
111
- os.environ.get("SUPABASE_SERVICE_KEY")
112
  )
113
-
114
  # Test query
115
- result = supabase.table('documents').select("*").limit(1).execute()
116
  print("Connection successful!")
117
  return True
118
-
119
  except Exception as e:
120
  print(f"Connection failed: {e}")
121
  return False
@@ -129,13 +144,15 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
129
  sys_msg = SystemMessage(content=system_prompt)
130
 
131
  # build a retriever
132
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
 
 
133
  supabase: Client = create_client(
134
- os.environ.get("SUPABASE_URL"),
135
- os.environ.get("SUPABASE_SERVICE_KEY"))
136
  vector_store = SupabaseVectorStore(
137
  client=supabase,
138
- embedding= embeddings,
139
  table_name="documents",
140
  query_name="match_documents_langchain",
141
  )
@@ -158,16 +175,21 @@ tools = [
158
  arvix_search,
159
  ]
160
 
 
161
  # Build graph function
162
- def build_graph(provider: str = "openai"):
163
  """Build the graph"""
164
  # Load environment variables from .env file
165
  if provider == "google":
166
  # Google Gemini
167
- llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-05-20", temperature=0)
 
 
168
  elif provider == "groq":
169
  # Groq https://console.groq.com/docs/models
170
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
 
 
171
  elif provider == "openai":
172
  # OpenAI
173
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
@@ -192,7 +214,9 @@ def build_graph(provider: str = "openai"):
192
  """Retriever node"""
193
  try:
194
  # Use the vector store to find similar questions
195
- similar_question = vector_store.similarity_search(state["messages"][0].content)
 
 
196
  if not similar_question:
197
  raise ValueError("No similar questions found.")
198
  except Exception as e:
 
1
  """LangGraph Agent"""
2
+
3
  import os
4
+
5
  from dotenv import load_dotenv
6
+ from langchain.agents import Tool, initialize_agent
7
+ from langchain.tools.retriever import create_retriever_tool
8
+ from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
 
 
 
 
 
9
  from langchain_community.tools.tavily_search import TavilySearchResults
 
 
10
  from langchain_community.vectorstores import SupabaseVectorStore
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
  from langchain_core.tools import tool
13
+ from langchain_google_genai import ChatGoogleGenerativeAI
14
+ from langchain_groq import ChatGroq
15
+ from langchain_huggingface import (
16
+ ChatHuggingFace,
17
+ HuggingFaceEmbeddings,
18
+ HuggingFaceEndpoint,
19
+ )
20
+ from langchain_openai import ChatOpenAI
21
+ from langgraph.graph import START, MessagesState, StateGraph
22
+ from langgraph.prebuilt import ToolNode, tools_condition
23
  from supabase.client import Client, create_client
24
 
25
  load_dotenv()
26
 
27
+
28
  @tool
29
  def multiply(a: int, b: int) -> int:
30
  """Multiply two numbers.
 
34
  """
35
  return a * b
36
 
37
+
38
  @tool
39
  def add(a: int, b: int) -> int:
40
  """Add two numbers.
 
44
  """
45
  return a + b
46
 
47
+
48
  @tool
49
  def subtract(a: int, b: int) -> int:
50
  """Subtract two numbers.
 
54
  """
55
  return a - b
56
 
57
+
58
  @tool
59
  def divide(a: int, b: int) -> int:
60
  """Divide two numbers.
 
66
  raise ValueError("Cannot divide by zero.")
67
  return a / b
68
 
69
+
70
  @tool
71
  def modulus(a: int, b: int) -> int:
72
  """Get the modulus of two numbers.
 
76
  """
77
  return a % b
78
 
79
+
80
  @tool
81
  def wiki_search(query: str) -> str:
82
  """Search Wikipedia for a query and return maximum 2 results.
 
87
  [
88
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
89
  for doc in search_docs
90
+ ]
91
+ )
92
  return {"wiki_results": formatted_search_docs}
93
 
94
+
95
  @tool
96
  def web_search(query: str) -> str:
97
  """Search Tavily for a query and return maximum 3 results.
98
  Args:
99
  query: The search query."""
100
+ search_docs = TavilySearchResults(max_results=3).invoke(
101
+ query
102
+ ) # Fixed: pass query as positional argument
103
  return {"web_results": search_docs} # Also fixed the return type issue
104
 
105
+
106
  @tool
107
  def arvix_search(query: str) -> str:
108
  """Search Arxiv for a query and return maximum 3 result.
 
113
  [
114
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
115
  for doc in search_docs
116
+ ]
117
+ )
118
  return {"arvix_results": formatted_search_docs}
119
 
120
 
121
  def test_supabase_connection():
122
  load_dotenv()
123
+
124
  try:
125
  supabase = create_client(
126
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY")
 
127
  )
128
+
129
  # Test query
130
+ result = supabase.table("documents").select("*").limit(1).execute()
131
  print("Connection successful!")
132
  return True
133
+
134
  except Exception as e:
135
  print(f"Connection failed: {e}")
136
  return False
 
144
  sys_msg = SystemMessage(content=system_prompt)
145
 
146
  # build a retriever
147
+ embeddings = HuggingFaceEmbeddings(
148
+ model_name="sentence-transformers/all-mpnet-base-v2"
149
+ ) # dim=768
150
  supabase: Client = create_client(
151
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY")
152
+ )
153
  vector_store = SupabaseVectorStore(
154
  client=supabase,
155
+ embedding=embeddings,
156
  table_name="documents",
157
  query_name="match_documents_langchain",
158
  )
 
175
  arvix_search,
176
  ]
177
 
178
+
179
  # Build graph function
180
+ def build_graph(provider: str = "google"):
181
  """Build the graph"""
182
  # Load environment variables from .env file
183
  if provider == "google":
184
  # Google Gemini
185
+ llm = ChatGoogleGenerativeAI(
186
+ model="gemini-2.5-flash-preview-05-20", temperature=0
187
+ )
188
  elif provider == "groq":
189
  # Groq https://console.groq.com/docs/models
190
+ llm = ChatGroq(
191
+ model="qwen-qwq-32b", temperature=0
192
+ ) # optional : qwen-qwq-32b gemma2-9b-it
193
  elif provider == "openai":
194
  # OpenAI
195
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
 
214
  """Retriever node"""
215
  try:
216
  # Use the vector store to find similar questions
217
+ similar_question = vector_store.similarity_search(
218
+ state["messages"][0].content
219
+ )
220
  if not similar_question:
221
  raise ValueError("No similar questions found.")
222
  except Exception as e:
supabase_docs.csv ADDED
The diff for this file is too large to render. See raw diff
 
util.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+
5
+ import pandas as pd
6
+ from dotenv import load_dotenv
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import SupabaseVectorStore
9
+ from supabase.client import create_client
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class SupabaseConnector:
16
+ def __init__(self):
17
+ load_dotenv()
18
+ self.supabase = create_client(
19
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY")
20
+ )
21
+ self.embeddings = HuggingFaceEmbeddings(
22
+ model_name="sentence-transformers/all-mpnet-base-v2"
23
+ )
24
+ self.vector_store = SupabaseVectorStore(
25
+ client=self.supabase,
26
+ embedding=self.embeddings,
27
+ table_name="documents",
28
+ query_name="match_documents_langchain",
29
+ )
30
+
31
+ def upload_csv(self, file_path: str, batch_size: int = 100):
32
+ """
33
+ Upload documents from supabase_docs.csv to Supabase vector store.
34
+ Only 'content' and parsed 'metadata' are used.
35
+ """
36
+ df = pd.read_csv(file_path)
37
+ logger.info(f"Loaded {len(df)} records from {file_path}")
38
+
39
+ # Parse metadata column from string to dict
40
+ df["metadata"] = df["metadata"].apply(
41
+ lambda x: ast.literal_eval(x) if isinstance(x, str) else {}
42
+ )
43
+
44
+ for i in range(0, len(df), batch_size):
45
+ batch = df.iloc[i : i + batch_size]
46
+ texts = batch["content"].tolist()
47
+ metadatas = batch["metadata"].tolist()
48
+ self.vector_store.add_texts(texts=texts, metadatas=metadatas)
49
+ logger.info(f"Uploaded batch {i//batch_size + 1}")
50
+
51
+ logger.info("CSV upload completed.")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ connector = SupabaseConnector()
56
+ connector.upload_csv("supabase_docs.csv")