flax
flax copied to clipboard
Raise clearer Exception when calling method of unbound module
Using this minimal example
import jax.numpy as np
from jax.numpy import log, exp
import jax.random as rand
import flax.linen as ln
class MultipleForw(ln.Module):
def setup(self):
self.s1 = self.param("s1", ln.initializers.ones, (1,))
def __call__(self, X, ):
return X * log(1 + exp(self.s1 - 1))
mf = MultipleForw()
X = np.arange(5)
mf.init(rand.PRNGKey(0), X)
mf(X)
Problem you have encountered:
The last line raised the rather opaque error message AttributeError: 'MultipleForw' object has no attribute 's1'
What you expected to happen:
The raised Exception should contain a hint that makes clear that calling a linen-Module correctly is by using mf.apply(parameters, input)
. See Discussion #1013
In #1072, I tried fixing this by creating a custom error class for Module AtttributeError, but after a discussion with @avital we found that this is not a very natural solution because users expect a normal AttributeError
when they are trying to access an unknown attribute in a Module.
Solving this issue is a bit more work, and probably not our highest priority. For now I'm lowering the priority of this issue because it seems we won't fix it soon, and we can higher it when it turns out that more users run into this problem.
Also unassigning myself since I don't plan to work on this soon.
I'd like to take this issue. A simple solution would be to customize the current error message to suggest calling apply
if self.scope is None
. I'll create a PR as this is a rather simple fix, if we want to tackle it a different way we can discuss there.
This already sounds an order of magnitude better than the current situation