serialization.from_bytes not checking shapes of loaded arrays w.r.t. target
Currently serialization.from_bytes checks that all keys in target are present the loaded structure and filter out additional keys not present in target. However, currently shapes are not checked for consistency between the loaded structure and target, such check could be useful for detecting misconfigured Modules early. Could address #2110.
Example:
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax import serialization
x = jnp.ones((1, 28, 28, 1))
# save
module = nn.Dense(10)
variables = module.init(jax.random.PRNGKey(1), x)
variable_bytes = serialization.to_bytes(variables)
# load
new_module = nn.Dense(20) # different number of features
wrong_shape_target = new_module.init(jax.random.PRNGKey(1), x)
new_variables = serialization.from_bytes(wrong_shape_target, variable_bytes) # currently no error
y = new_module.apply(new_variables, x) # error
A proposal could be to add a check_shapes: bool = False argument to from_bytes that would optionally make such check fail with a list of paths that shapes don't match, to avoid breaking changes we set the default to False.
Reassigning this to @chiamp since he is working on a fix.
@IvyZX does orbax do check shapes when loading? Maybe we could request this to the orbax team and close this issue.
I don't think they check shapes, but I am not sure it's a good idea to require that check. That means before loading large arrays, user need to pre-create a very large target array, which means twice the memory cost at restoration time, and it often OOMs.
I am thinking we should close this issue, if there are technical reason not to do it and Orbax will handle serialization from now on.
If someone deems this issue important feel free to comment and we might revise it in the future.