flax
flax copied to clipboard
Should Flax return FrozenDicts or regular dicts?
This topic is discussed regularly internally, and I feel we haven't reached a consensus here. Below are some arguments collected from users for both positions, feel free to add.
Arguments in favor of FrozenDict
- @avital: If you use normal dicts, it is easy to mutate them, which means the behavior may differ depending on whether the function in which the modification is made is jitted or not. Example:
def f(params):
params['conv1']['weight'] = ...
return ...some computation over params
params = load_from_checkpoint()
print(f(params))
# now what is the value of params['conv1']['weight']?
# depending on whether f is jitted or not, you'd get different results
Arguments in favor of regular dicts
-
@lucasb-eyer: Flax tells me "here's these precious weights, please hold them for me and give them back to me later on, but DONT TOUCH" it begs the question: why give them to me in the first place, if I'm not supposed to do anything with it?
-
@avital: I also think it'd be better for Flax to return normal Python dicts, but still use FrozenDict within modules (via the
mutable
argument toapply
).
I think the Python saying "We're all consenting adults here" is pretty fitting. In my view, trading convenience for safety is reasonable here because JAX users should know (or will quickly come to learn) that under the JAX transformations, they should not mutate state. Since FrozenDict
s are not as ergonomic as normal dicts
, I tend to unfreeze
them as soon as they're returned from init
anyway.
Though I would prefer the user-facing API to just use dict
, I wouldn't mind FrozenDict
if the behaviour was closer to dict
for non-mutating cases. In particular,
- The ability to pass
FrozenDict
toflax.traverse_util.flatten_dict
. - The views
FrozenDict.keys()
andFrozenDict.values()
should have a similar style__repr__
to normaldict
s so they are easy to inspect interactively in notebooks (right now they just show the wholeFrozenDict
) - The ability to update/merge with normal
dict
s
Explicit state management is one of my favourite aspects of Flax, as it gives me the ability to transparently manipulate modules/parameters without worrying about hidden side effects. I totally agree with @lucasb-eyer's point that it's counterproductive to provide explicit state without allowing the user to fully control it.
I think the Python saying "We're all consenting adults here" is pretty fitting
Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.
That said, I don't think FrozenDict has shown to be a very effective tool safety tool to avoid this kind of error. We should probably keep using it internally to avoid accidental reference sharing but for users it seems to big a burden while it doesn't avoid the more common issue of closing over mutable state (typically created by the user) or using things like np.random in a jitted function.
I do think we should at least provide an easy way to clone a pytree if we allow it to contain mutable containers. Something like the following:
def clone_pytree(xs):
# cloning is just an identity mapping
return jax.tree_map(lambda x: x, xs)
def some_nested_transformation():
my_copy = flax.traverse_util.clone_pytree(variables)
my_copy['batch_stats']['x'] += 2.
return my_copy
Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.
Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.
This is a good point, especially as a codebase grows it can sneak past you. Personally, if FrozenDict
better matched the ergonomics of dict
outside of mutation, I would not see it as a burden, at least for my own use cases.
Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.
I actually quite like the immutable dataclasses, since the .replace(...)
API is similar namedtuples. In my view, the inconveniences that arise with FrozenDict don't happen here since dataclasses don't have arbitrary structure and you don't generally manipulate the that structure.
Thanks for the input @n2cholas! After chatting with @jheek offline, the consensus is that it is indeed useful to return regular dicts, but that we block implementing this on merging the chex and flax dataclasses.
Gear, very happy about this decision. I'd just like to add that
See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link).
Is a complete red herring. This is about hidden global state, whereas this discussion is specifically about explicit, non-global state. It's actually more about rng design than anything else, and what we are talking about doing here is already the "better" rng design where the user explicitly is given, and trusted to correctly handle, the state.
Is there any further development on this?
Sorry for the delay -- I was on parental leave.
@jheek could you tell us whether any progress has been made on merging the chex and flax dataclasses?
What does merging the dataclasses consist of? Are flax dataclasses going to be inheriting the mapping interface?
The merging of dataclasses is taking much longer than originally anticipated. I'll bring this up in our next sync meeting because I think we should start to move towards allowing mutability independently of actually merging the implementations witch chex
I think we should start to move towards allowing mutability
Sorry, but why would you do that?
The merging of dataclasses is taking much longer than originally anticipated. I
Also, I stil don't understand what this merge will consist of. Flax's dataclasses are well-designed: They are just frozen dataclasses that register as pytrees, have a field function that conveniently supports marking static fields, and add a replace
method. Besides the replace
method (which is just a shortcut to dataclasses.replace
), this is a minimal interface.
Chex datacasses are badly designed: they are not frozen, they can't mark static fields, and they unnecessarily expose the whole mapping interface, which means you can access fields as attributes or keys. They also expose a to_tuple
method that is inferior to dataclasses.astuple
, which supports nested dataclasses. The from_tuple
method is also somewhat flimsy since it won't work with Python 3.10's new keyword-only arguments. This is not a minimal interface.
I was hoping to ditch tjax's dataclasses in favor of flax's, but if you're merging in any of chex's behavior, I won't be able to.
We won't be removing features like frozen, static fields, and replace. We do however want to be less strict about enforcing functional patterns. Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.
As for the mapping interface. This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types. This is also why chex cannot easily add static fields because tf.nest doesn't support it. We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.
Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.
I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit
, grad
, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.
For statistics, in my 5500 line Jax project, I call replace
9 times. It may be slightly more awkward than writing to attributes, but I don't think it's worth giving up the safety of all of the methods on my dataclasses being verified to be pure.
This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types.
Instead having a gigantic interface and passing the dataclass d
to tf.nest
, can't users pass dataclasses.asdict(d)
?
This is also why chex cannot easily add static fields because tf.nest doesn't support it.
I see. Why not create an asdict
function that removes the keys corresponding to static fields? Or more conveniently, convince Tensorflow to check for an as_dynamic_dict
method and call it in tf.nest
?
We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.
Yes! Thank you!
I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a good idea.
Yes, this is the tradeoff we have to think about and this we will discuss this further before making a final decision.
I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (
jit
,grad
, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.
What is a particular form of this problem? Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc). I understand the need to ensure frozen data structures within modules (and we're not proposing this changes -- module.apply
will still have a mutable
argument and use FrozenDicts based on that). The only proposed change that I am aware of is changing the signatures of module.init
and module.apply
to not return FrozenDicts.
By the way, I also think that Flax returning frozen dictionaries is extremely annoying. Changing this behaviour would also address https://github.com/deepmind/optax/issues/160
Moreover, our (NetKet) users and students learning Jax/Flax find it often confusing why they keep getting this object that they have to melt to edit.
The only proposed change that I am aware of is changing the signatures of module.init and module.apply to not return FrozenDicts.
Sorry, I'm not actually discussing the topic of the issue. I just noticed a comment about merging chex.dataclass
, and I wanted some clarification on that.
What is a particular form of this problem?
I can't find the example, but I saw one with treex (which doesn't enforce frozen dataclasses) where someone was doing
def f(x):
x.some_member = some_value
return x
@jit
def g(...):
...
x = f(x) # if you forget to assign to x, you will get different behavior for the jitted and unjitted function.
Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).
Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.
Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).
Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.
@NeilGirdhar I just mean things like updating state and params and reporting metrics -- it's totally fine to directly manipulate the variables dict in the main training loop, and people have to jump through (IMHO unnecessary) hoops to achieve this: https://github.com/google/flax/issues/1729#issuecomment-995839207.
@avital Fair enough. I need to learn Flax better before I can really suggest something. A couple other options:
An at
operator that does this under the covers, so that you can write:
embedding = params['params']['Embed_0']['embedding']
norm = jnp.linalg.norm(embedding, axis=-1, keepdims=True)
new_params = params.at['params']['Embed_0']['embedding'].divide(norm + 1e-10)
state = state.replace(params=new_params)
The at
operator would return a handle like the one in jax.index_ops
.
Or maybe a context manager that provides the handle and automatically rolls it back in when it ends:
with state.unfreeze() as unfrozen_state:
unfrozen_state.params['params']['Embed_0']['embedding'] /= (jnp.linalg.norm(state.params['params']['Embed_0']['embedding'], axis=-1, keepdims=True) + 1e-10)
You'd still be jumping through hoops, but it's just one hoop.
The problem with any hoop isn't it's complexity -- it's that it's something you have to learn suddenly, when you "just wanted to try this one thing". So any hoop should be justified by the benefit it gives you (hopefully a lot). Maybe I'm just misunderstanding this but I never understood the benefit of having module.apply
and module.init
return FrozenDicts. (I've always been strongly in support of FrozenDicts
inside modules, which happens internally as a function of the mutable
argument to module.apply
)
I guess another way to put it -- if someone really wants immutable data structures, they can always do, e.g. FrozenDict(module.init(...))
. So the question is: which default serves the users best?
And the answer is just plain dict, at least for this user here :)
+1 for this! I have a lot of code that immediately calls .unfreeze()
right after init
and apply
.
Hey @NeilGirdhar! I believe you're looking for this example from Treex's User Guide.
Since this would be a breaking change, we should bump Flax's version to avoid breaking OS user's using semantic versioning.
FYI: @chiamp is going to look into this
Closing after #3193 landed.