Spaces:
Running
Running
Francesco Capuano
commited on
Commit
·
529ed6b
1
Parent(s):
efd04f3
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- lerobot/__init__.py +217 -0
- lerobot/__version__.py +23 -0
- lerobot/common/constants.py +45 -0
- lerobot/common/datasets/backward_compatibility.py +68 -0
- lerobot/common/datasets/card_template.md +27 -0
- lerobot/common/datasets/compute_stats.py +176 -0
- lerobot/common/datasets/factory.py +118 -0
- lerobot/common/datasets/image_writer.py +178 -0
- lerobot/common/datasets/lerobot_dataset.py +1217 -0
- lerobot/common/datasets/online_buffer.py +384 -0
- lerobot/common/datasets/push_dataset_to_hub/utils.py +131 -0
- lerobot/common/datasets/sampler.py +61 -0
- lerobot/common/datasets/transforms.py +249 -0
- lerobot/common/datasets/utils.py +813 -0
- lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py +884 -0
- lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +664 -0
- lerobot/common/datasets/v21/_remove_language_instruction.py +87 -0
- lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py +54 -0
- lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +114 -0
- lerobot/common/datasets/v21/convert_stats.py +99 -0
- lerobot/common/datasets/video_utils.py +432 -0
- lerobot/common/envs/__init__.py +15 -0
- lerobot/common/envs/configs.py +156 -0
- lerobot/common/envs/factory.py +69 -0
- lerobot/common/envs/utils.py +127 -0
- lerobot/common/mocks/__init__.py +1 -0
- lerobot/common/mocks/cameras/__init__.py +0 -0
- lerobot/common/mocks/cameras/mock_cv2.py +101 -0
- lerobot/common/mocks/cameras/mock_pyrealsense2.py +148 -0
- lerobot/common/mocks/motors/__init__.py +1 -0
- lerobot/common/mocks/motors/mock_dynamixel_sdk.py +107 -0
- lerobot/common/mocks/motors/mock_scservo_sdk.py +125 -0
- lerobot/common/optim/__init__.py +15 -0
- lerobot/common/optim/factory.py +40 -0
- lerobot/common/optim/optimizers.py +118 -0
- lerobot/common/optim/schedulers.py +122 -0
- lerobot/common/policies/__init__.py +19 -0
- lerobot/common/policies/act/configuration_act.py +186 -0
- lerobot/common/policies/act/modeling_act.py +765 -0
- lerobot/common/policies/diffusion/configuration_diffusion.py +237 -0
- lerobot/common/policies/diffusion/modeling_diffusion.py +765 -0
- lerobot/common/policies/factory.py +157 -0
- lerobot/common/policies/normalize.py +254 -0
- lerobot/common/policies/pi0/configuration_pi0.py +149 -0
- lerobot/common/policies/pi0/conversion_scripts/benchmark.py +82 -0
- lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +131 -0
- lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py +84 -0
- lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +437 -0
- lerobot/common/policies/pi0/flex_attention.py +141 -0
- lerobot/common/policies/pi0/modeling_pi0.py +732 -0
lerobot/__init__.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
18 |
+
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
19 |
+
|
20 |
+
Example:
|
21 |
+
```python
|
22 |
+
import lerobot
|
23 |
+
print(lerobot.available_envs)
|
24 |
+
print(lerobot.available_tasks_per_env)
|
25 |
+
print(lerobot.available_datasets)
|
26 |
+
print(lerobot.available_datasets_per_env)
|
27 |
+
print(lerobot.available_real_world_datasets)
|
28 |
+
print(lerobot.available_policies)
|
29 |
+
print(lerobot.available_policies_per_env)
|
30 |
+
print(lerobot.available_robots)
|
31 |
+
print(lerobot.available_cameras)
|
32 |
+
print(lerobot.available_motors)
|
33 |
+
```
|
34 |
+
|
35 |
+
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
36 |
+
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
37 |
+
|
38 |
+
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
39 |
+
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
40 |
+
|
41 |
+
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
42 |
+
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
43 |
+
- Set the required `name` class attribute.
|
44 |
+
- Update variables in `tests/test_available.py` by importing your new Policy class
|
45 |
+
"""
|
46 |
+
|
47 |
+
import itertools
|
48 |
+
|
49 |
+
from lerobot.__version__ import __version__ # noqa: F401
|
50 |
+
|
51 |
+
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
|
52 |
+
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
|
53 |
+
# a yaml file AND a environment name. The difference should be more obvious.
|
54 |
+
available_tasks_per_env = {
|
55 |
+
"aloha": [
|
56 |
+
"AlohaInsertion-v0",
|
57 |
+
"AlohaTransferCube-v0",
|
58 |
+
],
|
59 |
+
"pusht": ["PushT-v0"],
|
60 |
+
"xarm": ["XarmLift-v0"],
|
61 |
+
}
|
62 |
+
available_envs = list(available_tasks_per_env.keys())
|
63 |
+
|
64 |
+
available_datasets_per_env = {
|
65 |
+
"aloha": [
|
66 |
+
"lerobot/aloha_sim_insertion_human",
|
67 |
+
"lerobot/aloha_sim_insertion_scripted",
|
68 |
+
"lerobot/aloha_sim_transfer_cube_human",
|
69 |
+
"lerobot/aloha_sim_transfer_cube_scripted",
|
70 |
+
"lerobot/aloha_sim_insertion_human_image",
|
71 |
+
"lerobot/aloha_sim_insertion_scripted_image",
|
72 |
+
"lerobot/aloha_sim_transfer_cube_human_image",
|
73 |
+
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
74 |
+
],
|
75 |
+
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
76 |
+
# coupled with tests.
|
77 |
+
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
78 |
+
"xarm": [
|
79 |
+
"lerobot/xarm_lift_medium",
|
80 |
+
"lerobot/xarm_lift_medium_replay",
|
81 |
+
"lerobot/xarm_push_medium",
|
82 |
+
"lerobot/xarm_push_medium_replay",
|
83 |
+
"lerobot/xarm_lift_medium_image",
|
84 |
+
"lerobot/xarm_lift_medium_replay_image",
|
85 |
+
"lerobot/xarm_push_medium_image",
|
86 |
+
"lerobot/xarm_push_medium_replay_image",
|
87 |
+
],
|
88 |
+
}
|
89 |
+
|
90 |
+
available_real_world_datasets = [
|
91 |
+
"lerobot/aloha_mobile_cabinet",
|
92 |
+
"lerobot/aloha_mobile_chair",
|
93 |
+
"lerobot/aloha_mobile_elevator",
|
94 |
+
"lerobot/aloha_mobile_shrimp",
|
95 |
+
"lerobot/aloha_mobile_wash_pan",
|
96 |
+
"lerobot/aloha_mobile_wipe_wine",
|
97 |
+
"lerobot/aloha_static_battery",
|
98 |
+
"lerobot/aloha_static_candy",
|
99 |
+
"lerobot/aloha_static_coffee",
|
100 |
+
"lerobot/aloha_static_coffee_new",
|
101 |
+
"lerobot/aloha_static_cups_open",
|
102 |
+
"lerobot/aloha_static_fork_pick_up",
|
103 |
+
"lerobot/aloha_static_pingpong_test",
|
104 |
+
"lerobot/aloha_static_pro_pencil",
|
105 |
+
"lerobot/aloha_static_screw_driver",
|
106 |
+
"lerobot/aloha_static_tape",
|
107 |
+
"lerobot/aloha_static_thread_velcro",
|
108 |
+
"lerobot/aloha_static_towel",
|
109 |
+
"lerobot/aloha_static_vinh_cup",
|
110 |
+
"lerobot/aloha_static_vinh_cup_left",
|
111 |
+
"lerobot/aloha_static_ziploc_slide",
|
112 |
+
"lerobot/umi_cup_in_the_wild",
|
113 |
+
"lerobot/unitreeh1_fold_clothes",
|
114 |
+
"lerobot/unitreeh1_rearrange_objects",
|
115 |
+
"lerobot/unitreeh1_two_robot_greeting",
|
116 |
+
"lerobot/unitreeh1_warehouse",
|
117 |
+
"lerobot/nyu_rot_dataset",
|
118 |
+
"lerobot/utokyo_saytap",
|
119 |
+
"lerobot/imperialcollege_sawyer_wrist_cam",
|
120 |
+
"lerobot/utokyo_xarm_bimanual",
|
121 |
+
"lerobot/tokyo_u_lsmo",
|
122 |
+
"lerobot/utokyo_pr2_opening_fridge",
|
123 |
+
"lerobot/cmu_franka_exploration_dataset",
|
124 |
+
"lerobot/cmu_stretch",
|
125 |
+
"lerobot/asu_table_top",
|
126 |
+
"lerobot/utokyo_pr2_tabletop_manipulation",
|
127 |
+
"lerobot/utokyo_xarm_pick_and_place",
|
128 |
+
"lerobot/ucsd_kitchen_dataset",
|
129 |
+
"lerobot/austin_buds_dataset",
|
130 |
+
"lerobot/dlr_sara_grid_clamp",
|
131 |
+
"lerobot/conq_hose_manipulation",
|
132 |
+
"lerobot/columbia_cairlab_pusht_real",
|
133 |
+
"lerobot/dlr_sara_pour",
|
134 |
+
"lerobot/dlr_edan_shared_control",
|
135 |
+
"lerobot/ucsd_pick_and_place_dataset",
|
136 |
+
"lerobot/berkeley_cable_routing",
|
137 |
+
"lerobot/nyu_franka_play_dataset",
|
138 |
+
"lerobot/austin_sirius_dataset",
|
139 |
+
"lerobot/cmu_play_fusion",
|
140 |
+
"lerobot/berkeley_gnm_sac_son",
|
141 |
+
"lerobot/nyu_door_opening_surprising_effectiveness",
|
142 |
+
"lerobot/berkeley_fanuc_manipulation",
|
143 |
+
"lerobot/jaco_play",
|
144 |
+
"lerobot/viola",
|
145 |
+
"lerobot/kaist_nonprehensile",
|
146 |
+
"lerobot/berkeley_mvp",
|
147 |
+
"lerobot/uiuc_d3field",
|
148 |
+
"lerobot/berkeley_gnm_recon",
|
149 |
+
"lerobot/austin_sailor_dataset",
|
150 |
+
"lerobot/utaustin_mutex",
|
151 |
+
"lerobot/roboturk",
|
152 |
+
"lerobot/stanford_hydra_dataset",
|
153 |
+
"lerobot/berkeley_autolab_ur5",
|
154 |
+
"lerobot/stanford_robocook",
|
155 |
+
"lerobot/toto",
|
156 |
+
"lerobot/fmb",
|
157 |
+
"lerobot/droid_100",
|
158 |
+
"lerobot/berkeley_rpt",
|
159 |
+
"lerobot/stanford_kuka_multimodal_dataset",
|
160 |
+
"lerobot/iamlab_cmu_pickup_insert",
|
161 |
+
"lerobot/taco_play",
|
162 |
+
"lerobot/berkeley_gnm_cory_hall",
|
163 |
+
"lerobot/usc_cloth_sim",
|
164 |
+
]
|
165 |
+
|
166 |
+
available_datasets = sorted(
|
167 |
+
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
|
168 |
+
)
|
169 |
+
|
170 |
+
# lists all available policies from `lerobot/common/policies`
|
171 |
+
available_policies = [
|
172 |
+
"act",
|
173 |
+
"diffusion",
|
174 |
+
"tdmpc",
|
175 |
+
"vqbet",
|
176 |
+
]
|
177 |
+
|
178 |
+
# lists all available robots from `lerobot/common/robot_devices/robots`
|
179 |
+
available_robots = [
|
180 |
+
"koch",
|
181 |
+
"koch_bimanual",
|
182 |
+
"aloha",
|
183 |
+
"so100",
|
184 |
+
"moss",
|
185 |
+
]
|
186 |
+
|
187 |
+
# lists all available cameras from `lerobot/common/robot_devices/cameras`
|
188 |
+
available_cameras = [
|
189 |
+
"opencv",
|
190 |
+
"intelrealsense",
|
191 |
+
]
|
192 |
+
|
193 |
+
# lists all available motors from `lerobot/common/robot_devices/motors`
|
194 |
+
available_motors = [
|
195 |
+
"dynamixel",
|
196 |
+
"feetech",
|
197 |
+
]
|
198 |
+
|
199 |
+
# keys and values refer to yaml files
|
200 |
+
available_policies_per_env = {
|
201 |
+
"aloha": ["act"],
|
202 |
+
"pusht": ["diffusion", "vqbet"],
|
203 |
+
"xarm": ["tdmpc"],
|
204 |
+
"koch_real": ["act_koch_real"],
|
205 |
+
"aloha_real": ["act_aloha_real"],
|
206 |
+
}
|
207 |
+
|
208 |
+
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
209 |
+
env_dataset_pairs = [
|
210 |
+
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
|
211 |
+
]
|
212 |
+
env_dataset_policy_triplets = [
|
213 |
+
(env, dataset, policy)
|
214 |
+
for env, datasets in available_datasets_per_env.items()
|
215 |
+
for dataset in datasets
|
216 |
+
for policy in available_policies_per_env[env]
|
217 |
+
]
|
lerobot/__version__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""To enable `lerobot.__version__`"""
|
17 |
+
|
18 |
+
from importlib.metadata import PackageNotFoundError, version
|
19 |
+
|
20 |
+
try:
|
21 |
+
__version__ = version("lerobot")
|
22 |
+
except PackageNotFoundError:
|
23 |
+
__version__ = "unknown"
|
lerobot/common/constants.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# keys
|
15 |
+
import os
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
from huggingface_hub.constants import HF_HOME
|
19 |
+
|
20 |
+
OBS_ENV = "observation.environment_state"
|
21 |
+
OBS_ROBOT = "observation.state"
|
22 |
+
OBS_IMAGE = "observation.image"
|
23 |
+
OBS_IMAGES = "observation.images"
|
24 |
+
ACTION = "action"
|
25 |
+
|
26 |
+
# files & directories
|
27 |
+
CHECKPOINTS_DIR = "checkpoints"
|
28 |
+
LAST_CHECKPOINT_LINK = "last"
|
29 |
+
PRETRAINED_MODEL_DIR = "pretrained_model"
|
30 |
+
TRAINING_STATE_DIR = "training_state"
|
31 |
+
RNG_STATE = "rng_state.safetensors"
|
32 |
+
TRAINING_STEP = "training_step.json"
|
33 |
+
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
34 |
+
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
35 |
+
SCHEDULER_STATE = "scheduler_state.json"
|
36 |
+
|
37 |
+
# cache dir
|
38 |
+
default_cache_path = Path(HF_HOME) / "lerobot"
|
39 |
+
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
40 |
+
|
41 |
+
if "LEROBOT_HOME" in os.environ:
|
42 |
+
raise ValueError(
|
43 |
+
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
44 |
+
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
45 |
+
)
|
lerobot/common/datasets/backward_compatibility.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import packaging.version
|
16 |
+
|
17 |
+
V2_MESSAGE = """
|
18 |
+
The dataset you requested ({repo_id}) is in {version} format.
|
19 |
+
|
20 |
+
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
21 |
+
Please, use our conversion script. Modify the following command with your own task description:
|
22 |
+
```
|
23 |
+
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
24 |
+
--repo-id {repo_id} \\
|
25 |
+
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
26 |
+
```
|
27 |
+
|
28 |
+
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
29 |
+
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
30 |
+
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
31 |
+
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
32 |
+
sweatshirt.", ...
|
33 |
+
|
34 |
+
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
35 |
+
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
36 |
+
"""
|
37 |
+
|
38 |
+
V21_MESSAGE = """
|
39 |
+
The dataset you requested ({repo_id}) is in {version} format.
|
40 |
+
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
41 |
+
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
42 |
+
```
|
43 |
+
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
44 |
+
```
|
45 |
+
|
46 |
+
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
47 |
+
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
48 |
+
"""
|
49 |
+
|
50 |
+
FUTURE_MESSAGE = """
|
51 |
+
The dataset you requested ({repo_id}) is only available in {version} format.
|
52 |
+
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
53 |
+
"""
|
54 |
+
|
55 |
+
|
56 |
+
class CompatibilityError(Exception): ...
|
57 |
+
|
58 |
+
|
59 |
+
class BackwardCompatibilityError(CompatibilityError):
|
60 |
+
def __init__(self, repo_id: str, version: packaging.version.Version):
|
61 |
+
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
62 |
+
super().__init__(message)
|
63 |
+
|
64 |
+
|
65 |
+
class ForwardCompatibilityError(CompatibilityError):
|
66 |
+
def __init__(self, repo_id: str, version: packaging.version.Version):
|
67 |
+
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
68 |
+
super().__init__(message)
|
lerobot/common/datasets/card_template.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
|
3 |
+
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
|
4 |
+
{{ card_data }}
|
5 |
+
---
|
6 |
+
|
7 |
+
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
8 |
+
|
9 |
+
## Dataset Description
|
10 |
+
|
11 |
+
{{ dataset_description | default("", true) }}
|
12 |
+
|
13 |
+
- **Homepage:** {{ url | default("[More Information Needed]", true)}}
|
14 |
+
- **Paper:** {{ paper | default("[More Information Needed]", true)}}
|
15 |
+
- **License:** {{ license | default("[More Information Needed]", true)}}
|
16 |
+
|
17 |
+
## Dataset Structure
|
18 |
+
|
19 |
+
{{ dataset_structure | default("[More Information Needed]", true)}}
|
20 |
+
|
21 |
+
## Citation
|
22 |
+
|
23 |
+
**BibTeX:**
|
24 |
+
|
25 |
+
```bibtex
|
26 |
+
{{ citation_bibtex | default("[More Information Needed]", true)}}
|
27 |
+
```
|
lerobot/common/datasets/compute_stats.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from lerobot.common.datasets.utils import load_image_as_numpy
|
19 |
+
|
20 |
+
|
21 |
+
def estimate_num_samples(
|
22 |
+
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
23 |
+
) -> int:
|
24 |
+
"""Heuristic to estimate the number of samples based on dataset size.
|
25 |
+
The power controls the sample growth relative to dataset size.
|
26 |
+
Lower the power for less number of samples.
|
27 |
+
|
28 |
+
For default arguments, we have:
|
29 |
+
- from 1 to ~500, num_samples=100
|
30 |
+
- at 1000, num_samples=177
|
31 |
+
- at 2000, num_samples=299
|
32 |
+
- at 5000, num_samples=594
|
33 |
+
- at 10000, num_samples=1000
|
34 |
+
- at 20000, num_samples=1681
|
35 |
+
"""
|
36 |
+
if dataset_len < min_num_samples:
|
37 |
+
min_num_samples = dataset_len
|
38 |
+
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
39 |
+
|
40 |
+
|
41 |
+
def sample_indices(data_len: int) -> list[int]:
|
42 |
+
num_samples = estimate_num_samples(data_len)
|
43 |
+
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
44 |
+
|
45 |
+
|
46 |
+
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
47 |
+
_, height, width = img.shape
|
48 |
+
|
49 |
+
if max(width, height) < max_size_threshold:
|
50 |
+
# no downsampling needed
|
51 |
+
return img
|
52 |
+
|
53 |
+
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
54 |
+
return img[:, ::downsample_factor, ::downsample_factor]
|
55 |
+
|
56 |
+
|
57 |
+
def sample_images(image_paths: list[str]) -> np.ndarray:
|
58 |
+
sampled_indices = sample_indices(len(image_paths))
|
59 |
+
|
60 |
+
images = None
|
61 |
+
for i, idx in enumerate(sampled_indices):
|
62 |
+
path = image_paths[idx]
|
63 |
+
# we load as uint8 to reduce memory usage
|
64 |
+
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
65 |
+
img = auto_downsample_height_width(img)
|
66 |
+
|
67 |
+
if images is None:
|
68 |
+
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
69 |
+
|
70 |
+
images[i] = img
|
71 |
+
|
72 |
+
return images
|
73 |
+
|
74 |
+
|
75 |
+
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
76 |
+
return {
|
77 |
+
"min": np.min(array, axis=axis, keepdims=keepdims),
|
78 |
+
"max": np.max(array, axis=axis, keepdims=keepdims),
|
79 |
+
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
80 |
+
"std": np.std(array, axis=axis, keepdims=keepdims),
|
81 |
+
"count": np.array([len(array)]),
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
86 |
+
ep_stats = {}
|
87 |
+
for key, data in episode_data.items():
|
88 |
+
if features[key]["dtype"] == "string":
|
89 |
+
continue # HACK: we should receive np.arrays of strings
|
90 |
+
elif features[key]["dtype"] in ["image", "video"]:
|
91 |
+
ep_ft_array = sample_images(data) # data is a list of image paths
|
92 |
+
axes_to_reduce = (0, 2, 3) # keep channel dim
|
93 |
+
keepdims = True
|
94 |
+
else:
|
95 |
+
ep_ft_array = data # data is already a np.ndarray
|
96 |
+
axes_to_reduce = 0 # compute stats over the first axis
|
97 |
+
keepdims = data.ndim == 1 # keep as np.array
|
98 |
+
|
99 |
+
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
100 |
+
|
101 |
+
# finally, we normalize and remove batch dim for images
|
102 |
+
if features[key]["dtype"] in ["image", "video"]:
|
103 |
+
ep_stats[key] = {
|
104 |
+
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
105 |
+
}
|
106 |
+
|
107 |
+
return ep_stats
|
108 |
+
|
109 |
+
|
110 |
+
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
111 |
+
for i in range(len(stats_list)):
|
112 |
+
for fkey in stats_list[i]:
|
113 |
+
for k, v in stats_list[i][fkey].items():
|
114 |
+
if not isinstance(v, np.ndarray):
|
115 |
+
raise ValueError(
|
116 |
+
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
117 |
+
)
|
118 |
+
if v.ndim == 0:
|
119 |
+
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
120 |
+
if k == "count" and v.shape != (1,):
|
121 |
+
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
122 |
+
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
123 |
+
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
124 |
+
|
125 |
+
|
126 |
+
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
127 |
+
"""Aggregates stats for a single feature."""
|
128 |
+
means = np.stack([s["mean"] for s in stats_ft_list])
|
129 |
+
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
130 |
+
counts = np.stack([s["count"] for s in stats_ft_list])
|
131 |
+
total_count = counts.sum(axis=0)
|
132 |
+
|
133 |
+
# Prepare weighted mean by matching number of dimensions
|
134 |
+
while counts.ndim < means.ndim:
|
135 |
+
counts = np.expand_dims(counts, axis=-1)
|
136 |
+
|
137 |
+
# Compute the weighted mean
|
138 |
+
weighted_means = means * counts
|
139 |
+
total_mean = weighted_means.sum(axis=0) / total_count
|
140 |
+
|
141 |
+
# Compute the variance using the parallel algorithm
|
142 |
+
delta_means = means - total_mean
|
143 |
+
weighted_variances = (variances + delta_means**2) * counts
|
144 |
+
total_variance = weighted_variances.sum(axis=0) / total_count
|
145 |
+
|
146 |
+
return {
|
147 |
+
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
148 |
+
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
149 |
+
"mean": total_mean,
|
150 |
+
"std": np.sqrt(total_variance),
|
151 |
+
"count": total_count,
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
156 |
+
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
157 |
+
|
158 |
+
The final stats will have the union of all data keys from each of the stats dicts.
|
159 |
+
|
160 |
+
For instance:
|
161 |
+
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
162 |
+
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
163 |
+
- new_mean = (mean of all data, weighted by counts)
|
164 |
+
- new_std = (std of all data)
|
165 |
+
"""
|
166 |
+
|
167 |
+
_assert_type_and_shape(stats_list)
|
168 |
+
|
169 |
+
data_keys = {key for stats in stats_list for key in stats}
|
170 |
+
aggregated_stats = {key: {} for key in data_keys}
|
171 |
+
|
172 |
+
for key in data_keys:
|
173 |
+
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
174 |
+
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
175 |
+
|
176 |
+
return aggregated_stats
|
lerobot/common/datasets/factory.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import logging
|
17 |
+
from pprint import pformat
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from lerobot.common.datasets.lerobot_dataset import (
|
22 |
+
LeRobotDataset,
|
23 |
+
LeRobotDatasetMetadata,
|
24 |
+
MultiLeRobotDataset,
|
25 |
+
)
|
26 |
+
from lerobot.common.datasets.transforms import ImageTransforms
|
27 |
+
from lerobot.configs.policies import PreTrainedConfig
|
28 |
+
from lerobot.configs.train import TrainPipelineConfig
|
29 |
+
|
30 |
+
IMAGENET_STATS = {
|
31 |
+
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
32 |
+
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def resolve_delta_timestamps(
|
37 |
+
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
38 |
+
) -> dict[str, list] | None:
|
39 |
+
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
43 |
+
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
44 |
+
delta_timestamps against.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
|
48 |
+
{
|
49 |
+
"observation.state": [-0.04, -0.02, 0]
|
50 |
+
"observation.action": [-0.02, 0, 0.02]
|
51 |
+
}
|
52 |
+
returns `None` if the the resulting dict is empty.
|
53 |
+
"""
|
54 |
+
delta_timestamps = {}
|
55 |
+
for key in ds_meta.features:
|
56 |
+
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
57 |
+
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
58 |
+
if key == "action" and cfg.action_delta_indices is not None:
|
59 |
+
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
60 |
+
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
61 |
+
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
62 |
+
|
63 |
+
if len(delta_timestamps) == 0:
|
64 |
+
delta_timestamps = None
|
65 |
+
|
66 |
+
return delta_timestamps
|
67 |
+
|
68 |
+
|
69 |
+
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
70 |
+
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
LeRobotDataset | MultiLeRobotDataset
|
80 |
+
"""
|
81 |
+
image_transforms = (
|
82 |
+
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
83 |
+
)
|
84 |
+
|
85 |
+
if isinstance(cfg.dataset.repo_id, str):
|
86 |
+
ds_meta = LeRobotDatasetMetadata(
|
87 |
+
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
88 |
+
)
|
89 |
+
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
90 |
+
dataset = LeRobotDataset(
|
91 |
+
cfg.dataset.repo_id,
|
92 |
+
root=cfg.dataset.root,
|
93 |
+
episodes=cfg.dataset.episodes,
|
94 |
+
delta_timestamps=delta_timestamps,
|
95 |
+
image_transforms=image_transforms,
|
96 |
+
revision=cfg.dataset.revision,
|
97 |
+
video_backend=cfg.dataset.video_backend,
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
101 |
+
dataset = MultiLeRobotDataset(
|
102 |
+
cfg.dataset.repo_id,
|
103 |
+
# TODO(aliberts): add proper support for multi dataset
|
104 |
+
# delta_timestamps=delta_timestamps,
|
105 |
+
image_transforms=image_transforms,
|
106 |
+
video_backend=cfg.dataset.video_backend,
|
107 |
+
)
|
108 |
+
logging.info(
|
109 |
+
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
110 |
+
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
111 |
+
)
|
112 |
+
|
113 |
+
if cfg.dataset.use_imagenet_stats:
|
114 |
+
for key in dataset.meta.camera_keys:
|
115 |
+
for stats_type, stats in IMAGENET_STATS.items():
|
116 |
+
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
117 |
+
|
118 |
+
return dataset
|
lerobot/common/datasets/image_writer.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import multiprocessing
|
17 |
+
import queue
|
18 |
+
import threading
|
19 |
+
from pathlib import Path
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import PIL.Image
|
23 |
+
import torch
|
24 |
+
|
25 |
+
|
26 |
+
def safe_stop_image_writer(func):
|
27 |
+
def wrapper(*args, **kwargs):
|
28 |
+
try:
|
29 |
+
return func(*args, **kwargs)
|
30 |
+
except Exception as e:
|
31 |
+
dataset = kwargs.get("dataset")
|
32 |
+
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
33 |
+
if image_writer is not None:
|
34 |
+
print("Waiting for image writer to terminate...")
|
35 |
+
image_writer.stop()
|
36 |
+
raise e
|
37 |
+
|
38 |
+
return wrapper
|
39 |
+
|
40 |
+
|
41 |
+
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
42 |
+
# TODO(aliberts): handle 1 channel and 4 for depth images
|
43 |
+
if image_array.ndim != 3:
|
44 |
+
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
45 |
+
|
46 |
+
if image_array.shape[0] == 3:
|
47 |
+
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
48 |
+
image_array = image_array.transpose(1, 2, 0)
|
49 |
+
|
50 |
+
elif image_array.shape[-1] != 3:
|
51 |
+
raise NotImplementedError(
|
52 |
+
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
53 |
+
)
|
54 |
+
|
55 |
+
if image_array.dtype != np.uint8:
|
56 |
+
if range_check:
|
57 |
+
max_ = image_array.max().item()
|
58 |
+
min_ = image_array.min().item()
|
59 |
+
if max_ > 1.0 or min_ < 0.0:
|
60 |
+
raise ValueError(
|
61 |
+
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
62 |
+
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
63 |
+
"provide a uint8 image with values in the range [0, 255]."
|
64 |
+
)
|
65 |
+
|
66 |
+
image_array = (image_array * 255).astype(np.uint8)
|
67 |
+
|
68 |
+
return PIL.Image.fromarray(image_array)
|
69 |
+
|
70 |
+
|
71 |
+
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
72 |
+
try:
|
73 |
+
if isinstance(image, np.ndarray):
|
74 |
+
img = image_array_to_pil_image(image)
|
75 |
+
elif isinstance(image, PIL.Image.Image):
|
76 |
+
img = image
|
77 |
+
else:
|
78 |
+
raise TypeError(f"Unsupported image type: {type(image)}")
|
79 |
+
img.save(fpath)
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error writing image {fpath}: {e}")
|
82 |
+
|
83 |
+
|
84 |
+
def worker_thread_loop(queue: queue.Queue):
|
85 |
+
while True:
|
86 |
+
item = queue.get()
|
87 |
+
if item is None:
|
88 |
+
queue.task_done()
|
89 |
+
break
|
90 |
+
image_array, fpath = item
|
91 |
+
write_image(image_array, fpath)
|
92 |
+
queue.task_done()
|
93 |
+
|
94 |
+
|
95 |
+
def worker_process(queue: queue.Queue, num_threads: int):
|
96 |
+
threads = []
|
97 |
+
for _ in range(num_threads):
|
98 |
+
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
99 |
+
t.daemon = True
|
100 |
+
t.start()
|
101 |
+
threads.append(t)
|
102 |
+
for t in threads:
|
103 |
+
t.join()
|
104 |
+
|
105 |
+
|
106 |
+
class AsyncImageWriter:
|
107 |
+
"""
|
108 |
+
This class abstract away the initialisation of processes or/and threads to
|
109 |
+
save images on disk asynchrounously, which is critical to control a robot and record data
|
110 |
+
at a high frame rate.
|
111 |
+
|
112 |
+
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
113 |
+
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
|
114 |
+
their own threads pool of size `num_threads`.
|
115 |
+
|
116 |
+
The optimal number of processes and threads depends on your computer capabilities.
|
117 |
+
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
118 |
+
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, num_processes: int = 0, num_threads: int = 1):
|
122 |
+
self.num_processes = num_processes
|
123 |
+
self.num_threads = num_threads
|
124 |
+
self.queue = None
|
125 |
+
self.threads = []
|
126 |
+
self.processes = []
|
127 |
+
self._stopped = False
|
128 |
+
|
129 |
+
if num_threads <= 0 and num_processes <= 0:
|
130 |
+
raise ValueError("Number of threads and processes must be greater than zero.")
|
131 |
+
|
132 |
+
if self.num_processes == 0:
|
133 |
+
# Use threading
|
134 |
+
self.queue = queue.Queue()
|
135 |
+
for _ in range(self.num_threads):
|
136 |
+
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
|
137 |
+
t.daemon = True
|
138 |
+
t.start()
|
139 |
+
self.threads.append(t)
|
140 |
+
else:
|
141 |
+
# Use multiprocessing
|
142 |
+
self.queue = multiprocessing.JoinableQueue()
|
143 |
+
for _ in range(self.num_processes):
|
144 |
+
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
145 |
+
p.daemon = True
|
146 |
+
p.start()
|
147 |
+
self.processes.append(p)
|
148 |
+
|
149 |
+
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
150 |
+
if isinstance(image, torch.Tensor):
|
151 |
+
# Convert tensor to numpy array to minimize main process time
|
152 |
+
image = image.cpu().numpy()
|
153 |
+
self.queue.put((image, fpath))
|
154 |
+
|
155 |
+
def wait_until_done(self):
|
156 |
+
self.queue.join()
|
157 |
+
|
158 |
+
def stop(self):
|
159 |
+
if self._stopped:
|
160 |
+
return
|
161 |
+
|
162 |
+
if self.num_processes == 0:
|
163 |
+
for _ in self.threads:
|
164 |
+
self.queue.put(None)
|
165 |
+
for t in self.threads:
|
166 |
+
t.join()
|
167 |
+
else:
|
168 |
+
num_nones = self.num_processes * self.num_threads
|
169 |
+
for _ in range(num_nones):
|
170 |
+
self.queue.put(None)
|
171 |
+
for p in self.processes:
|
172 |
+
p.join()
|
173 |
+
if p.is_alive():
|
174 |
+
p.terminate()
|
175 |
+
self.queue.close()
|
176 |
+
self.queue.join_thread()
|
177 |
+
|
178 |
+
self._stopped = True
|
lerobot/common/datasets/lerobot_dataset.py
ADDED
@@ -0,0 +1,1217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import contextlib
|
17 |
+
import logging
|
18 |
+
import shutil
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Callable
|
21 |
+
|
22 |
+
import datasets
|
23 |
+
import numpy as np
|
24 |
+
import packaging.version
|
25 |
+
import PIL.Image
|
26 |
+
import torch
|
27 |
+
import torch.utils
|
28 |
+
from datasets import concatenate_datasets, load_dataset
|
29 |
+
from huggingface_hub import HfApi, snapshot_download
|
30 |
+
from huggingface_hub.constants import REPOCARD_NAME
|
31 |
+
from huggingface_hub.errors import RevisionNotFoundError
|
32 |
+
|
33 |
+
from lerobot.common.constants import HF_LEROBOT_HOME
|
34 |
+
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
35 |
+
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
36 |
+
from lerobot.common.datasets.utils import (
|
37 |
+
DEFAULT_FEATURES,
|
38 |
+
DEFAULT_IMAGE_PATH,
|
39 |
+
INFO_PATH,
|
40 |
+
TASKS_PATH,
|
41 |
+
append_jsonlines,
|
42 |
+
backward_compatible_episodes_stats,
|
43 |
+
check_delta_timestamps,
|
44 |
+
check_timestamps_sync,
|
45 |
+
check_version_compatibility,
|
46 |
+
create_empty_dataset_info,
|
47 |
+
create_lerobot_dataset_card,
|
48 |
+
embed_images,
|
49 |
+
get_delta_indices,
|
50 |
+
get_episode_data_index,
|
51 |
+
get_features_from_robot,
|
52 |
+
get_hf_features_from_features,
|
53 |
+
get_safe_version,
|
54 |
+
hf_transform_to_torch,
|
55 |
+
is_valid_version,
|
56 |
+
load_episodes,
|
57 |
+
load_episodes_stats,
|
58 |
+
load_info,
|
59 |
+
load_stats,
|
60 |
+
load_tasks,
|
61 |
+
validate_episode_buffer,
|
62 |
+
validate_frame,
|
63 |
+
write_episode,
|
64 |
+
write_episode_stats,
|
65 |
+
write_info,
|
66 |
+
write_json,
|
67 |
+
)
|
68 |
+
from lerobot.common.datasets.video_utils import (
|
69 |
+
VideoFrame,
|
70 |
+
decode_video_frames,
|
71 |
+
encode_video_frames,
|
72 |
+
get_safe_default_codec,
|
73 |
+
get_video_info,
|
74 |
+
)
|
75 |
+
from lerobot.common.robot_devices.robots.utils import Robot
|
76 |
+
|
77 |
+
CODEBASE_VERSION = "v2.1"
|
78 |
+
|
79 |
+
|
80 |
+
class LeRobotDatasetMetadata:
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
repo_id: str,
|
84 |
+
root: str | Path | None = None,
|
85 |
+
revision: str | None = None,
|
86 |
+
force_cache_sync: bool = False,
|
87 |
+
):
|
88 |
+
self.repo_id = repo_id
|
89 |
+
self.revision = revision if revision else CODEBASE_VERSION
|
90 |
+
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
91 |
+
|
92 |
+
try:
|
93 |
+
if force_cache_sync:
|
94 |
+
raise FileNotFoundError
|
95 |
+
self.load_metadata()
|
96 |
+
except (FileNotFoundError, NotADirectoryError):
|
97 |
+
if is_valid_version(self.revision):
|
98 |
+
self.revision = get_safe_version(self.repo_id, self.revision)
|
99 |
+
|
100 |
+
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
101 |
+
self.pull_from_repo(allow_patterns="meta/")
|
102 |
+
self.load_metadata()
|
103 |
+
|
104 |
+
def load_metadata(self):
|
105 |
+
self.info = load_info(self.root)
|
106 |
+
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
107 |
+
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
108 |
+
self.episodes = load_episodes(self.root)
|
109 |
+
if self._version < packaging.version.parse("v2.1"):
|
110 |
+
self.stats = load_stats(self.root)
|
111 |
+
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
112 |
+
else:
|
113 |
+
self.episodes_stats = load_episodes_stats(self.root)
|
114 |
+
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
115 |
+
|
116 |
+
def pull_from_repo(
|
117 |
+
self,
|
118 |
+
allow_patterns: list[str] | str | None = None,
|
119 |
+
ignore_patterns: list[str] | str | None = None,
|
120 |
+
) -> None:
|
121 |
+
snapshot_download(
|
122 |
+
self.repo_id,
|
123 |
+
repo_type="dataset",
|
124 |
+
revision=self.revision,
|
125 |
+
local_dir=self.root,
|
126 |
+
allow_patterns=allow_patterns,
|
127 |
+
ignore_patterns=ignore_patterns,
|
128 |
+
)
|
129 |
+
|
130 |
+
@property
|
131 |
+
def _version(self) -> packaging.version.Version:
|
132 |
+
"""Codebase version used to create this dataset."""
|
133 |
+
return packaging.version.parse(self.info["codebase_version"])
|
134 |
+
|
135 |
+
def get_data_file_path(self, ep_index: int) -> Path:
|
136 |
+
ep_chunk = self.get_episode_chunk(ep_index)
|
137 |
+
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
138 |
+
return Path(fpath)
|
139 |
+
|
140 |
+
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
141 |
+
ep_chunk = self.get_episode_chunk(ep_index)
|
142 |
+
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
143 |
+
return Path(fpath)
|
144 |
+
|
145 |
+
def get_episode_chunk(self, ep_index: int) -> int:
|
146 |
+
return ep_index // self.chunks_size
|
147 |
+
|
148 |
+
@property
|
149 |
+
def data_path(self) -> str:
|
150 |
+
"""Formattable string for the parquet files."""
|
151 |
+
return self.info["data_path"]
|
152 |
+
|
153 |
+
@property
|
154 |
+
def video_path(self) -> str | None:
|
155 |
+
"""Formattable string for the video files."""
|
156 |
+
return self.info["video_path"]
|
157 |
+
|
158 |
+
@property
|
159 |
+
def robot_type(self) -> str | None:
|
160 |
+
"""Robot type used in recording this dataset."""
|
161 |
+
return self.info["robot_type"]
|
162 |
+
|
163 |
+
@property
|
164 |
+
def fps(self) -> int:
|
165 |
+
"""Frames per second used during data collection."""
|
166 |
+
return self.info["fps"]
|
167 |
+
|
168 |
+
@property
|
169 |
+
def features(self) -> dict[str, dict]:
|
170 |
+
"""All features contained in the dataset."""
|
171 |
+
return self.info["features"]
|
172 |
+
|
173 |
+
@property
|
174 |
+
def image_keys(self) -> list[str]:
|
175 |
+
"""Keys to access visual modalities stored as images."""
|
176 |
+
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
177 |
+
|
178 |
+
@property
|
179 |
+
def video_keys(self) -> list[str]:
|
180 |
+
"""Keys to access visual modalities stored as videos."""
|
181 |
+
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
182 |
+
|
183 |
+
@property
|
184 |
+
def camera_keys(self) -> list[str]:
|
185 |
+
"""Keys to access visual modalities (regardless of their storage method)."""
|
186 |
+
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
187 |
+
|
188 |
+
@property
|
189 |
+
def names(self) -> dict[str, list | dict]:
|
190 |
+
"""Names of the various dimensions of vector modalities."""
|
191 |
+
return {key: ft["names"] for key, ft in self.features.items()}
|
192 |
+
|
193 |
+
@property
|
194 |
+
def shapes(self) -> dict:
|
195 |
+
"""Shapes for the different features."""
|
196 |
+
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
197 |
+
|
198 |
+
@property
|
199 |
+
def total_episodes(self) -> int:
|
200 |
+
"""Total number of episodes available."""
|
201 |
+
return self.info["total_episodes"]
|
202 |
+
|
203 |
+
@property
|
204 |
+
def total_frames(self) -> int:
|
205 |
+
"""Total number of frames saved in this dataset."""
|
206 |
+
return self.info["total_frames"]
|
207 |
+
|
208 |
+
@property
|
209 |
+
def total_tasks(self) -> int:
|
210 |
+
"""Total number of different tasks performed in this dataset."""
|
211 |
+
return self.info["total_tasks"]
|
212 |
+
|
213 |
+
@property
|
214 |
+
def total_chunks(self) -> int:
|
215 |
+
"""Total number of chunks (groups of episodes)."""
|
216 |
+
return self.info["total_chunks"]
|
217 |
+
|
218 |
+
@property
|
219 |
+
def chunks_size(self) -> int:
|
220 |
+
"""Max number of episodes per chunk."""
|
221 |
+
return self.info["chunks_size"]
|
222 |
+
|
223 |
+
def get_task_index(self, task: str) -> int | None:
|
224 |
+
"""
|
225 |
+
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
226 |
+
otherwise return None.
|
227 |
+
"""
|
228 |
+
return self.task_to_task_index.get(task, None)
|
229 |
+
|
230 |
+
def add_task(self, task: str):
|
231 |
+
"""
|
232 |
+
Given a task in natural language, add it to the dictionary of tasks.
|
233 |
+
"""
|
234 |
+
if task in self.task_to_task_index:
|
235 |
+
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
236 |
+
|
237 |
+
task_index = self.info["total_tasks"]
|
238 |
+
self.task_to_task_index[task] = task_index
|
239 |
+
self.tasks[task_index] = task
|
240 |
+
self.info["total_tasks"] += 1
|
241 |
+
|
242 |
+
task_dict = {
|
243 |
+
"task_index": task_index,
|
244 |
+
"task": task,
|
245 |
+
}
|
246 |
+
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
247 |
+
|
248 |
+
def save_episode(
|
249 |
+
self,
|
250 |
+
episode_index: int,
|
251 |
+
episode_length: int,
|
252 |
+
episode_tasks: list[str],
|
253 |
+
episode_stats: dict[str, dict],
|
254 |
+
) -> None:
|
255 |
+
self.info["total_episodes"] += 1
|
256 |
+
self.info["total_frames"] += episode_length
|
257 |
+
|
258 |
+
chunk = self.get_episode_chunk(episode_index)
|
259 |
+
if chunk >= self.total_chunks:
|
260 |
+
self.info["total_chunks"] += 1
|
261 |
+
|
262 |
+
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
263 |
+
self.info["total_videos"] += len(self.video_keys)
|
264 |
+
if len(self.video_keys) > 0:
|
265 |
+
self.update_video_info()
|
266 |
+
|
267 |
+
write_info(self.info, self.root)
|
268 |
+
|
269 |
+
episode_dict = {
|
270 |
+
"episode_index": episode_index,
|
271 |
+
"tasks": episode_tasks,
|
272 |
+
"length": episode_length,
|
273 |
+
}
|
274 |
+
self.episodes[episode_index] = episode_dict
|
275 |
+
write_episode(episode_dict, self.root)
|
276 |
+
|
277 |
+
self.episodes_stats[episode_index] = episode_stats
|
278 |
+
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
279 |
+
write_episode_stats(episode_index, episode_stats, self.root)
|
280 |
+
|
281 |
+
def update_video_info(self) -> None:
|
282 |
+
"""
|
283 |
+
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
284 |
+
been encoded the same way. Also, this means it assumes the first episode exists.
|
285 |
+
"""
|
286 |
+
for key in self.video_keys:
|
287 |
+
if not self.features[key].get("info", None):
|
288 |
+
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
289 |
+
self.info["features"][key]["info"] = get_video_info(video_path)
|
290 |
+
|
291 |
+
def __repr__(self):
|
292 |
+
feature_keys = list(self.features)
|
293 |
+
return (
|
294 |
+
f"{self.__class__.__name__}({{\n"
|
295 |
+
f" Repository ID: '{self.repo_id}',\n"
|
296 |
+
f" Total episodes: '{self.total_episodes}',\n"
|
297 |
+
f" Total frames: '{self.total_frames}',\n"
|
298 |
+
f" Features: '{feature_keys}',\n"
|
299 |
+
"})',\n"
|
300 |
+
)
|
301 |
+
|
302 |
+
@classmethod
|
303 |
+
def create(
|
304 |
+
cls,
|
305 |
+
repo_id: str,
|
306 |
+
fps: int,
|
307 |
+
root: str | Path | None = None,
|
308 |
+
robot: Robot | None = None,
|
309 |
+
robot_type: str | None = None,
|
310 |
+
features: dict | None = None,
|
311 |
+
use_videos: bool = True,
|
312 |
+
) -> "LeRobotDatasetMetadata":
|
313 |
+
"""Creates metadata for a LeRobotDataset."""
|
314 |
+
obj = cls.__new__(cls)
|
315 |
+
obj.repo_id = repo_id
|
316 |
+
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
317 |
+
|
318 |
+
obj.root.mkdir(parents=True, exist_ok=False)
|
319 |
+
|
320 |
+
if robot is not None:
|
321 |
+
features = get_features_from_robot(robot, use_videos)
|
322 |
+
robot_type = robot.robot_type
|
323 |
+
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
324 |
+
logging.warning(
|
325 |
+
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
326 |
+
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
327 |
+
)
|
328 |
+
elif features is None:
|
329 |
+
raise ValueError(
|
330 |
+
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
# TODO(aliberts, rcadene): implement sanity check for features
|
334 |
+
features = {**features, **DEFAULT_FEATURES}
|
335 |
+
|
336 |
+
# check if none of the features contains a "/" in their names,
|
337 |
+
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
338 |
+
for key in features:
|
339 |
+
if "/" in key:
|
340 |
+
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
341 |
+
|
342 |
+
features = {**features, **DEFAULT_FEATURES}
|
343 |
+
|
344 |
+
obj.tasks, obj.task_to_task_index = {}, {}
|
345 |
+
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
346 |
+
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
347 |
+
if len(obj.video_keys) > 0 and not use_videos:
|
348 |
+
raise ValueError()
|
349 |
+
write_json(obj.info, obj.root / INFO_PATH)
|
350 |
+
obj.revision = None
|
351 |
+
return obj
|
352 |
+
|
353 |
+
|
354 |
+
class LeRobotDataset(torch.utils.data.Dataset):
|
355 |
+
def __init__(
|
356 |
+
self,
|
357 |
+
repo_id: str,
|
358 |
+
root: str | Path | None = None,
|
359 |
+
episodes: list[int] | None = None,
|
360 |
+
image_transforms: Callable | None = None,
|
361 |
+
delta_timestamps: dict[list[float]] | None = None,
|
362 |
+
tolerance_s: float = 1e-4,
|
363 |
+
revision: str | None = None,
|
364 |
+
force_cache_sync: bool = False,
|
365 |
+
download_videos: bool = True,
|
366 |
+
video_backend: str | None = None,
|
367 |
+
):
|
368 |
+
"""
|
369 |
+
2 modes are available for instantiating this class, depending on 2 different use cases:
|
370 |
+
|
371 |
+
1. Your dataset already exists:
|
372 |
+
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
373 |
+
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
374 |
+
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
375 |
+
internet connection).
|
376 |
+
|
377 |
+
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
378 |
+
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
379 |
+
the dataset from that address and load it, pending your dataset is compliant with
|
380 |
+
codebase_version v2.0. If your dataset has been created before this new format, you will be
|
381 |
+
prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at
|
382 |
+
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
383 |
+
|
384 |
+
|
385 |
+
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
|
386 |
+
LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
|
387 |
+
existing dataset to the LeRobotDataset format.
|
388 |
+
|
389 |
+
|
390 |
+
In terms of files, LeRobotDataset encapsulates 3 main things:
|
391 |
+
- metadata:
|
392 |
+
- info contains various information about the dataset like shapes, keys, fps etc.
|
393 |
+
- stats stores the dataset statistics of the different modalities for normalization
|
394 |
+
- tasks contains the prompts for each task of the dataset, which can be used for
|
395 |
+
task-conditioned training.
|
396 |
+
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
397 |
+
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
398 |
+
|
399 |
+
A typical LeRobotDataset looks like this from its root path:
|
400 |
+
.
|
401 |
+
├── data
|
402 |
+
│ ├── chunk-000
|
403 |
+
│ │ ├── episode_000000.parquet
|
404 |
+
│ │ ├── episode_000001.parquet
|
405 |
+
│ │ ├── episode_000002.parquet
|
406 |
+
│ │ └── ...
|
407 |
+
│ ├── chunk-001
|
408 |
+
│ │ ├── episode_001000.parquet
|
409 |
+
│ │ ├── episode_001001.parquet
|
410 |
+
│ │ ├── episode_001002.parquet
|
411 |
+
│ │ └── ...
|
412 |
+
│ └── ...
|
413 |
+
├── meta
|
414 |
+
│ ├── episodes.jsonl
|
415 |
+
│ ├── info.json
|
416 |
+
│ ├── stats.json
|
417 |
+
│ └── tasks.jsonl
|
418 |
+
└── videos
|
419 |
+
├── chunk-000
|
420 |
+
│ ├── observation.images.laptop
|
421 |
+
│ │ ├── episode_000000.mp4
|
422 |
+
│ │ ├── episode_000001.mp4
|
423 |
+
│ │ ├── episode_000002.mp4
|
424 |
+
│ │ └── ...
|
425 |
+
│ ├── observation.images.phone
|
426 |
+
│ │ ├── episode_000000.mp4
|
427 |
+
│ │ ├── episode_000001.mp4
|
428 |
+
│ │ ├── episode_000002.mp4
|
429 |
+
│ │ └── ...
|
430 |
+
├── chunk-001
|
431 |
+
└── ...
|
432 |
+
|
433 |
+
Note that this file-based structure is designed to be as versatile as possible. The files are split by
|
434 |
+
episodes which allows a more granular control over which episodes one wants to use and download. The
|
435 |
+
structure of the dataset is entirely described in the info.json file, which can be easily downloaded
|
436 |
+
or viewed directly on the hub before downloading any actual data. The type of files used are very
|
437 |
+
simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
|
438 |
+
for the README).
|
439 |
+
|
440 |
+
Args:
|
441 |
+
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
442 |
+
will be stored under root/repo_id.
|
443 |
+
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
444 |
+
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
445 |
+
'~/.cache/huggingface/lerobot'.
|
446 |
+
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
447 |
+
their episode_index in this list. Defaults to None.
|
448 |
+
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
449 |
+
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
450 |
+
from videos or images). Defaults to None.
|
451 |
+
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
|
452 |
+
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
453 |
+
sync with the fps value. It is used at the init of the dataset to make sure that each
|
454 |
+
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
455 |
+
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
456 |
+
multiples of 1/fps. Defaults to 1e-4.
|
457 |
+
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
458 |
+
commit hash. Defaults to current codebase version tag.
|
459 |
+
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
|
460 |
+
are already present in the local cache, this will be faster. However, files loaded might not
|
461 |
+
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
|
462 |
+
False.
|
463 |
+
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
464 |
+
video files are already present on local disk, they won't be downloaded again. Defaults to
|
465 |
+
True.
|
466 |
+
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
467 |
+
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
468 |
+
"""
|
469 |
+
super().__init__()
|
470 |
+
self.repo_id = repo_id
|
471 |
+
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
472 |
+
self.image_transforms = image_transforms
|
473 |
+
self.delta_timestamps = delta_timestamps
|
474 |
+
self.episodes = episodes
|
475 |
+
self.tolerance_s = tolerance_s
|
476 |
+
self.revision = revision if revision else CODEBASE_VERSION
|
477 |
+
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
478 |
+
self.delta_indices = None
|
479 |
+
|
480 |
+
# Unused attributes
|
481 |
+
self.image_writer = None
|
482 |
+
self.episode_buffer = None
|
483 |
+
|
484 |
+
self.root.mkdir(exist_ok=True, parents=True)
|
485 |
+
|
486 |
+
# Load metadata
|
487 |
+
self.meta = LeRobotDatasetMetadata(
|
488 |
+
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
489 |
+
)
|
490 |
+
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
491 |
+
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
492 |
+
self.stats = aggregate_stats(episodes_stats)
|
493 |
+
|
494 |
+
# Load actual data
|
495 |
+
try:
|
496 |
+
if force_cache_sync:
|
497 |
+
raise FileNotFoundError
|
498 |
+
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
499 |
+
self.hf_dataset = self.load_hf_dataset()
|
500 |
+
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
501 |
+
self.revision = get_safe_version(self.repo_id, self.revision)
|
502 |
+
self.download_episodes(download_videos)
|
503 |
+
self.hf_dataset = self.load_hf_dataset()
|
504 |
+
|
505 |
+
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
506 |
+
|
507 |
+
# Check timestamps
|
508 |
+
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
509 |
+
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
510 |
+
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
511 |
+
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
512 |
+
|
513 |
+
# Setup delta_indices
|
514 |
+
if self.delta_timestamps is not None:
|
515 |
+
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
516 |
+
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
517 |
+
|
518 |
+
def push_to_hub(
|
519 |
+
self,
|
520 |
+
branch: str | None = None,
|
521 |
+
tags: list | None = None,
|
522 |
+
license: str | None = "apache-2.0",
|
523 |
+
tag_version: bool = True,
|
524 |
+
push_videos: bool = True,
|
525 |
+
private: bool = False,
|
526 |
+
allow_patterns: list[str] | str | None = None,
|
527 |
+
upload_large_folder: bool = False,
|
528 |
+
**card_kwargs,
|
529 |
+
) -> None:
|
530 |
+
ignore_patterns = ["images/"]
|
531 |
+
if not push_videos:
|
532 |
+
ignore_patterns.append("videos/")
|
533 |
+
|
534 |
+
hub_api = HfApi()
|
535 |
+
hub_api.create_repo(
|
536 |
+
repo_id=self.repo_id,
|
537 |
+
private=private,
|
538 |
+
repo_type="dataset",
|
539 |
+
exist_ok=True,
|
540 |
+
)
|
541 |
+
if branch:
|
542 |
+
hub_api.create_branch(
|
543 |
+
repo_id=self.repo_id,
|
544 |
+
branch=branch,
|
545 |
+
revision=self.revision,
|
546 |
+
repo_type="dataset",
|
547 |
+
exist_ok=True,
|
548 |
+
)
|
549 |
+
|
550 |
+
upload_kwargs = {
|
551 |
+
"repo_id": self.repo_id,
|
552 |
+
"folder_path": self.root,
|
553 |
+
"repo_type": "dataset",
|
554 |
+
"revision": branch,
|
555 |
+
"allow_patterns": allow_patterns,
|
556 |
+
"ignore_patterns": ignore_patterns,
|
557 |
+
}
|
558 |
+
if upload_large_folder:
|
559 |
+
hub_api.upload_large_folder(**upload_kwargs)
|
560 |
+
else:
|
561 |
+
hub_api.upload_folder(**upload_kwargs)
|
562 |
+
|
563 |
+
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
564 |
+
card = create_lerobot_dataset_card(
|
565 |
+
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
566 |
+
)
|
567 |
+
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
568 |
+
|
569 |
+
if tag_version:
|
570 |
+
with contextlib.suppress(RevisionNotFoundError):
|
571 |
+
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
572 |
+
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
573 |
+
|
574 |
+
def pull_from_repo(
|
575 |
+
self,
|
576 |
+
allow_patterns: list[str] | str | None = None,
|
577 |
+
ignore_patterns: list[str] | str | None = None,
|
578 |
+
) -> None:
|
579 |
+
snapshot_download(
|
580 |
+
self.repo_id,
|
581 |
+
repo_type="dataset",
|
582 |
+
revision=self.revision,
|
583 |
+
local_dir=self.root,
|
584 |
+
allow_patterns=allow_patterns,
|
585 |
+
ignore_patterns=ignore_patterns,
|
586 |
+
)
|
587 |
+
|
588 |
+
def download_episodes(self, download_videos: bool = True) -> None:
|
589 |
+
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
590 |
+
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
591 |
+
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
592 |
+
in 'local_dir', they won't be downloaded again.
|
593 |
+
"""
|
594 |
+
# TODO(rcadene, aliberts): implement faster transfer
|
595 |
+
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
596 |
+
files = None
|
597 |
+
ignore_patterns = None if download_videos else "videos/"
|
598 |
+
if self.episodes is not None:
|
599 |
+
files = self.get_episodes_file_paths()
|
600 |
+
|
601 |
+
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
602 |
+
|
603 |
+
def get_episodes_file_paths(self) -> list[Path]:
|
604 |
+
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
605 |
+
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
606 |
+
if len(self.meta.video_keys) > 0:
|
607 |
+
video_files = [
|
608 |
+
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
609 |
+
for vid_key in self.meta.video_keys
|
610 |
+
for ep_idx in episodes
|
611 |
+
]
|
612 |
+
fpaths += video_files
|
613 |
+
|
614 |
+
return fpaths
|
615 |
+
|
616 |
+
def load_hf_dataset(self) -> datasets.Dataset:
|
617 |
+
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
618 |
+
if self.episodes is None:
|
619 |
+
path = str(self.root / "data")
|
620 |
+
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
621 |
+
else:
|
622 |
+
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
623 |
+
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
624 |
+
|
625 |
+
# TODO(aliberts): hf_dataset.set_format("torch")
|
626 |
+
hf_dataset.set_transform(hf_transform_to_torch)
|
627 |
+
return hf_dataset
|
628 |
+
|
629 |
+
def create_hf_dataset(self) -> datasets.Dataset:
|
630 |
+
features = get_hf_features_from_features(self.features)
|
631 |
+
ft_dict = {col: [] for col in features}
|
632 |
+
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
633 |
+
|
634 |
+
# TODO(aliberts): hf_dataset.set_format("torch")
|
635 |
+
hf_dataset.set_transform(hf_transform_to_torch)
|
636 |
+
return hf_dataset
|
637 |
+
|
638 |
+
@property
|
639 |
+
def fps(self) -> int:
|
640 |
+
"""Frames per second used during data collection."""
|
641 |
+
return self.meta.fps
|
642 |
+
|
643 |
+
@property
|
644 |
+
def num_frames(self) -> int:
|
645 |
+
"""Number of frames in selected episodes."""
|
646 |
+
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
647 |
+
|
648 |
+
@property
|
649 |
+
def num_episodes(self) -> int:
|
650 |
+
"""Number of episodes selected."""
|
651 |
+
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
652 |
+
|
653 |
+
@property
|
654 |
+
def features(self) -> dict[str, dict]:
|
655 |
+
return self.meta.features
|
656 |
+
|
657 |
+
@property
|
658 |
+
def hf_features(self) -> datasets.Features:
|
659 |
+
"""Features of the hf_dataset."""
|
660 |
+
if self.hf_dataset is not None:
|
661 |
+
return self.hf_dataset.features
|
662 |
+
else:
|
663 |
+
return get_hf_features_from_features(self.features)
|
664 |
+
|
665 |
+
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
666 |
+
ep_start = self.episode_data_index["from"][ep_idx]
|
667 |
+
ep_end = self.episode_data_index["to"][ep_idx]
|
668 |
+
query_indices = {
|
669 |
+
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
670 |
+
for key, delta_idx in self.delta_indices.items()
|
671 |
+
}
|
672 |
+
padding = { # Pad values outside of current episode range
|
673 |
+
f"{key}_is_pad": torch.BoolTensor(
|
674 |
+
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
675 |
+
)
|
676 |
+
for key, delta_idx in self.delta_indices.items()
|
677 |
+
}
|
678 |
+
return query_indices, padding
|
679 |
+
|
680 |
+
def _get_query_timestamps(
|
681 |
+
self,
|
682 |
+
current_ts: float,
|
683 |
+
query_indices: dict[str, list[int]] | None = None,
|
684 |
+
) -> dict[str, list[float]]:
|
685 |
+
query_timestamps = {}
|
686 |
+
for key in self.meta.video_keys:
|
687 |
+
if query_indices is not None and key in query_indices:
|
688 |
+
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
689 |
+
query_timestamps[key] = torch.stack(timestamps).tolist()
|
690 |
+
else:
|
691 |
+
query_timestamps[key] = [current_ts]
|
692 |
+
|
693 |
+
return query_timestamps
|
694 |
+
|
695 |
+
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
696 |
+
return {
|
697 |
+
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
698 |
+
for key, q_idx in query_indices.items()
|
699 |
+
if key not in self.meta.video_keys
|
700 |
+
}
|
701 |
+
|
702 |
+
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
703 |
+
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
704 |
+
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
705 |
+
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
706 |
+
the main process and a subprocess fails to access it.
|
707 |
+
"""
|
708 |
+
item = {}
|
709 |
+
for vid_key, query_ts in query_timestamps.items():
|
710 |
+
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
711 |
+
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
712 |
+
item[vid_key] = frames.squeeze(0)
|
713 |
+
|
714 |
+
return item
|
715 |
+
|
716 |
+
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
717 |
+
for key, val in padding.items():
|
718 |
+
item[key] = torch.BoolTensor(val)
|
719 |
+
return item
|
720 |
+
|
721 |
+
def __len__(self):
|
722 |
+
return self.num_frames
|
723 |
+
|
724 |
+
def __getitem__(self, idx) -> dict:
|
725 |
+
item = self.hf_dataset[idx]
|
726 |
+
ep_idx = item["episode_index"].item()
|
727 |
+
|
728 |
+
query_indices = None
|
729 |
+
if self.delta_indices is not None:
|
730 |
+
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
731 |
+
query_result = self._query_hf_dataset(query_indices)
|
732 |
+
item = {**item, **padding}
|
733 |
+
for key, val in query_result.items():
|
734 |
+
item[key] = val
|
735 |
+
|
736 |
+
if len(self.meta.video_keys) > 0:
|
737 |
+
current_ts = item["timestamp"].item()
|
738 |
+
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
739 |
+
video_frames = self._query_videos(query_timestamps, ep_idx)
|
740 |
+
item = {**video_frames, **item}
|
741 |
+
|
742 |
+
if self.image_transforms is not None:
|
743 |
+
image_keys = self.meta.camera_keys
|
744 |
+
for cam in image_keys:
|
745 |
+
item[cam] = self.image_transforms(item[cam])
|
746 |
+
|
747 |
+
# Add task as a string
|
748 |
+
task_idx = item["task_index"].item()
|
749 |
+
item["task"] = self.meta.tasks[task_idx]
|
750 |
+
|
751 |
+
return item
|
752 |
+
|
753 |
+
def __repr__(self):
|
754 |
+
feature_keys = list(self.features)
|
755 |
+
return (
|
756 |
+
f"{self.__class__.__name__}({{\n"
|
757 |
+
f" Repository ID: '{self.repo_id}',\n"
|
758 |
+
f" Number of selected episodes: '{self.num_episodes}',\n"
|
759 |
+
f" Number of selected samples: '{self.num_frames}',\n"
|
760 |
+
f" Features: '{feature_keys}',\n"
|
761 |
+
"})',\n"
|
762 |
+
)
|
763 |
+
|
764 |
+
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
765 |
+
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
766 |
+
ep_buffer = {}
|
767 |
+
# size and task are special cases that are not in self.features
|
768 |
+
ep_buffer["size"] = 0
|
769 |
+
ep_buffer["task"] = []
|
770 |
+
for key in self.features:
|
771 |
+
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
772 |
+
return ep_buffer
|
773 |
+
|
774 |
+
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
775 |
+
fpath = DEFAULT_IMAGE_PATH.format(
|
776 |
+
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
777 |
+
)
|
778 |
+
return self.root / fpath
|
779 |
+
|
780 |
+
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
781 |
+
if self.image_writer is None:
|
782 |
+
if isinstance(image, torch.Tensor):
|
783 |
+
image = image.cpu().numpy()
|
784 |
+
write_image(image, fpath)
|
785 |
+
else:
|
786 |
+
self.image_writer.save_image(image=image, fpath=fpath)
|
787 |
+
|
788 |
+
def add_frame(self, frame: dict) -> None:
|
789 |
+
"""
|
790 |
+
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
791 |
+
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
792 |
+
then needs to be called.
|
793 |
+
"""
|
794 |
+
# Convert torch to numpy if needed
|
795 |
+
for name in frame:
|
796 |
+
if isinstance(frame[name], torch.Tensor):
|
797 |
+
frame[name] = frame[name].numpy()
|
798 |
+
|
799 |
+
validate_frame(frame, self.features)
|
800 |
+
|
801 |
+
if self.episode_buffer is None:
|
802 |
+
self.episode_buffer = self.create_episode_buffer()
|
803 |
+
|
804 |
+
# Automatically add frame_index and timestamp to episode buffer
|
805 |
+
frame_index = self.episode_buffer["size"]
|
806 |
+
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
807 |
+
self.episode_buffer["frame_index"].append(frame_index)
|
808 |
+
self.episode_buffer["timestamp"].append(timestamp)
|
809 |
+
|
810 |
+
# Add frame features to episode_buffer
|
811 |
+
for key in frame:
|
812 |
+
if key == "task":
|
813 |
+
# Note: we associate the task in natural language to its task index during `save_episode`
|
814 |
+
self.episode_buffer["task"].append(frame["task"])
|
815 |
+
continue
|
816 |
+
|
817 |
+
if key not in self.features:
|
818 |
+
raise ValueError(
|
819 |
+
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
820 |
+
)
|
821 |
+
|
822 |
+
if self.features[key]["dtype"] in ["image", "video"]:
|
823 |
+
img_path = self._get_image_file_path(
|
824 |
+
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
825 |
+
)
|
826 |
+
if frame_index == 0:
|
827 |
+
img_path.parent.mkdir(parents=True, exist_ok=True)
|
828 |
+
self._save_image(frame[key], img_path)
|
829 |
+
self.episode_buffer[key].append(str(img_path))
|
830 |
+
else:
|
831 |
+
self.episode_buffer[key].append(frame[key])
|
832 |
+
|
833 |
+
self.episode_buffer["size"] += 1
|
834 |
+
|
835 |
+
def save_episode(self, episode_data: dict | None = None) -> None:
|
836 |
+
"""
|
837 |
+
This will save to disk the current episode in self.episode_buffer.
|
838 |
+
|
839 |
+
Args:
|
840 |
+
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
841 |
+
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
842 |
+
None.
|
843 |
+
"""
|
844 |
+
if not episode_data:
|
845 |
+
episode_buffer = self.episode_buffer
|
846 |
+
|
847 |
+
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
848 |
+
|
849 |
+
# size and task are special cases that won't be added to hf_dataset
|
850 |
+
episode_length = episode_buffer.pop("size")
|
851 |
+
tasks = episode_buffer.pop("task")
|
852 |
+
episode_tasks = list(set(tasks))
|
853 |
+
episode_index = episode_buffer["episode_index"]
|
854 |
+
|
855 |
+
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
856 |
+
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
857 |
+
|
858 |
+
# Add new tasks to the tasks dictionary
|
859 |
+
for task in episode_tasks:
|
860 |
+
task_index = self.meta.get_task_index(task)
|
861 |
+
if task_index is None:
|
862 |
+
self.meta.add_task(task)
|
863 |
+
|
864 |
+
# Given tasks in natural language, find their corresponding task indices
|
865 |
+
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
866 |
+
|
867 |
+
for key, ft in self.features.items():
|
868 |
+
# index, episode_index, task_index are already processed above, and image and video
|
869 |
+
# are processed separately by storing image path and frame info as meta data
|
870 |
+
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
871 |
+
continue
|
872 |
+
episode_buffer[key] = np.stack(episode_buffer[key])
|
873 |
+
|
874 |
+
self._wait_image_writer()
|
875 |
+
self._save_episode_table(episode_buffer, episode_index)
|
876 |
+
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
877 |
+
|
878 |
+
if len(self.meta.video_keys) > 0:
|
879 |
+
video_paths = self.encode_episode_videos(episode_index)
|
880 |
+
for key in self.meta.video_keys:
|
881 |
+
episode_buffer[key] = video_paths[key]
|
882 |
+
|
883 |
+
# `meta.save_episode` be executed after encoding the videos
|
884 |
+
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
885 |
+
|
886 |
+
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
887 |
+
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
888 |
+
check_timestamps_sync(
|
889 |
+
episode_buffer["timestamp"],
|
890 |
+
episode_buffer["episode_index"],
|
891 |
+
ep_data_index_np,
|
892 |
+
self.fps,
|
893 |
+
self.tolerance_s,
|
894 |
+
)
|
895 |
+
|
896 |
+
video_files = list(self.root.rglob("*.mp4"))
|
897 |
+
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
898 |
+
|
899 |
+
parquet_files = list(self.root.rglob("*.parquet"))
|
900 |
+
assert len(parquet_files) == self.num_episodes
|
901 |
+
|
902 |
+
# delete images
|
903 |
+
img_dir = self.root / "images"
|
904 |
+
if img_dir.is_dir():
|
905 |
+
shutil.rmtree(self.root / "images")
|
906 |
+
|
907 |
+
if not episode_data: # Reset the buffer
|
908 |
+
self.episode_buffer = self.create_episode_buffer()
|
909 |
+
|
910 |
+
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
911 |
+
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
912 |
+
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
913 |
+
ep_dataset = embed_images(ep_dataset)
|
914 |
+
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
915 |
+
self.hf_dataset.set_transform(hf_transform_to_torch)
|
916 |
+
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
917 |
+
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
918 |
+
ep_dataset.to_parquet(ep_data_path)
|
919 |
+
|
920 |
+
def clear_episode_buffer(self) -> None:
|
921 |
+
episode_index = self.episode_buffer["episode_index"]
|
922 |
+
if self.image_writer is not None:
|
923 |
+
for cam_key in self.meta.camera_keys:
|
924 |
+
img_dir = self._get_image_file_path(
|
925 |
+
episode_index=episode_index, image_key=cam_key, frame_index=0
|
926 |
+
).parent
|
927 |
+
if img_dir.is_dir():
|
928 |
+
shutil.rmtree(img_dir)
|
929 |
+
|
930 |
+
# Reset the buffer
|
931 |
+
self.episode_buffer = self.create_episode_buffer()
|
932 |
+
|
933 |
+
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
934 |
+
if isinstance(self.image_writer, AsyncImageWriter):
|
935 |
+
logging.warning(
|
936 |
+
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
937 |
+
)
|
938 |
+
|
939 |
+
self.image_writer = AsyncImageWriter(
|
940 |
+
num_processes=num_processes,
|
941 |
+
num_threads=num_threads,
|
942 |
+
)
|
943 |
+
|
944 |
+
def stop_image_writer(self) -> None:
|
945 |
+
"""
|
946 |
+
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
947 |
+
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
948 |
+
"""
|
949 |
+
if self.image_writer is not None:
|
950 |
+
self.image_writer.stop()
|
951 |
+
self.image_writer = None
|
952 |
+
|
953 |
+
def _wait_image_writer(self) -> None:
|
954 |
+
"""Wait for asynchronous image writer to finish."""
|
955 |
+
if self.image_writer is not None:
|
956 |
+
self.image_writer.wait_until_done()
|
957 |
+
|
958 |
+
def encode_videos(self) -> None:
|
959 |
+
"""
|
960 |
+
Use ffmpeg to convert frames stored as png into mp4 videos.
|
961 |
+
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
962 |
+
since video encoding with ffmpeg is already using multithreading.
|
963 |
+
"""
|
964 |
+
for ep_idx in range(self.meta.total_episodes):
|
965 |
+
self.encode_episode_videos(ep_idx)
|
966 |
+
|
967 |
+
def encode_episode_videos(self, episode_index: int) -> dict:
|
968 |
+
"""
|
969 |
+
Use ffmpeg to convert frames stored as png into mp4 videos.
|
970 |
+
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
971 |
+
since video encoding with ffmpeg is already using multithreading.
|
972 |
+
"""
|
973 |
+
video_paths = {}
|
974 |
+
for key in self.meta.video_keys:
|
975 |
+
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
976 |
+
video_paths[key] = str(video_path)
|
977 |
+
if video_path.is_file():
|
978 |
+
# Skip if video is already encoded. Could be the case when resuming data recording.
|
979 |
+
continue
|
980 |
+
img_dir = self._get_image_file_path(
|
981 |
+
episode_index=episode_index, image_key=key, frame_index=0
|
982 |
+
).parent
|
983 |
+
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
984 |
+
|
985 |
+
return video_paths
|
986 |
+
|
987 |
+
@classmethod
|
988 |
+
def create(
|
989 |
+
cls,
|
990 |
+
repo_id: str,
|
991 |
+
fps: int,
|
992 |
+
root: str | Path | None = None,
|
993 |
+
robot: Robot | None = None,
|
994 |
+
robot_type: str | None = None,
|
995 |
+
features: dict | None = None,
|
996 |
+
use_videos: bool = True,
|
997 |
+
tolerance_s: float = 1e-4,
|
998 |
+
image_writer_processes: int = 0,
|
999 |
+
image_writer_threads: int = 0,
|
1000 |
+
video_backend: str | None = None,
|
1001 |
+
) -> "LeRobotDataset":
|
1002 |
+
"""Create a LeRobot Dataset from scratch in order to record data."""
|
1003 |
+
obj = cls.__new__(cls)
|
1004 |
+
obj.meta = LeRobotDatasetMetadata.create(
|
1005 |
+
repo_id=repo_id,
|
1006 |
+
fps=fps,
|
1007 |
+
root=root,
|
1008 |
+
robot=robot,
|
1009 |
+
robot_type=robot_type,
|
1010 |
+
features=features,
|
1011 |
+
use_videos=use_videos,
|
1012 |
+
)
|
1013 |
+
obj.repo_id = obj.meta.repo_id
|
1014 |
+
obj.root = obj.meta.root
|
1015 |
+
obj.revision = None
|
1016 |
+
obj.tolerance_s = tolerance_s
|
1017 |
+
obj.image_writer = None
|
1018 |
+
|
1019 |
+
if image_writer_processes or image_writer_threads:
|
1020 |
+
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
1021 |
+
|
1022 |
+
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
1023 |
+
obj.episode_buffer = obj.create_episode_buffer()
|
1024 |
+
|
1025 |
+
obj.episodes = None
|
1026 |
+
obj.hf_dataset = obj.create_hf_dataset()
|
1027 |
+
obj.image_transforms = None
|
1028 |
+
obj.delta_timestamps = None
|
1029 |
+
obj.delta_indices = None
|
1030 |
+
obj.episode_data_index = None
|
1031 |
+
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
1032 |
+
return obj
|
1033 |
+
|
1034 |
+
|
1035 |
+
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
1036 |
+
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
1037 |
+
|
1038 |
+
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
1039 |
+
structure of `LeRobotDataset`.
|
1040 |
+
"""
|
1041 |
+
|
1042 |
+
def __init__(
|
1043 |
+
self,
|
1044 |
+
repo_ids: list[str],
|
1045 |
+
root: str | Path | None = None,
|
1046 |
+
episodes: dict | None = None,
|
1047 |
+
image_transforms: Callable | None = None,
|
1048 |
+
delta_timestamps: dict[list[float]] | None = None,
|
1049 |
+
tolerances_s: dict | None = None,
|
1050 |
+
download_videos: bool = True,
|
1051 |
+
video_backend: str | None = None,
|
1052 |
+
):
|
1053 |
+
super().__init__()
|
1054 |
+
self.repo_ids = repo_ids
|
1055 |
+
self.root = Path(root) if root else HF_LEROBOT_HOME
|
1056 |
+
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
1057 |
+
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
1058 |
+
# are handled by this class.
|
1059 |
+
self._datasets = [
|
1060 |
+
LeRobotDataset(
|
1061 |
+
repo_id,
|
1062 |
+
root=self.root / repo_id,
|
1063 |
+
episodes=episodes[repo_id] if episodes else None,
|
1064 |
+
image_transforms=image_transforms,
|
1065 |
+
delta_timestamps=delta_timestamps,
|
1066 |
+
tolerance_s=self.tolerances_s[repo_id],
|
1067 |
+
download_videos=download_videos,
|
1068 |
+
video_backend=video_backend,
|
1069 |
+
)
|
1070 |
+
for repo_id in repo_ids
|
1071 |
+
]
|
1072 |
+
|
1073 |
+
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
1074 |
+
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
1075 |
+
# to use PyTorch's default DataLoader collate function.
|
1076 |
+
self.disabled_features = set()
|
1077 |
+
intersection_features = set(self._datasets[0].features)
|
1078 |
+
for ds in self._datasets:
|
1079 |
+
intersection_features.intersection_update(ds.features)
|
1080 |
+
if len(intersection_features) == 0:
|
1081 |
+
raise RuntimeError(
|
1082 |
+
"Multiple datasets were provided but they had no keys common to all of them. "
|
1083 |
+
"The multi-dataset functionality currently only keeps common keys."
|
1084 |
+
)
|
1085 |
+
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
1086 |
+
extra_keys = set(ds.features).difference(intersection_features)
|
1087 |
+
logging.warning(
|
1088 |
+
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
1089 |
+
"other datasets."
|
1090 |
+
)
|
1091 |
+
self.disabled_features.update(extra_keys)
|
1092 |
+
|
1093 |
+
self.image_transforms = image_transforms
|
1094 |
+
self.delta_timestamps = delta_timestamps
|
1095 |
+
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
1096 |
+
# with multiple robots of different ranges. Instead we should have one normalization
|
1097 |
+
# per robot.
|
1098 |
+
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
1099 |
+
|
1100 |
+
@property
|
1101 |
+
def repo_id_to_index(self):
|
1102 |
+
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
1103 |
+
|
1104 |
+
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
1105 |
+
"""
|
1106 |
+
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
1107 |
+
|
1108 |
+
@property
|
1109 |
+
def repo_index_to_id(self):
|
1110 |
+
"""Return the inverse mapping if repo_id_to_index."""
|
1111 |
+
return {v: k for k, v in self.repo_id_to_index}
|
1112 |
+
|
1113 |
+
@property
|
1114 |
+
def fps(self) -> int:
|
1115 |
+
"""Frames per second used during data collection.
|
1116 |
+
|
1117 |
+
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
1118 |
+
"""
|
1119 |
+
return self._datasets[0].meta.info["fps"]
|
1120 |
+
|
1121 |
+
@property
|
1122 |
+
def video(self) -> bool:
|
1123 |
+
"""Returns True if this dataset loads video frames from mp4 files.
|
1124 |
+
|
1125 |
+
Returns False if it only loads images from png files.
|
1126 |
+
|
1127 |
+
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
1128 |
+
"""
|
1129 |
+
return self._datasets[0].meta.info.get("video", False)
|
1130 |
+
|
1131 |
+
@property
|
1132 |
+
def features(self) -> datasets.Features:
|
1133 |
+
features = {}
|
1134 |
+
for dataset in self._datasets:
|
1135 |
+
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
1136 |
+
return features
|
1137 |
+
|
1138 |
+
@property
|
1139 |
+
def camera_keys(self) -> list[str]:
|
1140 |
+
"""Keys to access image and video stream from cameras."""
|
1141 |
+
keys = []
|
1142 |
+
for key, feats in self.features.items():
|
1143 |
+
if isinstance(feats, (datasets.Image, VideoFrame)):
|
1144 |
+
keys.append(key)
|
1145 |
+
return keys
|
1146 |
+
|
1147 |
+
@property
|
1148 |
+
def video_frame_keys(self) -> list[str]:
|
1149 |
+
"""Keys to access video frames that requires to be decoded into images.
|
1150 |
+
|
1151 |
+
Note: It is empty if the dataset contains images only,
|
1152 |
+
or equal to `self.cameras` if the dataset contains videos only,
|
1153 |
+
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
1154 |
+
"""
|
1155 |
+
video_frame_keys = []
|
1156 |
+
for key, feats in self.features.items():
|
1157 |
+
if isinstance(feats, VideoFrame):
|
1158 |
+
video_frame_keys.append(key)
|
1159 |
+
return video_frame_keys
|
1160 |
+
|
1161 |
+
@property
|
1162 |
+
def num_frames(self) -> int:
|
1163 |
+
"""Number of samples/frames."""
|
1164 |
+
return sum(d.num_frames for d in self._datasets)
|
1165 |
+
|
1166 |
+
@property
|
1167 |
+
def num_episodes(self) -> int:
|
1168 |
+
"""Number of episodes."""
|
1169 |
+
return sum(d.num_episodes for d in self._datasets)
|
1170 |
+
|
1171 |
+
@property
|
1172 |
+
def tolerance_s(self) -> float:
|
1173 |
+
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
1174 |
+
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
1175 |
+
is provided or when loading video frames from mp4 files.
|
1176 |
+
"""
|
1177 |
+
# 1e-4 to account for possible numerical error
|
1178 |
+
return 1 / self.fps - 1e-4
|
1179 |
+
|
1180 |
+
def __len__(self):
|
1181 |
+
return self.num_frames
|
1182 |
+
|
1183 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
1184 |
+
if idx >= len(self):
|
1185 |
+
raise IndexError(f"Index {idx} out of bounds.")
|
1186 |
+
# Determine which dataset to get an item from based on the index.
|
1187 |
+
start_idx = 0
|
1188 |
+
dataset_idx = 0
|
1189 |
+
for dataset in self._datasets:
|
1190 |
+
if idx >= start_idx + dataset.num_frames:
|
1191 |
+
start_idx += dataset.num_frames
|
1192 |
+
dataset_idx += 1
|
1193 |
+
continue
|
1194 |
+
break
|
1195 |
+
else:
|
1196 |
+
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
1197 |
+
item = self._datasets[dataset_idx][idx - start_idx]
|
1198 |
+
item["dataset_index"] = torch.tensor(dataset_idx)
|
1199 |
+
for data_key in self.disabled_features:
|
1200 |
+
if data_key in item:
|
1201 |
+
del item[data_key]
|
1202 |
+
|
1203 |
+
return item
|
1204 |
+
|
1205 |
+
def __repr__(self):
|
1206 |
+
return (
|
1207 |
+
f"{self.__class__.__name__}(\n"
|
1208 |
+
f" Repository IDs: '{self.repo_ids}',\n"
|
1209 |
+
f" Number of Samples: {self.num_frames},\n"
|
1210 |
+
f" Number of Episodes: {self.num_episodes},\n"
|
1211 |
+
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
1212 |
+
f" Recorded Frames per Second: {self.fps},\n"
|
1213 |
+
f" Camera Keys: {self.camera_keys},\n"
|
1214 |
+
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
1215 |
+
f" Transformations: {self.image_transforms},\n"
|
1216 |
+
f")"
|
1217 |
+
)
|
lerobot/common/datasets/online_buffer.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""An online buffer for the online training loop in train.py
|
17 |
+
|
18 |
+
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
|
19 |
+
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
|
20 |
+
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
|
21 |
+
supports in-place slicing and mutation which is very handy for a dynamic buffer.
|
22 |
+
"""
|
23 |
+
|
24 |
+
import os
|
25 |
+
from pathlib import Path
|
26 |
+
from typing import Any
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
|
31 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
32 |
+
|
33 |
+
|
34 |
+
def _make_memmap_safe(**kwargs) -> np.memmap:
|
35 |
+
"""Make a numpy memmap with checks on available disk space first.
|
36 |
+
|
37 |
+
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
|
38 |
+
|
39 |
+
For information on dtypes:
|
40 |
+
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
|
41 |
+
"""
|
42 |
+
if kwargs["mode"].startswith("w"):
|
43 |
+
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
|
44 |
+
stats = os.statvfs(Path(kwargs["filename"]).parent)
|
45 |
+
available_space = stats.f_bavail * stats.f_frsize # bytes
|
46 |
+
if required_space >= available_space * 0.8:
|
47 |
+
raise RuntimeError(
|
48 |
+
f"You're about to take up {required_space} of {available_space} bytes available."
|
49 |
+
)
|
50 |
+
return np.memmap(**kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
class OnlineBuffer(torch.utils.data.Dataset):
|
54 |
+
"""FIFO data buffer for the online training loop in train.py.
|
55 |
+
|
56 |
+
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
|
57 |
+
loop in the same way that a LeRobotDataset would be used.
|
58 |
+
|
59 |
+
The underlying data structure will have data inserted in a circular fashion. Always insert after the
|
60 |
+
last index, and when you reach the end, wrap around to the start.
|
61 |
+
|
62 |
+
The data is stored in a numpy memmap.
|
63 |
+
"""
|
64 |
+
|
65 |
+
NEXT_INDEX_KEY = "_next_index"
|
66 |
+
OCCUPANCY_MASK_KEY = "_occupancy_mask"
|
67 |
+
INDEX_KEY = "index"
|
68 |
+
FRAME_INDEX_KEY = "frame_index"
|
69 |
+
EPISODE_INDEX_KEY = "episode_index"
|
70 |
+
TIMESTAMP_KEY = "timestamp"
|
71 |
+
IS_PAD_POSTFIX = "_is_pad"
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
write_dir: str | Path,
|
76 |
+
data_spec: dict[str, Any] | None,
|
77 |
+
buffer_capacity: int | None,
|
78 |
+
fps: float | None = None,
|
79 |
+
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
|
80 |
+
):
|
81 |
+
"""
|
82 |
+
The online buffer can be provided from scratch or you can load an existing online buffer by passing
|
83 |
+
a `write_dir` associated with an existing buffer.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
|
87 |
+
Note that if the files already exist, they are opened in read-write mode (used for training
|
88 |
+
resumption.)
|
89 |
+
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
|
90 |
+
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
|
91 |
+
but note that "index", "frame_index" and "episode_index" are already accounted for by this
|
92 |
+
class, so you don't need to include them.
|
93 |
+
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
|
94 |
+
system's available disk space when choosing this.
|
95 |
+
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
|
96 |
+
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
|
97 |
+
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
|
98 |
+
converted to dict[str, np.ndarray] for optimization purposes.
|
99 |
+
|
100 |
+
"""
|
101 |
+
self.set_delta_timestamps(delta_timestamps)
|
102 |
+
self._fps = fps
|
103 |
+
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
|
104 |
+
# the requested frames. It is only used when `delta_timestamps` is provided.
|
105 |
+
# minus 1e-4 to account for possible numerical error
|
106 |
+
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
|
107 |
+
self._buffer_capacity = buffer_capacity
|
108 |
+
data_spec = self._make_data_spec(data_spec, buffer_capacity)
|
109 |
+
Path(write_dir).mkdir(parents=True, exist_ok=True)
|
110 |
+
self._data = {}
|
111 |
+
for k, v in data_spec.items():
|
112 |
+
self._data[k] = _make_memmap_safe(
|
113 |
+
filename=Path(write_dir) / k,
|
114 |
+
dtype=v["dtype"] if v is not None else None,
|
115 |
+
mode="r+" if (Path(write_dir) / k).exists() else "w+",
|
116 |
+
shape=tuple(v["shape"]) if v is not None else None,
|
117 |
+
)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
|
121 |
+
return self._delta_timestamps
|
122 |
+
|
123 |
+
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
|
124 |
+
"""Set delta_timestamps converting the values to numpy arrays.
|
125 |
+
|
126 |
+
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
|
127 |
+
need to be converted into numpy arrays.
|
128 |
+
"""
|
129 |
+
if value is not None:
|
130 |
+
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
|
131 |
+
else:
|
132 |
+
self._delta_timestamps = None
|
133 |
+
|
134 |
+
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
135 |
+
"""Makes the data spec for np.memmap."""
|
136 |
+
if any(k.startswith("_") for k in data_spec):
|
137 |
+
raise ValueError(
|
138 |
+
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
|
139 |
+
)
|
140 |
+
preset_keys = {
|
141 |
+
OnlineBuffer.INDEX_KEY,
|
142 |
+
OnlineBuffer.FRAME_INDEX_KEY,
|
143 |
+
OnlineBuffer.EPISODE_INDEX_KEY,
|
144 |
+
OnlineBuffer.TIMESTAMP_KEY,
|
145 |
+
}
|
146 |
+
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
|
147 |
+
raise ValueError(
|
148 |
+
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
|
149 |
+
f"The provided data_spec has {intersection}."
|
150 |
+
)
|
151 |
+
complete_data_spec = {
|
152 |
+
# _next_index will be a pointer to the next index that we should start filling from when we add
|
153 |
+
# more data.
|
154 |
+
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
155 |
+
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
156 |
+
# with real data rather than the dummy initialization.
|
157 |
+
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
158 |
+
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
159 |
+
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
160 |
+
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
161 |
+
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
162 |
+
}
|
163 |
+
for k, v in data_spec.items():
|
164 |
+
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
165 |
+
return complete_data_spec
|
166 |
+
|
167 |
+
def add_data(self, data: dict[str, np.ndarray]):
|
168 |
+
"""Add new data to the buffer, which could potentially mean shifting old data out.
|
169 |
+
|
170 |
+
The new data should contain all the frames (in order) of any number of episodes. The indices should
|
171 |
+
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
|
172 |
+
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
|
173 |
+
|
174 |
+
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
|
175 |
+
will be done in place!
|
176 |
+
"""
|
177 |
+
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
|
178 |
+
raise ValueError(f"Missing data keys: {missing_keys}")
|
179 |
+
new_data_length = len(data[self.data_keys[0]])
|
180 |
+
if not all(len(data[k]) == new_data_length for k in self.data_keys):
|
181 |
+
raise ValueError("All data items should have the same length")
|
182 |
+
|
183 |
+
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
|
184 |
+
|
185 |
+
# Sanity check to make sure that the new data indices start from 0.
|
186 |
+
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
|
187 |
+
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
188 |
+
|
189 |
+
# Shift the incoming indices if necessary.
|
190 |
+
if self.num_frames > 0:
|
191 |
+
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
192 |
+
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
193 |
+
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
194 |
+
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
195 |
+
|
196 |
+
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
|
197 |
+
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
|
198 |
+
for k in self.data_keys:
|
199 |
+
if n_surplus == 0:
|
200 |
+
slc = slice(next_index, next_index + new_data_length)
|
201 |
+
self._data[k][slc] = data[k]
|
202 |
+
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
|
203 |
+
else:
|
204 |
+
self._data[k][next_index:] = data[k][:-n_surplus]
|
205 |
+
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
|
206 |
+
self._data[k][:n_surplus] = data[k][-n_surplus:]
|
207 |
+
if n_surplus == 0:
|
208 |
+
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
|
209 |
+
else:
|
210 |
+
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
|
211 |
+
|
212 |
+
@property
|
213 |
+
def data_keys(self) -> list[str]:
|
214 |
+
keys = set(self._data)
|
215 |
+
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
|
216 |
+
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
|
217 |
+
return sorted(keys)
|
218 |
+
|
219 |
+
@property
|
220 |
+
def fps(self) -> float | None:
|
221 |
+
return self._fps
|
222 |
+
|
223 |
+
@property
|
224 |
+
def num_episodes(self) -> int:
|
225 |
+
return len(
|
226 |
+
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
227 |
+
)
|
228 |
+
|
229 |
+
@property
|
230 |
+
def num_frames(self) -> int:
|
231 |
+
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
232 |
+
|
233 |
+
def __len__(self):
|
234 |
+
return self.num_frames
|
235 |
+
|
236 |
+
def _item_to_tensors(self, item: dict) -> dict:
|
237 |
+
item_ = {}
|
238 |
+
for k, v in item.items():
|
239 |
+
if isinstance(v, torch.Tensor):
|
240 |
+
item_[k] = v
|
241 |
+
elif isinstance(v, np.ndarray):
|
242 |
+
item_[k] = torch.from_numpy(v)
|
243 |
+
else:
|
244 |
+
item_[k] = torch.tensor(v)
|
245 |
+
return item_
|
246 |
+
|
247 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
248 |
+
if idx >= len(self) or idx < -len(self):
|
249 |
+
raise IndexError
|
250 |
+
|
251 |
+
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
|
252 |
+
|
253 |
+
if self.delta_timestamps is None:
|
254 |
+
return self._item_to_tensors(item)
|
255 |
+
|
256 |
+
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
|
257 |
+
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
|
258 |
+
episode_data_indices = np.where(
|
259 |
+
np.bitwise_and(
|
260 |
+
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
|
261 |
+
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
262 |
+
)
|
263 |
+
)[0]
|
264 |
+
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
265 |
+
|
266 |
+
for data_key in self.delta_timestamps:
|
267 |
+
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
268 |
+
# Get timestamps used as query to retrieve data of previous/future frames.
|
269 |
+
query_ts = current_ts + self.delta_timestamps[data_key]
|
270 |
+
|
271 |
+
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
|
272 |
+
# the episode.
|
273 |
+
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
|
274 |
+
argmin_ = np.argmin(dist, axis=1)
|
275 |
+
min_ = dist[np.arange(dist.shape[0]), argmin_]
|
276 |
+
|
277 |
+
is_pad = min_ > self.tolerance_s
|
278 |
+
|
279 |
+
# Check violated query timestamps are all outside the episode range.
|
280 |
+
assert (
|
281 |
+
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
282 |
+
).all(), (
|
283 |
+
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
284 |
+
") inside the episode range."
|
285 |
+
)
|
286 |
+
|
287 |
+
# Load frames for this data key.
|
288 |
+
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
|
289 |
+
|
290 |
+
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
|
291 |
+
|
292 |
+
return self._item_to_tensors(item)
|
293 |
+
|
294 |
+
def get_data_by_key(self, key: str) -> torch.Tensor:
|
295 |
+
"""Returns all data for a given data key as a Tensor."""
|
296 |
+
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
297 |
+
|
298 |
+
|
299 |
+
def compute_sampler_weights(
|
300 |
+
offline_dataset: LeRobotDataset,
|
301 |
+
offline_drop_n_last_frames: int = 0,
|
302 |
+
online_dataset: OnlineBuffer | None = None,
|
303 |
+
online_sampling_ratio: float | None = None,
|
304 |
+
online_drop_n_last_frames: int = 0,
|
305 |
+
) -> torch.Tensor:
|
306 |
+
"""Compute the sampling weights for the online training dataloader in train.py.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
offline_dataset: The LeRobotDataset used for offline pre-training.
|
310 |
+
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
|
311 |
+
online_dataset: The OnlineBuffer used in online training.
|
312 |
+
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
|
313 |
+
online dataset is provided, this value must also be provided.
|
314 |
+
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
|
315 |
+
dataset.
|
316 |
+
Returns:
|
317 |
+
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
|
318 |
+
|
319 |
+
Notes to maintainers:
|
320 |
+
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
|
321 |
+
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
|
322 |
+
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
|
323 |
+
is the ability to turn shuffling off.
|
324 |
+
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
325 |
+
included here to avoid adding complexity.
|
326 |
+
"""
|
327 |
+
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
328 |
+
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
329 |
+
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
330 |
+
raise ValueError(
|
331 |
+
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
332 |
+
)
|
333 |
+
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
334 |
+
|
335 |
+
weights = []
|
336 |
+
|
337 |
+
if len(offline_dataset) > 0:
|
338 |
+
offline_data_mask_indices = []
|
339 |
+
for start_index, end_index in zip(
|
340 |
+
offline_dataset.episode_data_index["from"],
|
341 |
+
offline_dataset.episode_data_index["to"],
|
342 |
+
strict=True,
|
343 |
+
):
|
344 |
+
offline_data_mask_indices.extend(
|
345 |
+
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
346 |
+
)
|
347 |
+
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
348 |
+
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
349 |
+
weights.append(
|
350 |
+
torch.full(
|
351 |
+
size=(len(offline_dataset),),
|
352 |
+
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
|
353 |
+
)
|
354 |
+
* offline_data_mask
|
355 |
+
)
|
356 |
+
|
357 |
+
if online_dataset is not None and len(online_dataset) > 0:
|
358 |
+
online_data_mask_indices = []
|
359 |
+
episode_indices = online_dataset.get_data_by_key("episode_index")
|
360 |
+
for episode_idx in torch.unique(episode_indices):
|
361 |
+
where_episode = torch.where(episode_indices == episode_idx)
|
362 |
+
start_index = where_episode[0][0]
|
363 |
+
end_index = where_episode[0][-1] + 1
|
364 |
+
online_data_mask_indices.extend(
|
365 |
+
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
|
366 |
+
)
|
367 |
+
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
|
368 |
+
online_data_mask[torch.tensor(online_data_mask_indices)] = True
|
369 |
+
weights.append(
|
370 |
+
torch.full(
|
371 |
+
size=(len(online_dataset),),
|
372 |
+
fill_value=online_sampling_ratio / online_data_mask.sum(),
|
373 |
+
)
|
374 |
+
* online_data_mask
|
375 |
+
)
|
376 |
+
|
377 |
+
weights = torch.cat(weights)
|
378 |
+
|
379 |
+
if weights.sum() == 0:
|
380 |
+
weights += 1 / len(weights)
|
381 |
+
else:
|
382 |
+
weights /= weights.sum()
|
383 |
+
|
384 |
+
return weights
|
lerobot/common/datasets/push_dataset_to_hub/utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import inspect
|
17 |
+
from concurrent.futures import ThreadPoolExecutor
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Dict
|
20 |
+
|
21 |
+
import datasets
|
22 |
+
import numpy
|
23 |
+
import PIL
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from lerobot.common.datasets.video_utils import encode_video_frames
|
27 |
+
|
28 |
+
|
29 |
+
def concatenate_episodes(ep_dicts):
|
30 |
+
data_dict = {}
|
31 |
+
|
32 |
+
keys = ep_dicts[0].keys()
|
33 |
+
for key in keys:
|
34 |
+
if torch.is_tensor(ep_dicts[0][key][0]):
|
35 |
+
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
36 |
+
else:
|
37 |
+
if key not in data_dict:
|
38 |
+
data_dict[key] = []
|
39 |
+
for ep_dict in ep_dicts:
|
40 |
+
for x in ep_dict[key]:
|
41 |
+
data_dict[key].append(x)
|
42 |
+
|
43 |
+
total_frames = data_dict["frame_index"].shape[0]
|
44 |
+
data_dict["index"] = torch.arange(0, total_frames, 1)
|
45 |
+
return data_dict
|
46 |
+
|
47 |
+
|
48 |
+
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
49 |
+
out_dir = Path(out_dir)
|
50 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
51 |
+
|
52 |
+
def save_image(img_array, i, out_dir):
|
53 |
+
img = PIL.Image.fromarray(img_array)
|
54 |
+
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
|
55 |
+
|
56 |
+
num_images = len(imgs_array)
|
57 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
58 |
+
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
59 |
+
|
60 |
+
|
61 |
+
def get_default_encoding() -> dict:
|
62 |
+
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
|
63 |
+
signature = inspect.signature(encode_video_frames)
|
64 |
+
return {
|
65 |
+
k: v.default
|
66 |
+
for k, v in signature.parameters.items()
|
67 |
+
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def check_repo_id(repo_id: str) -> None:
|
72 |
+
if len(repo_id.split("/")) != 2:
|
73 |
+
raise ValueError(
|
74 |
+
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
75 |
+
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
# TODO(aliberts): remove
|
80 |
+
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
81 |
+
"""
|
82 |
+
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
89 |
+
- "from": A tensor containing the starting index of each episode.
|
90 |
+
- "to": A tensor containing the ending index of each episode.
|
91 |
+
"""
|
92 |
+
episode_data_index = {"from": [], "to": []}
|
93 |
+
|
94 |
+
current_episode = None
|
95 |
+
"""
|
96 |
+
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
97 |
+
For instance, the following is a valid episode_index:
|
98 |
+
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
99 |
+
|
100 |
+
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
101 |
+
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
102 |
+
{
|
103 |
+
"from": [0, 3, 7],
|
104 |
+
"to": [3, 7, 12]
|
105 |
+
}
|
106 |
+
"""
|
107 |
+
if len(hf_dataset) == 0:
|
108 |
+
episode_data_index = {
|
109 |
+
"from": torch.tensor([]),
|
110 |
+
"to": torch.tensor([]),
|
111 |
+
}
|
112 |
+
return episode_data_index
|
113 |
+
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
114 |
+
if episode_idx != current_episode:
|
115 |
+
# We encountered a new episode, so we append its starting location to the "from" list
|
116 |
+
episode_data_index["from"].append(idx)
|
117 |
+
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
118 |
+
if current_episode is not None:
|
119 |
+
episode_data_index["to"].append(idx)
|
120 |
+
# Let's keep track of the current episode index
|
121 |
+
current_episode = episode_idx
|
122 |
+
else:
|
123 |
+
# We are still in the same episode, so there is nothing for us to do here
|
124 |
+
pass
|
125 |
+
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
126 |
+
episode_data_index["to"].append(idx + 1)
|
127 |
+
|
128 |
+
for k in ["from", "to"]:
|
129 |
+
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
130 |
+
|
131 |
+
return episode_data_index
|
lerobot/common/datasets/sampler.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from typing import Iterator, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
class EpisodeAwareSampler:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
episode_data_index: dict,
|
25 |
+
episode_indices_to_use: Union[list, None] = None,
|
26 |
+
drop_n_first_frames: int = 0,
|
27 |
+
drop_n_last_frames: int = 0,
|
28 |
+
shuffle: bool = False,
|
29 |
+
):
|
30 |
+
"""Sampler that optionally incorporates episode boundary information.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
34 |
+
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
35 |
+
Assumes that episodes are indexed from 0 to N-1.
|
36 |
+
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
37 |
+
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
38 |
+
shuffle: Whether to shuffle the indices.
|
39 |
+
"""
|
40 |
+
indices = []
|
41 |
+
for episode_idx, (start_index, end_index) in enumerate(
|
42 |
+
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
43 |
+
):
|
44 |
+
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
45 |
+
indices.extend(
|
46 |
+
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
47 |
+
)
|
48 |
+
|
49 |
+
self.indices = indices
|
50 |
+
self.shuffle = shuffle
|
51 |
+
|
52 |
+
def __iter__(self) -> Iterator[int]:
|
53 |
+
if self.shuffle:
|
54 |
+
for i in torch.randperm(len(self.indices)):
|
55 |
+
yield self.indices[i]
|
56 |
+
else:
|
57 |
+
for i in self.indices:
|
58 |
+
yield i
|
59 |
+
|
60 |
+
def __len__(self) -> int:
|
61 |
+
return len(self.indices)
|
lerobot/common/datasets/transforms.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import collections
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import Any, Callable, Sequence
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torchvision.transforms import v2
|
22 |
+
from torchvision.transforms.v2 import Transform
|
23 |
+
from torchvision.transforms.v2 import functional as F # noqa: N812
|
24 |
+
|
25 |
+
|
26 |
+
class RandomSubsetApply(Transform):
|
27 |
+
"""Apply a random subset of N transformations from a list of transformations.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
transforms: list of transformations.
|
31 |
+
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
|
32 |
+
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
|
33 |
+
have the same probability.
|
34 |
+
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
|
35 |
+
Must be in [1, len(transforms)].
|
36 |
+
random_order: apply transformations in a random order.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
transforms: Sequence[Callable],
|
42 |
+
p: list[float] | None = None,
|
43 |
+
n_subset: int | None = None,
|
44 |
+
random_order: bool = False,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
if not isinstance(transforms, Sequence):
|
48 |
+
raise TypeError("Argument transforms should be a sequence of callables")
|
49 |
+
if p is None:
|
50 |
+
p = [1] * len(transforms)
|
51 |
+
elif len(p) != len(transforms):
|
52 |
+
raise ValueError(
|
53 |
+
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
|
54 |
+
)
|
55 |
+
|
56 |
+
if n_subset is None:
|
57 |
+
n_subset = len(transforms)
|
58 |
+
elif not isinstance(n_subset, int):
|
59 |
+
raise TypeError("n_subset should be an int or None")
|
60 |
+
elif not (1 <= n_subset <= len(transforms)):
|
61 |
+
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
62 |
+
|
63 |
+
self.transforms = transforms
|
64 |
+
total = sum(p)
|
65 |
+
self.p = [prob / total for prob in p]
|
66 |
+
self.n_subset = n_subset
|
67 |
+
self.random_order = random_order
|
68 |
+
|
69 |
+
self.selected_transforms = None
|
70 |
+
|
71 |
+
def forward(self, *inputs: Any) -> Any:
|
72 |
+
needs_unpacking = len(inputs) > 1
|
73 |
+
|
74 |
+
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
|
75 |
+
if not self.random_order:
|
76 |
+
selected_indices = selected_indices.sort().values
|
77 |
+
|
78 |
+
self.selected_transforms = [self.transforms[i] for i in selected_indices]
|
79 |
+
|
80 |
+
for transform in self.selected_transforms:
|
81 |
+
outputs = transform(*inputs)
|
82 |
+
inputs = outputs if needs_unpacking else (outputs,)
|
83 |
+
|
84 |
+
return outputs
|
85 |
+
|
86 |
+
def extra_repr(self) -> str:
|
87 |
+
return (
|
88 |
+
f"transforms={self.transforms}, "
|
89 |
+
f"p={self.p}, "
|
90 |
+
f"n_subset={self.n_subset}, "
|
91 |
+
f"random_order={self.random_order}"
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
class SharpnessJitter(Transform):
|
96 |
+
"""Randomly change the sharpness of an image or video.
|
97 |
+
|
98 |
+
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
|
99 |
+
While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
|
100 |
+
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
|
101 |
+
augmentations as a result.
|
102 |
+
|
103 |
+
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
|
104 |
+
by a factor of 2.
|
105 |
+
|
106 |
+
If the input is a :class:`torch.Tensor`,
|
107 |
+
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
|
111 |
+
[max(0, 1 - sharpness), 1 + sharpness] or the given
|
112 |
+
[min, max]. Should be non negative numbers.
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, sharpness: float | Sequence[float]) -> None:
|
116 |
+
super().__init__()
|
117 |
+
self.sharpness = self._check_input(sharpness)
|
118 |
+
|
119 |
+
def _check_input(self, sharpness):
|
120 |
+
if isinstance(sharpness, (int, float)):
|
121 |
+
if sharpness < 0:
|
122 |
+
raise ValueError("If sharpness is a single number, it must be non negative.")
|
123 |
+
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
124 |
+
sharpness[0] = max(sharpness[0], 0.0)
|
125 |
+
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
126 |
+
sharpness = [float(v) for v in sharpness]
|
127 |
+
else:
|
128 |
+
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
129 |
+
|
130 |
+
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
131 |
+
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
132 |
+
|
133 |
+
return float(sharpness[0]), float(sharpness[1])
|
134 |
+
|
135 |
+
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
136 |
+
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
137 |
+
return {"sharpness_factor": sharpness_factor}
|
138 |
+
|
139 |
+
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
140 |
+
sharpness_factor = params["sharpness_factor"]
|
141 |
+
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
142 |
+
|
143 |
+
|
144 |
+
@dataclass
|
145 |
+
class ImageTransformConfig:
|
146 |
+
"""
|
147 |
+
For each transform, the following parameters are available:
|
148 |
+
weight: This represents the multinomial probability (with no replacement)
|
149 |
+
used for sampling the transform. If the sum of the weights is not 1,
|
150 |
+
they will be normalized.
|
151 |
+
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
|
152 |
+
custom transform defined here.
|
153 |
+
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
|
154 |
+
(following uniform distribution) when it's applied.
|
155 |
+
"""
|
156 |
+
|
157 |
+
weight: float = 1.0
|
158 |
+
type: str = "Identity"
|
159 |
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
160 |
+
|
161 |
+
|
162 |
+
@dataclass
|
163 |
+
class ImageTransformsConfig:
|
164 |
+
"""
|
165 |
+
These transforms are all using standard torchvision.transforms.v2
|
166 |
+
You can find out how these transformations affect images here:
|
167 |
+
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
168 |
+
We use a custom RandomSubsetApply container to sample them.
|
169 |
+
"""
|
170 |
+
|
171 |
+
# Set this flag to `true` to enable transforms during training
|
172 |
+
enable: bool = False
|
173 |
+
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
174 |
+
# It's an integer in the interval [1, number_of_available_transforms].
|
175 |
+
max_num_transforms: int = 3
|
176 |
+
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
177 |
+
# Set this to True to apply them in a random order.
|
178 |
+
random_order: bool = False
|
179 |
+
tfs: dict[str, ImageTransformConfig] = field(
|
180 |
+
default_factory=lambda: {
|
181 |
+
"brightness": ImageTransformConfig(
|
182 |
+
weight=1.0,
|
183 |
+
type="ColorJitter",
|
184 |
+
kwargs={"brightness": (0.8, 1.2)},
|
185 |
+
),
|
186 |
+
"contrast": ImageTransformConfig(
|
187 |
+
weight=1.0,
|
188 |
+
type="ColorJitter",
|
189 |
+
kwargs={"contrast": (0.8, 1.2)},
|
190 |
+
),
|
191 |
+
"saturation": ImageTransformConfig(
|
192 |
+
weight=1.0,
|
193 |
+
type="ColorJitter",
|
194 |
+
kwargs={"saturation": (0.5, 1.5)},
|
195 |
+
),
|
196 |
+
"hue": ImageTransformConfig(
|
197 |
+
weight=1.0,
|
198 |
+
type="ColorJitter",
|
199 |
+
kwargs={"hue": (-0.05, 0.05)},
|
200 |
+
),
|
201 |
+
"sharpness": ImageTransformConfig(
|
202 |
+
weight=1.0,
|
203 |
+
type="SharpnessJitter",
|
204 |
+
kwargs={"sharpness": (0.5, 1.5)},
|
205 |
+
),
|
206 |
+
}
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
def make_transform_from_config(cfg: ImageTransformConfig):
|
211 |
+
if cfg.type == "Identity":
|
212 |
+
return v2.Identity(**cfg.kwargs)
|
213 |
+
elif cfg.type == "ColorJitter":
|
214 |
+
return v2.ColorJitter(**cfg.kwargs)
|
215 |
+
elif cfg.type == "SharpnessJitter":
|
216 |
+
return SharpnessJitter(**cfg.kwargs)
|
217 |
+
else:
|
218 |
+
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
219 |
+
|
220 |
+
|
221 |
+
class ImageTransforms(Transform):
|
222 |
+
"""A class to compose image transforms based on configuration."""
|
223 |
+
|
224 |
+
def __init__(self, cfg: ImageTransformsConfig) -> None:
|
225 |
+
super().__init__()
|
226 |
+
self._cfg = cfg
|
227 |
+
|
228 |
+
self.weights = []
|
229 |
+
self.transforms = {}
|
230 |
+
for tf_name, tf_cfg in cfg.tfs.items():
|
231 |
+
if tf_cfg.weight <= 0.0:
|
232 |
+
continue
|
233 |
+
|
234 |
+
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
|
235 |
+
self.weights.append(tf_cfg.weight)
|
236 |
+
|
237 |
+
n_subset = min(len(self.transforms), cfg.max_num_transforms)
|
238 |
+
if n_subset == 0 or not cfg.enable:
|
239 |
+
self.tf = v2.Identity()
|
240 |
+
else:
|
241 |
+
self.tf = RandomSubsetApply(
|
242 |
+
transforms=list(self.transforms.values()),
|
243 |
+
p=self.weights,
|
244 |
+
n_subset=n_subset,
|
245 |
+
random_order=cfg.random_order,
|
246 |
+
)
|
247 |
+
|
248 |
+
def forward(self, *inputs: Any) -> Any:
|
249 |
+
return self.tf(*inputs)
|
lerobot/common/datasets/utils.py
ADDED
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import contextlib
|
17 |
+
import importlib.resources
|
18 |
+
import json
|
19 |
+
import logging
|
20 |
+
from collections.abc import Iterator
|
21 |
+
from itertools import accumulate
|
22 |
+
from pathlib import Path
|
23 |
+
from pprint import pformat
|
24 |
+
from types import SimpleNamespace
|
25 |
+
from typing import Any
|
26 |
+
|
27 |
+
import datasets
|
28 |
+
import jsonlines
|
29 |
+
import numpy as np
|
30 |
+
import packaging.version
|
31 |
+
import torch
|
32 |
+
from datasets.table import embed_table_storage
|
33 |
+
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
34 |
+
from huggingface_hub.errors import RevisionNotFoundError
|
35 |
+
from PIL import Image as PILImage
|
36 |
+
from torchvision import transforms
|
37 |
+
|
38 |
+
from lerobot.common.datasets.backward_compatibility import (
|
39 |
+
V21_MESSAGE,
|
40 |
+
BackwardCompatibilityError,
|
41 |
+
ForwardCompatibilityError,
|
42 |
+
)
|
43 |
+
from lerobot.common.robot_devices.robots.utils import Robot
|
44 |
+
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
45 |
+
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
46 |
+
|
47 |
+
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
48 |
+
|
49 |
+
INFO_PATH = "meta/info.json"
|
50 |
+
EPISODES_PATH = "meta/episodes.jsonl"
|
51 |
+
STATS_PATH = "meta/stats.json"
|
52 |
+
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
53 |
+
TASKS_PATH = "meta/tasks.jsonl"
|
54 |
+
|
55 |
+
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
56 |
+
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
57 |
+
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
58 |
+
|
59 |
+
DATASET_CARD_TEMPLATE = """
|
60 |
+
---
|
61 |
+
# Metadata will go there
|
62 |
+
---
|
63 |
+
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
64 |
+
|
65 |
+
## {}
|
66 |
+
|
67 |
+
"""
|
68 |
+
|
69 |
+
DEFAULT_FEATURES = {
|
70 |
+
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
71 |
+
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
72 |
+
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
73 |
+
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
74 |
+
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
79 |
+
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
80 |
+
|
81 |
+
For example:
|
82 |
+
```
|
83 |
+
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
84 |
+
>>> print(flatten_dict(dct))
|
85 |
+
{"a/b": 1, "a/c/d": 2, "e": 3}
|
86 |
+
"""
|
87 |
+
items = []
|
88 |
+
for k, v in d.items():
|
89 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
90 |
+
if isinstance(v, dict):
|
91 |
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
92 |
+
else:
|
93 |
+
items.append((new_key, v))
|
94 |
+
return dict(items)
|
95 |
+
|
96 |
+
|
97 |
+
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
98 |
+
outdict = {}
|
99 |
+
for key, value in d.items():
|
100 |
+
parts = key.split(sep)
|
101 |
+
d = outdict
|
102 |
+
for part in parts[:-1]:
|
103 |
+
if part not in d:
|
104 |
+
d[part] = {}
|
105 |
+
d = d[part]
|
106 |
+
d[parts[-1]] = value
|
107 |
+
return outdict
|
108 |
+
|
109 |
+
|
110 |
+
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
111 |
+
split_keys = flattened_key.split(sep)
|
112 |
+
getter = obj[split_keys[0]]
|
113 |
+
if len(split_keys) == 1:
|
114 |
+
return getter
|
115 |
+
|
116 |
+
for key in split_keys[1:]:
|
117 |
+
getter = getter[key]
|
118 |
+
|
119 |
+
return getter
|
120 |
+
|
121 |
+
|
122 |
+
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
123 |
+
serialized_dict = {}
|
124 |
+
for key, value in flatten_dict(stats).items():
|
125 |
+
if isinstance(value, (torch.Tensor, np.ndarray)):
|
126 |
+
serialized_dict[key] = value.tolist()
|
127 |
+
elif isinstance(value, np.generic):
|
128 |
+
serialized_dict[key] = value.item()
|
129 |
+
elif isinstance(value, (int, float)):
|
130 |
+
serialized_dict[key] = value
|
131 |
+
else:
|
132 |
+
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
133 |
+
return unflatten_dict(serialized_dict)
|
134 |
+
|
135 |
+
|
136 |
+
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
137 |
+
# Embed image bytes into the table before saving to parquet
|
138 |
+
format = dataset.format
|
139 |
+
dataset = dataset.with_format("arrow")
|
140 |
+
dataset = dataset.map(embed_table_storage, batched=False)
|
141 |
+
dataset = dataset.with_format(**format)
|
142 |
+
return dataset
|
143 |
+
|
144 |
+
|
145 |
+
def load_json(fpath: Path) -> Any:
|
146 |
+
with open(fpath) as f:
|
147 |
+
return json.load(f)
|
148 |
+
|
149 |
+
|
150 |
+
def write_json(data: dict, fpath: Path) -> None:
|
151 |
+
fpath.parent.mkdir(exist_ok=True, parents=True)
|
152 |
+
with open(fpath, "w") as f:
|
153 |
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
154 |
+
|
155 |
+
|
156 |
+
def load_jsonlines(fpath: Path) -> list[Any]:
|
157 |
+
with jsonlines.open(fpath, "r") as reader:
|
158 |
+
return list(reader)
|
159 |
+
|
160 |
+
|
161 |
+
def write_jsonlines(data: dict, fpath: Path) -> None:
|
162 |
+
fpath.parent.mkdir(exist_ok=True, parents=True)
|
163 |
+
with jsonlines.open(fpath, "w") as writer:
|
164 |
+
writer.write_all(data)
|
165 |
+
|
166 |
+
|
167 |
+
def append_jsonlines(data: dict, fpath: Path) -> None:
|
168 |
+
fpath.parent.mkdir(exist_ok=True, parents=True)
|
169 |
+
with jsonlines.open(fpath, "a") as writer:
|
170 |
+
writer.write(data)
|
171 |
+
|
172 |
+
|
173 |
+
def write_info(info: dict, local_dir: Path):
|
174 |
+
write_json(info, local_dir / INFO_PATH)
|
175 |
+
|
176 |
+
|
177 |
+
def load_info(local_dir: Path) -> dict:
|
178 |
+
info = load_json(local_dir / INFO_PATH)
|
179 |
+
for ft in info["features"].values():
|
180 |
+
ft["shape"] = tuple(ft["shape"])
|
181 |
+
return info
|
182 |
+
|
183 |
+
|
184 |
+
def write_stats(stats: dict, local_dir: Path):
|
185 |
+
serialized_stats = serialize_dict(stats)
|
186 |
+
write_json(serialized_stats, local_dir / STATS_PATH)
|
187 |
+
|
188 |
+
|
189 |
+
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
190 |
+
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
191 |
+
return unflatten_dict(stats)
|
192 |
+
|
193 |
+
|
194 |
+
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
195 |
+
if not (local_dir / STATS_PATH).exists():
|
196 |
+
return None
|
197 |
+
stats = load_json(local_dir / STATS_PATH)
|
198 |
+
return cast_stats_to_numpy(stats)
|
199 |
+
|
200 |
+
|
201 |
+
def write_task(task_index: int, task: dict, local_dir: Path):
|
202 |
+
task_dict = {
|
203 |
+
"task_index": task_index,
|
204 |
+
"task": task,
|
205 |
+
}
|
206 |
+
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
207 |
+
|
208 |
+
|
209 |
+
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
210 |
+
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
211 |
+
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
212 |
+
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
213 |
+
return tasks, task_to_task_index
|
214 |
+
|
215 |
+
|
216 |
+
def write_episode(episode: dict, local_dir: Path):
|
217 |
+
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
218 |
+
|
219 |
+
|
220 |
+
def load_episodes(local_dir: Path) -> dict:
|
221 |
+
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
222 |
+
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
223 |
+
|
224 |
+
|
225 |
+
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
226 |
+
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
227 |
+
# is a dictionary of stats and not an integer.
|
228 |
+
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
229 |
+
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
230 |
+
|
231 |
+
|
232 |
+
def load_episodes_stats(local_dir: Path) -> dict:
|
233 |
+
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
234 |
+
return {
|
235 |
+
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
236 |
+
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
def backward_compatible_episodes_stats(
|
241 |
+
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
242 |
+
) -> dict[str, dict[str, np.ndarray]]:
|
243 |
+
return dict.fromkeys(episodes, stats)
|
244 |
+
|
245 |
+
|
246 |
+
def load_image_as_numpy(
|
247 |
+
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
248 |
+
) -> np.ndarray:
|
249 |
+
img = PILImage.open(fpath).convert("RGB")
|
250 |
+
img_array = np.array(img, dtype=dtype)
|
251 |
+
if channel_first: # (H, W, C) -> (C, H, W)
|
252 |
+
img_array = np.transpose(img_array, (2, 0, 1))
|
253 |
+
if np.issubdtype(dtype, np.floating):
|
254 |
+
img_array /= 255.0
|
255 |
+
return img_array
|
256 |
+
|
257 |
+
|
258 |
+
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
259 |
+
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
260 |
+
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
261 |
+
a channel last representation (h w c) of uint8 type, to a torch image representation
|
262 |
+
with channel first (c h w) of float32 type in range [0,1].
|
263 |
+
"""
|
264 |
+
for key in items_dict:
|
265 |
+
first_item = items_dict[key][0]
|
266 |
+
if isinstance(first_item, PILImage.Image):
|
267 |
+
to_tensor = transforms.ToTensor()
|
268 |
+
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
269 |
+
elif first_item is None:
|
270 |
+
pass
|
271 |
+
else:
|
272 |
+
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
273 |
+
return items_dict
|
274 |
+
|
275 |
+
|
276 |
+
def is_valid_version(version: str) -> bool:
|
277 |
+
try:
|
278 |
+
packaging.version.parse(version)
|
279 |
+
return True
|
280 |
+
except packaging.version.InvalidVersion:
|
281 |
+
return False
|
282 |
+
|
283 |
+
|
284 |
+
def check_version_compatibility(
|
285 |
+
repo_id: str,
|
286 |
+
version_to_check: str | packaging.version.Version,
|
287 |
+
current_version: str | packaging.version.Version,
|
288 |
+
enforce_breaking_major: bool = True,
|
289 |
+
) -> None:
|
290 |
+
v_check = (
|
291 |
+
packaging.version.parse(version_to_check)
|
292 |
+
if not isinstance(version_to_check, packaging.version.Version)
|
293 |
+
else version_to_check
|
294 |
+
)
|
295 |
+
v_current = (
|
296 |
+
packaging.version.parse(current_version)
|
297 |
+
if not isinstance(current_version, packaging.version.Version)
|
298 |
+
else current_version
|
299 |
+
)
|
300 |
+
if v_check.major < v_current.major and enforce_breaking_major:
|
301 |
+
raise BackwardCompatibilityError(repo_id, v_check)
|
302 |
+
elif v_check.minor < v_current.minor:
|
303 |
+
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
304 |
+
|
305 |
+
|
306 |
+
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
307 |
+
"""Returns available valid versions (branches and tags) on given repo."""
|
308 |
+
api = HfApi()
|
309 |
+
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
310 |
+
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
311 |
+
repo_versions = []
|
312 |
+
for ref in repo_refs:
|
313 |
+
with contextlib.suppress(packaging.version.InvalidVersion):
|
314 |
+
repo_versions.append(packaging.version.parse(ref))
|
315 |
+
|
316 |
+
return repo_versions
|
317 |
+
|
318 |
+
|
319 |
+
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
320 |
+
"""
|
321 |
+
Returns the version if available on repo or the latest compatible one.
|
322 |
+
Otherwise, will throw a `CompatibilityError`.
|
323 |
+
"""
|
324 |
+
target_version = (
|
325 |
+
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
326 |
+
)
|
327 |
+
hub_versions = get_repo_versions(repo_id)
|
328 |
+
|
329 |
+
if not hub_versions:
|
330 |
+
raise RevisionNotFoundError(
|
331 |
+
f"""Your dataset must be tagged with a codebase version.
|
332 |
+
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
333 |
+
```python
|
334 |
+
from huggingface_hub import HfApi
|
335 |
+
|
336 |
+
hub_api = HfApi()
|
337 |
+
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
338 |
+
```
|
339 |
+
"""
|
340 |
+
)
|
341 |
+
|
342 |
+
if target_version in hub_versions:
|
343 |
+
return f"v{target_version}"
|
344 |
+
|
345 |
+
compatibles = [
|
346 |
+
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
347 |
+
]
|
348 |
+
if compatibles:
|
349 |
+
return_version = max(compatibles)
|
350 |
+
if return_version < target_version:
|
351 |
+
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
352 |
+
return f"v{return_version}"
|
353 |
+
|
354 |
+
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
355 |
+
if lower_major:
|
356 |
+
raise BackwardCompatibilityError(repo_id, max(lower_major))
|
357 |
+
|
358 |
+
upper_versions = [v for v in hub_versions if v > target_version]
|
359 |
+
assert len(upper_versions) > 0
|
360 |
+
raise ForwardCompatibilityError(repo_id, min(upper_versions))
|
361 |
+
|
362 |
+
|
363 |
+
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
364 |
+
hf_features = {}
|
365 |
+
for key, ft in features.items():
|
366 |
+
if ft["dtype"] == "video":
|
367 |
+
continue
|
368 |
+
elif ft["dtype"] == "image":
|
369 |
+
hf_features[key] = datasets.Image()
|
370 |
+
elif ft["shape"] == (1,):
|
371 |
+
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
372 |
+
elif len(ft["shape"]) == 1:
|
373 |
+
hf_features[key] = datasets.Sequence(
|
374 |
+
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
375 |
+
)
|
376 |
+
elif len(ft["shape"]) == 2:
|
377 |
+
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
378 |
+
elif len(ft["shape"]) == 3:
|
379 |
+
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
380 |
+
elif len(ft["shape"]) == 4:
|
381 |
+
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
382 |
+
elif len(ft["shape"]) == 5:
|
383 |
+
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
384 |
+
else:
|
385 |
+
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
386 |
+
|
387 |
+
return datasets.Features(hf_features)
|
388 |
+
|
389 |
+
|
390 |
+
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
391 |
+
camera_ft = {}
|
392 |
+
if robot.cameras:
|
393 |
+
camera_ft = {
|
394 |
+
key: {"dtype": "video" if use_videos else "image", **ft}
|
395 |
+
for key, ft in robot.camera_features.items()
|
396 |
+
}
|
397 |
+
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
398 |
+
|
399 |
+
|
400 |
+
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
401 |
+
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
402 |
+
policy_features = {}
|
403 |
+
for key, ft in features.items():
|
404 |
+
shape = ft["shape"]
|
405 |
+
if ft["dtype"] in ["image", "video"]:
|
406 |
+
type = FeatureType.VISUAL
|
407 |
+
if len(shape) != 3:
|
408 |
+
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
409 |
+
|
410 |
+
names = ft["names"]
|
411 |
+
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
412 |
+
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
413 |
+
shape = (shape[2], shape[0], shape[1])
|
414 |
+
elif key == "observation.environment_state":
|
415 |
+
type = FeatureType.ENV
|
416 |
+
elif key.startswith("observation"):
|
417 |
+
type = FeatureType.STATE
|
418 |
+
elif key == "action":
|
419 |
+
type = FeatureType.ACTION
|
420 |
+
else:
|
421 |
+
continue
|
422 |
+
|
423 |
+
policy_features[key] = PolicyFeature(
|
424 |
+
type=type,
|
425 |
+
shape=shape,
|
426 |
+
)
|
427 |
+
|
428 |
+
return policy_features
|
429 |
+
|
430 |
+
|
431 |
+
def create_empty_dataset_info(
|
432 |
+
codebase_version: str,
|
433 |
+
fps: int,
|
434 |
+
robot_type: str,
|
435 |
+
features: dict,
|
436 |
+
use_videos: bool,
|
437 |
+
) -> dict:
|
438 |
+
return {
|
439 |
+
"codebase_version": codebase_version,
|
440 |
+
"robot_type": robot_type,
|
441 |
+
"total_episodes": 0,
|
442 |
+
"total_frames": 0,
|
443 |
+
"total_tasks": 0,
|
444 |
+
"total_videos": 0,
|
445 |
+
"total_chunks": 0,
|
446 |
+
"chunks_size": DEFAULT_CHUNK_SIZE,
|
447 |
+
"fps": fps,
|
448 |
+
"splits": {},
|
449 |
+
"data_path": DEFAULT_PARQUET_PATH,
|
450 |
+
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
451 |
+
"features": features,
|
452 |
+
}
|
453 |
+
|
454 |
+
|
455 |
+
def get_episode_data_index(
|
456 |
+
episode_dicts: dict[dict], episodes: list[int] | None = None
|
457 |
+
) -> dict[str, torch.Tensor]:
|
458 |
+
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
459 |
+
if episodes is not None:
|
460 |
+
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
461 |
+
|
462 |
+
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
463 |
+
return {
|
464 |
+
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
465 |
+
"to": torch.LongTensor(cumulative_lengths),
|
466 |
+
}
|
467 |
+
|
468 |
+
|
469 |
+
def check_timestamps_sync(
|
470 |
+
timestamps: np.ndarray,
|
471 |
+
episode_indices: np.ndarray,
|
472 |
+
episode_data_index: dict[str, np.ndarray],
|
473 |
+
fps: int,
|
474 |
+
tolerance_s: float,
|
475 |
+
raise_value_error: bool = True,
|
476 |
+
) -> bool:
|
477 |
+
"""
|
478 |
+
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
479 |
+
to account for possible numerical error.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
timestamps (np.ndarray): Array of timestamps in seconds.
|
483 |
+
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
484 |
+
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
485 |
+
which identifies indices for the end of each episode.
|
486 |
+
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
487 |
+
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
488 |
+
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
492 |
+
|
493 |
+
Raises:
|
494 |
+
ValueError: If the check fails and `raise_value_error` is True.
|
495 |
+
"""
|
496 |
+
if timestamps.shape != episode_indices.shape:
|
497 |
+
raise ValueError(
|
498 |
+
"timestamps and episode_indices should have the same shape. "
|
499 |
+
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
500 |
+
)
|
501 |
+
|
502 |
+
# Consecutive differences
|
503 |
+
diffs = np.diff(timestamps)
|
504 |
+
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
505 |
+
|
506 |
+
# Mask to ignore differences at the boundaries between episodes
|
507 |
+
mask = np.ones(len(diffs), dtype=bool)
|
508 |
+
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
509 |
+
mask[ignored_diffs] = False
|
510 |
+
filtered_within_tolerance = within_tolerance[mask]
|
511 |
+
|
512 |
+
# Check if all remaining diffs are within tolerance
|
513 |
+
if not np.all(filtered_within_tolerance):
|
514 |
+
# Track original indices before masking
|
515 |
+
original_indices = np.arange(len(diffs))
|
516 |
+
filtered_indices = original_indices[mask]
|
517 |
+
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
518 |
+
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
519 |
+
|
520 |
+
outside_tolerances = []
|
521 |
+
for idx in outside_tolerance_indices:
|
522 |
+
entry = {
|
523 |
+
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
524 |
+
"diff": diffs[idx],
|
525 |
+
"episode_index": episode_indices[idx].item()
|
526 |
+
if hasattr(episode_indices[idx], "item")
|
527 |
+
else episode_indices[idx],
|
528 |
+
}
|
529 |
+
outside_tolerances.append(entry)
|
530 |
+
|
531 |
+
if raise_value_error:
|
532 |
+
raise ValueError(
|
533 |
+
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
534 |
+
This might be due to synchronization issues during data collection.
|
535 |
+
\n{pformat(outside_tolerances)}"""
|
536 |
+
)
|
537 |
+
return False
|
538 |
+
|
539 |
+
return True
|
540 |
+
|
541 |
+
|
542 |
+
def check_delta_timestamps(
|
543 |
+
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
544 |
+
) -> bool:
|
545 |
+
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
546 |
+
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
547 |
+
actual timestamps from the dataset.
|
548 |
+
"""
|
549 |
+
outside_tolerance = {}
|
550 |
+
for key, delta_ts in delta_timestamps.items():
|
551 |
+
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
552 |
+
if not all(within_tolerance):
|
553 |
+
outside_tolerance[key] = [
|
554 |
+
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
555 |
+
]
|
556 |
+
|
557 |
+
if len(outside_tolerance) > 0:
|
558 |
+
if raise_value_error:
|
559 |
+
raise ValueError(
|
560 |
+
f"""
|
561 |
+
The following delta_timestamps are found outside of tolerance range.
|
562 |
+
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
563 |
+
their values accordingly.
|
564 |
+
\n{pformat(outside_tolerance)}
|
565 |
+
"""
|
566 |
+
)
|
567 |
+
return False
|
568 |
+
|
569 |
+
return True
|
570 |
+
|
571 |
+
|
572 |
+
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
573 |
+
delta_indices = {}
|
574 |
+
for key, delta_ts in delta_timestamps.items():
|
575 |
+
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
576 |
+
|
577 |
+
return delta_indices
|
578 |
+
|
579 |
+
|
580 |
+
def cycle(iterable):
|
581 |
+
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
582 |
+
|
583 |
+
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
584 |
+
"""
|
585 |
+
iterator = iter(iterable)
|
586 |
+
while True:
|
587 |
+
try:
|
588 |
+
yield next(iterator)
|
589 |
+
except StopIteration:
|
590 |
+
iterator = iter(iterable)
|
591 |
+
|
592 |
+
|
593 |
+
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
594 |
+
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
595 |
+
exists before creating it.
|
596 |
+
"""
|
597 |
+
api = HfApi()
|
598 |
+
|
599 |
+
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
600 |
+
refs = [branch.ref for branch in branches]
|
601 |
+
ref = f"refs/heads/{branch}"
|
602 |
+
if ref in refs:
|
603 |
+
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
604 |
+
|
605 |
+
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
606 |
+
|
607 |
+
|
608 |
+
def create_lerobot_dataset_card(
|
609 |
+
tags: list | None = None,
|
610 |
+
dataset_info: dict | None = None,
|
611 |
+
**kwargs,
|
612 |
+
) -> DatasetCard:
|
613 |
+
"""
|
614 |
+
Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md.
|
615 |
+
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
616 |
+
"""
|
617 |
+
card_tags = ["LeRobot"]
|
618 |
+
|
619 |
+
if tags:
|
620 |
+
card_tags += tags
|
621 |
+
if dataset_info:
|
622 |
+
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
623 |
+
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
624 |
+
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
625 |
+
card_data = DatasetCardData(
|
626 |
+
license=kwargs.get("license"),
|
627 |
+
tags=card_tags,
|
628 |
+
task_categories=["robotics"],
|
629 |
+
configs=[
|
630 |
+
{
|
631 |
+
"config_name": "default",
|
632 |
+
"data_files": "data/*/*.parquet",
|
633 |
+
}
|
634 |
+
],
|
635 |
+
)
|
636 |
+
|
637 |
+
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
638 |
+
|
639 |
+
return DatasetCard.from_template(
|
640 |
+
card_data=card_data,
|
641 |
+
template_str=card_template,
|
642 |
+
**kwargs,
|
643 |
+
)
|
644 |
+
|
645 |
+
|
646 |
+
class IterableNamespace(SimpleNamespace):
|
647 |
+
"""
|
648 |
+
A namespace object that supports both dictionary-like iteration and dot notation access.
|
649 |
+
Automatically converts nested dictionaries into IterableNamespaces.
|
650 |
+
|
651 |
+
This class extends SimpleNamespace to provide:
|
652 |
+
- Dictionary-style iteration over keys
|
653 |
+
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
654 |
+
- Dictionary-like methods: items(), keys(), values()
|
655 |
+
- Recursive conversion of nested dictionaries
|
656 |
+
|
657 |
+
Args:
|
658 |
+
dictionary: Optional dictionary to initialize the namespace
|
659 |
+
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
660 |
+
|
661 |
+
Examples:
|
662 |
+
>>> data = {"name": "Alice", "details": {"age": 25}}
|
663 |
+
>>> ns = IterableNamespace(data)
|
664 |
+
>>> ns.name
|
665 |
+
'Alice'
|
666 |
+
>>> ns.details.age
|
667 |
+
25
|
668 |
+
>>> list(ns.keys())
|
669 |
+
['name', 'details']
|
670 |
+
>>> for key, value in ns.items():
|
671 |
+
... print(f"{key}: {value}")
|
672 |
+
name: Alice
|
673 |
+
details: IterableNamespace(age=25)
|
674 |
+
"""
|
675 |
+
|
676 |
+
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
677 |
+
super().__init__(**kwargs)
|
678 |
+
if dictionary is not None:
|
679 |
+
for key, value in dictionary.items():
|
680 |
+
if isinstance(value, dict):
|
681 |
+
setattr(self, key, IterableNamespace(value))
|
682 |
+
else:
|
683 |
+
setattr(self, key, value)
|
684 |
+
|
685 |
+
def __iter__(self) -> Iterator[str]:
|
686 |
+
return iter(vars(self))
|
687 |
+
|
688 |
+
def __getitem__(self, key: str) -> Any:
|
689 |
+
return vars(self)[key]
|
690 |
+
|
691 |
+
def items(self):
|
692 |
+
return vars(self).items()
|
693 |
+
|
694 |
+
def values(self):
|
695 |
+
return vars(self).values()
|
696 |
+
|
697 |
+
def keys(self):
|
698 |
+
return vars(self).keys()
|
699 |
+
|
700 |
+
|
701 |
+
def validate_frame(frame: dict, features: dict):
|
702 |
+
optional_features = {"timestamp"}
|
703 |
+
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
704 |
+
actual_features = set(frame.keys())
|
705 |
+
|
706 |
+
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
707 |
+
|
708 |
+
if "task" in frame:
|
709 |
+
error_message += validate_feature_string("task", frame["task"])
|
710 |
+
|
711 |
+
common_features = actual_features & (expected_features | optional_features)
|
712 |
+
for name in common_features - {"task"}:
|
713 |
+
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
714 |
+
|
715 |
+
if error_message:
|
716 |
+
raise ValueError(error_message)
|
717 |
+
|
718 |
+
|
719 |
+
def validate_features_presence(
|
720 |
+
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
721 |
+
):
|
722 |
+
error_message = ""
|
723 |
+
missing_features = expected_features - actual_features
|
724 |
+
extra_features = actual_features - (expected_features | optional_features)
|
725 |
+
|
726 |
+
if missing_features or extra_features:
|
727 |
+
error_message += "Feature mismatch in `frame` dictionary:\n"
|
728 |
+
if missing_features:
|
729 |
+
error_message += f"Missing features: {missing_features}\n"
|
730 |
+
if extra_features:
|
731 |
+
error_message += f"Extra features: {extra_features}\n"
|
732 |
+
|
733 |
+
return error_message
|
734 |
+
|
735 |
+
|
736 |
+
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
737 |
+
expected_dtype = feature["dtype"]
|
738 |
+
expected_shape = feature["shape"]
|
739 |
+
if is_valid_numpy_dtype_string(expected_dtype):
|
740 |
+
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
741 |
+
elif expected_dtype in ["image", "video"]:
|
742 |
+
return validate_feature_image_or_video(name, expected_shape, value)
|
743 |
+
elif expected_dtype == "string":
|
744 |
+
return validate_feature_string(name, value)
|
745 |
+
else:
|
746 |
+
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
747 |
+
|
748 |
+
|
749 |
+
def validate_feature_numpy_array(
|
750 |
+
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
751 |
+
):
|
752 |
+
error_message = ""
|
753 |
+
if isinstance(value, np.ndarray):
|
754 |
+
actual_dtype = value.dtype
|
755 |
+
actual_shape = value.shape
|
756 |
+
|
757 |
+
if actual_dtype != np.dtype(expected_dtype):
|
758 |
+
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
759 |
+
|
760 |
+
if actual_shape != expected_shape:
|
761 |
+
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
762 |
+
else:
|
763 |
+
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
764 |
+
|
765 |
+
return error_message
|
766 |
+
|
767 |
+
|
768 |
+
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
769 |
+
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
770 |
+
error_message = ""
|
771 |
+
if isinstance(value, np.ndarray):
|
772 |
+
actual_shape = value.shape
|
773 |
+
c, h, w = expected_shape
|
774 |
+
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
775 |
+
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
776 |
+
elif isinstance(value, PILImage.Image):
|
777 |
+
pass
|
778 |
+
else:
|
779 |
+
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
780 |
+
|
781 |
+
return error_message
|
782 |
+
|
783 |
+
|
784 |
+
def validate_feature_string(name: str, value: str):
|
785 |
+
if not isinstance(value, str):
|
786 |
+
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
787 |
+
return ""
|
788 |
+
|
789 |
+
|
790 |
+
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
791 |
+
if "size" not in episode_buffer:
|
792 |
+
raise ValueError("size key not found in episode_buffer")
|
793 |
+
|
794 |
+
if "task" not in episode_buffer:
|
795 |
+
raise ValueError("task key not found in episode_buffer")
|
796 |
+
|
797 |
+
if episode_buffer["episode_index"] != total_episodes:
|
798 |
+
# TODO(aliberts): Add option to use existing episode_index
|
799 |
+
raise NotImplementedError(
|
800 |
+
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
801 |
+
"match the total number of episodes already in the dataset. This is not supported for now."
|
802 |
+
)
|
803 |
+
|
804 |
+
if episode_buffer["size"] == 0:
|
805 |
+
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
806 |
+
|
807 |
+
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
808 |
+
if not buffer_keys == set(features):
|
809 |
+
raise ValueError(
|
810 |
+
f"Features from `episode_buffer` don't match the ones in `features`."
|
811 |
+
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
812 |
+
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
813 |
+
)
|
lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py
ADDED
@@ -0,0 +1,884 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
|
19 |
+
|
20 |
+
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
|
21 |
+
lerobot/configs/robot/aloha.yaml before running this script.
|
22 |
+
"""
|
23 |
+
|
24 |
+
import traceback
|
25 |
+
from pathlib import Path
|
26 |
+
from textwrap import dedent
|
27 |
+
|
28 |
+
from lerobot import available_datasets
|
29 |
+
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
30 |
+
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
31 |
+
|
32 |
+
LOCAL_DIR = Path("data/")
|
33 |
+
|
34 |
+
# spellchecker:off
|
35 |
+
ALOHA_MOBILE_INFO = {
|
36 |
+
"robot_config": AlohaRobotConfig(),
|
37 |
+
"license": "mit",
|
38 |
+
"url": "https://mobile-aloha.github.io/",
|
39 |
+
"paper": "https://arxiv.org/abs/2401.02117",
|
40 |
+
"citation_bibtex": dedent(r"""
|
41 |
+
@inproceedings{fu2024mobile,
|
42 |
+
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
|
43 |
+
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
|
44 |
+
booktitle = {arXiv},
|
45 |
+
year = {2024},
|
46 |
+
}""").lstrip(),
|
47 |
+
}
|
48 |
+
ALOHA_STATIC_INFO = {
|
49 |
+
"robot_config": AlohaRobotConfig(),
|
50 |
+
"license": "mit",
|
51 |
+
"url": "https://tonyzhaozh.github.io/aloha/",
|
52 |
+
"paper": "https://arxiv.org/abs/2304.13705",
|
53 |
+
"citation_bibtex": dedent(r"""
|
54 |
+
@article{Zhao2023LearningFB,
|
55 |
+
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
|
56 |
+
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
|
57 |
+
journal={RSS},
|
58 |
+
year={2023},
|
59 |
+
volume={abs/2304.13705},
|
60 |
+
url={https://arxiv.org/abs/2304.13705}
|
61 |
+
}""").lstrip(),
|
62 |
+
}
|
63 |
+
PUSHT_INFO = {
|
64 |
+
"license": "mit",
|
65 |
+
"url": "https://diffusion-policy.cs.columbia.edu/",
|
66 |
+
"paper": "https://arxiv.org/abs/2303.04137v5",
|
67 |
+
"citation_bibtex": dedent(r"""
|
68 |
+
@article{chi2024diffusionpolicy,
|
69 |
+
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
70 |
+
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
71 |
+
journal = {The International Journal of Robotics Research},
|
72 |
+
year = {2024},
|
73 |
+
}""").lstrip(),
|
74 |
+
}
|
75 |
+
XARM_INFO = {
|
76 |
+
"license": "mit",
|
77 |
+
"url": "https://www.nicklashansen.com/td-mpc/",
|
78 |
+
"paper": "https://arxiv.org/abs/2203.04955",
|
79 |
+
"citation_bibtex": dedent(r"""
|
80 |
+
@inproceedings{Hansen2022tdmpc,
|
81 |
+
title={Temporal Difference Learning for Model Predictive Control},
|
82 |
+
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
83 |
+
booktitle={ICML},
|
84 |
+
year={2022}
|
85 |
+
}
|
86 |
+
"""),
|
87 |
+
}
|
88 |
+
UNITREEH_INFO = {
|
89 |
+
"license": "apache-2.0",
|
90 |
+
}
|
91 |
+
|
92 |
+
DATASETS = {
|
93 |
+
"aloha_mobile_cabinet": {
|
94 |
+
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
|
95 |
+
**ALOHA_MOBILE_INFO,
|
96 |
+
},
|
97 |
+
"aloha_mobile_chair": {
|
98 |
+
"single_task": "Push the chairs in front of the desk to place them against it.",
|
99 |
+
**ALOHA_MOBILE_INFO,
|
100 |
+
},
|
101 |
+
"aloha_mobile_elevator": {
|
102 |
+
"single_task": "Take the elevator to the 1st floor.",
|
103 |
+
**ALOHA_MOBILE_INFO,
|
104 |
+
},
|
105 |
+
"aloha_mobile_shrimp": {
|
106 |
+
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
|
107 |
+
**ALOHA_MOBILE_INFO,
|
108 |
+
},
|
109 |
+
"aloha_mobile_wash_pan": {
|
110 |
+
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
|
111 |
+
**ALOHA_MOBILE_INFO,
|
112 |
+
},
|
113 |
+
"aloha_mobile_wipe_wine": {
|
114 |
+
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
|
115 |
+
**ALOHA_MOBILE_INFO,
|
116 |
+
},
|
117 |
+
"aloha_static_battery": {
|
118 |
+
"single_task": "Place the battery into the slot of the remote controller.",
|
119 |
+
**ALOHA_STATIC_INFO,
|
120 |
+
},
|
121 |
+
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
122 |
+
"aloha_static_coffee": {
|
123 |
+
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
124 |
+
**ALOHA_STATIC_INFO,
|
125 |
+
},
|
126 |
+
"aloha_static_coffee_new": {
|
127 |
+
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
|
128 |
+
**ALOHA_STATIC_INFO,
|
129 |
+
},
|
130 |
+
"aloha_static_cups_open": {
|
131 |
+
"single_task": "Pick up the plastic cup and open its lid.",
|
132 |
+
**ALOHA_STATIC_INFO,
|
133 |
+
},
|
134 |
+
"aloha_static_fork_pick_up": {
|
135 |
+
"single_task": "Pick up the fork and place it on the plate.",
|
136 |
+
**ALOHA_STATIC_INFO,
|
137 |
+
},
|
138 |
+
"aloha_static_pingpong_test": {
|
139 |
+
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
|
140 |
+
**ALOHA_STATIC_INFO,
|
141 |
+
},
|
142 |
+
"aloha_static_pro_pencil": {
|
143 |
+
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
|
144 |
+
**ALOHA_STATIC_INFO,
|
145 |
+
},
|
146 |
+
"aloha_static_screw_driver": {
|
147 |
+
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
|
148 |
+
**ALOHA_STATIC_INFO,
|
149 |
+
},
|
150 |
+
"aloha_static_tape": {
|
151 |
+
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
|
152 |
+
**ALOHA_STATIC_INFO,
|
153 |
+
},
|
154 |
+
"aloha_static_thread_velcro": {
|
155 |
+
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
|
156 |
+
**ALOHA_STATIC_INFO,
|
157 |
+
},
|
158 |
+
"aloha_static_towel": {
|
159 |
+
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
|
160 |
+
**ALOHA_STATIC_INFO,
|
161 |
+
},
|
162 |
+
"aloha_static_vinh_cup": {
|
163 |
+
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
164 |
+
**ALOHA_STATIC_INFO,
|
165 |
+
},
|
166 |
+
"aloha_static_vinh_cup_left": {
|
167 |
+
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
168 |
+
**ALOHA_STATIC_INFO,
|
169 |
+
},
|
170 |
+
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
171 |
+
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
172 |
+
"aloha_sim_insertion_scripted_image": {
|
173 |
+
"single_task": "Insert the peg into the socket.",
|
174 |
+
**ALOHA_STATIC_INFO,
|
175 |
+
},
|
176 |
+
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
177 |
+
"aloha_sim_insertion_human_image": {
|
178 |
+
"single_task": "Insert the peg into the socket.",
|
179 |
+
**ALOHA_STATIC_INFO,
|
180 |
+
},
|
181 |
+
"aloha_sim_transfer_cube_scripted": {
|
182 |
+
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
183 |
+
**ALOHA_STATIC_INFO,
|
184 |
+
},
|
185 |
+
"aloha_sim_transfer_cube_scripted_image": {
|
186 |
+
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
187 |
+
**ALOHA_STATIC_INFO,
|
188 |
+
},
|
189 |
+
"aloha_sim_transfer_cube_human": {
|
190 |
+
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
191 |
+
**ALOHA_STATIC_INFO,
|
192 |
+
},
|
193 |
+
"aloha_sim_transfer_cube_human_image": {
|
194 |
+
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
195 |
+
**ALOHA_STATIC_INFO,
|
196 |
+
},
|
197 |
+
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
198 |
+
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
199 |
+
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
200 |
+
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
201 |
+
"unitreeh1_two_robot_greeting": {
|
202 |
+
"single_task": "Greet the other robot with a high five.",
|
203 |
+
**UNITREEH_INFO,
|
204 |
+
},
|
205 |
+
"unitreeh1_warehouse": {
|
206 |
+
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
|
207 |
+
**UNITREEH_INFO,
|
208 |
+
},
|
209 |
+
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
210 |
+
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
211 |
+
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
212 |
+
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
213 |
+
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
214 |
+
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
215 |
+
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
216 |
+
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
217 |
+
"umi_cup_in_the_wild": {
|
218 |
+
"single_task": "Put the cup on the plate.",
|
219 |
+
"license": "apache-2.0",
|
220 |
+
},
|
221 |
+
"asu_table_top": {
|
222 |
+
"tasks_col": "language_instruction",
|
223 |
+
"license": "mit",
|
224 |
+
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
|
225 |
+
"citation_bibtex": dedent(r"""
|
226 |
+
@inproceedings{zhou2023modularity,
|
227 |
+
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
|
228 |
+
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
|
229 |
+
booktitle={Conference on Robot Learning},
|
230 |
+
pages={1684--1695},
|
231 |
+
year={2023},
|
232 |
+
organization={PMLR}
|
233 |
+
}
|
234 |
+
@article{zhou2023learning,
|
235 |
+
title={Learning modular language-conditioned robot policies through attention},
|
236 |
+
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
|
237 |
+
journal={Autonomous Robots},
|
238 |
+
pages={1--21},
|
239 |
+
year={2023},
|
240 |
+
publisher={Springer}
|
241 |
+
}""").lstrip(),
|
242 |
+
},
|
243 |
+
"austin_buds_dataset": {
|
244 |
+
"tasks_col": "language_instruction",
|
245 |
+
"license": "mit",
|
246 |
+
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
|
247 |
+
"paper": "https://arxiv.org/abs/2109.13841",
|
248 |
+
"citation_bibtex": dedent(r"""
|
249 |
+
@article{zhu2022bottom,
|
250 |
+
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
|
251 |
+
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
|
252 |
+
journal={IEEE Robotics and Automation Letters},
|
253 |
+
volume={7},
|
254 |
+
number={2},
|
255 |
+
pages={4126--4133},
|
256 |
+
year={2022},
|
257 |
+
publisher={IEEE}
|
258 |
+
}""").lstrip(),
|
259 |
+
},
|
260 |
+
"austin_sailor_dataset": {
|
261 |
+
"tasks_col": "language_instruction",
|
262 |
+
"license": "mit",
|
263 |
+
"url": "https://ut-austin-rpl.github.io/sailor/",
|
264 |
+
"paper": "https://arxiv.org/abs/2210.11435",
|
265 |
+
"citation_bibtex": dedent(r"""
|
266 |
+
@inproceedings{nasiriany2022sailor,
|
267 |
+
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
|
268 |
+
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
|
269 |
+
booktitle={Conference on Robot Learning (CoRL)},
|
270 |
+
year={2022}
|
271 |
+
}""").lstrip(),
|
272 |
+
},
|
273 |
+
"austin_sirius_dataset": {
|
274 |
+
"tasks_col": "language_instruction",
|
275 |
+
"license": "mit",
|
276 |
+
"url": "https://ut-austin-rpl.github.io/sirius/",
|
277 |
+
"paper": "https://arxiv.org/abs/2211.08416",
|
278 |
+
"citation_bibtex": dedent(r"""
|
279 |
+
@inproceedings{liu2022robot,
|
280 |
+
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
|
281 |
+
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
|
282 |
+
booktitle = {Robotics: Science and Systems (RSS)},
|
283 |
+
year = {2023}
|
284 |
+
}""").lstrip(),
|
285 |
+
},
|
286 |
+
"berkeley_autolab_ur5": {
|
287 |
+
"tasks_col": "language_instruction",
|
288 |
+
"license": "cc-by-4.0",
|
289 |
+
"url": "https://sites.google.com/view/berkeley-ur5/home",
|
290 |
+
"citation_bibtex": dedent(r"""
|
291 |
+
@misc{BerkeleyUR5Website,
|
292 |
+
title = {Berkeley {UR5} Demonstration Dataset},
|
293 |
+
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
|
294 |
+
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
|
295 |
+
}""").lstrip(),
|
296 |
+
},
|
297 |
+
"berkeley_cable_routing": {
|
298 |
+
"tasks_col": "language_instruction",
|
299 |
+
"license": "cc-by-4.0",
|
300 |
+
"url": "https://sites.google.com/view/cablerouting/home",
|
301 |
+
"paper": "https://arxiv.org/abs/2307.08927",
|
302 |
+
"citation_bibtex": dedent(r"""
|
303 |
+
@article{luo2023multistage,
|
304 |
+
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
|
305 |
+
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
|
306 |
+
journal = {arXiv pre-print},
|
307 |
+
year = {2023},
|
308 |
+
url = {https://arxiv.org/abs/2307.08927},
|
309 |
+
}""").lstrip(),
|
310 |
+
},
|
311 |
+
"berkeley_fanuc_manipulation": {
|
312 |
+
"tasks_col": "language_instruction",
|
313 |
+
"license": "mit",
|
314 |
+
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
|
315 |
+
"citation_bibtex": dedent(r"""
|
316 |
+
@article{fanuc_manipulation2023,
|
317 |
+
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
|
318 |
+
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
|
319 |
+
year={2023},
|
320 |
+
}""").lstrip(),
|
321 |
+
},
|
322 |
+
"berkeley_gnm_cory_hall": {
|
323 |
+
"tasks_col": "language_instruction",
|
324 |
+
"license": "mit",
|
325 |
+
"paper": "https://arxiv.org/abs/1709.10489",
|
326 |
+
"citation_bibtex": dedent(r"""
|
327 |
+
@inproceedings{kahn2018self,
|
328 |
+
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
|
329 |
+
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
|
330 |
+
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
|
331 |
+
pages={5129--5136},
|
332 |
+
year={2018},
|
333 |
+
organization={IEEE}
|
334 |
+
}""").lstrip(),
|
335 |
+
},
|
336 |
+
"berkeley_gnm_recon": {
|
337 |
+
"tasks_col": "language_instruction",
|
338 |
+
"license": "mit",
|
339 |
+
"url": "https://sites.google.com/view/recon-robot",
|
340 |
+
"paper": "https://arxiv.org/abs/2104.05859",
|
341 |
+
"citation_bibtex": dedent(r"""
|
342 |
+
@inproceedings{shah2021rapid,
|
343 |
+
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
|
344 |
+
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
|
345 |
+
booktitle={5th Annual Conference on Robot Learning },
|
346 |
+
year={2021},
|
347 |
+
url={https://openreview.net/forum?id=d_SWJhyKfVw}
|
348 |
+
}""").lstrip(),
|
349 |
+
},
|
350 |
+
"berkeley_gnm_sac_son": {
|
351 |
+
"tasks_col": "language_instruction",
|
352 |
+
"license": "mit",
|
353 |
+
"url": "https://sites.google.com/view/SACSoN-review",
|
354 |
+
"paper": "https://arxiv.org/abs/2306.01874",
|
355 |
+
"citation_bibtex": dedent(r"""
|
356 |
+
@article{hirose2023sacson,
|
357 |
+
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
|
358 |
+
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
|
359 |
+
journal={arXiv preprint arXiv:2306.01874},
|
360 |
+
year={2023}
|
361 |
+
}""").lstrip(),
|
362 |
+
},
|
363 |
+
"berkeley_mvp": {
|
364 |
+
"tasks_col": "language_instruction",
|
365 |
+
"license": "mit",
|
366 |
+
"paper": "https://arxiv.org/abs/2203.06173",
|
367 |
+
"citation_bibtex": dedent(r"""
|
368 |
+
@InProceedings{Radosavovic2022,
|
369 |
+
title = {Real-World Robot Learning with Masked Visual Pre-training},
|
370 |
+
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
|
371 |
+
booktitle = {CoRL},
|
372 |
+
year = {2022}
|
373 |
+
}""").lstrip(),
|
374 |
+
},
|
375 |
+
"berkeley_rpt": {
|
376 |
+
"tasks_col": "language_instruction",
|
377 |
+
"license": "mit",
|
378 |
+
"paper": "https://arxiv.org/abs/2306.10007",
|
379 |
+
"citation_bibtex": dedent(r"""
|
380 |
+
@article{Radosavovic2023,
|
381 |
+
title={Robot Learning with Sensorimotor Pre-training},
|
382 |
+
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
|
383 |
+
year={2023},
|
384 |
+
journal={arXiv:2306.10007}
|
385 |
+
}""").lstrip(),
|
386 |
+
},
|
387 |
+
"cmu_franka_exploration_dataset": {
|
388 |
+
"tasks_col": "language_instruction",
|
389 |
+
"license": "mit",
|
390 |
+
"url": "https://human-world-model.github.io/",
|
391 |
+
"paper": "https://arxiv.org/abs/2308.10901",
|
392 |
+
"citation_bibtex": dedent(r"""
|
393 |
+
@inproceedings{mendonca2023structured,
|
394 |
+
title={Structured World Models from Human Videos},
|
395 |
+
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
396 |
+
journal={RSS},
|
397 |
+
year={2023}
|
398 |
+
}""").lstrip(),
|
399 |
+
},
|
400 |
+
"cmu_play_fusion": {
|
401 |
+
"tasks_col": "language_instruction",
|
402 |
+
"license": "mit",
|
403 |
+
"url": "https://play-fusion.github.io/",
|
404 |
+
"paper": "https://arxiv.org/abs/2312.04549",
|
405 |
+
"citation_bibtex": dedent(r"""
|
406 |
+
@inproceedings{chen2023playfusion,
|
407 |
+
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
|
408 |
+
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
|
409 |
+
booktitle={CoRL},
|
410 |
+
year={2023}
|
411 |
+
}""").lstrip(),
|
412 |
+
},
|
413 |
+
"cmu_stretch": {
|
414 |
+
"tasks_col": "language_instruction",
|
415 |
+
"license": "mit",
|
416 |
+
"url": "https://robo-affordances.github.io/",
|
417 |
+
"paper": "https://arxiv.org/abs/2304.08488",
|
418 |
+
"citation_bibtex": dedent(r"""
|
419 |
+
@inproceedings{bahl2023affordances,
|
420 |
+
title={Affordances from Human Videos as a Versatile Representation for Robotics},
|
421 |
+
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
|
422 |
+
booktitle={CVPR},
|
423 |
+
year={2023}
|
424 |
+
}
|
425 |
+
@article{mendonca2023structured,
|
426 |
+
title={Structured World Models from Human Videos},
|
427 |
+
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
428 |
+
journal={CoRL},
|
429 |
+
year={2023}
|
430 |
+
}""").lstrip(),
|
431 |
+
},
|
432 |
+
"columbia_cairlab_pusht_real": {
|
433 |
+
"tasks_col": "language_instruction",
|
434 |
+
"license": "mit",
|
435 |
+
"url": "https://diffusion-policy.cs.columbia.edu/",
|
436 |
+
"paper": "https://arxiv.org/abs/2303.04137v5",
|
437 |
+
"citation_bibtex": dedent(r"""
|
438 |
+
@inproceedings{chi2023diffusionpolicy,
|
439 |
+
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
440 |
+
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
|
441 |
+
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
|
442 |
+
year={2023}
|
443 |
+
}""").lstrip(),
|
444 |
+
},
|
445 |
+
"conq_hose_manipulation": {
|
446 |
+
"tasks_col": "language_instruction",
|
447 |
+
"license": "mit",
|
448 |
+
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
|
449 |
+
"citation_bibtex": dedent(r"""
|
450 |
+
@misc{ConqHoseManipData,
|
451 |
+
author={Peter Mitrano and Dmitry Berenson},
|
452 |
+
title={Conq Hose Manipulation Dataset, v1.15.0},
|
453 |
+
year={2024},
|
454 |
+
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
|
455 |
+
}""").lstrip(),
|
456 |
+
},
|
457 |
+
"dlr_edan_shared_control": {
|
458 |
+
"tasks_col": "language_instruction",
|
459 |
+
"license": "mit",
|
460 |
+
"paper": "https://ieeexplore.ieee.org/document/9341156",
|
461 |
+
"citation_bibtex": dedent(r"""
|
462 |
+
@inproceedings{vogel_edan_2020,
|
463 |
+
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
|
464 |
+
language = {en},
|
465 |
+
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
|
466 |
+
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
|
467 |
+
year = {2020}
|
468 |
+
}
|
469 |
+
@inproceedings{quere_shared_2020,
|
470 |
+
address = {Paris, France},
|
471 |
+
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
|
472 |
+
language = {en},
|
473 |
+
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
|
474 |
+
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
|
475 |
+
year = {2020},
|
476 |
+
pages = {7},
|
477 |
+
}""").lstrip(),
|
478 |
+
},
|
479 |
+
"dlr_sara_grid_clamp": {
|
480 |
+
"tasks_col": "language_instruction",
|
481 |
+
"license": "mit",
|
482 |
+
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
|
483 |
+
"citation_bibtex": dedent(r"""
|
484 |
+
@article{padalkar2023guided,
|
485 |
+
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
|
486 |
+
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
487 |
+
journal={Research square preprint rs-3289569/v1},
|
488 |
+
year={2023}
|
489 |
+
}""").lstrip(),
|
490 |
+
},
|
491 |
+
"dlr_sara_pour": {
|
492 |
+
"tasks_col": "language_instruction",
|
493 |
+
"license": "mit",
|
494 |
+
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
|
495 |
+
"citation_bibtex": dedent(r"""
|
496 |
+
@inproceedings{padalkar2023guiding,
|
497 |
+
title={Guiding Reinforcement Learning with Shared Control Templates},
|
498 |
+
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
499 |
+
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
|
500 |
+
year={2023},
|
501 |
+
organization={IEEE}
|
502 |
+
}""").lstrip(),
|
503 |
+
},
|
504 |
+
"droid_100": {
|
505 |
+
"tasks_col": "language_instruction",
|
506 |
+
"license": "mit",
|
507 |
+
"url": "https://droid-dataset.github.io/",
|
508 |
+
"paper": "https://arxiv.org/abs/2403.12945",
|
509 |
+
"citation_bibtex": dedent(r"""
|
510 |
+
@article{khazatsky2024droid,
|
511 |
+
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
|
512 |
+
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
|
513 |
+
year = {2024},
|
514 |
+
}""").lstrip(),
|
515 |
+
},
|
516 |
+
"fmb": {
|
517 |
+
"tasks_col": "language_instruction",
|
518 |
+
"license": "cc-by-4.0",
|
519 |
+
"url": "https://functional-manipulation-benchmark.github.io/",
|
520 |
+
"paper": "https://arxiv.org/abs/2401.08553",
|
521 |
+
"citation_bibtex": dedent(r"""
|
522 |
+
@article{luo2024fmb,
|
523 |
+
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
|
524 |
+
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
|
525 |
+
journal={arXiv preprint arXiv:2401.08553},
|
526 |
+
year={2024}
|
527 |
+
}""").lstrip(),
|
528 |
+
},
|
529 |
+
"iamlab_cmu_pickup_insert": {
|
530 |
+
"tasks_col": "language_instruction",
|
531 |
+
"license": "mit",
|
532 |
+
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
|
533 |
+
"paper": "https://arxiv.org/abs/2401.14502",
|
534 |
+
"citation_bibtex": dedent(r"""
|
535 |
+
@inproceedings{saxena2023multiresolution,
|
536 |
+
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
|
537 |
+
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
|
538 |
+
booktitle={7th Annual Conference on Robot Learning},
|
539 |
+
year={2023},
|
540 |
+
url={https://openreview.net/forum?id=WuBv9-IGDUA}
|
541 |
+
}""").lstrip(),
|
542 |
+
},
|
543 |
+
"imperialcollege_sawyer_wrist_cam": {
|
544 |
+
"tasks_col": "language_instruction",
|
545 |
+
"license": "mit",
|
546 |
+
},
|
547 |
+
"jaco_play": {
|
548 |
+
"tasks_col": "language_instruction",
|
549 |
+
"license": "cc-by-4.0",
|
550 |
+
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
|
551 |
+
"citation_bibtex": dedent(r"""
|
552 |
+
@software{dass2023jacoplay,
|
553 |
+
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
|
554 |
+
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
|
555 |
+
title = {CLVR Jaco Play Dataset},
|
556 |
+
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
|
557 |
+
version = {1.0.0},
|
558 |
+
year = {2023}
|
559 |
+
}""").lstrip(),
|
560 |
+
},
|
561 |
+
"kaist_nonprehensile": {
|
562 |
+
"tasks_col": "language_instruction",
|
563 |
+
"license": "cc-by-4.0",
|
564 |
+
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
|
565 |
+
"citation_bibtex": dedent(r"""
|
566 |
+
@article{kimpre,
|
567 |
+
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
|
568 |
+
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
|
569 |
+
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
570 |
+
year={2023},
|
571 |
+
organization={IEEE}
|
572 |
+
}""").lstrip(),
|
573 |
+
},
|
574 |
+
"nyu_door_opening_surprising_effectiveness": {
|
575 |
+
"tasks_col": "language_instruction",
|
576 |
+
"license": "mit",
|
577 |
+
"url": "https://jyopari.github.io/VINN/",
|
578 |
+
"paper": "https://arxiv.org/abs/2112.01511",
|
579 |
+
"citation_bibtex": dedent(r"""
|
580 |
+
@misc{pari2021surprising,
|
581 |
+
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
|
582 |
+
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
|
583 |
+
year={2021},
|
584 |
+
eprint={2112.01511},
|
585 |
+
archivePrefix={arXiv},
|
586 |
+
primaryClass={cs.RO}
|
587 |
+
}""").lstrip(),
|
588 |
+
},
|
589 |
+
"nyu_franka_play_dataset": {
|
590 |
+
"tasks_col": "language_instruction",
|
591 |
+
"license": "mit",
|
592 |
+
"url": "https://play-to-policy.github.io/",
|
593 |
+
"paper": "https://arxiv.org/abs/2210.10047",
|
594 |
+
"citation_bibtex": dedent(r"""
|
595 |
+
@article{cui2022play,
|
596 |
+
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
|
597 |
+
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
598 |
+
journal = {arXiv preprint arXiv:2210.10047},
|
599 |
+
year = {2022}
|
600 |
+
}""").lstrip(),
|
601 |
+
},
|
602 |
+
"nyu_rot_dataset": {
|
603 |
+
"tasks_col": "language_instruction",
|
604 |
+
"license": "mit",
|
605 |
+
"url": "https://rot-robot.github.io/",
|
606 |
+
"paper": "https://arxiv.org/abs/2206.15469",
|
607 |
+
"citation_bibtex": dedent(r"""
|
608 |
+
@inproceedings{haldar2023watch,
|
609 |
+
title={Watch and match: Supercharging imitation with regularized optimal transport},
|
610 |
+
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
|
611 |
+
booktitle={Conference on Robot Learning},
|
612 |
+
pages={32--43},
|
613 |
+
year={2023},
|
614 |
+
organization={PMLR}
|
615 |
+
}""").lstrip(),
|
616 |
+
},
|
617 |
+
"roboturk": {
|
618 |
+
"tasks_col": "language_instruction",
|
619 |
+
"license": "mit",
|
620 |
+
"url": "https://roboturk.stanford.edu/dataset_real.html",
|
621 |
+
"paper": "PAPER",
|
622 |
+
"citation_bibtex": dedent(r"""
|
623 |
+
@inproceedings{mandlekar2019scaling,
|
624 |
+
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
|
625 |
+
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
|
626 |
+
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
627 |
+
pages={1048--1055},
|
628 |
+
year={2019},
|
629 |
+
organization={IEEE}
|
630 |
+
}""").lstrip(),
|
631 |
+
},
|
632 |
+
"stanford_hydra_dataset": {
|
633 |
+
"tasks_col": "language_instruction",
|
634 |
+
"license": "mit",
|
635 |
+
"url": "https://sites.google.com/view/hydra-il-2023",
|
636 |
+
"paper": "https://arxiv.org/abs/2306.17237",
|
637 |
+
"citation_bibtex": dedent(r"""
|
638 |
+
@article{belkhale2023hydra,
|
639 |
+
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
|
640 |
+
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
|
641 |
+
journal={arxiv},
|
642 |
+
year={2023}
|
643 |
+
}""").lstrip(),
|
644 |
+
},
|
645 |
+
"stanford_kuka_multimodal_dataset": {
|
646 |
+
"tasks_col": "language_instruction",
|
647 |
+
"license": "mit",
|
648 |
+
"url": "https://sites.google.com/view/visionandtouch",
|
649 |
+
"paper": "https://arxiv.org/abs/1810.10191",
|
650 |
+
"citation_bibtex": dedent(r"""
|
651 |
+
@inproceedings{lee2019icra,
|
652 |
+
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
|
653 |
+
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
|
654 |
+
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
|
655 |
+
year={2019},
|
656 |
+
url={https://arxiv.org/abs/1810.10191}
|
657 |
+
}""").lstrip(),
|
658 |
+
},
|
659 |
+
"stanford_robocook": {
|
660 |
+
"tasks_col": "language_instruction",
|
661 |
+
"license": "mit",
|
662 |
+
"url": "https://hshi74.github.io/robocook/",
|
663 |
+
"paper": "https://arxiv.org/abs/2306.14447",
|
664 |
+
"citation_bibtex": dedent(r"""
|
665 |
+
@article{shi2023robocook,
|
666 |
+
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
|
667 |
+
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
|
668 |
+
journal={arXiv preprint arXiv:2306.14447},
|
669 |
+
year={2023}
|
670 |
+
}""").lstrip(),
|
671 |
+
},
|
672 |
+
"taco_play": {
|
673 |
+
"tasks_col": "language_instruction",
|
674 |
+
"license": "cc-by-4.0",
|
675 |
+
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
|
676 |
+
"paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
|
677 |
+
"citation_bibtex": dedent(r"""
|
678 |
+
@inproceedings{rosete2022tacorl,
|
679 |
+
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
|
680 |
+
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
|
681 |
+
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
|
682 |
+
year = {2022}
|
683 |
+
}
|
684 |
+
@inproceedings{mees23hulc2,
|
685 |
+
title={Grounding Language with Visual Affordances over Unstructured Data},
|
686 |
+
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
|
687 |
+
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
|
688 |
+
year={2023},
|
689 |
+
address = {London, UK}
|
690 |
+
}""").lstrip(),
|
691 |
+
},
|
692 |
+
"tokyo_u_lsmo": {
|
693 |
+
"tasks_col": "language_instruction",
|
694 |
+
"license": "mit",
|
695 |
+
"url": "URL",
|
696 |
+
"paper": "https://arxiv.org/abs/2107.05842",
|
697 |
+
"citation_bibtex": dedent(r"""
|
698 |
+
@Article{Osa22,
|
699 |
+
author = {Takayuki Osa},
|
700 |
+
journal = {The International Journal of Robotics Research},
|
701 |
+
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
|
702 |
+
year = {2022},
|
703 |
+
number = {3},
|
704 |
+
pages = {291--311},
|
705 |
+
volume = {41},
|
706 |
+
}""").lstrip(),
|
707 |
+
},
|
708 |
+
"toto": {
|
709 |
+
"tasks_col": "language_instruction",
|
710 |
+
"license": "mit",
|
711 |
+
"url": "https://toto-benchmark.org/",
|
712 |
+
"paper": "https://arxiv.org/abs/2306.00942",
|
713 |
+
"citation_bibtex": dedent(r"""
|
714 |
+
@inproceedings{zhou2023train,
|
715 |
+
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
|
716 |
+
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
|
717 |
+
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
|
718 |
+
year={2023},
|
719 |
+
}""").lstrip(),
|
720 |
+
},
|
721 |
+
"ucsd_kitchen_dataset": {
|
722 |
+
"tasks_col": "language_instruction",
|
723 |
+
"license": "mit",
|
724 |
+
"citation_bibtex": dedent(r"""
|
725 |
+
@ARTICLE{ucsd_kitchens,
|
726 |
+
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
|
727 |
+
title = {{ucsd kitchens Dataset}},
|
728 |
+
year = {2023},
|
729 |
+
month = {August}
|
730 |
+
}""").lstrip(),
|
731 |
+
},
|
732 |
+
"ucsd_pick_and_place_dataset": {
|
733 |
+
"tasks_col": "language_instruction",
|
734 |
+
"license": "mit",
|
735 |
+
"url": "https://owmcorl.github.io/#",
|
736 |
+
"paper": "https://arxiv.org/abs/2310.16029",
|
737 |
+
"citation_bibtex": dedent(r"""
|
738 |
+
@preprint{Feng2023Finetuning,
|
739 |
+
title={Finetuning Offline World Models in the Real World},
|
740 |
+
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
|
741 |
+
year={2023}
|
742 |
+
}""").lstrip(),
|
743 |
+
},
|
744 |
+
"uiuc_d3field": {
|
745 |
+
"tasks_col": "language_instruction",
|
746 |
+
"license": "mit",
|
747 |
+
"url": "https://robopil.github.io/d3fields/",
|
748 |
+
"paper": "https://arxiv.org/abs/2309.16118",
|
749 |
+
"citation_bibtex": dedent(r"""
|
750 |
+
@article{wang2023d3field,
|
751 |
+
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
|
752 |
+
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
|
753 |
+
journal={arXiv preprint arXiv:},
|
754 |
+
year={2023},
|
755 |
+
}""").lstrip(),
|
756 |
+
},
|
757 |
+
"usc_cloth_sim": {
|
758 |
+
"tasks_col": "language_instruction",
|
759 |
+
"license": "mit",
|
760 |
+
"url": "https://uscresl.github.io/dmfd/",
|
761 |
+
"paper": "https://arxiv.org/abs/2207.10148",
|
762 |
+
"citation_bibtex": dedent(r"""
|
763 |
+
@article{salhotra2022dmfd,
|
764 |
+
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
|
765 |
+
journal={IEEE Robotics and Automation Letters},
|
766 |
+
title={Learning Deformable Object Manipulation From Expert Demonstrations},
|
767 |
+
year={2022},
|
768 |
+
volume={7},
|
769 |
+
number={4},
|
770 |
+
pages={8775-8782},
|
771 |
+
doi={10.1109/LRA.2022.3187843}
|
772 |
+
}""").lstrip(),
|
773 |
+
},
|
774 |
+
"utaustin_mutex": {
|
775 |
+
"tasks_col": "language_instruction",
|
776 |
+
"license": "mit",
|
777 |
+
"url": "https://ut-austin-rpl.github.io/MUTEX/",
|
778 |
+
"paper": "https://arxiv.org/abs/2309.14320",
|
779 |
+
"citation_bibtex": dedent(r"""
|
780 |
+
@inproceedings{shah2023mutex,
|
781 |
+
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
|
782 |
+
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
|
783 |
+
booktitle={7th Annual Conference on Robot Learning},
|
784 |
+
year={2023},
|
785 |
+
url={https://openreview.net/forum?id=PwqiqaaEzJ}
|
786 |
+
}""").lstrip(),
|
787 |
+
},
|
788 |
+
"utokyo_pr2_opening_fridge": {
|
789 |
+
"tasks_col": "language_instruction",
|
790 |
+
"license": "mit",
|
791 |
+
"citation_bibtex": dedent(r"""
|
792 |
+
@misc{oh2023pr2utokyodatasets,
|
793 |
+
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
794 |
+
title={X-Embodiment U-Tokyo PR2 Datasets},
|
795 |
+
year={2023},
|
796 |
+
url={https://github.com/ojh6404/rlds_dataset_builder},
|
797 |
+
}""").lstrip(),
|
798 |
+
},
|
799 |
+
"utokyo_pr2_tabletop_manipulation": {
|
800 |
+
"tasks_col": "language_instruction",
|
801 |
+
"license": "mit",
|
802 |
+
"citation_bibtex": dedent(r"""
|
803 |
+
@misc{oh2023pr2utokyodatasets,
|
804 |
+
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
805 |
+
title={X-Embodiment U-Tokyo PR2 Datasets},
|
806 |
+
year={2023},
|
807 |
+
url={https://github.com/ojh6404/rlds_dataset_builder},
|
808 |
+
}""").lstrip(),
|
809 |
+
},
|
810 |
+
"utokyo_saytap": {
|
811 |
+
"tasks_col": "language_instruction",
|
812 |
+
"license": "mit",
|
813 |
+
"url": "https://saytap.github.io/",
|
814 |
+
"paper": "https://arxiv.org/abs/2306.07580",
|
815 |
+
"citation_bibtex": dedent(r"""
|
816 |
+
@article{saytap2023,
|
817 |
+
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
|
818 |
+
Tatsuya Harada},
|
819 |
+
title = {SayTap: Language to Quadrupedal Locomotion},
|
820 |
+
eprint = {arXiv:2306.07580},
|
821 |
+
url = {https://saytap.github.io},
|
822 |
+
note = {https://saytap.github.io},
|
823 |
+
year = {2023}
|
824 |
+
}""").lstrip(),
|
825 |
+
},
|
826 |
+
"utokyo_xarm_bimanual": {
|
827 |
+
"tasks_col": "language_instruction",
|
828 |
+
"license": "cc-by-4.0",
|
829 |
+
"citation_bibtex": dedent(r"""
|
830 |
+
@misc{matsushima2023weblab,
|
831 |
+
title={Weblab xArm Dataset},
|
832 |
+
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
833 |
+
year={2023},
|
834 |
+
}""").lstrip(),
|
835 |
+
},
|
836 |
+
"utokyo_xarm_pick_and_place": {
|
837 |
+
"tasks_col": "language_instruction",
|
838 |
+
"license": "cc-by-4.0",
|
839 |
+
"citation_bibtex": dedent(r"""
|
840 |
+
@misc{matsushima2023weblab,
|
841 |
+
title={Weblab xArm Dataset},
|
842 |
+
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
843 |
+
year={2023},
|
844 |
+
}""").lstrip(),
|
845 |
+
},
|
846 |
+
"viola": {
|
847 |
+
"tasks_col": "language_instruction",
|
848 |
+
"license": "mit",
|
849 |
+
"url": "https://ut-austin-rpl.github.io/VIOLA/",
|
850 |
+
"paper": "https://arxiv.org/abs/2210.11339",
|
851 |
+
"citation_bibtex": dedent(r"""
|
852 |
+
@article{zhu2022viola,
|
853 |
+
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
|
854 |
+
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
|
855 |
+
journal={6th Annual Conference on Robot Learning (CoRL)},
|
856 |
+
year={2022}
|
857 |
+
}""").lstrip(),
|
858 |
+
},
|
859 |
+
}
|
860 |
+
# spellchecker:on
|
861 |
+
|
862 |
+
|
863 |
+
def batch_convert():
|
864 |
+
status = {}
|
865 |
+
logfile = LOCAL_DIR / "conversion_log.txt"
|
866 |
+
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
|
867 |
+
for num, (name, kwargs) in enumerate(DATASETS.items()):
|
868 |
+
repo_id = f"lerobot/{name}"
|
869 |
+
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
|
870 |
+
print("---------------------------------------------------------")
|
871 |
+
try:
|
872 |
+
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
873 |
+
status = f"{repo_id}: success."
|
874 |
+
with open(logfile, "a") as file:
|
875 |
+
file.write(status + "\n")
|
876 |
+
except Exception:
|
877 |
+
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
878 |
+
with open(logfile, "a") as file:
|
879 |
+
file.write(status + "\n")
|
880 |
+
continue
|
881 |
+
|
882 |
+
|
883 |
+
if __name__ == "__main__":
|
884 |
+
batch_convert()
|
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
ADDED
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
19 |
+
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
20 |
+
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
21 |
+
|
22 |
+
We support 3 different scenarios for these tasks (see instructions below):
|
23 |
+
1. Single task dataset: all episodes of your dataset have the same single task.
|
24 |
+
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
25 |
+
one episode to the next.
|
26 |
+
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
27 |
+
|
28 |
+
|
29 |
+
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
|
30 |
+
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
|
31 |
+
recorded with. For now, only Aloha/Koch type robots are supported with this option.
|
32 |
+
|
33 |
+
|
34 |
+
# 1. Single task dataset
|
35 |
+
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
36 |
+
'--single-task' option.
|
37 |
+
|
38 |
+
Examples:
|
39 |
+
|
40 |
+
```bash
|
41 |
+
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
42 |
+
--repo-id lerobot/aloha_sim_insertion_human_image \
|
43 |
+
--single-task "Insert the peg into the socket." \
|
44 |
+
--robot-config lerobot/configs/robot/aloha.yaml \
|
45 |
+
--local-dir data
|
46 |
+
```
|
47 |
+
|
48 |
+
```bash
|
49 |
+
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
50 |
+
--repo-id aliberts/koch_tutorial \
|
51 |
+
--single-task "Pick the Lego block and drop it in the box on the right." \
|
52 |
+
--robot-config lerobot/configs/robot/koch.yaml \
|
53 |
+
--local-dir data
|
54 |
+
```
|
55 |
+
|
56 |
+
|
57 |
+
# 2. Single task episodes
|
58 |
+
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
|
59 |
+
|
60 |
+
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
|
61 |
+
this column's name with the '--tasks-col' arg.
|
62 |
+
|
63 |
+
Example:
|
64 |
+
|
65 |
+
```bash
|
66 |
+
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
67 |
+
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
68 |
+
--tasks-col "language_instruction" \
|
69 |
+
--local-dir data
|
70 |
+
```
|
71 |
+
|
72 |
+
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
|
73 |
+
'--tasks-path' arg. This file should have the following structure where keys correspond to each
|
74 |
+
episode_index in the dataset, and values are the language instruction for that episode.
|
75 |
+
|
76 |
+
Example:
|
77 |
+
|
78 |
+
```json
|
79 |
+
{
|
80 |
+
"0": "Do something",
|
81 |
+
"1": "Do something else",
|
82 |
+
"2": "Do something",
|
83 |
+
"3": "Go there",
|
84 |
+
...
|
85 |
+
}
|
86 |
+
```
|
87 |
+
|
88 |
+
# 3. Multi task episodes
|
89 |
+
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
90 |
+
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
91 |
+
|
92 |
+
Example:
|
93 |
+
|
94 |
+
```bash
|
95 |
+
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
96 |
+
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
97 |
+
--tasks-col "language_instruction" \
|
98 |
+
--local-dir data
|
99 |
+
```
|
100 |
+
"""
|
101 |
+
|
102 |
+
import argparse
|
103 |
+
import contextlib
|
104 |
+
import filecmp
|
105 |
+
import json
|
106 |
+
import logging
|
107 |
+
import math
|
108 |
+
import shutil
|
109 |
+
import subprocess
|
110 |
+
import tempfile
|
111 |
+
from pathlib import Path
|
112 |
+
|
113 |
+
import datasets
|
114 |
+
import pyarrow.compute as pc
|
115 |
+
import pyarrow.parquet as pq
|
116 |
+
import torch
|
117 |
+
from datasets import Dataset
|
118 |
+
from huggingface_hub import HfApi
|
119 |
+
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
120 |
+
from safetensors.torch import load_file
|
121 |
+
|
122 |
+
from lerobot.common.datasets.utils import (
|
123 |
+
DEFAULT_CHUNK_SIZE,
|
124 |
+
DEFAULT_PARQUET_PATH,
|
125 |
+
DEFAULT_VIDEO_PATH,
|
126 |
+
EPISODES_PATH,
|
127 |
+
INFO_PATH,
|
128 |
+
STATS_PATH,
|
129 |
+
TASKS_PATH,
|
130 |
+
create_branch,
|
131 |
+
create_lerobot_dataset_card,
|
132 |
+
flatten_dict,
|
133 |
+
get_safe_version,
|
134 |
+
load_json,
|
135 |
+
unflatten_dict,
|
136 |
+
write_json,
|
137 |
+
write_jsonlines,
|
138 |
+
)
|
139 |
+
from lerobot.common.datasets.video_utils import (
|
140 |
+
VideoFrame, # noqa: F401
|
141 |
+
get_image_pixel_channels,
|
142 |
+
get_video_info,
|
143 |
+
)
|
144 |
+
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
145 |
+
from lerobot.common.robot_devices.robots.utils import make_robot_config
|
146 |
+
|
147 |
+
V16 = "v1.6"
|
148 |
+
V20 = "v2.0"
|
149 |
+
|
150 |
+
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
151 |
+
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
152 |
+
V1_INFO_PATH = "meta_data/info.json"
|
153 |
+
V1_STATS_PATH = "meta_data/stats.safetensors"
|
154 |
+
|
155 |
+
|
156 |
+
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
157 |
+
if robot_cfg.type in ["aloha", "koch"]:
|
158 |
+
state_names = [
|
159 |
+
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
|
160 |
+
for arm in robot_cfg.follower_arms
|
161 |
+
for motor in robot_cfg.follower_arms[arm].motors
|
162 |
+
]
|
163 |
+
action_names = [
|
164 |
+
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
165 |
+
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
|
166 |
+
for arm in robot_cfg.leader_arms
|
167 |
+
for motor in robot_cfg.leader_arms[arm].motors
|
168 |
+
]
|
169 |
+
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
170 |
+
else:
|
171 |
+
raise NotImplementedError(
|
172 |
+
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
|
173 |
+
)
|
174 |
+
|
175 |
+
return {
|
176 |
+
"robot_type": robot_cfg.type,
|
177 |
+
"names": {
|
178 |
+
"observation.state": state_names,
|
179 |
+
"observation.effort": state_names,
|
180 |
+
"action": action_names,
|
181 |
+
},
|
182 |
+
}
|
183 |
+
|
184 |
+
|
185 |
+
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
186 |
+
safetensor_path = v1_dir / V1_STATS_PATH
|
187 |
+
stats = load_file(safetensor_path)
|
188 |
+
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
189 |
+
serialized_stats = unflatten_dict(serialized_stats)
|
190 |
+
|
191 |
+
json_path = v2_dir / STATS_PATH
|
192 |
+
json_path.parent.mkdir(exist_ok=True, parents=True)
|
193 |
+
with open(json_path, "w") as f:
|
194 |
+
json.dump(serialized_stats, f, indent=4)
|
195 |
+
|
196 |
+
# Sanity check
|
197 |
+
with open(json_path) as f:
|
198 |
+
stats_json = json.load(f)
|
199 |
+
|
200 |
+
stats_json = flatten_dict(stats_json)
|
201 |
+
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
202 |
+
for key in stats:
|
203 |
+
torch.testing.assert_close(stats_json[key], stats[key])
|
204 |
+
|
205 |
+
|
206 |
+
def get_features_from_hf_dataset(
|
207 |
+
dataset: Dataset, robot_config: RobotConfig | None = None
|
208 |
+
) -> dict[str, list]:
|
209 |
+
robot_config = parse_robot_config(robot_config)
|
210 |
+
features = {}
|
211 |
+
for key, ft in dataset.features.items():
|
212 |
+
if isinstance(ft, datasets.Value):
|
213 |
+
dtype = ft.dtype
|
214 |
+
shape = (1,)
|
215 |
+
names = None
|
216 |
+
if isinstance(ft, datasets.Sequence):
|
217 |
+
assert isinstance(ft.feature, datasets.Value)
|
218 |
+
dtype = ft.feature.dtype
|
219 |
+
shape = (ft.length,)
|
220 |
+
motor_names = (
|
221 |
+
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
222 |
+
)
|
223 |
+
assert len(motor_names) == shape[0]
|
224 |
+
names = {"motors": motor_names}
|
225 |
+
elif isinstance(ft, datasets.Image):
|
226 |
+
dtype = "image"
|
227 |
+
image = dataset[0][key] # Assuming first row
|
228 |
+
channels = get_image_pixel_channels(image)
|
229 |
+
shape = (image.height, image.width, channels)
|
230 |
+
names = ["height", "width", "channels"]
|
231 |
+
elif ft._type == "VideoFrame":
|
232 |
+
dtype = "video"
|
233 |
+
shape = None # Add shape later
|
234 |
+
names = ["height", "width", "channels"]
|
235 |
+
|
236 |
+
features[key] = {
|
237 |
+
"dtype": dtype,
|
238 |
+
"shape": shape,
|
239 |
+
"names": names,
|
240 |
+
}
|
241 |
+
|
242 |
+
return features
|
243 |
+
|
244 |
+
|
245 |
+
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
246 |
+
df = dataset.to_pandas()
|
247 |
+
tasks = list(set(tasks_by_episodes.values()))
|
248 |
+
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
249 |
+
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
250 |
+
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
251 |
+
|
252 |
+
features = dataset.features
|
253 |
+
features["task_index"] = datasets.Value(dtype="int64")
|
254 |
+
dataset = Dataset.from_pandas(df, features=features, split="train")
|
255 |
+
return dataset, tasks
|
256 |
+
|
257 |
+
|
258 |
+
def add_task_index_from_tasks_col(
|
259 |
+
dataset: Dataset, tasks_col: str
|
260 |
+
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
|
261 |
+
df = dataset.to_pandas()
|
262 |
+
|
263 |
+
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
264 |
+
prefix_to_clean = "tf.Tensor(b'"
|
265 |
+
suffix_to_clean = "', shape=(), dtype=string)"
|
266 |
+
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
267 |
+
|
268 |
+
# Create task_index col
|
269 |
+
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
270 |
+
tasks = df[tasks_col].unique().tolist()
|
271 |
+
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
272 |
+
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
273 |
+
|
274 |
+
# Build the dataset back from df
|
275 |
+
features = dataset.features
|
276 |
+
features["task_index"] = datasets.Value(dtype="int64")
|
277 |
+
dataset = Dataset.from_pandas(df, features=features, split="train")
|
278 |
+
dataset = dataset.remove_columns(tasks_col)
|
279 |
+
|
280 |
+
return dataset, tasks, tasks_by_episode
|
281 |
+
|
282 |
+
|
283 |
+
def split_parquet_by_episodes(
|
284 |
+
dataset: Dataset,
|
285 |
+
total_episodes: int,
|
286 |
+
total_chunks: int,
|
287 |
+
output_dir: Path,
|
288 |
+
) -> list:
|
289 |
+
table = dataset.data.table
|
290 |
+
episode_lengths = []
|
291 |
+
for ep_chunk in range(total_chunks):
|
292 |
+
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
293 |
+
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
294 |
+
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
295 |
+
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
296 |
+
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
297 |
+
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
298 |
+
episode_lengths.insert(ep_idx, len(ep_table))
|
299 |
+
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
300 |
+
episode_chunk=ep_chunk, episode_index=ep_idx
|
301 |
+
)
|
302 |
+
pq.write_table(ep_table, output_file)
|
303 |
+
|
304 |
+
return episode_lengths
|
305 |
+
|
306 |
+
|
307 |
+
def move_videos(
|
308 |
+
repo_id: str,
|
309 |
+
video_keys: list[str],
|
310 |
+
total_episodes: int,
|
311 |
+
total_chunks: int,
|
312 |
+
work_dir: Path,
|
313 |
+
clean_gittatributes: Path,
|
314 |
+
branch: str = "main",
|
315 |
+
) -> None:
|
316 |
+
"""
|
317 |
+
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
|
318 |
+
commands to fetch git lfs video files references to move them into subdirectories without having to
|
319 |
+
actually download them.
|
320 |
+
"""
|
321 |
+
_lfs_clone(repo_id, work_dir, branch)
|
322 |
+
|
323 |
+
videos_moved = False
|
324 |
+
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
325 |
+
if len(video_files) == 0:
|
326 |
+
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
327 |
+
videos_moved = True # Videos have already been moved
|
328 |
+
|
329 |
+
assert len(video_files) == total_episodes * len(video_keys)
|
330 |
+
|
331 |
+
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
332 |
+
|
333 |
+
current_gittatributes = work_dir / ".gitattributes"
|
334 |
+
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
|
335 |
+
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
|
336 |
+
|
337 |
+
if lfs_untracked_videos:
|
338 |
+
fix_lfs_video_files_tracking(work_dir, video_files)
|
339 |
+
|
340 |
+
if videos_moved:
|
341 |
+
return
|
342 |
+
|
343 |
+
video_dirs = sorted(work_dir.glob("videos*/"))
|
344 |
+
for ep_chunk in range(total_chunks):
|
345 |
+
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
346 |
+
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
347 |
+
for vid_key in video_keys:
|
348 |
+
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
|
349 |
+
episode_chunk=ep_chunk, video_key=vid_key
|
350 |
+
)
|
351 |
+
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
352 |
+
|
353 |
+
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
354 |
+
target_path = DEFAULT_VIDEO_PATH.format(
|
355 |
+
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
356 |
+
)
|
357 |
+
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
358 |
+
if len(video_dirs) == 1:
|
359 |
+
video_path = video_dirs[0] / video_file
|
360 |
+
else:
|
361 |
+
for dir in video_dirs:
|
362 |
+
if (dir / video_file).is_file():
|
363 |
+
video_path = dir / video_file
|
364 |
+
break
|
365 |
+
|
366 |
+
video_path.rename(work_dir / target_path)
|
367 |
+
|
368 |
+
commit_message = "Move video files into chunk subdirectories"
|
369 |
+
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
370 |
+
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
371 |
+
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
372 |
+
|
373 |
+
|
374 |
+
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
375 |
+
"""
|
376 |
+
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
377 |
+
there's no other option than to download the actual files and reupload them with lfs tracking.
|
378 |
+
"""
|
379 |
+
for i in range(0, len(lfs_untracked_videos), 100):
|
380 |
+
files = lfs_untracked_videos[i : i + 100]
|
381 |
+
try:
|
382 |
+
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
383 |
+
except subprocess.CalledProcessError as e:
|
384 |
+
print("git rm --cached ERROR:")
|
385 |
+
print(e.stderr)
|
386 |
+
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
387 |
+
|
388 |
+
commit_message = "Track video files with git lfs"
|
389 |
+
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
390 |
+
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
391 |
+
|
392 |
+
|
393 |
+
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
394 |
+
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
395 |
+
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
396 |
+
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
397 |
+
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
398 |
+
|
399 |
+
|
400 |
+
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
401 |
+
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
402 |
+
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
403 |
+
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
404 |
+
subprocess.run(
|
405 |
+
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
406 |
+
check=True,
|
407 |
+
env=env,
|
408 |
+
)
|
409 |
+
|
410 |
+
|
411 |
+
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
412 |
+
lfs_tracked_files = subprocess.run(
|
413 |
+
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
414 |
+
)
|
415 |
+
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
416 |
+
return [f for f in video_files if f not in lfs_tracked_files]
|
417 |
+
|
418 |
+
|
419 |
+
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
420 |
+
# Assumes first episode
|
421 |
+
video_files = [
|
422 |
+
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
423 |
+
for vid_key in video_keys
|
424 |
+
]
|
425 |
+
hub_api = HfApi()
|
426 |
+
hub_api.snapshot_download(
|
427 |
+
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
428 |
+
)
|
429 |
+
videos_info_dict = {}
|
430 |
+
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
431 |
+
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
432 |
+
|
433 |
+
return videos_info_dict
|
434 |
+
|
435 |
+
|
436 |
+
def convert_dataset(
|
437 |
+
repo_id: str,
|
438 |
+
local_dir: Path,
|
439 |
+
single_task: str | None = None,
|
440 |
+
tasks_path: Path | None = None,
|
441 |
+
tasks_col: Path | None = None,
|
442 |
+
robot_config: RobotConfig | None = None,
|
443 |
+
test_branch: str | None = None,
|
444 |
+
**card_kwargs,
|
445 |
+
):
|
446 |
+
v1 = get_safe_version(repo_id, V16)
|
447 |
+
v1x_dir = local_dir / V16 / repo_id
|
448 |
+
v20_dir = local_dir / V20 / repo_id
|
449 |
+
v1x_dir.mkdir(parents=True, exist_ok=True)
|
450 |
+
v20_dir.mkdir(parents=True, exist_ok=True)
|
451 |
+
|
452 |
+
hub_api = HfApi()
|
453 |
+
hub_api.snapshot_download(
|
454 |
+
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
455 |
+
)
|
456 |
+
branch = "main"
|
457 |
+
if test_branch:
|
458 |
+
branch = test_branch
|
459 |
+
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
460 |
+
|
461 |
+
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
462 |
+
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
463 |
+
features = get_features_from_hf_dataset(dataset, robot_config)
|
464 |
+
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
465 |
+
|
466 |
+
if single_task and "language_instruction" in dataset.column_names:
|
467 |
+
logging.warning(
|
468 |
+
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
469 |
+
)
|
470 |
+
single_task = None
|
471 |
+
tasks_col = "language_instruction"
|
472 |
+
|
473 |
+
# Episodes & chunks
|
474 |
+
episode_indices = sorted(dataset.unique("episode_index"))
|
475 |
+
total_episodes = len(episode_indices)
|
476 |
+
assert episode_indices == list(range(total_episodes))
|
477 |
+
total_videos = total_episodes * len(video_keys)
|
478 |
+
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
479 |
+
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
480 |
+
total_chunks += 1
|
481 |
+
|
482 |
+
# Tasks
|
483 |
+
if single_task:
|
484 |
+
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
|
485 |
+
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
486 |
+
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
487 |
+
elif tasks_path:
|
488 |
+
tasks_by_episodes = load_json(tasks_path)
|
489 |
+
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
490 |
+
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
491 |
+
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
492 |
+
elif tasks_col:
|
493 |
+
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
494 |
+
else:
|
495 |
+
raise ValueError
|
496 |
+
|
497 |
+
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
498 |
+
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
499 |
+
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
500 |
+
features["task_index"] = {
|
501 |
+
"dtype": "int64",
|
502 |
+
"shape": (1,),
|
503 |
+
"names": None,
|
504 |
+
}
|
505 |
+
|
506 |
+
# Videos
|
507 |
+
if video_keys:
|
508 |
+
assert metadata_v1.get("video", False)
|
509 |
+
dataset = dataset.remove_columns(video_keys)
|
510 |
+
clean_gitattr = Path(
|
511 |
+
hub_api.hf_hub_download(
|
512 |
+
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
513 |
+
)
|
514 |
+
).absolute()
|
515 |
+
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
516 |
+
move_videos(
|
517 |
+
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
518 |
+
)
|
519 |
+
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
520 |
+
for key in video_keys:
|
521 |
+
features[key]["shape"] = (
|
522 |
+
videos_info[key].pop("video.height"),
|
523 |
+
videos_info[key].pop("video.width"),
|
524 |
+
videos_info[key].pop("video.channels"),
|
525 |
+
)
|
526 |
+
features[key]["video_info"] = videos_info[key]
|
527 |
+
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
528 |
+
if "encoding" in metadata_v1:
|
529 |
+
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
530 |
+
else:
|
531 |
+
assert metadata_v1.get("video", 0) == 0
|
532 |
+
videos_info = None
|
533 |
+
|
534 |
+
# Split data into 1 parquet file by episode
|
535 |
+
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
536 |
+
|
537 |
+
if robot_config is not None:
|
538 |
+
robot_type = robot_config.type
|
539 |
+
repo_tags = [robot_type]
|
540 |
+
else:
|
541 |
+
robot_type = "unknown"
|
542 |
+
repo_tags = None
|
543 |
+
|
544 |
+
# Episodes
|
545 |
+
episodes = [
|
546 |
+
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
547 |
+
for ep_idx in episode_indices
|
548 |
+
]
|
549 |
+
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
550 |
+
|
551 |
+
# Assemble metadata v2.0
|
552 |
+
metadata_v2_0 = {
|
553 |
+
"codebase_version": V20,
|
554 |
+
"robot_type": robot_type,
|
555 |
+
"total_episodes": total_episodes,
|
556 |
+
"total_frames": len(dataset),
|
557 |
+
"total_tasks": len(tasks),
|
558 |
+
"total_videos": total_videos,
|
559 |
+
"total_chunks": total_chunks,
|
560 |
+
"chunks_size": DEFAULT_CHUNK_SIZE,
|
561 |
+
"fps": metadata_v1["fps"],
|
562 |
+
"splits": {"train": f"0:{total_episodes}"},
|
563 |
+
"data_path": DEFAULT_PARQUET_PATH,
|
564 |
+
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
565 |
+
"features": features,
|
566 |
+
}
|
567 |
+
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
568 |
+
convert_stats_to_json(v1x_dir, v20_dir)
|
569 |
+
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
570 |
+
|
571 |
+
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
572 |
+
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
573 |
+
|
574 |
+
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
575 |
+
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
576 |
+
|
577 |
+
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
578 |
+
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
579 |
+
|
580 |
+
hub_api.upload_folder(
|
581 |
+
repo_id=repo_id,
|
582 |
+
path_in_repo="data",
|
583 |
+
folder_path=v20_dir / "data",
|
584 |
+
repo_type="dataset",
|
585 |
+
revision=branch,
|
586 |
+
)
|
587 |
+
hub_api.upload_folder(
|
588 |
+
repo_id=repo_id,
|
589 |
+
path_in_repo="meta",
|
590 |
+
folder_path=v20_dir / "meta",
|
591 |
+
repo_type="dataset",
|
592 |
+
revision=branch,
|
593 |
+
)
|
594 |
+
|
595 |
+
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
596 |
+
|
597 |
+
if not test_branch:
|
598 |
+
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
599 |
+
|
600 |
+
|
601 |
+
def main():
|
602 |
+
parser = argparse.ArgumentParser()
|
603 |
+
task_args = parser.add_mutually_exclusive_group(required=True)
|
604 |
+
|
605 |
+
parser.add_argument(
|
606 |
+
"--repo-id",
|
607 |
+
type=str,
|
608 |
+
required=True,
|
609 |
+
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
610 |
+
)
|
611 |
+
task_args.add_argument(
|
612 |
+
"--single-task",
|
613 |
+
type=str,
|
614 |
+
help="A short but accurate description of the single task performed in the dataset.",
|
615 |
+
)
|
616 |
+
task_args.add_argument(
|
617 |
+
"--tasks-col",
|
618 |
+
type=str,
|
619 |
+
help="The name of the column containing language instructions",
|
620 |
+
)
|
621 |
+
task_args.add_argument(
|
622 |
+
"--tasks-path",
|
623 |
+
type=Path,
|
624 |
+
help="The path to a .json file containing one language instruction for each episode_index",
|
625 |
+
)
|
626 |
+
parser.add_argument(
|
627 |
+
"--robot",
|
628 |
+
type=str,
|
629 |
+
default=None,
|
630 |
+
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
|
631 |
+
)
|
632 |
+
parser.add_argument(
|
633 |
+
"--local-dir",
|
634 |
+
type=Path,
|
635 |
+
default=None,
|
636 |
+
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
637 |
+
)
|
638 |
+
parser.add_argument(
|
639 |
+
"--license",
|
640 |
+
type=str,
|
641 |
+
default="apache-2.0",
|
642 |
+
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
643 |
+
)
|
644 |
+
parser.add_argument(
|
645 |
+
"--test-branch",
|
646 |
+
type=str,
|
647 |
+
default=None,
|
648 |
+
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
649 |
+
)
|
650 |
+
|
651 |
+
args = parser.parse_args()
|
652 |
+
if not args.local_dir:
|
653 |
+
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
654 |
+
|
655 |
+
if args.robot is not None:
|
656 |
+
robot_config = make_robot_config(args.robot)
|
657 |
+
|
658 |
+
del args.robot
|
659 |
+
|
660 |
+
convert_dataset(**vars(args), robot_config=robot_config)
|
661 |
+
|
662 |
+
|
663 |
+
if __name__ == "__main__":
|
664 |
+
main()
|
lerobot/common/datasets/v21/_remove_language_instruction.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import logging
|
16 |
+
import traceback
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
from datasets import get_dataset_config_info
|
20 |
+
from huggingface_hub import HfApi
|
21 |
+
|
22 |
+
from lerobot import available_datasets
|
23 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
24 |
+
from lerobot.common.datasets.utils import INFO_PATH, write_info
|
25 |
+
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
26 |
+
|
27 |
+
LOCAL_DIR = Path("data/")
|
28 |
+
|
29 |
+
hub_api = HfApi()
|
30 |
+
|
31 |
+
|
32 |
+
def fix_dataset(repo_id: str) -> str:
|
33 |
+
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
34 |
+
return f"{repo_id}: skipped (not in {V20})."
|
35 |
+
|
36 |
+
dataset_info = get_dataset_config_info(repo_id, "default")
|
37 |
+
with SuppressWarnings():
|
38 |
+
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
39 |
+
|
40 |
+
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
41 |
+
parquet_features = set(dataset_info.features)
|
42 |
+
|
43 |
+
diff_parquet_meta = parquet_features - meta_features
|
44 |
+
diff_meta_parquet = meta_features - parquet_features
|
45 |
+
|
46 |
+
if diff_parquet_meta:
|
47 |
+
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
48 |
+
|
49 |
+
if not diff_meta_parquet:
|
50 |
+
return f"{repo_id}: skipped (no diff)"
|
51 |
+
|
52 |
+
if diff_meta_parquet:
|
53 |
+
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
54 |
+
assert diff_meta_parquet == {"language_instruction"}
|
55 |
+
lerobot_metadata.features.pop("language_instruction")
|
56 |
+
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
57 |
+
commit_info = hub_api.upload_file(
|
58 |
+
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
59 |
+
path_in_repo=INFO_PATH,
|
60 |
+
repo_id=repo_id,
|
61 |
+
repo_type="dataset",
|
62 |
+
revision=V20,
|
63 |
+
commit_message="Remove 'language_instruction'",
|
64 |
+
create_pr=True,
|
65 |
+
)
|
66 |
+
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
67 |
+
|
68 |
+
|
69 |
+
def batch_fix():
|
70 |
+
status = {}
|
71 |
+
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
72 |
+
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
73 |
+
for num, repo_id in enumerate(available_datasets):
|
74 |
+
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
75 |
+
print("---------------------------------------------------------")
|
76 |
+
try:
|
77 |
+
status = fix_dataset(repo_id)
|
78 |
+
except Exception:
|
79 |
+
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
80 |
+
|
81 |
+
logging.info(status)
|
82 |
+
with open(logfile, "a") as file:
|
83 |
+
file.write(status + "\n")
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
batch_fix()
|
lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import traceback
|
22 |
+
from pathlib import Path
|
23 |
+
|
24 |
+
from huggingface_hub import HfApi
|
25 |
+
|
26 |
+
from lerobot import available_datasets
|
27 |
+
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
28 |
+
|
29 |
+
LOCAL_DIR = Path("data/")
|
30 |
+
|
31 |
+
|
32 |
+
def batch_convert():
|
33 |
+
status = {}
|
34 |
+
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
35 |
+
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
36 |
+
hub_api = HfApi()
|
37 |
+
for num, repo_id in enumerate(available_datasets):
|
38 |
+
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
39 |
+
print("---------------------------------------------------------")
|
40 |
+
try:
|
41 |
+
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
42 |
+
status = f"{repo_id}: success (already in {V21})."
|
43 |
+
else:
|
44 |
+
convert_dataset(repo_id)
|
45 |
+
status = f"{repo_id}: success."
|
46 |
+
except Exception:
|
47 |
+
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
48 |
+
|
49 |
+
with open(logfile, "a") as file:
|
50 |
+
file.write(status + "\n")
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
batch_convert()
|
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
17 |
+
2.1. It will:
|
18 |
+
|
19 |
+
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
20 |
+
- Check consistency between these new stats and the old ones.
|
21 |
+
- Remove the deprecated `stats.json`.
|
22 |
+
- Update codebase_version in `info.json`.
|
23 |
+
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
24 |
+
|
25 |
+
Usage:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
29 |
+
--repo-id=aliberts/koch_tutorial
|
30 |
+
```
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
import argparse
|
35 |
+
import logging
|
36 |
+
|
37 |
+
from huggingface_hub import HfApi
|
38 |
+
|
39 |
+
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
40 |
+
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
41 |
+
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
42 |
+
|
43 |
+
V20 = "v2.0"
|
44 |
+
V21 = "v2.1"
|
45 |
+
|
46 |
+
|
47 |
+
class SuppressWarnings:
|
48 |
+
def __enter__(self):
|
49 |
+
self.previous_level = logging.getLogger().getEffectiveLevel()
|
50 |
+
logging.getLogger().setLevel(logging.ERROR)
|
51 |
+
|
52 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
53 |
+
logging.getLogger().setLevel(self.previous_level)
|
54 |
+
|
55 |
+
|
56 |
+
def convert_dataset(
|
57 |
+
repo_id: str,
|
58 |
+
branch: str | None = None,
|
59 |
+
num_workers: int = 4,
|
60 |
+
):
|
61 |
+
with SuppressWarnings():
|
62 |
+
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
63 |
+
|
64 |
+
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
65 |
+
(dataset.root / EPISODES_STATS_PATH).unlink()
|
66 |
+
|
67 |
+
convert_stats(dataset, num_workers=num_workers)
|
68 |
+
ref_stats = load_stats(dataset.root)
|
69 |
+
check_aggregate_stats(dataset, ref_stats)
|
70 |
+
|
71 |
+
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
72 |
+
write_info(dataset.meta.info, dataset.root)
|
73 |
+
|
74 |
+
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
75 |
+
|
76 |
+
# delete old stats.json file
|
77 |
+
if (dataset.root / STATS_PATH).is_file:
|
78 |
+
(dataset.root / STATS_PATH).unlink()
|
79 |
+
|
80 |
+
hub_api = HfApi()
|
81 |
+
if hub_api.file_exists(
|
82 |
+
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
83 |
+
):
|
84 |
+
hub_api.delete_file(
|
85 |
+
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
86 |
+
)
|
87 |
+
|
88 |
+
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
parser = argparse.ArgumentParser()
|
93 |
+
parser.add_argument(
|
94 |
+
"--repo-id",
|
95 |
+
type=str,
|
96 |
+
required=True,
|
97 |
+
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
98 |
+
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--branch",
|
102 |
+
type=str,
|
103 |
+
default=None,
|
104 |
+
help="Repo branch to push your dataset. Defaults to the main branch.",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--num-workers",
|
108 |
+
type=int,
|
109 |
+
default=4,
|
110 |
+
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
111 |
+
)
|
112 |
+
|
113 |
+
args = parser.parse_args()
|
114 |
+
convert_dataset(**vars(args))
|
lerobot/common/datasets/v21/convert_stats.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
21 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
22 |
+
from lerobot.common.datasets.utils import write_episode_stats
|
23 |
+
|
24 |
+
|
25 |
+
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
26 |
+
ep_len = dataset.meta.episodes[episode_index]["length"]
|
27 |
+
sampled_indices = sample_indices(ep_len)
|
28 |
+
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
29 |
+
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
30 |
+
return video_frames[ft_key].numpy()
|
31 |
+
|
32 |
+
|
33 |
+
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
34 |
+
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
35 |
+
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
36 |
+
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
37 |
+
|
38 |
+
ep_stats = {}
|
39 |
+
for key, ft in dataset.features.items():
|
40 |
+
if ft["dtype"] == "video":
|
41 |
+
# We sample only for videos
|
42 |
+
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
43 |
+
else:
|
44 |
+
ep_ft_data = np.array(ep_data[key])
|
45 |
+
|
46 |
+
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
47 |
+
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
48 |
+
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
49 |
+
|
50 |
+
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
51 |
+
ep_stats[key] = {
|
52 |
+
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
53 |
+
}
|
54 |
+
|
55 |
+
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
56 |
+
|
57 |
+
|
58 |
+
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
59 |
+
assert dataset.episodes is None
|
60 |
+
print("Computing episodes stats")
|
61 |
+
total_episodes = dataset.meta.total_episodes
|
62 |
+
if num_workers > 0:
|
63 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
64 |
+
futures = {
|
65 |
+
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
66 |
+
for ep_idx in range(total_episodes)
|
67 |
+
}
|
68 |
+
for future in tqdm(as_completed(futures), total=total_episodes):
|
69 |
+
future.result()
|
70 |
+
else:
|
71 |
+
for ep_idx in tqdm(range(total_episodes)):
|
72 |
+
convert_episode_stats(dataset, ep_idx)
|
73 |
+
|
74 |
+
for ep_idx in tqdm(range(total_episodes)):
|
75 |
+
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
76 |
+
|
77 |
+
|
78 |
+
def check_aggregate_stats(
|
79 |
+
dataset: LeRobotDataset,
|
80 |
+
reference_stats: dict[str, dict[str, np.ndarray]],
|
81 |
+
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
82 |
+
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
83 |
+
):
|
84 |
+
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
85 |
+
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
86 |
+
for key, ft in dataset.features.items():
|
87 |
+
# These values might need some fine-tuning
|
88 |
+
if ft["dtype"] == "video":
|
89 |
+
# to account for image sub-sampling
|
90 |
+
rtol, atol = video_rtol_atol
|
91 |
+
else:
|
92 |
+
rtol, atol = default_rtol_atol
|
93 |
+
|
94 |
+
for stat, val in agg_stats[key].items():
|
95 |
+
if key in reference_stats and stat in reference_stats[key]:
|
96 |
+
err_msg = f"feature='{key}' stats='{stat}'"
|
97 |
+
np.testing.assert_allclose(
|
98 |
+
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
99 |
+
)
|
lerobot/common/datasets/video_utils.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import importlib
|
17 |
+
import json
|
18 |
+
import logging
|
19 |
+
import subprocess
|
20 |
+
import warnings
|
21 |
+
from collections import OrderedDict
|
22 |
+
from dataclasses import dataclass, field
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any, ClassVar
|
25 |
+
|
26 |
+
import pyarrow as pa
|
27 |
+
import torch
|
28 |
+
import torchvision
|
29 |
+
from datasets.features.features import register_feature
|
30 |
+
from PIL import Image
|
31 |
+
|
32 |
+
|
33 |
+
def get_safe_default_codec():
|
34 |
+
if importlib.util.find_spec("torchcodec"):
|
35 |
+
return "torchcodec"
|
36 |
+
else:
|
37 |
+
logging.warning(
|
38 |
+
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
39 |
+
)
|
40 |
+
return "pyav"
|
41 |
+
|
42 |
+
|
43 |
+
def decode_video_frames(
|
44 |
+
video_path: Path | str,
|
45 |
+
timestamps: list[float],
|
46 |
+
tolerance_s: float,
|
47 |
+
backend: str | None = None,
|
48 |
+
) -> torch.Tensor:
|
49 |
+
"""
|
50 |
+
Decodes video frames using the specified backend.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
video_path (Path): Path to the video file.
|
54 |
+
timestamps (list[float]): List of timestamps to extract frames.
|
55 |
+
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
56 |
+
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: Decoded frames.
|
60 |
+
|
61 |
+
Currently supports torchcodec on cpu and pyav.
|
62 |
+
"""
|
63 |
+
if backend is None:
|
64 |
+
backend = get_safe_default_codec()
|
65 |
+
if backend == "torchcodec":
|
66 |
+
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
67 |
+
elif backend in ["pyav", "video_reader"]:
|
68 |
+
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
69 |
+
else:
|
70 |
+
raise ValueError(f"Unsupported video backend: {backend}")
|
71 |
+
|
72 |
+
|
73 |
+
def decode_video_frames_torchvision(
|
74 |
+
video_path: Path | str,
|
75 |
+
timestamps: list[float],
|
76 |
+
tolerance_s: float,
|
77 |
+
backend: str = "pyav",
|
78 |
+
log_loaded_timestamps: bool = False,
|
79 |
+
) -> torch.Tensor:
|
80 |
+
"""Loads frames associated to the requested timestamps of a video
|
81 |
+
|
82 |
+
The backend can be either "pyav" (default) or "video_reader".
|
83 |
+
"video_reader" requires installing torchvision from source, see:
|
84 |
+
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
85 |
+
(note that you need to compile against ffmpeg<4.3)
|
86 |
+
|
87 |
+
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
88 |
+
For more info on video decoding, see `benchmark/video/README.md`
|
89 |
+
|
90 |
+
See torchvision doc for more info on these two backends:
|
91 |
+
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
92 |
+
|
93 |
+
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
94 |
+
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
95 |
+
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
96 |
+
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
97 |
+
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
98 |
+
"""
|
99 |
+
video_path = str(video_path)
|
100 |
+
|
101 |
+
# set backend
|
102 |
+
keyframes_only = False
|
103 |
+
torchvision.set_video_backend(backend)
|
104 |
+
if backend == "pyav":
|
105 |
+
keyframes_only = True # pyav doesnt support accuracte seek
|
106 |
+
|
107 |
+
# set a video stream reader
|
108 |
+
# TODO(rcadene): also load audio stream at the same time
|
109 |
+
reader = torchvision.io.VideoReader(video_path, "video")
|
110 |
+
|
111 |
+
# set the first and last requested timestamps
|
112 |
+
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
113 |
+
first_ts = min(timestamps)
|
114 |
+
last_ts = max(timestamps)
|
115 |
+
|
116 |
+
# access closest key frame of the first requested frame
|
117 |
+
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
118 |
+
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
119 |
+
reader.seek(first_ts, keyframes_only=keyframes_only)
|
120 |
+
|
121 |
+
# load all frames until last requested frame
|
122 |
+
loaded_frames = []
|
123 |
+
loaded_ts = []
|
124 |
+
for frame in reader:
|
125 |
+
current_ts = frame["pts"]
|
126 |
+
if log_loaded_timestamps:
|
127 |
+
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
128 |
+
loaded_frames.append(frame["data"])
|
129 |
+
loaded_ts.append(current_ts)
|
130 |
+
if current_ts >= last_ts:
|
131 |
+
break
|
132 |
+
|
133 |
+
if backend == "pyav":
|
134 |
+
reader.container.close()
|
135 |
+
|
136 |
+
reader = None
|
137 |
+
|
138 |
+
query_ts = torch.tensor(timestamps)
|
139 |
+
loaded_ts = torch.tensor(loaded_ts)
|
140 |
+
|
141 |
+
# compute distances between each query timestamp and timestamps of all loaded frames
|
142 |
+
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
143 |
+
min_, argmin_ = dist.min(1)
|
144 |
+
|
145 |
+
is_within_tol = min_ < tolerance_s
|
146 |
+
assert is_within_tol.all(), (
|
147 |
+
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
148 |
+
"It means that the closest frame that can be loaded from the video is too far away in time."
|
149 |
+
"This might be due to synchronization issues with timestamps during data collection."
|
150 |
+
"To be safe, we advise to ignore this item during training."
|
151 |
+
f"\nqueried timestamps: {query_ts}"
|
152 |
+
f"\nloaded timestamps: {loaded_ts}"
|
153 |
+
f"\nvideo: {video_path}"
|
154 |
+
f"\nbackend: {backend}"
|
155 |
+
)
|
156 |
+
|
157 |
+
# get closest frames to the query timestamps
|
158 |
+
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
159 |
+
closest_ts = loaded_ts[argmin_]
|
160 |
+
|
161 |
+
if log_loaded_timestamps:
|
162 |
+
logging.info(f"{closest_ts=}")
|
163 |
+
|
164 |
+
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
165 |
+
closest_frames = closest_frames.type(torch.float32) / 255
|
166 |
+
|
167 |
+
assert len(timestamps) == len(closest_frames)
|
168 |
+
return closest_frames
|
169 |
+
|
170 |
+
|
171 |
+
def decode_video_frames_torchcodec(
|
172 |
+
video_path: Path | str,
|
173 |
+
timestamps: list[float],
|
174 |
+
tolerance_s: float,
|
175 |
+
device: str = "cpu",
|
176 |
+
log_loaded_timestamps: bool = False,
|
177 |
+
) -> torch.Tensor:
|
178 |
+
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
179 |
+
|
180 |
+
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
181 |
+
|
182 |
+
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
183 |
+
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
184 |
+
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
185 |
+
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
186 |
+
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
187 |
+
"""
|
188 |
+
|
189 |
+
if importlib.util.find_spec("torchcodec"):
|
190 |
+
from torchcodec.decoders import VideoDecoder
|
191 |
+
else:
|
192 |
+
raise ImportError("torchcodec is required but not available.")
|
193 |
+
|
194 |
+
# initialize video decoder
|
195 |
+
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
196 |
+
loaded_frames = []
|
197 |
+
loaded_ts = []
|
198 |
+
# get metadata for frame information
|
199 |
+
metadata = decoder.metadata
|
200 |
+
average_fps = metadata.average_fps
|
201 |
+
|
202 |
+
# convert timestamps to frame indices
|
203 |
+
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
204 |
+
|
205 |
+
# retrieve frames based on indices
|
206 |
+
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
207 |
+
|
208 |
+
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
209 |
+
loaded_frames.append(frame)
|
210 |
+
loaded_ts.append(pts.item())
|
211 |
+
if log_loaded_timestamps:
|
212 |
+
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
213 |
+
|
214 |
+
query_ts = torch.tensor(timestamps)
|
215 |
+
loaded_ts = torch.tensor(loaded_ts)
|
216 |
+
|
217 |
+
# compute distances between each query timestamp and loaded timestamps
|
218 |
+
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
219 |
+
min_, argmin_ = dist.min(1)
|
220 |
+
|
221 |
+
is_within_tol = min_ < tolerance_s
|
222 |
+
assert is_within_tol.all(), (
|
223 |
+
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
224 |
+
"It means that the closest frame that can be loaded from the video is too far away in time."
|
225 |
+
"This might be due to synchronization issues with timestamps during data collection."
|
226 |
+
"To be safe, we advise to ignore this item during training."
|
227 |
+
f"\nqueried timestamps: {query_ts}"
|
228 |
+
f"\nloaded timestamps: {loaded_ts}"
|
229 |
+
f"\nvideo: {video_path}"
|
230 |
+
)
|
231 |
+
|
232 |
+
# get closest frames to the query timestamps
|
233 |
+
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
234 |
+
closest_ts = loaded_ts[argmin_]
|
235 |
+
|
236 |
+
if log_loaded_timestamps:
|
237 |
+
logging.info(f"{closest_ts=}")
|
238 |
+
|
239 |
+
# convert to float32 in [0,1] range (channel first)
|
240 |
+
closest_frames = closest_frames.type(torch.float32) / 255
|
241 |
+
|
242 |
+
assert len(timestamps) == len(closest_frames)
|
243 |
+
return closest_frames
|
244 |
+
|
245 |
+
|
246 |
+
def encode_video_frames(
|
247 |
+
imgs_dir: Path | str,
|
248 |
+
video_path: Path | str,
|
249 |
+
fps: int,
|
250 |
+
vcodec: str = "libsvtav1",
|
251 |
+
pix_fmt: str = "yuv420p",
|
252 |
+
g: int | None = 2,
|
253 |
+
crf: int | None = 30,
|
254 |
+
fast_decode: int = 0,
|
255 |
+
log_level: str | None = "error",
|
256 |
+
overwrite: bool = False,
|
257 |
+
) -> None:
|
258 |
+
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
259 |
+
video_path = Path(video_path)
|
260 |
+
imgs_dir = Path(imgs_dir)
|
261 |
+
video_path.parent.mkdir(parents=True, exist_ok=True)
|
262 |
+
|
263 |
+
ffmpeg_args = OrderedDict(
|
264 |
+
[
|
265 |
+
("-f", "image2"),
|
266 |
+
("-r", str(fps)),
|
267 |
+
("-i", str(imgs_dir / "frame_%06d.png")),
|
268 |
+
("-vcodec", vcodec),
|
269 |
+
("-pix_fmt", pix_fmt),
|
270 |
+
]
|
271 |
+
)
|
272 |
+
|
273 |
+
if g is not None:
|
274 |
+
ffmpeg_args["-g"] = str(g)
|
275 |
+
|
276 |
+
if crf is not None:
|
277 |
+
ffmpeg_args["-crf"] = str(crf)
|
278 |
+
|
279 |
+
if fast_decode:
|
280 |
+
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
281 |
+
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
282 |
+
ffmpeg_args[key] = value
|
283 |
+
|
284 |
+
if log_level is not None:
|
285 |
+
ffmpeg_args["-loglevel"] = str(log_level)
|
286 |
+
|
287 |
+
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
288 |
+
if overwrite:
|
289 |
+
ffmpeg_args.append("-y")
|
290 |
+
|
291 |
+
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
292 |
+
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
293 |
+
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
294 |
+
|
295 |
+
if not video_path.exists():
|
296 |
+
raise OSError(
|
297 |
+
f"Video encoding did not work. File not found: {video_path}. "
|
298 |
+
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
+
@dataclass
|
303 |
+
class VideoFrame:
|
304 |
+
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
305 |
+
"""
|
306 |
+
Provides a type for a dataset containing video frames.
|
307 |
+
|
308 |
+
Example:
|
309 |
+
|
310 |
+
```python
|
311 |
+
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
312 |
+
features = {"image": VideoFrame()}
|
313 |
+
Dataset.from_dict(data_dict, features=Features(features))
|
314 |
+
```
|
315 |
+
"""
|
316 |
+
|
317 |
+
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
318 |
+
_type: str = field(default="VideoFrame", init=False, repr=False)
|
319 |
+
|
320 |
+
def __call__(self):
|
321 |
+
return self.pa_type
|
322 |
+
|
323 |
+
|
324 |
+
with warnings.catch_warnings():
|
325 |
+
warnings.filterwarnings(
|
326 |
+
"ignore",
|
327 |
+
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
328 |
+
category=UserWarning,
|
329 |
+
)
|
330 |
+
# to make VideoFrame available in HuggingFace `datasets`
|
331 |
+
register_feature(VideoFrame, "VideoFrame")
|
332 |
+
|
333 |
+
|
334 |
+
def get_audio_info(video_path: Path | str) -> dict:
|
335 |
+
ffprobe_audio_cmd = [
|
336 |
+
"ffprobe",
|
337 |
+
"-v",
|
338 |
+
"error",
|
339 |
+
"-select_streams",
|
340 |
+
"a:0",
|
341 |
+
"-show_entries",
|
342 |
+
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
343 |
+
"-of",
|
344 |
+
"json",
|
345 |
+
str(video_path),
|
346 |
+
]
|
347 |
+
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
348 |
+
if result.returncode != 0:
|
349 |
+
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
350 |
+
|
351 |
+
info = json.loads(result.stdout)
|
352 |
+
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
353 |
+
if audio_stream_info is None:
|
354 |
+
return {"has_audio": False}
|
355 |
+
|
356 |
+
# Return the information, defaulting to None if no audio stream is present
|
357 |
+
return {
|
358 |
+
"has_audio": True,
|
359 |
+
"audio.channels": audio_stream_info.get("channels", None),
|
360 |
+
"audio.codec": audio_stream_info.get("codec_name", None),
|
361 |
+
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
362 |
+
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
363 |
+
if audio_stream_info.get("sample_rate")
|
364 |
+
else None,
|
365 |
+
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
366 |
+
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
367 |
+
}
|
368 |
+
|
369 |
+
|
370 |
+
def get_video_info(video_path: Path | str) -> dict:
|
371 |
+
ffprobe_video_cmd = [
|
372 |
+
"ffprobe",
|
373 |
+
"-v",
|
374 |
+
"error",
|
375 |
+
"-select_streams",
|
376 |
+
"v:0",
|
377 |
+
"-show_entries",
|
378 |
+
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
379 |
+
"-of",
|
380 |
+
"json",
|
381 |
+
str(video_path),
|
382 |
+
]
|
383 |
+
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
384 |
+
if result.returncode != 0:
|
385 |
+
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
386 |
+
|
387 |
+
info = json.loads(result.stdout)
|
388 |
+
video_stream_info = info["streams"][0]
|
389 |
+
|
390 |
+
# Calculate fps from r_frame_rate
|
391 |
+
r_frame_rate = video_stream_info["r_frame_rate"]
|
392 |
+
num, denom = map(int, r_frame_rate.split("/"))
|
393 |
+
fps = num / denom
|
394 |
+
|
395 |
+
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
396 |
+
|
397 |
+
video_info = {
|
398 |
+
"video.fps": fps,
|
399 |
+
"video.height": video_stream_info["height"],
|
400 |
+
"video.width": video_stream_info["width"],
|
401 |
+
"video.channels": pixel_channels,
|
402 |
+
"video.codec": video_stream_info["codec_name"],
|
403 |
+
"video.pix_fmt": video_stream_info["pix_fmt"],
|
404 |
+
"video.is_depth_map": False,
|
405 |
+
**get_audio_info(video_path),
|
406 |
+
}
|
407 |
+
|
408 |
+
return video_info
|
409 |
+
|
410 |
+
|
411 |
+
def get_video_pixel_channels(pix_fmt: str) -> int:
|
412 |
+
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
413 |
+
return 1
|
414 |
+
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
415 |
+
return 4
|
416 |
+
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
417 |
+
return 3
|
418 |
+
else:
|
419 |
+
raise ValueError("Unknown format")
|
420 |
+
|
421 |
+
|
422 |
+
def get_image_pixel_channels(image: Image):
|
423 |
+
if image.mode == "L":
|
424 |
+
return 1 # Grayscale
|
425 |
+
elif image.mode == "LA":
|
426 |
+
return 2 # Grayscale + Alpha
|
427 |
+
elif image.mode == "RGB":
|
428 |
+
return 3 # RGB
|
429 |
+
elif image.mode == "RGBA":
|
430 |
+
return 4 # RGBA
|
431 |
+
else:
|
432 |
+
raise ValueError("Unknown format")
|
lerobot/common/envs/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
lerobot/common/envs/configs.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import abc
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
|
18 |
+
import draccus
|
19 |
+
|
20 |
+
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
21 |
+
from lerobot.configs.types import FeatureType, PolicyFeature
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
26 |
+
task: str | None = None
|
27 |
+
fps: int = 30
|
28 |
+
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
29 |
+
features_map: dict[str, str] = field(default_factory=dict)
|
30 |
+
|
31 |
+
@property
|
32 |
+
def type(self) -> str:
|
33 |
+
return self.get_choice_name(self.__class__)
|
34 |
+
|
35 |
+
@abc.abstractproperty
|
36 |
+
def gym_kwargs(self) -> dict:
|
37 |
+
raise NotImplementedError()
|
38 |
+
|
39 |
+
|
40 |
+
@EnvConfig.register_subclass("aloha")
|
41 |
+
@dataclass
|
42 |
+
class AlohaEnv(EnvConfig):
|
43 |
+
task: str = "AlohaInsertion-v0"
|
44 |
+
fps: int = 50
|
45 |
+
episode_length: int = 400
|
46 |
+
obs_type: str = "pixels_agent_pos"
|
47 |
+
render_mode: str = "rgb_array"
|
48 |
+
features: dict[str, PolicyFeature] = field(
|
49 |
+
default_factory=lambda: {
|
50 |
+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
51 |
+
}
|
52 |
+
)
|
53 |
+
features_map: dict[str, str] = field(
|
54 |
+
default_factory=lambda: {
|
55 |
+
"action": ACTION,
|
56 |
+
"agent_pos": OBS_ROBOT,
|
57 |
+
"top": f"{OBS_IMAGE}.top",
|
58 |
+
"pixels/top": f"{OBS_IMAGES}.top",
|
59 |
+
}
|
60 |
+
)
|
61 |
+
|
62 |
+
def __post_init__(self):
|
63 |
+
if self.obs_type == "pixels":
|
64 |
+
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
65 |
+
elif self.obs_type == "pixels_agent_pos":
|
66 |
+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
67 |
+
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
68 |
+
|
69 |
+
@property
|
70 |
+
def gym_kwargs(self) -> dict:
|
71 |
+
return {
|
72 |
+
"obs_type": self.obs_type,
|
73 |
+
"render_mode": self.render_mode,
|
74 |
+
"max_episode_steps": self.episode_length,
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
@EnvConfig.register_subclass("pusht")
|
79 |
+
@dataclass
|
80 |
+
class PushtEnv(EnvConfig):
|
81 |
+
task: str = "PushT-v0"
|
82 |
+
fps: int = 10
|
83 |
+
episode_length: int = 300
|
84 |
+
obs_type: str = "pixels_agent_pos"
|
85 |
+
render_mode: str = "rgb_array"
|
86 |
+
visualization_width: int = 384
|
87 |
+
visualization_height: int = 384
|
88 |
+
features: dict[str, PolicyFeature] = field(
|
89 |
+
default_factory=lambda: {
|
90 |
+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
91 |
+
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
92 |
+
}
|
93 |
+
)
|
94 |
+
features_map: dict[str, str] = field(
|
95 |
+
default_factory=lambda: {
|
96 |
+
"action": ACTION,
|
97 |
+
"agent_pos": OBS_ROBOT,
|
98 |
+
"environment_state": OBS_ENV,
|
99 |
+
"pixels": OBS_IMAGE,
|
100 |
+
}
|
101 |
+
)
|
102 |
+
|
103 |
+
def __post_init__(self):
|
104 |
+
if self.obs_type == "pixels_agent_pos":
|
105 |
+
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
106 |
+
elif self.obs_type == "environment_state_agent_pos":
|
107 |
+
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
108 |
+
|
109 |
+
@property
|
110 |
+
def gym_kwargs(self) -> dict:
|
111 |
+
return {
|
112 |
+
"obs_type": self.obs_type,
|
113 |
+
"render_mode": self.render_mode,
|
114 |
+
"visualization_width": self.visualization_width,
|
115 |
+
"visualization_height": self.visualization_height,
|
116 |
+
"max_episode_steps": self.episode_length,
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
@EnvConfig.register_subclass("xarm")
|
121 |
+
@dataclass
|
122 |
+
class XarmEnv(EnvConfig):
|
123 |
+
task: str = "XarmLift-v0"
|
124 |
+
fps: int = 15
|
125 |
+
episode_length: int = 200
|
126 |
+
obs_type: str = "pixels_agent_pos"
|
127 |
+
render_mode: str = "rgb_array"
|
128 |
+
visualization_width: int = 384
|
129 |
+
visualization_height: int = 384
|
130 |
+
features: dict[str, PolicyFeature] = field(
|
131 |
+
default_factory=lambda: {
|
132 |
+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
133 |
+
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
134 |
+
}
|
135 |
+
)
|
136 |
+
features_map: dict[str, str] = field(
|
137 |
+
default_factory=lambda: {
|
138 |
+
"action": ACTION,
|
139 |
+
"agent_pos": OBS_ROBOT,
|
140 |
+
"pixels": OBS_IMAGE,
|
141 |
+
}
|
142 |
+
)
|
143 |
+
|
144 |
+
def __post_init__(self):
|
145 |
+
if self.obs_type == "pixels_agent_pos":
|
146 |
+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
147 |
+
|
148 |
+
@property
|
149 |
+
def gym_kwargs(self) -> dict:
|
150 |
+
return {
|
151 |
+
"obs_type": self.obs_type,
|
152 |
+
"render_mode": self.render_mode,
|
153 |
+
"visualization_width": self.visualization_width,
|
154 |
+
"visualization_height": self.visualization_height,
|
155 |
+
"max_episode_steps": self.episode_length,
|
156 |
+
}
|
lerobot/common/envs/factory.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import importlib
|
17 |
+
|
18 |
+
import gymnasium as gym
|
19 |
+
|
20 |
+
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
21 |
+
|
22 |
+
|
23 |
+
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
24 |
+
if env_type == "aloha":
|
25 |
+
return AlohaEnv(**kwargs)
|
26 |
+
elif env_type == "pusht":
|
27 |
+
return PushtEnv(**kwargs)
|
28 |
+
elif env_type == "xarm":
|
29 |
+
return XarmEnv(**kwargs)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Policy type '{env_type}' is not available.")
|
32 |
+
|
33 |
+
|
34 |
+
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
35 |
+
"""Makes a gym vector environment according to the config.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
cfg (EnvConfig): the config of the environment to instantiate.
|
39 |
+
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
40 |
+
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
41 |
+
False.
|
42 |
+
|
43 |
+
Raises:
|
44 |
+
ValueError: if n_envs < 1
|
45 |
+
ModuleNotFoundError: If the requested env package is not installed
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
gym.vector.VectorEnv: The parallelized gym.env instance.
|
49 |
+
"""
|
50 |
+
if n_envs < 1:
|
51 |
+
raise ValueError("`n_envs must be at least 1")
|
52 |
+
|
53 |
+
package_name = f"gym_{cfg.type}"
|
54 |
+
|
55 |
+
try:
|
56 |
+
importlib.import_module(package_name)
|
57 |
+
except ModuleNotFoundError as e:
|
58 |
+
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
59 |
+
raise e
|
60 |
+
|
61 |
+
gym_handle = f"{package_name}/{cfg.task}"
|
62 |
+
|
63 |
+
# batched version of the env that returns an observation of shape (b, c)
|
64 |
+
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
65 |
+
env = env_cls(
|
66 |
+
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
67 |
+
)
|
68 |
+
|
69 |
+
return env
|
lerobot/common/envs/utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import warnings
|
17 |
+
from typing import Any
|
18 |
+
|
19 |
+
import einops
|
20 |
+
import gymnasium as gym
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from torch import Tensor
|
24 |
+
|
25 |
+
from lerobot.common.envs.configs import EnvConfig
|
26 |
+
from lerobot.common.utils.utils import get_channel_first_image_shape
|
27 |
+
from lerobot.configs.types import FeatureType, PolicyFeature
|
28 |
+
|
29 |
+
|
30 |
+
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
31 |
+
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
32 |
+
"""Convert environment observation to LeRobot format observation.
|
33 |
+
Args:
|
34 |
+
observation: Dictionary of observation batches from a Gym vector environment.
|
35 |
+
Returns:
|
36 |
+
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
37 |
+
"""
|
38 |
+
# map to expected inputs for the policy
|
39 |
+
return_observations = {}
|
40 |
+
if "pixels" in observations:
|
41 |
+
if isinstance(observations["pixels"], dict):
|
42 |
+
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
43 |
+
else:
|
44 |
+
imgs = {"observation.image": observations["pixels"]}
|
45 |
+
|
46 |
+
for imgkey, img in imgs.items():
|
47 |
+
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
48 |
+
img = torch.from_numpy(img)
|
49 |
+
|
50 |
+
# sanity check that images are channel last
|
51 |
+
_, h, w, c = img.shape
|
52 |
+
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
53 |
+
|
54 |
+
# sanity check that images are uint8
|
55 |
+
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
56 |
+
|
57 |
+
# convert to channel first of type float32 in range [0,1]
|
58 |
+
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
59 |
+
img = img.type(torch.float32)
|
60 |
+
img /= 255
|
61 |
+
|
62 |
+
return_observations[imgkey] = img
|
63 |
+
|
64 |
+
if "environment_state" in observations:
|
65 |
+
return_observations["observation.environment_state"] = torch.from_numpy(
|
66 |
+
observations["environment_state"]
|
67 |
+
).float()
|
68 |
+
|
69 |
+
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
70 |
+
# requirement for "agent_pos"
|
71 |
+
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
72 |
+
return return_observations
|
73 |
+
|
74 |
+
|
75 |
+
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
76 |
+
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
77 |
+
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
78 |
+
policy_features = {}
|
79 |
+
for key, ft in env_cfg.features.items():
|
80 |
+
if ft.type is FeatureType.VISUAL:
|
81 |
+
if len(ft.shape) != 3:
|
82 |
+
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
83 |
+
|
84 |
+
shape = get_channel_first_image_shape(ft.shape)
|
85 |
+
feature = PolicyFeature(type=ft.type, shape=shape)
|
86 |
+
else:
|
87 |
+
feature = ft
|
88 |
+
|
89 |
+
policy_key = env_cfg.features_map[key]
|
90 |
+
policy_features[policy_key] = feature
|
91 |
+
|
92 |
+
return policy_features
|
93 |
+
|
94 |
+
|
95 |
+
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
|
96 |
+
first_type = type(env.envs[0]) # Get type of first env
|
97 |
+
return all(type(e) is first_type for e in env.envs) # Fast type check
|
98 |
+
|
99 |
+
|
100 |
+
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
101 |
+
with warnings.catch_warnings():
|
102 |
+
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
103 |
+
|
104 |
+
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
105 |
+
warnings.warn(
|
106 |
+
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
107 |
+
UserWarning,
|
108 |
+
stacklevel=2,
|
109 |
+
)
|
110 |
+
if not are_all_envs_same_type(env):
|
111 |
+
warnings.warn(
|
112 |
+
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
|
113 |
+
UserWarning,
|
114 |
+
stacklevel=2,
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
119 |
+
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
120 |
+
if hasattr(env.envs[0], "task_description"):
|
121 |
+
observation["task"] = env.call("task_description")
|
122 |
+
elif hasattr(env.envs[0], "task"):
|
123 |
+
observation["task"] = env.call("task")
|
124 |
+
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
|
125 |
+
num_envs = observation[list(observation.keys())[0]].shape[0]
|
126 |
+
observation["task"] = ["" for _ in range(num_envs)]
|
127 |
+
return observation
|
lerobot/common/mocks/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Common mocks for robot devices and testing
|
lerobot/common/mocks/cameras/__init__.py
ADDED
File without changes
|
lerobot/common/mocks/cameras/mock_cv2.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import cache
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
CAP_V4L2 = 200
|
19 |
+
CAP_DSHOW = 700
|
20 |
+
CAP_AVFOUNDATION = 1200
|
21 |
+
CAP_ANY = -1
|
22 |
+
|
23 |
+
CAP_PROP_FPS = 5
|
24 |
+
CAP_PROP_FRAME_WIDTH = 3
|
25 |
+
CAP_PROP_FRAME_HEIGHT = 4
|
26 |
+
COLOR_RGB2BGR = 4
|
27 |
+
COLOR_BGR2RGB = 4
|
28 |
+
|
29 |
+
ROTATE_90_COUNTERCLOCKWISE = 2
|
30 |
+
ROTATE_90_CLOCKWISE = 0
|
31 |
+
ROTATE_180 = 1
|
32 |
+
|
33 |
+
|
34 |
+
@cache
|
35 |
+
def _generate_image(width: int, height: int):
|
36 |
+
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
|
37 |
+
|
38 |
+
|
39 |
+
def cvtColor(color_image, color_conversion): # noqa: N802
|
40 |
+
if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
|
41 |
+
return color_image[:, :, [2, 1, 0]]
|
42 |
+
else:
|
43 |
+
raise NotImplementedError(color_conversion)
|
44 |
+
|
45 |
+
|
46 |
+
def rotate(color_image, rotation):
|
47 |
+
if rotation is None:
|
48 |
+
return color_image
|
49 |
+
elif rotation == ROTATE_90_CLOCKWISE:
|
50 |
+
return np.rot90(color_image, k=1)
|
51 |
+
elif rotation == ROTATE_180:
|
52 |
+
return np.rot90(color_image, k=2)
|
53 |
+
elif rotation == ROTATE_90_COUNTERCLOCKWISE:
|
54 |
+
return np.rot90(color_image, k=3)
|
55 |
+
else:
|
56 |
+
raise NotImplementedError(rotation)
|
57 |
+
|
58 |
+
|
59 |
+
class VideoCapture:
|
60 |
+
def __init__(self, *args, **kwargs):
|
61 |
+
self._mock_dict = {
|
62 |
+
CAP_PROP_FPS: 30,
|
63 |
+
CAP_PROP_FRAME_WIDTH: 640,
|
64 |
+
CAP_PROP_FRAME_HEIGHT: 480,
|
65 |
+
}
|
66 |
+
self._is_opened = True
|
67 |
+
|
68 |
+
def isOpened(self): # noqa: N802
|
69 |
+
return self._is_opened
|
70 |
+
|
71 |
+
def set(self, propId: int, value: float) -> bool: # noqa: N803
|
72 |
+
if not self._is_opened:
|
73 |
+
raise RuntimeError("Camera is not opened")
|
74 |
+
self._mock_dict[propId] = value
|
75 |
+
return True
|
76 |
+
|
77 |
+
def get(self, propId: int) -> float: # noqa: N803
|
78 |
+
if not self._is_opened:
|
79 |
+
raise RuntimeError("Camera is not opened")
|
80 |
+
value = self._mock_dict[propId]
|
81 |
+
if value == 0:
|
82 |
+
if propId == CAP_PROP_FRAME_HEIGHT:
|
83 |
+
value = 480
|
84 |
+
elif propId == CAP_PROP_FRAME_WIDTH:
|
85 |
+
value = 640
|
86 |
+
return value
|
87 |
+
|
88 |
+
def read(self):
|
89 |
+
if not self._is_opened:
|
90 |
+
raise RuntimeError("Camera is not opened")
|
91 |
+
h = self.get(CAP_PROP_FRAME_HEIGHT)
|
92 |
+
w = self.get(CAP_PROP_FRAME_WIDTH)
|
93 |
+
ret = True
|
94 |
+
return ret, _generate_image(width=w, height=h)
|
95 |
+
|
96 |
+
def release(self):
|
97 |
+
self._is_opened = False
|
98 |
+
|
99 |
+
def __del__(self):
|
100 |
+
if self._is_opened:
|
101 |
+
self.release()
|
lerobot/common/mocks/cameras/mock_pyrealsense2.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import enum
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
|
19 |
+
class stream(enum.Enum): # noqa: N801
|
20 |
+
color = 0
|
21 |
+
depth = 1
|
22 |
+
|
23 |
+
|
24 |
+
class format(enum.Enum): # noqa: N801
|
25 |
+
rgb8 = 0
|
26 |
+
z16 = 1
|
27 |
+
|
28 |
+
|
29 |
+
class config: # noqa: N801
|
30 |
+
def enable_device(self, device_id: str):
|
31 |
+
self.device_enabled = device_id
|
32 |
+
|
33 |
+
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
|
34 |
+
self.stream_type = stream_type
|
35 |
+
# Overwrite default values when possible
|
36 |
+
self.width = 848 if width is None else width
|
37 |
+
self.height = 480 if height is None else height
|
38 |
+
self.color_format = format.rgb8 if color_format is None else color_format
|
39 |
+
self.fps = 30 if fps is None else fps
|
40 |
+
|
41 |
+
|
42 |
+
class RSColorProfile:
|
43 |
+
def __init__(self, config):
|
44 |
+
self.config = config
|
45 |
+
|
46 |
+
def fps(self):
|
47 |
+
return self.config.fps
|
48 |
+
|
49 |
+
def width(self):
|
50 |
+
return self.config.width
|
51 |
+
|
52 |
+
def height(self):
|
53 |
+
return self.config.height
|
54 |
+
|
55 |
+
|
56 |
+
class RSColorStream:
|
57 |
+
def __init__(self, config):
|
58 |
+
self.config = config
|
59 |
+
|
60 |
+
def as_video_stream_profile(self):
|
61 |
+
return RSColorProfile(self.config)
|
62 |
+
|
63 |
+
|
64 |
+
class RSProfile:
|
65 |
+
def __init__(self, config):
|
66 |
+
self.config = config
|
67 |
+
|
68 |
+
def get_stream(self, color_format):
|
69 |
+
del color_format # unused
|
70 |
+
return RSColorStream(self.config)
|
71 |
+
|
72 |
+
|
73 |
+
class pipeline: # noqa: N801
|
74 |
+
def __init__(self):
|
75 |
+
self.started = False
|
76 |
+
self.config = None
|
77 |
+
|
78 |
+
def start(self, config):
|
79 |
+
self.started = True
|
80 |
+
self.config = config
|
81 |
+
return RSProfile(self.config)
|
82 |
+
|
83 |
+
def stop(self):
|
84 |
+
if not self.started:
|
85 |
+
raise RuntimeError("You need to start the camera before stop.")
|
86 |
+
self.started = False
|
87 |
+
self.config = None
|
88 |
+
|
89 |
+
def wait_for_frames(self, timeout_ms=50000):
|
90 |
+
del timeout_ms # unused
|
91 |
+
return RSFrames(self.config)
|
92 |
+
|
93 |
+
|
94 |
+
class RSFrames:
|
95 |
+
def __init__(self, config):
|
96 |
+
self.config = config
|
97 |
+
|
98 |
+
def get_color_frame(self):
|
99 |
+
return RSColorFrame(self.config)
|
100 |
+
|
101 |
+
def get_depth_frame(self):
|
102 |
+
return RSDepthFrame(self.config)
|
103 |
+
|
104 |
+
|
105 |
+
class RSColorFrame:
|
106 |
+
def __init__(self, config):
|
107 |
+
self.config = config
|
108 |
+
|
109 |
+
def get_data(self):
|
110 |
+
data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8)
|
111 |
+
# Create a difference between rgb and bgr
|
112 |
+
data[:, :, 0] = 2
|
113 |
+
return data
|
114 |
+
|
115 |
+
|
116 |
+
class RSDepthFrame:
|
117 |
+
def __init__(self, config):
|
118 |
+
self.config = config
|
119 |
+
|
120 |
+
def get_data(self):
|
121 |
+
return np.ones((self.config.height, self.config.width), dtype=np.uint16)
|
122 |
+
|
123 |
+
|
124 |
+
class RSDevice:
|
125 |
+
def __init__(self):
|
126 |
+
pass
|
127 |
+
|
128 |
+
def get_info(self, camera_info) -> str:
|
129 |
+
del camera_info # unused
|
130 |
+
# return fake serial number
|
131 |
+
return "123456789"
|
132 |
+
|
133 |
+
|
134 |
+
class context: # noqa: N801
|
135 |
+
def __init__(self):
|
136 |
+
pass
|
137 |
+
|
138 |
+
def query_devices(self):
|
139 |
+
return [RSDevice()]
|
140 |
+
|
141 |
+
|
142 |
+
class camera_info: # noqa: N801
|
143 |
+
# fake name
|
144 |
+
name = "Intel RealSense D435I"
|
145 |
+
|
146 |
+
def __init__(self, serial_number):
|
147 |
+
del serial_number
|
148 |
+
pass
|
lerobot/common/mocks/motors/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Mocks for motor modules
|
lerobot/common/mocks/motors/mock_dynamixel_sdk.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
15 |
+
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
16 |
+
|
17 |
+
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
|
18 |
+
from the original classes and functions (e.g. return types might be None instead of boolean).
|
19 |
+
"""
|
20 |
+
|
21 |
+
# from dynamixel_sdk import COMM_SUCCESS
|
22 |
+
|
23 |
+
DEFAULT_BAUDRATE = 9_600
|
24 |
+
COMM_SUCCESS = 0 # tx or rx packet communication success
|
25 |
+
|
26 |
+
|
27 |
+
def convert_to_bytes(value, bytes):
|
28 |
+
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
29 |
+
# `convert_bytes_to_value`
|
30 |
+
del bytes # unused
|
31 |
+
return value
|
32 |
+
|
33 |
+
|
34 |
+
def get_default_motor_values(motor_index):
|
35 |
+
return {
|
36 |
+
# Key (int) are from X_SERIES_CONTROL_TABLE
|
37 |
+
7: motor_index, # ID
|
38 |
+
8: DEFAULT_BAUDRATE, # Baud_rate
|
39 |
+
10: 0, # Drive_Mode
|
40 |
+
64: 0, # Torque_Enable
|
41 |
+
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
|
42 |
+
# For other joints, 2560 will be autocorrected to be in calibration range
|
43 |
+
132: 2560, # Present_Position
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
class PortHandler:
|
48 |
+
def __init__(self, port):
|
49 |
+
self.port = port
|
50 |
+
# factory default baudrate
|
51 |
+
self.baudrate = DEFAULT_BAUDRATE
|
52 |
+
|
53 |
+
def openPort(self): # noqa: N802
|
54 |
+
return True
|
55 |
+
|
56 |
+
def closePort(self): # noqa: N802
|
57 |
+
pass
|
58 |
+
|
59 |
+
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
|
60 |
+
del timeout_ms # unused
|
61 |
+
|
62 |
+
def getBaudRate(self): # noqa: N802
|
63 |
+
return self.baudrate
|
64 |
+
|
65 |
+
def setBaudRate(self, baudrate): # noqa: N802
|
66 |
+
self.baudrate = baudrate
|
67 |
+
|
68 |
+
|
69 |
+
class PacketHandler:
|
70 |
+
def __init__(self, protocol_version):
|
71 |
+
del protocol_version # unused
|
72 |
+
# Use packet_handler.data to communicate across Read and Write
|
73 |
+
self.data = {}
|
74 |
+
|
75 |
+
|
76 |
+
class GroupSyncRead:
|
77 |
+
def __init__(self, port_handler, packet_handler, address, bytes):
|
78 |
+
self.packet_handler = packet_handler
|
79 |
+
|
80 |
+
def addParam(self, motor_index): # noqa: N802
|
81 |
+
# Initialize motor default values
|
82 |
+
if motor_index not in self.packet_handler.data:
|
83 |
+
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
84 |
+
|
85 |
+
def txRxPacket(self): # noqa: N802
|
86 |
+
return COMM_SUCCESS
|
87 |
+
|
88 |
+
def getData(self, index, address, bytes): # noqa: N802
|
89 |
+
return self.packet_handler.data[index][address]
|
90 |
+
|
91 |
+
|
92 |
+
class GroupSyncWrite:
|
93 |
+
def __init__(self, port_handler, packet_handler, address, bytes):
|
94 |
+
self.packet_handler = packet_handler
|
95 |
+
self.address = address
|
96 |
+
|
97 |
+
def addParam(self, index, data): # noqa: N802
|
98 |
+
# Initialize motor default values
|
99 |
+
if index not in self.packet_handler.data:
|
100 |
+
self.packet_handler.data[index] = get_default_motor_values(index)
|
101 |
+
self.changeParam(index, data)
|
102 |
+
|
103 |
+
def txPacket(self): # noqa: N802
|
104 |
+
return COMM_SUCCESS
|
105 |
+
|
106 |
+
def changeParam(self, index, data): # noqa: N802
|
107 |
+
self.packet_handler.data[index][self.address] = data
|
lerobot/common/mocks/motors/mock_scservo_sdk.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
15 |
+
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
16 |
+
|
17 |
+
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
|
18 |
+
from the original classes and functions (e.g. return types might be None instead of boolean).
|
19 |
+
"""
|
20 |
+
|
21 |
+
# from dynamixel_sdk import COMM_SUCCESS
|
22 |
+
|
23 |
+
DEFAULT_BAUDRATE = 1_000_000
|
24 |
+
COMM_SUCCESS = 0 # tx or rx packet communication success
|
25 |
+
|
26 |
+
|
27 |
+
def convert_to_bytes(value, bytes):
|
28 |
+
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
29 |
+
# `convert_bytes_to_value`
|
30 |
+
del bytes # unused
|
31 |
+
return value
|
32 |
+
|
33 |
+
|
34 |
+
def get_default_motor_values(motor_index):
|
35 |
+
return {
|
36 |
+
# Key (int) are from SCS_SERIES_CONTROL_TABLE
|
37 |
+
5: motor_index, # ID
|
38 |
+
6: DEFAULT_BAUDRATE, # Baud_rate
|
39 |
+
10: 0, # Drive_Mode
|
40 |
+
21: 32, # P_Coefficient
|
41 |
+
22: 32, # D_Coefficient
|
42 |
+
23: 0, # I_Coefficient
|
43 |
+
40: 0, # Torque_Enable
|
44 |
+
41: 254, # Acceleration
|
45 |
+
31: -2047, # Offset
|
46 |
+
33: 0, # Mode
|
47 |
+
55: 1, # Lock
|
48 |
+
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
|
49 |
+
# For other joints, 2560 will be autocorrected to be in calibration range
|
50 |
+
56: 2560, # Present_Position
|
51 |
+
58: 0, # Present_Speed
|
52 |
+
69: 0, # Present_Current
|
53 |
+
85: 150, # Maximum_Acceleration
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
class PortHandler:
|
58 |
+
def __init__(self, port):
|
59 |
+
self.port = port
|
60 |
+
# factory default baudrate
|
61 |
+
self.baudrate = DEFAULT_BAUDRATE
|
62 |
+
self.ser = SerialMock()
|
63 |
+
|
64 |
+
def openPort(self): # noqa: N802
|
65 |
+
return True
|
66 |
+
|
67 |
+
def closePort(self): # noqa: N802
|
68 |
+
pass
|
69 |
+
|
70 |
+
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
|
71 |
+
del timeout_ms # unused
|
72 |
+
|
73 |
+
def getBaudRate(self): # noqa: N802
|
74 |
+
return self.baudrate
|
75 |
+
|
76 |
+
def setBaudRate(self, baudrate): # noqa: N802
|
77 |
+
self.baudrate = baudrate
|
78 |
+
|
79 |
+
|
80 |
+
class PacketHandler:
|
81 |
+
def __init__(self, protocol_version):
|
82 |
+
del protocol_version # unused
|
83 |
+
# Use packet_handler.data to communicate across Read and Write
|
84 |
+
self.data = {}
|
85 |
+
|
86 |
+
|
87 |
+
class GroupSyncRead:
|
88 |
+
def __init__(self, port_handler, packet_handler, address, bytes):
|
89 |
+
self.packet_handler = packet_handler
|
90 |
+
|
91 |
+
def addParam(self, motor_index): # noqa: N802
|
92 |
+
# Initialize motor default values
|
93 |
+
if motor_index not in self.packet_handler.data:
|
94 |
+
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
95 |
+
|
96 |
+
def txRxPacket(self): # noqa: N802
|
97 |
+
return COMM_SUCCESS
|
98 |
+
|
99 |
+
def getData(self, index, address, bytes): # noqa: N802
|
100 |
+
return self.packet_handler.data[index][address]
|
101 |
+
|
102 |
+
|
103 |
+
class GroupSyncWrite:
|
104 |
+
def __init__(self, port_handler, packet_handler, address, bytes):
|
105 |
+
self.packet_handler = packet_handler
|
106 |
+
self.address = address
|
107 |
+
|
108 |
+
def addParam(self, index, data): # noqa: N802
|
109 |
+
if index not in self.packet_handler.data:
|
110 |
+
self.packet_handler.data[index] = get_default_motor_values(index)
|
111 |
+
self.changeParam(index, data)
|
112 |
+
|
113 |
+
def txPacket(self): # noqa: N802
|
114 |
+
return COMM_SUCCESS
|
115 |
+
|
116 |
+
def changeParam(self, index, data): # noqa: N802
|
117 |
+
self.packet_handler.data[index][self.address] = data
|
118 |
+
|
119 |
+
|
120 |
+
class SerialMock:
|
121 |
+
def reset_output_buffer(self):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def reset_input_buffer(self):
|
125 |
+
pass
|
lerobot/common/optim/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .optimizers import OptimizerConfig as OptimizerConfig
|
lerobot/common/optim/factory.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
from torch.optim import Optimizer
|
19 |
+
from torch.optim.lr_scheduler import LRScheduler
|
20 |
+
|
21 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
22 |
+
from lerobot.configs.train import TrainPipelineConfig
|
23 |
+
|
24 |
+
|
25 |
+
def make_optimizer_and_scheduler(
|
26 |
+
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
|
27 |
+
) -> tuple[Optimizer, LRScheduler | None]:
|
28 |
+
"""Generates the optimizer and scheduler based on configs.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
|
32 |
+
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
36 |
+
"""
|
37 |
+
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
38 |
+
optimizer = cfg.optimizer.build(params)
|
39 |
+
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
40 |
+
return optimizer, lr_scheduler
|
lerobot/common/optim/optimizers.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import abc
|
17 |
+
from dataclasses import asdict, dataclass
|
18 |
+
from pathlib import Path
|
19 |
+
|
20 |
+
import draccus
|
21 |
+
import torch
|
22 |
+
from safetensors.torch import load_file, save_file
|
23 |
+
|
24 |
+
from lerobot.common.constants import (
|
25 |
+
OPTIMIZER_PARAM_GROUPS,
|
26 |
+
OPTIMIZER_STATE,
|
27 |
+
)
|
28 |
+
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
29 |
+
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
34 |
+
lr: float
|
35 |
+
weight_decay: float
|
36 |
+
grad_clip_norm: float
|
37 |
+
|
38 |
+
@property
|
39 |
+
def type(self) -> str:
|
40 |
+
return self.get_choice_name(self.__class__)
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def default_choice_name(cls) -> str | None:
|
44 |
+
return "adam"
|
45 |
+
|
46 |
+
@abc.abstractmethod
|
47 |
+
def build(self) -> torch.optim.Optimizer:
|
48 |
+
raise NotImplementedError
|
49 |
+
|
50 |
+
|
51 |
+
@OptimizerConfig.register_subclass("adam")
|
52 |
+
@dataclass
|
53 |
+
class AdamConfig(OptimizerConfig):
|
54 |
+
lr: float = 1e-3
|
55 |
+
betas: tuple[float, float] = (0.9, 0.999)
|
56 |
+
eps: float = 1e-8
|
57 |
+
weight_decay: float = 0.0
|
58 |
+
grad_clip_norm: float = 10.0
|
59 |
+
|
60 |
+
def build(self, params: dict) -> torch.optim.Optimizer:
|
61 |
+
kwargs = asdict(self)
|
62 |
+
kwargs.pop("grad_clip_norm")
|
63 |
+
return torch.optim.Adam(params, **kwargs)
|
64 |
+
|
65 |
+
|
66 |
+
@OptimizerConfig.register_subclass("adamw")
|
67 |
+
@dataclass
|
68 |
+
class AdamWConfig(OptimizerConfig):
|
69 |
+
lr: float = 1e-3
|
70 |
+
betas: tuple[float, float] = (0.9, 0.999)
|
71 |
+
eps: float = 1e-8
|
72 |
+
weight_decay: float = 1e-2
|
73 |
+
grad_clip_norm: float = 10.0
|
74 |
+
|
75 |
+
def build(self, params: dict) -> torch.optim.Optimizer:
|
76 |
+
kwargs = asdict(self)
|
77 |
+
kwargs.pop("grad_clip_norm")
|
78 |
+
return torch.optim.AdamW(params, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
@OptimizerConfig.register_subclass("sgd")
|
82 |
+
@dataclass
|
83 |
+
class SGDConfig(OptimizerConfig):
|
84 |
+
lr: float = 1e-3
|
85 |
+
momentum: float = 0.0
|
86 |
+
dampening: float = 0.0
|
87 |
+
nesterov: bool = False
|
88 |
+
weight_decay: float = 0.0
|
89 |
+
grad_clip_norm: float = 10.0
|
90 |
+
|
91 |
+
def build(self, params: dict) -> torch.optim.Optimizer:
|
92 |
+
kwargs = asdict(self)
|
93 |
+
kwargs.pop("grad_clip_norm")
|
94 |
+
return torch.optim.SGD(params, **kwargs)
|
95 |
+
|
96 |
+
|
97 |
+
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
98 |
+
state = optimizer.state_dict()
|
99 |
+
param_groups = state.pop("param_groups")
|
100 |
+
flat_state = flatten_dict(state)
|
101 |
+
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
102 |
+
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
103 |
+
|
104 |
+
|
105 |
+
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
106 |
+
current_state_dict = optimizer.state_dict()
|
107 |
+
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
108 |
+
state = unflatten_dict(flat_state)
|
109 |
+
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
110 |
+
|
111 |
+
if "param_groups" in current_state_dict:
|
112 |
+
param_groups = deserialize_json_into_object(
|
113 |
+
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
114 |
+
)
|
115 |
+
loaded_state_dict["param_groups"] = param_groups
|
116 |
+
|
117 |
+
optimizer.load_state_dict(loaded_state_dict)
|
118 |
+
return optimizer
|
lerobot/common/optim/schedulers.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import abc
|
17 |
+
import math
|
18 |
+
from dataclasses import asdict, dataclass
|
19 |
+
from pathlib import Path
|
20 |
+
|
21 |
+
import draccus
|
22 |
+
from torch.optim import Optimizer
|
23 |
+
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
24 |
+
|
25 |
+
from lerobot.common.constants import SCHEDULER_STATE
|
26 |
+
from lerobot.common.datasets.utils import write_json
|
27 |
+
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
32 |
+
num_warmup_steps: int
|
33 |
+
|
34 |
+
@property
|
35 |
+
def type(self) -> str:
|
36 |
+
return self.get_choice_name(self.__class__)
|
37 |
+
|
38 |
+
@abc.abstractmethod
|
39 |
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
|
43 |
+
@LRSchedulerConfig.register_subclass("diffuser")
|
44 |
+
@dataclass
|
45 |
+
class DiffuserSchedulerConfig(LRSchedulerConfig):
|
46 |
+
name: str = "cosine"
|
47 |
+
num_warmup_steps: int | None = None
|
48 |
+
|
49 |
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
50 |
+
from diffusers.optimization import get_scheduler
|
51 |
+
|
52 |
+
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
53 |
+
return get_scheduler(**kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
@LRSchedulerConfig.register_subclass("vqbet")
|
57 |
+
@dataclass
|
58 |
+
class VQBeTSchedulerConfig(LRSchedulerConfig):
|
59 |
+
num_warmup_steps: int
|
60 |
+
num_vqvae_training_steps: int
|
61 |
+
num_cycles: float = 0.5
|
62 |
+
|
63 |
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
64 |
+
def lr_lambda(current_step):
|
65 |
+
if current_step < self.num_vqvae_training_steps:
|
66 |
+
return float(1)
|
67 |
+
else:
|
68 |
+
adjusted_step = current_step - self.num_vqvae_training_steps
|
69 |
+
if adjusted_step < self.num_warmup_steps:
|
70 |
+
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
|
71 |
+
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
72 |
+
max(1, num_training_steps - self.num_warmup_steps)
|
73 |
+
)
|
74 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
75 |
+
|
76 |
+
return LambdaLR(optimizer, lr_lambda, -1)
|
77 |
+
|
78 |
+
|
79 |
+
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
80 |
+
@dataclass
|
81 |
+
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
82 |
+
"""Used by Physical Intelligence to train Pi0"""
|
83 |
+
|
84 |
+
num_warmup_steps: int
|
85 |
+
num_decay_steps: int
|
86 |
+
peak_lr: float
|
87 |
+
decay_lr: float
|
88 |
+
|
89 |
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
90 |
+
del num_training_steps
|
91 |
+
|
92 |
+
def lr_lambda(current_step):
|
93 |
+
def linear_warmup_schedule(current_step):
|
94 |
+
if current_step <= 0:
|
95 |
+
return 1 / (self.num_warmup_steps + 1)
|
96 |
+
frac = 1 - current_step / self.num_warmup_steps
|
97 |
+
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
98 |
+
|
99 |
+
def cosine_decay_schedule(current_step):
|
100 |
+
step = min(current_step, self.num_decay_steps)
|
101 |
+
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
102 |
+
alpha = self.decay_lr / self.peak_lr
|
103 |
+
decayed = (1 - alpha) * cosine_decay + alpha
|
104 |
+
return decayed
|
105 |
+
|
106 |
+
if current_step < self.num_warmup_steps:
|
107 |
+
return linear_warmup_schedule(current_step)
|
108 |
+
|
109 |
+
return cosine_decay_schedule(current_step)
|
110 |
+
|
111 |
+
return LambdaLR(optimizer, lr_lambda, -1)
|
112 |
+
|
113 |
+
|
114 |
+
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
115 |
+
state_dict = scheduler.state_dict()
|
116 |
+
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
117 |
+
|
118 |
+
|
119 |
+
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
120 |
+
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
121 |
+
scheduler.load_state_dict(state_dict)
|
122 |
+
return scheduler
|
lerobot/common/policies/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .act.configuration_act import ACTConfig as ACTConfig
|
16 |
+
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
17 |
+
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
18 |
+
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
19 |
+
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
lerobot/common/policies/act/configuration_act.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
|
18 |
+
from lerobot.common.optim.optimizers import AdamWConfig
|
19 |
+
from lerobot.configs.policies import PreTrainedConfig
|
20 |
+
from lerobot.configs.types import NormalizationMode
|
21 |
+
|
22 |
+
|
23 |
+
@PreTrainedConfig.register_subclass("act")
|
24 |
+
@dataclass
|
25 |
+
class ACTConfig(PreTrainedConfig):
|
26 |
+
"""Configuration class for the Action Chunking Transformers policy.
|
27 |
+
|
28 |
+
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
29 |
+
|
30 |
+
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
31 |
+
Those are: `input_shapes` and 'output_shapes`.
|
32 |
+
|
33 |
+
Notes on the inputs and outputs:
|
34 |
+
- Either:
|
35 |
+
- At least one key starting with "observation.image is required as an input.
|
36 |
+
AND/OR
|
37 |
+
- The key "observation.environment_state" is required as input.
|
38 |
+
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
39 |
+
views. Right now we only support all images having the same shape.
|
40 |
+
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
41 |
+
- "action" is required as an output key.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
45 |
+
current step and additional steps going back).
|
46 |
+
chunk_size: The size of the action prediction "chunks" in units of environment steps.
|
47 |
+
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
48 |
+
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
49 |
+
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
50 |
+
environment, and throws the other 50 out.
|
51 |
+
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
52 |
+
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
53 |
+
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
54 |
+
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
55 |
+
include batch dimension or temporal dimension.
|
56 |
+
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
57 |
+
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
58 |
+
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
59 |
+
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
60 |
+
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
61 |
+
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
62 |
+
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
63 |
+
[-1, 1] range.
|
64 |
+
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
65 |
+
original scale. Note that this is also used for normalizing the training targets.
|
66 |
+
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
67 |
+
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
68 |
+
`None` means no pretrained weights.
|
69 |
+
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
70 |
+
convolution.
|
71 |
+
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
72 |
+
dim_model: The transformer blocks' main hidden dimension.
|
73 |
+
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
|
74 |
+
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
|
75 |
+
layers.
|
76 |
+
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
|
77 |
+
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
|
78 |
+
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
|
79 |
+
use_vae: Whether to use a variational objective during training. This introduces another transformer
|
80 |
+
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
|
81 |
+
documentation in the policy class).
|
82 |
+
latent_dim: The VAE's latent dimension.
|
83 |
+
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
84 |
+
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
|
85 |
+
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
|
86 |
+
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
|
87 |
+
more information on how ensembling works, please see `ACTTemporalEnsembler`.
|
88 |
+
dropout: Dropout to use in the transformer layers (see code for details).
|
89 |
+
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
90 |
+
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
91 |
+
"""
|
92 |
+
|
93 |
+
# Input / output structure.
|
94 |
+
n_obs_steps: int = 1
|
95 |
+
chunk_size: int = 100
|
96 |
+
n_action_steps: int = 100
|
97 |
+
|
98 |
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
99 |
+
default_factory=lambda: {
|
100 |
+
"VISUAL": NormalizationMode.MEAN_STD,
|
101 |
+
"STATE": NormalizationMode.MEAN_STD,
|
102 |
+
"ACTION": NormalizationMode.MEAN_STD,
|
103 |
+
}
|
104 |
+
)
|
105 |
+
|
106 |
+
# Architecture.
|
107 |
+
# Vision backbone.
|
108 |
+
vision_backbone: str = "resnet18"
|
109 |
+
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
110 |
+
replace_final_stride_with_dilation: int = False
|
111 |
+
# Transformer layers.
|
112 |
+
pre_norm: bool = False
|
113 |
+
dim_model: int = 512
|
114 |
+
n_heads: int = 8
|
115 |
+
dim_feedforward: int = 3200
|
116 |
+
feedforward_activation: str = "relu"
|
117 |
+
n_encoder_layers: int = 4
|
118 |
+
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
119 |
+
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
120 |
+
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
121 |
+
n_decoder_layers: int = 1
|
122 |
+
# VAE.
|
123 |
+
use_vae: bool = True
|
124 |
+
latent_dim: int = 32
|
125 |
+
n_vae_encoder_layers: int = 4
|
126 |
+
|
127 |
+
# Inference.
|
128 |
+
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
|
129 |
+
temporal_ensemble_coeff: float | None = None
|
130 |
+
|
131 |
+
# Training and loss computation.
|
132 |
+
dropout: float = 0.1
|
133 |
+
kl_weight: float = 10.0
|
134 |
+
|
135 |
+
# Training preset
|
136 |
+
optimizer_lr: float = 1e-5
|
137 |
+
optimizer_weight_decay: float = 1e-4
|
138 |
+
optimizer_lr_backbone: float = 1e-5
|
139 |
+
|
140 |
+
def __post_init__(self):
|
141 |
+
super().__post_init__()
|
142 |
+
|
143 |
+
"""Input validation (not exhaustive)."""
|
144 |
+
if not self.vision_backbone.startswith("resnet"):
|
145 |
+
raise ValueError(
|
146 |
+
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
147 |
+
)
|
148 |
+
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
|
149 |
+
raise NotImplementedError(
|
150 |
+
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
151 |
+
"because the policy needs to be queried every step to compute the ensembled action."
|
152 |
+
)
|
153 |
+
if self.n_action_steps > self.chunk_size:
|
154 |
+
raise ValueError(
|
155 |
+
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
156 |
+
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
157 |
+
)
|
158 |
+
if self.n_obs_steps != 1:
|
159 |
+
raise ValueError(
|
160 |
+
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
161 |
+
)
|
162 |
+
|
163 |
+
def get_optimizer_preset(self) -> AdamWConfig:
|
164 |
+
return AdamWConfig(
|
165 |
+
lr=self.optimizer_lr,
|
166 |
+
weight_decay=self.optimizer_weight_decay,
|
167 |
+
)
|
168 |
+
|
169 |
+
def get_scheduler_preset(self) -> None:
|
170 |
+
return None
|
171 |
+
|
172 |
+
def validate_features(self) -> None:
|
173 |
+
if not self.image_features and not self.env_state_feature:
|
174 |
+
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
175 |
+
|
176 |
+
@property
|
177 |
+
def observation_delta_indices(self) -> None:
|
178 |
+
return None
|
179 |
+
|
180 |
+
@property
|
181 |
+
def action_delta_indices(self) -> list:
|
182 |
+
return list(range(self.chunk_size))
|
183 |
+
|
184 |
+
@property
|
185 |
+
def reward_delta_indices(self) -> None:
|
186 |
+
return None
|
lerobot/common/policies/act/modeling_act.py
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""Action Chunking Transformer Policy
|
17 |
+
|
18 |
+
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
19 |
+
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from collections import deque
|
24 |
+
from itertools import chain
|
25 |
+
from typing import Callable
|
26 |
+
|
27 |
+
import einops
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F # noqa: N812
|
31 |
+
import torchvision
|
32 |
+
from torch import Tensor, nn
|
33 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
34 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
35 |
+
|
36 |
+
from lerobot.common.policies.act.configuration_act import ACTConfig
|
37 |
+
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
38 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
39 |
+
|
40 |
+
|
41 |
+
class ACTPolicy(PreTrainedPolicy):
|
42 |
+
"""
|
43 |
+
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
44 |
+
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
45 |
+
"""
|
46 |
+
|
47 |
+
config_class = ACTConfig
|
48 |
+
name = "act"
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
config: ACTConfig,
|
53 |
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
config: Policy configuration class instance or None, in which case the default instantiation of
|
58 |
+
the configuration class is used.
|
59 |
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
60 |
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
61 |
+
"""
|
62 |
+
super().__init__(config)
|
63 |
+
config.validate_features()
|
64 |
+
self.config = config
|
65 |
+
|
66 |
+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
67 |
+
self.normalize_targets = Normalize(
|
68 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
69 |
+
)
|
70 |
+
self.unnormalize_outputs = Unnormalize(
|
71 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
72 |
+
)
|
73 |
+
|
74 |
+
self.model = ACT(config)
|
75 |
+
|
76 |
+
if config.temporal_ensemble_coeff is not None:
|
77 |
+
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
78 |
+
|
79 |
+
self.reset()
|
80 |
+
|
81 |
+
def get_optim_params(self) -> dict:
|
82 |
+
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
|
83 |
+
# Should we remove this and just `return self.parameters()`?
|
84 |
+
return [
|
85 |
+
{
|
86 |
+
"params": [
|
87 |
+
p
|
88 |
+
for n, p in self.named_parameters()
|
89 |
+
if not n.startswith("model.backbone") and p.requires_grad
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"params": [
|
94 |
+
p
|
95 |
+
for n, p in self.named_parameters()
|
96 |
+
if n.startswith("model.backbone") and p.requires_grad
|
97 |
+
],
|
98 |
+
"lr": self.config.optimizer_lr_backbone,
|
99 |
+
},
|
100 |
+
]
|
101 |
+
|
102 |
+
def reset(self):
|
103 |
+
"""This should be called whenever the environment is reset."""
|
104 |
+
if self.config.temporal_ensemble_coeff is not None:
|
105 |
+
self.temporal_ensembler.reset()
|
106 |
+
else:
|
107 |
+
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
108 |
+
|
109 |
+
@torch.no_grad
|
110 |
+
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
111 |
+
"""Select a single action given environment observations.
|
112 |
+
|
113 |
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
114 |
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
115 |
+
queue is empty.
|
116 |
+
"""
|
117 |
+
self.eval()
|
118 |
+
|
119 |
+
batch = self.normalize_inputs(batch)
|
120 |
+
if self.config.image_features:
|
121 |
+
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
122 |
+
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
123 |
+
|
124 |
+
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
125 |
+
# we are ensembling over.
|
126 |
+
if self.config.temporal_ensemble_coeff is not None:
|
127 |
+
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
128 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
129 |
+
action = self.temporal_ensembler.update(actions)
|
130 |
+
return action
|
131 |
+
|
132 |
+
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
133 |
+
# querying the policy.
|
134 |
+
if len(self._action_queue) == 0:
|
135 |
+
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
136 |
+
|
137 |
+
# TODO(rcadene): make _forward return output dictionary?
|
138 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
139 |
+
|
140 |
+
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
141 |
+
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
142 |
+
self._action_queue.extend(actions.transpose(0, 1))
|
143 |
+
return self._action_queue.popleft()
|
144 |
+
|
145 |
+
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
146 |
+
"""Run the batch through the model and compute the loss for training or validation."""
|
147 |
+
batch = self.normalize_inputs(batch)
|
148 |
+
if self.config.image_features:
|
149 |
+
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
150 |
+
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
151 |
+
|
152 |
+
batch = self.normalize_targets(batch)
|
153 |
+
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
154 |
+
|
155 |
+
l1_loss = (
|
156 |
+
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
157 |
+
).mean()
|
158 |
+
|
159 |
+
loss_dict = {"l1_loss": l1_loss.item()}
|
160 |
+
if self.config.use_vae:
|
161 |
+
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
162 |
+
# each dimension independently, we sum over the latent dimension to get the total
|
163 |
+
# KL-divergence per batch element, then take the mean over the batch.
|
164 |
+
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
165 |
+
mean_kld = (
|
166 |
+
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
167 |
+
)
|
168 |
+
loss_dict["kld_loss"] = mean_kld.item()
|
169 |
+
loss = l1_loss + mean_kld * self.config.kl_weight
|
170 |
+
else:
|
171 |
+
loss = l1_loss
|
172 |
+
|
173 |
+
return loss, loss_dict
|
174 |
+
|
175 |
+
|
176 |
+
class ACTTemporalEnsembler:
|
177 |
+
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
178 |
+
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
|
179 |
+
|
180 |
+
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
|
181 |
+
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
|
182 |
+
coefficient works:
|
183 |
+
- Setting it to 0 uniformly weighs all actions.
|
184 |
+
- Setting it positive gives more weight to older actions.
|
185 |
+
- Setting it negative gives more weight to newer actions.
|
186 |
+
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
|
187 |
+
results in older actions being weighed more highly than newer actions (the experiments documented in
|
188 |
+
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
|
189 |
+
detrimental: doing so aggressively may diminish the benefits of action chunking).
|
190 |
+
|
191 |
+
Here we use an online method for computing the average rather than caching a history of actions in
|
192 |
+
order to compute the average offline. For a simple 1D sequence it looks something like:
|
193 |
+
|
194 |
+
```
|
195 |
+
import torch
|
196 |
+
|
197 |
+
seq = torch.linspace(8, 8.5, 100)
|
198 |
+
print(seq)
|
199 |
+
|
200 |
+
m = 0.01
|
201 |
+
exp_weights = torch.exp(-m * torch.arange(len(seq)))
|
202 |
+
print(exp_weights)
|
203 |
+
|
204 |
+
# Calculate offline
|
205 |
+
avg = (exp_weights * seq).sum() / exp_weights.sum()
|
206 |
+
print("offline", avg)
|
207 |
+
|
208 |
+
# Calculate online
|
209 |
+
for i, item in enumerate(seq):
|
210 |
+
if i == 0:
|
211 |
+
avg = item
|
212 |
+
continue
|
213 |
+
avg *= exp_weights[:i].sum()
|
214 |
+
avg += item * exp_weights[i]
|
215 |
+
avg /= exp_weights[:i+1].sum()
|
216 |
+
print("online", avg)
|
217 |
+
```
|
218 |
+
"""
|
219 |
+
self.chunk_size = chunk_size
|
220 |
+
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
221 |
+
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
222 |
+
self.reset()
|
223 |
+
|
224 |
+
def reset(self):
|
225 |
+
"""Resets the online computation variables."""
|
226 |
+
self.ensembled_actions = None
|
227 |
+
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
|
228 |
+
self.ensembled_actions_count = None
|
229 |
+
|
230 |
+
def update(self, actions: Tensor) -> Tensor:
|
231 |
+
"""
|
232 |
+
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
|
233 |
+
time steps, and pop/return the next batch of actions in the sequence.
|
234 |
+
"""
|
235 |
+
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
236 |
+
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
237 |
+
if self.ensembled_actions is None:
|
238 |
+
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
239 |
+
# time step of the episode.
|
240 |
+
self.ensembled_actions = actions.clone()
|
241 |
+
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
242 |
+
# operations later.
|
243 |
+
self.ensembled_actions_count = torch.ones(
|
244 |
+
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
248 |
+
# the online update for those entries.
|
249 |
+
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
250 |
+
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
251 |
+
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
252 |
+
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
253 |
+
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
254 |
+
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
255 |
+
self.ensembled_actions_count = torch.cat(
|
256 |
+
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
257 |
+
)
|
258 |
+
# "Consume" the first action.
|
259 |
+
action, self.ensembled_actions, self.ensembled_actions_count = (
|
260 |
+
self.ensembled_actions[:, 0],
|
261 |
+
self.ensembled_actions[:, 1:],
|
262 |
+
self.ensembled_actions_count[1:],
|
263 |
+
)
|
264 |
+
return action
|
265 |
+
|
266 |
+
|
267 |
+
class ACT(nn.Module):
|
268 |
+
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
269 |
+
|
270 |
+
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
271 |
+
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
272 |
+
model that encodes the target data (a sequence of actions), and the condition (the robot
|
273 |
+
joint-space).
|
274 |
+
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
|
275 |
+
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
|
276 |
+
have an option to train this model without the variational objective (in which case we drop the
|
277 |
+
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
|
278 |
+
|
279 |
+
Transformer
|
280 |
+
Used alone for inference
|
281 |
+
(acts as VAE decoder
|
282 |
+
during training)
|
283 |
+
┌───────────────────────┐
|
284 |
+
│ Outputs │
|
285 |
+
│ ▲ │
|
286 |
+
│ ┌─────►┌───────┐ │
|
287 |
+
┌──────┐ │ │ │Transf.│ │
|
288 |
+
│ │ │ ├─────►│decoder│ │
|
289 |
+
┌────┴────┐ │ │ │ │ │ │
|
290 |
+
│ │ │ │ ┌───┴───┬─►│ │ │
|
291 |
+
│ VAE │ │ │ │ │ └───────┘ │
|
292 |
+
│ encoder │ │ │ │Transf.│ │
|
293 |
+
│ │ │ │ │encoder│ │
|
294 |
+
└───▲─────┘ │ │ │ │ │
|
295 |
+
│ │ │ └▲──▲─▲─┘ │
|
296 |
+
│ │ │ │ │ │ │
|
297 |
+
inputs └─────┼──┘ │ image emb. │
|
298 |
+
│ state emb. │
|
299 |
+
└───────────────────────┘
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, config: ACTConfig):
|
303 |
+
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
304 |
+
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
305 |
+
super().__init__()
|
306 |
+
self.config = config
|
307 |
+
|
308 |
+
if self.config.use_vae:
|
309 |
+
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
310 |
+
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
311 |
+
# Projection layer for joint-space configuration to hidden dimension.
|
312 |
+
if self.config.robot_state_feature:
|
313 |
+
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
314 |
+
self.config.robot_state_feature.shape[0], config.dim_model
|
315 |
+
)
|
316 |
+
# Projection layer for action (joint-space target) to hidden dimension.
|
317 |
+
self.vae_encoder_action_input_proj = nn.Linear(
|
318 |
+
self.config.action_feature.shape[0],
|
319 |
+
config.dim_model,
|
320 |
+
)
|
321 |
+
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
322 |
+
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
323 |
+
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
324 |
+
# dimension.
|
325 |
+
num_input_token_encoder = 1 + config.chunk_size
|
326 |
+
if self.config.robot_state_feature:
|
327 |
+
num_input_token_encoder += 1
|
328 |
+
self.register_buffer(
|
329 |
+
"vae_encoder_pos_enc",
|
330 |
+
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
331 |
+
)
|
332 |
+
|
333 |
+
# Backbone for image feature extraction.
|
334 |
+
if self.config.image_features:
|
335 |
+
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
336 |
+
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
337 |
+
weights=config.pretrained_backbone_weights,
|
338 |
+
norm_layer=FrozenBatchNorm2d,
|
339 |
+
)
|
340 |
+
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
341 |
+
# feature map).
|
342 |
+
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
343 |
+
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
344 |
+
|
345 |
+
# Transformer (acts as VAE decoder when training with the variational objective).
|
346 |
+
self.encoder = ACTEncoder(config)
|
347 |
+
self.decoder = ACTDecoder(config)
|
348 |
+
|
349 |
+
# Transformer encoder input projections. The tokens will be structured like
|
350 |
+
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
351 |
+
if self.config.robot_state_feature:
|
352 |
+
self.encoder_robot_state_input_proj = nn.Linear(
|
353 |
+
self.config.robot_state_feature.shape[0], config.dim_model
|
354 |
+
)
|
355 |
+
if self.config.env_state_feature:
|
356 |
+
self.encoder_env_state_input_proj = nn.Linear(
|
357 |
+
self.config.env_state_feature.shape[0], config.dim_model
|
358 |
+
)
|
359 |
+
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
360 |
+
if self.config.image_features:
|
361 |
+
self.encoder_img_feat_input_proj = nn.Conv2d(
|
362 |
+
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
363 |
+
)
|
364 |
+
# Transformer encoder positional embeddings.
|
365 |
+
n_1d_tokens = 1 # for the latent
|
366 |
+
if self.config.robot_state_feature:
|
367 |
+
n_1d_tokens += 1
|
368 |
+
if self.config.env_state_feature:
|
369 |
+
n_1d_tokens += 1
|
370 |
+
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
371 |
+
if self.config.image_features:
|
372 |
+
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
373 |
+
|
374 |
+
# Transformer decoder.
|
375 |
+
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
376 |
+
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
377 |
+
|
378 |
+
# Final action regression head on the output of the transformer's decoder.
|
379 |
+
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
|
380 |
+
|
381 |
+
self._reset_parameters()
|
382 |
+
|
383 |
+
def _reset_parameters(self):
|
384 |
+
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
385 |
+
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
386 |
+
if p.dim() > 1:
|
387 |
+
nn.init.xavier_uniform_(p)
|
388 |
+
|
389 |
+
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
390 |
+
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
391 |
+
|
392 |
+
`batch` should have the following structure:
|
393 |
+
{
|
394 |
+
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
|
395 |
+
|
396 |
+
[image_features]: (B, n_cameras, C, H, W) batch of images.
|
397 |
+
AND/OR
|
398 |
+
[env_state_feature]: (B, env_dim) batch of environment states.
|
399 |
+
|
400 |
+
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
401 |
+
}
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
(B, chunk_size, action_dim) batch of action sequences
|
405 |
+
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
406 |
+
latent dimension.
|
407 |
+
"""
|
408 |
+
if self.config.use_vae and self.training:
|
409 |
+
assert "action" in batch, (
|
410 |
+
"actions must be provided when using the variational objective in training mode."
|
411 |
+
)
|
412 |
+
|
413 |
+
if "observation.images" in batch:
|
414 |
+
batch_size = batch["observation.images"][0].shape[0]
|
415 |
+
else:
|
416 |
+
batch_size = batch["observation.environment_state"].shape[0]
|
417 |
+
|
418 |
+
# Prepare the latent for input to the transformer encoder.
|
419 |
+
if self.config.use_vae and "action" in batch:
|
420 |
+
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
421 |
+
cls_embed = einops.repeat(
|
422 |
+
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
423 |
+
) # (B, 1, D)
|
424 |
+
if self.config.robot_state_feature:
|
425 |
+
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
426 |
+
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
427 |
+
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
428 |
+
|
429 |
+
if self.config.robot_state_feature:
|
430 |
+
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
431 |
+
else:
|
432 |
+
vae_encoder_input = [cls_embed, action_embed]
|
433 |
+
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
434 |
+
|
435 |
+
# Prepare fixed positional embedding.
|
436 |
+
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
437 |
+
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
438 |
+
|
439 |
+
# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
|
440 |
+
# sequence depending whether we use the input states or not (cls and robot state)
|
441 |
+
# False means not a padding token.
|
442 |
+
cls_joint_is_pad = torch.full(
|
443 |
+
(batch_size, 2 if self.config.robot_state_feature else 1),
|
444 |
+
False,
|
445 |
+
device=batch["observation.state"].device,
|
446 |
+
)
|
447 |
+
key_padding_mask = torch.cat(
|
448 |
+
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
|
449 |
+
) # (bs, seq+1 or 2)
|
450 |
+
|
451 |
+
# Forward pass through VAE encoder to get the latent PDF parameters.
|
452 |
+
cls_token_out = self.vae_encoder(
|
453 |
+
vae_encoder_input.permute(1, 0, 2),
|
454 |
+
pos_embed=pos_embed.permute(1, 0, 2),
|
455 |
+
key_padding_mask=key_padding_mask,
|
456 |
+
)[0] # select the class token, with shape (B, D)
|
457 |
+
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
458 |
+
mu = latent_pdf_params[:, : self.config.latent_dim]
|
459 |
+
# This is 2log(sigma). Done this way to match the original implementation.
|
460 |
+
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
|
461 |
+
|
462 |
+
# Sample the latent with the reparameterization trick.
|
463 |
+
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
464 |
+
else:
|
465 |
+
# When not using the VAE encoder, we set the latent to be all zeros.
|
466 |
+
mu = log_sigma_x2 = None
|
467 |
+
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
468 |
+
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
469 |
+
batch["observation.state"].device
|
470 |
+
)
|
471 |
+
|
472 |
+
# Prepare transformer encoder inputs.
|
473 |
+
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
474 |
+
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
475 |
+
# Robot state token.
|
476 |
+
if self.config.robot_state_feature:
|
477 |
+
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
478 |
+
# Environment state token.
|
479 |
+
if self.config.env_state_feature:
|
480 |
+
encoder_in_tokens.append(
|
481 |
+
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
482 |
+
)
|
483 |
+
|
484 |
+
# Camera observation features and positional embeddings.
|
485 |
+
if self.config.image_features:
|
486 |
+
all_cam_features = []
|
487 |
+
all_cam_pos_embeds = []
|
488 |
+
|
489 |
+
# For a list of images, the H and W may vary but H*W is constant.
|
490 |
+
for img in batch["observation.images"]:
|
491 |
+
cam_features = self.backbone(img)["feature_map"]
|
492 |
+
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
493 |
+
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
494 |
+
|
495 |
+
# Rearrange features to (sequence, batch, dim).
|
496 |
+
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
497 |
+
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
498 |
+
|
499 |
+
all_cam_features.append(cam_features)
|
500 |
+
all_cam_pos_embeds.append(cam_pos_embed)
|
501 |
+
|
502 |
+
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
503 |
+
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
504 |
+
|
505 |
+
# Stack all tokens along the sequence dimension.
|
506 |
+
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
507 |
+
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
508 |
+
|
509 |
+
# Forward pass through the transformer modules.
|
510 |
+
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
|
511 |
+
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
512 |
+
decoder_in = torch.zeros(
|
513 |
+
(self.config.chunk_size, batch_size, self.config.dim_model),
|
514 |
+
dtype=encoder_in_pos_embed.dtype,
|
515 |
+
device=encoder_in_pos_embed.device,
|
516 |
+
)
|
517 |
+
decoder_out = self.decoder(
|
518 |
+
decoder_in,
|
519 |
+
encoder_out,
|
520 |
+
encoder_pos_embed=encoder_in_pos_embed,
|
521 |
+
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
|
522 |
+
)
|
523 |
+
|
524 |
+
# Move back to (B, S, C).
|
525 |
+
decoder_out = decoder_out.transpose(0, 1)
|
526 |
+
|
527 |
+
actions = self.action_head(decoder_out)
|
528 |
+
|
529 |
+
return actions, (mu, log_sigma_x2)
|
530 |
+
|
531 |
+
|
532 |
+
class ACTEncoder(nn.Module):
|
533 |
+
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
534 |
+
|
535 |
+
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
536 |
+
super().__init__()
|
537 |
+
self.is_vae_encoder = is_vae_encoder
|
538 |
+
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
539 |
+
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
540 |
+
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
541 |
+
|
542 |
+
def forward(
|
543 |
+
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
544 |
+
) -> Tensor:
|
545 |
+
for layer in self.layers:
|
546 |
+
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
547 |
+
x = self.norm(x)
|
548 |
+
return x
|
549 |
+
|
550 |
+
|
551 |
+
class ACTEncoderLayer(nn.Module):
|
552 |
+
def __init__(self, config: ACTConfig):
|
553 |
+
super().__init__()
|
554 |
+
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
555 |
+
|
556 |
+
# Feed forward layers.
|
557 |
+
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
558 |
+
self.dropout = nn.Dropout(config.dropout)
|
559 |
+
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
|
560 |
+
|
561 |
+
self.norm1 = nn.LayerNorm(config.dim_model)
|
562 |
+
self.norm2 = nn.LayerNorm(config.dim_model)
|
563 |
+
self.dropout1 = nn.Dropout(config.dropout)
|
564 |
+
self.dropout2 = nn.Dropout(config.dropout)
|
565 |
+
|
566 |
+
self.activation = get_activation_fn(config.feedforward_activation)
|
567 |
+
self.pre_norm = config.pre_norm
|
568 |
+
|
569 |
+
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
570 |
+
skip = x
|
571 |
+
if self.pre_norm:
|
572 |
+
x = self.norm1(x)
|
573 |
+
q = k = x if pos_embed is None else x + pos_embed
|
574 |
+
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
|
575 |
+
x = x[0] # note: [0] to select just the output, not the attention weights
|
576 |
+
x = skip + self.dropout1(x)
|
577 |
+
if self.pre_norm:
|
578 |
+
skip = x
|
579 |
+
x = self.norm2(x)
|
580 |
+
else:
|
581 |
+
x = self.norm1(x)
|
582 |
+
skip = x
|
583 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
584 |
+
x = skip + self.dropout2(x)
|
585 |
+
if not self.pre_norm:
|
586 |
+
x = self.norm2(x)
|
587 |
+
return x
|
588 |
+
|
589 |
+
|
590 |
+
class ACTDecoder(nn.Module):
|
591 |
+
def __init__(self, config: ACTConfig):
|
592 |
+
"""Convenience module for running multiple decoder layers followed by normalization."""
|
593 |
+
super().__init__()
|
594 |
+
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
595 |
+
self.norm = nn.LayerNorm(config.dim_model)
|
596 |
+
|
597 |
+
def forward(
|
598 |
+
self,
|
599 |
+
x: Tensor,
|
600 |
+
encoder_out: Tensor,
|
601 |
+
decoder_pos_embed: Tensor | None = None,
|
602 |
+
encoder_pos_embed: Tensor | None = None,
|
603 |
+
) -> Tensor:
|
604 |
+
for layer in self.layers:
|
605 |
+
x = layer(
|
606 |
+
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
607 |
+
)
|
608 |
+
if self.norm is not None:
|
609 |
+
x = self.norm(x)
|
610 |
+
return x
|
611 |
+
|
612 |
+
|
613 |
+
class ACTDecoderLayer(nn.Module):
|
614 |
+
def __init__(self, config: ACTConfig):
|
615 |
+
super().__init__()
|
616 |
+
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
617 |
+
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
618 |
+
|
619 |
+
# Feed forward layers.
|
620 |
+
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
621 |
+
self.dropout = nn.Dropout(config.dropout)
|
622 |
+
self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
|
623 |
+
|
624 |
+
self.norm1 = nn.LayerNorm(config.dim_model)
|
625 |
+
self.norm2 = nn.LayerNorm(config.dim_model)
|
626 |
+
self.norm3 = nn.LayerNorm(config.dim_model)
|
627 |
+
self.dropout1 = nn.Dropout(config.dropout)
|
628 |
+
self.dropout2 = nn.Dropout(config.dropout)
|
629 |
+
self.dropout3 = nn.Dropout(config.dropout)
|
630 |
+
|
631 |
+
self.activation = get_activation_fn(config.feedforward_activation)
|
632 |
+
self.pre_norm = config.pre_norm
|
633 |
+
|
634 |
+
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
635 |
+
return tensor if pos_embed is None else tensor + pos_embed
|
636 |
+
|
637 |
+
def forward(
|
638 |
+
self,
|
639 |
+
x: Tensor,
|
640 |
+
encoder_out: Tensor,
|
641 |
+
decoder_pos_embed: Tensor | None = None,
|
642 |
+
encoder_pos_embed: Tensor | None = None,
|
643 |
+
) -> Tensor:
|
644 |
+
"""
|
645 |
+
Args:
|
646 |
+
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
647 |
+
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
648 |
+
cross-attending with.
|
649 |
+
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
650 |
+
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
651 |
+
Returns:
|
652 |
+
(DS, B, C) tensor of decoder output features.
|
653 |
+
"""
|
654 |
+
skip = x
|
655 |
+
if self.pre_norm:
|
656 |
+
x = self.norm1(x)
|
657 |
+
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
658 |
+
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
659 |
+
x = skip + self.dropout1(x)
|
660 |
+
if self.pre_norm:
|
661 |
+
skip = x
|
662 |
+
x = self.norm2(x)
|
663 |
+
else:
|
664 |
+
x = self.norm1(x)
|
665 |
+
skip = x
|
666 |
+
x = self.multihead_attn(
|
667 |
+
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
|
668 |
+
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
|
669 |
+
value=encoder_out,
|
670 |
+
)[0] # select just the output, not the attention weights
|
671 |
+
x = skip + self.dropout2(x)
|
672 |
+
if self.pre_norm:
|
673 |
+
skip = x
|
674 |
+
x = self.norm3(x)
|
675 |
+
else:
|
676 |
+
x = self.norm2(x)
|
677 |
+
skip = x
|
678 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
679 |
+
x = skip + self.dropout3(x)
|
680 |
+
if not self.pre_norm:
|
681 |
+
x = self.norm3(x)
|
682 |
+
return x
|
683 |
+
|
684 |
+
|
685 |
+
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
|
686 |
+
"""1D sinusoidal positional embeddings as in Attention is All You Need.
|
687 |
+
|
688 |
+
Args:
|
689 |
+
num_positions: Number of token positions required.
|
690 |
+
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
|
691 |
+
|
692 |
+
"""
|
693 |
+
|
694 |
+
def get_position_angle_vec(position):
|
695 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
696 |
+
|
697 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
698 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
699 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
700 |
+
return torch.from_numpy(sinusoid_table).float()
|
701 |
+
|
702 |
+
|
703 |
+
class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
704 |
+
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
|
705 |
+
|
706 |
+
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
|
707 |
+
for the vertical direction, and 1/W for the horizontal direction.
|
708 |
+
"""
|
709 |
+
|
710 |
+
def __init__(self, dimension: int):
|
711 |
+
"""
|
712 |
+
Args:
|
713 |
+
dimension: The desired dimension of the embeddings.
|
714 |
+
"""
|
715 |
+
super().__init__()
|
716 |
+
self.dimension = dimension
|
717 |
+
self._two_pi = 2 * math.pi
|
718 |
+
self._eps = 1e-6
|
719 |
+
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
|
720 |
+
self._temperature = 10000
|
721 |
+
|
722 |
+
def forward(self, x: Tensor) -> Tensor:
|
723 |
+
"""
|
724 |
+
Args:
|
725 |
+
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
|
726 |
+
Returns:
|
727 |
+
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
|
728 |
+
"""
|
729 |
+
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
730 |
+
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
731 |
+
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
|
732 |
+
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
733 |
+
x_range = not_mask.cumsum(2, dtype=torch.float32)
|
734 |
+
|
735 |
+
# "Normalize" the position index such that it ranges in [0, 2π].
|
736 |
+
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
|
737 |
+
# are non-zero by construction. This is an artifact of the original code.
|
738 |
+
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
|
739 |
+
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
740 |
+
|
741 |
+
inverse_frequency = self._temperature ** (
|
742 |
+
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
743 |
+
)
|
744 |
+
|
745 |
+
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
746 |
+
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
747 |
+
|
748 |
+
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
749 |
+
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
750 |
+
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
751 |
+
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
752 |
+
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
753 |
+
|
754 |
+
return pos_embed
|
755 |
+
|
756 |
+
|
757 |
+
def get_activation_fn(activation: str) -> Callable:
|
758 |
+
"""Return an activation function given a string."""
|
759 |
+
if activation == "relu":
|
760 |
+
return F.relu
|
761 |
+
if activation == "gelu":
|
762 |
+
return F.gelu
|
763 |
+
if activation == "glu":
|
764 |
+
return F.glu
|
765 |
+
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
lerobot/common/policies/diffusion/configuration_diffusion.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
4 |
+
# and The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
|
19 |
+
from lerobot.common.optim.optimizers import AdamConfig
|
20 |
+
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
21 |
+
from lerobot.configs.policies import PreTrainedConfig
|
22 |
+
from lerobot.configs.types import NormalizationMode
|
23 |
+
|
24 |
+
|
25 |
+
@PreTrainedConfig.register_subclass("diffusion")
|
26 |
+
@dataclass
|
27 |
+
class DiffusionConfig(PreTrainedConfig):
|
28 |
+
"""Configuration class for DiffusionPolicy.
|
29 |
+
|
30 |
+
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
31 |
+
|
32 |
+
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
33 |
+
Those are: `input_shapes` and `output_shapes`.
|
34 |
+
|
35 |
+
Notes on the inputs and outputs:
|
36 |
+
- "observation.state" is required as an input key.
|
37 |
+
- Either:
|
38 |
+
- At least one key starting with "observation.image is required as an input.
|
39 |
+
AND/OR
|
40 |
+
- The key "observation.environment_state" is required as input.
|
41 |
+
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
42 |
+
views. Right now we only support all images having the same shape.
|
43 |
+
- "action" is required as an output key.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
47 |
+
current step and additional steps going back).
|
48 |
+
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
49 |
+
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
50 |
+
See `DiffusionPolicy.select_action` for more details.
|
51 |
+
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
52 |
+
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
53 |
+
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
54 |
+
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
55 |
+
include batch dimension or temporal dimension.
|
56 |
+
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
57 |
+
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
58 |
+
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
59 |
+
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
60 |
+
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
61 |
+
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
62 |
+
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
63 |
+
[-1, 1] range.
|
64 |
+
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
65 |
+
original scale. Note that this is also used for normalizing the training targets.
|
66 |
+
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
67 |
+
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
68 |
+
within the image size. If None, no cropping is done.
|
69 |
+
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
70 |
+
mode).
|
71 |
+
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
72 |
+
`None` means no pretrained weights.
|
73 |
+
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
74 |
+
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
75 |
+
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
76 |
+
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
77 |
+
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
78 |
+
You may provide a variable number of dimensions, therefore also controlling the degree of
|
79 |
+
downsampling.
|
80 |
+
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
|
81 |
+
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
|
82 |
+
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
|
83 |
+
network. This is the output dimension of that network, i.e., the embedding dimension.
|
84 |
+
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
85 |
+
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
86 |
+
modulation.
|
87 |
+
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
88 |
+
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
89 |
+
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
90 |
+
beta_start: Beta value for the first forward-diffusion step.
|
91 |
+
beta_end: Beta value for the last forward-diffusion step.
|
92 |
+
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
|
93 |
+
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
|
94 |
+
"epsilon" has been shown to work better in many deep neural network settings.
|
95 |
+
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
96 |
+
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
97 |
+
normalized to fit within this range.
|
98 |
+
clip_sample_range: The magnitude of the clipping range as described above.
|
99 |
+
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
100 |
+
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
101 |
+
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
102 |
+
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
103 |
+
to False as the original Diffusion Policy implementation does the same.
|
104 |
+
"""
|
105 |
+
|
106 |
+
# Inputs / output structure.
|
107 |
+
n_obs_steps: int = 2
|
108 |
+
horizon: int = 16
|
109 |
+
n_action_steps: int = 8
|
110 |
+
|
111 |
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
112 |
+
default_factory=lambda: {
|
113 |
+
"VISUAL": NormalizationMode.MEAN_STD,
|
114 |
+
"STATE": NormalizationMode.MIN_MAX,
|
115 |
+
"ACTION": NormalizationMode.MIN_MAX,
|
116 |
+
}
|
117 |
+
)
|
118 |
+
|
119 |
+
# The original implementation doesn't sample frames for the last 7 steps,
|
120 |
+
# which avoids excessive padding and leads to improved training results.
|
121 |
+
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
122 |
+
|
123 |
+
# Architecture / modeling.
|
124 |
+
# Vision backbone.
|
125 |
+
vision_backbone: str = "resnet18"
|
126 |
+
crop_shape: tuple[int, int] | None = (84, 84)
|
127 |
+
crop_is_random: bool = True
|
128 |
+
pretrained_backbone_weights: str | None = None
|
129 |
+
use_group_norm: bool = True
|
130 |
+
spatial_softmax_num_keypoints: int = 32
|
131 |
+
use_separate_rgb_encoder_per_camera: bool = False
|
132 |
+
# Unet.
|
133 |
+
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
134 |
+
kernel_size: int = 5
|
135 |
+
n_groups: int = 8
|
136 |
+
diffusion_step_embed_dim: int = 128
|
137 |
+
use_film_scale_modulation: bool = True
|
138 |
+
# Noise scheduler.
|
139 |
+
noise_scheduler_type: str = "DDPM"
|
140 |
+
num_train_timesteps: int = 100
|
141 |
+
beta_schedule: str = "squaredcos_cap_v2"
|
142 |
+
beta_start: float = 0.0001
|
143 |
+
beta_end: float = 0.02
|
144 |
+
prediction_type: str = "epsilon"
|
145 |
+
clip_sample: bool = True
|
146 |
+
clip_sample_range: float = 1.0
|
147 |
+
|
148 |
+
# Inference
|
149 |
+
num_inference_steps: int | None = None
|
150 |
+
|
151 |
+
# Loss computation
|
152 |
+
do_mask_loss_for_padding: bool = False
|
153 |
+
|
154 |
+
# Training presets
|
155 |
+
optimizer_lr: float = 1e-4
|
156 |
+
optimizer_betas: tuple = (0.95, 0.999)
|
157 |
+
optimizer_eps: float = 1e-8
|
158 |
+
optimizer_weight_decay: float = 1e-6
|
159 |
+
scheduler_name: str = "cosine"
|
160 |
+
scheduler_warmup_steps: int = 500
|
161 |
+
|
162 |
+
def __post_init__(self):
|
163 |
+
super().__post_init__()
|
164 |
+
|
165 |
+
"""Input validation (not exhaustive)."""
|
166 |
+
if not self.vision_backbone.startswith("resnet"):
|
167 |
+
raise ValueError(
|
168 |
+
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
169 |
+
)
|
170 |
+
|
171 |
+
supported_prediction_types = ["epsilon", "sample"]
|
172 |
+
if self.prediction_type not in supported_prediction_types:
|
173 |
+
raise ValueError(
|
174 |
+
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
175 |
+
)
|
176 |
+
supported_noise_schedulers = ["DDPM", "DDIM"]
|
177 |
+
if self.noise_scheduler_type not in supported_noise_schedulers:
|
178 |
+
raise ValueError(
|
179 |
+
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
180 |
+
f"Got {self.noise_scheduler_type}."
|
181 |
+
)
|
182 |
+
|
183 |
+
# Check that the horizon size and U-Net downsampling is compatible.
|
184 |
+
# U-Net downsamples by 2 with each stage.
|
185 |
+
downsampling_factor = 2 ** len(self.down_dims)
|
186 |
+
if self.horizon % downsampling_factor != 0:
|
187 |
+
raise ValueError(
|
188 |
+
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
189 |
+
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
190 |
+
)
|
191 |
+
|
192 |
+
def get_optimizer_preset(self) -> AdamConfig:
|
193 |
+
return AdamConfig(
|
194 |
+
lr=self.optimizer_lr,
|
195 |
+
betas=self.optimizer_betas,
|
196 |
+
eps=self.optimizer_eps,
|
197 |
+
weight_decay=self.optimizer_weight_decay,
|
198 |
+
)
|
199 |
+
|
200 |
+
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
201 |
+
return DiffuserSchedulerConfig(
|
202 |
+
name=self.scheduler_name,
|
203 |
+
num_warmup_steps=self.scheduler_warmup_steps,
|
204 |
+
)
|
205 |
+
|
206 |
+
def validate_features(self) -> None:
|
207 |
+
if len(self.image_features) == 0 and self.env_state_feature is None:
|
208 |
+
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
209 |
+
|
210 |
+
if self.crop_shape is not None:
|
211 |
+
for key, image_ft in self.image_features.items():
|
212 |
+
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
213 |
+
raise ValueError(
|
214 |
+
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
215 |
+
f"for `crop_shape` and {image_ft.shape} for "
|
216 |
+
f"`{key}`."
|
217 |
+
)
|
218 |
+
|
219 |
+
# Check that all input images have the same shape.
|
220 |
+
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
221 |
+
for key, image_ft in self.image_features.items():
|
222 |
+
if image_ft.shape != first_image_ft.shape:
|
223 |
+
raise ValueError(
|
224 |
+
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
225 |
+
)
|
226 |
+
|
227 |
+
@property
|
228 |
+
def observation_delta_indices(self) -> list:
|
229 |
+
return list(range(1 - self.n_obs_steps, 1))
|
230 |
+
|
231 |
+
@property
|
232 |
+
def action_delta_indices(self) -> list:
|
233 |
+
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
234 |
+
|
235 |
+
@property
|
236 |
+
def reward_delta_indices(self) -> None:
|
237 |
+
return None
|
lerobot/common/policies/diffusion/modeling_diffusion.py
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
4 |
+
# and The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
18 |
+
|
19 |
+
TODO(alexander-soare):
|
20 |
+
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
21 |
+
"""
|
22 |
+
|
23 |
+
import math
|
24 |
+
from collections import deque
|
25 |
+
from typing import Callable
|
26 |
+
|
27 |
+
import einops
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F # noqa: N812
|
31 |
+
import torchvision
|
32 |
+
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
33 |
+
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
34 |
+
from torch import Tensor, nn
|
35 |
+
|
36 |
+
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
37 |
+
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
38 |
+
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
39 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
40 |
+
from lerobot.common.policies.utils import (
|
41 |
+
get_device_from_parameters,
|
42 |
+
get_dtype_from_parameters,
|
43 |
+
get_output_shape,
|
44 |
+
populate_queues,
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
class DiffusionPolicy(PreTrainedPolicy):
|
49 |
+
"""
|
50 |
+
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
51 |
+
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
52 |
+
"""
|
53 |
+
|
54 |
+
config_class = DiffusionConfig
|
55 |
+
name = "diffusion"
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
config: DiffusionConfig,
|
60 |
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
61 |
+
):
|
62 |
+
"""
|
63 |
+
Args:
|
64 |
+
config: Policy configuration class instance or None, in which case the default instantiation of
|
65 |
+
the configuration class is used.
|
66 |
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
67 |
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
68 |
+
"""
|
69 |
+
super().__init__(config)
|
70 |
+
config.validate_features()
|
71 |
+
self.config = config
|
72 |
+
|
73 |
+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
74 |
+
self.normalize_targets = Normalize(
|
75 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
76 |
+
)
|
77 |
+
self.unnormalize_outputs = Unnormalize(
|
78 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
79 |
+
)
|
80 |
+
|
81 |
+
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
82 |
+
self._queues = None
|
83 |
+
|
84 |
+
self.diffusion = DiffusionModel(config)
|
85 |
+
|
86 |
+
self.reset()
|
87 |
+
|
88 |
+
def get_optim_params(self) -> dict:
|
89 |
+
return self.diffusion.parameters()
|
90 |
+
|
91 |
+
def reset(self):
|
92 |
+
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
93 |
+
self._queues = {
|
94 |
+
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
95 |
+
"action": deque(maxlen=self.config.n_action_steps),
|
96 |
+
}
|
97 |
+
if self.config.image_features:
|
98 |
+
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
99 |
+
if self.config.env_state_feature:
|
100 |
+
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
101 |
+
|
102 |
+
@torch.no_grad
|
103 |
+
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
104 |
+
"""Select a single action given environment observations.
|
105 |
+
|
106 |
+
This method handles caching a history of observations and an action trajectory generated by the
|
107 |
+
underlying diffusion model. Here's how it works:
|
108 |
+
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
109 |
+
copied `n_obs_steps` times to fill the cache).
|
110 |
+
- The diffusion model generates `horizon` steps worth of actions.
|
111 |
+
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
112 |
+
Schematically this looks like:
|
113 |
+
----------------------------------------------------------------------------------------------
|
114 |
+
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
115 |
+
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
116 |
+
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
117 |
+
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
118 |
+
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
119 |
+
----------------------------------------------------------------------------------------------
|
120 |
+
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
121 |
+
"horizon" may not the best name to describe what the variable actually means, because this period is
|
122 |
+
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
123 |
+
"""
|
124 |
+
batch = self.normalize_inputs(batch)
|
125 |
+
if self.config.image_features:
|
126 |
+
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
127 |
+
batch["observation.images"] = torch.stack(
|
128 |
+
[batch[key] for key in self.config.image_features], dim=-4
|
129 |
+
)
|
130 |
+
# Note: It's important that this happens after stacking the images into a single key.
|
131 |
+
self._queues = populate_queues(self._queues, batch)
|
132 |
+
|
133 |
+
if len(self._queues["action"]) == 0:
|
134 |
+
# stack n latest observations from the queue
|
135 |
+
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
136 |
+
actions = self.diffusion.generate_actions(batch)
|
137 |
+
|
138 |
+
# TODO(rcadene): make above methods return output dictionary?
|
139 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
140 |
+
|
141 |
+
self._queues["action"].extend(actions.transpose(0, 1))
|
142 |
+
|
143 |
+
action = self._queues["action"].popleft()
|
144 |
+
return action
|
145 |
+
|
146 |
+
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
147 |
+
"""Run the batch through the model and compute the loss for training or validation."""
|
148 |
+
batch = self.normalize_inputs(batch)
|
149 |
+
if self.config.image_features:
|
150 |
+
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
151 |
+
batch["observation.images"] = torch.stack(
|
152 |
+
[batch[key] for key in self.config.image_features], dim=-4
|
153 |
+
)
|
154 |
+
batch = self.normalize_targets(batch)
|
155 |
+
loss = self.diffusion.compute_loss(batch)
|
156 |
+
# no output_dict so returning None
|
157 |
+
return loss, None
|
158 |
+
|
159 |
+
|
160 |
+
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
161 |
+
"""
|
162 |
+
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
163 |
+
to the scheduler.
|
164 |
+
"""
|
165 |
+
if name == "DDPM":
|
166 |
+
return DDPMScheduler(**kwargs)
|
167 |
+
elif name == "DDIM":
|
168 |
+
return DDIMScheduler(**kwargs)
|
169 |
+
else:
|
170 |
+
raise ValueError(f"Unsupported noise scheduler type {name}")
|
171 |
+
|
172 |
+
|
173 |
+
class DiffusionModel(nn.Module):
|
174 |
+
def __init__(self, config: DiffusionConfig):
|
175 |
+
super().__init__()
|
176 |
+
self.config = config
|
177 |
+
|
178 |
+
# Build observation encoders (depending on which observations are provided).
|
179 |
+
global_cond_dim = self.config.robot_state_feature.shape[0]
|
180 |
+
if self.config.image_features:
|
181 |
+
num_images = len(self.config.image_features)
|
182 |
+
if self.config.use_separate_rgb_encoder_per_camera:
|
183 |
+
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
184 |
+
self.rgb_encoder = nn.ModuleList(encoders)
|
185 |
+
global_cond_dim += encoders[0].feature_dim * num_images
|
186 |
+
else:
|
187 |
+
self.rgb_encoder = DiffusionRgbEncoder(config)
|
188 |
+
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
189 |
+
if self.config.env_state_feature:
|
190 |
+
global_cond_dim += self.config.env_state_feature.shape[0]
|
191 |
+
|
192 |
+
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
193 |
+
|
194 |
+
self.noise_scheduler = _make_noise_scheduler(
|
195 |
+
config.noise_scheduler_type,
|
196 |
+
num_train_timesteps=config.num_train_timesteps,
|
197 |
+
beta_start=config.beta_start,
|
198 |
+
beta_end=config.beta_end,
|
199 |
+
beta_schedule=config.beta_schedule,
|
200 |
+
clip_sample=config.clip_sample,
|
201 |
+
clip_sample_range=config.clip_sample_range,
|
202 |
+
prediction_type=config.prediction_type,
|
203 |
+
)
|
204 |
+
|
205 |
+
if config.num_inference_steps is None:
|
206 |
+
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
207 |
+
else:
|
208 |
+
self.num_inference_steps = config.num_inference_steps
|
209 |
+
|
210 |
+
# ========= inference ============
|
211 |
+
def conditional_sample(
|
212 |
+
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
213 |
+
) -> Tensor:
|
214 |
+
device = get_device_from_parameters(self)
|
215 |
+
dtype = get_dtype_from_parameters(self)
|
216 |
+
|
217 |
+
# Sample prior.
|
218 |
+
sample = torch.randn(
|
219 |
+
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
220 |
+
dtype=dtype,
|
221 |
+
device=device,
|
222 |
+
generator=generator,
|
223 |
+
)
|
224 |
+
|
225 |
+
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
226 |
+
|
227 |
+
for t in self.noise_scheduler.timesteps:
|
228 |
+
# Predict model output.
|
229 |
+
model_output = self.unet(
|
230 |
+
sample,
|
231 |
+
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
232 |
+
global_cond=global_cond,
|
233 |
+
)
|
234 |
+
# Compute previous image: x_t -> x_t-1
|
235 |
+
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
236 |
+
|
237 |
+
return sample
|
238 |
+
|
239 |
+
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
240 |
+
"""Encode image features and concatenate them all together along with the state vector."""
|
241 |
+
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
242 |
+
global_cond_feats = [batch[OBS_ROBOT]]
|
243 |
+
# Extract image features.
|
244 |
+
if self.config.image_features:
|
245 |
+
if self.config.use_separate_rgb_encoder_per_camera:
|
246 |
+
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
247 |
+
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
248 |
+
img_features_list = torch.cat(
|
249 |
+
[
|
250 |
+
encoder(images)
|
251 |
+
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
252 |
+
]
|
253 |
+
)
|
254 |
+
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
255 |
+
# feature dim (effectively concatenating the camera features).
|
256 |
+
img_features = einops.rearrange(
|
257 |
+
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
261 |
+
img_features = self.rgb_encoder(
|
262 |
+
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
263 |
+
)
|
264 |
+
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
265 |
+
# feature dim (effectively concatenating the camera features).
|
266 |
+
img_features = einops.rearrange(
|
267 |
+
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
268 |
+
)
|
269 |
+
global_cond_feats.append(img_features)
|
270 |
+
|
271 |
+
if self.config.env_state_feature:
|
272 |
+
global_cond_feats.append(batch[OBS_ENV])
|
273 |
+
|
274 |
+
# Concatenate features then flatten to (B, global_cond_dim).
|
275 |
+
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
276 |
+
|
277 |
+
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
278 |
+
"""
|
279 |
+
This function expects `batch` to have:
|
280 |
+
{
|
281 |
+
"observation.state": (B, n_obs_steps, state_dim)
|
282 |
+
|
283 |
+
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
284 |
+
AND/OR
|
285 |
+
"observation.environment_state": (B, environment_dim)
|
286 |
+
}
|
287 |
+
"""
|
288 |
+
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
289 |
+
assert n_obs_steps == self.config.n_obs_steps
|
290 |
+
|
291 |
+
# Encode image features and concatenate them all together along with the state vector.
|
292 |
+
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
293 |
+
|
294 |
+
# run sampling
|
295 |
+
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
296 |
+
|
297 |
+
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
298 |
+
start = n_obs_steps - 1
|
299 |
+
end = start + self.config.n_action_steps
|
300 |
+
actions = actions[:, start:end]
|
301 |
+
|
302 |
+
return actions
|
303 |
+
|
304 |
+
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
305 |
+
"""
|
306 |
+
This function expects `batch` to have (at least):
|
307 |
+
{
|
308 |
+
"observation.state": (B, n_obs_steps, state_dim)
|
309 |
+
|
310 |
+
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
311 |
+
AND/OR
|
312 |
+
"observation.environment_state": (B, environment_dim)
|
313 |
+
|
314 |
+
"action": (B, horizon, action_dim)
|
315 |
+
"action_is_pad": (B, horizon)
|
316 |
+
}
|
317 |
+
"""
|
318 |
+
# Input validation.
|
319 |
+
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
320 |
+
assert "observation.images" in batch or "observation.environment_state" in batch
|
321 |
+
n_obs_steps = batch["observation.state"].shape[1]
|
322 |
+
horizon = batch["action"].shape[1]
|
323 |
+
assert horizon == self.config.horizon
|
324 |
+
assert n_obs_steps == self.config.n_obs_steps
|
325 |
+
|
326 |
+
# Encode image features and concatenate them all together along with the state vector.
|
327 |
+
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
328 |
+
|
329 |
+
# Forward diffusion.
|
330 |
+
trajectory = batch["action"]
|
331 |
+
# Sample noise to add to the trajectory.
|
332 |
+
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
333 |
+
# Sample a random noising timestep for each item in the batch.
|
334 |
+
timesteps = torch.randint(
|
335 |
+
low=0,
|
336 |
+
high=self.noise_scheduler.config.num_train_timesteps,
|
337 |
+
size=(trajectory.shape[0],),
|
338 |
+
device=trajectory.device,
|
339 |
+
).long()
|
340 |
+
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
341 |
+
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
342 |
+
|
343 |
+
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
344 |
+
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
345 |
+
|
346 |
+
# Compute the loss.
|
347 |
+
# The target is either the original trajectory, or the noise.
|
348 |
+
if self.config.prediction_type == "epsilon":
|
349 |
+
target = eps
|
350 |
+
elif self.config.prediction_type == "sample":
|
351 |
+
target = batch["action"]
|
352 |
+
else:
|
353 |
+
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
354 |
+
|
355 |
+
loss = F.mse_loss(pred, target, reduction="none")
|
356 |
+
|
357 |
+
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
358 |
+
if self.config.do_mask_loss_for_padding:
|
359 |
+
if "action_is_pad" not in batch:
|
360 |
+
raise ValueError(
|
361 |
+
"You need to provide 'action_is_pad' in the batch when "
|
362 |
+
f"{self.config.do_mask_loss_for_padding=}."
|
363 |
+
)
|
364 |
+
in_episode_bound = ~batch["action_is_pad"]
|
365 |
+
loss = loss * in_episode_bound.unsqueeze(-1)
|
366 |
+
|
367 |
+
return loss.mean()
|
368 |
+
|
369 |
+
|
370 |
+
class SpatialSoftmax(nn.Module):
|
371 |
+
"""
|
372 |
+
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
373 |
+
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
374 |
+
|
375 |
+
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
376 |
+
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
377 |
+
|
378 |
+
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
379 |
+
-----------------------------------------------------
|
380 |
+
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
381 |
+
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
382 |
+
| ... | ... | ... | ... |
|
383 |
+
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
384 |
+
-----------------------------------------------------
|
385 |
+
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
386 |
+
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
387 |
+
|
388 |
+
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
389 |
+
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
390 |
+
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __init__(self, input_shape, num_kp=None):
|
394 |
+
"""
|
395 |
+
Args:
|
396 |
+
input_shape (list): (C, H, W) input feature map shape.
|
397 |
+
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
398 |
+
"""
|
399 |
+
super().__init__()
|
400 |
+
|
401 |
+
assert len(input_shape) == 3
|
402 |
+
self._in_c, self._in_h, self._in_w = input_shape
|
403 |
+
|
404 |
+
if num_kp is not None:
|
405 |
+
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
406 |
+
self._out_c = num_kp
|
407 |
+
else:
|
408 |
+
self.nets = None
|
409 |
+
self._out_c = self._in_c
|
410 |
+
|
411 |
+
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
412 |
+
# and causes a small degradation in pc_success of pre-trained models.
|
413 |
+
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
414 |
+
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
415 |
+
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
416 |
+
# register as buffer so it's moved to the correct device.
|
417 |
+
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
418 |
+
|
419 |
+
def forward(self, features: Tensor) -> Tensor:
|
420 |
+
"""
|
421 |
+
Args:
|
422 |
+
features: (B, C, H, W) input feature maps.
|
423 |
+
Returns:
|
424 |
+
(B, K, 2) image-space coordinates of keypoints.
|
425 |
+
"""
|
426 |
+
if self.nets is not None:
|
427 |
+
features = self.nets(features)
|
428 |
+
|
429 |
+
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
430 |
+
features = features.reshape(-1, self._in_h * self._in_w)
|
431 |
+
# 2d softmax normalization
|
432 |
+
attention = F.softmax(features, dim=-1)
|
433 |
+
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
434 |
+
expected_xy = attention @ self.pos_grid
|
435 |
+
# reshape to [B, K, 2]
|
436 |
+
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
437 |
+
|
438 |
+
return feature_keypoints
|
439 |
+
|
440 |
+
|
441 |
+
class DiffusionRgbEncoder(nn.Module):
|
442 |
+
"""Encodes an RGB image into a 1D feature vector.
|
443 |
+
|
444 |
+
Includes the ability to normalize and crop the image first.
|
445 |
+
"""
|
446 |
+
|
447 |
+
def __init__(self, config: DiffusionConfig):
|
448 |
+
super().__init__()
|
449 |
+
# Set up optional preprocessing.
|
450 |
+
if config.crop_shape is not None:
|
451 |
+
self.do_crop = True
|
452 |
+
# Always use center crop for eval
|
453 |
+
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
454 |
+
if config.crop_is_random:
|
455 |
+
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
456 |
+
else:
|
457 |
+
self.maybe_random_crop = self.center_crop
|
458 |
+
else:
|
459 |
+
self.do_crop = False
|
460 |
+
|
461 |
+
# Set up backbone.
|
462 |
+
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
463 |
+
weights=config.pretrained_backbone_weights
|
464 |
+
)
|
465 |
+
# Note: This assumes that the layer4 feature map is children()[-3]
|
466 |
+
# TODO(alexander-soare): Use a safer alternative.
|
467 |
+
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
468 |
+
if config.use_group_norm:
|
469 |
+
if config.pretrained_backbone_weights:
|
470 |
+
raise ValueError(
|
471 |
+
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
472 |
+
)
|
473 |
+
self.backbone = _replace_submodules(
|
474 |
+
root_module=self.backbone,
|
475 |
+
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
476 |
+
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
477 |
+
)
|
478 |
+
|
479 |
+
# Set up pooling and final layers.
|
480 |
+
# Use a dry run to get the feature map shape.
|
481 |
+
# The dummy input should take the number of image channels from `config.image_features` and it should
|
482 |
+
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
483 |
+
# height and width from `config.image_features`.
|
484 |
+
|
485 |
+
# Note: we have a check in the config class to make sure all images have the same shape.
|
486 |
+
images_shape = next(iter(config.image_features.values())).shape
|
487 |
+
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
488 |
+
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
489 |
+
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
490 |
+
|
491 |
+
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
492 |
+
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
493 |
+
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
494 |
+
self.relu = nn.ReLU()
|
495 |
+
|
496 |
+
def forward(self, x: Tensor) -> Tensor:
|
497 |
+
"""
|
498 |
+
Args:
|
499 |
+
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
500 |
+
Returns:
|
501 |
+
(B, D) image feature.
|
502 |
+
"""
|
503 |
+
# Preprocess: maybe crop (if it was set up in the __init__).
|
504 |
+
if self.do_crop:
|
505 |
+
if self.training: # noqa: SIM108
|
506 |
+
x = self.maybe_random_crop(x)
|
507 |
+
else:
|
508 |
+
# Always use center crop for eval.
|
509 |
+
x = self.center_crop(x)
|
510 |
+
# Extract backbone feature.
|
511 |
+
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
512 |
+
# Final linear layer with non-linearity.
|
513 |
+
x = self.relu(self.out(x))
|
514 |
+
return x
|
515 |
+
|
516 |
+
|
517 |
+
def _replace_submodules(
|
518 |
+
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
519 |
+
) -> nn.Module:
|
520 |
+
"""
|
521 |
+
Args:
|
522 |
+
root_module: The module for which the submodules need to be replaced
|
523 |
+
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
524 |
+
func: Takes a module as an argument and returns a new module to replace it with.
|
525 |
+
Returns:
|
526 |
+
The root module with its submodules replaced.
|
527 |
+
"""
|
528 |
+
if predicate(root_module):
|
529 |
+
return func(root_module)
|
530 |
+
|
531 |
+
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
532 |
+
for *parents, k in replace_list:
|
533 |
+
parent_module = root_module
|
534 |
+
if len(parents) > 0:
|
535 |
+
parent_module = root_module.get_submodule(".".join(parents))
|
536 |
+
if isinstance(parent_module, nn.Sequential):
|
537 |
+
src_module = parent_module[int(k)]
|
538 |
+
else:
|
539 |
+
src_module = getattr(parent_module, k)
|
540 |
+
tgt_module = func(src_module)
|
541 |
+
if isinstance(parent_module, nn.Sequential):
|
542 |
+
parent_module[int(k)] = tgt_module
|
543 |
+
else:
|
544 |
+
setattr(parent_module, k, tgt_module)
|
545 |
+
# verify that all BN are replaced
|
546 |
+
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
547 |
+
return root_module
|
548 |
+
|
549 |
+
|
550 |
+
class DiffusionSinusoidalPosEmb(nn.Module):
|
551 |
+
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
552 |
+
|
553 |
+
def __init__(self, dim: int):
|
554 |
+
super().__init__()
|
555 |
+
self.dim = dim
|
556 |
+
|
557 |
+
def forward(self, x: Tensor) -> Tensor:
|
558 |
+
device = x.device
|
559 |
+
half_dim = self.dim // 2
|
560 |
+
emb = math.log(10000) / (half_dim - 1)
|
561 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
562 |
+
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
|
563 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
564 |
+
return emb
|
565 |
+
|
566 |
+
|
567 |
+
class DiffusionConv1dBlock(nn.Module):
|
568 |
+
"""Conv1d --> GroupNorm --> Mish"""
|
569 |
+
|
570 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
571 |
+
super().__init__()
|
572 |
+
|
573 |
+
self.block = nn.Sequential(
|
574 |
+
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
575 |
+
nn.GroupNorm(n_groups, out_channels),
|
576 |
+
nn.Mish(),
|
577 |
+
)
|
578 |
+
|
579 |
+
def forward(self, x):
|
580 |
+
return self.block(x)
|
581 |
+
|
582 |
+
|
583 |
+
class DiffusionConditionalUnet1d(nn.Module):
|
584 |
+
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
585 |
+
|
586 |
+
Note: this removes local conditioning as compared to the original diffusion policy code.
|
587 |
+
"""
|
588 |
+
|
589 |
+
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
|
590 |
+
super().__init__()
|
591 |
+
|
592 |
+
self.config = config
|
593 |
+
|
594 |
+
# Encoder for the diffusion timestep.
|
595 |
+
self.diffusion_step_encoder = nn.Sequential(
|
596 |
+
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
597 |
+
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
598 |
+
nn.Mish(),
|
599 |
+
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
600 |
+
)
|
601 |
+
|
602 |
+
# The FiLM conditioning dimension.
|
603 |
+
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
|
604 |
+
|
605 |
+
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
606 |
+
# just reverse these.
|
607 |
+
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
|
608 |
+
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
609 |
+
)
|
610 |
+
|
611 |
+
# Unet encoder.
|
612 |
+
common_res_block_kwargs = {
|
613 |
+
"cond_dim": cond_dim,
|
614 |
+
"kernel_size": config.kernel_size,
|
615 |
+
"n_groups": config.n_groups,
|
616 |
+
"use_film_scale_modulation": config.use_film_scale_modulation,
|
617 |
+
}
|
618 |
+
self.down_modules = nn.ModuleList([])
|
619 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
620 |
+
is_last = ind >= (len(in_out) - 1)
|
621 |
+
self.down_modules.append(
|
622 |
+
nn.ModuleList(
|
623 |
+
[
|
624 |
+
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
625 |
+
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
626 |
+
# Downsample as long as it is not the last block.
|
627 |
+
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
628 |
+
]
|
629 |
+
)
|
630 |
+
)
|
631 |
+
|
632 |
+
# Processing in the middle of the auto-encoder.
|
633 |
+
self.mid_modules = nn.ModuleList(
|
634 |
+
[
|
635 |
+
DiffusionConditionalResidualBlock1d(
|
636 |
+
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
637 |
+
),
|
638 |
+
DiffusionConditionalResidualBlock1d(
|
639 |
+
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
640 |
+
),
|
641 |
+
]
|
642 |
+
)
|
643 |
+
|
644 |
+
# Unet decoder.
|
645 |
+
self.up_modules = nn.ModuleList([])
|
646 |
+
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
|
647 |
+
is_last = ind >= (len(in_out) - 1)
|
648 |
+
self.up_modules.append(
|
649 |
+
nn.ModuleList(
|
650 |
+
[
|
651 |
+
# dim_in * 2, because it takes the encoder's skip connection as well
|
652 |
+
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
653 |
+
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
654 |
+
# Upsample as long as it is not the last block.
|
655 |
+
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
656 |
+
]
|
657 |
+
)
|
658 |
+
)
|
659 |
+
|
660 |
+
self.final_conv = nn.Sequential(
|
661 |
+
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
662 |
+
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
663 |
+
)
|
664 |
+
|
665 |
+
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
666 |
+
"""
|
667 |
+
Args:
|
668 |
+
x: (B, T, input_dim) tensor for input to the Unet.
|
669 |
+
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
670 |
+
global_cond: (B, global_cond_dim)
|
671 |
+
output: (B, T, input_dim)
|
672 |
+
Returns:
|
673 |
+
(B, T, input_dim) diffusion model prediction.
|
674 |
+
"""
|
675 |
+
# For 1D convolutions we'll need feature dimension first.
|
676 |
+
x = einops.rearrange(x, "b t d -> b d t")
|
677 |
+
|
678 |
+
timesteps_embed = self.diffusion_step_encoder(timestep)
|
679 |
+
|
680 |
+
# If there is a global conditioning feature, concatenate it to the timestep embedding.
|
681 |
+
if global_cond is not None:
|
682 |
+
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
|
683 |
+
else:
|
684 |
+
global_feature = timesteps_embed
|
685 |
+
|
686 |
+
# Run encoder, keeping track of skip features to pass to the decoder.
|
687 |
+
encoder_skip_features: list[Tensor] = []
|
688 |
+
for resnet, resnet2, downsample in self.down_modules:
|
689 |
+
x = resnet(x, global_feature)
|
690 |
+
x = resnet2(x, global_feature)
|
691 |
+
encoder_skip_features.append(x)
|
692 |
+
x = downsample(x)
|
693 |
+
|
694 |
+
for mid_module in self.mid_modules:
|
695 |
+
x = mid_module(x, global_feature)
|
696 |
+
|
697 |
+
# Run decoder, using the skip features from the encoder.
|
698 |
+
for resnet, resnet2, upsample in self.up_modules:
|
699 |
+
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
700 |
+
x = resnet(x, global_feature)
|
701 |
+
x = resnet2(x, global_feature)
|
702 |
+
x = upsample(x)
|
703 |
+
|
704 |
+
x = self.final_conv(x)
|
705 |
+
|
706 |
+
x = einops.rearrange(x, "b d t -> b t d")
|
707 |
+
return x
|
708 |
+
|
709 |
+
|
710 |
+
class DiffusionConditionalResidualBlock1d(nn.Module):
|
711 |
+
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
712 |
+
|
713 |
+
def __init__(
|
714 |
+
self,
|
715 |
+
in_channels: int,
|
716 |
+
out_channels: int,
|
717 |
+
cond_dim: int,
|
718 |
+
kernel_size: int = 3,
|
719 |
+
n_groups: int = 8,
|
720 |
+
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
721 |
+
# FiLM just modulates bias).
|
722 |
+
use_film_scale_modulation: bool = False,
|
723 |
+
):
|
724 |
+
super().__init__()
|
725 |
+
|
726 |
+
self.use_film_scale_modulation = use_film_scale_modulation
|
727 |
+
self.out_channels = out_channels
|
728 |
+
|
729 |
+
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
730 |
+
|
731 |
+
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
732 |
+
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
733 |
+
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
734 |
+
|
735 |
+
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
736 |
+
|
737 |
+
# A final convolution for dimension matching the residual (if needed).
|
738 |
+
self.residual_conv = (
|
739 |
+
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
740 |
+
)
|
741 |
+
|
742 |
+
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
743 |
+
"""
|
744 |
+
Args:
|
745 |
+
x: (B, in_channels, T)
|
746 |
+
cond: (B, cond_dim)
|
747 |
+
Returns:
|
748 |
+
(B, out_channels, T)
|
749 |
+
"""
|
750 |
+
out = self.conv1(x)
|
751 |
+
|
752 |
+
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
753 |
+
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
754 |
+
if self.use_film_scale_modulation:
|
755 |
+
# Treat the embedding as a list of scales and biases.
|
756 |
+
scale = cond_embed[:, : self.out_channels]
|
757 |
+
bias = cond_embed[:, self.out_channels :]
|
758 |
+
out = scale * out + bias
|
759 |
+
else:
|
760 |
+
# Treat the embedding as biases.
|
761 |
+
out = out + cond_embed
|
762 |
+
|
763 |
+
out = self.conv2(out)
|
764 |
+
out = out + self.residual_conv(x)
|
765 |
+
return out
|
lerobot/common/policies/factory.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
22 |
+
from lerobot.common.datasets.utils import dataset_to_policy_features
|
23 |
+
from lerobot.common.envs.configs import EnvConfig
|
24 |
+
from lerobot.common.envs.utils import env_to_policy_features
|
25 |
+
from lerobot.common.policies.act.configuration_act import ACTConfig
|
26 |
+
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
27 |
+
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
28 |
+
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
29 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
30 |
+
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
31 |
+
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
32 |
+
from lerobot.configs.policies import PreTrainedConfig
|
33 |
+
from lerobot.configs.types import FeatureType
|
34 |
+
|
35 |
+
|
36 |
+
def get_policy_class(name: str) -> PreTrainedPolicy:
|
37 |
+
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
38 |
+
if name == "tdmpc":
|
39 |
+
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
40 |
+
|
41 |
+
return TDMPCPolicy
|
42 |
+
elif name == "diffusion":
|
43 |
+
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
44 |
+
|
45 |
+
return DiffusionPolicy
|
46 |
+
elif name == "act":
|
47 |
+
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
48 |
+
|
49 |
+
return ACTPolicy
|
50 |
+
elif name == "vqbet":
|
51 |
+
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
52 |
+
|
53 |
+
return VQBeTPolicy
|
54 |
+
elif name == "pi0":
|
55 |
+
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
56 |
+
|
57 |
+
return PI0Policy
|
58 |
+
elif name == "pi0fast":
|
59 |
+
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
60 |
+
|
61 |
+
return PI0FASTPolicy
|
62 |
+
else:
|
63 |
+
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
64 |
+
|
65 |
+
|
66 |
+
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
67 |
+
if policy_type == "tdmpc":
|
68 |
+
return TDMPCConfig(**kwargs)
|
69 |
+
elif policy_type == "diffusion":
|
70 |
+
return DiffusionConfig(**kwargs)
|
71 |
+
elif policy_type == "act":
|
72 |
+
return ACTConfig(**kwargs)
|
73 |
+
elif policy_type == "vqbet":
|
74 |
+
return VQBeTConfig(**kwargs)
|
75 |
+
elif policy_type == "pi0":
|
76 |
+
return PI0Config(**kwargs)
|
77 |
+
elif policy_type == "pi0fast":
|
78 |
+
return PI0FASTConfig(**kwargs)
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
81 |
+
|
82 |
+
|
83 |
+
def make_policy(
|
84 |
+
cfg: PreTrainedConfig,
|
85 |
+
ds_meta: LeRobotDatasetMetadata | None = None,
|
86 |
+
env_cfg: EnvConfig | None = None,
|
87 |
+
) -> PreTrainedPolicy:
|
88 |
+
"""Make an instance of a policy class.
|
89 |
+
|
90 |
+
This function exists because (for now) we need to parse features from either a dataset or an environment
|
91 |
+
in order to properly dimension and instantiate a policy for that dataset or environment.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
95 |
+
be loaded with the weights from that path.
|
96 |
+
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
97 |
+
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
98 |
+
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
99 |
+
provided if ds_meta is not. Defaults to None.
|
100 |
+
|
101 |
+
Raises:
|
102 |
+
ValueError: Either ds_meta or env and env_cfg must be provided.
|
103 |
+
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
PreTrainedPolicy: _description_
|
107 |
+
"""
|
108 |
+
if bool(ds_meta) == bool(env_cfg):
|
109 |
+
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
110 |
+
|
111 |
+
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
|
112 |
+
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
|
113 |
+
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
|
114 |
+
# you want this op to be added in priority during the prototype phase of this feature, please comment on
|
115 |
+
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
116 |
+
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
117 |
+
# slower than running natively on MPS.
|
118 |
+
if cfg.type == "vqbet" and cfg.device == "mps":
|
119 |
+
raise NotImplementedError(
|
120 |
+
"Current implementation of VQBeT does not support `mps` backend. "
|
121 |
+
"Please use `cpu` or `cuda` backend."
|
122 |
+
)
|
123 |
+
|
124 |
+
policy_cls = get_policy_class(cfg.type)
|
125 |
+
|
126 |
+
kwargs = {}
|
127 |
+
if ds_meta is not None:
|
128 |
+
features = dataset_to_policy_features(ds_meta.features)
|
129 |
+
kwargs["dataset_stats"] = ds_meta.stats
|
130 |
+
else:
|
131 |
+
if not cfg.pretrained_path:
|
132 |
+
logging.warning(
|
133 |
+
"You are instantiating a policy from scratch and its features are parsed from an environment "
|
134 |
+
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
135 |
+
"by default without stats from a dataset."
|
136 |
+
)
|
137 |
+
features = env_to_policy_features(env_cfg)
|
138 |
+
|
139 |
+
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
140 |
+
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
141 |
+
kwargs["config"] = cfg
|
142 |
+
|
143 |
+
if cfg.pretrained_path:
|
144 |
+
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
145 |
+
# hyperparameters that we want to vary).
|
146 |
+
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
147 |
+
policy = policy_cls.from_pretrained(**kwargs)
|
148 |
+
else:
|
149 |
+
# Make a fresh policy.
|
150 |
+
policy = policy_cls(**kwargs)
|
151 |
+
|
152 |
+
policy.to(cfg.device)
|
153 |
+
assert isinstance(policy, nn.Module)
|
154 |
+
|
155 |
+
# policy = torch.compile(policy, mode="reduce-overhead")
|
156 |
+
|
157 |
+
return policy
|
lerobot/common/policies/normalize.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, nn
|
19 |
+
|
20 |
+
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
21 |
+
|
22 |
+
|
23 |
+
def create_stats_buffers(
|
24 |
+
features: dict[str, PolicyFeature],
|
25 |
+
norm_map: dict[str, NormalizationMode],
|
26 |
+
stats: dict[str, dict[str, Tensor]] | None = None,
|
27 |
+
) -> dict[str, dict[str, nn.ParameterDict]]:
|
28 |
+
"""
|
29 |
+
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
30 |
+
statistics.
|
31 |
+
|
32 |
+
Args: (see Normalize and Unnormalize)
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
36 |
+
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
37 |
+
"""
|
38 |
+
stats_buffers = {}
|
39 |
+
|
40 |
+
for key, ft in features.items():
|
41 |
+
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
42 |
+
if norm_mode is NormalizationMode.IDENTITY:
|
43 |
+
continue
|
44 |
+
|
45 |
+
assert isinstance(norm_mode, NormalizationMode)
|
46 |
+
|
47 |
+
shape = tuple(ft.shape)
|
48 |
+
|
49 |
+
if ft.type is FeatureType.VISUAL:
|
50 |
+
# sanity checks
|
51 |
+
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
52 |
+
c, h, w = shape
|
53 |
+
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
54 |
+
# override image shape to be invariant to height and width
|
55 |
+
shape = (c, 1, 1)
|
56 |
+
|
57 |
+
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
58 |
+
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
59 |
+
# we assert they are not infinity anymore.
|
60 |
+
|
61 |
+
buffer = {}
|
62 |
+
if norm_mode is NormalizationMode.MEAN_STD:
|
63 |
+
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
64 |
+
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
65 |
+
buffer = nn.ParameterDict(
|
66 |
+
{
|
67 |
+
"mean": nn.Parameter(mean, requires_grad=False),
|
68 |
+
"std": nn.Parameter(std, requires_grad=False),
|
69 |
+
}
|
70 |
+
)
|
71 |
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
72 |
+
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
73 |
+
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
74 |
+
buffer = nn.ParameterDict(
|
75 |
+
{
|
76 |
+
"min": nn.Parameter(min, requires_grad=False),
|
77 |
+
"max": nn.Parameter(max, requires_grad=False),
|
78 |
+
}
|
79 |
+
)
|
80 |
+
|
81 |
+
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
82 |
+
if stats:
|
83 |
+
if isinstance(stats[key]["mean"], np.ndarray):
|
84 |
+
if norm_mode is NormalizationMode.MEAN_STD:
|
85 |
+
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
86 |
+
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
87 |
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
88 |
+
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
89 |
+
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
90 |
+
elif isinstance(stats[key]["mean"], torch.Tensor):
|
91 |
+
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
92 |
+
# tensors anywhere (for example, when we use the same stats for normalization and
|
93 |
+
# unnormalization). See the logic here
|
94 |
+
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
95 |
+
if norm_mode is NormalizationMode.MEAN_STD:
|
96 |
+
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
97 |
+
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
98 |
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
99 |
+
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
100 |
+
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
101 |
+
else:
|
102 |
+
type_ = type(stats[key]["mean"])
|
103 |
+
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
104 |
+
|
105 |
+
stats_buffers[key] = buffer
|
106 |
+
return stats_buffers
|
107 |
+
|
108 |
+
|
109 |
+
def _no_stats_error_str(name: str) -> str:
|
110 |
+
return (
|
111 |
+
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
112 |
+
"pretrained model."
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
class Normalize(nn.Module):
|
117 |
+
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
features: dict[str, PolicyFeature],
|
122 |
+
norm_map: dict[str, NormalizationMode],
|
123 |
+
stats: dict[str, dict[str, Tensor]] | None = None,
|
124 |
+
):
|
125 |
+
"""
|
126 |
+
Args:
|
127 |
+
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
128 |
+
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
129 |
+
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
130 |
+
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
131 |
+
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
132 |
+
are their normalization modes among:
|
133 |
+
- "mean_std": subtract the mean and divide by standard deviation.
|
134 |
+
- "min_max": map to [-1, 1] range.
|
135 |
+
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
136 |
+
and values are dictionaries of statistic types and their values (e.g.
|
137 |
+
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
138 |
+
training the model for the first time, these statistics will overwrite the default buffers. If
|
139 |
+
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
140 |
+
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
141 |
+
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
142 |
+
"""
|
143 |
+
super().__init__()
|
144 |
+
self.features = features
|
145 |
+
self.norm_map = norm_map
|
146 |
+
self.stats = stats
|
147 |
+
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
148 |
+
for key, buffer in stats_buffers.items():
|
149 |
+
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
150 |
+
|
151 |
+
# TODO(rcadene): should we remove torch.no_grad?
|
152 |
+
@torch.no_grad
|
153 |
+
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
154 |
+
batch = dict(batch) # shallow copy avoids mutating the input batch
|
155 |
+
for key, ft in self.features.items():
|
156 |
+
if key not in batch:
|
157 |
+
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
158 |
+
continue
|
159 |
+
|
160 |
+
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
161 |
+
if norm_mode is NormalizationMode.IDENTITY:
|
162 |
+
continue
|
163 |
+
|
164 |
+
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
165 |
+
|
166 |
+
if norm_mode is NormalizationMode.MEAN_STD:
|
167 |
+
mean = buffer["mean"]
|
168 |
+
std = buffer["std"]
|
169 |
+
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
170 |
+
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
171 |
+
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
172 |
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
173 |
+
min = buffer["min"]
|
174 |
+
max = buffer["max"]
|
175 |
+
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
176 |
+
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
177 |
+
# normalize to [0,1]
|
178 |
+
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
179 |
+
# normalize to [-1, 1]
|
180 |
+
batch[key] = batch[key] * 2 - 1
|
181 |
+
else:
|
182 |
+
raise ValueError(norm_mode)
|
183 |
+
return batch
|
184 |
+
|
185 |
+
|
186 |
+
class Unnormalize(nn.Module):
|
187 |
+
"""
|
188 |
+
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
189 |
+
original range used by the environment.
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
features: dict[str, PolicyFeature],
|
195 |
+
norm_map: dict[str, NormalizationMode],
|
196 |
+
stats: dict[str, dict[str, Tensor]] | None = None,
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
201 |
+
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
202 |
+
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
203 |
+
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
204 |
+
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
205 |
+
are their normalization modes among:
|
206 |
+
- "mean_std": subtract the mean and divide by standard deviation.
|
207 |
+
- "min_max": map to [-1, 1] range.
|
208 |
+
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
209 |
+
and values are dictionaries of statistic types and their values (e.g.
|
210 |
+
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
211 |
+
training the model for the first time, these statistics will overwrite the default buffers. If
|
212 |
+
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
213 |
+
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
214 |
+
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
215 |
+
"""
|
216 |
+
super().__init__()
|
217 |
+
self.features = features
|
218 |
+
self.norm_map = norm_map
|
219 |
+
self.stats = stats
|
220 |
+
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
221 |
+
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
222 |
+
for key, buffer in stats_buffers.items():
|
223 |
+
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
224 |
+
|
225 |
+
# TODO(rcadene): should we remove torch.no_grad?
|
226 |
+
@torch.no_grad
|
227 |
+
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
228 |
+
batch = dict(batch) # shallow copy avoids mutating the input batch
|
229 |
+
for key, ft in self.features.items():
|
230 |
+
if key not in batch:
|
231 |
+
continue
|
232 |
+
|
233 |
+
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
234 |
+
if norm_mode is NormalizationMode.IDENTITY:
|
235 |
+
continue
|
236 |
+
|
237 |
+
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
238 |
+
|
239 |
+
if norm_mode is NormalizationMode.MEAN_STD:
|
240 |
+
mean = buffer["mean"]
|
241 |
+
std = buffer["std"]
|
242 |
+
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
243 |
+
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
244 |
+
batch[key] = batch[key] * std + mean
|
245 |
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
246 |
+
min = buffer["min"]
|
247 |
+
max = buffer["max"]
|
248 |
+
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
249 |
+
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
250 |
+
batch[key] = (batch[key] + 1) / 2
|
251 |
+
batch[key] = batch[key] * (max - min) + min
|
252 |
+
else:
|
253 |
+
raise ValueError(norm_mode)
|
254 |
+
return batch
|
lerobot/common/policies/pi0/configuration_pi0.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
|
17 |
+
from lerobot.common.optim.optimizers import AdamWConfig
|
18 |
+
from lerobot.common.optim.schedulers import (
|
19 |
+
CosineDecayWithWarmupSchedulerConfig,
|
20 |
+
)
|
21 |
+
from lerobot.configs.policies import PreTrainedConfig
|
22 |
+
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
23 |
+
|
24 |
+
|
25 |
+
@PreTrainedConfig.register_subclass("pi0")
|
26 |
+
@dataclass
|
27 |
+
class PI0Config(PreTrainedConfig):
|
28 |
+
# Input / output structure.
|
29 |
+
n_obs_steps: int = 1
|
30 |
+
chunk_size: int = 50
|
31 |
+
n_action_steps: int = 50
|
32 |
+
|
33 |
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
34 |
+
default_factory=lambda: {
|
35 |
+
"VISUAL": NormalizationMode.IDENTITY,
|
36 |
+
"STATE": NormalizationMode.MEAN_STD,
|
37 |
+
"ACTION": NormalizationMode.MEAN_STD,
|
38 |
+
}
|
39 |
+
)
|
40 |
+
|
41 |
+
# Shorter state and action vectors will be padded
|
42 |
+
max_state_dim: int = 32
|
43 |
+
max_action_dim: int = 32
|
44 |
+
|
45 |
+
# Image preprocessing
|
46 |
+
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
47 |
+
|
48 |
+
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
49 |
+
# left and right wrist cameras in addition to the top camera.
|
50 |
+
empty_cameras: int = 0
|
51 |
+
|
52 |
+
# Converts the joint and gripper values from the standard Aloha space to
|
53 |
+
# the space used by the pi internal runtime which was used to train the base model.
|
54 |
+
adapt_to_pi_aloha: bool = False
|
55 |
+
|
56 |
+
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
57 |
+
# Gripper dimensions will remain in absolute values.
|
58 |
+
use_delta_joint_actions_aloha: bool = False
|
59 |
+
|
60 |
+
# Tokenizer
|
61 |
+
tokenizer_max_length: int = 48
|
62 |
+
|
63 |
+
# Projector
|
64 |
+
proj_width: int = 1024
|
65 |
+
|
66 |
+
# Decoding
|
67 |
+
num_steps: int = 10
|
68 |
+
|
69 |
+
# Attention utils
|
70 |
+
use_cache: bool = True
|
71 |
+
attention_implementation: str = "eager" # or fa2, flex
|
72 |
+
|
73 |
+
# Finetuning settings
|
74 |
+
freeze_vision_encoder: bool = True
|
75 |
+
train_expert_only: bool = False
|
76 |
+
train_state_proj: bool = True
|
77 |
+
|
78 |
+
# Training presets
|
79 |
+
optimizer_lr: float = 2.5e-5
|
80 |
+
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
81 |
+
optimizer_eps: float = 1e-8
|
82 |
+
optimizer_weight_decay: float = 1e-10
|
83 |
+
|
84 |
+
scheduler_warmup_steps: int = 1_000
|
85 |
+
scheduler_decay_steps: int = 30_000
|
86 |
+
scheduler_decay_lr: float = 2.5e-6
|
87 |
+
|
88 |
+
# TODO: Add EMA
|
89 |
+
|
90 |
+
def __post_init__(self):
|
91 |
+
super().__post_init__()
|
92 |
+
|
93 |
+
# TODO(Steven): Validate device and amp? in all policy configs?
|
94 |
+
"""Input validation (not exhaustive)."""
|
95 |
+
if self.n_action_steps > self.chunk_size:
|
96 |
+
raise ValueError(
|
97 |
+
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
98 |
+
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
99 |
+
)
|
100 |
+
if self.n_obs_steps != 1:
|
101 |
+
raise ValueError(
|
102 |
+
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
103 |
+
)
|
104 |
+
|
105 |
+
if self.use_delta_joint_actions_aloha:
|
106 |
+
raise NotImplementedError(
|
107 |
+
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
108 |
+
)
|
109 |
+
|
110 |
+
def validate_features(self) -> None:
|
111 |
+
# TODO: implement value error
|
112 |
+
# if not self.image_features and not self.env_state_feature:
|
113 |
+
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
114 |
+
|
115 |
+
for i in range(self.empty_cameras):
|
116 |
+
key = f"observation.images.empty_camera_{i}"
|
117 |
+
empty_camera = PolicyFeature(
|
118 |
+
type=FeatureType.VISUAL,
|
119 |
+
shape=(3, 480, 640),
|
120 |
+
)
|
121 |
+
self.input_features[key] = empty_camera
|
122 |
+
|
123 |
+
def get_optimizer_preset(self) -> AdamWConfig:
|
124 |
+
return AdamWConfig(
|
125 |
+
lr=self.optimizer_lr,
|
126 |
+
betas=self.optimizer_betas,
|
127 |
+
eps=self.optimizer_eps,
|
128 |
+
weight_decay=self.optimizer_weight_decay,
|
129 |
+
)
|
130 |
+
|
131 |
+
def get_scheduler_preset(self):
|
132 |
+
return CosineDecayWithWarmupSchedulerConfig(
|
133 |
+
peak_lr=self.optimizer_lr,
|
134 |
+
decay_lr=self.scheduler_decay_lr,
|
135 |
+
num_warmup_steps=self.scheduler_warmup_steps,
|
136 |
+
num_decay_steps=self.scheduler_decay_steps,
|
137 |
+
)
|
138 |
+
|
139 |
+
@property
|
140 |
+
def observation_delta_indices(self) -> None:
|
141 |
+
return None
|
142 |
+
|
143 |
+
@property
|
144 |
+
def action_delta_indices(self) -> list:
|
145 |
+
return list(range(self.chunk_size))
|
146 |
+
|
147 |
+
@property
|
148 |
+
def reward_delta_indices(self) -> None:
|
149 |
+
return None
|
lerobot/common/policies/pi0/conversion_scripts/benchmark.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
18 |
+
from lerobot.common.policies.factory import make_policy
|
19 |
+
from lerobot.configs.policies import PreTrainedConfig
|
20 |
+
|
21 |
+
torch.backends.cudnn.benchmark = True
|
22 |
+
|
23 |
+
|
24 |
+
def main():
|
25 |
+
device = "cuda"
|
26 |
+
dataset_repo_id = "danaaubakirova/koch_test"
|
27 |
+
# model_name = "pi0_base"
|
28 |
+
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
29 |
+
ckpt_torch_dir = "lerobot/pi0"
|
30 |
+
|
31 |
+
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
32 |
+
|
33 |
+
dataloader = torch.utils.data.DataLoader(
|
34 |
+
dataset,
|
35 |
+
num_workers=0,
|
36 |
+
batch_size=1,
|
37 |
+
)
|
38 |
+
|
39 |
+
batch = next(iter(dataloader))
|
40 |
+
|
41 |
+
# To device
|
42 |
+
for k in batch:
|
43 |
+
if isinstance(batch[k], torch.Tensor):
|
44 |
+
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
45 |
+
|
46 |
+
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
47 |
+
cfg.pretrained_path = ckpt_torch_dir
|
48 |
+
policy = make_policy(cfg, ds_meta=dataset.meta)
|
49 |
+
|
50 |
+
# policy = torch.compile(policy, mode="reduce-overhead")
|
51 |
+
|
52 |
+
warmup_iters = 10
|
53 |
+
benchmark_iters = 30
|
54 |
+
|
55 |
+
# Warmup
|
56 |
+
for _ in range(warmup_iters):
|
57 |
+
torch.cuda.synchronize()
|
58 |
+
policy.select_action(batch)
|
59 |
+
policy.reset()
|
60 |
+
torch.cuda.synchronize()
|
61 |
+
|
62 |
+
# Benchmark
|
63 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
64 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
65 |
+
|
66 |
+
start_event.record()
|
67 |
+
for _ in range(benchmark_iters):
|
68 |
+
policy.select_action(batch)
|
69 |
+
policy.reset()
|
70 |
+
end_event.record()
|
71 |
+
|
72 |
+
# Synchronize and measure time
|
73 |
+
torch.cuda.synchronize()
|
74 |
+
elapsed_time_ms = start_event.elapsed_time(end_event)
|
75 |
+
|
76 |
+
avg_time_per_iter = elapsed_time_ms / benchmark_iters
|
77 |
+
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
with torch.inference_mode():
|
82 |
+
main()
|
lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import pickle
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
22 |
+
from lerobot.common.policies.factory import make_policy
|
23 |
+
from lerobot.configs.policies import PreTrainedConfig
|
24 |
+
|
25 |
+
|
26 |
+
def display(tensor: torch.Tensor):
|
27 |
+
if tensor.dtype == torch.bool:
|
28 |
+
tensor = tensor.float()
|
29 |
+
print(f"Shape: {tensor.shape}")
|
30 |
+
print(f"Mean: {tensor.mean().item()}")
|
31 |
+
print(f"Std: {tensor.std().item()}")
|
32 |
+
print(f"Min: {tensor.min().item()}")
|
33 |
+
print(f"Max: {tensor.max().item()}")
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
num_motors = 14
|
38 |
+
device = "cuda"
|
39 |
+
# model_name = "pi0_aloha_towel"
|
40 |
+
model_name = "pi0_aloha_sim"
|
41 |
+
|
42 |
+
if model_name == "pi0_aloha_towel":
|
43 |
+
dataset_repo_id = "lerobot/aloha_static_towel"
|
44 |
+
else:
|
45 |
+
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
46 |
+
|
47 |
+
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
48 |
+
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
49 |
+
save_dir = Path(f"../openpi/data/{model_name}/save")
|
50 |
+
|
51 |
+
with open(save_dir / "example.pkl", "rb") as f:
|
52 |
+
example = pickle.load(f)
|
53 |
+
with open(save_dir / "outputs.pkl", "rb") as f:
|
54 |
+
outputs = pickle.load(f)
|
55 |
+
with open(save_dir / "noise.pkl", "rb") as f:
|
56 |
+
noise = pickle.load(f)
|
57 |
+
|
58 |
+
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
59 |
+
norm_stats = json.load(f)
|
60 |
+
|
61 |
+
# Override stats
|
62 |
+
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
63 |
+
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
64 |
+
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
65 |
+
)
|
66 |
+
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
67 |
+
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
68 |
+
)
|
69 |
+
|
70 |
+
# Create LeRobot batch from Jax
|
71 |
+
batch = {}
|
72 |
+
for cam_key, uint_chw_array in example["images"].items():
|
73 |
+
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
74 |
+
batch["observation.state"] = torch.from_numpy(example["state"])
|
75 |
+
batch["action"] = torch.from_numpy(outputs["actions"])
|
76 |
+
batch["task"] = example["prompt"]
|
77 |
+
|
78 |
+
if model_name == "pi0_aloha_towel":
|
79 |
+
del batch["observation.images.cam_low"]
|
80 |
+
elif model_name == "pi0_aloha_sim":
|
81 |
+
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
82 |
+
del batch["observation.images.cam_high"]
|
83 |
+
|
84 |
+
# Batchify
|
85 |
+
for key in batch:
|
86 |
+
if isinstance(batch[key], torch.Tensor):
|
87 |
+
batch[key] = batch[key].unsqueeze(0)
|
88 |
+
elif isinstance(batch[key], str):
|
89 |
+
batch[key] = [batch[key]]
|
90 |
+
else:
|
91 |
+
raise ValueError(f"{key}, {batch[key]}")
|
92 |
+
|
93 |
+
# To device
|
94 |
+
for k in batch:
|
95 |
+
if isinstance(batch[k], torch.Tensor):
|
96 |
+
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
97 |
+
|
98 |
+
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
|
99 |
+
|
100 |
+
from lerobot.common import policies # noqa
|
101 |
+
|
102 |
+
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
103 |
+
cfg.pretrained_path = ckpt_torch_dir
|
104 |
+
policy = make_policy(cfg, dataset_meta)
|
105 |
+
|
106 |
+
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
107 |
+
# loss_dict["loss"].backward()
|
108 |
+
# print("losses")
|
109 |
+
# display(loss_dict["losses_after_forward"])
|
110 |
+
# print("pi_losses")
|
111 |
+
# display(pi_losses)
|
112 |
+
|
113 |
+
actions = []
|
114 |
+
for _ in range(50):
|
115 |
+
action = policy.select_action(batch, noise=noise)
|
116 |
+
actions.append(action)
|
117 |
+
|
118 |
+
actions = torch.stack(actions, dim=1)
|
119 |
+
pi_actions = batch["action"]
|
120 |
+
print("actions")
|
121 |
+
display(actions)
|
122 |
+
print()
|
123 |
+
print("pi_actions")
|
124 |
+
display(pi_actions)
|
125 |
+
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
|
126 |
+
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
|
127 |
+
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
main()
|
lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from transformers import GemmaConfig, PaliGemmaConfig
|
16 |
+
|
17 |
+
|
18 |
+
def get_paligemma_config(precision: str):
|
19 |
+
config = {
|
20 |
+
"image_token_index": None,
|
21 |
+
"pad_token_id": 0,
|
22 |
+
"bos_token_id": 2,
|
23 |
+
"eos_token_id": 1,
|
24 |
+
}
|
25 |
+
|
26 |
+
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
|
27 |
+
|
28 |
+
image_size = 224 # image_sizes[variant]
|
29 |
+
patch_size = 14
|
30 |
+
num_image_tokens = (image_size**2) // (patch_size**2)
|
31 |
+
|
32 |
+
config["image_token_index"] = 257152
|
33 |
+
text_config = {
|
34 |
+
"vocab_size": 257152,
|
35 |
+
"num_hidden_layers": 18,
|
36 |
+
"num_key_value_heads": 1,
|
37 |
+
"head_dim": 256,
|
38 |
+
"torch_dtype": precision,
|
39 |
+
"hidden_size": 2048,
|
40 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
41 |
+
"num_attention_heads": 8,
|
42 |
+
"intermediate_size": 16384,
|
43 |
+
"is_encoder_decoder": False,
|
44 |
+
}
|
45 |
+
vision_config = {
|
46 |
+
"torch_dtype": precision,
|
47 |
+
"image_size": image_size,
|
48 |
+
"patch_size": patch_size,
|
49 |
+
"num_image_tokens": num_image_tokens,
|
50 |
+
"hidden_size": 1152,
|
51 |
+
"intermediate_size": 4304,
|
52 |
+
"num_hidden_layers": 27,
|
53 |
+
"num_attention_heads": 16,
|
54 |
+
"projector_hidden_act": "gelu_fast",
|
55 |
+
"vision_use_head": False,
|
56 |
+
}
|
57 |
+
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
58 |
+
return final_config
|
59 |
+
|
60 |
+
|
61 |
+
def get_gemma_config(precision: str):
|
62 |
+
config = {
|
63 |
+
"image_token_index": None,
|
64 |
+
"pad_token_id": 0,
|
65 |
+
"bos_token_id": 2,
|
66 |
+
"eos_token_id": 1,
|
67 |
+
}
|
68 |
+
|
69 |
+
config["image_token_index"] = 257152
|
70 |
+
text_config = {
|
71 |
+
"vocab_size": 257152,
|
72 |
+
"num_hidden_layers": 18,
|
73 |
+
"num_key_value_heads": 1,
|
74 |
+
"head_dim": 256,
|
75 |
+
"torch_dtype": precision,
|
76 |
+
"hidden_size": 1024,
|
77 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
78 |
+
"num_attention_heads": 8,
|
79 |
+
"intermediate_size": 4096,
|
80 |
+
"is_encoder_decoder": False,
|
81 |
+
}
|
82 |
+
final_config = GemmaConfig()
|
83 |
+
final_config.update(text_config)
|
84 |
+
return final_config
|
lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
Convert pi0 parameters from Jax to Pytorch
|
17 |
+
|
18 |
+
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
19 |
+
and install the required libraries.
|
20 |
+
|
21 |
+
```bash
|
22 |
+
cd ~/code/openpi
|
23 |
+
source .venv/bin/activate
|
24 |
+
```
|
25 |
+
|
26 |
+
Example downloading parameters:
|
27 |
+
```bash
|
28 |
+
python
|
29 |
+
>>> import openpi.shared.download as download
|
30 |
+
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
|
31 |
+
>>> download.maybe_download(path)
|
32 |
+
```
|
33 |
+
|
34 |
+
Converting pi0_base:
|
35 |
+
```python
|
36 |
+
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
37 |
+
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
|
38 |
+
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
|
39 |
+
```
|
40 |
+
|
41 |
+
```python
|
42 |
+
python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
|
43 |
+
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
|
44 |
+
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
45 |
+
```
|
46 |
+
"""
|
47 |
+
|
48 |
+
import argparse
|
49 |
+
import pathlib
|
50 |
+
|
51 |
+
import jax
|
52 |
+
import numpy as np
|
53 |
+
import orbax.checkpoint as ocp
|
54 |
+
import torch
|
55 |
+
from jax.sharding import SingleDeviceSharding
|
56 |
+
|
57 |
+
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
58 |
+
from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
59 |
+
get_gemma_config,
|
60 |
+
get_paligemma_config,
|
61 |
+
)
|
62 |
+
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
63 |
+
|
64 |
+
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
65 |
+
|
66 |
+
|
67 |
+
def slice_paligemma_state_dict(state_dict, config):
|
68 |
+
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
69 |
+
|
70 |
+
# fmt: off
|
71 |
+
# patch embeddings
|
72 |
+
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
|
73 |
+
3, 2, 0, 1
|
74 |
+
)
|
75 |
+
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
|
76 |
+
# positional embeddings
|
77 |
+
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
|
78 |
+
-1, config.vision_config.hidden_size
|
79 |
+
)
|
80 |
+
|
81 |
+
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
82 |
+
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
83 |
+
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
84 |
+
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
85 |
+
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
86 |
+
|
87 |
+
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
88 |
+
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
89 |
+
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
90 |
+
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
91 |
+
|
92 |
+
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
|
93 |
+
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
|
94 |
+
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
|
95 |
+
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
|
96 |
+
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
|
97 |
+
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
|
98 |
+
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
|
99 |
+
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
|
100 |
+
|
101 |
+
for i in range(config.vision_config.num_hidden_layers):
|
102 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
|
103 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
|
104 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
|
105 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
|
106 |
+
|
107 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
|
108 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
|
109 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
|
110 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
|
111 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
112 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
113 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
114 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
115 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
116 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
117 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
118 |
+
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
119 |
+
|
120 |
+
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
|
121 |
+
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
|
122 |
+
|
123 |
+
# multimodal projector
|
124 |
+
|
125 |
+
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
|
126 |
+
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
|
127 |
+
|
128 |
+
# text decoder (gemma)
|
129 |
+
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
|
130 |
+
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
|
131 |
+
|
132 |
+
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
133 |
+
|
134 |
+
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
135 |
+
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
136 |
+
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
137 |
+
|
138 |
+
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
139 |
+
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
140 |
+
# TODO verify correctness of layer norm loading
|
141 |
+
|
142 |
+
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
143 |
+
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
144 |
+
|
145 |
+
for i in range(config.text_config.num_hidden_layers):
|
146 |
+
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
|
147 |
+
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
148 |
+
|
149 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
150 |
+
|
151 |
+
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
|
152 |
+
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
153 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
154 |
+
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
|
155 |
+
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
156 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
157 |
+
|
158 |
+
# output projection.
|
159 |
+
|
160 |
+
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
|
161 |
+
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
162 |
+
|
163 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
164 |
+
# mlp layers
|
165 |
+
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
166 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
167 |
+
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
168 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
169 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
170 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
171 |
+
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
172 |
+
|
173 |
+
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
|
174 |
+
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
|
175 |
+
|
176 |
+
# fmt: on
|
177 |
+
expert_dict = {}
|
178 |
+
final_state_dict = {}
|
179 |
+
for key, value in state_dict.items():
|
180 |
+
if key not in [
|
181 |
+
f"llm/final_norm_1/scale{suffix}",
|
182 |
+
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
183 |
+
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
184 |
+
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
185 |
+
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
186 |
+
f"llm/layers/mlp_1/linear{suffix}",
|
187 |
+
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
188 |
+
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
189 |
+
]:
|
190 |
+
final_state_dict[key] = torch.from_numpy(value)
|
191 |
+
else:
|
192 |
+
expert_dict[key] = value
|
193 |
+
|
194 |
+
return final_state_dict, expert_dict
|
195 |
+
|
196 |
+
|
197 |
+
def slice_gemma_state_dict(state_dict, config, num_expert=1):
|
198 |
+
# fmt: off
|
199 |
+
# text decoder (gemma)
|
200 |
+
# no embedding vector, the expert just has the decoder layers
|
201 |
+
|
202 |
+
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
|
203 |
+
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
|
204 |
+
|
205 |
+
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
206 |
+
|
207 |
+
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
208 |
+
|
209 |
+
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
210 |
+
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
211 |
+
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
212 |
+
|
213 |
+
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
214 |
+
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
215 |
+
# TODO verify correctness of layer norm loading
|
216 |
+
|
217 |
+
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
218 |
+
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
219 |
+
|
220 |
+
for i in range(config.num_hidden_layers):
|
221 |
+
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
222 |
+
|
223 |
+
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
224 |
+
|
225 |
+
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
226 |
+
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
227 |
+
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
228 |
+
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
229 |
+
|
230 |
+
# output projection.
|
231 |
+
|
232 |
+
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
|
233 |
+
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
|
234 |
+
|
235 |
+
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
236 |
+
# mlp layers
|
237 |
+
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
238 |
+
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
239 |
+
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
240 |
+
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
241 |
+
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
242 |
+
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
243 |
+
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
244 |
+
|
245 |
+
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
|
246 |
+
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
|
247 |
+
|
248 |
+
# fmt: on
|
249 |
+
final_state_dict = {}
|
250 |
+
for key, value in state_dict.items():
|
251 |
+
if not isinstance(value, torch.Tensor):
|
252 |
+
final_state_dict[key] = torch.from_numpy(value)
|
253 |
+
else:
|
254 |
+
final_state_dict[key] = value
|
255 |
+
return final_state_dict
|
256 |
+
|
257 |
+
|
258 |
+
def flatten_for_memory(tree, parent_key=""):
|
259 |
+
out = {}
|
260 |
+
for k, v in tree.items():
|
261 |
+
new_key = f"{parent_key}/{k}" if parent_key else k
|
262 |
+
if isinstance(v, dict):
|
263 |
+
out.update(flatten_for_memory(v, new_key))
|
264 |
+
else:
|
265 |
+
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
|
266 |
+
return out
|
267 |
+
|
268 |
+
|
269 |
+
def flatten_for_npz(tree, parent_key=""):
|
270 |
+
out = {}
|
271 |
+
for k, v in tree.items():
|
272 |
+
new_key = f"{parent_key}/{k}" if parent_key else k
|
273 |
+
if isinstance(v, dict):
|
274 |
+
out.update(flatten_for_npz(v, new_key))
|
275 |
+
else:
|
276 |
+
# bf16/f32 here?
|
277 |
+
out[new_key] = np.array(v)
|
278 |
+
return out
|
279 |
+
|
280 |
+
|
281 |
+
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
|
282 |
+
params_path = pathlib.Path(checkpoint_dir).resolve()
|
283 |
+
checkpointer = ocp.PyTreeCheckpointer()
|
284 |
+
|
285 |
+
metadata = checkpointer.metadata(params_path)
|
286 |
+
print("Metadata keys:", list(metadata.keys()))
|
287 |
+
|
288 |
+
params_name = "params"
|
289 |
+
|
290 |
+
item = {params_name: metadata[params_name]}
|
291 |
+
device = jax.local_devices()[0] # Use the first local device
|
292 |
+
sharding = SingleDeviceSharding(device)
|
293 |
+
restored = checkpointer.restore(
|
294 |
+
params_path,
|
295 |
+
ocp.args.PyTreeRestore(
|
296 |
+
item=item,
|
297 |
+
restore_args=jax.tree_util.tree_map(
|
298 |
+
lambda _: ocp.ArrayRestoreArgs(
|
299 |
+
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
|
300 |
+
sharding=sharding,
|
301 |
+
),
|
302 |
+
item,
|
303 |
+
),
|
304 |
+
transforms={},
|
305 |
+
),
|
306 |
+
)
|
307 |
+
params = restored[params_name]
|
308 |
+
|
309 |
+
# get params for PaliGemma
|
310 |
+
pali_params = params["PaliGemma"]
|
311 |
+
del params["PaliGemma"]
|
312 |
+
pali_params_flat = flatten_for_npz(pali_params)
|
313 |
+
return {"paligemma_params": pali_params_flat, "projection_params": params}
|
314 |
+
|
315 |
+
|
316 |
+
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
317 |
+
"""Update dictionary keys by adding a prefix."""
|
318 |
+
return {f"{prefix}{key}": value for key, value in d.items()}
|
319 |
+
|
320 |
+
|
321 |
+
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
322 |
+
# Break down orbax ckpts - they are in OCDBT
|
323 |
+
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
324 |
+
# process projection params
|
325 |
+
keys = [
|
326 |
+
"state_proj",
|
327 |
+
"action_in_proj",
|
328 |
+
"action_out_proj",
|
329 |
+
"action_time_mlp_in",
|
330 |
+
"action_time_mlp_out",
|
331 |
+
]
|
332 |
+
|
333 |
+
projection_params = {}
|
334 |
+
for key in keys:
|
335 |
+
kernel_params = initial_params["projection_params"][key]["kernel"]
|
336 |
+
bias_params = initial_params["projection_params"][key]["bias"]
|
337 |
+
if isinstance(kernel_params, dict):
|
338 |
+
weight = kernel_params["value"]
|
339 |
+
bias = bias_params["value"]
|
340 |
+
else:
|
341 |
+
weight = kernel_params
|
342 |
+
bias = bias_params
|
343 |
+
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
|
344 |
+
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
|
345 |
+
|
346 |
+
# Process PaliGemma weights
|
347 |
+
paligemma_config = get_paligemma_config(precision)
|
348 |
+
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
|
349 |
+
initial_params["paligemma_params"], paligemma_config
|
350 |
+
)
|
351 |
+
|
352 |
+
# Process Gemma weights (at this stage they are unused)
|
353 |
+
gemma_config = get_gemma_config(precision)
|
354 |
+
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
|
355 |
+
|
356 |
+
# Instantiate model from configs
|
357 |
+
|
358 |
+
if "pi0_aloha_sim" in checkpoint_dir:
|
359 |
+
pi0_config = PI0Config(
|
360 |
+
empty_cameras=2,
|
361 |
+
adapt_to_pi_aloha=True,
|
362 |
+
use_delta_joint_actions_aloha=False,
|
363 |
+
)
|
364 |
+
elif "pi0_aloha_towel" in checkpoint_dir:
|
365 |
+
pi0_config = PI0Config(
|
366 |
+
adapt_to_pi_aloha=True,
|
367 |
+
use_delta_joint_actions_aloha=True,
|
368 |
+
)
|
369 |
+
elif "pi0_base" in checkpoint_dir:
|
370 |
+
pi0_config = PI0Config(
|
371 |
+
empty_cameras=0,
|
372 |
+
adapt_to_pi_aloha=False,
|
373 |
+
use_delta_joint_actions_aloha=False,
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
raise ValueError()
|
377 |
+
|
378 |
+
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
379 |
+
pi0_model = PI0Policy(pi0_config)
|
380 |
+
|
381 |
+
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
382 |
+
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
383 |
+
projection_params = update_keys_with_prefix(projection_params, "model.")
|
384 |
+
|
385 |
+
# load state dict
|
386 |
+
torch_dtype = PRECISIONS[precision]
|
387 |
+
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
|
388 |
+
pi0_model = pi0_model.to(torch_dtype)
|
389 |
+
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
390 |
+
|
391 |
+
pi0_model.save_pretrained(output_path, safe_serialization=True)
|
392 |
+
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
|
393 |
+
|
394 |
+
# assert that model loads properly
|
395 |
+
del pi0_model
|
396 |
+
PI0Policy.from_pretrained(output_path)
|
397 |
+
|
398 |
+
|
399 |
+
if __name__ == "__main__":
|
400 |
+
parser = argparse.ArgumentParser()
|
401 |
+
parser.add_argument(
|
402 |
+
"--checkpoint_dir",
|
403 |
+
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
|
404 |
+
type=str,
|
405 |
+
help="Path to the ocdbt checkpoint",
|
406 |
+
)
|
407 |
+
|
408 |
+
parser.add_argument(
|
409 |
+
"--precision",
|
410 |
+
choices=["float32", "bfloat16", "float16"],
|
411 |
+
default="float32",
|
412 |
+
type=str,
|
413 |
+
help="Precision identifier for model conversion - should match the base checkpoint precision.",
|
414 |
+
)
|
415 |
+
# tokenizer is identical to paligemma, it appears
|
416 |
+
|
417 |
+
parser.add_argument(
|
418 |
+
"--tokenizer_hub_id",
|
419 |
+
default="google/paligemma-3b-pt-224",
|
420 |
+
type=str,
|
421 |
+
help="Hub path to the tokenizer to save",
|
422 |
+
)
|
423 |
+
|
424 |
+
parser.add_argument(
|
425 |
+
"--output_path",
|
426 |
+
required=True,
|
427 |
+
type=str,
|
428 |
+
help="Path to save converted weights to",
|
429 |
+
)
|
430 |
+
|
431 |
+
args = parser.parse_args()
|
432 |
+
convert_pi0_checkpoint(
|
433 |
+
checkpoint_dir=args.checkpoint_dir,
|
434 |
+
precision=args.precision,
|
435 |
+
tokenizer_id=args.tokenizer_hub_id,
|
436 |
+
output_path=args.output_path,
|
437 |
+
)
|
lerobot/common/policies/pi0/flex_attention.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F # noqa: N812
|
17 |
+
from packaging.version import Version
|
18 |
+
|
19 |
+
if Version(torch.__version__) > Version("2.5.0"):
|
20 |
+
# Ffex attention is only available from torch 2.5 onwards
|
21 |
+
from torch.nn.attention.flex_attention import (
|
22 |
+
_mask_mod_signature,
|
23 |
+
_round_up_to_multiple,
|
24 |
+
create_block_mask,
|
25 |
+
create_mask,
|
26 |
+
flex_attention,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
# @torch.compile(dynamic=False)
|
31 |
+
def flex_attention_forward(
|
32 |
+
attention_mask: torch.Tensor,
|
33 |
+
batch_size: int,
|
34 |
+
head_dim: int,
|
35 |
+
query_states: torch.Tensor,
|
36 |
+
key_states: torch.Tensor,
|
37 |
+
value_states: torch.Tensor,
|
38 |
+
scaling=None,
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
This is defined out of classes to make compile happy.
|
42 |
+
"""
|
43 |
+
|
44 |
+
original_dtype = query_states.dtype
|
45 |
+
num_att_heads = 8
|
46 |
+
num_key_value_heads = 1
|
47 |
+
num_key_value_groups = num_att_heads // num_key_value_heads
|
48 |
+
|
49 |
+
key_states = key_states[:, :, :, None, :]
|
50 |
+
key_states = key_states.expand(
|
51 |
+
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
52 |
+
)
|
53 |
+
key_states = key_states.reshape(
|
54 |
+
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
55 |
+
)
|
56 |
+
|
57 |
+
value_states = value_states[:, :, :, None, :]
|
58 |
+
value_states = value_states.expand(
|
59 |
+
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
60 |
+
)
|
61 |
+
value_states = value_states.reshape(
|
62 |
+
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
63 |
+
)
|
64 |
+
|
65 |
+
query_states = query_states.transpose(1, 2)
|
66 |
+
key_states = key_states.transpose(1, 2)
|
67 |
+
value_states = value_states.transpose(1, 2)
|
68 |
+
|
69 |
+
query_states = query_states.to(torch.float32)
|
70 |
+
key_states = key_states.to(torch.float32)
|
71 |
+
value_states = value_states.to(torch.float32)
|
72 |
+
|
73 |
+
causal_mask = attention_mask
|
74 |
+
if causal_mask is not None:
|
75 |
+
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
76 |
+
|
77 |
+
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
78 |
+
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
79 |
+
|
80 |
+
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
81 |
+
def mask_mod(b, h, q_idx, kv_idx):
|
82 |
+
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
83 |
+
return precomputed_mask[b][h][q_idx][kv_idx]
|
84 |
+
|
85 |
+
return mask_mod
|
86 |
+
|
87 |
+
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
88 |
+
|
89 |
+
block_size = 128
|
90 |
+
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
91 |
+
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
92 |
+
|
93 |
+
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
94 |
+
|
95 |
+
pad_q = q_len_rounded - q_len
|
96 |
+
pad_k = kv_len_rounded - kv_len
|
97 |
+
|
98 |
+
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
99 |
+
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
100 |
+
|
101 |
+
mask_4d = create_mask(
|
102 |
+
mod_fn=mask_mod_fn_orig,
|
103 |
+
B=b_mask,
|
104 |
+
H=h_mask,
|
105 |
+
Q_LEN=q_len_rounded,
|
106 |
+
KV_LEN=kv_len_rounded,
|
107 |
+
device=causal_mask.device,
|
108 |
+
_compile=False,
|
109 |
+
)
|
110 |
+
|
111 |
+
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
112 |
+
block_mask = create_block_mask(
|
113 |
+
mask_mod=mask_mod_fn_padded,
|
114 |
+
B=b_mask,
|
115 |
+
H=h_mask,
|
116 |
+
Q_LEN=q_len_rounded,
|
117 |
+
KV_LEN=kv_len_rounded,
|
118 |
+
BLOCK_SIZE=block_size,
|
119 |
+
device=causal_mask.device,
|
120 |
+
_compile=False,
|
121 |
+
)
|
122 |
+
|
123 |
+
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
124 |
+
attn_output, attention_weights = flex_attention(
|
125 |
+
query_states,
|
126 |
+
key_states,
|
127 |
+
value_states,
|
128 |
+
block_mask=block_mask,
|
129 |
+
enable_gqa=True, # because we shaped query/key states for GQA
|
130 |
+
scale=head_dim**-0.5 if scaling is None else scaling,
|
131 |
+
return_lse=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
attn_output = attn_output.to(dtype=original_dtype)
|
135 |
+
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
136 |
+
attn_output = attn_output.reshape(
|
137 |
+
batch_size,
|
138 |
+
-1,
|
139 |
+
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
140 |
+
)
|
141 |
+
return attn_output
|
lerobot/common/policies/pi0/modeling_pi0.py
ADDED
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
π0: A Vision-Language-Action Flow Model for General Robot Control
|
19 |
+
|
20 |
+
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
21 |
+
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
22 |
+
|
23 |
+
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
24 |
+
|
25 |
+
Install pi0 extra dependencies:
|
26 |
+
```bash
|
27 |
+
pip install -e ".[pi0]"
|
28 |
+
```
|
29 |
+
|
30 |
+
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
31 |
+
```bash
|
32 |
+
python lerobot/scripts/train.py \
|
33 |
+
--policy.path=lerobot/pi0 \
|
34 |
+
--dataset.repo_id=danaaubakirova/koch_test
|
35 |
+
```
|
36 |
+
|
37 |
+
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
38 |
+
pretrained with VLM default parameters before pi0 finetuning:
|
39 |
+
```bash
|
40 |
+
python lerobot/scripts/train.py \
|
41 |
+
--policy.type=pi0 \
|
42 |
+
--dataset.repo_id=danaaubakirova/koch_test
|
43 |
+
```
|
44 |
+
|
45 |
+
Example of using the pi0 pretrained model outside LeRobot training framework:
|
46 |
+
```python
|
47 |
+
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
48 |
+
```
|
49 |
+
|
50 |
+
"""
|
51 |
+
|
52 |
+
import math
|
53 |
+
from collections import deque
|
54 |
+
|
55 |
+
import torch
|
56 |
+
import torch.nn.functional as F # noqa: N812
|
57 |
+
from torch import Tensor, nn
|
58 |
+
from transformers import AutoTokenizer
|
59 |
+
|
60 |
+
from lerobot.common.constants import ACTION, OBS_ROBOT
|
61 |
+
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
62 |
+
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
63 |
+
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
64 |
+
PaliGemmaWithExpertConfig,
|
65 |
+
PaliGemmaWithExpertModel,
|
66 |
+
)
|
67 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
68 |
+
from lerobot.common.utils.utils import get_safe_dtype
|
69 |
+
|
70 |
+
|
71 |
+
def create_sinusoidal_pos_embedding(
|
72 |
+
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
73 |
+
) -> Tensor:
|
74 |
+
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
75 |
+
if dimension % 2 != 0:
|
76 |
+
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
77 |
+
|
78 |
+
if time.ndim != 1:
|
79 |
+
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
80 |
+
|
81 |
+
dtype = get_safe_dtype(torch.float64, device.type)
|
82 |
+
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
83 |
+
period = min_period * (max_period / min_period) ** fraction
|
84 |
+
|
85 |
+
# Compute the outer product
|
86 |
+
scaling_factor = 1.0 / period * 2 * math.pi
|
87 |
+
sin_input = scaling_factor[None, :] * time[:, None]
|
88 |
+
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
89 |
+
return pos_emb
|
90 |
+
|
91 |
+
|
92 |
+
def sample_beta(alpha, beta, bsize, device):
|
93 |
+
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
94 |
+
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
95 |
+
return gamma1 / (gamma1 + gamma2)
|
96 |
+
|
97 |
+
|
98 |
+
def make_att_2d_masks(pad_masks, att_masks):
|
99 |
+
"""Copied from big_vision.
|
100 |
+
|
101 |
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
102 |
+
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
103 |
+
setup several types of attention, for example:
|
104 |
+
|
105 |
+
[[1 1 1 1 1 1]]: pure causal attention.
|
106 |
+
|
107 |
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
108 |
+
themselves and the last 3 tokens have a causal attention. The first
|
109 |
+
entry could also be a 1 without changing behaviour.
|
110 |
+
|
111 |
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
112 |
+
block can attend all previous blocks and all tokens on the same block.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
input_mask: bool[B, N] true if its part of the input, false if padding.
|
116 |
+
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
117 |
+
it and 0 where it shares the same attention mask as the previous token.
|
118 |
+
"""
|
119 |
+
if att_masks.ndim != 2:
|
120 |
+
raise ValueError(att_masks.ndim)
|
121 |
+
if pad_masks.ndim != 2:
|
122 |
+
raise ValueError(pad_masks.ndim)
|
123 |
+
|
124 |
+
cumsum = torch.cumsum(att_masks, dim=1)
|
125 |
+
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
126 |
+
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
127 |
+
att_2d_masks = att_2d_masks & pad_2d_masks
|
128 |
+
return att_2d_masks
|
129 |
+
|
130 |
+
|
131 |
+
def resize_with_pad(img, width, height, pad_value=-1):
|
132 |
+
# assume no-op when width height fits already
|
133 |
+
if img.ndim != 4:
|
134 |
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
135 |
+
|
136 |
+
cur_height, cur_width = img.shape[2:]
|
137 |
+
|
138 |
+
ratio = max(cur_width / width, cur_height / height)
|
139 |
+
resized_height = int(cur_height / ratio)
|
140 |
+
resized_width = int(cur_width / ratio)
|
141 |
+
resized_img = F.interpolate(
|
142 |
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
143 |
+
)
|
144 |
+
|
145 |
+
pad_height = max(0, int(height - resized_height))
|
146 |
+
pad_width = max(0, int(width - resized_width))
|
147 |
+
|
148 |
+
# pad on left and top of image
|
149 |
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
150 |
+
return padded_img
|
151 |
+
|
152 |
+
|
153 |
+
def pad_vector(vector, new_dim):
|
154 |
+
"""Can be (batch_size x sequence_length x features_dimension)
|
155 |
+
or (batch_size x features_dimension)
|
156 |
+
"""
|
157 |
+
if vector.shape[-1] == new_dim:
|
158 |
+
return vector
|
159 |
+
shape = list(vector.shape)
|
160 |
+
current_dim = shape[-1]
|
161 |
+
shape[-1] = new_dim
|
162 |
+
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
163 |
+
new_vector[..., :current_dim] = vector
|
164 |
+
return new_vector
|
165 |
+
|
166 |
+
|
167 |
+
def normalize(x, min_val, max_val):
|
168 |
+
return (x - min_val) / (max_val - min_val)
|
169 |
+
|
170 |
+
|
171 |
+
def unnormalize(x, min_val, max_val):
|
172 |
+
return x * (max_val - min_val) + min_val
|
173 |
+
|
174 |
+
|
175 |
+
def safe_arcsin(value):
|
176 |
+
# This ensures that the input stays within
|
177 |
+
# [−1,1] to avoid invalid values for arcsin
|
178 |
+
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
179 |
+
|
180 |
+
|
181 |
+
def aloha_gripper_to_angular(value):
|
182 |
+
# Aloha transforms the gripper positions into a linear space. The following code
|
183 |
+
# reverses this transformation to be consistent with pi0 which is pretrained in
|
184 |
+
# angular space.
|
185 |
+
#
|
186 |
+
# These values are coming from the Aloha code:
|
187 |
+
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
188 |
+
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
189 |
+
|
190 |
+
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
191 |
+
def linear_to_radian(linear_position, arm_length, horn_radius):
|
192 |
+
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
193 |
+
return safe_arcsin(value)
|
194 |
+
|
195 |
+
# The constants are taken from the Interbotix code.
|
196 |
+
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
197 |
+
|
198 |
+
# Normalize to [0, 1].
|
199 |
+
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
200 |
+
return normalize(value, min_val=0.4, max_val=1.5)
|
201 |
+
|
202 |
+
|
203 |
+
def aloha_gripper_from_angular(value):
|
204 |
+
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
205 |
+
# Note that the units are still angular but the range is different.
|
206 |
+
|
207 |
+
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
208 |
+
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
209 |
+
|
210 |
+
# These values are coming from the Aloha code:
|
211 |
+
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
212 |
+
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
213 |
+
|
214 |
+
|
215 |
+
def aloha_gripper_from_angular_inv(value):
|
216 |
+
# Directly inverts the gripper_from_angular function.
|
217 |
+
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
218 |
+
return normalize(value, min_val=0.4, max_val=1.5)
|
219 |
+
|
220 |
+
|
221 |
+
class PI0Policy(PreTrainedPolicy):
|
222 |
+
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
|
223 |
+
|
224 |
+
config_class = PI0Config
|
225 |
+
name = "pi0"
|
226 |
+
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
config: PI0Config,
|
230 |
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
231 |
+
):
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
config: Policy configuration class instance or None, in which case the default instantiation of
|
235 |
+
the configuration class is used.
|
236 |
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
237 |
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
238 |
+
"""
|
239 |
+
|
240 |
+
super().__init__(config)
|
241 |
+
config.validate_features()
|
242 |
+
self.config = config
|
243 |
+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
244 |
+
self.normalize_targets = Normalize(
|
245 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
246 |
+
)
|
247 |
+
self.unnormalize_outputs = Unnormalize(
|
248 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
249 |
+
)
|
250 |
+
|
251 |
+
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
252 |
+
self.model = PI0FlowMatching(config)
|
253 |
+
|
254 |
+
self.reset()
|
255 |
+
|
256 |
+
def reset(self):
|
257 |
+
"""This should be called whenever the environment is reset."""
|
258 |
+
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
259 |
+
|
260 |
+
def get_optim_params(self) -> dict:
|
261 |
+
return self.parameters()
|
262 |
+
|
263 |
+
@torch.no_grad
|
264 |
+
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
265 |
+
"""Select a single action given environment observations.
|
266 |
+
|
267 |
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
268 |
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
269 |
+
queue is empty.
|
270 |
+
"""
|
271 |
+
self.eval()
|
272 |
+
|
273 |
+
if self.config.adapt_to_pi_aloha:
|
274 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
275 |
+
|
276 |
+
batch = self.normalize_inputs(batch)
|
277 |
+
|
278 |
+
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
279 |
+
# querying the policy.
|
280 |
+
if len(self._action_queue) == 0:
|
281 |
+
images, img_masks = self.prepare_images(batch)
|
282 |
+
state = self.prepare_state(batch)
|
283 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
284 |
+
|
285 |
+
actions = self.model.sample_actions(
|
286 |
+
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
287 |
+
)
|
288 |
+
|
289 |
+
# Unpad actions
|
290 |
+
original_action_dim = self.config.action_feature.shape[0]
|
291 |
+
actions = actions[:, :, :original_action_dim]
|
292 |
+
|
293 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
294 |
+
|
295 |
+
if self.config.adapt_to_pi_aloha:
|
296 |
+
actions = self._pi_aloha_encode_actions(actions)
|
297 |
+
|
298 |
+
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
299 |
+
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
300 |
+
self._action_queue.extend(actions.transpose(0, 1))
|
301 |
+
return self._action_queue.popleft()
|
302 |
+
|
303 |
+
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
304 |
+
"""Do a full training forward pass to compute the loss"""
|
305 |
+
if self.config.adapt_to_pi_aloha:
|
306 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
307 |
+
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
308 |
+
|
309 |
+
batch = self.normalize_inputs(batch)
|
310 |
+
batch = self.normalize_targets(batch)
|
311 |
+
|
312 |
+
images, img_masks = self.prepare_images(batch)
|
313 |
+
state = self.prepare_state(batch)
|
314 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
315 |
+
actions = self.prepare_action(batch)
|
316 |
+
actions_is_pad = batch.get("action_is_pad")
|
317 |
+
|
318 |
+
loss_dict = {}
|
319 |
+
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
320 |
+
loss_dict["losses_after_forward"] = losses.clone()
|
321 |
+
|
322 |
+
if actions_is_pad is not None:
|
323 |
+
in_episode_bound = ~actions_is_pad
|
324 |
+
losses = losses * in_episode_bound.unsqueeze(-1)
|
325 |
+
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
326 |
+
|
327 |
+
# Remove padding
|
328 |
+
losses = losses[:, :, : self.config.max_action_dim]
|
329 |
+
loss_dict["losses_after_rm_padding"] = losses.clone()
|
330 |
+
|
331 |
+
# For backward pass
|
332 |
+
loss = losses.mean()
|
333 |
+
# For logging
|
334 |
+
loss_dict["l2_loss"] = loss.item()
|
335 |
+
|
336 |
+
return loss, loss_dict
|
337 |
+
|
338 |
+
def prepare_images(self, batch):
|
339 |
+
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
340 |
+
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
341 |
+
"""
|
342 |
+
images = []
|
343 |
+
img_masks = []
|
344 |
+
|
345 |
+
present_img_keys = [key for key in self.config.image_features if key in batch]
|
346 |
+
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
347 |
+
|
348 |
+
if len(present_img_keys) == 0:
|
349 |
+
raise ValueError(
|
350 |
+
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
351 |
+
)
|
352 |
+
|
353 |
+
# Preprocess image features present in the batch
|
354 |
+
for key in present_img_keys:
|
355 |
+
img = batch[key]
|
356 |
+
|
357 |
+
if self.config.resize_imgs_with_padding is not None:
|
358 |
+
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
359 |
+
|
360 |
+
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
361 |
+
img = img * 2.0 - 1.0
|
362 |
+
|
363 |
+
bsize = img.shape[0]
|
364 |
+
device = img.device
|
365 |
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
366 |
+
images.append(img)
|
367 |
+
img_masks.append(mask)
|
368 |
+
|
369 |
+
# Create image features not present in the batch
|
370 |
+
# as fully 0 padded images.
|
371 |
+
for num_empty_cameras in range(len(missing_img_keys)):
|
372 |
+
if num_empty_cameras >= self.config.empty_cameras:
|
373 |
+
break
|
374 |
+
img = torch.ones_like(img) * -1
|
375 |
+
mask = torch.zeros_like(mask)
|
376 |
+
images.append(img)
|
377 |
+
img_masks.append(mask)
|
378 |
+
|
379 |
+
return images, img_masks
|
380 |
+
|
381 |
+
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
382 |
+
"""Tokenize the text input"""
|
383 |
+
device = batch[OBS_ROBOT].device
|
384 |
+
tasks = batch["task"]
|
385 |
+
|
386 |
+
# PaliGemma prompt has to end with a new line
|
387 |
+
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
388 |
+
|
389 |
+
tokenized_prompt = self.language_tokenizer.__call__(
|
390 |
+
tasks,
|
391 |
+
padding="max_length",
|
392 |
+
padding_side="right",
|
393 |
+
max_length=self.config.tokenizer_max_length,
|
394 |
+
return_tensors="pt",
|
395 |
+
)
|
396 |
+
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
397 |
+
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
398 |
+
|
399 |
+
return lang_tokens, lang_masks
|
400 |
+
|
401 |
+
def _pi_aloha_decode_state(self, state):
|
402 |
+
# Flip the joints.
|
403 |
+
for motor_idx in [1, 2, 8, 9]:
|
404 |
+
state[:, motor_idx] *= -1
|
405 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
406 |
+
for motor_idx in [6, 13]:
|
407 |
+
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
408 |
+
return state
|
409 |
+
|
410 |
+
def _pi_aloha_encode_actions(self, actions):
|
411 |
+
# Flip the joints.
|
412 |
+
for motor_idx in [1, 2, 8, 9]:
|
413 |
+
actions[:, :, motor_idx] *= -1
|
414 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
415 |
+
for motor_idx in [6, 13]:
|
416 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
417 |
+
return actions
|
418 |
+
|
419 |
+
def _pi_aloha_encode_actions_inv(self, actions):
|
420 |
+
# Flip the joints again.
|
421 |
+
for motor_idx in [1, 2, 8, 9]:
|
422 |
+
actions[:, :, motor_idx] *= -1
|
423 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
424 |
+
for motor_idx in [6, 13]:
|
425 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
426 |
+
return actions
|
427 |
+
|
428 |
+
def prepare_state(self, batch):
|
429 |
+
"""Pad state"""
|
430 |
+
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
431 |
+
return state
|
432 |
+
|
433 |
+
def prepare_action(self, batch):
|
434 |
+
"""Pad action"""
|
435 |
+
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
436 |
+
return actions
|
437 |
+
|
438 |
+
|
439 |
+
class PI0FlowMatching(nn.Module):
|
440 |
+
"""
|
441 |
+
π0: A Vision-Language-Action Flow Model for General Robot Control
|
442 |
+
|
443 |
+
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
444 |
+
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
445 |
+
|
446 |
+
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
447 |
+
┌──────────────────────────────┐
|
448 |
+
│ actions │
|
449 |
+
│ ▲ │
|
450 |
+
│ ┌┴─────┐ │
|
451 |
+
│ kv cache │Gemma │ │
|
452 |
+
│ ┌──────────►│Expert│ │
|
453 |
+
│ │ │ │ │
|
454 |
+
│ ┌┴────────┐ │x 10 │ │
|
455 |
+
│ │ │ └▲──▲──┘ │
|
456 |
+
│ │PaliGemma│ │ │ │
|
457 |
+
│ │ │ │ robot state │
|
458 |
+
│ │ │ noise │
|
459 |
+
│ └▲──▲─────┘ │
|
460 |
+
│ │ │ │
|
461 |
+
│ │ image(s) │
|
462 |
+
│ language tokens │
|
463 |
+
└──────────────────────────────┘
|
464 |
+
"""
|
465 |
+
|
466 |
+
def __init__(self, config):
|
467 |
+
super().__init__()
|
468 |
+
self.config = config
|
469 |
+
|
470 |
+
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
471 |
+
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
472 |
+
train_expert_only=self.config.train_expert_only,
|
473 |
+
attention_implementation=self.config.attention_implementation,
|
474 |
+
)
|
475 |
+
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
476 |
+
|
477 |
+
# Projections are float32
|
478 |
+
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
479 |
+
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
480 |
+
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
481 |
+
|
482 |
+
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
483 |
+
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
484 |
+
|
485 |
+
self.set_requires_grad()
|
486 |
+
|
487 |
+
def set_requires_grad(self):
|
488 |
+
for params in self.state_proj.parameters():
|
489 |
+
params.requires_grad = self.config.train_state_proj
|
490 |
+
|
491 |
+
def sample_noise(self, shape, device):
|
492 |
+
noise = torch.normal(
|
493 |
+
mean=0.0,
|
494 |
+
std=1.0,
|
495 |
+
size=shape,
|
496 |
+
dtype=torch.float32,
|
497 |
+
device=device,
|
498 |
+
)
|
499 |
+
return noise
|
500 |
+
|
501 |
+
def sample_time(self, bsize, device):
|
502 |
+
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
503 |
+
time = time_beta * 0.999 + 0.001
|
504 |
+
return time.to(dtype=torch.float32, device=device)
|
505 |
+
|
506 |
+
def embed_prefix(
|
507 |
+
self, images, img_masks, lang_tokens, lang_masks
|
508 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
509 |
+
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
510 |
+
for PaliGemma transformer processing.
|
511 |
+
"""
|
512 |
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
513 |
+
embs = []
|
514 |
+
pad_masks = []
|
515 |
+
att_masks = []
|
516 |
+
|
517 |
+
# TODO: remove for loop
|
518 |
+
for (
|
519 |
+
img,
|
520 |
+
img_mask,
|
521 |
+
) in zip(images, img_masks, strict=False):
|
522 |
+
img_emb = self.paligemma_with_expert.embed_image(img)
|
523 |
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
524 |
+
|
525 |
+
# Normalize image embeddings
|
526 |
+
img_emb_dim = img_emb.shape[-1]
|
527 |
+
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
528 |
+
|
529 |
+
bsize, num_img_embs = img_emb.shape[:2]
|
530 |
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
531 |
+
|
532 |
+
embs.append(img_emb)
|
533 |
+
pad_masks.append(img_mask)
|
534 |
+
|
535 |
+
# Create attention masks so that image tokens attend to each other
|
536 |
+
att_masks += [0] * num_img_embs
|
537 |
+
|
538 |
+
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
539 |
+
|
540 |
+
# Normalize language embeddings
|
541 |
+
lang_emb_dim = lang_emb.shape[-1]
|
542 |
+
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
543 |
+
|
544 |
+
embs.append(lang_emb)
|
545 |
+
pad_masks.append(lang_masks)
|
546 |
+
|
547 |
+
# full attention between image and language inputs
|
548 |
+
num_lang_embs = lang_emb.shape[1]
|
549 |
+
att_masks += [0] * num_lang_embs
|
550 |
+
|
551 |
+
embs = torch.cat(embs, dim=1)
|
552 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
553 |
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
554 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
555 |
+
|
556 |
+
return embs, pad_masks, att_masks
|
557 |
+
|
558 |
+
def embed_suffix(self, state, noisy_actions, timestep):
|
559 |
+
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
560 |
+
embs = []
|
561 |
+
pad_masks = []
|
562 |
+
att_masks = []
|
563 |
+
|
564 |
+
# Embed state
|
565 |
+
state_emb = self.state_proj(state)
|
566 |
+
state_emb = state_emb.to(dtype=torch.bfloat16)
|
567 |
+
embs.append(state_emb[:, None, :])
|
568 |
+
bsize = state_emb.shape[0]
|
569 |
+
dtype = state_emb.dtype
|
570 |
+
device = state_emb.device
|
571 |
+
|
572 |
+
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
573 |
+
pad_masks.append(state_mask)
|
574 |
+
|
575 |
+
# Set attention masks so that image and language inputs do not attend to state or actions
|
576 |
+
att_masks += [1]
|
577 |
+
|
578 |
+
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
579 |
+
time_emb = create_sinusoidal_pos_embedding(
|
580 |
+
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
581 |
+
)
|
582 |
+
time_emb = time_emb.type(dtype=dtype)
|
583 |
+
|
584 |
+
# Fuse timestep + action information using an MLP
|
585 |
+
action_emb = self.action_in_proj(noisy_actions)
|
586 |
+
|
587 |
+
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
588 |
+
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
589 |
+
|
590 |
+
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
591 |
+
action_time_emb = F.silu(action_time_emb) # swish == silu
|
592 |
+
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
593 |
+
|
594 |
+
# Add to input tokens
|
595 |
+
embs.append(action_time_emb)
|
596 |
+
|
597 |
+
bsize, action_time_dim = action_time_emb.shape[:2]
|
598 |
+
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
599 |
+
pad_masks.append(action_time_mask)
|
600 |
+
|
601 |
+
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
602 |
+
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
603 |
+
|
604 |
+
embs = torch.cat(embs, dim=1)
|
605 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
606 |
+
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
607 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
608 |
+
|
609 |
+
return embs, pad_masks, att_masks
|
610 |
+
|
611 |
+
def forward(
|
612 |
+
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
613 |
+
) -> Tensor:
|
614 |
+
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
615 |
+
if noise is None:
|
616 |
+
noise = self.sample_noise(actions.shape, actions.device)
|
617 |
+
|
618 |
+
if time is None:
|
619 |
+
time = self.sample_time(actions.shape[0], actions.device)
|
620 |
+
|
621 |
+
time_expanded = time[:, None, None]
|
622 |
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
623 |
+
u_t = noise - actions
|
624 |
+
|
625 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
626 |
+
images, img_masks, lang_tokens, lang_masks
|
627 |
+
)
|
628 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
629 |
+
|
630 |
+
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
631 |
+
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
632 |
+
|
633 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
634 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
635 |
+
|
636 |
+
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
637 |
+
attention_mask=att_2d_masks,
|
638 |
+
position_ids=position_ids,
|
639 |
+
past_key_values=None,
|
640 |
+
inputs_embeds=[prefix_embs, suffix_embs],
|
641 |
+
use_cache=False,
|
642 |
+
fill_kv_cache=False,
|
643 |
+
)
|
644 |
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
645 |
+
# Original openpi code, upcast attention output
|
646 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
647 |
+
v_t = self.action_out_proj(suffix_out)
|
648 |
+
|
649 |
+
losses = F.mse_loss(u_t, v_t, reduction="none")
|
650 |
+
return losses
|
651 |
+
|
652 |
+
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
653 |
+
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
654 |
+
bsize = state.shape[0]
|
655 |
+
device = state.device
|
656 |
+
|
657 |
+
if noise is None:
|
658 |
+
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
659 |
+
noise = self.sample_noise(actions_shape, device)
|
660 |
+
|
661 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
662 |
+
images, img_masks, lang_tokens, lang_masks
|
663 |
+
)
|
664 |
+
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
665 |
+
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
666 |
+
|
667 |
+
# Compute image and language key value cache
|
668 |
+
_, past_key_values = self.paligemma_with_expert.forward(
|
669 |
+
attention_mask=prefix_att_2d_masks,
|
670 |
+
position_ids=prefix_position_ids,
|
671 |
+
past_key_values=None,
|
672 |
+
inputs_embeds=[prefix_embs, None],
|
673 |
+
use_cache=self.config.use_cache,
|
674 |
+
fill_kv_cache=True,
|
675 |
+
)
|
676 |
+
|
677 |
+
dt = -1.0 / self.config.num_steps
|
678 |
+
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
679 |
+
|
680 |
+
x_t = noise
|
681 |
+
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
682 |
+
while time >= -dt / 2:
|
683 |
+
expanded_time = time.expand(bsize)
|
684 |
+
v_t = self.denoise_step(
|
685 |
+
state,
|
686 |
+
prefix_pad_masks,
|
687 |
+
past_key_values,
|
688 |
+
x_t,
|
689 |
+
expanded_time,
|
690 |
+
)
|
691 |
+
|
692 |
+
# Euler step
|
693 |
+
x_t += dt * v_t
|
694 |
+
time += dt
|
695 |
+
return x_t
|
696 |
+
|
697 |
+
def denoise_step(
|
698 |
+
self,
|
699 |
+
state,
|
700 |
+
prefix_pad_masks,
|
701 |
+
past_key_values,
|
702 |
+
x_t,
|
703 |
+
timestep,
|
704 |
+
):
|
705 |
+
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
706 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
707 |
+
|
708 |
+
suffix_len = suffix_pad_masks.shape[1]
|
709 |
+
batch_size = prefix_pad_masks.shape[0]
|
710 |
+
prefix_len = prefix_pad_masks.shape[1]
|
711 |
+
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
712 |
+
|
713 |
+
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
714 |
+
|
715 |
+
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
716 |
+
|
717 |
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
718 |
+
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
719 |
+
|
720 |
+
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
721 |
+
attention_mask=full_att_2d_masks,
|
722 |
+
position_ids=position_ids,
|
723 |
+
past_key_values=past_key_values,
|
724 |
+
inputs_embeds=[None, suffix_embs],
|
725 |
+
use_cache=self.config.use_cache,
|
726 |
+
fill_kv_cache=False,
|
727 |
+
)
|
728 |
+
suffix_out = outputs_embeds[1]
|
729 |
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
730 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
731 |
+
v_t = self.action_out_proj(suffix_out)
|
732 |
+
return v_t
|