mwalker22 commited on
Commit
d9897c8
·
1 Parent(s): 589384b

feat: Implement a Qdrant vector store using GTE-small embeddings. Initialized it with golf shot data collected over time.

Browse files
.gitignore CHANGED
@@ -204,5 +204,8 @@ dist-ssr
204
  *.sln
205
  *.sw?
206
 
207
- # Temporary directory exclusion
208
- frontend_backup/
 
 
 
 
204
  *.sln
205
  *.sw?
206
 
207
+ # Data files
208
+ data/raw/*
209
+ data/processed/*
210
+ !data/raw/.gitkeep
211
+ !data/processed/.gitkeep
data/processed/.gitkeep ADDED
File without changes
data/raw/.gitkeep ADDED
File without changes
notebooks/01_Embed_and_Store_Shots_GTE_with_Qdrant.ipynb ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e20d8915",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 📌 Notebook 1: Embed and Store Shot Data in Qdrant\n",
9
+ "\n",
10
+ "This notebook loads your cleaned shot data, embeds it using `bge-small-en-v1.5`, and stores the embeddings in Qdrant for use in retrieval and recommendation."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "b11634a3",
16
+ "metadata": {},
17
+ "source": [
18
+ "# Initial Setup"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 4,
24
+ "id": "edc009c4",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# Step 1: Initial Setup:\n",
29
+ "\n",
30
+ "# Load sentence-transformers and GTE-small model\n",
31
+ "from sentence_transformers import SentenceTransformer\n",
32
+ "model = SentenceTransformer('thenlper/gte-small')"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "8d5248e0",
38
+ "metadata": {},
39
+ "source": [
40
+ "## Load Shot Data From .csv"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 5,
46
+ "id": "dce5090d",
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "data": {
51
+ "text/html": [
52
+ "<div>\n",
53
+ "<style scoped>\n",
54
+ " .dataframe tbody tr th:only-of-type {\n",
55
+ " vertical-align: middle;\n",
56
+ " }\n",
57
+ "\n",
58
+ " .dataframe tbody tr th {\n",
59
+ " vertical-align: top;\n",
60
+ " }\n",
61
+ "\n",
62
+ " .dataframe thead th {\n",
63
+ " text-align: right;\n",
64
+ " }\n",
65
+ "</style>\n",
66
+ "<table border=\"1\" class=\"dataframe\">\n",
67
+ " <thead>\n",
68
+ " <tr style=\"text-align: right;\">\n",
69
+ " <th></th>\n",
70
+ " <th>Date</th>\n",
71
+ " <th>Club Type</th>\n",
72
+ " <th>Club Description</th>\n",
73
+ " <th>Carry Distance</th>\n",
74
+ " <th>Total Distance</th>\n",
75
+ " <th>Ball Speed</th>\n",
76
+ " <th>Club Speed</th>\n",
77
+ " <th>Spin Rate</th>\n",
78
+ " <th>Attack Angle</th>\n",
79
+ " <th>Descent Angle</th>\n",
80
+ " <th>Shot Classification</th>\n",
81
+ " </tr>\n",
82
+ " </thead>\n",
83
+ " <tbody>\n",
84
+ " <tr>\n",
85
+ " <th>0</th>\n",
86
+ " <td>2025-02-04 12:41:00</td>\n",
87
+ " <td>Driver</td>\n",
88
+ " <td>TopGolf - Driver (+1; N; 2.75T)</td>\n",
89
+ " <td>124.33</td>\n",
90
+ " <td>171.19</td>\n",
91
+ " <td>122.16</td>\n",
92
+ " <td>85.92</td>\n",
93
+ " <td>1154</td>\n",
94
+ " <td>2.95</td>\n",
95
+ " <td>11.33</td>\n",
96
+ " <td>Hook</td>\n",
97
+ " </tr>\n",
98
+ " <tr>\n",
99
+ " <th>1</th>\n",
100
+ " <td>2025-02-04 12:41:42</td>\n",
101
+ " <td>Driver</td>\n",
102
+ " <td>TopGolf - Driver (+1; N; 2.75T)</td>\n",
103
+ " <td>104.75</td>\n",
104
+ " <td>150.95</td>\n",
105
+ " <td>120.35</td>\n",
106
+ " <td>84.20</td>\n",
107
+ " <td>1666</td>\n",
108
+ " <td>2.45</td>\n",
109
+ " <td>8.19</td>\n",
110
+ " <td>Push Hook</td>\n",
111
+ " </tr>\n",
112
+ " <tr>\n",
113
+ " <th>2</th>\n",
114
+ " <td>2025-02-04 12:42:17</td>\n",
115
+ " <td>Driver</td>\n",
116
+ " <td>TopGolf - Driver (+1; N; 2.75T)</td>\n",
117
+ " <td>163.45</td>\n",
118
+ " <td>195.51</td>\n",
119
+ " <td>115.05</td>\n",
120
+ " <td>86.28</td>\n",
121
+ " <td>1227</td>\n",
122
+ " <td>4.30</td>\n",
123
+ " <td>23.02</td>\n",
124
+ " <td>Push</td>\n",
125
+ " </tr>\n",
126
+ " <tr>\n",
127
+ " <th>3</th>\n",
128
+ " <td>2025-02-04 12:43:05</td>\n",
129
+ " <td>Driver</td>\n",
130
+ " <td>TopGolf - Driver (+1; N; 2.75T)</td>\n",
131
+ " <td>162.57</td>\n",
132
+ " <td>192.56</td>\n",
133
+ " <td>110.91</td>\n",
134
+ " <td>81.96</td>\n",
135
+ " <td>1783</td>\n",
136
+ " <td>1.74</td>\n",
137
+ " <td>24.87</td>\n",
138
+ " <td>Push</td>\n",
139
+ " </tr>\n",
140
+ " <tr>\n",
141
+ " <th>4</th>\n",
142
+ " <td>2025-02-04 12:44:18</td>\n",
143
+ " <td>Driver</td>\n",
144
+ " <td>TopGolf - Driver (+1; N; 2.75T)</td>\n",
145
+ " <td>105.30</td>\n",
146
+ " <td>152.00</td>\n",
147
+ " <td>118.83</td>\n",
148
+ " <td>80.78</td>\n",
149
+ " <td>1478</td>\n",
150
+ " <td>1.29</td>\n",
151
+ " <td>8.67</td>\n",
152
+ " <td>Push Draw</td>\n",
153
+ " </tr>\n",
154
+ " </tbody>\n",
155
+ "</table>\n",
156
+ "</div>"
157
+ ],
158
+ "text/plain": [
159
+ " Date Club Type Club Description \\\n",
160
+ "0 2025-02-04 12:41:00 Driver TopGolf - Driver (+1; N; 2.75T) \n",
161
+ "1 2025-02-04 12:41:42 Driver TopGolf - Driver (+1; N; 2.75T) \n",
162
+ "2 2025-02-04 12:42:17 Driver TopGolf - Driver (+1; N; 2.75T) \n",
163
+ "3 2025-02-04 12:43:05 Driver TopGolf - Driver (+1; N; 2.75T) \n",
164
+ "4 2025-02-04 12:44:18 Driver TopGolf - Driver (+1; N; 2.75T) \n",
165
+ "\n",
166
+ " Carry Distance Total Distance Ball Speed Club Speed Spin Rate \\\n",
167
+ "0 124.33 171.19 122.16 85.92 1154 \n",
168
+ "1 104.75 150.95 120.35 84.20 1666 \n",
169
+ "2 163.45 195.51 115.05 86.28 1227 \n",
170
+ "3 162.57 192.56 110.91 81.96 1783 \n",
171
+ "4 105.30 152.00 118.83 80.78 1478 \n",
172
+ "\n",
173
+ " Attack Angle Descent Angle Shot Classification \n",
174
+ "0 2.95 11.33 Hook \n",
175
+ "1 2.45 8.19 Push Hook \n",
176
+ "2 4.30 23.02 Push \n",
177
+ "3 1.74 24.87 Push \n",
178
+ "4 1.29 8.67 Push Draw "
179
+ ]
180
+ },
181
+ "execution_count": 5,
182
+ "metadata": {},
183
+ "output_type": "execute_result"
184
+ }
185
+ ],
186
+ "source": [
187
+ "# Step 2: Load cleaned shot data\n",
188
+ "import pandas as pd\n",
189
+ "shot_data = pd.read_csv('../data/raw/cleaned_shot_data.csv')\n",
190
+ "shot_data.head()"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "56e952f9",
196
+ "metadata": {},
197
+ "source": [
198
+ "## Embed Shot Data"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 6,
204
+ "id": "a5052716",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "# Step 3: Format shot data into text chunks for embedding\n",
209
+ "def create_embedding_text(row):\n",
210
+ " return f\"{row['Club Type']} | Carry: {row['Carry Distance']} yds | Ball Speed: {row['Ball Speed']} mph | Classification: {row['Shot Classification']}\"\n",
211
+ "\n",
212
+ "texts = shot_data.apply(create_embedding_text, axis=1).tolist()"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 7,
218
+ "id": "fbd10b62",
219
+ "metadata": {},
220
+ "outputs": [
221
+ {
222
+ "name": "stderr",
223
+ "output_type": "stream",
224
+ "text": [
225
+ "Batches: 100%|██████████| 16/16 [00:17<00:00, 1.10s/it]\n"
226
+ ]
227
+ }
228
+ ],
229
+ "source": [
230
+ "# Generate embeddings for shot text\n",
231
+ "embeddings = model.encode(texts, show_progress_bar=True)"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "markdown",
236
+ "id": "b5744da3",
237
+ "metadata": {},
238
+ "source": [
239
+ "## Upload Embedded Shot Data to Qdrant"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 8,
245
+ "id": "d95b07f7",
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "# Collect the Qdrant API key\n",
250
+ "from getpass import getpass\n",
251
+ "\n",
252
+ "qdrant_api_key = getpass('🔑 Enter your Qdrant API Key: ')\n"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": 11,
258
+ "id": "c4a90e06",
259
+ "metadata": {},
260
+ "outputs": [
261
+ {
262
+ "name": "stderr",
263
+ "output_type": "stream",
264
+ "text": [
265
+ "/var/folders/5p/gq47dsys3k5663k1r5z8s3c40000gn/T/ipykernel_54911/3543225837.py:11: DeprecationWarning: `recreate_collection` method is deprecated and will be removed in the future. Use `collection_exists` to check collection existence and `create_collection` instead.\n",
266
+ " client.recreate_collection(\n"
267
+ ]
268
+ },
269
+ {
270
+ "data": {
271
+ "text/plain": [
272
+ "True"
273
+ ]
274
+ },
275
+ "execution_count": 11,
276
+ "metadata": {},
277
+ "output_type": "execute_result"
278
+ }
279
+ ],
280
+ "source": [
281
+ "# Qdrant setup\n",
282
+ "from qdrant_client import QdrantClient\n",
283
+ "from qdrant_client.models import VectorParams, PointStruct, Distance\n",
284
+ "\n",
285
+ "client = QdrantClient(\n",
286
+ " url='https://6f592f43-f667-4234-ad3a-4f15ed5882ef.us-west-2-0.aws.cloud.qdrant.io:6333',\n",
287
+ " api_key=qdrant_api_key\n",
288
+ ")\n",
289
+ "\n",
290
+ "# Recreate the collection to flush old data\n",
291
+ "client.recreate_collection(\n",
292
+ " collection_name='golf_shot_vectors',\n",
293
+ " vectors_config=VectorParams(size=384, distance=Distance.COSINE)\n",
294
+ ")"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 12,
300
+ "id": "f5c0f891",
301
+ "metadata": {},
302
+ "outputs": [
303
+ {
304
+ "data": {
305
+ "text/plain": [
306
+ "UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)"
307
+ ]
308
+ },
309
+ "execution_count": 12,
310
+ "metadata": {},
311
+ "output_type": "execute_result"
312
+ }
313
+ ],
314
+ "source": [
315
+ "# Upload embedded vectors to Qdrant\n",
316
+ "points = [\n",
317
+ " PointStruct(id=i, vector=embeddings[i], payload={'text': texts[i]})\n",
318
+ " for i in range(len(embeddings))\n",
319
+ "]\n",
320
+ "client.upsert(collection_name='golf_shot_vectors', points=points)"
321
+ ]
322
+ }
323
+ ],
324
+ "metadata": {
325
+ "kernelspec": {
326
+ "display_name": ".venv",
327
+ "language": "python",
328
+ "name": "python3"
329
+ },
330
+ "language_info": {
331
+ "codemirror_mode": {
332
+ "name": "ipython",
333
+ "version": 3
334
+ },
335
+ "file_extension": ".py",
336
+ "mimetype": "text/x-python",
337
+ "name": "python",
338
+ "nbconvert_exporter": "python",
339
+ "pygments_lexer": "ipython3",
340
+ "version": "3.12.9"
341
+ }
342
+ },
343
+ "nbformat": 4,
344
+ "nbformat_minor": 5
345
+ }
notebooks/02_Retrieve_and_Test_Recommendations_GTE_with_Qdrant.ipynb ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f36ee4e7",
6
+ "metadata": {},
7
+ "source": [
8
+ " # 📌 Notebook 2: Retrieve and Test Recommendations (GTE-small)\n",
9
+ "This notebook allows you to enter a shot scenario and retrieve semantically similar historical shots using Qdrant with GTE-small embeddings.\n",
10
+ " "
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "aabcef6d",
16
+ "metadata": {},
17
+ "source": [
18
+ "## Initial Setup"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "b3a051bc",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/Users/mwalker/development/TAMARKDesigns/AI-Maker-Space/cohort-6/projects/session-05/AIE6-Golf-Agent/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
32
+ " from .autonotebook import tqdm as notebook_tqdm\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# Step 1: Initial Setup:\n",
38
+ "\n",
39
+ "# Load sentence-transformers and GTE-small model\n",
40
+ "from sentence_transformers import SentenceTransformer\n",
41
+ "model = SentenceTransformer('thenlper/gte-small')"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 2,
47
+ "id": "83fa3f1d",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# Embed the query using GTE-small\n",
52
+ "\n",
53
+ "def get_embedding(text):\n",
54
+ " return model.encode(text)\n",
55
+ "\n",
56
+ "# Define a test query\n",
57
+ "query = \"What club should I use from 145 yards if I want to avoid a slice?\"\n",
58
+ "\n",
59
+ "query_vector = get_embedding(query)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 3,
65
+ "id": "bfa9fd99",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "# Collect the Qdrant API key\n",
70
+ "from getpass import getpass\n",
71
+ "\n",
72
+ "qdrant_api_key = getpass('🔑 Enter your Qdrant API Key: ')"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 14,
78
+ "id": "42a05a2e",
79
+ "metadata": {},
80
+ "outputs": [
81
+ {
82
+ "name": "stdout",
83
+ "output_type": "stream",
84
+ "text": [
85
+ "Score: 0.8943 | Text: Approach Wedge | Carry: 99.81 yds | Ball Speed: 74.13 mph | Classification: Slice\n",
86
+ "Score: 0.8941 | Text: Approach Wedge | Carry: 99.57 yds | Ball Speed: 74.27 mph | Classification: Slice\n",
87
+ "Score: 0.8869 | Text: Approach Wedge | Carry: 52.12 yds | Ball Speed: 52.12 mph | Classification: Push Slice\n",
88
+ "Score: 0.8868 | Text: Lob Wedge | Carry: 67.72 yds | Ball Speed: 59.75 mph | Classification: Slice\n",
89
+ "Score: 0.8867 | Text: Lob Wedge | Carry: 41.49 yds | Ball Speed: 45.45 mph | Classification: Slice\n"
90
+ ]
91
+ }
92
+ ],
93
+ "source": [
94
+ "# Retrieve top 5 most similar entries\n",
95
+ "from qdrant_client import QdrantClient\n",
96
+ "\n",
97
+ "client = QdrantClient(\n",
98
+ " url=\"https://6f592f43-f667-4234-ad3a-4f15ed5882ef.us-west-2-0.aws.cloud.qdrant.io:6333\",\n",
99
+ " api_key=qdrant_api_key\n",
100
+ ")\n",
101
+ "\n",
102
+ "collection_name = \"golf_shot_vectors\"\n",
103
+ "\n",
104
+ "search_result = client.query_points(\n",
105
+ " collection_name=collection_name,\n",
106
+ " query=query_vector, # Pass the vector directly\n",
107
+ " limit=5,\n",
108
+ " with_payload=True\n",
109
+ " )\n",
110
+ " \n",
111
+ "# Access the points from the QueryResponse\n",
112
+ "for point in search_result.points:\n",
113
+ " print(f\"Score: {point.score:.4f} | Text: {point.payload['text']}\")"
114
+ ]
115
+ }
116
+ ],
117
+ "metadata": {
118
+ "kernelspec": {
119
+ "display_name": ".venv",
120
+ "language": "python",
121
+ "name": "python3"
122
+ },
123
+ "language_info": {
124
+ "codemirror_mode": {
125
+ "name": "ipython",
126
+ "version": 3
127
+ },
128
+ "file_extension": ".py",
129
+ "mimetype": "text/x-python",
130
+ "name": "python",
131
+ "nbconvert_exporter": "python",
132
+ "pygments_lexer": "ipython3",
133
+ "version": "3.12.9"
134
+ }
135
+ },
136
+ "nbformat": 4,
137
+ "nbformat_minor": 5
138
+ }
pyproject.toml CHANGED
@@ -12,13 +12,17 @@ dependencies = [
12
  "pypdf", # for PDF processing
13
  "PyPDF2>=3.0.0",
14
  "numpy>=1.24.0",
 
15
  "langchain>=0.1.0",
16
  "langchain-community>=0.0.22",
17
- "langchain-openai>=0.0.8",
 
18
  "pytest>=7.0.0", # Added pytest dependency
19
- "pytest-asyncio>=0.23.0", # Added pytest-asyncio dependency
 
20
  "tavily-python>=0.3.0", # Added tavily dependency
21
  "langgraph>=0.0.15", # Added langgraph dependency
 
22
  ]
23
 
24
  [build-system]
 
12
  "pypdf", # for PDF processing
13
  "PyPDF2>=3.0.0",
14
  "numpy>=1.24.0",
15
+ "pandas>=2.0.0", # for data analysis
16
  "langchain>=0.1.0",
17
  "langchain-community>=0.0.22",
18
+ "langchain-openai>=0.0.8",
19
+ "qdrant-client>=1.14.2", # Vector storage
20
  "pytest>=7.0.0", # Added pytest dependency
21
+ "pytest-asyncio>=0.23.0", # Added pytest-asyncio dependency
22
+ "pytest-cov>=4.1.0", # Coverage reporting
23
  "tavily-python>=0.3.0", # Added tavily dependency
24
  "langgraph>=0.0.15", # Added langgraph dependency
25
+ "sentence-transformers", #used for embeddings
26
  ]
27
 
28
  [build-system]
uv.lock CHANGED
The diff for this file is too large to render. See raw diff