savtadepth / app /app_savta.py
Abid Ali Awan
Fix model path configuration to prevent duplicate directory
81cb914
from pathlib import Path
import gradio as gr
from fastai.vision.all import *
#######################
# Data & Learner #
#######################
class ImageImageDataLoaders(DataLoaders):
"""Create DataLoaders for image→image tasks."""
@classmethod
@delegates(DataLoaders.from_dblock)
def from_label_func(
cls,
path: Path,
filenames,
label_func,
valid_pct: float = 0.2,
seed: int | None = None,
item_transforms=None,
batch_transforms=None,
**kwargs,
):
dblock = DataBlock(
blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
get_y=label_func,
splitter=RandomSplitter(valid_pct, seed=seed),
item_tfms=item_transforms,
batch_tfms=batch_transforms,
)
return cls.from_dblock(dblock, filenames, path=path, **kwargs)
def get_y_fn(x: Path) -> Path:
"""Return same image as label for architecture initialization."""
return x
def create_data(data_path: Path):
"""Create minimal data loader for model architecture initialization."""
fnames = get_files(data_path, extensions=".jpg")
return ImageImageDataLoaders.from_label_func(
data_path,
seed=42,
bs=1,
num_workers=0,
valid_pct=0.0,
filenames=fnames,
label_func=get_y_fn,
)
# Initialize learner with architecture
data = create_data(Path("examples"))
learner = unet_learner(
data,
resnet34,
n_out=3,
loss_func=MSELossFlat(),
path=".",
model_dir="models",
)
learner.load("model")
#####################
# Inference Logic #
#####################
def predict_depth(input_img: PILImage) -> PILImageBW:
depth, *_ = learner.predict(input_img)
return PILImageBW.create(depth).convert("L")
#####################
# Gradio UI #
#####################
title = "📷 SavtaDepth WebApp"
description_md = """
<p style="text-align:center;font-size:1.05rem;max-width:760px;margin:auto;">
Upload an RGB image on the left and get a grayscale depth map on the right.
</p>
"""
footer_html = """
<p style='text-align:center;font-size:0.9rem;'>
<a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>Project on DAGsHub</a> •
<a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a>
</p>
"""
examples = [["examples/00008.jpg"], ["examples/00045.jpg"]]
input_component = gr.Image(width=640, height=480, label="Input RGB")
output_component = gr.Image(label="Predicted Depth", image_mode="L")
with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo:
gr.Markdown(f"<center><h1>{title}</h1></center>")
gr.HTML(description_md)
gr.Interface(
fn=predict_depth,
inputs=input_component,
outputs=output_component,
examples=examples,
cache_examples=False,
)
gr.HTML(footer_html)
if __name__ == "__main__":
demo.queue().launch()