flax
flax copied to clipboard
Error when calling `Module.tabulate` on normalization wrappers like `WeightNorm` and `SpectralNorm`
Follow-up from #3735. Partial fix in #3772.
Minimum repro:
import jax, jax.numpy as jnp
from flax import linen as nn
model = nn.WeightNorm(nn.Dense(3))
x = jnp.ones((1, 2))
key = jax.random.key(0)
print(model.tabulate(key,
x,
compute_flops=True,
compute_vjp_flops=True,
))
Error message:
Traceback (most recent call last):
File "/Users/marcuschiam/Desktop/asdf.py", line 20, in <module>
print(model.tabulate(key,
File "/Users/marcuschiam/flax/flax/linen/module.py", line 2843, in tabulate
return tabulate_fn(*args, **kwargs)
File "/Users/marcuschiam/flax/flax/linen/summary.py", line 315, in _tabulate_fn
table = table_fn(rngs, *fn_args, **fn_kwargs, **kwargs)
File "/Users/marcuschiam/flax/flax/linen/summary.py", line 490, in _get_table_fn
*_get_call_flops(c, compute_flops, compute_vjp_flops),
File "/Users/marcuschiam/flax/flax/linen/summary.py", line 400, in _get_call_flops
variables = jax.eval_shape(init, rngs, dynamic_leaves)
File "/Users/marcuschiam/flax/flax/linen/summary.py", line 392, in init
return c.module.init(
AttributeError: "Dense" object has no attribute "layer_forward". If "layer_forward" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.
When calling tabulate()
, calls
contains _CallInfo
objects where the method
does not exist in the module
, causing the AttributeError
seen in #3735.
This is illustrated by inspecting the _CallInfo
objects in the debugger:
for c in calls:
print(f'Index: {c.index}\tModule: {type(c.module)}\tPath: {c.path}\tMethod: {c.method}')
Index: 0 Module: <class 'flax.linen.normalization.WeightNorm'> Path: () Method: __call__
Index: 1 Module: <class 'flax.linen.linear.Dense'> Path: () Method: <lambda>
Index: 2 Module: <class 'flax.linen.linear.Dense'> Path: ('layer_instance',) Method: layer_forward
Index: 3 Module: <class 'flax.linen.linear.Dense'> Path: ('layer_instance',) Method: __call__
Index: 4 Module: <class 'flax.linen.normalization.WeightNorm'> Path: () Method: _l2_normalize
Index: 5 Module: <class 'flax.linen.normalization.WeightNorm'> Path: () Method: _l2_normalize
Index: 6 Module: <class 'flax.linen.linear.Dense'> Path: ('layer_instance',) Method: layer_forward
Index: 7 Module: <class 'flax.linen.linear.Dense'> Path: ('layer_instance',) Method: __call__
For the _CallInfo
object with index 2, Dense
does not have a layer_forward
method/attribute and an error is thrown.
For the _CallInfo
object with index 1, the <lambda>
function is skipped because the path ()
has already been visited in index 0.