QiaoyuZheng
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -23,9 +23,20 @@ split: 'late' # represents the late fusion strategy
|
|
23 |
### Load Model
|
24 |
|
25 |
```
|
|
|
26 |
model = RadNet(num_cls=num_classes, backbone=backbone, depth=depth, ltype=ltype, augment=augment, fuse=fuse, ke=ke, encoded=encoded, adapter=adapter)
|
27 |
pretrained_weights = torch.load("path/to/pytorch_model_32_late.bin")
|
28 |
missing, unexpect = model.load_state_dict(pretrained_weights,strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
```
|
30 |
|
31 |
## Why we provide this checkpoint?
|
|
|
23 |
### Load Model
|
24 |
|
25 |
```
|
26 |
+
# Load backnone
|
27 |
model = RadNet(num_cls=num_classes, backbone=backbone, depth=depth, ltype=ltype, augment=augment, fuse=fuse, ke=ke, encoded=encoded, adapter=adapter)
|
28 |
pretrained_weights = torch.load("path/to/pytorch_model_32_late.bin")
|
29 |
missing, unexpect = model.load_state_dict(pretrained_weights,strict=False)
|
30 |
+
print("missing_cpt:", missing)
|
31 |
+
print("unexpect_cpt:", unexpect)
|
32 |
+
|
33 |
+
# If KE is set True, load text encoder
|
34 |
+
medcpt = MedCPT_clinical(bert_model_name = 'ncbi/MedCPT-Query-Encoder')
|
35 |
+
checkpoint = torch.load('path/to/epoch_state.pt',map_location='cpu')['state_dict']
|
36 |
+
load_checkpoint = {key.replace('module.', ''): value for key, value in checkpoint.items()}
|
37 |
+
missing, unexpect = medcpt.load_state_dict(load_checkpoint, strict=False)
|
38 |
+
print("missing_cpt:", missing)
|
39 |
+
print("unexpect_cpt:", unexpect)
|
40 |
```
|
41 |
|
42 |
## Why we provide this checkpoint?
|