xla icon indicating copy to clipboard operation
xla copied to clipboard

PyTorch/XLA should not crash runtime during int64 dot product on TPU

Open sagelywizard opened this issue 11 months ago • 6 comments

Issue description

Currently, PyTorch/XLA will crash the entire runtime (ret check failure?) if you do an int64 dot product on a TPU. Seems like poor UX to me. It'd be preferable to raise a proper Python exception rather than crashing the whole runtime. This is particularly problematic for folks using PyTorch/XLA in a Jupyter notebook, because it'll crash the whole notebook.

Code example

> import torch
> import torch_xla.core.xla_model as xm
> dev = xm.xla_device()
> t1 = torch.tensor([1,2,3], device=dev)
> t2 = torch.tensor([3,4,5], device=dev)
> t1 @ t2
Non-OK-status: status.status() status: UNIMPLEMENTED: While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %dot.3 = s64[] dot(s64[3]{0} %p1.2, s64[3]{0} %p0.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}

System Info

  • reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.1.0

sagelywizard avatar Mar 08 '24 19:03 sagelywizard

Thanks @sagelywizard for raising the issue. I think what we could do is

  1. verify if this is the issue on nightly
  2. fallback to cpu if in64 dot product is still not implemented.

@bhavya01 do you have cycle to help with this?

JackCaoG avatar Mar 08 '24 19:03 JackCaoG

Looking at this now

bhavya01 avatar Mar 08 '24 21:03 bhavya01

Crashing on this operation is one thing. Fallback to CPU to avoid the crash seems reasonable to me in this case.

The way we crash is the real issue IMO. Since most of the time, we're running through xmp.spawn. When we hit a fatal error, just the child process dies (usually after dumping an error message) and the parent throws a real exception.

But in notebooks, if you run an operation that fails a status check, it crashes the whole notebook process and resets your state. That's a pretty bad user experience.

will-cromar avatar Mar 08 '24 22:03 will-cromar

hmm do we know why?

JackCaoG avatar Mar 08 '24 22:03 JackCaoG

I reproduced the error to get the line number where we terminate the process:

https://github.com/pytorch/xla/blob/6cc5c3819c09c7b1b4ca4927d7fa65133f95b41c/torch_xla/csrc/runtime/debug_macros.h#L20

So this has to do with our implementation of XLA_CHECK_OK. Ideally we want to make Python throw an exception in this case instead of terminating the whole process. I'll dig in and see what XLA_CHECK_OK is actually doing.

will-cromar avatar Mar 08 '24 22:03 will-cromar

Here's what I found:

We've gotten lazy with calling XLA_CHECK_OK on everything, when we should really should be lazily calling StatusOr::value() in most cases. We should only be calling XLA_CHECK_OK when an error is really fatal and we want to kill the whole process.

This case should be an easy fix -- compilation errors don't need to be fatal. We should review our other use of XLA_CHECK_OK and transitively ConsumeValue and decide what makes sense.

will-cromar avatar Mar 08 '24 22:03 will-cromar

I'll wait for the new nightly release and confirm that this issue is fixed before closing.

bhavya01 avatar Mar 19 '24 17:03 bhavya01