Spaces:
Runtime error
Runtime error
chore: init demo
Browse files- .gitignore +3 -0
- app.py +248 -0
- constants.py +217 -0
- requirements.txt +3 -0
- utils.py +270 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.gradio
|
2 |
+
__pycache__
|
3 |
+
plots
|
app.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
import os
|
6 |
+
from sklearn.manifold import TSNE
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from utils import (plot_distances_tsne,
|
9 |
+
plot_distances_umap,
|
10 |
+
cluster_languages_hdbscan,
|
11 |
+
cluster_languages_kmeans,
|
12 |
+
plot_mst,
|
13 |
+
cluster_languages_by_families,
|
14 |
+
cluster_languages_by_subfamilies,
|
15 |
+
filter_languages_by_families)
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
|
19 |
+
with open("../../results/languages_list.pkl", "rb") as f:
|
20 |
+
languages = pickle.load(f)
|
21 |
+
|
22 |
+
DATASETS = ["wikimedia/wikipedia", "uonlp/CulturaX", "HuggingFaceFW/fineweb-2"]
|
23 |
+
MODELS = ["mistralai/Mistral-7B-v0.1", "google/gemma-3-4b-pt", "meta-llama/Llama-3.2-1B"]
|
24 |
+
|
25 |
+
distance_matrices = {
|
26 |
+
dataset: {
|
27 |
+
model: np.load(os.path.join("../../results", dataset, model, "distances_matrix.npy"))
|
28 |
+
for model in MODELS
|
29 |
+
}
|
30 |
+
for dataset in DATASETS
|
31 |
+
}
|
32 |
+
|
33 |
+
average_distances_matrix = np.load("../../results/average_distances_matrix.npy")
|
34 |
+
|
35 |
+
|
36 |
+
def filter_languages_nan(model, dataset, use_average):
|
37 |
+
if use_average:
|
38 |
+
matrix = average_distances_matrix
|
39 |
+
else:
|
40 |
+
matrix = distance_matrices[dataset][model]
|
41 |
+
|
42 |
+
vector = matrix[0]
|
43 |
+
updated_languages = np.array(languages)[~np.isnan(vector)]
|
44 |
+
updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)]
|
45 |
+
|
46 |
+
return updated_matrix, updated_languages
|
47 |
+
|
48 |
+
|
49 |
+
def get_similar_languages(model, dataset, selected_language, use_average, n):
|
50 |
+
"""
|
51 |
+
Retrieves the distances for the selected language from the chosen model and dataset,
|
52 |
+
sorts them by similarity (lowest distance first), and returns a DataFrame.
|
53 |
+
"""
|
54 |
+
if use_average:
|
55 |
+
matrix = average_distances_matrix
|
56 |
+
else:
|
57 |
+
matrix = distance_matrices[dataset][model]
|
58 |
+
selected_language_index = languages.index(selected_language)
|
59 |
+
distances = matrix[selected_language_index]
|
60 |
+
df = pd.DataFrame({"Language": languages, "Distance": distances})
|
61 |
+
sorted_distances = df.sort_values(by="Distance")
|
62 |
+
sorted_distances.drop(index=selected_language_index, inplace=True)
|
63 |
+
sorted_distances.reset_index(drop=True, inplace=True)
|
64 |
+
sorted_distances.reset_index(inplace=True)
|
65 |
+
sorted_distances["Distance"] = sorted_distances["Distance"].round(4)
|
66 |
+
return sorted_distances.head(n)
|
67 |
+
|
68 |
+
def update_languages(model, dataset):
|
69 |
+
"""
|
70 |
+
Returns the language list based on the given model and dataset.
|
71 |
+
"""
|
72 |
+
matrix = distance_matrices[dataset][model]
|
73 |
+
vector = matrix[0]
|
74 |
+
updated_languages = np.array(languages)[~np.isnan(vector)]
|
75 |
+
return list(updated_languages)
|
76 |
+
|
77 |
+
|
78 |
+
def update_language_options(model, dataset, language, use_average):
|
79 |
+
if use_average:
|
80 |
+
updated_languages = languages
|
81 |
+
else:
|
82 |
+
updated_languages = update_languages(model, dataset)
|
83 |
+
if language not in updated_languages:
|
84 |
+
language = updated_languages[0]
|
85 |
+
return gr.Dropdown(label="Language", choices=updated_languages, value=language)
|
86 |
+
|
87 |
+
|
88 |
+
def toggle_inputs(use_average):
|
89 |
+
if use_average:
|
90 |
+
return gr.update(interactive=False, visible=False), gr.update(interactive=False, visible=False)
|
91 |
+
else:
|
92 |
+
return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True)
|
93 |
+
|
94 |
+
i = 0
|
95 |
+
|
96 |
+
def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn):
|
97 |
+
"""
|
98 |
+
Plots all languages from the distances matrix using t-SNE.
|
99 |
+
"""
|
100 |
+
|
101 |
+
global i
|
102 |
+
|
103 |
+
updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
|
104 |
+
|
105 |
+
if cluster_method == "HDBSCAN":
|
106 |
+
filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan(
|
107 |
+
updated_matrix, updated_languages, min_cluster_size=cluster_method_param
|
108 |
+
)
|
109 |
+
legends = None
|
110 |
+
elif cluster_method == "KMeans":
|
111 |
+
filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans(
|
112 |
+
updated_matrix, updated_languages, n_clusters=cluster_method_param
|
113 |
+
)
|
114 |
+
legends = None
|
115 |
+
elif cluster_method == "Family":
|
116 |
+
clusters, legends = cluster_languages_by_families(updated_languages)
|
117 |
+
filtered_matrix = updated_matrix
|
118 |
+
filtered_languages = updated_languages
|
119 |
+
elif cluster_method == "Subfamily":
|
120 |
+
clusters, legends = cluster_languages_by_subfamilies(updated_languages)
|
121 |
+
filtered_matrix = updated_matrix
|
122 |
+
filtered_languages = updated_languages
|
123 |
+
else:
|
124 |
+
raise ValueError("Invalid cluster method")
|
125 |
+
|
126 |
+
fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends)
|
127 |
+
fig.tight_layout()
|
128 |
+
fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
|
129 |
+
i += 1
|
130 |
+
return fig
|
131 |
+
|
132 |
+
|
133 |
+
with gr.Blocks() as demo:
|
134 |
+
gr.Markdown("## Language Distance Explorer")
|
135 |
+
average_checkbox = gr.Checkbox(label="Use Average Distances", value=False)
|
136 |
+
with gr.Row():
|
137 |
+
model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0])
|
138 |
+
dataset_input = gr.Dropdown(
|
139 |
+
label="Dataset",
|
140 |
+
choices=DATASETS,
|
141 |
+
value=DATASETS[0]
|
142 |
+
)
|
143 |
+
|
144 |
+
with gr.Tab(label="Closest Languages Table"):
|
145 |
+
with gr.Row():
|
146 |
+
language_input = gr.Dropdown(label="Language", choices=languages, value=languages[0])
|
147 |
+
top_n_input = gr.Slider(label="Top N", minimum=1, maximum=30, step=1, value=10)
|
148 |
+
|
149 |
+
output_table = gr.Dataframe(label="Similar Languages")
|
150 |
+
|
151 |
+
model_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
|
152 |
+
dataset_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
|
153 |
+
language_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
154 |
+
model_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
155 |
+
dataset_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
156 |
+
top_n_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
157 |
+
|
158 |
+
average_checkbox.change(
|
159 |
+
fn=toggle_inputs,
|
160 |
+
inputs=[average_checkbox],
|
161 |
+
outputs=[model_input, dataset_input]
|
162 |
+
)
|
163 |
+
|
164 |
+
average_checkbox.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
|
165 |
+
average_checkbox.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
|
166 |
+
|
167 |
+
|
168 |
+
with gr.Tab(label="Distance Plot"):
|
169 |
+
with gr.Row():
|
170 |
+
cluster_method_input = gr.Dropdown(label="Cluster Method", choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], value="HDBSCAN")
|
171 |
+
clusters_input = gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2)
|
172 |
+
|
173 |
+
def update_clusters_input_option(cluster_method):
|
174 |
+
if cluster_method == "HDBSCAN":
|
175 |
+
return gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, visible=True, interactive=True)
|
176 |
+
elif cluster_method == "KMeans":
|
177 |
+
return gr.Slider(label="Number of Clusters", minimum=2, maximum=20, step=1, value=2, visible=True, interactive=True)
|
178 |
+
else:
|
179 |
+
return gr.update(interactive=False, visible=False)
|
180 |
+
|
181 |
+
cluster_method_input.change(fn=update_clusters_input_option, inputs=[cluster_method_input], outputs=clusters_input)
|
182 |
+
|
183 |
+
with gr.Row():
|
184 |
+
plot_tsne_button = gr.Button("Plot t-SNE")
|
185 |
+
plot_umap_button = gr.Button("Plot UMAP")
|
186 |
+
plot_mst_button = gr.Button("Plot MST")
|
187 |
+
|
188 |
+
with gr.Row():
|
189 |
+
plot_output = gr.Plot(label="Distance Plot")
|
190 |
+
|
191 |
+
plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne),
|
192 |
+
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
|
193 |
+
outputs=plot_output)
|
194 |
+
plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap),
|
195 |
+
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
|
196 |
+
outputs=plot_output)
|
197 |
+
plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst),
|
198 |
+
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
|
199 |
+
outputs=plot_output)
|
200 |
+
|
201 |
+
with gr.Tab(label="Language Families Subplot"):
|
202 |
+
|
203 |
+
checked_families_input = gr.CheckboxGroup(label="Language Families",
|
204 |
+
choices=[
|
205 |
+
'Afroasiatic',
|
206 |
+
'Austroasiatic',
|
207 |
+
'Austronesian',
|
208 |
+
'Constructed',
|
209 |
+
'Creole',
|
210 |
+
'Dravidian',
|
211 |
+
'Germanic',
|
212 |
+
'Indo-European',
|
213 |
+
'Japonic',
|
214 |
+
'Kartvelian',
|
215 |
+
'Koreanic',
|
216 |
+
'Language Isolate',
|
217 |
+
'Niger-Congo',
|
218 |
+
'Northeast Caucasian',
|
219 |
+
'Romance',
|
220 |
+
'Sino-Tibetan',
|
221 |
+
'Turkic',
|
222 |
+
'Uralic'
|
223 |
+
],
|
224 |
+
value=["Indo-European"])
|
225 |
+
with gr.Row():
|
226 |
+
plot_family_button = gr.Button("Plot Families")
|
227 |
+
plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15)
|
228 |
+
plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15)
|
229 |
+
plot_family_output = gr.Plot(label="Families Plot")
|
230 |
+
def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w):
|
231 |
+
global i
|
232 |
+
|
233 |
+
updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
|
234 |
+
updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families)
|
235 |
+
|
236 |
+
clusters, legends = cluster_languages_by_subfamilies(updated_languages)
|
237 |
+
fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h))
|
238 |
+
fig.tight_layout()
|
239 |
+
fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
|
240 |
+
i += 1
|
241 |
+
return fig
|
242 |
+
|
243 |
+
plot_family_button.click(fn=plot_families_subfamilies,
|
244 |
+
inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input],
|
245 |
+
outputs=plot_family_output)
|
246 |
+
|
247 |
+
|
248 |
+
demo.launch(share=True)
|
constants.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
language_subfamilies = {
|
2 |
+
"Afrikaans": "West Germanic",
|
3 |
+
"Albanian": "Albanian",
|
4 |
+
"Arabic": "Semitic",
|
5 |
+
"Egyptian Arabic": "Semitic",
|
6 |
+
"Aragonese": "Romance",
|
7 |
+
"Armenian": "Armenian",
|
8 |
+
"Asturian": "Romance",
|
9 |
+
"Azerbaijani": "Oghuz",
|
10 |
+
"Bashkir": "Kypchak",
|
11 |
+
"Basque": "Language Isolate",
|
12 |
+
"Bavarian": "Austro-Bavarian",
|
13 |
+
"Belarusian": "East Slavic",
|
14 |
+
"Bengali": "Eastern Indo-Aryan",
|
15 |
+
"Bishnupriya Manipuri": "Eastern Indo-Aryan",
|
16 |
+
"Bosnian": "South Slavic",
|
17 |
+
"Breton": "Brythonic",
|
18 |
+
"Bulgarian": "South Slavic",
|
19 |
+
"Burmese": "Burmish",
|
20 |
+
"Catalan": "Romance",
|
21 |
+
"Cebuano": "Central Philippine",
|
22 |
+
"Chechen": "Nakh-Daghestanian",
|
23 |
+
"Chinese (Simplified)": "Sinitic",
|
24 |
+
"Chinese (Traditional)": "Sinitic",
|
25 |
+
"Min Nan Chinese": "Sinitic",
|
26 |
+
"Chuvash": "Oghur",
|
27 |
+
"Croatian": "South Slavic",
|
28 |
+
"Czech": "West Slavic",
|
29 |
+
"Danish": "North Germanic",
|
30 |
+
"Dutch": "West Germanic",
|
31 |
+
"English": "West Germanic",
|
32 |
+
"Estonian": "Finnic",
|
33 |
+
"Finnish": "Finnic",
|
34 |
+
"French": "Gallo-Romance",
|
35 |
+
"Galician": "Gallo-Romance",
|
36 |
+
"Georgian": "Kartvelian",
|
37 |
+
"German": "West Germanic",
|
38 |
+
"Greek": "Hellenic",
|
39 |
+
"Gujarati": "Gujarati",
|
40 |
+
"Haitian": "French-based Creole",
|
41 |
+
"Hebrew": "Semitic",
|
42 |
+
"Hindi": "Central Indo-Aryan",
|
43 |
+
"Hungarian": "Ugric",
|
44 |
+
"Icelandic": "North Germanic",
|
45 |
+
"Ido": "Constructed",
|
46 |
+
"Indonesian": "Malayic",
|
47 |
+
"Irish": "Goidelic",
|
48 |
+
"Italian": "Italo-Dalmatian",
|
49 |
+
"Japanese": "Japonic",
|
50 |
+
"Javanese": "Javanic",
|
51 |
+
"Kannada": "Southern Dravidian",
|
52 |
+
"Kazakh": "Kypchak",
|
53 |
+
"Kirghiz": "Kypchak",
|
54 |
+
"Korean": "Koreanic",
|
55 |
+
"Latin": "Italic",
|
56 |
+
"Latvian": "Baltic",
|
57 |
+
"Lithuanian": "Baltic",
|
58 |
+
"Lombard": "Gallo-Italic",
|
59 |
+
"Low Saxon": "West Germanic",
|
60 |
+
"Luxembourgish": "West Germanic",
|
61 |
+
"Macedonian": "South Slavic",
|
62 |
+
"Malagasy": "Malayic",
|
63 |
+
"Malay": "Malayic",
|
64 |
+
"Malayalam": "Southern Dravidian",
|
65 |
+
"Marathi": "Central Indo-Aryan",
|
66 |
+
"Minangkabau": "Malayic",
|
67 |
+
"Nepali": "Eastern Indo-Aryan",
|
68 |
+
"Newar": "Newaric",
|
69 |
+
"Norwegian (Bokmal)": "North Germanic",
|
70 |
+
"Norwegian (Nynorsk)": "North Germanic",
|
71 |
+
"Occitan": "Gallo-Romance",
|
72 |
+
"Persian (Farsi)": "Iranian",
|
73 |
+
"Piedmontese": "Gallo-Italic",
|
74 |
+
"Polish": "West Slavic",
|
75 |
+
"Portuguese": "Iberian Romance",
|
76 |
+
"Punjabi": "Punjabi",
|
77 |
+
"Romanian": "Eastern Romance",
|
78 |
+
"Russian": "East Slavic",
|
79 |
+
"Scots": "West Germanic",
|
80 |
+
"Serbian": "South Slavic",
|
81 |
+
"Serbo-Croatian": "South Slavic",
|
82 |
+
"Sicilian": "Italo-Dalmatian",
|
83 |
+
"Slovak": "West Slavic",
|
84 |
+
"Slovenian": "South Slavic",
|
85 |
+
"South Azerbaijani": "Oghuz",
|
86 |
+
"Spanish": "Iberian Romance",
|
87 |
+
"Sundanese": "Sundic",
|
88 |
+
"Swahili": "Bantu",
|
89 |
+
"Swedish": "North Germanic",
|
90 |
+
"Tagalog": "Central Philippine",
|
91 |
+
"Tajik": "Iranian",
|
92 |
+
"Tamil": "Southern Dravidian",
|
93 |
+
"Tatar": "Kypchak",
|
94 |
+
"Telugu": "Southern Dravidian",
|
95 |
+
"Turkish": "Oghuz",
|
96 |
+
"Ukrainian": "East Slavic",
|
97 |
+
"Urdu": "Central Indo-Aryan",
|
98 |
+
"Uzbek": "Karluk",
|
99 |
+
"Vietnamese": "Vietic",
|
100 |
+
"Volapük": "Constructed",
|
101 |
+
"Waray-Waray": "Central Philippine",
|
102 |
+
"Welsh": "Brythonic",
|
103 |
+
"West Frisian": "West Germanic",
|
104 |
+
"Western Punjabi": "Punjabi",
|
105 |
+
"Yoruba": "Yoruboid",
|
106 |
+
"Esperanto": "Constructed",
|
107 |
+
"Crimean Tatar": "Kypchak"
|
108 |
+
}
|
109 |
+
|
110 |
+
language_families = {
|
111 |
+
"Afrikaans": "Germanic",
|
112 |
+
"Albanian": "Indo-European",
|
113 |
+
"Arabic": "Afroasiatic",
|
114 |
+
"Egyptian Arabic": "Afroasiatic",
|
115 |
+
"Aragonese": "Romance",
|
116 |
+
"Armenian": "Indo-European",
|
117 |
+
"Asturian": "Romance",
|
118 |
+
"Azerbaijani": "Turkic",
|
119 |
+
"Bashkir": "Turkic",
|
120 |
+
"Basque": "Language Isolate",
|
121 |
+
"Bavarian": "Germanic",
|
122 |
+
"Belarusian": "Indo-European",
|
123 |
+
"Bengali": "Indo-European",
|
124 |
+
"Bishnupriya Manipuri": "Indo-European",
|
125 |
+
"Bosnian": "Indo-European",
|
126 |
+
"Breton": "Indo-European",
|
127 |
+
"Bulgarian": "Indo-European",
|
128 |
+
"Burmese": "Sino-Tibetan",
|
129 |
+
"Catalan": "Romance",
|
130 |
+
"Cebuano": "Austronesian",
|
131 |
+
"Chechen": "Northeast Caucasian",
|
132 |
+
"Chinese (Simplified)": "Sino-Tibetan",
|
133 |
+
"Chinese (Traditional)": "Sino-Tibetan",
|
134 |
+
"Min Nan Chinese": "Sino-Tibetan",
|
135 |
+
"Chuvash": "Turkic",
|
136 |
+
"Croatian": "Indo-European",
|
137 |
+
"Czech": "Indo-European",
|
138 |
+
"Danish": "Germanic",
|
139 |
+
"Dutch": "Germanic",
|
140 |
+
"English": "Germanic",
|
141 |
+
"Estonian": "Uralic",
|
142 |
+
"Finnish": "Uralic",
|
143 |
+
"French": "Romance",
|
144 |
+
"Galician": "Romance",
|
145 |
+
"Georgian": "Kartvelian",
|
146 |
+
"German": "Germanic",
|
147 |
+
"Greek": "Indo-European",
|
148 |
+
"Gujarati": "Indo-European",
|
149 |
+
"Haitian": "Creole",
|
150 |
+
"Hebrew": "Afroasiatic",
|
151 |
+
"Hindi": "Indo-European",
|
152 |
+
"Hungarian": "Uralic",
|
153 |
+
"Icelandic": "Germanic",
|
154 |
+
"Ido": "Constructed",
|
155 |
+
"Indonesian": "Austronesian",
|
156 |
+
"Irish": "Indo-European",
|
157 |
+
"Italian": "Romance",
|
158 |
+
"Japanese": "Japonic",
|
159 |
+
"Javanese": "Austronesian",
|
160 |
+
"Kannada": "Dravidian",
|
161 |
+
"Kazakh": "Turkic",
|
162 |
+
"Kirghiz": "Turkic",
|
163 |
+
"Korean": "Koreanic",
|
164 |
+
"Latin": "Indo-European",
|
165 |
+
"Latvian": "Indo-European",
|
166 |
+
"Lithuanian": "Indo-European",
|
167 |
+
"Lombard": "Romance",
|
168 |
+
"Low Saxon": "Germanic",
|
169 |
+
"Luxembourgish": "Germanic",
|
170 |
+
"Macedonian": "Indo-European",
|
171 |
+
"Malagasy": "Austronesian",
|
172 |
+
"Malay": "Austronesian",
|
173 |
+
"Malayalam": "Dravidian",
|
174 |
+
"Marathi": "Indo-European",
|
175 |
+
"Minangkabau": "Austronesian",
|
176 |
+
"Nepali": "Indo-European",
|
177 |
+
"Newar": "Sino-Tibetan",
|
178 |
+
"Norwegian (Bokmal)": "Germanic",
|
179 |
+
"Norwegian (Nynorsk)": "Germanic",
|
180 |
+
"Occitan": "Romance",
|
181 |
+
"Persian (Farsi)": "Indo-European",
|
182 |
+
"Piedmontese": "Romance",
|
183 |
+
"Polish": "Indo-European",
|
184 |
+
"Portuguese": "Romance",
|
185 |
+
"Punjabi": "Indo-European",
|
186 |
+
"Romanian": "Romance",
|
187 |
+
"Russian": "Indo-European",
|
188 |
+
"Scots": "Germanic",
|
189 |
+
"Serbian": "Indo-European",
|
190 |
+
"Serbo-Croatian": "Indo-European",
|
191 |
+
"Sicilian": "Romance",
|
192 |
+
"Slovak": "Indo-European",
|
193 |
+
"Slovenian": "Indo-European",
|
194 |
+
"South Azerbaijani": "Turkic",
|
195 |
+
"Spanish": "Romance",
|
196 |
+
"Sundanese": "Austronesian",
|
197 |
+
"Swahili": "Niger-Congo",
|
198 |
+
"Swedish": "Germanic",
|
199 |
+
"Tagalog": "Austronesian",
|
200 |
+
"Tajik": "Indo-European",
|
201 |
+
"Tamil": "Dravidian",
|
202 |
+
"Tatar": "Turkic",
|
203 |
+
"Telugu": "Dravidian",
|
204 |
+
"Turkish": "Turkic",
|
205 |
+
"Ukrainian": "Indo-European",
|
206 |
+
"Urdu": "Indo-European",
|
207 |
+
"Uzbek": "Turkic",
|
208 |
+
"Vietnamese": "Austroasiatic",
|
209 |
+
"Volapük": "Constructed",
|
210 |
+
"Waray-Waray": "Austronesian",
|
211 |
+
"Welsh": "Indo-European",
|
212 |
+
"West Frisian": "Germanic",
|
213 |
+
"Western Punjabi": "Indo-European",
|
214 |
+
"Yoruba": "Niger-Congo",
|
215 |
+
"Esperanto": "Constructed",
|
216 |
+
"Crimean Tatar": "Turkic"
|
217 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.23.3
|
2 |
+
networkx==3.4.2
|
3 |
+
umap-learn==0.5.7
|
utils.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import networkx as nx
|
2 |
+
from sklearn.cluster import HDBSCAN
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.manifold import TSNE
|
6 |
+
import umap
|
7 |
+
from sklearn.cluster import KMeans
|
8 |
+
from scipy.spatial import KDTree
|
9 |
+
from adjustText import adjust_text
|
10 |
+
from constants import language_families, language_subfamilies
|
11 |
+
|
12 |
+
|
13 |
+
def filter_languages_by_families(matrix, languages, families):
|
14 |
+
"""
|
15 |
+
Filters the languages based on their families.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
- languages: list of languages to filter.
|
19 |
+
- families: list of families to include.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
- filtered_languages: list of languages that belong to the specified families.
|
23 |
+
"""
|
24 |
+
filtered_languages = [(i, lang) for i, lang in enumerate(languages) if language_families[lang] in families]
|
25 |
+
filtered_indices = [i for i, lang in filtered_languages]
|
26 |
+
filtered_languages = [lang for i, lang in filtered_languages]
|
27 |
+
filtered_matrix = matrix[np.ix_(filtered_indices, filtered_indices)]
|
28 |
+
return filtered_matrix, filtered_languages
|
29 |
+
|
30 |
+
|
31 |
+
def get_dynamic_color_map(n_colors):
|
32 |
+
"""
|
33 |
+
Generates a dynamic color map with the specified number of colors.
|
34 |
+
|
35 |
+
Parameters:
|
36 |
+
- n_colors: int, the number of distinct colors required.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
- color_map: list of RGB tuples representing the colors.
|
40 |
+
"""
|
41 |
+
cmap = plt.get_cmap("tab20") if n_colors <= 20 else plt.get_cmap("hsv")
|
42 |
+
color_map = [cmap(i / n_colors) for i in range(n_colors)]
|
43 |
+
return color_map
|
44 |
+
|
45 |
+
|
46 |
+
def cluster_languages_by_families(languages):
|
47 |
+
lang_families = [language_families[lang] for lang in languages]
|
48 |
+
legend = sorted(set(lang_families))
|
49 |
+
clusters = [legend.index(family) for family in lang_families]
|
50 |
+
return clusters, legend
|
51 |
+
|
52 |
+
|
53 |
+
def cluster_languages_by_subfamilies(languages):
|
54 |
+
labels = [language_families[lang] + f" ({language_subfamilies[lang]})" for lang in languages]
|
55 |
+
legend = sorted(set(labels))
|
56 |
+
clusters = [legend.index(family) for family in labels]
|
57 |
+
return clusters, legend
|
58 |
+
|
59 |
+
|
60 |
+
def plot_mst(model, dataset, use_average, matrix, languages, clusters, legend=None, fig_size=(20,20)):
|
61 |
+
"""
|
62 |
+
Plots a Minimum Spanning Tree (MST) from a given distance matrix, node labels, and cluster assignments.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
- dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between nodes.
|
66 |
+
- labels: list of length N containing the labels for each node.
|
67 |
+
- clusters: list of length N containing the cluster assignment (or ID) for each node.
|
68 |
+
"""
|
69 |
+
# Create an empty undirected graph
|
70 |
+
G = nx.Graph()
|
71 |
+
|
72 |
+
# Number of nodes
|
73 |
+
N = len(languages)
|
74 |
+
|
75 |
+
# Add edges to the graph from the distance matrix.
|
76 |
+
# Only iterate over the upper triangle of the matrix (i < j)
|
77 |
+
for i in range(N):
|
78 |
+
for j in range(i + 1, N):
|
79 |
+
G.add_edge(i, j, weight=matrix[i, j])
|
80 |
+
|
81 |
+
# Compute the Minimum Spanning Tree using NetworkX's built-in function.
|
82 |
+
mst = nx.minimum_spanning_tree(G)
|
83 |
+
|
84 |
+
# Choose a layout for the MST. Here we use Kamada-Kawai layout which considers edge weights.
|
85 |
+
pos = nx.kamada_kawai_layout(mst, weight='weight')
|
86 |
+
|
87 |
+
# Map each cluster to a color
|
88 |
+
unique_clusters = sorted(set(clusters))
|
89 |
+
cmap = get_dynamic_color_map(len(unique_clusters))
|
90 |
+
cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
|
91 |
+
|
92 |
+
node_colors = [cluster_colors.get(cluster) for cluster in clusters]
|
93 |
+
|
94 |
+
# Create a figure for plotting.
|
95 |
+
fig, ax = plt.subplots(figsize=fig_size)
|
96 |
+
|
97 |
+
# Draw the MST edges.
|
98 |
+
nx.draw_networkx_edges(mst, pos, edge_color='gray', ax=ax)
|
99 |
+
|
100 |
+
# Draw the nodes with colors corresponding to their clusters.
|
101 |
+
nx.draw_networkx_nodes(mst, pos, node_color=node_colors, node_size=100, ax=ax, alpha=0.7)
|
102 |
+
|
103 |
+
# Instead of directly drawing labels, we create text objects to adjust them later
|
104 |
+
texts = []
|
105 |
+
for i, label in enumerate(languages):
|
106 |
+
x, y = pos[i]
|
107 |
+
texts.append(ax.text(x, y, label, fontsize=10))
|
108 |
+
|
109 |
+
# Adjust text labels to minimize overlap.
|
110 |
+
# The arrowprops argument can draw arrows from labels to nodes if desired.
|
111 |
+
adjust_text(texts, expand_text=(1.05, 1.2))
|
112 |
+
|
113 |
+
# Add a legend for clusters
|
114 |
+
if legend is None:
|
115 |
+
legend = {cluster: str(cluster) for cluster in unique_clusters}
|
116 |
+
legend_handles = [
|
117 |
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, alpha=0.7, label=legend[cluster])
|
118 |
+
for cluster in unique_clusters
|
119 |
+
]
|
120 |
+
ax.legend(handles=legend_handles, title="Clusters", loc="best")
|
121 |
+
|
122 |
+
# Remove axis for clarity.
|
123 |
+
ax.axis('off')
|
124 |
+
# ax.set_title(f"Minimum Spanning Tree of Languages ({'Average' if use_average else f'{model}, {dataset}'})")
|
125 |
+
|
126 |
+
return fig
|
127 |
+
|
128 |
+
def cluster_languages_kmeans(dist_matrix, languages, n_clusters=5):
|
129 |
+
"""
|
130 |
+
Clusters languages using a distance matrix and KMeans.
|
131 |
+
|
132 |
+
Parameters:
|
133 |
+
- dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between languages.
|
134 |
+
- n_clusters: int, the number of clusters to form.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
- filtered_matrix: 2D NumPy array of the filtered distance matrix.
|
138 |
+
- filtered_languages: list of filtered languages.
|
139 |
+
- filtered_clusters: list of filtered cluster assignments.
|
140 |
+
"""
|
141 |
+
|
142 |
+
# Perform clustering using KMeans
|
143 |
+
kmeans_model = KMeans(n_clusters=n_clusters, random_state=23)
|
144 |
+
clusters = kmeans_model.fit_predict(dist_matrix)
|
145 |
+
|
146 |
+
# Count the number of elements in each cluster
|
147 |
+
cluster_counts = np.bincount(clusters)
|
148 |
+
|
149 |
+
# Identify clusters with more than 1 element
|
150 |
+
valid_clusters = np.where(cluster_counts > 1)[0]
|
151 |
+
|
152 |
+
# Filter out points belonging to clusters with only 1 element
|
153 |
+
valid_indices = np.isin(clusters, valid_clusters)
|
154 |
+
filtered_matrix = dist_matrix[np.ix_(valid_indices, valid_indices)]
|
155 |
+
filtered_languages = np.array(languages)[valid_indices]
|
156 |
+
filtered_clusters = np.array(clusters)[valid_indices]
|
157 |
+
|
158 |
+
return filtered_matrix, filtered_languages, filtered_clusters
|
159 |
+
|
160 |
+
|
161 |
+
def cluster_languages_hdbscan(dist_matrix, languages, min_cluster_size=2):
|
162 |
+
"""
|
163 |
+
Clusters languages using a distance matrix and HDBSCAN.
|
164 |
+
|
165 |
+
Parameters:
|
166 |
+
- dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between languages.
|
167 |
+
- min_cluster_size: int, the minimum size of clusters.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
- clusters: list of length N containing the cluster assignment (or ID) for each language.
|
171 |
+
"""
|
172 |
+
# Perform clustering using HDBSCAN with the precomputed distance matrix
|
173 |
+
clustering_model = HDBSCAN(
|
174 |
+
metric='precomputed', min_cluster_size=min_cluster_size
|
175 |
+
)
|
176 |
+
clusters = clustering_model.fit_predict(dist_matrix)
|
177 |
+
|
178 |
+
# Filter out points belonging to cluster -1 using NumPy
|
179 |
+
valid_indices = np.where(clusters != -1)[0]
|
180 |
+
filtered_matrix = dist_matrix[np.ix_(valid_indices, valid_indices)]
|
181 |
+
filtered_languages = np.array(languages)[valid_indices]
|
182 |
+
filtered_clusters = np.array(clusters)[valid_indices]
|
183 |
+
return filtered_matrix, filtered_languages, filtered_clusters
|
184 |
+
|
185 |
+
|
186 |
+
def plot_distances_tsne(model, dataset, use_average, matrix, languages, clusters, legend=None):
|
187 |
+
"""
|
188 |
+
Plots all languages from the distances matrix using t-SNE and colors them by clusters.
|
189 |
+
"""
|
190 |
+
tsne = TSNE(n_components=2, random_state=23, metric="precomputed", init="random")
|
191 |
+
tsne_results = tsne.fit_transform(matrix)
|
192 |
+
|
193 |
+
# Map each cluster to a color
|
194 |
+
unique_clusters = sorted(set(clusters))
|
195 |
+
cmap = get_dynamic_color_map(len(unique_clusters))
|
196 |
+
cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
|
197 |
+
|
198 |
+
fig, ax = plt.subplots(figsize=(16, 12))
|
199 |
+
scatter = ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
|
200 |
+
|
201 |
+
# for i, lang in enumerate(languages):
|
202 |
+
# ax.text(tsne_results[i, 0], tsne_results[i, 1], lang, fontsize=8, alpha=0.8)
|
203 |
+
|
204 |
+
# Instead of directly drawing labels, we create text objects to adjust them later
|
205 |
+
texts = []
|
206 |
+
for i, label in enumerate(languages):
|
207 |
+
x, y = tsne_results[i, 0], tsne_results[i, 1]
|
208 |
+
texts.append(ax.text(x, y, label, fontsize=10))
|
209 |
+
|
210 |
+
# Adjust text labels to minimize overlap.
|
211 |
+
# The arrowprops argument can draw arrows from labels to nodes if desired.
|
212 |
+
adjust_text(texts, expand_text=(1.05, 1.2))
|
213 |
+
|
214 |
+
# Add a legend for clusters
|
215 |
+
if legend is None:
|
216 |
+
legend = {cluster: str(cluster) for cluster in unique_clusters}
|
217 |
+
legend_handles = [
|
218 |
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
|
219 |
+
for cluster in unique_clusters
|
220 |
+
]
|
221 |
+
ax.legend(handles=legend_handles, title="Clusters", loc="best")
|
222 |
+
|
223 |
+
ax.set_title(f"t-SNE Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
|
224 |
+
ax.set_xlabel("t-SNE Dimension 1")
|
225 |
+
ax.set_ylabel("t-SNE Dimension 2")
|
226 |
+
return fig
|
227 |
+
|
228 |
+
|
229 |
+
def plot_distances_umap(model, dataset, use_average, matrix, languages, clusters, legend=None):
|
230 |
+
"""
|
231 |
+
Plots all languages from the distances matrix using UMAP and colors them by clusters.
|
232 |
+
"""
|
233 |
+
|
234 |
+
umap_model = umap.UMAP(metric="precomputed", random_state=23)
|
235 |
+
umap_results = umap_model.fit_transform(matrix)
|
236 |
+
|
237 |
+
# Map each cluster to a color
|
238 |
+
unique_clusters = sorted(set(clusters))
|
239 |
+
cmap = get_dynamic_color_map(len(unique_clusters))
|
240 |
+
cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
|
241 |
+
|
242 |
+
fig, ax = plt.subplots(figsize=(16, 12))
|
243 |
+
scatter = ax.scatter(umap_results[:, 0], umap_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
|
244 |
+
|
245 |
+
# for i, lang in enumerate(languages):
|
246 |
+
# ax.text(umap_results[i, 0], umap_results[i, 1], lang, fontsize=8, alpha=0.8)
|
247 |
+
|
248 |
+
# Instead of directly drawing labels, we create text objects to adjust them later
|
249 |
+
texts = []
|
250 |
+
for i, label in enumerate(languages):
|
251 |
+
x, y = umap_results[i, 0], umap_results[i, 1]
|
252 |
+
texts.append(ax.text(x, y, label, fontsize=10))
|
253 |
+
|
254 |
+
# Adjust text labels to minimize overlap.
|
255 |
+
# The arrowprops argument can draw arrows from labels to nodes if desired.
|
256 |
+
adjust_text(texts, expand_text=(1.05, 1.2))
|
257 |
+
|
258 |
+
# Add a legend for clusters
|
259 |
+
if legend is None:
|
260 |
+
legend = {cluster: str(cluster) for cluster in unique_clusters}
|
261 |
+
legend_handles = [
|
262 |
+
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
|
263 |
+
for cluster in unique_clusters
|
264 |
+
]
|
265 |
+
ax.legend(handles=legend_handles, title="Clusters", loc="best")
|
266 |
+
|
267 |
+
ax.set_title(f"UMAP Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
|
268 |
+
ax.set_xlabel("UMAP Dimension 1")
|
269 |
+
ax.set_ylabel("UMAP Dimension 2")
|
270 |
+
return fig
|