jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.jit changes the key order of returned dictionaries

Open carlosgmartin opened this issue 1 year ago • 5 comments

Description

jax.jit changes the key order of returned dictionaries:

$ python3 -c "import jax; print(jax.jit(lambda: {'b': None, 'a': None})())"
{'a': None, 'b': None}

Dictionaries are guaranteed to be ordered since Python 3.7.

Potentially related:

  • https://github.com/jax-ml/jax/issues/4085
  • https://github.com/jax-ml/jax/issues/8419
  • https://github.com/jax-ml/jax/pull/11871

Parenthetically, this is also true for jax.tree.map, as pointed out in the first issue above.

python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 'a': None}))"
{'a': None, 'b': None}

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.12.7 (main, Oct  1 2024, 02:05:46) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro-2.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:46 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6031', machine='arm64')

carlosgmartin avatar Oct 18 '24 23:10 carlosgmartin

I think this is working as expected.

Mainly because of 2 things:

  • we need to sort the dictionary order because if not, we will get cache misses. Example jit(f)({'a': 1, 'b': 2}) vs jit(f)({'b': 2, 'a': 1}). We should get a cache hit for both but we won't if we don't sort.

  • Second and probably most important, in multi-controller JAX, if by mistake, if the orders differ, you can get hangs which is very bad and much harder to debug. So if dictionaries are sorted, this problem just doesn't occur by construction.

yashk2810 avatar Oct 18 '24 23:10 yashk2810

Also forgot to add that the current key-sorting approach causes errors for incomparable keys:

$ python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 1: None}))"
TypeError: '<' not supported between instances of 'int' and 'str'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 342, in tree_map
    leaves, treedef = tree_flatten(tree, is_leaf)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 79, in tree_flatten
    return default_registry.flatten(tree, is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Comparator raised exception while sorting pytree dictionary keys.

carlosgmartin avatar Oct 18 '24 23:10 carlosgmartin

xref #15358 for the unorderable keys issue.

jakevdp avatar Oct 19 '24 02:10 jakevdp

@yashk2810 Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level?

carlosgmartin avatar Oct 19 '24 06:10 carlosgmartin

Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level?

FYI:

  • https://github.com/jax-ml/jax/issues/4085#issuecomment-1475788704

One solution is to store the input dict keys in insertion order in Node during flatten, and update the PyTreeDef.unflatten method to respect the key order while reconstructing the output pytree.

leaves, treedef = jax.tree_util.tree_flatten({'b': 2, 'a': 1})
leaves   # [1, 2]
treedef  # PyTreeDef({'a': *, 'b': *})
treedef.unflatten([11, 22])  # {'b': 22, 'a': 11} # respect original key order

Ref:

  • metaopt/optree#45
  • metaopt/optree#46

XuehaiPan avatar Oct 19 '24 17:10 XuehaiPan

+1 for this issue, all other tree manipulation libraries keep the key order (tf.nest, dm-tree, optree,...).

Conchylicultor avatar Oct 28 '24 10:10 Conchylicultor