cgeorgiaw HF Staff commited on
Commit
d160535
·
1 Parent(s): 02d93ed

back to it! first real try

Browse files
Files changed (3) hide show
  1. app.py +138 -3
  2. requirements.txt +5 -0
  3. utils.py +28 -0
app.py CHANGED
@@ -1,7 +1,142 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
 
1
  import gradio as gr
2
 
3
+ import os
4
+ import matplotlib.pyplot as plt
5
+ import netCDF4 as nc
6
+ from huggingface_hub import hf_hub_download
7
+ from utils import fig2img, create_gif_from_frames
8
+
9
+ # Global dictionary to store datasets in memory
10
+ datasets = {}
11
+ dataset_metadata = {"CE-RPUI": {"key":"data", "features":["density", "horz_velocity", "vertical_velocity", "pressure", "energy"]}, "CE-RM": {"key":"solution", "features":["density", "horz_velocity", "vertical_velocity", "pressure", "energy"]}, "NS-PwC": {"key":"velocity", "features":["horz_velocity", "vertical_velocity", "passive_tracer"]}}
12
+ feature_label_map = {"density": "Density", "horz_velocity": "Horizontal Velocity", "vertical_velocity": "Vertical Velocity", "pressure": "Pressure", "energy": "Energy", "passive_tracer": "Passive Tracer"}
13
+
14
+ def load_dataset_into_memory(dataset_name):
15
+ if dataset_name not in datasets:
16
+ print(f"Loading dataset: {dataset_name}...")
17
+
18
+ dataset_path = hf_hub_download(
19
+ repo_id=f"camlab-ethz/{dataset_name}",
20
+ filename=dataset_metadata[dataset_name]['key'] + "_0.nc",
21
+ repo_type="dataset",
22
+ token=os.getenv("HF_TOKEN")
23
+ )
24
+
25
+ dataset = nc.Dataset(dataset_path)
26
+ datasets[dataset_name] = dataset.variables[dataset_metadata[dataset_name]['key']] # Store only the data variable
27
+
28
+ return datasets[dataset_name]
29
+
30
+ # Function to generate GIFs dynamically
31
+ def generate_gifs(dataset_name, index):
32
+ data = load_dataset_into_memory(dataset_name) # Load dataset from memory
33
+ first_traj = data[index] # Extract trajectory at given index
34
+
35
+ # Extract features
36
+ if dataset_name == "NS-PwC":
37
+ features = {
38
+ "horz_velocity": first_traj[:, 0, :, :],
39
+ "vertical_velocity": first_traj[:, 1, :, :],
40
+ "passive_tracer": first_traj[:, 2, :, :],
41
+ }
42
+ else:
43
+ features = {
44
+ "density": first_traj[:, 0, :, :],
45
+ "horz_velocity": first_traj[:, 1, :, :],
46
+ "vertical_velocity": first_traj[:, 2, :, :],
47
+ "pressure": first_traj[:, 3, :, :],
48
+ "energy": first_traj[:, 4, :, :]
49
+ }
50
+
51
+ gif_paths = []
52
+ output_dir = "episode_gifs"
53
+ os.makedirs(output_dir, exist_ok=True) # Ensure output directory exists
54
+
55
+ for feature_name, feature in features.items():
56
+ gif_filename = f"{output_dir}/{dataset_name.lower()}_{feature_name}_{index}.gif"
57
+
58
+ # Only generate if the GIF doesn't already exist
59
+ if not os.path.exists(gif_filename):
60
+ print(f"Generating GIF: {gif_filename}")
61
+ frames = []
62
+ for i in range(len(feature)):
63
+ fig = plt.figure(figsize=(1.28, 1.28), dpi=100, frameon=False)
64
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
65
+ ax.set_axis_off()
66
+ fig.add_axes(ax)
67
+
68
+ ax.imshow(feature[i, :, :])
69
+ frames.append(fig2img(fig))
70
+ plt.close(fig)
71
+
72
+ create_gif_from_frames(frames, gif_filename)
73
+ else:
74
+ print(f"Using cached GIF: {gif_filename}")
75
+
76
+ gif_paths.append(gif_filename)
77
+
78
+ return gif_paths # Return paths to update the Gradio UI
79
+
80
+
81
+ gif_css = """
82
+ #gif-label {
83
+ text-align: center;
84
+ width: 100%;
85
+ display: flex;
86
+ justify-content: center;
87
+ }
88
+ """
89
+
90
+ css = """
91
+ #label {
92
+ text-align: center;
93
+ width: 100%;
94
+ }
95
+ """
96
+
97
+ with gr.Blocks(css=css) as demo:
98
+ gr.Markdown(
99
+ """
100
+ # Poseidon: Efficient Foundation Models for PDEs
101
+
102
+ Partial Differential Equations (PDEs) are fundamental in describing various physical phenomena, such as heat conduction, fluid dynamics, and electromagnetism. Poseidon leverages advanced modeling techniques to provide efficient solutions for complex PDEs, enabling faster computations and deeper insights into these phenomena.
103
+ """
104
+ )
105
+
106
+ with gr.Row():
107
+ dataset_selector = gr.Dropdown(choices=["CE-RPUI", "CE-RM", "NS-PwC"], value="CE-RPUI", label="Select Dataset")
108
+ index_slider = gr.Slider(minimum=0, maximum=100, value=0, step=1, label="Select Index")
109
+
110
+ gif_outputs = []
111
+ label_outputs = []
112
+ with gr.Row(equal_height=True) as gif_container:
113
+ for _ in range(5):
114
+ with gr.Column(scale=1, min_width=0):
115
+ output = gr.Image(type="filepath", show_label=False, container=False)
116
+ gif_outputs.append(output)
117
+
118
+ with gr.Column(): # Wrap label inside a column
119
+ label = gr.Markdown(elem_id="label")
120
+ label_outputs.append(label)
121
+
122
+ def update_layout(dataset_name, index):
123
+ features = dataset_metadata[dataset_name]["features"]
124
+ gif_paths = generate_gifs(dataset_name, index)
125
+
126
+ output_values = []
127
+ label_values = []
128
+ for i in range(5):
129
+ if i < len(features):
130
+ output_values.append(gr.update(value=gif_paths[i], visible=True))
131
+ label_values.append(gr.update(value=f"**{feature_label_map[features[i]]}**", visible=True))
132
+ else:
133
+ output_values.append(gr.update(visible=False))
134
+ label_values.append(gr.update(visible=False))
135
+
136
+ return output_values + label_values
137
+
138
+ dataset_selector.change(update_layout, inputs=[dataset_selector, index_slider], outputs=gif_outputs + label_outputs)
139
+ index_slider.change(update_layout, inputs=[dataset_selector, index_slider], outputs=gif_outputs + label_outputs)
140
 
 
141
  demo.launch()
142
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ matplotlib
3
+ netCDF4
4
+ huggingface_hub
5
+ pillow
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from PIL import Image
4
+
5
+ def fig2img(fig):
6
+ """Convert a Matplotlib figure to a PIL Image and return it"""
7
+ # Save figure to a temporary buffer.
8
+ buf = io.BytesIO()
9
+ fig.savefig(buf)
10
+ buf.seek(0)
11
+
12
+ # Create PIL image from buffer
13
+ img = Image.open(buf)
14
+ return img
15
+
16
+ def create_gif_from_frames(frames, filename):
17
+ """Create a GIF from a list of PIL Image frames"""
18
+ # Create output directory if it doesn't exist
19
+ os.makedirs('episode_gifs', exist_ok=True)
20
+
21
+ # Save the frames as GIF
22
+ frames[0].save(
23
+ f'{filename}',
24
+ save_all=True,
25
+ append_images=frames[1:],
26
+ duration=200, # Duration for each frame in milliseconds
27
+ loop=0
28
+ )