xla
xla copied to clipboard
[Core ATen Opset] Lower aten_pixel_shuffle
In order for PyTorch/XLA to support the PyTorch core ATen opset, it requires lowering each core ATen op in PyTorch/XLA. This issue is used to track the PyTorch/XLA lowering for aten_pixel_shuffle.
Here are some general guidelines to lowering this op:
- Uncomment
@unittest.skip
or@unittest.expectFailure
and run the unit test at test_core_aten_ops.py. Eg:pytest test/test_core_aten_ops.py -k test_aten_pixel_shuffle_0
- Make code changes until the test passes. Read and follow fix_lowering_for_core_aten_ops.md for ideas to fix.
- There may be multiple unit tests for a single op. For this op, the corresponding unit tests are:
- test_aten_pixel_shuffle_0
- test_aten_pixel_shuffle_1
- test_aten_pixel_shuffle_2
- Please also uncomment the skips for all these tests and ensure all tests are fixed.
- Note that sometimes the fix may be to fix the unit tests itself. Please take a look at the corresponding unit tests to make sure the tests are valid.
- There may be multiple unit tests for a single op. For this op, the corresponding unit tests are:
- Submit the PR!
For any questions, feel free to leave a comment in this PR.
As @zpcore mentioned in an offline sync, the behavior of this op seems a bit odd:
For aten op pixel_shuffle. When dispatch in XLA, the op shows different behavior when calling it through C++ (https://github.com/pytorch/xla/blob/657b6925c65b7531a0560488df63fba90399a778/test/cpp/test_aten_xla_tensor_4.cpp#L1224) and python. If we call through python, the output will be a tensor with all value 0. However, cpp test doesn't show the issue.
Below is the sample code I use for testing pixel_shuffle in python:
import torch import torch_xla import torch_xla.core.xla_model as xm a = torch.rand([5,18,4,4]) device = xm.xla_device() d = a.to(device) f = torch.pixel_shuffle(d, 3) xm.master_print(met.metrics_report()) print(f)
The output will always be tensor with all value 0, which is different from run it through C++.
- In CPP I run the following test program:
torch::Tensor input = torch::rand({5, 18, 4, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::pixel_shuffle(xla_input, upscale_factor);
std::cout << GetTensorsHloGraph({xla_output}, EmitMode::kHloReadable) << std::endl;
We got the following HLO:
HloModule IrToHlo.6, entry_computation_layout={(f32[5,18,4,4]{3,2,1,0})->(f32[5,2,12,12]{3,2,1,0})}
ENTRY %IrToHlo.6 (p0.1: f32[5,18,4,4]) -> (f32[5,2,12,12]) {
%p0.1 = f32[5,18,4,4]{3,2,1,0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data"}
%reshape.2 = f32[5,2,3,3,4,4]{5,4,3,2,1,0} reshape(f32[5,18,4,4]{3,2,1,0} %p0.1), metadata={op_type="aten__view" op_name="aten__view"}
%transpose.3 = f32[5,2,4,3,4,3]{4,2,5,3,1,0} transpose(f32[5,2,3,3,4,4]{5,4,3,2,1,0} %reshape.2), dimensions={0,1,4,2,5,3}, metadata={op_type="aten__permute" op_name="aten__permute"}
%reshape.4 = f32[5,2,12,12]{3,2,1,0} reshape(f32[5,2,4,3,4,3]{4,2,5,3,1,0} %transpose.3), metadata={op_type="aten__view" op_name="aten__view"}
ROOT %tuple.5 = (f32[5,2,12,12]{3,2,1,0}) tuple(f32[5,2,12,12]{3,2,1,0} %reshape.4)
}
- In Python I run the following test program:
a = torch.rand([5,18,4,4])
device = xm.xla_device()
d = a.to(device)
f = torch.pixel_shuffle(d, 3)
print(torch_xla._XLAC._get_xla_tensors_hlo([f]))
We got the following HLO:
HloModule IrToHlo.7, entry_computation_layout={()->(f32[5,2,12,12]{3,2,1,0})}
ENTRY %IrToHlo.7 () -> (f32[5,2,12,12]) {
%constant.1 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/mnt/disks/data/pytorch/torch/_meta_registrations.py" source_line=5933}
%reshape.2 = f32[1,1,1,1]{3,2,1,0} reshape(f32[] %constant.1), metadata={op_type="aten__expand" op_name="aten__expand" source_file="/mnt/disks/data/pytorch/torch/_meta_registrations.py" source_line=5933}
%broadcast.3 = f32[1,1,1,1]{3,2,1,0} broadcast(f32[1,1,1,1]{3,2,1,0} %reshape.2), dimensions={0,1,2,3}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="/mnt/disks/data/pytorch/torch/_meta_registrations.py" source_line=5933}
%reshape.4 = f32[] reshape(f32[1,1,1,1]{3,2,1,0} %broadcast.3), metadata={op_type="aten__expand" op_name="aten__expand" source_file="/mnt/disks/data/pytorch/torch/_meta_registrations.py" source_line=5933}
%broadcast.5 = f32[5,2,12,12]{3,2,1,0} broadcast(f32[] %reshape.4), dimensions={}, metadata={op_type="aten__expand" op_name="aten__expand" source_file="/mnt/disks/data/pytorch/torch/_meta_registrations.py" source_line=5933}
ROOT %tuple.6 = (f32[5,2,12,12]{3,2,1,0}) tuple(f32[5,2,12,12]{3,2,1,0} %broadcast.5)
}
As we can see, CPP and Python generated different HLO graphs. The op aten__permute
doesn't exist in python routine and as a result we see python always produces tensor with all zero values. Will looking into detail of the lowering trace.
This is something to do with at::functionalization::functionalize_aten_op<ATEN_OP(pixel_shuffle)>::call(self, upscale_factor)
called here.
pixel_shuffle
has it's own decomposition declared in torch here. If we remove the pixel_shuffle
op from xla_native_functions.yaml
and let torch do the decomposition to call other ops, the result will be correct using the python API in torch_xla.
functionalization
has it's own purposes to handle mutations as discussed in https://github.com/pytorch/xla/pull/3646 and https://dev-discuss.pytorch.org/t/functionalization-in-pytorch-everything-you-wanted-to-know/965.
@bdhirsh to see why functionalization
will incorrectly decompose the pixel_shuffle
through python API. Can we simply remove it?
Update 03/19: I forgot the mention that for some reason the input tensor is dropped as is shown in ENTRY %IrToHlo.7 () -> (f32[5,2,12,12])
. We need to figure out from which stage the input tensor got lost when we call functionalization
.
any updates on this issue?
any updates on this issue?
No, let me follow up with @bdhirsh .
cc @jiawenliu64 as this op appears to fail tests when Functionalization flag is enabled
Hey!
The API run at https://github.com/pytorch/xla/blob/5a113aff98ce42420891c724843ccb30691dc24a/torch_xla/csrc/aten_xla_type.cpp#L3639-L3645 is just re-invoking the dispatcher, and attempting to run whatever decomposition for pixel_shuffle
is registered there.
My guess for the python vs C++ difference is that there's a C++ decomposition registered in core here (that secretly runs some view ops like permute()
, which is why the linked API above "re-runs" functionalization to be able to convert permute
to permute_copy
)
And when you use python, import
torch probably ends up registering this python decomposition to the dispatcher, so that is run instead
In theory... both of those decomps should be correct. Does one of the IR's look wrong?
Alternatively: you can definitely remove the return at::functionalization::functionalize_aten_op<ATEN_OP(pixel_shuffle)>::call(self, upscale_factor);
code and insert your own lowering instead, if you want to lower the op directly without using a decomposition from core
Thanks for the response, @bdhirsh! So the original reason why we wanted to lower this was that it was marked as a "core" op in pytorch native_functions.yaml and we wanted torch_xla to support lowerings for all core ops.
According to my initial understanding, I thought a "core" op would not be decomposed into any further ops, is this the case? But according to https://github.com/pytorch/pytorch/blob/58047205ed098c04ec045e66fc39dcc70b60600b/torch/_refs/nn/functional/init.py#L1169, it appears to have some decompositions.
@wonjoolee95 the way I would categorize the decomps we have is that:
(1) there are a ton of decomps, for many (most?) ops in ATen. A compiler backend to torch.compile can specify which of those decomps they do/don't want to run, and lower the remaining primitive ops that show up in the graph directly. For example, inductor has its own set of decomps that it uses (basically core aten + a few other decomps: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/decomposition.py#L75)
(2) there is a canonical "core ATen opset" that backends can choose to target, and you can get out a graph of core ATen IR by specifying that you only want to run core ATen decomps.
So if you e.g. use torch.export and run with core ATen decomps set, you'll get a graph of core ATen IR. But the eager-mode XLA integration doesn't necessarily run the same set of decomps as core ATen (although you can change which ops you choose to decompose vs lower directly)
Thanks for the explanation, that makes a lot of sense.
@zpcore, I'll remove this issue from the scope of "core aten opset". In torch.export
's view, this op can be further decomposed so we don't have to keep within this project's scope.
Hi @bdhirsh , @wonjoolee95 , thanks for the following up. I checked the decomposition trace, it turns out that if we move the tensor to the XLA device, it will use a different decompose. Below is the example code. If we use the following code:
device = xm.xla_device()
a = torch.rand([5,18,4,4])
a = a.to(device)
f = torch.pixel_shuffle(a, 3)
It will enter _meta_registrations.py
If we remove a = a.to(device)
, the code will end up in the CPP decomposition pixel_shuffle
Either way, I didn't see the python registration been called.