tensordict
tensordict copied to clipboard
Any plans to make it compatible with torch.jit?
Currently, this fails:
import torch
import torch.nn as nn
from tensordict import TensorDict
from torch.jit import script, trace # type: ignore
class MyModule(nn.Module):
def forward(self, x: TensorDict) -> TensorDict:
return x
model = MyModule()
input_ = TensorDict.from_dict({"input": torch.randn(1, 1, 28, 28)})
model(input_)
trace(model, input_)(input_) # type: ignore
script(model)(input_) # type: ignore
Any plans to make TensorDict compatible with torch.jit?
Hey! We're actually looking into this, thanks for showing interest. Let's hope we can make it work!
Thank you very much! šš»
Any updates on this? The scenario I am looking into right now is deploying trained models easily on edge devices. Tracing and jitting would be super helpful, as it avoids having to install many dependencies on the device.
Same here. I want to deploy a trained policy to a quadruped. Currently none of torch.jit.script, torch.jit.trace, and torch.compile work with TensorDictSequential. Does the prototyping feature symbolic_trace do the job? If yes, it would be great if an example could be added to the tutorial.
Hello! We don't have plans to make tensordict compatible with jit as of now. torchscript is deprecated and unmaintained. Since I'm alone maintaining this lib along with RL and other stuff in pytorch core, my bandwidth is very limited (we welcome contributions, as always!) At first glance, making the tensordict stack compatible with torchscript would take me at least 2 weeks (considering that I dedicate 100% of my time to it) and would likely compromise any further work to make it compatible with torch.compile.
Which brings us to the second point: the blessed way now to export python code is through torch compile + export.
Making all the code stack compatible with compile is very high pri and I will make sure that this is done in a timely manner!
Since I work closely with the compile team that shouldn't take too long (hopefully).
TensorDict now supports compile for many of its operations. Let us know if you encounter any problem, we'll be quick in fixing these!
Thanks for your effort! Now I can compile my modules. But still, I have no clue how to export it to use on my robot: compiled modules are currently not serializable, and it seems torch.export does not work, either.
Currently I have to directly torch.save the TensorDictModule object and setup the whole environment on my robot so that torch.load correctly finds the class information, which is very tedious.
Is there a good solution so far?
Agree with @btx0424: I have the same use case, and while torch.compile works, torch.export does not. Iād highly appreciate any suggestions for workarounds or updates on future development.
Do these two docs help? https://pytorch.org/rl/main/tutorials/export.html https://pytorch.org/tensordict/main/tutorials/export.html
Otherwise lmk what I should add to make things clearer!