xla icon indicating copy to clipboard operation
xla copied to clipboard

RFC: Evolving PyTorch/XLA for a more native experience on TPU

Open qcc4cp opened this issue 2 months ago • 16 comments

Motivation

For many years, torch_xla has been the primary way for the community to run PyTorch programs on Cloud TPUs. It has successfully enabled the training of massive models by bringing the power of the XLA compiler to the PyTorch ecosystem.

The current implementation, while powerful, presents a developer experience that can sometimes feel distinct from "native" PyTorch. The reliance on a lazy tensor model and explicit graph tracing (xm.mark_step) creates a separation from PyTorch's eager-first philosophy. This can introduce challenges in debugging, complicates integration with the broader PyTorch ecosystem, and requires users to learn a torch_xla-specific set of APIs and concepts.

We believe we can deliver a more seamless and native experience for PyTorch users on TPUs. The goal is to provide the best of both worlds: the interactive, flexible development experience of PyTorch's eager mode and the world-class performance of the XLA compiler for scaled-out workloads.


Proposal: A Native TPU Backend

We propose a TPU backend for PyTorch that is designed to align with modern PyTorch architecture and eager-first design. The goal is to make a "native" device in PyTorch, where tensor.to('tpu') feels just as natural and intuitive as tensor.to('cuda'). This new direction aims to fully embrace PyTorch's eager mode while still leveraging the powerful XLA compiler for performance-critical code paths.

The core principles of this new stack are:

  1. XLA: Similarly to torch_xla, our proposal assumes that we can continue to rely on XLA as the underlying compiler infrastructure. However, we would call it in a profoundly different way which enables new techniques and a better user experience. Note that on TPU, compilation is required for the best performance — but it should be possible to hide the compile times.
  2. Eager Mode with Deferred Execution: Similar to standard PyTorch eager mode, ops are being dispatched. However, the new stack can then choose to compile and execute individual ops, shorter or longer sequences of ops, or potential candidates for fusion clusters—all the way up to a full compile of a forward or backward pass.
    Compilation would happen asynchronously, which means compilation of graphs and their execution could overlap, and compilation results would be cached. We would work with the XLA team to further reduce overall compile time overhead with techniques such as persistent deduping and by limiting inlining and unrolling. As a result, the compile time overhead would be drastically minimized even for larger incrementally compiled graphs.
  3. JIT: This approach would enable a true just-in-time compilation engine with recompilation, feedback-directed optimizations, autotuning, and active memory management to avoid OOMs. With this, users would get the eager experience but with compiled performance after just a few inferences or training steps.

With these principles in mind, we could deliver on the following features:

  1. Eager Execution by Default: As described above, operations will appear as being eagerly executed, just as they do on CPU or GPU, even though they are being compiled in the background with minimal, and mostly hidden, compile time overhead. This would provide a familiar, intuitive, and much easier-to-debug workflow where users can inspect tensors and use standard Python tooling.
  2. Integration with torch.compile: For maximizing performance, TPU would integrate as a first-class backend for torch.compile. This would allow users to get the performance benefits of XLA compilation and TPUs at scale on their performance-critical code with a simple @torch.compile decorator.
  3. Distributed Training via DTensor: The new backend would natively support PyTorch's distributed APIs. This would allow users to leverage advanced, large-scale distributed training strategies like Fully Sharded Data Parallel (FSDP) and other model parallelism techniques out of the box, making it much simpler to scale up models.
  4. A More "PyTorch Native" Feel: The end goal is to abstract away the complexities of the underlying compiler. Developing for a TPU should not require a fundamentally different programming model. This would mean moving away from torch_xla-specific APIs and toward the standard PyTorch API surface. This approach would provide the best of both worlds: the interactive, flexible development experience of PyTorch's eager mode and the world-class performance of the XLA compiler for scaled-out workloads.

We Want Your Feedback!

We're excited for this direction, and to bring together PyTorch's eager mode and the XLA compiler in a way that helps the community achieve new levels of performance and scale. This is a significant undertaking, and we want to build it with the community. We're open to feedback on this direction.

  • Does this proposal address the pain points you've experienced with torch_xla?
  • Are there specific workflows or PyTorch features whose support is critical for your work?
  • What would be the most important factors for you when considering a migration from torch_xla to this new stack or from PyTorch on GPU?

Thank you for being a part of the PyTorch/XLA community. We're excited to build this next chapter with you.

qcc4cp avatar Oct 20 '25 22:10 qcc4cp

This would be awesome and the natural way to use TPU hardware with PyTorch.

tchaton avatar Oct 23 '25 09:10 tchaton

Excited to see this direction!

jigarsavla avatar Oct 23 '25 23:10 jigarsavla

This is super cool!.

bdubey avatar Oct 24 '25 03:10 bdubey

Super excited about it!

WilliamZhang20 avatar Oct 26 '25 22:10 WilliamZhang20

It's too cooool

Ffffffffchopin avatar Oct 27 '25 16:10 Ffffffffchopin

Pretty cool idea Robert. Great to see. Would have been even better if started 7 years back IIRC:-).

rama-govindaraju avatar Oct 29 '25 14:10 rama-govindaraju

Very excited for this! Was wondering what the plan would be to support SPMD flows? Currently we use Torch-XLA's xs.mark_sharding() APIs to shard input tensors, and later use Shardy propagation to solve for activation and output tensor sharding within the MLIR graph. Can we expect something similar with this new backend?

hshahTT avatar Nov 05 '25 17:11 hshahTT

We're adopting a more PyTorch native approach and will not be using mark_sharding. Initially we're focused on PyTorch's native distributed APIs (torch.distributed). Once our focus shifts to peak distributed performance on TPU spmd-style sharding annotations is something we intend to explore.

cjonesy20 avatar Nov 05 '25 17:11 cjonesy20

Gotcha, thx for the response @cjonesy20. So I guess SPMD isn't going to be supported in the first release of TorchTPU (since IIRC torch.distributed doesn't explicitly support SPMD)? Also, is there a high-level document available that will detail how the torch.distributed APIs will work with the OpenXLA compiler stack?

hshahTT avatar Nov 05 '25 17:11 hshahTT

Correct, SPMD is not planned for the first release. We're implementing the core collectives (scatter, gather, broadcast, etc) and then using those to implement specific APIs like torch.distributed.fsdp. From a user perspective the torch.distributed APIs should work just like they do on GPU.

cjonesy20 avatar Nov 06 '25 21:11 cjonesy20

We're open to feedback on this direction.

Overall this looks like a good direction FWIW.

  • Does this proposal address the pain points you've experienced with torch_xla?

My outside perspective is that PyTorch community seems torn between two often conflicting use cases:

  1. those who want design new models and highly value PyTorch imperative design in eager mode, and are typically using GPU(s) for which eager mode is not optimal, but good enough;
  2. and those who what to squeeze maximum performance be it in GPUs or ASICs, and therefore need the whole computation to be broken up in a fixed number of computational graphs.

Due to the prevalence of GPUs the first camp is the largest.

torch_xla was clearly aimed at the latter camp. torch.compile looked like a way to have the cake and eat it, but in practice a big refactoring is needed to avoid performance killing graph recompilations (and a comprehensive mental model of how TorchDynamo and torch_xla work under the hood.)

Making eager mode faster means the performance cliff is not as steep, and that migrating code from the first to the second camp can be done more gradually.

But this is more about the path taken than the end destination: AFAICT, if one wants optimum performance, be with GPUs or ASICs, the code needs to be refactored to avoid eager.

  • Are there specific workflows or PyTorch features whose support is critical for your work?

Being able to achieve good performance on vLLM inference -> PyTorch -> XLA -> GPUs/ASICs

  • What would be the most important factors for you when considering a migration from torch_xla to this new stack or from PyTorch on GPU?
  • that torch.compile with this new stack continues to works just as well as it did with torch_xla
  • that all the time need to pre-compile the graphs needed to execute eager operations can be hidden (through pre-warmed offline caches) or opted-out
  • metrics are provided to know whenever eager operations were executed (and therefore when there are opportunities to optimize by fusing those operations into torch.compile graphs)
  • Pallas -> GPU support (mosaic or not), or first-class support for Triton kernels

Thanks.

jrfonseca avatar Dec 02 '25 11:12 jrfonseca

A key advantage of PyTorch/XLA has always been the performance gains achieved by optimizing the computation graph as a whole. With Eager Mode becoming the default alongside true just-in-time compilation, I am curious how the new stack achieves performance parity with the current implementation. Is there an existing example using JIT compilation that I could review?

I strongly agree with Jose that there are two distinct user groups with fundamentally different needs. Researchers designing new models benefit from the simplicity of Eager Mode, while engineers focusing on fine-tuning or inference prioritize performance and require low-level visibility to maximize hardware utilization.

While this proposal clearly benefits the research camp, I don't believe the user experience is the primary factor deterring people from using TPUs—accessibility is. GPUs are simply much easier for individual and institutional researchers to access. Even if PyTorch/XLA matches CUDA's simplicity via Eager Mode, researchers may still default to GPUs due to their prevalence. Is enabling eager execution a positive step? Absolutely—provided it does not alienate the second group (who likely make up the majority of current PyTorch/XLA users) by significantly increasing complexity or introducing too much "magic."

iwknow avatar Dec 04 '25 07:12 iwknow

  • What would be the most important factors for you when considering a migration from torch_xla to this new stack or from PyTorch on GPU?
  • that torch.compile with this new stack continues to works just as well as it did with torch_xla

Our goal is to mirror as close as possible how torch.compile works with GPU.

  • that all the time need to pre-compile the graphs needed to execute eager operations can be hidden (through pre-warmed offline caches) or opted-out

Pre-warmed offline caches are something we were discussing yesterday. It's a reasonable ask.

  • metrics are provided to know whenever eager operations were executed (and therefore when there are opportunities to optimize by fusing those operations into torch.compile graphs)

The deferred execution in eager means that if you don't materialize a tensor it will be deferred and we send the largest graph possible to XLA to maximize our opportunity for things like fusions.

  • Pallas -> GPU support (mosaic or not), or first-class support for Triton kernels

Like torch_xla Pallas will be fully supported through Mosaic. Separately we're investigating how best to help users migrate cuda and triton gpu kernels to pallas.

Thanks.

chrishjones20 avatar Dec 04 '25 19:12 chrishjones20

A key advantage of PyTorch/XLA has always been the performance gains achieved by optimizing the computation graph as a whole. With Eager Mode becoming the default alongside true just-in-time compilation, I am curious how the new stack achieves performance parity with the current implementation. Is there an existing example using JIT compilation that I could review?

I believe that key advantage remains, if you don't materialize any tensors in eager mode in most cases it will defer everything and send an entire computational graph to XLA. In effect you're getting an implicit compilation vs. torch.compile (through Dynamo) being an explicit full graph compilation.

chrishjones20 avatar Dec 04 '25 19:12 chrishjones20

if you don't materialize any tensors in eager mode in most cases it will defer everything and send an entire computational graph to XLA.

how can i don't materialize any tensors in eager mode? any operation will result in a compilation and an execution in eager mode. right? also, how does it optimize the entire computation graph? my understanding is that the eager mode will break the entire computation graph into smaller graphs. do you have more details about the mechanism for the "graph fusion".

iwknow avatar Dec 04 '25 21:12 iwknow

if you don't materialize any tensors in eager mode in most cases it will defer everything and send an entire computational graph to XLA.

how can i don't materialize any tensors in eager mode? any operation will result in a compilation and an execution in eager mode. right? also, how does it optimize the entire computation graph? my understanding is that the eager mode will break the entire computation graph into smaller graphs. do you have more details about the mechanism for the "graph fusion".

In PyTorch on CUDA or CPU today if you do:

A = torch.tensor([[1.0, 2.0], 
                      [3.0, 4.0]])
B = torch.tensor([[5.0, 6.0], 
                      [7.0, 8.0]])
C = A + B 
# print(C) would cause a materialization and break the graph
D_reshaped = C.view(-1)
print(E_squared = D_reshaped ** 2)

It will immediately execute each line as soon as the Python interpreter reaches it. With deferred execution it will wait until we materialize the final output (E_squared). This allows us to maximize the opportunity for fusions. However, if you put a breakpoint or print on, for example, C that would take the graph for A, B, and C and send that to the compiler. Then D and E would get put in a second graph and sent to the compiler on the last line. cc: @qcc4cp fact check me here if I have this wrong.

chrishjones20 avatar Dec 04 '25 22:12 chrishjones20

What's the evolution plan of Torch-XLA/TorchAX after this RFC? Will Torch-XLA be deprecated later?

I asked this because Torch-XLA has provided PJRT API and can be extended with custom SW/HW. Will PJRT API exist in the new native TPU backend, or become something else?

Zantares avatar Dec 19 '25 03:12 Zantares

TorchTPU will use the PJRT API similar to PyTorch/XLA and JAX.

chrishjones20 avatar Dec 19 '25 03:12 chrishjones20