[BUG] info['_weight'] device for Importance Sampling in PER
Describe the bug
The device of info['_weight'] doesn't match the storage device.
To Reproduce
# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)
# Check devices
print(f"sample device: {sample.device}\n"
f"info['_weight'] device: {info['_weight'].device}")
sample device: cuda:0
info['_weight'] device: cpu
Expected behavior
Both should be on the same device defined in storage(..., device) as these weights are later used to compute the loss.
System info
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.10.23 1.26.4 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0] linux
Reason and Possible fixes
Specify device argument in samplers.py (L508):
weight = torch.as_tensor(self._sum_tree[index], device=storage.device)
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)
That and also we should be able to execute this directly on device. I'll push some changes
Just FYI you could do this instead:
# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler, TensorDictReplayBuffer
from tensordict import TensorDict
import torch
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample = rb.sample(10)
# Check devices
print(f"sample device: {sample.device}\n"
f"sample['_weight'] device: {sample['_weight'].device}")
which will put your weights on cuda.
There are two issues in patching the PRB to account for the device of the storage:
-
The issue you're having is caused by the fact that, for the
ReplayBufferclass, the device of the storage is unknown, but it could beNone. Also, the sampler is unaware of what the storage is. You could have multiple storages for instance. So in practice, if we want to cast the content of the info dict to the storage device, we would need to pass the storage device to the sampler and do that transfer. Another option could be for the buffer (and not the sampler) to do the casting if and only if the info dict is required (that would avoid useless H2D transfers when the info dict isn't asked for) but then we would still face the issue (2) below. -
If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:
# From documentation
import functools
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
import torch
device = "cuda"
# patch
sample = PrioritizedSampler.sample
@functools.wraps(sample)
def new_sample(self, *args, **kwargs):
out = sample(self, *args, **kwargs)
out = torch.utils._pytree.tree_map(lambda x: x.to(device), out)
return out
PrioritizedSampler.sample = new_sample
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device(device)), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
# map back content on cpu
rb.append_transform(lambda x: x.to("cpu"))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)
# Check devices
print(f"sample device: {sample.device}\n"
f"info['_weight'] device: {info['_weight'].device}")
So to recap:
PRB is currenlty only hosted on CPU. It's the only part of the lib that relies on C++ code. The fact that the compuation is done on CPU is why you're getting info dict on cpu. Mapping to the storage device could be done We could do the sumtree and mintree on CUDA, that shouldn't be too hard. In the meantime we can send the info dict content to the storage device (see #2527) but that will only be an incomplete patch if you're not using TensorDictReplayBuffer.
Also, the sampler is unaware of what the storage is. You could have multiple storages for instance.
Maybe I'm missing something, but def sample(self, storage: Storage, batch_size: int) accepts the storage as an argument, thus we can query storage.device - which will also cover the multiple storages case.
If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:
That's a valid point. I wanted to suggest adding info to the data, but preallocating the memory might not be that trivial. On the other hand, I can't think of any reason (besides mapping a device) for which one will need to transform the info dict.