{ "cells": [ { "cell_type": "markdown", "id": "e20d8915", "metadata": {}, "source": [ "# 📌 Notebook 1: Embed and Store Shot Data in Qdrant\n", "\n", "This notebook loads your cleaned shot data, embeds it using `bge-small-en-v1.5`, and stores the embeddings in Qdrant for use in retrieval and recommendation." ] }, { "cell_type": "markdown", "id": "b11634a3", "metadata": {}, "source": [ "# Initial Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "edc009c4", "metadata": {}, "outputs": [], "source": [ "# # Step 1: Initial Setup:\n", "\n", "# # Load sentence-transformers and GTE-small model\n", "# from sentence_transformers import SentenceTransformer\n", "# model = SentenceTransformer('thenlper/gte-small')" ] }, { "cell_type": "markdown", "id": "8d5248e0", "metadata": {}, "source": [ "## Load Shot Data From .csv" ] }, { "cell_type": "code", "execution_count": 2, "id": "dce5090d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DateClub TypeClub DescriptionCarry DistanceTotal DistanceBall SpeedClub SpeedSpin RateAttack AngleDescent AngleShot Classification
02025-02-04 12:41:00DriverTopGolf - Driver (+1; N; 2.75T)124.33171.19122.1685.9211542.9511.33Hook
12025-02-04 12:41:42DriverTopGolf - Driver (+1; N; 2.75T)104.75150.95120.3584.2016662.458.19Push Hook
22025-02-04 12:42:17DriverTopGolf - Driver (+1; N; 2.75T)163.45195.51115.0586.2812274.3023.02Push
32025-02-04 12:43:05DriverTopGolf - Driver (+1; N; 2.75T)162.57192.56110.9181.9617831.7424.87Push
42025-02-04 12:44:18DriverTopGolf - Driver (+1; N; 2.75T)105.30152.00118.8380.7814781.298.67Push Draw
\n", "
" ], "text/plain": [ " Date Club Type Club Description \\\n", "0 2025-02-04 12:41:00 Driver TopGolf - Driver (+1; N; 2.75T) \n", "1 2025-02-04 12:41:42 Driver TopGolf - Driver (+1; N; 2.75T) \n", "2 2025-02-04 12:42:17 Driver TopGolf - Driver (+1; N; 2.75T) \n", "3 2025-02-04 12:43:05 Driver TopGolf - Driver (+1; N; 2.75T) \n", "4 2025-02-04 12:44:18 Driver TopGolf - Driver (+1; N; 2.75T) \n", "\n", " Carry Distance Total Distance Ball Speed Club Speed Spin Rate \\\n", "0 124.33 171.19 122.16 85.92 1154 \n", "1 104.75 150.95 120.35 84.20 1666 \n", "2 163.45 195.51 115.05 86.28 1227 \n", "3 162.57 192.56 110.91 81.96 1783 \n", "4 105.30 152.00 118.83 80.78 1478 \n", "\n", " Attack Angle Descent Angle Shot Classification \n", "0 2.95 11.33 Hook \n", "1 2.45 8.19 Push Hook \n", "2 4.30 23.02 Push \n", "3 1.74 24.87 Push \n", "4 1.29 8.67 Push Draw " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Step 2: Load cleaned shot data\n", "import pandas as pd\n", "shot_data = pd.read_csv('../data/raw/cleaned_shot_data.csv')\n", "shot_data.head()" ] }, { "cell_type": "markdown", "id": "56e952f9", "metadata": {}, "source": [ "## Embed Shot Data" ] }, { "cell_type": "code", "execution_count": 3, "id": "a5052716", "metadata": {}, "outputs": [], "source": [ "# Step 3: Format shot data into text chunks for embedding\n", "\n", "def create_embedding_text(row):\n", " return (\n", " f\"On {row['Date']}, the golfer hit a shot {row['Total Distance']} yards with a carry of {row['Carry Distance']} yards \"\n", " f\"using a {row['Club Type']} ({row['Club Description']}). \"\n", " f\"The shot was classified as {row['Shot Classification']}. \"\n", " f\"The known contributing factors to this result were: \"\n", " f\"Ball speed: {row['Ball Speed']} mph. \"\n", " f\"Club speed: {row['Club Speed']} mph. \"\n", " f\"Spin rate: {row['Spin Rate']} rpm. \"\n", " f\"Attack angle: {row['Attack Angle']} degrees. \"\n", " f\"Descent angle: {row['Descent Angle']} degrees.\"\n", " )\n", "\n", "texts = shot_data.apply(create_embedding_text, axis=1).tolist()\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "9f23de1d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/mwalker/development/TAMARKDesigns/AI-Maker-Space/cohort-6/projects/session-05/AIE6-Golf-Agent/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "# Select embedding model\n", "# Options: 'e5-base-v2', 'bge-base-en-v1.5'\n", "embedding_model_choice = 'bge-base-en-v1.5'\n", "\n", "# Use a model-specific collection name\n", "descriptive_collection_name = f\"golf_shot_vectors_{embedding_model_choice.replace('-', '_')}\"\n", "\n", "from sentence_transformers import SentenceTransformer\n", "\n", "def load_embedding_model(name: str):\n", " if name == 'e5-base-v2':\n", " return SentenceTransformer('intfloat/e5-base-v2')\n", " elif name == 'bge-base-en-v1.5':\n", " return SentenceTransformer('BAAI/bge-base-en-v1.5')\n", " else:\n", " raise ValueError(f'Unsupported model: {name}')\n", "\n", "model = load_embedding_model(embedding_model_choice)" ] }, { "cell_type": "code", "execution_count": 5, "id": "fbd10b62", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Batches: 100%|██████████| 16/16 [02:35<00:00, 9.73s/it]\n" ] } ], "source": [ "# Generate embeddings with model-specific query formatting\n", "if embedding_model_choice.startswith('e5'):\n", " texts_for_embedding = [f'passage: {t}' for t in texts]\n", "else:\n", " texts_for_embedding = texts\n", "embeddings = model.encode(texts_for_embedding, show_progress_bar=True)" ] }, { "cell_type": "markdown", "id": "b5744da3", "metadata": {}, "source": [ "## Upload Embedded Shot Data to Qdrant" ] }, { "cell_type": "code", "execution_count": 6, "id": "d95b07f7", "metadata": {}, "outputs": [], "source": [ "# Collect the Qdrant API key\n", "from getpass import getpass\n", "\n", "qdrant_api_key = getpass('🔑 Enter your Qdrant API Key: ')\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "c4a90e06", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/5p/gq47dsys3k5663k1r5z8s3c40000gn/T/ipykernel_40134/2265687801.py:11: DeprecationWarning: `recreate_collection` method is deprecated and will be removed in the future. Use `collection_exists` to check collection existence and `create_collection` instead.\n", " client.recreate_collection(\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Qdrant setup\n", "from qdrant_client import QdrantClient\n", "from qdrant_client.models import VectorParams, PointStruct, Distance\n", "\n", "client = QdrantClient(\n", " url='https://6f592f43-f667-4234-ad3a-4f15ed5882ef.us-west-2-0.aws.cloud.qdrant.io:6333',\n", " api_key=qdrant_api_key\n", ")\n", "\n", "# Recreate the collection to flush old data\n", "client.recreate_collection(\n", " collection_name=descriptive_collection_name,\n", " vectors_config=VectorParams(size=768, distance=Distance.COSINE)\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "f5c0f891", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "UpdateResult(operation_id=0, status=)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Upload embedded vectors to Qdrant\n", "points = [\n", " PointStruct(id=i, vector=embeddings[i], payload={'text': texts[i]})\n", " for i in range(len(embeddings))\n", "]\n", "client.upsert(collection_name=descriptive_collection_name, points=points)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }