mshamrai commited on
Commit
f9b063b
·
1 Parent(s): 983aedb

chore: init demo

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. app.py +248 -0
  3. constants.py +217 -0
  4. requirements.txt +3 -0
  5. utils.py +270 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .gradio
2
+ __pycache__
3
+ plots
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pickle
5
+ import os
6
+ from sklearn.manifold import TSNE
7
+ import matplotlib.pyplot as plt
8
+ from utils import (plot_distances_tsne,
9
+ plot_distances_umap,
10
+ cluster_languages_hdbscan,
11
+ cluster_languages_kmeans,
12
+ plot_mst,
13
+ cluster_languages_by_families,
14
+ cluster_languages_by_subfamilies,
15
+ filter_languages_by_families)
16
+ from functools import partial
17
+
18
+
19
+ with open("../../results/languages_list.pkl", "rb") as f:
20
+ languages = pickle.load(f)
21
+
22
+ DATASETS = ["wikimedia/wikipedia", "uonlp/CulturaX", "HuggingFaceFW/fineweb-2"]
23
+ MODELS = ["mistralai/Mistral-7B-v0.1", "google/gemma-3-4b-pt", "meta-llama/Llama-3.2-1B"]
24
+
25
+ distance_matrices = {
26
+ dataset: {
27
+ model: np.load(os.path.join("../../results", dataset, model, "distances_matrix.npy"))
28
+ for model in MODELS
29
+ }
30
+ for dataset in DATASETS
31
+ }
32
+
33
+ average_distances_matrix = np.load("../../results/average_distances_matrix.npy")
34
+
35
+
36
+ def filter_languages_nan(model, dataset, use_average):
37
+ if use_average:
38
+ matrix = average_distances_matrix
39
+ else:
40
+ matrix = distance_matrices[dataset][model]
41
+
42
+ vector = matrix[0]
43
+ updated_languages = np.array(languages)[~np.isnan(vector)]
44
+ updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)]
45
+
46
+ return updated_matrix, updated_languages
47
+
48
+
49
+ def get_similar_languages(model, dataset, selected_language, use_average, n):
50
+ """
51
+ Retrieves the distances for the selected language from the chosen model and dataset,
52
+ sorts them by similarity (lowest distance first), and returns a DataFrame.
53
+ """
54
+ if use_average:
55
+ matrix = average_distances_matrix
56
+ else:
57
+ matrix = distance_matrices[dataset][model]
58
+ selected_language_index = languages.index(selected_language)
59
+ distances = matrix[selected_language_index]
60
+ df = pd.DataFrame({"Language": languages, "Distance": distances})
61
+ sorted_distances = df.sort_values(by="Distance")
62
+ sorted_distances.drop(index=selected_language_index, inplace=True)
63
+ sorted_distances.reset_index(drop=True, inplace=True)
64
+ sorted_distances.reset_index(inplace=True)
65
+ sorted_distances["Distance"] = sorted_distances["Distance"].round(4)
66
+ return sorted_distances.head(n)
67
+
68
+ def update_languages(model, dataset):
69
+ """
70
+ Returns the language list based on the given model and dataset.
71
+ """
72
+ matrix = distance_matrices[dataset][model]
73
+ vector = matrix[0]
74
+ updated_languages = np.array(languages)[~np.isnan(vector)]
75
+ return list(updated_languages)
76
+
77
+
78
+ def update_language_options(model, dataset, language, use_average):
79
+ if use_average:
80
+ updated_languages = languages
81
+ else:
82
+ updated_languages = update_languages(model, dataset)
83
+ if language not in updated_languages:
84
+ language = updated_languages[0]
85
+ return gr.Dropdown(label="Language", choices=updated_languages, value=language)
86
+
87
+
88
+ def toggle_inputs(use_average):
89
+ if use_average:
90
+ return gr.update(interactive=False, visible=False), gr.update(interactive=False, visible=False)
91
+ else:
92
+ return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True)
93
+
94
+ i = 0
95
+
96
+ def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn):
97
+ """
98
+ Plots all languages from the distances matrix using t-SNE.
99
+ """
100
+
101
+ global i
102
+
103
+ updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
104
+
105
+ if cluster_method == "HDBSCAN":
106
+ filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan(
107
+ updated_matrix, updated_languages, min_cluster_size=cluster_method_param
108
+ )
109
+ legends = None
110
+ elif cluster_method == "KMeans":
111
+ filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans(
112
+ updated_matrix, updated_languages, n_clusters=cluster_method_param
113
+ )
114
+ legends = None
115
+ elif cluster_method == "Family":
116
+ clusters, legends = cluster_languages_by_families(updated_languages)
117
+ filtered_matrix = updated_matrix
118
+ filtered_languages = updated_languages
119
+ elif cluster_method == "Subfamily":
120
+ clusters, legends = cluster_languages_by_subfamilies(updated_languages)
121
+ filtered_matrix = updated_matrix
122
+ filtered_languages = updated_languages
123
+ else:
124
+ raise ValueError("Invalid cluster method")
125
+
126
+ fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends)
127
+ fig.tight_layout()
128
+ fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
129
+ i += 1
130
+ return fig
131
+
132
+
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown("## Language Distance Explorer")
135
+ average_checkbox = gr.Checkbox(label="Use Average Distances", value=False)
136
+ with gr.Row():
137
+ model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0])
138
+ dataset_input = gr.Dropdown(
139
+ label="Dataset",
140
+ choices=DATASETS,
141
+ value=DATASETS[0]
142
+ )
143
+
144
+ with gr.Tab(label="Closest Languages Table"):
145
+ with gr.Row():
146
+ language_input = gr.Dropdown(label="Language", choices=languages, value=languages[0])
147
+ top_n_input = gr.Slider(label="Top N", minimum=1, maximum=30, step=1, value=10)
148
+
149
+ output_table = gr.Dataframe(label="Similar Languages")
150
+
151
+ model_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
152
+ dataset_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
153
+ language_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
154
+ model_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
155
+ dataset_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
156
+ top_n_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
157
+
158
+ average_checkbox.change(
159
+ fn=toggle_inputs,
160
+ inputs=[average_checkbox],
161
+ outputs=[model_input, dataset_input]
162
+ )
163
+
164
+ average_checkbox.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input)
165
+ average_checkbox.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table)
166
+
167
+
168
+ with gr.Tab(label="Distance Plot"):
169
+ with gr.Row():
170
+ cluster_method_input = gr.Dropdown(label="Cluster Method", choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], value="HDBSCAN")
171
+ clusters_input = gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2)
172
+
173
+ def update_clusters_input_option(cluster_method):
174
+ if cluster_method == "HDBSCAN":
175
+ return gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, visible=True, interactive=True)
176
+ elif cluster_method == "KMeans":
177
+ return gr.Slider(label="Number of Clusters", minimum=2, maximum=20, step=1, value=2, visible=True, interactive=True)
178
+ else:
179
+ return gr.update(interactive=False, visible=False)
180
+
181
+ cluster_method_input.change(fn=update_clusters_input_option, inputs=[cluster_method_input], outputs=clusters_input)
182
+
183
+ with gr.Row():
184
+ plot_tsne_button = gr.Button("Plot t-SNE")
185
+ plot_umap_button = gr.Button("Plot UMAP")
186
+ plot_mst_button = gr.Button("Plot MST")
187
+
188
+ with gr.Row():
189
+ plot_output = gr.Plot(label="Distance Plot")
190
+
191
+ plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne),
192
+ inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
193
+ outputs=plot_output)
194
+ plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap),
195
+ inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
196
+ outputs=plot_output)
197
+ plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst),
198
+ inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input],
199
+ outputs=plot_output)
200
+
201
+ with gr.Tab(label="Language Families Subplot"):
202
+
203
+ checked_families_input = gr.CheckboxGroup(label="Language Families",
204
+ choices=[
205
+ 'Afroasiatic',
206
+ 'Austroasiatic',
207
+ 'Austronesian',
208
+ 'Constructed',
209
+ 'Creole',
210
+ 'Dravidian',
211
+ 'Germanic',
212
+ 'Indo-European',
213
+ 'Japonic',
214
+ 'Kartvelian',
215
+ 'Koreanic',
216
+ 'Language Isolate',
217
+ 'Niger-Congo',
218
+ 'Northeast Caucasian',
219
+ 'Romance',
220
+ 'Sino-Tibetan',
221
+ 'Turkic',
222
+ 'Uralic'
223
+ ],
224
+ value=["Indo-European"])
225
+ with gr.Row():
226
+ plot_family_button = gr.Button("Plot Families")
227
+ plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15)
228
+ plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15)
229
+ plot_family_output = gr.Plot(label="Families Plot")
230
+ def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w):
231
+ global i
232
+
233
+ updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average)
234
+ updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families)
235
+
236
+ clusters, legends = cluster_languages_by_subfamilies(updated_languages)
237
+ fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h))
238
+ fig.tight_layout()
239
+ fig.savefig(f"plots/plot_{i}.pdf", format="pdf")
240
+ i += 1
241
+ return fig
242
+
243
+ plot_family_button.click(fn=plot_families_subfamilies,
244
+ inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input],
245
+ outputs=plot_family_output)
246
+
247
+
248
+ demo.launch(share=True)
constants.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ language_subfamilies = {
2
+ "Afrikaans": "West Germanic",
3
+ "Albanian": "Albanian",
4
+ "Arabic": "Semitic",
5
+ "Egyptian Arabic": "Semitic",
6
+ "Aragonese": "Romance",
7
+ "Armenian": "Armenian",
8
+ "Asturian": "Romance",
9
+ "Azerbaijani": "Oghuz",
10
+ "Bashkir": "Kypchak",
11
+ "Basque": "Language Isolate",
12
+ "Bavarian": "Austro-Bavarian",
13
+ "Belarusian": "East Slavic",
14
+ "Bengali": "Eastern Indo-Aryan",
15
+ "Bishnupriya Manipuri": "Eastern Indo-Aryan",
16
+ "Bosnian": "South Slavic",
17
+ "Breton": "Brythonic",
18
+ "Bulgarian": "South Slavic",
19
+ "Burmese": "Burmish",
20
+ "Catalan": "Romance",
21
+ "Cebuano": "Central Philippine",
22
+ "Chechen": "Nakh-Daghestanian",
23
+ "Chinese (Simplified)": "Sinitic",
24
+ "Chinese (Traditional)": "Sinitic",
25
+ "Min Nan Chinese": "Sinitic",
26
+ "Chuvash": "Oghur",
27
+ "Croatian": "South Slavic",
28
+ "Czech": "West Slavic",
29
+ "Danish": "North Germanic",
30
+ "Dutch": "West Germanic",
31
+ "English": "West Germanic",
32
+ "Estonian": "Finnic",
33
+ "Finnish": "Finnic",
34
+ "French": "Gallo-Romance",
35
+ "Galician": "Gallo-Romance",
36
+ "Georgian": "Kartvelian",
37
+ "German": "West Germanic",
38
+ "Greek": "Hellenic",
39
+ "Gujarati": "Gujarati",
40
+ "Haitian": "French-based Creole",
41
+ "Hebrew": "Semitic",
42
+ "Hindi": "Central Indo-Aryan",
43
+ "Hungarian": "Ugric",
44
+ "Icelandic": "North Germanic",
45
+ "Ido": "Constructed",
46
+ "Indonesian": "Malayic",
47
+ "Irish": "Goidelic",
48
+ "Italian": "Italo-Dalmatian",
49
+ "Japanese": "Japonic",
50
+ "Javanese": "Javanic",
51
+ "Kannada": "Southern Dravidian",
52
+ "Kazakh": "Kypchak",
53
+ "Kirghiz": "Kypchak",
54
+ "Korean": "Koreanic",
55
+ "Latin": "Italic",
56
+ "Latvian": "Baltic",
57
+ "Lithuanian": "Baltic",
58
+ "Lombard": "Gallo-Italic",
59
+ "Low Saxon": "West Germanic",
60
+ "Luxembourgish": "West Germanic",
61
+ "Macedonian": "South Slavic",
62
+ "Malagasy": "Malayic",
63
+ "Malay": "Malayic",
64
+ "Malayalam": "Southern Dravidian",
65
+ "Marathi": "Central Indo-Aryan",
66
+ "Minangkabau": "Malayic",
67
+ "Nepali": "Eastern Indo-Aryan",
68
+ "Newar": "Newaric",
69
+ "Norwegian (Bokmal)": "North Germanic",
70
+ "Norwegian (Nynorsk)": "North Germanic",
71
+ "Occitan": "Gallo-Romance",
72
+ "Persian (Farsi)": "Iranian",
73
+ "Piedmontese": "Gallo-Italic",
74
+ "Polish": "West Slavic",
75
+ "Portuguese": "Iberian Romance",
76
+ "Punjabi": "Punjabi",
77
+ "Romanian": "Eastern Romance",
78
+ "Russian": "East Slavic",
79
+ "Scots": "West Germanic",
80
+ "Serbian": "South Slavic",
81
+ "Serbo-Croatian": "South Slavic",
82
+ "Sicilian": "Italo-Dalmatian",
83
+ "Slovak": "West Slavic",
84
+ "Slovenian": "South Slavic",
85
+ "South Azerbaijani": "Oghuz",
86
+ "Spanish": "Iberian Romance",
87
+ "Sundanese": "Sundic",
88
+ "Swahili": "Bantu",
89
+ "Swedish": "North Germanic",
90
+ "Tagalog": "Central Philippine",
91
+ "Tajik": "Iranian",
92
+ "Tamil": "Southern Dravidian",
93
+ "Tatar": "Kypchak",
94
+ "Telugu": "Southern Dravidian",
95
+ "Turkish": "Oghuz",
96
+ "Ukrainian": "East Slavic",
97
+ "Urdu": "Central Indo-Aryan",
98
+ "Uzbek": "Karluk",
99
+ "Vietnamese": "Vietic",
100
+ "Volapük": "Constructed",
101
+ "Waray-Waray": "Central Philippine",
102
+ "Welsh": "Brythonic",
103
+ "West Frisian": "West Germanic",
104
+ "Western Punjabi": "Punjabi",
105
+ "Yoruba": "Yoruboid",
106
+ "Esperanto": "Constructed",
107
+ "Crimean Tatar": "Kypchak"
108
+ }
109
+
110
+ language_families = {
111
+ "Afrikaans": "Germanic",
112
+ "Albanian": "Indo-European",
113
+ "Arabic": "Afroasiatic",
114
+ "Egyptian Arabic": "Afroasiatic",
115
+ "Aragonese": "Romance",
116
+ "Armenian": "Indo-European",
117
+ "Asturian": "Romance",
118
+ "Azerbaijani": "Turkic",
119
+ "Bashkir": "Turkic",
120
+ "Basque": "Language Isolate",
121
+ "Bavarian": "Germanic",
122
+ "Belarusian": "Indo-European",
123
+ "Bengali": "Indo-European",
124
+ "Bishnupriya Manipuri": "Indo-European",
125
+ "Bosnian": "Indo-European",
126
+ "Breton": "Indo-European",
127
+ "Bulgarian": "Indo-European",
128
+ "Burmese": "Sino-Tibetan",
129
+ "Catalan": "Romance",
130
+ "Cebuano": "Austronesian",
131
+ "Chechen": "Northeast Caucasian",
132
+ "Chinese (Simplified)": "Sino-Tibetan",
133
+ "Chinese (Traditional)": "Sino-Tibetan",
134
+ "Min Nan Chinese": "Sino-Tibetan",
135
+ "Chuvash": "Turkic",
136
+ "Croatian": "Indo-European",
137
+ "Czech": "Indo-European",
138
+ "Danish": "Germanic",
139
+ "Dutch": "Germanic",
140
+ "English": "Germanic",
141
+ "Estonian": "Uralic",
142
+ "Finnish": "Uralic",
143
+ "French": "Romance",
144
+ "Galician": "Romance",
145
+ "Georgian": "Kartvelian",
146
+ "German": "Germanic",
147
+ "Greek": "Indo-European",
148
+ "Gujarati": "Indo-European",
149
+ "Haitian": "Creole",
150
+ "Hebrew": "Afroasiatic",
151
+ "Hindi": "Indo-European",
152
+ "Hungarian": "Uralic",
153
+ "Icelandic": "Germanic",
154
+ "Ido": "Constructed",
155
+ "Indonesian": "Austronesian",
156
+ "Irish": "Indo-European",
157
+ "Italian": "Romance",
158
+ "Japanese": "Japonic",
159
+ "Javanese": "Austronesian",
160
+ "Kannada": "Dravidian",
161
+ "Kazakh": "Turkic",
162
+ "Kirghiz": "Turkic",
163
+ "Korean": "Koreanic",
164
+ "Latin": "Indo-European",
165
+ "Latvian": "Indo-European",
166
+ "Lithuanian": "Indo-European",
167
+ "Lombard": "Romance",
168
+ "Low Saxon": "Germanic",
169
+ "Luxembourgish": "Germanic",
170
+ "Macedonian": "Indo-European",
171
+ "Malagasy": "Austronesian",
172
+ "Malay": "Austronesian",
173
+ "Malayalam": "Dravidian",
174
+ "Marathi": "Indo-European",
175
+ "Minangkabau": "Austronesian",
176
+ "Nepali": "Indo-European",
177
+ "Newar": "Sino-Tibetan",
178
+ "Norwegian (Bokmal)": "Germanic",
179
+ "Norwegian (Nynorsk)": "Germanic",
180
+ "Occitan": "Romance",
181
+ "Persian (Farsi)": "Indo-European",
182
+ "Piedmontese": "Romance",
183
+ "Polish": "Indo-European",
184
+ "Portuguese": "Romance",
185
+ "Punjabi": "Indo-European",
186
+ "Romanian": "Romance",
187
+ "Russian": "Indo-European",
188
+ "Scots": "Germanic",
189
+ "Serbian": "Indo-European",
190
+ "Serbo-Croatian": "Indo-European",
191
+ "Sicilian": "Romance",
192
+ "Slovak": "Indo-European",
193
+ "Slovenian": "Indo-European",
194
+ "South Azerbaijani": "Turkic",
195
+ "Spanish": "Romance",
196
+ "Sundanese": "Austronesian",
197
+ "Swahili": "Niger-Congo",
198
+ "Swedish": "Germanic",
199
+ "Tagalog": "Austronesian",
200
+ "Tajik": "Indo-European",
201
+ "Tamil": "Dravidian",
202
+ "Tatar": "Turkic",
203
+ "Telugu": "Dravidian",
204
+ "Turkish": "Turkic",
205
+ "Ukrainian": "Indo-European",
206
+ "Urdu": "Indo-European",
207
+ "Uzbek": "Turkic",
208
+ "Vietnamese": "Austroasiatic",
209
+ "Volapük": "Constructed",
210
+ "Waray-Waray": "Austronesian",
211
+ "Welsh": "Indo-European",
212
+ "West Frisian": "Germanic",
213
+ "Western Punjabi": "Indo-European",
214
+ "Yoruba": "Niger-Congo",
215
+ "Esperanto": "Constructed",
216
+ "Crimean Tatar": "Turkic"
217
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==5.23.3
2
+ networkx==3.4.2
3
+ umap-learn==0.5.7
utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ from sklearn.cluster import HDBSCAN
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from sklearn.manifold import TSNE
6
+ import umap
7
+ from sklearn.cluster import KMeans
8
+ from scipy.spatial import KDTree
9
+ from adjustText import adjust_text
10
+ from constants import language_families, language_subfamilies
11
+
12
+
13
+ def filter_languages_by_families(matrix, languages, families):
14
+ """
15
+ Filters the languages based on their families.
16
+
17
+ Parameters:
18
+ - languages: list of languages to filter.
19
+ - families: list of families to include.
20
+
21
+ Returns:
22
+ - filtered_languages: list of languages that belong to the specified families.
23
+ """
24
+ filtered_languages = [(i, lang) for i, lang in enumerate(languages) if language_families[lang] in families]
25
+ filtered_indices = [i for i, lang in filtered_languages]
26
+ filtered_languages = [lang for i, lang in filtered_languages]
27
+ filtered_matrix = matrix[np.ix_(filtered_indices, filtered_indices)]
28
+ return filtered_matrix, filtered_languages
29
+
30
+
31
+ def get_dynamic_color_map(n_colors):
32
+ """
33
+ Generates a dynamic color map with the specified number of colors.
34
+
35
+ Parameters:
36
+ - n_colors: int, the number of distinct colors required.
37
+
38
+ Returns:
39
+ - color_map: list of RGB tuples representing the colors.
40
+ """
41
+ cmap = plt.get_cmap("tab20") if n_colors <= 20 else plt.get_cmap("hsv")
42
+ color_map = [cmap(i / n_colors) for i in range(n_colors)]
43
+ return color_map
44
+
45
+
46
+ def cluster_languages_by_families(languages):
47
+ lang_families = [language_families[lang] for lang in languages]
48
+ legend = sorted(set(lang_families))
49
+ clusters = [legend.index(family) for family in lang_families]
50
+ return clusters, legend
51
+
52
+
53
+ def cluster_languages_by_subfamilies(languages):
54
+ labels = [language_families[lang] + f" ({language_subfamilies[lang]})" for lang in languages]
55
+ legend = sorted(set(labels))
56
+ clusters = [legend.index(family) for family in labels]
57
+ return clusters, legend
58
+
59
+
60
+ def plot_mst(model, dataset, use_average, matrix, languages, clusters, legend=None, fig_size=(20,20)):
61
+ """
62
+ Plots a Minimum Spanning Tree (MST) from a given distance matrix, node labels, and cluster assignments.
63
+
64
+ Parameters:
65
+ - dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between nodes.
66
+ - labels: list of length N containing the labels for each node.
67
+ - clusters: list of length N containing the cluster assignment (or ID) for each node.
68
+ """
69
+ # Create an empty undirected graph
70
+ G = nx.Graph()
71
+
72
+ # Number of nodes
73
+ N = len(languages)
74
+
75
+ # Add edges to the graph from the distance matrix.
76
+ # Only iterate over the upper triangle of the matrix (i < j)
77
+ for i in range(N):
78
+ for j in range(i + 1, N):
79
+ G.add_edge(i, j, weight=matrix[i, j])
80
+
81
+ # Compute the Minimum Spanning Tree using NetworkX's built-in function.
82
+ mst = nx.minimum_spanning_tree(G)
83
+
84
+ # Choose a layout for the MST. Here we use Kamada-Kawai layout which considers edge weights.
85
+ pos = nx.kamada_kawai_layout(mst, weight='weight')
86
+
87
+ # Map each cluster to a color
88
+ unique_clusters = sorted(set(clusters))
89
+ cmap = get_dynamic_color_map(len(unique_clusters))
90
+ cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
91
+
92
+ node_colors = [cluster_colors.get(cluster) for cluster in clusters]
93
+
94
+ # Create a figure for plotting.
95
+ fig, ax = plt.subplots(figsize=fig_size)
96
+
97
+ # Draw the MST edges.
98
+ nx.draw_networkx_edges(mst, pos, edge_color='gray', ax=ax)
99
+
100
+ # Draw the nodes with colors corresponding to their clusters.
101
+ nx.draw_networkx_nodes(mst, pos, node_color=node_colors, node_size=100, ax=ax, alpha=0.7)
102
+
103
+ # Instead of directly drawing labels, we create text objects to adjust them later
104
+ texts = []
105
+ for i, label in enumerate(languages):
106
+ x, y = pos[i]
107
+ texts.append(ax.text(x, y, label, fontsize=10))
108
+
109
+ # Adjust text labels to minimize overlap.
110
+ # The arrowprops argument can draw arrows from labels to nodes if desired.
111
+ adjust_text(texts, expand_text=(1.05, 1.2))
112
+
113
+ # Add a legend for clusters
114
+ if legend is None:
115
+ legend = {cluster: str(cluster) for cluster in unique_clusters}
116
+ legend_handles = [
117
+ plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, alpha=0.7, label=legend[cluster])
118
+ for cluster in unique_clusters
119
+ ]
120
+ ax.legend(handles=legend_handles, title="Clusters", loc="best")
121
+
122
+ # Remove axis for clarity.
123
+ ax.axis('off')
124
+ # ax.set_title(f"Minimum Spanning Tree of Languages ({'Average' if use_average else f'{model}, {dataset}'})")
125
+
126
+ return fig
127
+
128
+ def cluster_languages_kmeans(dist_matrix, languages, n_clusters=5):
129
+ """
130
+ Clusters languages using a distance matrix and KMeans.
131
+
132
+ Parameters:
133
+ - dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between languages.
134
+ - n_clusters: int, the number of clusters to form.
135
+
136
+ Returns:
137
+ - filtered_matrix: 2D NumPy array of the filtered distance matrix.
138
+ - filtered_languages: list of filtered languages.
139
+ - filtered_clusters: list of filtered cluster assignments.
140
+ """
141
+
142
+ # Perform clustering using KMeans
143
+ kmeans_model = KMeans(n_clusters=n_clusters, random_state=23)
144
+ clusters = kmeans_model.fit_predict(dist_matrix)
145
+
146
+ # Count the number of elements in each cluster
147
+ cluster_counts = np.bincount(clusters)
148
+
149
+ # Identify clusters with more than 1 element
150
+ valid_clusters = np.where(cluster_counts > 1)[0]
151
+
152
+ # Filter out points belonging to clusters with only 1 element
153
+ valid_indices = np.isin(clusters, valid_clusters)
154
+ filtered_matrix = dist_matrix[np.ix_(valid_indices, valid_indices)]
155
+ filtered_languages = np.array(languages)[valid_indices]
156
+ filtered_clusters = np.array(clusters)[valid_indices]
157
+
158
+ return filtered_matrix, filtered_languages, filtered_clusters
159
+
160
+
161
+ def cluster_languages_hdbscan(dist_matrix, languages, min_cluster_size=2):
162
+ """
163
+ Clusters languages using a distance matrix and HDBSCAN.
164
+
165
+ Parameters:
166
+ - dist_matrix: 2D NumPy array (N x N) representing the pairwise distances between languages.
167
+ - min_cluster_size: int, the minimum size of clusters.
168
+
169
+ Returns:
170
+ - clusters: list of length N containing the cluster assignment (or ID) for each language.
171
+ """
172
+ # Perform clustering using HDBSCAN with the precomputed distance matrix
173
+ clustering_model = HDBSCAN(
174
+ metric='precomputed', min_cluster_size=min_cluster_size
175
+ )
176
+ clusters = clustering_model.fit_predict(dist_matrix)
177
+
178
+ # Filter out points belonging to cluster -1 using NumPy
179
+ valid_indices = np.where(clusters != -1)[0]
180
+ filtered_matrix = dist_matrix[np.ix_(valid_indices, valid_indices)]
181
+ filtered_languages = np.array(languages)[valid_indices]
182
+ filtered_clusters = np.array(clusters)[valid_indices]
183
+ return filtered_matrix, filtered_languages, filtered_clusters
184
+
185
+
186
+ def plot_distances_tsne(model, dataset, use_average, matrix, languages, clusters, legend=None):
187
+ """
188
+ Plots all languages from the distances matrix using t-SNE and colors them by clusters.
189
+ """
190
+ tsne = TSNE(n_components=2, random_state=23, metric="precomputed", init="random")
191
+ tsne_results = tsne.fit_transform(matrix)
192
+
193
+ # Map each cluster to a color
194
+ unique_clusters = sorted(set(clusters))
195
+ cmap = get_dynamic_color_map(len(unique_clusters))
196
+ cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
197
+
198
+ fig, ax = plt.subplots(figsize=(16, 12))
199
+ scatter = ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
200
+
201
+ # for i, lang in enumerate(languages):
202
+ # ax.text(tsne_results[i, 0], tsne_results[i, 1], lang, fontsize=8, alpha=0.8)
203
+
204
+ # Instead of directly drawing labels, we create text objects to adjust them later
205
+ texts = []
206
+ for i, label in enumerate(languages):
207
+ x, y = tsne_results[i, 0], tsne_results[i, 1]
208
+ texts.append(ax.text(x, y, label, fontsize=10))
209
+
210
+ # Adjust text labels to minimize overlap.
211
+ # The arrowprops argument can draw arrows from labels to nodes if desired.
212
+ adjust_text(texts, expand_text=(1.05, 1.2))
213
+
214
+ # Add a legend for clusters
215
+ if legend is None:
216
+ legend = {cluster: str(cluster) for cluster in unique_clusters}
217
+ legend_handles = [
218
+ plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
219
+ for cluster in unique_clusters
220
+ ]
221
+ ax.legend(handles=legend_handles, title="Clusters", loc="best")
222
+
223
+ ax.set_title(f"t-SNE Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
224
+ ax.set_xlabel("t-SNE Dimension 1")
225
+ ax.set_ylabel("t-SNE Dimension 2")
226
+ return fig
227
+
228
+
229
+ def plot_distances_umap(model, dataset, use_average, matrix, languages, clusters, legend=None):
230
+ """
231
+ Plots all languages from the distances matrix using UMAP and colors them by clusters.
232
+ """
233
+
234
+ umap_model = umap.UMAP(metric="precomputed", random_state=23)
235
+ umap_results = umap_model.fit_transform(matrix)
236
+
237
+ # Map each cluster to a color
238
+ unique_clusters = sorted(set(clusters))
239
+ cmap = get_dynamic_color_map(len(unique_clusters))
240
+ cluster_colors = {cluster: cmap[i] for i, cluster in enumerate(unique_clusters)}
241
+
242
+ fig, ax = plt.subplots(figsize=(16, 12))
243
+ scatter = ax.scatter(umap_results[:, 0], umap_results[:, 1], c=[cluster_colors[cluster] for cluster in clusters], alpha=0.7)
244
+
245
+ # for i, lang in enumerate(languages):
246
+ # ax.text(umap_results[i, 0], umap_results[i, 1], lang, fontsize=8, alpha=0.8)
247
+
248
+ # Instead of directly drawing labels, we create text objects to adjust them later
249
+ texts = []
250
+ for i, label in enumerate(languages):
251
+ x, y = umap_results[i, 0], umap_results[i, 1]
252
+ texts.append(ax.text(x, y, label, fontsize=10))
253
+
254
+ # Adjust text labels to minimize overlap.
255
+ # The arrowprops argument can draw arrows from labels to nodes if desired.
256
+ adjust_text(texts, expand_text=(1.05, 1.2))
257
+
258
+ # Add a legend for clusters
259
+ if legend is None:
260
+ legend = {cluster: str(cluster) for cluster in unique_clusters}
261
+ legend_handles = [
262
+ plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cluster_colors[cluster], markersize=10, label=legend[cluster])
263
+ for cluster in unique_clusters
264
+ ]
265
+ ax.legend(handles=legend_handles, title="Clusters", loc="best")
266
+
267
+ ax.set_title(f"UMAP Visualization of Language Distances ({'Average' if use_average else f'{model}, {dataset}'})")
268
+ ax.set_xlabel("UMAP Dimension 1")
269
+ ax.set_ylabel("UMAP Dimension 2")
270
+ return fig