xla
xla copied to clipboard
A machine learning compiler for GPUs, CPUs, and ML accelerators
Add an algebraic simplification pattern for multiply(add(conv(input, filter), bias), broadcast(constant)) -> add(conv(input, multiply(filter, broadcast(constant))), multiply(bias, broadcast(constant)))
Reverts 693ee2e13225331bebc946442af7e2d59355adea
This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since...
[JAX] Add PyClient::GetAllDevices() and expose it as an internal JAX backend API JAX backend forwards `xla::ifrt::Client::GetAllDevices()` to `xla::PyClient::GetAllDevices()`, which is accessible via JAX `backend.get_all_devices()`. This API is an internal JAX...
[IFRT] Add Client::GetAllDevices() This defines `Client::GetAllDevices()`. It is similar to `Client::devices()`, but it enumerates all devices available on the client, regardless of the type/kind of devices. This multi-device behavior was...
Remove more GPU/CUDA/ROCm attribute guards from xla/service/gpu - This removes `if_gpu_is_configured` guards from targets that are only supposed to be built for GPU. (Also tags them as `gpu` so that...
[TEST] Debug linking of mlir_fusion_opt For whatever reason it somestimes can't find cudnn. Let's find out why.
…ocblas_get_version_string
#sdy add JAX Shardy support for memories.