RP3D-DiagModel

About Checkpoint

The detailed parameter we use for training is in the following:

start_class: 0
end_clas: 5569
backbone: 'resnet'
level: 'articles'     # represents the disorder level
depth: 32
ltype: 'MultiLabel'   # represents the Binary Cross Entropy Loss
augment: True         # represents the medical data augmentation
split: 'late'         # represents the late fusion strategy

Load Model

# Load backnone
model = RadNet(num_cls=num_classes, backbone=backbone, depth=depth, ltype=ltype, augment=augment, fuse=fuse, ke=ke, encoded=encoded, adapter=adapter)
pretrained_weights = torch.load("path/to/pytorch_model_32_late.bin")
missing, unexpect = model.load_state_dict(pretrained_weights,strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)

# If KE is set True, load text encoder
medcpt = MedCPT_clinical(bert_model_name = 'ncbi/MedCPT-Query-Encoder')
checkpoint = torch.load('path/to/epoch_state.pt',map_location='cpu')['state_dict']
load_checkpoint = {key.replace('module.', ''): value for key, value in checkpoint.items()}
missing, unexpect = medcpt.load_state_dict(load_checkpoint, strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)

Why we provide this checkpoint?

All the early fusion checkpoint can be further finetuned from this checkpoint. If you need other checkpoints using different parameter settings, there are two possible ways:

Finetune from this checkpoint

''' checkpoint: "None" safetensor: path to this checkpoint(pytorch_model.bin) '''

Contact Us

Email the author: [email protected]

About Dataset

Please refer to RP3D-DiagDS

For more information, please refer to our instructions on github to download and use.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .