rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] `torchrl.objectives.SACLoss` is broken when there is more than one `qvalue_network`

Open fmeirinhos opened this issue 1 year ago • 0 comments

Describe the bug

The torchrl.objectives.SACLoss module is currently broken when the input type of qvalue_network is a List[TensorDictModule].

Note also the discrepancy between the docstring type TensorDictModule and the constructor-argument union-type TensorDictModule | List[TensorDictModule].

The bug is there because the internal method _set_in_keys cannot extract the in_keys of a List[TensorDictModule].

NOTE: I do not know if this is the only extent for which the method breaks down when there are multiple qvalue_networks.

To Reproduce

This is the same example given in the docstring, but with two qvalue_networks

import torch
from torch import nn
from torchrl.data import Bounded
from tensordict import TensorDict
from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.objectives.sac import SACLoss

_ = torch.manual_seed(42)
n_act, n_obs = 4, 3
spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
module = SafeModule(
    module=nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()),
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)
actor = ProbabilisticActor(
    module=module,
    in_keys=["loc", "scale"],
    spec=spec,
    distribution_class=TanhNormal,
)


class ValueClass(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(n_obs + n_act, 1)

    def forward(self, obs, act):
        return self.linear(torch.cat([obs, act], -1))


qvalue = ValueOperator(
    module=ValueClass(),
    in_keys=["observation", "action"],
)
value = ValueOperator(
    module=nn.Linear(n_obs, 1),
    in_keys=["observation"],
)
loss = SACLoss(actor, [qvalue, qvalue], num_qvalue_nets=2)
batch = [
    2,
]
action = spec.rand(batch)
data = TensorDict(
    {
        "observation": torch.randn(*batch, n_obs),
        "action": action,
        ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
        ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
        ("next", "reward"): torch.randn(*batch, 1),
        ("next", "observation"): torch.randn(*batch, n_obs),
    },
    batch,
)

loss(data)

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

fmeirinhos avatar Nov 20 '24 11:11 fmeirinhos