Create inference.py
Browse files- inference.py +46 -0
inference.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import torchxrayvision as xrv
|
8 |
+
|
9 |
+
def init():
|
10 |
+
"""
|
11 |
+
Called once at container startup.
|
12 |
+
Loads the DenseNet model from torchxrayvision (using HF Hub weights)
|
13 |
+
and sets up the crop transform.
|
14 |
+
"""
|
15 |
+
global model, transform
|
16 |
+
model_name = "densenet121-res224-chex"
|
17 |
+
model = xrv.models.get_model(model_name, from_hf_hub=True)
|
18 |
+
model.eval()
|
19 |
+
# Center‐crop to a square patch around the lung
|
20 |
+
transform = xrv.datasets.XRayCenterCrop(pad=32)
|
21 |
+
|
22 |
+
def predict(request):
|
23 |
+
"""
|
24 |
+
Called on each inference request.
|
25 |
+
Expects a JSON payload like {"image": "data:image/jpeg;base64,/9j/4AAQ..."}.
|
26 |
+
Returns a dict with scores and labels.
|
27 |
+
"""
|
28 |
+
# 1) Decode base64 Data URI
|
29 |
+
data_uri = request.json.get("image", "")
|
30 |
+
b64 = data_uri.split(",")[-1]
|
31 |
+
img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
|
32 |
+
|
33 |
+
# 2) To numpy array & normalize
|
34 |
+
arr = np.array(img)
|
35 |
+
arr = xrv.datasets.normalize(arr, 255) # scale pixel values
|
36 |
+
|
37 |
+
# 3) Center crop & to tensor
|
38 |
+
arr = transform(arr) # H×W → cropped H×W
|
39 |
+
tensor = torch.tensor(arr).permute(2, 0, 1).float().unsqueeze(0)
|
40 |
+
|
41 |
+
# 4) Inference
|
42 |
+
with torch.no_grad():
|
43 |
+
scores = model(tensor).tolist()
|
44 |
+
|
45 |
+
# 5) Return scores + pathologies
|
46 |
+
return {"scores": scores, "labels": model.pathologies}
|