torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[TORCH][MLIR] Add E2E support for `aten.native_layer_norm_backward` op.

Open gprateek93 opened this issue 3 years ago • 5 comments

This PR adds two commits:

  1. adds support for aten.native_layer_norm_backward operation. It also adds support for matching constant bools stored in a boolean list.
  2. fixes aten.native_layer_norm. Previously this operation was not calculating correct shapes for mean and inverted STD. This has been corrected in this commit. Some new helper functions are added to calculate the inverted STD and to broadcast a given input with the help of a broadcast mask.

Signed-Off-By: Prateek Gupta [email protected]

gprateek93 avatar Feb 09 '22 16:02 gprateek93

@cathyzhyi I have added a PR for this decomposition in functorch: https://github.com/pytorch/functorch/pull/525 So I guess having the decomposition here also will duplicate it.

gprateek93 avatar Mar 10 '22 10:03 gprateek93

@cathyzhyi I have added a PR for this decomposition in functorch: pytorch/functorch#525 So I guess having the decomposition here also will duplicate it.

@gprateek93 After the decomposition is merged can you add an e2e test to make sure it would make aten::native_layer_norm_backward work?

cathyzhyi avatar Mar 15 '22 16:03 cathyzhyi

The functorch PR for aten.native_layer_norm_backward is merged. Waiting for the integration of latest functorch in torch-mlir. Once done, we can safely close this PR.

gprateek93 avatar Mar 30 '22 16:03 gprateek93

Hey are there any updates on this PR? We seem to have support for native_layer_norm but it would be great to also have native_layer_norm_backward working too.

Ideally we'd like to have native_layer_norm_backward stay as a high level op, rather than have it be decomposed.

cc: @antoniojkim @ke1337

henrytwo avatar May 18 '22 16:05 henrytwo

FYI: I'm working on a new PR that will borrow some code from this one (with credits and reference), since I need support for aten.native_layer_norm_backward soon. https://github.com/llvm/torch-mlir/pull/888

henrytwo avatar May 30 '22 19:05 henrytwo

This PR can be closed, seems https://github.com/llvm/torch-mlir/pull/888 is merged to add aten.native_layer_norm_backward op.

xgupta avatar Nov 26 '22 18:11 xgupta