xla icon indicating copy to clipboard operation
xla copied to clipboard

Introduce CUDA OpenXLA fallback.

Open ysiraichi opened this issue 1 year ago • 13 comments

This PR introduces OpenXLA fallback on PyTorch GPU eager. Instead of running fallback operations (i.e. whenever a operation has no lowering implemented) on CPU, we now make it possible to run them on GPU. This makes sense specially when using XLA:CUDA devices.

In summary, this PR introduces the following changes:

  • Rename xla_cpu_fallback into xla_fallback
    • Changes every call site that manually invokes the fallback
  • Implement cuda_fallback function
    • A version of at::native::cpu_fallback, but with a few changes (called out before each function)
    • Ideally, it would be better to generalize at::native::cpu_fallback implementation inside PyTorch, though
  • Add XLA_FALLBACK_CUDA flag for using this feature
  • Add tests for fallback operations that are found in torchbench

cc @miladm @JackCaoG @vanbasten23

ysiraichi avatar Jun 20 '24 13:06 ysiraichi

I'm still running torchbench. Will report back when it is over.

ysiraichi avatar Jun 20 '24 13:06 ysiraichi

The CI error is a bit tricky to solve.

Problem: I'm using some CUDA functions defined inside PyTorch, which requires linking libc10_cuda.so to the test binaries. However, since (in CI) PyTorch isn't being compiled with CUDA support, that won't work.

While I could condition compilation of that code with C++ macros (e.g. using XLA_CUDA definition), that would mean that we never compile that code in CI, since PyTorch/XLA is compiled without that flag set + PyTorch is compiled without CUDA support (in that specific CI action).

Possible Solution: create a phony implementation for the CUDA functions I'm using, and compile it to another library. Then, if we don't find the library libc10_cuda.so, we link this other library.

Notice that this is only needed for the test binaries.

@JackCaoG @vanbasten23 @lezcano What do you think?

ysiraichi avatar Jun 21 '24 14:06 ysiraichi

We could also always compile PyTorch with CUDA support in CI.

lezcano avatar Jun 21 '24 15:06 lezcano

The CI error is a bit tricky to solve.

Problem: I'm using some CUDA functions defined inside PyTorch, which requires linking libc10_cuda.so to the test binaries. However, since (in CI) PyTorch isn't being compiled with CUDA support, that won't work.

While I could condition compilation of that code with C++ macros (e.g. using XLA_CUDA definition), that would mean that we never compile that code in CI, since PyTorch/XLA is compiled without that flag set + PyTorch is compiled without CUDA support (in that specific CI action).

Possible Solution: create a phony implementation for the CUDA functions I'm using, and compile it to another library. Then, if we don't find the library libc10_cuda.so, we link this other library.

Notice that this is only needed for the test binaries.

@JackCaoG @vanbasten23 @lezcano What do you think?

If it's only the test binary that requires pytorch built with CUDA, there is a way to achieve it. In our CI, there is a workflow that build pytorch with CUDA, build torch_xla with CUDA, and run only those tests that requires pytorch with CUDA: image You can add your tests to https://github.com/pytorch/xla/blob/bb27cb2a49ee1413cb824af5c9a5e9bacabd0b04/.github/workflows/_test_requiring_torch_cuda.yml#L104.

vanbasten23 avatar Jun 21 '24 19:06 vanbasten23

For the problem 1 "Problem1: C++ test binaries need all references to be resolved", you mentioned the "Solution: Create a fallback implementation of the CUDA functions". Could you point to me where is the fallback implementation of the CUDA functions?

vanbasten23 avatar Jun 24 '24 17:06 vanbasten23

@zpcore to upgrade the XLA:GPU benchmarking to adopt CUDA fallback setting after this PR lands.

cc @will-cromar for viz re: comment https://github.com/pytorch/xla/pull/7318#issuecomment-2187034843

miladm avatar Jun 24 '24 17:06 miladm

@will-cromar I'm having a hard time figuring out how to make this PR work with CI. Specifically: compile + run fallback operations test (at test_ops.py).

Context: I'm calling a few PyTorch CUDA functions inside a function in aten_cpu_fallback.cpp. The implementation of these functions live in libc10_cuda.so.

Problem: In the CI action we compile PyTorch/XLA, we actually compile PyTorch and PyTorch/XLA without CUDA support. In other words, libc10_cuda.so is not created.

  • When we try to import torch_xla in that same CI action, it fails because _XLAC has undefined references to the CUDA functions
  • We can't conditionally compile (i.e. #ifdef XLA_CUDA) the CUDA functions, since it would mean CUDA OpenXLA fallback never gets compiled

Proposed Solution: have 2 libraries: _XLAC_cpu (no CUDA OpenXLA fallback) and _XLAC_cuda.

  • Conditionally import either of them, depending on whether PyTorch was compiled with support for CUDA
  • Create an alias like so: import _XLAC_cpu as _XLAC for backwards compatibility

I know this is not a pretty solution, so do you have any suggestions?

ysiraichi avatar Jun 24 '24 23:06 ysiraichi

Hey @ysiraichi, I'll spend some more time going over this PR tomorrow to try to understand it better.

We were just preparing to remove the separate GPU variant of the main torch_xla package by moving the GPU runtime implementation to a PJRT plugin. PyPI doesn't support any sort of platform tag that would let us release separate stable TPU and GPU variants of the main package. We need to figure out how to build one variant of the torch_xla package so everyone can just pip install torch_xla.

Most of the team that is building from source is doing so on TPUs realistically, so it is a nice convenience to not have to build the CUDA version of PyTorch first. Obviously adding the CUDA torch build to the critical path on the CI will be a significant overhead as well. But if we can use a pre-built PyTorch package somehow, I actually don't mind if we use the regular CUDA torch package as a build dependency, since my main concern is how slow that build is. cc @JackCaoG since we've talked about this possibility a few times but never had a good enough reason to add this option

I don't fully understand after skimming the PR why we need libc10_cuda at build time. Can that be dynamically loaded as needed?

will-cromar avatar Jun 25 '24 00:06 will-cromar

I don't fully understand after skimming the PR why we need libc10_cuda at build time. Can that be dynamically loaded as needed?

It can be loaded at runtime. However, it can't be loaded conditionally. At least, not like this.

Loading conditionally ("as needed") was, in fact, the solution that I was proposing. We could have a separate library with a phony implementation of these CUDA functions. Then, import it only if we are in an environment where PyTorch has no CUDA support.

Let me have a first implementation. We can remove it, if that's not what we want.

ysiraichi avatar Jun 25 '24 14:06 ysiraichi

I have worked on this for a while, now, trying a bunch of things. Unfortunately, none of them worked. Here's the current state of things:

What I tried:

  • Created a new Python library _XLAC_cuda_functions.so that holds the definition for the c10::cuda functions I need
    • Idea: introduce a definition to those c10::cuda functions whenever PyTorch doesn't have CUDA support
  • Modified the torch_xla/__init__.py so that we import this library if not torch.cuda.available()
    • If CUDA is available, we rely on import torch to load libc10_cuda.so, which brings definition to c10::cuda functions

What is happening:

  • Even though I'm able to import the new _XLAC_cuda_functions.so library, I'm still getting undefined reference for c10::cuda functions when import torch_xla is called

I'm not sure why this is not working given that:

$ nm -CD _XLAC.cpython-310-x86_64-linux-gnu.so | grep c10::cuda
                 U c10::cuda::set_device(signed char)
                 U c10::cuda::current_device()
                 U c10::cuda::device_synchronize()

$ nm -CD _XLAC_cuda_functions.cpython-310-x86_64-linux-gnu.so | grep c10::cuda
000000000002c0b1 T c10::cuda::set_device(signed char)
000000000002c09e T c10::cuda::current_device()
000000000002c0cd T c10::cuda::device_synchronize()
# This works!
$ LD_PRELOAD=./_XLAC_cuda_functions.cpython-310-x86_64-linux-gnu.so python -c "import torch_xla"

# This doesn't work...
$ python -c "import _XLAC_cuda_functions; import torch_xla"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "xla/torch_xla/__init__.py", line 11, in <module>
    import _XLAC
ImportError: xla/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda14current_deviceEv

@JackCaoG @vanbasten23 @lezcano @will-cromar Any thoughts?

ysiraichi avatar Jun 27 '24 13:06 ysiraichi

python imports the libraries with RTLD_LOCAL which means the symbols from _XLAC_cuda_functions are not added to the global symbol table. You need to set RTLD_GLOBAL before importing _XLAC_cuda_functions.

import sys, os
prev = sys.getdlopenflags()
sys.setdlopenflags(prev | os.RTLD_GLOBAL)
import _XLAC_cuda_functions
sys.setdlopenflags(prev)

import torch_xla

isuruf avatar Jun 27 '24 14:06 isuruf

@vanbasten23 @JackCaoG @will-cromar I think I finally found a solution to the CUDA functions problem, thanks to @isuruf. Here's a summary of the recent changes in this PR:

  • aten_cuda_functions.cpp implements the CUDA functions used by aten_cpu_fallback.cpu that, unless PyTorch is compiled with CUDA, aren't defined anywhere
  • In addition to _XLAC.so, we build _XLAC_cuda_functions.so, which is composed of only the aforementioned aten_cuda_functions.cpp
  • Inside torch_xla/__init__.py, we conditionally load the newly added _XLAC_cuda_functions.so library, just after importing torch
    • We only load it if not torch.cuda.is_available(), fulfilling the reference in _XLAC.so
    • Note: we load it using the os.RTDL_GLOBAL options, which loads the library symbols globally

With these modifications, we are able to load the dummy implementation of the CUDA functions conditionally. And, in case we start using a PyTorch instance with CUDA support, we don't load it and rely on libtorch_python.so, which depends on libc10_cuda.so.

Let me know what you think. Whenever you have some time, could you give this PR another go?

ysiraichi avatar Jun 27 '24 17:06 ysiraichi

Mostly LGTM with minor comments.

Amazing work!

vanbasten23 avatar Jul 02 '24 02:07 vanbasten23

Running TorchBench with --verify flag, showed no new accuracy problems on inference (the flag doesn't work with training). Therefore, I will go on and merge this PR.

ysiraichi avatar Jul 03 '24 21:07 ysiraichi