xla
xla copied to clipboard
[PJRT] Enable PJRT C API option
PjRtCApiClient
in upstream TensorFlow
The new PjRt C API will be used to enable the new TPU runtime that we will support long-term. Note: The PjRtCApiClient
currently requires a Google-internal build of libtpu. The TPU_C_API
option will not work with the public libtpu-nightly
builds at the time of writing.
- Create a
PjRtCApiClient
whenPJRT_DEVICE=TPU_C_API
- Still return
"TPU"
frompjrt.device_type()
because this should be interchangeable withPjRtTpuClient
.
- Still return
- Avoid calling
executable->GetHloModules()
except for SPMD-sharded executables because it's not yet supported in the C API. This also means we won't be able to test SPMD with the C API untilGetHloModules
is supported upstream. - Use
xla::PjRtClient::HostBufferSemantics::kZeroCopy
with C API because it does not yet supportkImmutableUntilTransferCompletes
. - Avoid calling
AcquireExternalReference
to get anOpaqueHandle
.AcquireExternalReference
is for sharing memory with an external framework and is not necessary in this case. We only need a unique int to represent the underlying buffer (analogous to a "handle" in XRT) to use inRunPostOrder
. Use the buffer address directly instead because it's not trivial to get aliasedPjRtBuffer
s and PjRt doesn't have the same notion of a "handle" or a unique buffer ID. The C API does not supportAcquireExternalReference
, but we didn't actually need it anyway. - Don't modify
TPU_LIBRARY_PATH
in__init__.py
, since that needs to be set to the custom internallibtpu
build for testing this PR. Instead, useTPU_LOAD_LIBRARY=0
to prevent the TF TPU runtime from initializing the TPU when we're not using XRT. This expresses our intent (i.e. don't load libtpu yet) more clearly anyway. - Patch https://github.com/tensorflow/tensorflow/commit/9a4502ab470177c259865b5b58598be690bdd86e until we update our TF pin past the fix.
Tested manually with ResNet50 on a v4-8. Performance was within ~2% of the PjRtTpuClient
baseline.
CC @yeounoh @skye