flax icon indicating copy to clipboard operation
flax copied to clipboard

We broke the 1:1 correspondence with attribute names and variable dict names

Open marcvanzee opened this issue 2 years ago • 3 comments

from @levskaya:

class Foo(nn.Module):
  def setup(self):
    self.foo = nn.Dense(1, name="bar")
    self.qup = self.param('baz', lambda k: jnp.zeros((1,)))
  def __call__(self, x):
    return self.foo(x) + self.qup
Foo().init(jax.random.PRNGKey(0), jnp.zeros((1,)))
FrozenDict({
    params: {
        baz: DeviceArray([0.], dtype=float32),
        bar: {
            kernel: DeviceArray([[-0.58376074]], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

The above used to break loudly, and it should!

Initial investigation by @avital:

Over the past 1.5 years (I ran a test against every single commit), we actually never had any commit where the following code raised an exception:

  def test_setattr_name_var_agreement_in_setup(self):
    class Foo(nn.Module):
      def setup(self):
        self.qup = self.param('baz', lambda k: 0)
      def __call__(self):
        pass

    Foo(parent=None).init(jax.random.PRNGKey(0))

But we did, in the part, disallow entirely the use of name= for submodules defined within setup, which would have disallowed setting the wrong name for a submodule in setup. We lost that guard with https://github.com/google/flax/pull/976/files

I don't think we ever had tests for the variable attribute correspondence. We do have tests that you can't define two variables with the same name in different collections but not that the name aligns with the attribute being assigned to.

Suggestion from @jheek:

I think you could do something like this to disallowed giving different names in setup.

def __setattr__(self, name, value):
  if any(name in variables[col] for in col):
   assert variables[col] is value, f"A variable named {name} already exist. We don't allow variables and fields to have overlapping names"

marcvanzee avatar May 04 '22 14:05 marcvanzee

Looking into this :eyes:

cgarciae avatar May 04 '22 16:05 cgarciae

I'm honestly afraid that this cat is already out of the bag. Many users' models (and checkpoints!) now exploit the current freedom to set the name apart from python attribute name. If we tried forcing it at this point we'd probably piss a lot of people off.

levskaya avatar May 04 '22 19:05 levskaya

Based on @jheek's initial proposal I came up with this logic:

def _is_valid_field_value(name, val, variables) -> bool:
  value_found = False
  
  for collection in variables.values():
      for field, existing_value in collection.items():
        if val is existing_value:
          value_found = True
          if name == field:
            return True
  
  return not value_found

It will look for val on all collections, if it is found and name matches its good, if its found but no name matches you get a runtime error. However there are easy ways to get into trouble:

class Foo(nn.Module):

  def setup(self):
    self.bar = self.param("bar", lambda key: jnp.array(1))
    self.baz = self.bar # error: value was found but no `baz` key exists

What to do?

Seems like there is no general way to solve the issue, however if users don't use "value types" (ints and floats) and avoid the pattern above it should be good.

cgarciae avatar May 04 '22 20:05 cgarciae

@cgarciae is this issue still active? I see @jheek reviewed your PR #2102 but then it went stale. Maybe you two can try to work together to get this in, or if it turns out to be unfeasible we close the issue?

marcvanzee avatar Dec 13 '22 07:12 marcvanzee

@marcvanzee can you check the internal CL? Maybe it was breaking some internal tests which is why we stopped? The PR is also tricky because it has a ton of edge cases we cant cover.

cgarciae avatar Dec 13 '22 18:12 cgarciae