xla
xla copied to clipboard
Use PjRt GPU client
- Add support for single-host single-GPU
- Add a simple unit test
@will-cromar Yep! I was able to run the MNIST example to convergence using the XLA device, and verified GPU utilization with nvidia-smi.
Currently looking into adding a CI test, I'll update the PR once that's ready.
rebase should fix build issue..
@cicirori FYI
The CI failure is from test_index_select_0dim
. I'm able to replicate locally, and it looks like the tensor's size
isn't being preserved after index_select
when on the XLA device, e.g.:
torch.index_select(tensor([], size=(0, 1, 2, 0)), 0, tensor([], dtype=torch.int64)) = tensor([], size=(0, 1, 2, 0))
torch.index_select(tensor([], device='xla:0', size=(0, 1, 2, 0)), 0, tensor([], device='xla:0', dtype=torch.int64)) = tensor([], device='xla:0')
This only appears to happen when using the PJRT GPU client.
not sure why pjrt:gpu will have this issue, I would suggest trying some other index_select
case. Index select with tensor([], size=(0, 1, 2, 0))
seems to be a corner case.
Rebase after the TF pin update seems to have fixed it.
@ymwangg FYI, we have not test the speed with PJRT_GPU client and it currently only support 1 gpu. If you want to give it a try and let us know how it works that would be great.