Image Classification
densenet
vision
lucazhou2000 commited on
Commit
deb2df9
·
verified ·
1 Parent(s): 79e006c

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +46 -0
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ import torchxrayvision as xrv
8
+
9
+ def init():
10
+ """
11
+ Called once at container startup.
12
+ Loads the DenseNet model from torchxrayvision (using HF Hub weights)
13
+ and sets up the crop transform.
14
+ """
15
+ global model, transform
16
+ model_name = "densenet121-res224-chex"
17
+ model = xrv.models.get_model(model_name, from_hf_hub=True)
18
+ model.eval()
19
+ # Center‐crop to a square patch around the lung
20
+ transform = xrv.datasets.XRayCenterCrop(pad=32)
21
+
22
+ def predict(request):
23
+ """
24
+ Called on each inference request.
25
+ Expects a JSON payload like {"image": "..."}.
26
+ Returns a dict with scores and labels.
27
+ """
28
+ # 1) Decode base64 Data URI
29
+ data_uri = request.json.get("image", "")
30
+ b64 = data_uri.split(",")[-1]
31
+ img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
32
+
33
+ # 2) To numpy array & normalize
34
+ arr = np.array(img)
35
+ arr = xrv.datasets.normalize(arr, 255) # scale pixel values
36
+
37
+ # 3) Center crop & to tensor
38
+ arr = transform(arr) # H×W → cropped H×W
39
+ tensor = torch.tensor(arr).permute(2, 0, 1).float().unsqueeze(0)
40
+
41
+ # 4) Inference
42
+ with torch.no_grad():
43
+ scores = model(tensor).tolist()
44
+
45
+ # 5) Return scores + pathologies
46
+ return {"scores": scores, "labels": model.pathologies}