flax
flax copied to clipboard
lowering / cost analysis of @nnx.jit functions
Discussed in https://github.com/google/flax/discussions/4093
Originally posted by lzanini July 19, 2024 hi, is there a way to use jax's cost analysis api for nnx.jitted functions ? typically, a training step with a module, an optimizer and a metric. it is unclear to me how to lower the inner jax.jitted function.
Currently the inner jitted function is store in an .inner attribute.
f = nnx.jit(lambda x: x)
print(f.inner) # <PjitFunction of <function jit_fn at ...>>
However, there is no easy way to lower with the exact same inputs that NNX will pass to it. We should provide some way to do this.
We have an internal implementation of this which works by adding a .lower to the return from jit which follows the same pattern as the existing inner jit_wrapper function to construct the arguments to pass to the .lower of jitted_fn. Would a PR using that approach to this issue be entertained?
@gabbard sure! Happy to review a PR.
Hi, I was just wondering if this PR ever landed?
Hi guys, I'm also wondering how can I approach it, any news?
@cgarciae @gabbard
I've just upgraded to latest flax (0.12.0) and seems like now nnx.jitted function now has .lower exposed so I guess this issue can be marked as resolved
We can also use pure jax transforms with nnx.Modules since 0.11. Closing as resolved.