#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. copy from https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/visualize_dataset_html.py Note: The last frame of the episode doesnt always correspond to a final state. That's because our datasets are composed of transition from state to state up to the antepenultimate state associated to the ultimate action to arrive in the final state. However, there might not be a transition from a final state to another state. Note: This script aims to visualize the data used to train the neural networks. ~What you see is what you get~. When visualizing image modality, it is often expected to observe lossly compression artifacts since these images have been decoded from compressed mp4 videos to save disk space. The compression factor applied has been tuned to not affect success rate. Example of usage: - Visualize data stored on a local machine: ```bash local$ python lerobot/scripts/visualize_dataset_html.py \ --repo-id lerobot/pusht local$ open http://localhost:9090 ``` - Visualize data stored on a distant machine with a local viewer: ```bash distant$ python lerobot/scripts/visualize_dataset_html.py \ --repo-id lerobot/pusht local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel local$ open http://localhost:9090 ``` - Select episodes to visualize: ```bash python lerobot/scripts/visualize_dataset_html.py \ --repo-id lerobot/pusht \ --episodes 7 3 5 1 4 ``` """ import argparse import csv import json import logging import re import shutil import tempfile from io import StringIO from pathlib import Path import numpy as np import pandas as pd import requests from flask import Flask, redirect, render_template, request, url_for from huggingface_hub import HfApi from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import IterableNamespace from lerobot.common.utils.utils import init_logging def available_datasets(): api = HfApi() datasets = api.list_datasets(author="IPEC-COMMUNITY", tags=["LeRobot"], filter="modality:video") datasets = [dataset.id for dataset in datasets] return datasets def run_server( dataset: LeRobotDataset | IterableNamespace | None, episodes: list[int] | None, host: str, port: str, static_folder: Path, template_folder: Path, ): app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @app.route("/") def hommepage(dataset=dataset): if dataset: dataset_namespace, dataset_name = dataset.repo_id.split("/") return redirect( url_for( "show_episode", dataset_namespace=dataset_namespace, dataset_name=dataset_name, episode_id=0, ) ) dataset_param, episode_param = None, None all_params = request.args if "dataset" in all_params: dataset_param = all_params["dataset"] if "episode" in all_params: episode_param = int(all_params["episode"]) if dataset_param: dataset_namespace, dataset_name = dataset_param.split("/") return redirect( url_for( "show_episode", dataset_namespace=dataset_namespace, dataset_name=dataset_name, episode_id=episode_param if episode_param is not None else 0, ) ) featured_datasets = [ "IPEC-COMMUNITY/roboturk_lerobot", "IPEC-COMMUNITY/cmu_play_fusion_lerobot", "IPEC-COMMUNITY/fractal20220817_data_lerobot", ] return render_template( "visualize_dataset_homepage.html", featured_datasets=featured_datasets, lerobot_datasets=available_datasets(), ) @app.route("//") def show_first_episode(dataset_namespace, dataset_name): first_episode_id = 0 return redirect( url_for( "show_episode", dataset_namespace=dataset_namespace, dataset_name=dataset_name, episode_id=first_episode_id, ) ) @app.route("///episode_") def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): repo_id = f"{dataset_namespace}/{dataset_name}" try: if dataset is None: dataset = get_dataset_info(repo_id) except FileNotFoundError: return ( "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461", 400, ) dataset_version = dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version match = re.search(r"v(\d+)\.", dataset_version) if match: major_version = int(match.group(1)) if major_version < 2: return "Make sure to convert your LeRobotDataset to v2 & above." episode_data_csv_str, columns = get_episode_data(dataset, episode_id) dataset_info = { "repo_id": f"{dataset_namespace}/{dataset_name}", "num_samples": dataset.num_frames if isinstance(dataset, LeRobotDataset) else dataset.total_frames, "num_episodes": dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes, "fps": dataset.fps, } if isinstance(dataset, LeRobotDataset): video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys] videos_info = [ {"url": url_for("static", filename=video_path), "filename": video_path.parent.name} for video_path in video_paths ] tasks = dataset.meta.episodes[episode_id]["tasks"] else: video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"] videos_info = [ { "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.video_path.format( episode_chunk=int(episode_id) // dataset.chunks_size, video_key=video_key, episode_index=episode_id, ), "filename": video_key, } for video_key in video_keys ] response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl") response.raise_for_status() # Split into lines and parse each line as JSON tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()] filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id] tasks = filtered_tasks_jsonl[0]["tasks"] videos_info[0]["language_instruction"] = tasks if episodes is None: episodes = list( range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) ) return render_template( "visualize_dataset_template.html", episode_id=episode_id, episodes=episodes, dataset_info=dataset_info, videos_info=videos_info, episode_data_csv_str=episode_data_csv_str, columns=columns, ) app.run(host=host, port=port) def get_ep_csv_fname(episode_id: int): ep_csv_fname = f"episode_{episode_id}.csv" return ep_csv_fname def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index): """Get a csv str containing timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" columns = [] selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"] selected_columns.remove("timestamp") # init header of csv with state and action names header = ["timestamp"] for column_name in selected_columns: dim_state = ( dataset.meta.shapes[column_name][0] if isinstance(dataset, LeRobotDataset) else dataset.features[column_name].shape[0] ) header += [f"{column_name}_{i}" for i in range(dim_state)] if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: column_names = dataset.features[column_name]["names"] while not isinstance(column_names, list): column_names = list(column_names.values())[0] else: column_names = [f"motor_{i}" for i in range(dim_state)] columns.append({"key": column_name, "value": column_names}) selected_columns.insert(0, "timestamp") if isinstance(dataset, LeRobotDataset): from_idx = dataset.episode_data_index["from"][episode_index] to_idx = dataset.episode_data_index["to"][episode_index] data = dataset.hf_dataset.select(range(from_idx, to_idx)).select_columns(selected_columns).with_format("pandas") else: repo_id = dataset.repo_id url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index ) df = pd.read_parquet(url) data = df[selected_columns] # Select specific columns rows = np.hstack( ( np.expand_dims(data["timestamp"], axis=1), *[np.vstack(data[col]) for col in selected_columns[1:]], ) ).tolist() # Convert data to CSV string csv_buffer = StringIO() csv_writer = csv.writer(csv_buffer) # Write header csv_writer.writerow(header) # Write data rows csv_writer.writerows(rows) csv_string = csv_buffer.getvalue() return csv_string, columns def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: # get first frame of episode (hack to get video_path of the episode) first_frame_idx = dataset.episode_data_index["from"][ep_index].item() return [dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.meta.video_keys] def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: # check if the dataset has language instructions if "language_instruction" not in dataset.features: return None # get first frame index first_frame_idx = dataset.episode_data_index["from"][ep_index].item() language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored # with the tf.tensor appearing in the string return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") def get_dataset_info(repo_id: str) -> IterableNamespace: response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json") response.raise_for_status() # Raises an HTTPError for bad responses dataset_info = response.json() dataset_info["repo_id"] = repo_id return IterableNamespace(dataset_info) def visualize_dataset_html( dataset: LeRobotDataset | None, episodes: list[int] | None = None, output_dir: Path | None = None, serve: bool = True, host: str = "127.0.0.1", port: int = 9090, force_override: bool = False, ) -> Path | None: init_logging() template_dir = Path(__file__).resolve().parent / "templates" if output_dir is None: # Create a temporary directory that will be automatically cleaned up output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") output_dir = Path(output_dir) if output_dir.exists(): if force_override: shutil.rmtree(output_dir) else: logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") output_dir.mkdir(parents=True, exist_ok=True) static_dir = output_dir / "static" static_dir.mkdir(parents=True, exist_ok=True) if dataset is None: if serve: run_server( dataset=None, episodes=None, host=host, port=port, static_folder=static_dir, template_folder=template_dir, ) else: # Create a simlink from the dataset video folder containg mp4 files to the output directory # so that the http server can get access to the mp4 files. if isinstance(dataset, LeRobotDataset): ln_videos_dir = static_dir / "videos" if not ln_videos_dir.exists(): ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) if serve: run_server(dataset, episodes, host, port, static_dir, template_dir) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--repo-id", type=str, default=None, help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", ) parser.add_argument( "--local-files-only", type=int, default=0, help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", ) parser.add_argument( "--root", type=Path, default=None, help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", ) parser.add_argument( "--load-from-hf-hub", type=int, default=0, help="Load videos and parquet files from HF Hub rather than local system.", ) parser.add_argument( "--episodes", type=int, nargs="*", default=None, help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", ) parser.add_argument( "--output-dir", type=Path, default=None, help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", ) parser.add_argument( "--serve", type=int, default=1, help="Launch web server.", ) parser.add_argument( "--host", type=str, default="127.0.0.1", help="Web host used by the http server.", ) parser.add_argument( "--port", type=int, default=9090, help="Web port used by the http server.", ) parser.add_argument( "--force-override", type=int, default=0, help="Delete the output directory if it exists already.", ) args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") load_from_hf_hub = kwargs.pop("load_from_hf_hub") root = kwargs.pop("root") local_files_only = kwargs.pop("local_files_only") dataset = None if repo_id: dataset = ( LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) if not load_from_hf_hub else get_dataset_info(repo_id) ) visualize_dataset_html(dataset, **vars(args)) if __name__ == "__main__": main()