Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	add: ripe
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- README.md +2 -1
- config/config.yaml +13 -2
- imcui/hloc/extract_features.py +11 -0
- imcui/hloc/extractors/ripe.py +46 -0
- imcui/third_party/RIPE/.gitignore +179 -0
- imcui/third_party/RIPE/LICENSE +35 -0
- imcui/third_party/RIPE/LICENSE_DALF_DISK +201 -0
- imcui/third_party/RIPE/README.md +367 -0
- imcui/third_party/RIPE/app.py +272 -0
- imcui/third_party/RIPE/assets/all_souls_000013.jpg +3 -0
- imcui/third_party/RIPE/assets/all_souls_000055.jpg +3 -0
- imcui/third_party/RIPE/assets/teaser_image.png +3 -0
- imcui/third_party/RIPE/conda_env.yml +26 -0
- imcui/third_party/RIPE/conf/backbones/resnet.yaml +6 -0
- imcui/third_party/RIPE/conf/backbones/vgg.yaml +5 -0
- imcui/third_party/RIPE/conf/data/disk_megadepth.yaml +12 -0
- imcui/third_party/RIPE/conf/data/megadepth+acdc.yaml +33 -0
- imcui/third_party/RIPE/conf/data/megadepth+tokyo.yaml +29 -0
- imcui/third_party/RIPE/conf/descriptor_loss/contrastive_loss.yaml +3 -0
- imcui/third_party/RIPE/conf/inl_th/constant.yaml +2 -0
- imcui/third_party/RIPE/conf/inl_th/exp_decay.yaml +4 -0
- imcui/third_party/RIPE/conf/matcher/concurrent_mnn_poselib.yaml +8 -0
- imcui/third_party/RIPE/conf/train.yaml +89 -0
- imcui/third_party/RIPE/conf/upsampler/hypercolumn_features.yaml +2 -0
- imcui/third_party/RIPE/conf/upsampler/interpolate_sparse2D.yaml +1 -0
- imcui/third_party/RIPE/data/download_disk_data.sh +43 -0
- imcui/third_party/RIPE/demo.py +51 -0
- imcui/third_party/RIPE/ripe/__init__.py +1 -0
- imcui/third_party/RIPE/ripe/benchmarks/imw_2020.py +320 -0
- imcui/third_party/RIPE/ripe/data/__init__.py +0 -0
- imcui/third_party/RIPE/ripe/data/data_transforms.py +204 -0
- imcui/third_party/RIPE/ripe/data/datasets/__init__.py +0 -0
- imcui/third_party/RIPE/ripe/data/datasets/acdc.py +154 -0
- imcui/third_party/RIPE/ripe/data/datasets/dataset_combinator.py +88 -0
- imcui/third_party/RIPE/ripe/data/datasets/disk_imw.py +160 -0
- imcui/third_party/RIPE/ripe/data/datasets/disk_megadepth.py +157 -0
- imcui/third_party/RIPE/ripe/data/datasets/tokyo247.py +134 -0
- imcui/third_party/RIPE/ripe/losses/__init__.py +0 -0
- imcui/third_party/RIPE/ripe/losses/contrastive_loss.py +88 -0
- imcui/third_party/RIPE/ripe/matcher/__init__.py +0 -0
- imcui/third_party/RIPE/ripe/matcher/concurrent_matcher.py +97 -0
- imcui/third_party/RIPE/ripe/matcher/pose_estimator_poselib.py +31 -0
- imcui/third_party/RIPE/ripe/model_zoo/__init__.py +1 -0
- imcui/third_party/RIPE/ripe/model_zoo/vgg_hyper.py +39 -0
- imcui/third_party/RIPE/ripe/models/__init__.py +0 -0
- imcui/third_party/RIPE/ripe/models/backbones/__init__.py +0 -0
- imcui/third_party/RIPE/ripe/models/backbones/backbone_base.py +61 -0
- imcui/third_party/RIPE/ripe/models/backbones/vgg.py +99 -0
- imcui/third_party/RIPE/ripe/models/backbones/vgg_utils.py +143 -0
- imcui/third_party/RIPE/ripe/models/ripe.py +303 -0
    	
        README.md
    CHANGED
    
    | @@ -44,8 +44,9 @@ The tool currently supports various popular image matching algorithms, namely: | |
| 44 |  | 
| 45 | 
             
            | Algorithm        | Supported | Conference/Journal | Year | GitHub Link |
         | 
| 46 | 
             
            |------------------|-----------|--------------------|------|-------------|
         | 
| 47 | 
            -
            |  | 
| 48 | 
             
            | RDD            | ✅ | CVPR    | 2025 | [Link](https://github.com/xtcpete/rdd)  |
         | 
|  | |
| 49 | 
             
            | DaD            | ✅ | ARXIV   | 2025 | [Link](https://github.com/Parskatt/dad) |
         | 
| 50 | 
             
            | MINIMA         | ✅ | ARXIV   | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
         | 
| 51 | 
             
            | XoFTR          | ✅ | CVPR    | 2024 | [Link](https://github.com/OnderT/XoFTR) |
         | 
|  | |
| 44 |  | 
| 45 | 
             
            | Algorithm        | Supported | Conference/Journal | Year | GitHub Link |
         | 
| 46 | 
             
            |------------------|-----------|--------------------|------|-------------|
         | 
| 47 | 
            +
            | RIPE           | ✅ | ICCV    | 2025 | [Link](https://github.com/fraunhoferhhi/RIPE)  |
         | 
| 48 | 
             
            | RDD            | ✅ | CVPR    | 2025 | [Link](https://github.com/xtcpete/rdd)  |
         | 
| 49 | 
            +
            | LiftFeat       | ✅ | ICRA    | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
         | 
| 50 | 
             
            | DaD            | ✅ | ARXIV   | 2025 | [Link](https://github.com/Parskatt/dad) |
         | 
| 51 | 
             
            | MINIMA         | ✅ | ARXIV   | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
         | 
| 52 | 
             
            | XoFTR          | ✅ | CVPR    | 2024 | [Link](https://github.com/OnderT/XoFTR) |
         | 
    	
        config/config.yaml
    CHANGED
    
    | @@ -267,6 +267,17 @@ matcher_zoo: | |
| 267 | 
             
                  paper: https://arxiv.org/abs/2505.0342
         | 
| 268 | 
             
                  project: null
         | 
| 269 | 
             
                  display: true
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 270 | 
             
              rdd(sparse):
         | 
| 271 | 
             
                matcher: NN-mutual
         | 
| 272 | 
             
                feature: rdd
         | 
| @@ -274,7 +285,7 @@ matcher_zoo: | |
| 274 | 
             
                info:
         | 
| 275 | 
             
                  name: RDD(sparse) #dispaly name
         | 
| 276 | 
             
                  source: "CVPR 2025"
         | 
| 277 | 
            -
                  github:  | 
| 278 | 
             
                  paper: https://arxiv.org/abs/2505.08013
         | 
| 279 | 
             
                  project: https://xtcpete.github.io/rdd
         | 
| 280 | 
             
                  display: true
         | 
| @@ -284,7 +295,7 @@ matcher_zoo: | |
| 284 | 
             
                info:
         | 
| 285 | 
             
                  name: RDD(dense) #dispaly name
         | 
| 286 | 
             
                  source: "CVPR 2025"
         | 
| 287 | 
            -
                  github:  | 
| 288 | 
             
                  paper: https://arxiv.org/abs/2505.08013
         | 
| 289 | 
             
                  project: https://xtcpete.github.io/rdd
         | 
| 290 | 
             
                  display: true
         | 
|  | |
| 267 | 
             
                  paper: https://arxiv.org/abs/2505.0342
         | 
| 268 | 
             
                  project: null
         | 
| 269 | 
             
                  display: true
         | 
| 270 | 
            +
              ripe(+mnn):
         | 
| 271 | 
            +
                matcher: NN-mutual
         | 
| 272 | 
            +
                feature: ripe
         | 
| 273 | 
            +
                dense: false
         | 
| 274 | 
            +
                info:
         | 
| 275 | 
            +
                  name: RIPE #dispaly name
         | 
| 276 | 
            +
                  source: "ICCV 2025"
         | 
| 277 | 
            +
                  github: https://github.com/fraunhoferhhi/RIPE
         | 
| 278 | 
            +
                  paper: https://arxiv.org/abs/2507.04839
         | 
| 279 | 
            +
                  project: https://fraunhoferhhi.github.io/RIPE
         | 
| 280 | 
            +
                  display: true
         | 
| 281 | 
             
              rdd(sparse):
         | 
| 282 | 
             
                matcher: NN-mutual
         | 
| 283 | 
             
                feature: rdd
         | 
|  | |
| 285 | 
             
                info:
         | 
| 286 | 
             
                  name: RDD(sparse) #dispaly name
         | 
| 287 | 
             
                  source: "CVPR 2025"
         | 
| 288 | 
            +
                  github: https://github.com/xtcpete/rdd
         | 
| 289 | 
             
                  paper: https://arxiv.org/abs/2505.08013
         | 
| 290 | 
             
                  project: https://xtcpete.github.io/rdd
         | 
| 291 | 
             
                  display: true
         | 
|  | |
| 295 | 
             
                info:
         | 
| 296 | 
             
                  name: RDD(dense) #dispaly name
         | 
| 297 | 
             
                  source: "CVPR 2025"
         | 
| 298 | 
            +
                  github: https://github.com/xtcpete/rdd
         | 
| 299 | 
             
                  paper: https://arxiv.org/abs/2505.08013
         | 
| 300 | 
             
                  project: https://xtcpete.github.io/rdd
         | 
| 301 | 
             
                  display: true
         | 
    	
        imcui/hloc/extract_features.py
    CHANGED
    
    | @@ -236,6 +236,17 @@ confs = { | |
| 236 | 
             
                        "resize_max": 1600,
         | 
| 237 | 
             
                    },
         | 
| 238 | 
             
                },
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 239 | 
             
                "aliked-n16-rot": {
         | 
| 240 | 
             
                    "output": "feats-aliked-n16-rot",
         | 
| 241 | 
             
                    "model": {
         | 
|  | |
| 236 | 
             
                        "resize_max": 1600,
         | 
| 237 | 
             
                    },
         | 
| 238 | 
             
                },
         | 
| 239 | 
            +
                "ripe": {
         | 
| 240 | 
            +
                    "output": "feats-ripe-n2048-r1600",
         | 
| 241 | 
            +
                    "model": {
         | 
| 242 | 
            +
                        "name": "ripe",
         | 
| 243 | 
            +
                        "max_keypoints": 2048,
         | 
| 244 | 
            +
                    },
         | 
| 245 | 
            +
                    "preprocessing": {
         | 
| 246 | 
            +
                        "grayscale": False,
         | 
| 247 | 
            +
                        "resize_max": 1600,
         | 
| 248 | 
            +
                    },
         | 
| 249 | 
            +
                },
         | 
| 250 | 
             
                "aliked-n16-rot": {
         | 
| 251 | 
             
                    "output": "feats-aliked-n16-rot",
         | 
| 252 | 
             
                    "model": {
         | 
    	
        imcui/hloc/extractors/ripe.py
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
            +
            from ..utils.base_model import BaseModel
         | 
| 4 | 
            +
            from .. import logger, MODEL_REPO_ID
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            ripe_path = Path(__file__).parent / "../../third_party/RIPE"
         | 
| 7 | 
            +
            sys.path.append(str(ripe_path))
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ripe import vgg_hyper
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class RIPE(BaseModel):
         | 
| 13 | 
            +
                default_conf = {
         | 
| 14 | 
            +
                    "keypoint_threshold": 0.05,
         | 
| 15 | 
            +
                    "max_keypoints": 5000,
         | 
| 16 | 
            +
                    "model_name": "weights_ripe.pth",
         | 
| 17 | 
            +
                }
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                required_inputs = ["image"]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def _init(self, conf):
         | 
| 22 | 
            +
                    logger.info("Loading RIPE model...")
         | 
| 23 | 
            +
                    model_path = self._download_model(
         | 
| 24 | 
            +
                        repo_id=MODEL_REPO_ID,
         | 
| 25 | 
            +
                        filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
         | 
| 26 | 
            +
                    )
         | 
| 27 | 
            +
                    self.net = vgg_hyper(Path(model_path))
         | 
| 28 | 
            +
                    logger.info("Loading RIPE model done!")
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def _forward(self, data):
         | 
| 31 | 
            +
                    keypoints, descriptors, scores = self.net.detectAndCompute(
         | 
| 32 | 
            +
                        data["image"], threshold=0.5, top_k=2048
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    if self.conf["max_keypoints"] < len(keypoints):
         | 
| 36 | 
            +
                        idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
         | 
| 37 | 
            +
                        keypoints = keypoints[idxs, :2]
         | 
| 38 | 
            +
                        descriptors = descriptors[idxs]
         | 
| 39 | 
            +
                        scores = scores[idxs]
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    pred = {
         | 
| 42 | 
            +
                        "keypoints": keypoints[None],
         | 
| 43 | 
            +
                        "descriptors": descriptors[None].permute(0, 2, 1),
         | 
| 44 | 
            +
                        "scores": scores[None],
         | 
| 45 | 
            +
                    }
         | 
| 46 | 
            +
                    return pred
         | 
    	
        imcui/third_party/RIPE/.gitignore
    ADDED
    
    | @@ -0,0 +1,179 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            *.py[cod]
         | 
| 4 | 
            +
            *$py.class
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # C extensions
         | 
| 7 | 
            +
            *.so
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Distribution / packaging
         | 
| 10 | 
            +
            .Python
         | 
| 11 | 
            +
            build/
         | 
| 12 | 
            +
            develop-eggs/
         | 
| 13 | 
            +
            dist/
         | 
| 14 | 
            +
            downloads/
         | 
| 15 | 
            +
            eggs/
         | 
| 16 | 
            +
            .eggs/
         | 
| 17 | 
            +
            lib/
         | 
| 18 | 
            +
            lib64/
         | 
| 19 | 
            +
            parts/
         | 
| 20 | 
            +
            sdist/
         | 
| 21 | 
            +
            var/
         | 
| 22 | 
            +
            wheels/
         | 
| 23 | 
            +
            pip-wheel-metadata/
         | 
| 24 | 
            +
            share/python-wheels/
         | 
| 25 | 
            +
            *.egg-info/
         | 
| 26 | 
            +
            .installed.cfg
         | 
| 27 | 
            +
            *.egg
         | 
| 28 | 
            +
            MANIFEST
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # PyInstaller
         | 
| 31 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 32 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 33 | 
            +
            *.manifest
         | 
| 34 | 
            +
            *.spec
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Installer logs
         | 
| 37 | 
            +
            pip-log.txt
         | 
| 38 | 
            +
            pip-delete-this-directory.txt
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            # Unit test / coverage reports
         | 
| 41 | 
            +
            htmlcov/
         | 
| 42 | 
            +
            .tox/
         | 
| 43 | 
            +
            .nox/
         | 
| 44 | 
            +
            .coverage
         | 
| 45 | 
            +
            .coverage.*
         | 
| 46 | 
            +
            .cache
         | 
| 47 | 
            +
            nosetests.xml
         | 
| 48 | 
            +
            coverage.xml
         | 
| 49 | 
            +
            *.cover
         | 
| 50 | 
            +
            *.py,cover
         | 
| 51 | 
            +
            .hypothesis/
         | 
| 52 | 
            +
            .pytest_cache/
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Translations
         | 
| 55 | 
            +
            *.mo
         | 
| 56 | 
            +
            *.pot
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Django stuff:
         | 
| 59 | 
            +
            *.log
         | 
| 60 | 
            +
            local_settings.py
         | 
| 61 | 
            +
            db.sqlite3
         | 
| 62 | 
            +
            db.sqlite3-journal
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            # Flask stuff:
         | 
| 65 | 
            +
            instance/
         | 
| 66 | 
            +
            .webassets-cache
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            # Scrapy stuff:
         | 
| 69 | 
            +
            .scrapy
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            # Sphinx documentation
         | 
| 72 | 
            +
            docs/_build/
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            # PyBuilder
         | 
| 75 | 
            +
            target/
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            # Jupyter Notebook
         | 
| 78 | 
            +
            .ipynb_checkpoints
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            # IPython
         | 
| 81 | 
            +
            profile_default/
         | 
| 82 | 
            +
            ipython_config.py
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            # pyenv
         | 
| 85 | 
            +
            .python-version
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            # pipenv
         | 
| 88 | 
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| 89 | 
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         | 
| 90 | 
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         | 
| 91 | 
            +
            #   install all needed dependencies.
         | 
| 92 | 
            +
            #Pipfile.lock
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow
         | 
| 95 | 
            +
            __pypackages__/
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            # Celery stuff
         | 
| 98 | 
            +
            celerybeat-schedule
         | 
| 99 | 
            +
            celerybeat.pid
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            # SageMath parsed files
         | 
| 102 | 
            +
            *.sage.py
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            # Environments
         | 
| 105 | 
            +
            .venv
         | 
| 106 | 
            +
            env/
         | 
| 107 | 
            +
            venv/
         | 
| 108 | 
            +
            ENV/
         | 
| 109 | 
            +
            env.bak/
         | 
| 110 | 
            +
            venv.bak/
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            # Spyder project settings
         | 
| 113 | 
            +
            .spyderproject
         | 
| 114 | 
            +
            .spyproject
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            # Rope project settings
         | 
| 117 | 
            +
            .ropeproject
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            # mkdocs documentation
         | 
| 120 | 
            +
            /site
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            # mypy
         | 
| 123 | 
            +
            .mypy_cache/
         | 
| 124 | 
            +
            .dmypy.json
         | 
| 125 | 
            +
            dmypy.json
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            # Pyre type checker
         | 
| 128 | 
            +
            .pyre/
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            ### VisualStudioCode
         | 
| 131 | 
            +
            .vscode/*
         | 
| 132 | 
            +
            !.vscode/settings.json
         | 
| 133 | 
            +
            !.vscode/tasks.json
         | 
| 134 | 
            +
            !.vscode/launch.json
         | 
| 135 | 
            +
            !.vscode/extensions.json
         | 
| 136 | 
            +
            *.code-workspace
         | 
| 137 | 
            +
            **/.vscode
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            # JetBrains
         | 
| 140 | 
            +
            .idea/
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            # ignore outputs
         | 
| 143 | 
            +
            /outputs/
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            # ignore logs
         | 
| 146 | 
            +
            /logs/
         | 
| 147 | 
            +
            tmp.py
         | 
| 148 | 
            +
            .env
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            # ignore pretrained pytorch models
         | 
| 151 | 
            +
            *.pth
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            # ignore lightning_logs
         | 
| 154 | 
            +
            /lightning_logs/*
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            # ignore built apptainer images
         | 
| 157 | 
            +
            *.sif
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            # ignore the outputs server on the cluster
         | 
| 160 | 
            +
            /output/*
         | 
| 161 | 
            +
            # ignore .out files generated from the cluster
         | 
| 162 | 
            +
            *.out
         | 
| 163 | 
            +
            # ignore hparams_search folder
         | 
| 164 | 
            +
            /hparams_search_configs/*
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            *.o
         | 
| 167 | 
            +
            *.pkl
         | 
| 168 | 
            +
            *.ninja_deps
         | 
| 169 | 
            +
            *.ninja_log
         | 
| 170 | 
            +
            *.ninja
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            /misc/*
         | 
| 173 | 
            +
            /tmp/*
         | 
| 174 | 
            +
            /apptainer_env.box/*
         | 
| 175 | 
            +
            /scripts/tmp_build/*
         | 
| 176 | 
            +
            /checkpoints
         | 
| 177 | 
            +
            /pretrained_weights
         | 
| 178 | 
            +
            /results_supple_cvpr
         | 
| 179 | 
            +
            /ext_files
         | 
    	
        imcui/third_party/RIPE/LICENSE
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Software Copyright License for Academic Use of RIPE, Version 2.0
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            © Copyright (2025) Fraunhofer-Gesellschaft zur Förderung der angewandten Forschung e.V.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            1. INTRODUCTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            RIPE which means any source code, object code or binary files provided by Fraunhofer excluding third party software and materials, is made available under this Software Copyright License.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            2. COPYRIGHT LICENSE
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Internal use of RIPE, in source and binary forms, with or without modification, is permitted without payment of copyright license fees for non-commercial purposes of evaluation, testing and academic research.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            No right or license, express or implied, is granted to any part of RIPE except and solely to the extent as expressly set forth herein. Any commercial use or exploitation of RIPE and/or any modifications thereto under this license are prohibited.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            For any other use of RIPE than permitted by this software copyright license You need another license from Fraunhofer. In such case please contact Fraunhofer under the CONTACT INFORMATION below.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            3. LIMITED PATENT LICENSE
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            If Fraunhofer patents are implemented by RIPE their use is permitted for internal non-commercial purposes of evaluation, testing and academic research. No patent grant is provided for any other use, including but not limited to commercial use or exploitation.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            Fraunhofer provides no warranty of patent non-infringement with respect to RIPE.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            4. PLACE OF JURISDICTION
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            German law shall apply to all disputes arising from the use of the licensed software. A court in Munich shall have local jurisdiction.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            5. DISCLAIMER
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            RIPE is provided by Fraunhofer "AS IS" and WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES, including but not limited to the implied warranties of fitness for a particular purpose. IN NO EVENT SHALL FRAUNHOFER BE LIABLE for any direct, indirect, incidental, special, exemplary, or consequential damages, including but not limited to procurement of substitute goods or services; loss of use, data, or profits, or business interruption, however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence), arising in any way out of the use of the Fraunhofer Software, even if advised of the possibility of such damage.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            6. CONTACT INFORMATION
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            Fraunhofer-Institut für Nachrichtentechnik, Heinrich-Hertz-Institut, HHI
         | 
| 34 | 
            +
            Einsteinufer 37, 10587 Berlin, Germany
         | 
| 35 | 
    	
        imcui/third_party/RIPE/LICENSE_DALF_DISK
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright [yyyy] [name of copyright owner]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        imcui/third_party/RIPE/README.md
    ADDED
    
    | @@ -0,0 +1,367 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            <p align="center">
         | 
| 3 | 
            +
              <h1 align="center"> <ins>RIPE</ins>:<br> Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction <br><br>🌊🌺 ICCV 2025 🌺🌊</h1>
         | 
| 4 | 
            +
              <p align="center">
         | 
| 5 | 
            +
                <a href="https://scholar.google.com/citations?user=ybMR38kAAAAJ">Johannes Künzel</a>
         | 
| 6 | 
            +
                ·
         | 
| 7 | 
            +
                <a href="https://scholar.google.com/citations?user=5yTuyGIAAAAJ">Anna Hilsmann</a>
         | 
| 8 | 
            +
                ·
         | 
| 9 | 
            +
                <a href="https://scholar.google.com/citations?user=BCElyCkAAAAJ">Peter Eisert</a>
         | 
| 10 | 
            +
              </p>
         | 
| 11 | 
            +
              <h2 align="center"><p>
         | 
| 12 | 
            +
                <a href="https://arxiv.org/abs/2507.04839" align="center">Arxiv</a> | 
         | 
| 13 | 
            +
                <a href="https://fraunhoferhhi.github.io/RIPE/" align="center">Project Page</a> |
         | 
| 14 | 
            +
                <a href="https://huggingface.co/spaces/JohannesK14/RIPE" align="center">🤗Demo🤗</a>
         | 
| 15 | 
            +
              </p></h2>  
         | 
| 16 | 
            +
              <div align="center"></div>
         | 
| 17 | 
            +
            </p>
         | 
| 18 | 
            +
            <br/>
         | 
| 19 | 
            +
            <p align="center">
         | 
| 20 | 
            +
                <img src="assets/teaser_image.png" alt="example" width=80%>
         | 
| 21 | 
            +
                <br>
         | 
| 22 | 
            +
                <em>RIPE demonstrates that keypoint detection and description can be learned from image pairs only - no depth, no pose, no artificial augmentation required.</em>
         | 
| 23 | 
            +
            </p>
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            ## Setup
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            💡**Alternative**💡 Install nothing locally and try our Hugging Face demo: [🤗Demo🤗](https://huggingface.co/spaces/JohannesK14/RIPE)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            1. Install mamba by following the instructions given here: [Mamba Installation](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            2. Create a new environment with:
         | 
| 32 | 
            +
            ```bash
         | 
| 33 | 
            +
            mamba create -f conda_env.yml
         | 
| 34 | 
            +
            mamba activate ripe-env
         | 
| 35 | 
            +
            ```
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            ## How to use
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            Or just check [demo.py](demo.py)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            ```python
         | 
| 42 | 
            +
            import cv2
         | 
| 43 | 
            +
            import kornia.feature as KF
         | 
| 44 | 
            +
            import kornia.geometry as KG
         | 
| 45 | 
            +
            import matplotlib.pyplot as plt
         | 
| 46 | 
            +
            import numpy as np
         | 
| 47 | 
            +
            import torch
         | 
| 48 | 
            +
            from torchvision.io import decode_image
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            from ripe import vgg_hyper
         | 
| 51 | 
            +
            from ripe.utils.utils import cv2_matches_from_kornia, resize_image, to_cv_kpts
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            model = vgg_hyper().to(dev)
         | 
| 56 | 
            +
            model.eval()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            image1 = resize_image(decode_image("assets/all_souls_000013.jpg").float().to(dev) / 255.0)
         | 
| 59 | 
            +
            image2 = resize_image(decode_image("assets/all_souls_000055.jpg").float().to(dev) / 255.0)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
         | 
| 62 | 
            +
            kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            matcher = KF.DescriptorMatcher("mnn")  # threshold is not used with mnn
         | 
| 65 | 
            +
            match_dists, match_idxs = matcher(desc_1, desc_2)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            matched_pts_1 = kpts_1[match_idxs[:, 0]]
         | 
| 68 | 
            +
            matched_pts_2 = kpts_2[match_idxs[:, 1]]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=1.0)(matched_pts_1, matched_pts_2)
         | 
| 71 | 
            +
            matchesMask = mask.int().ravel().tolist()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            result_ransac = cv2.drawMatches(
         | 
| 74 | 
            +
                (image1.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
         | 
| 75 | 
            +
                to_cv_kpts(kpts_1, score_1),
         | 
| 76 | 
            +
                (image2.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
         | 
| 77 | 
            +
                to_cv_kpts(kpts_2, score_2),
         | 
| 78 | 
            +
                cv2_matches_from_kornia(match_dists, match_idxs),
         | 
| 79 | 
            +
                None,
         | 
| 80 | 
            +
                matchColor=(0, 255, 0),
         | 
| 81 | 
            +
                matchesMask=matchesMask,
         | 
| 82 | 
            +
                # matchesMask=None, # without RANSAC filtering
         | 
| 83 | 
            +
                singlePointColor=(0, 0, 255),
         | 
| 84 | 
            +
                flags=cv2.DrawMatchesFlags_DEFAULT,
         | 
| 85 | 
            +
            )
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            plt.imshow(result_ransac)
         | 
| 88 | 
            +
            plt.axis("off")
         | 
| 89 | 
            +
            plt.tight_layout()
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            plt.show()
         | 
| 92 | 
            +
            # plt.savefig("result_ransac.png")
         | 
| 93 | 
            +
            ```
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            ## Reproduce the results
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            ### MegaDepth 1500 & HPatches
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            1. Download and install [Glue Factory](https://github.com/cvg/glue-factory)
         | 
| 100 | 
            +
            2. Add this repo as a submodule to Glue Factory:
         | 
| 101 | 
            +
            ```bash
         | 
| 102 | 
            +
            cd glue-factory
         | 
| 103 | 
            +
            git submodule add https://github.com/fraunhoferhhi/RIPE.git thirdparty/ripe
         | 
| 104 | 
            +
            ```
         | 
| 105 | 
            +
            3. Create the new file ripe.py under gluefactory/models/extractors/ with the following content:
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                <details>
         | 
| 108 | 
            +
                <summary>ripe.py</summary>
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                ```python
         | 
| 111 | 
            +
                import sys
         | 
| 112 | 
            +
                from pathlib import Path
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                import torch
         | 
| 115 | 
            +
                import torchvision.transforms as transforms
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                from ..base_model import BaseModel
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                ripe_path = Path(__file__).parent / "../../../thirdparty/ripe"
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                print(f"RIPE Path: {ripe_path.resolve()}")
         | 
| 122 | 
            +
                # check if the path exists
         | 
| 123 | 
            +
                if not ripe_path.exists():
         | 
| 124 | 
            +
                    raise RuntimeError(f"RIPE path not found: {ripe_path}")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                sys.path.append(str(ripe_path))
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                from ripe import vgg_hyper
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
                class RIPE(BaseModel):
         | 
| 132 | 
            +
                    default_conf = {
         | 
| 133 | 
            +
                        "name": "RIPE",
         | 
| 134 | 
            +
                        "model_path": None,
         | 
| 135 | 
            +
                        "chunk": 4,
         | 
| 136 | 
            +
                        "dense_outputs": False,
         | 
| 137 | 
            +
                        "threshold": 1.0,
         | 
| 138 | 
            +
                        "top_k": 2048,
         | 
| 139 | 
            +
                    }
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    required_data_keys = ["image"]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    # Initialize the line matcher
         | 
| 144 | 
            +
                    def _init(self, conf):
         | 
| 145 | 
            +
                        self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         | 
| 146 | 
            +
                        self.model = vgg_hyper(model_path=conf.model_path)
         | 
| 147 | 
            +
                        self.model.eval()
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        self.set_initialized()
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    def _forward(self, data):
         | 
| 152 | 
            +
                        image = data["image"]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        keypoints, scores, descriptors = [], [], []
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                        chunk = self.conf.chunk
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                        for i in range(0, image.shape[0], chunk):
         | 
| 159 | 
            +
                            if self.conf.dense_outputs:
         | 
| 160 | 
            +
                                raise NotImplementedError("Dense outputs are not supported")
         | 
| 161 | 
            +
                            else:
         | 
| 162 | 
            +
                                im = image[: min(image.shape[0], i + chunk)]
         | 
| 163 | 
            +
                                im = self.normalizer(im)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                                H, W = im.shape[-2:]
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                                kpt, desc, score = self.model.detectAndCompute(
         | 
| 168 | 
            +
                                    im,
         | 
| 169 | 
            +
                                    threshold=self.conf.threshold,
         | 
| 170 | 
            +
                                    top_k=self.conf.top_k,
         | 
| 171 | 
            +
                                )
         | 
| 172 | 
            +
                            keypoints += [kpt.squeeze(0)]
         | 
| 173 | 
            +
                            scores += [score.squeeze(0)]
         | 
| 174 | 
            +
                            descriptors += [desc.squeeze(0)]
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                            del kpt
         | 
| 177 | 
            +
                            del desc
         | 
| 178 | 
            +
                            del score
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                        keypoints = torch.stack(keypoints, 0)
         | 
| 181 | 
            +
                        scores = torch.stack(scores, 0)
         | 
| 182 | 
            +
                        descriptors = torch.stack(descriptors, 0)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                        pred = {
         | 
| 185 | 
            +
                            # "keypoints": keypoints.to(image) + 0.5,
         | 
| 186 | 
            +
                            "keypoints": keypoints.to(image),
         | 
| 187 | 
            +
                            "keypoint_scores": scores.to(image),
         | 
| 188 | 
            +
                            "descriptors": descriptors.to(image),
         | 
| 189 | 
            +
                        }
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                        return pred
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    def loss(self, pred, data):
         | 
| 194 | 
            +
                        raise NotImplementedError
         | 
| 195 | 
            +
                ```
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                </details>
         | 
| 198 | 
            +
             | 
| 199 | 
            +
            4. Create ripe+NN.yaml in gluefactory/configs with the following content:
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                <details>
         | 
| 202 | 
            +
                <summary>ripe+NN.yaml</summary>
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                ```yaml
         | 
| 205 | 
            +
                model:
         | 
| 206 | 
            +
                    name: two_view_pipeline
         | 
| 207 | 
            +
                    extractor:
         | 
| 208 | 
            +
                        name: extractors.ripe
         | 
| 209 | 
            +
                        threshold: 1.0
         | 
| 210 | 
            +
                        top_k: 2048
         | 
| 211 | 
            +
                    matcher:
         | 
| 212 | 
            +
                        name: matchers.nearest_neighbor_matcher
         | 
| 213 | 
            +
                benchmarks:
         | 
| 214 | 
            +
                    megadepth1500:
         | 
| 215 | 
            +
                      data:
         | 
| 216 | 
            +
                        preprocessing:
         | 
| 217 | 
            +
                          side: long
         | 
| 218 | 
            +
                          resize: 1600
         | 
| 219 | 
            +
                      eval:
         | 
| 220 | 
            +
                        estimator: poselib
         | 
| 221 | 
            +
                        ransac_th: 0.5
         | 
| 222 | 
            +
                    hpatches:
         | 
| 223 | 
            +
                      eval:
         | 
| 224 | 
            +
                        estimator: poselib
         | 
| 225 | 
            +
                        ransac_th: 0.5
         | 
| 226 | 
            +
                      model:
         | 
| 227 | 
            +
                        extractor:
         | 
| 228 | 
            +
                          top_k: 1024  # overwrite config above
         | 
| 229 | 
            +
                ```
         | 
| 230 | 
            +
             | 
| 231 | 
            +
            5. Run the MegaDepth 1500 evaluation script:
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            ```bash
         | 
| 234 | 
            +
            python -m gluefactory.eval.megadepth1500 --conf ripe+NN # for MegaDepth 1500
         | 
| 235 | 
            +
            ```
         | 
| 236 | 
            +
             | 
| 237 | 
            +
            Should result in: 
         | 
| 238 | 
            +
             | 
| 239 | 
            +
            ```bash
         | 
| 240 | 
            +
            'rel_pose_error@10°': 0.6834,
         | 
| 241 | 
            +
            'rel_pose_error@20°': 0.7803,
         | 
| 242 | 
            +
            'rel_pose_error@5°': 0.5511,
         | 
| 243 | 
            +
            ```
         | 
| 244 | 
            +
             | 
| 245 | 
            +
            6. Run the HPatches evaluation script:
         | 
| 246 | 
            +
             | 
| 247 | 
            +
            ```bash
         | 
| 248 | 
            +
            python -m gluefactory.eval.hpatches --conf ripe+NN # for HPatches
         | 
| 249 | 
            +
            ```
         | 
| 250 | 
            +
             | 
| 251 | 
            +
            Should result in:
         | 
| 252 | 
            +
             | 
| 253 | 
            +
            ```bash
         | 
| 254 | 
            +
            'H_error_ransac@1px': 0.3793,
         | 
| 255 | 
            +
            'H_error_ransac@3px': 0.5893,
         | 
| 256 | 
            +
            'H_error_ransac@5px': 0.692,
         | 
| 257 | 
            +
            ```
         | 
| 258 | 
            +
             | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            ## Training
         | 
| 262 | 
            +
             | 
| 263 | 
            +
            1. Create a .env file with the following content:
         | 
| 264 | 
            +
            ```bash
         | 
| 265 | 
            +
            OUTPUT_DIR="/output"
         | 
| 266 | 
            +
            DATA_DIR="/data"
         | 
| 267 | 
            +
            ```
         | 
| 268 | 
            +
             | 
| 269 | 
            +
            2. Download the required datasets:
         | 
| 270 | 
            +
                    
         | 
| 271 | 
            +
                <details>
         | 
| 272 | 
            +
                <summary>DISK Megadepth subset</summary>
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                To download the dataset used by [DISK](https://github.com/cvlab-epfl/disk) execute the following commands:
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                ```bash
         | 
| 277 | 
            +
                cd data
         | 
| 278 | 
            +
                bash download_disk_data.sh
         | 
| 279 | 
            +
                ```
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                </details>
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                <details>
         | 
| 284 | 
            +
                <summary>Tokyo 24/7</summary>
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                - ⚠️**Optional**⚠️: Only if you are interest in the model used in Section 4.6 of the paper!
         | 
| 287 | 
            +
                - Download the Tokyo 24/7 query images from here: [Tokyo 24/7 Query Images V3](http://www.ok.ctrl.titech.ac.jp/~torii/project/247/download/247query_v3.zip) from the official [website](http://www.ok.ctrl.titech.ac.jp/~torii/project/247/_).
         | 
| 288 | 
            +
                - extract them into data/Tolyo_Query_V3
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                ```bash
         | 
| 291 | 
            +
                Tokyo_Query_V3/
         | 
| 292 | 
            +
                ├── 00001.csv
         | 
| 293 | 
            +
                ├── 00001.jpg
         | 
| 294 | 
            +
                ├── 00002.csv
         | 
| 295 | 
            +
                ├── 00002.jpg
         | 
| 296 | 
            +
                ├── ...
         | 
| 297 | 
            +
                ├── 01125.csv
         | 
| 298 | 
            +
                ├── 01125.jpg
         | 
| 299 | 
            +
                ├── Readme.txt
         | 
| 300 | 
            +
                └── Readme.txt~
         | 
| 301 | 
            +
                ```
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                </details>
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                <details>
         | 
| 306 | 
            +
                <summary>ACDC</summary>
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                - ⚠️**Optional**⚠️: Only if you are interest in the model used in Section 6.1 (supplementary) of the paper!
         | 
| 309 | 
            +
                - Download the RGB images from here: [ACDC RGB Images](https://acdc.vision.ee.ethz.ch/rgb_anon_trainvaltest.zip)
         | 
| 310 | 
            +
                - extract them into data/ACDC
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                ```bash
         | 
| 313 | 
            +
                ACDC/
         | 
| 314 | 
            +
                rgb_anon
         | 
| 315 | 
            +
                ├── fog
         | 
| 316 | 
            +
                │   ├── test
         | 
| 317 | 
            +
                │   │   ├── GOPR0475
         | 
| 318 | 
            +
                │   │   ├── GOPR0477
         | 
| 319 | 
            +
                │   ├── test_ref
         | 
| 320 | 
            +
                │   │   ├── GOPR0475
         | 
| 321 | 
            +
                │   │   ├── GOPR0477
         | 
| 322 | 
            +
                │   ├── train
         | 
| 323 | 
            +
                │   │   ├── GOPR0475
         | 
| 324 | 
            +
                │   │   ├── GOPR0476
         | 
| 325 | 
            +
                ├── night
         | 
| 326 | 
            +
                ```
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                </details>
         | 
| 329 | 
            +
             | 
| 330 | 
            +
            3. Run the training script:
         | 
| 331 | 
            +
             | 
| 332 | 
            +
            ```bash
         | 
| 333 | 
            +
            python ripe/train.py --config-name train project_name=train name=reproduce wandb_mode=offline
         | 
| 334 | 
            +
            ```
         | 
| 335 | 
            +
             | 
| 336 | 
            +
            You can also easily switch setting from the command line, e.g. to addionally train on the Tokyo 24/7 dataset:
         | 
| 337 | 
            +
            ```bash
         | 
| 338 | 
            +
            python ripe/train.py --config-name train project_name=train name=reproduce wandb_mode=offline data=megadepth+tokyo
         | 
| 339 | 
            +
            ```
         | 
| 340 | 
            +
             | 
| 341 | 
            +
            ## Acknowledgements
         | 
| 342 | 
            +
             | 
| 343 | 
            +
            Our code is partly based on the following repositories:
         | 
| 344 | 
            +
            - [DALF](https://github.com/verlab/DALF_CVPR_2023) Apache License 2.0
         | 
| 345 | 
            +
            - [DeDoDe](https://github.com/Parskatt/DeDoDe) MIT License
         | 
| 346 | 
            +
            - [DISK](https://github.com/cvlab-epfl/disk) Apache License 2.0
         | 
| 347 | 
            +
             | 
| 348 | 
            +
            Our evaluation was based on the following repositories:
         | 
| 349 | 
            +
            - [Glue Factory](https://github.com/cvg/glue-factory)
         | 
| 350 | 
            +
            - [hloc](https://github.com/cvg/Hierarchical-Localization)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
            We would like to thank the authors of these repositories for their great work and for making their code available.
         | 
| 353 | 
            +
             | 
| 354 | 
            +
            Our project webpage is based on the [Acadamic Project Page Template](https://github.com/eliahuhorwitz/Academic-project-page-template) by Eliahu Horwitz.
         | 
| 355 | 
            +
             | 
| 356 | 
            +
            ## BibTex Citation
         | 
| 357 | 
            +
             | 
| 358 | 
            +
            ```
         | 
| 359 | 
            +
             | 
| 360 | 
            +
            @article{ripe2025, 
         | 
| 361 | 
            +
            year = {2025}, 
         | 
| 362 | 
            +
            title = {{RIPE: Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction}}, 
         | 
| 363 | 
            +
            author = {Künzel, Johannes and Hilsmann, Anna and Eisert, Peter}, 
         | 
| 364 | 
            +
            journal = {arXiv}, 
         | 
| 365 | 
            +
            eprint = {2507.04839}, 
         | 
| 366 | 
            +
            }
         | 
| 367 | 
            +
            ```
         | 
    	
        imcui/third_party/RIPE/app.py
    ADDED
    
    | @@ -0,0 +1,272 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # This is a small gradio interface to access our RIPE keypoint extractor.
         | 
| 2 | 
            +
            # You can either upload two images or use one of the example image pairs.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import gradio as gr
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ripe import vgg_hyper
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            SEED = 32000
         | 
| 12 | 
            +
            os.environ["PYTHONHASHSEED"] = str(SEED)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import random
         | 
| 15 | 
            +
            from pathlib import Path
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            torch.manual_seed(SEED)
         | 
| 21 | 
            +
            np.random.seed(SEED)
         | 
| 22 | 
            +
            random.seed(SEED)
         | 
| 23 | 
            +
            import cv2
         | 
| 24 | 
            +
            import kornia.feature as KF
         | 
| 25 | 
            +
            import kornia.geometry as KG
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from ripe.utils.utils import cv2_matches_from_kornia, to_cv_kpts
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            MIN_SIZE = 512
         | 
| 30 | 
            +
            MAX_SIZE = 768
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            description_text = """
         | 
| 33 | 
            +
            <p align='center'>
         | 
| 34 | 
            +
              <h1 align='center'>🌊🌺 ICCV 2025 🌺🌊</h1>
         | 
| 35 | 
            +
              <p align='center'>
         | 
| 36 | 
            +
                <a href='https://scholar.google.com/citations?user=ybMR38kAAAAJ'>Johannes Künzel</a> · 
         | 
| 37 | 
            +
                <a href='https://scholar.google.com/citations?user=5yTuyGIAAAAJ'>Anna Hilsmann</a> · 
         | 
| 38 | 
            +
                <a href='https://scholar.google.com/citations?user=BCElyCkAAAAJ'>Peter Eisert</a>
         | 
| 39 | 
            +
              </p>
         | 
| 40 | 
            +
              <h2 align='center'>
         | 
| 41 | 
            +
                <a href='???'>Arxiv</a> | 
         | 
| 42 | 
            +
                <a href='???'>Project Page</a> | 
         | 
| 43 | 
            +
                <a href='???'>Code</a>
         | 
| 44 | 
            +
              </h2>
         | 
| 45 | 
            +
            </p>
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            <br/>
         | 
| 48 | 
            +
            <div align='center'>
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            ### This demo showcases our new keypoint extractor model, RIPE (Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction).
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            ### RIPE is trained without requiring pose or depth supervision or artificial augmentations. By leveraging reinforcement learning, it learns to extract keypoints solely based on whether an image pair depicts the same scene or not.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            ### For more detailed information, please refer to our [paper](link to be added).
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            The demo code extracts the 2048-top keypoints from the two input images. It uses the mutual nearest neighbor (MNN) descriptor matcher from kornia to find matches between the two images.
         | 
| 57 | 
            +
            If the number of matches is greater than 8, it applies RANSAC to filter out outliers based on the inlier threshold provided by the user.
         | 
| 58 | 
            +
            Images are resized to fit within a maximum size of 2048x2048 pixels with maintained aspect ratio.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            </div>
         | 
| 61 | 
            +
            """
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            path_weights = Path(
         | 
| 64 | 
            +
                "/media/jwkuenzel/work/projects/CVG_Reinforced_Keypoints/output/train/ablation_iccv/inlier_threshold/1571243/2025-02-19/14-00-10_789013/model_inlier_threshold_best.pth"
         | 
| 65 | 
            +
            )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            model = vgg_hyper(path_weights)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def get_new_image_size(image, min_size=1600, max_size=2048):
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                Get a new size for the image that is scaled to fit between min_size and max_size while maintaining the aspect ratio.
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                Args:
         | 
| 75 | 
            +
                    image (PIL.Image): Input image.
         | 
| 76 | 
            +
                    min_size (int): Minimum allowed size for width and height.
         | 
| 77 | 
            +
                    max_size (int): Maximum allowed size for width and height.
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                Returns:
         | 
| 80 | 
            +
                    tuple: New size (width, height) for the image.
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                width, height = image.size
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                aspect_ratio = width / height
         | 
| 85 | 
            +
                if width > height:
         | 
| 86 | 
            +
                    new_width = max(min_size, min(max_size, width))
         | 
| 87 | 
            +
                    new_height = int(new_width / aspect_ratio)
         | 
| 88 | 
            +
                else:
         | 
| 89 | 
            +
                    new_height = max(min_size, min(max_size, height))
         | 
| 90 | 
            +
                    new_width = int(new_height * aspect_ratio)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                new_size = (new_width, new_height)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                return new_size
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def extract_keypoints(image1, image2, inl_th):
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                Extract keypoints from two input images using the RIPE model.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                Args:
         | 
| 102 | 
            +
                    image1 (PIL.Image): First input image.
         | 
| 103 | 
            +
                    image2 (PIL.Image): Second input image.
         | 
| 104 | 
            +
                    inl_th (float): RANSAC inlier threshold.
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                Returns:
         | 
| 107 | 
            +
                    dict: A dictionary containing keypoints and matches.
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
                log_text = "Extracting keypoints and matches with RIPE\n"
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                log_text += f"Image 1 size: {image1.size}\n"
         | 
| 112 | 
            +
                log_text += f"Image 2 size: {image2.size}\n"
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                # check not larger than 2048x2048
         | 
| 115 | 
            +
                new_size = get_new_image_size(image1, min_size=MIN_SIZE, max_size=MAX_SIZE)
         | 
| 116 | 
            +
                image1 = image1.resize(new_size)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                new_size = get_new_image_size(image2, min_size=MIN_SIZE, max_size=MAX_SIZE)
         | 
| 119 | 
            +
                image2 = image2.resize(new_size)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                log_text += f"Resized Image 1 size: {image1.size}\n"
         | 
| 122 | 
            +
                log_text += f"Resized Image 2 size: {image2.size}\n"
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 125 | 
            +
                model.to(dev)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                image1 = image1.convert("RGB")
         | 
| 128 | 
            +
                image2 = image2.convert("RGB")
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                image1_original = image1.copy()
         | 
| 131 | 
            +
                image2_original = image2.copy()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # convert PIL images to numpy arrays
         | 
| 134 | 
            +
                image1_original = np.array(image1_original)
         | 
| 135 | 
            +
                image2_original = np.array(image2_original)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                # convert PIL images to tensors
         | 
| 138 | 
            +
                image1 = torch.tensor(np.array(image1)).permute(2, 0, 1).float() / 255.0
         | 
| 139 | 
            +
                image2 = torch.tensor(np.array(image2)).permute(2, 0, 1).float() / 255.0
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                image1 = image1.to(dev).unsqueeze(0)  # Add batch dimension
         | 
| 142 | 
            +
                image2 = image2.to(dev).unsqueeze(0)  # Add batch dimension
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
         | 
| 145 | 
            +
                kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                log_text += f"Number of keypoints in image 1: {kpts_1.shape[0]}\n"
         | 
| 148 | 
            +
                log_text += f"Number of keypoints in image 2: {kpts_2.shape[0]}\n"
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                matcher = KF.DescriptorMatcher("mnn")  # threshold is not used with mnn
         | 
| 151 | 
            +
                match_dists, match_idxs = matcher(desc_1, desc_2)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                log_text += f"Number of MNN matches: {match_idxs.shape[0]}\n"
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                cv2_matches = cv2_matches_from_kornia(match_dists, match_idxs)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                do_ransac = match_idxs.shape[0] > 8
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                if do_ransac:
         | 
| 160 | 
            +
                    matched_pts_1 = kpts_1[match_idxs[:, 0]]
         | 
| 161 | 
            +
                    matched_pts_2 = kpts_2[match_idxs[:, 1]]
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=inl_th)(matched_pts_1, matched_pts_2)
         | 
| 164 | 
            +
                    matchesMask = mask.int().ravel().tolist()
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    log_text += f"RANSAC found {mask.sum().item()} inliers out of {mask.shape[0]} matches with an inlier threshold of {inl_th}.\n"
         | 
| 167 | 
            +
                else:
         | 
| 168 | 
            +
                    log_text += "Not enough matches for RANSAC, skipping RANSAC step.\n"
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                kpts_1 = to_cv_kpts(kpts_1, score_1)
         | 
| 171 | 
            +
                kpts_2 = to_cv_kpts(kpts_2, score_2)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                keypoints_raw_1 = cv2.drawKeypoints(image1_original, kpts_1, image1_original, color=(0, 255, 0))
         | 
| 174 | 
            +
                keypoints_raw_2 = cv2.drawKeypoints(image2_original, kpts_2, image2_original, color=(0, 255, 0))
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                # pad height smaller image to match the height of the larger image
         | 
| 177 | 
            +
                if keypoints_raw_1.shape[0] < keypoints_raw_2.shape[0]:
         | 
| 178 | 
            +
                    pad_height = keypoints_raw_2.shape[0] - keypoints_raw_1.shape[0]
         | 
| 179 | 
            +
                    keypoints_raw_1 = np.pad(
         | 
| 180 | 
            +
                        keypoints_raw_1, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
         | 
| 181 | 
            +
                    )
         | 
| 182 | 
            +
                elif keypoints_raw_1.shape[0] > keypoints_raw_2.shape[0]:
         | 
| 183 | 
            +
                    pad_height = keypoints_raw_1.shape[0] - keypoints_raw_2.shape[0]
         | 
| 184 | 
            +
                    keypoints_raw_2 = np.pad(
         | 
| 185 | 
            +
                        keypoints_raw_2, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
         | 
| 186 | 
            +
                    )
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                # concatenate keypoints images horizontally
         | 
| 189 | 
            +
                keypoints_raw = np.concatenate((keypoints_raw_1, keypoints_raw_2), axis=1)
         | 
| 190 | 
            +
                keypoints_raw_pil = Image.fromarray(keypoints_raw)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                result_raw = cv2.drawMatches(
         | 
| 193 | 
            +
                    image1_original,
         | 
| 194 | 
            +
                    kpts_1,
         | 
| 195 | 
            +
                    image2_original,
         | 
| 196 | 
            +
                    kpts_2,
         | 
| 197 | 
            +
                    cv2_matches,
         | 
| 198 | 
            +
                    None,
         | 
| 199 | 
            +
                    matchColor=(0, 255, 0),
         | 
| 200 | 
            +
                    matchesMask=None,
         | 
| 201 | 
            +
                    # matchesMask=None,
         | 
| 202 | 
            +
                    flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
         | 
| 203 | 
            +
                )
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                if not do_ransac:
         | 
| 206 | 
            +
                    result_ransac = None
         | 
| 207 | 
            +
                else:
         | 
| 208 | 
            +
                    result_ransac = cv2.drawMatches(
         | 
| 209 | 
            +
                        image1_original,
         | 
| 210 | 
            +
                        kpts_1,
         | 
| 211 | 
            +
                        image2_original,
         | 
| 212 | 
            +
                        kpts_2,
         | 
| 213 | 
            +
                        cv2_matches,
         | 
| 214 | 
            +
                        None,
         | 
| 215 | 
            +
                        matchColor=(0, 255, 0),
         | 
| 216 | 
            +
                        matchesMask=matchesMask,
         | 
| 217 | 
            +
                        singlePointColor=(0, 0, 255),
         | 
| 218 | 
            +
                        flags=cv2.DrawMatchesFlags_DEFAULT,
         | 
| 219 | 
            +
                    )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for display
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                # convert to PIL Image
         | 
| 224 | 
            +
                result_raw_pil = Image.fromarray(result_raw)
         | 
| 225 | 
            +
                if result_ransac is not None:
         | 
| 226 | 
            +
                    result_ransac_pil = Image.fromarray(result_ransac)
         | 
| 227 | 
            +
                else:
         | 
| 228 | 
            +
                    result_ransac_pil = None
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                return log_text, result_ransac_pil, result_raw_pil, keypoints_raw_pil
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            demo = gr.Interface(
         | 
| 234 | 
            +
                fn=extract_keypoints,
         | 
| 235 | 
            +
                inputs=[
         | 
| 236 | 
            +
                    gr.Image(type="pil", label="Image 1"),
         | 
| 237 | 
            +
                    gr.Image(type="pil", label="Image 2"),
         | 
| 238 | 
            +
                    gr.Slider(
         | 
| 239 | 
            +
                        minimum=0.1,
         | 
| 240 | 
            +
                        maximum=3.0,
         | 
| 241 | 
            +
                        step=0.1,
         | 
| 242 | 
            +
                        value=0.5,
         | 
| 243 | 
            +
                        label="RANSAC inlier threshold",
         | 
| 244 | 
            +
                        info="Threshold for RANSAC inlier detection. Lower values may yield fewer inliers but more robust matches.",
         | 
| 245 | 
            +
                    ),
         | 
| 246 | 
            +
                ],
         | 
| 247 | 
            +
                outputs=[
         | 
| 248 | 
            +
                    gr.Textbox(type="text", label="Log"),
         | 
| 249 | 
            +
                    gr.Image(type="pil", label="Keypoints and Matches (RANSAC)"),
         | 
| 250 | 
            +
                    gr.Image(type="pil", label="Keypoints and Matches"),
         | 
| 251 | 
            +
                    gr.Image(type="pil", label="Keypoint Detection Results"),
         | 
| 252 | 
            +
                ],
         | 
| 253 | 
            +
                title="RIPE: Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction",
         | 
| 254 | 
            +
                description=description_text,
         | 
| 255 | 
            +
                examples=[
         | 
| 256 | 
            +
                    [
         | 
| 257 | 
            +
                        "assets_gradio/all_souls_000013.jpg",
         | 
| 258 | 
            +
                        "assets_gradio/all_souls_000055.jpg",
         | 
| 259 | 
            +
                    ],
         | 
| 260 | 
            +
                    [
         | 
| 261 | 
            +
                        "assets_gradio/167170681_0e5c42fd21_o.jpg",
         | 
| 262 | 
            +
                        "assets_gradio/170804731_6bf4fbecd4_o.jpg",
         | 
| 263 | 
            +
                    ],
         | 
| 264 | 
            +
                    [
         | 
| 265 | 
            +
                        "assets_gradio/4171014767_0fe879b783_o.jpg",
         | 
| 266 | 
            +
                        "assets_gradio/4174108353_20422632d6_o.jpg",
         | 
| 267 | 
            +
                    ],
         | 
| 268 | 
            +
                ],
         | 
| 269 | 
            +
                flagging_mode="never",
         | 
| 270 | 
            +
                theme="default",
         | 
| 271 | 
            +
            )
         | 
| 272 | 
            +
            demo.launch()
         | 
    	
        imcui/third_party/RIPE/assets/all_souls_000013.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        imcui/third_party/RIPE/assets/all_souls_000055.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        imcui/third_party/RIPE/assets/teaser_image.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        imcui/third_party/RIPE/conda_env.yml
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: ripe-env
         | 
| 2 | 
            +
            channels:
         | 
| 3 | 
            +
              - conda-forge
         | 
| 4 | 
            +
            dependencies:
         | 
| 5 | 
            +
              - python
         | 
| 6 | 
            +
              - cmake
         | 
| 7 | 
            +
              - eigen # for poselib
         | 
| 8 | 
            +
              - pytorch=2.6=*cuda*
         | 
| 9 | 
            +
              - torchvision
         | 
| 10 | 
            +
              - pip
         | 
| 11 | 
            +
              # others
         | 
| 12 | 
            +
              - pudb # debugger
         | 
| 13 | 
            +
              - pip:
         | 
| 14 | 
            +
                  - lightning>=2.0.0
         | 
| 15 | 
            +
                  - setuptools
         | 
| 16 | 
            +
                  - poselib @ git+https://github.com/PoseLib/PoseLib.git@56d158f744d3561b0b70174e6d8ca9a7fc9bd9c1
         | 
| 17 | 
            +
                  - hydra-core
         | 
| 18 | 
            +
                  - opencv-python
         | 
| 19 | 
            +
                  - torchmetrics
         | 
| 20 | 
            +
                  - pyrootutils # standardizing the project root setup
         | 
| 21 | 
            +
                  - rich
         | 
| 22 | 
            +
                  - matplotlib
         | 
| 23 | 
            +
                  - kornia
         | 
| 24 | 
            +
                  - numpy
         | 
| 25 | 
            +
                  - wandb
         | 
| 26 | 
            +
                  - h5py
         | 
    	
        imcui/third_party/RIPE/conf/backbones/resnet.yaml
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.models.backbones.resnet.ResNet
         | 
| 2 | 
            +
            nchannels: 3
         | 
| 3 | 
            +
            pretrained: True
         | 
| 4 | 
            +
            use_instance_norm: False
         | 
| 5 | 
            +
            mode: dect
         | 
| 6 | 
            +
            num_layers: 4
         | 
    	
        imcui/third_party/RIPE/conf/backbones/vgg.yaml
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.models.backbones.vgg.VGG
         | 
| 2 | 
            +
            nchannels: 3
         | 
| 3 | 
            +
            pretrained: True
         | 
| 4 | 
            +
            use_instance_norm: False
         | 
| 5 | 
            +
            mode: dect
         | 
    	
        imcui/third_party/RIPE/conf/data/disk_megadepth.yaml
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.data.datasets.disk_megadepth.DISK_Megadepth
         | 
| 2 | 
            +
            root: ${oc.env:DATA_DIR}/disk-data
         | 
| 3 | 
            +
            stage: train
         | 
| 4 | 
            +
            max_scene_size: 10000
         | 
| 5 | 
            +
            transforms:
         | 
| 6 | 
            +
              _target_: ripe.data.data_transforms.Compose
         | 
| 7 | 
            +
              transforms:
         | 
| 8 | 
            +
                - _target_: ripe.data.data_transforms.Normalize
         | 
| 9 | 
            +
                  mean: [0.485, 0.456, 0.406]
         | 
| 10 | 
            +
                  std: [0.229, 0.224, 0.225]
         | 
| 11 | 
            +
                - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
         | 
| 12 | 
            +
                  target_size_longer_side: 560
         | 
    	
        imcui/third_party/RIPE/conf/data/megadepth+acdc.yaml
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.data.datasets.dataset_combinator.DatasetCombinator
         | 
| 2 | 
            +
            mode: custom
         | 
| 3 | 
            +
            weights:
         | 
| 4 | 
            +
              - 0.2
         | 
| 5 | 
            +
              - 0.8
         | 
| 6 | 
            +
            datasets:
         | 
| 7 | 
            +
              - _target_: ripe.data.datasets.acdc.ACDC
         | 
| 8 | 
            +
                root: ${oc.env:DATA_DIR}/ACDC
         | 
| 9 | 
            +
                stage: train
         | 
| 10 | 
            +
                condition: all
         | 
| 11 | 
            +
                transforms:
         | 
| 12 | 
            +
                  _target_: ripe.data.data_transforms.Compose
         | 
| 13 | 
            +
                  transforms:
         | 
| 14 | 
            +
                    - _target_: ripe.data.data_transforms.Normalize
         | 
| 15 | 
            +
                      mean: [0.485, 0.456, 0.406]
         | 
| 16 | 
            +
                      std: [0.229, 0.224, 0.225]
         | 
| 17 | 
            +
                    - _target_: ripe.data.data_transforms.Crop # to remove the car hood from some images
         | 
| 18 | 
            +
                      crop_height: 896
         | 
| 19 | 
            +
                      crop_width: 1920
         | 
| 20 | 
            +
                    - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
         | 
| 21 | 
            +
                      target_size_longer_side: 560
         | 
| 22 | 
            +
              - _target_: ripe.data.datasets.disk_megadepth.DISK_Megadepth
         | 
| 23 | 
            +
                root: ${oc.env:DATA_DIR}/disk-data
         | 
| 24 | 
            +
                stage: train
         | 
| 25 | 
            +
                max_scene_size: 10000
         | 
| 26 | 
            +
                transforms:
         | 
| 27 | 
            +
                  _target_: ripe.data.data_transforms.Compose
         | 
| 28 | 
            +
                  transforms:
         | 
| 29 | 
            +
                    - _target_: ripe.data.data_transforms.Normalize
         | 
| 30 | 
            +
                      mean: [0.485, 0.456, 0.406]
         | 
| 31 | 
            +
                      std: [0.229, 0.224, 0.225]
         | 
| 32 | 
            +
                    - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
         | 
| 33 | 
            +
                      target_size_longer_side: 560
         | 
    	
        imcui/third_party/RIPE/conf/data/megadepth+tokyo.yaml
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.data.datasets.dataset_combinator.DatasetCombinator
         | 
| 2 | 
            +
            mode: custom
         | 
| 3 | 
            +
            weights:
         | 
| 4 | 
            +
              - 0.2
         | 
| 5 | 
            +
              - 0.8
         | 
| 6 | 
            +
            datasets:
         | 
| 7 | 
            +
              - _target_: ripe.data.datasets.tokyo_query_v3.TokyoQueryV3
         | 
| 8 | 
            +
                root: ${oc.env:DATA_DIR}/Tokyo_Query_V3
         | 
| 9 | 
            +
                stage: train
         | 
| 10 | 
            +
                transforms:
         | 
| 11 | 
            +
                  _target_: ripe.data.data_transforms.Compose
         | 
| 12 | 
            +
                  transforms:
         | 
| 13 | 
            +
                    - _target_: ripe.data.data_transforms.Normalize
         | 
| 14 | 
            +
                      mean: [0.485, 0.456, 0.406]
         | 
| 15 | 
            +
                      std: [0.229, 0.224, 0.225]
         | 
| 16 | 
            +
                    - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
         | 
| 17 | 
            +
                      target_size_longer_side: 560 # like DeDoDe
         | 
| 18 | 
            +
              - _target_: ripe.data.datasets.disk_megadepth.DISK_Megadepth
         | 
| 19 | 
            +
                root: ${oc.env:DATA_DIR}/disk-data
         | 
| 20 | 
            +
                stage: train
         | 
| 21 | 
            +
                max_scene_size: 10000
         | 
| 22 | 
            +
                transforms:
         | 
| 23 | 
            +
                  _target_: ripe.data.data_transforms.Compose
         | 
| 24 | 
            +
                  transforms:
         | 
| 25 | 
            +
                    - _target_: ripe.data.data_transforms.Normalize
         | 
| 26 | 
            +
                      mean: [0.485, 0.456, 0.406]
         | 
| 27 | 
            +
                      std: [0.229, 0.224, 0.225]
         | 
| 28 | 
            +
                    - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
         | 
| 29 | 
            +
                      target_size_longer_side: 560
         | 
    	
        imcui/third_party/RIPE/conf/descriptor_loss/contrastive_loss.yaml
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.losses.contrastive_loss.ContrastiveLoss
         | 
| 2 | 
            +
            pos_margin: 0.2
         | 
| 3 | 
            +
            neg_margin: 0.2
         | 
    	
        imcui/third_party/RIPE/conf/inl_th/constant.yaml
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.scheduler.constant.ConstantScheduler
         | 
| 2 | 
            +
            value: 1.0
         | 
    	
        imcui/third_party/RIPE/conf/inl_th/exp_decay.yaml
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.scheduler.expDecay.ExpDecay
         | 
| 2 | 
            +
            a: 2.5
         | 
| 3 | 
            +
            b: 0.0005
         | 
| 4 | 
            +
            c: 0.5
         | 
    	
        imcui/third_party/RIPE/conf/matcher/concurrent_mnn_poselib.yaml
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.matcher.concurrent_matcher.ConcurrentMatcher
         | 
| 2 | 
            +
            min_num_matches: 8
         | 
| 3 | 
            +
            matcher:
         | 
| 4 | 
            +
              _target_: kornia.feature.DescriptorMatcher
         | 
| 5 | 
            +
              match_mode: "mnn"
         | 
| 6 | 
            +
              th: 0.8
         | 
| 7 | 
            +
            robust_estimator:
         | 
| 8 | 
            +
              _target_: ripe.matcher.pose_estimator_poselib.PoseLibRelativePoseEstimator
         | 
    	
        imcui/third_party/RIPE/conf/train.yaml
    ADDED
    
    | @@ -0,0 +1,89 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            defaults:
         | 
| 2 | 
            +
              - data: disk_megadepth # megadepth+acdc or megadepth+tokyo
         | 
| 3 | 
            +
              - backbones: vgg
         | 
| 4 | 
            +
              - upsampler: hypercolumn_features # interpolate_sparse2D
         | 
| 5 | 
            +
              - matcher: concurrent_mnn_poselib
         | 
| 6 | 
            +
              - descriptor_loss: contrastive_loss # none to deactivate
         | 
| 7 | 
            +
              - inl_th: constant # exp_decay
         | 
| 8 | 
            +
              - _self_
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            project_name: ???
         | 
| 11 | 
            +
            name: ???
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            hydra:
         | 
| 14 | 
            +
              run:
         | 
| 15 | 
            +
                dir: ${oc.env:OUTPUT_DIR}/${project_name}/${name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d}/${now:%H-%M-%S}
         | 
| 16 | 
            +
            output_dir: ${hydra:runtime.output_dir}
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            num_gpus: 1
         | 
| 19 | 
            +
            # precision: "32-true"
         | 
| 20 | 
            +
            precision: "bf16-mixed" # numerically more stable
         | 
| 21 | 
            +
            # precision: "16-mixed"
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            log_interval: 50 # log every N steps/ batches
         | 
| 24 | 
            +
            wandb_mode: online
         | 
| 25 | 
            +
            val_interval: 2000
         | 
| 26 | 
            +
            conf_inference:
         | 
| 27 | 
            +
              threshold: 0.5
         | 
| 28 | 
            +
              top_k: 2048
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            desc_loss_weight: 5.0 # 0.0 to deactivate, also deactivates 1x1 conv
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            num_workers: 8
         | 
| 33 | 
            +
            batch_size: 6
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            transformation_model: fundamental
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            network:
         | 
| 38 | 
            +
              _target_: ripe.models.ripe.RIPE
         | 
| 39 | 
            +
              _partial_: true
         | 
| 40 | 
            +
              window_size: 8
         | 
| 41 | 
            +
              non_linearity_dect:
         | 
| 42 | 
            +
                _target_: torch.nn.Identity
         | 
| 43 | 
            +
                # _target_: torch.nn.ReLU
         | 
| 44 | 
            +
              desc_shares:
         | 
| 45 | 
            +
                null
         | 
| 46 | 
            +
                # - 64
         | 
| 47 | 
            +
                # - 64
         | 
| 48 | 
            +
                # - 64
         | 
| 49 | 
            +
                # - 64
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            lr: 0.001 # 0.001 makes it somewhat unstable
         | 
| 52 | 
            +
            fp_penalty: -1e-7 # -1e-7
         | 
| 53 | 
            +
            kp_penalty: -7e-7 # -7e-7
         | 
| 54 | 
            +
            num_grad_accs: 4
         | 
| 55 | 
            +
            reward_type: inlier # inlier_ratio , inlier+inlier_ratio
         | 
| 56 | 
            +
            no_filtering_negatives: False
         | 
| 57 | 
            +
            descriptor_dim: 256
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            lr_scheduler:
         | 
| 60 | 
            +
              _partial_: true
         | 
| 61 | 
            +
              _target_: ripe.scheduler.linearLR.StepLinearLR
         | 
| 62 | 
            +
              num_steps: ${num_steps}
         | 
| 63 | 
            +
              initial_lr: ${lr}
         | 
| 64 | 
            +
              final_lr: 1e-6
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            use_whitening: false
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            selected_only: False
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            padding_filter_mode: ignore
         | 
| 71 | 
            +
            # padding_filter_mode: punish
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            num_steps: 80000
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            alpha_scheduler: # 1.0 after 1/3 of the steps
         | 
| 76 | 
            +
              _target_: ripe.scheduler.linear_with_plateaus.LinearWithPlateaus
         | 
| 77 | 
            +
              start_val: 0.0
         | 
| 78 | 
            +
              end_val: 1.0
         | 
| 79 | 
            +
              steps_total: ${num_steps}
         | 
| 80 | 
            +
              rel_length_start_plateau: 0.0
         | 
| 81 | 
            +
              rel_length_end_plateu: 0.6666666
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            beta_scheduler: # linear increase over all steps
         | 
| 84 | 
            +
              _target_: ripe.scheduler.linear_with_plateaus.LinearWithPlateaus
         | 
| 85 | 
            +
              start_val: 0.0
         | 
| 86 | 
            +
              end_val: 1.0
         | 
| 87 | 
            +
              steps_total: ${num_steps}
         | 
| 88 | 
            +
              rel_length_start_plateau: 0.0
         | 
| 89 | 
            +
              rel_length_end_plateu: 0.0
         | 
    	
        imcui/third_party/RIPE/conf/upsampler/hypercolumn_features.yaml
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.models.upsampler.hypercolumn_features.HyperColumnFeatures
         | 
| 2 | 
            +
            mode: bilinear
         | 
    	
        imcui/third_party/RIPE/conf/upsampler/interpolate_sparse2D.yaml
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            _target_: ripe.models.upsampler.interpolate_sparse2d.InterpolateSparse2d
         | 
    	
        imcui/third_party/RIPE/data/download_disk_data.sh
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #/usr/bin/env bash
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # get the data (zipped)
         | 
| 4 | 
            +
            # wget -r https://datasets.epfl.ch/disk-data/index.html
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            cd datasets.epfl.ch/disk-data;
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # check for MD5 match
         | 
| 9 | 
            +
            # md5sum -c md5sum.txt;
         | 
| 10 | 
            +
            # if [ $? ]; then
         | 
| 11 | 
            +
            #     echo "MD5 mismatch (corrupt download)";
         | 
| 12 | 
            +
            #     return 1;
         | 
| 13 | 
            +
            # fi
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # create a crude progress counter
         | 
| 16 | 
            +
            ITER=1;
         | 
| 17 | 
            +
            TOTAL=138;
         | 
| 18 | 
            +
            # unzip test scenes
         | 
| 19 | 
            +
            cd imw2020-val/scenes;
         | 
| 20 | 
            +
            for SCENE_TAR in *.tar.gz; do
         | 
| 21 | 
            +
                echo "Unzipping $SCENE_TAR ($ITER / $TOTAL)";
         | 
| 22 | 
            +
                tar -xz --strip-components=3 -f $SCENE_TAR;
         | 
| 23 | 
            +
                rm $SCENE_TAR;
         | 
| 24 | 
            +
                ITER=$(($ITER+1));
         | 
| 25 | 
            +
            done
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # unzip megadepth scenes
         | 
| 28 | 
            +
            cd ../../megadepth/scenes;
         | 
| 29 | 
            +
            for SCENE_TAR in *.tar; do
         | 
| 30 | 
            +
                echo "Unzipping $SCENE_TAR ($ITER / $TOTAL)";
         | 
| 31 | 
            +
                tar -x --strip-components=3 -f $SCENE_TAR;
         | 
| 32 | 
            +
                rm $SCENE_TAR;
         | 
| 33 | 
            +
                ITER=$(($ITER+1));
         | 
| 34 | 
            +
            done
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            cd ../../../../
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            mv datasets.epfl.ch/disk-data ./
         | 
| 39 | 
            +
            rm -rf datasets.epfl.ch
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
    	
        imcui/third_party/RIPE/demo.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import kornia.feature as KF
         | 
| 3 | 
            +
            import kornia.geometry as KG
         | 
| 4 | 
            +
            import matplotlib.pyplot as plt
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torchvision.io import decode_image
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ripe import vgg_hyper
         | 
| 10 | 
            +
            from ripe.utils.utils import cv2_matches_from_kornia, resize_image, to_cv_kpts
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            model = vgg_hyper().to(dev)
         | 
| 15 | 
            +
            model.eval()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            image1 = resize_image(decode_image("assets/all_souls_000013.jpg").float().to(dev) / 255.0)
         | 
| 18 | 
            +
            image2 = resize_image(decode_image("assets/all_souls_000055.jpg").float().to(dev) / 255.0)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
         | 
| 21 | 
            +
            kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            matcher = KF.DescriptorMatcher("mnn")  # threshold is not used with mnn
         | 
| 24 | 
            +
            match_dists, match_idxs = matcher(desc_1, desc_2)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            matched_pts_1 = kpts_1[match_idxs[:, 0]]
         | 
| 27 | 
            +
            matched_pts_2 = kpts_2[match_idxs[:, 1]]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=1.0)(matched_pts_1, matched_pts_2)
         | 
| 30 | 
            +
            matchesMask = mask.int().ravel().tolist()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            result_ransac = cv2.drawMatches(
         | 
| 33 | 
            +
                (image1.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
         | 
| 34 | 
            +
                to_cv_kpts(kpts_1, score_1),
         | 
| 35 | 
            +
                (image2.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
         | 
| 36 | 
            +
                to_cv_kpts(kpts_2, score_2),
         | 
| 37 | 
            +
                cv2_matches_from_kornia(match_dists, match_idxs),
         | 
| 38 | 
            +
                None,
         | 
| 39 | 
            +
                matchColor=(0, 255, 0),
         | 
| 40 | 
            +
                matchesMask=matchesMask,
         | 
| 41 | 
            +
                # matchesMask=None, # without RANSAC filtering
         | 
| 42 | 
            +
                singlePointColor=(0, 0, 255),
         | 
| 43 | 
            +
                flags=cv2.DrawMatchesFlags_DEFAULT,
         | 
| 44 | 
            +
            )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            plt.imshow(result_ransac)
         | 
| 47 | 
            +
            plt.axis("off")
         | 
| 48 | 
            +
            plt.tight_layout()
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            # plt.show()
         | 
| 51 | 
            +
            plt.savefig("result_ransac.png")
         | 
    	
        imcui/third_party/RIPE/ripe/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .model_zoo import vgg_hyper  # noqa: F401
         | 
    	
        imcui/third_party/RIPE/ripe/benchmarks/imw_2020.py
    ADDED
    
    | @@ -0,0 +1,320 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import cv2
         | 
| 5 | 
            +
            import kornia.feature as KF
         | 
| 6 | 
            +
            import matplotlib.pyplot as plt
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import poselib
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from tqdm import tqdm
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from ripe import utils
         | 
| 13 | 
            +
            from ripe.data.data_transforms import Compose, Normalize, Resize
         | 
| 14 | 
            +
            from ripe.data.datasets.disk_imw import DISK_IMW
         | 
| 15 | 
            +
            from ripe.utils.pose_error import AUCMetric, relative_pose_error
         | 
| 16 | 
            +
            from ripe.utils.utils import (
         | 
| 17 | 
            +
                cv2_matches_from_kornia,
         | 
| 18 | 
            +
                cv_resize_and_pad_to_shape,
         | 
| 19 | 
            +
                to_cv_kpts,
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class IMW_2020_Benchmark:
         | 
| 26 | 
            +
                def __init__(
         | 
| 27 | 
            +
                    self,
         | 
| 28 | 
            +
                    use_predefined_subset: bool = True,
         | 
| 29 | 
            +
                    conf_inference=None,
         | 
| 30 | 
            +
                    edge_input_divisible_by=None,
         | 
| 31 | 
            +
                ):
         | 
| 32 | 
            +
                    data_dir = os.getenv("DATA_DIR")
         | 
| 33 | 
            +
                    if data_dir is None:
         | 
| 34 | 
            +
                        raise ValueError("Environment variable DATA_DIR is not set.")
         | 
| 35 | 
            +
                    root_path = Path(data_dir) / "disk-data"
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.data = DISK_IMW(
         | 
| 38 | 
            +
                        str(
         | 
| 39 | 
            +
                            root_path
         | 
| 40 | 
            +
                        ),  # Resize only to ensure that the input size is divisible the value of edge_input_divisible_by
         | 
| 41 | 
            +
                        transforms=Compose(
         | 
| 42 | 
            +
                            [
         | 
| 43 | 
            +
                                Resize(None, edge_input_divisible_by),
         | 
| 44 | 
            +
                                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
         | 
| 45 | 
            +
                            ]
         | 
| 46 | 
            +
                        ),
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    self.ids_subset = None
         | 
| 49 | 
            +
                    self.results = []
         | 
| 50 | 
            +
                    self.conf_inference = conf_inference
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # fmt: off
         | 
| 53 | 
            +
                    if use_predefined_subset:
         | 
| 54 | 
            +
                        self.ids_subset = [4921, 3561, 3143, 6040, 802, 6828, 5338, 9275, 10764, 10085, 5124, 11355, 7, 10027, 2161, 4433, 6887, 3311, 10766,
         | 
| 55 | 
            +
                                           11451, 11433, 8539, 2581, 10300, 10562, 1723, 8803, 6275, 10140, 11487, 6238, 638, 8092, 9979, 201, 10394, 3414,
         | 
| 56 | 
            +
                                           9002, 7456, 2431, 632, 6589, 9265, 9889, 3139, 7890, 10619, 4899, 675, 176, 4309, 4814, 3833, 3519, 148, 4560, 10705,
         | 
| 57 | 
            +
                                           3744, 1441, 4049, 1791, 5106, 575, 1540, 1105, 6791, 1383, 9344, 501, 2504, 4335, 8992, 10970, 10786, 10405, 9317,
         | 
| 58 | 
            +
                                           5279, 1396, 5044, 9408, 11125, 10417, 7627, 7480, 1358, 7738, 5461, 10178, 9226, 8106, 2766, 6216, 4032, 7298, 259,
         | 
| 59 | 
            +
                                           3021, 2645, 8756, 7513, 3163, 2510, 6701, 6684, 3159, 9689, 7425, 6066, 1904, 6382, 3052, 777, 6277, 7409, 5997, 2987,
         | 
| 60 | 
            +
                                           11316, 2894, 4528, 1927, 10366, 8605, 2726, 1886, 2416, 2164, 3352, 2997, 6636, 6765, 5609, 3679, 76, 10956, 3612, 6699,
         | 
| 61 | 
            +
                                           1741, 8811, 3755, 1285, 9520, 2476, 3977, 370, 9823, 1834, 7551, 6227, 7303, 6399, 4758, 10713, 5050, 380, 11056, 7620,
         | 
| 62 | 
            +
                                           4826, 6090, 9011, 7523, 7355, 8021, 9801, 1801, 6522, 7138, 10017, 8732, 6402, 3116, 4031, 6088, 3975, 9841, 9082, 9412,
         | 
| 63 | 
            +
                                           5406, 217, 2385, 8791, 8361, 494, 4319, 5275, 3274, 335, 6731, 207, 10095, 3068, 5996, 3951, 2808, 5877, 6134, 7772, 10042,
         | 
| 64 | 
            +
                                           8574, 5501, 10885, 7871]
         | 
| 65 | 
            +
                        # self.ids_subset = self.ids_subset[:10]
         | 
| 66 | 
            +
                    # fmt: on
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def evaluate_sample(self, model, sample, dev):
         | 
| 69 | 
            +
                    img_1 = sample["src_image"].unsqueeze(0).to(dev)
         | 
| 70 | 
            +
                    img_2 = sample["trg_image"].unsqueeze(0).to(dev)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    scale_h_1, scale_w_1 = (
         | 
| 73 | 
            +
                        sample["orig_size_src"][0] / img_1.shape[2],
         | 
| 74 | 
            +
                        sample["orig_size_src"][1] / img_1.shape[3],
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
                    scale_h_2, scale_w_2 = (
         | 
| 77 | 
            +
                        sample["orig_size_trg"][0] / img_2.shape[2],
         | 
| 78 | 
            +
                        sample["orig_size_trg"][1] / img_2.shape[3],
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    M = None
         | 
| 82 | 
            +
                    info = {}
         | 
| 83 | 
            +
                    kpts_1, desc_1, score_1 = None, None, None
         | 
| 84 | 
            +
                    kpts_2, desc_2, score_2 = None, None, None
         | 
| 85 | 
            +
                    match_dists, match_idxs = None, None
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    try:
         | 
| 88 | 
            +
                        kpts_1, desc_1, score_1 = model.detectAndCompute(img_1, **self.conf_inference)
         | 
| 89 | 
            +
                        kpts_2, desc_2, score_2 = model.detectAndCompute(img_2, **self.conf_inference)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        if kpts_1.dim() == 3:
         | 
| 92 | 
            +
                            assert kpts_1.shape[0] == 1 and kpts_2.shape[0] == 1, "Batch size must be 1"
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                            kpts_1, desc_1, score_1 = (
         | 
| 95 | 
            +
                                kpts_1.squeeze(0),
         | 
| 96 | 
            +
                                desc_1[0].squeeze(0),
         | 
| 97 | 
            +
                                score_1[0].squeeze(0),
         | 
| 98 | 
            +
                            )
         | 
| 99 | 
            +
                            kpts_2, desc_2, score_2 = (
         | 
| 100 | 
            +
                                kpts_2.squeeze(0),
         | 
| 101 | 
            +
                                desc_2[0].squeeze(0),
         | 
| 102 | 
            +
                                score_2[0].squeeze(0),
         | 
| 103 | 
            +
                            )
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                        scale_1 = torch.tensor([scale_w_1, scale_h_1], dtype=torch.float).to(dev)
         | 
| 106 | 
            +
                        scale_2 = torch.tensor([scale_w_2, scale_h_2], dtype=torch.float).to(dev)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                        kpts_1 = kpts_1 * scale_1
         | 
| 109 | 
            +
                        kpts_2 = kpts_2 * scale_2
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        matcher = KF.DescriptorMatcher("mnn")  # threshold is not used with mnn
         | 
| 112 | 
            +
                        match_dists, match_idxs = matcher(desc_1, desc_2)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                        matched_pts_1 = kpts_1[match_idxs[:, 0]]
         | 
| 115 | 
            +
                        matched_pts_2 = kpts_2[match_idxs[:, 1]]
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                        camera_1 = sample["src_camera"]
         | 
| 118 | 
            +
                        camera_2 = sample["trg_camera"]
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        M, info = poselib.estimate_relative_pose(
         | 
| 121 | 
            +
                            matched_pts_1.cpu().numpy(),
         | 
| 122 | 
            +
                            matched_pts_2.cpu().numpy(),
         | 
| 123 | 
            +
                            camera_1.to_cameradict(),
         | 
| 124 | 
            +
                            camera_2.to_cameradict(),
         | 
| 125 | 
            +
                            {
         | 
| 126 | 
            +
                                "max_epipolar_error": 0.5,
         | 
| 127 | 
            +
                            },
         | 
| 128 | 
            +
                            {},
         | 
| 129 | 
            +
                        )
         | 
| 130 | 
            +
                    except RuntimeError as e:
         | 
| 131 | 
            +
                        if "No keypoints detected" in str(e):
         | 
| 132 | 
            +
                            pass
         | 
| 133 | 
            +
                        else:
         | 
| 134 | 
            +
                            raise e
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    success = M is not None
         | 
| 137 | 
            +
                    if success:
         | 
| 138 | 
            +
                        M = {
         | 
| 139 | 
            +
                            "R": torch.tensor(M.R, dtype=torch.float),
         | 
| 140 | 
            +
                            "t": torch.tensor(M.t, dtype=torch.float),
         | 
| 141 | 
            +
                        }
         | 
| 142 | 
            +
                        inl = info["inliers"]
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        M = {
         | 
| 145 | 
            +
                            "R": torch.eye(3, dtype=torch.float),
         | 
| 146 | 
            +
                            "t": torch.zeros((3), dtype=torch.float),
         | 
| 147 | 
            +
                        }
         | 
| 148 | 
            +
                        inl = np.zeros((0,)).astype(bool)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    t_err, r_err = relative_pose_error(sample["s2t_R"].cpu(), sample["s2t_T"].cpu(), M["R"], M["t"])
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    rel_pose_error = max(t_err.item(), r_err.item()) if success else np.inf
         | 
| 153 | 
            +
                    ransac_inl = np.sum(inl)
         | 
| 154 | 
            +
                    ransac_inl_ratio = np.mean(inl)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    if success:
         | 
| 157 | 
            +
                        assert match_dists is not None and match_idxs is not None, "Matches must be computed"
         | 
| 158 | 
            +
                        cv_keypoints_src = to_cv_kpts(kpts_1, score_1)
         | 
| 159 | 
            +
                        cv_keypoints_trg = to_cv_kpts(kpts_2, score_2)
         | 
| 160 | 
            +
                        cv_matches = cv2_matches_from_kornia(match_dists, match_idxs)
         | 
| 161 | 
            +
                        cv_mask = [int(m) for m in inl]
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        cv_keypoints_src, cv_keypoints_trg = [], []
         | 
| 164 | 
            +
                        cv_matches, cv_mask = [], []
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    estimation = {
         | 
| 167 | 
            +
                        "success": success,
         | 
| 168 | 
            +
                        "M_0to1": M,
         | 
| 169 | 
            +
                        "inliers": torch.tensor(inl).to(img_1),
         | 
| 170 | 
            +
                        "rel_pose_error": rel_pose_error,
         | 
| 171 | 
            +
                        "ransac_inl": ransac_inl,
         | 
| 172 | 
            +
                        "ransac_inl_ratio": ransac_inl_ratio,
         | 
| 173 | 
            +
                        "path_src_image": sample["src_path"],
         | 
| 174 | 
            +
                        "path_trg_image": sample["trg_path"],
         | 
| 175 | 
            +
                        "cv_keypoints_src": cv_keypoints_src,
         | 
| 176 | 
            +
                        "cv_keypoints_trg": cv_keypoints_trg,
         | 
| 177 | 
            +
                        "cv_matches": cv_matches,
         | 
| 178 | 
            +
                        "cv_mask": cv_mask,
         | 
| 179 | 
            +
                    }
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    return estimation
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def evaluate(self, model, dev, progress_bar=False):
         | 
| 184 | 
            +
                    model.eval()
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    # reset results
         | 
| 187 | 
            +
                    self.results = []
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    for idx in tqdm(
         | 
| 190 | 
            +
                        self.ids_subset if self.ids_subset is not None else range(len(self.data)),
         | 
| 191 | 
            +
                        disable=not progress_bar,
         | 
| 192 | 
            +
                    ):
         | 
| 193 | 
            +
                        sample = self.data[idx]
         | 
| 194 | 
            +
                        self.results.append(self.evaluate_sample(model, sample, dev))
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def get_auc(self, threshold=5, downsampled=False):
         | 
| 197 | 
            +
                    if len(self.results) == 0:
         | 
| 198 | 
            +
                        raise ValueError("No results to log. Run evaluate first.")
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    summary_results = self.calc_auc(downsampled=downsampled)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    return summary_results[f"rel_pose_error@{threshold}°{'__original' if not downsampled else '__downsampled'}"]
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def plot_results(self, num_samples=10, logger=None, step=None, downsampled=False):
         | 
| 205 | 
            +
                    if len(self.results) == 0:
         | 
| 206 | 
            +
                        raise ValueError("No results to plot. Run evaluate first.")
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    plot_data = []
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    for result in self.results[:num_samples]:
         | 
| 211 | 
            +
                        img1 = cv2.imread(result["path_src_image"])
         | 
| 212 | 
            +
                        img2 = cv2.imread(result["path_trg_image"])
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        # from BGR to RGB
         | 
| 215 | 
            +
                        img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
         | 
| 216 | 
            +
                        img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        plt_matches = cv2.drawMatches(
         | 
| 219 | 
            +
                            img1,
         | 
| 220 | 
            +
                            result["cv_keypoints_src"],
         | 
| 221 | 
            +
                            img2,
         | 
| 222 | 
            +
                            result["cv_keypoints_trg"],
         | 
| 223 | 
            +
                            result["cv_matches"],
         | 
| 224 | 
            +
                            None,
         | 
| 225 | 
            +
                            matchColor=None,
         | 
| 226 | 
            +
                            matchesMask=result["cv_mask"],
         | 
| 227 | 
            +
                            flags=cv2.DrawMatchesFlags_DEFAULT,
         | 
| 228 | 
            +
                        )
         | 
| 229 | 
            +
                        file_name = (
         | 
| 230 | 
            +
                            Path(result["path_src_image"]).parent.parent.name
         | 
| 231 | 
            +
                            + "_"
         | 
| 232 | 
            +
                            + Path(result["path_src_image"]).stem
         | 
| 233 | 
            +
                            + Path(result["path_trg_image"]).stem
         | 
| 234 | 
            +
                            + ("_downsampled" if downsampled else "")
         | 
| 235 | 
            +
                            + ".png"
         | 
| 236 | 
            +
                        )
         | 
| 237 | 
            +
                        # print rel_pose_error on image
         | 
| 238 | 
            +
                        plt_matches = cv2.putText(
         | 
| 239 | 
            +
                            plt_matches,
         | 
| 240 | 
            +
                            f"rel_pose_error: {result['rel_pose_error']:.2f} num_inliers: {result['ransac_inl']} inl_ratio: {result['ransac_inl_ratio']:.2f} num_matches: {len(result['cv_matches'])} num_keypoints: {len(result['cv_keypoints_src'])}/{len(result['cv_keypoints_trg'])}",
         | 
| 241 | 
            +
                            (10, 30),
         | 
| 242 | 
            +
                            cv2.FONT_HERSHEY_SIMPLEX,
         | 
| 243 | 
            +
                            1,
         | 
| 244 | 
            +
                            (0, 0, 0),
         | 
| 245 | 
            +
                            2,
         | 
| 246 | 
            +
                            cv2.LINE_8,
         | 
| 247 | 
            +
                        )
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                        plot_data.append({"file_name": file_name, "image": plt_matches})
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    if logger is None:
         | 
| 252 | 
            +
                        log.info("No logger provided. Using plt to plot results.")
         | 
| 253 | 
            +
                        for image in plot_data:
         | 
| 254 | 
            +
                            plt.imsave(
         | 
| 255 | 
            +
                                image["file_name"],
         | 
| 256 | 
            +
                                cv_resize_and_pad_to_shape(image["image"], (1024, 2048)),
         | 
| 257 | 
            +
                            )
         | 
| 258 | 
            +
                            plt.close()
         | 
| 259 | 
            +
                    else:
         | 
| 260 | 
            +
                        import wandb
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                        log.info(f"Logging images to wandb with step={step}")
         | 
| 263 | 
            +
                        if not downsampled:
         | 
| 264 | 
            +
                            logger.log(
         | 
| 265 | 
            +
                                {
         | 
| 266 | 
            +
                                    "examples": [
         | 
| 267 | 
            +
                                        wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
         | 
| 268 | 
            +
                                    ]
         | 
| 269 | 
            +
                                },
         | 
| 270 | 
            +
                                step=step,
         | 
| 271 | 
            +
                            )
         | 
| 272 | 
            +
                        else:
         | 
| 273 | 
            +
                            logger.log(
         | 
| 274 | 
            +
                                {
         | 
| 275 | 
            +
                                    "examples_downsampled": [
         | 
| 276 | 
            +
                                        wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
         | 
| 277 | 
            +
                                    ]
         | 
| 278 | 
            +
                                },
         | 
| 279 | 
            +
                                step=step,
         | 
| 280 | 
            +
                            )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                def log_results(self, logger=None, step=None, downsampled=False):
         | 
| 283 | 
            +
                    if len(self.results) == 0:
         | 
| 284 | 
            +
                        raise ValueError("No results to log. Run evaluate first.")
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    summary_results = self.calc_auc(downsampled=downsampled)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    if logger is not None:
         | 
| 289 | 
            +
                        logger.log(summary_results, step=step)
         | 
| 290 | 
            +
                    else:
         | 
| 291 | 
            +
                        log.warning("No logger provided. Printing results instead.")
         | 
| 292 | 
            +
                        print(self.calc_auc())
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def print_results(self):
         | 
| 295 | 
            +
                    if len(self.results) == 0:
         | 
| 296 | 
            +
                        raise ValueError("No results to print. Run evaluate first.")
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    print(self.calc_auc())
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def calc_auc(self, auc_thresholds=None, downsampled=False):
         | 
| 301 | 
            +
                    if auc_thresholds is None:
         | 
| 302 | 
            +
                        auc_thresholds = [5, 10, 20]
         | 
| 303 | 
            +
                    if not isinstance(auc_thresholds, list):
         | 
| 304 | 
            +
                        auc_thresholds = [auc_thresholds]
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    if len(self.results) == 0:
         | 
| 307 | 
            +
                        raise ValueError("No results to calculate auc. Run evaluate first.")
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    rel_pose_errors = [r["rel_pose_error"] for r in self.results]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    pose_aucs = AUCMetric(auc_thresholds, rel_pose_errors).compute()
         | 
| 312 | 
            +
                    assert isinstance(pose_aucs, list) and len(pose_aucs) == len(auc_thresholds)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    ext = "_downsampled" if downsampled else "_original"
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    summary = {}
         | 
| 317 | 
            +
                    for i, ath in enumerate(auc_thresholds):
         | 
| 318 | 
            +
                        summary[f"rel_pose_error@{ath}°_{ext}"] = pose_aucs[i]
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    return summary
         | 
    	
        imcui/third_party/RIPE/ripe/data/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        imcui/third_party/RIPE/ripe/data/data_transforms.py
    ADDED
    
    | @@ -0,0 +1,204 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import collections
         | 
| 2 | 
            +
            import collections.abc
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import kornia.geometry as KG
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torchvision.transforms import functional as TF
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Compose:
         | 
| 11 | 
            +
                """Composes several transforms together. The transforms are applied in the order they are passed in.
         | 
| 12 | 
            +
                Args:        transforms (list): A list of transforms to be applied.
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __init__(self, transforms):
         | 
| 16 | 
            +
                    self.transforms = transforms
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def __call__(self, src, trg, src_mask, trg_mask, h):
         | 
| 19 | 
            +
                    for t in self.transforms:
         | 
| 20 | 
            +
                        src, trg, src_mask, trg_mask, h = t(src, trg, src_mask, trg_mask, h)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    return src, trg, src_mask, trg_mask, h
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class Transform:
         | 
| 26 | 
            +
                """Base class for all transforms. It provides a method to apply a transformation function to the input images and masks.
         | 
| 27 | 
            +
                Args:
         | 
| 28 | 
            +
                    src (torch.Tensor): The source image tensor.
         | 
| 29 | 
            +
                    trg (torch.Tensor): The target image tensor.
         | 
| 30 | 
            +
                    src_mask (torch.Tensor): The source image mask tensor.
         | 
| 31 | 
            +
                    trg_mask (torch.Tensor): The target image mask tensor.
         | 
| 32 | 
            +
                    h (torch.Tensor): The homography matrix tensor.
         | 
| 33 | 
            +
                Returns:
         | 
| 34 | 
            +
                    tuple: A tuple containing the transformed source image, the transformed target image, the transformed source mask,
         | 
| 35 | 
            +
                    the transformed target mask and the updated homography matrix.
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def __init__(self):
         | 
| 39 | 
            +
                    pass
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def apply_transform(self, src, trg, src_mask, trg_mask, h, transfrom_function):
         | 
| 42 | 
            +
                    src, trg, src_mask, trg_mask, h = transfrom_function(src, trg, src_mask, trg_mask, h)
         | 
| 43 | 
            +
                    return src, trg, src_mask, trg_mask, h
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class Normalize(Transform):
         | 
| 47 | 
            +
                def __init__(self, mean, std):
         | 
| 48 | 
            +
                    self.mean = mean
         | 
| 49 | 
            +
                    self.std = std
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __call__(self, src, trg, src_mask, trg_mask, h):
         | 
| 52 | 
            +
                    return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def transform_function(self, src, trg, src_mask, trg_mask, h):
         | 
| 55 | 
            +
                    src = TF.normalize(src, mean=self.mean, std=self.std)
         | 
| 56 | 
            +
                    trg = TF.normalize(trg, mean=self.mean, std=self.std)
         | 
| 57 | 
            +
                    return src, trg, src_mask, trg_mask, h
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class ResizeAndPadWithHomography(Transform):
         | 
| 61 | 
            +
                def __init__(self, target_size_longer_side=768):
         | 
| 62 | 
            +
                    self.target_size = target_size_longer_side
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def __call__(self, src, trg, src_mask, trg_mask, h):
         | 
| 65 | 
            +
                    return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def transform_function(self, src, trg, src_mask, trg_mask, h):
         | 
| 68 | 
            +
                    src_w, src_h = src.shape[-1], src.shape[-2]
         | 
| 69 | 
            +
                    trg_w, trg_h = trg.shape[-1], trg.shape[-2]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # Resizing logic for both images
         | 
| 72 | 
            +
                    scale_src, new_src_w, new_src_h = self.compute_resize(src_w, src_h)
         | 
| 73 | 
            +
                    scale_trg, new_trg_w, new_trg_h = self.compute_resize(trg_w, trg_h)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # Resize both images
         | 
| 76 | 
            +
                    src_resized = TF.resize(src, [new_src_h, new_src_w])
         | 
| 77 | 
            +
                    trg_resized = TF.resize(trg, [new_trg_h, new_trg_w])
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    src_mask_resized = TF.resize(src_mask, [new_src_h, new_src_w])
         | 
| 80 | 
            +
                    trg_mask_resized = TF.resize(trg_mask, [new_trg_h, new_trg_w])
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # Pad the resized images to be square (768x768)
         | 
| 83 | 
            +
                    src_padded, src_padding = self.apply_padding(src_resized, new_src_w, new_src_h)
         | 
| 84 | 
            +
                    trg_padded, trg_padding = self.apply_padding(trg_resized, new_trg_w, new_trg_h)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    src_mask_padded, _ = self.apply_padding(src_mask_resized, new_src_w, new_src_h)
         | 
| 87 | 
            +
                    trg_mask_padded, _ = self.apply_padding(trg_mask_resized, new_trg_w, new_trg_h)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # Update the homography matrix
         | 
| 90 | 
            +
                    h = self.update_homography(h, scale_src, src_padding, scale_trg, trg_padding)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    return src_padded, trg_padded, src_mask_padded, trg_mask_padded, h
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def compute_resize(self, w, h):
         | 
| 95 | 
            +
                    if w > h:
         | 
| 96 | 
            +
                        scale = self.target_size / w
         | 
| 97 | 
            +
                        new_w = self.target_size
         | 
| 98 | 
            +
                        new_h = int(h * scale)
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        scale = self.target_size / h
         | 
| 101 | 
            +
                        new_h = self.target_size
         | 
| 102 | 
            +
                        new_w = int(w * scale)
         | 
| 103 | 
            +
                    return scale, new_w, new_h
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def apply_padding(self, img, new_w, new_h):
         | 
| 106 | 
            +
                    pad_w = (self.target_size - new_w) // 2
         | 
| 107 | 
            +
                    pad_h = (self.target_size - new_h) // 2
         | 
| 108 | 
            +
                    padding = [
         | 
| 109 | 
            +
                        pad_w,
         | 
| 110 | 
            +
                        pad_h,
         | 
| 111 | 
            +
                        self.target_size - new_w - pad_w,
         | 
| 112 | 
            +
                        self.target_size - new_h - pad_h,
         | 
| 113 | 
            +
                    ]
         | 
| 114 | 
            +
                    img_padded = TF.pad(img, padding, fill=0)  # Zero-pad
         | 
| 115 | 
            +
                    return img_padded, padding
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def update_homography(self, h, scale_src, padding_src, scale_trg, padding_trg):
         | 
| 118 | 
            +
                    # Create the scaling matrices
         | 
| 119 | 
            +
                    scale_matrix_src = np.array([[scale_src, 0, 0], [0, scale_src, 0], [0, 0, 1]])
         | 
| 120 | 
            +
                    scale_matrix_trg = np.array([[scale_trg, 0, 0], [0, scale_trg, 0], [0, 0, 1]])
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # Create the padding translation matrices
         | 
| 123 | 
            +
                    pad_matrix_src = np.array([[1, 0, padding_src[0]], [0, 1, padding_src[1]], [0, 0, 1]])
         | 
| 124 | 
            +
                    pad_matrix_trg = np.array([[1, 0, -padding_trg[0]], [0, 1, -padding_trg[1]], [0, 0, 1]])
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    # Update the homography: apply scaling and translation
         | 
| 127 | 
            +
                    h_updated = (
         | 
| 128 | 
            +
                        pad_matrix_trg
         | 
| 129 | 
            +
                        @ scale_matrix_trg
         | 
| 130 | 
            +
                        @ h.numpy()
         | 
| 131 | 
            +
                        @ np.linalg.inv(scale_matrix_src)
         | 
| 132 | 
            +
                        @ np.linalg.inv(pad_matrix_src)
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    return torch.from_numpy(h_updated).float()
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            class Resize(Transform):
         | 
| 139 | 
            +
                def __init__(self, output_size, edge_divisible_by=None, side="long", antialias=True):
         | 
| 140 | 
            +
                    self.output_size = output_size
         | 
| 141 | 
            +
                    self.edge_divisible_by = edge_divisible_by
         | 
| 142 | 
            +
                    self.side = side
         | 
| 143 | 
            +
                    self.antialias = antialias
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def __call__(self, src, trg, src_mask, trg_mask, h):
         | 
| 146 | 
            +
                    return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def transform_function(self, src, trg, src_mask, trg_mask, h):
         | 
| 149 | 
            +
                    new_size_src = self.get_new_image_size(src)
         | 
| 150 | 
            +
                    new_size_trg = self.get_new_image_size(trg)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    src, T_src = self.resize(src, new_size_src)
         | 
| 153 | 
            +
                    trg, T_trg = self.resize(trg, new_size_trg)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    src_mask, _ = self.resize(src_mask, new_size_src)
         | 
| 156 | 
            +
                    trg_mask, _ = self.resize(trg_mask, new_size_trg)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    h = torch.from_numpy(T_trg @ h.numpy() @ T_src).float()
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    return src, trg, src_mask, trg_mask, h
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def resize(self, img, size):
         | 
| 163 | 
            +
                    h, w = img.shape[-2:]
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    img = KG.transform.resize(
         | 
| 166 | 
            +
                        img,
         | 
| 167 | 
            +
                        size,
         | 
| 168 | 
            +
                        side=self.side,
         | 
| 169 | 
            +
                        antialias=self.antialias,
         | 
| 170 | 
            +
                        align_corners=None,
         | 
| 171 | 
            +
                        interpolation="bilinear",
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
         | 
| 175 | 
            +
                    T = np.diag([scale[0].item(), scale[1].item(), 1])
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    return img, T
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def get_new_image_size(self, img):
         | 
| 180 | 
            +
                    h, w = img.shape[-2:]
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    if isinstance(self.output_size, collections.abc.Iterable):
         | 
| 183 | 
            +
                        assert len(self.output_size) == 2
         | 
| 184 | 
            +
                        return tuple(self.output_size)
         | 
| 185 | 
            +
                    if self.output_size is None:  # keep the original size, but possibly make it divisible by edge_divisible_by
         | 
| 186 | 
            +
                        size = (h, w)
         | 
| 187 | 
            +
                    else:
         | 
| 188 | 
            +
                        side_size = self.output_size
         | 
| 189 | 
            +
                        aspect_ratio = w / h
         | 
| 190 | 
            +
                        if self.side not in ("short", "long", "vert", "horz"):
         | 
| 191 | 
            +
                            raise ValueError(f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{self.side}'")
         | 
| 192 | 
            +
                        if self.side == "vert":
         | 
| 193 | 
            +
                            size = side_size, int(side_size * aspect_ratio)
         | 
| 194 | 
            +
                        elif self.side == "horz":
         | 
| 195 | 
            +
                            size = int(side_size / aspect_ratio), side_size
         | 
| 196 | 
            +
                        elif (self.side == "short") ^ (aspect_ratio < 1.0):
         | 
| 197 | 
            +
                            size = side_size, int(side_size * aspect_ratio)
         | 
| 198 | 
            +
                        else:
         | 
| 199 | 
            +
                            size = int(side_size / aspect_ratio), side_size
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    if self.edge_divisible_by is not None:
         | 
| 202 | 
            +
                        df = self.edge_divisible_by
         | 
| 203 | 
            +
                        size = list(map(lambda x: int(x // df * df), size))
         | 
| 204 | 
            +
                    return size
         | 
    	
        imcui/third_party/RIPE/ripe/data/datasets/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        imcui/third_party/RIPE/ripe/data/datasets/acdc.py
    ADDED
    
    | @@ -0,0 +1,154 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from pathlib import Path
         | 
| 2 | 
            +
            from typing import Any, Callable, Dict, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from torch.utils.data import Dataset
         | 
| 6 | 
            +
            from torchvision.io import read_image
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ripe import utils
         | 
| 9 | 
            +
            from ripe.data.data_transforms import Compose
         | 
| 10 | 
            +
            from ripe.utils.utils import get_other_random_id
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class ACDC(Dataset):
         | 
| 16 | 
            +
                def __init__(
         | 
| 17 | 
            +
                    self,
         | 
| 18 | 
            +
                    root: Path,
         | 
| 19 | 
            +
                    stage: str = "train",
         | 
| 20 | 
            +
                    condition: str = "rain",
         | 
| 21 | 
            +
                    transforms: Optional[Callable] = None,
         | 
| 22 | 
            +
                    positive_only: bool = False,
         | 
| 23 | 
            +
                ) -> None:
         | 
| 24 | 
            +
                    self.root = root
         | 
| 25 | 
            +
                    self.stage = stage
         | 
| 26 | 
            +
                    self.condition = condition
         | 
| 27 | 
            +
                    self.transforms = transforms
         | 
| 28 | 
            +
                    self.positive_only = positive_only
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    if isinstance(self.root, str):
         | 
| 31 | 
            +
                        self.root = Path(self.root)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    if not self.root.exists():
         | 
| 34 | 
            +
                        raise FileNotFoundError(f"Dataset not found at {self.root}")
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    if transforms is None:
         | 
| 37 | 
            +
                        self.transforms = Compose([])
         | 
| 38 | 
            +
                    else:
         | 
| 39 | 
            +
                        self.transforms = transforms
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    if self.stage not in ["train", "val", "test", "pred"]:
         | 
| 42 | 
            +
                        raise RuntimeError(
         | 
| 43 | 
            +
                            "Unknown option "
         | 
| 44 | 
            +
                            + self.stage
         | 
| 45 | 
            +
                            + " as training stage variable. Valid options: 'train', 'val', 'test' and 'pred'"
         | 
| 46 | 
            +
                        )
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    if self.stage == "pred":  # prediction uses the test set
         | 
| 49 | 
            +
                        self.stage = "test"
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    if self.stage in ["val", "test", "pred"]:
         | 
| 52 | 
            +
                        self.positive_only = True
         | 
| 53 | 
            +
                        log.info(f"{self.stage} stage: Using only positive pairs!")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    weather_conditions = ["fog", "night", "rain", "snow"]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    if self.condition not in weather_conditions + ["all"]:
         | 
| 58 | 
            +
                        raise RuntimeError(
         | 
| 59 | 
            +
                            "Unknown option "
         | 
| 60 | 
            +
                            + self.condition
         | 
| 61 | 
            +
                            + " as weather condition variable. Valid options: 'fog', 'night', 'rain', 'snow' and 'all'"
         | 
| 62 | 
            +
                        )
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.weather_condition_query = weather_conditions if self.condition == "all" else [self.condition]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self._read_sample_files()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    if positive_only:
         | 
| 69 | 
            +
                        log.warning("Using only positive pairs!")
         | 
| 70 | 
            +
                    log.info(f"Found {len(self.src_images)} source images and {len(self.trg_images)} target images.")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def _read_sample_files(self):
         | 
| 73 | 
            +
                    file_name_pattern_ref = "_ref_anon.png"
         | 
| 74 | 
            +
                    file_name_pattern = "_rgb_anon.png"
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.trg_images = []
         | 
| 77 | 
            +
                    self.src_images = []
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    for weather_condition in self.weather_condition_query:
         | 
| 80 | 
            +
                        rgb_files = sorted(
         | 
| 81 | 
            +
                            list(self.root.glob("rgb_anon/" + weather_condition + "/" + self.stage + "/**/*" + file_name_pattern)),
         | 
| 82 | 
            +
                            key=lambda i: i.stem[:21],
         | 
| 83 | 
            +
                        )
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                        src_images = sorted(
         | 
| 86 | 
            +
                            list(
         | 
| 87 | 
            +
                                self.root.glob(
         | 
| 88 | 
            +
                                    "rgb_anon/" + weather_condition + "/" + self.stage + "_ref" + "/**/*" + file_name_pattern_ref
         | 
| 89 | 
            +
                                )
         | 
| 90 | 
            +
                            ),
         | 
| 91 | 
            +
                            key=lambda i: i.stem[:21],
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                        self.trg_images += rgb_files
         | 
| 95 | 
            +
                        self.src_images += src_images
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def __len__(self) -> int:
         | 
| 98 | 
            +
                    if self.positive_only:
         | 
| 99 | 
            +
                        return len(self.trg_images)
         | 
| 100 | 
            +
                    return 2 * len(self.trg_images)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def __getitem__(self, idx: int) -> Dict[str, Any]:
         | 
| 103 | 
            +
                    sample: Any = {}
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    positive_sample = (idx % 2 == 0) or (self.positive_only)
         | 
| 106 | 
            +
                    if not self.positive_only:
         | 
| 107 | 
            +
                        idx = idx // 2
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    sample["label"] = positive_sample
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    if positive_sample:
         | 
| 112 | 
            +
                        sample["src_path"] = str(self.src_images[idx])
         | 
| 113 | 
            +
                        sample["trg_path"] = str(self.trg_images[idx])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        assert self.src_images[idx].stem[:21] == self.trg_images[idx].stem[:21], (
         | 
| 116 | 
            +
                            f"Source and target image mismatch: {self.src_images[idx]} vs {self.trg_images[idx]}"
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        src_img = read_image(sample["src_path"])
         | 
| 120 | 
            +
                        trg_img = read_image(sample["trg_path"])
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        homography = torch.eye(3, dtype=torch.float32)
         | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        sample["src_path"] = str(self.src_images[idx])
         | 
| 125 | 
            +
                        idx_other = get_other_random_id(idx, len(self) // 2)
         | 
| 126 | 
            +
                        sample["trg_path"] = str(self.trg_images[idx_other])
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                        assert self.src_images[idx].stem[:21] != self.trg_images[idx_other].stem[:21], (
         | 
| 129 | 
            +
                            f"Source and target image match for negative sample: {self.src_images[idx]} vs {self.trg_images[idx_other]}"
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        src_img = read_image(sample["src_path"])
         | 
| 133 | 
            +
                        trg_img = read_image(sample["trg_path"])
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        homography = torch.zeros((3, 3), dtype=torch.float32)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    src_img = src_img / 255.0
         | 
| 138 | 
            +
                    trg_img = trg_img / 255.0
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    _, H, W = src_img.shape
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    src_mask = torch.ones((1, H, W), dtype=torch.uint8)
         | 
| 143 | 
            +
                    trg_mask = torch.ones((1, H, W), dtype=torch.uint8)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if self.transforms:
         | 
| 146 | 
            +
                        src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    sample["src_image"] = src_img
         | 
| 149 | 
            +
                    sample["trg_image"] = trg_img
         | 
| 150 | 
            +
                    sample["src_mask"] = src_mask.to(torch.bool)
         | 
| 151 | 
            +
                    sample["trg_mask"] = trg_mask.to(torch.bool)
         | 
| 152 | 
            +
                    sample["homography"] = homography
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    return sample
         | 
    	
        imcui/third_party/RIPE/ripe/data/datasets/dataset_combinator.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from ripe import utils
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class DatasetCombinator:
         | 
| 9 | 
            +
                """Combines multiple datasets into one. Length of the combined dataset is the length of the
         | 
| 10 | 
            +
                longest dataset. Shorter datasets are looped over.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Args:
         | 
| 13 | 
            +
                    datasets: List of datasets to combine.
         | 
| 14 | 
            +
                    mode: How to sample from the datasets. Can be either "uniform" or "weighted".
         | 
| 15 | 
            +
                        In "uniform" mode, each dataset is sampled with equal probability.
         | 
| 16 | 
            +
                        In "weighted" mode, each dataset is sampled with probability proportional to its length.
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def __init__(self, datasets, mode="uniform", weights=None):
         | 
| 20 | 
            +
                    self.datasets = datasets
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    names_datasets = [type(ds).__name__ for ds in self.datasets]
         | 
| 23 | 
            +
                    self.lengths = [len(ds) for ds in datasets]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    if mode == "weighted":
         | 
| 26 | 
            +
                        self.probs_datasets = [length / sum(self.lengths) for length in self.lengths]
         | 
| 27 | 
            +
                    elif mode == "uniform":
         | 
| 28 | 
            +
                        self.probs_datasets = [1 / len(self.datasets) for _ in self.datasets]
         | 
| 29 | 
            +
                    elif mode == "custom":
         | 
| 30 | 
            +
                        assert weights is not None, "Weights must be provided in custom mode"
         | 
| 31 | 
            +
                        assert len(weights) == len(datasets), "Number of weights must match number of datasets"
         | 
| 32 | 
            +
                        assert sum(weights) == 1.0, "Weights must sum to 1"
         | 
| 33 | 
            +
                        self.probs_datasets = weights
         | 
| 34 | 
            +
                    else:
         | 
| 35 | 
            +
                        raise ValueError(f"Unknown mode {mode}")
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    log.info("Got the following datasets: ")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    for name, length, prob in zip(names_datasets, self.lengths, self.probs_datasets):
         | 
| 40 | 
            +
                        log.info(f"{name} with {length} samples and probability {prob}")
         | 
| 41 | 
            +
                    log.info(f"Total number of samples: {sum(self.lengths)}")
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    self.num_samples = max(self.lengths)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    self.dataset_dist = torch.distributions.Categorical(probs=torch.tensor(self.probs_datasets))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __len__(self):
         | 
| 48 | 
            +
                    return self.num_samples
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def __getitem__(self, idx: int):
         | 
| 51 | 
            +
                    positive_sample = idx % 2 == 0
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    if positive_sample:
         | 
| 54 | 
            +
                        dataset_idx = self.dataset_dist.sample().item()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                        idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
         | 
| 57 | 
            +
                        while idx % 2 == 1:
         | 
| 58 | 
            +
                            idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        return self.datasets[dataset_idx][idx]
         | 
| 61 | 
            +
                    else:
         | 
| 62 | 
            +
                        dataset_idx_1 = self.dataset_dist.sample().item()
         | 
| 63 | 
            +
                        dataset_idx_2 = self.dataset_dist.sample().item()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                        if dataset_idx_1 == dataset_idx_2:
         | 
| 66 | 
            +
                            idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
         | 
| 67 | 
            +
                            while idx % 2 == 0:
         | 
| 68 | 
            +
                                idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
         | 
| 69 | 
            +
                            return self.datasets[dataset_idx_1][idx]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                        else:
         | 
| 72 | 
            +
                            idx_1 = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
         | 
| 73 | 
            +
                            idx_2 = torch.randint(0, self.lengths[dataset_idx_2], (1,)).item()
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                            sample_1 = self.datasets[dataset_idx_1][idx_1]
         | 
| 76 | 
            +
                            sample_2 = self.datasets[dataset_idx_2][idx_2]
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                            sample = {
         | 
| 79 | 
            +
                                "label": False,
         | 
| 80 | 
            +
                                "src_path": sample_1["src_path"],
         | 
| 81 | 
            +
                                "trg_path": sample_2["trg_path"],
         | 
| 82 | 
            +
                                "src_image": sample_1["src_image"],
         | 
| 83 | 
            +
                                "trg_image": sample_2["trg_image"],
         | 
| 84 | 
            +
                                "src_mask": sample_1["src_mask"],
         | 
| 85 | 
            +
                                "trg_mask": sample_2["trg_mask"],
         | 
| 86 | 
            +
                                "homography": sample_2["homography"],
         | 
| 87 | 
            +
                            }
         | 
| 88 | 
            +
                            return sample
         | 
    	
        imcui/third_party/RIPE/ripe/data/datasets/disk_imw.py
    ADDED
    
    | @@ -0,0 +1,160 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from itertools import accumulate
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            from typing import Any, Callable, Dict, Optional, Tuple
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from torch.utils.data import Dataset
         | 
| 9 | 
            +
            from torchvision.io import read_image
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from ripe import utils
         | 
| 12 | 
            +
            from ripe.data.data_transforms import Compose
         | 
| 13 | 
            +
            from ripe.utils.image_utils import Camera, cameras2F
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class DISK_IMW(Dataset):
         | 
| 19 | 
            +
                def __init__(
         | 
| 20 | 
            +
                    self,
         | 
| 21 | 
            +
                    root: str,
         | 
| 22 | 
            +
                    stage: str = "val",
         | 
| 23 | 
            +
                    # condition: str = "rain",
         | 
| 24 | 
            +
                    transforms: Optional[Callable] = None,
         | 
| 25 | 
            +
                ) -> None:
         | 
| 26 | 
            +
                    self.root = root
         | 
| 27 | 
            +
                    self.stage = stage
         | 
| 28 | 
            +
                    self.transforms = transforms
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    if isinstance(self.root, str):
         | 
| 31 | 
            +
                        self.root = Path(self.root)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    if not self.root.exists():
         | 
| 34 | 
            +
                        raise FileNotFoundError(f"Dataset not found at {self.root}")
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    if transforms is None:
         | 
| 37 | 
            +
                        self.transforms = Compose([])
         | 
| 38 | 
            +
                    else:
         | 
| 39 | 
            +
                        self.transforms = transforms
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    if self.stage not in ["val"]:
         | 
| 42 | 
            +
                        raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    json_path = self.root / "imw2020-val" / "dataset.json"
         | 
| 45 | 
            +
                    with open(json_path) as json_file:
         | 
| 46 | 
            +
                        json_data = json.load(json_file)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    self.scenes = []
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    for scene in json_data:
         | 
| 51 | 
            +
                        self.scenes.append(Scene(self.root / "imw2020-val", json_data[scene]))
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.tuples_per_scene = [len(scene) for scene in self.scenes]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __len__(self) -> int:
         | 
| 56 | 
            +
                    return sum(self.tuples_per_scene)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __getitem__(self, idx: int) -> Dict[str, Any]:
         | 
| 59 | 
            +
                    sample: Any = {}
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    sample["src_path"], sample["trg_path"], path_calib_src, path_calib_trg = self.scenes[i_scene][i_image]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    cam_src = Camera.from_calibration_file(path_calib_src)
         | 
| 66 | 
            +
                    cam_trg = Camera.from_calibration_file(path_calib_trg)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    F = self.get_F(cam_src, cam_trg)
         | 
| 69 | 
            +
                    s2t_R, s2t_T = self.get_relative_pose(cam_src, cam_trg)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    src_img = read_image(sample["src_path"]) / 255.0
         | 
| 72 | 
            +
                    trg_img = read_image(sample["trg_path"]) / 255.0
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    _, H_src, W_src = src_img.shape
         | 
| 75 | 
            +
                    _, H_trg, W_trg = trg_img.shape
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
         | 
| 78 | 
            +
                    trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    H = torch.eye(3)
         | 
| 81 | 
            +
                    if self.transforms:
         | 
| 82 | 
            +
                        src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, H)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # check if transformations in self.transforms. Only Normalize is allowed
         | 
| 85 | 
            +
                    for t in self.transforms.transforms:
         | 
| 86 | 
            +
                        if t.__class__.__name__ not in ["Normalize", "Resize"]:
         | 
| 87 | 
            +
                            raise ValueError(f"Transform {t.__class__.__name__} not allowed in DISK_IMW dataset")
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    sample["src_image"] = src_img
         | 
| 90 | 
            +
                    sample["trg_image"] = trg_img
         | 
| 91 | 
            +
                    sample["orig_size_src"] = (H_src, W_src)
         | 
| 92 | 
            +
                    sample["orig_size_trg"] = (H_trg, W_trg)
         | 
| 93 | 
            +
                    sample["src_mask"] = src_mask.to(torch.bool)
         | 
| 94 | 
            +
                    sample["trg_mask"] = trg_mask.to(torch.bool)
         | 
| 95 | 
            +
                    sample["F"] = F
         | 
| 96 | 
            +
                    sample["s2t_R"] = s2t_R
         | 
| 97 | 
            +
                    sample["s2t_T"] = s2t_T
         | 
| 98 | 
            +
                    sample["src_camera"] = cam_src
         | 
| 99 | 
            +
                    sample["trg_camera"] = cam_trg
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    return sample
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def get_relative_pose(self, cam_src: Camera, cam_trg: Camera) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 104 | 
            +
                    R = cam_trg.R @ cam_src.R.T
         | 
| 105 | 
            +
                    T = cam_trg.t - R @ cam_src.t
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    return R, T
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def get_F(self, cam_src: Camera, cam_trg: Camera) -> torch.Tensor:
         | 
| 110 | 
            +
                    F = cameras2F(cam_src, cam_trg)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    return F
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
         | 
| 115 | 
            +
                    accumulated_tuples = accumulate(self.tuples_per_scene)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if idx >= sum(self.tuples_per_scene):
         | 
| 118 | 
            +
                        raise IndexError(f"Index {idx} out of bounds")
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    idx_scene = None
         | 
| 121 | 
            +
                    for i, accumulated_tuple in enumerate(accumulated_tuples):
         | 
| 122 | 
            +
                        idx_scene = i
         | 
| 123 | 
            +
                        if idx < accumulated_tuple:
         | 
| 124 | 
            +
                            break
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    return idx_scene, idx_image
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
         | 
| 131 | 
            +
                    possible_scene_ids = list(range(len(self.scenes)))
         | 
| 132 | 
            +
                    possible_scene_ids.remove(scene_id_to_exclude)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    idx_scene = random.choice(possible_scene_ids)
         | 
| 135 | 
            +
                    idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    return idx_scene, idx_image
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            class Scene:
         | 
| 141 | 
            +
                def __init__(self, root_path, scene_data: Dict[str, Any]) -> None:
         | 
| 142 | 
            +
                    self.root_path = root_path
         | 
| 143 | 
            +
                    self.image_path = Path(scene_data["image_path"])
         | 
| 144 | 
            +
                    self.calib_path = Path(scene_data["calib_path"])
         | 
| 145 | 
            +
                    self.image_names = scene_data["images"]
         | 
| 146 | 
            +
                    self.tuples = scene_data["tuples"]
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def __len__(self) -> int:
         | 
| 149 | 
            +
                    return len(self.tuples)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def __getitem__(self, idx: int) -> Dict[str, Any]:
         | 
| 152 | 
            +
                    idx_1 = self.tuples[idx][0]
         | 
| 153 | 
            +
                    idx_2 = self.tuples[idx][1]
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1]) + ".jpg"
         | 
| 156 | 
            +
                    path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2]) + ".jpg"
         | 
| 157 | 
            +
                    path_calib_1 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_1])) + ".h5"
         | 
| 158 | 
            +
                    path_calib_2 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_2])) + ".h5"
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    return path_image_1, path_image_2, path_calib_1, path_calib_2
         | 
    	
        imcui/third_party/RIPE/ripe/data/datasets/disk_megadepth.py
    ADDED
    
    | @@ -0,0 +1,157 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from itertools import accumulate
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            from typing import Any, Callable, Dict, Optional, Tuple
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from torch.utils.data import Dataset
         | 
| 9 | 
            +
            from torchvision.io import read_image
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from ripe import utils
         | 
| 12 | 
            +
            from ripe.data.data_transforms import Compose
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class DISK_Megadepth(Dataset):
         | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self,
         | 
| 20 | 
            +
                    root: str,
         | 
| 21 | 
            +
                    max_scene_size: int,
         | 
| 22 | 
            +
                    stage: str = "train",
         | 
| 23 | 
            +
                    # condition: str = "rain",
         | 
| 24 | 
            +
                    transforms: Optional[Callable] = None,
         | 
| 25 | 
            +
                    positive_only: bool = False,
         | 
| 26 | 
            +
                ) -> None:
         | 
| 27 | 
            +
                    self.root = root
         | 
| 28 | 
            +
                    self.stage = stage
         | 
| 29 | 
            +
                    self.transforms = transforms
         | 
| 30 | 
            +
                    self.positive_only = positive_only
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    if isinstance(self.root, str):
         | 
| 33 | 
            +
                        self.root = Path(self.root)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    if not self.root.exists():
         | 
| 36 | 
            +
                        raise FileNotFoundError(f"Dataset not found at {self.root}")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    if transforms is None:
         | 
| 39 | 
            +
                        self.transforms = Compose([])
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        self.transforms = transforms
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    if self.stage not in ["train"]:
         | 
| 44 | 
            +
                        raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    json_path = self.root / "megadepth" / "dataset.json"
         | 
| 47 | 
            +
                    with open(json_path) as json_file:
         | 
| 48 | 
            +
                        json_data = json.load(json_file)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    self.scenes = []
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    for scene in json_data:
         | 
| 53 | 
            +
                        self.scenes.append(Scene(self.root / "megadepth", json_data[scene], max_scene_size))
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    self.tuples_per_scene = [len(scene) for scene in self.scenes]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    if positive_only:
         | 
| 58 | 
            +
                        log.warning("Using only positive pairs!")
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __len__(self) -> int:
         | 
| 61 | 
            +
                    if self.positive_only:
         | 
| 62 | 
            +
                        return sum(self.tuples_per_scene)
         | 
| 63 | 
            +
                    return 2 * sum(self.tuples_per_scene)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def __getitem__(self, idx: int) -> Dict[str, Any]:
         | 
| 66 | 
            +
                    sample: Any = {}
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    positive_sample = idx % 2 == 0 or self.positive_only
         | 
| 69 | 
            +
                    if not self.positive_only:
         | 
| 70 | 
            +
                        idx = idx // 2
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    sample["label"] = positive_sample
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    if positive_sample:
         | 
| 77 | 
            +
                        sample["src_path"], sample["trg_path"] = self.scenes[i_scene][i_image]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        homography = torch.eye(3, dtype=torch.float32)
         | 
| 80 | 
            +
                    else:
         | 
| 81 | 
            +
                        sample["src_path"], _ = self.scenes[i_scene][i_image]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                        i_scene_other, i_image_other = self._get_other_random_scene_and_image_id(i_scene)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                        sample["trg_path"], _ = self.scenes[i_scene_other][i_image_other]
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                        homography = torch.zeros((3, 3), dtype=torch.float32)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    src_img = read_image(sample["src_path"]) / 255.0
         | 
| 90 | 
            +
                    trg_img = read_image(sample["trg_path"]) / 255.0
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    _, H_src, W_src = src_img.shape
         | 
| 93 | 
            +
                    _, H_trg, W_trg = trg_img.shape
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
         | 
| 96 | 
            +
                    trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    if self.transforms:
         | 
| 99 | 
            +
                        src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    sample["src_image"] = src_img
         | 
| 102 | 
            +
                    sample["trg_image"] = trg_img
         | 
| 103 | 
            +
                    sample["src_mask"] = src_mask.to(torch.bool)
         | 
| 104 | 
            +
                    sample["trg_mask"] = trg_mask.to(torch.bool)
         | 
| 105 | 
            +
                    sample["homography"] = homography
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    return sample
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
         | 
| 110 | 
            +
                    accumulated_tuples = accumulate(self.tuples_per_scene)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    if idx >= sum(self.tuples_per_scene):
         | 
| 113 | 
            +
                        raise IndexError(f"Index {idx} out of bounds")
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    idx_scene = None
         | 
| 116 | 
            +
                    for i, accumulated_tuple in enumerate(accumulated_tuples):
         | 
| 117 | 
            +
                        idx_scene = i
         | 
| 118 | 
            +
                        if idx < accumulated_tuple:
         | 
| 119 | 
            +
                            break
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    return idx_scene, idx_image
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
         | 
| 126 | 
            +
                    possible_scene_ids = list(range(len(self.scenes)))
         | 
| 127 | 
            +
                    possible_scene_ids.remove(scene_id_to_exclude)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    idx_scene = random.choice(possible_scene_ids)
         | 
| 130 | 
            +
                    idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    return idx_scene, idx_image
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            class Scene:
         | 
| 136 | 
            +
                def __init__(self, root_path, scene_data: Dict[str, Any], max_size_scene) -> None:
         | 
| 137 | 
            +
                    self.root_path = root_path
         | 
| 138 | 
            +
                    self.image_path = Path(scene_data["image_path"])
         | 
| 139 | 
            +
                    self.image_names = scene_data["images"]
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # randomly sample tuples
         | 
| 142 | 
            +
                    if max_size_scene > 0:
         | 
| 143 | 
            +
                        self.tuples = random.sample(scene_data["tuples"], min(max_size_scene, len(scene_data["tuples"])))
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def __len__(self) -> int:
         | 
| 146 | 
            +
                    return len(self.tuples)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def __getitem__(self, idx: int) -> Tuple[str, str]:
         | 
| 149 | 
            +
                    idx_1, idx_2 = random.sample([0, 1, 2], 2)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    idx_1 = self.tuples[idx][idx_1]
         | 
| 152 | 
            +
                    idx_2 = self.tuples[idx][idx_2]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1])
         | 
| 155 | 
            +
                    path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2])
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    return path_image_1, path_image_2
         | 
    	
        imcui/third_party/RIPE/ripe/data/datasets/tokyo247.py
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from glob import glob
         | 
| 4 | 
            +
            from typing import Any, Callable, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torch.utils.data import Dataset
         | 
| 8 | 
            +
            from torchvision.io import read_image
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from ripe import utils
         | 
| 11 | 
            +
            from ripe.data.data_transforms import Compose
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class Tokyo247(Dataset):
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self,
         | 
| 19 | 
            +
                    root: str,
         | 
| 20 | 
            +
                    stage: str = "train",
         | 
| 21 | 
            +
                    transforms: Optional[Callable] = None,
         | 
| 22 | 
            +
                    positive_only: bool = False,
         | 
| 23 | 
            +
                ):
         | 
| 24 | 
            +
                    if stage != "train":
         | 
| 25 | 
            +
                        raise ValueError("Tokyo247Dataset only supports the 'train' stage.")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    # check if the root directory exists
         | 
| 28 | 
            +
                    if not os.path.isdir(root):
         | 
| 29 | 
            +
                        raise FileNotFoundError(f"Directory {root} does not exist.")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self.root_dir = root
         | 
| 32 | 
            +
                    self.transforms = transforms if transforms is not None else Compose([])
         | 
| 33 | 
            +
                    self.positive_only = positive_only
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.image_paths = []
         | 
| 36 | 
            +
                    self.positive_pairs = []
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    # Collect images grouped by location folder
         | 
| 39 | 
            +
                    self.locations = {}
         | 
| 40 | 
            +
                    for location_rough in sorted(os.listdir(self.root_dir)):
         | 
| 41 | 
            +
                        location_rough_path = os.path.join(self.root_dir, location_rough)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                        # check if the location_rough_path is a directory
         | 
| 44 | 
            +
                        if not os.path.isdir(location_rough_path):
         | 
| 45 | 
            +
                            continue
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                        for location_fine in sorted(os.listdir(location_rough_path)):
         | 
| 48 | 
            +
                            location_fine_path = os.path.join(self.root_dir, location_rough, location_fine)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                            if os.path.isdir(location_fine_path):
         | 
| 51 | 
            +
                                images = sorted(
         | 
| 52 | 
            +
                                    glob(os.path.join(location_fine_path, "*.png")),
         | 
| 53 | 
            +
                                    key=lambda i: int(i[-7:-4]),
         | 
| 54 | 
            +
                                )
         | 
| 55 | 
            +
                                if len(images) >= 12:
         | 
| 56 | 
            +
                                    self.locations[location_fine] = images
         | 
| 57 | 
            +
                                    self.image_paths.extend(images)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Generate positive pairs
         | 
| 60 | 
            +
                    for _, images in self.locations.items():
         | 
| 61 | 
            +
                        for i in range(len(images) - 1):
         | 
| 62 | 
            +
                            self.positive_pairs.append((images[i], images[i + 1]))
         | 
| 63 | 
            +
                        self.positive_pairs.append((images[-1], images[0]))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    if positive_only:
         | 
| 66 | 
            +
                        log.warning("Using only positive pairs!")
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    log.info(f"Found {len(self.positive_pairs)} image pairs.")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def __len__(self):
         | 
| 71 | 
            +
                    if self.positive_only:
         | 
| 72 | 
            +
                        return len(self.positive_pairs)
         | 
| 73 | 
            +
                    return 2 * len(self.positive_pairs)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __getitem__(self, idx):
         | 
| 76 | 
            +
                    sample: Any = {}
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    positive_sample = (idx % 2 == 0) or (self.positive_only)
         | 
| 79 | 
            +
                    if not self.positive_only:
         | 
| 80 | 
            +
                        idx = idx // 2
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    sample["label"] = positive_sample
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    if positive_sample:  # Positive pair
         | 
| 85 | 
            +
                        img1_path, img2_path = self.positive_pairs[idx]
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                        assert os.path.dirname(img1_path) == os.path.dirname(img2_path), (
         | 
| 88 | 
            +
                            f"Source and target image mismatch: {img1_path} vs {img2_path}"
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        homography = torch.eye(3, dtype=torch.float32)
         | 
| 92 | 
            +
                    else:  # Negative pair
         | 
| 93 | 
            +
                        img1_path = random.choice(self.image_paths)
         | 
| 94 | 
            +
                        img2_path = random.choice(self.image_paths)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                        # Ensure images are from different folders
         | 
| 97 | 
            +
                        esc = 0
         | 
| 98 | 
            +
                        while os.path.dirname(img1_path) == os.path.dirname(img2_path):
         | 
| 99 | 
            +
                            img2_path = random.choice(self.image_paths)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                            esc += 1
         | 
| 102 | 
            +
                            if esc > 100:
         | 
| 103 | 
            +
                                raise RuntimeError("Could not find a negative pair.")
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                        assert os.path.dirname(img1_path) != os.path.dirname(img2_path), (
         | 
| 106 | 
            +
                            f"Source and target image match for negative pair: {img1_path} vs {img2_path}"
         | 
| 107 | 
            +
                        )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        homography = torch.zeros((3, 3), dtype=torch.float32)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    sample["src_path"] = img1_path
         | 
| 112 | 
            +
                    sample["trg_path"] = img2_path
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # Load images
         | 
| 115 | 
            +
                    src_img = read_image(sample["src_path"]) / 255.0
         | 
| 116 | 
            +
                    trg_img = read_image(sample["trg_path"]) / 255.0
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    _, H_src, W_src = src_img.shape
         | 
| 119 | 
            +
                    _, H_trg, W_trg = src_img.shape
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
         | 
| 122 | 
            +
                    trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    # Apply transformations
         | 
| 125 | 
            +
                    if self.transforms:
         | 
| 126 | 
            +
                        src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    sample["src_image"] = src_img
         | 
| 129 | 
            +
                    sample["trg_image"] = trg_img
         | 
| 130 | 
            +
                    sample["src_mask"] = src_mask.to(torch.bool)
         | 
| 131 | 
            +
                    sample["trg_mask"] = trg_mask.to(torch.bool)
         | 
| 132 | 
            +
                    sample["homography"] = homography
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    return sample
         | 
    	
        imcui/third_party/RIPE/ripe/losses/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        imcui/third_party/RIPE/ripe/losses/contrastive_loss.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def second_nearest_neighbor(desc1, desc2):
         | 
| 7 | 
            +
                if desc2.shape[0] < 2:  # We cannot perform snn check, so output empty matches
         | 
| 8 | 
            +
                    raise ValueError("desc2 should have at least 2 descriptors")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                dist = torch.cdist(desc1, desc2, p=2)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                vals, idxs = torch.topk(dist, 2, dim=1, largest=False)
         | 
| 13 | 
            +
                idxs_in_2 = idxs[:, 1]
         | 
| 14 | 
            +
                idxs_in_1 = torch.arange(0, idxs_in_2.size(0), device=dist.device)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                matches_idxs = torch.cat([idxs_in_1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                return vals[:, 1].view(-1, 1), matches_idxs
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def contrastive_loss(
         | 
| 22 | 
            +
                desc1,
         | 
| 23 | 
            +
                desc2,
         | 
| 24 | 
            +
                matches,
         | 
| 25 | 
            +
                inliers,
         | 
| 26 | 
            +
                label,
         | 
| 27 | 
            +
                logits_1,
         | 
| 28 | 
            +
                logits_2,
         | 
| 29 | 
            +
                pos_margin=1.0,
         | 
| 30 | 
            +
                neg_margin=1.0,
         | 
| 31 | 
            +
            ):
         | 
| 32 | 
            +
                if inliers.sum() < 8:  # if there are too few inliers, calculate loss on all matches
         | 
| 33 | 
            +
                    inliers = torch.ones_like(inliers)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                matched_inliers_descs1 = desc1[matches[:, 0][inliers]]
         | 
| 36 | 
            +
                matched_inliers_descs2 = desc2[matches[:, 1][inliers]]
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                if logits_1 is not None and logits_2 is not None:
         | 
| 39 | 
            +
                    matched_inliers_logits1 = logits_1[matches[:, 0][inliers]]
         | 
| 40 | 
            +
                    matched_inliers_logits2 = logits_2[matches[:, 1][inliers]]
         | 
| 41 | 
            +
                    logits = torch.minimum(matched_inliers_logits1, matched_inliers_logits2)
         | 
| 42 | 
            +
                else:
         | 
| 43 | 
            +
                    logits = torch.ones_like(matches[:, 0][inliers])
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if label:
         | 
| 46 | 
            +
                    snn_match_dists_1, idx1 = second_nearest_neighbor(matched_inliers_descs1, desc2)
         | 
| 47 | 
            +
                    snn_match_dists_2, idx2 = second_nearest_neighbor(matched_inliers_descs2, desc1)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    dists = torch.hstack((snn_match_dists_1, snn_match_dists_2))
         | 
| 50 | 
            +
                    min_dists_idx = torch.min(dists, dim=1).indices.unsqueeze(1)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    dists_hard = torch.gather(dists, 1, min_dists_idx).squeeze(-1)
         | 
| 53 | 
            +
                    dists_pos = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    contrastive_loss = torch.clamp(pos_margin + dists_pos - dists_hard, min=0.0)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    contrastive_loss = contrastive_loss * logits
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8)  # small epsilon to avoid division by zero
         | 
| 60 | 
            +
                else:
         | 
| 61 | 
            +
                    dists = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
         | 
| 62 | 
            +
                    contrastive_loss = torch.clamp(neg_margin - dists, min=0.0)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    contrastive_loss = contrastive_loss * logits
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8)  # small epsilon to avoid division by zero
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                return contrastive_loss
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class ContrastiveLoss(nn.Module):
         | 
| 72 | 
            +
                def __init__(self, pos_margin=1.0, neg_margin=1.0):
         | 
| 73 | 
            +
                    super().__init__()
         | 
| 74 | 
            +
                    self.pos_margin = pos_margin
         | 
| 75 | 
            +
                    self.neg_margin = neg_margin
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def forward(self, desc1, desc2, matches, inliers, label, logits_1=None, logits_2=None):
         | 
| 78 | 
            +
                    return contrastive_loss(
         | 
| 79 | 
            +
                        desc1,
         | 
| 80 | 
            +
                        desc2,
         | 
| 81 | 
            +
                        matches,
         | 
| 82 | 
            +
                        inliers,
         | 
| 83 | 
            +
                        label,
         | 
| 84 | 
            +
                        logits_1,
         | 
| 85 | 
            +
                        logits_2,
         | 
| 86 | 
            +
                        self.pos_margin,
         | 
| 87 | 
            +
                        self.neg_margin,
         | 
| 88 | 
            +
                    )
         | 
    	
        imcui/third_party/RIPE/ripe/matcher/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        imcui/third_party/RIPE/ripe/matcher/concurrent_matcher.py
    ADDED
    
    | @@ -0,0 +1,97 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import concurrent.futures
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class ConcurrentMatcher:
         | 
| 7 | 
            +
                """A class that performs matching and geometric filtering in parallel using a thread pool executor.
         | 
| 8 | 
            +
                It matches keypoints from two sets of descriptors and applies a robust estimator to filter the matches based on geometric constraints.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                Args:
         | 
| 11 | 
            +
                    matcher (callable): A callable that takes two sets of descriptors and returns distances and indices of matches.
         | 
| 12 | 
            +
                    robust_estimator (callable): A callable that estimates a geometric transformation and returns inliers.
         | 
| 13 | 
            +
                    min_num_matches (int, optional): Minimum number of matches required to perform geometric filtering. Defaults to 8.
         | 
| 14 | 
            +
                    max_workers (int, optional): Maximum number of threads in the thread pool executor. Defaults to 12.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self, matcher, robust_estimator, min_num_matches=8, max_workers=12):
         | 
| 18 | 
            +
                    self.matcher = matcher
         | 
| 19 | 
            +
                    self.robust_estimator = robust_estimator
         | 
| 20 | 
            +
                    self.min_num_matches = min_num_matches
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                @torch.no_grad()
         | 
| 25 | 
            +
                def __call__(
         | 
| 26 | 
            +
                    self,
         | 
| 27 | 
            +
                    kpts1,
         | 
| 28 | 
            +
                    kpts2,
         | 
| 29 | 
            +
                    pdesc1,
         | 
| 30 | 
            +
                    pdesc2,
         | 
| 31 | 
            +
                    selected_mask1,
         | 
| 32 | 
            +
                    selected_mask2,
         | 
| 33 | 
            +
                    inl_th,
         | 
| 34 | 
            +
                    label=None,
         | 
| 35 | 
            +
                ):
         | 
| 36 | 
            +
                    dev = pdesc1.device
         | 
| 37 | 
            +
                    B = pdesc1.shape[0]
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    batch_rel_idx_matches = [None] * B
         | 
| 40 | 
            +
                    batch_idx_matches = [None] * B
         | 
| 41 | 
            +
                    future_results = [None] * B
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    for b in range(B):
         | 
| 44 | 
            +
                        if selected_mask1[b].sum() < 16 or selected_mask2[b].sum() < 16:
         | 
| 45 | 
            +
                            continue
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                        dists, idx_matches = self.matcher(pdesc1[b][selected_mask1[b]], pdesc2[b][selected_mask2[b]])
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                        batch_rel_idx_matches[b] = idx_matches.clone()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        # calculate ABSOLUTE indexes
         | 
| 52 | 
            +
                        idx_matches[:, 0] = torch.nonzero(selected_mask1[b], as_tuple=False)[idx_matches[:, 0]].squeeze()
         | 
| 53 | 
            +
                        idx_matches[:, 1] = torch.nonzero(selected_mask2[b], as_tuple=False)[idx_matches[:, 1]].squeeze()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                        batch_idx_matches[b] = idx_matches
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                        # if not enough matches
         | 
| 58 | 
            +
                        if idx_matches.shape[0] < self.min_num_matches:
         | 
| 59 | 
            +
                            ransac_inliers = torch.zeros((idx_matches.shape[0]), device=dev).bool()
         | 
| 60 | 
            +
                            future_results[b] = (None, ransac_inliers)
         | 
| 61 | 
            +
                            continue
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                        # use label information to exclude negative pairs from geometric filtering process -> enforces more descriminative descriptors
         | 
| 64 | 
            +
                        if label is not None and label[b] == 0:
         | 
| 65 | 
            +
                            ransac_inliers = torch.ones((idx_matches.shape[0]), device=dev).bool()
         | 
| 66 | 
            +
                            future_results[b] = (None, ransac_inliers)
         | 
| 67 | 
            +
                            continue
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                        mkpts1 = kpts1[b][idx_matches[:, 0]]
         | 
| 70 | 
            +
                        mkpts2 = kpts2[b][idx_matches[:, 1]]
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                        future_results[b] = self.executor.submit(self.robust_estimator, mkpts1, mkpts2, inl_th)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    batch_ransac_inliers = [None] * B
         | 
| 75 | 
            +
                    batch_Fm = [None] * B
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    for b in range(B):
         | 
| 78 | 
            +
                        future_result = future_results[b]
         | 
| 79 | 
            +
                        if future_result is None:
         | 
| 80 | 
            +
                            ransac_inliers = None
         | 
| 81 | 
            +
                            Fm = None
         | 
| 82 | 
            +
                        elif isinstance(future_result, tuple):
         | 
| 83 | 
            +
                            Fm, ransac_inliers = future_result
         | 
| 84 | 
            +
                        else:
         | 
| 85 | 
            +
                            Fm, ransac_inliers = future_result.result()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                            # if no inliers
         | 
| 88 | 
            +
                            if ransac_inliers.sum() == 0:
         | 
| 89 | 
            +
                                ransac_inliers = ransac_inliers.squeeze(
         | 
| 90 | 
            +
                                    -1
         | 
| 91 | 
            +
                                )  # kornia.geometry.ransac.RANSAC returns (N, 1) tensor if no inliers and (N,) tensor if inliers
         | 
| 92 | 
            +
                                Fm = None
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                        batch_ransac_inliers[b] = ransac_inliers
         | 
| 95 | 
            +
                        batch_Fm[b] = Fm
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    return batch_rel_idx_matches, batch_idx_matches, batch_ransac_inliers, batch_Fm
         | 
    	
        imcui/third_party/RIPE/ripe/matcher/pose_estimator_poselib.py
    ADDED
    
    | @@ -0,0 +1,31 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import poselib
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class PoseLibRelativePoseEstimator:
         | 
| 6 | 
            +
                """PoseLibRelativePoseEstimator estimates the fundamental matrix using poselib library.
         | 
| 7 | 
            +
                It uses the poselib's estimate_fundamental function to compute the fundamental matrix and inliers based on the provided points.
         | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                    None
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __init__(self):
         | 
| 13 | 
            +
                    pass
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __call__(self, pts0, pts1, inl_th):
         | 
| 16 | 
            +
                    F, info = poselib.estimate_fundamental(
         | 
| 17 | 
            +
                        pts0.cpu().numpy(),
         | 
| 18 | 
            +
                        pts1.cpu().numpy(),
         | 
| 19 | 
            +
                        {
         | 
| 20 | 
            +
                            "max_epipolar_error": inl_th,
         | 
| 21 | 
            +
                        },
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    success = F is not None
         | 
| 25 | 
            +
                    if success:
         | 
| 26 | 
            +
                        inliers = info.pop("inliers")
         | 
| 27 | 
            +
                        inliers = torch.tensor(inliers, dtype=torch.bool, device=pts0.device)
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        inliers = torch.zeros(pts0.shape[0], dtype=torch.bool, device=pts0.device)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    return F, inliers
         | 
    	
        imcui/third_party/RIPE/ripe/model_zoo/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .vgg_hyper import vgg_hyper  # noqa: F401
         | 
    	
        imcui/third_party/RIPE/ripe/model_zoo/vgg_hyper.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from pathlib import Path
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from ripe.models.backbones.vgg import VGG
         | 
| 6 | 
            +
            from ripe.models.ripe import RIPE
         | 
| 7 | 
            +
            from ripe.models.upsampler.hypercolumn_features import HyperColumnFeatures
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def vgg_hyper(model_path: Path = None, desc_shares=None):
         | 
| 11 | 
            +
                if model_path is None:
         | 
| 12 | 
            +
                    # check if the weights file exists in the current directory
         | 
| 13 | 
            +
                    model_path = Path("/tmp/ripe_weights.pth")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    if model_path.exists():
         | 
| 16 | 
            +
                        print(f"Using existing weights from {model_path}")
         | 
| 17 | 
            +
                    else:
         | 
| 18 | 
            +
                        print("Weights file not found. Downloading ...")
         | 
| 19 | 
            +
                        torch.hub.download_url_to_file(
         | 
| 20 | 
            +
                            "https://cvg.hhi.fraunhofer.de/RIPE/ripe_weights.pth",
         | 
| 21 | 
            +
                            "/tmp/ripe_weights.pth",
         | 
| 22 | 
            +
                        )
         | 
| 23 | 
            +
                else:
         | 
| 24 | 
            +
                    if not model_path.exists():
         | 
| 25 | 
            +
                        print(f"Error: {model_path} does not exist.")
         | 
| 26 | 
            +
                        raise FileNotFoundError(f"Error: {model_path} does not exist.")
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                backbone = VGG(pretrained=False)
         | 
| 29 | 
            +
                upsampler = HyperColumnFeatures()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                extractor = RIPE(
         | 
| 32 | 
            +
                    net=backbone,
         | 
| 33 | 
            +
                    upsampler=upsampler,
         | 
| 34 | 
            +
                    desc_shares=desc_shares,
         | 
| 35 | 
            +
                )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                extractor.load_state_dict(torch.load(model_path, map_location="cpu"))
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return extractor
         | 
    	
        imcui/third_party/RIPE/ripe/models/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        imcui/third_party/RIPE/ripe/models/backbones/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        imcui/third_party/RIPE/ripe/models/backbones/backbone_base.py
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class BackboneBase(nn.Module):
         | 
| 6 | 
            +
                """Base class for backbone networks. Provides a standard interface for preprocessing inputs and
         | 
| 7 | 
            +
                defining encoder dimensions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                Args:
         | 
| 10 | 
            +
                    nchannels (int): Number of input channels.
         | 
| 11 | 
            +
                    use_instance_norm (bool): Whether to apply instance normalization.
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(self, nchannels=3, use_instance_norm=False):
         | 
| 15 | 
            +
                    super().__init__()
         | 
| 16 | 
            +
                    assert nchannels > 0, "Number of channels must be positive."
         | 
| 17 | 
            +
                    self.nchannels = nchannels
         | 
| 18 | 
            +
                    self.use_instance_norm = use_instance_norm
         | 
| 19 | 
            +
                    self.norm = nn.InstanceNorm2d(nchannels) if use_instance_norm else None
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def get_dim_layers_encoder(self):
         | 
| 22 | 
            +
                    """Get dimensions of encoder layers."""
         | 
| 23 | 
            +
                    raise NotImplementedError("Subclasses must implement this method.")
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def _forward(self, x):
         | 
| 26 | 
            +
                    """Define the forward pass for the backbone."""
         | 
| 27 | 
            +
                    raise NotImplementedError("Subclasses must implement this method.")
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def forward(self, x: torch.Tensor, preprocess=True):
         | 
| 30 | 
            +
                    """Forward pass with optional preprocessing.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    Args:
         | 
| 33 | 
            +
                        x (Tensor): Input tensor.
         | 
| 34 | 
            +
                        preprocess (bool): Whether to apply channel reduction.
         | 
| 35 | 
            +
                    """
         | 
| 36 | 
            +
                    if preprocess:
         | 
| 37 | 
            +
                        if x.dim() != 4:
         | 
| 38 | 
            +
                            if x.dim() == 2 and x.shape[0] > 3 and x.shape[1] > 3:
         | 
| 39 | 
            +
                                x = x.unsqueeze(0).unsqueeze(0)
         | 
| 40 | 
            +
                            elif x.dim() == 3:
         | 
| 41 | 
            +
                                x = x.unsqueeze(0)
         | 
| 42 | 
            +
                            else:
         | 
| 43 | 
            +
                                raise ValueError(f"Unexpected input shape: {x.shape}")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        if self.nchannels == 1 and x.shape[1] != 1:
         | 
| 46 | 
            +
                            if len(x.shape) == 4:  # Assumes (batch, channel, height, width)
         | 
| 47 | 
            +
                                x = torch.mean(x, axis=1, keepdim=True)
         | 
| 48 | 
            +
                            else:
         | 
| 49 | 
            +
                                raise ValueError(f"Unexpected input shape: {x.shape}")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        #
         | 
| 52 | 
            +
                        if self.nchannels == 3 and x.shape[1] == 1:
         | 
| 53 | 
            +
                            if len(x.shape) == 4:
         | 
| 54 | 
            +
                                x = x.repeat(1, 3, 1, 1)
         | 
| 55 | 
            +
                            else:
         | 
| 56 | 
            +
                                raise ValueError(f"Unexpected input shape: {x.shape}")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    if self.use_instance_norm:
         | 
| 59 | 
            +
                        x = self.norm(x)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    return self._forward(x)
         | 
    	
        imcui/third_party/RIPE/ripe/models/backbones/vgg.py
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .backbone_base import BackboneBase
         | 
| 7 | 
            +
            from .vgg_utils import VGG19, ConvRefiner, Decoder
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class VGG(BackboneBase):
         | 
| 11 | 
            +
                def __init__(self, nchannels=3, pretrained=True, use_instance_norm=True, mode="dect"):
         | 
| 12 | 
            +
                    super().__init__(nchannels=nchannels, use_instance_norm=use_instance_norm)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    self.nchannels = nchannels
         | 
| 15 | 
            +
                    self.mode = mode
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    if self.mode not in ["dect", "desc", "dect+desc"]:
         | 
| 18 | 
            +
                        raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    NUM_OUTPUT_CHANNELS, hidden_blocks = self._get_mode_params(mode)
         | 
| 21 | 
            +
                    conv_refiner = self._create_conv_refiner(NUM_OUTPUT_CHANNELS, hidden_blocks)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.encoder = VGG19(pretrained=pretrained, num_input_channels=nchannels)
         | 
| 24 | 
            +
                    self.decoder = Decoder(conv_refiner, num_prototypes=NUM_OUTPUT_CHANNELS)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def _get_mode_params(self, mode):
         | 
| 27 | 
            +
                    """Get the number of output channels and the number of hidden blocks for the ConvRefiner.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    Depending on the mode, the ConvRefiner will have a different number of output channels.
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    if mode == "dect":
         | 
| 33 | 
            +
                        return 1, 8
         | 
| 34 | 
            +
                    elif mode == "desc":
         | 
| 35 | 
            +
                        return 256, 5
         | 
| 36 | 
            +
                    elif mode == "dect+desc":
         | 
| 37 | 
            +
                        return 256 + 1, 8
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def _create_conv_refiner(self, num_output_channels, hidden_blocks):
         | 
| 40 | 
            +
                    return nn.ModuleDict(
         | 
| 41 | 
            +
                        {
         | 
| 42 | 
            +
                            "8": ConvRefiner(
         | 
| 43 | 
            +
                                512,
         | 
| 44 | 
            +
                                512,
         | 
| 45 | 
            +
                                256 + num_output_channels,
         | 
| 46 | 
            +
                                hidden_blocks=hidden_blocks,
         | 
| 47 | 
            +
                                residual=True,
         | 
| 48 | 
            +
                            ),
         | 
| 49 | 
            +
                            "4": ConvRefiner(
         | 
| 50 | 
            +
                                256 + 256,
         | 
| 51 | 
            +
                                256,
         | 
| 52 | 
            +
                                128 + num_output_channels,
         | 
| 53 | 
            +
                                hidden_blocks=hidden_blocks,
         | 
| 54 | 
            +
                                residual=True,
         | 
| 55 | 
            +
                            ),
         | 
| 56 | 
            +
                            "2": ConvRefiner(
         | 
| 57 | 
            +
                                128 + 128,
         | 
| 58 | 
            +
                                128,
         | 
| 59 | 
            +
                                64 + num_output_channels,
         | 
| 60 | 
            +
                                hidden_blocks=hidden_blocks,
         | 
| 61 | 
            +
                                residual=True,
         | 
| 62 | 
            +
                            ),
         | 
| 63 | 
            +
                            "1": ConvRefiner(
         | 
| 64 | 
            +
                                64 + 64,
         | 
| 65 | 
            +
                                64,
         | 
| 66 | 
            +
                                1 + num_output_channels,
         | 
| 67 | 
            +
                                hidden_blocks=hidden_blocks,
         | 
| 68 | 
            +
                                residual=True,
         | 
| 69 | 
            +
                            ),
         | 
| 70 | 
            +
                        }
         | 
| 71 | 
            +
                    )
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def get_dim_layers_encoder(self):
         | 
| 74 | 
            +
                    return self.encoder.get_dim_layers()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def _forward(self, x):
         | 
| 77 | 
            +
                    features, sizes = self.encoder(x)
         | 
| 78 | 
            +
                    output = 0
         | 
| 79 | 
            +
                    context = None
         | 
| 80 | 
            +
                    scales = self.decoder.scales
         | 
| 81 | 
            +
                    for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
         | 
| 82 | 
            +
                        delta_descriptor, context = self.decoder(feature_map, scale=scale, context=context)
         | 
| 83 | 
            +
                        output = output + delta_descriptor
         | 
| 84 | 
            +
                        if idx < len(scales) - 1:
         | 
| 85 | 
            +
                            size = sizes[-(idx + 2)]
         | 
| 86 | 
            +
                            output = F.interpolate(output, size=size, mode="bilinear", align_corners=False)
         | 
| 87 | 
            +
                            context = F.interpolate(context, size=size, mode="bilinear", align_corners=False)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if self.mode == "dect":
         | 
| 90 | 
            +
                        return {"heatmap": output, "coarse_descs": features}
         | 
| 91 | 
            +
                    elif self.mode == "desc":
         | 
| 92 | 
            +
                        return {"fine_descs": output, "coarse_descs": features}
         | 
| 93 | 
            +
                    elif self.mode == "dect+desc":
         | 
| 94 | 
            +
                        logits = output[:, :1].contiguous()
         | 
| 95 | 
            +
                        descs = output[:, 1:].contiguous()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        return {"heatmap": logits, "fine_descs": descs, "coarse_descs": features}
         | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
         | 
    	
        imcui/third_party/RIPE/ripe/models/backbones/vgg_utils.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torchvision.models as tvm
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from ripe import utils
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Decoder(nn.Module):
         | 
| 13 | 
            +
                def __init__(self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs) -> None:
         | 
| 14 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 15 | 
            +
                    self.layers = layers
         | 
| 16 | 
            +
                    self.scales = self.layers.keys()
         | 
| 17 | 
            +
                    self.super_resolution = super_resolution
         | 
| 18 | 
            +
                    self.num_prototypes = num_prototypes
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def forward(self, features, context=None, scale=None):
         | 
| 21 | 
            +
                    if context is not None:
         | 
| 22 | 
            +
                        features = torch.cat((features, context), dim=1)
         | 
| 23 | 
            +
                    stuff = self.layers[scale](features)
         | 
| 24 | 
            +
                    logits, context = (
         | 
| 25 | 
            +
                        stuff[:, : self.num_prototypes],
         | 
| 26 | 
            +
                        stuff[:, self.num_prototypes :],
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
                    return logits, context
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ConvRefiner(nn.Module):
         | 
| 32 | 
            +
                def __init__(
         | 
| 33 | 
            +
                    self,
         | 
| 34 | 
            +
                    in_dim=6,
         | 
| 35 | 
            +
                    hidden_dim=16,
         | 
| 36 | 
            +
                    out_dim=2,
         | 
| 37 | 
            +
                    dw=True,
         | 
| 38 | 
            +
                    kernel_size=5,
         | 
| 39 | 
            +
                    hidden_blocks=5,
         | 
| 40 | 
            +
                    residual=False,
         | 
| 41 | 
            +
                ):
         | 
| 42 | 
            +
                    super().__init__()
         | 
| 43 | 
            +
                    self.block1 = self.create_block(
         | 
| 44 | 
            +
                        in_dim,
         | 
| 45 | 
            +
                        hidden_dim,
         | 
| 46 | 
            +
                        dw=False,
         | 
| 47 | 
            +
                        kernel_size=1,
         | 
| 48 | 
            +
                    )
         | 
| 49 | 
            +
                    self.hidden_blocks = nn.Sequential(
         | 
| 50 | 
            +
                        *[
         | 
| 51 | 
            +
                            self.create_block(
         | 
| 52 | 
            +
                                hidden_dim,
         | 
| 53 | 
            +
                                hidden_dim,
         | 
| 54 | 
            +
                                dw=dw,
         | 
| 55 | 
            +
                                kernel_size=kernel_size,
         | 
| 56 | 
            +
                            )
         | 
| 57 | 
            +
                            for hb in range(hidden_blocks)
         | 
| 58 | 
            +
                        ]
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
                    self.hidden_blocks = self.hidden_blocks
         | 
| 61 | 
            +
                    self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
         | 
| 62 | 
            +
                    self.residual = residual
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def create_block(
         | 
| 65 | 
            +
                    self,
         | 
| 66 | 
            +
                    in_dim,
         | 
| 67 | 
            +
                    out_dim,
         | 
| 68 | 
            +
                    dw=True,
         | 
| 69 | 
            +
                    kernel_size=5,
         | 
| 70 | 
            +
                    bias=True,
         | 
| 71 | 
            +
                    norm_type=nn.BatchNorm2d,
         | 
| 72 | 
            +
                ):
         | 
| 73 | 
            +
                    num_groups = 1 if not dw else in_dim
         | 
| 74 | 
            +
                    if dw:
         | 
| 75 | 
            +
                        assert out_dim % in_dim == 0, "outdim must be divisible by indim for depthwise"
         | 
| 76 | 
            +
                    conv1 = nn.Conv2d(
         | 
| 77 | 
            +
                        in_dim,
         | 
| 78 | 
            +
                        out_dim,
         | 
| 79 | 
            +
                        kernel_size=kernel_size,
         | 
| 80 | 
            +
                        stride=1,
         | 
| 81 | 
            +
                        padding=kernel_size // 2,
         | 
| 82 | 
            +
                        groups=num_groups,
         | 
| 83 | 
            +
                        bias=bias,
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels=out_dim)
         | 
| 86 | 
            +
                    relu = nn.ReLU(inplace=True)
         | 
| 87 | 
            +
                    conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
         | 
| 88 | 
            +
                    return nn.Sequential(conv1, norm, relu, conv2)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def forward(self, feats):
         | 
| 91 | 
            +
                    b, c, hs, ws = feats.shape
         | 
| 92 | 
            +
                    x0 = self.block1(feats)
         | 
| 93 | 
            +
                    x = self.hidden_blocks(x0)
         | 
| 94 | 
            +
                    if self.residual:
         | 
| 95 | 
            +
                        x = (x + x0) / 1.4
         | 
| 96 | 
            +
                    x = self.out_conv(x)
         | 
| 97 | 
            +
                    return x
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            class VGG19(nn.Module):
         | 
| 101 | 
            +
                def __init__(self, pretrained=False, num_input_channels=3) -> None:
         | 
| 102 | 
            +
                    super().__init__()
         | 
| 103 | 
            +
                    self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
         | 
| 104 | 
            +
                    # Maxpool layers: 6, 13, 26, 39
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    if num_input_channels != 3:
         | 
| 107 | 
            +
                        log.info(f"Changing input channels from 3 to {num_input_channels}")
         | 
| 108 | 
            +
                        self.layers[0] = nn.Conv2d(num_input_channels, 64, 3, 1, 1)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def get_dim_layers(self):
         | 
| 111 | 
            +
                    return [64, 128, 256, 512]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def forward(self, x, **kwargs):
         | 
| 114 | 
            +
                    feats = []
         | 
| 115 | 
            +
                    sizes = []
         | 
| 116 | 
            +
                    for layer in self.layers:
         | 
| 117 | 
            +
                        if isinstance(layer, nn.MaxPool2d):
         | 
| 118 | 
            +
                            feats.append(x)
         | 
| 119 | 
            +
                            sizes.append(x.shape[-2:])
         | 
| 120 | 
            +
                        x = layer(x)
         | 
| 121 | 
            +
                    return feats, sizes
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            class VGG(nn.Module):
         | 
| 125 | 
            +
                def __init__(self, size="19", pretrained=False) -> None:
         | 
| 126 | 
            +
                    super().__init__()
         | 
| 127 | 
            +
                    if size == "11":
         | 
| 128 | 
            +
                        self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22])
         | 
| 129 | 
            +
                    elif size == "13":
         | 
| 130 | 
            +
                        self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28])
         | 
| 131 | 
            +
                    elif size == "19":
         | 
| 132 | 
            +
                        self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
         | 
| 133 | 
            +
                    # Maxpool layers: 6, 13, 26, 39
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def forward(self, x, **kwargs):
         | 
| 136 | 
            +
                    feats = []
         | 
| 137 | 
            +
                    sizes = []
         | 
| 138 | 
            +
                    for layer in self.layers:
         | 
| 139 | 
            +
                        if isinstance(layer, nn.MaxPool2d):
         | 
| 140 | 
            +
                            feats.append(x)
         | 
| 141 | 
            +
                            sizes.append(x.shape[-2:])
         | 
| 142 | 
            +
                        x = layer(x)
         | 
| 143 | 
            +
                    return feats, sizes
         | 
    	
        imcui/third_party/RIPE/ripe/models/ripe.py
    ADDED
    
    | @@ -0,0 +1,303 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import List, Optional
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ripe import utils
         | 
| 9 | 
            +
            from ripe.utils.utils import gridify
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            log = utils.get_pylogger(__name__)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class KeypointSampler(nn.Module):
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                Sample keypoints according to a Heatmap
         | 
| 17 | 
            +
                Adapted from: https://github.com/verlab/DALF_CVPR_2023/blob/main/modules/models/DALF.py
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def __init__(self, window_size=8):
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
                    self.window_size = window_size
         | 
| 23 | 
            +
                    self.idx_cells = None  # Cache for meshgrid indices
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def sample(self, grid):
         | 
| 26 | 
            +
                    """
         | 
| 27 | 
            +
                    Sample keypoints given a grid where each cell has logits stacked in last dimension
         | 
| 28 | 
            +
                    Input
         | 
| 29 | 
            +
                      grid: [B, C, H//w, W//w, w*w]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    Returns
         | 
| 32 | 
            +
                      log_probs: [B, C, H//w, W//w ] - logprobs of selected samples
         | 
| 33 | 
            +
                      choices: [B, C, H//w, W//w] indices of choices
         | 
| 34 | 
            +
                      accept_mask: [B, C, H//w, W//w] mask of accepted keypoints
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    chooser = torch.distributions.Categorical(logits=grid)
         | 
| 38 | 
            +
                    choices = chooser.sample()
         | 
| 39 | 
            +
                    logits_selected = torch.gather(grid, -1, choices.unsqueeze(-1)).squeeze(-1)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    flipper = torch.distributions.Bernoulli(logits=logits_selected)
         | 
| 42 | 
            +
                    accepted_choices = flipper.sample()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # Sum log-probabilities is equivalent to multiplying the probabilities
         | 
| 45 | 
            +
                    log_probs = chooser.log_prob(choices) + flipper.log_prob(accepted_choices)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    accept_mask = accepted_choices.gt(0)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    return (
         | 
| 50 | 
            +
                        log_probs.squeeze(1),
         | 
| 51 | 
            +
                        choices,
         | 
| 52 | 
            +
                        accept_mask.squeeze(1),
         | 
| 53 | 
            +
                        logits_selected.squeeze(1),
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def precompute_idx_cells(self, H, W, device):
         | 
| 57 | 
            +
                    idx_cells = gridify(
         | 
| 58 | 
            +
                        torch.dstack(
         | 
| 59 | 
            +
                            torch.meshgrid(
         | 
| 60 | 
            +
                                torch.arange(H, dtype=torch.float32, device=device),
         | 
| 61 | 
            +
                                torch.arange(W, dtype=torch.float32, device=device),
         | 
| 62 | 
            +
                            )
         | 
| 63 | 
            +
                        )
         | 
| 64 | 
            +
                        .permute(2, 0, 1)
         | 
| 65 | 
            +
                        .unsqueeze(0)
         | 
| 66 | 
            +
                        .expand(1, -1, -1, -1),
         | 
| 67 | 
            +
                        window_size=self.window_size,
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    return idx_cells
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def forward(self, x, mask_padding=None):
         | 
| 73 | 
            +
                    """
         | 
| 74 | 
            +
                    Sample keypoints from a heatmap
         | 
| 75 | 
            +
                    Input
         | 
| 76 | 
            +
                      x: [B, C, H, W] Heatmap
         | 
| 77 | 
            +
                      mask_padding: [B, 1, H, W] Mask for padding (optional)
         | 
| 78 | 
            +
                    Returns
         | 
| 79 | 
            +
                        keypoints: [B, H//w, W//w, 2] Keypoints in (x, y) format
         | 
| 80 | 
            +
                        log_probs: [B, H//w, W//w] Log probabilities of selected keypoints
         | 
| 81 | 
            +
                        mask: [B, H//w, W//w] Mask of accepted keypoints
         | 
| 82 | 
            +
                        mask_padding: [B, 1, H//w, W//w] Mask of padding (optional)
         | 
| 83 | 
            +
                        logits_selected: [B, H//w, W//w] Logits of selected keypoints
         | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    B, C, H, W = x.shape
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    keypoint_cells = gridify(x, self.window_size)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    mask_padding = (
         | 
| 91 | 
            +
                        (torch.min(gridify(mask_padding, self.window_size), dim=4).values) if mask_padding is not None else None
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    if self.idx_cells is None or self.idx_cells.shape[2:4] != (
         | 
| 95 | 
            +
                        H // self.window_size,
         | 
| 96 | 
            +
                        W // self.window_size,
         | 
| 97 | 
            +
                    ):
         | 
| 98 | 
            +
                        self.idx_cells = self.precompute_idx_cells(H, W, x.device)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    log_probs, idx, mask, logits_selected = self.sample(keypoint_cells)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    keypoints = (
         | 
| 103 | 
            +
                        torch.gather(
         | 
| 104 | 
            +
                            self.idx_cells.expand(B, -1, -1, -1, -1),
         | 
| 105 | 
            +
                            -1,
         | 
| 106 | 
            +
                            idx.repeat(1, 2, 1, 1).unsqueeze(-1),
         | 
| 107 | 
            +
                        )
         | 
| 108 | 
            +
                        .squeeze(-1)
         | 
| 109 | 
            +
                        .permute(0, 2, 3, 1)
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # flip keypoints to (x, y) format
         | 
| 113 | 
            +
                    return keypoints.flip(-1), log_probs, mask, mask_padding, logits_selected
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            class RIPE(nn.Module):
         | 
| 117 | 
            +
                """
         | 
| 118 | 
            +
                Base class for extracting keypoints and descriptors
         | 
| 119 | 
            +
                Input
         | 
| 120 | 
            +
                  x: [B, C, H, W] Images
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                Returns
         | 
| 123 | 
            +
                  kpts:
         | 
| 124 | 
            +
                    list of size [B] with detected keypoints
         | 
| 125 | 
            +
                  descs:
         | 
| 126 | 
            +
                    list of size [B] with descriptors
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def __init__(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    net,
         | 
| 132 | 
            +
                    upsampler,
         | 
| 133 | 
            +
                    window_size: int = 8,
         | 
| 134 | 
            +
                    non_linearity_dect=None,
         | 
| 135 | 
            +
                    desc_shares: Optional[List[int]] = None,
         | 
| 136 | 
            +
                    descriptor_dim: int = 256,
         | 
| 137 | 
            +
                    device=None,
         | 
| 138 | 
            +
                ):
         | 
| 139 | 
            +
                    super().__init__()
         | 
| 140 | 
            +
                    self.net = net
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    self.detector = KeypointSampler(window_size)
         | 
| 143 | 
            +
                    self.upsampler = upsampler
         | 
| 144 | 
            +
                    self.sampler = None
         | 
| 145 | 
            +
                    self.window_size = window_size
         | 
| 146 | 
            +
                    self.non_linearity_dect = non_linearity_dect if non_linearity_dect is not None else nn.Identity()
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    log.info(f"Training with window size {window_size}.")
         | 
| 149 | 
            +
                    log.info(f"Use {non_linearity_dect} as final non-linearity before the detection heatmap.")
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    dim_coarse_desc = self.get_dim_raw_desc()
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    if desc_shares is not None:
         | 
| 154 | 
            +
                        assert upsampler.name == "HyperColumnFeatures", (
         | 
| 155 | 
            +
                            "Individual descriptor convolutions are only supported with HyperColumnFeatures"
         | 
| 156 | 
            +
                        )
         | 
| 157 | 
            +
                        assert len(desc_shares) == 4, "desc_shares should have 4 elements"
         | 
| 158 | 
            +
                        assert sum(desc_shares) == descriptor_dim, f"sum of desc_shares should be {descriptor_dim}"
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                        self.conv_dim_reduction_coarse_desc = nn.ModuleList()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        for dim_in, dim_out in zip(dim_coarse_desc, desc_shares):
         | 
| 163 | 
            +
                            log.info(f"Training dim reduction descriptor with {dim_in} -> {dim_out} 1x1 conv")
         | 
| 164 | 
            +
                            self.conv_dim_reduction_coarse_desc.append(
         | 
| 165 | 
            +
                                nn.Conv1d(dim_in, dim_out, kernel_size=1, stride=1, padding=0)
         | 
| 166 | 
            +
                            )
         | 
| 167 | 
            +
                    else:
         | 
| 168 | 
            +
                        if descriptor_dim is not None:
         | 
| 169 | 
            +
                            log.info(f"Training dim reduction descriptor with {sum(dim_coarse_desc)} -> {descriptor_dim} 1x1 conv")
         | 
| 170 | 
            +
                            self.conv_dim_reduction_coarse_desc = nn.Conv1d(
         | 
| 171 | 
            +
                                sum(dim_coarse_desc),
         | 
| 172 | 
            +
                                descriptor_dim,
         | 
| 173 | 
            +
                                kernel_size=1,
         | 
| 174 | 
            +
                                stride=1,
         | 
| 175 | 
            +
                                padding=0,
         | 
| 176 | 
            +
                            )
         | 
| 177 | 
            +
                        else:
         | 
| 178 | 
            +
                            log.warning(
         | 
| 179 | 
            +
                                f"No descriptor dimension specified, no 1x1 conv will be applied! Direct usage of {sum(dim_coarse_desc)}-dimensional raw descriptor"
         | 
| 180 | 
            +
                            )
         | 
| 181 | 
            +
                            self.conv_dim_reduction_coarse_desc = nn.Identity()
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def get_dim_raw_desc(self):
         | 
| 184 | 
            +
                    layers_dims_encoder = self.net.get_dim_layers_encoder()
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    if self.upsampler.name == "InterpolateSparse2d":
         | 
| 187 | 
            +
                        return [layers_dims_encoder[-1]]
         | 
| 188 | 
            +
                    elif self.upsampler.name == "HyperColumnFeatures":
         | 
| 189 | 
            +
                        return layers_dims_encoder
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        raise ValueError(f"Unknown interpolator {self.upsampler.name}")
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                @torch.inference_mode()
         | 
| 194 | 
            +
                def detectAndCompute(self, img, threshold=0.5, top_k=2048, output_aux=False):
         | 
| 195 | 
            +
                    self.train(False)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    if img.dim() == 3:
         | 
| 198 | 
            +
                        img = img.unsqueeze(0)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    out = self(img, training=False)
         | 
| 201 | 
            +
                    B, K, H, W = out["heatmap"].shape
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    assert B == 1, "Batch size should be 1"
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    kpts = [{"xy": self.NMS(out["heatmap"][b], threshold)} for b in range(B)]
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    if top_k is not None:
         | 
| 208 | 
            +
                        for b in range(B):
         | 
| 209 | 
            +
                            scores = out["heatmap"][b].squeeze(0)[kpts[b]["xy"][:, 1].long(), kpts[b]["xy"][:, 0].long()]
         | 
| 210 | 
            +
                            sorted_idx = torch.argsort(-scores)
         | 
| 211 | 
            +
                            kpts[b]["xy"] = kpts[b]["xy"][sorted_idx[:top_k]]
         | 
| 212 | 
            +
                            if "logprobs" in kpts[b]:
         | 
| 213 | 
            +
                                kpts[b]["logprobs"] = kpts[b]["xy"][sorted_idx[:top_k]]
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if kpts[0]["xy"].shape[0] == 0:
         | 
| 216 | 
            +
                        raise RuntimeError("No keypoints detected")
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    # the following works for batch size 1 only
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    descs = self.get_descs(out["coarse_descs"], img, kpts[0]["xy"].unsqueeze(0), H, W)
         | 
| 221 | 
            +
                    descs = descs.squeeze(0)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    score_map = out["heatmap"][0].squeeze(0)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    kpts = kpts[0]["xy"]
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    scores = score_map[kpts[:, 1], kpts[:, 0]]
         | 
| 228 | 
            +
                    scores /= score_map.max()
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    sort_idx = torch.argsort(-scores)
         | 
| 231 | 
            +
                    kpts, descs, scores = kpts[sort_idx], descs[sort_idx], scores[sort_idx]
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    if output_aux:
         | 
| 234 | 
            +
                        return (
         | 
| 235 | 
            +
                            kpts.float(),
         | 
| 236 | 
            +
                            descs,
         | 
| 237 | 
            +
                            scores,
         | 
| 238 | 
            +
                            {
         | 
| 239 | 
            +
                                "heatmap": out["heatmap"],
         | 
| 240 | 
            +
                                "descs": out["coarse_descs"],
         | 
| 241 | 
            +
                                "conv": self.conv_dim_reduction_coarse_desc,
         | 
| 242 | 
            +
                            },
         | 
| 243 | 
            +
                        )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    return kpts.float(), descs, scores
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                def NMS(self, x, threshold=3.0, kernel_size=3):
         | 
| 248 | 
            +
                    pad = kernel_size // 2
         | 
| 249 | 
            +
                    local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    pos = (x == local_max) & (x > threshold)
         | 
| 252 | 
            +
                    return pos.nonzero()[..., 1:].flip(-1)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def get_descs(self, feature_map, guidance, kpts, H, W):
         | 
| 255 | 
            +
                    descs = self.upsampler(feature_map, kpts, H, W)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    if isinstance(self.conv_dim_reduction_coarse_desc, nn.ModuleList):
         | 
| 258 | 
            +
                        # individual descriptor convolutions for each layer
         | 
| 259 | 
            +
                        desc_conv = []
         | 
| 260 | 
            +
                        for desc, conv in zip(descs, self.conv_dim_reduction_coarse_desc):
         | 
| 261 | 
            +
                            desc_conv.append(conv(desc.permute(0, 2, 1)).permute(0, 2, 1))
         | 
| 262 | 
            +
                        desc = torch.cat(desc_conv, dim=-1)
         | 
| 263 | 
            +
                    else:
         | 
| 264 | 
            +
                        desc = torch.cat(descs, dim=-1)
         | 
| 265 | 
            +
                        desc = self.conv_dim_reduction_coarse_desc(desc.permute(0, 2, 1)).permute(0, 2, 1)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    desc = F.normalize(desc, dim=2)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    return desc
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def forward(self, x, mask_padding=None, training=False):
         | 
| 272 | 
            +
                    B, C, H, W = x.shape
         | 
| 273 | 
            +
                    out = self.net(x)
         | 
| 274 | 
            +
                    out["heatmap"] = self.non_linearity_dect(out["heatmap"])
         | 
| 275 | 
            +
                    # print(out['map'].shape, out['descr'].shape)
         | 
| 276 | 
            +
                    if training:
         | 
| 277 | 
            +
                        kpts, log_probs, mask, mask_padding, logits_selected = self.detector(out["heatmap"], mask_padding)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                        filter_A = kpts[:, :, :, 0] >= 16
         | 
| 280 | 
            +
                        filter_B = kpts[:, :, :, 1] >= 16
         | 
| 281 | 
            +
                        filter_C = kpts[:, :, :, 0] < W - 16
         | 
| 282 | 
            +
                        filter_D = kpts[:, :, :, 1] < H - 16
         | 
| 283 | 
            +
                        filter_all = filter_A * filter_B * filter_C * filter_D
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                        mask = mask * filter_all
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                        return (
         | 
| 288 | 
            +
                            kpts.view(B, -1, 2),
         | 
| 289 | 
            +
                            log_probs.view(B, -1),
         | 
| 290 | 
            +
                            mask.view(B, -1),
         | 
| 291 | 
            +
                            mask_padding.view(B, -1),
         | 
| 292 | 
            +
                            logits_selected.view(B, -1),
         | 
| 293 | 
            +
                            out,
         | 
| 294 | 
            +
                        )
         | 
| 295 | 
            +
                    else:
         | 
| 296 | 
            +
                        return out
         | 
| 297 | 
            +
             | 
| 298 | 
            +
             | 
| 299 | 
            +
            def output_number_trainable_params(model):
         | 
| 300 | 
            +
                model_parameters = filter(lambda p: p.requires_grad, model.parameters())
         | 
| 301 | 
            +
                nb_params = sum([np.prod(p.size()) for p in model_parameters])
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                print(f"Number of trainable parameters: {nb_params:d}")
         | 
