equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Port of PureJaxRL to equinox

Open dorjeduck opened this issue 1 year ago • 5 comments

I’ve been tinkering with porting PureJaxRL (which is Flax/Linen-based) over to Equinox. I’m still new to JAX, so my first attempt is pretty raw—if you’re curious, it’s here: https://github.com/dorjeduck/eqxrl.

I really like the philosophy behind Equinox, so thanks for making such a great library! That said, I’ve noticed the performance of my port seems slower compared to the Flax/Linen version. Is it a common thing when comparing the two libs or is it most likely due to my beginner way of implementing equinox based tools. I tried to add warmup rounds and avoid compilation during benchmark time periods but i might surely miss things here.

I haven’t tried NNx yet—it's pragmatism seems to lead it too much away from the JAX way of thinking for my taste.

Sorry for dropping this as an issue—I didn’t see a better place to ask, and it seems like others do this here too. Appreciate any advice!

dorjeduck avatar Jan 06 '25 04:01 dorjeduck

The most surprising thing about this to me at first glance is that a trivial 3 layer MLP is almost 100% slower with equinox than with flax. Even without any other comparisons, this comparison is quite small and direct (just a few lines of equinox code and a few lines of flax) and surprising (I would also recommend isolating at least that part in a MVC). I will take more of a look tomorrow.

I am also working on an equinox based RL library (https://github.com/lockwo/NARLL, I know its private right now, but when I open it up the link will work), so this is definitely of interest to me.

lockwo avatar Jan 06 '25 06:01 lockwo

Great to hear about NARLL, looking forward.

I find the benchmarks also surprising to the extend that i think there must be a mistake in the way i implemented it but havent discovered it yet.

I actually didnt intend to frame my question too much around my project as I am a JAX beginner and dont want to ask people to correct my code. What I am most interested in is other's experience on comparing the performance of equinox and flax/linen as this must be something people have looked into ...

dorjeduck avatar Jan 06 '25 06:01 dorjeduck

They should usually both get the exact same performance!

The reason for this is that they usually end up expressing pretty much the same JAX-level computation graph, and at that point it's all in the hands of the jit compiler.

Or put another way, Equinox and Flax both just help you to organise your code -- not what it compiles to.

On benchmarking - there are a couple of common mistakes that can be made here: things like measuring the cost of compilation (run your program once to compile it before timing it), or missing a jax.block_until_ready on the output.

patrick-kidger avatar Jan 06 '25 06:01 patrick-kidger

Thanks Patrick for this clarification. I use jax.block_until_ready and pre compile but there are other aspects I will have to look into. Looking forward to work with your library.

dorjeduck avatar Jan 06 '25 07:01 dorjeduck

I isolated some of the MLP specific code here: https://github.com/patrick-kidger/equinox/issues/928

lockwo avatar Jan 06 '25 23:01 lockwo