File Rename
Browse files- lrn_vector_embeddings.py +109 -0
- s1_lrn_gradio.py +37 -0
- s2_download_data.py +67 -0
- s3_data_to_vector_embedding.py +62 -0
- s4_calculate_distance.py +83 -0
- s5-how-to-umap.py +126 -0
lrn_vector_embeddings.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from numpy.linalg import norm
|
5 |
+
import cv2
|
6 |
+
from io import StringIO, BytesIO
|
7 |
+
from umap import UMAP
|
8 |
+
from sklearn.preprocessing import MinMaxScaler
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
import base64
|
12 |
+
from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM
|
13 |
+
import requests
|
14 |
+
from PIL import Image
|
15 |
+
import torch
|
16 |
+
|
17 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
18 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
19 |
+
|
20 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
21 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
22 |
+
|
23 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
24 |
+
cap3='a cat laying down stretched out near a laptop'
|
25 |
+
|
26 |
+
img1 = {
|
27 |
+
'flickr_url': url1,
|
28 |
+
'caption': cap1,
|
29 |
+
'image_path' : './shared_data/motorcycle_1.jpg'
|
30 |
+
}
|
31 |
+
|
32 |
+
img2 = {
|
33 |
+
'flickr_url': url2,
|
34 |
+
'caption': cap2,
|
35 |
+
'image_path' : './shared_data/motorcycle_2.jpg'
|
36 |
+
}
|
37 |
+
|
38 |
+
img3 = {
|
39 |
+
'flickr_url' : url3,
|
40 |
+
'caption': cap3,
|
41 |
+
'image_path' : './shared_data/cat_1.jpg'
|
42 |
+
}
|
43 |
+
|
44 |
+
def bt_embeddings_from_local(text, image):
|
45 |
+
|
46 |
+
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
47 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
48 |
+
|
49 |
+
processed_inputs = processor(image, text, padding=True, return_tensors="pt")
|
50 |
+
|
51 |
+
#inputs = processor(prompt, base64_image, padding=True, return_tensors="pt")
|
52 |
+
outputs = model(**processed_inputs)
|
53 |
+
|
54 |
+
cross_modal_embeddings = outputs.cross_embeds
|
55 |
+
text_embeddings = outputs.text_embeds
|
56 |
+
image_embeddings = outputs.image_embeds
|
57 |
+
return {
|
58 |
+
'cross_modal_embeddings': cross_modal_embeddings,
|
59 |
+
'text_embeddings': text_embeddings,
|
60 |
+
'image_embeddings': image_embeddings
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
def bt_scores_with_image_and_text_retrieval():
|
65 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
66 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
67 |
+
texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
|
68 |
+
|
69 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
70 |
+
model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
71 |
+
|
72 |
+
# forward pass
|
73 |
+
scores = dict()
|
74 |
+
for text in texts:
|
75 |
+
# prepare inputs
|
76 |
+
encoding = processor(image, text, return_tensors="pt")
|
77 |
+
outputs = model(**encoding)
|
78 |
+
scores[text] = outputs.logits[0,1].item()
|
79 |
+
return scores
|
80 |
+
|
81 |
+
|
82 |
+
def bt_with_masked_input():
|
83 |
+
url = "http://images.cocodataset.org/val2017/000000360943.jpg"
|
84 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
85 |
+
text = "a <mask> looking out of the window"
|
86 |
+
|
87 |
+
|
88 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
89 |
+
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
90 |
+
|
91 |
+
# prepare inputs
|
92 |
+
encoding = processor(image, text, return_tensors="pt")
|
93 |
+
|
94 |
+
# forward pass
|
95 |
+
outputs = model(**encoding)
|
96 |
+
|
97 |
+
token_ids = outputs.logits.argmax(dim=-1).squeeze(0).tolist()
|
98 |
+
if isinstance(token_ids, list):
|
99 |
+
results = processor.tokenizer.decode(token_ids)
|
100 |
+
else:
|
101 |
+
results = processor.tokenizer.decode([token_ids])
|
102 |
+
|
103 |
+
print(results)
|
104 |
+
return results
|
105 |
+
#res = bt_embeddingsl()
|
106 |
+
#print((res['text_embeddings']))
|
107 |
+
for img in [img1, img2, img3]:
|
108 |
+
embeddings = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
|
109 |
+
print(embeddings['cross_modal_embeddings'][0].shape)
|
s1_lrn_gradio.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name, intensity): #Number of inputs should match the number of input components
|
4 |
+
return "Hello, " + name + "!" * int(intensity)
|
5 |
+
|
6 |
+
|
7 |
+
basicDemo = gr.Interface(
|
8 |
+
fn=greet,
|
9 |
+
inputs=["text", "slider"],
|
10 |
+
outputs=["text"],
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
with gr.Blocks() as blockDemo:
|
15 |
+
gr.Markdown("Enter your Name and Intensity.")
|
16 |
+
with gr.Row():
|
17 |
+
inp1 = gr.Textbox(placeholder="What is your name?")
|
18 |
+
inp2 = gr.Slider(minimum=1, maximum=100)
|
19 |
+
out = gr.Textbox()
|
20 |
+
btn = gr.Button("Run")
|
21 |
+
btn.click(fn=greet, inputs=[inp1,inp2], outputs=out)
|
22 |
+
|
23 |
+
|
24 |
+
def random_response(messages, history):
|
25 |
+
return "I am a bot. I don't understand human language. I can only say Hello. 🤖"
|
26 |
+
|
27 |
+
with gr.Blocks() as chatInterfaceDemo:
|
28 |
+
with gr.Row():
|
29 |
+
with gr.Column(scale=4):
|
30 |
+
gr.Video(height=512, width=512, elem_id="video", interactive=False )
|
31 |
+
with gr.Column(scale=7):
|
32 |
+
gr.ChatInterface(
|
33 |
+
fn=random_response,
|
34 |
+
type="messages" )
|
35 |
+
|
36 |
+
chatInterfaceDemo.launch(share=False) # Share your demo with just 1 extra parameter 🚀
|
37 |
+
|
s2_download_data.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from PIL import Image
|
3 |
+
from IPython.display import display
|
4 |
+
import huggingface_hub
|
5 |
+
from huggingface_hub import list_datasets
|
6 |
+
from huggingface_hub import HfApi
|
7 |
+
|
8 |
+
# You can use your own uploaded images and captions.
|
9 |
+
# You will be responsible for the legal use of images that
|
10 |
+
# you are going to use.
|
11 |
+
|
12 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
13 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
14 |
+
|
15 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
16 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
17 |
+
|
18 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
19 |
+
cap3='a cat laying down stretched out near a laptop'
|
20 |
+
|
21 |
+
img1 = {
|
22 |
+
'flickr_url': url1,
|
23 |
+
'caption': cap1,
|
24 |
+
'image_path' : './shared_data/motorcycle_1.jpg'
|
25 |
+
}
|
26 |
+
|
27 |
+
img2 = {
|
28 |
+
'flickr_url': url2,
|
29 |
+
'caption': cap2,
|
30 |
+
'image_path' : './shared_data/motorcycle_2.jpg'
|
31 |
+
}
|
32 |
+
|
33 |
+
img3 = {
|
34 |
+
'flickr_url' : url3,
|
35 |
+
'caption': cap3,
|
36 |
+
'image_path' : './shared_data/cat_1.jpg'
|
37 |
+
}
|
38 |
+
|
39 |
+
def download_images():
|
40 |
+
# download images
|
41 |
+
imgs = [img1, img2, img3]
|
42 |
+
for img in imgs:
|
43 |
+
data = requests.get(img['flickr_url']).content
|
44 |
+
with open(img['image_path'], 'wb') as f:
|
45 |
+
f.write(data)
|
46 |
+
|
47 |
+
for img in [img1, img2, img3]:
|
48 |
+
image = Image.open(img['image_path'])
|
49 |
+
caption = img['caption']
|
50 |
+
display(image)
|
51 |
+
print(caption)
|
52 |
+
|
53 |
+
def load_data_from_huggingface(hf_dataset_name):
|
54 |
+
|
55 |
+
api = HfApi()
|
56 |
+
|
57 |
+
#list models from huggingface
|
58 |
+
|
59 |
+
#models = list(api.list_models())
|
60 |
+
|
61 |
+
#list datasets from huggingface
|
62 |
+
|
63 |
+
#datasets = list(api.list_datasets())
|
64 |
+
|
65 |
+
|
66 |
+
return api.list_datasets(search=hf_dataset_name)
|
67 |
+
|
s3_data_to_vector_embedding.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy.linalg import norm
|
2 |
+
from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
8 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
9 |
+
|
10 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
11 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
12 |
+
|
13 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
14 |
+
cap3='a cat laying down stretched out near a laptop'
|
15 |
+
|
16 |
+
img1 = {
|
17 |
+
'flickr_url': url1,
|
18 |
+
'caption': cap1,
|
19 |
+
'image_path' : './shared_data/motorcycle_1.jpg',
|
20 |
+
'tensor_path' : './shared_data/motorcycle_1'
|
21 |
+
}
|
22 |
+
|
23 |
+
img2 = {
|
24 |
+
'flickr_url': url2,
|
25 |
+
'caption': cap2,
|
26 |
+
'image_path' : './shared_data/motorcycle_2.jpg',
|
27 |
+
'tensor_path' : './shared_data/motorcycle_2'
|
28 |
+
}
|
29 |
+
|
30 |
+
img3 = {
|
31 |
+
'flickr_url' : url3,
|
32 |
+
'caption': cap3,
|
33 |
+
'image_path' : './shared_data/cat_1.jpg',
|
34 |
+
'tensor_path' : './shared_data/cat_1'
|
35 |
+
}
|
36 |
+
|
37 |
+
def bt_embeddings_from_local(text, image):
|
38 |
+
|
39 |
+
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
40 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
41 |
+
|
42 |
+
processed_inputs = processor(image, text, padding=True, return_tensors="pt")
|
43 |
+
|
44 |
+
outputs = model(**processed_inputs)
|
45 |
+
|
46 |
+
cross_modal_embeddings = outputs.cross_embeds
|
47 |
+
text_embeddings = outputs.text_embeds
|
48 |
+
image_embeddings = outputs.image_embeds
|
49 |
+
return {
|
50 |
+
'cross_modal_embeddings': cross_modal_embeddings,
|
51 |
+
'text_embeddings': text_embeddings,
|
52 |
+
'image_embeddings': image_embeddings
|
53 |
+
}
|
54 |
+
|
55 |
+
def save_embeddings():
|
56 |
+
for img in [img1, img2, img3]:
|
57 |
+
embedding = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
|
58 |
+
print(embedding['cross_modal_embeddings'][0].shape) #<class 'torch.Tensor'>
|
59 |
+
torch.save(embedding['cross_modal_embeddings'][0], img['tensor_path'] + '.pt')
|
60 |
+
|
61 |
+
save_embeddings()
|
62 |
+
|
s4_calculate_distance.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from numpy.linalg import norm
|
3 |
+
import torch
|
4 |
+
from IPython.display import display
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
8 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
9 |
+
|
10 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
11 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
12 |
+
|
13 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
14 |
+
cap3='a cat laying down stretched out near a laptop'
|
15 |
+
|
16 |
+
img1 = {
|
17 |
+
'flickr_url': url1,
|
18 |
+
'caption': cap1,
|
19 |
+
'image_path' : './shared_data/motorcycle_1.jpg',
|
20 |
+
'tensor_path' : './shared_data/motorcycle_1'
|
21 |
+
}
|
22 |
+
|
23 |
+
img2 = {
|
24 |
+
'flickr_url': url2,
|
25 |
+
'caption': cap2,
|
26 |
+
'image_path' : './shared_data/motorcycle_2.jpg',
|
27 |
+
'tensor_path' : './shared_data/motorcycle_2'
|
28 |
+
}
|
29 |
+
|
30 |
+
img3 = {
|
31 |
+
'flickr_url' : url3,
|
32 |
+
'caption': cap3,
|
33 |
+
'image_path' : './shared_data/cat_1.jpg',
|
34 |
+
'tensor_path' : './shared_data/cat_1'
|
35 |
+
}
|
36 |
+
|
37 |
+
def load_tensor(path):
|
38 |
+
return torch.load(path)
|
39 |
+
|
40 |
+
def load_embeddings():
|
41 |
+
ex1_embed = load_tensor(img1['tensor_path'] + '.pt')
|
42 |
+
ex2_embed = load_tensor(img2['tensor_path'] + '.pt')
|
43 |
+
ex3_embed = load_tensor(img3['tensor_path'] + '.pt')
|
44 |
+
return ex1_embed.data.numpy(), ex2_embed.data.numpy(), ex3_embed.data.numpy()
|
45 |
+
|
46 |
+
def cosine_similarity(vec1, vec2):
|
47 |
+
similarity = np.dot(vec1,vec2)/(norm(vec1)*norm(vec2))
|
48 |
+
return similarity
|
49 |
+
|
50 |
+
def calculate_cosine_distance():
|
51 |
+
ex1_embed, ex2_embed, ex3_embed = load_embeddings()
|
52 |
+
similarity1 = cosine_similarity(ex1_embed, ex2_embed)
|
53 |
+
similarity2 = cosine_similarity(ex1_embed, ex3_embed)
|
54 |
+
similarity3 = cosine_similarity(ex2_embed, ex3_embed)
|
55 |
+
return [similarity1, similarity2, similarity3]
|
56 |
+
|
57 |
+
def calcuate_euclidean_distance():
|
58 |
+
ex1_embed, ex2_embed, ex3_embed = load_embeddings()
|
59 |
+
distance1 = cv2.norm(ex1_embed,ex2_embed, cv2.NORM_L2)
|
60 |
+
distance2 = cv2.norm(ex1_embed,ex3_embed, cv2.NORM_L2)
|
61 |
+
distance3 = cv2.norm(ex2_embed,ex3_embed, cv2.NORM_L2)
|
62 |
+
return [distance1, distance2, distance3]
|
63 |
+
|
64 |
+
def show_cosine_distance():
|
65 |
+
distances = calculate_cosine_distance()
|
66 |
+
print("Cosine similarity between ex1_embeded and ex2_embeded is:")
|
67 |
+
display(distances[0])
|
68 |
+
print("Cosine similarity between ex1_embeded and ex3_embeded is:")
|
69 |
+
display(distances[1])
|
70 |
+
print("Cosine similarity between ex2_embeded and ex2_embeded is:")
|
71 |
+
display(distances[2])
|
72 |
+
|
73 |
+
def show_euclidean_distance():
|
74 |
+
distances = calcuate_euclidean_distance()
|
75 |
+
print("Euclidean distance between ex1_embeded and ex2_embeded is:")
|
76 |
+
display(distances[0])
|
77 |
+
print("Euclidean distance between ex1_embeded and ex3_embeded is:")
|
78 |
+
display(distances[1])
|
79 |
+
print("Euclidean distance between ex2_embeded and ex2_embeded is:")
|
80 |
+
display(distances[2])
|
81 |
+
|
82 |
+
show_cosine_distance()
|
83 |
+
show_euclidean_distance()
|
s5-how-to-umap.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from IPython.display import display
|
2 |
+
from umap import UMAP
|
3 |
+
from sklearn.preprocessing import MinMaxScaler
|
4 |
+
import pandas as pd
|
5 |
+
from tqdm import tqdm
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import seaborn as sns
|
8 |
+
from s2_download_data import load_data_from_huggingface
|
9 |
+
from utils import prepare_dataset_for_umap_visualization as data_prep
|
10 |
+
from s3_data_to_vector_embedding import bt_embeddings_from_local
|
11 |
+
import random
|
12 |
+
|
13 |
+
# prompt templates
|
14 |
+
templates = [
|
15 |
+
'a picture of {}',
|
16 |
+
'an image of {}',
|
17 |
+
'a nice {}',
|
18 |
+
'a beautiful {}',
|
19 |
+
]
|
20 |
+
# function helps to prepare list image-text pairs from the first [test_size] data
|
21 |
+
def data_prep(hf_dataset_name, templates=templates, test_size=1000):
|
22 |
+
# load Huggingface dataset (download if needed)
|
23 |
+
|
24 |
+
#dataset = load_dataset(hf_dataset_name, trust_remote_code=True)
|
25 |
+
dataset = load_data_from_huggingface(hf_dataset_name)
|
26 |
+
# split dataset with specific test_size
|
27 |
+
train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
|
28 |
+
# get the test dataset
|
29 |
+
test_dataset = train_test_dataset['test']
|
30 |
+
img_txt_pairs = []
|
31 |
+
for i in range(len(test_dataset)):
|
32 |
+
img_txt_pairs.append({
|
33 |
+
'caption' : templates[random.randint(0, len(templates)-1)],
|
34 |
+
'pil_img' : test_dataset[i]['image']
|
35 |
+
})
|
36 |
+
return img_txt_pairs
|
37 |
+
|
38 |
+
# prepare image_text pairs
|
39 |
+
|
40 |
+
# for the first 50 data of Huggingface dataset
|
41 |
+
# "yashikota/cat-image-dataset"
|
42 |
+
cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset",
|
43 |
+
"cat", test_size=50)
|
44 |
+
|
45 |
+
# for the first 50 data of Huggingface dataset
|
46 |
+
# "tanganke/stanford_cars"
|
47 |
+
car_img_txt_pairs = data_prep("tanganke/stanford_cars",
|
48 |
+
"car", test_size=50)
|
49 |
+
|
50 |
+
# display an example of a cat image-text pair data
|
51 |
+
display(cat_img_txt_pairs[0]['caption'])
|
52 |
+
display(cat_img_txt_pairs[0]['pil_img'])
|
53 |
+
|
54 |
+
# display an example of a car image-text pair data
|
55 |
+
display(car_img_txt_pairs[0]['caption'])
|
56 |
+
display(car_img_txt_pairs[0]['pil_img'])
|
57 |
+
|
58 |
+
# compute BridgeTower embeddings for cat image-text pairs
|
59 |
+
def load_cat_and_car_embeddings():
|
60 |
+
|
61 |
+
def load_embeddings(img_txt_pair):
|
62 |
+
pil_img = img_txt_pair['pil_img']
|
63 |
+
caption = img_txt_pair['caption']
|
64 |
+
return bt_embeddings_from_local(caption, pil_img)
|
65 |
+
|
66 |
+
cat_embeddings = []
|
67 |
+
for img_txt_pair in tqdm(
|
68 |
+
cat_img_txt_pairs,
|
69 |
+
total=len(cat_img_txt_pairs)
|
70 |
+
):
|
71 |
+
pil_img = img_txt_pair['pil_img']
|
72 |
+
caption = img_txt_pair['caption']
|
73 |
+
embedding =load_embeddings(caption, pil_img)
|
74 |
+
cat_embeddings.append(embedding)
|
75 |
+
|
76 |
+
# compute BridgeTower embeddings for car image-text pairs
|
77 |
+
car_embeddings = []
|
78 |
+
for img_txt_pair in tqdm(
|
79 |
+
car_img_txt_pairs,
|
80 |
+
total=len(car_img_txt_pairs)
|
81 |
+
):
|
82 |
+
pil_img = img_txt_pair['pil_img']
|
83 |
+
caption = img_txt_pair['caption']
|
84 |
+
embedding = load_embeddings(caption, pil_img)
|
85 |
+
car_embeddings.append(embedding)
|
86 |
+
return cat_embeddings, car_embeddings
|
87 |
+
|
88 |
+
|
89 |
+
# function transforms high-dimension vectors to 2D vectors using UMAP
|
90 |
+
def dimensionality_reduction(embed_arr, label):
|
91 |
+
X_scaled = MinMaxScaler().fit_transform(embed_arr)
|
92 |
+
print(X_scaled)
|
93 |
+
mapper = UMAP(n_components=2, metric="cosine").fit(X_scaled)
|
94 |
+
df_emb = pd.DataFrame(mapper.embedding_, columns=["X", "Y"])
|
95 |
+
df_emb["label"] = label
|
96 |
+
print(df_emb)
|
97 |
+
return df_emb
|
98 |
+
|
99 |
+
def show_umap_visualization():
|
100 |
+
def reduce_dimensions():
|
101 |
+
cat_embeddings, car_embeddings = load_cat_and_car_embeddings()
|
102 |
+
# stacking embeddings of cat and car examples into one numpy array
|
103 |
+
all_embeddings = np.concatenate([cat_embeddings, car_embeddings])
|
104 |
+
|
105 |
+
# prepare labels for the 3 examples
|
106 |
+
labels = ['cat'] * len(cat_embeddings) + ['car'] * len(car_embeddings)
|
107 |
+
|
108 |
+
# compute dimensionality reduction for the 3 examples
|
109 |
+
reduced_dim_emb = dimensionality_reduction(all_embeddings, labels)
|
110 |
+
return reduced_dim_emb
|
111 |
+
|
112 |
+
reduced_dim_emb = reduce_dimensions()
|
113 |
+
# Plot the centroids against the cluster
|
114 |
+
fig, ax = plt.subplots(figsize=(8,6)) # Set figsize
|
115 |
+
|
116 |
+
sns.set_style("whitegrid", {'axes.grid' : False})
|
117 |
+
sns.scatterplot(data=reduced_dim_emb,
|
118 |
+
x=reduced_dim_emb['X'],
|
119 |
+
y=reduced_dim_emb['Y'],
|
120 |
+
hue='label',
|
121 |
+
palette='bright')
|
122 |
+
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
|
123 |
+
plt.title('Scatter plot of images of cats and cars using UMAP')
|
124 |
+
plt.xlabel('X')
|
125 |
+
plt.ylabel('Y')
|
126 |
+
plt.show()
|