flax icon indicating copy to clipboard operation
flax copied to clipboard

serialization.from_bytes not checking shapes of loaded arrays w.r.t. target

Open cgarciae opened this issue 3 years ago • 1 comments

Currently serialization.from_bytes checks that all keys in target are present the loaded structure and filter out additional keys not present in target. However, currently shapes are not checked for consistency between the loaded structure and target, such check could be useful for detecting misconfigured Modules early. Could address #2110.

Example:

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax import serialization

x = jnp.ones((1, 28, 28, 1))

# save
module = nn.Dense(10)
variables = module.init(jax.random.PRNGKey(1), x)
variable_bytes = serialization.to_bytes(variables)

# load
new_module = nn.Dense(20) # different number of features
wrong_shape_target = new_module.init(jax.random.PRNGKey(1), x)
new_variables = serialization.from_bytes(wrong_shape_target, variable_bytes) # currently no error

y = new_module.apply(new_variables, x) # error

cgarciae avatar Sep 05 '22 22:09 cgarciae

A proposal could be to add a check_shapes: bool = False argument to from_bytes that would optionally make such check fail with a list of paths that shapes don't match, to avoid breaking changes we set the default to False.

cgarciae avatar Sep 05 '22 22:09 cgarciae

Reassigning this to @chiamp since he is working on a fix.

marcvanzee avatar Dec 13 '22 12:12 marcvanzee

@IvyZX does orbax do check shapes when loading? Maybe we could request this to the orbax team and close this issue.

cgarciae avatar Apr 25 '23 19:04 cgarciae

I don't think they check shapes, but I am not sure it's a good idea to require that check. That means before loading large arrays, user need to pre-create a very large target array, which means twice the memory cost at restoration time, and it often OOMs.

IvyZX avatar Apr 26 '23 01:04 IvyZX

I am thinking we should close this issue, if there are technical reason not to do it and Orbax will handle serialization from now on.

cgarciae avatar Apr 26 '23 14:04 cgarciae

If someone deems this issue important feel free to comment and we might revise it in the future.

cgarciae avatar Apr 26 '23 14:04 cgarciae