flax
flax copied to clipboard
Unexpected behavior for @nn.compact_name_scope
import jax
from flax import linen as nn
from jax import random
import jax.numpy as jnp
class Foo(nn.Module):
@nn.compact_name_scope
def up(self, x):
return self._embed(x) + nn.Dense(4)(x)
def _embed(self, x):
return nn.Dense(4)(x)
Foo().init(random.PRNGKey(0), jnp.zeros((3, 4)), method=Foo.up)
This throws error: flax.errors.AssignSubModuleError: Submodule Dense must be defined in
setup()or in a method wrapped in
@compact (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)
However when I replace @nn.compact_name_scope
with @nn.compact
it does work.
So it seems methods wrapped in @nn.compact_name_scope
can only call other methods that are wrapped in that as well, which is different behavior than @nn.compact
.
Is this behavior intended? If so, it should probably be documented. Otherwise would it be possible to fix it?
An easy way out is to decorate _embed
with @nn.nowrap
(which seems like what you should do nonetheless).