import torch
from transformers import AutoTokenizer, AutoModel

# Pick one sentence
sentence = "The patient has a right pneumothorax."

# Load pretrained model and tokenizer
model_name = "IAMJB/RadEvalModernBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Put model in eval mode and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Tokenize input
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(device)

# Get embeddings
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
    last_hidden_state = outputs.hidden_states[-1]
    cls_embedding = last_hidden_state[:, 0, :]  # CLS token

print("Sentence:", sentence)
print("Embedding shape:", cls_embedding.shape)

Similarity heatmap example

import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

def get_cls_embeddings(model, tokenizer, texts, device):
    """Get CLS token embeddings for a list of texts."""
    embeddings = []
    
    for text in texts:
        # Tokenize the text
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get the embeddings (use CLS token)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            # Use the last hidden state
            last_hidden_state = outputs.hidden_states[-1]
            # Extract CLS token (first token) embedding
            cls_embedding = last_hidden_state[:, 0, :]
            embeddings.append(cls_embedding.cpu().numpy()[0])
    
    return np.array(embeddings)

def compute_similarities(embeddings):
    """Compute cosine similarity between embeddings."""
    # Normalize embeddings
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    # Compute similarity matrix
    similarity_matrix = np.matmul(normalized_embeddings, normalized_embeddings.T)
    return similarity_matrix

def plot_heatmap(similarity_matrix, labels, output_path="cls_embedding_similarities.png"):
    """Generate a heatmap visualization of the similarity matrix."""
    plt.figure(figsize=(10, 8))
    
    # Find min value to set as vmin (or use 0.6 as a reasonable value)
    min_val = max(0.0, np.min(similarity_matrix))
    
    # Create the heatmap with adjusted color scale
    ax = sns.heatmap(
        similarity_matrix,
        annot=True,
        fmt=".3f",
        cmap="viridis",  # Better colormap for distinguishing high values
        vmin=min_val,    # Start from minimum value or 0.6
        vmax=1.0,
        xticklabels=labels,
        yticklabels=labels,
        cbar_kws={"label": "Similarity"}
    )
    
    # Add title and adjust layout
    plt.title("CLS Token Embedding Similarities")
    plt.tight_layout()
    
    # Rotate x-axis labels for better readability
    plt.xticks(rotation=90)
    
    # Save the figure
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    print(f"Heatmap saved to {output_path}")
    
    # Show the plot
    plt.show()

def main():
    # Medical terms to compare
    medical_terms = [
        "large right pneumothorax",
        "right pneumothorax",
        "pneumonia in the right lower lobe",
        "consolidation in the right lower lobe",
        "right 9th rib fracture",
        "left 9th rib fracture",
        "left 5th rib fracture",
        "5th metatarsal fracture",
        "no pneumothorax is present",
        "prior consolidation has cleared",
        "no rib fractures"
    ]
    
    # Set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(IAMJB/RadEvalModernBERT)
    
    # Load the model
    model = AutoModel.from_pretrained(IAMJB/RadEvalModernBERT)
    model.to(device)
    model.eval()
    
    # Get CLS token embeddings for the medical terms
    print("Generating CLS token embeddings...")
    embeddings = get_cls_embeddings(model, tokenizer, medical_terms, device)
    
    # Compute similarities
    print("Computing similarity matrix...")
    similarity_matrix = compute_similarities(embeddings)
    
    # Plot and save the heatmap
    print("Generating heatmap...")
    plot_heatmap(similarity_matrix, medical_terms, "cls_embedding_similarities.png")
    
    print("Done!")

if __name__ == "__main__":
    main()

image/png

Downloads last month
3,153
Safetensors
Model size
149M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using IAMJB/RadEvalModernBERT 1