import torch
from transformers import AutoTokenizer, AutoModel
sentence = "The patient has a right pneumothorax."
model_name = "IAMJB/RadEvalModernBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
last_hidden_state = outputs.hidden_states[-1]
cls_embedding = last_hidden_state[:, 0, :]
print("Sentence:", sentence)
print("Embedding shape:", cls_embedding.shape)
Similarity heatmap example
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
from transformers import AutoTokenizer, AutoModel
def get_cls_embeddings(model, tokenizer, texts, device):
"""Get CLS token embeddings for a list of texts."""
embeddings = []
for text in texts:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
last_hidden_state = outputs.hidden_states[-1]
cls_embedding = last_hidden_state[:, 0, :]
embeddings.append(cls_embedding.cpu().numpy()[0])
return np.array(embeddings)
def compute_similarities(embeddings):
"""Compute cosine similarity between embeddings."""
normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
similarity_matrix = np.matmul(normalized_embeddings, normalized_embeddings.T)
return similarity_matrix
def plot_heatmap(similarity_matrix, labels, output_path="cls_embedding_similarities.png"):
"""Generate a heatmap visualization of the similarity matrix."""
plt.figure(figsize=(10, 8))
min_val = max(0.0, np.min(similarity_matrix))
ax = sns.heatmap(
similarity_matrix,
annot=True,
fmt=".3f",
cmap="viridis",
vmin=min_val,
vmax=1.0,
xticklabels=labels,
yticklabels=labels,
cbar_kws={"label": "Similarity"}
)
plt.title("CLS Token Embedding Similarities")
plt.tight_layout()
plt.xticks(rotation=90)
plt.savefig(output_path, dpi=300, bbox_inches="tight")
print(f"Heatmap saved to {output_path}")
plt.show()
def main():
medical_terms = [
"large right pneumothorax",
"right pneumothorax",
"pneumonia in the right lower lobe",
"consolidation in the right lower lobe",
"right 9th rib fracture",
"left 9th rib fracture",
"left 5th rib fracture",
"5th metatarsal fracture",
"no pneumothorax is present",
"prior consolidation has cleared",
"no rib fractures"
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(IAMJB/RadEvalModernBERT)
model = AutoModel.from_pretrained(IAMJB/RadEvalModernBERT)
model.to(device)
model.eval()
print("Generating CLS token embeddings...")
embeddings = get_cls_embeddings(model, tokenizer, medical_terms, device)
print("Computing similarity matrix...")
similarity_matrix = compute_similarities(embeddings)
print("Generating heatmap...")
plot_heatmap(similarity_matrix, medical_terms, "cls_embedding_similarities.png")
print("Done!")
if __name__ == "__main__":
main()
