flax icon indicating copy to clipboard operation
flax copied to clipboard

Should Flax return FrozenDicts or regular dicts?

Open marcvanzee opened this issue 3 years ago • 23 comments

This topic is discussed regularly internally, and I feel we haven't reached a consensus here. Below are some arguments collected from users for both positions, feel free to add.

Arguments in favor of FrozenDict

  • @avital: If you use normal dicts, it is easy to mutate them, which means the behavior may differ depending on whether the function in which the modification is made is jitted or not. Example:
def f(params):
  params['conv1']['weight'] = ...
  return ...some computation over params

params = load_from_checkpoint()
print(f(params))
# now what is the value of params['conv1']['weight']?
# depending on whether f is jitted or not, you'd get different results

Arguments in favor of regular dicts

  • @lucasb-eyer: Flax tells me "here's these precious weights, please hold them for me and give them back to me later on, but DONT TOUCH" it begs the question: why give them to me in the first place, if I'm not supposed to do anything with it?

  • @avital: I also think it'd be better for Flax to return normal Python dicts, but still use FrozenDict within modules (via the mutable argument to apply).

marcvanzee avatar Apr 08 '21 12:04 marcvanzee

I think the Python saying "We're all consenting adults here" is pretty fitting. In my view, trading convenience for safety is reasonable here because JAX users should know (or will quickly come to learn) that under the JAX transformations, they should not mutate state. Since FrozenDicts are not as ergonomic as normal dicts, I tend to unfreeze them as soon as they're returned from init anyway.

Though I would prefer the user-facing API to just use dict, I wouldn't mind FrozenDict if the behaviour was closer to dict for non-mutating cases. In particular,

  • The ability to pass FrozenDict to flax.traverse_util.flatten_dict.
  • The views FrozenDict.keys() and FrozenDict.values() should have a similar style __repr__ to normal dicts so they are easy to inspect interactively in notebooks (right now they just show the whole FrozenDict)
  • The ability to update/merge with normal dicts

Explicit state management is one of my favourite aspects of Flax, as it gives me the ability to transparently manipulate modules/parameters without worrying about hidden side effects. I totally agree with @lucasb-eyer's point that it's counterproductive to provide explicit state without allowing the user to fully control it.

n2cholas avatar Apr 12 '21 01:04 n2cholas

I think the Python saying "We're all consenting adults here" is pretty fitting

Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.

That said, I don't think FrozenDict has shown to be a very effective tool safety tool to avoid this kind of error. We should probably keep using it internally to avoid accidental reference sharing but for users it seems to big a burden while it doesn't avoid the more common issue of closing over mutable state (typically created by the user) or using things like np.random in a jitted function.

I do think we should at least provide an easy way to clone a pytree if we allow it to contain mutable containers. Something like the following:

def clone_pytree(xs):
  # cloning is just an identity mapping
  return jax.tree_map(lambda x: x, xs)

def some_nested_transformation():
  my_copy = flax.traverse_util.clone_pytree(variables)
  my_copy['batch_stats']['x'] += 2.
  return my_copy

Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.

jheek avatar Apr 13 '21 13:04 jheek

Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.

This is a good point, especially as a codebase grows it can sneak past you. ​Personally, if FrozenDict better matched the ergonomics of dict outside of mutation, I would not see it as a burden, at least for my own use cases.

Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.

I actually quite like the immutable dataclasses, since the .replace(...) API is similar namedtuples. In my view, the inconveniences that arise with FrozenDict don't happen here since dataclasses don't have arbitrary structure and you don't generally manipulate the that structure.

n2cholas avatar Apr 13 '21 16:04 n2cholas

Thanks for the input @n2cholas! After chatting with @jheek offline, the consensus is that it is indeed useful to return regular dicts, but that we block implementing this on merging the chex and flax dataclasses.

marcvanzee avatar Apr 14 '21 08:04 marcvanzee

Gear, very happy about this decision. I'd just like to add that

See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link).

Is a complete red herring. This is about hidden global state, whereas this discussion is specifically about explicit, non-global state. It's actually more about rng design than anything else, and what we are talking about doing here is already the "better" rng design where the user explicitly is given, and trusted to correctly handle, the state.

lucasb-eyer avatar Apr 16 '21 13:04 lucasb-eyer

Is there any further development on this?

PhilipVinc avatar Aug 12 '21 15:08 PhilipVinc

Sorry for the delay -- I was on parental leave.

@jheek could you tell us whether any progress has been made on merging the chex and flax dataclasses?

marcvanzee avatar Sep 06 '21 12:09 marcvanzee

What does merging the dataclasses consist of? Are flax dataclasses going to be inheriting the mapping interface?

NeilGirdhar avatar Dec 06 '21 20:12 NeilGirdhar

The merging of dataclasses is taking much longer than originally anticipated. I'll bring this up in our next sync meeting because I think we should start to move towards allowing mutability independently of actually merging the implementations witch chex

jheek avatar Dec 07 '21 07:12 jheek

I think we should start to move towards allowing mutability

Sorry, but why would you do that?

The merging of dataclasses is taking much longer than originally anticipated. I

Also, I stil don't understand what this merge will consist of. Flax's dataclasses are well-designed: They are just frozen dataclasses that register as pytrees, have a field function that conveniently supports marking static fields, and add a replace method. Besides the replace method (which is just a shortcut to dataclasses.replace), this is a minimal interface.

Chex datacasses are badly designed: they are not frozen, they can't mark static fields, and they unnecessarily expose the whole mapping interface, which means you can access fields as attributes or keys. They also expose a to_tuple method that is inferior to dataclasses.astuple, which supports nested dataclasses. The from_tuple method is also somewhat flimsy since it won't work with Python 3.10's new keyword-only arguments. This is not a minimal interface.

I was hoping to ditch tjax's dataclasses in favor of flax's, but if you're merging in any of chex's behavior, I won't be able to.

NeilGirdhar avatar Dec 07 '21 09:12 NeilGirdhar

We won't be removing features like frozen, static fields, and replace. We do however want to be less strict about enforcing functional patterns. Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.

As for the mapping interface. This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types. This is also why chex cannot easily add static fields because tf.nest doesn't support it. We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.

jheek avatar Dec 07 '21 10:12 jheek

Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.

For statistics, in my 5500 line Jax project, I call replace 9 times. It may be slightly more awkward than writing to attributes, but I don't think it's worth giving up the safety of all of the methods on my dataclasses being verified to be pure.

This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types.

Instead having a gigantic interface and passing the dataclass d to tf.nest, can't users pass dataclasses.asdict(d)?

This is also why chex cannot easily add static fields because tf.nest doesn't support it.

I see. Why not create an asdict function that removes the keys corresponding to static fields? Or more conveniently, convince Tensorflow to check for an as_dynamic_dict method and call it in tf.nest?

We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.

Yes! Thank you!

NeilGirdhar avatar Dec 07 '21 10:12 NeilGirdhar

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a good idea.

Yes, this is the tradeoff we have to think about and this we will discuss this further before making a final decision.

jheek avatar Dec 07 '21 11:12 jheek

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.

What is a particular form of this problem? Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc). I understand the need to ensure frozen data structures within modules (and we're not proposing this changes -- module.apply will still have a mutable argument and use FrozenDicts based on that). The only proposed change that I am aware of is changing the signatures of module.init and module.apply to not return FrozenDicts.

avital avatar Dec 07 '21 11:12 avital

By the way, I also think that Flax returning frozen dictionaries is extremely annoying. Changing this behaviour would also address https://github.com/deepmind/optax/issues/160

Moreover, our (NetKet) users and students learning Jax/Flax find it often confusing why they keep getting this object that they have to melt to edit.

PhilipVinc avatar Dec 07 '21 11:12 PhilipVinc

The only proposed change that I am aware of is changing the signatures of module.init and module.apply to not return FrozenDicts.

Sorry, I'm not actually discussing the topic of the issue. I just noticed a comment about merging chex.dataclass, and I wanted some clarification on that.

What is a particular form of this problem?

I can't find the example, but I saw one with treex (which doesn't enforce frozen dataclasses) where someone was doing

def f(x):
    x.some_member = some_value
    return x

@jit
def g(...):
    ...
    x = f(x)  # if you forget to assign to x, you will get different behavior for the jitted and unjitted function. 

Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).

Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.

NeilGirdhar avatar Dec 07 '21 11:12 NeilGirdhar

Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).

Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.

@NeilGirdhar I just mean things like updating state and params and reporting metrics -- it's totally fine to directly manipulate the variables dict in the main training loop, and people have to jump through (IMHO unnecessary) hoops to achieve this: https://github.com/google/flax/issues/1729#issuecomment-995839207.

avital avatar Dec 16 '21 15:12 avital

@avital Fair enough. I need to learn Flax better before I can really suggest something. A couple other options:

An at operator that does this under the covers, so that you can write:

embedding = params['params']['Embed_0']['embedding']
norm = jnp.linalg.norm(embedding, axis=-1, keepdims=True)
new_params = params.at['params']['Embed_0']['embedding'].divide(norm + 1e-10)
state = state.replace(params=new_params)

The at operator would return a handle like the one in jax.index_ops.

Or maybe a context manager that provides the handle and automatically rolls it back in when it ends:

with state.unfreeze() as unfrozen_state:
    unfrozen_state.params['params']['Embed_0']['embedding'] /= (jnp.linalg.norm(state.params['params']['Embed_0']['embedding'], axis=-1, keepdims=True) + 1e-10)

You'd still be jumping through hoops, but it's just one hoop.

NeilGirdhar avatar Dec 16 '21 15:12 NeilGirdhar

The problem with any hoop isn't it's complexity -- it's that it's something you have to learn suddenly, when you "just wanted to try this one thing". So any hoop should be justified by the benefit it gives you (hopefully a lot). Maybe I'm just misunderstanding this but I never understood the benefit of having module.apply and module.init return FrozenDicts. (I've always been strongly in support of FrozenDicts inside modules, which happens internally as a function of the mutable argument to module.apply)

avital avatar Dec 16 '21 15:12 avital

I guess another way to put it -- if someone really wants immutable data structures, they can always do, e.g. FrozenDict(module.init(...)). So the question is: which default serves the users best?

avital avatar Dec 16 '21 15:12 avital

And the answer is just plain dict, at least for this user here :)

lucasb-eyer avatar Dec 17 '21 20:12 lucasb-eyer

+1 for this! I have a lot of code that immediately calls .unfreeze() right after init and apply.

cgarciae avatar Dec 17 '21 21:12 cgarciae

Hey @NeilGirdhar! I believe you're looking for this example from Treex's User Guide.

cgarciae avatar Dec 17 '21 21:12 cgarciae

Since this would be a breaking change, we should bump Flax's version to avoid breaking OS user's using semantic versioning.

cgarciae avatar Dec 12 '22 14:12 cgarciae

FYI: @chiamp is going to look into this

marcvanzee avatar Jan 24 '23 20:01 marcvanzee

Closing after #3193 landed.

chiamp avatar Aug 31 '23 22:08 chiamp