neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

Question: nt.stax vs flax (or other JAX-based NN libs)

Open vimalthilak opened this issue 3 years ago • 1 comments

Hello neutal-tangents (NT) authors,

Thanks for creating and maintaining such a great product.

I have a few questions related to usage:

  • Is nt.stax() the best way to interact with NT library at the moment? I observed in issue #99 that I can it may be possible to use the empirical portion with a FLAX-NT bridge. However, if I want infinite width networks then I have to use nt.stax. Is my understanding correct? -I am looking at using FLAX and/or other higher level libraries (common loop utils for instance) to simplify writing training loops. But I do not want to lose my ability to interact with NT. Do you have any advice on whether there are drawbacks in terms of using NT if I went down this path?
  • Another naive question: is CLU a viable tool for use with nt.stax for regular NN training

Any help with the above is very much appreciated!

vimalthilak avatar Jul 02 '21 17:07 vimalthilak

Hi Vimal,

  1. Yes, I'm afraid your understanding is correct, currently NT only compoutes exact infinite-width kernels of networks defined in nt.stax. As mentioned in the thread you linked, @sschoenholz is looking into relaxing this constraint, but this is a challenging task and we don't have a precise timeline. So in the meantime you may need to either write a converter between nt.stax and FLAX models, or adapt the FLAX utilities to nt.stax (either can be quite laborious...).

  2. Hard to say, I'm not familiar with it, and seems like they don't have documentation yet. I imagine some things that don't depend heavily on flax might work (maybe https://github.com/google/CommonLoopUtils/blob/master/clu/metrics.py?), but others more specific to flax may not (perhaps https://github.com/google/CommonLoopUtils/blob/master/clu/checkpoint.py?). So you may need to try it out in practice / read through their code to figure out. Sorry I can't be more helpful here! (If someone is more familiar with flax/CLU, please let us know if you have any ideas; one other place to ask could be https://github.com/google/jax regarding whether CLU would work with jax.experimental.stax, which is very similar to nt.stax).

romanngg avatar Jul 05 '21 17:07 romanngg