Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from typing import Optional, List | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException, Query | |
| from pydantic import BaseModel | |
| import chromadb | |
| import logging | |
| from load_data import get_save_path, refresh_data | |
| from cashews import cache | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Set up caching | |
| cache.setup("mem://?check_interval=10&size=10000") | |
| # Initialize Chroma client | |
| SAVE_PATH = get_save_path() | |
| client = chromadb.PersistentClient(path=SAVE_PATH) | |
| collection = client.get_collection("dataset_cards") | |
| class QueryResult(BaseModel): | |
| dataset_id: str | |
| similarity: float | |
| class QueryResponse(BaseModel): | |
| results: List[QueryResult] | |
| async def lifespan(app: FastAPI): | |
| # Startup: refresh data | |
| logger.info("Starting up the application") | |
| try: | |
| refresh_data() | |
| logger.info("Data refresh completed successfully") | |
| except Exception as e: | |
| logger.error(f"Error during data refresh: {str(e)}") | |
| yield # Here the app is running and handling requests | |
| # Shutdown: perform any cleanup | |
| logger.info("Shutting down the application") | |
| # Add any cleanup code here if needed | |
| app = FastAPI(lifespan=lifespan) | |
| async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): | |
| try: | |
| logger.info(f"Querying dataset: {dataset_id}") | |
| # Get the embedding for the given dataset_id | |
| result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| if not result["embeddings"]: | |
| logger.info(f"Dataset not found: {dataset_id}") | |
| raise HTTPException(status_code=404, detail="Dataset not found") | |
| embedding = result["embeddings"][0] | |
| # Query the collection for similar datasets | |
| query_result = collection.query( | |
| query_embeddings=[embedding], n_results=n, include=["distances"] | |
| ) | |
| if not query_result["ids"]: | |
| logger.info(f"No similar datasets found for: {dataset_id}") | |
| return None | |
| # Prepare the response | |
| results = [ | |
| QueryResult(dataset_id=id, similarity=1 - distance) | |
| for id, distance in zip( | |
| query_result["ids"][0], query_result["distances"][0] | |
| ) | |
| ] | |
| logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") | |
| return QueryResponse(results=results) | |
| except Exception as e: | |
| logger.error(f"Error querying dataset {dataset_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |