Fails to export policy
I'm new to TorchRL and BenchMARL. First of all, thank you for the impressive work!
I'm trying to export BenchMARL policies following the TorchRL documentation. The following code
from benchmarl.algorithms import MappoConfig
from benchmarl.environments import PettingZooTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models.mlp import MlpConfig
from torchrl.envs import set_exploration_type
import torch
experiment_config = ExperimentConfig.get_from_yaml()
experiment_config.loggers = []
task = PettingZooTask.SIMPLE_SPEAKER_LISTENER.get_from_yaml()
algorithm_config = MappoConfig.get_from_yaml()
model_config = MlpConfig.get_from_yaml()
critic_model_config = MlpConfig.get_from_yaml()
experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
)
policy, _ = experiment.policy
env = experiment.env_func()
x = env.base_env.fake_tensordict()
obs = x['speaker']['observation']
with set_exploration_type("DETERMINISTIC"):
exported_policy = torch.export.export(
policy.select_out_keys(('speaker', 'action')),
args=(),
kwargs={'speaker_observation': obs},
strict=True,
)
raises a
NotImplementedError: GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__) has no type
I attach the full log (with TORCH_LOGS="+dynamo", TORCHDYNAMO_VERBOSE=1)
I do not understand if/which parts of the model are causing the issue. Is it is possible to export such policies? Should I do something differently?
Versions:
- System: macOS 18.4 arm64
- Python: 3.12
- torch: 2.6.0
- torchrl: 0.7.2
- benchmarl: 1.4.0 (from head)
Thanks for this! I'll try to reproduce and let you know
I have distilled this into this torchrl issue https://github.com/pytorch/rl/issues/2902
There seems to be some issues related to vmap.
In the meantime, the way I use to export the policy is that you can dig inside it until you find the parameters. And then you can export those and rebuild the model you want.
Here is an example of how we traced and exported a gnn model (just an example of the things you can do)
class PostEncNN(torch.nn.Module):
def __init__(self, policy):
super().__init__()
actor = policy.module[0].module[0].module[0]
mlp = actor.mlp_local_and_comms
self.mlp_nn = deepcopy(mlp._empty_net)
mlp.params.to_module(self.mlp_nn)
self.loc_scale = policy.module[0].module[0].module[1].module
dist_module = policy.module[0].module[1]
min_act = dist_module.distribution_kwargs['min'][0]
max_act = dist_module.distribution_kwargs['max'][0]
self.transform = D.AffineTransform(loc=(max_act + min_act) / 2, scale=(max_act - min_act) / 2)
def forward(self, x):
logits = self.mlp_nn(x)
loc, scale = self.loc_scale(logits)
return self.transform(loc.tanh())
def jit_models(policy):
actor = policy.module[0].module[0].module[0]
enc = actor.gnns[0].message_encoder
enc_traced = torch.jit.trace(enc, example_inputs=torch.zeros(1, 9))
post = PostEncNN(policy)
post_traced = torch.jit.trace(post, example_inputs=torch.zeros(1, 132))
return enc_traced, post_traced
Not that nice, but a viable solution until the tensordict/torchrl exports get fixed. + you won't have any torchrl/tensordict/benchmarl dependencies in your exported net