How to load CONCH v1.5 weights?
Hi Mahmood Lab team,
Thank you for your excellent work . we’ve been using your models regularly for our pathology image research.
I have a question regarding the usage of CONCH v1.5. The weight file structure seems different from the previous version of CONCH, and since the model structure and loading logic seem slightly different as well, I wanted to confirm the correct usage.
Could you kindly provide a brief guide on how to apply the v1.5 weights? If possible, a minimal example or reference to a sample script would be greatly appreciated.
Also, it looks like the v1.5 weights only contain parameters for the vision encoder. Are the previous CONCH text encoder parameters still compatible with v1.5?
Thank you in advance!
import os
import logging
import torch
import torch.nn as nn
from torchvision import transforms
import timm
from os.path import join as pjoin
import json
from torchvision.transforms.functional import to_pil_image
logging.basicConfig(level=logging.INFO)
def get_eval_transforms_conchv1_5(img_resize: int = 448):
transform = transforms.Compose(
[
transforms.Resize(
img_resize, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(img_resize),
transforms.Lambda(
lambda img: img.convert("RGB") if img.mode != "RGB" else img
),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
return transform
def get_encoder_conchv1_5():
ckpt_path = "/data/ckpt/conchv1.5/pytorch_model_vision.bin"
vision_tower = timm.create_model(
"vit_large_patch16_224",
img_size=448,
patch_size=16,
init_values=1.0,
num_classes=0,
dynamic_img_size=True,
)
# original_forward = vision_tower.forward
# vision_tower.forward = lambda x: vision_tower.forward_features(x)
state_dict = torch.load(ckpt_path, map_location="cpu")
missing_keys, unexpected_keys = vision_tower.load_state_dict(
state_dict, strict=False
)
print("missing_keys, unexpected_keys: ", missing_keys, unexpected_keys)
print("ConchV1.5 parameters: ", sum(p.numel() for p in vision_tower.parameters()))
vision_tower.eval()
return vision_tower
if __name__ == "__main__":
model = get_encoder_conchv1_5()
transform = get_eval_transforms_conchv1_5()
print(model)
print(transform)
input = torch.rand(3, 256, 512)
print(input.shape)
input = transform(to_pil_image(input))
print(input.shape)
output = model.forward_features(input.unsqueeze(0))
print(output.shape)
You can give this a try.
We recommend using the Trident library for model loading, see https://github.com/mahmoodlab/TRIDENT/blob/main/trident/patch_encoder_models/load.py#L26.
from trident.patch_encoder_models import encoder_factory
model = encoder_factory(model_name='conch_v15')
This logic applies to all our models (UNI, UNI2, CONCH, TITAN, and other models from the community)
Hope this helps,
Best, Guillaume
Wow, thank you very much! This code repository is incredibly helpful for my research!!
Thank you! I was able to successfully load the model thanks to your help.
I have one more question—would it be alright to ask if the text model is compatible with the original CONCH?
I think it might be incompatible, due to the use of different training models.