brax icon indicating copy to clipboard operation
brax copied to clipboard

Support for Flax NNX API

Open scott-yj-yang opened this issue 11 months ago • 1 comments

Hi Brax development team,

We are currently using the brax.training.ppo training scripts in our project and have been very happy with Brax's performance.

We noticed that the brax.training module uses the older flax.linen API for neural network definitions. Flax recently introduced the flax.nnx API, which offers a more Pythonic and streamlined approach to model development.

Are there any plans to transition the brax repository to use the nnx API in the future? This would help improve model flexibility and maintainability for projects that depend on brax.training.

Thank you for your time and consideration.

Scott

scott-yj-yang avatar Dec 04 '24 22:12 scott-yj-yang

Hi @scott-yj-yang there aren't immediate plans to switch to nnx as flax.linen isn't being deprecated, but we'd be happy to review clean/minimal PRs. nnx does generally look like a cleaner API.

@erikfrey any thoughts?

btaba avatar Dec 04 '24 22:12 btaba