Using LeRobotDataset 'episodes' argument causes error on __getitem__
System Info
At some point since https://github.com/huggingface/lerobot/pull/724, LeRobotDataset (specifically lerobot/pusht) started failing for me, specifically when using it with episode indices like so:
delta_timestamps = {
"observation.image": [0.0],
"observation.state": [0.0],
"action": [0.0 + 0.1 * i for i in range(15)],
}
train_dataset = LeRobotDataset(
"lerobot/pusht",
delta_timestamps=delta_timestamps,
episodes=list(range(200)),
)
val_dataset = LeRobotDataset(
"lerobot/pusht",
delta_timestamps=delta_timestamps,
episodes=list(range(200, 205)),
)
When I iterate over the validation dataset, I get this error:
[rank0]: File "/workspaces/robots/policy/vla_dataset.py", line 602, in __getitem__
[rank0]: sample = self.dataset[idx]
[rank0]: File "/opt/lerobot/lerobot/lerobot/common/datasets/lerobot_dataset.py", line 796, in __getitem__
[rank0]: query_indices, padding = self._get_query_indices(idx, ep_idx)
[rank0]: File "/opt/lerobot/lerobot/lerobot/common/datasets/lerobot_dataset.py", line 722, in _get_query_indices
[rank0]: ep_start = self.episode_data_index["from"][ep_idx]
[rank0]: IndexError: index 200 is out of bounds for dimension 0 with size 5
I believe that `self.episode_data_index["from"]` is an array of length 5 (equal to the number of episodes I specified in the validation set), but we are trying to index into it with the raw episode indices. In this case, I believe that instead of indexing with 200, we should be indexing with 0.
This is what is at head:
def _get_query_indices(
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
this works:
def _get_query_indices(
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
relative_ep_idx = self.episodes.index(ep_idx)
ep_start = self.episode_data_index["from"][relative_ep_idx]
ep_end = self.episode_data_index["to"][relative_ep_idx]
However, looking at the blame, _get_query_indices hasn't changed in four months, so I'm not sure if there's some other solution to this problem.
Information
- [ ] One of the scripts in the examples/ folder of LeRobot
- [x] My own task or dataset (give details below)
Reproduction
Create the following datasets:
delta_timestamps = { "observation.image": [0.0], "observation.state": [0.0], "action": [0.0 + 0.1 * i for i in range(15)], }
train_dataset = LeRobotDataset( "lerobot/pusht", delta_timestamps=delta_timestamps, episodes=list(range(200)), ) val_dataset = LeRobotDataset( "lerobot/pusht", delta_timestamps=delta_timestamps, episodes=list(range(200, 205)), )
iterate through the val dataset, observe error.
Expected behavior
Can iterate through val dataset without error (without hitting indexerror)
I am also encountering this error, small snippet to reproduce the error
` from pathlib import Path from pprint import pprint
import lerobot from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
repo_id = "lerobot/pusht"
dataset = LeRobotDataset( repo_id, episodes=[10], delta_timestamps={ 'observation.image': [0, 0.1, 0.2], 'observation.state': [0, 0.1, 0.2], 'action': [0, 0.1, 0.2], 'episode_index': [0, 0.1, 0.2], 'frame_index': [0, 0.1, 0.2], 'next.reward': [0, 0.1, 0.2], 'next.done': [0, 0.1, 0.2], 'next.success': [0, 0.1, 0.2], 'index': [0, 0.1, 0.2] } )
dataset[187]`
will lead to error with stack trace
`--------------------------------------------------------------------------- IndexError Traceback (most recent call last) Cell In[29], line 25 7 repo_id = "lerobot/pusht" 9 dataset = LeRobotDataset( 10 repo_id, 11 episodes=[10], (...) 22 } 23 ) ---> 25 dataset[187]
File ~/lerobot/lerobot/common/datasets/lerobot_dataset.py:731, in LeRobotDataset.getitem(self, idx) 729 query_indices = None 730 if self.delta_indices is not None: --> 731 query_indices, padding = self._get_query_indices(idx, ep_idx) 732 query_result = self._query_hf_dataset(query_indices) 733 item = {**item, **padding}
File ~/lerobot/lerobot/common/datasets/lerobot_dataset.py:665, in LeRobotDataset._get_query_indices(self, idx, ep_idx) 664 def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: --> 665 ep_start = self.episode_data_index["from"][ep_idx] 666 ep_end = self.episode_data_index["to"][ep_idx] 667 query_indices = { 668 key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] 669 for key, delta_idx in self.delta_indices.items() 670 }
IndexError: index 10 is out of bounds for dimension 0 with size 1`
Compared with the previous version,“ current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx” is missed in function "def getitem(self, idx) -> dict:...."
#1062 Will fix this I guess.