import numpy as np import gradio as gr import os import sys import numpy as np import torch from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import transforms as T sys.path.append(os.getcwd()) np.set_printoptions(suppress=True) from cubercnn.config import get_cfg_defaults from cubercnn.modeling.meta_arch import build_model from cubercnn.modeling.backbone import build_dla_from_vision_fpn_backbone # this must be here even though it is not used from cubercnn import util, vis def do_test(im, threshold): if im is None: return None, None model = load_model_config() model.eval() thres = threshold min_size = 500 max_size = 1000 augmentations = T.AugmentationList([T.ResizeShortestEdge(min_size, max_size, "choice")]) category_path = 'configs/category_meta.json' # store locally if needed if category_path.startswith(util.CubeRCNNHandler.PREFIX): category_path = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, category_path) metadata = util.load_json(category_path) cats = metadata['thing_classes'] aug_input = T.AugInput(im) tfms = augmentations(aug_input) im = tfms.apply_image(im) image_shape = im.shape[:2] # h, w' h, w = image_shape focal_length_ndc = 4.0 focal_length = focal_length_ndc * h / 2 px, py = w/2, h/2 K = np.array([ [focal_length, 0.0, px], [0.0, focal_length, py], [0.0, 0.0, 1.0] ]) # model.to(device) batched = [{ 'image': torch.as_tensor(np.ascontiguousarray(im.transpose(2, 0, 1))), 'height': image_shape[0], 'width': image_shape[1], 'K': K }] with torch.no_grad(): dets = model(batched)[0]['instances'] n_det = len(dets) meshes = [] meshes_text = [] if n_det > 0: for idx, (corners3D, center_cam, center_2D, dimensions, pose, score, cat_idx) in enumerate(zip( dets.pred_bbox3D, dets.pred_center_cam, dets.pred_center_2D, dets.pred_dimensions, dets.pred_pose, dets.scores, dets.pred_classes )): # skip if score < thres: continue cat = cats[cat_idx] bbox3D = center_cam.tolist() + dimensions.tolist() meshes_text.append('{} {:.2f}'.format(cat, score)) color = [c/255.0 for c in util.get_color(idx)] box_mesh = util.mesh_cuboid(bbox3D, pose.tolist(), color=color) meshes.append(box_mesh) # print('File with {} dets'.format(len(meshes))) if len(meshes) > 0: im_drawn_rgb, im_topdown, _ = vis.draw_scene_view(im, K, meshes, text=meshes_text, scale=im.shape[0], blend_weight=0.5, blend_weight_overlay=0.85) im_drawn_rgb, im_topdown = im_drawn_rgb.astype(np.uint8), im_topdown.astype(np.uint8) else: im_drawn_rgb, im_topdown = im.astype(np.uint8), None return im_drawn_rgb, im_topdown def setup(config_file): """ Create configs and perform basic setups. """ cfg = get_cfg() get_cfg_defaults(cfg) # store locally if needed if config_file.startswith(util.CubeRCNNHandler.PREFIX): config_file = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, config_file) cfg.merge_from_file(config_file) cfg.freeze() return cfg def main(config_file, weigths=None): cfg = setup(config_file) model = build_model(cfg) DetectionCheckpointer(model).resume_or_load( weigths, resume=True ) return cfg, model if __name__ == "__main__": def load_model_config(): config_file = "configs/Omni_combined.yaml" MODEL_WEIGHTS = "output/weak_cube_r-cnn/model_final.pth" cfg, model = main(config_file, MODEL_WEIGHTS) return model title = 'Weak Cube R-CNN' description = "This showcases the different our model [`Weak Cube RCNN`](https://arxiv.org/abs/2504.13297). To create Weak Cube RCNN, we modify the framework by replacing its 3D loss functions with ones based solely on 2D annotations. Our methods rely heavily on external, strong generalised deep learning models to infer spatial information in scenes. Experimental results show that all models perform comparably to an annotation time-equalised Cube R-CNN, whereof the pseudo ground truth method achieves the highest accuracy. The results show the methods' ability to understand scenes in 3D, providing satisfactory visual results. Although not precise enough for centimetre accurate measurements, the method provide a solid foundation for further research. \n Check out the code on [`GitHub`](https://github.com/AndreasLH/Weak-Cube-R-CNN)" demo = gr.Interface( title=title, fn=do_test, inputs=[ gr.Image(label="Input Image"), gr.Slider(0, 1, value=0.5, label="Threshold", info="Only show predictions with a confidence above this threshold"), ], outputs=[gr.Image(label="Predictions"), gr.Image(label="Top view")], description=description, flagging_mode="never", examples=[["examples/ex2.jpg"],["examples/ex1.jpg"]], ) demo.launch(server_name="0.0.0.0", server_port=7860) # demo.launch()