equinox icon indicating copy to clipboard operation
equinox copied to clipboard

mirror `dataclasses.replace` behaviour?

Open epignatelli opened this issue 2 years ago • 12 comments

Hi Patrick, thanks for equinox -- it amazing as it is really easy to jump in without much prior knowledge. I found some patters very clear, but harder to jump in without prior knowledge.

equinox.tree_at is one of these. Would you consider mirroring the canonical dataclass interface?

For example,

    def replace(self, **kwargs):
        values = [kwargs[k] for k in kwargs]
        return eqn.tree_at(lambda x: [getattr(x, k) for k in kwargs], self, values)

I am sure there can be something more efficient than this, but it is enough to explain the idea.

epignatelli avatar Jul 17 '23 07:07 epignatelli

Right! Whilst the lambda approach might be a completely general way to navigate pytrees, in practice 99% of use-cases just involve navigating via item-getting and attr-getting, and folks seem to find the lambda approach a bit confusing.

I'm wondering about generalising what you've written to something like:

class tree_modify:
    def __init__(self, tree, path=()):
        self.tree = tree
        self.path = path

    def __getitem__(self, item):
        return tree_modify(self.tree, self.path + ("getitem", item)))

    def __getattr__(self, item):
        return tree_modify(self.tree, self.path + ("getattr", item))

    def __lshift__(self, value):
        def _get(tree):
            for kind, item in self.path:
                if kind == "getitem":
                    tree = tree[item]
                elif kind == "getattr":
                    tree = getattr(tree, item)
                else:
                    assert False
            return tree
        return eqx.tree_at(_get, self.tree, value)

which you can then use as:

mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
new_mlp = eqx.tree_modify(mlp).layers[-1] << new_linear

Although maybe the << syntax is too magic?

Some kind of nicer interface seems to make sense, anyway.

patrick-kidger avatar Jul 17 '23 21:07 patrick-kidger

I actually like the lambda approach, once I got how it works! It is just takes longer to get than replace does (to me at least), and because equinox's strength is a low learning curve, I think replace might help.

Yeah, I think using lshift might be not very probable pattern in the ML python user population. What about resorting to jax/python indexing to generalise the above?

import equinox as eqx
from jax import Array
import jax.numpy as jnp


class Foo(eqx.Module):
  bar: Array

  def replace(self, **kwargs):
      values = [kwargs[k] for k in kwargs]
      return eqn.tree_at(lambda x: [getattr(x, k) for k in kwargs], self, values)

bar = jnp.zeros((2,))
foo = Foo(bar=bar)

baz = jnp.asarray(1)
baz = bar.at[0].set(baz)
foo = foo.replace(bar=baz)
print(foo.bar)
# Array([1., 0.], dtype=float32)

We can either leave the item selection stragety to the user or generalise replace to use jax vs python indexing according to whether the field is dynamic or static.

epignatelli avatar Jul 18 '23 07:07 epignatelli

Another idea might be to mirror and extend jax indexing helper to pytrees.

E.g.,

mlp.at[mlp.layers[-1]].set(new_linear)

or even, because I don't like the mlp repetition, and I am not sure how to get the index path out of this,

mlp.at['layers', -1].set(new_linear)

But I'd find it cumbersone to combine indices of different type.

Perhaps this might even reuse eqx.tree_at with its logic baked in the IndexingHelper:

class Module:
  @property
  def at(self):
    return eqx.IndexingHelper

What do you think?

epignatelli avatar Jul 18 '23 07:07 epignatelli

Check here for generalized concept based on lenses, and here on how to integrate with flax/equinox (or any registered pytree)

ASEM000 avatar Jul 18 '23 09:07 ASEM000

Haha, I actually remembered and took a look at your work @ASEM000 when I was writing this. IIUC, there's no way to distinguish between x["foo"] and x.foo: a string always refers to the latter, but this makes it impossible to look up values in a dict[str, Any]. I think @epignatelli's mlp.at['layers', -1].set(new_linear) has the same issue.

I would say the main two criteria are:

  • this should be a free function, not a method. Right now Modules are never special-cased relative to any other pytree -- in particular, they have no extra methods.
  • this should offer a general way to mutate nested leaves, not just at depth-1 like the replace function.

patrick-kidger avatar Jul 18 '23 12:07 patrick-kidger

The example below should answer the questions, notably:

  1. it can work on dict/class. the distinction between dict/class attribute is dealt with on the path entry level defined using jax registration, and accessed using tree_map_with_path.
  2. you don't need a method, you can use pytc.AtIndexer(pytree, where=...)
  3. allows nesting,
  4. it's not special cased to any pytree. as long as a pytree defined using jax key paths (like equinox module). you can check the source
  5. you can do much more, as outlined in the links above.

If you have any questions let me know.


import pytreeclass as pytc
import equinox as eqx
from typing import Any


class Tree(eqx.Module):
    some_dict: dict[str, Any]
    some_list: list[int]


tree = Tree(some_dict={"a": 1, "b": [1, 2, 3]}, some_list=[1, 2, 3])

print(pytc.AtIndexer(tree)["some_dict"]["a"].get())
# Tree(some_dict={'a': 1, 'b': [None, None, None]}, some_list=[None, None, None])

print(pytc.AtIndexer(tree)["some_dict"]["b"][0].get())
# Tree(some_dict={'a': None, 'b': [1, None, None]}, some_list=[None, None, None])

ASEM000 avatar Jul 18 '23 14:07 ASEM000

As a user, I think I'd prefer the lambda approach to the chain of at because it is more concise and direct.

The original proposal for the replace was not the replace the existing tree_at, but to expose the dataclasses interface to equinox.Module instead.

The reason is to remove the knowledge requirement for the lambda when a user first approaches equinox, and let them scale up to the other approach at their pace.

Also replace would and lambda would be tailored to different use cases, despite their functionalities also overlapping. I'dreplace for quick, intuitive, and large updates, and tree_at for more general, and more surgical updates.

epignatelli avatar Jul 19 '23 09:07 epignatelli

As a user, I think I'd prefer the lambda approach to the chain of at because it is more concise and direct.

You can try this if you do not prefer chained .at, at is just syntactic sugar.


import pytreeclass as pytc
import equinox as eqx
from typing import Any

class Tree(eqx.Module):
    some_dict: dict[str, Any]
    some_list: list[int]


tree = Tree(some_dict={"a": 1, "b": [1, 2, 3]}, some_list=[1, 2, 3])

print(pytc.AtIndexer(tree, where=["some_dict", "a"]).get())
# Tree(some_dict={'a': 1, 'b': [None, None, None]}, some_list=[None, None, None])

print(pytc.AtIndexer(tree, where=["some_dict", "b", 0]).get())
# Tree(some_dict={'a': None, 'b': [1, None, None]}, some_list=[None, None, None])

  • Among the objectives of PyTreeClass is to combine oop+fp concepts for intuitive yet safe nested structure (pytrees) manipulation. The .at/AtIndexer is similar to Haskell lenses which is a popular way to do such an operation functionally.

  • I think a good example to compare the two approaches, is to try to replicate the example in the readme here using tree_at.

If you have any questions let me know.

ASEM000 avatar Jul 19 '23 11:07 ASEM000

Thanks for the example, @ASEM000. The feature request had different objectives, though -- see my comment above.

epignatelli avatar Jul 19 '23 23:07 epignatelli

Hey @epignatelli I have had similar thoughts to you about the tree_at syntax. My main issue was how to deal with highly nested classes, manipulating multiple nested parameters simultaneously and how to avoid lambda function entirely as my target users are scientists who don't necessarily have a programming background so I'm trying to keep the API as simple as possible.

Anyway I ended up coming up with Zodiax as an extension to equinox which has a similar syntax, but geared towards scientific models. Updating is done like so: foo = foo.set('x', value), foo = foo.multiply('x', value). It has it's own downsides, but does provide a simplified syntactical interface for pytree manipulation if you're curious!

LouisDesdoigts avatar Aug 04 '23 08:08 LouisDesdoigts

Thanks @LouisDesdoigts !

I honestly think that the approach of equinox is the best approach for deep and complex updates.

This feature request had a different purpose: to expose the dataclasses.replace method for A) a gentler transition to equinox (which is the main reason behind its philosophy, if I understand correctly), and B) better syntax only for simpler and shallower updates

The proposal was for the two approaches to coexist, not for one to replace the other.

@patrick-kidger feel free to close this if this is not happening and it helps you keeps things tidy!

epignatelli avatar Aug 09 '23 11:08 epignatelli

Right now I don't have any strong feelings either way! I generally prefer not to have simple-use-case-only versions of things, just because it helps upskill novice users into the more general way of doing things. But that's not a hard-and-fast rule.

patrick-kidger avatar Aug 09 '23 13:08 patrick-kidger