Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import pickle | |
import os | |
from sklearn.manifold import TSNE | |
import matplotlib.pyplot as plt | |
from utils import (plot_distances_tsne, | |
plot_distances_umap, | |
cluster_languages_hdbscan, | |
cluster_languages_kmeans, | |
plot_mst, | |
cluster_languages_by_families, | |
cluster_languages_by_subfamilies, | |
filter_languages_by_families) | |
from functools import partial | |
with open("../../results/languages_list.pkl", "rb") as f: | |
languages = pickle.load(f) | |
DATASETS = ["wikimedia/wikipedia", "uonlp/CulturaX", "HuggingFaceFW/fineweb-2"] | |
MODELS = ["mistralai/Mistral-7B-v0.1", "google/gemma-3-4b-pt", "meta-llama/Llama-3.2-1B"] | |
distance_matrices = { | |
dataset: { | |
model: np.load(os.path.join("../../results", dataset, model, "distances_matrix.npy")) | |
for model in MODELS | |
} | |
for dataset in DATASETS | |
} | |
average_distances_matrix = np.load("../../results/average_distances_matrix.npy") | |
def filter_languages_nan(model, dataset, use_average): | |
if use_average: | |
matrix = average_distances_matrix | |
else: | |
matrix = distance_matrices[dataset][model] | |
vector = matrix[0] | |
updated_languages = np.array(languages)[~np.isnan(vector)] | |
updated_matrix = matrix[~np.isnan(vector), :][:, ~np.isnan(vector)] | |
return updated_matrix, updated_languages | |
def get_similar_languages(model, dataset, selected_language, use_average, n): | |
""" | |
Retrieves the distances for the selected language from the chosen model and dataset, | |
sorts them by similarity (lowest distance first), and returns a DataFrame. | |
""" | |
if use_average: | |
matrix = average_distances_matrix | |
else: | |
matrix = distance_matrices[dataset][model] | |
selected_language_index = languages.index(selected_language) | |
distances = matrix[selected_language_index] | |
df = pd.DataFrame({"Language": languages, "Distance": distances}) | |
sorted_distances = df.sort_values(by="Distance") | |
sorted_distances.drop(index=selected_language_index, inplace=True) | |
sorted_distances.reset_index(drop=True, inplace=True) | |
sorted_distances.reset_index(inplace=True) | |
sorted_distances["Distance"] = sorted_distances["Distance"].round(4) | |
return sorted_distances.head(n) | |
def update_languages(model, dataset): | |
""" | |
Returns the language list based on the given model and dataset. | |
""" | |
matrix = distance_matrices[dataset][model] | |
vector = matrix[0] | |
updated_languages = np.array(languages)[~np.isnan(vector)] | |
return list(updated_languages) | |
def update_language_options(model, dataset, language, use_average): | |
if use_average: | |
updated_languages = languages | |
else: | |
updated_languages = update_languages(model, dataset) | |
if language not in updated_languages: | |
language = updated_languages[0] | |
return gr.Dropdown(label="Language", choices=updated_languages, value=language) | |
def toggle_inputs(use_average): | |
if use_average: | |
return gr.update(interactive=False, visible=False), gr.update(interactive=False, visible=False) | |
else: | |
return gr.update(interactive=True, visible=True), gr.update(interactive=True, visible=True) | |
i = 0 | |
def plot_distances(model, dataset, use_average, cluster_method, cluster_method_param, plot_fn): | |
""" | |
Plots all languages from the distances matrix using t-SNE. | |
""" | |
global i | |
updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average) | |
if cluster_method == "HDBSCAN": | |
filtered_matrix, filtered_languages, clusters = cluster_languages_hdbscan( | |
updated_matrix, updated_languages, min_cluster_size=cluster_method_param | |
) | |
legends = None | |
elif cluster_method == "KMeans": | |
filtered_matrix, filtered_languages, clusters = cluster_languages_kmeans( | |
updated_matrix, updated_languages, n_clusters=cluster_method_param | |
) | |
legends = None | |
elif cluster_method == "Family": | |
clusters, legends = cluster_languages_by_families(updated_languages) | |
filtered_matrix = updated_matrix | |
filtered_languages = updated_languages | |
elif cluster_method == "Subfamily": | |
clusters, legends = cluster_languages_by_subfamilies(updated_languages) | |
filtered_matrix = updated_matrix | |
filtered_languages = updated_languages | |
else: | |
raise ValueError("Invalid cluster method") | |
fig = plot_fn(model, dataset, use_average, filtered_matrix, filtered_languages, clusters, legends) | |
fig.tight_layout() | |
fig.savefig(f"plots/plot_{i}.pdf", format="pdf") | |
i += 1 | |
return fig | |
with gr.Blocks() as demo: | |
gr.Markdown("## Language Distance Explorer") | |
average_checkbox = gr.Checkbox(label="Use Average Distances", value=False) | |
with gr.Row(): | |
model_input = gr.Dropdown(label="Model", choices=MODELS, value=MODELS[0]) | |
dataset_input = gr.Dropdown( | |
label="Dataset", | |
choices=DATASETS, | |
value=DATASETS[0] | |
) | |
with gr.Tab(label="Closest Languages Table"): | |
with gr.Row(): | |
language_input = gr.Dropdown(label="Language", choices=languages, value=languages[0]) | |
top_n_input = gr.Slider(label="Top N", minimum=1, maximum=30, step=1, value=10) | |
output_table = gr.Dataframe(label="Similar Languages") | |
model_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input) | |
dataset_input.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input) | |
language_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) | |
model_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) | |
dataset_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) | |
top_n_input.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) | |
average_checkbox.change( | |
fn=toggle_inputs, | |
inputs=[average_checkbox], | |
outputs=[model_input, dataset_input] | |
) | |
average_checkbox.change(fn=update_language_options, inputs=[model_input, dataset_input, language_input, average_checkbox], outputs=language_input) | |
average_checkbox.change(fn=get_similar_languages, inputs=[model_input, dataset_input, language_input, average_checkbox, top_n_input], outputs=output_table) | |
with gr.Tab(label="Distance Plot"): | |
with gr.Row(): | |
cluster_method_input = gr.Dropdown(label="Cluster Method", choices=["HDBSCAN", "KMeans", "Family", "Subfamily"], value="HDBSCAN") | |
clusters_input = gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2) | |
def update_clusters_input_option(cluster_method): | |
if cluster_method == "HDBSCAN": | |
return gr.Slider(label="Minimum Elements in a Cluster", minimum=2, maximum=10, step=1, value=2, visible=True, interactive=True) | |
elif cluster_method == "KMeans": | |
return gr.Slider(label="Number of Clusters", minimum=2, maximum=20, step=1, value=2, visible=True, interactive=True) | |
else: | |
return gr.update(interactive=False, visible=False) | |
cluster_method_input.change(fn=update_clusters_input_option, inputs=[cluster_method_input], outputs=clusters_input) | |
with gr.Row(): | |
plot_tsne_button = gr.Button("Plot t-SNE") | |
plot_umap_button = gr.Button("Plot UMAP") | |
plot_mst_button = gr.Button("Plot MST") | |
with gr.Row(): | |
plot_output = gr.Plot(label="Distance Plot") | |
plot_tsne_button.click(fn=partial(plot_distances, plot_fn=plot_distances_tsne), | |
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input], | |
outputs=plot_output) | |
plot_umap_button.click(fn=partial(plot_distances, plot_fn=plot_distances_umap), | |
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input], | |
outputs=plot_output) | |
plot_mst_button.click(fn=partial(plot_distances, plot_fn=plot_mst), | |
inputs=[model_input, dataset_input, average_checkbox, cluster_method_input, clusters_input], | |
outputs=plot_output) | |
with gr.Tab(label="Language Families Subplot"): | |
checked_families_input = gr.CheckboxGroup(label="Language Families", | |
choices=[ | |
'Afroasiatic', | |
'Austroasiatic', | |
'Austronesian', | |
'Constructed', | |
'Creole', | |
'Dravidian', | |
'Germanic', | |
'Indo-European', | |
'Japonic', | |
'Kartvelian', | |
'Koreanic', | |
'Language Isolate', | |
'Niger-Congo', | |
'Northeast Caucasian', | |
'Romance', | |
'Sino-Tibetan', | |
'Turkic', | |
'Uralic' | |
], | |
value=["Indo-European"]) | |
with gr.Row(): | |
plot_family_button = gr.Button("Plot Families") | |
plot_figsize_h_input = gr.Slider(label="Figure Height", minimum=5, maximum=30, step=1, value=15) | |
plot_figsize_w_input = gr.Slider(label="Figure Width", minimum=5, maximum=30, step=1, value=15) | |
plot_family_output = gr.Plot(label="Families Plot") | |
def plot_families_subfamilies(families, model, dataset, use_average, figsize_h, figsize_w): | |
global i | |
updated_matrix, updated_languages = filter_languages_nan(model, dataset, use_average) | |
updated_matrix, updated_languages = filter_languages_by_families(updated_matrix, updated_languages, families) | |
clusters, legends = cluster_languages_by_subfamilies(updated_languages) | |
fig = plot_mst(model, dataset, use_average, updated_matrix, updated_languages, clusters, legends, fig_size=(figsize_w, figsize_h)) | |
fig.tight_layout() | |
fig.savefig(f"plots/plot_{i}.pdf", format="pdf") | |
i += 1 | |
return fig | |
plot_family_button.click(fn=plot_families_subfamilies, | |
inputs=[checked_families_input, model_input, dataset_input, average_checkbox, plot_figsize_h_input, plot_figsize_w_input], | |
outputs=plot_family_output) | |
demo.launch(share=True) | |