torch-mlir
torch-mlir copied to clipboard
[TORCH][MLIR] Add E2E support for `aten.native_layer_norm_backward` op.
This PR adds two commits:
- adds support for
aten.native_layer_norm_backward
operation. It also adds support for matching constant bools stored in a boolean list. - fixes
aten.native_layer_norm
. Previously this operation was not calculating correct shapes for mean and inverted STD. This has been corrected in this commit. Some new helper functions are added to calculate the inverted STD and to broadcast a given input with the help of a broadcast mask.
Signed-Off-By: Prateek Gupta [email protected]
@cathyzhyi I have added a PR for this decomposition in functorch: https://github.com/pytorch/functorch/pull/525 So I guess having the decomposition here also will duplicate it.
@cathyzhyi I have added a PR for this decomposition in functorch: pytorch/functorch#525 So I guess having the decomposition here also will duplicate it.
@gprateek93 After the decomposition is merged can you add an e2e test to make sure it would make aten::native_layer_norm_backward
work?
The functorch PR for aten.native_layer_norm_backward
is merged. Waiting for the integration of latest functorch in torch-mlir. Once done, we can safely close this PR.
Hey are there any updates on this PR? We seem to have support for native_layer_norm
but it would be great to also have native_layer_norm_backward
working too.
Ideally we'd like to have native_layer_norm_backward
stay as a high level op, rather than have it be decomposed.
cc: @antoniojkim @ke1337
FYI: I'm working on a new PR that will borrow some code from this one (with credits and reference), since I need support for aten.native_layer_norm_backward
soon. https://github.com/llvm/torch-mlir/pull/888
This PR can be closed, seems https://github.com/llvm/torch-mlir/pull/888 is merged to add aten.native_layer_norm_backward op.