rl
rl copied to clipboard
[BUG] `torchrl.objectives.SACLoss` is broken when there is more than one `qvalue_network`
Describe the bug
The torchrl.objectives.SACLoss module is currently broken when the input type of qvalue_network is a List[TensorDictModule].
Note also the discrepancy between the docstring type TensorDictModule and the constructor-argument union-type TensorDictModule | List[TensorDictModule].
The bug is there because the internal method _set_in_keys cannot extract the in_keys of a List[TensorDictModule].
NOTE: I do not know if this is the only extent for which the method breaks down when there are multiple qvalue_networks.
To Reproduce
This is the same example given in the docstring, but with two qvalue_networks
import torch
from torch import nn
from torchrl.data import Bounded
from tensordict import TensorDict
from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.objectives.sac import SACLoss
_ = torch.manual_seed(42)
n_act, n_obs = 4, 3
spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
module = SafeModule(
module=nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()),
in_keys=["observation"],
out_keys=["loc", "scale"],
)
actor = ProbabilisticActor(
module=module,
in_keys=["loc", "scale"],
spec=spec,
distribution_class=TanhNormal,
)
class ValueClass(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(n_obs + n_act, 1)
def forward(self, obs, act):
return self.linear(torch.cat([obs, act], -1))
qvalue = ValueOperator(
module=ValueClass(),
in_keys=["observation", "action"],
)
value = ValueOperator(
module=nn.Linear(n_obs, 1),
in_keys=["observation"],
)
loss = SACLoss(actor, [qvalue, qvalue], num_qvalue_nets=2)
batch = [
2,
]
action = spec.rand(batch)
data = TensorDict(
{
"observation": torch.randn(*batch, n_obs),
"action": action,
("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
("next", "reward"): torch.randn(*batch, 1),
("next", "observation"): torch.randn(*batch, n_obs),
},
batch,
)
loss(data)
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)