neural-tangents
neural-tangents copied to clipboard
Question: nt.stax vs flax (or other JAX-based NN libs)
Hello neutal-tangents (NT) authors,
Thanks for creating and maintaining such a great product.
I have a few questions related to usage:
- Is nt.stax() the best way to interact with NT library at the moment? I observed in issue #99 that I can it may be possible to use the empirical portion with a FLAX-NT bridge. However, if I want infinite width networks then I have to use nt.stax. Is my understanding correct? -I am looking at using FLAX and/or other higher level libraries (common loop utils for instance) to simplify writing training loops. But I do not want to lose my ability to interact with NT. Do you have any advice on whether there are drawbacks in terms of using NT if I went down this path?
- Another naive question: is CLU a viable tool for use with nt.stax for regular NN training
Any help with the above is very much appreciated!
Hi Vimal,
-
Yes, I'm afraid your understanding is correct, currently NT only compoutes exact infinite-width kernels of networks defined in
nt.stax
. As mentioned in the thread you linked, @sschoenholz is looking into relaxing this constraint, but this is a challenging task and we don't have a precise timeline. So in the meantime you may need to either write a converter between nt.stax and FLAX models, or adapt the FLAX utilities to nt.stax (either can be quite laborious...). -
Hard to say, I'm not familiar with it, and seems like they don't have documentation yet. I imagine some things that don't depend heavily on flax might work (maybe https://github.com/google/CommonLoopUtils/blob/master/clu/metrics.py?), but others more specific to flax may not (perhaps https://github.com/google/CommonLoopUtils/blob/master/clu/checkpoint.py?). So you may need to try it out in practice / read through their code to figure out. Sorry I can't be more helpful here! (If someone is more familiar with flax/CLU, please let us know if you have any ideas; one other place to ask could be https://github.com/google/jax regarding whether CLU would work with
jax.experimental.stax
, which is very similar tont.stax
).