[BUG] Failing to export multi-agent models
torch.export seems to not work on the multiagent models
I have distilled the issue from https://github.com/facebookresearch/BenchMARL/issues/188 into this rminimal reproducing script
import torch
from tensordict import TensorDict
from tensordict.nn import NormalParamExtractor, TensorDictModule
from torch import nn
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal
from torchrl.envs.utils import ExplorationType, set_exploration_type
n_actions = 3
n_obs = 5
n_agents = 2
batch = 4
policy = TensorDictModule(
MultiAgentMLP(
n_agent_inputs=n_obs,
n_agent_outputs=2 * n_actions,
n_agents=n_agents,
centralised=False,
share_params=True,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
),
in_keys=[("agents", "observation")],
out_keys=[
("agents", "out"),
],
)
obs = TensorDict(
{
"agents": TensorDict(
{"observation": torch.randn((batch, n_agents, n_obs))},
batch_size=[batch, n_agents],
)
},
batch_size=[batch],
)
print(policy(obs)) # Success
with set_exploration_type(ExplorationType.DETERMINISTIC):
exported_policy = torch.export.export(
policy.select_out_keys(("agents", "out")),
args=(),
kwargs={"agents_observation": obs["agents", "observation"]},
strict=True,
) # Fail
torch._dynamo.exc.Unsupported: isinstance(GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__), BuiltinVariable(dict)): can't determine type of GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__)
This is a TensorDict bug, addressed in https://github.com/pytorch/tensordict/pull/1285
Thanks @vmoens for looking into this.
I tried the latest tensordict (https://github.com/pytorch/tensordict/commit/c61d045aaadf6c0625706a3670fc6a741f31f1b0) and TorchRL (https://github.com/pytorch/rl/commit/382430db3c457312366fce4ea42330a656337419) from source as the nightly are not built for macOS. I also installed the lasted nightly of PyTorch:
>>> torch.__version__, tensordict.__version__, torchrl.__version__
('2.8.0.dev20250421', '0.8.0+c61d045', '0.7.0+382430d')
The code posted by @matteobettini still fails with this configuration, raising the following exception:
Unsupported: builtin isinstance() cannot determine type of argument
Explanation: Dynamo doesn't have a rule to determine the type of argument GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__)
Hint: This is likely to be a Dynamo bug. Please report an issue to PyTorch.
Developer debug context: isinstance(GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__), BuiltinVariable(dict))
Still seeing this in v0.10.0 (both torchrl and tensordict).
A few comments on my findings, there seem to be various causes of failure:
select_out_keys:
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn
n_obs = 5
n_actions = 3
module = TensorDictModule(
nn.Linear(n_obs, n_actions),
in_keys=["obs"],
out_keys=["out"],
)
obs = TensorDict({"obs": torch.randn(4, n_obs)}, batch_size=[4])
print(module(obs))
exported = torch.export.export(
module,
args=(),
kwargs={"obs": obs["obs"]},
strict=True,
) # Success
print("No select success")
exported = torch.export.export(
module.select_out_keys("out"),
args=(),
kwargs={"obs": obs["obs"]},
strict=True,
) # Fail
The from_module is a bit trickier and could be due to device move and uninitialised weights. Maybe due to https://github.com/pytorch/rl/blob/80bfa6e957fed6bc37c6d73db0680198119c4f0b/torchrl/modules/models/multiagent.py#L152
Here is a simpler repro.
import torch
from torch import nn
from tensordict import TensorDict
class Repro(nn.Module):
def __init__(self):
super().__init__()
self.skel = nn.Linear(5, 3)
# convert skel parameters to meta tensors to mimic uninitialised weights
meta_td = TensorDict.from_module(self.skel).clone().to("meta")
meta_td.to_module(self.skel)
# real parameters stored separately
self.params = TensorDict.from_module(nn.Linear(5, 3))
def forward(self, x):
with self.params.to_module(self.skel):
return self.skel(x)
torch.export.export(Repro(), (torch.randn(2, 5),), strict=True)
For when share_params=False, the vmap seems problematic in https://github.com/pytorch/rl/blob/80bfa6e957fed6bc37c6d73db0680198119c4f0b/torchrl/modules/models/multiagent.py#L127