nn function now compares vectors of target word only with vectors within the same model
Browse files- app.py +33 -34
- word2vec.py +67 -6
app.py
CHANGED
|
@@ -12,6 +12,9 @@ from streamlit_tags import st_tags, st_tags_sidebar
|
|
| 12 |
|
| 13 |
st.set_page_config(page_title="Ancient Greek Word2Vec", layout="centered")
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
# Horizontal menu
|
| 16 |
active_tab = option_menu(None, ["Nearest neighbours", "Cosine similarity", "3D graph", 'Dictionary'],
|
| 17 |
menu_icon="cast", default_index=0, orientation="horizontal")
|
|
@@ -29,59 +32,55 @@ if active_tab == "Nearest neighbours":
|
|
| 29 |
all_words = load_compressed_word_list(compressed_word_list_filename)
|
| 30 |
eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
with st.container():
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
|
| 41 |
-
with col2:
|
| 42 |
-
time_slice = st.selectbox("Time slice", eligible_models)
|
| 43 |
|
| 44 |
models = st.multiselect(
|
| 45 |
"Select models to search for neighbours",
|
| 46 |
-
|
| 47 |
)
|
| 48 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
| 49 |
|
| 50 |
-
nearest_neighbours_button = st.button("Find nearest neighbours")
|
| 51 |
|
| 52 |
# If the button to calculate nearest neighbours is clicked
|
| 53 |
-
if
|
| 54 |
-
|
| 55 |
-
# Rewrite timeslices to model names: Archaic -> archaic_cbow
|
| 56 |
-
if time_slice == 'Hellenistic':
|
| 57 |
-
time_slice = 'hellen'
|
| 58 |
-
elif time_slice == 'Early Roman':
|
| 59 |
-
time_slice = 'early_roman'
|
| 60 |
-
elif time_slice == 'Late Roman':
|
| 61 |
-
time_slice = 'late_roman'
|
| 62 |
-
|
| 63 |
-
time_slice = time_slice.lower() + "_cbow"
|
| 64 |
-
|
| 65 |
|
| 66 |
# Check if all fields are filled in
|
| 67 |
-
if validate_nearest_neighbours(word,
|
| 68 |
st.error('Please fill in all fields')
|
| 69 |
else:
|
| 70 |
# Rewrite models to list of all loaded models
|
| 71 |
models = load_selected_models(models)
|
| 72 |
|
| 73 |
-
nearest_neighbours = get_nearest_neighbours(word,
|
| 74 |
-
|
| 75 |
-
df = pd.DataFrame(
|
| 76 |
-
nearest_neighbours,
|
| 77 |
-
columns=["Word", "Time slice", "Similarity"],
|
| 78 |
-
index = range(1, len(nearest_neighbours) + 1)
|
| 79 |
-
)
|
| 80 |
-
st.table(df)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
|
|
|
| 83 |
# Store content in a temporary file
|
| 84 |
-
tmp_file = store_df_in_temp_file(
|
| 85 |
|
| 86 |
# Open the temporary file and read its content
|
| 87 |
with open(tmp_file, "rb") as file:
|
|
@@ -91,7 +90,7 @@ if active_tab == "Nearest neighbours":
|
|
| 91 |
st.download_button(
|
| 92 |
"Download results",
|
| 93 |
data=file_byte,
|
| 94 |
-
file_name = f'nearest_neighbours_{word}
|
| 95 |
mime='application/octet-stream'
|
| 96 |
)
|
| 97 |
|
|
|
|
| 12 |
|
| 13 |
st.set_page_config(page_title="Ancient Greek Word2Vec", layout="centered")
|
| 14 |
|
| 15 |
+
def click_nn_button():
|
| 16 |
+
st.session_state.nearest_neighbours = not st.session_state.nearest_neighbours
|
| 17 |
+
|
| 18 |
# Horizontal menu
|
| 19 |
active_tab = option_menu(None, ["Nearest neighbours", "Cosine similarity", "3D graph", 'Dictionary'],
|
| 20 |
menu_icon="cast", default_index=0, orientation="horizontal")
|
|
|
|
| 32 |
all_words = load_compressed_word_list(compressed_word_list_filename)
|
| 33 |
eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
|
| 34 |
|
| 35 |
+
if 'nearest_neighbours' not in st.session_state:
|
| 36 |
+
st.session_state.nearest_neighbours = False
|
| 37 |
+
|
| 38 |
with st.container():
|
| 39 |
+
|
| 40 |
+
word = st.multiselect("Enter a word", all_words, max_selections=1)
|
| 41 |
+
if len(word) > 0:
|
| 42 |
+
word = word[0]
|
| 43 |
+
|
| 44 |
+
# Check which models contain the word
|
| 45 |
+
eligible_models = check_word_in_models(word)
|
| 46 |
|
|
|
|
|
|
|
| 47 |
|
| 48 |
models = st.multiselect(
|
| 49 |
"Select models to search for neighbours",
|
| 50 |
+
eligible_models
|
| 51 |
)
|
| 52 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
| 53 |
|
| 54 |
+
nearest_neighbours_button = st.button("Find nearest neighbours", on_click = click_nn_button)
|
| 55 |
|
| 56 |
# If the button to calculate nearest neighbours is clicked
|
| 57 |
+
if st.session_state.nearest_neighbours:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# Check if all fields are filled in
|
| 60 |
+
if validate_nearest_neighbours(word, n, models) == False:
|
| 61 |
st.error('Please fill in all fields')
|
| 62 |
else:
|
| 63 |
# Rewrite models to list of all loaded models
|
| 64 |
models = load_selected_models(models)
|
| 65 |
|
| 66 |
+
nearest_neighbours = get_nearest_neighbours(word, n, models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
all_dfs = []
|
| 69 |
+
|
| 70 |
+
# Create dataframes
|
| 71 |
+
for model in nearest_neighbours.keys():
|
| 72 |
+
st.write(f"### {model}")
|
| 73 |
+
df = pd.DataFrame(
|
| 74 |
+
nearest_neighbours[model],
|
| 75 |
+
columns = ['Word', 'Cosine Similarity']
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
all_dfs.append((model, df))
|
| 79 |
+
st.table(df)
|
| 80 |
|
| 81 |
+
|
| 82 |
# Store content in a temporary file
|
| 83 |
+
tmp_file = store_df_in_temp_file(all_dfs)
|
| 84 |
|
| 85 |
# Open the temporary file and read its content
|
| 86 |
with open(tmp_file, "rb") as file:
|
|
|
|
| 90 |
st.download_button(
|
| 91 |
"Download results",
|
| 92 |
data=file_byte,
|
| 93 |
+
file_name = f'nearest_neighbours_{word}_TEST.xlsx',
|
| 94 |
mime='application/octet-stream'
|
| 95 |
)
|
| 96 |
|
word2vec.py
CHANGED
|
@@ -148,11 +148,11 @@ def get_cosine_similarity_one_word(word, time_slice1, time_slice2):
|
|
| 148 |
|
| 149 |
|
| 150 |
|
| 151 |
-
def validate_nearest_neighbours(word,
|
| 152 |
'''
|
| 153 |
Validate the input of the nearest neighbours function
|
| 154 |
'''
|
| 155 |
-
if word == '' or
|
| 156 |
return False
|
| 157 |
return True
|
| 158 |
|
|
@@ -198,7 +198,7 @@ def convert_time_name_to_model(time_name):
|
|
| 198 |
elif time_name == 'archaic':
|
| 199 |
return 'Archaic'
|
| 200 |
|
| 201 |
-
def
|
| 202 |
'''
|
| 203 |
Return the nearest neighbours of a word
|
| 204 |
|
|
@@ -243,6 +243,51 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
|
|
| 243 |
|
| 244 |
|
| 245 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
def get_nearest_neighbours_vectors(word, time_slice_model, n=15):
|
|
@@ -287,7 +332,7 @@ def write_to_file(data):
|
|
| 287 |
return temp_file_path
|
| 288 |
|
| 289 |
|
| 290 |
-
def store_df_in_temp_file(
|
| 291 |
'''
|
| 292 |
Store the dataframe in a temporary file
|
| 293 |
'''
|
|
@@ -300,9 +345,25 @@ def store_df_in_temp_file(df):
|
|
| 300 |
# Create random tmp file name
|
| 301 |
_, temp_file_path = tempfile.mkstemp(prefix="temp_", suffix=".xlsx", dir=temp_dir)
|
| 302 |
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
with pd.ExcelWriter(temp_file_path, engine='xlsxwriter') as writer:
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
return temp_file_path
|
| 308 |
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
|
| 151 |
+
def validate_nearest_neighbours(word, n, models):
|
| 152 |
'''
|
| 153 |
Validate the input of the nearest neighbours function
|
| 154 |
'''
|
| 155 |
+
if word == '' or n == '' or models == []:
|
| 156 |
return False
|
| 157 |
return True
|
| 158 |
|
|
|
|
| 198 |
elif time_name == 'archaic':
|
| 199 |
return 'Archaic'
|
| 200 |
|
| 201 |
+
def get_nearest_neighbours2(word, n=10, models=load_all_models()):
|
| 202 |
'''
|
| 203 |
Return the nearest neighbours of a word
|
| 204 |
|
|
|
|
| 243 |
|
| 244 |
|
| 245 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_nearest_neighbours(target_word, n=10, models=load_all_models()):
|
| 249 |
+
"""
|
| 250 |
+
Return the nearest neighbours of a word for the given models
|
| 251 |
+
|
| 252 |
+
word: the word for which the nearest neighbours are calculated
|
| 253 |
+
n: the number of nearest neighbours to return (default: 10)
|
| 254 |
+
models: list of tuples with the name of the time slice and the word2vec model (default: all in ./models)
|
| 255 |
+
|
| 256 |
+
Return: { 'model_name': [(word, cosine_similarity), ...], ... }
|
| 257 |
+
"""
|
| 258 |
+
nearest_neighbours = {}
|
| 259 |
+
|
| 260 |
+
# Iterate over models and compute nearest neighbours
|
| 261 |
+
for model in models:
|
| 262 |
+
model_neighbours = []
|
| 263 |
+
model_name = convert_model_to_time_name(model[0])
|
| 264 |
+
model = model[1]
|
| 265 |
+
vector_1 = get_word_vector(model, target_word)
|
| 266 |
+
|
| 267 |
+
# Iterate over all words of the model
|
| 268 |
+
for word, index in model.wv.key_to_index.items():
|
| 269 |
+
vector_2 = get_word_vector(model, word)
|
| 270 |
+
cosine_sim = cosine_similarity(vector_1, vector_2)
|
| 271 |
+
|
| 272 |
+
# If the list of nearest neighbours is not full yet, add the current word
|
| 273 |
+
if len(model_neighbours) < n:
|
| 274 |
+
model_neighbours.append((word, cosine_sim))
|
| 275 |
+
else:
|
| 276 |
+
# If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
|
| 277 |
+
smallest_neighbour = min(model_neighbours, key=lambda x: x[1])
|
| 278 |
+
if cosine_sim > smallest_neighbour[1]:
|
| 279 |
+
model_neighbours.remove(smallest_neighbour)
|
| 280 |
+
model_neighbours.append((word, cosine_sim))
|
| 281 |
+
|
| 282 |
+
# Sort the nearest neighbours by cosine similarity
|
| 283 |
+
model_neighbours = sorted(model_neighbours, key=lambda x: x[1], reverse=True)
|
| 284 |
+
|
| 285 |
+
# Add the model name and the nearest neighbours to the dictionary
|
| 286 |
+
nearest_neighbours[model_name] = model_neighbours
|
| 287 |
+
|
| 288 |
+
return nearest_neighbours
|
| 289 |
+
|
| 290 |
+
|
| 291 |
|
| 292 |
|
| 293 |
def get_nearest_neighbours_vectors(word, time_slice_model, n=15):
|
|
|
|
| 332 |
return temp_file_path
|
| 333 |
|
| 334 |
|
| 335 |
+
def store_df_in_temp_file(all_dfs):
|
| 336 |
'''
|
| 337 |
Store the dataframe in a temporary file
|
| 338 |
'''
|
|
|
|
| 345 |
# Create random tmp file name
|
| 346 |
_, temp_file_path = tempfile.mkstemp(prefix="temp_", suffix=".xlsx", dir=temp_dir)
|
| 347 |
|
| 348 |
+
|
| 349 |
+
# Concatenate all dataframes
|
| 350 |
+
df = pd.concat([df for _, df in all_dfs], axis=1, keys=[model for model, _ in all_dfs])
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# Create an ExcelWriter object
|
| 354 |
with pd.ExcelWriter(temp_file_path, engine='xlsxwriter') as writer:
|
| 355 |
+
# Create a new sheet
|
| 356 |
+
worksheet = writer.book.add_worksheet('Results')
|
| 357 |
+
|
| 358 |
+
# Write text before DataFrames
|
| 359 |
+
start_row = 0
|
| 360 |
+
for model, df in all_dfs:
|
| 361 |
+
# Write model name as text
|
| 362 |
+
worksheet.write(start_row, 0, f"Model: {model}")
|
| 363 |
+
# Write DataFrame
|
| 364 |
+
df.to_excel(writer, sheet_name='Results', index=False, startrow=start_row + 1, startcol=0)
|
| 365 |
+
# Update start_row for the next model
|
| 366 |
+
start_row += df.shape[0] + 3 # Add some space between models
|
| 367 |
|
| 368 |
return temp_file_path
|
| 369 |
|