davanstrien HF staff commited on
Commit
79f2ae1
·
1 Parent(s): 1e1cb2a

refactor to duckdb

Browse files
Files changed (2) hide show
  1. Dockerfile +15 -8
  2. main.py +175 -194
Dockerfile CHANGED
@@ -1,7 +1,18 @@
1
- FROM python:3.12
 
 
 
 
 
 
 
2
  # Set up a new user named "user" with user ID 1000
3
  RUN useradd -m -u 1000 user
4
 
 
 
 
 
5
  # Switch to the "user" user
6
  USER user
7
 
@@ -9,14 +20,10 @@ USER user
9
  ENV HOME=/home/user \
10
  PATH=/home/user/.local/bin:$PATH
11
 
12
- # Set the working directory to the user's home directory
13
- WORKDIR $HOME/code
14
  WORKDIR /code
15
 
16
- COPY ./requirements.txt /code/requirements.txt
17
-
18
- RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
19
-
20
- COPY . .
21
 
22
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--log-config=log_conf.yaml"]
 
1
+ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
2
+
3
+ # Copy requirements file
4
+ COPY ./requirements.txt /code/requirements.txt
5
+
6
+ # Install dependencies using uv (while still root)
7
+ RUN uv pip install --system --no-cache-dir -r /code/requirements.txt
8
+
9
  # Set up a new user named "user" with user ID 1000
10
  RUN useradd -m -u 1000 user
11
 
12
+ # Create data directory with proper permissions
13
+ RUN mkdir -p /data && chown -R user:user /data
14
+ RUN chown -R user:user /code
15
+
16
  # Switch to the "user" user
17
  USER user
18
 
 
20
  ENV HOME=/home/user \
21
  PATH=/home/user/.local/bin:$PATH
22
 
23
+ # Set the working directory
 
24
  WORKDIR /code
25
 
26
+ # Copy the rest of the application
27
+ COPY --chown=user:user . .
 
 
 
28
 
29
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--log-config=log_conf.yaml"]
main.py CHANGED
@@ -1,252 +1,233 @@
1
  import logging
2
- from contextlib import asynccontextmanager
3
- from typing import List, Optional
4
-
5
- import chromadb
6
- from cashews import cache
7
- from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
8
  from fastapi import FastAPI, HTTPException, Query
9
- from httpx import AsyncClient
10
- from huggingface_hub import DatasetCard
11
  from pydantic import BaseModel
12
- from starlette.responses import RedirectResponse
13
- from starlette.status import (
14
- HTTP_403_FORBIDDEN,
15
- HTTP_404_NOT_FOUND,
16
- HTTP_500_INTERNAL_SERVER_ERROR,
17
- )
18
-
19
- from load_card_data import card_embedding_function, refresh_card_data
20
- from load_viewer_data import refresh_viewer_data
21
- from utils import get_save_path, get_collection, get_chroma_client
22
 
 
23
  # Set up logging
24
- logging.basicConfig(
25
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
26
- )
27
  logger = logging.getLogger(__name__)
28
 
29
- # Set up caching
30
- cache.setup("mem://?check_interval=10&size=1000")
31
-
32
- # Initialize Chroma client
33
- client = get_chroma_client()
34
-
35
- async_client = AsyncClient(
36
- follow_redirects=True,
37
- )
38
 
39
 
 
40
  @asynccontextmanager
41
  async def lifespan(app: FastAPI):
42
- # Startup: refresh data and initialize collection
43
- logger.info("Starting up the application")
44
- try:
45
- # Refresh data
46
- logger.info("Starting refresh of card data")
47
- refresh_card_data()
48
- logger.info("Card data refresh completed")
49
- logger.info("Starting refresh of viewer data")
50
- await refresh_viewer_data()
51
- logger.info("Viewer data refresh completed")
52
- logger.info("Data refresh completed successfully")
53
- except Exception as e:
54
- logger.error(f"Error during startup: {str(e)}")
55
- logger.warning("Application starting with potential data issues")
56
  yield
57
-
58
- # Shutdown: perform any cleanup
59
- logger.info("Shutting down the application")
60
- # Add any cleanup code here if needed
61
 
62
 
63
  app = FastAPI(lifespan=lifespan)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- @app.get("/", include_in_schema=False)
67
- def root():
68
- return RedirectResponse(url="/docs")
69
-
70
-
71
- async def try_get_card(hub_id: str) -> Optional[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
- response = await async_client.get(
74
- f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md"
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
- if response.status_code == 200:
77
- card = DatasetCard(response.text)
78
- return card.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
- logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}")
81
- return None
 
 
 
82
 
83
 
84
  class QueryResult(BaseModel):
85
  dataset_id: str
86
  similarity: float
 
 
 
87
 
88
 
89
  class QueryResponse(BaseModel):
90
  results: List[QueryResult]
91
 
92
 
93
- class DatasetCardNotFoundError(HTTPException):
94
- def __init__(self, dataset_id: str):
95
- super().__init__(
96
- status_code=HTTP_404_NOT_FOUND,
97
- detail=f"No dataset card available for dataset: {dataset_id}",
98
- )
99
-
100
 
101
- class DatasetNotForAllAudiencesError(HTTPException):
102
- def __init__(self, dataset_id: str):
103
- super().__init__(
104
- status_code=HTTP_403_FORBIDDEN,
105
- detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.",
106
- )
107
 
108
 
109
- @app.get("/similar", response_model=QueryResponse)
110
- @cache(ttl="1h")
111
- async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
112
- embedding_function = card_embedding_function()
113
- collection = get_collection(client, embedding_function, "dataset_cards")
114
  try:
115
- logger.info(f"Querying dataset: {dataset_id}")
116
- # Get the embedding for the given dataset_id
117
- result = collection.get(ids=[dataset_id], include=["embeddings"])
118
- if not result.get("embeddings"):
119
- logger.info(f"Dataset not found: {dataset_id}")
120
- try:
121
- card = await try_get_card(dataset_id)
122
- if card is None:
123
- raise DatasetCardNotFoundError(dataset_id)
124
- embeddings = embedding_function(card)
125
- collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
126
- logger.info(f"Dataset {dataset_id} added to collection")
127
- result = collection.get(ids=[dataset_id], include=["embeddings"])
128
- if result.get("not-for-all-audiences"):
129
- raise DatasetNotForAllAudiencesError(dataset_id)
130
- except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError):
131
- raise
132
- except Exception as e:
133
- logger.error(
134
- f"Error adding dataset {dataset_id} to collection: {str(e)}"
135
- )
136
- raise DatasetCardNotFoundError(dataset_id) from e
137
-
138
- embedding = result["embeddings"][0]
139
-
140
- # Query the collection for similar datasets
141
- query_result = collection.query(
142
- query_embeddings=[embedding], n_results=n, include=["distances"]
143
- )
144
-
145
- if not query_result["ids"]:
146
- logger.info(f"No similar datasets found for: {dataset_id}")
147
- raise HTTPException(
148
- status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
149
- )
150
-
151
- # Prepare the response
152
  results = [
153
- QueryResult(dataset_id=id, similarity=1 - distance)
154
- for id, distance in zip(
155
- query_result["ids"][0], query_result["distances"][0]
 
 
 
156
  )
 
157
  ]
158
 
159
- logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
160
  return QueryResponse(results=results)
161
 
162
- except (HTTPException, DatasetCardNotFoundError):
163
- raise
164
  except Exception as e:
165
- logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
166
- raise HTTPException(
167
- status_code=HTTP_500_INTERNAL_SERVER_ERROR,
168
- detail="An unexpected error occurred.",
169
- ) from e
170
 
171
 
172
- @app.get("/similar-text", response_model=QueryResponse)
173
- @cache(ttl="1h")
174
- async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
 
 
175
  try:
176
- logger.info(f"Querying datasets by text: {query}")
177
- collection = client.get_collection(
178
- name="dataset_cards", embedding_function=card_embedding_function()
179
- )
180
- print(query)
181
- query_result = collection.query(
182
- query_texts=query, n_results=n, include=["distances"]
183
- )
184
- print(query_result)
185
-
186
- if not query_result["ids"]:
187
- logger.info(f"No similar datasets found for query: {query}")
188
  raise HTTPException(
189
- status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
190
  )
191
 
192
- # Prepare the response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  results = [
194
- QueryResult(dataset_id=str(id), similarity=float(1 - distance))
195
- for id, distance in zip(
196
- query_result["ids"][0], query_result["distances"][0]
 
 
 
197
  )
 
198
  ]
199
- logger.info(f"Found {len(results)} similar datasets for query: {query}")
200
- return QueryResponse(results=results)
201
-
202
- except Exception as e:
203
- logger.error(f"Error querying datasets by text {query}: {str(e)}")
204
- raise HTTPException(
205
- status_code=HTTP_500_INTERNAL_SERVER_ERROR,
206
- detail="An unexpected error occurred.",
207
- ) from e
208
-
209
 
210
- @app.get("/search-viewer", response_model=QueryResponse)
211
- @cache(ttl="1h")
212
- async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
213
- try:
214
- embedding_function = SentenceTransformerEmbeddingFunction(
215
- model_name="davanstrien/query-to-dataset-viewer-descriptions",
216
- trust_remote_code=True,
217
- )
218
- collection = client.get_collection(
219
- name="dataset-viewer-descriptions",
220
- embedding_function=embedding_function,
221
- )
222
- query = f"USER_QUERY: {query}"
223
- query_result = collection.query(
224
- query_texts=query, n_results=n, include=["distances"]
225
- )
226
- print(query_result)
227
-
228
- if not query_result["ids"]:
229
- logger.info(f"No similar datasets found for query: {query}")
230
- raise HTTPException(
231
- status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
232
- )
233
-
234
- # Prepare the response
235
- results = [
236
- QueryResult(dataset_id=str(id), similarity=float(1 - distance))
237
- for id, distance in zip(
238
- query_result["ids"][0], query_result["distances"][0]
239
- )
240
- ]
241
- logger.info(f"Found {len(results)} similar datasets for query: {query}")
242
  return QueryResponse(results=results)
243
 
 
 
244
  except Exception as e:
245
- logger.error(f"Error querying datasets by text {query}: {str(e)}")
246
- raise HTTPException(
247
- status_code=HTTP_500_INTERNAL_SERVER_ERROR,
248
- detail="An unexpected error occurred.",
249
- ) from e
250
 
251
 
252
  if __name__ == "__main__":
 
1
  import logging
2
+ import os
3
+ from typing import List
4
+ import sys
5
+ import duckdb
6
+ from cashews import cache # Add this import
 
7
  from fastapi import FastAPI, HTTPException, Query
8
+ from fastapi.middleware.cors import CORSMiddleware
 
9
  from pydantic import BaseModel
10
+ from sentence_transformers import SentenceTransformer
11
+ from contextlib import asynccontextmanager
 
 
 
 
 
 
 
 
12
 
13
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
14
  # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
 
 
16
  logger = logging.getLogger(__name__)
17
 
18
+ LOCAL = False
19
+ if sys.platform == "darwin":
20
+ LOCAL = True
21
+ DATA_DIR = "data" if LOCAL else "/data"
22
+ # Configure cache
23
+ cache.setup("mem://", size_limit="4gb")
 
 
 
24
 
25
 
26
+ # Initialize FastAPI app
27
  @asynccontextmanager
28
  async def lifespan(app: FastAPI):
29
+ # Startup: nothing special needed here since model and DB are initialized at module level
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  yield
31
+ # Cleanup
32
+ await cache.close()
33
+ con.close()
 
34
 
35
 
36
  app = FastAPI(lifespan=lifespan)
37
 
38
+ # Add CORS middleware
39
+ app.add_middleware(
40
+ CORSMiddleware,
41
+ allow_origins=[
42
+ "https://*.hf.space", # Allow all Hugging Face Spaces
43
+ "https://*.huggingface.co", # Allow all Hugging Face domains
44
+ # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
45
+ ],
46
+ allow_credentials=True,
47
+ allow_methods=["*"],
48
+ allow_headers=["*"],
49
+ )
50
 
51
+ # Initialize model and DuckDB
52
+ model = SentenceTransformer("nomic-ai/modernbert-embed-base", device="cpu")
53
+ embedding_dim = model.get_sentence_embedding_dimension()
54
+
55
+ # Database setup with fallback
56
+ db_path = f"{DATA_DIR}/vector_store.db"
57
+ try:
58
+ # Create directory if it doesn't exist
59
+ os.makedirs(os.path.dirname(db_path), exist_ok=True)
60
+ con = duckdb.connect(db_path)
61
+ logger.info(f"Connected to persistent database at {db_path}")
62
+ except (OSError, PermissionError) as e:
63
+ logger.warning(
64
+ f"Could not create/access {db_path}. Falling back to in-memory database. Error: {e}"
65
+ )
66
+ con = duckdb.connect(":memory:")
67
+
68
+ # Initialize VSS extension
69
+ con.sql("INSTALL vss; LOAD vss;")
70
+ con.sql("SET hnsw_enable_experimental_persistence=true;")
71
+
72
+
73
+ def setup_database():
74
  try:
75
+ # Create table with properly typed embeddings
76
+ con.sql(f"""
77
+ CREATE TABLE IF NOT EXISTS model_cards AS
78
+ SELECT *, embeddings::FLOAT[{embedding_dim}] as embeddings_float
79
+ FROM 'hf://datasets/davanstrien/outputs-embeddings/**/*.parquet';
80
+ """)
81
+
82
+ # Check if index exists
83
+ index_exists = (
84
+ con.sql("""
85
+ SELECT COUNT(*) as count
86
+ FROM duckdb_indexes
87
+ WHERE index_name = 'my_hnsw_index';
88
+ """).fetchone()[0]
89
+ > 0
90
  )
91
+
92
+ if index_exists:
93
+ # Drop existing index
94
+ con.sql("DROP INDEX my_hnsw_index;")
95
+ logger.info("Dropped existing HNSW index")
96
+
97
+ # Create/Recreate HNSW index
98
+ con.sql("""
99
+ CREATE INDEX my_hnsw_index ON model_cards
100
+ USING HNSW (embeddings_float) WITH (metric = 'cosine');
101
+ """)
102
+ logger.info("Created/Recreated HNSW index")
103
+
104
+ # Log the number of rows in the database
105
+ row_count = con.sql("SELECT COUNT(*) as count FROM model_cards").fetchone()[0]
106
+ logger.info(f"Database initialized with {row_count:,} rows")
107
+
108
  except Exception as e:
109
+ logger.error(f"Setup error: {e}")
110
+
111
+
112
+ # Run setup on startup
113
+ setup_database()
114
 
115
 
116
  class QueryResult(BaseModel):
117
  dataset_id: str
118
  similarity: float
119
+ summary: str
120
+ likes: int
121
+ downloads: int
122
 
123
 
124
  class QueryResponse(BaseModel):
125
  results: List[QueryResult]
126
 
127
 
128
+ @app.get("/")
129
+ async def redirect_to_docs():
130
+ from fastapi.responses import RedirectResponse
 
 
 
 
131
 
132
+ return RedirectResponse(url="/docs")
 
 
 
 
 
133
 
134
 
135
+ @app.get("/search/datasets", response_model=QueryResponse)
136
+ @cache(ttl="10m")
137
+ async def search_datasets(query: str, k: int = Query(default=5, ge=1, le=100)):
 
 
138
  try:
139
+ query_embedding = model.encode(f"search_query: {query}").tolist()
140
+
141
+ # Updated SQL query to include likes and downloads
142
+ result = con.sql(f"""
143
+ SELECT
144
+ datasetId as dataset_id,
145
+ 1 - array_cosine_distance(
146
+ embeddings_float::FLOAT[{embedding_dim}],
147
+ {query_embedding}::FLOAT[{embedding_dim}]
148
+ ) as similarity,
149
+ summary,
150
+ likes,
151
+ downloads
152
+ FROM model_cards
153
+ ORDER BY similarity DESC
154
+ LIMIT {k};
155
+ """).df()
156
+
157
+ # Updated result conversion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  results = [
159
+ QueryResult(
160
+ dataset_id=row["dataset_id"],
161
+ similarity=float(row["similarity"]),
162
+ summary=row["summary"],
163
+ likes=int(row["likes"]),
164
+ downloads=int(row["downloads"]),
165
  )
166
+ for _, row in result.iterrows()
167
  ]
168
 
 
169
  return QueryResponse(results=results)
170
 
 
 
171
  except Exception as e:
172
+ logger.error(f"Search error: {str(e)}")
173
+ raise HTTPException(status_code=500, detail="Search failed")
 
 
 
174
 
175
 
176
+ @app.get("/similarity/datasets", response_model=QueryResponse)
177
+ @cache(ttl="10m")
178
+ async def find_similar_datasets(
179
+ dataset_id: str, k: int = Query(default=5, ge=1, le=100)
180
+ ):
181
  try:
182
+ # First, get the embedding for the input dataset_id
183
+ reference_embedding = con.sql(f"""
184
+ SELECT embeddings_float
185
+ FROM model_cards
186
+ WHERE datasetId = '{dataset_id}'
187
+ LIMIT 1;
188
+ """).df()
189
+
190
+ if reference_embedding.empty:
 
 
 
191
  raise HTTPException(
192
+ status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
193
  )
194
 
195
+ # Updated similarity search query to include likes and downloads
196
+ result = con.sql(f"""
197
+ SELECT
198
+ datasetId as dataset_id,
199
+ 1 - array_cosine_distance(
200
+ embeddings_float::FLOAT[{embedding_dim}],
201
+ (SELECT embeddings_float FROM model_cards WHERE datasetId = '{dataset_id}' LIMIT 1)
202
+ ) as similarity,
203
+ summary,
204
+ likes,
205
+ downloads
206
+ FROM model_cards
207
+ WHERE datasetId != '{dataset_id}'
208
+ ORDER BY similarity DESC
209
+ LIMIT {k};
210
+ """).df()
211
+
212
+ # Updated result conversion
213
  results = [
214
+ QueryResult(
215
+ dataset_id=row["dataset_id"],
216
+ similarity=float(row["similarity"]),
217
+ summary=row["summary"],
218
+ likes=int(row["likes"]),
219
+ downloads=int(row["downloads"]),
220
  )
221
+ for _, row in result.iterrows()
222
  ]
 
 
 
 
 
 
 
 
 
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  return QueryResponse(results=results)
225
 
226
+ except HTTPException:
227
+ raise
228
  except Exception as e:
229
+ logger.error(f"Similarity search error: {str(e)}")
230
+ raise HTTPException(status_code=500, detail="Similarity search failed")
 
 
 
231
 
232
 
233
  if __name__ == "__main__":