Spaces:
Sleeping
Sleeping
Update server.py
Browse files
server.py
CHANGED
@@ -27,6 +27,7 @@ import timm
|
|
27 |
import torch
|
28 |
import uvicorn
|
29 |
from fastapi import FastAPI, File, HTTPException, UploadFile
|
|
|
30 |
from PIL import Image
|
31 |
from pydantic import BaseModel, Field
|
32 |
from pydantic_settings import BaseSettings
|
@@ -447,6 +448,8 @@ def create_app(settings: Settings) -> FastAPI:
|
|
447 |
description="An API for tagging images using an ONNX model.",
|
448 |
version="1.0.1", # Incremented version
|
449 |
lifespan=lifespan,
|
|
|
|
|
450 |
)
|
451 |
app.state = AppState(settings)
|
452 |
return app
|
@@ -467,7 +470,20 @@ def get_tagger(app: FastAPI) -> Tagger:
|
|
467 |
def add_endpoints(app: FastAPI):
|
468 |
tagger_dependency = lambda: get_tagger(app)
|
469 |
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
async def tag_batch(
|
472 |
tags_threshold: TaggerArgs = TaggerArgs(),
|
473 |
file: UploadFile = File(
|
@@ -506,6 +522,7 @@ def add_endpoints(app: FastAPI):
|
|
506 |
),
|
507 |
)
|
508 |
|
|
|
509 |
@app.get("/status", response_model=StatusResponse, summary="Get server status")
|
510 |
async def status():
|
511 |
tagger = tagger_dependency()
|
|
|
27 |
import torch
|
28 |
import uvicorn
|
29 |
from fastapi import FastAPI, File, HTTPException, UploadFile
|
30 |
+
from fastapi.responses import RedirectResponse
|
31 |
from PIL import Image
|
32 |
from pydantic import BaseModel, Field
|
33 |
from pydantic_settings import BaseSettings
|
|
|
448 |
description="An API for tagging images using an ONNX model.",
|
449 |
version="1.0.1", # Incremented version
|
450 |
lifespan=lifespan,
|
451 |
+
docs_url="/docs",
|
452 |
+
|
453 |
)
|
454 |
app.state = AppState(settings)
|
455 |
return app
|
|
|
470 |
def add_endpoints(app: FastAPI):
|
471 |
tagger_dependency = lambda: get_tagger(app)
|
472 |
|
473 |
+
# Root welcome/docs page
|
474 |
+
@app.get("/", include_in_schema=False)
|
475 |
+
async def root():
|
476 |
+
if app.docs_url:
|
477 |
+
return RedirectResponse(url=app.docs_url)
|
478 |
+
elif app.redoc_url:
|
479 |
+
return RedirectResponse(url=app.redoc_url)
|
480 |
+
return HTMLResponse(
|
481 |
+
content="<h1>Welcome to the Tagger API</h1><p>Use /batch to tag images.</p>",
|
482 |
+
status_code=200
|
483 |
+
)
|
484 |
+
|
485 |
+
# Tagging endpoint at /batch
|
486 |
+
@app.post("/batch", response_model=BatchTaggingResponse, summary="Tag a batch of images")
|
487 |
async def tag_batch(
|
488 |
tags_threshold: TaggerArgs = TaggerArgs(),
|
489 |
file: UploadFile = File(
|
|
|
522 |
),
|
523 |
)
|
524 |
|
525 |
+
# Status endpoint
|
526 |
@app.get("/status", response_model=StatusResponse, summary="Get server status")
|
527 |
async def status():
|
528 |
tagger = tagger_dependency()
|