Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
b5f94b5
1
Parent(s):
1d74113
refactor
Browse files
main.py
CHANGED
|
@@ -16,8 +16,9 @@ from starlette.status import (
|
|
| 16 |
HTTP_500_INTERNAL_SERVER_ERROR,
|
| 17 |
)
|
| 18 |
|
| 19 |
-
from load_card_data import
|
| 20 |
from load_viewer_data import refresh_viewer_data
|
|
|
|
| 21 |
|
| 22 |
# Set up logging
|
| 23 |
logging.basicConfig(
|
|
@@ -31,7 +32,7 @@ cache.setup("mem://?check_interval=10&size=1000")
|
|
| 31 |
# Initialize Chroma client
|
| 32 |
SAVE_PATH = get_save_path()
|
| 33 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
| 34 |
-
|
| 35 |
|
| 36 |
async_client = AsyncClient(
|
| 37 |
follow_redirects=True,
|
|
@@ -40,33 +41,20 @@ async_client = AsyncClient(
|
|
| 40 |
|
| 41 |
@asynccontextmanager
|
| 42 |
async def lifespan(app: FastAPI):
|
| 43 |
-
global collection
|
| 44 |
# Startup: refresh data and initialize collection
|
| 45 |
logger.info("Starting up the application")
|
| 46 |
try:
|
| 47 |
-
# Create or get the collection
|
| 48 |
-
logger.info("Initializing embedding function")
|
| 49 |
-
embedding_function = get_embedding_function()
|
| 50 |
-
logger.info("Creating or getting collection")
|
| 51 |
-
collection = client.get_or_create_collection(
|
| 52 |
-
name="dataset_cards", embedding_function=embedding_function
|
| 53 |
-
)
|
| 54 |
-
logger.info("Collection initialized successfully")
|
| 55 |
-
|
| 56 |
# Refresh data
|
| 57 |
logger.info("Starting refresh of card data")
|
| 58 |
refresh_card_data()
|
| 59 |
logger.info("Card data refresh completed")
|
| 60 |
-
|
| 61 |
logger.info("Starting refresh of viewer data")
|
| 62 |
await refresh_viewer_data()
|
| 63 |
logger.info("Viewer data refresh completed")
|
| 64 |
-
|
| 65 |
logger.info("Data refresh completed successfully")
|
| 66 |
except Exception as e:
|
| 67 |
logger.error(f"Error during startup: {str(e)}")
|
| 68 |
logger.warning("Application starting with potential data issues")
|
| 69 |
-
|
| 70 |
yield
|
| 71 |
|
| 72 |
# Shutdown: perform any cleanup
|
|
@@ -123,6 +111,8 @@ class DatasetNotForAllAudiencesError(HTTPException):
|
|
| 123 |
@app.get("/similar", response_model=QueryResponse)
|
| 124 |
@cache(ttl="1h")
|
| 125 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
|
|
|
|
|
| 126 |
try:
|
| 127 |
logger.info(f"Querying dataset: {dataset_id}")
|
| 128 |
# Get the embedding for the given dataset_id
|
|
@@ -130,7 +120,6 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 130 |
if not result.get("embeddings"):
|
| 131 |
logger.info(f"Dataset not found: {dataset_id}")
|
| 132 |
try:
|
| 133 |
-
embedding_function = get_embedding_function()
|
| 134 |
card = await try_get_card(dataset_id)
|
| 135 |
if card is None:
|
| 136 |
raise DatasetCardNotFoundError(dataset_id)
|
|
@@ -182,13 +171,13 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 182 |
) from e
|
| 183 |
|
| 184 |
|
| 185 |
-
@app.
|
| 186 |
@cache(ttl="1h")
|
| 187 |
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
|
| 188 |
try:
|
| 189 |
logger.info(f"Querying datasets by text: {query}")
|
| 190 |
collection = client.get_collection(
|
| 191 |
-
name="dataset_cards", embedding_function=
|
| 192 |
)
|
| 193 |
print(query)
|
| 194 |
query_result = collection.query(
|
|
@@ -220,7 +209,7 @@ async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)
|
|
| 220 |
) from e
|
| 221 |
|
| 222 |
|
| 223 |
-
@app.
|
| 224 |
@cache(ttl="1h")
|
| 225 |
async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
|
| 226 |
try:
|
|
|
|
| 16 |
HTTP_500_INTERNAL_SERVER_ERROR,
|
| 17 |
)
|
| 18 |
|
| 19 |
+
from load_card_data import card_embedding_function, refresh_card_data
|
| 20 |
from load_viewer_data import refresh_viewer_data
|
| 21 |
+
from utils import get_save_path, get_collection
|
| 22 |
|
| 23 |
# Set up logging
|
| 24 |
logging.basicConfig(
|
|
|
|
| 32 |
# Initialize Chroma client
|
| 33 |
SAVE_PATH = get_save_path()
|
| 34 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
| 35 |
+
|
| 36 |
|
| 37 |
async_client = AsyncClient(
|
| 38 |
follow_redirects=True,
|
|
|
|
| 41 |
|
| 42 |
@asynccontextmanager
|
| 43 |
async def lifespan(app: FastAPI):
|
|
|
|
| 44 |
# Startup: refresh data and initialize collection
|
| 45 |
logger.info("Starting up the application")
|
| 46 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Refresh data
|
| 48 |
logger.info("Starting refresh of card data")
|
| 49 |
refresh_card_data()
|
| 50 |
logger.info("Card data refresh completed")
|
|
|
|
| 51 |
logger.info("Starting refresh of viewer data")
|
| 52 |
await refresh_viewer_data()
|
| 53 |
logger.info("Viewer data refresh completed")
|
|
|
|
| 54 |
logger.info("Data refresh completed successfully")
|
| 55 |
except Exception as e:
|
| 56 |
logger.error(f"Error during startup: {str(e)}")
|
| 57 |
logger.warning("Application starting with potential data issues")
|
|
|
|
| 58 |
yield
|
| 59 |
|
| 60 |
# Shutdown: perform any cleanup
|
|
|
|
| 111 |
@app.get("/similar", response_model=QueryResponse)
|
| 112 |
@cache(ttl="1h")
|
| 113 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
| 114 |
+
embedding_function = card_embedding_function()
|
| 115 |
+
collection = get_collection(client, embedding_function, "dataset_cards")
|
| 116 |
try:
|
| 117 |
logger.info(f"Querying dataset: {dataset_id}")
|
| 118 |
# Get the embedding for the given dataset_id
|
|
|
|
| 120 |
if not result.get("embeddings"):
|
| 121 |
logger.info(f"Dataset not found: {dataset_id}")
|
| 122 |
try:
|
|
|
|
| 123 |
card = await try_get_card(dataset_id)
|
| 124 |
if card is None:
|
| 125 |
raise DatasetCardNotFoundError(dataset_id)
|
|
|
|
| 171 |
) from e
|
| 172 |
|
| 173 |
|
| 174 |
+
@app.get("/similar-text", response_model=QueryResponse)
|
| 175 |
@cache(ttl="1h")
|
| 176 |
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
|
| 177 |
try:
|
| 178 |
logger.info(f"Querying datasets by text: {query}")
|
| 179 |
collection = client.get_collection(
|
| 180 |
+
name="dataset_cards", embedding_function=card_embedding_function()
|
| 181 |
)
|
| 182 |
print(query)
|
| 183 |
query_result = collection.query(
|
|
|
|
| 209 |
) from e
|
| 210 |
|
| 211 |
|
| 212 |
+
@app.get("/search-viewer", response_model=QueryResponse)
|
| 213 |
@cache(ttl="1h")
|
| 214 |
async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
|
| 215 |
try:
|