[PT2.1] SIGSEGV seen with view + sgn operator inside torch.compile
🐛 Describe the bug
when view operator with sgn used inside torch.compile, then signal segmentation violation error show. Please use below code to reproduce the issue.
import torch
def fn(a):
b = a.view((2, 2))
return b.sgn()
x_cpu =torch.tensor([[2.0, 2], [-2, -2]], requires_grad=True)
compiled_fn = torch.compile(fn)
y_cpu = compiled_fn(x_cpu)
print("y_hpu", y_cpu)
Error logs
Thread 1 "python" received signal SIGSEGV, Segmentation fault. 0x00007fffeb281440 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so (gdb) bt #0 0x00007fffeb281440 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so #1 0x00007fffeb2a9999 in at::FunctionalTensorWrapper::replace_(at::Tensor const&) () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so #2 0x00007fffeb2aa48c in at::FunctionalTensorWrapper::regenerate_from_base() () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so #3 0x00007ffff6710e3b in torch::autograd::THPVariable__sync(_object*, _object*, _object*) () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so #4 0x00000000005f6939 in PyCFunction_Call () #5 0x00000000005f7506 in _PyObject_MakeTpCall () #6 0x0000000000570b8e in _PyEval_EvalFrameDefault () #7 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so #8 0x00000000005f6ce6 in _PyFunction_Vectorcall () #9 0x000000000056b4ed in _PyEval_EvalFrameDefault () #10 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so #11 0x00000000005697da in _PyEval_EvalCodeWithName () #12 0x00000000005f6ec3 in _PyFunction_Vectorcall () #13 0x000000000056b4ed in _PyEval_EvalFrameDefault () #14 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so #15 0x00000000005697da in _PyEval_EvalCodeWithName () #16 0x00000000005f6ec3 in _PyFunction_Vectorcall () #17 0x0000000000570556 in _PyEval_EvalFrameDefault () #18 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so #19 0x00000000005697da in _PyEval_EvalCodeWithName () #20 0x00000000005f6ec3 in _PyFunction_Vectorcall ()
Minified repro
No response
Versions
[pip3] numpy==1.24.4 [pip3] torch==2.1.0 [pip3] torchaudio==2.0.1 [pip3] torchdata==0.6.1 [pip3] torchmetrics==1.2.0 [pip3] torchtext==0.15.2a0 [pip3] torchvision==0.15.1a0
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @wconstab @bdhirsh @anijain2305
Verified reproducible with pytorch 2.1, but not on the nightly or my local build (hash f2d476843ee).
Hey @jay746, just confirming, but did you file this as a regression? I also tried installing torch 2.0.0, and confirmed that I see the same segfault. So looks like a hi-pri bug, although it doesn't seem to be a regression.
sgn also has a backward formula that uses efficientzerotensor: https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L578
This looks like a bad interaction between efficientzerotensor and functionalization. Here's a minimal repro:
import torch
def f():
torch._enable_functionalization(reapply_views=True)
x = torch._efficientzerotensor(4)
y = x.reshape(-1)
torch._sync(y)
return y
out = f()
When I run under a debug build with TORCH_SHOW_DISPATCH_TRACE=1 (which gets me print statements on every dispatcher call), I see:
[call] op=[aten::_efficientzerotensor], key=[PythonTLSSnapshot]
[redispatchBoxed] op=[aten::_efficientzerotensor], key=[Functionalize]
[callBoxed] op=[aten::_efficientzerotensor], key=[PythonTLSSnapshot]
[redispatchBoxed] op=[aten::_efficientzerotensor], key=[Python]
[callBoxed] op=[aten::_efficientzerotensor], key=[Python]
[callBoxed] op=[aten::_efficientzerotensor], key=[BackendSelect]
[redispatch] op=[aten::_efficientzerotensor], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[call] op=[aten::detach], key=[Python]
[callBoxed] op=[aten::detach], key=[Meta]
[call] op=[aten::detach], key=[CPU]
[call] op=[aten::reshape], key=[PythonTLSSnapshot]
[redispatchBoxed] op=[aten::reshape], key=[AutogradCPU]
[call] op=[aten::view], key=[PythonTLSSnapshot]
[redispatchBoxed] op=[aten::view], key=[AutogradCPU]
[redispatch] op=[aten::view], key=[ADInplaceOrView]
[redispatch] op=[aten::view], key=[ZeroTensor]
You can see that op=[aten::view], key=[ZeroTensor] runs, but it shortcircuits and doesn't give functionalization a chance to run, causing the FunctionalTensorWrapper to get in a bad state.
The right thing to do here is probably not to copy the ZeroTensor dispatch key onto the FunctionalTensorWrapper. That way, functionalization will have a chance to run first.
Hey @jay746, just confirming, but did you file this as a regression? I also tried installing torch 2.0.0, and confirmed that I see the same segfault. So looks like a hi-pri bug, although it doesn't seem to be a regression.
Hi @bdhirsh , No it's not a regression. We have this issue at PT2.0 as well. I had missed to raised issue.
I'm tentatively removing hi-pri, since this is "fixed" on tip of main (you cannot repro the segfault). Since this is not a regression, I'm not sure that this is a candidate for a fix in the 2.1 release. The main issue is that the segfault was incidentally "fixed" by updating AOTAutograd to use python functionalization, and this was a pretty large change (that would be risky to put into a patch release). Alternatively we could try to fix the segfault directly in C++ functionalization, although this would mean adding an entirely new set of changes to the 2.1 branch that are not in main, which I'm not sure that we want either.
I wanted to leave this issue open and mark it triage review though, because the state of ZeroTensor with torch.compile (and make_fx in particular) seems suboptimal. For example, the below code errors at tracetime (you can repro the same error with torch.compile, I'm just showing the error with make_fx since there are fewer moving parts):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._dispatch.python import enable_python_dispatcher
def fn(a):
b = torch.mul(a, 2)
out = b.sgn()
a_grad = torch.autograd.grad([out], [a], grad_outputs=[torch.ones_like(out)])
return a_grad
x = torch.ones(2, 2, requires_grad=True, dtype=torch.float32)
with enable_python_dispatcher():
fx_g = make_fx(fn, decomposition_table=torch._decomp.core_aten_decompositions())(x)
print(fx_g.code)
errors with:
File "/data/users/hirsheybar/c/pytorch/torch/_decomp/decompositions.py", line 1869, in _to_copy
if device is not None and device != x.device:
AttributeError: 'int' object has no attribute 'device'
What's going on? It's a bit clearer with this slightly larger repro:
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._dispatch.python import enable_python_dispatcher
def fn(a):
b = torch.add(a, a)
out = b.sgn()
a_grad = torch.autograd.grad([out], [a], grad_outputs=[torch.ones_like(out)])
return a_grad
x = torch.ones(2, 2, requires_grad=True, dtype=torch.float32)
with enable_python_dispatcher():
fx_g = make_fx(fn, decomposition_table=torch._decomp.core_aten_decompositions())(x)
print(fx_g.code)
this "runs", and prints the following FX graph (I've annotated the forward and backward bits myself):
def forward(self, a_1):
# forward
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
sign = torch.ops.aten.sign.default(add); add = None
# backward
alias = torch.ops.aten.alias.default(sign)
full_like = torch.ops.aten.full_like.default(sign, 1, pin_memory = False, memory_format = torch.preserve_format)
is_same_size = torch.ops.aten.is_same_size.default(sign, full_like); sign = full_like = None
alias_1 = torch.ops.aten.alias.default(alias); alias = None
_efficientzerotensor = torch.ops.aten._efficientzerotensor.default([2, 2], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
_to_copy = torch.ops.aten._to_copy.default(_efficientzerotensor, device = device(type='meta'))
_to_copy_1 = torch.ops.aten._to_copy.default(_efficientzerotensor, device = device(type='meta')); _efficientzerotensor = None
add_1 = torch.ops.prims.add.default(_to_copy_1, _to_copy); _to_copy_1 = _to_copy = None
_efficientzerotensor_1 = torch.ops.aten._efficientzerotensor.default([2, 2], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
return (_efficientzerotensor_1,)
The backward graph is particularly weird- there are calls to _to_copy(..., "meta"), which will show up directly in the graph that inductor sees! This appears to happen because:
(1) sgn's backward formula uses _efficientzerotensor
(2) make_fx will directly trace the _to_copy("meta") calls from the efficientzerotensor's implementation here
I'm marking this with triage review, because I'd like to understand what we actually want to have happen with efficientzerotensor and PT2. At a minimum:
(a) Should we fix the compile-time errors with zerotensor that I mentioned above? Probably yes.
(b) Should we try to avoid directly tracing the meta-tensor calls from ZeroTensor into the graph, which will force inductor to handle them? Probably yes
(c) What do we actually want inductor to see in the graph when we're tracing zerotensor code from eager mode? Today, inductor has a fallback for _efficientzerotensor here, which seems suboptimal. One option would be to have efficientzerotensor code desugar into "ordinary" ops, and let inductor optimize the zeros away itself (and e.g. decompose aten._efficientzerotensor into aten.zeros).
By the same logic, we probably want aten._efficientzerotensor to decompose into aten.zeros anyway in our core aten decompositions, unless we want to have to consider _efficientzerotensor to be a core aten op.
cc @ezyang, @zou3519 since "tracing ZeroTensor" came up a few times in conversation when I added python functionalization.
Decompose is 100% right, if you want to be fancy we might want to sometimes return zero tensors as outputs when we know some outputs must be zero but this should be done as an add on
@ezyang I don't see this issue after PT2.2 upgrade, did we fix it?
idk, maybe @bdhirsh knows!
Hi @bdhirsh can you please check and let me know if fixes then we can close it
I believe this issue is already fixed in PT2.2. Closing.