IAMJB commited on
Commit
452c7ff
·
verified ·
1 Parent(s): b5dbe8d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +124 -1
README.md CHANGED
@@ -31,4 +31,127 @@ with torch.no_grad():
31
 
32
  print("Sentence:", sentence)
33
  print("Embedding shape:", cls_embedding.shape)
34
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  print("Sentence:", sentence)
33
  print("Embedding shape:", cls_embedding.shape)
34
+ ```
35
+
36
+
37
+
38
+ ### Similarity heatmap example
39
+
40
+
41
+ ```python
42
+ import argparse
43
+ import numpy as np
44
+ import matplotlib.pyplot as plt
45
+ import torch
46
+ import seaborn as sns
47
+ from transformers import AutoTokenizer, AutoModel
48
+
49
+ def get_cls_embeddings(model, tokenizer, texts, device):
50
+ """Get CLS token embeddings for a list of texts."""
51
+ embeddings = []
52
+
53
+ for text in texts:
54
+ # Tokenize the text
55
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
56
+ inputs = {k: v.to(device) for k, v in inputs.items()}
57
+
58
+ # Get the embeddings (use CLS token)
59
+ with torch.no_grad():
60
+ outputs = model(**inputs, output_hidden_states=True)
61
+ # Use the last hidden state
62
+ last_hidden_state = outputs.hidden_states[-1]
63
+ # Extract CLS token (first token) embedding
64
+ cls_embedding = last_hidden_state[:, 0, :]
65
+ embeddings.append(cls_embedding.cpu().numpy()[0])
66
+
67
+ return np.array(embeddings)
68
+
69
+ def compute_similarities(embeddings):
70
+ """Compute cosine similarity between embeddings."""
71
+ # Normalize embeddings
72
+ normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
73
+ # Compute similarity matrix
74
+ similarity_matrix = np.matmul(normalized_embeddings, normalized_embeddings.T)
75
+ return similarity_matrix
76
+
77
+ def plot_heatmap(similarity_matrix, labels, output_path="cls_embedding_similarities.png"):
78
+ """Generate a heatmap visualization of the similarity matrix."""
79
+ plt.figure(figsize=(10, 8))
80
+
81
+ # Find min value to set as vmin (or use 0.6 as a reasonable value)
82
+ min_val = max(0.0, np.min(similarity_matrix))
83
+
84
+ # Create the heatmap with adjusted color scale
85
+ ax = sns.heatmap(
86
+ similarity_matrix,
87
+ annot=True,
88
+ fmt=".3f",
89
+ cmap="viridis", # Better colormap for distinguishing high values
90
+ vmin=min_val, # Start from minimum value or 0.6
91
+ vmax=1.0,
92
+ xticklabels=labels,
93
+ yticklabels=labels,
94
+ cbar_kws={"label": "Similarity"}
95
+ )
96
+
97
+ # Add title and adjust layout
98
+ plt.title("CLS Token Embedding Similarities")
99
+ plt.tight_layout()
100
+
101
+ # Rotate x-axis labels for better readability
102
+ plt.xticks(rotation=90)
103
+
104
+ # Save the figure
105
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
106
+ print(f"Heatmap saved to {output_path}")
107
+
108
+ # Show the plot
109
+ plt.show()
110
+
111
+ def main():
112
+ # Medical terms to compare
113
+ medical_terms = [
114
+ "large right pneumothorax",
115
+ "right pneumothorax",
116
+ "pneumonia in the right lower lobe",
117
+ "consolidation in the right lower lobe",
118
+ "right 9th rib fracture",
119
+ "left 9th rib fracture",
120
+ "left 5th rib fracture",
121
+ "5th metatarsal fracture",
122
+ "no pneumothorax is present",
123
+ "prior consolidation has cleared",
124
+ "no rib fractures"
125
+ ]
126
+
127
+ # Set the device
128
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129
+ print(f"Using device: {device}")
130
+
131
+ # Load the tokenizer
132
+ tokenizer = AutoTokenizer.from_pretrained(IAMJB/RadEvalModernBERT)
133
+
134
+ # Load the model
135
+ model = AutoModel.from_pretrained(IAMJB/RadEvalModernBERT)
136
+ model.to(device)
137
+ model.eval()
138
+
139
+ # Get CLS token embeddings for the medical terms
140
+ print("Generating CLS token embeddings...")
141
+ embeddings = get_cls_embeddings(model, tokenizer, medical_terms, device)
142
+
143
+ # Compute similarities
144
+ print("Computing similarity matrix...")
145
+ similarity_matrix = compute_similarities(embeddings)
146
+
147
+ # Plot and save the heatmap
148
+ print("Generating heatmap...")
149
+ plot_heatmap(similarity_matrix, medical_terms, "cls_embedding_similarities.png")
150
+
151
+ print("Done!")
152
+
153
+ if __name__ == "__main__":
154
+ main()
155
+ ```
156
+
157
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62716952bcef985363db8485/6mzZ5_Xz2ovl3a6TlAzxo.png)