jax
jax copied to clipboard
jax.jit changes the key order of returned dictionaries
Description
jax.jit changes the key order of returned dictionaries:
$ python3 -c "import jax; print(jax.jit(lambda: {'b': None, 'a': None})())"
{'a': None, 'b': None}
Dictionaries are guaranteed to be ordered since Python 3.7.
Potentially related:
- https://github.com/jax-ml/jax/issues/4085
- https://github.com/jax-ml/jax/issues/8419
- https://github.com/jax-ml/jax/pull/11871
Parenthetically, this is also true for jax.tree.map, as pointed out in the first issue above.
python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 'a': None}))"
{'a': None, 'b': None}
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 1.26.4
python: 3.12.7 (main, Oct 1 2024, 02:05:46) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro-2.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:46 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6031', machine='arm64')
I think this is working as expected.
Mainly because of 2 things:
-
we need to sort the dictionary order because if not, we will get cache misses. Example
jit(f)({'a': 1, 'b': 2})vsjit(f)({'b': 2, 'a': 1}). We should get a cache hit for both but we won't if we don't sort. -
Second and probably most important, in multi-controller JAX, if by mistake, if the orders differ, you can get hangs which is very bad and much harder to debug. So if dictionaries are sorted, this problem just doesn't occur by construction.
Also forgot to add that the current key-sorting approach causes errors for incomparable keys:
$ python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 1: None}))"
TypeError: '<' not supported between instances of 'int' and 'str'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 342, in tree_map
leaves, treedef = tree_flatten(tree, is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 79, in tree_flatten
return default_registry.flatten(tree, is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Comparator raised exception while sorting pytree dictionary keys.
xref #15358 for the unorderable keys issue.
@yashk2810 Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level?
Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level?
FYI:
- https://github.com/jax-ml/jax/issues/4085#issuecomment-1475788704
One solution is to store the input dict keys in insertion order in
Nodeduringflatten, and update thePyTreeDef.unflattenmethod to respect the key order while reconstructing the output pytree.leaves, treedef = jax.tree_util.tree_flatten({'b': 2, 'a': 1}) leaves # [1, 2] treedef # PyTreeDef({'a': *, 'b': *}) treedef.unflatten([11, 22]) # {'b': 22, 'a': 11} # respect original key orderRef:
- metaopt/optree#45
- metaopt/optree#46
+1 for this issue, all other tree manipulation libraries keep the key order (tf.nest, dm-tree, optree,...).