Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
4d185df
1
Parent(s):
a5c714c
try adding missing cards
Browse files
main.py
CHANGED
|
@@ -7,8 +7,9 @@ from cashews import cache
|
|
| 7 |
from fastapi import FastAPI, HTTPException, Query
|
| 8 |
from pydantic import BaseModel
|
| 9 |
from starlette.responses import RedirectResponse
|
| 10 |
-
|
| 11 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
|
|
|
| 12 |
|
| 13 |
# Set up logging
|
| 14 |
logging.basicConfig(
|
|
@@ -24,6 +25,10 @@ SAVE_PATH = get_save_path()
|
|
| 24 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
| 25 |
collection = None
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
class QueryResult(BaseModel):
|
| 29 |
dataset_id: str
|
|
@@ -69,6 +74,19 @@ def root():
|
|
| 69 |
return RedirectResponse(url="/docs")
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
@app.get("/query", response_model=Optional[QueryResponse])
|
| 73 |
@cache(ttl="1h")
|
| 74 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
@@ -76,10 +94,20 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 76 |
logger.info(f"Querying dataset: {dataset_id}")
|
| 77 |
# Get the embedding for the given dataset_id
|
| 78 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
| 79 |
-
|
| 80 |
-
if not result["embeddings"]:
|
| 81 |
logger.info(f"Dataset not found: {dataset_id}")
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
embedding = result["embeddings"][0]
|
| 85 |
|
|
|
|
| 7 |
from fastapi import FastAPI, HTTPException, Query
|
| 8 |
from pydantic import BaseModel
|
| 9 |
from starlette.responses import RedirectResponse
|
| 10 |
+
from httpx import AsyncClient
|
| 11 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
| 12 |
+
from huggingface_hub import DatasetCard
|
| 13 |
|
| 14 |
# Set up logging
|
| 15 |
logging.basicConfig(
|
|
|
|
| 25 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
| 26 |
collection = None
|
| 27 |
|
| 28 |
+
async_client = AsyncClient(
|
| 29 |
+
follow_redirects=True,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
|
| 33 |
class QueryResult(BaseModel):
|
| 34 |
dataset_id: str
|
|
|
|
| 74 |
return RedirectResponse(url="/docs")
|
| 75 |
|
| 76 |
|
| 77 |
+
async def try_get_card(hub_id: str) -> Optional[str]:
|
| 78 |
+
try:
|
| 79 |
+
response = await async_client.get(
|
| 80 |
+
f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md"
|
| 81 |
+
)
|
| 82 |
+
if response.status_code == 200:
|
| 83 |
+
card = DatasetCard(response.text)
|
| 84 |
+
return card.text
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}")
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
@app.get("/query", response_model=Optional[QueryResponse])
|
| 91 |
@cache(ttl="1h")
|
| 92 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
|
|
| 94 |
logger.info(f"Querying dataset: {dataset_id}")
|
| 95 |
# Get the embedding for the given dataset_id
|
| 96 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
| 97 |
+
if not result.get("embeddings"):
|
|
|
|
| 98 |
logger.info(f"Dataset not found: {dataset_id}")
|
| 99 |
+
try:
|
| 100 |
+
embedding_function = get_embedding_function()
|
| 101 |
+
card = await try_get_card(dataset_id)
|
| 102 |
+
embeddings = embedding_function(card)
|
| 103 |
+
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
| 104 |
+
logger.info(f"Dataset {dataset_id} added to collection")
|
| 105 |
+
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(
|
| 108 |
+
f"Error adding dataset {dataset_id} to collection: {str(e)}"
|
| 109 |
+
)
|
| 110 |
+
raise HTTPException(status_code=404, detail="Dataset not found") from e
|
| 111 |
|
| 112 |
embedding = result["embeddings"][0]
|
| 113 |
|