flax
flax copied to clipboard
We broke the 1:1 correspondence with attribute names and variable dict names
from @levskaya:
class Foo(nn.Module):
def setup(self):
self.foo = nn.Dense(1, name="bar")
self.qup = self.param('baz', lambda k: jnp.zeros((1,)))
def __call__(self, x):
return self.foo(x) + self.qup
Foo().init(jax.random.PRNGKey(0), jnp.zeros((1,)))
FrozenDict({
params: {
baz: DeviceArray([0.], dtype=float32),
bar: {
kernel: DeviceArray([[-0.58376074]], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
})
The above used to break loudly, and it should!
Initial investigation by @avital:
Over the past 1.5 years (I ran a test against every single commit), we actually never had any commit where the following code raised an exception:
def test_setattr_name_var_agreement_in_setup(self):
class Foo(nn.Module):
def setup(self):
self.qup = self.param('baz', lambda k: 0)
def __call__(self):
pass
Foo(parent=None).init(jax.random.PRNGKey(0))
But we did, in the part, disallow entirely the use of name=
for submodules defined within setup
, which would have disallowed setting the wrong name for a submodule in setup. We lost that guard with https://github.com/google/flax/pull/976/files
I don't think we ever had tests for the variable attribute correspondence. We do have tests that you can't define two variables with the same name in different collections but not that the name aligns with the attribute being assigned to.
Suggestion from @jheek:
I think you could do something like this to disallowed giving different names in setup.
def __setattr__(self, name, value):
if any(name in variables[col] for in col):
assert variables[col] is value, f"A variable named {name} already exist. We don't allow variables and fields to have overlapping names"
Looking into this :eyes:
I'm honestly afraid that this cat is already out of the bag. Many users' models (and checkpoints!) now exploit the current freedom to set the name apart from python attribute name. If we tried forcing it at this point we'd probably piss a lot of people off.
Based on @jheek's initial proposal I came up with this logic:
def _is_valid_field_value(name, val, variables) -> bool:
value_found = False
for collection in variables.values():
for field, existing_value in collection.items():
if val is existing_value:
value_found = True
if name == field:
return True
return not value_found
It will look for val
on all collections, if it is found and name
matches its good, if its found but no name matches you get a runtime error. However there are easy ways to get into trouble:
class Foo(nn.Module):
def setup(self):
self.bar = self.param("bar", lambda key: jnp.array(1))
self.baz = self.bar # error: value was found but no `baz` key exists
What to do?
Seems like there is no general way to solve the issue, however if users don't use "value types" (ints and floats) and avoid the pattern above it should be good.
@cgarciae is this issue still active? I see @jheek reviewed your PR #2102 but then it went stale. Maybe you two can try to work together to get this in, or if it turns out to be unfeasible we close the issue?
@marcvanzee can you check the internal CL? Maybe it was breaking some internal tests which is why we stopped? The PR is also tricky because it has a ton of edge cases we cant cover.