xla icon indicating copy to clipboard operation
xla copied to clipboard

GSPMD + PyTorch Compile + TPU crash

Open agemagician opened this issue 1 year ago • 4 comments

Hi,

I am trying to combine both GSPMD + PyTorch Compile, but it doesn't work. I took a copy of the test script "test_train_spmd_imagenet.py" and test it in colab, and it started normally. However, after I added the compile line :

device = xm.xla_device()
  model = get_model_property('model_fn')().to(device)

  model = torch.compile(
        model, backend='aot_torchxla_trace_once')

It crashed.

Here is a Colab example to reproduce the results: https://colab.research.google.com/drive/1KNcBydAfZXLATpSo-CXILxtHJkK8JD-2?usp=sharing

agemagician avatar Mar 28 '23 04:03 agemagician