neural-tangents
neural-tangents copied to clipboard
Does neural-tangents work for custom layer?
I have built a custom layer (KerasLayer
) using class
in python (say class NewLayer
). Can I use something like stax.NewLayer
for manipulating neural-tangents
on this custom layer?
I'm afraid not, you would need to write your own stax
layer, defining init_fn
, apply_fn
, and kernel_fn
, e.g. as https://github.com/google/neural-tangents/blob/9f21e6e4f21a279ebbb033ff924e1ebc4723e077/neural_tangents/_src/stax/linear.py#L749
To what extent you'll be able to reuse your existing code will depend on the specifics.
We have tools allowing to implement some layers easier than from scratch, such as pointwise nonlinearities https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Elementwise.html. If your layer is an affine non-parametric transformation (similar to https://neural-tangents.readthedocs.io/en/latest/stax.html#linear-nonparametric), it is also easy to automatically translate in into a stax
layer (something I just didn't get to doing yet). In general, if you could tell us what your layer does, we may be able to help implementing it.
Finally, note that empirical kernels (https://neural-tangents.readthedocs.io/en/latest/empirical.html) work with any JAX functions, and don't require them to be written in stax
.