torch-mlir
torch-mlir copied to clipboard
Can torchscript dump backward graph?
Can torchscript dump backward graph? When using TorchScript, can the backward graph be dumped and compiled using MLIR pass ? Or only forward graph is supported?
Take a look at the BERT training recipe here.
It combines 2 different techniques from PyTorch core:
-
torch.nn.utils.stateless.functional_callto transformmodel.forward(inputs)toforward(model, inputs) -
torch.fx.experimental.proxy_tensor.make_fxto transform a eager train step function into atorch.fx.GraphModule
One last thing missing would be making sure that the same train function is applied in each training step, which will require users' explicit annotation here. For more information on this, I would recommend this awesome post that summarizes the status quo of PyTorch graph capture (courtesy @penguinwu).