Spaces:
Sleeping
Sleeping
import matplotlib | |
matplotlib.use('Agg') | |
import gradio as gr | |
import gymnasium as gym | |
from stable_baselines3 import SAC | |
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv | |
import os | |
from huggingface_hub import hf_hub_download | |
import gym_laser # Registers env name for gym.make() | |
# Pre-trained model configurations (TODO: add models by hosting them on huggingface) | |
PRETRAINED_MODELS = { | |
"Random Policy": None, | |
"Upload Custom Model": "upload", | |
"SAC-UDR(1.5,2.5)": "sac-udr-narrow", | |
"SAC-UDR(1.0,9.0)": "sac-udr-wide-extra", | |
} | |
MAX_STEPS = 100_000 # large number for continuous simulation | |
def get_model_path(model_id): | |
"""Get the path to a pre-trained model.""" | |
return f"pretrained-policies/{model_id}.zip" | |
def load_pretrained_model(model_id): | |
"""Load a pre-trained model.""" | |
model = hf_hub_download( | |
repo_id=f"fracapuano/{model_id}", filename=f"{model_id}.zip" | |
) | |
return SAC.load(model) | |
def make_env_fn(): | |
"""Helper function to create a single environment instance.""" | |
return gym.make("LaserEnv", render_mode="rgb_array") | |
def initialize_environment(): | |
"""Initializes the environment on app load.""" | |
try: | |
env = DummyVecEnv([make_env_fn]) | |
env = VecFrameStack(env, n_stack=5) | |
obs = env.reset() | |
state = { | |
"env": env, | |
"obs": obs, | |
"model": None, | |
"step_num": 0, | |
"current_b_integral": 2.0, # Store current B-integral in state | |
"model_filename": "Random Policy" # Default model name | |
} | |
return state | |
except Exception as e: | |
return None, f"Error: {e}" | |
def load_selected_model(state, model_selection, uploaded_file): | |
"""Loads a model based on selection (pre-trained or uploaded).""" | |
if state is None: | |
return state, gr.update() | |
try: | |
if model_selection == "Random Policy": | |
state["model"] = None | |
state["model_filename"] = "Random Policy" | |
state["obs"] = state["env"].reset() | |
state["step_num"] = 0 | |
return state, gr.update() | |
elif model_selection == "Upload Custom Model": | |
if uploaded_file is None: | |
return state, "Please upload a model file.", gr.update() | |
model_filename = uploaded_file.name.split('/')[-1] | |
state["model"] = SAC.load(uploaded_file.name) | |
state["model_filename"] = model_filename | |
state["obs"] = state["env"].reset() | |
state["step_num"] = 0 | |
return state, gr.update() | |
else: | |
model_id = PRETRAINED_MODELS[model_selection] | |
model = load_pretrained_model(model_id) | |
state["model"] = model | |
state["model_filename"] = model_selection | |
state["obs"] = state["env"].reset() | |
state["step_num"] = 0 | |
return state, gr.update() | |
except Exception as e: | |
return state, f"Error loading model: {e}", gr.update() | |
def update_b_integral(state, b_integral): | |
"""Updates the B-integral value in the state without restarting simulation.""" | |
if state is not None: | |
state["current_b_integral"] = b_integral | |
return state | |
def run_continuous_simulation(state): | |
"""Runs the simulation continuously, using the current B-integral from state.""" | |
if not state or "env" not in state: | |
yield state, None, "Environment not ready." | |
return | |
env = state["env"] | |
obs = state["obs"] | |
step_num = state.get("step_num", 0) | |
# Run for a large number of steps to simulate "always-on" | |
for i in range(MAX_STEPS): | |
model = state.get("model") | |
model_filename = state.get("model_filename", "Random Policy") | |
current_b = state.get("current_b_integral", 2.0) | |
# Apply the current B-integral value from state | |
env.envs[0].unwrapped.laser.B = float(current_b) | |
if model: | |
action, _ = model.predict(obs, deterministic=True) | |
else: | |
action = env.action_space.sample().reshape(1, -1) | |
obs, _, done, _ = env.step(action) | |
frame = env.render() | |
if done[0]: | |
obs = env.reset() | |
step_num = 0 | |
else: | |
step_num += 1 | |
state["obs"] = obs | |
state["step_num"] = step_num | |
yield state, frame | |
with gr.Blocks(css="body {zoom: 90%}") as demo: | |
gr.Markdown("# Shaping Laser Pulses with Reinforcement Learning") | |
with gr.Tab("Demo"): | |
sim_state = gr.State() | |
with gr.Row(): | |
b_slider = gr.Slider( | |
minimum=0, | |
maximum=10, | |
step=0.5, | |
value=2.0, | |
label="B-integral", | |
info="Adjust nonlinearity live during simulation.", | |
) | |
with gr.Row(): | |
image_display = gr.Image(label="Environment Render", interactive=False, height=360) | |
with gr.Row(): | |
with gr.Column(): | |
model_selector = gr.Dropdown( | |
choices=list(PRETRAINED_MODELS.keys()), | |
value="Random Policy", | |
label="Model Selection", | |
info="Choose a pre-trained model or upload your own" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_uploader = gr.UploadButton( | |
"Upload Model (.zip)", | |
file_types=['.zip'], | |
elem_id="model-upload", | |
visible=False # Initially hidden | |
) | |
# Show/hide upload button based on selection | |
def update_upload_visibility(selection): | |
return gr.update(visible=(selection == "Upload Custom Model")) | |
model_selector.change( | |
fn=update_upload_visibility, | |
inputs=[model_selector], | |
outputs=[model_uploader] | |
) | |
# On page load, initialize and start the continuous simulation | |
init_event = demo.load( | |
fn=initialize_environment, | |
inputs=None, | |
outputs=[sim_state] | |
) | |
continuous_event = init_event.then( | |
fn=run_continuous_simulation, | |
inputs=[sim_state], | |
outputs=[sim_state, image_display] | |
) | |
# When model selection changes, load the selected model | |
model_change_event = model_selector.change( | |
fn=load_selected_model, | |
inputs=[sim_state, model_selector, model_uploader], | |
outputs=[sim_state, model_uploader], | |
cancels=[continuous_event] | |
).then( | |
fn=run_continuous_simulation, | |
inputs=[sim_state], | |
outputs=[sim_state, image_display] | |
) | |
# When a custom model is uploaded, load it | |
model_upload_event = model_uploader.upload( | |
fn=load_selected_model, | |
inputs=[sim_state, model_selector, model_uploader], | |
outputs=[sim_state, model_uploader], | |
cancels=[continuous_event] | |
).then( | |
fn=run_continuous_simulation, | |
inputs=[sim_state], | |
outputs=[sim_state, image_display] | |
) | |
# When B-integral slider changes, just update the value in state (no restart needed) | |
b_slider.change( | |
fn=update_b_integral, | |
inputs=[sim_state, b_slider], | |
outputs=[sim_state] | |
) | |
with gr.Tab("About"): | |
with open("copy.md", "r") as f: | |
gr.Markdown(f.read()) | |
demo.launch() |