--- license: apache-2.0 --- # RP3D-DiagModel ## About Checkpoint The detailed parameter we use for training is in the following: ``` start_class: 0 end_clas: 5569 backbone: 'resnet' level: 'articles' # represents the disorder level depth: 32 ltype: 'MultiLabel' # represents the Binary Cross Entropy Loss augment: True # represents the medical data augmentation split: 'late' # represents the late fusion strategy ``` ### Load Model ``` # Load backnone model = RadNet(num_cls=num_classes, backbone=backbone, depth=depth, ltype=ltype, augment=augment, fuse=fuse, ke=ke, encoded=encoded, adapter=adapter) pretrained_weights = torch.load("path/to/pytorch_model_32_late.bin") missing, unexpect = model.load_state_dict(pretrained_weights,strict=False) print("missing_cpt:", missing) print("unexpect_cpt:", unexpect) # If KE is set True, load text encoder medcpt = MedCPT_clinical(bert_model_name = 'ncbi/MedCPT-Query-Encoder') checkpoint = torch.load('path/to/epoch_state.pt',map_location='cpu')['state_dict'] load_checkpoint = {key.replace('module.', ''): value for key, value in checkpoint.items()} missing, unexpect = medcpt.load_state_dict(load_checkpoint, strict=False) print("missing_cpt:", missing) print("unexpect_cpt:", unexpect) ``` ## Why we provide this checkpoint? All the early fusion checkpoint can be further finetuned from this checkpoint. If you need other checkpoints using different parameter settings, there are two possible ways: ### Finetune from this checkpoint ''' checkpoint: "None" safetensor: path to this checkpoint(pytorch_model.bin) ''' ### Contact Us Email the author: three-world@sjtu.edu.cn ## About Dataset Please refer to [RP3D-DiagDS](https://huggingface.co/datasets/QiaoyuZheng/RP3D-DiagDS) For more information, please refer to our instructions on [github](https://github.com/qiaoyu-zheng/RP3D-Diag) to download and use.