torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Can torchscript dump backward graph?

Open yxd886 opened this issue 3 years ago • 1 comments

Can torchscript dump backward graph? When using TorchScript, can the backward graph be dumped and compiled using MLIR pass ? Or only forward graph is supported?

yxd886 avatar Jul 04 '22 07:07 yxd886

Take a look at the BERT training recipe here.

It combines 2 different techniques from PyTorch core:

  • torch.nn.utils.stateless.functional_call to transform model.forward(inputs) to forward(model, inputs)
  • torch.fx.experimental.proxy_tensor.make_fx to transform a eager train step function into a torch.fx.GraphModule

One last thing missing would be making sure that the same train function is applied in each training step, which will require users' explicit annotation here. For more information on this, I would recommend this awesome post that summarizes the status quo of PyTorch graph capture (courtesy @penguinwu).

byronyi avatar Jul 04 '22 08:07 byronyi