88hours commited on
Commit
d1fd97e
·
1 Parent(s): 60e35a0

Remove Hugging Face download

Browse files
requirements.txt CHANGED
@@ -7,4 +7,5 @@ youtube_transcript_api
7
  torch
8
  transformers
9
  matplotlib
10
- seaborn
 
 
7
  torch
8
  transformers
9
  matplotlib
10
+ seaborn
11
+ datasets
s2_download_data.py CHANGED
@@ -1,9 +1,6 @@
1
  import requests
2
  from PIL import Image
3
  from IPython.display import display
4
- import huggingface_hub
5
- from huggingface_hub import list_datasets
6
- from huggingface_hub import HfApi
7
 
8
  # You can use your own uploaded images and captions.
9
  # You will be responsible for the legal use of images that
@@ -49,19 +46,4 @@ def download_images():
49
  caption = img['caption']
50
  display(image)
51
  print(caption)
52
-
53
- def load_data_from_huggingface(hf_dataset_name):
54
-
55
- api = HfApi()
56
-
57
- #list models from huggingface
58
-
59
- #models = list(api.list_models())
60
-
61
- #list datasets from huggingface
62
-
63
- #datasets = list(api.list_datasets())
64
-
65
-
66
- return api.list_datasets(search=hf_dataset_name)
67
-
 
1
  import requests
2
  from PIL import Image
3
  from IPython.display import display
 
 
 
4
 
5
  # You can use your own uploaded images and captions.
6
  # You will be responsible for the legal use of images that
 
46
  caption = img['caption']
47
  display(image)
48
  print(caption)
49
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s3_data_to_vector_embedding.py CHANGED
@@ -58,5 +58,4 @@ def save_embeddings():
58
  print(embedding['cross_modal_embeddings'][0].shape) #<class 'torch.Tensor'>
59
  torch.save(embedding['cross_modal_embeddings'][0], img['tensor_path'] + '.pt')
60
 
61
- save_embeddings()
62
 
 
58
  print(embedding['cross_modal_embeddings'][0].shape) #<class 'torch.Tensor'>
59
  torch.save(embedding['cross_modal_embeddings'][0], img['tensor_path'] + '.pt')
60
 
 
61
 
s5-how-to-umap.py CHANGED
@@ -6,12 +6,13 @@ import pandas as pd
6
  from tqdm import tqdm
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
- from s2_download_data import load_data_from_huggingface
10
- from utils import prepare_dataset_for_umap_visualization as data_prep
11
  from s3_data_to_vector_embedding import bt_embeddings_from_local
12
  import random
13
  import numpy as np
14
  import torch
 
 
 
15
  # prompt templates
16
  templates = [
17
  'a picture of {}',
@@ -23,10 +24,20 @@ templates = [
23
  def data_prep(hf_dataset_name, templates=templates, test_size=1000):
24
  # load Huggingface dataset (download if needed)
25
 
26
- #dataset = load_dataset(hf_dataset_name, trust_remote_code=True)
27
- dataset = load_data_from_huggingface(hf_dataset_name)
 
 
 
 
 
 
 
 
 
28
  # split dataset with specific test_size
29
- train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
 
30
  # get the test dataset
31
  test_dataset = train_test_dataset['test']
32
  img_txt_pairs = []
@@ -36,30 +47,47 @@ def data_prep(hf_dataset_name, templates=templates, test_size=1000):
36
  'pil_img' : test_dataset[i]['image']
37
  })
38
  return img_txt_pairs
39
-
40
- # compute BridgeTower embeddings for cat image-text pairs
41
- def load_cat_and_car_embeddings():
42
-
43
- # prepare image_text pairs
44
 
45
- # for the first 50 data of Huggingface dataset
46
- # "yashikota/cat-image-dataset"
47
- cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset",
48
- "cat", test_size=50)
49
 
50
- # for the first 50 data of Huggingface dataset
51
- # "tanganke/stanford_cars"
52
- car_img_txt_pairs = data_prep("tanganke/stanford_cars",
53
- "car", test_size=50)
 
 
 
 
 
54
 
55
- # display an example of a cat image-text pair data
56
- display(cat_img_txt_pairs[0]['caption'])
57
- display(cat_img_txt_pairs[0]['pil_img'])
58
 
59
- # display an example of a car image-text pair data
60
- display(car_img_txt_pairs[0]['caption'])
61
- display(car_img_txt_pairs[0]['pil_img'])
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
63
  def save_embeddings(embedding, path):
64
  torch.save(embedding, path)
65
 
@@ -68,18 +96,18 @@ def load_cat_and_car_embeddings():
68
  caption = img_txt_pair['caption']
69
  return bt_embeddings_from_local(caption, pil_img)
70
 
71
- def load_all_embeddings_from_image_text_pairs(file_name):
72
- cat_embeddings = []
73
  for img_txt_pair in tqdm(
74
- cat_img_txt_pairs,
75
- total=len(cat_img_txt_pairs)
76
  ):
77
  pil_img = img_txt_pair['pil_img']
78
  caption = img_txt_pair['caption']
79
  embedding = load_embeddings(caption, pil_img)
80
- cat_embeddings.append(embedding)
81
  save_embeddings(cat_embeddings, file_name)
82
- return cat_embeddings
83
 
84
 
85
  cat_embeddings = []
@@ -87,12 +115,12 @@ def load_cat_and_car_embeddings():
87
  if (path.exists('./shared_data/cat_embeddings.pt')):
88
  cat_embeddings = torch.load('./shared_data/cat_embeddings.pt')
89
  else:
90
- cat_embeddings = load_all_embeddings_from_image_text_pairs('./shared_data/cat_embeddings.pt')
91
 
92
  if (path.exists('./shared_data/car_embeddings.pt')):
93
  car_embeddings = torch.load('./shared_data/car_embeddings.pt')
94
  else:
95
- car_embeddings = load_all_embeddings_from_image_text_pairs('./shared_data/car_embeddings.pt')
96
 
97
  return cat_embeddings, car_embeddings
98
 
@@ -135,5 +163,15 @@ def show_umap_visualization():
135
  plt.xlabel('X')
136
  plt.ylabel('Y')
137
  plt.show()
138
-
139
- load_cat_and_car_embeddings()
 
 
 
 
 
 
 
 
 
 
 
6
  from tqdm import tqdm
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
 
 
9
  from s3_data_to_vector_embedding import bt_embeddings_from_local
10
  import random
11
  import numpy as np
12
  import torch
13
+ from sklearn.model_selection import train_test_split
14
+ from datasets import load_dataset
15
+
16
  # prompt templates
17
  templates = [
18
  'a picture of {}',
 
24
  def data_prep(hf_dataset_name, templates=templates, test_size=1000):
25
  # load Huggingface dataset (download if needed)
26
 
27
+ dataset = load_dataset(hf_dataset_name, trust_remote_code=True)
28
+ #dataset = load_data_from_huggingface(hf_dataset_name)
29
+ def display_list(lst, indent=0):
30
+ for item in lst:
31
+ if isinstance(item, list):
32
+ display_list(item, indent + 2)
33
+ else:
34
+ print(' ' * indent + str(item))
35
+
36
+ # Example usage:
37
+ display_list(dataset)
38
  # split dataset with specific test_size
39
+ train_test_dataset = train_test_split(dataset, test_size=test_size)
40
+
41
  # get the test dataset
42
  test_dataset = train_test_dataset['test']
43
  img_txt_pairs = []
 
47
  'pil_img' : test_dataset[i]['image']
48
  })
49
  return img_txt_pairs
 
 
 
 
 
50
 
51
+ # load cat and car image-text pairs
52
+ def load_pairs_from_dataset(dataset_name, file_name):
 
 
53
 
54
+ def load_dataset_locally(file_name):
55
+ with open(file_name, 'r') as f:
56
+ dataset = f.readlines()
57
+ return dataset
58
+
59
+ def save_dataset_locally(dataset_list, file_name):
60
+ with open(file_name, 'w') as f:
61
+ for item in dataset_list:
62
+ f.write("%s\n" % item)
63
 
 
 
 
64
 
65
+ def check_dataset_locally(file_name):
66
+ if (path.exists(file_name)):
67
+ return True
68
+ return False
69
+
70
+ if (check_dataset_locally(file_name)):
71
+ print('Dataset already exists')
72
+ img_txt_pairs = load_dataset_locally(file_name)
73
+ else:
74
+ print('Downloading dataset')
75
+
76
+ img_txt_pairs = data_prep(dataset_name, test_size=50)
77
+ save_dataset_locally(img_txt_pairs, file_name)
78
+ return img_txt_pairs
79
+
80
 
81
+ def load_all_dataset():
82
+
83
+ cat_img_txt_pairs = load_pairs_from_dataset("yashikota/cat-image-dataset", './shared_data/cat_img_txt_pairs.txt')
84
+ car_img_txt_pairs = load_pairs_from_dataset("tanganke/stanford_cars", './shared_data/car_img_txt_pairs.txt')
85
+
86
+ return cat_img_txt_pairs, car_img_txt_pairs
87
+ # compute BridgeTower embeddings for cat image-text pairs
88
+ def load_cat_and_car_embeddings():
89
+ # prepare image_text pairs
90
+ cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
91
  def save_embeddings(embedding, path):
92
  torch.save(embedding, path)
93
 
 
96
  caption = img_txt_pair['caption']
97
  return bt_embeddings_from_local(caption, pil_img)
98
 
99
+ def load_all_embeddings_from_image_text_pairs(img_txt_pairs, file_name):
100
+ embeddings = []
101
  for img_txt_pair in tqdm(
102
+ img_txt_pairs,
103
+ total=len(img_txt_pairs)
104
  ):
105
  pil_img = img_txt_pair['pil_img']
106
  caption = img_txt_pair['caption']
107
  embedding = load_embeddings(caption, pil_img)
108
+ embeddings.append(embedding)
109
  save_embeddings(cat_embeddings, file_name)
110
+ return embeddings
111
 
112
 
113
  cat_embeddings = []
 
115
  if (path.exists('./shared_data/cat_embeddings.pt')):
116
  cat_embeddings = torch.load('./shared_data/cat_embeddings.pt')
117
  else:
118
+ cat_embeddings = load_all_embeddings_from_image_text_pairs(cat_img_txt_pairs, './shared_data/cat_embeddings.pt')
119
 
120
  if (path.exists('./shared_data/car_embeddings.pt')):
121
  car_embeddings = torch.load('./shared_data/car_embeddings.pt')
122
  else:
123
+ car_embeddings = load_all_embeddings_from_image_text_pairs(car_img_txt_pairs, './shared_data/car_embeddings.pt')
124
 
125
  return cat_embeddings, car_embeddings
126
 
 
163
  plt.xlabel('X')
164
  plt.ylabel('Y')
165
  plt.show()
166
+
167
+ def run():
168
+ cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
169
+ # display an example of a cat image-text pair data
170
+ display(cat_img_txt_pairs[0]['caption'])
171
+ display(cat_img_txt_pairs[0]['pil_img'])
172
+
173
+ # display an example of a car image-text pair data
174
+ display(car_img_txt_pairs[0]['caption'])
175
+ display(car_img_txt_pairs[0]['pil_img'])
176
+
177
+ run()