xla icon indicating copy to clipboard operation
xla copied to clipboard

Use PjRt GPU client

Open jonb377 opened this issue 2 years ago • 5 comments

  • Add support for single-host single-GPU
  • Add a simple unit test

jonb377 avatar Sep 23 '22 18:09 jonb377

@will-cromar Yep! I was able to run the MNIST example to convergence using the XLA device, and verified GPU utilization with nvidia-smi.

Currently looking into adding a CI test, I'll update the PR once that's ready.

jonb377 avatar Sep 23 '22 21:09 jonb377

rebase should fix build issue..

JackCaoG avatar Sep 27 '22 05:09 JackCaoG

@cicirori FYI

JackCaoG avatar Sep 28 '22 23:09 JackCaoG

The CI failure is from test_index_select_0dim. I'm able to replicate locally, and it looks like the tensor's size isn't being preserved after index_select when on the XLA device, e.g.:

torch.index_select(tensor([], size=(0, 1, 2, 0)), 0, tensor([], dtype=torch.int64)) = tensor([], size=(0, 1, 2, 0))
torch.index_select(tensor([], device='xla:0', size=(0, 1, 2, 0)), 0, tensor([], device='xla:0', dtype=torch.int64)) = tensor([], device='xla:0')

This only appears to happen when using the PJRT GPU client.

jonb377 avatar Oct 03 '22 22:10 jonb377

not sure why pjrt:gpu will have this issue, I would suggest trying some other index_select case. Index select with tensor([], size=(0, 1, 2, 0)) seems to be a corner case.

JackCaoG avatar Oct 10 '22 22:10 JackCaoG

Rebase after the TF pin update seems to have fixed it.

jonb377 avatar Oct 14 '22 02:10 jonb377

@ymwangg FYI, we have not test the speed with PJRT_GPU client and it currently only support 1 gpu. If you want to give it a try and let us know how it works that would be great.

JackCaoG avatar Oct 14 '22 16:10 JackCaoG