[BUG] SliceSampler should return unique IDs when sampling multiple times from the same trajectory
Describe the bug
When using SliceSampler, with strict_length=False, the documentation recommends the use of split_trajectories. However, if two samples from the same episode are placed next to each other, this produces the wrong output because subsequent samples may have the same trajectory_key despite being logically independent.
To Reproduce
import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler
rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
sampler=SliceSampler(
slice_len=5, traj_key="episode",strict_length=False
))
ep_1 = TensorDict(
{"obs": torch.arange(100),
"episode": torch.zeros(100),},
batch_size=[100]
)
ep_2 = TensorDict(
{"obs": torch.arange(4),
"episode": torch.ones(4),},
batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)
s = rb.sample(50)
t = split_trajectories(s, trajectory_key="episode")
split_trajectories returns nonsense results when trajectory_key contains non-contiguous duplicates.
Even if that weren't the case, there would still be a bug:
When SliceSampler is drawing from relatively few trajectories, there will be situations where multiple slices of the same trajectory are returned next to each other:
episode 0 0 0 0 0 0 0 0 0 0...
obs 2 3 4 5 6 41 42 43 44 45...
|-1st slice-| |-2nd slice--|
However, split_trajectories will see that episode is the same for both slices, and incorrectly combine them into one longer slice.
Expected behavior
SliceSampler should add an additional key to its returned dict to distinguish samples, at least when strict_length=False:
episode 0 0 0 0 0 0 0 0 0 0...
obs 2 3 4 5 6 41 42 43 44 45...
slice 0 0 0 0 0 1 1 1 1 1
Screenshots
If applicable, add screenshots to help explain your problem.
System info
M1 Mac, version 15.1
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.6.0+7bf320c 1.26.4 3.11.9 (main, Apr 19 2024, 11:44:45) [Clang 14.0.6 ] darwin
Both torchrl and tensordict were installed from source.
Checklist
- [X] I have checked that there is no similar issue in the repo (required)
- [X] I have read the documentation (required)
- [X] I have provided a minimal working example to reproduce the bug (required)
Taking a fresh look at this again, it seems that a workaround may be to do something like:
sample, info = rb.sample(minibatch_size, return_info=True)
sample["next", "end_of_slice"] = (
info["next", "truncated"]
| info["next", "done"]
| info["next", "terminated"]
)
sample = split_trajectories(sample, done_key="end_of_slice")
But this is hardly ergonomic, or should at least be clarified as an example in the documentation.
Hey Thanks for reporting this
- One option is to use SliceSamplerWithoutReplacement:
import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
sampler=SliceSamplerWithoutReplacement(
slice_len=5, traj_key="episode",strict_length=False
))
ep_1 = TensorDict(
{"obs": torch.arange(100),
"episode": torch.zeros(100),},
batch_size=[100]
)
ep_2 = TensorDict(
{"obs": torch.arange(4),
"episode": torch.ones(4),},
batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)
s = rb.sample(50)
t = split_trajectories(s, trajectory_key="episode")
print(t["obs"])
print(t["episode"])
That will ensure that you don't have the same item twice
- Another is to use TensorDictReplayBuffer with the slice sampler. That will update the
("next", "truncated")key in the sampled data andsplit_trajectoriescan understand that
import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
sampler=SliceSampler(
slice_len=5, traj_key="episode",strict_length=False,
))
ep_1 = TensorDict(
{"obs": torch.arange(100),
"episode": torch.zeros(100),},
batch_size=[100]
)
ep_2 = TensorDict(
{"obs": torch.arange(4),
"episode": torch.ones(4),},
batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)
s = rb.sample(50)
print(s)
t = split_trajectories(s, done_key="truncated")
print(t["obs"])
print(t["episode"])
- Finally there is your solution (do it manually) but as you mention it's clunky. If you do it manually you could also just do
s, info = rb.sample(50, return_info=True)
print(s)
s["next", "truncated"] = info[("next", "truncated")]
t = split_trajectories(s, done_key="truncated")
But in general I do agree that we need better doc. Aside from the docstrings of the slice sampler, where would you look for that info?
Thanks for responding so quickly!
In my particular case, I am collecting a few episodes (of wildly varying length), training on a few large-ish batches on short-ish slices, and then clearing the replay buffer, so unfortunately SliceSamplerWithoutReplacement wouldn't work (though the documentation should clarify if without replacement refers to never sampling two different slices of the same episode vs allowing sampling the same episode multiple times on non-overlapping slices).
I first looked at the SliceSampler docs, for the strict_length parameter. This then led me to the docs for split_trajectories, which showed the basic usage, but should probably include a warning about its input assumptions (no duplicate trajectory_keys from different slices, even noncontiguous).
I then looked at the SliceSampler docs again, for the truncated_key parameter, which led me to discovering the return_info=True option. The docs also seem to imply ("next", "truncated") is False in cases where the last step in a slice is simply the done last step in an episode.
Is this a good edit? https://github.com/pytorch/rl/pull/2607 Would you add anything? Is there anything you think is broken in the API?
This seems like a major improvement to the documentation! Thanks for updating that.