import base64 import io import numpy as np from PIL import Image import torch import torchxrayvision as xrv def init(): """ Called once at container startup. Loads the DenseNet model from torchxrayvision (using HF Hub weights) and sets up the crop transform. """ global model, transform model_name = "densenet121-res224-chex" model = xrv.models.get_model(model_name, from_hf_hub=True) model.eval() # Center‐crop to a square patch around the lung transform = xrv.datasets.XRayCenterCrop(pad=32) def predict(request): """ Called on each inference request. Expects a JSON payload like {"image": "..."}. Returns a dict with scores and labels. """ # 1) Decode base64 Data URI data_uri = request.json.get("image", "") b64 = data_uri.split(",")[-1] img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") # 2) To numpy array & normalize arr = np.array(img) arr = xrv.datasets.normalize(arr, 255) # scale pixel values # 3) Center crop & to tensor arr = transform(arr) # H×W → cropped H×W tensor = torch.tensor(arr).permute(2, 0, 1).float().unsqueeze(0) # 4) Inference with torch.no_grad(): scores = model(tensor).tolist() # 5) Return scores + pathologies return {"scores": scores, "labels": model.pathologies}