Spaces:
Paused
Paused
| import collections | |
| import os | |
| import time | |
| from threading import Lock | |
| import glfw | |
| import imageio | |
| import mujoco | |
| import numpy as np | |
| def _import_egl(width, height): | |
| from mujoco.egl import GLContext | |
| return GLContext(width, height) | |
| def _import_glfw(width, height): | |
| from mujoco.glfw import GLContext | |
| return GLContext(width, height) | |
| def _import_osmesa(width, height): | |
| from mujoco.osmesa import GLContext | |
| return GLContext(width, height) | |
| _ALL_RENDERERS = collections.OrderedDict( | |
| [ | |
| ("glfw", _import_glfw), | |
| ("egl", _import_egl), | |
| ("osmesa", _import_osmesa), | |
| ] | |
| ) | |
| class RenderContext: | |
| """Render context superclass for offscreen and window rendering.""" | |
| def __init__(self, model, data, offscreen=True): | |
| self.model = model | |
| self.data = data | |
| self.offscreen = offscreen | |
| self.offwidth = model.vis.global_.offwidth | |
| self.offheight = model.vis.global_.offheight | |
| max_geom = 1000 | |
| mujoco.mj_forward(self.model, self.data) | |
| self.scn = mujoco.MjvScene(self.model, max_geom) | |
| self.cam = mujoco.MjvCamera() | |
| self.vopt = mujoco.MjvOption() | |
| self.pert = mujoco.MjvPerturb() | |
| self.con = mujoco.MjrContext(self.model, mujoco.mjtFontScale.mjFONTSCALE_150) | |
| self._markers = [] | |
| self._overlays = {} | |
| self._init_camera() | |
| self._set_mujoco_buffers() | |
| def _set_mujoco_buffers(self): | |
| if self.offscreen: | |
| mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_OFFSCREEN, self.con) | |
| if self.con.currentBuffer != mujoco.mjtFramebuffer.mjFB_OFFSCREEN: | |
| raise RuntimeError("Offscreen rendering not supported") | |
| else: | |
| mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_WINDOW, self.con) | |
| if self.con.currentBuffer != mujoco.mjtFramebuffer.mjFB_WINDOW: | |
| raise RuntimeError("Window rendering not supported") | |
| def render(self, camera_id=None, segmentation=False): | |
| width, height = self.offwidth, self.offheight | |
| rect = mujoco.MjrRect(left=0, bottom=0, width=width, height=height) | |
| if camera_id is not None: | |
| if camera_id == -1: | |
| self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE | |
| else: | |
| self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED | |
| self.cam.fixedcamid = camera_id | |
| mujoco.mjv_updateScene( | |
| self.model, | |
| self.data, | |
| self.vopt, | |
| self.pert, | |
| self.cam, | |
| mujoco.mjtCatBit.mjCAT_ALL, | |
| self.scn, | |
| ) | |
| if segmentation: | |
| self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 1 | |
| self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 1 | |
| for marker_params in self._markers: | |
| self._add_marker_to_scene(marker_params) | |
| mujoco.mjr_render(rect, self.scn, self.con) | |
| for gridpos, (text1, text2) in self._overlays.items(): | |
| mujoco.mjr_overlay( | |
| mujoco.mjtFontScale.mjFONTSCALE_150, | |
| gridpos, | |
| rect, | |
| text1.encode(), | |
| text2.encode(), | |
| self.con, | |
| ) | |
| if segmentation: | |
| self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 0 | |
| self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 0 | |
| def read_pixels(self, depth=True, segmentation=False): | |
| width, height = self.offwidth, self.offheight | |
| rect = mujoco.MjrRect(left=0, bottom=0, width=width, height=height) | |
| rgb_arr = np.zeros(3 * rect.width * rect.height, dtype=np.uint8) | |
| depth_arr = np.zeros(rect.width * rect.height, dtype=np.float32) | |
| mujoco.mjr_readPixels(rgb_arr, depth_arr, rect, self.con) | |
| rgb_img = rgb_arr.reshape(rect.height, rect.width, 3) | |
| ret_img = rgb_img | |
| if segmentation: | |
| seg_img = ( | |
| rgb_img[:, :, 0] | |
| + rgb_img[:, :, 1] * (2**8) | |
| + rgb_img[:, :, 2] * (2**16) | |
| ) | |
| seg_img[seg_img >= (self.scn.ngeom + 1)] = 0 | |
| seg_ids = np.full((self.scn.ngeom + 1, 2), fill_value=-1, dtype=np.int32) | |
| for i in range(self.scn.ngeom): | |
| geom = self.scn.geoms[i] | |
| if geom.segid != -1: | |
| seg_ids[geom.segid + 1, 0] = geom.objtype | |
| seg_ids[geom.segid + 1, 1] = geom.objid | |
| ret_img = seg_ids[seg_img] | |
| if depth: | |
| depth_img = depth_arr.reshape(rect.height, rect.width) | |
| return (ret_img, depth_img) | |
| else: | |
| return ret_img | |
| def _init_camera(self): | |
| self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE | |
| self.cam.fixedcamid = -1 | |
| for i in range(3): | |
| self.cam.lookat[i] = np.median(self.data.geom_xpos[:, i]) | |
| self.cam.distance = self.model.stat.extent | |
| def add_overlay(self, gridpos: int, text1: str, text2: str): | |
| """Overlays text on the scene.""" | |
| if gridpos not in self._overlays: | |
| self._overlays[gridpos] = ["", ""] | |
| self._overlays[gridpos][0] += text1 + "\n" | |
| self._overlays[gridpos][1] += text2 + "\n" | |
| def add_marker(self, **marker_params): | |
| self._markers.append(marker_params) | |
| def _add_marker_to_scene(self, marker): | |
| if self.scn.ngeom >= self.scn.maxgeom: | |
| raise RuntimeError("Ran out of geoms. maxgeom: %d" % self.scn.maxgeom) | |
| g = self.scn.geoms[self.scn.ngeom] | |
| # default values. | |
| g.dataid = -1 | |
| g.objtype = mujoco.mjtObj.mjOBJ_UNKNOWN | |
| g.objid = -1 | |
| g.category = mujoco.mjtCatBit.mjCAT_DECOR | |
| g.texid = -1 | |
| g.texuniform = 0 | |
| g.texrepeat[0] = 1 | |
| g.texrepeat[1] = 1 | |
| g.emission = 0 | |
| g.specular = 0.5 | |
| g.shininess = 0.5 | |
| g.reflectance = 0 | |
| g.type = mujoco.mjtGeom.mjGEOM_BOX | |
| g.size[:] = np.ones(3) * 0.1 | |
| g.mat[:] = np.eye(3) | |
| g.rgba[:] = np.ones(4) | |
| for key, value in marker.items(): | |
| if isinstance(value, (int, float, mujoco._enums.mjtGeom)): | |
| setattr(g, key, value) | |
| elif isinstance(value, (tuple, list, np.ndarray)): | |
| attr = getattr(g, key) | |
| attr[:] = np.asarray(value).reshape(attr.shape) | |
| elif isinstance(value, str): | |
| assert key == "label", "Only label is a string in mjtGeom." | |
| if value is None: | |
| g.label[0] = 0 | |
| else: | |
| g.label = value | |
| elif hasattr(g, key): | |
| raise ValueError( | |
| "mjtGeom has attr {} but type {} is invalid".format( | |
| key, type(value) | |
| ) | |
| ) | |
| else: | |
| raise ValueError("mjtGeom doesn't have field %s" % key) | |
| self.scn.ngeom += 1 | |
| def close(self): | |
| """Override close in your rendering subclass to perform any necessary cleanup | |
| after env.close() is called. | |
| """ | |
| pass | |
| class RenderContextOffscreen(RenderContext): | |
| """Offscreen rendering class with opengl context.""" | |
| def __init__(self, model, data): | |
| # We must make GLContext before MjrContext | |
| width = model.vis.global_.offwidth | |
| height = model.vis.global_.offheight | |
| self._get_opengl_backend(width, height) | |
| self.opengl_context.make_current() | |
| super().__init__(model, data, offscreen=True) | |
| def _get_opengl_backend(self, width, height): | |
| backend = os.environ.get("MUJOCO_GL") | |
| if backend is not None: | |
| try: | |
| self.opengl_context = _ALL_RENDERERS[backend](width, height) | |
| except KeyError: | |
| raise RuntimeError( | |
| "Environment variable {} must be one of {!r}: got {!r}.".format( | |
| "MUJOCO_GL", _ALL_RENDERERS.keys(), backend | |
| ) | |
| ) | |
| else: | |
| for name, _ in _ALL_RENDERERS.items(): | |
| try: | |
| self.opengl_context = _ALL_RENDERERS[name](width, height) | |
| backend = name | |
| break | |
| except: # noqa:E722 | |
| pass | |
| if backend is None: | |
| raise RuntimeError( | |
| "No OpenGL backend could be imported. Attempting to create a " | |
| "rendering context will result in a RuntimeError." | |
| ) | |
| class Viewer(RenderContext): | |
| """Class for window rendering in all MuJoCo environments.""" | |
| def __init__(self, model, data): | |
| self._gui_lock = Lock() | |
| self._button_left_pressed = False | |
| self._button_right_pressed = False | |
| self._last_mouse_x = 0 | |
| self._last_mouse_y = 0 | |
| self._paused = False | |
| self._transparent = False | |
| self._contacts = False | |
| self._render_every_frame = True | |
| self._image_idx = 0 | |
| self._image_path = "/tmp/frame_%07d.png" | |
| self._time_per_render = 1 / 60.0 | |
| self._run_speed = 1.0 | |
| self._loop_count = 0 | |
| self._advance_by_one_step = False | |
| self._hide_menu = False | |
| # glfw init | |
| glfw.init() | |
| width, height = glfw.get_video_mode(glfw.get_primary_monitor()).size | |
| self.window = glfw.create_window(width // 2, height // 2, "mujoco", None, None) | |
| glfw.make_context_current(self.window) | |
| glfw.swap_interval(1) | |
| framebuffer_width, framebuffer_height = glfw.get_framebuffer_size(self.window) | |
| window_width, _ = glfw.get_window_size(self.window) | |
| self._scale = framebuffer_width * 1.0 / window_width | |
| # set callbacks | |
| glfw.set_cursor_pos_callback(self.window, self._cursor_pos_callback) | |
| glfw.set_mouse_button_callback(self.window, self._mouse_button_callback) | |
| glfw.set_scroll_callback(self.window, self._scroll_callback) | |
| glfw.set_key_callback(self.window, self._key_callback) | |
| # get viewport | |
| self.viewport = mujoco.MjrRect(0, 0, framebuffer_width, framebuffer_height) | |
| super().__init__(model, data, offscreen=False) | |
| def _key_callback(self, window, key, scancode, action, mods): | |
| if action != glfw.RELEASE: | |
| return | |
| # Switch cameras | |
| elif key == glfw.KEY_TAB: | |
| self.cam.fixedcamid += 1 | |
| self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED | |
| if self.cam.fixedcamid >= self.model.ncam: | |
| self.cam.fixedcamid = -1 | |
| self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE | |
| # Pause simulation | |
| elif key == glfw.KEY_SPACE and self._paused is not None: | |
| self._paused = not self._paused | |
| # Advances simulation by one step. | |
| elif key == glfw.KEY_RIGHT and self._paused is not None: | |
| self._advance_by_one_step = True | |
| self._paused = True | |
| # Slows down simulation | |
| elif key == glfw.KEY_S: | |
| self._run_speed /= 2.0 | |
| # Speeds up simulation | |
| elif key == glfw.KEY_F: | |
| self._run_speed *= 2.0 | |
| # Turn off / turn on rendering every frame. | |
| elif key == glfw.KEY_D: | |
| self._render_every_frame = not self._render_every_frame | |
| # Capture screenshot | |
| elif key == glfw.KEY_T: | |
| img = np.zeros( | |
| ( | |
| glfw.get_framebuffer_size(self.window)[1], | |
| glfw.get_framebuffer_size(self.window)[0], | |
| 3, | |
| ), | |
| dtype=np.uint8, | |
| ) | |
| mujoco.mjr_readPixels(img, None, self.viewport, self.con) | |
| imageio.imwrite(self._image_path % self._image_idx, np.flipud(img)) | |
| self._image_idx += 1 | |
| # Display contact forces | |
| elif key == glfw.KEY_C: | |
| self._contacts = not self._contacts | |
| self.vopt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = self._contacts | |
| self.vopt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = self._contacts | |
| # Display coordinate frames | |
| elif key == glfw.KEY_E: | |
| self.vopt.frame = 1 - self.vopt.frame | |
| # Hide overlay menu | |
| elif key == glfw.KEY_H: | |
| self._hide_menu = not self._hide_menu | |
| # Make transparent | |
| elif key == glfw.KEY_R: | |
| self._transparent = not self._transparent | |
| if self._transparent: | |
| self.model.geom_rgba[:, 3] /= 5.0 | |
| else: | |
| self.model.geom_rgba[:, 3] *= 5.0 | |
| # Geom group visibility | |
| elif key in (glfw.KEY_0, glfw.KEY_1, glfw.KEY_2, glfw.KEY_3, glfw.KEY_4): | |
| self.vopt.geomgroup[key - glfw.KEY_0] ^= 1 | |
| # Quit | |
| if key == glfw.KEY_ESCAPE: | |
| print("Pressed ESC") | |
| print("Quitting.") | |
| glfw.destroy_window(self.window) | |
| glfw.terminate() | |
| def _cursor_pos_callback(self, window, xpos, ypos): | |
| if not (self._button_left_pressed or self._button_right_pressed): | |
| return | |
| mod_shift = ( | |
| glfw.get_key(window, glfw.KEY_LEFT_SHIFT) == glfw.PRESS | |
| or glfw.get_key(window, glfw.KEY_RIGHT_SHIFT) == glfw.PRESS | |
| ) | |
| if self._button_right_pressed: | |
| action = ( | |
| mujoco.mjtMouse.mjMOUSE_MOVE_H | |
| if mod_shift | |
| else mujoco.mjtMouse.mjMOUSE_MOVE_V | |
| ) | |
| elif self._button_left_pressed: | |
| action = ( | |
| mujoco.mjtMouse.mjMOUSE_ROTATE_H | |
| if mod_shift | |
| else mujoco.mjtMouse.mjMOUSE_ROTATE_V | |
| ) | |
| else: | |
| action = mujoco.mjtMouse.mjMOUSE_ZOOM | |
| dx = int(self._scale * xpos) - self._last_mouse_x | |
| dy = int(self._scale * ypos) - self._last_mouse_y | |
| width, height = glfw.get_framebuffer_size(window) | |
| with self._gui_lock: | |
| mujoco.mjv_moveCamera( | |
| self.model, action, dx / height, dy / height, self.scn, self.cam | |
| ) | |
| self._last_mouse_x = int(self._scale * xpos) | |
| self._last_mouse_y = int(self._scale * ypos) | |
| def _mouse_button_callback(self, window, button, act, mods): | |
| self._button_left_pressed = ( | |
| glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_LEFT) == glfw.PRESS | |
| ) | |
| self._button_right_pressed = ( | |
| glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_RIGHT) == glfw.PRESS | |
| ) | |
| x, y = glfw.get_cursor_pos(window) | |
| self._last_mouse_x = int(self._scale * x) | |
| self._last_mouse_y = int(self._scale * y) | |
| def _scroll_callback(self, window, x_offset, y_offset): | |
| with self._gui_lock: | |
| mujoco.mjv_moveCamera( | |
| self.model, | |
| mujoco.mjtMouse.mjMOUSE_ZOOM, | |
| 0, | |
| -0.05 * y_offset, | |
| self.scn, | |
| self.cam, | |
| ) | |
| def _create_overlay(self): | |
| topleft = mujoco.mjtGridPos.mjGRID_TOPLEFT | |
| bottomleft = mujoco.mjtGridPos.mjGRID_BOTTOMLEFT | |
| if self._render_every_frame: | |
| self.add_overlay(topleft, "", "") | |
| else: | |
| self.add_overlay( | |
| topleft, | |
| "Run speed = %.3f x real time" % self._run_speed, | |
| "[S]lower, [F]aster", | |
| ) | |
| self.add_overlay( | |
| topleft, "Ren[d]er every frame", "On" if self._render_every_frame else "Off" | |
| ) | |
| self.add_overlay( | |
| topleft, | |
| "Switch camera (#cams = %d)" % (self.model.ncam + 1), | |
| "[Tab] (camera ID = %d)" % self.cam.fixedcamid, | |
| ) | |
| self.add_overlay(topleft, "[C]ontact forces", "On" if self._contacts else "Off") | |
| self.add_overlay(topleft, "T[r]ansparent", "On" if self._transparent else "Off") | |
| if self._paused is not None: | |
| if not self._paused: | |
| self.add_overlay(topleft, "Stop", "[Space]") | |
| else: | |
| self.add_overlay(topleft, "Start", "[Space]") | |
| self.add_overlay( | |
| topleft, "Advance simulation by one step", "[right arrow]" | |
| ) | |
| self.add_overlay( | |
| topleft, "Referenc[e] frames", "On" if self.vopt.frame == 1 else "Off" | |
| ) | |
| self.add_overlay(topleft, "[H]ide Menu", "") | |
| if self._image_idx > 0: | |
| fname = self._image_path % (self._image_idx - 1) | |
| self.add_overlay(topleft, "Cap[t]ure frame", "Saved as %s" % fname) | |
| else: | |
| self.add_overlay(topleft, "Cap[t]ure frame", "") | |
| self.add_overlay(topleft, "Toggle geomgroup visibility", "0-4") | |
| self.add_overlay(bottomleft, "FPS", "%d%s" % (1 / self._time_per_render, "")) | |
| self.add_overlay( | |
| bottomleft, "Solver iterations", str(self.data.solver_iter + 1) | |
| ) | |
| self.add_overlay( | |
| bottomleft, "Step", str(round(self.data.time / self.model.opt.timestep)) | |
| ) | |
| self.add_overlay(bottomleft, "timestep", "%.5f" % self.model.opt.timestep) | |
| def render(self): | |
| # mjv_updateScene, mjr_render, mjr_overlay | |
| def update(): | |
| # fill overlay items | |
| self._create_overlay() | |
| render_start = time.time() | |
| if self.window is None: | |
| return | |
| elif glfw.window_should_close(self.window): | |
| glfw.destroy_window(self.window) | |
| glfw.terminate() | |
| self.viewport.width, self.viewport.height = glfw.get_framebuffer_size( | |
| self.window | |
| ) | |
| with self._gui_lock: | |
| # update scene | |
| mujoco.mjv_updateScene( | |
| self.model, | |
| self.data, | |
| self.vopt, | |
| mujoco.MjvPerturb(), | |
| self.cam, | |
| mujoco.mjtCatBit.mjCAT_ALL.value, | |
| self.scn, | |
| ) | |
| # marker items | |
| for marker in self._markers: | |
| self._add_marker_to_scene(marker) | |
| # render | |
| mujoco.mjr_render(self.viewport, self.scn, self.con) | |
| # overlay items | |
| if not self._hide_menu: | |
| for gridpos, [t1, t2] in self._overlays.items(): | |
| mujoco.mjr_overlay( | |
| mujoco.mjtFontScale.mjFONTSCALE_150, | |
| gridpos, | |
| self.viewport, | |
| t1, | |
| t2, | |
| self.con, | |
| ) | |
| glfw.swap_buffers(self.window) | |
| glfw.poll_events() | |
| self._time_per_render = 0.9 * self._time_per_render + 0.1 * ( | |
| time.time() - render_start | |
| ) | |
| # clear overlay | |
| self._overlays.clear() | |
| if self._paused: | |
| while self._paused: | |
| update() | |
| if self._advance_by_one_step: | |
| self._advance_by_one_step = False | |
| break | |
| else: | |
| self._loop_count += self.model.opt.timestep / ( | |
| self._time_per_render * self._run_speed | |
| ) | |
| if self._render_every_frame: | |
| self._loop_count = 1 | |
| while self._loop_count > 0: | |
| update() | |
| self._loop_count -= 1 | |
| # clear markers | |
| self._markers[:] = [] | |
| def close(self): | |
| glfw.destroy_window(self.window) | |
| glfw.terminate() | |