Spaces:
Sleeping
Sleeping
import numpy as np | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D | |
from scipy.stats import gaussian_kde | |
import gradio as gr | |
from pathlib import Path | |
import gradio as gr | |
import plotly.graph_objects as go | |
import re | |
import ast | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
def convert_google_sheet_url(url): | |
# Regular expression to match and capture the necessary part of the URL | |
pattern = r'https://docs\.google\.com/spreadsheets/d/([a-zA-Z0-9-_]+)(/edit#gid=(\d+)|/edit.*)?' | |
# Replace function to construct the new URL for CSV export | |
# If gid is present in the URL, it includes it in the export URL, otherwise, it's omitted | |
replacement = lambda m: f'https://docs.google.com/spreadsheets/d/{m.group(1)}/export?' + (f'gid={m.group(3)}&' if m.group(3) else '') + 'format=csv' | |
# Replace using regex | |
new_url = re.sub(pattern, replacement, url) | |
return new_url | |
# Replace with your modified URL | |
# url = "https://docs.google.com/spreadsheets/d/1dlTjKJrGVwRDU8m-hT53IdSluRAsWXftnx5uRqnq4yE/edit?gid=0#gid=0" | |
url = "https://docs.google.com/spreadsheets/d/1MY0-DOitMZGnib73BAaSKg0TI7i5V1CXP8dF6jAgKWc/edit?gid=293606167#gid=293606167" | |
new_url = convert_google_sheet_url(url) | |
df = pd.read_csv(new_url) | |
# Set 'Categories' column as index | |
df1 = df.copy() | |
df1.set_index('Categories', inplace=True) | |
transposed_df = df.transpose() | |
transposed_df.columns = transposed_df.iloc[0] | |
df = transposed_df.drop(["Categories"]) | |
df = df.fillna("[]") | |
df1 = df1.fillna("[]") | |
# Convert the string representation of lists into actual lists for all relevant columns | |
for col in df.columns: # Skip the first column which is 'Categories' | |
df[col] = df[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) | |
# Convert the string representation of lists into actual lists for all relevant columns | |
for col in df1.columns: # Skip the first column which is 'Categories' | |
df1[col] = df1[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) | |
cols = df.columns | |
# Get the specific column while filtering out empty cells | |
column_data = df[cols[0]] | |
# Filter out the empty lists ([]) | |
filtered_column_data = column_data[column_data.apply(lambda x: x != [])] | |
def get_score(avg_kl_div,kl_div,missing,extra,common): | |
Wc=1 | |
Wm=1.5 | |
We=1.5 | |
WeE=(We*extra)**2 | |
WeM=(Wm*missing)**2 | |
WeC=(We*common)**2 | |
if kl_div==-1: | |
kl_div=avg_kl_div | |
kl_div_factor=kl_div/avg_kl_div | |
ans=kl_div_factor*(((WeE+WeM)/WeC)-2)# (e**2 -c**2)/c**2 +(m**2-c**2)/c**2 => (0-1)*[((e**2+m**2)/c**2 -2)] => ((rank*y/a)m(m+1)/2)) | |
return ans | |
def get_individual_score(avg_kl_div,kl_div,e_or_m,common): | |
if kl_div==-1: | |
kl_div=avg_kl_div | |
kl_div_factor=kl_div/avg_kl_div | |
weight=1.5 | |
ans=avg_kl_div + ((1+(e_or_m/common))*(((e_or_m)*(e_or_m+1)))/2)**0.5 # X +- [(1+b/a)*n**2*y] | |
# ans = kl_div_factor*((((weight*e_or_m)**2)/(common**2))-1) | |
return ans | |
def get_entity_scores(ans4): | |
# Calculate average KL divergence | |
tt = 0 | |
avg_kl_div = 0 | |
for t in ans4: | |
if t[0] != -1: | |
avg_kl_div += t[0] | |
tt += 1 | |
# Avoid division by zero | |
if tt > 0: | |
avg_kl_div /= tt | |
else: | |
avg_kl_div = 0 | |
extra_entity_score = [] | |
missing_entity_score = [] | |
for t in ans4: | |
extra_entity_score.append(get_individual_score(avg_kl_div, t[0], t[2], t[3])) | |
missing_entity_score.append(get_individual_score(avg_kl_div, t[0], t[1], t[3])) | |
extra_entity_score.sort() | |
missing_entity_score.sort() | |
return ( | |
missing_entity_score[:int(0.950 * len(missing_entity_score))], | |
extra_entity_score[:int(0.95 * len(extra_entity_score))] | |
) | |
compare = df.columns[0] | |
column_data = df[compare] | |
# Filter out the empty lists ([]) | |
filtered_column_data = column_data[column_data.apply(lambda x: x != [])] | |
# Display the filtered column data | |
variables = filtered_column_data.to_list() | |
models = filtered_column_data.index.to_list() | |
color_schemes = [ | |
'#d60000', # Red | |
'#2f5282', # Navy Blue | |
'#f15cd8', # Pink | |
'#66abb7', # Light Teal | |
'#ce7391', # Rose | |
'#6bdb7a', # Light Green | |
'#ea8569', # Coral | |
'#b36cc9', # Lavender | |
'#ffd700', # Gold | |
'#ff7f0e', # Orange | |
'#1f77b4', # Blue | |
'#2ca02c', # Green | |
] | |
colors = color_schemes[:len(models)] | |
values_dict = {model: var for var, model in zip(variables, models)} | |
color_dict = {model: color for model, color in zip(models, colors)} | |
# plot_grouped_3d_kde(values_dict, models, color_dict, compare) | |
import numpy as np | |
import plotly.graph_objects as go | |
from scipy.stats import gaussian_kde | |
import plotly.express as px | |
def adjust_kde_range(data, increment=25, threshold=0.00005): | |
kde = gaussian_kde(data) | |
min_x, max_x = min(data) - increment, max(data) + increment | |
# Keep expanding the range until both tails get close to zero | |
while True: | |
x_values = np.linspace(min_x, max_x, 1000) | |
y_values = kde(x_values) | |
# # Check the values at the tails | |
# print(y_values[0], y_values[-1]) | |
# print(x_values[0], x_values[-1], "\n") | |
if y_values[0] < threshold and y_values[-1] < threshold: | |
break # Stop if both tails are below the threshold | |
# Extend the range | |
min_x -= increment | |
max_x += increment | |
return x_values, y_values | |
def compute_kde_ranges(missing_scores, extra_scores): | |
data1 = np.array(missing_scores) | |
data2 = -np.array(extra_scores) # Negate extra scores for alignment | |
# Compute KDE for missing scores with extended range | |
x_missing, y_missing = adjust_kde_range(data1) | |
# Compute KDE for extra scores with extended range | |
x_extra, y_extra = adjust_kde_range(data2) | |
# Calculate axis limits | |
Val_x_extra = [max(x_extra)] | |
Val_x_miss = [x_missing[np.argmax(y_missing)]] | |
peak_extra = max(y_extra) | |
peak_miss = max(y_missing) | |
# Calculate the x and y axis ranges | |
min_x = min(min(x_missing), min(x_extra)) | |
max_x = max(max(x_missing), max(x_extra)) | |
x_range = [min_x, max_x] | |
y_range = [-peak_extra, peak_miss * 1.25] | |
return x_missing, y_missing, x_extra, y_extra, x_range, y_range | |
def calculate_ticks(x_min, x_max, num_ticks=20): | |
# Calculate the total range | |
total_range = x_max - x_min | |
# Determine the interval between ticks | |
interval = total_range / (num_ticks - 1) # We need num_ticks - 1 intervals | |
# Generate tick values | |
ticks = np.arange(x_min, x_max + interval, interval) | |
return ticks | |
def plot_filled_surface(x, z, y_level, color): | |
""" | |
Create a 3D mesh to fill the surface between the KDE curve and the 0-axis. | |
""" | |
x_full = np.concatenate([x, x[::-1]]) # X-axis values, with reverse for baseline | |
z_full = np.concatenate([z, np.zeros_like(z)]) # Z-axis (KDE and baseline at 0) | |
y_full = np.full_like(x_full, y_level) # Flat Y plane (constant for each model) | |
num_pts = len(x) | |
i = np.arange(num_pts - 1) | |
j = i + 1 | |
k = i + num_pts | |
i = np.concatenate([i, i + num_pts]) | |
j = np.concatenate([j, j + num_pts]) | |
k = np.concatenate([k, i[:len(i)//2]]) | |
return go.Mesh3d( | |
x=x_full, y=y_full, z=z_full, | |
i=i, j=j, k=k, | |
opacity=0.5, | |
color=color, | |
showscale=False, | |
legendgroup='filling' | |
) | |
def plot_kde_3d(values_dict, models, color_dict, compare): | |
# values_dict, models, color_dict, compare = (values_dict, models, color_dict, 'Comparison Title') | |
fig = go.Figure() | |
model_y_positions = {model: i for i, model in enumerate(models)} | |
x_ranges = [] | |
y_ranges = [] | |
for model in models: | |
missing_scores, extra_scores = get_entity_scores(values_dict[model]) | |
# Compute KDE and ranges for missing and extra scores | |
x_m, y_m, x_e, y_e, x_range, y_range = compute_kde_ranges(missing_scores, extra_scores) | |
# Append ranges for global limits | |
x_ranges.append(x_range) | |
y_ranges.append(y_range) | |
# Get color for this model | |
color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)') # Default color if not found | |
# Create filled surfaces between KDE curves and zero line | |
fig.add_trace(plot_filled_surface(x_m, y_m, model_y_positions[model], color)) | |
fig.add_trace(plot_filled_surface(x_e, -y_e, model_y_positions[model], color)) | |
# Plot the KDE lines (for visualization of the curves) | |
fig.add_trace(go.Scatter3d( | |
x=x_m, | |
y=[model_y_positions[model]] * len(x_m), | |
z=y_m, | |
mode='lines', | |
line=dict(color='blue'), | |
showlegend=False | |
)) | |
fig.add_trace(go.Scatter3d( | |
x=x_e, | |
y=[model_y_positions[model]] * len(x_e), | |
z=-y_e, | |
mode='lines', | |
line=dict(color='red'), | |
showlegend=False # Hide legend for extra scores to combine with missing scores | |
)) | |
# Compute global x and y limits | |
x_min = min(r[0] for r in x_ranges) | |
x_max = max(r[1] for r in x_ranges) | |
y_min = min(r[0] for r in y_ranges) | |
y_max = max(r[1] for r in y_ranges) | |
# Define x, y, z axis tick intervals | |
x_ticks = calculate_ticks(np.floor(x_min), np.ceil(x_max)) | |
y_ticks = list(model_y_positions.values()) | |
z_ticks = calculate_ticks(y_min, y_max) | |
# Add a line through the 0-axis of density for each model | |
for model in models: | |
color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)') | |
fig.add_trace(go.Scatter3d( | |
x=[x_min, x_max], | |
y=[model_y_positions[model], model_y_positions[model]], | |
z=[0, 0], | |
mode='lines', | |
# line=dict(color=color, width=2, dash='dash'), | |
line=dict(color=color), | |
name=model, | |
# showlegend=False | |
)) | |
# Update layout for 3D plot | |
fig.update_layout( | |
title=f'3D KDE Plots for {compare}', | |
scene=dict( | |
xaxis_title='Score', | |
yaxis_title='Model', | |
zaxis_title='Density', | |
xaxis=dict( | |
range=[x_min, x_max], | |
tickvals=x_ticks, | |
ticktext=[f'{tick:.2f}' for tick in x_ticks] | |
), | |
yaxis=dict( | |
tickvals=y_ticks, | |
ticktext=[list(model_y_positions.keys())[list(model_y_positions.values()).index(tick)] for tick in y_ticks] | |
), | |
zaxis=dict( | |
range=[y_min, y_max], | |
tickvals=z_ticks, | |
ticktext=[f'{tick:.4f}' for tick in z_ticks] | |
), | |
camera=dict( | |
eye=dict(x=1.25, y=1.25, z=1.25) | |
) | |
), | |
autosize=True, | |
width=1200*.75, | |
height=800*.75 | |
) | |
# Save the plot as an HTML file | |
# plot = px.scatter(x=range(10), y=range(10)) | |
filename = f"{compare}.html" | |
fig.write_html(filename) | |
# fig.show() | |
return fig | |
# Path to your saved HTML file | |
html_file_path = '3d_plot.html' | |
title = 'My 3D Plot' | |
def display_plot(): | |
fig = plot_kde_3d(values_dict, models, color_dict, compare) | |
return fig | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=display_plot, | |
inputs=[], | |
outputs=gr.Plot(), | |
title='Plotly 3D Plot in Gradio', | |
description='This app displays a 3D Plotly plot directly in the Gradio interface.', | |
live=False | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
interface.launch() | |