Francesco Capuano commited on
Commit
529ed6b
·
1 Parent(s): efd04f3

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. lerobot/__init__.py +217 -0
  2. lerobot/__version__.py +23 -0
  3. lerobot/common/constants.py +45 -0
  4. lerobot/common/datasets/backward_compatibility.py +68 -0
  5. lerobot/common/datasets/card_template.md +27 -0
  6. lerobot/common/datasets/compute_stats.py +176 -0
  7. lerobot/common/datasets/factory.py +118 -0
  8. lerobot/common/datasets/image_writer.py +178 -0
  9. lerobot/common/datasets/lerobot_dataset.py +1217 -0
  10. lerobot/common/datasets/online_buffer.py +384 -0
  11. lerobot/common/datasets/push_dataset_to_hub/utils.py +131 -0
  12. lerobot/common/datasets/sampler.py +61 -0
  13. lerobot/common/datasets/transforms.py +249 -0
  14. lerobot/common/datasets/utils.py +813 -0
  15. lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py +884 -0
  16. lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +664 -0
  17. lerobot/common/datasets/v21/_remove_language_instruction.py +87 -0
  18. lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py +54 -0
  19. lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +114 -0
  20. lerobot/common/datasets/v21/convert_stats.py +99 -0
  21. lerobot/common/datasets/video_utils.py +432 -0
  22. lerobot/common/envs/__init__.py +15 -0
  23. lerobot/common/envs/configs.py +156 -0
  24. lerobot/common/envs/factory.py +69 -0
  25. lerobot/common/envs/utils.py +127 -0
  26. lerobot/common/mocks/__init__.py +1 -0
  27. lerobot/common/mocks/cameras/__init__.py +0 -0
  28. lerobot/common/mocks/cameras/mock_cv2.py +101 -0
  29. lerobot/common/mocks/cameras/mock_pyrealsense2.py +148 -0
  30. lerobot/common/mocks/motors/__init__.py +1 -0
  31. lerobot/common/mocks/motors/mock_dynamixel_sdk.py +107 -0
  32. lerobot/common/mocks/motors/mock_scservo_sdk.py +125 -0
  33. lerobot/common/optim/__init__.py +15 -0
  34. lerobot/common/optim/factory.py +40 -0
  35. lerobot/common/optim/optimizers.py +118 -0
  36. lerobot/common/optim/schedulers.py +122 -0
  37. lerobot/common/policies/__init__.py +19 -0
  38. lerobot/common/policies/act/configuration_act.py +186 -0
  39. lerobot/common/policies/act/modeling_act.py +765 -0
  40. lerobot/common/policies/diffusion/configuration_diffusion.py +237 -0
  41. lerobot/common/policies/diffusion/modeling_diffusion.py +765 -0
  42. lerobot/common/policies/factory.py +157 -0
  43. lerobot/common/policies/normalize.py +254 -0
  44. lerobot/common/policies/pi0/configuration_pi0.py +149 -0
  45. lerobot/common/policies/pi0/conversion_scripts/benchmark.py +82 -0
  46. lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +131 -0
  47. lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py +84 -0
  48. lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +437 -0
  49. lerobot/common/policies/pi0/flex_attention.py +141 -0
  50. lerobot/common/policies/pi0/modeling_pi0.py +732 -0
lerobot/__init__.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
18
+ We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
19
+
20
+ Example:
21
+ ```python
22
+ import lerobot
23
+ print(lerobot.available_envs)
24
+ print(lerobot.available_tasks_per_env)
25
+ print(lerobot.available_datasets)
26
+ print(lerobot.available_datasets_per_env)
27
+ print(lerobot.available_real_world_datasets)
28
+ print(lerobot.available_policies)
29
+ print(lerobot.available_policies_per_env)
30
+ print(lerobot.available_robots)
31
+ print(lerobot.available_cameras)
32
+ print(lerobot.available_motors)
33
+ ```
34
+
35
+ When implementing a new dataset loadable with LeRobotDataset follow these steps:
36
+ - Update `available_datasets_per_env` in `lerobot/__init__.py`
37
+
38
+ When implementing a new environment (e.g. `gym_aloha`), follow these steps:
39
+ - Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
40
+
41
+ When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
42
+ - Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
43
+ - Set the required `name` class attribute.
44
+ - Update variables in `tests/test_available.py` by importing your new Policy class
45
+ """
46
+
47
+ import itertools
48
+
49
+ from lerobot.__version__ import __version__ # noqa: F401
50
+
51
+ # TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
52
+ # refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
53
+ # a yaml file AND a environment name. The difference should be more obvious.
54
+ available_tasks_per_env = {
55
+ "aloha": [
56
+ "AlohaInsertion-v0",
57
+ "AlohaTransferCube-v0",
58
+ ],
59
+ "pusht": ["PushT-v0"],
60
+ "xarm": ["XarmLift-v0"],
61
+ }
62
+ available_envs = list(available_tasks_per_env.keys())
63
+
64
+ available_datasets_per_env = {
65
+ "aloha": [
66
+ "lerobot/aloha_sim_insertion_human",
67
+ "lerobot/aloha_sim_insertion_scripted",
68
+ "lerobot/aloha_sim_transfer_cube_human",
69
+ "lerobot/aloha_sim_transfer_cube_scripted",
70
+ "lerobot/aloha_sim_insertion_human_image",
71
+ "lerobot/aloha_sim_insertion_scripted_image",
72
+ "lerobot/aloha_sim_transfer_cube_human_image",
73
+ "lerobot/aloha_sim_transfer_cube_scripted_image",
74
+ ],
75
+ # TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
76
+ # coupled with tests.
77
+ "pusht": ["lerobot/pusht", "lerobot/pusht_image"],
78
+ "xarm": [
79
+ "lerobot/xarm_lift_medium",
80
+ "lerobot/xarm_lift_medium_replay",
81
+ "lerobot/xarm_push_medium",
82
+ "lerobot/xarm_push_medium_replay",
83
+ "lerobot/xarm_lift_medium_image",
84
+ "lerobot/xarm_lift_medium_replay_image",
85
+ "lerobot/xarm_push_medium_image",
86
+ "lerobot/xarm_push_medium_replay_image",
87
+ ],
88
+ }
89
+
90
+ available_real_world_datasets = [
91
+ "lerobot/aloha_mobile_cabinet",
92
+ "lerobot/aloha_mobile_chair",
93
+ "lerobot/aloha_mobile_elevator",
94
+ "lerobot/aloha_mobile_shrimp",
95
+ "lerobot/aloha_mobile_wash_pan",
96
+ "lerobot/aloha_mobile_wipe_wine",
97
+ "lerobot/aloha_static_battery",
98
+ "lerobot/aloha_static_candy",
99
+ "lerobot/aloha_static_coffee",
100
+ "lerobot/aloha_static_coffee_new",
101
+ "lerobot/aloha_static_cups_open",
102
+ "lerobot/aloha_static_fork_pick_up",
103
+ "lerobot/aloha_static_pingpong_test",
104
+ "lerobot/aloha_static_pro_pencil",
105
+ "lerobot/aloha_static_screw_driver",
106
+ "lerobot/aloha_static_tape",
107
+ "lerobot/aloha_static_thread_velcro",
108
+ "lerobot/aloha_static_towel",
109
+ "lerobot/aloha_static_vinh_cup",
110
+ "lerobot/aloha_static_vinh_cup_left",
111
+ "lerobot/aloha_static_ziploc_slide",
112
+ "lerobot/umi_cup_in_the_wild",
113
+ "lerobot/unitreeh1_fold_clothes",
114
+ "lerobot/unitreeh1_rearrange_objects",
115
+ "lerobot/unitreeh1_two_robot_greeting",
116
+ "lerobot/unitreeh1_warehouse",
117
+ "lerobot/nyu_rot_dataset",
118
+ "lerobot/utokyo_saytap",
119
+ "lerobot/imperialcollege_sawyer_wrist_cam",
120
+ "lerobot/utokyo_xarm_bimanual",
121
+ "lerobot/tokyo_u_lsmo",
122
+ "lerobot/utokyo_pr2_opening_fridge",
123
+ "lerobot/cmu_franka_exploration_dataset",
124
+ "lerobot/cmu_stretch",
125
+ "lerobot/asu_table_top",
126
+ "lerobot/utokyo_pr2_tabletop_manipulation",
127
+ "lerobot/utokyo_xarm_pick_and_place",
128
+ "lerobot/ucsd_kitchen_dataset",
129
+ "lerobot/austin_buds_dataset",
130
+ "lerobot/dlr_sara_grid_clamp",
131
+ "lerobot/conq_hose_manipulation",
132
+ "lerobot/columbia_cairlab_pusht_real",
133
+ "lerobot/dlr_sara_pour",
134
+ "lerobot/dlr_edan_shared_control",
135
+ "lerobot/ucsd_pick_and_place_dataset",
136
+ "lerobot/berkeley_cable_routing",
137
+ "lerobot/nyu_franka_play_dataset",
138
+ "lerobot/austin_sirius_dataset",
139
+ "lerobot/cmu_play_fusion",
140
+ "lerobot/berkeley_gnm_sac_son",
141
+ "lerobot/nyu_door_opening_surprising_effectiveness",
142
+ "lerobot/berkeley_fanuc_manipulation",
143
+ "lerobot/jaco_play",
144
+ "lerobot/viola",
145
+ "lerobot/kaist_nonprehensile",
146
+ "lerobot/berkeley_mvp",
147
+ "lerobot/uiuc_d3field",
148
+ "lerobot/berkeley_gnm_recon",
149
+ "lerobot/austin_sailor_dataset",
150
+ "lerobot/utaustin_mutex",
151
+ "lerobot/roboturk",
152
+ "lerobot/stanford_hydra_dataset",
153
+ "lerobot/berkeley_autolab_ur5",
154
+ "lerobot/stanford_robocook",
155
+ "lerobot/toto",
156
+ "lerobot/fmb",
157
+ "lerobot/droid_100",
158
+ "lerobot/berkeley_rpt",
159
+ "lerobot/stanford_kuka_multimodal_dataset",
160
+ "lerobot/iamlab_cmu_pickup_insert",
161
+ "lerobot/taco_play",
162
+ "lerobot/berkeley_gnm_cory_hall",
163
+ "lerobot/usc_cloth_sim",
164
+ ]
165
+
166
+ available_datasets = sorted(
167
+ set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
168
+ )
169
+
170
+ # lists all available policies from `lerobot/common/policies`
171
+ available_policies = [
172
+ "act",
173
+ "diffusion",
174
+ "tdmpc",
175
+ "vqbet",
176
+ ]
177
+
178
+ # lists all available robots from `lerobot/common/robot_devices/robots`
179
+ available_robots = [
180
+ "koch",
181
+ "koch_bimanual",
182
+ "aloha",
183
+ "so100",
184
+ "moss",
185
+ ]
186
+
187
+ # lists all available cameras from `lerobot/common/robot_devices/cameras`
188
+ available_cameras = [
189
+ "opencv",
190
+ "intelrealsense",
191
+ ]
192
+
193
+ # lists all available motors from `lerobot/common/robot_devices/motors`
194
+ available_motors = [
195
+ "dynamixel",
196
+ "feetech",
197
+ ]
198
+
199
+ # keys and values refer to yaml files
200
+ available_policies_per_env = {
201
+ "aloha": ["act"],
202
+ "pusht": ["diffusion", "vqbet"],
203
+ "xarm": ["tdmpc"],
204
+ "koch_real": ["act_koch_real"],
205
+ "aloha_real": ["act_aloha_real"],
206
+ }
207
+
208
+ env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
209
+ env_dataset_pairs = [
210
+ (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
211
+ ]
212
+ env_dataset_policy_triplets = [
213
+ (env, dataset, policy)
214
+ for env, datasets in available_datasets_per_env.items()
215
+ for dataset in datasets
216
+ for policy in available_policies_per_env[env]
217
+ ]
lerobot/__version__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """To enable `lerobot.__version__`"""
17
+
18
+ from importlib.metadata import PackageNotFoundError, version
19
+
20
+ try:
21
+ __version__ = version("lerobot")
22
+ except PackageNotFoundError:
23
+ __version__ = "unknown"
lerobot/common/constants.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # keys
15
+ import os
16
+ from pathlib import Path
17
+
18
+ from huggingface_hub.constants import HF_HOME
19
+
20
+ OBS_ENV = "observation.environment_state"
21
+ OBS_ROBOT = "observation.state"
22
+ OBS_IMAGE = "observation.image"
23
+ OBS_IMAGES = "observation.images"
24
+ ACTION = "action"
25
+
26
+ # files & directories
27
+ CHECKPOINTS_DIR = "checkpoints"
28
+ LAST_CHECKPOINT_LINK = "last"
29
+ PRETRAINED_MODEL_DIR = "pretrained_model"
30
+ TRAINING_STATE_DIR = "training_state"
31
+ RNG_STATE = "rng_state.safetensors"
32
+ TRAINING_STEP = "training_step.json"
33
+ OPTIMIZER_STATE = "optimizer_state.safetensors"
34
+ OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
35
+ SCHEDULER_STATE = "scheduler_state.json"
36
+
37
+ # cache dir
38
+ default_cache_path = Path(HF_HOME) / "lerobot"
39
+ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
40
+
41
+ if "LEROBOT_HOME" in os.environ:
42
+ raise ValueError(
43
+ f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
44
+ "'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
45
+ )
lerobot/common/datasets/backward_compatibility.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import packaging.version
16
+
17
+ V2_MESSAGE = """
18
+ The dataset you requested ({repo_id}) is in {version} format.
19
+
20
+ We introduced a new format since v2.0 which is not backward compatible with v1.x.
21
+ Please, use our conversion script. Modify the following command with your own task description:
22
+ ```
23
+ python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
24
+ --repo-id {repo_id} \\
25
+ --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
26
+ ```
27
+
28
+ A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
29
+ peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
30
+ cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
31
+ target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
32
+ sweatshirt.", ...
33
+
34
+ If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
35
+ or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
36
+ """
37
+
38
+ V21_MESSAGE = """
39
+ The dataset you requested ({repo_id}) is in {version} format.
40
+ While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
41
+ stats instead of per-episode stats. Update your dataset stats to the new format using this command:
42
+ ```
43
+ python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
44
+ ```
45
+
46
+ If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
47
+ or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
48
+ """
49
+
50
+ FUTURE_MESSAGE = """
51
+ The dataset you requested ({repo_id}) is only available in {version} format.
52
+ As we cannot ensure forward compatibility with it, please update your current version of lerobot.
53
+ """
54
+
55
+
56
+ class CompatibilityError(Exception): ...
57
+
58
+
59
+ class BackwardCompatibilityError(CompatibilityError):
60
+ def __init__(self, repo_id: str, version: packaging.version.Version):
61
+ message = V2_MESSAGE.format(repo_id=repo_id, version=version)
62
+ super().__init__(message)
63
+
64
+
65
+ class ForwardCompatibilityError(CompatibilityError):
66
+ def __init__(self, repo_id: str, version: packaging.version.Version):
67
+ message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
68
+ super().__init__(message)
lerobot/common/datasets/card_template.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
3
+ # Doc / guide: https://huggingface.co/docs/hub/datasets-cards
4
+ {{ card_data }}
5
+ ---
6
+
7
+ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
8
+
9
+ ## Dataset Description
10
+
11
+ {{ dataset_description | default("", true) }}
12
+
13
+ - **Homepage:** {{ url | default("[More Information Needed]", true)}}
14
+ - **Paper:** {{ paper | default("[More Information Needed]", true)}}
15
+ - **License:** {{ license | default("[More Information Needed]", true)}}
16
+
17
+ ## Dataset Structure
18
+
19
+ {{ dataset_structure | default("[More Information Needed]", true)}}
20
+
21
+ ## Citation
22
+
23
+ **BibTeX:**
24
+
25
+ ```bibtex
26
+ {{ citation_bibtex | default("[More Information Needed]", true)}}
27
+ ```
lerobot/common/datasets/compute_stats.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import numpy as np
17
+
18
+ from lerobot.common.datasets.utils import load_image_as_numpy
19
+
20
+
21
+ def estimate_num_samples(
22
+ dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
23
+ ) -> int:
24
+ """Heuristic to estimate the number of samples based on dataset size.
25
+ The power controls the sample growth relative to dataset size.
26
+ Lower the power for less number of samples.
27
+
28
+ For default arguments, we have:
29
+ - from 1 to ~500, num_samples=100
30
+ - at 1000, num_samples=177
31
+ - at 2000, num_samples=299
32
+ - at 5000, num_samples=594
33
+ - at 10000, num_samples=1000
34
+ - at 20000, num_samples=1681
35
+ """
36
+ if dataset_len < min_num_samples:
37
+ min_num_samples = dataset_len
38
+ return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
39
+
40
+
41
+ def sample_indices(data_len: int) -> list[int]:
42
+ num_samples = estimate_num_samples(data_len)
43
+ return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
44
+
45
+
46
+ def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
47
+ _, height, width = img.shape
48
+
49
+ if max(width, height) < max_size_threshold:
50
+ # no downsampling needed
51
+ return img
52
+
53
+ downsample_factor = int(width / target_size) if width > height else int(height / target_size)
54
+ return img[:, ::downsample_factor, ::downsample_factor]
55
+
56
+
57
+ def sample_images(image_paths: list[str]) -> np.ndarray:
58
+ sampled_indices = sample_indices(len(image_paths))
59
+
60
+ images = None
61
+ for i, idx in enumerate(sampled_indices):
62
+ path = image_paths[idx]
63
+ # we load as uint8 to reduce memory usage
64
+ img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
65
+ img = auto_downsample_height_width(img)
66
+
67
+ if images is None:
68
+ images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
69
+
70
+ images[i] = img
71
+
72
+ return images
73
+
74
+
75
+ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
76
+ return {
77
+ "min": np.min(array, axis=axis, keepdims=keepdims),
78
+ "max": np.max(array, axis=axis, keepdims=keepdims),
79
+ "mean": np.mean(array, axis=axis, keepdims=keepdims),
80
+ "std": np.std(array, axis=axis, keepdims=keepdims),
81
+ "count": np.array([len(array)]),
82
+ }
83
+
84
+
85
+ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
86
+ ep_stats = {}
87
+ for key, data in episode_data.items():
88
+ if features[key]["dtype"] == "string":
89
+ continue # HACK: we should receive np.arrays of strings
90
+ elif features[key]["dtype"] in ["image", "video"]:
91
+ ep_ft_array = sample_images(data) # data is a list of image paths
92
+ axes_to_reduce = (0, 2, 3) # keep channel dim
93
+ keepdims = True
94
+ else:
95
+ ep_ft_array = data # data is already a np.ndarray
96
+ axes_to_reduce = 0 # compute stats over the first axis
97
+ keepdims = data.ndim == 1 # keep as np.array
98
+
99
+ ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
100
+
101
+ # finally, we normalize and remove batch dim for images
102
+ if features[key]["dtype"] in ["image", "video"]:
103
+ ep_stats[key] = {
104
+ k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
105
+ }
106
+
107
+ return ep_stats
108
+
109
+
110
+ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
111
+ for i in range(len(stats_list)):
112
+ for fkey in stats_list[i]:
113
+ for k, v in stats_list[i][fkey].items():
114
+ if not isinstance(v, np.ndarray):
115
+ raise ValueError(
116
+ f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
117
+ )
118
+ if v.ndim == 0:
119
+ raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
120
+ if k == "count" and v.shape != (1,):
121
+ raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
122
+ if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
123
+ raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
124
+
125
+
126
+ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
127
+ """Aggregates stats for a single feature."""
128
+ means = np.stack([s["mean"] for s in stats_ft_list])
129
+ variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
130
+ counts = np.stack([s["count"] for s in stats_ft_list])
131
+ total_count = counts.sum(axis=0)
132
+
133
+ # Prepare weighted mean by matching number of dimensions
134
+ while counts.ndim < means.ndim:
135
+ counts = np.expand_dims(counts, axis=-1)
136
+
137
+ # Compute the weighted mean
138
+ weighted_means = means * counts
139
+ total_mean = weighted_means.sum(axis=0) / total_count
140
+
141
+ # Compute the variance using the parallel algorithm
142
+ delta_means = means - total_mean
143
+ weighted_variances = (variances + delta_means**2) * counts
144
+ total_variance = weighted_variances.sum(axis=0) / total_count
145
+
146
+ return {
147
+ "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
148
+ "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
149
+ "mean": total_mean,
150
+ "std": np.sqrt(total_variance),
151
+ "count": total_count,
152
+ }
153
+
154
+
155
+ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
156
+ """Aggregate stats from multiple compute_stats outputs into a single set of stats.
157
+
158
+ The final stats will have the union of all data keys from each of the stats dicts.
159
+
160
+ For instance:
161
+ - new_min = min(min_dataset_0, min_dataset_1, ...)
162
+ - new_max = max(max_dataset_0, max_dataset_1, ...)
163
+ - new_mean = (mean of all data, weighted by counts)
164
+ - new_std = (std of all data)
165
+ """
166
+
167
+ _assert_type_and_shape(stats_list)
168
+
169
+ data_keys = {key for stats in stats_list for key in stats}
170
+ aggregated_stats = {key: {} for key in data_keys}
171
+
172
+ for key in data_keys:
173
+ stats_with_key = [stats[key] for stats in stats_list if key in stats]
174
+ aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
175
+
176
+ return aggregated_stats
lerobot/common/datasets/factory.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ from pprint import pformat
18
+
19
+ import torch
20
+
21
+ from lerobot.common.datasets.lerobot_dataset import (
22
+ LeRobotDataset,
23
+ LeRobotDatasetMetadata,
24
+ MultiLeRobotDataset,
25
+ )
26
+ from lerobot.common.datasets.transforms import ImageTransforms
27
+ from lerobot.configs.policies import PreTrainedConfig
28
+ from lerobot.configs.train import TrainPipelineConfig
29
+
30
+ IMAGENET_STATS = {
31
+ "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
32
+ "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
33
+ }
34
+
35
+
36
+ def resolve_delta_timestamps(
37
+ cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
38
+ ) -> dict[str, list] | None:
39
+ """Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
40
+
41
+ Args:
42
+ cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
43
+ ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
44
+ delta_timestamps against.
45
+
46
+ Returns:
47
+ dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
48
+ {
49
+ "observation.state": [-0.04, -0.02, 0]
50
+ "observation.action": [-0.02, 0, 0.02]
51
+ }
52
+ returns `None` if the the resulting dict is empty.
53
+ """
54
+ delta_timestamps = {}
55
+ for key in ds_meta.features:
56
+ if key == "next.reward" and cfg.reward_delta_indices is not None:
57
+ delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
58
+ if key == "action" and cfg.action_delta_indices is not None:
59
+ delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
60
+ if key.startswith("observation.") and cfg.observation_delta_indices is not None:
61
+ delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
62
+
63
+ if len(delta_timestamps) == 0:
64
+ delta_timestamps = None
65
+
66
+ return delta_timestamps
67
+
68
+
69
+ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
70
+ """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
71
+
72
+ Args:
73
+ cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
74
+
75
+ Raises:
76
+ NotImplementedError: The MultiLeRobotDataset is currently deactivated.
77
+
78
+ Returns:
79
+ LeRobotDataset | MultiLeRobotDataset
80
+ """
81
+ image_transforms = (
82
+ ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
83
+ )
84
+
85
+ if isinstance(cfg.dataset.repo_id, str):
86
+ ds_meta = LeRobotDatasetMetadata(
87
+ cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
88
+ )
89
+ delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
90
+ dataset = LeRobotDataset(
91
+ cfg.dataset.repo_id,
92
+ root=cfg.dataset.root,
93
+ episodes=cfg.dataset.episodes,
94
+ delta_timestamps=delta_timestamps,
95
+ image_transforms=image_transforms,
96
+ revision=cfg.dataset.revision,
97
+ video_backend=cfg.dataset.video_backend,
98
+ )
99
+ else:
100
+ raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
101
+ dataset = MultiLeRobotDataset(
102
+ cfg.dataset.repo_id,
103
+ # TODO(aliberts): add proper support for multi dataset
104
+ # delta_timestamps=delta_timestamps,
105
+ image_transforms=image_transforms,
106
+ video_backend=cfg.dataset.video_backend,
107
+ )
108
+ logging.info(
109
+ "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
110
+ f"{pformat(dataset.repo_id_to_index, indent=2)}"
111
+ )
112
+
113
+ if cfg.dataset.use_imagenet_stats:
114
+ for key in dataset.meta.camera_keys:
115
+ for stats_type, stats in IMAGENET_STATS.items():
116
+ dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
117
+
118
+ return dataset
lerobot/common/datasets/image_writer.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import multiprocessing
17
+ import queue
18
+ import threading
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+
25
+
26
+ def safe_stop_image_writer(func):
27
+ def wrapper(*args, **kwargs):
28
+ try:
29
+ return func(*args, **kwargs)
30
+ except Exception as e:
31
+ dataset = kwargs.get("dataset")
32
+ image_writer = getattr(dataset, "image_writer", None) if dataset else None
33
+ if image_writer is not None:
34
+ print("Waiting for image writer to terminate...")
35
+ image_writer.stop()
36
+ raise e
37
+
38
+ return wrapper
39
+
40
+
41
+ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
42
+ # TODO(aliberts): handle 1 channel and 4 for depth images
43
+ if image_array.ndim != 3:
44
+ raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
45
+
46
+ if image_array.shape[0] == 3:
47
+ # Transpose from pytorch convention (C, H, W) to (H, W, C)
48
+ image_array = image_array.transpose(1, 2, 0)
49
+
50
+ elif image_array.shape[-1] != 3:
51
+ raise NotImplementedError(
52
+ f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
53
+ )
54
+
55
+ if image_array.dtype != np.uint8:
56
+ if range_check:
57
+ max_ = image_array.max().item()
58
+ min_ = image_array.min().item()
59
+ if max_ > 1.0 or min_ < 0.0:
60
+ raise ValueError(
61
+ "The image data type is float, which requires values in the range [0.0, 1.0]. "
62
+ f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
63
+ "provide a uint8 image with values in the range [0, 255]."
64
+ )
65
+
66
+ image_array = (image_array * 255).astype(np.uint8)
67
+
68
+ return PIL.Image.fromarray(image_array)
69
+
70
+
71
+ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
72
+ try:
73
+ if isinstance(image, np.ndarray):
74
+ img = image_array_to_pil_image(image)
75
+ elif isinstance(image, PIL.Image.Image):
76
+ img = image
77
+ else:
78
+ raise TypeError(f"Unsupported image type: {type(image)}")
79
+ img.save(fpath)
80
+ except Exception as e:
81
+ print(f"Error writing image {fpath}: {e}")
82
+
83
+
84
+ def worker_thread_loop(queue: queue.Queue):
85
+ while True:
86
+ item = queue.get()
87
+ if item is None:
88
+ queue.task_done()
89
+ break
90
+ image_array, fpath = item
91
+ write_image(image_array, fpath)
92
+ queue.task_done()
93
+
94
+
95
+ def worker_process(queue: queue.Queue, num_threads: int):
96
+ threads = []
97
+ for _ in range(num_threads):
98
+ t = threading.Thread(target=worker_thread_loop, args=(queue,))
99
+ t.daemon = True
100
+ t.start()
101
+ threads.append(t)
102
+ for t in threads:
103
+ t.join()
104
+
105
+
106
+ class AsyncImageWriter:
107
+ """
108
+ This class abstract away the initialisation of processes or/and threads to
109
+ save images on disk asynchrounously, which is critical to control a robot and record data
110
+ at a high frame rate.
111
+
112
+ When `num_processes=0`, it creates a threads pool of size `num_threads`.
113
+ When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
114
+ their own threads pool of size `num_threads`.
115
+
116
+ The optimal number of processes and threads depends on your computer capabilities.
117
+ We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
118
+ the number of threads. If it is still not stable, try to use 1 subprocess, or more.
119
+ """
120
+
121
+ def __init__(self, num_processes: int = 0, num_threads: int = 1):
122
+ self.num_processes = num_processes
123
+ self.num_threads = num_threads
124
+ self.queue = None
125
+ self.threads = []
126
+ self.processes = []
127
+ self._stopped = False
128
+
129
+ if num_threads <= 0 and num_processes <= 0:
130
+ raise ValueError("Number of threads and processes must be greater than zero.")
131
+
132
+ if self.num_processes == 0:
133
+ # Use threading
134
+ self.queue = queue.Queue()
135
+ for _ in range(self.num_threads):
136
+ t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
137
+ t.daemon = True
138
+ t.start()
139
+ self.threads.append(t)
140
+ else:
141
+ # Use multiprocessing
142
+ self.queue = multiprocessing.JoinableQueue()
143
+ for _ in range(self.num_processes):
144
+ p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
145
+ p.daemon = True
146
+ p.start()
147
+ self.processes.append(p)
148
+
149
+ def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
150
+ if isinstance(image, torch.Tensor):
151
+ # Convert tensor to numpy array to minimize main process time
152
+ image = image.cpu().numpy()
153
+ self.queue.put((image, fpath))
154
+
155
+ def wait_until_done(self):
156
+ self.queue.join()
157
+
158
+ def stop(self):
159
+ if self._stopped:
160
+ return
161
+
162
+ if self.num_processes == 0:
163
+ for _ in self.threads:
164
+ self.queue.put(None)
165
+ for t in self.threads:
166
+ t.join()
167
+ else:
168
+ num_nones = self.num_processes * self.num_threads
169
+ for _ in range(num_nones):
170
+ self.queue.put(None)
171
+ for p in self.processes:
172
+ p.join()
173
+ if p.is_alive():
174
+ p.terminate()
175
+ self.queue.close()
176
+ self.queue.join_thread()
177
+
178
+ self._stopped = True
lerobot/common/datasets/lerobot_dataset.py ADDED
@@ -0,0 +1,1217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import contextlib
17
+ import logging
18
+ import shutil
19
+ from pathlib import Path
20
+ from typing import Callable
21
+
22
+ import datasets
23
+ import numpy as np
24
+ import packaging.version
25
+ import PIL.Image
26
+ import torch
27
+ import torch.utils
28
+ from datasets import concatenate_datasets, load_dataset
29
+ from huggingface_hub import HfApi, snapshot_download
30
+ from huggingface_hub.constants import REPOCARD_NAME
31
+ from huggingface_hub.errors import RevisionNotFoundError
32
+
33
+ from lerobot.common.constants import HF_LEROBOT_HOME
34
+ from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
35
+ from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
36
+ from lerobot.common.datasets.utils import (
37
+ DEFAULT_FEATURES,
38
+ DEFAULT_IMAGE_PATH,
39
+ INFO_PATH,
40
+ TASKS_PATH,
41
+ append_jsonlines,
42
+ backward_compatible_episodes_stats,
43
+ check_delta_timestamps,
44
+ check_timestamps_sync,
45
+ check_version_compatibility,
46
+ create_empty_dataset_info,
47
+ create_lerobot_dataset_card,
48
+ embed_images,
49
+ get_delta_indices,
50
+ get_episode_data_index,
51
+ get_features_from_robot,
52
+ get_hf_features_from_features,
53
+ get_safe_version,
54
+ hf_transform_to_torch,
55
+ is_valid_version,
56
+ load_episodes,
57
+ load_episodes_stats,
58
+ load_info,
59
+ load_stats,
60
+ load_tasks,
61
+ validate_episode_buffer,
62
+ validate_frame,
63
+ write_episode,
64
+ write_episode_stats,
65
+ write_info,
66
+ write_json,
67
+ )
68
+ from lerobot.common.datasets.video_utils import (
69
+ VideoFrame,
70
+ decode_video_frames,
71
+ encode_video_frames,
72
+ get_safe_default_codec,
73
+ get_video_info,
74
+ )
75
+ from lerobot.common.robot_devices.robots.utils import Robot
76
+
77
+ CODEBASE_VERSION = "v2.1"
78
+
79
+
80
+ class LeRobotDatasetMetadata:
81
+ def __init__(
82
+ self,
83
+ repo_id: str,
84
+ root: str | Path | None = None,
85
+ revision: str | None = None,
86
+ force_cache_sync: bool = False,
87
+ ):
88
+ self.repo_id = repo_id
89
+ self.revision = revision if revision else CODEBASE_VERSION
90
+ self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
91
+
92
+ try:
93
+ if force_cache_sync:
94
+ raise FileNotFoundError
95
+ self.load_metadata()
96
+ except (FileNotFoundError, NotADirectoryError):
97
+ if is_valid_version(self.revision):
98
+ self.revision = get_safe_version(self.repo_id, self.revision)
99
+
100
+ (self.root / "meta").mkdir(exist_ok=True, parents=True)
101
+ self.pull_from_repo(allow_patterns="meta/")
102
+ self.load_metadata()
103
+
104
+ def load_metadata(self):
105
+ self.info = load_info(self.root)
106
+ check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
107
+ self.tasks, self.task_to_task_index = load_tasks(self.root)
108
+ self.episodes = load_episodes(self.root)
109
+ if self._version < packaging.version.parse("v2.1"):
110
+ self.stats = load_stats(self.root)
111
+ self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
112
+ else:
113
+ self.episodes_stats = load_episodes_stats(self.root)
114
+ self.stats = aggregate_stats(list(self.episodes_stats.values()))
115
+
116
+ def pull_from_repo(
117
+ self,
118
+ allow_patterns: list[str] | str | None = None,
119
+ ignore_patterns: list[str] | str | None = None,
120
+ ) -> None:
121
+ snapshot_download(
122
+ self.repo_id,
123
+ repo_type="dataset",
124
+ revision=self.revision,
125
+ local_dir=self.root,
126
+ allow_patterns=allow_patterns,
127
+ ignore_patterns=ignore_patterns,
128
+ )
129
+
130
+ @property
131
+ def _version(self) -> packaging.version.Version:
132
+ """Codebase version used to create this dataset."""
133
+ return packaging.version.parse(self.info["codebase_version"])
134
+
135
+ def get_data_file_path(self, ep_index: int) -> Path:
136
+ ep_chunk = self.get_episode_chunk(ep_index)
137
+ fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
138
+ return Path(fpath)
139
+
140
+ def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
141
+ ep_chunk = self.get_episode_chunk(ep_index)
142
+ fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
143
+ return Path(fpath)
144
+
145
+ def get_episode_chunk(self, ep_index: int) -> int:
146
+ return ep_index // self.chunks_size
147
+
148
+ @property
149
+ def data_path(self) -> str:
150
+ """Formattable string for the parquet files."""
151
+ return self.info["data_path"]
152
+
153
+ @property
154
+ def video_path(self) -> str | None:
155
+ """Formattable string for the video files."""
156
+ return self.info["video_path"]
157
+
158
+ @property
159
+ def robot_type(self) -> str | None:
160
+ """Robot type used in recording this dataset."""
161
+ return self.info["robot_type"]
162
+
163
+ @property
164
+ def fps(self) -> int:
165
+ """Frames per second used during data collection."""
166
+ return self.info["fps"]
167
+
168
+ @property
169
+ def features(self) -> dict[str, dict]:
170
+ """All features contained in the dataset."""
171
+ return self.info["features"]
172
+
173
+ @property
174
+ def image_keys(self) -> list[str]:
175
+ """Keys to access visual modalities stored as images."""
176
+ return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
177
+
178
+ @property
179
+ def video_keys(self) -> list[str]:
180
+ """Keys to access visual modalities stored as videos."""
181
+ return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
182
+
183
+ @property
184
+ def camera_keys(self) -> list[str]:
185
+ """Keys to access visual modalities (regardless of their storage method)."""
186
+ return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
187
+
188
+ @property
189
+ def names(self) -> dict[str, list | dict]:
190
+ """Names of the various dimensions of vector modalities."""
191
+ return {key: ft["names"] for key, ft in self.features.items()}
192
+
193
+ @property
194
+ def shapes(self) -> dict:
195
+ """Shapes for the different features."""
196
+ return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
197
+
198
+ @property
199
+ def total_episodes(self) -> int:
200
+ """Total number of episodes available."""
201
+ return self.info["total_episodes"]
202
+
203
+ @property
204
+ def total_frames(self) -> int:
205
+ """Total number of frames saved in this dataset."""
206
+ return self.info["total_frames"]
207
+
208
+ @property
209
+ def total_tasks(self) -> int:
210
+ """Total number of different tasks performed in this dataset."""
211
+ return self.info["total_tasks"]
212
+
213
+ @property
214
+ def total_chunks(self) -> int:
215
+ """Total number of chunks (groups of episodes)."""
216
+ return self.info["total_chunks"]
217
+
218
+ @property
219
+ def chunks_size(self) -> int:
220
+ """Max number of episodes per chunk."""
221
+ return self.info["chunks_size"]
222
+
223
+ def get_task_index(self, task: str) -> int | None:
224
+ """
225
+ Given a task in natural language, returns its task_index if the task already exists in the dataset,
226
+ otherwise return None.
227
+ """
228
+ return self.task_to_task_index.get(task, None)
229
+
230
+ def add_task(self, task: str):
231
+ """
232
+ Given a task in natural language, add it to the dictionary of tasks.
233
+ """
234
+ if task in self.task_to_task_index:
235
+ raise ValueError(f"The task '{task}' already exists and can't be added twice.")
236
+
237
+ task_index = self.info["total_tasks"]
238
+ self.task_to_task_index[task] = task_index
239
+ self.tasks[task_index] = task
240
+ self.info["total_tasks"] += 1
241
+
242
+ task_dict = {
243
+ "task_index": task_index,
244
+ "task": task,
245
+ }
246
+ append_jsonlines(task_dict, self.root / TASKS_PATH)
247
+
248
+ def save_episode(
249
+ self,
250
+ episode_index: int,
251
+ episode_length: int,
252
+ episode_tasks: list[str],
253
+ episode_stats: dict[str, dict],
254
+ ) -> None:
255
+ self.info["total_episodes"] += 1
256
+ self.info["total_frames"] += episode_length
257
+
258
+ chunk = self.get_episode_chunk(episode_index)
259
+ if chunk >= self.total_chunks:
260
+ self.info["total_chunks"] += 1
261
+
262
+ self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
263
+ self.info["total_videos"] += len(self.video_keys)
264
+ if len(self.video_keys) > 0:
265
+ self.update_video_info()
266
+
267
+ write_info(self.info, self.root)
268
+
269
+ episode_dict = {
270
+ "episode_index": episode_index,
271
+ "tasks": episode_tasks,
272
+ "length": episode_length,
273
+ }
274
+ self.episodes[episode_index] = episode_dict
275
+ write_episode(episode_dict, self.root)
276
+
277
+ self.episodes_stats[episode_index] = episode_stats
278
+ self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
279
+ write_episode_stats(episode_index, episode_stats, self.root)
280
+
281
+ def update_video_info(self) -> None:
282
+ """
283
+ Warning: this function writes info from first episode videos, implicitly assuming that all videos have
284
+ been encoded the same way. Also, this means it assumes the first episode exists.
285
+ """
286
+ for key in self.video_keys:
287
+ if not self.features[key].get("info", None):
288
+ video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
289
+ self.info["features"][key]["info"] = get_video_info(video_path)
290
+
291
+ def __repr__(self):
292
+ feature_keys = list(self.features)
293
+ return (
294
+ f"{self.__class__.__name__}({{\n"
295
+ f" Repository ID: '{self.repo_id}',\n"
296
+ f" Total episodes: '{self.total_episodes}',\n"
297
+ f" Total frames: '{self.total_frames}',\n"
298
+ f" Features: '{feature_keys}',\n"
299
+ "})',\n"
300
+ )
301
+
302
+ @classmethod
303
+ def create(
304
+ cls,
305
+ repo_id: str,
306
+ fps: int,
307
+ root: str | Path | None = None,
308
+ robot: Robot | None = None,
309
+ robot_type: str | None = None,
310
+ features: dict | None = None,
311
+ use_videos: bool = True,
312
+ ) -> "LeRobotDatasetMetadata":
313
+ """Creates metadata for a LeRobotDataset."""
314
+ obj = cls.__new__(cls)
315
+ obj.repo_id = repo_id
316
+ obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
317
+
318
+ obj.root.mkdir(parents=True, exist_ok=False)
319
+
320
+ if robot is not None:
321
+ features = get_features_from_robot(robot, use_videos)
322
+ robot_type = robot.robot_type
323
+ if not all(cam.fps == fps for cam in robot.cameras.values()):
324
+ logging.warning(
325
+ f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
326
+ "In this case, frames from lower fps cameras will be repeated to fill in the blanks."
327
+ )
328
+ elif features is None:
329
+ raise ValueError(
330
+ "Dataset features must either come from a Robot or explicitly passed upon creation."
331
+ )
332
+ else:
333
+ # TODO(aliberts, rcadene): implement sanity check for features
334
+ features = {**features, **DEFAULT_FEATURES}
335
+
336
+ # check if none of the features contains a "/" in their names,
337
+ # as this would break the dict flattening in the stats computation, which uses '/' as separator
338
+ for key in features:
339
+ if "/" in key:
340
+ raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
341
+
342
+ features = {**features, **DEFAULT_FEATURES}
343
+
344
+ obj.tasks, obj.task_to_task_index = {}, {}
345
+ obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
346
+ obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
347
+ if len(obj.video_keys) > 0 and not use_videos:
348
+ raise ValueError()
349
+ write_json(obj.info, obj.root / INFO_PATH)
350
+ obj.revision = None
351
+ return obj
352
+
353
+
354
+ class LeRobotDataset(torch.utils.data.Dataset):
355
+ def __init__(
356
+ self,
357
+ repo_id: str,
358
+ root: str | Path | None = None,
359
+ episodes: list[int] | None = None,
360
+ image_transforms: Callable | None = None,
361
+ delta_timestamps: dict[list[float]] | None = None,
362
+ tolerance_s: float = 1e-4,
363
+ revision: str | None = None,
364
+ force_cache_sync: bool = False,
365
+ download_videos: bool = True,
366
+ video_backend: str | None = None,
367
+ ):
368
+ """
369
+ 2 modes are available for instantiating this class, depending on 2 different use cases:
370
+
371
+ 1. Your dataset already exists:
372
+ - On your local disk in the 'root' folder. This is typically the case when you recorded your
373
+ dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
374
+ with 'root' will load your dataset directly from disk. This can happen while you're offline (no
375
+ internet connection).
376
+
377
+ - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
378
+ your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
379
+ the dataset from that address and load it, pending your dataset is compliant with
380
+ codebase_version v2.0. If your dataset has been created before this new format, you will be
381
+ prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at
382
+ lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
383
+
384
+
385
+ 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
386
+ LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
387
+ existing dataset to the LeRobotDataset format.
388
+
389
+
390
+ In terms of files, LeRobotDataset encapsulates 3 main things:
391
+ - metadata:
392
+ - info contains various information about the dataset like shapes, keys, fps etc.
393
+ - stats stores the dataset statistics of the different modalities for normalization
394
+ - tasks contains the prompts for each task of the dataset, which can be used for
395
+ task-conditioned training.
396
+ - hf_dataset (from datasets.Dataset), which will read any values from parquet files.
397
+ - videos (optional) from which frames are loaded to be synchronous with data from parquet files.
398
+
399
+ A typical LeRobotDataset looks like this from its root path:
400
+ .
401
+ ├── data
402
+ │ ├── chunk-000
403
+ │ │ ├── episode_000000.parquet
404
+ │ │ ├── episode_000001.parquet
405
+ │ │ ├── episode_000002.parquet
406
+ │ │ └── ...
407
+ │ ├── chunk-001
408
+ │ │ ├── episode_001000.parquet
409
+ │ │ ├── episode_001001.parquet
410
+ │ │ ├── episode_001002.parquet
411
+ │ │ └── ...
412
+ │ └── ...
413
+ ├── meta
414
+ │ ├── episodes.jsonl
415
+ │ ├── info.json
416
+ │ ├── stats.json
417
+ │ └── tasks.jsonl
418
+ └── videos
419
+ ├── chunk-000
420
+ │ ├── observation.images.laptop
421
+ │ │ ├── episode_000000.mp4
422
+ │ │ ├── episode_000001.mp4
423
+ │ │ ├── episode_000002.mp4
424
+ │ │ └── ...
425
+ │ ├── observation.images.phone
426
+ │ │ ├── episode_000000.mp4
427
+ │ │ ├── episode_000001.mp4
428
+ │ │ ├── episode_000002.mp4
429
+ │ │ └── ...
430
+ ├── chunk-001
431
+ └── ...
432
+
433
+ Note that this file-based structure is designed to be as versatile as possible. The files are split by
434
+ episodes which allows a more granular control over which episodes one wants to use and download. The
435
+ structure of the dataset is entirely described in the info.json file, which can be easily downloaded
436
+ or viewed directly on the hub before downloading any actual data. The type of files used are very
437
+ simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
438
+ for the README).
439
+
440
+ Args:
441
+ repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
442
+ will be stored under root/repo_id.
443
+ root (Path | None, optional): Local directory to use for downloading/writing files. You can also
444
+ set the LEROBOT_HOME environment variable to point to a different location. Defaults to
445
+ '~/.cache/huggingface/lerobot'.
446
+ episodes (list[int] | None, optional): If specified, this will only load episodes specified by
447
+ their episode_index in this list. Defaults to None.
448
+ image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
449
+ torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
450
+ from videos or images). Defaults to None.
451
+ delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
452
+ tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
453
+ sync with the fps value. It is used at the init of the dataset to make sure that each
454
+ timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
455
+ decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
456
+ multiples of 1/fps. Defaults to 1e-4.
457
+ revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
458
+ commit hash. Defaults to current codebase version tag.
459
+ sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
460
+ are already present in the local cache, this will be faster. However, files loaded might not
461
+ be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
462
+ False.
463
+ download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
464
+ video files are already present on local disk, they won't be downloaded again. Defaults to
465
+ True.
466
+ video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
467
+ You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
468
+ """
469
+ super().__init__()
470
+ self.repo_id = repo_id
471
+ self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
472
+ self.image_transforms = image_transforms
473
+ self.delta_timestamps = delta_timestamps
474
+ self.episodes = episodes
475
+ self.tolerance_s = tolerance_s
476
+ self.revision = revision if revision else CODEBASE_VERSION
477
+ self.video_backend = video_backend if video_backend else get_safe_default_codec()
478
+ self.delta_indices = None
479
+
480
+ # Unused attributes
481
+ self.image_writer = None
482
+ self.episode_buffer = None
483
+
484
+ self.root.mkdir(exist_ok=True, parents=True)
485
+
486
+ # Load metadata
487
+ self.meta = LeRobotDatasetMetadata(
488
+ self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
489
+ )
490
+ if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
491
+ episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
492
+ self.stats = aggregate_stats(episodes_stats)
493
+
494
+ # Load actual data
495
+ try:
496
+ if force_cache_sync:
497
+ raise FileNotFoundError
498
+ assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
499
+ self.hf_dataset = self.load_hf_dataset()
500
+ except (AssertionError, FileNotFoundError, NotADirectoryError):
501
+ self.revision = get_safe_version(self.repo_id, self.revision)
502
+ self.download_episodes(download_videos)
503
+ self.hf_dataset = self.load_hf_dataset()
504
+
505
+ self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
506
+
507
+ # Check timestamps
508
+ timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
509
+ episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
510
+ ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
511
+ check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
512
+
513
+ # Setup delta_indices
514
+ if self.delta_timestamps is not None:
515
+ check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
516
+ self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
517
+
518
+ def push_to_hub(
519
+ self,
520
+ branch: str | None = None,
521
+ tags: list | None = None,
522
+ license: str | None = "apache-2.0",
523
+ tag_version: bool = True,
524
+ push_videos: bool = True,
525
+ private: bool = False,
526
+ allow_patterns: list[str] | str | None = None,
527
+ upload_large_folder: bool = False,
528
+ **card_kwargs,
529
+ ) -> None:
530
+ ignore_patterns = ["images/"]
531
+ if not push_videos:
532
+ ignore_patterns.append("videos/")
533
+
534
+ hub_api = HfApi()
535
+ hub_api.create_repo(
536
+ repo_id=self.repo_id,
537
+ private=private,
538
+ repo_type="dataset",
539
+ exist_ok=True,
540
+ )
541
+ if branch:
542
+ hub_api.create_branch(
543
+ repo_id=self.repo_id,
544
+ branch=branch,
545
+ revision=self.revision,
546
+ repo_type="dataset",
547
+ exist_ok=True,
548
+ )
549
+
550
+ upload_kwargs = {
551
+ "repo_id": self.repo_id,
552
+ "folder_path": self.root,
553
+ "repo_type": "dataset",
554
+ "revision": branch,
555
+ "allow_patterns": allow_patterns,
556
+ "ignore_patterns": ignore_patterns,
557
+ }
558
+ if upload_large_folder:
559
+ hub_api.upload_large_folder(**upload_kwargs)
560
+ else:
561
+ hub_api.upload_folder(**upload_kwargs)
562
+
563
+ if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
564
+ card = create_lerobot_dataset_card(
565
+ tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
566
+ )
567
+ card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
568
+
569
+ if tag_version:
570
+ with contextlib.suppress(RevisionNotFoundError):
571
+ hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
572
+ hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
573
+
574
+ def pull_from_repo(
575
+ self,
576
+ allow_patterns: list[str] | str | None = None,
577
+ ignore_patterns: list[str] | str | None = None,
578
+ ) -> None:
579
+ snapshot_download(
580
+ self.repo_id,
581
+ repo_type="dataset",
582
+ revision=self.revision,
583
+ local_dir=self.root,
584
+ allow_patterns=allow_patterns,
585
+ ignore_patterns=ignore_patterns,
586
+ )
587
+
588
+ def download_episodes(self, download_videos: bool = True) -> None:
589
+ """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
590
+ will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
591
+ dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
592
+ in 'local_dir', they won't be downloaded again.
593
+ """
594
+ # TODO(rcadene, aliberts): implement faster transfer
595
+ # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
596
+ files = None
597
+ ignore_patterns = None if download_videos else "videos/"
598
+ if self.episodes is not None:
599
+ files = self.get_episodes_file_paths()
600
+
601
+ self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
602
+
603
+ def get_episodes_file_paths(self) -> list[Path]:
604
+ episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
605
+ fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
606
+ if len(self.meta.video_keys) > 0:
607
+ video_files = [
608
+ str(self.meta.get_video_file_path(ep_idx, vid_key))
609
+ for vid_key in self.meta.video_keys
610
+ for ep_idx in episodes
611
+ ]
612
+ fpaths += video_files
613
+
614
+ return fpaths
615
+
616
+ def load_hf_dataset(self) -> datasets.Dataset:
617
+ """hf_dataset contains all the observations, states, actions, rewards, etc."""
618
+ if self.episodes is None:
619
+ path = str(self.root / "data")
620
+ hf_dataset = load_dataset("parquet", data_dir=path, split="train")
621
+ else:
622
+ files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
623
+ hf_dataset = load_dataset("parquet", data_files=files, split="train")
624
+
625
+ # TODO(aliberts): hf_dataset.set_format("torch")
626
+ hf_dataset.set_transform(hf_transform_to_torch)
627
+ return hf_dataset
628
+
629
+ def create_hf_dataset(self) -> datasets.Dataset:
630
+ features = get_hf_features_from_features(self.features)
631
+ ft_dict = {col: [] for col in features}
632
+ hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
633
+
634
+ # TODO(aliberts): hf_dataset.set_format("torch")
635
+ hf_dataset.set_transform(hf_transform_to_torch)
636
+ return hf_dataset
637
+
638
+ @property
639
+ def fps(self) -> int:
640
+ """Frames per second used during data collection."""
641
+ return self.meta.fps
642
+
643
+ @property
644
+ def num_frames(self) -> int:
645
+ """Number of frames in selected episodes."""
646
+ return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
647
+
648
+ @property
649
+ def num_episodes(self) -> int:
650
+ """Number of episodes selected."""
651
+ return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
652
+
653
+ @property
654
+ def features(self) -> dict[str, dict]:
655
+ return self.meta.features
656
+
657
+ @property
658
+ def hf_features(self) -> datasets.Features:
659
+ """Features of the hf_dataset."""
660
+ if self.hf_dataset is not None:
661
+ return self.hf_dataset.features
662
+ else:
663
+ return get_hf_features_from_features(self.features)
664
+
665
+ def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
666
+ ep_start = self.episode_data_index["from"][ep_idx]
667
+ ep_end = self.episode_data_index["to"][ep_idx]
668
+ query_indices = {
669
+ key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
670
+ for key, delta_idx in self.delta_indices.items()
671
+ }
672
+ padding = { # Pad values outside of current episode range
673
+ f"{key}_is_pad": torch.BoolTensor(
674
+ [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
675
+ )
676
+ for key, delta_idx in self.delta_indices.items()
677
+ }
678
+ return query_indices, padding
679
+
680
+ def _get_query_timestamps(
681
+ self,
682
+ current_ts: float,
683
+ query_indices: dict[str, list[int]] | None = None,
684
+ ) -> dict[str, list[float]]:
685
+ query_timestamps = {}
686
+ for key in self.meta.video_keys:
687
+ if query_indices is not None and key in query_indices:
688
+ timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
689
+ query_timestamps[key] = torch.stack(timestamps).tolist()
690
+ else:
691
+ query_timestamps[key] = [current_ts]
692
+
693
+ return query_timestamps
694
+
695
+ def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
696
+ return {
697
+ key: torch.stack(self.hf_dataset.select(q_idx)[key])
698
+ for key, q_idx in query_indices.items()
699
+ if key not in self.meta.video_keys
700
+ }
701
+
702
+ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
703
+ """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
704
+ in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
705
+ Segmentation Fault. This probably happens because a memory reference to the video loader is created in
706
+ the main process and a subprocess fails to access it.
707
+ """
708
+ item = {}
709
+ for vid_key, query_ts in query_timestamps.items():
710
+ video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
711
+ frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
712
+ item[vid_key] = frames.squeeze(0)
713
+
714
+ return item
715
+
716
+ def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
717
+ for key, val in padding.items():
718
+ item[key] = torch.BoolTensor(val)
719
+ return item
720
+
721
+ def __len__(self):
722
+ return self.num_frames
723
+
724
+ def __getitem__(self, idx) -> dict:
725
+ item = self.hf_dataset[idx]
726
+ ep_idx = item["episode_index"].item()
727
+
728
+ query_indices = None
729
+ if self.delta_indices is not None:
730
+ query_indices, padding = self._get_query_indices(idx, ep_idx)
731
+ query_result = self._query_hf_dataset(query_indices)
732
+ item = {**item, **padding}
733
+ for key, val in query_result.items():
734
+ item[key] = val
735
+
736
+ if len(self.meta.video_keys) > 0:
737
+ current_ts = item["timestamp"].item()
738
+ query_timestamps = self._get_query_timestamps(current_ts, query_indices)
739
+ video_frames = self._query_videos(query_timestamps, ep_idx)
740
+ item = {**video_frames, **item}
741
+
742
+ if self.image_transforms is not None:
743
+ image_keys = self.meta.camera_keys
744
+ for cam in image_keys:
745
+ item[cam] = self.image_transforms(item[cam])
746
+
747
+ # Add task as a string
748
+ task_idx = item["task_index"].item()
749
+ item["task"] = self.meta.tasks[task_idx]
750
+
751
+ return item
752
+
753
+ def __repr__(self):
754
+ feature_keys = list(self.features)
755
+ return (
756
+ f"{self.__class__.__name__}({{\n"
757
+ f" Repository ID: '{self.repo_id}',\n"
758
+ f" Number of selected episodes: '{self.num_episodes}',\n"
759
+ f" Number of selected samples: '{self.num_frames}',\n"
760
+ f" Features: '{feature_keys}',\n"
761
+ "})',\n"
762
+ )
763
+
764
+ def create_episode_buffer(self, episode_index: int | None = None) -> dict:
765
+ current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
766
+ ep_buffer = {}
767
+ # size and task are special cases that are not in self.features
768
+ ep_buffer["size"] = 0
769
+ ep_buffer["task"] = []
770
+ for key in self.features:
771
+ ep_buffer[key] = current_ep_idx if key == "episode_index" else []
772
+ return ep_buffer
773
+
774
+ def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
775
+ fpath = DEFAULT_IMAGE_PATH.format(
776
+ image_key=image_key, episode_index=episode_index, frame_index=frame_index
777
+ )
778
+ return self.root / fpath
779
+
780
+ def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
781
+ if self.image_writer is None:
782
+ if isinstance(image, torch.Tensor):
783
+ image = image.cpu().numpy()
784
+ write_image(image, fpath)
785
+ else:
786
+ self.image_writer.save_image(image=image, fpath=fpath)
787
+
788
+ def add_frame(self, frame: dict) -> None:
789
+ """
790
+ This function only adds the frame to the episode_buffer. Apart from images — which are written in a
791
+ temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
792
+ then needs to be called.
793
+ """
794
+ # Convert torch to numpy if needed
795
+ for name in frame:
796
+ if isinstance(frame[name], torch.Tensor):
797
+ frame[name] = frame[name].numpy()
798
+
799
+ validate_frame(frame, self.features)
800
+
801
+ if self.episode_buffer is None:
802
+ self.episode_buffer = self.create_episode_buffer()
803
+
804
+ # Automatically add frame_index and timestamp to episode buffer
805
+ frame_index = self.episode_buffer["size"]
806
+ timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
807
+ self.episode_buffer["frame_index"].append(frame_index)
808
+ self.episode_buffer["timestamp"].append(timestamp)
809
+
810
+ # Add frame features to episode_buffer
811
+ for key in frame:
812
+ if key == "task":
813
+ # Note: we associate the task in natural language to its task index during `save_episode`
814
+ self.episode_buffer["task"].append(frame["task"])
815
+ continue
816
+
817
+ if key not in self.features:
818
+ raise ValueError(
819
+ f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
820
+ )
821
+
822
+ if self.features[key]["dtype"] in ["image", "video"]:
823
+ img_path = self._get_image_file_path(
824
+ episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
825
+ )
826
+ if frame_index == 0:
827
+ img_path.parent.mkdir(parents=True, exist_ok=True)
828
+ self._save_image(frame[key], img_path)
829
+ self.episode_buffer[key].append(str(img_path))
830
+ else:
831
+ self.episode_buffer[key].append(frame[key])
832
+
833
+ self.episode_buffer["size"] += 1
834
+
835
+ def save_episode(self, episode_data: dict | None = None) -> None:
836
+ """
837
+ This will save to disk the current episode in self.episode_buffer.
838
+
839
+ Args:
840
+ episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
841
+ save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
842
+ None.
843
+ """
844
+ if not episode_data:
845
+ episode_buffer = self.episode_buffer
846
+
847
+ validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
848
+
849
+ # size and task are special cases that won't be added to hf_dataset
850
+ episode_length = episode_buffer.pop("size")
851
+ tasks = episode_buffer.pop("task")
852
+ episode_tasks = list(set(tasks))
853
+ episode_index = episode_buffer["episode_index"]
854
+
855
+ episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
856
+ episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
857
+
858
+ # Add new tasks to the tasks dictionary
859
+ for task in episode_tasks:
860
+ task_index = self.meta.get_task_index(task)
861
+ if task_index is None:
862
+ self.meta.add_task(task)
863
+
864
+ # Given tasks in natural language, find their corresponding task indices
865
+ episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
866
+
867
+ for key, ft in self.features.items():
868
+ # index, episode_index, task_index are already processed above, and image and video
869
+ # are processed separately by storing image path and frame info as meta data
870
+ if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
871
+ continue
872
+ episode_buffer[key] = np.stack(episode_buffer[key])
873
+
874
+ self._wait_image_writer()
875
+ self._save_episode_table(episode_buffer, episode_index)
876
+ ep_stats = compute_episode_stats(episode_buffer, self.features)
877
+
878
+ if len(self.meta.video_keys) > 0:
879
+ video_paths = self.encode_episode_videos(episode_index)
880
+ for key in self.meta.video_keys:
881
+ episode_buffer[key] = video_paths[key]
882
+
883
+ # `meta.save_episode` be executed after encoding the videos
884
+ self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
885
+
886
+ ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
887
+ ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
888
+ check_timestamps_sync(
889
+ episode_buffer["timestamp"],
890
+ episode_buffer["episode_index"],
891
+ ep_data_index_np,
892
+ self.fps,
893
+ self.tolerance_s,
894
+ )
895
+
896
+ video_files = list(self.root.rglob("*.mp4"))
897
+ assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
898
+
899
+ parquet_files = list(self.root.rglob("*.parquet"))
900
+ assert len(parquet_files) == self.num_episodes
901
+
902
+ # delete images
903
+ img_dir = self.root / "images"
904
+ if img_dir.is_dir():
905
+ shutil.rmtree(self.root / "images")
906
+
907
+ if not episode_data: # Reset the buffer
908
+ self.episode_buffer = self.create_episode_buffer()
909
+
910
+ def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
911
+ episode_dict = {key: episode_buffer[key] for key in self.hf_features}
912
+ ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
913
+ ep_dataset = embed_images(ep_dataset)
914
+ self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
915
+ self.hf_dataset.set_transform(hf_transform_to_torch)
916
+ ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
917
+ ep_data_path.parent.mkdir(parents=True, exist_ok=True)
918
+ ep_dataset.to_parquet(ep_data_path)
919
+
920
+ def clear_episode_buffer(self) -> None:
921
+ episode_index = self.episode_buffer["episode_index"]
922
+ if self.image_writer is not None:
923
+ for cam_key in self.meta.camera_keys:
924
+ img_dir = self._get_image_file_path(
925
+ episode_index=episode_index, image_key=cam_key, frame_index=0
926
+ ).parent
927
+ if img_dir.is_dir():
928
+ shutil.rmtree(img_dir)
929
+
930
+ # Reset the buffer
931
+ self.episode_buffer = self.create_episode_buffer()
932
+
933
+ def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
934
+ if isinstance(self.image_writer, AsyncImageWriter):
935
+ logging.warning(
936
+ "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
937
+ )
938
+
939
+ self.image_writer = AsyncImageWriter(
940
+ num_processes=num_processes,
941
+ num_threads=num_threads,
942
+ )
943
+
944
+ def stop_image_writer(self) -> None:
945
+ """
946
+ Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
947
+ remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
948
+ """
949
+ if self.image_writer is not None:
950
+ self.image_writer.stop()
951
+ self.image_writer = None
952
+
953
+ def _wait_image_writer(self) -> None:
954
+ """Wait for asynchronous image writer to finish."""
955
+ if self.image_writer is not None:
956
+ self.image_writer.wait_until_done()
957
+
958
+ def encode_videos(self) -> None:
959
+ """
960
+ Use ffmpeg to convert frames stored as png into mp4 videos.
961
+ Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
962
+ since video encoding with ffmpeg is already using multithreading.
963
+ """
964
+ for ep_idx in range(self.meta.total_episodes):
965
+ self.encode_episode_videos(ep_idx)
966
+
967
+ def encode_episode_videos(self, episode_index: int) -> dict:
968
+ """
969
+ Use ffmpeg to convert frames stored as png into mp4 videos.
970
+ Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
971
+ since video encoding with ffmpeg is already using multithreading.
972
+ """
973
+ video_paths = {}
974
+ for key in self.meta.video_keys:
975
+ video_path = self.root / self.meta.get_video_file_path(episode_index, key)
976
+ video_paths[key] = str(video_path)
977
+ if video_path.is_file():
978
+ # Skip if video is already encoded. Could be the case when resuming data recording.
979
+ continue
980
+ img_dir = self._get_image_file_path(
981
+ episode_index=episode_index, image_key=key, frame_index=0
982
+ ).parent
983
+ encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
984
+
985
+ return video_paths
986
+
987
+ @classmethod
988
+ def create(
989
+ cls,
990
+ repo_id: str,
991
+ fps: int,
992
+ root: str | Path | None = None,
993
+ robot: Robot | None = None,
994
+ robot_type: str | None = None,
995
+ features: dict | None = None,
996
+ use_videos: bool = True,
997
+ tolerance_s: float = 1e-4,
998
+ image_writer_processes: int = 0,
999
+ image_writer_threads: int = 0,
1000
+ video_backend: str | None = None,
1001
+ ) -> "LeRobotDataset":
1002
+ """Create a LeRobot Dataset from scratch in order to record data."""
1003
+ obj = cls.__new__(cls)
1004
+ obj.meta = LeRobotDatasetMetadata.create(
1005
+ repo_id=repo_id,
1006
+ fps=fps,
1007
+ root=root,
1008
+ robot=robot,
1009
+ robot_type=robot_type,
1010
+ features=features,
1011
+ use_videos=use_videos,
1012
+ )
1013
+ obj.repo_id = obj.meta.repo_id
1014
+ obj.root = obj.meta.root
1015
+ obj.revision = None
1016
+ obj.tolerance_s = tolerance_s
1017
+ obj.image_writer = None
1018
+
1019
+ if image_writer_processes or image_writer_threads:
1020
+ obj.start_image_writer(image_writer_processes, image_writer_threads)
1021
+
1022
+ # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
1023
+ obj.episode_buffer = obj.create_episode_buffer()
1024
+
1025
+ obj.episodes = None
1026
+ obj.hf_dataset = obj.create_hf_dataset()
1027
+ obj.image_transforms = None
1028
+ obj.delta_timestamps = None
1029
+ obj.delta_indices = None
1030
+ obj.episode_data_index = None
1031
+ obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
1032
+ return obj
1033
+
1034
+
1035
+ class MultiLeRobotDataset(torch.utils.data.Dataset):
1036
+ """A dataset consisting of multiple underlying `LeRobotDataset`s.
1037
+
1038
+ The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
1039
+ structure of `LeRobotDataset`.
1040
+ """
1041
+
1042
+ def __init__(
1043
+ self,
1044
+ repo_ids: list[str],
1045
+ root: str | Path | None = None,
1046
+ episodes: dict | None = None,
1047
+ image_transforms: Callable | None = None,
1048
+ delta_timestamps: dict[list[float]] | None = None,
1049
+ tolerances_s: dict | None = None,
1050
+ download_videos: bool = True,
1051
+ video_backend: str | None = None,
1052
+ ):
1053
+ super().__init__()
1054
+ self.repo_ids = repo_ids
1055
+ self.root = Path(root) if root else HF_LEROBOT_HOME
1056
+ self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
1057
+ # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
1058
+ # are handled by this class.
1059
+ self._datasets = [
1060
+ LeRobotDataset(
1061
+ repo_id,
1062
+ root=self.root / repo_id,
1063
+ episodes=episodes[repo_id] if episodes else None,
1064
+ image_transforms=image_transforms,
1065
+ delta_timestamps=delta_timestamps,
1066
+ tolerance_s=self.tolerances_s[repo_id],
1067
+ download_videos=download_videos,
1068
+ video_backend=video_backend,
1069
+ )
1070
+ for repo_id in repo_ids
1071
+ ]
1072
+
1073
+ # Disable any data keys that are not common across all of the datasets. Note: we may relax this
1074
+ # restriction in future iterations of this class. For now, this is necessary at least for being able
1075
+ # to use PyTorch's default DataLoader collate function.
1076
+ self.disabled_features = set()
1077
+ intersection_features = set(self._datasets[0].features)
1078
+ for ds in self._datasets:
1079
+ intersection_features.intersection_update(ds.features)
1080
+ if len(intersection_features) == 0:
1081
+ raise RuntimeError(
1082
+ "Multiple datasets were provided but they had no keys common to all of them. "
1083
+ "The multi-dataset functionality currently only keeps common keys."
1084
+ )
1085
+ for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
1086
+ extra_keys = set(ds.features).difference(intersection_features)
1087
+ logging.warning(
1088
+ f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
1089
+ "other datasets."
1090
+ )
1091
+ self.disabled_features.update(extra_keys)
1092
+
1093
+ self.image_transforms = image_transforms
1094
+ self.delta_timestamps = delta_timestamps
1095
+ # TODO(rcadene, aliberts): We should not perform this aggregation for datasets
1096
+ # with multiple robots of different ranges. Instead we should have one normalization
1097
+ # per robot.
1098
+ self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
1099
+
1100
+ @property
1101
+ def repo_id_to_index(self):
1102
+ """Return a mapping from dataset repo_id to a dataset index automatically created by this class.
1103
+
1104
+ This index is incorporated as a data key in the dictionary returned by `__getitem__`.
1105
+ """
1106
+ return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
1107
+
1108
+ @property
1109
+ def repo_index_to_id(self):
1110
+ """Return the inverse mapping if repo_id_to_index."""
1111
+ return {v: k for k, v in self.repo_id_to_index}
1112
+
1113
+ @property
1114
+ def fps(self) -> int:
1115
+ """Frames per second used during data collection.
1116
+
1117
+ NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
1118
+ """
1119
+ return self._datasets[0].meta.info["fps"]
1120
+
1121
+ @property
1122
+ def video(self) -> bool:
1123
+ """Returns True if this dataset loads video frames from mp4 files.
1124
+
1125
+ Returns False if it only loads images from png files.
1126
+
1127
+ NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
1128
+ """
1129
+ return self._datasets[0].meta.info.get("video", False)
1130
+
1131
+ @property
1132
+ def features(self) -> datasets.Features:
1133
+ features = {}
1134
+ for dataset in self._datasets:
1135
+ features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
1136
+ return features
1137
+
1138
+ @property
1139
+ def camera_keys(self) -> list[str]:
1140
+ """Keys to access image and video stream from cameras."""
1141
+ keys = []
1142
+ for key, feats in self.features.items():
1143
+ if isinstance(feats, (datasets.Image, VideoFrame)):
1144
+ keys.append(key)
1145
+ return keys
1146
+
1147
+ @property
1148
+ def video_frame_keys(self) -> list[str]:
1149
+ """Keys to access video frames that requires to be decoded into images.
1150
+
1151
+ Note: It is empty if the dataset contains images only,
1152
+ or equal to `self.cameras` if the dataset contains videos only,
1153
+ or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
1154
+ """
1155
+ video_frame_keys = []
1156
+ for key, feats in self.features.items():
1157
+ if isinstance(feats, VideoFrame):
1158
+ video_frame_keys.append(key)
1159
+ return video_frame_keys
1160
+
1161
+ @property
1162
+ def num_frames(self) -> int:
1163
+ """Number of samples/frames."""
1164
+ return sum(d.num_frames for d in self._datasets)
1165
+
1166
+ @property
1167
+ def num_episodes(self) -> int:
1168
+ """Number of episodes."""
1169
+ return sum(d.num_episodes for d in self._datasets)
1170
+
1171
+ @property
1172
+ def tolerance_s(self) -> float:
1173
+ """Tolerance in seconds used to discard loaded frames when their timestamps
1174
+ are not close enough from the requested frames. It is only used when `delta_timestamps`
1175
+ is provided or when loading video frames from mp4 files.
1176
+ """
1177
+ # 1e-4 to account for possible numerical error
1178
+ return 1 / self.fps - 1e-4
1179
+
1180
+ def __len__(self):
1181
+ return self.num_frames
1182
+
1183
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
1184
+ if idx >= len(self):
1185
+ raise IndexError(f"Index {idx} out of bounds.")
1186
+ # Determine which dataset to get an item from based on the index.
1187
+ start_idx = 0
1188
+ dataset_idx = 0
1189
+ for dataset in self._datasets:
1190
+ if idx >= start_idx + dataset.num_frames:
1191
+ start_idx += dataset.num_frames
1192
+ dataset_idx += 1
1193
+ continue
1194
+ break
1195
+ else:
1196
+ raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
1197
+ item = self._datasets[dataset_idx][idx - start_idx]
1198
+ item["dataset_index"] = torch.tensor(dataset_idx)
1199
+ for data_key in self.disabled_features:
1200
+ if data_key in item:
1201
+ del item[data_key]
1202
+
1203
+ return item
1204
+
1205
+ def __repr__(self):
1206
+ return (
1207
+ f"{self.__class__.__name__}(\n"
1208
+ f" Repository IDs: '{self.repo_ids}',\n"
1209
+ f" Number of Samples: {self.num_frames},\n"
1210
+ f" Number of Episodes: {self.num_episodes},\n"
1211
+ f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
1212
+ f" Recorded Frames per Second: {self.fps},\n"
1213
+ f" Camera Keys: {self.camera_keys},\n"
1214
+ f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
1215
+ f" Transformations: {self.image_transforms},\n"
1216
+ f")"
1217
+ )
lerobot/common/datasets/online_buffer.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """An online buffer for the online training loop in train.py
17
+
18
+ Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
19
+ consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
20
+ faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
21
+ supports in-place slicing and mutation which is very handy for a dynamic buffer.
22
+ """
23
+
24
+ import os
25
+ from pathlib import Path
26
+ from typing import Any
27
+
28
+ import numpy as np
29
+ import torch
30
+
31
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
32
+
33
+
34
+ def _make_memmap_safe(**kwargs) -> np.memmap:
35
+ """Make a numpy memmap with checks on available disk space first.
36
+
37
+ Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
38
+
39
+ For information on dtypes:
40
+ https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
41
+ """
42
+ if kwargs["mode"].startswith("w"):
43
+ required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
44
+ stats = os.statvfs(Path(kwargs["filename"]).parent)
45
+ available_space = stats.f_bavail * stats.f_frsize # bytes
46
+ if required_space >= available_space * 0.8:
47
+ raise RuntimeError(
48
+ f"You're about to take up {required_space} of {available_space} bytes available."
49
+ )
50
+ return np.memmap(**kwargs)
51
+
52
+
53
+ class OnlineBuffer(torch.utils.data.Dataset):
54
+ """FIFO data buffer for the online training loop in train.py.
55
+
56
+ Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
57
+ loop in the same way that a LeRobotDataset would be used.
58
+
59
+ The underlying data structure will have data inserted in a circular fashion. Always insert after the
60
+ last index, and when you reach the end, wrap around to the start.
61
+
62
+ The data is stored in a numpy memmap.
63
+ """
64
+
65
+ NEXT_INDEX_KEY = "_next_index"
66
+ OCCUPANCY_MASK_KEY = "_occupancy_mask"
67
+ INDEX_KEY = "index"
68
+ FRAME_INDEX_KEY = "frame_index"
69
+ EPISODE_INDEX_KEY = "episode_index"
70
+ TIMESTAMP_KEY = "timestamp"
71
+ IS_PAD_POSTFIX = "_is_pad"
72
+
73
+ def __init__(
74
+ self,
75
+ write_dir: str | Path,
76
+ data_spec: dict[str, Any] | None,
77
+ buffer_capacity: int | None,
78
+ fps: float | None = None,
79
+ delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
80
+ ):
81
+ """
82
+ The online buffer can be provided from scratch or you can load an existing online buffer by passing
83
+ a `write_dir` associated with an existing buffer.
84
+
85
+ Args:
86
+ write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
87
+ Note that if the files already exist, they are opened in read-write mode (used for training
88
+ resumption.)
89
+ data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
90
+ "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
91
+ but note that "index", "frame_index" and "episode_index" are already accounted for by this
92
+ class, so you don't need to include them.
93
+ buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
94
+ system's available disk space when choosing this.
95
+ fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
96
+ delta_timestamps logic. You can pass None if you are not using delta_timestamps.
97
+ delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
98
+ converted to dict[str, np.ndarray] for optimization purposes.
99
+
100
+ """
101
+ self.set_delta_timestamps(delta_timestamps)
102
+ self._fps = fps
103
+ # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
104
+ # the requested frames. It is only used when `delta_timestamps` is provided.
105
+ # minus 1e-4 to account for possible numerical error
106
+ self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
107
+ self._buffer_capacity = buffer_capacity
108
+ data_spec = self._make_data_spec(data_spec, buffer_capacity)
109
+ Path(write_dir).mkdir(parents=True, exist_ok=True)
110
+ self._data = {}
111
+ for k, v in data_spec.items():
112
+ self._data[k] = _make_memmap_safe(
113
+ filename=Path(write_dir) / k,
114
+ dtype=v["dtype"] if v is not None else None,
115
+ mode="r+" if (Path(write_dir) / k).exists() else "w+",
116
+ shape=tuple(v["shape"]) if v is not None else None,
117
+ )
118
+
119
+ @property
120
+ def delta_timestamps(self) -> dict[str, np.ndarray] | None:
121
+ return self._delta_timestamps
122
+
123
+ def set_delta_timestamps(self, value: dict[str, list[float]] | None):
124
+ """Set delta_timestamps converting the values to numpy arrays.
125
+
126
+ The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
127
+ need to be converted into numpy arrays.
128
+ """
129
+ if value is not None:
130
+ self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
131
+ else:
132
+ self._delta_timestamps = None
133
+
134
+ def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
135
+ """Makes the data spec for np.memmap."""
136
+ if any(k.startswith("_") for k in data_spec):
137
+ raise ValueError(
138
+ "data_spec keys should not start with '_'. This prefix is reserved for internal logic."
139
+ )
140
+ preset_keys = {
141
+ OnlineBuffer.INDEX_KEY,
142
+ OnlineBuffer.FRAME_INDEX_KEY,
143
+ OnlineBuffer.EPISODE_INDEX_KEY,
144
+ OnlineBuffer.TIMESTAMP_KEY,
145
+ }
146
+ if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
147
+ raise ValueError(
148
+ f"data_spec should not contain any of {preset_keys} as these are handled internally. "
149
+ f"The provided data_spec has {intersection}."
150
+ )
151
+ complete_data_spec = {
152
+ # _next_index will be a pointer to the next index that we should start filling from when we add
153
+ # more data.
154
+ OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
155
+ # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
156
+ # with real data rather than the dummy initialization.
157
+ OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
158
+ OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
159
+ OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
160
+ OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
161
+ OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
162
+ }
163
+ for k, v in data_spec.items():
164
+ complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
165
+ return complete_data_spec
166
+
167
+ def add_data(self, data: dict[str, np.ndarray]):
168
+ """Add new data to the buffer, which could potentially mean shifting old data out.
169
+
170
+ The new data should contain all the frames (in order) of any number of episodes. The indices should
171
+ start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
172
+ `eval_policy` functions in `eval.py` for more information on how the data is constructed.
173
+
174
+ Shift the incoming data index and episode_index to continue on from the last frame. Note that this
175
+ will be done in place!
176
+ """
177
+ if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
178
+ raise ValueError(f"Missing data keys: {missing_keys}")
179
+ new_data_length = len(data[self.data_keys[0]])
180
+ if not all(len(data[k]) == new_data_length for k in self.data_keys):
181
+ raise ValueError("All data items should have the same length")
182
+
183
+ next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
184
+
185
+ # Sanity check to make sure that the new data indices start from 0.
186
+ assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
187
+ assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
188
+
189
+ # Shift the incoming indices if necessary.
190
+ if self.num_frames > 0:
191
+ last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
192
+ last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
193
+ data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
194
+ data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
195
+
196
+ # Insert the new data starting from next_index. It may be necessary to wrap around to the start.
197
+ n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
198
+ for k in self.data_keys:
199
+ if n_surplus == 0:
200
+ slc = slice(next_index, next_index + new_data_length)
201
+ self._data[k][slc] = data[k]
202
+ self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
203
+ else:
204
+ self._data[k][next_index:] = data[k][:-n_surplus]
205
+ self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
206
+ self._data[k][:n_surplus] = data[k][-n_surplus:]
207
+ if n_surplus == 0:
208
+ self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
209
+ else:
210
+ self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
211
+
212
+ @property
213
+ def data_keys(self) -> list[str]:
214
+ keys = set(self._data)
215
+ keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
216
+ keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
217
+ return sorted(keys)
218
+
219
+ @property
220
+ def fps(self) -> float | None:
221
+ return self._fps
222
+
223
+ @property
224
+ def num_episodes(self) -> int:
225
+ return len(
226
+ np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
227
+ )
228
+
229
+ @property
230
+ def num_frames(self) -> int:
231
+ return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
232
+
233
+ def __len__(self):
234
+ return self.num_frames
235
+
236
+ def _item_to_tensors(self, item: dict) -> dict:
237
+ item_ = {}
238
+ for k, v in item.items():
239
+ if isinstance(v, torch.Tensor):
240
+ item_[k] = v
241
+ elif isinstance(v, np.ndarray):
242
+ item_[k] = torch.from_numpy(v)
243
+ else:
244
+ item_[k] = torch.tensor(v)
245
+ return item_
246
+
247
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
248
+ if idx >= len(self) or idx < -len(self):
249
+ raise IndexError
250
+
251
+ item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
252
+
253
+ if self.delta_timestamps is None:
254
+ return self._item_to_tensors(item)
255
+
256
+ episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
257
+ current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
258
+ episode_data_indices = np.where(
259
+ np.bitwise_and(
260
+ self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
261
+ self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
262
+ )
263
+ )[0]
264
+ episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
265
+
266
+ for data_key in self.delta_timestamps:
267
+ # Note: The logic in this loop is copied from `load_previous_and_future_frames`.
268
+ # Get timestamps used as query to retrieve data of previous/future frames.
269
+ query_ts = current_ts + self.delta_timestamps[data_key]
270
+
271
+ # Compute distances between each query timestamp and all timestamps of all the frames belonging to
272
+ # the episode.
273
+ dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
274
+ argmin_ = np.argmin(dist, axis=1)
275
+ min_ = dist[np.arange(dist.shape[0]), argmin_]
276
+
277
+ is_pad = min_ > self.tolerance_s
278
+
279
+ # Check violated query timestamps are all outside the episode range.
280
+ assert (
281
+ (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
282
+ ).all(), (
283
+ f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
284
+ ") inside the episode range."
285
+ )
286
+
287
+ # Load frames for this data key.
288
+ item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
289
+
290
+ item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
291
+
292
+ return self._item_to_tensors(item)
293
+
294
+ def get_data_by_key(self, key: str) -> torch.Tensor:
295
+ """Returns all data for a given data key as a Tensor."""
296
+ return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
297
+
298
+
299
+ def compute_sampler_weights(
300
+ offline_dataset: LeRobotDataset,
301
+ offline_drop_n_last_frames: int = 0,
302
+ online_dataset: OnlineBuffer | None = None,
303
+ online_sampling_ratio: float | None = None,
304
+ online_drop_n_last_frames: int = 0,
305
+ ) -> torch.Tensor:
306
+ """Compute the sampling weights for the online training dataloader in train.py.
307
+
308
+ Args:
309
+ offline_dataset: The LeRobotDataset used for offline pre-training.
310
+ online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
311
+ online_dataset: The OnlineBuffer used in online training.
312
+ online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
313
+ online dataset is provided, this value must also be provided.
314
+ online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
315
+ dataset.
316
+ Returns:
317
+ Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
318
+
319
+ Notes to maintainers:
320
+ - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
321
+ - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
322
+ `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
323
+ is the ability to turn shuffling off.
324
+ - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
325
+ included here to avoid adding complexity.
326
+ """
327
+ if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
328
+ raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
329
+ if (online_dataset is None) ^ (online_sampling_ratio is None):
330
+ raise ValueError(
331
+ "`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
332
+ )
333
+ offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
334
+
335
+ weights = []
336
+
337
+ if len(offline_dataset) > 0:
338
+ offline_data_mask_indices = []
339
+ for start_index, end_index in zip(
340
+ offline_dataset.episode_data_index["from"],
341
+ offline_dataset.episode_data_index["to"],
342
+ strict=True,
343
+ ):
344
+ offline_data_mask_indices.extend(
345
+ range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
346
+ )
347
+ offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
348
+ offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
349
+ weights.append(
350
+ torch.full(
351
+ size=(len(offline_dataset),),
352
+ fill_value=offline_sampling_ratio / offline_data_mask.sum(),
353
+ )
354
+ * offline_data_mask
355
+ )
356
+
357
+ if online_dataset is not None and len(online_dataset) > 0:
358
+ online_data_mask_indices = []
359
+ episode_indices = online_dataset.get_data_by_key("episode_index")
360
+ for episode_idx in torch.unique(episode_indices):
361
+ where_episode = torch.where(episode_indices == episode_idx)
362
+ start_index = where_episode[0][0]
363
+ end_index = where_episode[0][-1] + 1
364
+ online_data_mask_indices.extend(
365
+ range(start_index.item(), end_index.item() - online_drop_n_last_frames)
366
+ )
367
+ online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
368
+ online_data_mask[torch.tensor(online_data_mask_indices)] = True
369
+ weights.append(
370
+ torch.full(
371
+ size=(len(online_dataset),),
372
+ fill_value=online_sampling_ratio / online_data_mask.sum(),
373
+ )
374
+ * online_data_mask
375
+ )
376
+
377
+ weights = torch.cat(weights)
378
+
379
+ if weights.sum() == 0:
380
+ weights += 1 / len(weights)
381
+ else:
382
+ weights /= weights.sum()
383
+
384
+ return weights
lerobot/common/datasets/push_dataset_to_hub/utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import inspect
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from pathlib import Path
19
+ from typing import Dict
20
+
21
+ import datasets
22
+ import numpy
23
+ import PIL
24
+ import torch
25
+
26
+ from lerobot.common.datasets.video_utils import encode_video_frames
27
+
28
+
29
+ def concatenate_episodes(ep_dicts):
30
+ data_dict = {}
31
+
32
+ keys = ep_dicts[0].keys()
33
+ for key in keys:
34
+ if torch.is_tensor(ep_dicts[0][key][0]):
35
+ data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
36
+ else:
37
+ if key not in data_dict:
38
+ data_dict[key] = []
39
+ for ep_dict in ep_dicts:
40
+ for x in ep_dict[key]:
41
+ data_dict[key].append(x)
42
+
43
+ total_frames = data_dict["frame_index"].shape[0]
44
+ data_dict["index"] = torch.arange(0, total_frames, 1)
45
+ return data_dict
46
+
47
+
48
+ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
49
+ out_dir = Path(out_dir)
50
+ out_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ def save_image(img_array, i, out_dir):
53
+ img = PIL.Image.fromarray(img_array)
54
+ img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
55
+
56
+ num_images = len(imgs_array)
57
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
58
+ [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
59
+
60
+
61
+ def get_default_encoding() -> dict:
62
+ """Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
63
+ signature = inspect.signature(encode_video_frames)
64
+ return {
65
+ k: v.default
66
+ for k, v in signature.parameters.items()
67
+ if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
68
+ }
69
+
70
+
71
+ def check_repo_id(repo_id: str) -> None:
72
+ if len(repo_id.split("/")) != 2:
73
+ raise ValueError(
74
+ f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
75
+ (e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
76
+ )
77
+
78
+
79
+ # TODO(aliberts): remove
80
+ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
81
+ """
82
+ Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
83
+
84
+ Parameters:
85
+ - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
86
+
87
+ Returns:
88
+ - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
89
+ - "from": A tensor containing the starting index of each episode.
90
+ - "to": A tensor containing the ending index of each episode.
91
+ """
92
+ episode_data_index = {"from": [], "to": []}
93
+
94
+ current_episode = None
95
+ """
96
+ The episode_index is a list of integers, each representing the episode index of the corresponding example.
97
+ For instance, the following is a valid episode_index:
98
+ [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
99
+
100
+ Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
101
+ ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
102
+ {
103
+ "from": [0, 3, 7],
104
+ "to": [3, 7, 12]
105
+ }
106
+ """
107
+ if len(hf_dataset) == 0:
108
+ episode_data_index = {
109
+ "from": torch.tensor([]),
110
+ "to": torch.tensor([]),
111
+ }
112
+ return episode_data_index
113
+ for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
114
+ if episode_idx != current_episode:
115
+ # We encountered a new episode, so we append its starting location to the "from" list
116
+ episode_data_index["from"].append(idx)
117
+ # If this is not the first episode, we append the ending location of the previous episode to the "to" list
118
+ if current_episode is not None:
119
+ episode_data_index["to"].append(idx)
120
+ # Let's keep track of the current episode index
121
+ current_episode = episode_idx
122
+ else:
123
+ # We are still in the same episode, so there is nothing for us to do here
124
+ pass
125
+ # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
126
+ episode_data_index["to"].append(idx + 1)
127
+
128
+ for k in ["from", "to"]:
129
+ episode_data_index[k] = torch.tensor(episode_data_index[k])
130
+
131
+ return episode_data_index
lerobot/common/datasets/sampler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Iterator, Union
17
+
18
+ import torch
19
+
20
+
21
+ class EpisodeAwareSampler:
22
+ def __init__(
23
+ self,
24
+ episode_data_index: dict,
25
+ episode_indices_to_use: Union[list, None] = None,
26
+ drop_n_first_frames: int = 0,
27
+ drop_n_last_frames: int = 0,
28
+ shuffle: bool = False,
29
+ ):
30
+ """Sampler that optionally incorporates episode boundary information.
31
+
32
+ Args:
33
+ episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
34
+ episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
35
+ Assumes that episodes are indexed from 0 to N-1.
36
+ drop_n_first_frames: Number of frames to drop from the start of each episode.
37
+ drop_n_last_frames: Number of frames to drop from the end of each episode.
38
+ shuffle: Whether to shuffle the indices.
39
+ """
40
+ indices = []
41
+ for episode_idx, (start_index, end_index) in enumerate(
42
+ zip(episode_data_index["from"], episode_data_index["to"], strict=True)
43
+ ):
44
+ if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
45
+ indices.extend(
46
+ range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
47
+ )
48
+
49
+ self.indices = indices
50
+ self.shuffle = shuffle
51
+
52
+ def __iter__(self) -> Iterator[int]:
53
+ if self.shuffle:
54
+ for i in torch.randperm(len(self.indices)):
55
+ yield self.indices[i]
56
+ else:
57
+ for i in self.indices:
58
+ yield i
59
+
60
+ def __len__(self) -> int:
61
+ return len(self.indices)
lerobot/common/datasets/transforms.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import collections
17
+ from dataclasses import dataclass, field
18
+ from typing import Any, Callable, Sequence
19
+
20
+ import torch
21
+ from torchvision.transforms import v2
22
+ from torchvision.transforms.v2 import Transform
23
+ from torchvision.transforms.v2 import functional as F # noqa: N812
24
+
25
+
26
+ class RandomSubsetApply(Transform):
27
+ """Apply a random subset of N transformations from a list of transformations.
28
+
29
+ Args:
30
+ transforms: list of transformations.
31
+ p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
32
+ If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
33
+ have the same probability.
34
+ n_subset: number of transformations to apply. If ``None``, all transforms are applied.
35
+ Must be in [1, len(transforms)].
36
+ random_order: apply transformations in a random order.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ transforms: Sequence[Callable],
42
+ p: list[float] | None = None,
43
+ n_subset: int | None = None,
44
+ random_order: bool = False,
45
+ ) -> None:
46
+ super().__init__()
47
+ if not isinstance(transforms, Sequence):
48
+ raise TypeError("Argument transforms should be a sequence of callables")
49
+ if p is None:
50
+ p = [1] * len(transforms)
51
+ elif len(p) != len(transforms):
52
+ raise ValueError(
53
+ f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
54
+ )
55
+
56
+ if n_subset is None:
57
+ n_subset = len(transforms)
58
+ elif not isinstance(n_subset, int):
59
+ raise TypeError("n_subset should be an int or None")
60
+ elif not (1 <= n_subset <= len(transforms)):
61
+ raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
62
+
63
+ self.transforms = transforms
64
+ total = sum(p)
65
+ self.p = [prob / total for prob in p]
66
+ self.n_subset = n_subset
67
+ self.random_order = random_order
68
+
69
+ self.selected_transforms = None
70
+
71
+ def forward(self, *inputs: Any) -> Any:
72
+ needs_unpacking = len(inputs) > 1
73
+
74
+ selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
75
+ if not self.random_order:
76
+ selected_indices = selected_indices.sort().values
77
+
78
+ self.selected_transforms = [self.transforms[i] for i in selected_indices]
79
+
80
+ for transform in self.selected_transforms:
81
+ outputs = transform(*inputs)
82
+ inputs = outputs if needs_unpacking else (outputs,)
83
+
84
+ return outputs
85
+
86
+ def extra_repr(self) -> str:
87
+ return (
88
+ f"transforms={self.transforms}, "
89
+ f"p={self.p}, "
90
+ f"n_subset={self.n_subset}, "
91
+ f"random_order={self.random_order}"
92
+ )
93
+
94
+
95
+ class SharpnessJitter(Transform):
96
+ """Randomly change the sharpness of an image or video.
97
+
98
+ Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
99
+ While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
100
+ SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
101
+ augmentations as a result.
102
+
103
+ A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
104
+ by a factor of 2.
105
+
106
+ If the input is a :class:`torch.Tensor`,
107
+ it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
108
+
109
+ Args:
110
+ sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
111
+ [max(0, 1 - sharpness), 1 + sharpness] or the given
112
+ [min, max]. Should be non negative numbers.
113
+ """
114
+
115
+ def __init__(self, sharpness: float | Sequence[float]) -> None:
116
+ super().__init__()
117
+ self.sharpness = self._check_input(sharpness)
118
+
119
+ def _check_input(self, sharpness):
120
+ if isinstance(sharpness, (int, float)):
121
+ if sharpness < 0:
122
+ raise ValueError("If sharpness is a single number, it must be non negative.")
123
+ sharpness = [1.0 - sharpness, 1.0 + sharpness]
124
+ sharpness[0] = max(sharpness[0], 0.0)
125
+ elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
126
+ sharpness = [float(v) for v in sharpness]
127
+ else:
128
+ raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
129
+
130
+ if not 0.0 <= sharpness[0] <= sharpness[1]:
131
+ raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
132
+
133
+ return float(sharpness[0]), float(sharpness[1])
134
+
135
+ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
136
+ sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
137
+ return {"sharpness_factor": sharpness_factor}
138
+
139
+ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
140
+ sharpness_factor = params["sharpness_factor"]
141
+ return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
142
+
143
+
144
+ @dataclass
145
+ class ImageTransformConfig:
146
+ """
147
+ For each transform, the following parameters are available:
148
+ weight: This represents the multinomial probability (with no replacement)
149
+ used for sampling the transform. If the sum of the weights is not 1,
150
+ they will be normalized.
151
+ type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
152
+ custom transform defined here.
153
+ kwargs: Lower & upper bound respectively used for sampling the transform's parameter
154
+ (following uniform distribution) when it's applied.
155
+ """
156
+
157
+ weight: float = 1.0
158
+ type: str = "Identity"
159
+ kwargs: dict[str, Any] = field(default_factory=dict)
160
+
161
+
162
+ @dataclass
163
+ class ImageTransformsConfig:
164
+ """
165
+ These transforms are all using standard torchvision.transforms.v2
166
+ You can find out how these transformations affect images here:
167
+ https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
168
+ We use a custom RandomSubsetApply container to sample them.
169
+ """
170
+
171
+ # Set this flag to `true` to enable transforms during training
172
+ enable: bool = False
173
+ # This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
174
+ # It's an integer in the interval [1, number_of_available_transforms].
175
+ max_num_transforms: int = 3
176
+ # By default, transforms are applied in Torchvision's suggested order (shown below).
177
+ # Set this to True to apply them in a random order.
178
+ random_order: bool = False
179
+ tfs: dict[str, ImageTransformConfig] = field(
180
+ default_factory=lambda: {
181
+ "brightness": ImageTransformConfig(
182
+ weight=1.0,
183
+ type="ColorJitter",
184
+ kwargs={"brightness": (0.8, 1.2)},
185
+ ),
186
+ "contrast": ImageTransformConfig(
187
+ weight=1.0,
188
+ type="ColorJitter",
189
+ kwargs={"contrast": (0.8, 1.2)},
190
+ ),
191
+ "saturation": ImageTransformConfig(
192
+ weight=1.0,
193
+ type="ColorJitter",
194
+ kwargs={"saturation": (0.5, 1.5)},
195
+ ),
196
+ "hue": ImageTransformConfig(
197
+ weight=1.0,
198
+ type="ColorJitter",
199
+ kwargs={"hue": (-0.05, 0.05)},
200
+ ),
201
+ "sharpness": ImageTransformConfig(
202
+ weight=1.0,
203
+ type="SharpnessJitter",
204
+ kwargs={"sharpness": (0.5, 1.5)},
205
+ ),
206
+ }
207
+ )
208
+
209
+
210
+ def make_transform_from_config(cfg: ImageTransformConfig):
211
+ if cfg.type == "Identity":
212
+ return v2.Identity(**cfg.kwargs)
213
+ elif cfg.type == "ColorJitter":
214
+ return v2.ColorJitter(**cfg.kwargs)
215
+ elif cfg.type == "SharpnessJitter":
216
+ return SharpnessJitter(**cfg.kwargs)
217
+ else:
218
+ raise ValueError(f"Transform '{cfg.type}' is not valid.")
219
+
220
+
221
+ class ImageTransforms(Transform):
222
+ """A class to compose image transforms based on configuration."""
223
+
224
+ def __init__(self, cfg: ImageTransformsConfig) -> None:
225
+ super().__init__()
226
+ self._cfg = cfg
227
+
228
+ self.weights = []
229
+ self.transforms = {}
230
+ for tf_name, tf_cfg in cfg.tfs.items():
231
+ if tf_cfg.weight <= 0.0:
232
+ continue
233
+
234
+ self.transforms[tf_name] = make_transform_from_config(tf_cfg)
235
+ self.weights.append(tf_cfg.weight)
236
+
237
+ n_subset = min(len(self.transforms), cfg.max_num_transforms)
238
+ if n_subset == 0 or not cfg.enable:
239
+ self.tf = v2.Identity()
240
+ else:
241
+ self.tf = RandomSubsetApply(
242
+ transforms=list(self.transforms.values()),
243
+ p=self.weights,
244
+ n_subset=n_subset,
245
+ random_order=cfg.random_order,
246
+ )
247
+
248
+ def forward(self, *inputs: Any) -> Any:
249
+ return self.tf(*inputs)
lerobot/common/datasets/utils.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import contextlib
17
+ import importlib.resources
18
+ import json
19
+ import logging
20
+ from collections.abc import Iterator
21
+ from itertools import accumulate
22
+ from pathlib import Path
23
+ from pprint import pformat
24
+ from types import SimpleNamespace
25
+ from typing import Any
26
+
27
+ import datasets
28
+ import jsonlines
29
+ import numpy as np
30
+ import packaging.version
31
+ import torch
32
+ from datasets.table import embed_table_storage
33
+ from huggingface_hub import DatasetCard, DatasetCardData, HfApi
34
+ from huggingface_hub.errors import RevisionNotFoundError
35
+ from PIL import Image as PILImage
36
+ from torchvision import transforms
37
+
38
+ from lerobot.common.datasets.backward_compatibility import (
39
+ V21_MESSAGE,
40
+ BackwardCompatibilityError,
41
+ ForwardCompatibilityError,
42
+ )
43
+ from lerobot.common.robot_devices.robots.utils import Robot
44
+ from lerobot.common.utils.utils import is_valid_numpy_dtype_string
45
+ from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
46
+
47
+ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
48
+
49
+ INFO_PATH = "meta/info.json"
50
+ EPISODES_PATH = "meta/episodes.jsonl"
51
+ STATS_PATH = "meta/stats.json"
52
+ EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
53
+ TASKS_PATH = "meta/tasks.jsonl"
54
+
55
+ DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
56
+ DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
57
+ DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
58
+
59
+ DATASET_CARD_TEMPLATE = """
60
+ ---
61
+ # Metadata will go there
62
+ ---
63
+ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
64
+
65
+ ## {}
66
+
67
+ """
68
+
69
+ DEFAULT_FEATURES = {
70
+ "timestamp": {"dtype": "float32", "shape": (1,), "names": None},
71
+ "frame_index": {"dtype": "int64", "shape": (1,), "names": None},
72
+ "episode_index": {"dtype": "int64", "shape": (1,), "names": None},
73
+ "index": {"dtype": "int64", "shape": (1,), "names": None},
74
+ "task_index": {"dtype": "int64", "shape": (1,), "names": None},
75
+ }
76
+
77
+
78
+ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
79
+ """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
80
+
81
+ For example:
82
+ ```
83
+ >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
84
+ >>> print(flatten_dict(dct))
85
+ {"a/b": 1, "a/c/d": 2, "e": 3}
86
+ """
87
+ items = []
88
+ for k, v in d.items():
89
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
90
+ if isinstance(v, dict):
91
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
92
+ else:
93
+ items.append((new_key, v))
94
+ return dict(items)
95
+
96
+
97
+ def unflatten_dict(d: dict, sep: str = "/") -> dict:
98
+ outdict = {}
99
+ for key, value in d.items():
100
+ parts = key.split(sep)
101
+ d = outdict
102
+ for part in parts[:-1]:
103
+ if part not in d:
104
+ d[part] = {}
105
+ d = d[part]
106
+ d[parts[-1]] = value
107
+ return outdict
108
+
109
+
110
+ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
111
+ split_keys = flattened_key.split(sep)
112
+ getter = obj[split_keys[0]]
113
+ if len(split_keys) == 1:
114
+ return getter
115
+
116
+ for key in split_keys[1:]:
117
+ getter = getter[key]
118
+
119
+ return getter
120
+
121
+
122
+ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
123
+ serialized_dict = {}
124
+ for key, value in flatten_dict(stats).items():
125
+ if isinstance(value, (torch.Tensor, np.ndarray)):
126
+ serialized_dict[key] = value.tolist()
127
+ elif isinstance(value, np.generic):
128
+ serialized_dict[key] = value.item()
129
+ elif isinstance(value, (int, float)):
130
+ serialized_dict[key] = value
131
+ else:
132
+ raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
133
+ return unflatten_dict(serialized_dict)
134
+
135
+
136
+ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
137
+ # Embed image bytes into the table before saving to parquet
138
+ format = dataset.format
139
+ dataset = dataset.with_format("arrow")
140
+ dataset = dataset.map(embed_table_storage, batched=False)
141
+ dataset = dataset.with_format(**format)
142
+ return dataset
143
+
144
+
145
+ def load_json(fpath: Path) -> Any:
146
+ with open(fpath) as f:
147
+ return json.load(f)
148
+
149
+
150
+ def write_json(data: dict, fpath: Path) -> None:
151
+ fpath.parent.mkdir(exist_ok=True, parents=True)
152
+ with open(fpath, "w") as f:
153
+ json.dump(data, f, indent=4, ensure_ascii=False)
154
+
155
+
156
+ def load_jsonlines(fpath: Path) -> list[Any]:
157
+ with jsonlines.open(fpath, "r") as reader:
158
+ return list(reader)
159
+
160
+
161
+ def write_jsonlines(data: dict, fpath: Path) -> None:
162
+ fpath.parent.mkdir(exist_ok=True, parents=True)
163
+ with jsonlines.open(fpath, "w") as writer:
164
+ writer.write_all(data)
165
+
166
+
167
+ def append_jsonlines(data: dict, fpath: Path) -> None:
168
+ fpath.parent.mkdir(exist_ok=True, parents=True)
169
+ with jsonlines.open(fpath, "a") as writer:
170
+ writer.write(data)
171
+
172
+
173
+ def write_info(info: dict, local_dir: Path):
174
+ write_json(info, local_dir / INFO_PATH)
175
+
176
+
177
+ def load_info(local_dir: Path) -> dict:
178
+ info = load_json(local_dir / INFO_PATH)
179
+ for ft in info["features"].values():
180
+ ft["shape"] = tuple(ft["shape"])
181
+ return info
182
+
183
+
184
+ def write_stats(stats: dict, local_dir: Path):
185
+ serialized_stats = serialize_dict(stats)
186
+ write_json(serialized_stats, local_dir / STATS_PATH)
187
+
188
+
189
+ def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
190
+ stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
191
+ return unflatten_dict(stats)
192
+
193
+
194
+ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
195
+ if not (local_dir / STATS_PATH).exists():
196
+ return None
197
+ stats = load_json(local_dir / STATS_PATH)
198
+ return cast_stats_to_numpy(stats)
199
+
200
+
201
+ def write_task(task_index: int, task: dict, local_dir: Path):
202
+ task_dict = {
203
+ "task_index": task_index,
204
+ "task": task,
205
+ }
206
+ append_jsonlines(task_dict, local_dir / TASKS_PATH)
207
+
208
+
209
+ def load_tasks(local_dir: Path) -> tuple[dict, dict]:
210
+ tasks = load_jsonlines(local_dir / TASKS_PATH)
211
+ tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
212
+ task_to_task_index = {task: task_index for task_index, task in tasks.items()}
213
+ return tasks, task_to_task_index
214
+
215
+
216
+ def write_episode(episode: dict, local_dir: Path):
217
+ append_jsonlines(episode, local_dir / EPISODES_PATH)
218
+
219
+
220
+ def load_episodes(local_dir: Path) -> dict:
221
+ episodes = load_jsonlines(local_dir / EPISODES_PATH)
222
+ return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
223
+
224
+
225
+ def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
226
+ # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
227
+ # is a dictionary of stats and not an integer.
228
+ episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
229
+ append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
230
+
231
+
232
+ def load_episodes_stats(local_dir: Path) -> dict:
233
+ episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
234
+ return {
235
+ item["episode_index"]: cast_stats_to_numpy(item["stats"])
236
+ for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
237
+ }
238
+
239
+
240
+ def backward_compatible_episodes_stats(
241
+ stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
242
+ ) -> dict[str, dict[str, np.ndarray]]:
243
+ return dict.fromkeys(episodes, stats)
244
+
245
+
246
+ def load_image_as_numpy(
247
+ fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
248
+ ) -> np.ndarray:
249
+ img = PILImage.open(fpath).convert("RGB")
250
+ img_array = np.array(img, dtype=dtype)
251
+ if channel_first: # (H, W, C) -> (C, H, W)
252
+ img_array = np.transpose(img_array, (2, 0, 1))
253
+ if np.issubdtype(dtype, np.floating):
254
+ img_array /= 255.0
255
+ return img_array
256
+
257
+
258
+ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
259
+ """Get a transform function that convert items from Hugging Face dataset (pyarrow)
260
+ to torch tensors. Importantly, images are converted from PIL, which corresponds to
261
+ a channel last representation (h w c) of uint8 type, to a torch image representation
262
+ with channel first (c h w) of float32 type in range [0,1].
263
+ """
264
+ for key in items_dict:
265
+ first_item = items_dict[key][0]
266
+ if isinstance(first_item, PILImage.Image):
267
+ to_tensor = transforms.ToTensor()
268
+ items_dict[key] = [to_tensor(img) for img in items_dict[key]]
269
+ elif first_item is None:
270
+ pass
271
+ else:
272
+ items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
273
+ return items_dict
274
+
275
+
276
+ def is_valid_version(version: str) -> bool:
277
+ try:
278
+ packaging.version.parse(version)
279
+ return True
280
+ except packaging.version.InvalidVersion:
281
+ return False
282
+
283
+
284
+ def check_version_compatibility(
285
+ repo_id: str,
286
+ version_to_check: str | packaging.version.Version,
287
+ current_version: str | packaging.version.Version,
288
+ enforce_breaking_major: bool = True,
289
+ ) -> None:
290
+ v_check = (
291
+ packaging.version.parse(version_to_check)
292
+ if not isinstance(version_to_check, packaging.version.Version)
293
+ else version_to_check
294
+ )
295
+ v_current = (
296
+ packaging.version.parse(current_version)
297
+ if not isinstance(current_version, packaging.version.Version)
298
+ else current_version
299
+ )
300
+ if v_check.major < v_current.major and enforce_breaking_major:
301
+ raise BackwardCompatibilityError(repo_id, v_check)
302
+ elif v_check.minor < v_current.minor:
303
+ logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
304
+
305
+
306
+ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
307
+ """Returns available valid versions (branches and tags) on given repo."""
308
+ api = HfApi()
309
+ repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
310
+ repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
311
+ repo_versions = []
312
+ for ref in repo_refs:
313
+ with contextlib.suppress(packaging.version.InvalidVersion):
314
+ repo_versions.append(packaging.version.parse(ref))
315
+
316
+ return repo_versions
317
+
318
+
319
+ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
320
+ """
321
+ Returns the version if available on repo or the latest compatible one.
322
+ Otherwise, will throw a `CompatibilityError`.
323
+ """
324
+ target_version = (
325
+ packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
326
+ )
327
+ hub_versions = get_repo_versions(repo_id)
328
+
329
+ if not hub_versions:
330
+ raise RevisionNotFoundError(
331
+ f"""Your dataset must be tagged with a codebase version.
332
+ Assuming _version_ is the codebase_version value in the info.json, you can run this:
333
+ ```python
334
+ from huggingface_hub import HfApi
335
+
336
+ hub_api = HfApi()
337
+ hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
338
+ ```
339
+ """
340
+ )
341
+
342
+ if target_version in hub_versions:
343
+ return f"v{target_version}"
344
+
345
+ compatibles = [
346
+ v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
347
+ ]
348
+ if compatibles:
349
+ return_version = max(compatibles)
350
+ if return_version < target_version:
351
+ logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
352
+ return f"v{return_version}"
353
+
354
+ lower_major = [v for v in hub_versions if v.major < target_version.major]
355
+ if lower_major:
356
+ raise BackwardCompatibilityError(repo_id, max(lower_major))
357
+
358
+ upper_versions = [v for v in hub_versions if v > target_version]
359
+ assert len(upper_versions) > 0
360
+ raise ForwardCompatibilityError(repo_id, min(upper_versions))
361
+
362
+
363
+ def get_hf_features_from_features(features: dict) -> datasets.Features:
364
+ hf_features = {}
365
+ for key, ft in features.items():
366
+ if ft["dtype"] == "video":
367
+ continue
368
+ elif ft["dtype"] == "image":
369
+ hf_features[key] = datasets.Image()
370
+ elif ft["shape"] == (1,):
371
+ hf_features[key] = datasets.Value(dtype=ft["dtype"])
372
+ elif len(ft["shape"]) == 1:
373
+ hf_features[key] = datasets.Sequence(
374
+ length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
375
+ )
376
+ elif len(ft["shape"]) == 2:
377
+ hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
378
+ elif len(ft["shape"]) == 3:
379
+ hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
380
+ elif len(ft["shape"]) == 4:
381
+ hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
382
+ elif len(ft["shape"]) == 5:
383
+ hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
384
+ else:
385
+ raise ValueError(f"Corresponding feature is not valid: {ft}")
386
+
387
+ return datasets.Features(hf_features)
388
+
389
+
390
+ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
391
+ camera_ft = {}
392
+ if robot.cameras:
393
+ camera_ft = {
394
+ key: {"dtype": "video" if use_videos else "image", **ft}
395
+ for key, ft in robot.camera_features.items()
396
+ }
397
+ return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
398
+
399
+
400
+ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
401
+ # TODO(aliberts): Implement "type" in dataset features and simplify this
402
+ policy_features = {}
403
+ for key, ft in features.items():
404
+ shape = ft["shape"]
405
+ if ft["dtype"] in ["image", "video"]:
406
+ type = FeatureType.VISUAL
407
+ if len(shape) != 3:
408
+ raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
409
+
410
+ names = ft["names"]
411
+ # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
412
+ if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
413
+ shape = (shape[2], shape[0], shape[1])
414
+ elif key == "observation.environment_state":
415
+ type = FeatureType.ENV
416
+ elif key.startswith("observation"):
417
+ type = FeatureType.STATE
418
+ elif key == "action":
419
+ type = FeatureType.ACTION
420
+ else:
421
+ continue
422
+
423
+ policy_features[key] = PolicyFeature(
424
+ type=type,
425
+ shape=shape,
426
+ )
427
+
428
+ return policy_features
429
+
430
+
431
+ def create_empty_dataset_info(
432
+ codebase_version: str,
433
+ fps: int,
434
+ robot_type: str,
435
+ features: dict,
436
+ use_videos: bool,
437
+ ) -> dict:
438
+ return {
439
+ "codebase_version": codebase_version,
440
+ "robot_type": robot_type,
441
+ "total_episodes": 0,
442
+ "total_frames": 0,
443
+ "total_tasks": 0,
444
+ "total_videos": 0,
445
+ "total_chunks": 0,
446
+ "chunks_size": DEFAULT_CHUNK_SIZE,
447
+ "fps": fps,
448
+ "splits": {},
449
+ "data_path": DEFAULT_PARQUET_PATH,
450
+ "video_path": DEFAULT_VIDEO_PATH if use_videos else None,
451
+ "features": features,
452
+ }
453
+
454
+
455
+ def get_episode_data_index(
456
+ episode_dicts: dict[dict], episodes: list[int] | None = None
457
+ ) -> dict[str, torch.Tensor]:
458
+ episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
459
+ if episodes is not None:
460
+ episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
461
+
462
+ cumulative_lengths = list(accumulate(episode_lengths.values()))
463
+ return {
464
+ "from": torch.LongTensor([0] + cumulative_lengths[:-1]),
465
+ "to": torch.LongTensor(cumulative_lengths),
466
+ }
467
+
468
+
469
+ def check_timestamps_sync(
470
+ timestamps: np.ndarray,
471
+ episode_indices: np.ndarray,
472
+ episode_data_index: dict[str, np.ndarray],
473
+ fps: int,
474
+ tolerance_s: float,
475
+ raise_value_error: bool = True,
476
+ ) -> bool:
477
+ """
478
+ This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
479
+ to account for possible numerical error.
480
+
481
+ Args:
482
+ timestamps (np.ndarray): Array of timestamps in seconds.
483
+ episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
484
+ episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
485
+ which identifies indices for the end of each episode.
486
+ fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
487
+ tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
488
+ raise_value_error (bool): Whether to raise a ValueError if the check fails.
489
+
490
+ Returns:
491
+ bool: True if all checked timestamp differences lie within tolerance, False otherwise.
492
+
493
+ Raises:
494
+ ValueError: If the check fails and `raise_value_error` is True.
495
+ """
496
+ if timestamps.shape != episode_indices.shape:
497
+ raise ValueError(
498
+ "timestamps and episode_indices should have the same shape. "
499
+ f"Found {timestamps.shape=} and {episode_indices.shape=}."
500
+ )
501
+
502
+ # Consecutive differences
503
+ diffs = np.diff(timestamps)
504
+ within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
505
+
506
+ # Mask to ignore differences at the boundaries between episodes
507
+ mask = np.ones(len(diffs), dtype=bool)
508
+ ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
509
+ mask[ignored_diffs] = False
510
+ filtered_within_tolerance = within_tolerance[mask]
511
+
512
+ # Check if all remaining diffs are within tolerance
513
+ if not np.all(filtered_within_tolerance):
514
+ # Track original indices before masking
515
+ original_indices = np.arange(len(diffs))
516
+ filtered_indices = original_indices[mask]
517
+ outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
518
+ outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
519
+
520
+ outside_tolerances = []
521
+ for idx in outside_tolerance_indices:
522
+ entry = {
523
+ "timestamps": [timestamps[idx], timestamps[idx + 1]],
524
+ "diff": diffs[idx],
525
+ "episode_index": episode_indices[idx].item()
526
+ if hasattr(episode_indices[idx], "item")
527
+ else episode_indices[idx],
528
+ }
529
+ outside_tolerances.append(entry)
530
+
531
+ if raise_value_error:
532
+ raise ValueError(
533
+ f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
534
+ This might be due to synchronization issues during data collection.
535
+ \n{pformat(outside_tolerances)}"""
536
+ )
537
+ return False
538
+
539
+ return True
540
+
541
+
542
+ def check_delta_timestamps(
543
+ delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
544
+ ) -> bool:
545
+ """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
546
+ This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
547
+ actual timestamps from the dataset.
548
+ """
549
+ outside_tolerance = {}
550
+ for key, delta_ts in delta_timestamps.items():
551
+ within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
552
+ if not all(within_tolerance):
553
+ outside_tolerance[key] = [
554
+ ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
555
+ ]
556
+
557
+ if len(outside_tolerance) > 0:
558
+ if raise_value_error:
559
+ raise ValueError(
560
+ f"""
561
+ The following delta_timestamps are found outside of tolerance range.
562
+ Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
563
+ their values accordingly.
564
+ \n{pformat(outside_tolerance)}
565
+ """
566
+ )
567
+ return False
568
+
569
+ return True
570
+
571
+
572
+ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
573
+ delta_indices = {}
574
+ for key, delta_ts in delta_timestamps.items():
575
+ delta_indices[key] = [round(d * fps) for d in delta_ts]
576
+
577
+ return delta_indices
578
+
579
+
580
+ def cycle(iterable):
581
+ """The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
582
+
583
+ See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
584
+ """
585
+ iterator = iter(iterable)
586
+ while True:
587
+ try:
588
+ yield next(iterator)
589
+ except StopIteration:
590
+ iterator = iter(iterable)
591
+
592
+
593
+ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
594
+ """Create a branch on a existing Hugging Face repo. Delete the branch if it already
595
+ exists before creating it.
596
+ """
597
+ api = HfApi()
598
+
599
+ branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
600
+ refs = [branch.ref for branch in branches]
601
+ ref = f"refs/heads/{branch}"
602
+ if ref in refs:
603
+ api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
604
+
605
+ api.create_branch(repo_id, repo_type=repo_type, branch=branch)
606
+
607
+
608
+ def create_lerobot_dataset_card(
609
+ tags: list | None = None,
610
+ dataset_info: dict | None = None,
611
+ **kwargs,
612
+ ) -> DatasetCard:
613
+ """
614
+ Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md.
615
+ Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
616
+ """
617
+ card_tags = ["LeRobot"]
618
+
619
+ if tags:
620
+ card_tags += tags
621
+ if dataset_info:
622
+ dataset_structure = "[meta/info.json](meta/info.json):\n"
623
+ dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
624
+ kwargs = {**kwargs, "dataset_structure": dataset_structure}
625
+ card_data = DatasetCardData(
626
+ license=kwargs.get("license"),
627
+ tags=card_tags,
628
+ task_categories=["robotics"],
629
+ configs=[
630
+ {
631
+ "config_name": "default",
632
+ "data_files": "data/*/*.parquet",
633
+ }
634
+ ],
635
+ )
636
+
637
+ card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
638
+
639
+ return DatasetCard.from_template(
640
+ card_data=card_data,
641
+ template_str=card_template,
642
+ **kwargs,
643
+ )
644
+
645
+
646
+ class IterableNamespace(SimpleNamespace):
647
+ """
648
+ A namespace object that supports both dictionary-like iteration and dot notation access.
649
+ Automatically converts nested dictionaries into IterableNamespaces.
650
+
651
+ This class extends SimpleNamespace to provide:
652
+ - Dictionary-style iteration over keys
653
+ - Access to items via both dot notation (obj.key) and brackets (obj["key"])
654
+ - Dictionary-like methods: items(), keys(), values()
655
+ - Recursive conversion of nested dictionaries
656
+
657
+ Args:
658
+ dictionary: Optional dictionary to initialize the namespace
659
+ **kwargs: Additional keyword arguments passed to SimpleNamespace
660
+
661
+ Examples:
662
+ >>> data = {"name": "Alice", "details": {"age": 25}}
663
+ >>> ns = IterableNamespace(data)
664
+ >>> ns.name
665
+ 'Alice'
666
+ >>> ns.details.age
667
+ 25
668
+ >>> list(ns.keys())
669
+ ['name', 'details']
670
+ >>> for key, value in ns.items():
671
+ ... print(f"{key}: {value}")
672
+ name: Alice
673
+ details: IterableNamespace(age=25)
674
+ """
675
+
676
+ def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
677
+ super().__init__(**kwargs)
678
+ if dictionary is not None:
679
+ for key, value in dictionary.items():
680
+ if isinstance(value, dict):
681
+ setattr(self, key, IterableNamespace(value))
682
+ else:
683
+ setattr(self, key, value)
684
+
685
+ def __iter__(self) -> Iterator[str]:
686
+ return iter(vars(self))
687
+
688
+ def __getitem__(self, key: str) -> Any:
689
+ return vars(self)[key]
690
+
691
+ def items(self):
692
+ return vars(self).items()
693
+
694
+ def values(self):
695
+ return vars(self).values()
696
+
697
+ def keys(self):
698
+ return vars(self).keys()
699
+
700
+
701
+ def validate_frame(frame: dict, features: dict):
702
+ optional_features = {"timestamp"}
703
+ expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
704
+ actual_features = set(frame.keys())
705
+
706
+ error_message = validate_features_presence(actual_features, expected_features, optional_features)
707
+
708
+ if "task" in frame:
709
+ error_message += validate_feature_string("task", frame["task"])
710
+
711
+ common_features = actual_features & (expected_features | optional_features)
712
+ for name in common_features - {"task"}:
713
+ error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
714
+
715
+ if error_message:
716
+ raise ValueError(error_message)
717
+
718
+
719
+ def validate_features_presence(
720
+ actual_features: set[str], expected_features: set[str], optional_features: set[str]
721
+ ):
722
+ error_message = ""
723
+ missing_features = expected_features - actual_features
724
+ extra_features = actual_features - (expected_features | optional_features)
725
+
726
+ if missing_features or extra_features:
727
+ error_message += "Feature mismatch in `frame` dictionary:\n"
728
+ if missing_features:
729
+ error_message += f"Missing features: {missing_features}\n"
730
+ if extra_features:
731
+ error_message += f"Extra features: {extra_features}\n"
732
+
733
+ return error_message
734
+
735
+
736
+ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
737
+ expected_dtype = feature["dtype"]
738
+ expected_shape = feature["shape"]
739
+ if is_valid_numpy_dtype_string(expected_dtype):
740
+ return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
741
+ elif expected_dtype in ["image", "video"]:
742
+ return validate_feature_image_or_video(name, expected_shape, value)
743
+ elif expected_dtype == "string":
744
+ return validate_feature_string(name, value)
745
+ else:
746
+ raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
747
+
748
+
749
+ def validate_feature_numpy_array(
750
+ name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
751
+ ):
752
+ error_message = ""
753
+ if isinstance(value, np.ndarray):
754
+ actual_dtype = value.dtype
755
+ actual_shape = value.shape
756
+
757
+ if actual_dtype != np.dtype(expected_dtype):
758
+ error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
759
+
760
+ if actual_shape != expected_shape:
761
+ error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
762
+ else:
763
+ error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
764
+
765
+ return error_message
766
+
767
+
768
+ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
769
+ # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
770
+ error_message = ""
771
+ if isinstance(value, np.ndarray):
772
+ actual_shape = value.shape
773
+ c, h, w = expected_shape
774
+ if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
775
+ error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
776
+ elif isinstance(value, PILImage.Image):
777
+ pass
778
+ else:
779
+ error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
780
+
781
+ return error_message
782
+
783
+
784
+ def validate_feature_string(name: str, value: str):
785
+ if not isinstance(value, str):
786
+ return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
787
+ return ""
788
+
789
+
790
+ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
791
+ if "size" not in episode_buffer:
792
+ raise ValueError("size key not found in episode_buffer")
793
+
794
+ if "task" not in episode_buffer:
795
+ raise ValueError("task key not found in episode_buffer")
796
+
797
+ if episode_buffer["episode_index"] != total_episodes:
798
+ # TODO(aliberts): Add option to use existing episode_index
799
+ raise NotImplementedError(
800
+ "You might have manually provided the episode_buffer with an episode_index that doesn't "
801
+ "match the total number of episodes already in the dataset. This is not supported for now."
802
+ )
803
+
804
+ if episode_buffer["size"] == 0:
805
+ raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
806
+
807
+ buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
808
+ if not buffer_keys == set(features):
809
+ raise ValueError(
810
+ f"Features from `episode_buffer` don't match the ones in `features`."
811
+ f"In episode_buffer not in features: {buffer_keys - set(features)}"
812
+ f"In features not in episode_buffer: {set(features) - buffer_keys}"
813
+ )
lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py ADDED
@@ -0,0 +1,884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
19
+
20
+ Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
21
+ lerobot/configs/robot/aloha.yaml before running this script.
22
+ """
23
+
24
+ import traceback
25
+ from pathlib import Path
26
+ from textwrap import dedent
27
+
28
+ from lerobot import available_datasets
29
+ from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
30
+ from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
31
+
32
+ LOCAL_DIR = Path("data/")
33
+
34
+ # spellchecker:off
35
+ ALOHA_MOBILE_INFO = {
36
+ "robot_config": AlohaRobotConfig(),
37
+ "license": "mit",
38
+ "url": "https://mobile-aloha.github.io/",
39
+ "paper": "https://arxiv.org/abs/2401.02117",
40
+ "citation_bibtex": dedent(r"""
41
+ @inproceedings{fu2024mobile,
42
+ author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
43
+ title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
44
+ booktitle = {arXiv},
45
+ year = {2024},
46
+ }""").lstrip(),
47
+ }
48
+ ALOHA_STATIC_INFO = {
49
+ "robot_config": AlohaRobotConfig(),
50
+ "license": "mit",
51
+ "url": "https://tonyzhaozh.github.io/aloha/",
52
+ "paper": "https://arxiv.org/abs/2304.13705",
53
+ "citation_bibtex": dedent(r"""
54
+ @article{Zhao2023LearningFB,
55
+ title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
56
+ author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
57
+ journal={RSS},
58
+ year={2023},
59
+ volume={abs/2304.13705},
60
+ url={https://arxiv.org/abs/2304.13705}
61
+ }""").lstrip(),
62
+ }
63
+ PUSHT_INFO = {
64
+ "license": "mit",
65
+ "url": "https://diffusion-policy.cs.columbia.edu/",
66
+ "paper": "https://arxiv.org/abs/2303.04137v5",
67
+ "citation_bibtex": dedent(r"""
68
+ @article{chi2024diffusionpolicy,
69
+ author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
70
+ title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
71
+ journal = {The International Journal of Robotics Research},
72
+ year = {2024},
73
+ }""").lstrip(),
74
+ }
75
+ XARM_INFO = {
76
+ "license": "mit",
77
+ "url": "https://www.nicklashansen.com/td-mpc/",
78
+ "paper": "https://arxiv.org/abs/2203.04955",
79
+ "citation_bibtex": dedent(r"""
80
+ @inproceedings{Hansen2022tdmpc,
81
+ title={Temporal Difference Learning for Model Predictive Control},
82
+ author={Nicklas Hansen and Xiaolong Wang and Hao Su},
83
+ booktitle={ICML},
84
+ year={2022}
85
+ }
86
+ """),
87
+ }
88
+ UNITREEH_INFO = {
89
+ "license": "apache-2.0",
90
+ }
91
+
92
+ DATASETS = {
93
+ "aloha_mobile_cabinet": {
94
+ "single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
95
+ **ALOHA_MOBILE_INFO,
96
+ },
97
+ "aloha_mobile_chair": {
98
+ "single_task": "Push the chairs in front of the desk to place them against it.",
99
+ **ALOHA_MOBILE_INFO,
100
+ },
101
+ "aloha_mobile_elevator": {
102
+ "single_task": "Take the elevator to the 1st floor.",
103
+ **ALOHA_MOBILE_INFO,
104
+ },
105
+ "aloha_mobile_shrimp": {
106
+ "single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
107
+ **ALOHA_MOBILE_INFO,
108
+ },
109
+ "aloha_mobile_wash_pan": {
110
+ "single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
111
+ **ALOHA_MOBILE_INFO,
112
+ },
113
+ "aloha_mobile_wipe_wine": {
114
+ "single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
115
+ **ALOHA_MOBILE_INFO,
116
+ },
117
+ "aloha_static_battery": {
118
+ "single_task": "Place the battery into the slot of the remote controller.",
119
+ **ALOHA_STATIC_INFO,
120
+ },
121
+ "aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
122
+ "aloha_static_coffee": {
123
+ "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
124
+ **ALOHA_STATIC_INFO,
125
+ },
126
+ "aloha_static_coffee_new": {
127
+ "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
128
+ **ALOHA_STATIC_INFO,
129
+ },
130
+ "aloha_static_cups_open": {
131
+ "single_task": "Pick up the plastic cup and open its lid.",
132
+ **ALOHA_STATIC_INFO,
133
+ },
134
+ "aloha_static_fork_pick_up": {
135
+ "single_task": "Pick up the fork and place it on the plate.",
136
+ **ALOHA_STATIC_INFO,
137
+ },
138
+ "aloha_static_pingpong_test": {
139
+ "single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
140
+ **ALOHA_STATIC_INFO,
141
+ },
142
+ "aloha_static_pro_pencil": {
143
+ "single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
144
+ **ALOHA_STATIC_INFO,
145
+ },
146
+ "aloha_static_screw_driver": {
147
+ "single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
148
+ **ALOHA_STATIC_INFO,
149
+ },
150
+ "aloha_static_tape": {
151
+ "single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
152
+ **ALOHA_STATIC_INFO,
153
+ },
154
+ "aloha_static_thread_velcro": {
155
+ "single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
156
+ **ALOHA_STATIC_INFO,
157
+ },
158
+ "aloha_static_towel": {
159
+ "single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
160
+ **ALOHA_STATIC_INFO,
161
+ },
162
+ "aloha_static_vinh_cup": {
163
+ "single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
164
+ **ALOHA_STATIC_INFO,
165
+ },
166
+ "aloha_static_vinh_cup_left": {
167
+ "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
168
+ **ALOHA_STATIC_INFO,
169
+ },
170
+ "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
171
+ "aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
172
+ "aloha_sim_insertion_scripted_image": {
173
+ "single_task": "Insert the peg into the socket.",
174
+ **ALOHA_STATIC_INFO,
175
+ },
176
+ "aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
177
+ "aloha_sim_insertion_human_image": {
178
+ "single_task": "Insert the peg into the socket.",
179
+ **ALOHA_STATIC_INFO,
180
+ },
181
+ "aloha_sim_transfer_cube_scripted": {
182
+ "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
183
+ **ALOHA_STATIC_INFO,
184
+ },
185
+ "aloha_sim_transfer_cube_scripted_image": {
186
+ "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
187
+ **ALOHA_STATIC_INFO,
188
+ },
189
+ "aloha_sim_transfer_cube_human": {
190
+ "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
191
+ **ALOHA_STATIC_INFO,
192
+ },
193
+ "aloha_sim_transfer_cube_human_image": {
194
+ "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
195
+ **ALOHA_STATIC_INFO,
196
+ },
197
+ "pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
198
+ "pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
199
+ "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
200
+ "unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
201
+ "unitreeh1_two_robot_greeting": {
202
+ "single_task": "Greet the other robot with a high five.",
203
+ **UNITREEH_INFO,
204
+ },
205
+ "unitreeh1_warehouse": {
206
+ "single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
207
+ **UNITREEH_INFO,
208
+ },
209
+ "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
210
+ "xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
211
+ "xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
212
+ "xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
213
+ "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
214
+ "xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
215
+ "xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
216
+ "xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
217
+ "umi_cup_in_the_wild": {
218
+ "single_task": "Put the cup on the plate.",
219
+ "license": "apache-2.0",
220
+ },
221
+ "asu_table_top": {
222
+ "tasks_col": "language_instruction",
223
+ "license": "mit",
224
+ "paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
225
+ "citation_bibtex": dedent(r"""
226
+ @inproceedings{zhou2023modularity,
227
+ title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
228
+ author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
229
+ booktitle={Conference on Robot Learning},
230
+ pages={1684--1695},
231
+ year={2023},
232
+ organization={PMLR}
233
+ }
234
+ @article{zhou2023learning,
235
+ title={Learning modular language-conditioned robot policies through attention},
236
+ author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
237
+ journal={Autonomous Robots},
238
+ pages={1--21},
239
+ year={2023},
240
+ publisher={Springer}
241
+ }""").lstrip(),
242
+ },
243
+ "austin_buds_dataset": {
244
+ "tasks_col": "language_instruction",
245
+ "license": "mit",
246
+ "url": "https://ut-austin-rpl.github.io/BUDS-website/",
247
+ "paper": "https://arxiv.org/abs/2109.13841",
248
+ "citation_bibtex": dedent(r"""
249
+ @article{zhu2022bottom,
250
+ title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
251
+ author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
252
+ journal={IEEE Robotics and Automation Letters},
253
+ volume={7},
254
+ number={2},
255
+ pages={4126--4133},
256
+ year={2022},
257
+ publisher={IEEE}
258
+ }""").lstrip(),
259
+ },
260
+ "austin_sailor_dataset": {
261
+ "tasks_col": "language_instruction",
262
+ "license": "mit",
263
+ "url": "https://ut-austin-rpl.github.io/sailor/",
264
+ "paper": "https://arxiv.org/abs/2210.11435",
265
+ "citation_bibtex": dedent(r"""
266
+ @inproceedings{nasiriany2022sailor,
267
+ title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
268
+ author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
269
+ booktitle={Conference on Robot Learning (CoRL)},
270
+ year={2022}
271
+ }""").lstrip(),
272
+ },
273
+ "austin_sirius_dataset": {
274
+ "tasks_col": "language_instruction",
275
+ "license": "mit",
276
+ "url": "https://ut-austin-rpl.github.io/sirius/",
277
+ "paper": "https://arxiv.org/abs/2211.08416",
278
+ "citation_bibtex": dedent(r"""
279
+ @inproceedings{liu2022robot,
280
+ title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
281
+ author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
282
+ booktitle = {Robotics: Science and Systems (RSS)},
283
+ year = {2023}
284
+ }""").lstrip(),
285
+ },
286
+ "berkeley_autolab_ur5": {
287
+ "tasks_col": "language_instruction",
288
+ "license": "cc-by-4.0",
289
+ "url": "https://sites.google.com/view/berkeley-ur5/home",
290
+ "citation_bibtex": dedent(r"""
291
+ @misc{BerkeleyUR5Website,
292
+ title = {Berkeley {UR5} Demonstration Dataset},
293
+ author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
294
+ howpublished = {https://sites.google.com/view/berkeley-ur5/home},
295
+ }""").lstrip(),
296
+ },
297
+ "berkeley_cable_routing": {
298
+ "tasks_col": "language_instruction",
299
+ "license": "cc-by-4.0",
300
+ "url": "https://sites.google.com/view/cablerouting/home",
301
+ "paper": "https://arxiv.org/abs/2307.08927",
302
+ "citation_bibtex": dedent(r"""
303
+ @article{luo2023multistage,
304
+ author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
305
+ title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
306
+ journal = {arXiv pre-print},
307
+ year = {2023},
308
+ url = {https://arxiv.org/abs/2307.08927},
309
+ }""").lstrip(),
310
+ },
311
+ "berkeley_fanuc_manipulation": {
312
+ "tasks_col": "language_instruction",
313
+ "license": "mit",
314
+ "url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
315
+ "citation_bibtex": dedent(r"""
316
+ @article{fanuc_manipulation2023,
317
+ title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
318
+ author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
319
+ year={2023},
320
+ }""").lstrip(),
321
+ },
322
+ "berkeley_gnm_cory_hall": {
323
+ "tasks_col": "language_instruction",
324
+ "license": "mit",
325
+ "paper": "https://arxiv.org/abs/1709.10489",
326
+ "citation_bibtex": dedent(r"""
327
+ @inproceedings{kahn2018self,
328
+ title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
329
+ author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
330
+ booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
331
+ pages={5129--5136},
332
+ year={2018},
333
+ organization={IEEE}
334
+ }""").lstrip(),
335
+ },
336
+ "berkeley_gnm_recon": {
337
+ "tasks_col": "language_instruction",
338
+ "license": "mit",
339
+ "url": "https://sites.google.com/view/recon-robot",
340
+ "paper": "https://arxiv.org/abs/2104.05859",
341
+ "citation_bibtex": dedent(r"""
342
+ @inproceedings{shah2021rapid,
343
+ title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
344
+ author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
345
+ booktitle={5th Annual Conference on Robot Learning },
346
+ year={2021},
347
+ url={https://openreview.net/forum?id=d_SWJhyKfVw}
348
+ }""").lstrip(),
349
+ },
350
+ "berkeley_gnm_sac_son": {
351
+ "tasks_col": "language_instruction",
352
+ "license": "mit",
353
+ "url": "https://sites.google.com/view/SACSoN-review",
354
+ "paper": "https://arxiv.org/abs/2306.01874",
355
+ "citation_bibtex": dedent(r"""
356
+ @article{hirose2023sacson,
357
+ title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
358
+ author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
359
+ journal={arXiv preprint arXiv:2306.01874},
360
+ year={2023}
361
+ }""").lstrip(),
362
+ },
363
+ "berkeley_mvp": {
364
+ "tasks_col": "language_instruction",
365
+ "license": "mit",
366
+ "paper": "https://arxiv.org/abs/2203.06173",
367
+ "citation_bibtex": dedent(r"""
368
+ @InProceedings{Radosavovic2022,
369
+ title = {Real-World Robot Learning with Masked Visual Pre-training},
370
+ author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
371
+ booktitle = {CoRL},
372
+ year = {2022}
373
+ }""").lstrip(),
374
+ },
375
+ "berkeley_rpt": {
376
+ "tasks_col": "language_instruction",
377
+ "license": "mit",
378
+ "paper": "https://arxiv.org/abs/2306.10007",
379
+ "citation_bibtex": dedent(r"""
380
+ @article{Radosavovic2023,
381
+ title={Robot Learning with Sensorimotor Pre-training},
382
+ author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
383
+ year={2023},
384
+ journal={arXiv:2306.10007}
385
+ }""").lstrip(),
386
+ },
387
+ "cmu_franka_exploration_dataset": {
388
+ "tasks_col": "language_instruction",
389
+ "license": "mit",
390
+ "url": "https://human-world-model.github.io/",
391
+ "paper": "https://arxiv.org/abs/2308.10901",
392
+ "citation_bibtex": dedent(r"""
393
+ @inproceedings{mendonca2023structured,
394
+ title={Structured World Models from Human Videos},
395
+ author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
396
+ journal={RSS},
397
+ year={2023}
398
+ }""").lstrip(),
399
+ },
400
+ "cmu_play_fusion": {
401
+ "tasks_col": "language_instruction",
402
+ "license": "mit",
403
+ "url": "https://play-fusion.github.io/",
404
+ "paper": "https://arxiv.org/abs/2312.04549",
405
+ "citation_bibtex": dedent(r"""
406
+ @inproceedings{chen2023playfusion,
407
+ title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
408
+ author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
409
+ booktitle={CoRL},
410
+ year={2023}
411
+ }""").lstrip(),
412
+ },
413
+ "cmu_stretch": {
414
+ "tasks_col": "language_instruction",
415
+ "license": "mit",
416
+ "url": "https://robo-affordances.github.io/",
417
+ "paper": "https://arxiv.org/abs/2304.08488",
418
+ "citation_bibtex": dedent(r"""
419
+ @inproceedings{bahl2023affordances,
420
+ title={Affordances from Human Videos as a Versatile Representation for Robotics},
421
+ author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
422
+ booktitle={CVPR},
423
+ year={2023}
424
+ }
425
+ @article{mendonca2023structured,
426
+ title={Structured World Models from Human Videos},
427
+ author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
428
+ journal={CoRL},
429
+ year={2023}
430
+ }""").lstrip(),
431
+ },
432
+ "columbia_cairlab_pusht_real": {
433
+ "tasks_col": "language_instruction",
434
+ "license": "mit",
435
+ "url": "https://diffusion-policy.cs.columbia.edu/",
436
+ "paper": "https://arxiv.org/abs/2303.04137v5",
437
+ "citation_bibtex": dedent(r"""
438
+ @inproceedings{chi2023diffusionpolicy,
439
+ title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
440
+ author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
441
+ booktitle={Proceedings of Robotics: Science and Systems (RSS)},
442
+ year={2023}
443
+ }""").lstrip(),
444
+ },
445
+ "conq_hose_manipulation": {
446
+ "tasks_col": "language_instruction",
447
+ "license": "mit",
448
+ "url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
449
+ "citation_bibtex": dedent(r"""
450
+ @misc{ConqHoseManipData,
451
+ author={Peter Mitrano and Dmitry Berenson},
452
+ title={Conq Hose Manipulation Dataset, v1.15.0},
453
+ year={2024},
454
+ howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
455
+ }""").lstrip(),
456
+ },
457
+ "dlr_edan_shared_control": {
458
+ "tasks_col": "language_instruction",
459
+ "license": "mit",
460
+ "paper": "https://ieeexplore.ieee.org/document/9341156",
461
+ "citation_bibtex": dedent(r"""
462
+ @inproceedings{vogel_edan_2020,
463
+ title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
464
+ language = {en},
465
+ booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
466
+ author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
467
+ year = {2020}
468
+ }
469
+ @inproceedings{quere_shared_2020,
470
+ address = {Paris, France},
471
+ title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
472
+ language = {en},
473
+ booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
474
+ author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
475
+ year = {2020},
476
+ pages = {7},
477
+ }""").lstrip(),
478
+ },
479
+ "dlr_sara_grid_clamp": {
480
+ "tasks_col": "language_instruction",
481
+ "license": "mit",
482
+ "paper": "https://www.researchsquare.com/article/rs-3289569/v1",
483
+ "citation_bibtex": dedent(r"""
484
+ @article{padalkar2023guided,
485
+ title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
486
+ author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
487
+ journal={Research square preprint rs-3289569/v1},
488
+ year={2023}
489
+ }""").lstrip(),
490
+ },
491
+ "dlr_sara_pour": {
492
+ "tasks_col": "language_instruction",
493
+ "license": "mit",
494
+ "paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
495
+ "citation_bibtex": dedent(r"""
496
+ @inproceedings{padalkar2023guiding,
497
+ title={Guiding Reinforcement Learning with Shared Control Templates},
498
+ author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
499
+ booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
500
+ year={2023},
501
+ organization={IEEE}
502
+ }""").lstrip(),
503
+ },
504
+ "droid_100": {
505
+ "tasks_col": "language_instruction",
506
+ "license": "mit",
507
+ "url": "https://droid-dataset.github.io/",
508
+ "paper": "https://arxiv.org/abs/2403.12945",
509
+ "citation_bibtex": dedent(r"""
510
+ @article{khazatsky2024droid,
511
+ title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
512
+ author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
513
+ year = {2024},
514
+ }""").lstrip(),
515
+ },
516
+ "fmb": {
517
+ "tasks_col": "language_instruction",
518
+ "license": "cc-by-4.0",
519
+ "url": "https://functional-manipulation-benchmark.github.io/",
520
+ "paper": "https://arxiv.org/abs/2401.08553",
521
+ "citation_bibtex": dedent(r"""
522
+ @article{luo2024fmb,
523
+ title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
524
+ author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
525
+ journal={arXiv preprint arXiv:2401.08553},
526
+ year={2024}
527
+ }""").lstrip(),
528
+ },
529
+ "iamlab_cmu_pickup_insert": {
530
+ "tasks_col": "language_instruction",
531
+ "license": "mit",
532
+ "url": "https://openreview.net/forum?id=WuBv9-IGDUA",
533
+ "paper": "https://arxiv.org/abs/2401.14502",
534
+ "citation_bibtex": dedent(r"""
535
+ @inproceedings{saxena2023multiresolution,
536
+ title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
537
+ author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
538
+ booktitle={7th Annual Conference on Robot Learning},
539
+ year={2023},
540
+ url={https://openreview.net/forum?id=WuBv9-IGDUA}
541
+ }""").lstrip(),
542
+ },
543
+ "imperialcollege_sawyer_wrist_cam": {
544
+ "tasks_col": "language_instruction",
545
+ "license": "mit",
546
+ },
547
+ "jaco_play": {
548
+ "tasks_col": "language_instruction",
549
+ "license": "cc-by-4.0",
550
+ "url": "https://github.com/clvrai/clvr_jaco_play_dataset",
551
+ "citation_bibtex": dedent(r"""
552
+ @software{dass2023jacoplay,
553
+ author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
554
+ and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
555
+ title = {CLVR Jaco Play Dataset},
556
+ url = {https://github.com/clvrai/clvr_jaco_play_dataset},
557
+ version = {1.0.0},
558
+ year = {2023}
559
+ }""").lstrip(),
560
+ },
561
+ "kaist_nonprehensile": {
562
+ "tasks_col": "language_instruction",
563
+ "license": "cc-by-4.0",
564
+ "url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
565
+ "citation_bibtex": dedent(r"""
566
+ @article{kimpre,
567
+ title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
568
+ author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
569
+ booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
570
+ year={2023},
571
+ organization={IEEE}
572
+ }""").lstrip(),
573
+ },
574
+ "nyu_door_opening_surprising_effectiveness": {
575
+ "tasks_col": "language_instruction",
576
+ "license": "mit",
577
+ "url": "https://jyopari.github.io/VINN/",
578
+ "paper": "https://arxiv.org/abs/2112.01511",
579
+ "citation_bibtex": dedent(r"""
580
+ @misc{pari2021surprising,
581
+ title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
582
+ author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
583
+ year={2021},
584
+ eprint={2112.01511},
585
+ archivePrefix={arXiv},
586
+ primaryClass={cs.RO}
587
+ }""").lstrip(),
588
+ },
589
+ "nyu_franka_play_dataset": {
590
+ "tasks_col": "language_instruction",
591
+ "license": "mit",
592
+ "url": "https://play-to-policy.github.io/",
593
+ "paper": "https://arxiv.org/abs/2210.10047",
594
+ "citation_bibtex": dedent(r"""
595
+ @article{cui2022play,
596
+ title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
597
+ author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
598
+ journal = {arXiv preprint arXiv:2210.10047},
599
+ year = {2022}
600
+ }""").lstrip(),
601
+ },
602
+ "nyu_rot_dataset": {
603
+ "tasks_col": "language_instruction",
604
+ "license": "mit",
605
+ "url": "https://rot-robot.github.io/",
606
+ "paper": "https://arxiv.org/abs/2206.15469",
607
+ "citation_bibtex": dedent(r"""
608
+ @inproceedings{haldar2023watch,
609
+ title={Watch and match: Supercharging imitation with regularized optimal transport},
610
+ author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
611
+ booktitle={Conference on Robot Learning},
612
+ pages={32--43},
613
+ year={2023},
614
+ organization={PMLR}
615
+ }""").lstrip(),
616
+ },
617
+ "roboturk": {
618
+ "tasks_col": "language_instruction",
619
+ "license": "mit",
620
+ "url": "https://roboturk.stanford.edu/dataset_real.html",
621
+ "paper": "PAPER",
622
+ "citation_bibtex": dedent(r"""
623
+ @inproceedings{mandlekar2019scaling,
624
+ title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
625
+ author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
626
+ booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
627
+ pages={1048--1055},
628
+ year={2019},
629
+ organization={IEEE}
630
+ }""").lstrip(),
631
+ },
632
+ "stanford_hydra_dataset": {
633
+ "tasks_col": "language_instruction",
634
+ "license": "mit",
635
+ "url": "https://sites.google.com/view/hydra-il-2023",
636
+ "paper": "https://arxiv.org/abs/2306.17237",
637
+ "citation_bibtex": dedent(r"""
638
+ @article{belkhale2023hydra,
639
+ title={HYDRA: Hybrid Robot Actions for Imitation Learning},
640
+ author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
641
+ journal={arxiv},
642
+ year={2023}
643
+ }""").lstrip(),
644
+ },
645
+ "stanford_kuka_multimodal_dataset": {
646
+ "tasks_col": "language_instruction",
647
+ "license": "mit",
648
+ "url": "https://sites.google.com/view/visionandtouch",
649
+ "paper": "https://arxiv.org/abs/1810.10191",
650
+ "citation_bibtex": dedent(r"""
651
+ @inproceedings{lee2019icra,
652
+ title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
653
+ author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
654
+ booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
655
+ year={2019},
656
+ url={https://arxiv.org/abs/1810.10191}
657
+ }""").lstrip(),
658
+ },
659
+ "stanford_robocook": {
660
+ "tasks_col": "language_instruction",
661
+ "license": "mit",
662
+ "url": "https://hshi74.github.io/robocook/",
663
+ "paper": "https://arxiv.org/abs/2306.14447",
664
+ "citation_bibtex": dedent(r"""
665
+ @article{shi2023robocook,
666
+ title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
667
+ author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
668
+ journal={arXiv preprint arXiv:2306.14447},
669
+ year={2023}
670
+ }""").lstrip(),
671
+ },
672
+ "taco_play": {
673
+ "tasks_col": "language_instruction",
674
+ "license": "cc-by-4.0",
675
+ "url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
676
+ "paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
677
+ "citation_bibtex": dedent(r"""
678
+ @inproceedings{rosete2022tacorl,
679
+ author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
680
+ title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
681
+ journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
682
+ year = {2022}
683
+ }
684
+ @inproceedings{mees23hulc2,
685
+ title={Grounding Language with Visual Affordances over Unstructured Data},
686
+ author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
687
+ booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
688
+ year={2023},
689
+ address = {London, UK}
690
+ }""").lstrip(),
691
+ },
692
+ "tokyo_u_lsmo": {
693
+ "tasks_col": "language_instruction",
694
+ "license": "mit",
695
+ "url": "URL",
696
+ "paper": "https://arxiv.org/abs/2107.05842",
697
+ "citation_bibtex": dedent(r"""
698
+ @Article{Osa22,
699
+ author = {Takayuki Osa},
700
+ journal = {The International Journal of Robotics Research},
701
+ title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
702
+ year = {2022},
703
+ number = {3},
704
+ pages = {291--311},
705
+ volume = {41},
706
+ }""").lstrip(),
707
+ },
708
+ "toto": {
709
+ "tasks_col": "language_instruction",
710
+ "license": "mit",
711
+ "url": "https://toto-benchmark.org/",
712
+ "paper": "https://arxiv.org/abs/2306.00942",
713
+ "citation_bibtex": dedent(r"""
714
+ @inproceedings{zhou2023train,
715
+ author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
716
+ booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
717
+ title={Train Offline, Test Online: A Real Robot Learning Benchmark},
718
+ year={2023},
719
+ }""").lstrip(),
720
+ },
721
+ "ucsd_kitchen_dataset": {
722
+ "tasks_col": "language_instruction",
723
+ "license": "mit",
724
+ "citation_bibtex": dedent(r"""
725
+ @ARTICLE{ucsd_kitchens,
726
+ author = {Ge Yan, Kris Wu, and Xiaolong Wang},
727
+ title = {{ucsd kitchens Dataset}},
728
+ year = {2023},
729
+ month = {August}
730
+ }""").lstrip(),
731
+ },
732
+ "ucsd_pick_and_place_dataset": {
733
+ "tasks_col": "language_instruction",
734
+ "license": "mit",
735
+ "url": "https://owmcorl.github.io/#",
736
+ "paper": "https://arxiv.org/abs/2310.16029",
737
+ "citation_bibtex": dedent(r"""
738
+ @preprint{Feng2023Finetuning,
739
+ title={Finetuning Offline World Models in the Real World},
740
+ author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
741
+ year={2023}
742
+ }""").lstrip(),
743
+ },
744
+ "uiuc_d3field": {
745
+ "tasks_col": "language_instruction",
746
+ "license": "mit",
747
+ "url": "https://robopil.github.io/d3fields/",
748
+ "paper": "https://arxiv.org/abs/2309.16118",
749
+ "citation_bibtex": dedent(r"""
750
+ @article{wang2023d3field,
751
+ title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
752
+ author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
753
+ journal={arXiv preprint arXiv:},
754
+ year={2023},
755
+ }""").lstrip(),
756
+ },
757
+ "usc_cloth_sim": {
758
+ "tasks_col": "language_instruction",
759
+ "license": "mit",
760
+ "url": "https://uscresl.github.io/dmfd/",
761
+ "paper": "https://arxiv.org/abs/2207.10148",
762
+ "citation_bibtex": dedent(r"""
763
+ @article{salhotra2022dmfd,
764
+ author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
765
+ journal={IEEE Robotics and Automation Letters},
766
+ title={Learning Deformable Object Manipulation From Expert Demonstrations},
767
+ year={2022},
768
+ volume={7},
769
+ number={4},
770
+ pages={8775-8782},
771
+ doi={10.1109/LRA.2022.3187843}
772
+ }""").lstrip(),
773
+ },
774
+ "utaustin_mutex": {
775
+ "tasks_col": "language_instruction",
776
+ "license": "mit",
777
+ "url": "https://ut-austin-rpl.github.io/MUTEX/",
778
+ "paper": "https://arxiv.org/abs/2309.14320",
779
+ "citation_bibtex": dedent(r"""
780
+ @inproceedings{shah2023mutex,
781
+ title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
782
+ author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
783
+ booktitle={7th Annual Conference on Robot Learning},
784
+ year={2023},
785
+ url={https://openreview.net/forum?id=PwqiqaaEzJ}
786
+ }""").lstrip(),
787
+ },
788
+ "utokyo_pr2_opening_fridge": {
789
+ "tasks_col": "language_instruction",
790
+ "license": "mit",
791
+ "citation_bibtex": dedent(r"""
792
+ @misc{oh2023pr2utokyodatasets,
793
+ author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
794
+ title={X-Embodiment U-Tokyo PR2 Datasets},
795
+ year={2023},
796
+ url={https://github.com/ojh6404/rlds_dataset_builder},
797
+ }""").lstrip(),
798
+ },
799
+ "utokyo_pr2_tabletop_manipulation": {
800
+ "tasks_col": "language_instruction",
801
+ "license": "mit",
802
+ "citation_bibtex": dedent(r"""
803
+ @misc{oh2023pr2utokyodatasets,
804
+ author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
805
+ title={X-Embodiment U-Tokyo PR2 Datasets},
806
+ year={2023},
807
+ url={https://github.com/ojh6404/rlds_dataset_builder},
808
+ }""").lstrip(),
809
+ },
810
+ "utokyo_saytap": {
811
+ "tasks_col": "language_instruction",
812
+ "license": "mit",
813
+ "url": "https://saytap.github.io/",
814
+ "paper": "https://arxiv.org/abs/2306.07580",
815
+ "citation_bibtex": dedent(r"""
816
+ @article{saytap2023,
817
+ author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
818
+ Tatsuya Harada},
819
+ title = {SayTap: Language to Quadrupedal Locomotion},
820
+ eprint = {arXiv:2306.07580},
821
+ url = {https://saytap.github.io},
822
+ note = {https://saytap.github.io},
823
+ year = {2023}
824
+ }""").lstrip(),
825
+ },
826
+ "utokyo_xarm_bimanual": {
827
+ "tasks_col": "language_instruction",
828
+ "license": "cc-by-4.0",
829
+ "citation_bibtex": dedent(r"""
830
+ @misc{matsushima2023weblab,
831
+ title={Weblab xArm Dataset},
832
+ author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
833
+ year={2023},
834
+ }""").lstrip(),
835
+ },
836
+ "utokyo_xarm_pick_and_place": {
837
+ "tasks_col": "language_instruction",
838
+ "license": "cc-by-4.0",
839
+ "citation_bibtex": dedent(r"""
840
+ @misc{matsushima2023weblab,
841
+ title={Weblab xArm Dataset},
842
+ author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
843
+ year={2023},
844
+ }""").lstrip(),
845
+ },
846
+ "viola": {
847
+ "tasks_col": "language_instruction",
848
+ "license": "mit",
849
+ "url": "https://ut-austin-rpl.github.io/VIOLA/",
850
+ "paper": "https://arxiv.org/abs/2210.11339",
851
+ "citation_bibtex": dedent(r"""
852
+ @article{zhu2022viola,
853
+ title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
854
+ author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
855
+ journal={6th Annual Conference on Robot Learning (CoRL)},
856
+ year={2022}
857
+ }""").lstrip(),
858
+ },
859
+ }
860
+ # spellchecker:on
861
+
862
+
863
+ def batch_convert():
864
+ status = {}
865
+ logfile = LOCAL_DIR / "conversion_log.txt"
866
+ assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
867
+ for num, (name, kwargs) in enumerate(DATASETS.items()):
868
+ repo_id = f"lerobot/{name}"
869
+ print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
870
+ print("---------------------------------------------------------")
871
+ try:
872
+ convert_dataset(repo_id, LOCAL_DIR, **kwargs)
873
+ status = f"{repo_id}: success."
874
+ with open(logfile, "a") as file:
875
+ file.write(status + "\n")
876
+ except Exception:
877
+ status = f"{repo_id}: failed\n {traceback.format_exc()}"
878
+ with open(logfile, "a") as file:
879
+ file.write(status + "\n")
880
+ continue
881
+
882
+
883
+ if __name__ == "__main__":
884
+ batch_convert()
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
19
+ 2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
20
+ for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
21
+
22
+ We support 3 different scenarios for these tasks (see instructions below):
23
+ 1. Single task dataset: all episodes of your dataset have the same single task.
24
+ 2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
25
+ one episode to the next.
26
+ 3. Multi task episodes: episodes of your dataset may each contain several different tasks.
27
+
28
+
29
+ Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
30
+ '--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
31
+ recorded with. For now, only Aloha/Koch type robots are supported with this option.
32
+
33
+
34
+ # 1. Single task dataset
35
+ If your dataset contains a single task, you can simply provide it directly via the CLI with the
36
+ '--single-task' option.
37
+
38
+ Examples:
39
+
40
+ ```bash
41
+ python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
42
+ --repo-id lerobot/aloha_sim_insertion_human_image \
43
+ --single-task "Insert the peg into the socket." \
44
+ --robot-config lerobot/configs/robot/aloha.yaml \
45
+ --local-dir data
46
+ ```
47
+
48
+ ```bash
49
+ python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
50
+ --repo-id aliberts/koch_tutorial \
51
+ --single-task "Pick the Lego block and drop it in the box on the right." \
52
+ --robot-config lerobot/configs/robot/koch.yaml \
53
+ --local-dir data
54
+ ```
55
+
56
+
57
+ # 2. Single task episodes
58
+ If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
59
+
60
+ - If your dataset already contains a language instruction column in its parquet file, you can simply provide
61
+ this column's name with the '--tasks-col' arg.
62
+
63
+ Example:
64
+
65
+ ```bash
66
+ python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
67
+ --repo-id lerobot/stanford_kuka_multimodal_dataset \
68
+ --tasks-col "language_instruction" \
69
+ --local-dir data
70
+ ```
71
+
72
+ - If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
73
+ '--tasks-path' arg. This file should have the following structure where keys correspond to each
74
+ episode_index in the dataset, and values are the language instruction for that episode.
75
+
76
+ Example:
77
+
78
+ ```json
79
+ {
80
+ "0": "Do something",
81
+ "1": "Do something else",
82
+ "2": "Do something",
83
+ "3": "Go there",
84
+ ...
85
+ }
86
+ ```
87
+
88
+ # 3. Multi task episodes
89
+ If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
90
+ parquet file, and you must provide this column's name with the '--tasks-col' arg.
91
+
92
+ Example:
93
+
94
+ ```bash
95
+ python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
96
+ --repo-id lerobot/stanford_kuka_multimodal_dataset \
97
+ --tasks-col "language_instruction" \
98
+ --local-dir data
99
+ ```
100
+ """
101
+
102
+ import argparse
103
+ import contextlib
104
+ import filecmp
105
+ import json
106
+ import logging
107
+ import math
108
+ import shutil
109
+ import subprocess
110
+ import tempfile
111
+ from pathlib import Path
112
+
113
+ import datasets
114
+ import pyarrow.compute as pc
115
+ import pyarrow.parquet as pq
116
+ import torch
117
+ from datasets import Dataset
118
+ from huggingface_hub import HfApi
119
+ from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
120
+ from safetensors.torch import load_file
121
+
122
+ from lerobot.common.datasets.utils import (
123
+ DEFAULT_CHUNK_SIZE,
124
+ DEFAULT_PARQUET_PATH,
125
+ DEFAULT_VIDEO_PATH,
126
+ EPISODES_PATH,
127
+ INFO_PATH,
128
+ STATS_PATH,
129
+ TASKS_PATH,
130
+ create_branch,
131
+ create_lerobot_dataset_card,
132
+ flatten_dict,
133
+ get_safe_version,
134
+ load_json,
135
+ unflatten_dict,
136
+ write_json,
137
+ write_jsonlines,
138
+ )
139
+ from lerobot.common.datasets.video_utils import (
140
+ VideoFrame, # noqa: F401
141
+ get_image_pixel_channels,
142
+ get_video_info,
143
+ )
144
+ from lerobot.common.robot_devices.robots.configs import RobotConfig
145
+ from lerobot.common.robot_devices.robots.utils import make_robot_config
146
+
147
+ V16 = "v1.6"
148
+ V20 = "v2.0"
149
+
150
+ GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
151
+ V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
152
+ V1_INFO_PATH = "meta_data/info.json"
153
+ V1_STATS_PATH = "meta_data/stats.safetensors"
154
+
155
+
156
+ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
157
+ if robot_cfg.type in ["aloha", "koch"]:
158
+ state_names = [
159
+ f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
160
+ for arm in robot_cfg.follower_arms
161
+ for motor in robot_cfg.follower_arms[arm].motors
162
+ ]
163
+ action_names = [
164
+ # f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
165
+ f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
166
+ for arm in robot_cfg.leader_arms
167
+ for motor in robot_cfg.leader_arms[arm].motors
168
+ ]
169
+ # elif robot_cfg["robot_type"] == "stretch3": TODO
170
+ else:
171
+ raise NotImplementedError(
172
+ "Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
173
+ )
174
+
175
+ return {
176
+ "robot_type": robot_cfg.type,
177
+ "names": {
178
+ "observation.state": state_names,
179
+ "observation.effort": state_names,
180
+ "action": action_names,
181
+ },
182
+ }
183
+
184
+
185
+ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
186
+ safetensor_path = v1_dir / V1_STATS_PATH
187
+ stats = load_file(safetensor_path)
188
+ serialized_stats = {key: value.tolist() for key, value in stats.items()}
189
+ serialized_stats = unflatten_dict(serialized_stats)
190
+
191
+ json_path = v2_dir / STATS_PATH
192
+ json_path.parent.mkdir(exist_ok=True, parents=True)
193
+ with open(json_path, "w") as f:
194
+ json.dump(serialized_stats, f, indent=4)
195
+
196
+ # Sanity check
197
+ with open(json_path) as f:
198
+ stats_json = json.load(f)
199
+
200
+ stats_json = flatten_dict(stats_json)
201
+ stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
202
+ for key in stats:
203
+ torch.testing.assert_close(stats_json[key], stats[key])
204
+
205
+
206
+ def get_features_from_hf_dataset(
207
+ dataset: Dataset, robot_config: RobotConfig | None = None
208
+ ) -> dict[str, list]:
209
+ robot_config = parse_robot_config(robot_config)
210
+ features = {}
211
+ for key, ft in dataset.features.items():
212
+ if isinstance(ft, datasets.Value):
213
+ dtype = ft.dtype
214
+ shape = (1,)
215
+ names = None
216
+ if isinstance(ft, datasets.Sequence):
217
+ assert isinstance(ft.feature, datasets.Value)
218
+ dtype = ft.feature.dtype
219
+ shape = (ft.length,)
220
+ motor_names = (
221
+ robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
222
+ )
223
+ assert len(motor_names) == shape[0]
224
+ names = {"motors": motor_names}
225
+ elif isinstance(ft, datasets.Image):
226
+ dtype = "image"
227
+ image = dataset[0][key] # Assuming first row
228
+ channels = get_image_pixel_channels(image)
229
+ shape = (image.height, image.width, channels)
230
+ names = ["height", "width", "channels"]
231
+ elif ft._type == "VideoFrame":
232
+ dtype = "video"
233
+ shape = None # Add shape later
234
+ names = ["height", "width", "channels"]
235
+
236
+ features[key] = {
237
+ "dtype": dtype,
238
+ "shape": shape,
239
+ "names": names,
240
+ }
241
+
242
+ return features
243
+
244
+
245
+ def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
246
+ df = dataset.to_pandas()
247
+ tasks = list(set(tasks_by_episodes.values()))
248
+ tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
249
+ episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
250
+ df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
251
+
252
+ features = dataset.features
253
+ features["task_index"] = datasets.Value(dtype="int64")
254
+ dataset = Dataset.from_pandas(df, features=features, split="train")
255
+ return dataset, tasks
256
+
257
+
258
+ def add_task_index_from_tasks_col(
259
+ dataset: Dataset, tasks_col: str
260
+ ) -> tuple[Dataset, dict[str, list[str]], list[str]]:
261
+ df = dataset.to_pandas()
262
+
263
+ # HACK: This is to clean some of the instructions in our version of Open X datasets
264
+ prefix_to_clean = "tf.Tensor(b'"
265
+ suffix_to_clean = "', shape=(), dtype=string)"
266
+ df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
267
+
268
+ # Create task_index col
269
+ tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
270
+ tasks = df[tasks_col].unique().tolist()
271
+ tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
272
+ df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
273
+
274
+ # Build the dataset back from df
275
+ features = dataset.features
276
+ features["task_index"] = datasets.Value(dtype="int64")
277
+ dataset = Dataset.from_pandas(df, features=features, split="train")
278
+ dataset = dataset.remove_columns(tasks_col)
279
+
280
+ return dataset, tasks, tasks_by_episode
281
+
282
+
283
+ def split_parquet_by_episodes(
284
+ dataset: Dataset,
285
+ total_episodes: int,
286
+ total_chunks: int,
287
+ output_dir: Path,
288
+ ) -> list:
289
+ table = dataset.data.table
290
+ episode_lengths = []
291
+ for ep_chunk in range(total_chunks):
292
+ ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
293
+ ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
294
+ chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
295
+ (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
296
+ for ep_idx in range(ep_chunk_start, ep_chunk_end):
297
+ ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
298
+ episode_lengths.insert(ep_idx, len(ep_table))
299
+ output_file = output_dir / DEFAULT_PARQUET_PATH.format(
300
+ episode_chunk=ep_chunk, episode_index=ep_idx
301
+ )
302
+ pq.write_table(ep_table, output_file)
303
+
304
+ return episode_lengths
305
+
306
+
307
+ def move_videos(
308
+ repo_id: str,
309
+ video_keys: list[str],
310
+ total_episodes: int,
311
+ total_chunks: int,
312
+ work_dir: Path,
313
+ clean_gittatributes: Path,
314
+ branch: str = "main",
315
+ ) -> None:
316
+ """
317
+ HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
318
+ commands to fetch git lfs video files references to move them into subdirectories without having to
319
+ actually download them.
320
+ """
321
+ _lfs_clone(repo_id, work_dir, branch)
322
+
323
+ videos_moved = False
324
+ video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
325
+ if len(video_files) == 0:
326
+ video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
327
+ videos_moved = True # Videos have already been moved
328
+
329
+ assert len(video_files) == total_episodes * len(video_keys)
330
+
331
+ lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
332
+
333
+ current_gittatributes = work_dir / ".gitattributes"
334
+ if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
335
+ fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
336
+
337
+ if lfs_untracked_videos:
338
+ fix_lfs_video_files_tracking(work_dir, video_files)
339
+
340
+ if videos_moved:
341
+ return
342
+
343
+ video_dirs = sorted(work_dir.glob("videos*/"))
344
+ for ep_chunk in range(total_chunks):
345
+ ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
346
+ ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
347
+ for vid_key in video_keys:
348
+ chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
349
+ episode_chunk=ep_chunk, video_key=vid_key
350
+ )
351
+ (work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
352
+
353
+ for ep_idx in range(ep_chunk_start, ep_chunk_end):
354
+ target_path = DEFAULT_VIDEO_PATH.format(
355
+ episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
356
+ )
357
+ video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
358
+ if len(video_dirs) == 1:
359
+ video_path = video_dirs[0] / video_file
360
+ else:
361
+ for dir in video_dirs:
362
+ if (dir / video_file).is_file():
363
+ video_path = dir / video_file
364
+ break
365
+
366
+ video_path.rename(work_dir / target_path)
367
+
368
+ commit_message = "Move video files into chunk subdirectories"
369
+ subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
370
+ subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
371
+ subprocess.run(["git", "push"], cwd=work_dir, check=True)
372
+
373
+
374
+ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
375
+ """
376
+ HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
377
+ there's no other option than to download the actual files and reupload them with lfs tracking.
378
+ """
379
+ for i in range(0, len(lfs_untracked_videos), 100):
380
+ files = lfs_untracked_videos[i : i + 100]
381
+ try:
382
+ subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
383
+ except subprocess.CalledProcessError as e:
384
+ print("git rm --cached ERROR:")
385
+ print(e.stderr)
386
+ subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
387
+
388
+ commit_message = "Track video files with git lfs"
389
+ subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
390
+ subprocess.run(["git", "push"], cwd=work_dir, check=True)
391
+
392
+
393
+ def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
394
+ shutil.copyfile(clean_gittatributes, current_gittatributes)
395
+ subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
396
+ subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
397
+ subprocess.run(["git", "push"], cwd=work_dir, check=True)
398
+
399
+
400
+ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
401
+ subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
402
+ repo_url = f"https://huggingface.co/datasets/{repo_id}"
403
+ env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
404
+ subprocess.run(
405
+ ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
406
+ check=True,
407
+ env=env,
408
+ )
409
+
410
+
411
+ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
412
+ lfs_tracked_files = subprocess.run(
413
+ ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
414
+ )
415
+ lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
416
+ return [f for f in video_files if f not in lfs_tracked_files]
417
+
418
+
419
+ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
420
+ # Assumes first episode
421
+ video_files = [
422
+ DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
423
+ for vid_key in video_keys
424
+ ]
425
+ hub_api = HfApi()
426
+ hub_api.snapshot_download(
427
+ repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
428
+ )
429
+ videos_info_dict = {}
430
+ for vid_key, vid_path in zip(video_keys, video_files, strict=True):
431
+ videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
432
+
433
+ return videos_info_dict
434
+
435
+
436
+ def convert_dataset(
437
+ repo_id: str,
438
+ local_dir: Path,
439
+ single_task: str | None = None,
440
+ tasks_path: Path | None = None,
441
+ tasks_col: Path | None = None,
442
+ robot_config: RobotConfig | None = None,
443
+ test_branch: str | None = None,
444
+ **card_kwargs,
445
+ ):
446
+ v1 = get_safe_version(repo_id, V16)
447
+ v1x_dir = local_dir / V16 / repo_id
448
+ v20_dir = local_dir / V20 / repo_id
449
+ v1x_dir.mkdir(parents=True, exist_ok=True)
450
+ v20_dir.mkdir(parents=True, exist_ok=True)
451
+
452
+ hub_api = HfApi()
453
+ hub_api.snapshot_download(
454
+ repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
455
+ )
456
+ branch = "main"
457
+ if test_branch:
458
+ branch = test_branch
459
+ create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
460
+
461
+ metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
462
+ dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
463
+ features = get_features_from_hf_dataset(dataset, robot_config)
464
+ video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
465
+
466
+ if single_task and "language_instruction" in dataset.column_names:
467
+ logging.warning(
468
+ "'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
469
+ )
470
+ single_task = None
471
+ tasks_col = "language_instruction"
472
+
473
+ # Episodes & chunks
474
+ episode_indices = sorted(dataset.unique("episode_index"))
475
+ total_episodes = len(episode_indices)
476
+ assert episode_indices == list(range(total_episodes))
477
+ total_videos = total_episodes * len(video_keys)
478
+ total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
479
+ if total_episodes % DEFAULT_CHUNK_SIZE != 0:
480
+ total_chunks += 1
481
+
482
+ # Tasks
483
+ if single_task:
484
+ tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
485
+ dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
486
+ tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
487
+ elif tasks_path:
488
+ tasks_by_episodes = load_json(tasks_path)
489
+ tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
490
+ dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
491
+ tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
492
+ elif tasks_col:
493
+ dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
494
+ else:
495
+ raise ValueError
496
+
497
+ assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
498
+ tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
499
+ write_jsonlines(tasks, v20_dir / TASKS_PATH)
500
+ features["task_index"] = {
501
+ "dtype": "int64",
502
+ "shape": (1,),
503
+ "names": None,
504
+ }
505
+
506
+ # Videos
507
+ if video_keys:
508
+ assert metadata_v1.get("video", False)
509
+ dataset = dataset.remove_columns(video_keys)
510
+ clean_gitattr = Path(
511
+ hub_api.hf_hub_download(
512
+ repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
513
+ )
514
+ ).absolute()
515
+ with tempfile.TemporaryDirectory() as tmp_video_dir:
516
+ move_videos(
517
+ repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
518
+ )
519
+ videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
520
+ for key in video_keys:
521
+ features[key]["shape"] = (
522
+ videos_info[key].pop("video.height"),
523
+ videos_info[key].pop("video.width"),
524
+ videos_info[key].pop("video.channels"),
525
+ )
526
+ features[key]["video_info"] = videos_info[key]
527
+ assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
528
+ if "encoding" in metadata_v1:
529
+ assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
530
+ else:
531
+ assert metadata_v1.get("video", 0) == 0
532
+ videos_info = None
533
+
534
+ # Split data into 1 parquet file by episode
535
+ episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
536
+
537
+ if robot_config is not None:
538
+ robot_type = robot_config.type
539
+ repo_tags = [robot_type]
540
+ else:
541
+ robot_type = "unknown"
542
+ repo_tags = None
543
+
544
+ # Episodes
545
+ episodes = [
546
+ {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
547
+ for ep_idx in episode_indices
548
+ ]
549
+ write_jsonlines(episodes, v20_dir / EPISODES_PATH)
550
+
551
+ # Assemble metadata v2.0
552
+ metadata_v2_0 = {
553
+ "codebase_version": V20,
554
+ "robot_type": robot_type,
555
+ "total_episodes": total_episodes,
556
+ "total_frames": len(dataset),
557
+ "total_tasks": len(tasks),
558
+ "total_videos": total_videos,
559
+ "total_chunks": total_chunks,
560
+ "chunks_size": DEFAULT_CHUNK_SIZE,
561
+ "fps": metadata_v1["fps"],
562
+ "splits": {"train": f"0:{total_episodes}"},
563
+ "data_path": DEFAULT_PARQUET_PATH,
564
+ "video_path": DEFAULT_VIDEO_PATH if video_keys else None,
565
+ "features": features,
566
+ }
567
+ write_json(metadata_v2_0, v20_dir / INFO_PATH)
568
+ convert_stats_to_json(v1x_dir, v20_dir)
569
+ card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
570
+
571
+ with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
572
+ hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
573
+
574
+ with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
575
+ hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
576
+
577
+ with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
578
+ hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
579
+
580
+ hub_api.upload_folder(
581
+ repo_id=repo_id,
582
+ path_in_repo="data",
583
+ folder_path=v20_dir / "data",
584
+ repo_type="dataset",
585
+ revision=branch,
586
+ )
587
+ hub_api.upload_folder(
588
+ repo_id=repo_id,
589
+ path_in_repo="meta",
590
+ folder_path=v20_dir / "meta",
591
+ repo_type="dataset",
592
+ revision=branch,
593
+ )
594
+
595
+ card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
596
+
597
+ if not test_branch:
598
+ create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
599
+
600
+
601
+ def main():
602
+ parser = argparse.ArgumentParser()
603
+ task_args = parser.add_mutually_exclusive_group(required=True)
604
+
605
+ parser.add_argument(
606
+ "--repo-id",
607
+ type=str,
608
+ required=True,
609
+ help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
610
+ )
611
+ task_args.add_argument(
612
+ "--single-task",
613
+ type=str,
614
+ help="A short but accurate description of the single task performed in the dataset.",
615
+ )
616
+ task_args.add_argument(
617
+ "--tasks-col",
618
+ type=str,
619
+ help="The name of the column containing language instructions",
620
+ )
621
+ task_args.add_argument(
622
+ "--tasks-path",
623
+ type=Path,
624
+ help="The path to a .json file containing one language instruction for each episode_index",
625
+ )
626
+ parser.add_argument(
627
+ "--robot",
628
+ type=str,
629
+ default=None,
630
+ help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
631
+ )
632
+ parser.add_argument(
633
+ "--local-dir",
634
+ type=Path,
635
+ default=None,
636
+ help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
637
+ )
638
+ parser.add_argument(
639
+ "--license",
640
+ type=str,
641
+ default="apache-2.0",
642
+ help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
643
+ )
644
+ parser.add_argument(
645
+ "--test-branch",
646
+ type=str,
647
+ default=None,
648
+ help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
649
+ )
650
+
651
+ args = parser.parse_args()
652
+ if not args.local_dir:
653
+ args.local_dir = Path("/tmp/lerobot_dataset_v2")
654
+
655
+ if args.robot is not None:
656
+ robot_config = make_robot_config(args.robot)
657
+
658
+ del args.robot
659
+
660
+ convert_dataset(**vars(args), robot_config=robot_config)
661
+
662
+
663
+ if __name__ == "__main__":
664
+ main()
lerobot/common/datasets/v21/_remove_language_instruction.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import traceback
17
+ from pathlib import Path
18
+
19
+ from datasets import get_dataset_config_info
20
+ from huggingface_hub import HfApi
21
+
22
+ from lerobot import available_datasets
23
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
24
+ from lerobot.common.datasets.utils import INFO_PATH, write_info
25
+ from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
26
+
27
+ LOCAL_DIR = Path("data/")
28
+
29
+ hub_api = HfApi()
30
+
31
+
32
+ def fix_dataset(repo_id: str) -> str:
33
+ if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
34
+ return f"{repo_id}: skipped (not in {V20})."
35
+
36
+ dataset_info = get_dataset_config_info(repo_id, "default")
37
+ with SuppressWarnings():
38
+ lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
39
+
40
+ meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
41
+ parquet_features = set(dataset_info.features)
42
+
43
+ diff_parquet_meta = parquet_features - meta_features
44
+ diff_meta_parquet = meta_features - parquet_features
45
+
46
+ if diff_parquet_meta:
47
+ raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
48
+
49
+ if not diff_meta_parquet:
50
+ return f"{repo_id}: skipped (no diff)"
51
+
52
+ if diff_meta_parquet:
53
+ logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
54
+ assert diff_meta_parquet == {"language_instruction"}
55
+ lerobot_metadata.features.pop("language_instruction")
56
+ write_info(lerobot_metadata.info, lerobot_metadata.root)
57
+ commit_info = hub_api.upload_file(
58
+ path_or_fileobj=lerobot_metadata.root / INFO_PATH,
59
+ path_in_repo=INFO_PATH,
60
+ repo_id=repo_id,
61
+ repo_type="dataset",
62
+ revision=V20,
63
+ commit_message="Remove 'language_instruction'",
64
+ create_pr=True,
65
+ )
66
+ return f"{repo_id}: success - PR: {commit_info.pr_url}"
67
+
68
+
69
+ def batch_fix():
70
+ status = {}
71
+ LOCAL_DIR.mkdir(parents=True, exist_ok=True)
72
+ logfile = LOCAL_DIR / "fix_features_v20.txt"
73
+ for num, repo_id in enumerate(available_datasets):
74
+ print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
75
+ print("---------------------------------------------------------")
76
+ try:
77
+ status = fix_dataset(repo_id)
78
+ except Exception:
79
+ status = f"{repo_id}: failed\n {traceback.format_exc()}"
80
+
81
+ logging.info(status)
82
+ with open(logfile, "a") as file:
83
+ file.write(status + "\n")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ batch_fix()
lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
19
+ """
20
+
21
+ import traceback
22
+ from pathlib import Path
23
+
24
+ from huggingface_hub import HfApi
25
+
26
+ from lerobot import available_datasets
27
+ from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
28
+
29
+ LOCAL_DIR = Path("data/")
30
+
31
+
32
+ def batch_convert():
33
+ status = {}
34
+ LOCAL_DIR.mkdir(parents=True, exist_ok=True)
35
+ logfile = LOCAL_DIR / "conversion_log_v21.txt"
36
+ hub_api = HfApi()
37
+ for num, repo_id in enumerate(available_datasets):
38
+ print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
39
+ print("---------------------------------------------------------")
40
+ try:
41
+ if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
42
+ status = f"{repo_id}: success (already in {V21})."
43
+ else:
44
+ convert_dataset(repo_id)
45
+ status = f"{repo_id}: success."
46
+ except Exception:
47
+ status = f"{repo_id}: failed\n {traceback.format_exc()}"
48
+
49
+ with open(logfile, "a") as file:
50
+ file.write(status + "\n")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ batch_convert()
lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
17
+ 2.1. It will:
18
+
19
+ - Generate per-episodes stats and writes them in `episodes_stats.jsonl`
20
+ - Check consistency between these new stats and the old ones.
21
+ - Remove the deprecated `stats.json`.
22
+ - Update codebase_version in `info.json`.
23
+ - Push this new version to the hub on the 'main' branch and tags it with "v2.1".
24
+
25
+ Usage:
26
+
27
+ ```bash
28
+ python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
29
+ --repo-id=aliberts/koch_tutorial
30
+ ```
31
+
32
+ """
33
+
34
+ import argparse
35
+ import logging
36
+
37
+ from huggingface_hub import HfApi
38
+
39
+ from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
40
+ from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
41
+ from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
42
+
43
+ V20 = "v2.0"
44
+ V21 = "v2.1"
45
+
46
+
47
+ class SuppressWarnings:
48
+ def __enter__(self):
49
+ self.previous_level = logging.getLogger().getEffectiveLevel()
50
+ logging.getLogger().setLevel(logging.ERROR)
51
+
52
+ def __exit__(self, exc_type, exc_val, exc_tb):
53
+ logging.getLogger().setLevel(self.previous_level)
54
+
55
+
56
+ def convert_dataset(
57
+ repo_id: str,
58
+ branch: str | None = None,
59
+ num_workers: int = 4,
60
+ ):
61
+ with SuppressWarnings():
62
+ dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
63
+
64
+ if (dataset.root / EPISODES_STATS_PATH).is_file():
65
+ (dataset.root / EPISODES_STATS_PATH).unlink()
66
+
67
+ convert_stats(dataset, num_workers=num_workers)
68
+ ref_stats = load_stats(dataset.root)
69
+ check_aggregate_stats(dataset, ref_stats)
70
+
71
+ dataset.meta.info["codebase_version"] = CODEBASE_VERSION
72
+ write_info(dataset.meta.info, dataset.root)
73
+
74
+ dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
75
+
76
+ # delete old stats.json file
77
+ if (dataset.root / STATS_PATH).is_file:
78
+ (dataset.root / STATS_PATH).unlink()
79
+
80
+ hub_api = HfApi()
81
+ if hub_api.file_exists(
82
+ repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
83
+ ):
84
+ hub_api.delete_file(
85
+ path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
86
+ )
87
+
88
+ hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument(
94
+ "--repo-id",
95
+ type=str,
96
+ required=True,
97
+ help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
98
+ "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
99
+ )
100
+ parser.add_argument(
101
+ "--branch",
102
+ type=str,
103
+ default=None,
104
+ help="Repo branch to push your dataset. Defaults to the main branch.",
105
+ )
106
+ parser.add_argument(
107
+ "--num-workers",
108
+ type=int,
109
+ default=4,
110
+ help="Number of workers for parallelizing stats compute. Defaults to 4.",
111
+ )
112
+
113
+ args = parser.parse_args()
114
+ convert_dataset(**vars(args))
lerobot/common/datasets/v21/convert_stats.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
+
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+
20
+ from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
21
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
22
+ from lerobot.common.datasets.utils import write_episode_stats
23
+
24
+
25
+ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
26
+ ep_len = dataset.meta.episodes[episode_index]["length"]
27
+ sampled_indices = sample_indices(ep_len)
28
+ query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
29
+ video_frames = dataset._query_videos(query_timestamps, episode_index)
30
+ return video_frames[ft_key].numpy()
31
+
32
+
33
+ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
34
+ ep_start_idx = dataset.episode_data_index["from"][ep_idx]
35
+ ep_end_idx = dataset.episode_data_index["to"][ep_idx]
36
+ ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
37
+
38
+ ep_stats = {}
39
+ for key, ft in dataset.features.items():
40
+ if ft["dtype"] == "video":
41
+ # We sample only for videos
42
+ ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
43
+ else:
44
+ ep_ft_data = np.array(ep_data[key])
45
+
46
+ axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
47
+ keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
48
+ ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
49
+
50
+ if ft["dtype"] in ["image", "video"]: # remove batch dim
51
+ ep_stats[key] = {
52
+ k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
53
+ }
54
+
55
+ dataset.meta.episodes_stats[ep_idx] = ep_stats
56
+
57
+
58
+ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
59
+ assert dataset.episodes is None
60
+ print("Computing episodes stats")
61
+ total_episodes = dataset.meta.total_episodes
62
+ if num_workers > 0:
63
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
64
+ futures = {
65
+ executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
66
+ for ep_idx in range(total_episodes)
67
+ }
68
+ for future in tqdm(as_completed(futures), total=total_episodes):
69
+ future.result()
70
+ else:
71
+ for ep_idx in tqdm(range(total_episodes)):
72
+ convert_episode_stats(dataset, ep_idx)
73
+
74
+ for ep_idx in tqdm(range(total_episodes)):
75
+ write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
76
+
77
+
78
+ def check_aggregate_stats(
79
+ dataset: LeRobotDataset,
80
+ reference_stats: dict[str, dict[str, np.ndarray]],
81
+ video_rtol_atol: tuple[float] = (1e-2, 1e-2),
82
+ default_rtol_atol: tuple[float] = (5e-6, 6e-5),
83
+ ):
84
+ """Verifies that the aggregated stats from episodes_stats are close to reference stats."""
85
+ agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
86
+ for key, ft in dataset.features.items():
87
+ # These values might need some fine-tuning
88
+ if ft["dtype"] == "video":
89
+ # to account for image sub-sampling
90
+ rtol, atol = video_rtol_atol
91
+ else:
92
+ rtol, atol = default_rtol_atol
93
+
94
+ for stat, val in agg_stats[key].items():
95
+ if key in reference_stats and stat in reference_stats[key]:
96
+ err_msg = f"feature='{key}' stats='{stat}'"
97
+ np.testing.assert_allclose(
98
+ val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
99
+ )
lerobot/common/datasets/video_utils.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import importlib
17
+ import json
18
+ import logging
19
+ import subprocess
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass, field
23
+ from pathlib import Path
24
+ from typing import Any, ClassVar
25
+
26
+ import pyarrow as pa
27
+ import torch
28
+ import torchvision
29
+ from datasets.features.features import register_feature
30
+ from PIL import Image
31
+
32
+
33
+ def get_safe_default_codec():
34
+ if importlib.util.find_spec("torchcodec"):
35
+ return "torchcodec"
36
+ else:
37
+ logging.warning(
38
+ "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
39
+ )
40
+ return "pyav"
41
+
42
+
43
+ def decode_video_frames(
44
+ video_path: Path | str,
45
+ timestamps: list[float],
46
+ tolerance_s: float,
47
+ backend: str | None = None,
48
+ ) -> torch.Tensor:
49
+ """
50
+ Decodes video frames using the specified backend.
51
+
52
+ Args:
53
+ video_path (Path): Path to the video file.
54
+ timestamps (list[float]): List of timestamps to extract frames.
55
+ tolerance_s (float): Allowed deviation in seconds for frame retrieval.
56
+ backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
57
+
58
+ Returns:
59
+ torch.Tensor: Decoded frames.
60
+
61
+ Currently supports torchcodec on cpu and pyav.
62
+ """
63
+ if backend is None:
64
+ backend = get_safe_default_codec()
65
+ if backend == "torchcodec":
66
+ return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
67
+ elif backend in ["pyav", "video_reader"]:
68
+ return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
69
+ else:
70
+ raise ValueError(f"Unsupported video backend: {backend}")
71
+
72
+
73
+ def decode_video_frames_torchvision(
74
+ video_path: Path | str,
75
+ timestamps: list[float],
76
+ tolerance_s: float,
77
+ backend: str = "pyav",
78
+ log_loaded_timestamps: bool = False,
79
+ ) -> torch.Tensor:
80
+ """Loads frames associated to the requested timestamps of a video
81
+
82
+ The backend can be either "pyav" (default) or "video_reader".
83
+ "video_reader" requires installing torchvision from source, see:
84
+ https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
85
+ (note that you need to compile against ffmpeg<4.3)
86
+
87
+ While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
88
+ For more info on video decoding, see `benchmark/video/README.md`
89
+
90
+ See torchvision doc for more info on these two backends:
91
+ https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
92
+
93
+ Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
94
+ the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
95
+ that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
96
+ and all subsequent frames until reaching the requested frame. The number of key frames in a video
97
+ can be adjusted during encoding to take into account decoding time and video size in bytes.
98
+ """
99
+ video_path = str(video_path)
100
+
101
+ # set backend
102
+ keyframes_only = False
103
+ torchvision.set_video_backend(backend)
104
+ if backend == "pyav":
105
+ keyframes_only = True # pyav doesnt support accuracte seek
106
+
107
+ # set a video stream reader
108
+ # TODO(rcadene): also load audio stream at the same time
109
+ reader = torchvision.io.VideoReader(video_path, "video")
110
+
111
+ # set the first and last requested timestamps
112
+ # Note: previous timestamps are usually loaded, since we need to access the previous key frame
113
+ first_ts = min(timestamps)
114
+ last_ts = max(timestamps)
115
+
116
+ # access closest key frame of the first requested frame
117
+ # Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
118
+ # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
119
+ reader.seek(first_ts, keyframes_only=keyframes_only)
120
+
121
+ # load all frames until last requested frame
122
+ loaded_frames = []
123
+ loaded_ts = []
124
+ for frame in reader:
125
+ current_ts = frame["pts"]
126
+ if log_loaded_timestamps:
127
+ logging.info(f"frame loaded at timestamp={current_ts:.4f}")
128
+ loaded_frames.append(frame["data"])
129
+ loaded_ts.append(current_ts)
130
+ if current_ts >= last_ts:
131
+ break
132
+
133
+ if backend == "pyav":
134
+ reader.container.close()
135
+
136
+ reader = None
137
+
138
+ query_ts = torch.tensor(timestamps)
139
+ loaded_ts = torch.tensor(loaded_ts)
140
+
141
+ # compute distances between each query timestamp and timestamps of all loaded frames
142
+ dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
143
+ min_, argmin_ = dist.min(1)
144
+
145
+ is_within_tol = min_ < tolerance_s
146
+ assert is_within_tol.all(), (
147
+ f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
148
+ "It means that the closest frame that can be loaded from the video is too far away in time."
149
+ "This might be due to synchronization issues with timestamps during data collection."
150
+ "To be safe, we advise to ignore this item during training."
151
+ f"\nqueried timestamps: {query_ts}"
152
+ f"\nloaded timestamps: {loaded_ts}"
153
+ f"\nvideo: {video_path}"
154
+ f"\nbackend: {backend}"
155
+ )
156
+
157
+ # get closest frames to the query timestamps
158
+ closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
159
+ closest_ts = loaded_ts[argmin_]
160
+
161
+ if log_loaded_timestamps:
162
+ logging.info(f"{closest_ts=}")
163
+
164
+ # convert to the pytorch format which is float32 in [0,1] range (and channel first)
165
+ closest_frames = closest_frames.type(torch.float32) / 255
166
+
167
+ assert len(timestamps) == len(closest_frames)
168
+ return closest_frames
169
+
170
+
171
+ def decode_video_frames_torchcodec(
172
+ video_path: Path | str,
173
+ timestamps: list[float],
174
+ tolerance_s: float,
175
+ device: str = "cpu",
176
+ log_loaded_timestamps: bool = False,
177
+ ) -> torch.Tensor:
178
+ """Loads frames associated with the requested timestamps of a video using torchcodec.
179
+
180
+ Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
181
+
182
+ Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
183
+ the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
184
+ that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
185
+ and all subsequent frames until reaching the requested frame. The number of key frames in a video
186
+ can be adjusted during encoding to take into account decoding time and video size in bytes.
187
+ """
188
+
189
+ if importlib.util.find_spec("torchcodec"):
190
+ from torchcodec.decoders import VideoDecoder
191
+ else:
192
+ raise ImportError("torchcodec is required but not available.")
193
+
194
+ # initialize video decoder
195
+ decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
196
+ loaded_frames = []
197
+ loaded_ts = []
198
+ # get metadata for frame information
199
+ metadata = decoder.metadata
200
+ average_fps = metadata.average_fps
201
+
202
+ # convert timestamps to frame indices
203
+ frame_indices = [round(ts * average_fps) for ts in timestamps]
204
+
205
+ # retrieve frames based on indices
206
+ frames_batch = decoder.get_frames_at(indices=frame_indices)
207
+
208
+ for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
209
+ loaded_frames.append(frame)
210
+ loaded_ts.append(pts.item())
211
+ if log_loaded_timestamps:
212
+ logging.info(f"Frame loaded at timestamp={pts:.4f}")
213
+
214
+ query_ts = torch.tensor(timestamps)
215
+ loaded_ts = torch.tensor(loaded_ts)
216
+
217
+ # compute distances between each query timestamp and loaded timestamps
218
+ dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
219
+ min_, argmin_ = dist.min(1)
220
+
221
+ is_within_tol = min_ < tolerance_s
222
+ assert is_within_tol.all(), (
223
+ f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
224
+ "It means that the closest frame that can be loaded from the video is too far away in time."
225
+ "This might be due to synchronization issues with timestamps during data collection."
226
+ "To be safe, we advise to ignore this item during training."
227
+ f"\nqueried timestamps: {query_ts}"
228
+ f"\nloaded timestamps: {loaded_ts}"
229
+ f"\nvideo: {video_path}"
230
+ )
231
+
232
+ # get closest frames to the query timestamps
233
+ closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
234
+ closest_ts = loaded_ts[argmin_]
235
+
236
+ if log_loaded_timestamps:
237
+ logging.info(f"{closest_ts=}")
238
+
239
+ # convert to float32 in [0,1] range (channel first)
240
+ closest_frames = closest_frames.type(torch.float32) / 255
241
+
242
+ assert len(timestamps) == len(closest_frames)
243
+ return closest_frames
244
+
245
+
246
+ def encode_video_frames(
247
+ imgs_dir: Path | str,
248
+ video_path: Path | str,
249
+ fps: int,
250
+ vcodec: str = "libsvtav1",
251
+ pix_fmt: str = "yuv420p",
252
+ g: int | None = 2,
253
+ crf: int | None = 30,
254
+ fast_decode: int = 0,
255
+ log_level: str | None = "error",
256
+ overwrite: bool = False,
257
+ ) -> None:
258
+ """More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
259
+ video_path = Path(video_path)
260
+ imgs_dir = Path(imgs_dir)
261
+ video_path.parent.mkdir(parents=True, exist_ok=True)
262
+
263
+ ffmpeg_args = OrderedDict(
264
+ [
265
+ ("-f", "image2"),
266
+ ("-r", str(fps)),
267
+ ("-i", str(imgs_dir / "frame_%06d.png")),
268
+ ("-vcodec", vcodec),
269
+ ("-pix_fmt", pix_fmt),
270
+ ]
271
+ )
272
+
273
+ if g is not None:
274
+ ffmpeg_args["-g"] = str(g)
275
+
276
+ if crf is not None:
277
+ ffmpeg_args["-crf"] = str(crf)
278
+
279
+ if fast_decode:
280
+ key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
281
+ value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
282
+ ffmpeg_args[key] = value
283
+
284
+ if log_level is not None:
285
+ ffmpeg_args["-loglevel"] = str(log_level)
286
+
287
+ ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
288
+ if overwrite:
289
+ ffmpeg_args.append("-y")
290
+
291
+ ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
292
+ # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
293
+ subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
294
+
295
+ if not video_path.exists():
296
+ raise OSError(
297
+ f"Video encoding did not work. File not found: {video_path}. "
298
+ f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
299
+ )
300
+
301
+
302
+ @dataclass
303
+ class VideoFrame:
304
+ # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
305
+ """
306
+ Provides a type for a dataset containing video frames.
307
+
308
+ Example:
309
+
310
+ ```python
311
+ data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
312
+ features = {"image": VideoFrame()}
313
+ Dataset.from_dict(data_dict, features=Features(features))
314
+ ```
315
+ """
316
+
317
+ pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
318
+ _type: str = field(default="VideoFrame", init=False, repr=False)
319
+
320
+ def __call__(self):
321
+ return self.pa_type
322
+
323
+
324
+ with warnings.catch_warnings():
325
+ warnings.filterwarnings(
326
+ "ignore",
327
+ "'register_feature' is experimental and might be subject to breaking changes in the future.",
328
+ category=UserWarning,
329
+ )
330
+ # to make VideoFrame available in HuggingFace `datasets`
331
+ register_feature(VideoFrame, "VideoFrame")
332
+
333
+
334
+ def get_audio_info(video_path: Path | str) -> dict:
335
+ ffprobe_audio_cmd = [
336
+ "ffprobe",
337
+ "-v",
338
+ "error",
339
+ "-select_streams",
340
+ "a:0",
341
+ "-show_entries",
342
+ "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
343
+ "-of",
344
+ "json",
345
+ str(video_path),
346
+ ]
347
+ result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
348
+ if result.returncode != 0:
349
+ raise RuntimeError(f"Error running ffprobe: {result.stderr}")
350
+
351
+ info = json.loads(result.stdout)
352
+ audio_stream_info = info["streams"][0] if info.get("streams") else None
353
+ if audio_stream_info is None:
354
+ return {"has_audio": False}
355
+
356
+ # Return the information, defaulting to None if no audio stream is present
357
+ return {
358
+ "has_audio": True,
359
+ "audio.channels": audio_stream_info.get("channels", None),
360
+ "audio.codec": audio_stream_info.get("codec_name", None),
361
+ "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
362
+ "audio.sample_rate": int(audio_stream_info["sample_rate"])
363
+ if audio_stream_info.get("sample_rate")
364
+ else None,
365
+ "audio.bit_depth": audio_stream_info.get("bit_depth", None),
366
+ "audio.channel_layout": audio_stream_info.get("channel_layout", None),
367
+ }
368
+
369
+
370
+ def get_video_info(video_path: Path | str) -> dict:
371
+ ffprobe_video_cmd = [
372
+ "ffprobe",
373
+ "-v",
374
+ "error",
375
+ "-select_streams",
376
+ "v:0",
377
+ "-show_entries",
378
+ "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
379
+ "-of",
380
+ "json",
381
+ str(video_path),
382
+ ]
383
+ result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
384
+ if result.returncode != 0:
385
+ raise RuntimeError(f"Error running ffprobe: {result.stderr}")
386
+
387
+ info = json.loads(result.stdout)
388
+ video_stream_info = info["streams"][0]
389
+
390
+ # Calculate fps from r_frame_rate
391
+ r_frame_rate = video_stream_info["r_frame_rate"]
392
+ num, denom = map(int, r_frame_rate.split("/"))
393
+ fps = num / denom
394
+
395
+ pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
396
+
397
+ video_info = {
398
+ "video.fps": fps,
399
+ "video.height": video_stream_info["height"],
400
+ "video.width": video_stream_info["width"],
401
+ "video.channels": pixel_channels,
402
+ "video.codec": video_stream_info["codec_name"],
403
+ "video.pix_fmt": video_stream_info["pix_fmt"],
404
+ "video.is_depth_map": False,
405
+ **get_audio_info(video_path),
406
+ }
407
+
408
+ return video_info
409
+
410
+
411
+ def get_video_pixel_channels(pix_fmt: str) -> int:
412
+ if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
413
+ return 1
414
+ elif "rgba" in pix_fmt or "yuva" in pix_fmt:
415
+ return 4
416
+ elif "rgb" in pix_fmt or "yuv" in pix_fmt:
417
+ return 3
418
+ else:
419
+ raise ValueError("Unknown format")
420
+
421
+
422
+ def get_image_pixel_channels(image: Image):
423
+ if image.mode == "L":
424
+ return 1 # Grayscale
425
+ elif image.mode == "LA":
426
+ return 2 # Grayscale + Alpha
427
+ elif image.mode == "RGB":
428
+ return 3 # RGB
429
+ elif image.mode == "RGBA":
430
+ return 4 # RGBA
431
+ else:
432
+ raise ValueError("Unknown format")
lerobot/common/envs/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
lerobot/common/envs/configs.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from dataclasses import dataclass, field
17
+
18
+ import draccus
19
+
20
+ from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
21
+ from lerobot.configs.types import FeatureType, PolicyFeature
22
+
23
+
24
+ @dataclass
25
+ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
26
+ task: str | None = None
27
+ fps: int = 30
28
+ features: dict[str, PolicyFeature] = field(default_factory=dict)
29
+ features_map: dict[str, str] = field(default_factory=dict)
30
+
31
+ @property
32
+ def type(self) -> str:
33
+ return self.get_choice_name(self.__class__)
34
+
35
+ @abc.abstractproperty
36
+ def gym_kwargs(self) -> dict:
37
+ raise NotImplementedError()
38
+
39
+
40
+ @EnvConfig.register_subclass("aloha")
41
+ @dataclass
42
+ class AlohaEnv(EnvConfig):
43
+ task: str = "AlohaInsertion-v0"
44
+ fps: int = 50
45
+ episode_length: int = 400
46
+ obs_type: str = "pixels_agent_pos"
47
+ render_mode: str = "rgb_array"
48
+ features: dict[str, PolicyFeature] = field(
49
+ default_factory=lambda: {
50
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
51
+ }
52
+ )
53
+ features_map: dict[str, str] = field(
54
+ default_factory=lambda: {
55
+ "action": ACTION,
56
+ "agent_pos": OBS_ROBOT,
57
+ "top": f"{OBS_IMAGE}.top",
58
+ "pixels/top": f"{OBS_IMAGES}.top",
59
+ }
60
+ )
61
+
62
+ def __post_init__(self):
63
+ if self.obs_type == "pixels":
64
+ self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
65
+ elif self.obs_type == "pixels_agent_pos":
66
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
67
+ self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
68
+
69
+ @property
70
+ def gym_kwargs(self) -> dict:
71
+ return {
72
+ "obs_type": self.obs_type,
73
+ "render_mode": self.render_mode,
74
+ "max_episode_steps": self.episode_length,
75
+ }
76
+
77
+
78
+ @EnvConfig.register_subclass("pusht")
79
+ @dataclass
80
+ class PushtEnv(EnvConfig):
81
+ task: str = "PushT-v0"
82
+ fps: int = 10
83
+ episode_length: int = 300
84
+ obs_type: str = "pixels_agent_pos"
85
+ render_mode: str = "rgb_array"
86
+ visualization_width: int = 384
87
+ visualization_height: int = 384
88
+ features: dict[str, PolicyFeature] = field(
89
+ default_factory=lambda: {
90
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
91
+ "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
92
+ }
93
+ )
94
+ features_map: dict[str, str] = field(
95
+ default_factory=lambda: {
96
+ "action": ACTION,
97
+ "agent_pos": OBS_ROBOT,
98
+ "environment_state": OBS_ENV,
99
+ "pixels": OBS_IMAGE,
100
+ }
101
+ )
102
+
103
+ def __post_init__(self):
104
+ if self.obs_type == "pixels_agent_pos":
105
+ self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
106
+ elif self.obs_type == "environment_state_agent_pos":
107
+ self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
108
+
109
+ @property
110
+ def gym_kwargs(self) -> dict:
111
+ return {
112
+ "obs_type": self.obs_type,
113
+ "render_mode": self.render_mode,
114
+ "visualization_width": self.visualization_width,
115
+ "visualization_height": self.visualization_height,
116
+ "max_episode_steps": self.episode_length,
117
+ }
118
+
119
+
120
+ @EnvConfig.register_subclass("xarm")
121
+ @dataclass
122
+ class XarmEnv(EnvConfig):
123
+ task: str = "XarmLift-v0"
124
+ fps: int = 15
125
+ episode_length: int = 200
126
+ obs_type: str = "pixels_agent_pos"
127
+ render_mode: str = "rgb_array"
128
+ visualization_width: int = 384
129
+ visualization_height: int = 384
130
+ features: dict[str, PolicyFeature] = field(
131
+ default_factory=lambda: {
132
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
133
+ "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
134
+ }
135
+ )
136
+ features_map: dict[str, str] = field(
137
+ default_factory=lambda: {
138
+ "action": ACTION,
139
+ "agent_pos": OBS_ROBOT,
140
+ "pixels": OBS_IMAGE,
141
+ }
142
+ )
143
+
144
+ def __post_init__(self):
145
+ if self.obs_type == "pixels_agent_pos":
146
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
147
+
148
+ @property
149
+ def gym_kwargs(self) -> dict:
150
+ return {
151
+ "obs_type": self.obs_type,
152
+ "render_mode": self.render_mode,
153
+ "visualization_width": self.visualization_width,
154
+ "visualization_height": self.visualization_height,
155
+ "max_episode_steps": self.episode_length,
156
+ }
lerobot/common/envs/factory.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import importlib
17
+
18
+ import gymnasium as gym
19
+
20
+ from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
21
+
22
+
23
+ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
24
+ if env_type == "aloha":
25
+ return AlohaEnv(**kwargs)
26
+ elif env_type == "pusht":
27
+ return PushtEnv(**kwargs)
28
+ elif env_type == "xarm":
29
+ return XarmEnv(**kwargs)
30
+ else:
31
+ raise ValueError(f"Policy type '{env_type}' is not available.")
32
+
33
+
34
+ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
35
+ """Makes a gym vector environment according to the config.
36
+
37
+ Args:
38
+ cfg (EnvConfig): the config of the environment to instantiate.
39
+ n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
40
+ use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
41
+ False.
42
+
43
+ Raises:
44
+ ValueError: if n_envs < 1
45
+ ModuleNotFoundError: If the requested env package is not installed
46
+
47
+ Returns:
48
+ gym.vector.VectorEnv: The parallelized gym.env instance.
49
+ """
50
+ if n_envs < 1:
51
+ raise ValueError("`n_envs must be at least 1")
52
+
53
+ package_name = f"gym_{cfg.type}"
54
+
55
+ try:
56
+ importlib.import_module(package_name)
57
+ except ModuleNotFoundError as e:
58
+ print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
59
+ raise e
60
+
61
+ gym_handle = f"{package_name}/{cfg.task}"
62
+
63
+ # batched version of the env that returns an observation of shape (b, c)
64
+ env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
65
+ env = env_cls(
66
+ [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
67
+ )
68
+
69
+ return env
lerobot/common/envs/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import warnings
17
+ from typing import Any
18
+
19
+ import einops
20
+ import gymnasium as gym
21
+ import numpy as np
22
+ import torch
23
+ from torch import Tensor
24
+
25
+ from lerobot.common.envs.configs import EnvConfig
26
+ from lerobot.common.utils.utils import get_channel_first_image_shape
27
+ from lerobot.configs.types import FeatureType, PolicyFeature
28
+
29
+
30
+ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
31
+ # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
32
+ """Convert environment observation to LeRobot format observation.
33
+ Args:
34
+ observation: Dictionary of observation batches from a Gym vector environment.
35
+ Returns:
36
+ Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
37
+ """
38
+ # map to expected inputs for the policy
39
+ return_observations = {}
40
+ if "pixels" in observations:
41
+ if isinstance(observations["pixels"], dict):
42
+ imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
43
+ else:
44
+ imgs = {"observation.image": observations["pixels"]}
45
+
46
+ for imgkey, img in imgs.items():
47
+ # TODO(aliberts, rcadene): use transforms.ToTensor()?
48
+ img = torch.from_numpy(img)
49
+
50
+ # sanity check that images are channel last
51
+ _, h, w, c = img.shape
52
+ assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
53
+
54
+ # sanity check that images are uint8
55
+ assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
56
+
57
+ # convert to channel first of type float32 in range [0,1]
58
+ img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
59
+ img = img.type(torch.float32)
60
+ img /= 255
61
+
62
+ return_observations[imgkey] = img
63
+
64
+ if "environment_state" in observations:
65
+ return_observations["observation.environment_state"] = torch.from_numpy(
66
+ observations["environment_state"]
67
+ ).float()
68
+
69
+ # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
70
+ # requirement for "agent_pos"
71
+ return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
72
+ return return_observations
73
+
74
+
75
+ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
76
+ # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
77
+ # (need to also refactor preprocess_observation and externalize normalization from policies)
78
+ policy_features = {}
79
+ for key, ft in env_cfg.features.items():
80
+ if ft.type is FeatureType.VISUAL:
81
+ if len(ft.shape) != 3:
82
+ raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
83
+
84
+ shape = get_channel_first_image_shape(ft.shape)
85
+ feature = PolicyFeature(type=ft.type, shape=shape)
86
+ else:
87
+ feature = ft
88
+
89
+ policy_key = env_cfg.features_map[key]
90
+ policy_features[policy_key] = feature
91
+
92
+ return policy_features
93
+
94
+
95
+ def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
96
+ first_type = type(env.envs[0]) # Get type of first env
97
+ return all(type(e) is first_type for e in env.envs) # Fast type check
98
+
99
+
100
+ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
101
+ with warnings.catch_warnings():
102
+ warnings.simplefilter("once", UserWarning) # Apply filter only in this function
103
+
104
+ if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
105
+ warnings.warn(
106
+ "The environment does not have 'task_description' and 'task'. Some policies require these features.",
107
+ UserWarning,
108
+ stacklevel=2,
109
+ )
110
+ if not are_all_envs_same_type(env):
111
+ warnings.warn(
112
+ "The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
113
+ UserWarning,
114
+ stacklevel=2,
115
+ )
116
+
117
+
118
+ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
119
+ """Adds task feature to the observation dict with respect to the first environment attribute."""
120
+ if hasattr(env.envs[0], "task_description"):
121
+ observation["task"] = env.call("task_description")
122
+ elif hasattr(env.envs[0], "task"):
123
+ observation["task"] = env.call("task")
124
+ else: # For envs without language instructions, e.g. aloha transfer cube and etc.
125
+ num_envs = observation[list(observation.keys())[0]].shape[0]
126
+ observation["task"] = ["" for _ in range(num_envs)]
127
+ return observation
lerobot/common/mocks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Common mocks for robot devices and testing
lerobot/common/mocks/cameras/__init__.py ADDED
File without changes
lerobot/common/mocks/cameras/mock_cv2.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import cache
15
+
16
+ import numpy as np
17
+
18
+ CAP_V4L2 = 200
19
+ CAP_DSHOW = 700
20
+ CAP_AVFOUNDATION = 1200
21
+ CAP_ANY = -1
22
+
23
+ CAP_PROP_FPS = 5
24
+ CAP_PROP_FRAME_WIDTH = 3
25
+ CAP_PROP_FRAME_HEIGHT = 4
26
+ COLOR_RGB2BGR = 4
27
+ COLOR_BGR2RGB = 4
28
+
29
+ ROTATE_90_COUNTERCLOCKWISE = 2
30
+ ROTATE_90_CLOCKWISE = 0
31
+ ROTATE_180 = 1
32
+
33
+
34
+ @cache
35
+ def _generate_image(width: int, height: int):
36
+ return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
37
+
38
+
39
+ def cvtColor(color_image, color_conversion): # noqa: N802
40
+ if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
41
+ return color_image[:, :, [2, 1, 0]]
42
+ else:
43
+ raise NotImplementedError(color_conversion)
44
+
45
+
46
+ def rotate(color_image, rotation):
47
+ if rotation is None:
48
+ return color_image
49
+ elif rotation == ROTATE_90_CLOCKWISE:
50
+ return np.rot90(color_image, k=1)
51
+ elif rotation == ROTATE_180:
52
+ return np.rot90(color_image, k=2)
53
+ elif rotation == ROTATE_90_COUNTERCLOCKWISE:
54
+ return np.rot90(color_image, k=3)
55
+ else:
56
+ raise NotImplementedError(rotation)
57
+
58
+
59
+ class VideoCapture:
60
+ def __init__(self, *args, **kwargs):
61
+ self._mock_dict = {
62
+ CAP_PROP_FPS: 30,
63
+ CAP_PROP_FRAME_WIDTH: 640,
64
+ CAP_PROP_FRAME_HEIGHT: 480,
65
+ }
66
+ self._is_opened = True
67
+
68
+ def isOpened(self): # noqa: N802
69
+ return self._is_opened
70
+
71
+ def set(self, propId: int, value: float) -> bool: # noqa: N803
72
+ if not self._is_opened:
73
+ raise RuntimeError("Camera is not opened")
74
+ self._mock_dict[propId] = value
75
+ return True
76
+
77
+ def get(self, propId: int) -> float: # noqa: N803
78
+ if not self._is_opened:
79
+ raise RuntimeError("Camera is not opened")
80
+ value = self._mock_dict[propId]
81
+ if value == 0:
82
+ if propId == CAP_PROP_FRAME_HEIGHT:
83
+ value = 480
84
+ elif propId == CAP_PROP_FRAME_WIDTH:
85
+ value = 640
86
+ return value
87
+
88
+ def read(self):
89
+ if not self._is_opened:
90
+ raise RuntimeError("Camera is not opened")
91
+ h = self.get(CAP_PROP_FRAME_HEIGHT)
92
+ w = self.get(CAP_PROP_FRAME_WIDTH)
93
+ ret = True
94
+ return ret, _generate_image(width=w, height=h)
95
+
96
+ def release(self):
97
+ self._is_opened = False
98
+
99
+ def __del__(self):
100
+ if self._is_opened:
101
+ self.release()
lerobot/common/mocks/cameras/mock_pyrealsense2.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import enum
15
+
16
+ import numpy as np
17
+
18
+
19
+ class stream(enum.Enum): # noqa: N801
20
+ color = 0
21
+ depth = 1
22
+
23
+
24
+ class format(enum.Enum): # noqa: N801
25
+ rgb8 = 0
26
+ z16 = 1
27
+
28
+
29
+ class config: # noqa: N801
30
+ def enable_device(self, device_id: str):
31
+ self.device_enabled = device_id
32
+
33
+ def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
34
+ self.stream_type = stream_type
35
+ # Overwrite default values when possible
36
+ self.width = 848 if width is None else width
37
+ self.height = 480 if height is None else height
38
+ self.color_format = format.rgb8 if color_format is None else color_format
39
+ self.fps = 30 if fps is None else fps
40
+
41
+
42
+ class RSColorProfile:
43
+ def __init__(self, config):
44
+ self.config = config
45
+
46
+ def fps(self):
47
+ return self.config.fps
48
+
49
+ def width(self):
50
+ return self.config.width
51
+
52
+ def height(self):
53
+ return self.config.height
54
+
55
+
56
+ class RSColorStream:
57
+ def __init__(self, config):
58
+ self.config = config
59
+
60
+ def as_video_stream_profile(self):
61
+ return RSColorProfile(self.config)
62
+
63
+
64
+ class RSProfile:
65
+ def __init__(self, config):
66
+ self.config = config
67
+
68
+ def get_stream(self, color_format):
69
+ del color_format # unused
70
+ return RSColorStream(self.config)
71
+
72
+
73
+ class pipeline: # noqa: N801
74
+ def __init__(self):
75
+ self.started = False
76
+ self.config = None
77
+
78
+ def start(self, config):
79
+ self.started = True
80
+ self.config = config
81
+ return RSProfile(self.config)
82
+
83
+ def stop(self):
84
+ if not self.started:
85
+ raise RuntimeError("You need to start the camera before stop.")
86
+ self.started = False
87
+ self.config = None
88
+
89
+ def wait_for_frames(self, timeout_ms=50000):
90
+ del timeout_ms # unused
91
+ return RSFrames(self.config)
92
+
93
+
94
+ class RSFrames:
95
+ def __init__(self, config):
96
+ self.config = config
97
+
98
+ def get_color_frame(self):
99
+ return RSColorFrame(self.config)
100
+
101
+ def get_depth_frame(self):
102
+ return RSDepthFrame(self.config)
103
+
104
+
105
+ class RSColorFrame:
106
+ def __init__(self, config):
107
+ self.config = config
108
+
109
+ def get_data(self):
110
+ data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8)
111
+ # Create a difference between rgb and bgr
112
+ data[:, :, 0] = 2
113
+ return data
114
+
115
+
116
+ class RSDepthFrame:
117
+ def __init__(self, config):
118
+ self.config = config
119
+
120
+ def get_data(self):
121
+ return np.ones((self.config.height, self.config.width), dtype=np.uint16)
122
+
123
+
124
+ class RSDevice:
125
+ def __init__(self):
126
+ pass
127
+
128
+ def get_info(self, camera_info) -> str:
129
+ del camera_info # unused
130
+ # return fake serial number
131
+ return "123456789"
132
+
133
+
134
+ class context: # noqa: N801
135
+ def __init__(self):
136
+ pass
137
+
138
+ def query_devices(self):
139
+ return [RSDevice()]
140
+
141
+
142
+ class camera_info: # noqa: N801
143
+ # fake name
144
+ name = "Intel RealSense D435I"
145
+
146
+ def __init__(self, serial_number):
147
+ del serial_number
148
+ pass
lerobot/common/mocks/motors/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Mocks for motor modules
lerobot/common/mocks/motors/mock_dynamixel_sdk.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Mocked classes and functions from dynamixel_sdk to allow for continuous integration
15
+ and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
16
+
17
+ Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
18
+ from the original classes and functions (e.g. return types might be None instead of boolean).
19
+ """
20
+
21
+ # from dynamixel_sdk import COMM_SUCCESS
22
+
23
+ DEFAULT_BAUDRATE = 9_600
24
+ COMM_SUCCESS = 0 # tx or rx packet communication success
25
+
26
+
27
+ def convert_to_bytes(value, bytes):
28
+ # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
29
+ # `convert_bytes_to_value`
30
+ del bytes # unused
31
+ return value
32
+
33
+
34
+ def get_default_motor_values(motor_index):
35
+ return {
36
+ # Key (int) are from X_SERIES_CONTROL_TABLE
37
+ 7: motor_index, # ID
38
+ 8: DEFAULT_BAUDRATE, # Baud_rate
39
+ 10: 0, # Drive_Mode
40
+ 64: 0, # Torque_Enable
41
+ # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
42
+ # For other joints, 2560 will be autocorrected to be in calibration range
43
+ 132: 2560, # Present_Position
44
+ }
45
+
46
+
47
+ class PortHandler:
48
+ def __init__(self, port):
49
+ self.port = port
50
+ # factory default baudrate
51
+ self.baudrate = DEFAULT_BAUDRATE
52
+
53
+ def openPort(self): # noqa: N802
54
+ return True
55
+
56
+ def closePort(self): # noqa: N802
57
+ pass
58
+
59
+ def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
60
+ del timeout_ms # unused
61
+
62
+ def getBaudRate(self): # noqa: N802
63
+ return self.baudrate
64
+
65
+ def setBaudRate(self, baudrate): # noqa: N802
66
+ self.baudrate = baudrate
67
+
68
+
69
+ class PacketHandler:
70
+ def __init__(self, protocol_version):
71
+ del protocol_version # unused
72
+ # Use packet_handler.data to communicate across Read and Write
73
+ self.data = {}
74
+
75
+
76
+ class GroupSyncRead:
77
+ def __init__(self, port_handler, packet_handler, address, bytes):
78
+ self.packet_handler = packet_handler
79
+
80
+ def addParam(self, motor_index): # noqa: N802
81
+ # Initialize motor default values
82
+ if motor_index not in self.packet_handler.data:
83
+ self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
84
+
85
+ def txRxPacket(self): # noqa: N802
86
+ return COMM_SUCCESS
87
+
88
+ def getData(self, index, address, bytes): # noqa: N802
89
+ return self.packet_handler.data[index][address]
90
+
91
+
92
+ class GroupSyncWrite:
93
+ def __init__(self, port_handler, packet_handler, address, bytes):
94
+ self.packet_handler = packet_handler
95
+ self.address = address
96
+
97
+ def addParam(self, index, data): # noqa: N802
98
+ # Initialize motor default values
99
+ if index not in self.packet_handler.data:
100
+ self.packet_handler.data[index] = get_default_motor_values(index)
101
+ self.changeParam(index, data)
102
+
103
+ def txPacket(self): # noqa: N802
104
+ return COMM_SUCCESS
105
+
106
+ def changeParam(self, index, data): # noqa: N802
107
+ self.packet_handler.data[index][self.address] = data
lerobot/common/mocks/motors/mock_scservo_sdk.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Mocked classes and functions from dynamixel_sdk to allow for continuous integration
15
+ and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
16
+
17
+ Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
18
+ from the original classes and functions (e.g. return types might be None instead of boolean).
19
+ """
20
+
21
+ # from dynamixel_sdk import COMM_SUCCESS
22
+
23
+ DEFAULT_BAUDRATE = 1_000_000
24
+ COMM_SUCCESS = 0 # tx or rx packet communication success
25
+
26
+
27
+ def convert_to_bytes(value, bytes):
28
+ # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
29
+ # `convert_bytes_to_value`
30
+ del bytes # unused
31
+ return value
32
+
33
+
34
+ def get_default_motor_values(motor_index):
35
+ return {
36
+ # Key (int) are from SCS_SERIES_CONTROL_TABLE
37
+ 5: motor_index, # ID
38
+ 6: DEFAULT_BAUDRATE, # Baud_rate
39
+ 10: 0, # Drive_Mode
40
+ 21: 32, # P_Coefficient
41
+ 22: 32, # D_Coefficient
42
+ 23: 0, # I_Coefficient
43
+ 40: 0, # Torque_Enable
44
+ 41: 254, # Acceleration
45
+ 31: -2047, # Offset
46
+ 33: 0, # Mode
47
+ 55: 1, # Lock
48
+ # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
49
+ # For other joints, 2560 will be autocorrected to be in calibration range
50
+ 56: 2560, # Present_Position
51
+ 58: 0, # Present_Speed
52
+ 69: 0, # Present_Current
53
+ 85: 150, # Maximum_Acceleration
54
+ }
55
+
56
+
57
+ class PortHandler:
58
+ def __init__(self, port):
59
+ self.port = port
60
+ # factory default baudrate
61
+ self.baudrate = DEFAULT_BAUDRATE
62
+ self.ser = SerialMock()
63
+
64
+ def openPort(self): # noqa: N802
65
+ return True
66
+
67
+ def closePort(self): # noqa: N802
68
+ pass
69
+
70
+ def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
71
+ del timeout_ms # unused
72
+
73
+ def getBaudRate(self): # noqa: N802
74
+ return self.baudrate
75
+
76
+ def setBaudRate(self, baudrate): # noqa: N802
77
+ self.baudrate = baudrate
78
+
79
+
80
+ class PacketHandler:
81
+ def __init__(self, protocol_version):
82
+ del protocol_version # unused
83
+ # Use packet_handler.data to communicate across Read and Write
84
+ self.data = {}
85
+
86
+
87
+ class GroupSyncRead:
88
+ def __init__(self, port_handler, packet_handler, address, bytes):
89
+ self.packet_handler = packet_handler
90
+
91
+ def addParam(self, motor_index): # noqa: N802
92
+ # Initialize motor default values
93
+ if motor_index not in self.packet_handler.data:
94
+ self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
95
+
96
+ def txRxPacket(self): # noqa: N802
97
+ return COMM_SUCCESS
98
+
99
+ def getData(self, index, address, bytes): # noqa: N802
100
+ return self.packet_handler.data[index][address]
101
+
102
+
103
+ class GroupSyncWrite:
104
+ def __init__(self, port_handler, packet_handler, address, bytes):
105
+ self.packet_handler = packet_handler
106
+ self.address = address
107
+
108
+ def addParam(self, index, data): # noqa: N802
109
+ if index not in self.packet_handler.data:
110
+ self.packet_handler.data[index] = get_default_motor_values(index)
111
+ self.changeParam(index, data)
112
+
113
+ def txPacket(self): # noqa: N802
114
+ return COMM_SUCCESS
115
+
116
+ def changeParam(self, index, data): # noqa: N802
117
+ self.packet_handler.data[index][self.address] = data
118
+
119
+
120
+ class SerialMock:
121
+ def reset_output_buffer(self):
122
+ pass
123
+
124
+ def reset_input_buffer(self):
125
+ pass
lerobot/common/optim/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .optimizers import OptimizerConfig as OptimizerConfig
lerobot/common/optim/factory.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from torch.optim import Optimizer
19
+ from torch.optim.lr_scheduler import LRScheduler
20
+
21
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
22
+ from lerobot.configs.train import TrainPipelineConfig
23
+
24
+
25
+ def make_optimizer_and_scheduler(
26
+ cfg: TrainPipelineConfig, policy: PreTrainedPolicy
27
+ ) -> tuple[Optimizer, LRScheduler | None]:
28
+ """Generates the optimizer and scheduler based on configs.
29
+
30
+ Args:
31
+ cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
32
+ policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
33
+
34
+ Returns:
35
+ tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
36
+ """
37
+ params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
38
+ optimizer = cfg.optimizer.build(params)
39
+ lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
40
+ return optimizer, lr_scheduler
lerobot/common/optim/optimizers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import abc
17
+ from dataclasses import asdict, dataclass
18
+ from pathlib import Path
19
+
20
+ import draccus
21
+ import torch
22
+ from safetensors.torch import load_file, save_file
23
+
24
+ from lerobot.common.constants import (
25
+ OPTIMIZER_PARAM_GROUPS,
26
+ OPTIMIZER_STATE,
27
+ )
28
+ from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
29
+ from lerobot.common.utils.io_utils import deserialize_json_into_object
30
+
31
+
32
+ @dataclass
33
+ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
34
+ lr: float
35
+ weight_decay: float
36
+ grad_clip_norm: float
37
+
38
+ @property
39
+ def type(self) -> str:
40
+ return self.get_choice_name(self.__class__)
41
+
42
+ @classmethod
43
+ def default_choice_name(cls) -> str | None:
44
+ return "adam"
45
+
46
+ @abc.abstractmethod
47
+ def build(self) -> torch.optim.Optimizer:
48
+ raise NotImplementedError
49
+
50
+
51
+ @OptimizerConfig.register_subclass("adam")
52
+ @dataclass
53
+ class AdamConfig(OptimizerConfig):
54
+ lr: float = 1e-3
55
+ betas: tuple[float, float] = (0.9, 0.999)
56
+ eps: float = 1e-8
57
+ weight_decay: float = 0.0
58
+ grad_clip_norm: float = 10.0
59
+
60
+ def build(self, params: dict) -> torch.optim.Optimizer:
61
+ kwargs = asdict(self)
62
+ kwargs.pop("grad_clip_norm")
63
+ return torch.optim.Adam(params, **kwargs)
64
+
65
+
66
+ @OptimizerConfig.register_subclass("adamw")
67
+ @dataclass
68
+ class AdamWConfig(OptimizerConfig):
69
+ lr: float = 1e-3
70
+ betas: tuple[float, float] = (0.9, 0.999)
71
+ eps: float = 1e-8
72
+ weight_decay: float = 1e-2
73
+ grad_clip_norm: float = 10.0
74
+
75
+ def build(self, params: dict) -> torch.optim.Optimizer:
76
+ kwargs = asdict(self)
77
+ kwargs.pop("grad_clip_norm")
78
+ return torch.optim.AdamW(params, **kwargs)
79
+
80
+
81
+ @OptimizerConfig.register_subclass("sgd")
82
+ @dataclass
83
+ class SGDConfig(OptimizerConfig):
84
+ lr: float = 1e-3
85
+ momentum: float = 0.0
86
+ dampening: float = 0.0
87
+ nesterov: bool = False
88
+ weight_decay: float = 0.0
89
+ grad_clip_norm: float = 10.0
90
+
91
+ def build(self, params: dict) -> torch.optim.Optimizer:
92
+ kwargs = asdict(self)
93
+ kwargs.pop("grad_clip_norm")
94
+ return torch.optim.SGD(params, **kwargs)
95
+
96
+
97
+ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
98
+ state = optimizer.state_dict()
99
+ param_groups = state.pop("param_groups")
100
+ flat_state = flatten_dict(state)
101
+ save_file(flat_state, save_dir / OPTIMIZER_STATE)
102
+ write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
103
+
104
+
105
+ def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
106
+ current_state_dict = optimizer.state_dict()
107
+ flat_state = load_file(save_dir / OPTIMIZER_STATE)
108
+ state = unflatten_dict(flat_state)
109
+ loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
110
+
111
+ if "param_groups" in current_state_dict:
112
+ param_groups = deserialize_json_into_object(
113
+ save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
114
+ )
115
+ loaded_state_dict["param_groups"] = param_groups
116
+
117
+ optimizer.load_state_dict(loaded_state_dict)
118
+ return optimizer
lerobot/common/optim/schedulers.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import abc
17
+ import math
18
+ from dataclasses import asdict, dataclass
19
+ from pathlib import Path
20
+
21
+ import draccus
22
+ from torch.optim import Optimizer
23
+ from torch.optim.lr_scheduler import LambdaLR, LRScheduler
24
+
25
+ from lerobot.common.constants import SCHEDULER_STATE
26
+ from lerobot.common.datasets.utils import write_json
27
+ from lerobot.common.utils.io_utils import deserialize_json_into_object
28
+
29
+
30
+ @dataclass
31
+ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
32
+ num_warmup_steps: int
33
+
34
+ @property
35
+ def type(self) -> str:
36
+ return self.get_choice_name(self.__class__)
37
+
38
+ @abc.abstractmethod
39
+ def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
40
+ raise NotImplementedError
41
+
42
+
43
+ @LRSchedulerConfig.register_subclass("diffuser")
44
+ @dataclass
45
+ class DiffuserSchedulerConfig(LRSchedulerConfig):
46
+ name: str = "cosine"
47
+ num_warmup_steps: int | None = None
48
+
49
+ def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
50
+ from diffusers.optimization import get_scheduler
51
+
52
+ kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
53
+ return get_scheduler(**kwargs)
54
+
55
+
56
+ @LRSchedulerConfig.register_subclass("vqbet")
57
+ @dataclass
58
+ class VQBeTSchedulerConfig(LRSchedulerConfig):
59
+ num_warmup_steps: int
60
+ num_vqvae_training_steps: int
61
+ num_cycles: float = 0.5
62
+
63
+ def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
64
+ def lr_lambda(current_step):
65
+ if current_step < self.num_vqvae_training_steps:
66
+ return float(1)
67
+ else:
68
+ adjusted_step = current_step - self.num_vqvae_training_steps
69
+ if adjusted_step < self.num_warmup_steps:
70
+ return float(adjusted_step) / float(max(1, self.num_warmup_steps))
71
+ progress = float(adjusted_step - self.num_warmup_steps) / float(
72
+ max(1, num_training_steps - self.num_warmup_steps)
73
+ )
74
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
75
+
76
+ return LambdaLR(optimizer, lr_lambda, -1)
77
+
78
+
79
+ @LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
80
+ @dataclass
81
+ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
82
+ """Used by Physical Intelligence to train Pi0"""
83
+
84
+ num_warmup_steps: int
85
+ num_decay_steps: int
86
+ peak_lr: float
87
+ decay_lr: float
88
+
89
+ def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
90
+ del num_training_steps
91
+
92
+ def lr_lambda(current_step):
93
+ def linear_warmup_schedule(current_step):
94
+ if current_step <= 0:
95
+ return 1 / (self.num_warmup_steps + 1)
96
+ frac = 1 - current_step / self.num_warmup_steps
97
+ return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
98
+
99
+ def cosine_decay_schedule(current_step):
100
+ step = min(current_step, self.num_decay_steps)
101
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
102
+ alpha = self.decay_lr / self.peak_lr
103
+ decayed = (1 - alpha) * cosine_decay + alpha
104
+ return decayed
105
+
106
+ if current_step < self.num_warmup_steps:
107
+ return linear_warmup_schedule(current_step)
108
+
109
+ return cosine_decay_schedule(current_step)
110
+
111
+ return LambdaLR(optimizer, lr_lambda, -1)
112
+
113
+
114
+ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
115
+ state_dict = scheduler.state_dict()
116
+ write_json(state_dict, save_dir / SCHEDULER_STATE)
117
+
118
+
119
+ def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
120
+ state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
121
+ scheduler.load_state_dict(state_dict)
122
+ return scheduler
lerobot/common/policies/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .act.configuration_act import ACTConfig as ACTConfig
16
+ from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
17
+ from .pi0.configuration_pi0 import PI0Config as PI0Config
18
+ from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
19
+ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
lerobot/common/policies/act/configuration_act.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from dataclasses import dataclass, field
17
+
18
+ from lerobot.common.optim.optimizers import AdamWConfig
19
+ from lerobot.configs.policies import PreTrainedConfig
20
+ from lerobot.configs.types import NormalizationMode
21
+
22
+
23
+ @PreTrainedConfig.register_subclass("act")
24
+ @dataclass
25
+ class ACTConfig(PreTrainedConfig):
26
+ """Configuration class for the Action Chunking Transformers policy.
27
+
28
+ Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
29
+
30
+ The parameters you will most likely need to change are the ones which depend on the environment / sensors.
31
+ Those are: `input_shapes` and 'output_shapes`.
32
+
33
+ Notes on the inputs and outputs:
34
+ - Either:
35
+ - At least one key starting with "observation.image is required as an input.
36
+ AND/OR
37
+ - The key "observation.environment_state" is required as input.
38
+ - If there are multiple keys beginning with "observation.images." they are treated as multiple camera
39
+ views. Right now we only support all images having the same shape.
40
+ - May optionally work without an "observation.state" key for the proprioceptive robot state.
41
+ - "action" is required as an output key.
42
+
43
+ Args:
44
+ n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
45
+ current step and additional steps going back).
46
+ chunk_size: The size of the action prediction "chunks" in units of environment steps.
47
+ n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
48
+ This should be no greater than the chunk size. For example, if the chunk size size 100, you may
49
+ set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
50
+ environment, and throws the other 50 out.
51
+ input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
52
+ the input data name, and the value is a list indicating the dimensions of the corresponding data.
53
+ For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
54
+ indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
55
+ include batch dimension or temporal dimension.
56
+ output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
57
+ the output data name, and the value is a list indicating the dimensions of the corresponding data.
58
+ For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
59
+ Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
60
+ input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
61
+ and the value specifies the normalization mode to apply. The two available modes are "mean_std"
62
+ which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
63
+ [-1, 1] range.
64
+ output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
65
+ original scale. Note that this is also used for normalizing the training targets.
66
+ vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
67
+ pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
68
+ `None` means no pretrained weights.
69
+ replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
70
+ convolution.
71
+ pre_norm: Whether to use "pre-norm" in the transformer blocks.
72
+ dim_model: The transformer blocks' main hidden dimension.
73
+ n_heads: The number of heads to use in the transformer blocks' multi-head attention.
74
+ dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
75
+ layers.
76
+ feedforward_activation: The activation to use in the transformer block's feed-forward layers.
77
+ n_encoder_layers: The number of transformer layers to use for the transformer encoder.
78
+ n_decoder_layers: The number of transformer layers to use for the transformer decoder.
79
+ use_vae: Whether to use a variational objective during training. This introduces another transformer
80
+ which is used as the VAE's encoder (not to be confused with the transformer encoder - see
81
+ documentation in the policy class).
82
+ latent_dim: The VAE's latent dimension.
83
+ n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
84
+ temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
85
+ ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
86
+ 1 when using this feature, as inference needs to happen at every step to form an ensemble. For
87
+ more information on how ensembling works, please see `ACTTemporalEnsembler`.
88
+ dropout: Dropout to use in the transformer layers (see code for details).
89
+ kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
90
+ is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
91
+ """
92
+
93
+ # Input / output structure.
94
+ n_obs_steps: int = 1
95
+ chunk_size: int = 100
96
+ n_action_steps: int = 100
97
+
98
+ normalization_mapping: dict[str, NormalizationMode] = field(
99
+ default_factory=lambda: {
100
+ "VISUAL": NormalizationMode.MEAN_STD,
101
+ "STATE": NormalizationMode.MEAN_STD,
102
+ "ACTION": NormalizationMode.MEAN_STD,
103
+ }
104
+ )
105
+
106
+ # Architecture.
107
+ # Vision backbone.
108
+ vision_backbone: str = "resnet18"
109
+ pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
110
+ replace_final_stride_with_dilation: int = False
111
+ # Transformer layers.
112
+ pre_norm: bool = False
113
+ dim_model: int = 512
114
+ n_heads: int = 8
115
+ dim_feedforward: int = 3200
116
+ feedforward_activation: str = "relu"
117
+ n_encoder_layers: int = 4
118
+ # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
119
+ # that means only the first layer is used. Here we match the original implementation by setting this to 1.
120
+ # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
121
+ n_decoder_layers: int = 1
122
+ # VAE.
123
+ use_vae: bool = True
124
+ latent_dim: int = 32
125
+ n_vae_encoder_layers: int = 4
126
+
127
+ # Inference.
128
+ # Note: the value used in ACT when temporal ensembling is enabled is 0.01.
129
+ temporal_ensemble_coeff: float | None = None
130
+
131
+ # Training and loss computation.
132
+ dropout: float = 0.1
133
+ kl_weight: float = 10.0
134
+
135
+ # Training preset
136
+ optimizer_lr: float = 1e-5
137
+ optimizer_weight_decay: float = 1e-4
138
+ optimizer_lr_backbone: float = 1e-5
139
+
140
+ def __post_init__(self):
141
+ super().__post_init__()
142
+
143
+ """Input validation (not exhaustive)."""
144
+ if not self.vision_backbone.startswith("resnet"):
145
+ raise ValueError(
146
+ f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
147
+ )
148
+ if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
149
+ raise NotImplementedError(
150
+ "`n_action_steps` must be 1 when using temporal ensembling. This is "
151
+ "because the policy needs to be queried every step to compute the ensembled action."
152
+ )
153
+ if self.n_action_steps > self.chunk_size:
154
+ raise ValueError(
155
+ f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
156
+ f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
157
+ )
158
+ if self.n_obs_steps != 1:
159
+ raise ValueError(
160
+ f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
161
+ )
162
+
163
+ def get_optimizer_preset(self) -> AdamWConfig:
164
+ return AdamWConfig(
165
+ lr=self.optimizer_lr,
166
+ weight_decay=self.optimizer_weight_decay,
167
+ )
168
+
169
+ def get_scheduler_preset(self) -> None:
170
+ return None
171
+
172
+ def validate_features(self) -> None:
173
+ if not self.image_features and not self.env_state_feature:
174
+ raise ValueError("You must provide at least one image or the environment state among the inputs.")
175
+
176
+ @property
177
+ def observation_delta_indices(self) -> None:
178
+ return None
179
+
180
+ @property
181
+ def action_delta_indices(self) -> list:
182
+ return list(range(self.chunk_size))
183
+
184
+ @property
185
+ def reward_delta_indices(self) -> None:
186
+ return None
lerobot/common/policies/act/modeling_act.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Action Chunking Transformer Policy
17
+
18
+ As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
19
+ The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
20
+ """
21
+
22
+ import math
23
+ from collections import deque
24
+ from itertools import chain
25
+ from typing import Callable
26
+
27
+ import einops
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F # noqa: N812
31
+ import torchvision
32
+ from torch import Tensor, nn
33
+ from torchvision.models._utils import IntermediateLayerGetter
34
+ from torchvision.ops.misc import FrozenBatchNorm2d
35
+
36
+ from lerobot.common.policies.act.configuration_act import ACTConfig
37
+ from lerobot.common.policies.normalize import Normalize, Unnormalize
38
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
39
+
40
+
41
+ class ACTPolicy(PreTrainedPolicy):
42
+ """
43
+ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
44
+ Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
45
+ """
46
+
47
+ config_class = ACTConfig
48
+ name = "act"
49
+
50
+ def __init__(
51
+ self,
52
+ config: ACTConfig,
53
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
54
+ ):
55
+ """
56
+ Args:
57
+ config: Policy configuration class instance or None, in which case the default instantiation of
58
+ the configuration class is used.
59
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
60
+ that they will be passed with a call to `load_state_dict` before the policy is used.
61
+ """
62
+ super().__init__(config)
63
+ config.validate_features()
64
+ self.config = config
65
+
66
+ self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
67
+ self.normalize_targets = Normalize(
68
+ config.output_features, config.normalization_mapping, dataset_stats
69
+ )
70
+ self.unnormalize_outputs = Unnormalize(
71
+ config.output_features, config.normalization_mapping, dataset_stats
72
+ )
73
+
74
+ self.model = ACT(config)
75
+
76
+ if config.temporal_ensemble_coeff is not None:
77
+ self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
78
+
79
+ self.reset()
80
+
81
+ def get_optim_params(self) -> dict:
82
+ # TODO(aliberts, rcadene): As of now, lr_backbone == lr
83
+ # Should we remove this and just `return self.parameters()`?
84
+ return [
85
+ {
86
+ "params": [
87
+ p
88
+ for n, p in self.named_parameters()
89
+ if not n.startswith("model.backbone") and p.requires_grad
90
+ ]
91
+ },
92
+ {
93
+ "params": [
94
+ p
95
+ for n, p in self.named_parameters()
96
+ if n.startswith("model.backbone") and p.requires_grad
97
+ ],
98
+ "lr": self.config.optimizer_lr_backbone,
99
+ },
100
+ ]
101
+
102
+ def reset(self):
103
+ """This should be called whenever the environment is reset."""
104
+ if self.config.temporal_ensemble_coeff is not None:
105
+ self.temporal_ensembler.reset()
106
+ else:
107
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
108
+
109
+ @torch.no_grad
110
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
111
+ """Select a single action given environment observations.
112
+
113
+ This method wraps `select_actions` in order to return one action at a time for execution in the
114
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
115
+ queue is empty.
116
+ """
117
+ self.eval()
118
+
119
+ batch = self.normalize_inputs(batch)
120
+ if self.config.image_features:
121
+ batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
122
+ batch["observation.images"] = [batch[key] for key in self.config.image_features]
123
+
124
+ # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
125
+ # we are ensembling over.
126
+ if self.config.temporal_ensemble_coeff is not None:
127
+ actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
128
+ actions = self.unnormalize_outputs({"action": actions})["action"]
129
+ action = self.temporal_ensembler.update(actions)
130
+ return action
131
+
132
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
133
+ # querying the policy.
134
+ if len(self._action_queue) == 0:
135
+ actions = self.model(batch)[0][:, : self.config.n_action_steps]
136
+
137
+ # TODO(rcadene): make _forward return output dictionary?
138
+ actions = self.unnormalize_outputs({"action": actions})["action"]
139
+
140
+ # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
141
+ # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
142
+ self._action_queue.extend(actions.transpose(0, 1))
143
+ return self._action_queue.popleft()
144
+
145
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
146
+ """Run the batch through the model and compute the loss for training or validation."""
147
+ batch = self.normalize_inputs(batch)
148
+ if self.config.image_features:
149
+ batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
150
+ batch["observation.images"] = [batch[key] for key in self.config.image_features]
151
+
152
+ batch = self.normalize_targets(batch)
153
+ actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
154
+
155
+ l1_loss = (
156
+ F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
157
+ ).mean()
158
+
159
+ loss_dict = {"l1_loss": l1_loss.item()}
160
+ if self.config.use_vae:
161
+ # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
162
+ # each dimension independently, we sum over the latent dimension to get the total
163
+ # KL-divergence per batch element, then take the mean over the batch.
164
+ # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
165
+ mean_kld = (
166
+ (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
167
+ )
168
+ loss_dict["kld_loss"] = mean_kld.item()
169
+ loss = l1_loss + mean_kld * self.config.kl_weight
170
+ else:
171
+ loss = l1_loss
172
+
173
+ return loss, loss_dict
174
+
175
+
176
+ class ACTTemporalEnsembler:
177
+ def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
178
+ """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
179
+
180
+ The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
181
+ They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
182
+ coefficient works:
183
+ - Setting it to 0 uniformly weighs all actions.
184
+ - Setting it positive gives more weight to older actions.
185
+ - Setting it negative gives more weight to newer actions.
186
+ NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
187
+ results in older actions being weighed more highly than newer actions (the experiments documented in
188
+ https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
189
+ detrimental: doing so aggressively may diminish the benefits of action chunking).
190
+
191
+ Here we use an online method for computing the average rather than caching a history of actions in
192
+ order to compute the average offline. For a simple 1D sequence it looks something like:
193
+
194
+ ```
195
+ import torch
196
+
197
+ seq = torch.linspace(8, 8.5, 100)
198
+ print(seq)
199
+
200
+ m = 0.01
201
+ exp_weights = torch.exp(-m * torch.arange(len(seq)))
202
+ print(exp_weights)
203
+
204
+ # Calculate offline
205
+ avg = (exp_weights * seq).sum() / exp_weights.sum()
206
+ print("offline", avg)
207
+
208
+ # Calculate online
209
+ for i, item in enumerate(seq):
210
+ if i == 0:
211
+ avg = item
212
+ continue
213
+ avg *= exp_weights[:i].sum()
214
+ avg += item * exp_weights[i]
215
+ avg /= exp_weights[:i+1].sum()
216
+ print("online", avg)
217
+ ```
218
+ """
219
+ self.chunk_size = chunk_size
220
+ self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
221
+ self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
222
+ self.reset()
223
+
224
+ def reset(self):
225
+ """Resets the online computation variables."""
226
+ self.ensembled_actions = None
227
+ # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
228
+ self.ensembled_actions_count = None
229
+
230
+ def update(self, actions: Tensor) -> Tensor:
231
+ """
232
+ Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
233
+ time steps, and pop/return the next batch of actions in the sequence.
234
+ """
235
+ self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
236
+ self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
237
+ if self.ensembled_actions is None:
238
+ # Initializes `self._ensembled_action` to the sequence of actions predicted during the first
239
+ # time step of the episode.
240
+ self.ensembled_actions = actions.clone()
241
+ # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
242
+ # operations later.
243
+ self.ensembled_actions_count = torch.ones(
244
+ (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
245
+ )
246
+ else:
247
+ # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
248
+ # the online update for those entries.
249
+ self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
250
+ self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
251
+ self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
252
+ self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
253
+ # The last action, which has no prior online average, needs to get concatenated onto the end.
254
+ self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
255
+ self.ensembled_actions_count = torch.cat(
256
+ [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
257
+ )
258
+ # "Consume" the first action.
259
+ action, self.ensembled_actions, self.ensembled_actions_count = (
260
+ self.ensembled_actions[:, 0],
261
+ self.ensembled_actions[:, 1:],
262
+ self.ensembled_actions_count[1:],
263
+ )
264
+ return action
265
+
266
+
267
+ class ACT(nn.Module):
268
+ """Action Chunking Transformer: The underlying neural network for ACTPolicy.
269
+
270
+ Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
271
+ - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
272
+ model that encodes the target data (a sequence of actions), and the condition (the robot
273
+ joint-space).
274
+ - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
275
+ cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
276
+ have an option to train this model without the variational objective (in which case we drop the
277
+ `vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
278
+
279
+ Transformer
280
+ Used alone for inference
281
+ (acts as VAE decoder
282
+ during training)
283
+ ┌───────────────────────┐
284
+ │ Outputs │
285
+ │ ▲ │
286
+ │ ┌─────►┌───────┐ │
287
+ ┌──────┐ │ │ │Transf.│ │
288
+ │ │ │ ├─────►│decoder│ │
289
+ ┌────┴────┐ │ │ │ │ │ │
290
+ │ │ │ │ ┌───┴───┬─►│ │ │
291
+ │ VAE │ │ │ │ │ └───────┘ │
292
+ │ encoder │ │ │ │Transf.│ │
293
+ │ │ │ │ │encoder│ │
294
+ └───▲─────┘ │ │ │ │ │
295
+ │ │ │ └▲──▲─▲─┘ │
296
+ │ │ │ │ │ │ │
297
+ inputs └─────┼──┘ │ image emb. │
298
+ │ state emb. │
299
+ └───────────────────────┘
300
+ """
301
+
302
+ def __init__(self, config: ACTConfig):
303
+ # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
304
+ # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
305
+ super().__init__()
306
+ self.config = config
307
+
308
+ if self.config.use_vae:
309
+ self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
310
+ self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
311
+ # Projection layer for joint-space configuration to hidden dimension.
312
+ if self.config.robot_state_feature:
313
+ self.vae_encoder_robot_state_input_proj = nn.Linear(
314
+ self.config.robot_state_feature.shape[0], config.dim_model
315
+ )
316
+ # Projection layer for action (joint-space target) to hidden dimension.
317
+ self.vae_encoder_action_input_proj = nn.Linear(
318
+ self.config.action_feature.shape[0],
319
+ config.dim_model,
320
+ )
321
+ # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
322
+ self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
323
+ # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
324
+ # dimension.
325
+ num_input_token_encoder = 1 + config.chunk_size
326
+ if self.config.robot_state_feature:
327
+ num_input_token_encoder += 1
328
+ self.register_buffer(
329
+ "vae_encoder_pos_enc",
330
+ create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
331
+ )
332
+
333
+ # Backbone for image feature extraction.
334
+ if self.config.image_features:
335
+ backbone_model = getattr(torchvision.models, config.vision_backbone)(
336
+ replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
337
+ weights=config.pretrained_backbone_weights,
338
+ norm_layer=FrozenBatchNorm2d,
339
+ )
340
+ # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
341
+ # feature map).
342
+ # Note: The forward method of this returns a dict: {"feature_map": output}.
343
+ self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
344
+
345
+ # Transformer (acts as VAE decoder when training with the variational objective).
346
+ self.encoder = ACTEncoder(config)
347
+ self.decoder = ACTDecoder(config)
348
+
349
+ # Transformer encoder input projections. The tokens will be structured like
350
+ # [latent, (robot_state), (env_state), (image_feature_map_pixels)].
351
+ if self.config.robot_state_feature:
352
+ self.encoder_robot_state_input_proj = nn.Linear(
353
+ self.config.robot_state_feature.shape[0], config.dim_model
354
+ )
355
+ if self.config.env_state_feature:
356
+ self.encoder_env_state_input_proj = nn.Linear(
357
+ self.config.env_state_feature.shape[0], config.dim_model
358
+ )
359
+ self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
360
+ if self.config.image_features:
361
+ self.encoder_img_feat_input_proj = nn.Conv2d(
362
+ backbone_model.fc.in_features, config.dim_model, kernel_size=1
363
+ )
364
+ # Transformer encoder positional embeddings.
365
+ n_1d_tokens = 1 # for the latent
366
+ if self.config.robot_state_feature:
367
+ n_1d_tokens += 1
368
+ if self.config.env_state_feature:
369
+ n_1d_tokens += 1
370
+ self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
371
+ if self.config.image_features:
372
+ self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
373
+
374
+ # Transformer decoder.
375
+ # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
376
+ self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
377
+
378
+ # Final action regression head on the output of the transformer's decoder.
379
+ self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
380
+
381
+ self._reset_parameters()
382
+
383
+ def _reset_parameters(self):
384
+ """Xavier-uniform initialization of the transformer parameters as in the original code."""
385
+ for p in chain(self.encoder.parameters(), self.decoder.parameters()):
386
+ if p.dim() > 1:
387
+ nn.init.xavier_uniform_(p)
388
+
389
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
390
+ """A forward pass through the Action Chunking Transformer (with optional VAE encoder).
391
+
392
+ `batch` should have the following structure:
393
+ {
394
+ [robot_state_feature] (optional): (B, state_dim) batch of robot states.
395
+
396
+ [image_features]: (B, n_cameras, C, H, W) batch of images.
397
+ AND/OR
398
+ [env_state_feature]: (B, env_dim) batch of environment states.
399
+
400
+ [action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
401
+ }
402
+
403
+ Returns:
404
+ (B, chunk_size, action_dim) batch of action sequences
405
+ Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
406
+ latent dimension.
407
+ """
408
+ if self.config.use_vae and self.training:
409
+ assert "action" in batch, (
410
+ "actions must be provided when using the variational objective in training mode."
411
+ )
412
+
413
+ if "observation.images" in batch:
414
+ batch_size = batch["observation.images"][0].shape[0]
415
+ else:
416
+ batch_size = batch["observation.environment_state"].shape[0]
417
+
418
+ # Prepare the latent for input to the transformer encoder.
419
+ if self.config.use_vae and "action" in batch:
420
+ # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
421
+ cls_embed = einops.repeat(
422
+ self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
423
+ ) # (B, 1, D)
424
+ if self.config.robot_state_feature:
425
+ robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
426
+ robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
427
+ action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
428
+
429
+ if self.config.robot_state_feature:
430
+ vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
431
+ else:
432
+ vae_encoder_input = [cls_embed, action_embed]
433
+ vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
434
+
435
+ # Prepare fixed positional embedding.
436
+ # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
437
+ pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
438
+
439
+ # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
440
+ # sequence depending whether we use the input states or not (cls and robot state)
441
+ # False means not a padding token.
442
+ cls_joint_is_pad = torch.full(
443
+ (batch_size, 2 if self.config.robot_state_feature else 1),
444
+ False,
445
+ device=batch["observation.state"].device,
446
+ )
447
+ key_padding_mask = torch.cat(
448
+ [cls_joint_is_pad, batch["action_is_pad"]], axis=1
449
+ ) # (bs, seq+1 or 2)
450
+
451
+ # Forward pass through VAE encoder to get the latent PDF parameters.
452
+ cls_token_out = self.vae_encoder(
453
+ vae_encoder_input.permute(1, 0, 2),
454
+ pos_embed=pos_embed.permute(1, 0, 2),
455
+ key_padding_mask=key_padding_mask,
456
+ )[0] # select the class token, with shape (B, D)
457
+ latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
458
+ mu = latent_pdf_params[:, : self.config.latent_dim]
459
+ # This is 2log(sigma). Done this way to match the original implementation.
460
+ log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
461
+
462
+ # Sample the latent with the reparameterization trick.
463
+ latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
464
+ else:
465
+ # When not using the VAE encoder, we set the latent to be all zeros.
466
+ mu = log_sigma_x2 = None
467
+ # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
468
+ latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
469
+ batch["observation.state"].device
470
+ )
471
+
472
+ # Prepare transformer encoder inputs.
473
+ encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
474
+ encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
475
+ # Robot state token.
476
+ if self.config.robot_state_feature:
477
+ encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
478
+ # Environment state token.
479
+ if self.config.env_state_feature:
480
+ encoder_in_tokens.append(
481
+ self.encoder_env_state_input_proj(batch["observation.environment_state"])
482
+ )
483
+
484
+ # Camera observation features and positional embeddings.
485
+ if self.config.image_features:
486
+ all_cam_features = []
487
+ all_cam_pos_embeds = []
488
+
489
+ # For a list of images, the H and W may vary but H*W is constant.
490
+ for img in batch["observation.images"]:
491
+ cam_features = self.backbone(img)["feature_map"]
492
+ cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
493
+ cam_features = self.encoder_img_feat_input_proj(cam_features)
494
+
495
+ # Rearrange features to (sequence, batch, dim).
496
+ cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
497
+ cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
498
+
499
+ all_cam_features.append(cam_features)
500
+ all_cam_pos_embeds.append(cam_pos_embed)
501
+
502
+ encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
503
+ encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
504
+
505
+ # Stack all tokens along the sequence dimension.
506
+ encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
507
+ encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
508
+
509
+ # Forward pass through the transformer modules.
510
+ encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
511
+ # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
512
+ decoder_in = torch.zeros(
513
+ (self.config.chunk_size, batch_size, self.config.dim_model),
514
+ dtype=encoder_in_pos_embed.dtype,
515
+ device=encoder_in_pos_embed.device,
516
+ )
517
+ decoder_out = self.decoder(
518
+ decoder_in,
519
+ encoder_out,
520
+ encoder_pos_embed=encoder_in_pos_embed,
521
+ decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
522
+ )
523
+
524
+ # Move back to (B, S, C).
525
+ decoder_out = decoder_out.transpose(0, 1)
526
+
527
+ actions = self.action_head(decoder_out)
528
+
529
+ return actions, (mu, log_sigma_x2)
530
+
531
+
532
+ class ACTEncoder(nn.Module):
533
+ """Convenience module for running multiple encoder layers, maybe followed by normalization."""
534
+
535
+ def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
536
+ super().__init__()
537
+ self.is_vae_encoder = is_vae_encoder
538
+ num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
539
+ self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
540
+ self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
541
+
542
+ def forward(
543
+ self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
544
+ ) -> Tensor:
545
+ for layer in self.layers:
546
+ x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
547
+ x = self.norm(x)
548
+ return x
549
+
550
+
551
+ class ACTEncoderLayer(nn.Module):
552
+ def __init__(self, config: ACTConfig):
553
+ super().__init__()
554
+ self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
555
+
556
+ # Feed forward layers.
557
+ self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
558
+ self.dropout = nn.Dropout(config.dropout)
559
+ self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
560
+
561
+ self.norm1 = nn.LayerNorm(config.dim_model)
562
+ self.norm2 = nn.LayerNorm(config.dim_model)
563
+ self.dropout1 = nn.Dropout(config.dropout)
564
+ self.dropout2 = nn.Dropout(config.dropout)
565
+
566
+ self.activation = get_activation_fn(config.feedforward_activation)
567
+ self.pre_norm = config.pre_norm
568
+
569
+ def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
570
+ skip = x
571
+ if self.pre_norm:
572
+ x = self.norm1(x)
573
+ q = k = x if pos_embed is None else x + pos_embed
574
+ x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
575
+ x = x[0] # note: [0] to select just the output, not the attention weights
576
+ x = skip + self.dropout1(x)
577
+ if self.pre_norm:
578
+ skip = x
579
+ x = self.norm2(x)
580
+ else:
581
+ x = self.norm1(x)
582
+ skip = x
583
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
584
+ x = skip + self.dropout2(x)
585
+ if not self.pre_norm:
586
+ x = self.norm2(x)
587
+ return x
588
+
589
+
590
+ class ACTDecoder(nn.Module):
591
+ def __init__(self, config: ACTConfig):
592
+ """Convenience module for running multiple decoder layers followed by normalization."""
593
+ super().__init__()
594
+ self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
595
+ self.norm = nn.LayerNorm(config.dim_model)
596
+
597
+ def forward(
598
+ self,
599
+ x: Tensor,
600
+ encoder_out: Tensor,
601
+ decoder_pos_embed: Tensor | None = None,
602
+ encoder_pos_embed: Tensor | None = None,
603
+ ) -> Tensor:
604
+ for layer in self.layers:
605
+ x = layer(
606
+ x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
607
+ )
608
+ if self.norm is not None:
609
+ x = self.norm(x)
610
+ return x
611
+
612
+
613
+ class ACTDecoderLayer(nn.Module):
614
+ def __init__(self, config: ACTConfig):
615
+ super().__init__()
616
+ self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
617
+ self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
618
+
619
+ # Feed forward layers.
620
+ self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
621
+ self.dropout = nn.Dropout(config.dropout)
622
+ self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)
623
+
624
+ self.norm1 = nn.LayerNorm(config.dim_model)
625
+ self.norm2 = nn.LayerNorm(config.dim_model)
626
+ self.norm3 = nn.LayerNorm(config.dim_model)
627
+ self.dropout1 = nn.Dropout(config.dropout)
628
+ self.dropout2 = nn.Dropout(config.dropout)
629
+ self.dropout3 = nn.Dropout(config.dropout)
630
+
631
+ self.activation = get_activation_fn(config.feedforward_activation)
632
+ self.pre_norm = config.pre_norm
633
+
634
+ def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
635
+ return tensor if pos_embed is None else tensor + pos_embed
636
+
637
+ def forward(
638
+ self,
639
+ x: Tensor,
640
+ encoder_out: Tensor,
641
+ decoder_pos_embed: Tensor | None = None,
642
+ encoder_pos_embed: Tensor | None = None,
643
+ ) -> Tensor:
644
+ """
645
+ Args:
646
+ x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
647
+ encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
648
+ cross-attending with.
649
+ decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
650
+ encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
651
+ Returns:
652
+ (DS, B, C) tensor of decoder output features.
653
+ """
654
+ skip = x
655
+ if self.pre_norm:
656
+ x = self.norm1(x)
657
+ q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
658
+ x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
659
+ x = skip + self.dropout1(x)
660
+ if self.pre_norm:
661
+ skip = x
662
+ x = self.norm2(x)
663
+ else:
664
+ x = self.norm1(x)
665
+ skip = x
666
+ x = self.multihead_attn(
667
+ query=self.maybe_add_pos_embed(x, decoder_pos_embed),
668
+ key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
669
+ value=encoder_out,
670
+ )[0] # select just the output, not the attention weights
671
+ x = skip + self.dropout2(x)
672
+ if self.pre_norm:
673
+ skip = x
674
+ x = self.norm3(x)
675
+ else:
676
+ x = self.norm2(x)
677
+ skip = x
678
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
679
+ x = skip + self.dropout3(x)
680
+ if not self.pre_norm:
681
+ x = self.norm3(x)
682
+ return x
683
+
684
+
685
+ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
686
+ """1D sinusoidal positional embeddings as in Attention is All You Need.
687
+
688
+ Args:
689
+ num_positions: Number of token positions required.
690
+ Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
691
+
692
+ """
693
+
694
+ def get_position_angle_vec(position):
695
+ return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
696
+
697
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
698
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
699
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
700
+ return torch.from_numpy(sinusoid_table).float()
701
+
702
+
703
+ class ACTSinusoidalPositionEmbedding2d(nn.Module):
704
+ """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
705
+
706
+ The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
707
+ for the vertical direction, and 1/W for the horizontal direction.
708
+ """
709
+
710
+ def __init__(self, dimension: int):
711
+ """
712
+ Args:
713
+ dimension: The desired dimension of the embeddings.
714
+ """
715
+ super().__init__()
716
+ self.dimension = dimension
717
+ self._two_pi = 2 * math.pi
718
+ self._eps = 1e-6
719
+ # Inverse "common ratio" for the geometric progression in sinusoid frequencies.
720
+ self._temperature = 10000
721
+
722
+ def forward(self, x: Tensor) -> Tensor:
723
+ """
724
+ Args:
725
+ x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
726
+ Returns:
727
+ A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
728
+ """
729
+ not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
730
+ # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
731
+ # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
732
+ y_range = not_mask.cumsum(1, dtype=torch.float32)
733
+ x_range = not_mask.cumsum(2, dtype=torch.float32)
734
+
735
+ # "Normalize" the position index such that it ranges in [0, 2π].
736
+ # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
737
+ # are non-zero by construction. This is an artifact of the original code.
738
+ y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
739
+ x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
740
+
741
+ inverse_frequency = self._temperature ** (
742
+ 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
743
+ )
744
+
745
+ x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
746
+ y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
747
+
748
+ # Note: this stack then flatten operation results in interleaved sine and cosine terms.
749
+ # pos_embed_x and pos_embed_y are (1, H, W, C // 2).
750
+ pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
751
+ pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
752
+ pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
753
+
754
+ return pos_embed
755
+
756
+
757
+ def get_activation_fn(activation: str) -> Callable:
758
+ """Return an activation function given a string."""
759
+ if activation == "relu":
760
+ return F.relu
761
+ if activation == "gelu":
762
+ return F.gelu
763
+ if activation == "glu":
764
+ return F.glu
765
+ raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
lerobot/common/policies/diffusion/configuration_diffusion.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
4
+ # and The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from dataclasses import dataclass, field
18
+
19
+ from lerobot.common.optim.optimizers import AdamConfig
20
+ from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
21
+ from lerobot.configs.policies import PreTrainedConfig
22
+ from lerobot.configs.types import NormalizationMode
23
+
24
+
25
+ @PreTrainedConfig.register_subclass("diffusion")
26
+ @dataclass
27
+ class DiffusionConfig(PreTrainedConfig):
28
+ """Configuration class for DiffusionPolicy.
29
+
30
+ Defaults are configured for training with PushT providing proprioceptive and single camera observations.
31
+
32
+ The parameters you will most likely need to change are the ones which depend on the environment / sensors.
33
+ Those are: `input_shapes` and `output_shapes`.
34
+
35
+ Notes on the inputs and outputs:
36
+ - "observation.state" is required as an input key.
37
+ - Either:
38
+ - At least one key starting with "observation.image is required as an input.
39
+ AND/OR
40
+ - The key "observation.environment_state" is required as input.
41
+ - If there are multiple keys beginning with "observation.image" they are treated as multiple camera
42
+ views. Right now we only support all images having the same shape.
43
+ - "action" is required as an output key.
44
+
45
+ Args:
46
+ n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
47
+ current step and additional steps going back).
48
+ horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
49
+ n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
50
+ See `DiffusionPolicy.select_action` for more details.
51
+ input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
52
+ the input data name, and the value is a list indicating the dimensions of the corresponding data.
53
+ For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
54
+ indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
55
+ include batch dimension or temporal dimension.
56
+ output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
57
+ the output data name, and the value is a list indicating the dimensions of the corresponding data.
58
+ For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
59
+ Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
60
+ input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
61
+ and the value specifies the normalization mode to apply. The two available modes are "mean_std"
62
+ which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
63
+ [-1, 1] range.
64
+ output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
65
+ original scale. Note that this is also used for normalizing the training targets.
66
+ vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
67
+ crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
68
+ within the image size. If None, no cropping is done.
69
+ crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
70
+ mode).
71
+ pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
72
+ `None` means no pretrained weights.
73
+ use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
74
+ The group sizes are set to be about 16 (to be precise, feature_dim // 16).
75
+ spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
76
+ use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
77
+ down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
78
+ You may provide a variable number of dimensions, therefore also controlling the degree of
79
+ downsampling.
80
+ kernel_size: The convolutional kernel size of the diffusion modeling Unet.
81
+ n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
82
+ diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
83
+ network. This is the output dimension of that network, i.e., the embedding dimension.
84
+ use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
85
+ Bias modulation is used be default, while this parameter indicates whether to also use scale
86
+ modulation.
87
+ noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
88
+ num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
89
+ beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
90
+ beta_start: Beta value for the first forward-diffusion step.
91
+ beta_end: Beta value for the last forward-diffusion step.
92
+ prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
93
+ or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
94
+ "epsilon" has been shown to work better in many deep neural network settings.
95
+ clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
96
+ denoising step at inference time. WARNING: you will need to make sure your action-space is
97
+ normalized to fit within this range.
98
+ clip_sample_range: The magnitude of the clipping range as described above.
99
+ num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
100
+ spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
101
+ do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
102
+ `LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
103
+ to False as the original Diffusion Policy implementation does the same.
104
+ """
105
+
106
+ # Inputs / output structure.
107
+ n_obs_steps: int = 2
108
+ horizon: int = 16
109
+ n_action_steps: int = 8
110
+
111
+ normalization_mapping: dict[str, NormalizationMode] = field(
112
+ default_factory=lambda: {
113
+ "VISUAL": NormalizationMode.MEAN_STD,
114
+ "STATE": NormalizationMode.MIN_MAX,
115
+ "ACTION": NormalizationMode.MIN_MAX,
116
+ }
117
+ )
118
+
119
+ # The original implementation doesn't sample frames for the last 7 steps,
120
+ # which avoids excessive padding and leads to improved training results.
121
+ drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
122
+
123
+ # Architecture / modeling.
124
+ # Vision backbone.
125
+ vision_backbone: str = "resnet18"
126
+ crop_shape: tuple[int, int] | None = (84, 84)
127
+ crop_is_random: bool = True
128
+ pretrained_backbone_weights: str | None = None
129
+ use_group_norm: bool = True
130
+ spatial_softmax_num_keypoints: int = 32
131
+ use_separate_rgb_encoder_per_camera: bool = False
132
+ # Unet.
133
+ down_dims: tuple[int, ...] = (512, 1024, 2048)
134
+ kernel_size: int = 5
135
+ n_groups: int = 8
136
+ diffusion_step_embed_dim: int = 128
137
+ use_film_scale_modulation: bool = True
138
+ # Noise scheduler.
139
+ noise_scheduler_type: str = "DDPM"
140
+ num_train_timesteps: int = 100
141
+ beta_schedule: str = "squaredcos_cap_v2"
142
+ beta_start: float = 0.0001
143
+ beta_end: float = 0.02
144
+ prediction_type: str = "epsilon"
145
+ clip_sample: bool = True
146
+ clip_sample_range: float = 1.0
147
+
148
+ # Inference
149
+ num_inference_steps: int | None = None
150
+
151
+ # Loss computation
152
+ do_mask_loss_for_padding: bool = False
153
+
154
+ # Training presets
155
+ optimizer_lr: float = 1e-4
156
+ optimizer_betas: tuple = (0.95, 0.999)
157
+ optimizer_eps: float = 1e-8
158
+ optimizer_weight_decay: float = 1e-6
159
+ scheduler_name: str = "cosine"
160
+ scheduler_warmup_steps: int = 500
161
+
162
+ def __post_init__(self):
163
+ super().__post_init__()
164
+
165
+ """Input validation (not exhaustive)."""
166
+ if not self.vision_backbone.startswith("resnet"):
167
+ raise ValueError(
168
+ f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
169
+ )
170
+
171
+ supported_prediction_types = ["epsilon", "sample"]
172
+ if self.prediction_type not in supported_prediction_types:
173
+ raise ValueError(
174
+ f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
175
+ )
176
+ supported_noise_schedulers = ["DDPM", "DDIM"]
177
+ if self.noise_scheduler_type not in supported_noise_schedulers:
178
+ raise ValueError(
179
+ f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
180
+ f"Got {self.noise_scheduler_type}."
181
+ )
182
+
183
+ # Check that the horizon size and U-Net downsampling is compatible.
184
+ # U-Net downsamples by 2 with each stage.
185
+ downsampling_factor = 2 ** len(self.down_dims)
186
+ if self.horizon % downsampling_factor != 0:
187
+ raise ValueError(
188
+ "The horizon should be an integer multiple of the downsampling factor (which is determined "
189
+ f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
190
+ )
191
+
192
+ def get_optimizer_preset(self) -> AdamConfig:
193
+ return AdamConfig(
194
+ lr=self.optimizer_lr,
195
+ betas=self.optimizer_betas,
196
+ eps=self.optimizer_eps,
197
+ weight_decay=self.optimizer_weight_decay,
198
+ )
199
+
200
+ def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
201
+ return DiffuserSchedulerConfig(
202
+ name=self.scheduler_name,
203
+ num_warmup_steps=self.scheduler_warmup_steps,
204
+ )
205
+
206
+ def validate_features(self) -> None:
207
+ if len(self.image_features) == 0 and self.env_state_feature is None:
208
+ raise ValueError("You must provide at least one image or the environment state among the inputs.")
209
+
210
+ if self.crop_shape is not None:
211
+ for key, image_ft in self.image_features.items():
212
+ if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
213
+ raise ValueError(
214
+ f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
215
+ f"for `crop_shape` and {image_ft.shape} for "
216
+ f"`{key}`."
217
+ )
218
+
219
+ # Check that all input images have the same shape.
220
+ first_image_key, first_image_ft = next(iter(self.image_features.items()))
221
+ for key, image_ft in self.image_features.items():
222
+ if image_ft.shape != first_image_ft.shape:
223
+ raise ValueError(
224
+ f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
225
+ )
226
+
227
+ @property
228
+ def observation_delta_indices(self) -> list:
229
+ return list(range(1 - self.n_obs_steps, 1))
230
+
231
+ @property
232
+ def action_delta_indices(self) -> list:
233
+ return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
234
+
235
+ @property
236
+ def reward_delta_indices(self) -> None:
237
+ return None
lerobot/common/policies/diffusion/modeling_diffusion.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
4
+ # and The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
18
+
19
+ TODO(alexander-soare):
20
+ - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
21
+ """
22
+
23
+ import math
24
+ from collections import deque
25
+ from typing import Callable
26
+
27
+ import einops
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F # noqa: N812
31
+ import torchvision
32
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
33
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
34
+ from torch import Tensor, nn
35
+
36
+ from lerobot.common.constants import OBS_ENV, OBS_ROBOT
37
+ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
38
+ from lerobot.common.policies.normalize import Normalize, Unnormalize
39
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
40
+ from lerobot.common.policies.utils import (
41
+ get_device_from_parameters,
42
+ get_dtype_from_parameters,
43
+ get_output_shape,
44
+ populate_queues,
45
+ )
46
+
47
+
48
+ class DiffusionPolicy(PreTrainedPolicy):
49
+ """
50
+ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
51
+ (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
52
+ """
53
+
54
+ config_class = DiffusionConfig
55
+ name = "diffusion"
56
+
57
+ def __init__(
58
+ self,
59
+ config: DiffusionConfig,
60
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
61
+ ):
62
+ """
63
+ Args:
64
+ config: Policy configuration class instance or None, in which case the default instantiation of
65
+ the configuration class is used.
66
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
67
+ that they will be passed with a call to `load_state_dict` before the policy is used.
68
+ """
69
+ super().__init__(config)
70
+ config.validate_features()
71
+ self.config = config
72
+
73
+ self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
74
+ self.normalize_targets = Normalize(
75
+ config.output_features, config.normalization_mapping, dataset_stats
76
+ )
77
+ self.unnormalize_outputs = Unnormalize(
78
+ config.output_features, config.normalization_mapping, dataset_stats
79
+ )
80
+
81
+ # queues are populated during rollout of the policy, they contain the n latest observations and actions
82
+ self._queues = None
83
+
84
+ self.diffusion = DiffusionModel(config)
85
+
86
+ self.reset()
87
+
88
+ def get_optim_params(self) -> dict:
89
+ return self.diffusion.parameters()
90
+
91
+ def reset(self):
92
+ """Clear observation and action queues. Should be called on `env.reset()`"""
93
+ self._queues = {
94
+ "observation.state": deque(maxlen=self.config.n_obs_steps),
95
+ "action": deque(maxlen=self.config.n_action_steps),
96
+ }
97
+ if self.config.image_features:
98
+ self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
99
+ if self.config.env_state_feature:
100
+ self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
101
+
102
+ @torch.no_grad
103
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
104
+ """Select a single action given environment observations.
105
+
106
+ This method handles caching a history of observations and an action trajectory generated by the
107
+ underlying diffusion model. Here's how it works:
108
+ - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
109
+ copied `n_obs_steps` times to fill the cache).
110
+ - The diffusion model generates `horizon` steps worth of actions.
111
+ - `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
112
+ Schematically this looks like:
113
+ ----------------------------------------------------------------------------------------------
114
+ (legend: o = n_obs_steps, h = horizon, a = n_action_steps)
115
+ |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
116
+ |observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
117
+ |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
118
+ |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
119
+ ----------------------------------------------------------------------------------------------
120
+ Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
121
+ "horizon" may not the best name to describe what the variable actually means, because this period is
122
+ actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
123
+ """
124
+ batch = self.normalize_inputs(batch)
125
+ if self.config.image_features:
126
+ batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
127
+ batch["observation.images"] = torch.stack(
128
+ [batch[key] for key in self.config.image_features], dim=-4
129
+ )
130
+ # Note: It's important that this happens after stacking the images into a single key.
131
+ self._queues = populate_queues(self._queues, batch)
132
+
133
+ if len(self._queues["action"]) == 0:
134
+ # stack n latest observations from the queue
135
+ batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
136
+ actions = self.diffusion.generate_actions(batch)
137
+
138
+ # TODO(rcadene): make above methods return output dictionary?
139
+ actions = self.unnormalize_outputs({"action": actions})["action"]
140
+
141
+ self._queues["action"].extend(actions.transpose(0, 1))
142
+
143
+ action = self._queues["action"].popleft()
144
+ return action
145
+
146
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
147
+ """Run the batch through the model and compute the loss for training or validation."""
148
+ batch = self.normalize_inputs(batch)
149
+ if self.config.image_features:
150
+ batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
151
+ batch["observation.images"] = torch.stack(
152
+ [batch[key] for key in self.config.image_features], dim=-4
153
+ )
154
+ batch = self.normalize_targets(batch)
155
+ loss = self.diffusion.compute_loss(batch)
156
+ # no output_dict so returning None
157
+ return loss, None
158
+
159
+
160
+ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
161
+ """
162
+ Factory for noise scheduler instances of the requested type. All kwargs are passed
163
+ to the scheduler.
164
+ """
165
+ if name == "DDPM":
166
+ return DDPMScheduler(**kwargs)
167
+ elif name == "DDIM":
168
+ return DDIMScheduler(**kwargs)
169
+ else:
170
+ raise ValueError(f"Unsupported noise scheduler type {name}")
171
+
172
+
173
+ class DiffusionModel(nn.Module):
174
+ def __init__(self, config: DiffusionConfig):
175
+ super().__init__()
176
+ self.config = config
177
+
178
+ # Build observation encoders (depending on which observations are provided).
179
+ global_cond_dim = self.config.robot_state_feature.shape[0]
180
+ if self.config.image_features:
181
+ num_images = len(self.config.image_features)
182
+ if self.config.use_separate_rgb_encoder_per_camera:
183
+ encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
184
+ self.rgb_encoder = nn.ModuleList(encoders)
185
+ global_cond_dim += encoders[0].feature_dim * num_images
186
+ else:
187
+ self.rgb_encoder = DiffusionRgbEncoder(config)
188
+ global_cond_dim += self.rgb_encoder.feature_dim * num_images
189
+ if self.config.env_state_feature:
190
+ global_cond_dim += self.config.env_state_feature.shape[0]
191
+
192
+ self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
193
+
194
+ self.noise_scheduler = _make_noise_scheduler(
195
+ config.noise_scheduler_type,
196
+ num_train_timesteps=config.num_train_timesteps,
197
+ beta_start=config.beta_start,
198
+ beta_end=config.beta_end,
199
+ beta_schedule=config.beta_schedule,
200
+ clip_sample=config.clip_sample,
201
+ clip_sample_range=config.clip_sample_range,
202
+ prediction_type=config.prediction_type,
203
+ )
204
+
205
+ if config.num_inference_steps is None:
206
+ self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
207
+ else:
208
+ self.num_inference_steps = config.num_inference_steps
209
+
210
+ # ========= inference ============
211
+ def conditional_sample(
212
+ self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
213
+ ) -> Tensor:
214
+ device = get_device_from_parameters(self)
215
+ dtype = get_dtype_from_parameters(self)
216
+
217
+ # Sample prior.
218
+ sample = torch.randn(
219
+ size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
220
+ dtype=dtype,
221
+ device=device,
222
+ generator=generator,
223
+ )
224
+
225
+ self.noise_scheduler.set_timesteps(self.num_inference_steps)
226
+
227
+ for t in self.noise_scheduler.timesteps:
228
+ # Predict model output.
229
+ model_output = self.unet(
230
+ sample,
231
+ torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
232
+ global_cond=global_cond,
233
+ )
234
+ # Compute previous image: x_t -> x_t-1
235
+ sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
236
+
237
+ return sample
238
+
239
+ def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
240
+ """Encode image features and concatenate them all together along with the state vector."""
241
+ batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
242
+ global_cond_feats = [batch[OBS_ROBOT]]
243
+ # Extract image features.
244
+ if self.config.image_features:
245
+ if self.config.use_separate_rgb_encoder_per_camera:
246
+ # Combine batch and sequence dims while rearranging to make the camera index dimension first.
247
+ images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
248
+ img_features_list = torch.cat(
249
+ [
250
+ encoder(images)
251
+ for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
252
+ ]
253
+ )
254
+ # Separate batch and sequence dims back out. The camera index dim gets absorbed into the
255
+ # feature dim (effectively concatenating the camera features).
256
+ img_features = einops.rearrange(
257
+ img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
258
+ )
259
+ else:
260
+ # Combine batch, sequence, and "which camera" dims before passing to shared encoder.
261
+ img_features = self.rgb_encoder(
262
+ einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
263
+ )
264
+ # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
265
+ # feature dim (effectively concatenating the camera features).
266
+ img_features = einops.rearrange(
267
+ img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
268
+ )
269
+ global_cond_feats.append(img_features)
270
+
271
+ if self.config.env_state_feature:
272
+ global_cond_feats.append(batch[OBS_ENV])
273
+
274
+ # Concatenate features then flatten to (B, global_cond_dim).
275
+ return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
276
+
277
+ def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
278
+ """
279
+ This function expects `batch` to have:
280
+ {
281
+ "observation.state": (B, n_obs_steps, state_dim)
282
+
283
+ "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
284
+ AND/OR
285
+ "observation.environment_state": (B, environment_dim)
286
+ }
287
+ """
288
+ batch_size, n_obs_steps = batch["observation.state"].shape[:2]
289
+ assert n_obs_steps == self.config.n_obs_steps
290
+
291
+ # Encode image features and concatenate them all together along with the state vector.
292
+ global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
293
+
294
+ # run sampling
295
+ actions = self.conditional_sample(batch_size, global_cond=global_cond)
296
+
297
+ # Extract `n_action_steps` steps worth of actions (from the current observation).
298
+ start = n_obs_steps - 1
299
+ end = start + self.config.n_action_steps
300
+ actions = actions[:, start:end]
301
+
302
+ return actions
303
+
304
+ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
305
+ """
306
+ This function expects `batch` to have (at least):
307
+ {
308
+ "observation.state": (B, n_obs_steps, state_dim)
309
+
310
+ "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
311
+ AND/OR
312
+ "observation.environment_state": (B, environment_dim)
313
+
314
+ "action": (B, horizon, action_dim)
315
+ "action_is_pad": (B, horizon)
316
+ }
317
+ """
318
+ # Input validation.
319
+ assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
320
+ assert "observation.images" in batch or "observation.environment_state" in batch
321
+ n_obs_steps = batch["observation.state"].shape[1]
322
+ horizon = batch["action"].shape[1]
323
+ assert horizon == self.config.horizon
324
+ assert n_obs_steps == self.config.n_obs_steps
325
+
326
+ # Encode image features and concatenate them all together along with the state vector.
327
+ global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
328
+
329
+ # Forward diffusion.
330
+ trajectory = batch["action"]
331
+ # Sample noise to add to the trajectory.
332
+ eps = torch.randn(trajectory.shape, device=trajectory.device)
333
+ # Sample a random noising timestep for each item in the batch.
334
+ timesteps = torch.randint(
335
+ low=0,
336
+ high=self.noise_scheduler.config.num_train_timesteps,
337
+ size=(trajectory.shape[0],),
338
+ device=trajectory.device,
339
+ ).long()
340
+ # Add noise to the clean trajectories according to the noise magnitude at each timestep.
341
+ noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
342
+
343
+ # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
344
+ pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
345
+
346
+ # Compute the loss.
347
+ # The target is either the original trajectory, or the noise.
348
+ if self.config.prediction_type == "epsilon":
349
+ target = eps
350
+ elif self.config.prediction_type == "sample":
351
+ target = batch["action"]
352
+ else:
353
+ raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
354
+
355
+ loss = F.mse_loss(pred, target, reduction="none")
356
+
357
+ # Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
358
+ if self.config.do_mask_loss_for_padding:
359
+ if "action_is_pad" not in batch:
360
+ raise ValueError(
361
+ "You need to provide 'action_is_pad' in the batch when "
362
+ f"{self.config.do_mask_loss_for_padding=}."
363
+ )
364
+ in_episode_bound = ~batch["action_is_pad"]
365
+ loss = loss * in_episode_bound.unsqueeze(-1)
366
+
367
+ return loss.mean()
368
+
369
+
370
+ class SpatialSoftmax(nn.Module):
371
+ """
372
+ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
373
+ (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
374
+
375
+ At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
376
+ of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
377
+
378
+ Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
379
+ -----------------------------------------------------
380
+ | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
381
+ | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
382
+ | ... | ... | ... | ... |
383
+ | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
384
+ -----------------------------------------------------
385
+ This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
386
+ product with the coordinates (120x2) to get expected points of maximal activation (512x2).
387
+
388
+ The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
389
+ provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
390
+ linear mapping (in_channels, H, W) -> (num_kp, H, W).
391
+ """
392
+
393
+ def __init__(self, input_shape, num_kp=None):
394
+ """
395
+ Args:
396
+ input_shape (list): (C, H, W) input feature map shape.
397
+ num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
398
+ """
399
+ super().__init__()
400
+
401
+ assert len(input_shape) == 3
402
+ self._in_c, self._in_h, self._in_w = input_shape
403
+
404
+ if num_kp is not None:
405
+ self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
406
+ self._out_c = num_kp
407
+ else:
408
+ self.nets = None
409
+ self._out_c = self._in_c
410
+
411
+ # we could use torch.linspace directly but that seems to behave slightly differently than numpy
412
+ # and causes a small degradation in pc_success of pre-trained models.
413
+ pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
414
+ pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
415
+ pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
416
+ # register as buffer so it's moved to the correct device.
417
+ self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
418
+
419
+ def forward(self, features: Tensor) -> Tensor:
420
+ """
421
+ Args:
422
+ features: (B, C, H, W) input feature maps.
423
+ Returns:
424
+ (B, K, 2) image-space coordinates of keypoints.
425
+ """
426
+ if self.nets is not None:
427
+ features = self.nets(features)
428
+
429
+ # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
430
+ features = features.reshape(-1, self._in_h * self._in_w)
431
+ # 2d softmax normalization
432
+ attention = F.softmax(features, dim=-1)
433
+ # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
434
+ expected_xy = attention @ self.pos_grid
435
+ # reshape to [B, K, 2]
436
+ feature_keypoints = expected_xy.view(-1, self._out_c, 2)
437
+
438
+ return feature_keypoints
439
+
440
+
441
+ class DiffusionRgbEncoder(nn.Module):
442
+ """Encodes an RGB image into a 1D feature vector.
443
+
444
+ Includes the ability to normalize and crop the image first.
445
+ """
446
+
447
+ def __init__(self, config: DiffusionConfig):
448
+ super().__init__()
449
+ # Set up optional preprocessing.
450
+ if config.crop_shape is not None:
451
+ self.do_crop = True
452
+ # Always use center crop for eval
453
+ self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
454
+ if config.crop_is_random:
455
+ self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
456
+ else:
457
+ self.maybe_random_crop = self.center_crop
458
+ else:
459
+ self.do_crop = False
460
+
461
+ # Set up backbone.
462
+ backbone_model = getattr(torchvision.models, config.vision_backbone)(
463
+ weights=config.pretrained_backbone_weights
464
+ )
465
+ # Note: This assumes that the layer4 feature map is children()[-3]
466
+ # TODO(alexander-soare): Use a safer alternative.
467
+ self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
468
+ if config.use_group_norm:
469
+ if config.pretrained_backbone_weights:
470
+ raise ValueError(
471
+ "You can't replace BatchNorm in a pretrained model without ruining the weights!"
472
+ )
473
+ self.backbone = _replace_submodules(
474
+ root_module=self.backbone,
475
+ predicate=lambda x: isinstance(x, nn.BatchNorm2d),
476
+ func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
477
+ )
478
+
479
+ # Set up pooling and final layers.
480
+ # Use a dry run to get the feature map shape.
481
+ # The dummy input should take the number of image channels from `config.image_features` and it should
482
+ # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
483
+ # height and width from `config.image_features`.
484
+
485
+ # Note: we have a check in the config class to make sure all images have the same shape.
486
+ images_shape = next(iter(config.image_features.values())).shape
487
+ dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
488
+ dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
489
+ feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
490
+
491
+ self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
492
+ self.feature_dim = config.spatial_softmax_num_keypoints * 2
493
+ self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
494
+ self.relu = nn.ReLU()
495
+
496
+ def forward(self, x: Tensor) -> Tensor:
497
+ """
498
+ Args:
499
+ x: (B, C, H, W) image tensor with pixel values in [0, 1].
500
+ Returns:
501
+ (B, D) image feature.
502
+ """
503
+ # Preprocess: maybe crop (if it was set up in the __init__).
504
+ if self.do_crop:
505
+ if self.training: # noqa: SIM108
506
+ x = self.maybe_random_crop(x)
507
+ else:
508
+ # Always use center crop for eval.
509
+ x = self.center_crop(x)
510
+ # Extract backbone feature.
511
+ x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
512
+ # Final linear layer with non-linearity.
513
+ x = self.relu(self.out(x))
514
+ return x
515
+
516
+
517
+ def _replace_submodules(
518
+ root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
519
+ ) -> nn.Module:
520
+ """
521
+ Args:
522
+ root_module: The module for which the submodules need to be replaced
523
+ predicate: Takes a module as an argument and must return True if the that module is to be replaced.
524
+ func: Takes a module as an argument and returns a new module to replace it with.
525
+ Returns:
526
+ The root module with its submodules replaced.
527
+ """
528
+ if predicate(root_module):
529
+ return func(root_module)
530
+
531
+ replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
532
+ for *parents, k in replace_list:
533
+ parent_module = root_module
534
+ if len(parents) > 0:
535
+ parent_module = root_module.get_submodule(".".join(parents))
536
+ if isinstance(parent_module, nn.Sequential):
537
+ src_module = parent_module[int(k)]
538
+ else:
539
+ src_module = getattr(parent_module, k)
540
+ tgt_module = func(src_module)
541
+ if isinstance(parent_module, nn.Sequential):
542
+ parent_module[int(k)] = tgt_module
543
+ else:
544
+ setattr(parent_module, k, tgt_module)
545
+ # verify that all BN are replaced
546
+ assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
547
+ return root_module
548
+
549
+
550
+ class DiffusionSinusoidalPosEmb(nn.Module):
551
+ """1D sinusoidal positional embeddings as in Attention is All You Need."""
552
+
553
+ def __init__(self, dim: int):
554
+ super().__init__()
555
+ self.dim = dim
556
+
557
+ def forward(self, x: Tensor) -> Tensor:
558
+ device = x.device
559
+ half_dim = self.dim // 2
560
+ emb = math.log(10000) / (half_dim - 1)
561
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
562
+ emb = x.unsqueeze(-1) * emb.unsqueeze(0)
563
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
564
+ return emb
565
+
566
+
567
+ class DiffusionConv1dBlock(nn.Module):
568
+ """Conv1d --> GroupNorm --> Mish"""
569
+
570
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
571
+ super().__init__()
572
+
573
+ self.block = nn.Sequential(
574
+ nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
575
+ nn.GroupNorm(n_groups, out_channels),
576
+ nn.Mish(),
577
+ )
578
+
579
+ def forward(self, x):
580
+ return self.block(x)
581
+
582
+
583
+ class DiffusionConditionalUnet1d(nn.Module):
584
+ """A 1D convolutional UNet with FiLM modulation for conditioning.
585
+
586
+ Note: this removes local conditioning as compared to the original diffusion policy code.
587
+ """
588
+
589
+ def __init__(self, config: DiffusionConfig, global_cond_dim: int):
590
+ super().__init__()
591
+
592
+ self.config = config
593
+
594
+ # Encoder for the diffusion timestep.
595
+ self.diffusion_step_encoder = nn.Sequential(
596
+ DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
597
+ nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
598
+ nn.Mish(),
599
+ nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
600
+ )
601
+
602
+ # The FiLM conditioning dimension.
603
+ cond_dim = config.diffusion_step_embed_dim + global_cond_dim
604
+
605
+ # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
606
+ # just reverse these.
607
+ in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
608
+ zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
609
+ )
610
+
611
+ # Unet encoder.
612
+ common_res_block_kwargs = {
613
+ "cond_dim": cond_dim,
614
+ "kernel_size": config.kernel_size,
615
+ "n_groups": config.n_groups,
616
+ "use_film_scale_modulation": config.use_film_scale_modulation,
617
+ }
618
+ self.down_modules = nn.ModuleList([])
619
+ for ind, (dim_in, dim_out) in enumerate(in_out):
620
+ is_last = ind >= (len(in_out) - 1)
621
+ self.down_modules.append(
622
+ nn.ModuleList(
623
+ [
624
+ DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
625
+ DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
626
+ # Downsample as long as it is not the last block.
627
+ nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
628
+ ]
629
+ )
630
+ )
631
+
632
+ # Processing in the middle of the auto-encoder.
633
+ self.mid_modules = nn.ModuleList(
634
+ [
635
+ DiffusionConditionalResidualBlock1d(
636
+ config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
637
+ ),
638
+ DiffusionConditionalResidualBlock1d(
639
+ config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
640
+ ),
641
+ ]
642
+ )
643
+
644
+ # Unet decoder.
645
+ self.up_modules = nn.ModuleList([])
646
+ for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
647
+ is_last = ind >= (len(in_out) - 1)
648
+ self.up_modules.append(
649
+ nn.ModuleList(
650
+ [
651
+ # dim_in * 2, because it takes the encoder's skip connection as well
652
+ DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
653
+ DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
654
+ # Upsample as long as it is not the last block.
655
+ nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
656
+ ]
657
+ )
658
+ )
659
+
660
+ self.final_conv = nn.Sequential(
661
+ DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
662
+ nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
663
+ )
664
+
665
+ def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
666
+ """
667
+ Args:
668
+ x: (B, T, input_dim) tensor for input to the Unet.
669
+ timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
670
+ global_cond: (B, global_cond_dim)
671
+ output: (B, T, input_dim)
672
+ Returns:
673
+ (B, T, input_dim) diffusion model prediction.
674
+ """
675
+ # For 1D convolutions we'll need feature dimension first.
676
+ x = einops.rearrange(x, "b t d -> b d t")
677
+
678
+ timesteps_embed = self.diffusion_step_encoder(timestep)
679
+
680
+ # If there is a global conditioning feature, concatenate it to the timestep embedding.
681
+ if global_cond is not None:
682
+ global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
683
+ else:
684
+ global_feature = timesteps_embed
685
+
686
+ # Run encoder, keeping track of skip features to pass to the decoder.
687
+ encoder_skip_features: list[Tensor] = []
688
+ for resnet, resnet2, downsample in self.down_modules:
689
+ x = resnet(x, global_feature)
690
+ x = resnet2(x, global_feature)
691
+ encoder_skip_features.append(x)
692
+ x = downsample(x)
693
+
694
+ for mid_module in self.mid_modules:
695
+ x = mid_module(x, global_feature)
696
+
697
+ # Run decoder, using the skip features from the encoder.
698
+ for resnet, resnet2, upsample in self.up_modules:
699
+ x = torch.cat((x, encoder_skip_features.pop()), dim=1)
700
+ x = resnet(x, global_feature)
701
+ x = resnet2(x, global_feature)
702
+ x = upsample(x)
703
+
704
+ x = self.final_conv(x)
705
+
706
+ x = einops.rearrange(x, "b d t -> b t d")
707
+ return x
708
+
709
+
710
+ class DiffusionConditionalResidualBlock1d(nn.Module):
711
+ """ResNet style 1D convolutional block with FiLM modulation for conditioning."""
712
+
713
+ def __init__(
714
+ self,
715
+ in_channels: int,
716
+ out_channels: int,
717
+ cond_dim: int,
718
+ kernel_size: int = 3,
719
+ n_groups: int = 8,
720
+ # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
721
+ # FiLM just modulates bias).
722
+ use_film_scale_modulation: bool = False,
723
+ ):
724
+ super().__init__()
725
+
726
+ self.use_film_scale_modulation = use_film_scale_modulation
727
+ self.out_channels = out_channels
728
+
729
+ self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
730
+
731
+ # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
732
+ cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
733
+ self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
734
+
735
+ self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
736
+
737
+ # A final convolution for dimension matching the residual (if needed).
738
+ self.residual_conv = (
739
+ nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
740
+ )
741
+
742
+ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
743
+ """
744
+ Args:
745
+ x: (B, in_channels, T)
746
+ cond: (B, cond_dim)
747
+ Returns:
748
+ (B, out_channels, T)
749
+ """
750
+ out = self.conv1(x)
751
+
752
+ # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
753
+ cond_embed = self.cond_encoder(cond).unsqueeze(-1)
754
+ if self.use_film_scale_modulation:
755
+ # Treat the embedding as a list of scales and biases.
756
+ scale = cond_embed[:, : self.out_channels]
757
+ bias = cond_embed[:, self.out_channels :]
758
+ out = scale * out + bias
759
+ else:
760
+ # Treat the embedding as biases.
761
+ out = out + cond_embed
762
+
763
+ out = self.conv2(out)
764
+ out = out + self.residual_conv(x)
765
+ return out
lerobot/common/policies/factory.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+
19
+ from torch import nn
20
+
21
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
22
+ from lerobot.common.datasets.utils import dataset_to_policy_features
23
+ from lerobot.common.envs.configs import EnvConfig
24
+ from lerobot.common.envs.utils import env_to_policy_features
25
+ from lerobot.common.policies.act.configuration_act import ACTConfig
26
+ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
27
+ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
28
+ from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
29
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
30
+ from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
31
+ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
32
+ from lerobot.configs.policies import PreTrainedConfig
33
+ from lerobot.configs.types import FeatureType
34
+
35
+
36
+ def get_policy_class(name: str) -> PreTrainedPolicy:
37
+ """Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
38
+ if name == "tdmpc":
39
+ from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
40
+
41
+ return TDMPCPolicy
42
+ elif name == "diffusion":
43
+ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
44
+
45
+ return DiffusionPolicy
46
+ elif name == "act":
47
+ from lerobot.common.policies.act.modeling_act import ACTPolicy
48
+
49
+ return ACTPolicy
50
+ elif name == "vqbet":
51
+ from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
52
+
53
+ return VQBeTPolicy
54
+ elif name == "pi0":
55
+ from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
56
+
57
+ return PI0Policy
58
+ elif name == "pi0fast":
59
+ from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
60
+
61
+ return PI0FASTPolicy
62
+ else:
63
+ raise NotImplementedError(f"Policy with name {name} is not implemented.")
64
+
65
+
66
+ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
67
+ if policy_type == "tdmpc":
68
+ return TDMPCConfig(**kwargs)
69
+ elif policy_type == "diffusion":
70
+ return DiffusionConfig(**kwargs)
71
+ elif policy_type == "act":
72
+ return ACTConfig(**kwargs)
73
+ elif policy_type == "vqbet":
74
+ return VQBeTConfig(**kwargs)
75
+ elif policy_type == "pi0":
76
+ return PI0Config(**kwargs)
77
+ elif policy_type == "pi0fast":
78
+ return PI0FASTConfig(**kwargs)
79
+ else:
80
+ raise ValueError(f"Policy type '{policy_type}' is not available.")
81
+
82
+
83
+ def make_policy(
84
+ cfg: PreTrainedConfig,
85
+ ds_meta: LeRobotDatasetMetadata | None = None,
86
+ env_cfg: EnvConfig | None = None,
87
+ ) -> PreTrainedPolicy:
88
+ """Make an instance of a policy class.
89
+
90
+ This function exists because (for now) we need to parse features from either a dataset or an environment
91
+ in order to properly dimension and instantiate a policy for that dataset or environment.
92
+
93
+ Args:
94
+ cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
95
+ be loaded with the weights from that path.
96
+ ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
97
+ statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
98
+ env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
99
+ provided if ds_meta is not. Defaults to None.
100
+
101
+ Raises:
102
+ ValueError: Either ds_meta or env and env_cfg must be provided.
103
+ NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
104
+
105
+ Returns:
106
+ PreTrainedPolicy: _description_
107
+ """
108
+ if bool(ds_meta) == bool(env_cfg):
109
+ raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
110
+
111
+ # NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
112
+ # TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
113
+ # NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
114
+ # you want this op to be added in priority during the prototype phase of this feature, please comment on
115
+ # https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
116
+ # variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
117
+ # slower than running natively on MPS.
118
+ if cfg.type == "vqbet" and cfg.device == "mps":
119
+ raise NotImplementedError(
120
+ "Current implementation of VQBeT does not support `mps` backend. "
121
+ "Please use `cpu` or `cuda` backend."
122
+ )
123
+
124
+ policy_cls = get_policy_class(cfg.type)
125
+
126
+ kwargs = {}
127
+ if ds_meta is not None:
128
+ features = dataset_to_policy_features(ds_meta.features)
129
+ kwargs["dataset_stats"] = ds_meta.stats
130
+ else:
131
+ if not cfg.pretrained_path:
132
+ logging.warning(
133
+ "You are instantiating a policy from scratch and its features are parsed from an environment "
134
+ "rather than a dataset. Normalization modules inside the policy will have infinite values "
135
+ "by default without stats from a dataset."
136
+ )
137
+ features = env_to_policy_features(env_cfg)
138
+
139
+ cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
140
+ cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
141
+ kwargs["config"] = cfg
142
+
143
+ if cfg.pretrained_path:
144
+ # Load a pretrained policy and override the config if needed (for example, if there are inference-time
145
+ # hyperparameters that we want to vary).
146
+ kwargs["pretrained_name_or_path"] = cfg.pretrained_path
147
+ policy = policy_cls.from_pretrained(**kwargs)
148
+ else:
149
+ # Make a fresh policy.
150
+ policy = policy_cls(**kwargs)
151
+
152
+ policy.to(cfg.device)
153
+ assert isinstance(policy, nn.Module)
154
+
155
+ # policy = torch.compile(policy, mode="reduce-overhead")
156
+
157
+ return policy
lerobot/common/policies/normalize.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import numpy as np
17
+ import torch
18
+ from torch import Tensor, nn
19
+
20
+ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
21
+
22
+
23
+ def create_stats_buffers(
24
+ features: dict[str, PolicyFeature],
25
+ norm_map: dict[str, NormalizationMode],
26
+ stats: dict[str, dict[str, Tensor]] | None = None,
27
+ ) -> dict[str, dict[str, nn.ParameterDict]]:
28
+ """
29
+ Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
30
+ statistics.
31
+
32
+ Args: (see Normalize and Unnormalize)
33
+
34
+ Returns:
35
+ dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
36
+ `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
37
+ """
38
+ stats_buffers = {}
39
+
40
+ for key, ft in features.items():
41
+ norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
42
+ if norm_mode is NormalizationMode.IDENTITY:
43
+ continue
44
+
45
+ assert isinstance(norm_mode, NormalizationMode)
46
+
47
+ shape = tuple(ft.shape)
48
+
49
+ if ft.type is FeatureType.VISUAL:
50
+ # sanity checks
51
+ assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
52
+ c, h, w = shape
53
+ assert c < h and c < w, f"{key} is not channel first ({shape=})"
54
+ # override image shape to be invariant to height and width
55
+ shape = (c, 1, 1)
56
+
57
+ # Note: we initialize mean, std, min, max to infinity. They should be overwritten
58
+ # downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
59
+ # we assert they are not infinity anymore.
60
+
61
+ buffer = {}
62
+ if norm_mode is NormalizationMode.MEAN_STD:
63
+ mean = torch.ones(shape, dtype=torch.float32) * torch.inf
64
+ std = torch.ones(shape, dtype=torch.float32) * torch.inf
65
+ buffer = nn.ParameterDict(
66
+ {
67
+ "mean": nn.Parameter(mean, requires_grad=False),
68
+ "std": nn.Parameter(std, requires_grad=False),
69
+ }
70
+ )
71
+ elif norm_mode is NormalizationMode.MIN_MAX:
72
+ min = torch.ones(shape, dtype=torch.float32) * torch.inf
73
+ max = torch.ones(shape, dtype=torch.float32) * torch.inf
74
+ buffer = nn.ParameterDict(
75
+ {
76
+ "min": nn.Parameter(min, requires_grad=False),
77
+ "max": nn.Parameter(max, requires_grad=False),
78
+ }
79
+ )
80
+
81
+ # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
82
+ if stats:
83
+ if isinstance(stats[key]["mean"], np.ndarray):
84
+ if norm_mode is NormalizationMode.MEAN_STD:
85
+ buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
86
+ buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
87
+ elif norm_mode is NormalizationMode.MIN_MAX:
88
+ buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
89
+ buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
90
+ elif isinstance(stats[key]["mean"], torch.Tensor):
91
+ # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
92
+ # tensors anywhere (for example, when we use the same stats for normalization and
93
+ # unnormalization). See the logic here
94
+ # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
95
+ if norm_mode is NormalizationMode.MEAN_STD:
96
+ buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
97
+ buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
98
+ elif norm_mode is NormalizationMode.MIN_MAX:
99
+ buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
100
+ buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
101
+ else:
102
+ type_ = type(stats[key]["mean"])
103
+ raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
104
+
105
+ stats_buffers[key] = buffer
106
+ return stats_buffers
107
+
108
+
109
+ def _no_stats_error_str(name: str) -> str:
110
+ return (
111
+ f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
112
+ "pretrained model."
113
+ )
114
+
115
+
116
+ class Normalize(nn.Module):
117
+ """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
118
+
119
+ def __init__(
120
+ self,
121
+ features: dict[str, PolicyFeature],
122
+ norm_map: dict[str, NormalizationMode],
123
+ stats: dict[str, dict[str, Tensor]] | None = None,
124
+ ):
125
+ """
126
+ Args:
127
+ shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
128
+ are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
129
+ mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
130
+ is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
131
+ modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
132
+ are their normalization modes among:
133
+ - "mean_std": subtract the mean and divide by standard deviation.
134
+ - "min_max": map to [-1, 1] range.
135
+ stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
136
+ and values are dictionaries of statistic types and their values (e.g.
137
+ `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
138
+ training the model for the first time, these statistics will overwrite the default buffers. If
139
+ not provided, as expected for finetuning or evaluation, the default buffers should to be
140
+ overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
141
+ dataset is not needed to get the stats, since they are already in the policy state_dict.
142
+ """
143
+ super().__init__()
144
+ self.features = features
145
+ self.norm_map = norm_map
146
+ self.stats = stats
147
+ stats_buffers = create_stats_buffers(features, norm_map, stats)
148
+ for key, buffer in stats_buffers.items():
149
+ setattr(self, "buffer_" + key.replace(".", "_"), buffer)
150
+
151
+ # TODO(rcadene): should we remove torch.no_grad?
152
+ @torch.no_grad
153
+ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
154
+ batch = dict(batch) # shallow copy avoids mutating the input batch
155
+ for key, ft in self.features.items():
156
+ if key not in batch:
157
+ # FIXME(aliberts, rcadene): This might lead to silent fail!
158
+ continue
159
+
160
+ norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
161
+ if norm_mode is NormalizationMode.IDENTITY:
162
+ continue
163
+
164
+ buffer = getattr(self, "buffer_" + key.replace(".", "_"))
165
+
166
+ if norm_mode is NormalizationMode.MEAN_STD:
167
+ mean = buffer["mean"]
168
+ std = buffer["std"]
169
+ assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
170
+ assert not torch.isinf(std).any(), _no_stats_error_str("std")
171
+ batch[key] = (batch[key] - mean) / (std + 1e-8)
172
+ elif norm_mode is NormalizationMode.MIN_MAX:
173
+ min = buffer["min"]
174
+ max = buffer["max"]
175
+ assert not torch.isinf(min).any(), _no_stats_error_str("min")
176
+ assert not torch.isinf(max).any(), _no_stats_error_str("max")
177
+ # normalize to [0,1]
178
+ batch[key] = (batch[key] - min) / (max - min + 1e-8)
179
+ # normalize to [-1, 1]
180
+ batch[key] = batch[key] * 2 - 1
181
+ else:
182
+ raise ValueError(norm_mode)
183
+ return batch
184
+
185
+
186
+ class Unnormalize(nn.Module):
187
+ """
188
+ Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
189
+ original range used by the environment.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ features: dict[str, PolicyFeature],
195
+ norm_map: dict[str, NormalizationMode],
196
+ stats: dict[str, dict[str, Tensor]] | None = None,
197
+ ):
198
+ """
199
+ Args:
200
+ shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
201
+ are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
202
+ mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
203
+ is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
204
+ modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
205
+ are their normalization modes among:
206
+ - "mean_std": subtract the mean and divide by standard deviation.
207
+ - "min_max": map to [-1, 1] range.
208
+ stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
209
+ and values are dictionaries of statistic types and their values (e.g.
210
+ `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
211
+ training the model for the first time, these statistics will overwrite the default buffers. If
212
+ not provided, as expected for finetuning or evaluation, the default buffers should to be
213
+ overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
214
+ dataset is not needed to get the stats, since they are already in the policy state_dict.
215
+ """
216
+ super().__init__()
217
+ self.features = features
218
+ self.norm_map = norm_map
219
+ self.stats = stats
220
+ # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
221
+ stats_buffers = create_stats_buffers(features, norm_map, stats)
222
+ for key, buffer in stats_buffers.items():
223
+ setattr(self, "buffer_" + key.replace(".", "_"), buffer)
224
+
225
+ # TODO(rcadene): should we remove torch.no_grad?
226
+ @torch.no_grad
227
+ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
228
+ batch = dict(batch) # shallow copy avoids mutating the input batch
229
+ for key, ft in self.features.items():
230
+ if key not in batch:
231
+ continue
232
+
233
+ norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
234
+ if norm_mode is NormalizationMode.IDENTITY:
235
+ continue
236
+
237
+ buffer = getattr(self, "buffer_" + key.replace(".", "_"))
238
+
239
+ if norm_mode is NormalizationMode.MEAN_STD:
240
+ mean = buffer["mean"]
241
+ std = buffer["std"]
242
+ assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
243
+ assert not torch.isinf(std).any(), _no_stats_error_str("std")
244
+ batch[key] = batch[key] * std + mean
245
+ elif norm_mode is NormalizationMode.MIN_MAX:
246
+ min = buffer["min"]
247
+ max = buffer["max"]
248
+ assert not torch.isinf(min).any(), _no_stats_error_str("min")
249
+ assert not torch.isinf(max).any(), _no_stats_error_str("max")
250
+ batch[key] = (batch[key] + 1) / 2
251
+ batch[key] = batch[key] * (max - min) + min
252
+ else:
253
+ raise ValueError(norm_mode)
254
+ return batch
lerobot/common/policies/pi0/configuration_pi0.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from lerobot.common.optim.optimizers import AdamWConfig
18
+ from lerobot.common.optim.schedulers import (
19
+ CosineDecayWithWarmupSchedulerConfig,
20
+ )
21
+ from lerobot.configs.policies import PreTrainedConfig
22
+ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
23
+
24
+
25
+ @PreTrainedConfig.register_subclass("pi0")
26
+ @dataclass
27
+ class PI0Config(PreTrainedConfig):
28
+ # Input / output structure.
29
+ n_obs_steps: int = 1
30
+ chunk_size: int = 50
31
+ n_action_steps: int = 50
32
+
33
+ normalization_mapping: dict[str, NormalizationMode] = field(
34
+ default_factory=lambda: {
35
+ "VISUAL": NormalizationMode.IDENTITY,
36
+ "STATE": NormalizationMode.MEAN_STD,
37
+ "ACTION": NormalizationMode.MEAN_STD,
38
+ }
39
+ )
40
+
41
+ # Shorter state and action vectors will be padded
42
+ max_state_dim: int = 32
43
+ max_action_dim: int = 32
44
+
45
+ # Image preprocessing
46
+ resize_imgs_with_padding: tuple[int, int] = (224, 224)
47
+
48
+ # Add empty images. Used by pi0_aloha_sim which adds the empty
49
+ # left and right wrist cameras in addition to the top camera.
50
+ empty_cameras: int = 0
51
+
52
+ # Converts the joint and gripper values from the standard Aloha space to
53
+ # the space used by the pi internal runtime which was used to train the base model.
54
+ adapt_to_pi_aloha: bool = False
55
+
56
+ # Converts joint dimensions to deltas with respect to the current state before passing to the model.
57
+ # Gripper dimensions will remain in absolute values.
58
+ use_delta_joint_actions_aloha: bool = False
59
+
60
+ # Tokenizer
61
+ tokenizer_max_length: int = 48
62
+
63
+ # Projector
64
+ proj_width: int = 1024
65
+
66
+ # Decoding
67
+ num_steps: int = 10
68
+
69
+ # Attention utils
70
+ use_cache: bool = True
71
+ attention_implementation: str = "eager" # or fa2, flex
72
+
73
+ # Finetuning settings
74
+ freeze_vision_encoder: bool = True
75
+ train_expert_only: bool = False
76
+ train_state_proj: bool = True
77
+
78
+ # Training presets
79
+ optimizer_lr: float = 2.5e-5
80
+ optimizer_betas: tuple[float, float] = (0.9, 0.95)
81
+ optimizer_eps: float = 1e-8
82
+ optimizer_weight_decay: float = 1e-10
83
+
84
+ scheduler_warmup_steps: int = 1_000
85
+ scheduler_decay_steps: int = 30_000
86
+ scheduler_decay_lr: float = 2.5e-6
87
+
88
+ # TODO: Add EMA
89
+
90
+ def __post_init__(self):
91
+ super().__post_init__()
92
+
93
+ # TODO(Steven): Validate device and amp? in all policy configs?
94
+ """Input validation (not exhaustive)."""
95
+ if self.n_action_steps > self.chunk_size:
96
+ raise ValueError(
97
+ f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
98
+ f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
99
+ )
100
+ if self.n_obs_steps != 1:
101
+ raise ValueError(
102
+ f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
103
+ )
104
+
105
+ if self.use_delta_joint_actions_aloha:
106
+ raise NotImplementedError(
107
+ "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
108
+ )
109
+
110
+ def validate_features(self) -> None:
111
+ # TODO: implement value error
112
+ # if not self.image_features and not self.env_state_feature:
113
+ # raise ValueError("You must provide at least one image or the environment state among the inputs.")
114
+
115
+ for i in range(self.empty_cameras):
116
+ key = f"observation.images.empty_camera_{i}"
117
+ empty_camera = PolicyFeature(
118
+ type=FeatureType.VISUAL,
119
+ shape=(3, 480, 640),
120
+ )
121
+ self.input_features[key] = empty_camera
122
+
123
+ def get_optimizer_preset(self) -> AdamWConfig:
124
+ return AdamWConfig(
125
+ lr=self.optimizer_lr,
126
+ betas=self.optimizer_betas,
127
+ eps=self.optimizer_eps,
128
+ weight_decay=self.optimizer_weight_decay,
129
+ )
130
+
131
+ def get_scheduler_preset(self):
132
+ return CosineDecayWithWarmupSchedulerConfig(
133
+ peak_lr=self.optimizer_lr,
134
+ decay_lr=self.scheduler_decay_lr,
135
+ num_warmup_steps=self.scheduler_warmup_steps,
136
+ num_decay_steps=self.scheduler_decay_steps,
137
+ )
138
+
139
+ @property
140
+ def observation_delta_indices(self) -> None:
141
+ return None
142
+
143
+ @property
144
+ def action_delta_indices(self) -> list:
145
+ return list(range(self.chunk_size))
146
+
147
+ @property
148
+ def reward_delta_indices(self) -> None:
149
+ return None
lerobot/common/policies/pi0/conversion_scripts/benchmark.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
18
+ from lerobot.common.policies.factory import make_policy
19
+ from lerobot.configs.policies import PreTrainedConfig
20
+
21
+ torch.backends.cudnn.benchmark = True
22
+
23
+
24
+ def main():
25
+ device = "cuda"
26
+ dataset_repo_id = "danaaubakirova/koch_test"
27
+ # model_name = "pi0_base"
28
+ # ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
29
+ ckpt_torch_dir = "lerobot/pi0"
30
+
31
+ dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
32
+
33
+ dataloader = torch.utils.data.DataLoader(
34
+ dataset,
35
+ num_workers=0,
36
+ batch_size=1,
37
+ )
38
+
39
+ batch = next(iter(dataloader))
40
+
41
+ # To device
42
+ for k in batch:
43
+ if isinstance(batch[k], torch.Tensor):
44
+ batch[k] = batch[k].to(device=device, dtype=torch.float32)
45
+
46
+ cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
47
+ cfg.pretrained_path = ckpt_torch_dir
48
+ policy = make_policy(cfg, ds_meta=dataset.meta)
49
+
50
+ # policy = torch.compile(policy, mode="reduce-overhead")
51
+
52
+ warmup_iters = 10
53
+ benchmark_iters = 30
54
+
55
+ # Warmup
56
+ for _ in range(warmup_iters):
57
+ torch.cuda.synchronize()
58
+ policy.select_action(batch)
59
+ policy.reset()
60
+ torch.cuda.synchronize()
61
+
62
+ # Benchmark
63
+ start_event = torch.cuda.Event(enable_timing=True)
64
+ end_event = torch.cuda.Event(enable_timing=True)
65
+
66
+ start_event.record()
67
+ for _ in range(benchmark_iters):
68
+ policy.select_action(batch)
69
+ policy.reset()
70
+ end_event.record()
71
+
72
+ # Synchronize and measure time
73
+ torch.cuda.synchronize()
74
+ elapsed_time_ms = start_event.elapsed_time(end_event)
75
+
76
+ avg_time_per_iter = elapsed_time_ms / benchmark_iters
77
+ print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ with torch.inference_mode():
82
+ main()
lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import pickle
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
22
+ from lerobot.common.policies.factory import make_policy
23
+ from lerobot.configs.policies import PreTrainedConfig
24
+
25
+
26
+ def display(tensor: torch.Tensor):
27
+ if tensor.dtype == torch.bool:
28
+ tensor = tensor.float()
29
+ print(f"Shape: {tensor.shape}")
30
+ print(f"Mean: {tensor.mean().item()}")
31
+ print(f"Std: {tensor.std().item()}")
32
+ print(f"Min: {tensor.min().item()}")
33
+ print(f"Max: {tensor.max().item()}")
34
+
35
+
36
+ def main():
37
+ num_motors = 14
38
+ device = "cuda"
39
+ # model_name = "pi0_aloha_towel"
40
+ model_name = "pi0_aloha_sim"
41
+
42
+ if model_name == "pi0_aloha_towel":
43
+ dataset_repo_id = "lerobot/aloha_static_towel"
44
+ else:
45
+ dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
46
+
47
+ ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
48
+ ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
49
+ save_dir = Path(f"../openpi/data/{model_name}/save")
50
+
51
+ with open(save_dir / "example.pkl", "rb") as f:
52
+ example = pickle.load(f)
53
+ with open(save_dir / "outputs.pkl", "rb") as f:
54
+ outputs = pickle.load(f)
55
+ with open(save_dir / "noise.pkl", "rb") as f:
56
+ noise = pickle.load(f)
57
+
58
+ with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
59
+ norm_stats = json.load(f)
60
+
61
+ # Override stats
62
+ dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
63
+ dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
64
+ norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
65
+ )
66
+ dataset_meta.stats["observation.state"]["std"] = torch.tensor(
67
+ norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
68
+ )
69
+
70
+ # Create LeRobot batch from Jax
71
+ batch = {}
72
+ for cam_key, uint_chw_array in example["images"].items():
73
+ batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
74
+ batch["observation.state"] = torch.from_numpy(example["state"])
75
+ batch["action"] = torch.from_numpy(outputs["actions"])
76
+ batch["task"] = example["prompt"]
77
+
78
+ if model_name == "pi0_aloha_towel":
79
+ del batch["observation.images.cam_low"]
80
+ elif model_name == "pi0_aloha_sim":
81
+ batch["observation.images.top"] = batch["observation.images.cam_high"]
82
+ del batch["observation.images.cam_high"]
83
+
84
+ # Batchify
85
+ for key in batch:
86
+ if isinstance(batch[key], torch.Tensor):
87
+ batch[key] = batch[key].unsqueeze(0)
88
+ elif isinstance(batch[key], str):
89
+ batch[key] = [batch[key]]
90
+ else:
91
+ raise ValueError(f"{key}, {batch[key]}")
92
+
93
+ # To device
94
+ for k in batch:
95
+ if isinstance(batch[k], torch.Tensor):
96
+ batch[k] = batch[k].to(device=device, dtype=torch.float32)
97
+
98
+ noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
99
+
100
+ from lerobot.common import policies # noqa
101
+
102
+ cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
103
+ cfg.pretrained_path = ckpt_torch_dir
104
+ policy = make_policy(cfg, dataset_meta)
105
+
106
+ # loss_dict = policy.forward(batch, noise=noise, time=time_beta)
107
+ # loss_dict["loss"].backward()
108
+ # print("losses")
109
+ # display(loss_dict["losses_after_forward"])
110
+ # print("pi_losses")
111
+ # display(pi_losses)
112
+
113
+ actions = []
114
+ for _ in range(50):
115
+ action = policy.select_action(batch, noise=noise)
116
+ actions.append(action)
117
+
118
+ actions = torch.stack(actions, dim=1)
119
+ pi_actions = batch["action"]
120
+ print("actions")
121
+ display(actions)
122
+ print()
123
+ print("pi_actions")
124
+ display(pi_actions)
125
+ print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
126
+ print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
127
+ print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from transformers import GemmaConfig, PaliGemmaConfig
16
+
17
+
18
+ def get_paligemma_config(precision: str):
19
+ config = {
20
+ "image_token_index": None,
21
+ "pad_token_id": 0,
22
+ "bos_token_id": 2,
23
+ "eos_token_id": 1,
24
+ }
25
+
26
+ # image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
27
+
28
+ image_size = 224 # image_sizes[variant]
29
+ patch_size = 14
30
+ num_image_tokens = (image_size**2) // (patch_size**2)
31
+
32
+ config["image_token_index"] = 257152
33
+ text_config = {
34
+ "vocab_size": 257152,
35
+ "num_hidden_layers": 18,
36
+ "num_key_value_heads": 1,
37
+ "head_dim": 256,
38
+ "torch_dtype": precision,
39
+ "hidden_size": 2048,
40
+ "hidden_activation": "gelu_pytorch_tanh",
41
+ "num_attention_heads": 8,
42
+ "intermediate_size": 16384,
43
+ "is_encoder_decoder": False,
44
+ }
45
+ vision_config = {
46
+ "torch_dtype": precision,
47
+ "image_size": image_size,
48
+ "patch_size": patch_size,
49
+ "num_image_tokens": num_image_tokens,
50
+ "hidden_size": 1152,
51
+ "intermediate_size": 4304,
52
+ "num_hidden_layers": 27,
53
+ "num_attention_heads": 16,
54
+ "projector_hidden_act": "gelu_fast",
55
+ "vision_use_head": False,
56
+ }
57
+ final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
58
+ return final_config
59
+
60
+
61
+ def get_gemma_config(precision: str):
62
+ config = {
63
+ "image_token_index": None,
64
+ "pad_token_id": 0,
65
+ "bos_token_id": 2,
66
+ "eos_token_id": 1,
67
+ }
68
+
69
+ config["image_token_index"] = 257152
70
+ text_config = {
71
+ "vocab_size": 257152,
72
+ "num_hidden_layers": 18,
73
+ "num_key_value_heads": 1,
74
+ "head_dim": 256,
75
+ "torch_dtype": precision,
76
+ "hidden_size": 1024,
77
+ "hidden_activation": "gelu_pytorch_tanh",
78
+ "num_attention_heads": 8,
79
+ "intermediate_size": 4096,
80
+ "is_encoder_decoder": False,
81
+ }
82
+ final_config = GemmaConfig()
83
+ final_config.update(text_config)
84
+ return final_config
lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Convert pi0 parameters from Jax to Pytorch
17
+
18
+ Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
19
+ and install the required libraries.
20
+
21
+ ```bash
22
+ cd ~/code/openpi
23
+ source .venv/bin/activate
24
+ ```
25
+
26
+ Example downloading parameters:
27
+ ```bash
28
+ python
29
+ >>> import openpi.shared.download as download
30
+ >>> path='s3://openpi-assets/checkpoints/pi0_base/params'
31
+ >>> download.maybe_download(path)
32
+ ```
33
+
34
+ Converting pi0_base:
35
+ ```python
36
+ python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
37
+ --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
38
+ --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
39
+ ```
40
+
41
+ ```python
42
+ python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \
43
+ --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
44
+ --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
45
+ ```
46
+ """
47
+
48
+ import argparse
49
+ import pathlib
50
+
51
+ import jax
52
+ import numpy as np
53
+ import orbax.checkpoint as ocp
54
+ import torch
55
+ from jax.sharding import SingleDeviceSharding
56
+
57
+ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
58
+ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
59
+ get_gemma_config,
60
+ get_paligemma_config,
61
+ )
62
+ from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
63
+
64
+ PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
65
+
66
+
67
+ def slice_paligemma_state_dict(state_dict, config):
68
+ suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
69
+
70
+ # fmt: off
71
+ # patch embeddings
72
+ state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
73
+ 3, 2, 0, 1
74
+ )
75
+ state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
76
+ # positional embeddings
77
+ state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
78
+ -1, config.vision_config.hidden_size
79
+ )
80
+
81
+ # extract vision layers to be sliced at index 0. There are 27 layers in the base model.
82
+ encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
83
+ encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
84
+ encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
85
+ encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
86
+
87
+ encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
88
+ encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
89
+ encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
90
+ encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
91
+
92
+ encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
93
+ encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
94
+ encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
95
+ encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
96
+ encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
97
+ encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
98
+ encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
99
+ encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
100
+
101
+ for i in range(config.vision_config.num_hidden_layers):
102
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
103
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
104
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
105
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
106
+
107
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
108
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
109
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
110
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
111
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
112
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
113
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
114
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
115
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
116
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
117
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
118
+ state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
119
+
120
+ state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
121
+ state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
122
+
123
+ # multimodal projector
124
+
125
+ state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
126
+ state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
127
+
128
+ # text decoder (gemma)
129
+ embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
130
+ state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
131
+
132
+ # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
133
+
134
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
135
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
136
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
137
+
138
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
139
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
140
+ # TODO verify correctness of layer norm loading
141
+
142
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
143
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
144
+
145
+ for i in range(config.text_config.num_hidden_layers):
146
+ # llm_attention_q_einsum[i].shape = (8, 2048, 256)
147
+ q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
148
+
149
+ state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
150
+
151
+ # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
152
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
153
+ state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
154
+ # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
155
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
156
+ state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
157
+
158
+ # output projection.
159
+
160
+ # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
161
+ o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
162
+
163
+ state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
164
+ # mlp layers
165
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
166
+ state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
167
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
168
+ state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
169
+ state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
170
+ state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
171
+ state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
172
+
173
+ state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
174
+ state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
175
+
176
+ # fmt: on
177
+ expert_dict = {}
178
+ final_state_dict = {}
179
+ for key, value in state_dict.items():
180
+ if key not in [
181
+ f"llm/final_norm_1/scale{suffix}",
182
+ f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
183
+ f"llm/layers/attn/kv_einsum_1/w{suffix}",
184
+ f"llm/layers/attn/q_einsum_1/w{suffix}",
185
+ f"llm/layers/mlp_1/gating_einsum{suffix}",
186
+ f"llm/layers/mlp_1/linear{suffix}",
187
+ f"llm/layers/pre_attention_norm_1/scale{suffix}",
188
+ f"llm/layers/pre_ffw_norm_1/scale{suffix}",
189
+ ]:
190
+ final_state_dict[key] = torch.from_numpy(value)
191
+ else:
192
+ expert_dict[key] = value
193
+
194
+ return final_state_dict, expert_dict
195
+
196
+
197
+ def slice_gemma_state_dict(state_dict, config, num_expert=1):
198
+ # fmt: off
199
+ # text decoder (gemma)
200
+ # no embedding vector, the expert just has the decoder layers
201
+
202
+ embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
203
+ state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
204
+
205
+ # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
206
+
207
+ suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
208
+
209
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
210
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
211
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
212
+
213
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
214
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
215
+ # TODO verify correctness of layer norm loading
216
+
217
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
218
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
219
+
220
+ for i in range(config.num_hidden_layers):
221
+ q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
222
+
223
+ state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
224
+
225
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
226
+ state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
227
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
228
+ state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
229
+
230
+ # output projection.
231
+
232
+ # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
233
+ o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
234
+
235
+ state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
236
+ # mlp layers
237
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
238
+ state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
239
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
240
+ state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
241
+ state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
242
+ state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
243
+ state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
244
+
245
+ state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
246
+ state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
247
+
248
+ # fmt: on
249
+ final_state_dict = {}
250
+ for key, value in state_dict.items():
251
+ if not isinstance(value, torch.Tensor):
252
+ final_state_dict[key] = torch.from_numpy(value)
253
+ else:
254
+ final_state_dict[key] = value
255
+ return final_state_dict
256
+
257
+
258
+ def flatten_for_memory(tree, parent_key=""):
259
+ out = {}
260
+ for k, v in tree.items():
261
+ new_key = f"{parent_key}/{k}" if parent_key else k
262
+ if isinstance(v, dict):
263
+ out.update(flatten_for_memory(v, new_key))
264
+ else:
265
+ out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
266
+ return out
267
+
268
+
269
+ def flatten_for_npz(tree, parent_key=""):
270
+ out = {}
271
+ for k, v in tree.items():
272
+ new_key = f"{parent_key}/{k}" if parent_key else k
273
+ if isinstance(v, dict):
274
+ out.update(flatten_for_npz(v, new_key))
275
+ else:
276
+ # bf16/f32 here?
277
+ out[new_key] = np.array(v)
278
+ return out
279
+
280
+
281
+ def slice_initial_orbax_checkpoint(checkpoint_dir: str):
282
+ params_path = pathlib.Path(checkpoint_dir).resolve()
283
+ checkpointer = ocp.PyTreeCheckpointer()
284
+
285
+ metadata = checkpointer.metadata(params_path)
286
+ print("Metadata keys:", list(metadata.keys()))
287
+
288
+ params_name = "params"
289
+
290
+ item = {params_name: metadata[params_name]}
291
+ device = jax.local_devices()[0] # Use the first local device
292
+ sharding = SingleDeviceSharding(device)
293
+ restored = checkpointer.restore(
294
+ params_path,
295
+ ocp.args.PyTreeRestore(
296
+ item=item,
297
+ restore_args=jax.tree_util.tree_map(
298
+ lambda _: ocp.ArrayRestoreArgs(
299
+ restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
300
+ sharding=sharding,
301
+ ),
302
+ item,
303
+ ),
304
+ transforms={},
305
+ ),
306
+ )
307
+ params = restored[params_name]
308
+
309
+ # get params for PaliGemma
310
+ pali_params = params["PaliGemma"]
311
+ del params["PaliGemma"]
312
+ pali_params_flat = flatten_for_npz(pali_params)
313
+ return {"paligemma_params": pali_params_flat, "projection_params": params}
314
+
315
+
316
+ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
317
+ """Update dictionary keys by adding a prefix."""
318
+ return {f"{prefix}{key}": value for key, value in d.items()}
319
+
320
+
321
+ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
322
+ # Break down orbax ckpts - they are in OCDBT
323
+ initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
324
+ # process projection params
325
+ keys = [
326
+ "state_proj",
327
+ "action_in_proj",
328
+ "action_out_proj",
329
+ "action_time_mlp_in",
330
+ "action_time_mlp_out",
331
+ ]
332
+
333
+ projection_params = {}
334
+ for key in keys:
335
+ kernel_params = initial_params["projection_params"][key]["kernel"]
336
+ bias_params = initial_params["projection_params"][key]["bias"]
337
+ if isinstance(kernel_params, dict):
338
+ weight = kernel_params["value"]
339
+ bias = bias_params["value"]
340
+ else:
341
+ weight = kernel_params
342
+ bias = bias_params
343
+ projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
344
+ projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
345
+
346
+ # Process PaliGemma weights
347
+ paligemma_config = get_paligemma_config(precision)
348
+ paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
349
+ initial_params["paligemma_params"], paligemma_config
350
+ )
351
+
352
+ # Process Gemma weights (at this stage they are unused)
353
+ gemma_config = get_gemma_config(precision)
354
+ gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
355
+
356
+ # Instantiate model from configs
357
+
358
+ if "pi0_aloha_sim" in checkpoint_dir:
359
+ pi0_config = PI0Config(
360
+ empty_cameras=2,
361
+ adapt_to_pi_aloha=True,
362
+ use_delta_joint_actions_aloha=False,
363
+ )
364
+ elif "pi0_aloha_towel" in checkpoint_dir:
365
+ pi0_config = PI0Config(
366
+ adapt_to_pi_aloha=True,
367
+ use_delta_joint_actions_aloha=True,
368
+ )
369
+ elif "pi0_base" in checkpoint_dir:
370
+ pi0_config = PI0Config(
371
+ empty_cameras=0,
372
+ adapt_to_pi_aloha=False,
373
+ use_delta_joint_actions_aloha=False,
374
+ )
375
+ else:
376
+ raise ValueError()
377
+
378
+ # gemma_config=gemma_config, paligemma_config=paligemma_config)
379
+ pi0_model = PI0Policy(pi0_config)
380
+
381
+ paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
382
+ gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
383
+ projection_params = update_keys_with_prefix(projection_params, "model.")
384
+
385
+ # load state dict
386
+ torch_dtype = PRECISIONS[precision]
387
+ pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
388
+ pi0_model = pi0_model.to(torch_dtype)
389
+ # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
390
+
391
+ pi0_model.save_pretrained(output_path, safe_serialization=True)
392
+ # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
393
+
394
+ # assert that model loads properly
395
+ del pi0_model
396
+ PI0Policy.from_pretrained(output_path)
397
+
398
+
399
+ if __name__ == "__main__":
400
+ parser = argparse.ArgumentParser()
401
+ parser.add_argument(
402
+ "--checkpoint_dir",
403
+ default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
404
+ type=str,
405
+ help="Path to the ocdbt checkpoint",
406
+ )
407
+
408
+ parser.add_argument(
409
+ "--precision",
410
+ choices=["float32", "bfloat16", "float16"],
411
+ default="float32",
412
+ type=str,
413
+ help="Precision identifier for model conversion - should match the base checkpoint precision.",
414
+ )
415
+ # tokenizer is identical to paligemma, it appears
416
+
417
+ parser.add_argument(
418
+ "--tokenizer_hub_id",
419
+ default="google/paligemma-3b-pt-224",
420
+ type=str,
421
+ help="Hub path to the tokenizer to save",
422
+ )
423
+
424
+ parser.add_argument(
425
+ "--output_path",
426
+ required=True,
427
+ type=str,
428
+ help="Path to save converted weights to",
429
+ )
430
+
431
+ args = parser.parse_args()
432
+ convert_pi0_checkpoint(
433
+ checkpoint_dir=args.checkpoint_dir,
434
+ precision=args.precision,
435
+ tokenizer_id=args.tokenizer_hub_id,
436
+ output_path=args.output_path,
437
+ )
lerobot/common/policies/pi0/flex_attention.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F # noqa: N812
17
+ from packaging.version import Version
18
+
19
+ if Version(torch.__version__) > Version("2.5.0"):
20
+ # Ffex attention is only available from torch 2.5 onwards
21
+ from torch.nn.attention.flex_attention import (
22
+ _mask_mod_signature,
23
+ _round_up_to_multiple,
24
+ create_block_mask,
25
+ create_mask,
26
+ flex_attention,
27
+ )
28
+
29
+
30
+ # @torch.compile(dynamic=False)
31
+ def flex_attention_forward(
32
+ attention_mask: torch.Tensor,
33
+ batch_size: int,
34
+ head_dim: int,
35
+ query_states: torch.Tensor,
36
+ key_states: torch.Tensor,
37
+ value_states: torch.Tensor,
38
+ scaling=None,
39
+ ):
40
+ """
41
+ This is defined out of classes to make compile happy.
42
+ """
43
+
44
+ original_dtype = query_states.dtype
45
+ num_att_heads = 8
46
+ num_key_value_heads = 1
47
+ num_key_value_groups = num_att_heads // num_key_value_heads
48
+
49
+ key_states = key_states[:, :, :, None, :]
50
+ key_states = key_states.expand(
51
+ batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
52
+ )
53
+ key_states = key_states.reshape(
54
+ batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
55
+ )
56
+
57
+ value_states = value_states[:, :, :, None, :]
58
+ value_states = value_states.expand(
59
+ batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
60
+ )
61
+ value_states = value_states.reshape(
62
+ batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
63
+ )
64
+
65
+ query_states = query_states.transpose(1, 2)
66
+ key_states = key_states.transpose(1, 2)
67
+ value_states = value_states.transpose(1, 2)
68
+
69
+ query_states = query_states.to(torch.float32)
70
+ key_states = key_states.to(torch.float32)
71
+ value_states = value_states.to(torch.float32)
72
+
73
+ causal_mask = attention_mask
74
+ if causal_mask is not None:
75
+ causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
76
+
77
+ if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
78
+ causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
79
+
80
+ def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
81
+ def mask_mod(b, h, q_idx, kv_idx):
82
+ # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
83
+ return precomputed_mask[b][h][q_idx][kv_idx]
84
+
85
+ return mask_mod
86
+
87
+ b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
88
+
89
+ block_size = 128
90
+ q_len_rounded = _round_up_to_multiple(q_len, block_size)
91
+ kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
92
+
93
+ # *CRITICAL* we do need to expand here, else we get a CUDA index error
94
+
95
+ pad_q = q_len_rounded - q_len
96
+ pad_k = kv_len_rounded - kv_len
97
+
98
+ padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
99
+ mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
100
+
101
+ mask_4d = create_mask(
102
+ mod_fn=mask_mod_fn_orig,
103
+ B=b_mask,
104
+ H=h_mask,
105
+ Q_LEN=q_len_rounded,
106
+ KV_LEN=kv_len_rounded,
107
+ device=causal_mask.device,
108
+ _compile=False,
109
+ )
110
+
111
+ mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
112
+ block_mask = create_block_mask(
113
+ mask_mod=mask_mod_fn_padded,
114
+ B=b_mask,
115
+ H=h_mask,
116
+ Q_LEN=q_len_rounded,
117
+ KV_LEN=kv_len_rounded,
118
+ BLOCK_SIZE=block_size,
119
+ device=causal_mask.device,
120
+ _compile=False,
121
+ )
122
+
123
+ # mask is applied inside the kernel, ideally more efficiently than score_mod.
124
+ attn_output, attention_weights = flex_attention(
125
+ query_states,
126
+ key_states,
127
+ value_states,
128
+ block_mask=block_mask,
129
+ enable_gqa=True, # because we shaped query/key states for GQA
130
+ scale=head_dim**-0.5 if scaling is None else scaling,
131
+ return_lse=True,
132
+ )
133
+
134
+ attn_output = attn_output.to(dtype=original_dtype)
135
+ attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
136
+ attn_output = attn_output.reshape(
137
+ batch_size,
138
+ -1,
139
+ attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
140
+ )
141
+ return attn_output
lerobot/common/policies/pi0/modeling_pi0.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ π0: A Vision-Language-Action Flow Model for General Robot Control
19
+
20
+ [Paper](https://www.physicalintelligence.company/download/pi0.pdf)
21
+ [Jax code](https://github.com/Physical-Intelligence/openpi)
22
+
23
+ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
24
+
25
+ Install pi0 extra dependencies:
26
+ ```bash
27
+ pip install -e ".[pi0]"
28
+ ```
29
+
30
+ Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
31
+ ```bash
32
+ python lerobot/scripts/train.py \
33
+ --policy.path=lerobot/pi0 \
34
+ --dataset.repo_id=danaaubakirova/koch_test
35
+ ```
36
+
37
+ Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
38
+ pretrained with VLM default parameters before pi0 finetuning:
39
+ ```bash
40
+ python lerobot/scripts/train.py \
41
+ --policy.type=pi0 \
42
+ --dataset.repo_id=danaaubakirova/koch_test
43
+ ```
44
+
45
+ Example of using the pi0 pretrained model outside LeRobot training framework:
46
+ ```python
47
+ policy = Pi0Policy.from_pretrained("lerobot/pi0")
48
+ ```
49
+
50
+ """
51
+
52
+ import math
53
+ from collections import deque
54
+
55
+ import torch
56
+ import torch.nn.functional as F # noqa: N812
57
+ from torch import Tensor, nn
58
+ from transformers import AutoTokenizer
59
+
60
+ from lerobot.common.constants import ACTION, OBS_ROBOT
61
+ from lerobot.common.policies.normalize import Normalize, Unnormalize
62
+ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
63
+ from lerobot.common.policies.pi0.paligemma_with_expert import (
64
+ PaliGemmaWithExpertConfig,
65
+ PaliGemmaWithExpertModel,
66
+ )
67
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
68
+ from lerobot.common.utils.utils import get_safe_dtype
69
+
70
+
71
+ def create_sinusoidal_pos_embedding(
72
+ time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
73
+ ) -> Tensor:
74
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
75
+ if dimension % 2 != 0:
76
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
77
+
78
+ if time.ndim != 1:
79
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
80
+
81
+ dtype = get_safe_dtype(torch.float64, device.type)
82
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
83
+ period = min_period * (max_period / min_period) ** fraction
84
+
85
+ # Compute the outer product
86
+ scaling_factor = 1.0 / period * 2 * math.pi
87
+ sin_input = scaling_factor[None, :] * time[:, None]
88
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
89
+ return pos_emb
90
+
91
+
92
+ def sample_beta(alpha, beta, bsize, device):
93
+ gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
94
+ gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
95
+ return gamma1 / (gamma1 + gamma2)
96
+
97
+
98
+ def make_att_2d_masks(pad_masks, att_masks):
99
+ """Copied from big_vision.
100
+
101
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
102
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
103
+ setup several types of attention, for example:
104
+
105
+ [[1 1 1 1 1 1]]: pure causal attention.
106
+
107
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
108
+ themselves and the last 3 tokens have a causal attention. The first
109
+ entry could also be a 1 without changing behaviour.
110
+
111
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
112
+ block can attend all previous blocks and all tokens on the same block.
113
+
114
+ Args:
115
+ input_mask: bool[B, N] true if its part of the input, false if padding.
116
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
117
+ it and 0 where it shares the same attention mask as the previous token.
118
+ """
119
+ if att_masks.ndim != 2:
120
+ raise ValueError(att_masks.ndim)
121
+ if pad_masks.ndim != 2:
122
+ raise ValueError(pad_masks.ndim)
123
+
124
+ cumsum = torch.cumsum(att_masks, dim=1)
125
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
126
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
127
+ att_2d_masks = att_2d_masks & pad_2d_masks
128
+ return att_2d_masks
129
+
130
+
131
+ def resize_with_pad(img, width, height, pad_value=-1):
132
+ # assume no-op when width height fits already
133
+ if img.ndim != 4:
134
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
135
+
136
+ cur_height, cur_width = img.shape[2:]
137
+
138
+ ratio = max(cur_width / width, cur_height / height)
139
+ resized_height = int(cur_height / ratio)
140
+ resized_width = int(cur_width / ratio)
141
+ resized_img = F.interpolate(
142
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
143
+ )
144
+
145
+ pad_height = max(0, int(height - resized_height))
146
+ pad_width = max(0, int(width - resized_width))
147
+
148
+ # pad on left and top of image
149
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
150
+ return padded_img
151
+
152
+
153
+ def pad_vector(vector, new_dim):
154
+ """Can be (batch_size x sequence_length x features_dimension)
155
+ or (batch_size x features_dimension)
156
+ """
157
+ if vector.shape[-1] == new_dim:
158
+ return vector
159
+ shape = list(vector.shape)
160
+ current_dim = shape[-1]
161
+ shape[-1] = new_dim
162
+ new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
163
+ new_vector[..., :current_dim] = vector
164
+ return new_vector
165
+
166
+
167
+ def normalize(x, min_val, max_val):
168
+ return (x - min_val) / (max_val - min_val)
169
+
170
+
171
+ def unnormalize(x, min_val, max_val):
172
+ return x * (max_val - min_val) + min_val
173
+
174
+
175
+ def safe_arcsin(value):
176
+ # This ensures that the input stays within
177
+ # [−1,1] to avoid invalid values for arcsin
178
+ return torch.arcsin(torch.clamp(value, -1.0, 1.0))
179
+
180
+
181
+ def aloha_gripper_to_angular(value):
182
+ # Aloha transforms the gripper positions into a linear space. The following code
183
+ # reverses this transformation to be consistent with pi0 which is pretrained in
184
+ # angular space.
185
+ #
186
+ # These values are coming from the Aloha code:
187
+ # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
188
+ value = unnormalize(value, min_val=0.01844, max_val=0.05800)
189
+
190
+ # This is the inverse of the angular to linear transformation inside the Interbotix code.
191
+ def linear_to_radian(linear_position, arm_length, horn_radius):
192
+ value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
193
+ return safe_arcsin(value)
194
+
195
+ # The constants are taken from the Interbotix code.
196
+ value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
197
+
198
+ # Normalize to [0, 1].
199
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
200
+ return normalize(value, min_val=0.4, max_val=1.5)
201
+
202
+
203
+ def aloha_gripper_from_angular(value):
204
+ # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
205
+ # Note that the units are still angular but the range is different.
206
+
207
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
208
+ value = unnormalize(value, min_val=0.4, max_val=1.5)
209
+
210
+ # These values are coming from the Aloha code:
211
+ # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
212
+ return normalize(value, min_val=-0.6213, max_val=1.4910)
213
+
214
+
215
+ def aloha_gripper_from_angular_inv(value):
216
+ # Directly inverts the gripper_from_angular function.
217
+ value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
218
+ return normalize(value, min_val=0.4, max_val=1.5)
219
+
220
+
221
+ class PI0Policy(PreTrainedPolicy):
222
+ """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
223
+
224
+ config_class = PI0Config
225
+ name = "pi0"
226
+
227
+ def __init__(
228
+ self,
229
+ config: PI0Config,
230
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
231
+ ):
232
+ """
233
+ Args:
234
+ config: Policy configuration class instance or None, in which case the default instantiation of
235
+ the configuration class is used.
236
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
237
+ that they will be passed with a call to `load_state_dict` before the policy is used.
238
+ """
239
+
240
+ super().__init__(config)
241
+ config.validate_features()
242
+ self.config = config
243
+ self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
244
+ self.normalize_targets = Normalize(
245
+ config.output_features, config.normalization_mapping, dataset_stats
246
+ )
247
+ self.unnormalize_outputs = Unnormalize(
248
+ config.output_features, config.normalization_mapping, dataset_stats
249
+ )
250
+
251
+ self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
252
+ self.model = PI0FlowMatching(config)
253
+
254
+ self.reset()
255
+
256
+ def reset(self):
257
+ """This should be called whenever the environment is reset."""
258
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
259
+
260
+ def get_optim_params(self) -> dict:
261
+ return self.parameters()
262
+
263
+ @torch.no_grad
264
+ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
265
+ """Select a single action given environment observations.
266
+
267
+ This method wraps `select_actions` in order to return one action at a time for execution in the
268
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
269
+ queue is empty.
270
+ """
271
+ self.eval()
272
+
273
+ if self.config.adapt_to_pi_aloha:
274
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
275
+
276
+ batch = self.normalize_inputs(batch)
277
+
278
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
279
+ # querying the policy.
280
+ if len(self._action_queue) == 0:
281
+ images, img_masks = self.prepare_images(batch)
282
+ state = self.prepare_state(batch)
283
+ lang_tokens, lang_masks = self.prepare_language(batch)
284
+
285
+ actions = self.model.sample_actions(
286
+ images, img_masks, lang_tokens, lang_masks, state, noise=noise
287
+ )
288
+
289
+ # Unpad actions
290
+ original_action_dim = self.config.action_feature.shape[0]
291
+ actions = actions[:, :, :original_action_dim]
292
+
293
+ actions = self.unnormalize_outputs({"action": actions})["action"]
294
+
295
+ if self.config.adapt_to_pi_aloha:
296
+ actions = self._pi_aloha_encode_actions(actions)
297
+
298
+ # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
299
+ # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
300
+ self._action_queue.extend(actions.transpose(0, 1))
301
+ return self._action_queue.popleft()
302
+
303
+ def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
304
+ """Do a full training forward pass to compute the loss"""
305
+ if self.config.adapt_to_pi_aloha:
306
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
307
+ batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
308
+
309
+ batch = self.normalize_inputs(batch)
310
+ batch = self.normalize_targets(batch)
311
+
312
+ images, img_masks = self.prepare_images(batch)
313
+ state = self.prepare_state(batch)
314
+ lang_tokens, lang_masks = self.prepare_language(batch)
315
+ actions = self.prepare_action(batch)
316
+ actions_is_pad = batch.get("action_is_pad")
317
+
318
+ loss_dict = {}
319
+ losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
320
+ loss_dict["losses_after_forward"] = losses.clone()
321
+
322
+ if actions_is_pad is not None:
323
+ in_episode_bound = ~actions_is_pad
324
+ losses = losses * in_episode_bound.unsqueeze(-1)
325
+ loss_dict["losses_after_in_ep_bound"] = losses.clone()
326
+
327
+ # Remove padding
328
+ losses = losses[:, :, : self.config.max_action_dim]
329
+ loss_dict["losses_after_rm_padding"] = losses.clone()
330
+
331
+ # For backward pass
332
+ loss = losses.mean()
333
+ # For logging
334
+ loss_dict["l2_loss"] = loss.item()
335
+
336
+ return loss, loss_dict
337
+
338
+ def prepare_images(self, batch):
339
+ """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
340
+ convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
341
+ """
342
+ images = []
343
+ img_masks = []
344
+
345
+ present_img_keys = [key for key in self.config.image_features if key in batch]
346
+ missing_img_keys = [key for key in self.config.image_features if key not in batch]
347
+
348
+ if len(present_img_keys) == 0:
349
+ raise ValueError(
350
+ f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
351
+ )
352
+
353
+ # Preprocess image features present in the batch
354
+ for key in present_img_keys:
355
+ img = batch[key]
356
+
357
+ if self.config.resize_imgs_with_padding is not None:
358
+ img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
359
+
360
+ # Normalize from range [0,1] to [-1,1] as expacted by siglip
361
+ img = img * 2.0 - 1.0
362
+
363
+ bsize = img.shape[0]
364
+ device = img.device
365
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
366
+ images.append(img)
367
+ img_masks.append(mask)
368
+
369
+ # Create image features not present in the batch
370
+ # as fully 0 padded images.
371
+ for num_empty_cameras in range(len(missing_img_keys)):
372
+ if num_empty_cameras >= self.config.empty_cameras:
373
+ break
374
+ img = torch.ones_like(img) * -1
375
+ mask = torch.zeros_like(mask)
376
+ images.append(img)
377
+ img_masks.append(mask)
378
+
379
+ return images, img_masks
380
+
381
+ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
382
+ """Tokenize the text input"""
383
+ device = batch[OBS_ROBOT].device
384
+ tasks = batch["task"]
385
+
386
+ # PaliGemma prompt has to end with a new line
387
+ tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
388
+
389
+ tokenized_prompt = self.language_tokenizer.__call__(
390
+ tasks,
391
+ padding="max_length",
392
+ padding_side="right",
393
+ max_length=self.config.tokenizer_max_length,
394
+ return_tensors="pt",
395
+ )
396
+ lang_tokens = tokenized_prompt["input_ids"].to(device=device)
397
+ lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
398
+
399
+ return lang_tokens, lang_masks
400
+
401
+ def _pi_aloha_decode_state(self, state):
402
+ # Flip the joints.
403
+ for motor_idx in [1, 2, 8, 9]:
404
+ state[:, motor_idx] *= -1
405
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
406
+ for motor_idx in [6, 13]:
407
+ state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
408
+ return state
409
+
410
+ def _pi_aloha_encode_actions(self, actions):
411
+ # Flip the joints.
412
+ for motor_idx in [1, 2, 8, 9]:
413
+ actions[:, :, motor_idx] *= -1
414
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
415
+ for motor_idx in [6, 13]:
416
+ actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
417
+ return actions
418
+
419
+ def _pi_aloha_encode_actions_inv(self, actions):
420
+ # Flip the joints again.
421
+ for motor_idx in [1, 2, 8, 9]:
422
+ actions[:, :, motor_idx] *= -1
423
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
424
+ for motor_idx in [6, 13]:
425
+ actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
426
+ return actions
427
+
428
+ def prepare_state(self, batch):
429
+ """Pad state"""
430
+ state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
431
+ return state
432
+
433
+ def prepare_action(self, batch):
434
+ """Pad action"""
435
+ actions = pad_vector(batch[ACTION], self.config.max_action_dim)
436
+ return actions
437
+
438
+
439
+ class PI0FlowMatching(nn.Module):
440
+ """
441
+ π0: A Vision-Language-Action Flow Model for General Robot Control
442
+
443
+ [Paper](https://www.physicalintelligence.company/download/pi0.pdf)
444
+ [Jax code](https://github.com/Physical-Intelligence/openpi)
445
+
446
+ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
447
+ ┌──────────────────────────────┐
448
+ │ actions │
449
+ │ ▲ │
450
+ │ ┌┴─────┐ │
451
+ │ kv cache │Gemma │ │
452
+ │ ┌──────────►│Expert│ │
453
+ │ │ │ │ │
454
+ │ ┌┴────────┐ │x 10 │ │
455
+ │ │ │ └▲──▲──┘ │
456
+ │ │PaliGemma│ │ │ │
457
+ │ │ │ │ robot state │
458
+ │ │ │ noise │
459
+ │ └▲──▲─────┘ │
460
+ │ │ │ │
461
+ │ │ image(s) │
462
+ │ language tokens │
463
+ └──────────────────────────────┘
464
+ """
465
+
466
+ def __init__(self, config):
467
+ super().__init__()
468
+ self.config = config
469
+
470
+ paligemma_with_export_config = PaliGemmaWithExpertConfig(
471
+ freeze_vision_encoder=self.config.freeze_vision_encoder,
472
+ train_expert_only=self.config.train_expert_only,
473
+ attention_implementation=self.config.attention_implementation,
474
+ )
475
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
476
+
477
+ # Projections are float32
478
+ self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
479
+ self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
480
+ self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
481
+
482
+ self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
483
+ self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
484
+
485
+ self.set_requires_grad()
486
+
487
+ def set_requires_grad(self):
488
+ for params in self.state_proj.parameters():
489
+ params.requires_grad = self.config.train_state_proj
490
+
491
+ def sample_noise(self, shape, device):
492
+ noise = torch.normal(
493
+ mean=0.0,
494
+ std=1.0,
495
+ size=shape,
496
+ dtype=torch.float32,
497
+ device=device,
498
+ )
499
+ return noise
500
+
501
+ def sample_time(self, bsize, device):
502
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
503
+ time = time_beta * 0.999 + 0.001
504
+ return time.to(dtype=torch.float32, device=device)
505
+
506
+ def embed_prefix(
507
+ self, images, img_masks, lang_tokens, lang_masks
508
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
509
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
510
+ for PaliGemma transformer processing.
511
+ """
512
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
513
+ embs = []
514
+ pad_masks = []
515
+ att_masks = []
516
+
517
+ # TODO: remove for loop
518
+ for (
519
+ img,
520
+ img_mask,
521
+ ) in zip(images, img_masks, strict=False):
522
+ img_emb = self.paligemma_with_expert.embed_image(img)
523
+ img_emb = img_emb.to(dtype=torch.bfloat16)
524
+
525
+ # Normalize image embeddings
526
+ img_emb_dim = img_emb.shape[-1]
527
+ img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
528
+
529
+ bsize, num_img_embs = img_emb.shape[:2]
530
+ img_mask = img_mask[:, None].expand(bsize, num_img_embs)
531
+
532
+ embs.append(img_emb)
533
+ pad_masks.append(img_mask)
534
+
535
+ # Create attention masks so that image tokens attend to each other
536
+ att_masks += [0] * num_img_embs
537
+
538
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
539
+
540
+ # Normalize language embeddings
541
+ lang_emb_dim = lang_emb.shape[-1]
542
+ lang_emb = lang_emb * math.sqrt(lang_emb_dim)
543
+
544
+ embs.append(lang_emb)
545
+ pad_masks.append(lang_masks)
546
+
547
+ # full attention between image and language inputs
548
+ num_lang_embs = lang_emb.shape[1]
549
+ att_masks += [0] * num_lang_embs
550
+
551
+ embs = torch.cat(embs, dim=1)
552
+ pad_masks = torch.cat(pad_masks, dim=1)
553
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
554
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
555
+
556
+ return embs, pad_masks, att_masks
557
+
558
+ def embed_suffix(self, state, noisy_actions, timestep):
559
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
560
+ embs = []
561
+ pad_masks = []
562
+ att_masks = []
563
+
564
+ # Embed state
565
+ state_emb = self.state_proj(state)
566
+ state_emb = state_emb.to(dtype=torch.bfloat16)
567
+ embs.append(state_emb[:, None, :])
568
+ bsize = state_emb.shape[0]
569
+ dtype = state_emb.dtype
570
+ device = state_emb.device
571
+
572
+ state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
573
+ pad_masks.append(state_mask)
574
+
575
+ # Set attention masks so that image and language inputs do not attend to state or actions
576
+ att_masks += [1]
577
+
578
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
579
+ time_emb = create_sinusoidal_pos_embedding(
580
+ timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
581
+ )
582
+ time_emb = time_emb.type(dtype=dtype)
583
+
584
+ # Fuse timestep + action information using an MLP
585
+ action_emb = self.action_in_proj(noisy_actions)
586
+
587
+ time_emb = time_emb[:, None, :].expand_as(action_emb)
588
+ action_time_emb = torch.cat([action_emb, time_emb], dim=2)
589
+
590
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
591
+ action_time_emb = F.silu(action_time_emb) # swish == silu
592
+ action_time_emb = self.action_time_mlp_out(action_time_emb)
593
+
594
+ # Add to input tokens
595
+ embs.append(action_time_emb)
596
+
597
+ bsize, action_time_dim = action_time_emb.shape[:2]
598
+ action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
599
+ pad_masks.append(action_time_mask)
600
+
601
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
602
+ att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
603
+
604
+ embs = torch.cat(embs, dim=1)
605
+ pad_masks = torch.cat(pad_masks, dim=1)
606
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
607
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
608
+
609
+ return embs, pad_masks, att_masks
610
+
611
+ def forward(
612
+ self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
613
+ ) -> Tensor:
614
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
615
+ if noise is None:
616
+ noise = self.sample_noise(actions.shape, actions.device)
617
+
618
+ if time is None:
619
+ time = self.sample_time(actions.shape[0], actions.device)
620
+
621
+ time_expanded = time[:, None, None]
622
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
623
+ u_t = noise - actions
624
+
625
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
626
+ images, img_masks, lang_tokens, lang_masks
627
+ )
628
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
629
+
630
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
631
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
632
+
633
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
634
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
635
+
636
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
637
+ attention_mask=att_2d_masks,
638
+ position_ids=position_ids,
639
+ past_key_values=None,
640
+ inputs_embeds=[prefix_embs, suffix_embs],
641
+ use_cache=False,
642
+ fill_kv_cache=False,
643
+ )
644
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
645
+ # Original openpi code, upcast attention output
646
+ suffix_out = suffix_out.to(dtype=torch.float32)
647
+ v_t = self.action_out_proj(suffix_out)
648
+
649
+ losses = F.mse_loss(u_t, v_t, reduction="none")
650
+ return losses
651
+
652
+ def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
653
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
654
+ bsize = state.shape[0]
655
+ device = state.device
656
+
657
+ if noise is None:
658
+ actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
659
+ noise = self.sample_noise(actions_shape, device)
660
+
661
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
662
+ images, img_masks, lang_tokens, lang_masks
663
+ )
664
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
665
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
666
+
667
+ # Compute image and language key value cache
668
+ _, past_key_values = self.paligemma_with_expert.forward(
669
+ attention_mask=prefix_att_2d_masks,
670
+ position_ids=prefix_position_ids,
671
+ past_key_values=None,
672
+ inputs_embeds=[prefix_embs, None],
673
+ use_cache=self.config.use_cache,
674
+ fill_kv_cache=True,
675
+ )
676
+
677
+ dt = -1.0 / self.config.num_steps
678
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
679
+
680
+ x_t = noise
681
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
682
+ while time >= -dt / 2:
683
+ expanded_time = time.expand(bsize)
684
+ v_t = self.denoise_step(
685
+ state,
686
+ prefix_pad_masks,
687
+ past_key_values,
688
+ x_t,
689
+ expanded_time,
690
+ )
691
+
692
+ # Euler step
693
+ x_t += dt * v_t
694
+ time += dt
695
+ return x_t
696
+
697
+ def denoise_step(
698
+ self,
699
+ state,
700
+ prefix_pad_masks,
701
+ past_key_values,
702
+ x_t,
703
+ timestep,
704
+ ):
705
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
706
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
707
+
708
+ suffix_len = suffix_pad_masks.shape[1]
709
+ batch_size = prefix_pad_masks.shape[0]
710
+ prefix_len = prefix_pad_masks.shape[1]
711
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
712
+
713
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
714
+
715
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
716
+
717
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
718
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
719
+
720
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
721
+ attention_mask=full_att_2d_masks,
722
+ position_ids=position_ids,
723
+ past_key_values=past_key_values,
724
+ inputs_embeds=[None, suffix_embs],
725
+ use_cache=self.config.use_cache,
726
+ fill_kv_cache=False,
727
+ )
728
+ suffix_out = outputs_embeds[1]
729
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
730
+ suffix_out = suffix_out.to(dtype=torch.float32)
731
+ v_t = self.action_out_proj(suffix_out)
732
+ return v_t