from transformers import AutoImageProcessor, AutoModel from typing import Dict import numpy as np from matplotlib import cm from PIL import Image from torch import Tensor model = AutoModel.from_pretrained( "RGBD-SOD/dptdepth", trust_remote_code=True, cache_dir="model_cache" ) image_processor = AutoImageProcessor.from_pretrained( "RGBD-SOD/dptdepth", trust_remote_code=True, cache_dir="image_processor_cache" ) def inference(rgb: Image.Image) -> Image.Image: rgb = rgb.convert(mode="RGB") preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess( { "rgb": rgb, } ) output: Dict[str, Tensor] = model(preprocessed_sample["rgb"]) postprocessed_sample: np.ndarray = image_processor.postprocess( output["logits"], [rgb.size[1], rgb.size[0]] ) prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255)) return prediction if __name__ == "__main__": pass