How to load CONCH v1.5 weights?

#2
by credit1 - opened

Hi Mahmood Lab team,

Thank you for your excellent work . we’ve been using your models regularly for our pathology image research.

I have a question regarding the usage of CONCH v1.5. The weight file structure seems different from the previous version of CONCH, and since the model structure and loading logic seem slightly different as well, I wanted to confirm the correct usage.
Could you kindly provide a brief guide on how to apply the v1.5 weights? If possible, a minimal example or reference to a sample script would be greatly appreciated.

Also, it looks like the v1.5 weights only contain parameters for the vision encoder. Are the previous CONCH text encoder parameters still compatible with v1.5?

Thank you in advance!

import os
import logging
import torch
import torch.nn as nn
from torchvision import transforms
import timm
from os.path import join as pjoin
import json
from torchvision.transforms.functional import to_pil_image

logging.basicConfig(level=logging.INFO)

def get_eval_transforms_conchv1_5(img_resize: int = 448):
    transform = transforms.Compose(
        [
            transforms.Resize(
                img_resize, interpolation=transforms.InterpolationMode.BICUBIC
            ),
            transforms.CenterCrop(img_resize),
            transforms.Lambda(
                lambda img: img.convert("RGB") if img.mode != "RGB" else img
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ]
    )
    return transform


def get_encoder_conchv1_5():
    ckpt_path = "/data/ckpt/conchv1.5/pytorch_model_vision.bin"
    vision_tower = timm.create_model(
        "vit_large_patch16_224",
        img_size=448,
        patch_size=16,
        init_values=1.0,
        num_classes=0,
        dynamic_img_size=True,
    )
    # original_forward = vision_tower.forward
    # vision_tower.forward = lambda x: vision_tower.forward_features(x)

    state_dict = torch.load(ckpt_path, map_location="cpu")
    missing_keys, unexpected_keys = vision_tower.load_state_dict(
        state_dict, strict=False
    )
    print("missing_keys, unexpected_keys: ", missing_keys, unexpected_keys)
    print("ConchV1.5 parameters: ", sum(p.numel() for p in vision_tower.parameters()))
    vision_tower.eval()
    return vision_tower


if __name__ == "__main__":
    model = get_encoder_conchv1_5()
    transform = get_eval_transforms_conchv1_5()
    print(model)
    print(transform)
    input = torch.rand(3, 256, 512)
    print(input.shape)
    input = transform(to_pil_image(input))
    print(input.shape)
    output = model.forward_features(input.unsqueeze(0))
    print(output.shape)

You can give this a try.

AI for Pathology Image Analysis Lab @ HMS / BWH org
edited 3 days ago

We recommend using the Trident library for model loading, see https://github.com/mahmoodlab/TRIDENT/blob/main/trident/patch_encoder_models/load.py#L26.

from trident.patch_encoder_models import encoder_factory

model = encoder_factory(model_name='conch_v15')

This logic applies to all our models (UNI, UNI2, CONCH, TITAN, and other models from the community)

Hope this helps,
Best, Guillaume

Wow, thank you very much! This code repository is incredibly helpful for my research!!

Thank you! I was able to successfully load the model thanks to your help.

I have one more question—would it be alright to ask if the text model is compatible with the original CONCH?

I think it might be incompatible, due to the use of different training models.

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment