from typing import Callable from threading import Lock from secrets import compare_digest from modules import shared from modules.api.api import decode_base64_to_image from modules.call_queue import queue_lock from fastapi import FastAPI, Depends, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials from tagger import utils from tagger import api_models as models class Api: def __init__(self, app: FastAPI, queue_lock: Lock, prefix: str = None) -> None: if shared.cmd_opts.api_auth: self.credentials = dict() for auth in shared.cmd_opts.api_auth.split(","): user, password = auth.split(":") self.credentials[user] = password self.app = app self.queue_lock = queue_lock self.prefix = prefix self.add_api_route( 'interrogate', self.endpoint_interrogate, methods=['POST'], response_model=models.TaggerInterrogateResponse ) self.add_api_route( 'interrogators', self.endpoint_interrogators, methods=['GET'], response_model=models.InterrogatorsResponse ) def auth(self, creds: HTTPBasicCredentials = Depends(HTTPBasic())): if creds.username in self.credentials: if compare_digest(creds.password, self.credentials[creds.username]): return True raise HTTPException( status_code=401, detail="Incorrect username or password", headers={ "WWW-Authenticate": "Basic" }) def add_api_route(self, path: str, endpoint: Callable, **kwargs): if self.prefix: path = f'{self.prefix}/{path}' if shared.cmd_opts.api_auth: return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) return self.app.add_api_route(path, endpoint, **kwargs) def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): if req.image is None: raise HTTPException(404, 'Image not found') if req.model not in utils.interrogators.keys(): raise HTTPException(404, 'Model not found') image = decode_base64_to_image(req.image) interrogator = utils.interrogators[req.model] with self.queue_lock: ratings, tags = interrogator.interrogate(image) return models.TaggerInterrogateResponse( caption={ **ratings, **interrogator.postprocess_tags( tags, req.threshold ) }) def endpoint_interrogators(self): return models.InterrogatorsResponse( models=list(utils.interrogators.keys()) ) def on_app_started(_, app: FastAPI): Api(app, queue_lock, '/tagger/v1')