from pathlib import Path import gradio as gr from fastai.vision.all import * ####################### # Data & Learner # ####################### class ImageImageDataLoaders(DataLoaders): """Create DataLoaders for image→image tasks.""" @classmethod @delegates(DataLoaders.from_dblock) def from_label_func( cls, path: Path, filenames, label_func, valid_pct: float = 0.2, seed: int | None = None, item_transforms=None, batch_transforms=None, **kwargs, ): dblock = DataBlock( blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)), get_y=label_func, splitter=RandomSplitter(valid_pct, seed=seed), item_tfms=item_transforms, batch_tfms=batch_transforms, ) return cls.from_dblock(dblock, filenames, path=path, **kwargs) def get_y_fn(x: Path) -> Path: """Return same image as label for architecture initialization.""" return x def create_data(data_path: Path): """Create minimal data loader for model architecture initialization.""" fnames = get_files(data_path, extensions=".jpg") return ImageImageDataLoaders.from_label_func( data_path, seed=42, bs=1, num_workers=0, valid_pct=0.0, filenames=fnames, label_func=get_y_fn, ) # Initialize learner with architecture data = create_data(Path("examples")) learner = unet_learner( data, resnet34, n_out=3, loss_func=MSELossFlat(), path=".", model_dir="models", ) learner.load("model") ##################### # Inference Logic # ##################### def predict_depth(input_img: PILImage) -> PILImageBW: depth, *_ = learner.predict(input_img) return PILImageBW.create(depth).convert("L") ##################### # Gradio UI # ##################### title = "📷 SavtaDepth WebApp" description_md = """
Upload an RGB image on the left and get a grayscale depth map on the right.
""" footer_html = """Project on DAGsHub • Google Colab Demo
""" examples = [["examples/00008.jpg"], ["examples/00045.jpg"]] input_component = gr.Image(width=640, height=480, label="Input RGB") output_component = gr.Image(label="Predicted Depth", image_mode="L") with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo: gr.Markdown(f"