Renaming functions in tree_utils
Hi, I was curious if there’s a specific reason why all the functions in tree_utils start with "tree". Are there any plans to simplify the naming, similar to how JAX has done? Thanks, Ata
Adding the tree_ prefix makes the function names a little more explicit. JAX's jax.tree module is just an alias module.
We currently don't have plans to support this, but if you'd like, we'd love a PR! We would need a file similar to https://github.com/jax-ml/jax/blob/main/jax/tree.py
Would you be interested in contributing?
Sure! But I am using a Mac. Would that be a problem?
Not a problem at all
The general workflow of fork the repo, make changes in your own branch and then create pull request on github works!
Take a look here: https://github.com/google-deepmind/optax/blob/main/CONTRIBUTING.md for signing the Contributor License Agreement
Let me know if you have any questions!
I created the file "tree.py" just like in JAX and renamed the functions. However, I am not sure where to put the file. Is it ok just to leave as "./optax/tree.py"? Also, should I be worried about the following message I received after running test.sh: "************* Module optax.tree optax/tree.py:3:0: C0301: Line too long (90/80) (line-too-long) optax/tree.py:14:0: W0622: Redefining built-in 'set' (redefined-builtin) optax/tree.py:21:0: W0622: Redefining built-in 'max' (redefined-builtin) optax/tree.py:21:0: W0622: Redefining built-in 'sum' (redefined-builtin) optax/tree.py:14:0: C0414: Import alias does not rename original package (useless-import-alias) optax/tree.py:5:0: W0611: Unused tree_cast imported from optax.tree_utils._casting as cast (unused-import) optax/tree.py:5:0: W0611: Unused tree_dtype imported from optax.tree_utils._casting as dtype (unused-import) optax/tree.py:9:0: W0611: Unused tree_random_like imported from optax.tree_utils._random as random_like (unused-import) optax/tree.py:9:0: W0611: Unused tree_split_key_like imported from optax.tree_utils._random as split_key_like (unused-import) optax/tree.py:14:0: W0611: Unused NamedTupleKey imported from optax.tree_utils._state_utils as NamedTupleKey (unused-import) optax/tree.py:14:0: W0611: Unused tree_get imported from optax.tree_utils._state_utils as get (unused-import) optax/tree.py:14:0: W0611: Unused tree_get_all_with_path imported from optax.tree_utils._state_utils as get_all_with_path (unused-import) optax/tree.py:14:0: W0611: Unused tree_map_params imported from optax.tree_utils._state_utils as map_params (unused-import) optax/tree.py:14:0: W0611: Unused tree_set imported from optax.tree_utils._state_utils as set (unused-import) optax/tree.py:21:0: W0611: Unused tree_add imported from optax.tree_utils._tree_math as add (unused-import) optax/tree.py:21:0: W0611: Unused tree_add_scalar_mul imported from optax.tree_utils._tree_math as add_scalar_mul (unused-import) optax/tree.py:21:0: W0611: Unused tree_bias_correction imported from optax.tree_utils._tree_math as bias_correction (unused-import) optax/tree.py:21:0: W0611: Unused tree_clip imported from optax.tree_utils._tree_math as clip (unused-import) optax/tree.py:21:0: W0611: Unused tree_conj imported from optax.tree_utils._tree_math as conj (unused-import) optax/tree.py:21:0: W0611: Unused tree_div imported from optax.tree_utils._tree_math as div (unused-import) optax/tree.py:21:0: W0611: Unused tree_full_like imported from optax.tree_utils._tree_math as full_like (unused-import) optax/tree.py:21:0: W0611: Unused tree_l1_norm imported from optax.tree_utils._tree_math as l1_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_l2_norm imported from optax.tree_utils._tree_math as l2_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_linf_norm imported from optax.tree_utils._tree_math as linf_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_max imported from optax.tree_utils._tree_math as max (unused-import) optax/tree.py:21:0: W0611: Unused tree_mul imported from optax.tree_utils._tree_math as mul (unused-import) optax/tree.py:21:0: W0611: Unused tree_ones_like imported from optax.tree_utils._tree_math as ones_like (unused-import) optax/tree.py:21:0: W0611: Unused tree_real imported from optax.tree_utils._tree_math as real (unused-import) optax/tree.py:21:0: W0611: Unused tree_scalar_mul imported from optax.tree_utils._tree_math as scalar_mul (unused-import) optax/tree.py:21:0: W0611: Unused tree_sub imported from optax.tree_utils._tree_math as sub (unused-import) optax/tree.py:21:0: W0611: Unused tree_sum imported from optax.tree_utils._tree_math as sum (unused-import) optax/tree.py:21:0: W0611: Unused tree_update_infinity_moment imported from optax.tree_utils._tree_math as update_infinity_moment (unused-import) optax/tree.py:21:0: W0611: Unused tree_update_moment imported from optax.tree_utils._tree_math as update_moment (unused-import) optax/tree.py:21:0: W0611: Unused tree_update_moment_per_elem_norm imported from optax.tree_utils._tree_math as update_moment_per_elem_norm (unused-import) optax/tree.py:21:0: W0611: Unused tree_vdot imported from optax.tree_utils._tree_math as vdot (unused-import) optax/tree.py:21:0: W0611: Unused tree_where imported from optax.tree_utils._tree_math as where (unused-import) optax/tree.py:21:0: W0611: Unused tree_zeros_like imported from optax.tree_utils._tree_math as zeros_like (unused-import)
Your code has been rated at 9.91/10
The following messages were raised:
- warning message issued
- convention message issued
Fatal messages detected. Failing..."?
You'll need to disable all unused-import warnings for the whole file and specific warnings per-line.
Place the tree.py file in optax/_src and import it in optax/__init__.py like so: from optax._src import tree
A quick reference: https://stackoverflow.com/questions/28829236/is-it-possible-to-ignore-one-single-specific-line-with-pylint
Thanks for the guidance. I have created the pull request.
Done in #1306