rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] Failing to export multi-agent models

Open matteobettini opened this issue 10 months ago • 3 comments

torch.export seems to not work on the multiagent models

I have distilled the issue from https://github.com/facebookresearch/BenchMARL/issues/188 into this rminimal reproducing script

import torch
from tensordict import TensorDict
from tensordict.nn import NormalParamExtractor, TensorDictModule
from torch import nn

from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal
from torchrl.envs.utils import ExplorationType, set_exploration_type

n_actions = 3
n_obs = 5
n_agents = 2
batch = 4


policy = TensorDictModule(
    MultiAgentMLP(
        n_agent_inputs=n_obs,
        n_agent_outputs=2 * n_actions,
        n_agents=n_agents,
        centralised=False,
        share_params=True,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    ),
    in_keys=[("agents", "observation")],
    out_keys=[
        ("agents", "out"),
    ],
)

obs = TensorDict(
    {
        "agents": TensorDict(
            {"observation": torch.randn((batch, n_agents, n_obs))},
            batch_size=[batch, n_agents],
        )
    },
    batch_size=[batch],
)
print(policy(obs))  # Success
with set_exploration_type(ExplorationType.DETERMINISTIC):
    exported_policy = torch.export.export(
        policy.select_out_keys(("agents", "out")),
        args=(),
        kwargs={"agents_observation": obs["agents", "observation"]},
        strict=True,
    )  # Fail

torch._dynamo.exc.Unsupported: isinstance(GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__), BuiltinVariable(dict)): can't determine type of GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__)

matteobettini avatar Apr 11 '25 10:04 matteobettini

This is a TensorDict bug, addressed in https://github.com/pytorch/tensordict/pull/1285

vmoens avatar Apr 11 '25 11:04 vmoens

Thanks @vmoens for looking into this.

I tried the latest tensordict (https://github.com/pytorch/tensordict/commit/c61d045aaadf6c0625706a3670fc6a741f31f1b0) and TorchRL (https://github.com/pytorch/rl/commit/382430db3c457312366fce4ea42330a656337419) from source as the nightly are not built for macOS. I also installed the lasted nightly of PyTorch:

>>> torch.__version__, tensordict.__version__, torchrl.__version__
('2.8.0.dev20250421', '0.8.0+c61d045', '0.7.0+382430d')

The code posted by @matteobettini still fails with this configuration, raising the following exception:

Unsupported: builtin isinstance() cannot determine type of argument
  Explanation: Dynamo doesn't have a rule to determine the type of argument GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__)
  Hint: This is likely to be a Dynamo bug. Please report an issue to PyTorch.

  Developer debug context: isinstance(GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__), BuiltinVariable(dict))

jeguzzi avatar Apr 22 '25 08:04 jeguzzi

Still seeing this in v0.10.0 (both torchrl and tensordict).

A few comments on my findings, there seem to be various causes of failure:

select_out_keys:

import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn

n_obs = 5
n_actions = 3

module = TensorDictModule(
    nn.Linear(n_obs, n_actions),
    in_keys=["obs"],
    out_keys=["out"],
)

obs = TensorDict({"obs": torch.randn(4, n_obs)}, batch_size=[4])
print(module(obs))

exported = torch.export.export(
    module,
    args=(),
    kwargs={"obs": obs["obs"]},
    strict=True,
) # Success
print("No select success")
exported = torch.export.export(
    module.select_out_keys("out"),
    args=(),
    kwargs={"obs": obs["obs"]},
    strict=True,
) # Fail

The from_module is a bit trickier and could be due to device move and uninitialised weights. Maybe due to https://github.com/pytorch/rl/blob/80bfa6e957fed6bc37c6d73db0680198119c4f0b/torchrl/modules/models/multiagent.py#L152

Here is a simpler repro.

import torch
from torch import nn
from tensordict import TensorDict

class Repro(nn.Module):
    def __init__(self):
        super().__init__()
        self.skel = nn.Linear(5, 3)
        # convert skel parameters to meta tensors to mimic uninitialised weights
        meta_td = TensorDict.from_module(self.skel).clone().to("meta")
        meta_td.to_module(self.skel)
        # real parameters stored separately
        self.params = TensorDict.from_module(nn.Linear(5, 3))

    def forward(self, x):
        with self.params.to_module(self.skel):
            return self.skel(x)

torch.export.export(Repro(), (torch.randn(2, 5),), strict=True)

For when share_params=False, the vmap seems problematic in https://github.com/pytorch/rl/blob/80bfa6e957fed6bc37c6d73db0680198119c4f0b/torchrl/modules/models/multiagent.py#L127

Xmaster6y avatar Oct 15 '25 15:10 Xmaster6y