flax icon indicating copy to clipboard operation
flax copied to clipboard

Best practice of dealing with sporadic FrozenDict conversions?

Open TimSchneider42 opened this issue 1 year ago • 2 comments

Hi,

in my project, I have multiple instances of modules that look approximately like this:

class TreeEncoder(nn.Module):
    leaf_encoders: Any  # pytree of nn.Modules

    @nn.compact
    def __call__(self, data: Any):  # data is a pytree of jax.Array
        return jnp.stack(jax.tree_leaves(jax.tree_map(lambda d, enc: enc(d), data, self.leaf_encoders)), axis=-1)

So essentially, this module has one encoder for every leaf of the input pytree and uses them to obtain a vector encoding of the entire tree. Usage could look like this:

encoder = TreeEncoder({"img": ResNet(...), "vector": DenseNN(...)})
output = encoder({"img": jnp.zeros((480, 640, 3)), "vector": jnp.zeros(5)})

The beauty of JAX is that the pytree can be an arbitrary structure and not only dicts are possible. I make heavy use of this fact and sometimes just define my own dataclasses.

But here comes my problem: if I use a dictionary, flax will sometimes convert them into FrozenDicts and then the call to jax.tree_map fails with

Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value: {'img': ..., 'vector': ...}

It took me a while to understand when these conversions happen, but I am fairly certain now that flax behaves as follows:

  1. The inputs to __call__ are never converted (data is always a dict)
  2. The leaf_encoders are converted iff the TreeEncoder instance is created inside a @nn.compact call

How do I deal with this? I cannot call self.leaf_encoders.unfreeze() because it might not be a FrozenDict (and not even a dictionary). Is there some way I can disable the FrozenDict conversion in general? Or is it possible to make FrozenDicts and dicts compatible as arguments to jax.tree_map?

Thanks a lot in advance!

Best, Tim

TimSchneider42 avatar Jun 13 '24 13:06 TimSchneider42

Hmm, are you using the latest Flax? The latest Flax should have flax.config.flax_return_frozendict as False and return normal dict by default.

In other case, you can do an explicit check like hasattr(x, unfreeze) and call x.unfreeze() only when true. But please let me know if you are using latest Flax and still run into this problem.

cc @chiamp

IvyZX avatar Jun 21 '24 01:06 IvyZX

Hi,

I have to double check which flax version I am using, but I installed it from PyPi around 2 months ago, so it should be fairly recent.

To be clear, my problem is not that modules return FrozenDicts, but rather that fields of modules sometimes get converted to FrozenDicts. I think I will go with your suggestion for now, but an option to turn that behavior off fully would be nice for the future.

Best, Tim

TimSchneider42 avatar Jun 21 '24 06:06 TimSchneider42