back to it! first real try
Browse files- app.py +138 -3
- requirements.txt +5 -0
- utils.py +28 -0
app.py
CHANGED
@@ -1,7 +1,142 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
demo.launch()
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
import os
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import netCDF4 as nc
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from utils import fig2img, create_gif_from_frames
|
8 |
+
|
9 |
+
# Global dictionary to store datasets in memory
|
10 |
+
datasets = {}
|
11 |
+
dataset_metadata = {"CE-RPUI": {"key":"data", "features":["density", "horz_velocity", "vertical_velocity", "pressure", "energy"]}, "CE-RM": {"key":"solution", "features":["density", "horz_velocity", "vertical_velocity", "pressure", "energy"]}, "NS-PwC": {"key":"velocity", "features":["horz_velocity", "vertical_velocity", "passive_tracer"]}}
|
12 |
+
feature_label_map = {"density": "Density", "horz_velocity": "Horizontal Velocity", "vertical_velocity": "Vertical Velocity", "pressure": "Pressure", "energy": "Energy", "passive_tracer": "Passive Tracer"}
|
13 |
+
|
14 |
+
def load_dataset_into_memory(dataset_name):
|
15 |
+
if dataset_name not in datasets:
|
16 |
+
print(f"Loading dataset: {dataset_name}...")
|
17 |
+
|
18 |
+
dataset_path = hf_hub_download(
|
19 |
+
repo_id=f"camlab-ethz/{dataset_name}",
|
20 |
+
filename=dataset_metadata[dataset_name]['key'] + "_0.nc",
|
21 |
+
repo_type="dataset",
|
22 |
+
token=os.getenv("HF_TOKEN")
|
23 |
+
)
|
24 |
+
|
25 |
+
dataset = nc.Dataset(dataset_path)
|
26 |
+
datasets[dataset_name] = dataset.variables[dataset_metadata[dataset_name]['key']] # Store only the data variable
|
27 |
+
|
28 |
+
return datasets[dataset_name]
|
29 |
+
|
30 |
+
# Function to generate GIFs dynamically
|
31 |
+
def generate_gifs(dataset_name, index):
|
32 |
+
data = load_dataset_into_memory(dataset_name) # Load dataset from memory
|
33 |
+
first_traj = data[index] # Extract trajectory at given index
|
34 |
+
|
35 |
+
# Extract features
|
36 |
+
if dataset_name == "NS-PwC":
|
37 |
+
features = {
|
38 |
+
"horz_velocity": first_traj[:, 0, :, :],
|
39 |
+
"vertical_velocity": first_traj[:, 1, :, :],
|
40 |
+
"passive_tracer": first_traj[:, 2, :, :],
|
41 |
+
}
|
42 |
+
else:
|
43 |
+
features = {
|
44 |
+
"density": first_traj[:, 0, :, :],
|
45 |
+
"horz_velocity": first_traj[:, 1, :, :],
|
46 |
+
"vertical_velocity": first_traj[:, 2, :, :],
|
47 |
+
"pressure": first_traj[:, 3, :, :],
|
48 |
+
"energy": first_traj[:, 4, :, :]
|
49 |
+
}
|
50 |
+
|
51 |
+
gif_paths = []
|
52 |
+
output_dir = "episode_gifs"
|
53 |
+
os.makedirs(output_dir, exist_ok=True) # Ensure output directory exists
|
54 |
+
|
55 |
+
for feature_name, feature in features.items():
|
56 |
+
gif_filename = f"{output_dir}/{dataset_name.lower()}_{feature_name}_{index}.gif"
|
57 |
+
|
58 |
+
# Only generate if the GIF doesn't already exist
|
59 |
+
if not os.path.exists(gif_filename):
|
60 |
+
print(f"Generating GIF: {gif_filename}")
|
61 |
+
frames = []
|
62 |
+
for i in range(len(feature)):
|
63 |
+
fig = plt.figure(figsize=(1.28, 1.28), dpi=100, frameon=False)
|
64 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
65 |
+
ax.set_axis_off()
|
66 |
+
fig.add_axes(ax)
|
67 |
+
|
68 |
+
ax.imshow(feature[i, :, :])
|
69 |
+
frames.append(fig2img(fig))
|
70 |
+
plt.close(fig)
|
71 |
+
|
72 |
+
create_gif_from_frames(frames, gif_filename)
|
73 |
+
else:
|
74 |
+
print(f"Using cached GIF: {gif_filename}")
|
75 |
+
|
76 |
+
gif_paths.append(gif_filename)
|
77 |
+
|
78 |
+
return gif_paths # Return paths to update the Gradio UI
|
79 |
+
|
80 |
+
|
81 |
+
gif_css = """
|
82 |
+
#gif-label {
|
83 |
+
text-align: center;
|
84 |
+
width: 100%;
|
85 |
+
display: flex;
|
86 |
+
justify-content: center;
|
87 |
+
}
|
88 |
+
"""
|
89 |
+
|
90 |
+
css = """
|
91 |
+
#label {
|
92 |
+
text-align: center;
|
93 |
+
width: 100%;
|
94 |
+
}
|
95 |
+
"""
|
96 |
+
|
97 |
+
with gr.Blocks(css=css) as demo:
|
98 |
+
gr.Markdown(
|
99 |
+
"""
|
100 |
+
# Poseidon: Efficient Foundation Models for PDEs
|
101 |
+
|
102 |
+
Partial Differential Equations (PDEs) are fundamental in describing various physical phenomena, such as heat conduction, fluid dynamics, and electromagnetism. Poseidon leverages advanced modeling techniques to provide efficient solutions for complex PDEs, enabling faster computations and deeper insights into these phenomena.
|
103 |
+
"""
|
104 |
+
)
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
dataset_selector = gr.Dropdown(choices=["CE-RPUI", "CE-RM", "NS-PwC"], value="CE-RPUI", label="Select Dataset")
|
108 |
+
index_slider = gr.Slider(minimum=0, maximum=100, value=0, step=1, label="Select Index")
|
109 |
+
|
110 |
+
gif_outputs = []
|
111 |
+
label_outputs = []
|
112 |
+
with gr.Row(equal_height=True) as gif_container:
|
113 |
+
for _ in range(5):
|
114 |
+
with gr.Column(scale=1, min_width=0):
|
115 |
+
output = gr.Image(type="filepath", show_label=False, container=False)
|
116 |
+
gif_outputs.append(output)
|
117 |
+
|
118 |
+
with gr.Column(): # Wrap label inside a column
|
119 |
+
label = gr.Markdown(elem_id="label")
|
120 |
+
label_outputs.append(label)
|
121 |
+
|
122 |
+
def update_layout(dataset_name, index):
|
123 |
+
features = dataset_metadata[dataset_name]["features"]
|
124 |
+
gif_paths = generate_gifs(dataset_name, index)
|
125 |
+
|
126 |
+
output_values = []
|
127 |
+
label_values = []
|
128 |
+
for i in range(5):
|
129 |
+
if i < len(features):
|
130 |
+
output_values.append(gr.update(value=gif_paths[i], visible=True))
|
131 |
+
label_values.append(gr.update(value=f"**{feature_label_map[features[i]]}**", visible=True))
|
132 |
+
else:
|
133 |
+
output_values.append(gr.update(visible=False))
|
134 |
+
label_values.append(gr.update(visible=False))
|
135 |
+
|
136 |
+
return output_values + label_values
|
137 |
+
|
138 |
+
dataset_selector.change(update_layout, inputs=[dataset_selector, index_slider], outputs=gif_outputs + label_outputs)
|
139 |
+
index_slider.change(update_layout, inputs=[dataset_selector, index_slider], outputs=gif_outputs + label_outputs)
|
140 |
|
|
|
141 |
demo.launch()
|
142 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
matplotlib
|
3 |
+
netCDF4
|
4 |
+
huggingface_hub
|
5 |
+
pillow
|
utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
def fig2img(fig):
|
6 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
7 |
+
# Save figure to a temporary buffer.
|
8 |
+
buf = io.BytesIO()
|
9 |
+
fig.savefig(buf)
|
10 |
+
buf.seek(0)
|
11 |
+
|
12 |
+
# Create PIL image from buffer
|
13 |
+
img = Image.open(buf)
|
14 |
+
return img
|
15 |
+
|
16 |
+
def create_gif_from_frames(frames, filename):
|
17 |
+
"""Create a GIF from a list of PIL Image frames"""
|
18 |
+
# Create output directory if it doesn't exist
|
19 |
+
os.makedirs('episode_gifs', exist_ok=True)
|
20 |
+
|
21 |
+
# Save the frames as GIF
|
22 |
+
frames[0].save(
|
23 |
+
f'{filename}',
|
24 |
+
save_all=True,
|
25 |
+
append_images=frames[1:],
|
26 |
+
duration=200, # Duration for each frame in milliseconds
|
27 |
+
loop=0
|
28 |
+
)
|