tensordict icon indicating copy to clipboard operation
tensordict copied to clipboard

Any plans to make it compatible with torch.jit?

Open LucaBonfiglioli opened this issue 2 years ago • 5 comments

Currently, this fails:

import torch
import torch.nn as nn
from tensordict import TensorDict
from torch.jit import script, trace  # type: ignore


class MyModule(nn.Module):
    def forward(self, x: TensorDict) -> TensorDict:
        return x


model = MyModule()
input_ = TensorDict.from_dict({"input": torch.randn(1, 1, 28, 28)})
model(input_)
trace(model, input_)(input_)  # type: ignore
script(model)(input_)  # type: ignore

Any plans to make TensorDict compatible with torch.jit?

LucaBonfiglioli avatar Sep 25 '23 07:09 LucaBonfiglioli

Hey! We're actually looking into this, thanks for showing interest. Let's hope we can make it work!

vmoens avatar Sep 25 '23 07:09 vmoens

Thank you very much! šŸ™šŸ»

LucaBonfiglioli avatar Sep 25 '23 07:09 LucaBonfiglioli

Any updates on this? The scenario I am looking into right now is deploying trained models easily on edge devices. Tracing and jitting would be super helpful, as it avoids having to install many dependencies on the device.

janblumenkamp avatar Apr 24 '24 16:04 janblumenkamp

Same here. I want to deploy a trained policy to a quadruped. Currently none of torch.jit.script, torch.jit.trace, and torch.compile work with TensorDictSequential. Does the prototyping feature symbolic_trace do the job? If yes, it would be great if an example could be added to the tutorial.

btx0424 avatar May 07 '24 15:05 btx0424

Hello! We don't have plans to make tensordict compatible with jit as of now. torchscript is deprecated and unmaintained. Since I'm alone maintaining this lib along with RL and other stuff in pytorch core, my bandwidth is very limited (we welcome contributions, as always!) At first glance, making the tensordict stack compatible with torchscript would take me at least 2 weeks (considering that I dedicate 100% of my time to it) and would likely compromise any further work to make it compatible with torch.compile.

Which brings us to the second point: the blessed way now to export python code is through torch compile + export. Making all the code stack compatible with compile is very high pri and I will make sure that this is done in a timely manner! Since I work closely with the compile team that shouldn't take too long (hopefully).

vmoens avatar May 08 '24 06:05 vmoens

TensorDict now supports compile for many of its operations. Let us know if you encounter any problem, we'll be quick in fixing these!

vmoens avatar Jul 15 '24 16:07 vmoens

Thanks for your effort! Now I can compile my modules. But still, I have no clue how to export it to use on my robot: compiled modules are currently not serializable, and it seems torch.export does not work, either.

Currently I have to directly torch.save the TensorDictModule object and setup the whole environment on my robot so that torch.load correctly finds the class information, which is very tedious.

Is there a good solution so far?

btx0424 avatar Jul 21 '24 07:07 btx0424

Agree with @btx0424: I have the same use case, and while torch.compile works, torch.export does not. I’d highly appreciate any suggestions for workarounds or updates on future development.

koritsky avatar Nov 28 '24 14:11 koritsky

Do these two docs help? https://pytorch.org/rl/main/tutorials/export.html https://pytorch.org/tensordict/main/tutorials/export.html

Otherwise lmk what I should add to make things clearer!

vmoens avatar Nov 28 '24 15:11 vmoens