xla
xla copied to clipboard
PyTorch/XLA should not crash runtime during int64 dot product on TPU
Issue description
Currently, PyTorch/XLA will crash the entire runtime (ret check failure?) if you do an int64 dot product on a TPU. Seems like poor UX to me. It'd be preferable to raise a proper Python exception rather than crashing the whole runtime. This is particularly problematic for folks using PyTorch/XLA in a Jupyter notebook, because it'll crash the whole notebook.
Code example
> import torch
> import torch_xla.core.xla_model as xm
> dev = xm.xla_device()
> t1 = torch.tensor([1,2,3], device=dev)
> t2 = torch.tensor([3,4,5], device=dev)
> t1 @ t2
Non-OK-status: status.status() status: UNIMPLEMENTED: While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %dot.3 = s64[] dot(s64[3]{0} %p1.2, s64[3]{0} %p0.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
System Info
- reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: 2.1.0
Thanks @sagelywizard for raising the issue. I think what we could do is
- verify if this is the issue on nightly
- fallback to cpu if in64 dot product is still not implemented.
@bhavya01 do you have cycle to help with this?
Looking at this now
Crashing on this operation is one thing. Fallback to CPU to avoid the crash seems reasonable to me in this case.
The way we crash is the real issue IMO. Since most of the time, we're running through xmp.spawn
. When we hit a fatal error, just the child process dies (usually after dumping an error message) and the parent throws a real exception.
But in notebooks, if you run an operation that fails a status check, it crashes the whole notebook process and resets your state. That's a pretty bad user experience.
hmm do we know why?
I reproduced the error to get the line number where we terminate the process:
https://github.com/pytorch/xla/blob/6cc5c3819c09c7b1b4ca4927d7fa65133f95b41c/torch_xla/csrc/runtime/debug_macros.h#L20
So this has to do with our implementation of XLA_CHECK_OK
. Ideally we want to make Python throw an exception in this case instead of terminating the whole process. I'll dig in and see what XLA_CHECK_OK
is actually doing.
Here's what I found:
- Absl's
CHECK_OK
and friends terminate the program by design because Google bans exceptions internally. - Exceptions are what we want in this case because pybind magic will turn them into a Python
RuntimeError
- If we blindly call
.value()
on an invalidStatusOr
, we should get an exception :confetti_ball:
We've gotten lazy with calling XLA_CHECK_OK
on everything, when we should really should be lazily calling StatusOr::value()
in most cases. We should only be calling XLA_CHECK_OK
when an error is really fatal and we want to kill the whole process.
This case should be an easy fix -- compilation errors don't need to be fatal. We should review our other use of XLA_CHECK_OK
and transitively ConsumeValue
and decide what makes sense.
I'll wait for the new nightly release and confirm that this issue is fixed before closing.