Carlos Martin
Carlos Martin
Also forgot to add that the current key-sorting approach causes errors for incomparable keys: ```sh $ python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 1: None}))" TypeError: '
@yashk2810 Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level?
I agree with @anntzer's comments [here](https://github.com/matplotlib/matplotlib/pull/16221#issuecomment-575565228). The current behavior is surprising and not user-friendly.
Any update on this? I have a similar issue with [jax.jit](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html): ```sh $ python3 -c "import jax; print(jax.jit(lambda: {'b': None, 'a': None})())" {'a': None, 'b': None} ```