flax icon indicating copy to clipboard operation
flax copied to clipboard

Unexpected behavior for @nn.compact_name_scope

Open marcvanzee opened this issue 1 year ago • 1 comments

import jax
from flax import linen as nn
from jax import random
import jax.numpy as jnp

class Foo(nn.Module):
  @nn.compact_name_scope
  def up(self, x):
    return self._embed(x) + nn.Dense(4)(x)

  def _embed(self, x):
    return nn.Dense(4)(x)

Foo().init(random.PRNGKey(0), jnp.zeros((3, 4)), method=Foo.up)

This throws error: flax.errors.AssignSubModuleError: Submodule Dense must be defined in setup()or in a method wrapped in@compact (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)

However when I replace @nn.compact_name_scope with @nn.compact it does work.

So it seems methods wrapped in @nn.compact_name_scope can only call other methods that are wrapped in that as well, which is different behavior than @nn.compact.

Is this behavior intended? If so, it should probably be documented. Otherwise would it be possible to fix it?

marcvanzee avatar Jan 22 '24 08:01 marcvanzee

An easy way out is to decorate _embed with @nn.nowrap (which seems like what you should do nonetheless).

nova77 avatar Jun 17 '24 18:06 nova77