| import os | |
| import gradio as gr | |
| from src.run.unet.inference import ResUnetInfer | |
| infer = ResUnetInfer( | |
| model_path="./checkpoint/resunet/decoder.pt", | |
| config_path="./src/models/unet/config/resnet_config.yml", | |
| ) | |
| demo = gr.Interface( | |
| fn=infer.infer, | |
| inputs=[ | |
| gr.Image( | |
| shape=(224, 224), | |
| label="Input Image", | |
| value="./sample/bird_plane.jpeg", | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Image(), | |
| ], | |
| examples=[[os.path.join("./sample/", f)] for f in os.listdir("./sample/")], | |
| ) | |
| demo.launch() | |