xla icon indicating copy to clipboard operation
xla copied to clipboard

Support of AOT compilation (refine #6992)

Open zpcore opened this issue 8 months ago • 0 comments

This is a follow up PR to refine https://github.com/pytorch/xla/pull/6992.

In this PR, I created the PjRtCompilationClient to serve the ahead of time compilation. In this way, we don't need to create the CompileOnlyPjRtClient, CompileOnlyPjRtDevice etc. This makes it easier during openxla pin update.

Instructions on how to run AOT compilation has been updated. We need to specify two extra flags when run on CPU device: XLA_PERSISTENT_CACHE_PATH as follows:
----------------------- ON CPU--------------------

PJRT_DEVICE=TPU   XLA_PERSISTENT_CACHE_PATH=./  python aot_encode.py

aot_encode.py:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

torch_xla._XLAC._xla_set_virtual_topology("v4:2x2x1")  # <----- define virtual topology here. Must be specified before any device declaration.

a = torch.rand([2,3])
b = torch.rand([2,3])
device = xm.xla_device()
a = a.to(device)
b = b.to(device)
f = torch.hstack([a,b])

torch_xla._XLAC._xla_warm_up_cache([f],[])  # <----- call this to avoid real computation.

This will genereate a hashing file named like 229013763457648799216243727807636414712, which can be deserialized by running the same graph code on a TPU:
-----------------------ON TPU v4-8--------------------

PJRT_DEVICE=TPU XLA_PERSISTENT_CACHE_PATH=./  python aot_decode.py

aot_decode.py:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

a = torch.rand([2,3])
b = torch.rand([2,3])
device = xm.xla_device()
a = a.to(device)
b = b.to(device)
f = torch.hstack([a,b])

print(f)

zpcore avatar Jun 26 '24 20:06 zpcore