Update README.md
Browse files
README.md
CHANGED
@@ -72,8 +72,8 @@ caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_len
|
|
72 |
with torch.no_grad():
|
73 |
image_feature = model.get_image_features(image_input)
|
74 |
text_feature = model.get_text_features(caption_input,walk_short_pos=walk_short_pos)
|
75 |
-
image_feature = image_feature / image_feature.norm(dim
|
76 |
-
text_feature = text_feature / text_feature.norm(dim
|
77 |
|
78 |
logits_per_image = image_feature @ text_feature.T
|
79 |
probs = logits_per_image.softmax(dim=1)
|
@@ -102,8 +102,9 @@ with torch.no_grad():
|
|
102 |
captions = ["white cat"]
|
103 |
caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=device)
|
104 |
text_feature = model.get_text_features(caption_input,walk_short_pos=True)
|
105 |
-
text_feature = text_feature / text_feature.norm(dim
|
106 |
-
dense_image_feature = dense_image_feature / dense_image_feature.norm(dim
|
|
|
107 |
|
108 |
|
109 |
similarity = dense_image_feature.squeeze() @ text_feature.squeeze().T
|
|
|
72 |
with torch.no_grad():
|
73 |
image_feature = model.get_image_features(image_input)
|
74 |
text_feature = model.get_text_features(caption_input,walk_short_pos=walk_short_pos)
|
75 |
+
image_feature = image_feature / image_feature.norm(p=2, dim=-1, keepdim=True)
|
76 |
+
text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
|
77 |
|
78 |
logits_per_image = image_feature @ text_feature.T
|
79 |
probs = logits_per_image.softmax(dim=1)
|
|
|
102 |
captions = ["white cat"]
|
103 |
caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=device)
|
104 |
text_feature = model.get_text_features(caption_input,walk_short_pos=True)
|
105 |
+
text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
|
106 |
+
dense_image_feature = dense_image_feature / dense_image_feature.norm(p=2, dim=-1, keepdim=True)
|
107 |
+
|
108 |
|
109 |
|
110 |
similarity = dense_image_feature.squeeze() @ text_feature.squeeze().T
|