xla
xla copied to clipboard
Support of AOT compilation (refine #6992)
This is a follow up PR to refine https://github.com/pytorch/xla/pull/6992.
In this PR, I created the PjRtCompilationClient
to serve the ahead of time compilation. In this way, we don't need to create the CompileOnlyPjRtClient
, CompileOnlyPjRtDevice
etc. This makes it easier during openxla pin update.
Instructions on how to run AOT compilation has been updated. We need to specify two extra flags when run on CPU device: XLA_PERSISTENT_CACHE_PATH
as follows:
----------------------- ON CPU--------------------
PJRT_DEVICE=TPU XLA_PERSISTENT_CACHE_PATH=./ python aot_encode.py
aot_encode.py:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
torch_xla._XLAC._xla_set_virtual_topology("v4:2x2x1") # <----- define virtual topology here. Must be specified before any device declaration.
a = torch.rand([2,3])
b = torch.rand([2,3])
device = xm.xla_device()
a = a.to(device)
b = b.to(device)
f = torch.hstack([a,b])
torch_xla._XLAC._xla_warm_up_cache([f],[]) # <----- call this to avoid real computation.
This will genereate a hashing file named like 229013763457648799216243727807636414712, which can be deserialized by running the same graph code on a TPU:
-----------------------ON TPU v4-8--------------------
PJRT_DEVICE=TPU XLA_PERSISTENT_CACHE_PATH=./ python aot_decode.py
aot_decode.py:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
a = torch.rand([2,3])
b = torch.rand([2,3])
device = xm.xla_device()
a = a.to(device)
b = b.to(device)
f = torch.hstack([a,b])
print(f)