BenchMARL icon indicating copy to clipboard operation
BenchMARL copied to clipboard

Fails to export policy

Open jeguzzi opened this issue 11 months ago • 3 comments

I'm new to TorchRL and BenchMARL. First of all, thank you for the impressive work!

I'm trying to export BenchMARL policies following the TorchRL documentation. The following code

from benchmarl.algorithms import MappoConfig
from benchmarl.environments import PettingZooTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models.mlp import MlpConfig
from torchrl.envs import set_exploration_type
import torch


experiment_config = ExperimentConfig.get_from_yaml()
experiment_config.loggers = []
task = PettingZooTask.SIMPLE_SPEAKER_LISTENER.get_from_yaml()
algorithm_config = MappoConfig.get_from_yaml()
model_config = MlpConfig.get_from_yaml()
critic_model_config = MlpConfig.get_from_yaml()
experiment = Experiment(
    task=task,
    algorithm_config=algorithm_config,
    model_config=model_config,
    critic_model_config=critic_model_config,
    seed=0,
    config=experiment_config,
)
policy, _ = experiment.policy
env = experiment.env_func()
x = env.base_env.fake_tensordict()
obs = x['speaker']['observation']

with set_exploration_type("DETERMINISTIC"):
    exported_policy = torch.export.export(
        policy.select_out_keys(('speaker', 'action')),
        args=(),
        kwargs={'speaker_observation': obs},
        strict=True,
    )

raises a

NotImplementedError: GetAttrVariable(UnspecializedBuiltinNNModuleVariable(Linear), __dict__) has no type

I attach the full log (with TORCH_LOGS="+dynamo", TORCHDYNAMO_VERBOSE=1)

I do not understand if/which parts of the model are causing the issue. Is it is possible to export such policies? Should I do something differently?


Versions:

  • System: macOS 18.4 arm64
  • Python: 3.12
  • torch: 2.6.0
  • torchrl: 0.7.2
  • benchmarl: 1.4.0 (from head)

export_fails.log

jeguzzi avatar Apr 10 '25 08:04 jeguzzi

Thanks for this! I'll try to reproduce and let you know

matteobettini avatar Apr 10 '25 13:04 matteobettini

I have distilled this into this torchrl issue https://github.com/pytorch/rl/issues/2902

matteobettini avatar Apr 11 '25 10:04 matteobettini

There seems to be some issues related to vmap.

In the meantime, the way I use to export the policy is that you can dig inside it until you find the parameters. And then you can export those and rebuild the model you want.

Here is an example of how we traced and exported a gnn model (just an example of the things you can do)

class PostEncNN(torch.nn.Module):
    def __init__(self, policy):
        super().__init__()

        actor = policy.module[0].module[0].module[0]
        mlp = actor.mlp_local_and_comms
        self.mlp_nn = deepcopy(mlp._empty_net)
        mlp.params.to_module(self.mlp_nn)

        self.loc_scale = policy.module[0].module[0].module[1].module
        dist_module = policy.module[0].module[1]
        min_act = dist_module.distribution_kwargs['min'][0]
        max_act = dist_module.distribution_kwargs['max'][0]
        self.transform = D.AffineTransform(loc=(max_act + min_act) / 2, scale=(max_act - min_act) / 2)

    def forward(self, x):
        logits = self.mlp_nn(x)
        loc, scale = self.loc_scale(logits)
        return self.transform(loc.tanh())

def jit_models(policy):
    actor = policy.module[0].module[0].module[0]
    enc = actor.gnns[0].message_encoder
    enc_traced = torch.jit.trace(enc, example_inputs=torch.zeros(1, 9))

    post = PostEncNN(policy)    
    post_traced = torch.jit.trace(post, example_inputs=torch.zeros(1, 132))
    
    return enc_traced, post_traced

Not that nice, but a viable solution until the tensordict/torchrl exports get fixed. + you won't have any torchrl/tensordict/benchmarl dependencies in your exported net

matteobettini avatar Apr 17 '25 09:04 matteobettini