qingshan777 commited on
Commit
c3dfc69
·
verified ·
1 Parent(s): 73a4b50

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -0
README.md CHANGED
@@ -72,6 +72,8 @@ caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_len
72
  with torch.no_grad():
73
  image_feature = model.get_image_features(image_input)
74
  text_feature = model.get_text_features(caption_input,walk_short_pos=walk_short_pos)
 
 
75
 
76
  logits_per_image = image_feature @ text_feature.T
77
  probs = logits_per_image.softmax(dim=1)
@@ -100,6 +102,9 @@ with torch.no_grad():
100
  captions = ["white cat"]
101
  caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=device)
102
  text_feature = model.get_text_features(caption_input,walk_short_pos=True)
 
 
 
103
 
104
  similarity = dense_image_feature.squeeze() @ text_feature.squeeze().T
105
  similarity = similarity.cpu().numpy()
 
72
  with torch.no_grad():
73
  image_feature = model.get_image_features(image_input)
74
  text_feature = model.get_text_features(caption_input,walk_short_pos=walk_short_pos)
75
+ image_feature = image_feature / image_feature.norm(dim=1, keepdim=True)
76
+ text_feature = text_feature / text_feature.norm(dim=1, keepdim=True)
77
 
78
  logits_per_image = image_feature @ text_feature.T
79
  probs = logits_per_image.softmax(dim=1)
 
102
  captions = ["white cat"]
103
  caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=device)
104
  text_feature = model.get_text_features(caption_input,walk_short_pos=True)
105
+ text_feature = text_feature / text_feature.norm(dim=1, keepdim=True)
106
+ dense_image_feature = dense_image_feature / dense_image_feature.norm(dim=1, keepdim=True)
107
+
108
 
109
  similarity = dense_image_feature.squeeze() @ text_feature.squeeze().T
110
  similarity = similarity.cpu().numpy()