axon icon indicating copy to clipboard operation
axon copied to clipboard

Add custom RNN kernels

Open seanmor5 opened this issue 2 years ago • 0 comments

Currently, XLA's GPU while thunk actually runs the loop on the CPU and is inefficient for training RNNs. TensorFlow implements a CuDNNRNN op which takes advantage of specialized GPU RNN implementations in supported scenarios. We should offer something similar.

This would require the ability to use defn to defer to a custom/specialized implementation. See discussion in: https://github.com/elixir-nx/nx/issues/362

seanmor5 avatar Aug 27 '21 20:08 seanmor5