xla icon indicating copy to clipboard operation
xla copied to clipboard

[PJRT] Enable PJRT C API option

Open will-cromar opened this issue 2 years ago • 0 comments

PjRtCApiClient in upstream TensorFlow

The new PjRt C API will be used to enable the new TPU runtime that we will support long-term. Note: The PjRtCApiClient currently requires a Google-internal build of libtpu. The TPU_C_API option will not work with the public libtpu-nightly builds at the time of writing.

  • Create a PjRtCApiClient when PJRT_DEVICE=TPU_C_API
    • Still return "TPU" from pjrt.device_type() because this should be interchangeable with PjRtTpuClient.
  • Avoid calling executable->GetHloModules() except for SPMD-sharded executables because it's not yet supported in the C API. This also means we won't be able to test SPMD with the C API until GetHloModules is supported upstream.
  • Use xla::PjRtClient::HostBufferSemantics::kZeroCopy with C API because it does not yet support kImmutableUntilTransferCompletes.
  • Avoid calling AcquireExternalReference to get an OpaqueHandle. AcquireExternalReference is for sharing memory with an external framework and is not necessary in this case. We only need a unique int to represent the underlying buffer (analogous to a "handle" in XRT) to use in RunPostOrder. Use the buffer address directly instead because it's not trivial to get aliased PjRtBuffers and PjRt doesn't have the same notion of a "handle" or a unique buffer ID. The C API does not support AcquireExternalReference, but we didn't actually need it anyway.
  • Don't modify TPU_LIBRARY_PATH in __init__.py, since that needs to be set to the custom internal libtpu build for testing this PR. Instead, use TPU_LOAD_LIBRARY=0 to prevent the TF TPU runtime from initializing the TPU when we're not using XRT. This expresses our intent (i.e. don't load libtpu yet) more clearly anyway.
  • Patch https://github.com/tensorflow/tensorflow/commit/9a4502ab470177c259865b5b58598be690bdd86e until we update our TF pin past the fix.

Tested manually with ResNet50 on a v4-8. Performance was within ~2% of the PjRtTpuClient baseline.

CC @yeounoh @skye

will-cromar avatar Oct 07 '22 17:10 will-cromar