neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

Does neural-tangents work for custom layer?

Open Shuhul24 opened this issue 1 year ago • 1 comments

I have built a custom layer (KerasLayer) using class in python (say class NewLayer). Can I use something like stax.NewLayer for manipulating neural-tangents on this custom layer?

Shuhul24 avatar Sep 03 '22 14:09 Shuhul24

I'm afraid not, you would need to write your own stax layer, defining init_fn, apply_fn, and kernel_fn, e.g. as https://github.com/google/neural-tangents/blob/9f21e6e4f21a279ebbb033ff924e1ebc4723e077/neural_tangents/_src/stax/linear.py#L749

To what extent you'll be able to reuse your existing code will depend on the specifics.

We have tools allowing to implement some layers easier than from scratch, such as pointwise nonlinearities https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Elementwise.html. If your layer is an affine non-parametric transformation (similar to https://neural-tangents.readthedocs.io/en/latest/stax.html#linear-nonparametric), it is also easy to automatically translate in into a stax layer (something I just didn't get to doing yet). In general, if you could tell us what your layer does, we may be able to help implementing it.

Finally, note that empirical kernels (https://neural-tangents.readthedocs.io/en/latest/empirical.html) work with any JAX functions, and don't require them to be written in stax.

romanngg avatar Sep 04 '22 20:09 romanngg