array-api
array-api copied to clipboard
Deep Learning Extensions
Hi,
Now that we have a good foundation on the core array API in the standard, it's probably a good time to start thinking about the neural network extensions. As an initial step, here are the neural network operations from a few selected deep learning frameworks:
TF: https://www.tensorflow.org/api_docs/python/tf/nn Pytorch: https://pytorch.org/docs/stable/nn.functional.html MXNet: https://mxnet.apache.org/versions/master/api/python/docs/api/npx/index.html flax (jax): https://flax.readthedocs.io/en/latest/flax.linen.html#linear-modules haiku (jax): https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules
In addition, I think the array API standard can benefit from the model exchange format definition of ONNX. Here are the operators that are currently in ONNX opsets. ONNX: https://github.com/onnx/onnx/blob/master/docs/Operators.md
Next step would be to figure out a good set of operators in the intersection and iterate through the design choices in them.
Not sure if relevant -- mlir-npcomp (https://github.com/llvm/mlir-npcomp) does the conversion of NumPy -> MLIR, but it doesn't feel very mature yet.
One question I have is whether there's any kind of commonality to APIs today. Some of the ones I checked that I expected to be "simplest" are functions like softmax
and avg_pool
. softmax
seems to be the same except for in MXNet. Overall it's not easy to find many functions that overlap well though.
@rgommers valid question. I think there is enough commonality for basic use cases to start standardizing. The existence of ONNX and the possibility for mapping operators from TF/PT/MX to ONNX is to some extent an evidence of that. In fact, ONNX has been focusing on intersections of operator sets from different frameworks so far so it should provide a good starting point. That said, because deep learning is newer, it's more likely to have operators across frameworks that have semantic equivalence than those having identical definition.
We will likely need some more analysis and comparison to tell. I'm hoping to contribute some as soon as I have free time.
Activation Functions
Here I summarize a few activation functions in ONNX, PyTorch, Flax (JAX), Tensorflow, and MXNet.
celu
kwargs | ONNX | PyTorch | Flax (JAX) |
---|---|---|---|
alpha | Y | Y | Y |
inplace | N | Y | N |
- Not implemented in Tensorflow, MXNet.
elu
kwargs | ONNX | PyTorch | Flax (JAX) | Tensorflow | MXNet |
---|---|---|---|---|---|
alpha | Y | Y | Y | N | Y |
inplace | N | Y | N | N | N |
- Implemented as part of leaky_relu in MXNet.
gelu
kwargs | PyTorch | Flax (JAX) | Tensorflow | MXNet |
---|---|---|---|---|
approximate | N | Y | Y | N |
- Not defined in ONNX. MXNet hasn't implemented approximate version yet.
log_softmax
kwargs | ONNX | PyTorch | Flax (JAX) | Tensorflow | MXNet |
---|---|---|---|---|---|
axis/dim | Y | Y | Y | Y | Y |
length | N | N | N | N | Y |
relu
kwargs | ONNX | PyTorch | Flax (JAX) | Tensorflow | MXNet |
---|---|---|---|---|---|
inplace | N | Y | N | N | N |
sigmoid
All libraries have consistent definition.
soft_sign
All libraries have consistent definition except MXNet.
- Implemented as part of activation op in MXNet.
softmax
kwargs | ONNX | PyTorch | Flax (JAX) | Tensorflow | MXNet |
---|---|---|---|---|---|
axis/dim | Y | Y | Y | Y | Y |
length | N | N | N | N | Y |
silu
kwargs | PyTorch | Flax (JAX) | Tensorflow | MXNet |
---|---|---|---|---|
beta | N | N | N | Y |
- Not defined in ONNX.
We mentioned a few open questions in 4/15 meeting:
- How large do we expect the API surface to be?
- What criteria should we use for including operators as part of the standard?
- What's the half-life of these activation functions and how many useful ones stick around?
- Do we require the operators to be differentiable? If so, should we standardize the gradient definition (e.g. approximate gelu)?
- Do we require the operators to be differentiable? If so, should we standardize the gradient definition (e.g. approximate gelu)?
I think it's good to have them differentiable, but then all differentiable functions should be grouped in a separate module, say, array_api.dl
so that we can differentiate them from non-differentiable functions in the main namespace (and make this module optional for e.g. NumPy/CuPy).
Not sure if this is the place to comment, but sigmoid
might be a bad name. "Sigmoid" just means s-shaped. The function should really be called logistic
: https://en.wikipedia.org/wiki/Logistic_function
As this proposal is without a champion, I'll go ahead and close. Should we see more ecosystem consensus, we can revisit/reopen and consider as a future specification extension.