Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| from embodied_gen.utils.monkey_patches import monkey_patch_maniskill | |
| monkey_patch_maniskill() | |
| import json | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from typing import Literal | |
| import gymnasium as gym | |
| import numpy as np | |
| import torch | |
| import tyro | |
| from mani_skill.utils.wrappers import RecordEpisode | |
| from tqdm import tqdm | |
| import embodied_gen.envs.pick_embodiedgen | |
| from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum | |
| from embodied_gen.utils.log import logger | |
| from embodied_gen.utils.simulation import FrankaPandaGrasper | |
| class ParallelSimConfig: | |
| """CLI parameters for Parallel Sapien simulation.""" | |
| # Environment configuration | |
| layout_file: str | |
| """Path to the layout JSON file""" | |
| output_dir: str | |
| """Directory to save recorded videos""" | |
| gym_env_name: str = "PickEmbodiedGen-v1" | |
| """Name of the Gym environment to use""" | |
| num_envs: int = 4 | |
| """Number of parallel environments""" | |
| render_mode: Literal["rgb_array", "hybrid"] = "hybrid" | |
| """Rendering mode: rgb_array or hybrid""" | |
| enable_shadow: bool = True | |
| """Whether to enable shadows in rendering""" | |
| control_mode: str = "pd_joint_pos" | |
| """Control mode for the agent""" | |
| # Recording configuration | |
| max_steps_per_video: int = 1000 | |
| """Maximum steps to record per video""" | |
| save_trajectory: bool = False | |
| """Whether to save trajectory data""" | |
| # Simulation parameters | |
| seed: int = 0 | |
| """Random seed for environment reset""" | |
| warmup_steps: int = 50 | |
| """Number of warmup steps before action computation""" | |
| reach_target_only: bool = True | |
| """Whether to only reach target without full action""" | |
| def entrypoint(**kwargs): | |
| if kwargs is None or len(kwargs) == 0: | |
| cfg = tyro.cli(ParallelSimConfig) | |
| else: | |
| cfg = ParallelSimConfig(**kwargs) | |
| env = gym.make( | |
| cfg.gym_env_name, | |
| num_envs=cfg.num_envs, | |
| render_mode=cfg.render_mode, | |
| enable_shadow=cfg.enable_shadow, | |
| layout_file=cfg.layout_file, | |
| control_mode=cfg.control_mode, | |
| ) | |
| env = RecordEpisode( | |
| env, | |
| cfg.output_dir, | |
| max_steps_per_video=cfg.max_steps_per_video, | |
| save_trajectory=cfg.save_trajectory, | |
| ) | |
| env.reset(seed=cfg.seed) | |
| default_action = env.unwrapped.agent.init_qpos[:, :8] | |
| for _ in tqdm(range(cfg.warmup_steps), desc="SIM Warmup"): | |
| # action = env.action_space.sample() # Random action | |
| obs, reward, terminated, truncated, info = env.step(default_action) | |
| grasper = FrankaPandaGrasper( | |
| env.unwrapped.agent, | |
| env.unwrapped.sim_config.control_freq, | |
| ) | |
| layout_data = LayoutInfo.from_dict(json.load(open(cfg.layout_file, "r"))) | |
| actions = defaultdict(list) | |
| # Plan Grasp reach pose for each manipulated object in each env. | |
| for env_idx in range(env.num_envs): | |
| actors = env.unwrapped.env_actors[f"env{env_idx}"] | |
| for node in layout_data.relation[ | |
| Scene3DItemEnum.MANIPULATED_OBJS.value | |
| ]: | |
| action = grasper.compute_grasp_action( | |
| actor=actors[node]._objs[0], | |
| reach_target_only=True, | |
| env_idx=env_idx, | |
| ) | |
| actions[node].append(action) | |
| # Excute the planned actions for each manipulated object in each env. | |
| for node in actions: | |
| max_env_steps = 0 | |
| for env_idx in range(env.num_envs): | |
| if actions[node][env_idx] is None: | |
| continue | |
| max_env_steps = max(max_env_steps, len(actions[node][env_idx])) | |
| action_tensor = np.ones( | |
| (max_env_steps, env.num_envs, env.action_space.shape[-1]) | |
| ) | |
| action_tensor *= default_action[None, ...] | |
| for env_idx in range(env.num_envs): | |
| action = actions[node][env_idx] | |
| if action is None: | |
| continue | |
| action_tensor[: len(action), env_idx, :] = action | |
| for step in tqdm(range(max_env_steps), desc=f"Grasping: {node}"): | |
| action = torch.Tensor(action_tensor[step]).to(env.unwrapped.device) | |
| env.unwrapped.agent.set_action(action) | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| env.close() | |
| logger.info(f"Results saved in {cfg.output_dir}") | |
| if __name__ == "__main__": | |
| entrypoint() | |