cgeorgiaw's picture
cgeorgiaw HF Staff
adding upload option
2777d26
raw
history blame
4.36 kB
import gradio as gr
import plotly.graph_objects as go
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import pathlib
import json
import pandas as pd
from evaluation import load_boundary, load_boundaries
from constellaration import forward_model, initial_guess
from constellaration.boozer import boozer
from constellaration.utils import (
file_exporter,
visualization,
visualization_utils,
)
organization = 'proxima-fusion'
results_repo = f'{organization}/constellaration-bench-results'
def read_result_from_hub(filename):
local_path = hf_hub_download(
repo_id=results_repo,
repo_type="dataset",
filename=filename,
)
return local_path
def make_visual(boundary):
vis = visualization.plot_surface(boundary)
return vis
def gradio_interface() -> gr.Blocks:
with gr.Blocks() as demo:
gr.Markdown("""
# Welcome to the ConStellaration Boundary Explorer!
### Here, you can visualize submissions to the ConStellaration Leaderboard, generate and visualize new random boundaries, or upload and visualize your own!
""")
ds = load_dataset(results_repo, split='train')
full_df = pd.DataFrame(ds)
filenames = full_df['result_filename'].to_list()
mode_selector = gr.Radio(choices=["Leaderboard", "Upload", "Generate"],
label="Select input method:",
value="Leaderboard")
with gr.Row():
with gr.Column(visible=True) as leaderboard_ui:
dropdown = gr.Dropdown(choices=filenames, label="Choose a leaderboard entry", value=filenames[0])
rld_btn = gr.Button(value="Reload")
with gr.Column(visible=False) as upload_ui:
upload_box = gr.File(file_types=[".json"], label="Upload your boundary file")
with gr.Column(visible=False) as generate_ui:
aspect_ratio = gr.Number(label="Aspect Ratio", value=3)
elongation = gr.Number(label="Elongation", value=0.5)
rotational_transform = gr.Number(label="Rotational Transform", value=0.4)
n_field_periods = gr.Number(label="Number of Period Fields", value=3)
generate_btn = gr.Button(value="Generate")
plot = gr.Plot()
def update_ui(mode):
return (
gr.update(visible=(mode == "Leaderboard")),
gr.update(visible=(mode == "Upload")),
gr.update(visible=(mode == "Generate")),
)
mode_selector.change(update_ui, inputs=[mode_selector], outputs=[leaderboard_ui, upload_ui, generate_ui])
def get_boundary_from_leaderboard(selected_file):
row = full_df[full_df['result_filename'] == selected_file].iloc[0]
if row['problem_type'] == 'mhd_stable':
raise gr.Error("Sorry this isn't implemented for mhd_stable submissions yet!")
else:
boundary = load_boundary(row['boundary_json'])
vis = make_visual(boundary)
return vis
dropdown.change(get_boundary_from_leaderboard, dropdown, plot)
rld_btn.click(get_boundary_from_leaderboard, dropdown, plot)
def get_boundary_vis_from_upload(uploaded_file):
if uploaded_file is None:
raise gr.Error("Please upload a file.")
with open(uploaded_file.name, 'r') as f:
data = f.read()
boundary = load_boundary(data)
return make_visual(boundary)
upload_box.change(get_boundary_vis_from_upload, inputs=[upload_box], outputs=[plot])
def generate_random_boundary(aspect_ratio, elongation, rotational_transform, n_field_periods):
boundary = initial_guess.generate_rotating_ellipse(
aspect_ratio=aspect_ratio, elongation=elongation, rotational_transform=rotational_transform, n_field_periods=n_field_periods
)
vis = make_visual(boundary)
return vis
generate_btn.click(generate_random_boundary, [aspect_ratio, elongation, rotational_transform, n_field_periods], plot)
return demo
if __name__ == "__main__":
gradio_interface().launch()