flax icon indicating copy to clipboard operation
flax copied to clipboard

Return key from `traverse_util.Traversal.iterate()`

Open andsteing opened this issue 3 years ago • 6 comments

Currently traverse_util.Traversal.iterate() only returns the traversed values. In some cases we need access to the traversed keys as well, for example when checking that two ModelParamTraversal do not overlap in their keyspace (#1135).

Thus, the interface should be updated to

  @abc.abstractmethod
  def iterate(self, inputs):
    """Iterate over the values selected by this `Traversal`.

    Args:
      inputs: the object that should be traversed.
    Returns:
      An iterator over the traversed ``(key, value)``.
    """
    pass

And all uses of traverse_util.Traversal.iterate() updated accordingly.

andsteing avatar Mar 19 '21 20:03 andsteing

Can I work on this?

Dsantra92 avatar Dec 04 '21 19:12 Dsantra92

@Dsantra92 sure, feel free to give it a try. We would have to run it against all our internal tests as well, but we can do that once all our public tests pass.

marcvanzee avatar Dec 06 '21 09:12 marcvanzee

@marcvanzee To the best of my knowledge, the use of flax.traverese_utils.Traversal is now deprecated for flax.optim. So is there still the need of updating the flax.traverese_utils API? If yes, I would like to work on it.

Also, will it be fine if my fix changed the signature of iterate to:

  @abc.abstractmethod
  def iterate(self, inputs, return_paths):
    """Iterate over the values selected by this `Traversal`.

    Args:
      inputs: the object that should be traversed.
      return_paths: bool. whether to return (key, value) pair.
    Returns:
      An iterator over the traversed ``(key, value)``.
    """
    pass

This will ensure that all the existing implementations if any, will still work after the fix. (Just a try to avoid backwards-incompatibility) The default value of return_paths for every overridden method will be False

mayureshagashe2105 avatar Aug 21 '22 18:08 mayureshagashe2105

@marcvanzee To the best of my knowledge, the use of flax.traverese_utils.Traversal is now deprecated for flax.optim. So is there still the need of updating the flax.traverese_utils API? If yes, I would like to work on it.

Also, will it be fine if my fix changed the signature of iterate to:

  @abc.abstractmethod
  def iterate(self, inputs, return_paths):
    """Iterate over the values selected by this `Traversal`.

    Args:
      inputs: the object that should be traversed.
      return_paths: bool. whether to return (key, value) pair.
    Returns:
      An iterator over the traversed ``(key, value)``.
    """
    pass

This will ensure that all the existing implementations if any, will still work after the fix. (Just a try to avoid backwards-incompatibility) The default value of return_paths for every overridden method will be False

@cgarciae, is this issue still relevant ?

SauravMaheshkar avatar Aug 02 '23 07:08 SauravMaheshkar

I've never used that API so it's hard to comment. @andsteing WDYT?

cgarciae avatar Aug 04 '23 15:08 cgarciae