Spaces:
Running
on
Zero
Running
on
Zero
| import warnings | |
| import torch | |
| from kornia.feature import LoFTR as LoFTR_ | |
| from kornia.feature.loftr.loftr import default_cfg | |
| from hloc import logger | |
| from ..utils.base_model import BaseModel | |
| class LoFTR(BaseModel): | |
| default_conf = { | |
| "weights": "outdoor", | |
| "match_threshold": 0.2, | |
| "sinkhorn_iterations": 20, | |
| "max_keypoints": -1, | |
| } | |
| required_inputs = ["image0", "image1"] | |
| def _init(self, conf): | |
| cfg = default_cfg | |
| cfg["match_coarse"]["thr"] = conf["match_threshold"] | |
| cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] | |
| print(cfg) | |
| print(cfg['coarse']['temp_bug_fix']) | |
| if conf["weights"] == "minima_loftr_outdoor": | |
| conf["weights"] = "outdoor" | |
| cfg['coarse']['temp_bug_fix'] = True | |
| self.net = LoFTR_(pretrained=conf["weights"], config=cfg) | |
| model_web_path='https://github.com/LSXI7/storage/releases/download/MINIMA/minima_loftr.ckpt' | |
| weight_path = torch.hub.load_state_dict_from_url(model_web_path, map_location=torch.device('cpu')) | |
| self.net.load_state_dict(weight_path['state_dict']) | |
| else: | |
| self.net = LoFTR_(pretrained=conf["weights"], config=cfg) | |
| logger.info(f"Loaded LoFTR with weights {conf['weights']}") | |
| def _forward(self, data): | |
| # For consistency with hloc pairs, we refine kpts in image0! | |
| rename = { | |
| "keypoints0": "keypoints1", | |
| "keypoints1": "keypoints0", | |
| "image0": "image1", | |
| "image1": "image0", | |
| "mask0": "mask1", | |
| "mask1": "mask0", | |
| } | |
| data_ = {rename[k]: v for k, v in data.items()} | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| pred = self.net(data_) | |
| scores = pred["confidence"] | |
| top_k = self.conf["max_keypoints"] | |
| if top_k is not None and len(scores) > top_k: | |
| keep = torch.argsort(scores, descending=True)[:top_k] | |
| pred["keypoints0"], pred["keypoints1"] = ( | |
| pred["keypoints0"][keep], | |
| pred["keypoints1"][keep], | |
| ) | |
| scores = scores[keep] | |
| # Switch back indices | |
| pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} | |
| pred["scores"] = scores | |
| del pred["confidence"] | |
| return pred | |