flax icon indicating copy to clipboard operation
flax copied to clipboard

lowering / cost analysis of @nnx.jit functions

Open cgarciae opened this issue 1 year ago • 1 comments

Discussed in https://github.com/google/flax/discussions/4093

Originally posted by lzanini July 19, 2024 hi, is there a way to use jax's cost analysis api for nnx.jitted functions ? typically, a training step with a module, an optimizer and a metric. it is unclear to me how to lower the inner jax.jitted function.

cgarciae avatar Jul 19 '24 14:07 cgarciae

Currently the inner jitted function is store in an .inner attribute.

f = nnx.jit(lambda x: x)

print(f.inner)  # <PjitFunction of <function jit_fn at ...>>

However, there is no easy way to lower with the exact same inputs that NNX will pass to it. We should provide some way to do this.

cgarciae avatar Jul 19 '24 14:07 cgarciae

We have an internal implementation of this which works by adding a .lower to the return from jit which follows the same pattern as the existing inner jit_wrapper function to construct the arguments to pass to the .lower of jitted_fn. Would a PR using that approach to this issue be entertained?

gabbard avatar Nov 27 '24 16:11 gabbard

@gabbard sure! Happy to review a PR.

cgarciae avatar Nov 27 '24 17:11 cgarciae

Hi, I was just wondering if this PR ever landed?

jkyl avatar Sep 28 '25 18:09 jkyl

Hi guys, I'm also wondering how can I approach it, any news?

@cgarciae @gabbard

qGentry avatar Nov 13 '25 17:11 qGentry

I've just upgraded to latest flax (0.12.0) and seems like now nnx.jitted function now has .lower exposed so I guess this issue can be marked as resolved

qGentry avatar Nov 18 '25 11:11 qGentry

We can also use pure jax transforms with nnx.Modules since 0.11. Closing as resolved.

vfdev-5 avatar Nov 18 '25 11:11 vfdev-5