axon
axon copied to clipboard
Add custom RNN kernels
Currently, XLA's GPU while
thunk actually runs the loop on the CPU and is inefficient for training RNNs. TensorFlow implements a CuDNNRNN op which takes advantage of specialized GPU RNN implementations in supported scenarios. We should offer something similar.
This would require the ability to use defn
to defer to a custom/specialized implementation. See discussion in: https://github.com/elixir-nx/nx/issues/362