Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import gradio as gr | |
| from fastai.vision.all import * | |
| ####################### | |
| # Data & Learner # | |
| ####################### | |
| class ImageImageDataLoaders(DataLoaders): | |
| """Create DataLoaders for image→image tasks.""" | |
| def from_label_func( | |
| cls, | |
| path: Path, | |
| filenames, | |
| label_func, | |
| valid_pct: float = 0.2, | |
| seed: int | None = None, | |
| item_transforms=None, | |
| batch_transforms=None, | |
| **kwargs, | |
| ): | |
| dblock = DataBlock( | |
| blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)), | |
| get_y=label_func, | |
| splitter=RandomSplitter(valid_pct, seed=seed), | |
| item_tfms=item_transforms, | |
| batch_tfms=batch_transforms, | |
| ) | |
| return cls.from_dblock(dblock, filenames, path=path, **kwargs) | |
| def get_y_fn(x: Path) -> Path: | |
| """Return same image as label for architecture initialization.""" | |
| return x | |
| def create_data(data_path: Path): | |
| """Create minimal data loader for model architecture initialization.""" | |
| fnames = get_files(data_path, extensions=".jpg") | |
| return ImageImageDataLoaders.from_label_func( | |
| data_path, | |
| seed=42, | |
| bs=1, | |
| num_workers=0, | |
| valid_pct=0.0, | |
| filenames=fnames, | |
| label_func=get_y_fn, | |
| ) | |
| # Initialize learner with architecture | |
| data = create_data(Path("examples")) | |
| learner = unet_learner( | |
| data, | |
| resnet34, | |
| n_out=3, | |
| loss_func=MSELossFlat(), | |
| path=".", | |
| model_dir="models", | |
| ) | |
| learner.load("model") | |
| ##################### | |
| # Inference Logic # | |
| ##################### | |
| def predict_depth(input_img: PILImage) -> PILImageBW: | |
| depth, *_ = learner.predict(input_img) | |
| return PILImageBW.create(depth).convert("L") | |
| ##################### | |
| # Gradio UI # | |
| ##################### | |
| title = "📷 SavtaDepth WebApp" | |
| description_md = """ | |
| <p style="text-align:center;font-size:1.05rem;max-width:760px;margin:auto;"> | |
| Upload an RGB image on the left and get a grayscale depth map on the right. | |
| </p> | |
| """ | |
| footer_html = """ | |
| <p style='text-align:center;font-size:0.9rem;'> | |
| <a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>Project on DAGsHub</a> • | |
| <a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a> | |
| </p> | |
| """ | |
| examples = [["examples/00008.jpg"], ["examples/00045.jpg"]] | |
| input_component = gr.Image(width=640, height=480, label="Input RGB") | |
| output_component = gr.Image(label="Predicted Depth", image_mode="L") | |
| with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"<center><h1>{title}</h1></center>") | |
| gr.HTML(description_md) | |
| gr.Interface( | |
| fn=predict_depth, | |
| inputs=input_component, | |
| outputs=output_component, | |
| examples=examples, | |
| cache_examples=False, | |
| ) | |
| gr.HTML(footer_html) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |