xla
xla copied to clipboard
[JAX] Add PyClient::GetAllDevices() and expose it as an internal JAX backend API
[JAX] Add PyClient::GetAllDevices() and expose it as an internal JAX backend API
JAX backend forwards xla::ifrt::Client::GetAllDevices() to
xla::PyClient::GetAllDevices(), which is accessible via JAX
backend.get_all_devices(). This API is an internal JAX API that is used for
building an experimental mesh utils API (finding colocated CPU devices) and should
not be used by the user code.