qingshan777 commited on
Commit
39a06ee
·
verified ·
1 Parent(s): c3dfc69

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -4
README.md CHANGED
@@ -72,8 +72,8 @@ caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_len
72
  with torch.no_grad():
73
  image_feature = model.get_image_features(image_input)
74
  text_feature = model.get_text_features(caption_input,walk_short_pos=walk_short_pos)
75
- image_feature = image_feature / image_feature.norm(dim=1, keepdim=True)
76
- text_feature = text_feature / text_feature.norm(dim=1, keepdim=True)
77
 
78
  logits_per_image = image_feature @ text_feature.T
79
  probs = logits_per_image.softmax(dim=1)
@@ -102,8 +102,9 @@ with torch.no_grad():
102
  captions = ["white cat"]
103
  caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=device)
104
  text_feature = model.get_text_features(caption_input,walk_short_pos=True)
105
- text_feature = text_feature / text_feature.norm(dim=1, keepdim=True)
106
- dense_image_feature = dense_image_feature / dense_image_feature.norm(dim=1, keepdim=True)
 
107
 
108
 
109
  similarity = dense_image_feature.squeeze() @ text_feature.squeeze().T
 
72
  with torch.no_grad():
73
  image_feature = model.get_image_features(image_input)
74
  text_feature = model.get_text_features(caption_input,walk_short_pos=walk_short_pos)
75
+ image_feature = image_feature / image_feature.norm(p=2, dim=-1, keepdim=True)
76
+ text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
77
 
78
  logits_per_image = image_feature @ text_feature.T
79
  probs = logits_per_image.softmax(dim=1)
 
102
  captions = ["white cat"]
103
  caption_input = torch.tensor(tokenizer(captions, max_length=77, padding="max_length", truncation=True).input_ids, dtype=torch.long, device=device)
104
  text_feature = model.get_text_features(caption_input,walk_short_pos=True)
105
+ text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
106
+ dense_image_feature = dense_image_feature / dense_image_feature.norm(p=2, dim=-1, keepdim=True)
107
+
108
 
109
 
110
  similarity = dense_image_feature.squeeze() @ text_feature.squeeze().T