William9999 commited on
Commit
583a375
·
verified ·
1 Parent(s): e6f51a4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +366 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Import necessary libraries: here you will use streamlit library to run a text search demo, please make sure to install it.
2
+ # !pip install streamlit sentence-transformers gdown matplotlib
3
+ # !pip install pyngrok
4
+ import subprocess
5
+
6
+ subprocess.run([
7
+ "pip", "install",
8
+ "streamlit",
9
+ "sentence-transformers",
10
+ "gdown",
11
+ "matplotlib",
12
+ "tf-keras" # 添加 tf-keras 到依赖列表
13
+ ], check=True)
14
+
15
+ import streamlit as st
16
+ import numpy as np
17
+ import numpy.linalg as la
18
+ import pickle
19
+ import os
20
+ import gdown
21
+ from sentence_transformers import SentenceTransformer
22
+ import matplotlib.pyplot as plt
23
+ import math
24
+ import os
25
+ import subprocess
26
+
27
+ ### Some predefined utility functions for you to load the text embeddings
28
+
29
+ # Function to Load Glove Embeddings
30
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
31
+ with open(glove_path, "rb") as f:
32
+ embeddings_dict = pickle.load(f, encoding="latin1")
33
+
34
+ return embeddings_dict
35
+
36
+ def get_model_id_gdrive(model_type):
37
+ if model_type == "25d":
38
+ word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
39
+ embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
40
+ elif model_type == "50d":
41
+ embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
42
+ word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
43
+ elif model_type == "100d":
44
+ word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
45
+ embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
46
+
47
+ return word_index_id, embeddings_id
48
+
49
+ def download_glove_embeddings_gdrive(model_type):
50
+ # Get glove embeddings from google drive
51
+ word_index_id, embeddings_id = get_model_id_gdrive(model_type)
52
+
53
+ # Use gdown to get files from google drive
54
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
55
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
56
+
57
+ # Download word_index pickle file
58
+ print("Downloading word index dictionary....\n")
59
+ gdown.download(id=word_index_id, output=word_index_temp, quiet=False)
60
+
61
+ # Download embeddings numpy file
62
+ print("Donwloading embedings...\n\n")
63
+ gdown.download(id=embeddings_id, output=embeddings_temp, quiet=False)
64
+
65
+ # @st.cache_data()
66
+ def load_glove_embeddings_gdrive(model_type):
67
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
68
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
69
+
70
+ # Load word index dictionary
71
+ word_index_dict = pickle.load(open(word_index_temp, "rb"), encoding="latin")
72
+
73
+ # Load embeddings numpy
74
+ embeddings = np.load(embeddings_temp)
75
+
76
+ return word_index_dict, embeddings
77
+
78
+ @st.cache_resource()
79
+ def load_sentence_transformer_model(model_name):
80
+ sentenceTransformer = SentenceTransformer(model_name)
81
+ return sentenceTransformer
82
+
83
+ def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
84
+ """
85
+ Get sentence transformer embeddings for a sentence
86
+ """
87
+ # 384 dimensional embedding
88
+ # Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
89
+ sentenceTransformer = load_sentence_transformer_model(model_name)
90
+
91
+ try:
92
+ return sentenceTransformer.encode(sentence)
93
+ except:
94
+ if model_name == "all-MiniLM-L6-v2":
95
+ return np.zeros(384)
96
+ else:
97
+ return np.zeros(512)
98
+
99
+ def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
100
+ """
101
+ Get glove embedding for a single word
102
+ """
103
+ if word.lower() in word_index_dict:
104
+ return embeddings[word_index_dict[word.lower()]]
105
+ else:
106
+ return np.zeros(int(model_type.split("d")[0]))
107
+
108
+ def get_category_embeddings(embeddings_metadata):
109
+ """
110
+ Get embeddings for each category
111
+ 1. Split categories into words
112
+ 2. Get embeddings for each word
113
+ """
114
+ model_name = embeddings_metadata["model_name"]
115
+ st.session_state["cat_embed_" + model_name] = {}
116
+ for category in st.session_state.categories.split(" "):
117
+ if model_name:
118
+ if not category in st.session_state["cat_embed_" + model_name]:
119
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
120
+ else:
121
+ if not category in st.session_state["cat_embed_" + model_name]:
122
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
123
+
124
+ def update_category_embeddings(embeddings_metadata):
125
+ """
126
+ Update embeddings for each category
127
+ """
128
+ get_category_embeddings(embeddings_metadata)
129
+
130
+ ### Plotting utility functions
131
+
132
+ def plot_piechart(sorted_cosine_scores_items):
133
+ sorted_cosine_scores = np.array([
134
+ sorted_cosine_scores_items[index][1]
135
+ for index in range(len(sorted_cosine_scores_items))
136
+ ]
137
+ )
138
+ categories = st.session_state.categories.split(" ")
139
+ categories_sorted = [
140
+ categories[sorted_cosine_scores_items[index][0]]
141
+ for index in range(len(sorted_cosine_scores_items))
142
+ ]
143
+ fig, ax = plt.subplots()
144
+ ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
145
+ st.pyplot(fig) # Figure
146
+
147
+ def plot_piechart_helper(sorted_cosine_scores_items):
148
+ sorted_cosine_scores = np.array(
149
+ [
150
+ sorted_cosine_scores_items[index][1]
151
+ for index in range(len(sorted_cosine_scores_items))
152
+ ]
153
+ )
154
+ categories = st.session_state.categories.split(" ")
155
+ categories_sorted = [
156
+ categories[sorted_cosine_scores_items[index][0]]
157
+ for index in range(len(sorted_cosine_scores_items))
158
+ ]
159
+ fig, ax = plt.subplots(figsize=(3, 3))
160
+ my_explode = np.zeros(len(categories_sorted))
161
+ my_explode[0] = 0.2
162
+ if len(categories_sorted) == 3:
163
+ my_explode[1] = 0.1 # explode this by 0.2
164
+ elif len(categories_sorted) > 3:
165
+ my_explode[2] = 0.05
166
+ ax.pie(
167
+ sorted_cosine_scores,
168
+ labels=categories_sorted,
169
+ autopct="%1.1f%%",
170
+ explode=my_explode,
171
+ )
172
+
173
+ return fig
174
+
175
+ def plot_piecharts(sorted_cosine_scores_models):
176
+ scores_list = []
177
+ categories = st.session_state.categories.split(" ")
178
+ index = 0
179
+ for model in sorted_cosine_scores_models:
180
+ scores_list.append(sorted_cosine_scores_models[model])
181
+ index += 1
182
+
183
+ if len(sorted_cosine_scores_models) == 2:
184
+ fig, (ax1, ax2) = plt.subplots(2)
185
+
186
+ categories_sorted = [
187
+ categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))
188
+ ]
189
+ sorted_scores = np.array(
190
+ [scores_list[0][index][1] for index in range(len(scores_list[0]))]
191
+ )
192
+ ax1.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
193
+
194
+ categories_sorted = [
195
+ categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))
196
+ ]
197
+ sorted_scores = np.array(
198
+ [scores_list[1][index][1] for index in range(len(scores_list[1]))]
199
+ )
200
+ ax2.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
201
+
202
+ st.pyplot(fig)
203
+
204
+ def plot_alatirchart(sorted_cosine_scores_models):
205
+ models = list(sorted_cosine_scores_models.keys())
206
+ tabs = st.tabs(models)
207
+ figs = {}
208
+ for model in models:
209
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
210
+
211
+ for index in range(len(tabs)):
212
+ with tabs[index]:
213
+ st.pyplot(figs[models[index]])
214
+
215
+ ### Your Part To Complete: Follow the instructions in each function below to complete the similarity calculation between text embeddings
216
+
217
+ # Task I: Compute Cosine Similarity
218
+ def cosine_similarity(x, y):
219
+ """
220
+ Exponentiated cosine similarity
221
+ 1. Compute cosine similarity
222
+ 2. Exponentiate cosine similarity
223
+ 3. Return exponentiated cosine similarity
224
+ (20 pts)
225
+ """
226
+ cosine_sim = np.dot(x, y) / (la.norm(x) * la.norm(y))
227
+ return np.exp(cosine_sim)
228
+
229
+ # Task II: Average Glove Embedding Calculation
230
+ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50):
231
+ """
232
+ Get averaged glove embeddings for a sentence
233
+ 1. Split sentence into words
234
+ 2. Get embeddings for each word
235
+ 3. Add embeddings for each word
236
+ 4. Divide by number of words
237
+ 5. Return averaged embeddings
238
+ (30 pts)
239
+ """
240
+ words = sentence.split()
241
+ embedding = np.zeros(int(model_type.split("d")[0]))
242
+ for word in words:
243
+ embedding += get_glove_embeddings(word, word_index_dict, embeddings, model_type)
244
+ return embedding / len(words)
245
+
246
+ # Task III: Sort the cosine similarity
247
+ def get_sorted_cosine_similarity(embeddings_metadata):
248
+ """
249
+ Get sorted cosine similarity between input sentence and categories
250
+ Steps:
251
+ 1. Get embeddings for input sentence
252
+ 2. Get embeddings for categories (if not found, update category embeddings)
253
+ 3. Compute cosine similarity between input sentence and categories
254
+ 4. Sort cosine similarity
255
+ 5. Return sorted cosine similarity
256
+ (50 pts)
257
+ """
258
+ categories = st.session_state.categories.split(" ")
259
+ cosine_sim = {}
260
+ if embeddings_metadata["embedding_model"] == "glove":
261
+ word_index_dict = embeddings_metadata["word_index_dict"]
262
+ embeddings = embeddings_metadata["embeddings"]
263
+ model_type = embeddings_metadata["model_type"]
264
+
265
+ input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
266
+ word_index_dict,
267
+ embeddings, model_type)
268
+
269
+ for index, category in enumerate(categories):
270
+ category_embedding = averaged_glove_embeddings_gdrive(category, word_index_dict, embeddings, model_type)
271
+ cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
272
+
273
+ else:
274
+ model_name = embeddings_metadata["model_name"]
275
+ if not "cat_embed_" + model_name in st.session_state:
276
+ get_category_embeddings(embeddings_metadata)
277
+
278
+ category_embeddings = st.session_state["cat_embed_" + model_name]
279
+
280
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
281
+ for index, category in enumerate(categories):
282
+ cosine_sim[index] = cosine_similarity(input_embedding, category_embeddings[category])
283
+
284
+ sorted_cosine_sim = sorted(cosine_sim.items(), key=lambda x: x[1], reverse=True)
285
+ return sorted_cosine_sim
286
+
287
+ ### Below is the main function, creating the app demo for text search engine using the text embeddings.
288
+
289
+ if __name__ == "__main__":
290
+ # Initialize session state variables
291
+ if "categories" not in st.session_state:
292
+ st.session_state["categories"] = "Flowers Colors Cars Weather Food"
293
+
294
+ if "text_search" not in st.session_state:
295
+ st.session_state["text_search"] = "Roses are red, trucks are blue, and Seattle is grey right now"
296
+
297
+ st.sidebar.title("GloVe Twitter")
298
+ st.sidebar.markdown(
299
+ """
300
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
301
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
302
+
303
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
304
+ """
305
+ )
306
+
307
+ model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d", "100d"), index=1)
308
+
309
+ st.title("Search Based Retrieval Demo")
310
+ st.subheader(
311
+ "Pass in space separated categories you want this search demo to be about."
312
+ )
313
+ st.text_input(
314
+ label="Categories", key="categories", value=st.session_state["categories"]
315
+ )
316
+
317
+ st.subheader("Pass in an input word or even a sentence")
318
+ st.text_input(
319
+ label="Input your sentence",
320
+ key="text_search",
321
+ value=st.session_state["text_search"],
322
+ )
323
+
324
+ embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
325
+ word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
326
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
327
+ with st.spinner("Downloading glove embeddings..."):
328
+ download_glove_embeddings_gdrive(model_type)
329
+
330
+ word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
331
+
332
+ if st.session_state.text_search:
333
+ embeddings_metadata = {
334
+ "embedding_model": "glove",
335
+ "word_index_dict": word_index_dict,
336
+ "embeddings": embeddings,
337
+ "model_type": model_type,
338
+ }
339
+ with st.spinner("Obtaining Cosine similarity for Glove..."):
340
+ sorted_cosine_sim_glove = get_sorted_cosine_similarity(embeddings_metadata)
341
+
342
+ embeddings_metadata = {
343
+ "embedding_model": "transformers",
344
+ "model_name": "all-MiniLM-L6-v2"
345
+ }
346
+ with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
347
+ sorted_cosine_sim_transformer = get_sorted_cosine_similarity(embeddings_metadata)
348
+
349
+ st.subheader(
350
+ "Closest word I have between: "
351
+ + st.session_state.categories
352
+ + " as per different Embeddings"
353
+ )
354
+
355
+ plot_alatirchart(
356
+ {
357
+ "glove_" + str(model_type): sorted_cosine_sim_glove,
358
+ "sentence_transformer_384": sorted_cosine_sim_transformer,
359
+ }
360
+ )
361
+
362
+ st.write("")
363
+ st.write(
364
+ "Demo developed by [Your Name](https://www.linkedin.com/in/your_id/ - Optional)"
365
+ )
366
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gdown==5.2.0
2
+ matplotlib==3.9.2
3
+ sentence-transformers==3.4.0
4
+ streamlit==1.30.0
5
+ tf-keras