flax
flax copied to clipboard
Return key from `traverse_util.Traversal.iterate()`
Currently traverse_util.Traversal.iterate()
only returns the traversed values. In some cases we need access to the traversed keys as well, for example when checking that two ModelParamTraversal
do not overlap in their keyspace (#1135).
Thus, the interface should be updated to
@abc.abstractmethod
def iterate(self, inputs):
"""Iterate over the values selected by this `Traversal`.
Args:
inputs: the object that should be traversed.
Returns:
An iterator over the traversed ``(key, value)``.
"""
pass
And all uses of traverse_util.Traversal.iterate()
updated accordingly.
Can I work on this?
@Dsantra92 sure, feel free to give it a try. We would have to run it against all our internal tests as well, but we can do that once all our public tests pass.
@marcvanzee To the best of my knowledge, the use of flax.traverese_utils.Traversal
is now deprecated for flax.optim
. So is there still the need of updating the flax.traverese_utils
API? If yes, I would like to work on it.
Also, will it be fine if my fix changed the signature of iterate
to:
@abc.abstractmethod
def iterate(self, inputs, return_paths):
"""Iterate over the values selected by this `Traversal`.
Args:
inputs: the object that should be traversed.
return_paths: bool. whether to return (key, value) pair.
Returns:
An iterator over the traversed ``(key, value)``.
"""
pass
This will ensure that all the existing implementations if any, will still work after the fix. (Just a try to avoid backwards-incompatibility) The default value of return_paths
for every overridden method will be False
@marcvanzee To the best of my knowledge, the use of
flax.traverese_utils.Traversal
is now deprecated forflax.optim
. So is there still the need of updating theflax.traverese_utils
API? If yes, I would like to work on it.Also, will it be fine if my fix changed the signature of
iterate
to:@abc.abstractmethod def iterate(self, inputs, return_paths): """Iterate over the values selected by this `Traversal`. Args: inputs: the object that should be traversed. return_paths: bool. whether to return (key, value) pair. Returns: An iterator over the traversed ``(key, value)``. """ pass
This will ensure that all the existing implementations if any, will still work after the fix. (Just a try to avoid backwards-incompatibility) The default value of
return_paths
for every overridden method will beFalse
@cgarciae, is this issue still relevant ?
I've never used that API so it's hard to comment. @andsteing WDYT?