optax icon indicating copy to clipboard operation
optax copied to clipboard

Renaming functions in tree_utils

Open ataa1312 opened this issue 11 months ago • 6 comments

Hi, I was curious if there’s a specific reason why all the functions in tree_utils start with "tree". Are there any plans to simplify the naming, similar to how JAX has done? Thanks, Ata

ataa1312 avatar Dec 29 '24 23:12 ataa1312

Adding the tree_ prefix makes the function names a little more explicit. JAX's jax.tree module is just an alias module.

We currently don't have plans to support this, but if you'd like, we'd love a PR! We would need a file similar to https://github.com/jax-ml/jax/blob/main/jax/tree.py

Would you be interested in contributing?

rdyro avatar Jan 02 '25 10:01 rdyro

Sure! But I am using a Mac. Would that be a problem?

ataa1312 avatar Jan 02 '25 15:01 ataa1312

Not a problem at all

The general workflow of fork the repo, make changes in your own branch and then create pull request on github works!

Take a look here: https://github.com/google-deepmind/optax/blob/main/CONTRIBUTING.md for signing the Contributor License Agreement

Let me know if you have any questions!

rdyro avatar Jan 02 '25 16:01 rdyro

I created the file "tree.py" just like in JAX and renamed the functions. However, I am not sure where to put the file. Is it ok just to leave as "./optax/tree.py"? Also, should I be worried about the following message I received after running test.sh: "************* Module optax.tree optax/tree.py:3:0: C0301: Line too long (90/80) (line-too-long) optax/tree.py:14:0: W0622: Redefining built-in 'set' (redefined-builtin) optax/tree.py:21:0: W0622: Redefining built-in 'max' (redefined-builtin) optax/tree.py:21:0: W0622: Redefining built-in 'sum' (redefined-builtin) optax/tree.py:14:0: C0414: Import alias does not rename original package (useless-import-alias) optax/tree.py:5:0: W0611: Unused tree_cast imported from optax.tree_utils._casting as cast (unused-import) optax/tree.py:5:0: W0611: Unused tree_dtype imported from optax.tree_utils._casting as dtype (unused-import) optax/tree.py:9:0: W0611: Unused tree_random_like imported from optax.tree_utils._random as random_like (unused-import) optax/tree.py:9:0: W0611: Unused tree_split_key_like imported from optax.tree_utils._random as split_key_like (unused-import) optax/tree.py:14:0: W0611: Unused NamedTupleKey imported from optax.tree_utils._state_utils as NamedTupleKey (unused-import) optax/tree.py:14:0: W0611: Unused tree_get imported from optax.tree_utils._state_utils as get (unused-import) optax/tree.py:14:0: W0611: Unused tree_get_all_with_path imported from optax.tree_utils._state_utils as get_all_with_path (unused-import) optax/tree.py:14:0: W0611: Unused tree_map_params imported from optax.tree_utils._state_utils as map_params (unused-import) optax/tree.py:14:0: W0611: Unused tree_set imported from optax.tree_utils._state_utils as set (unused-import) optax/tree.py:21:0: W0611: Unused tree_add imported from optax.tree_utils._tree_math as add (unused-import) optax/tree.py:21:0: W0611: Unused tree_add_scalar_mul imported from optax.tree_utils._tree_math as add_scalar_mul (unused-import) optax/tree.py:21:0: W0611: Unused tree_bias_correction imported from optax.tree_utils._tree_math as bias_correction (unused-import) optax/tree.py:21:0: W0611: Unused tree_clip imported from optax.tree_utils._tree_math as clip (unused-import) optax/tree.py:21:0: W0611: Unused tree_conj imported from optax.tree_utils._tree_math as conj (unused-import) optax/tree.py:21:0: W0611: Unused tree_div imported from optax.tree_utils._tree_math as div (unused-import) optax/tree.py:21:0: W0611: Unused tree_full_like imported from optax.tree_utils._tree_math as full_like (unused-import) optax/tree.py:21:0: W0611: Unused tree_l1_norm imported from optax.tree_utils._tree_math as l1_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_l2_norm imported from optax.tree_utils._tree_math as l2_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_linf_norm imported from optax.tree_utils._tree_math as linf_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_max imported from optax.tree_utils._tree_math as max (unused-import) optax/tree.py:21:0: W0611: Unused tree_mul imported from optax.tree_utils._tree_math as mul (unused-import) optax/tree.py:21:0: W0611: Unused tree_ones_like imported from optax.tree_utils._tree_math as ones_like (unused-import) optax/tree.py:21:0: W0611: Unused tree_real imported from optax.tree_utils._tree_math as real (unused-import) optax/tree.py:21:0: W0611: Unused tree_scalar_mul imported from optax.tree_utils._tree_math as scalar_mul (unused-import) optax/tree.py:21:0: W0611: Unused tree_sub imported from optax.tree_utils._tree_math as sub (unused-import) optax/tree.py:21:0: W0611: Unused tree_sum imported from optax.tree_utils._tree_math as sum (unused-import) optax/tree.py:21:0: W0611: Unused tree_update_infinity_moment imported from optax.tree_utils._tree_math as update_infinity_moment (unused-import) optax/tree.py:21:0: W0611: Unused tree_update_moment imported from optax.tree_utils._tree_math as update_moment (unused-import) optax/tree.py:21:0: W0611: Unused tree_update_moment_per_elem_norm imported from optax.tree_utils._tree_math as update_moment_per_elem_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_vdot imported from optax.tree_utils._tree_math as vdot (unused-import) optax/tree.py:21:0: W0611: Unused tree_where imported from optax.tree_utils._tree_math as where (unused-import) optax/tree.py:21:0: W0611: Unused tree_zeros_like imported from optax.tree_utils._tree_math as zeros_like (unused-import)


Your code has been rated at 9.91/10

The following messages were raised:

  • warning message issued
  • convention message issued

Fatal messages detected. Failing..."?

ataa1312 avatar Jan 02 '25 18:01 ataa1312

You'll need to disable all unused-import warnings for the whole file and specific warnings per-line.

Place the tree.py file in optax/_src and import it in optax/__init__.py like so: from optax._src import tree

A quick reference: https://stackoverflow.com/questions/28829236/is-it-possible-to-ignore-one-single-specific-line-with-pylint

rdyro avatar Jan 03 '25 11:01 rdyro

Thanks for the guidance. I have created the pull request.

ataa1312 avatar Jan 03 '25 21:01 ataa1312

Done in #1306

vroulet avatar Jun 18 '25 18:06 vroulet