Fix handling of `OrderedDict` with optree and related documentation.
The tree API had specific but contradicting documentation calling out the handling of OrderedDicts. However, the behavior of the optree implementation did not honor this documentation (using the key order, not the sequence order) for flatten, although it did for pack_sequence_as. The result was that not only did flatten not behave the same with optree and dm-tree, but also pack_sequence_as(flatten(...)) was not idempotent. The optree implementation did have all the machinery needed to handle OrderedDicts per spec, which was used for pack_sequence_as, but not flatten. This also fixes the discrepancy in the behavior for namedtuples.
- Fixed contradicting documentation in
flattenandpack_sequence_asrelated to the handling ofOrderedDicts. - Fixed references to
unflatten_as, which doesn't exist. - Removed most
if optreetests intree_test.py, which should not exist for consistency betweenoptreeanddm-tree. - Fixed unit tests which were incorrectly flattening the result of
flatten_with_path. - Fixed unintented use of
treeinstead ofkeras.treein unit test. - Ran unit tests for all backends with
dm-treeuninstalled.
Codecov Report
Attention: Patch coverage is 93.25397% with 17 lines in your changes missing coverage. Please review.
Project coverage is 82.30%. Comparing base (
461fbf3) to head (012e336).
Additional details and impacted files
@@ Coverage Diff @@
## master #20481 +/- ##
==========================================
+ Coverage 82.22% 82.30% +0.07%
==========================================
Files 515 515
Lines 48166 48222 +56
Branches 7527 7540 +13
==========================================
+ Hits 39604 39688 +84
+ Misses 6744 6720 -24
+ Partials 1818 1814 -4
| Flag | Coverage Δ | |
|---|---|---|
| keras | 82.14% <92.85%> (+0.07%) |
:arrow_up: |
| keras-jax | 65.20% <87.69%> (+0.07%) |
:arrow_up: |
| keras-numpy | 60.21% <87.30%> (+0.08%) |
:arrow_up: |
| keras-tensorflow | 66.19% <92.85%> (+0.09%) |
:arrow_up: |
| keras-torch | 65.14% <87.69%> (+0.08%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Wouldn't it be better to wrap OrderedDict rather than re-implement flatten, which is written in C++ in optree?
e.g.:
from collections import OrderedDict
import optree
from keras import tree
class WrappedOrderedDict(OrderedDict):
pass
def flatten(d):
values = []
keys = []
for key in sorted(d.keys()):
values.append(d[key])
keys.append(key)
return values, list(d.keys()), keys
def unflatten(metadata, children):
index = {key: i for i, key in enumerate(sorted(metadata))}
return OrderedDict({key: children[index[key]] for key in metadata})
optree.register_pytree_node(
WrappedOrderedDict,
flatten,
unflatten,
namespace="keras",
)
def ordereddict_pytree_test():
# Create an OrderedDict with deliberately unsorted keys
ordered_d = OrderedDict([('c', 3), ('a', 1), ('b', 2)])
def wrap(s):
if isinstance(s, OrderedDict):
return WrappedOrderedDict(s)
return None
def unwrap(s):
if isinstance(s, WrappedOrderedDict):
return OrderedDict(s)
return None
d = tree.traverse(wrap, ordered_d, top_down=False)
flat_d = tree.flatten(d)
flat_d_paths = tree.flatten_with_path(d)
assert flat_d == [1, 2, 3]
assert [p[0] for p, v in flat_d_paths] == ["a", "b", "c"]
tree_struct = tree.traverse(wrap, ordered_d, top_down=False)
wrapped_d = tree.pack_sequence_as(tree_struct, flat_d)
orig_d = tree.traverse(unwrap, wrapped_d, top_down=False)
assert isinstance(orig_d, OrderedDict)
assert list(orig_d.keys()) == list(ordered_d.keys())
assert list(orig_d.values()) == list(ordered_d.values())
ordereddict_pytree_test()
Wouldn't it be better to wrap
OrderedDictrather than re-implementflatten, which is written in C++ in optree?
Hi Nicolas,
Thank you for the suggestion. I actually completely scratched this PR and decided to use a different approach. The optree behavior will be the reference behavior. The goal is indeed to maximize the use of the C++ implementation of optree since it is the default and dm-tree is only a fallback.