Shio-Koube commited on
Commit
db966aa
·
verified ·
1 Parent(s): 108dadc

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +18 -1
server.py CHANGED
@@ -27,6 +27,7 @@ import timm
27
  import torch
28
  import uvicorn
29
  from fastapi import FastAPI, File, HTTPException, UploadFile
 
30
  from PIL import Image
31
  from pydantic import BaseModel, Field
32
  from pydantic_settings import BaseSettings
@@ -447,6 +448,8 @@ def create_app(settings: Settings) -> FastAPI:
447
  description="An API for tagging images using an ONNX model.",
448
  version="1.0.1", # Incremented version
449
  lifespan=lifespan,
 
 
450
  )
451
  app.state = AppState(settings)
452
  return app
@@ -467,7 +470,20 @@ def get_tagger(app: FastAPI) -> Tagger:
467
  def add_endpoints(app: FastAPI):
468
  tagger_dependency = lambda: get_tagger(app)
469
 
470
- @app.post("/", response_model=BatchTaggingResponse, summary="Tag a batch of images")
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  async def tag_batch(
472
  tags_threshold: TaggerArgs = TaggerArgs(),
473
  file: UploadFile = File(
@@ -506,6 +522,7 @@ def add_endpoints(app: FastAPI):
506
  ),
507
  )
508
 
 
509
  @app.get("/status", response_model=StatusResponse, summary="Get server status")
510
  async def status():
511
  tagger = tagger_dependency()
 
27
  import torch
28
  import uvicorn
29
  from fastapi import FastAPI, File, HTTPException, UploadFile
30
+ from fastapi.responses import RedirectResponse
31
  from PIL import Image
32
  from pydantic import BaseModel, Field
33
  from pydantic_settings import BaseSettings
 
448
  description="An API for tagging images using an ONNX model.",
449
  version="1.0.1", # Incremented version
450
  lifespan=lifespan,
451
+ docs_url="/docs",
452
+
453
  )
454
  app.state = AppState(settings)
455
  return app
 
470
  def add_endpoints(app: FastAPI):
471
  tagger_dependency = lambda: get_tagger(app)
472
 
473
+ # Root welcome/docs page
474
+ @app.get("/", include_in_schema=False)
475
+ async def root():
476
+ if app.docs_url:
477
+ return RedirectResponse(url=app.docs_url)
478
+ elif app.redoc_url:
479
+ return RedirectResponse(url=app.redoc_url)
480
+ return HTMLResponse(
481
+ content="<h1>Welcome to the Tagger API</h1><p>Use /batch to tag images.</p>",
482
+ status_code=200
483
+ )
484
+
485
+ # Tagging endpoint at /batch
486
+ @app.post("/batch", response_model=BatchTaggingResponse, summary="Tag a batch of images")
487
  async def tag_batch(
488
  tags_threshold: TaggerArgs = TaggerArgs(),
489
  file: UploadFile = File(
 
522
  ),
523
  )
524
 
525
+ # Status endpoint
526
  @app.get("/status", response_model=StatusResponse, summary="Get server status")
527
  async def status():
528
  tagger = tagger_dependency()