from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from PIL import Image import io import torch from clip_interrogator import Config, Interrogator app = FastAPI() # Allow CORS for all origins (adjust as needed for production) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Setup the CLIP Interrogator config = Config() config.device = 'cuda' if torch.cuda.is_available() else 'cpu' config.blip_offload = False if torch.cuda.is_available() else True config.chunk_size = 2048 config.flavor_intermediate_count = 512 config.blip_num_beams = 64 ci = Interrogator(config) @app.post("/inference/") async def interrogate_images(file: UploadFile = File(...), mode: str = Form(...), best_max_flavors: int = Form(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert('RGB') if mode == 'best': prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors)) elif mode == 'classic': prompt_result = ci.interrogate_classic(image) else: prompt_result = ci.interrogate_fast(image) return JSONResponse(content={"prompt_results": [prompt_result]}) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500)