array-api icon indicating copy to clipboard operation
array-api copied to clipboard

Deep Learning Extensions

Open szha opened this issue 3 years ago • 7 comments

Hi,

Now that we have a good foundation on the core array API in the standard, it's probably a good time to start thinking about the neural network extensions. As an initial step, here are the neural network operations from a few selected deep learning frameworks:

TF: https://www.tensorflow.org/api_docs/python/tf/nn Pytorch: https://pytorch.org/docs/stable/nn.functional.html MXNet: https://mxnet.apache.org/versions/master/api/python/docs/api/npx/index.html flax (jax): https://flax.readthedocs.io/en/latest/flax.linen.html#linear-modules haiku (jax): https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules

In addition, I think the array API standard can benefit from the model exchange format definition of ONNX. Here are the operators that are currently in ONNX opsets. ONNX: https://github.com/onnx/onnx/blob/master/docs/Operators.md

Next step would be to figure out a good set of operators in the intersection and iterate through the design choices in them.

szha avatar Apr 07 '21 19:04 szha

Not sure if relevant -- mlir-npcomp (https://github.com/llvm/mlir-npcomp) does the conversion of NumPy -> MLIR, but it doesn't feel very mature yet.

learning-chip avatar Apr 12 '21 06:04 learning-chip

One question I have is whether there's any kind of commonality to APIs today. Some of the ones I checked that I expected to be "simplest" are functions like softmax and avg_pool. softmax seems to be the same except for in MXNet. Overall it's not easy to find many functions that overlap well though.

rgommers avatar Apr 12 '21 18:04 rgommers

@rgommers valid question. I think there is enough commonality for basic use cases to start standardizing. The existence of ONNX and the possibility for mapping operators from TF/PT/MX to ONNX is to some extent an evidence of that. In fact, ONNX has been focusing on intersections of operator sets from different frameworks so far so it should provide a good starting point. That said, because deep learning is newer, it's more likely to have operators across frameworks that have semantic equivalence than those having identical definition.

We will likely need some more analysis and comparison to tell. I'm hoping to contribute some as soon as I have free time.

szha avatar Apr 12 '21 20:04 szha

Activation Functions

Here I summarize a few activation functions in ONNX, PyTorch, Flax (JAX), Tensorflow, and MXNet.

celu

kwargs ONNX PyTorch Flax (JAX)
alpha Y Y Y
inplace N Y N
  • Not implemented in Tensorflow, MXNet.

elu

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
alpha Y Y Y N Y
inplace N Y N N N
  • Implemented as part of leaky_relu in MXNet.

gelu

kwargs PyTorch Flax (JAX) Tensorflow MXNet
approximate N Y Y N
  • Not defined in ONNX. MXNet hasn't implemented approximate version yet.

log_softmax

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
axis/dim Y Y Y Y Y
length N N N N Y

relu

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
inplace N Y N N N

sigmoid

All libraries have consistent definition.

soft_sign

All libraries have consistent definition except MXNet.

  • Implemented as part of activation op in MXNet.

softmax

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
axis/dim Y Y Y Y Y
length N N N N Y

silu

kwargs PyTorch Flax (JAX) Tensorflow MXNet
beta N N N Y
  • Not defined in ONNX.

szha avatar Apr 15 '21 16:04 szha

We mentioned a few open questions in 4/15 meeting:

  • How large do we expect the API surface to be?
  • What criteria should we use for including operators as part of the standard?
    • What's the half-life of these activation functions and how many useful ones stick around?
  • Do we require the operators to be differentiable? If so, should we standardize the gradient definition (e.g. approximate gelu)?

szha avatar Apr 15 '21 18:04 szha

  • Do we require the operators to be differentiable? If so, should we standardize the gradient definition (e.g. approximate gelu)?

I think it's good to have them differentiable, but then all differentiable functions should be grouped in a separate module, say, array_api.dl so that we can differentiate them from non-differentiable functions in the main namespace (and make this module optional for e.g. NumPy/CuPy).

leofang avatar Apr 15 '21 18:04 leofang

Not sure if this is the place to comment, but sigmoid might be a bad name. "Sigmoid" just means s-shaped. The function should really be called logistic: https://en.wikipedia.org/wiki/Logistic_function

NeilGirdhar avatar Jan 15 '22 10:01 NeilGirdhar

As this proposal is without a champion, I'll go ahead and close. Should we see more ecosystem consensus, we can revisit/reopen and consider as a future specification extension.

kgryte avatar Jun 29 '23 08:06 kgryte