keras icon indicating copy to clipboard operation
keras copied to clipboard

Fix handling of `OrderedDict` with optree and related documentation.

Open hertschuh opened this issue 1 year ago • 3 comments

The tree API had specific but contradicting documentation calling out the handling of OrderedDicts. However, the behavior of the optree implementation did not honor this documentation (using the key order, not the sequence order) for flatten, although it did for pack_sequence_as. The result was that not only did flatten not behave the same with optree and dm-tree, but also pack_sequence_as(flatten(...)) was not idempotent. The optree implementation did have all the machinery needed to handle OrderedDicts per spec, which was used for pack_sequence_as, but not flatten. This also fixes the discrepancy in the behavior for namedtuples.

  • Fixed contradicting documentation in flatten and pack_sequence_as related to the handling of OrderedDicts.
  • Fixed references to unflatten_as, which doesn't exist.
  • Removed most if optree tests in tree_test.py, which should not exist for consistency between optree and dm-tree.
  • Fixed unit tests which were incorrectly flattening the result of flatten_with_path.
  • Fixed unintented use of tree instead of keras.tree in unit test.
  • Ran unit tests for all backends with dm-tree uninstalled.

hertschuh avatar Nov 11 '24 19:11 hertschuh

Codecov Report

Attention: Patch coverage is 93.25397% with 17 lines in your changes missing coverage. Please review.

Project coverage is 82.30%. Comparing base (461fbf3) to head (012e336).

Files with missing lines Patch % Lines
keras/src/tree/dmtree_impl.py 96.92% 3 Missing and 3 partials :warning:
keras/src/tree/optree_impl.py 81.81% 3 Missing and 3 partials :warning:
keras/src/tree/tree_api.py 66.66% 3 Missing and 1 partial :warning:
keras/api/_tf_keras/keras/tree/__init__.py 0.00% 1 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20481      +/-   ##
==========================================
+ Coverage   82.22%   82.30%   +0.07%     
==========================================
  Files         515      515              
  Lines       48166    48222      +56     
  Branches     7527     7540      +13     
==========================================
+ Hits        39604    39688      +84     
+ Misses       6744     6720      -24     
+ Partials     1818     1814       -4     
Flag Coverage Δ
keras 82.14% <92.85%> (+0.07%) :arrow_up:
keras-jax 65.20% <87.69%> (+0.07%) :arrow_up:
keras-numpy 60.21% <87.30%> (+0.08%) :arrow_up:
keras-tensorflow 66.19% <92.85%> (+0.09%) :arrow_up:
keras-torch 65.14% <87.69%> (+0.08%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Nov 11 '24 19:11 codecov-commenter

Wouldn't it be better to wrap OrderedDict rather than re-implement flatten, which is written in C++ in optree?

e.g.:

from collections import OrderedDict
import optree
from keras import tree

class WrappedOrderedDict(OrderedDict):
  pass

def flatten(d):
  values = []
  keys = []
  for key in sorted(d.keys()):
    values.append(d[key])
    keys.append(key)
  return values, list(d.keys()), keys

def unflatten(metadata, children):
  index = {key: i for i, key in enumerate(sorted(metadata))}
  return OrderedDict({key: children[index[key]] for key in metadata})


optree.register_pytree_node(
    WrappedOrderedDict,
    flatten,
    unflatten,
    namespace="keras",
)


def ordereddict_pytree_test():
  # Create an OrderedDict with deliberately unsorted keys
  ordered_d = OrderedDict([('c', 3), ('a', 1), ('b', 2)])
  
  def wrap(s):
    if isinstance(s, OrderedDict):
      return WrappedOrderedDict(s)
    return None

  def unwrap(s):
    if isinstance(s, WrappedOrderedDict):
      return OrderedDict(s)
    return None

  d = tree.traverse(wrap, ordered_d, top_down=False)
  flat_d = tree.flatten(d)
  flat_d_paths = tree.flatten_with_path(d)
  assert flat_d == [1, 2, 3]
  assert [p[0] for p, v in flat_d_paths] == ["a", "b", "c"]
  
  tree_struct = tree.traverse(wrap, ordered_d, top_down=False)
  wrapped_d = tree.pack_sequence_as(tree_struct, flat_d)
  orig_d = tree.traverse(unwrap, wrapped_d, top_down=False)
  assert isinstance(orig_d, OrderedDict)
  assert list(orig_d.keys()) == list(ordered_d.keys())
  assert list(orig_d.values()) == list(ordered_d.values())


ordereddict_pytree_test()

nicolaspi avatar Nov 18 '24 01:11 nicolaspi

Wouldn't it be better to wrap OrderedDict rather than re-implement flatten, which is written in C++ in optree?

Hi Nicolas,

Thank you for the suggestion. I actually completely scratched this PR and decided to use a different approach. The optree behavior will be the reference behavior. The goal is indeed to maximize the use of the C++ implementation of optree since it is the default and dm-tree is only a fallback.

hertschuh avatar Nov 18 '24 18:11 hertschuh