Luisgust commited on
Commit
2269978
·
verified ·
1 Parent(s): 770bb90

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +46 -0
main.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from PIL import Image
5
+ import io
6
+ import torch
7
+ from clip_interrogator import Config, Interrogator
8
+
9
+ app = FastAPI()
10
+
11
+ # Allow CORS for all origins (adjust as needed for production)
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Setup the CLIP Interrogator
21
+ config = Config()
22
+ config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ config.blip_offload = False if torch.cuda.is_available() else True
24
+ config.chunk_size = 2048
25
+ config.flavor_intermediate_count = 512
26
+ config.blip_num_beams = 64
27
+
28
+ ci = Interrogator(config)
29
+
30
+ @app.post("/inference/")
31
+ async def interrogate_images(file: UploadFile = File(...), mode: str = Form(...), best_max_flavors: int = Form(...)):
32
+ try:
33
+ contents = await file.read()
34
+ image = Image.open(io.BytesIO(contents)).convert('RGB')
35
+
36
+ if mode == 'best':
37
+ prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors))
38
+ elif mode == 'classic':
39
+ prompt_result = ci.interrogate_classic(image)
40
+ else:
41
+ prompt_result = ci.interrogate_fast(image)
42
+
43
+ return JSONResponse(content={"prompt_results": [prompt_result]})
44
+ except Exception as e:
45
+ return JSONResponse(content={"error": str(e)}, status_code=500)
46
+