flax
flax copied to clipboard
Multiple Inheritance -> doesn't recognize as Module throws ValueError: parent must be None, Module or Scope
Discussed in https://github.com/google/flax/discussions/1390
Originally posted by SauravMaheshkar June 26, 2021 I'm working on a Flax implementation for ProteinBERT: A universal deep-learning model of protein sequence and function. My work so far is in SauravMaheshkar/ProteinBERT.
I've made a simple test.py
to check instantiation using the .init()
function. My test script is as follows :
from proteinbert import ProteinBERT
import jax
from jax import random
def test():
seq = jax.random.randint(
key=random.PRNGKey(0), minval=0, maxval=21, shape=(2, 2048)
)
annotation = jax.random.randint(
key=random.PRNGKey(0), minval=0, maxval=1, shape=(2, 8943)
)
init_rngs = {"params": random.PRNGKey(0), "layers": random.PRNGKey(1)}
ProteinBERT().init(init_rngs, seq, annotation)
if __name__ == "__main__":
test()
And I've been getting this error message
Error Message
Traceback (most recent call last):
File "/Users/sauravmaheshkar/github/protein_bert/test.py", line 21, in <module>
test()
File "/Users/sauravmaheshkar/github/protein_bert/test.py", line 17, in test
ProteinBERT().init(init_rngs, seq, annotation)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 1000, in init
method=method, mutable=mutable, **kwargs)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 969, in init_with_output
{}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 939, in apply
)(variables, *args, **kwargs, rngs=rngs)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/core/scope.py", line 687, in wrapper
y = fn(root, *args, **kwargs)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 1178, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 266, in wrapped_module_method
self._try_setup()
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 679, in _try_setup
self.setup()
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/Users/sauravmaheshkar/github/protein_bert/proteinbert/model.py", line 82, in setup
Reduce("b n d -> b d", "mean"),
File "<string>", line 5, in __init__
File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 599, in __post_init__
raise ValueError("parent must be None, Module or Scope")
ValueError: parent must be None, Module or Scope
The problem lies in the Reduce
defined in proteinbert/utils.py class which is defined as follows:
class Reduce(ReduceMixin, nn.Module):
"""
Flax Module to act as a Reduce layer (from einops)
"""
def __call__(self, input):
return self._apply_recipe(input)
The idea is to create a Reduce
layer/Module for flax which performs the reduce
operation from einops
. Although the module inherits from flax.linen.Module
it still throws a ValueError
.
Any help would be much appreciated 😊.
@marcvanzee will be investigating this.
A few guiding questions:
- Why would einops be implemented as a Module instead of just a function?
- Why is multiple inheritance needed here?
- Regardless, this error shouldn't happen. So even if we answer (1) and (2) in a way that means there's a workaround, we should still fix this bug.
Just noticing this issue for the first time... I've seen similarly weird issues with Mixins and Flax resolved in the past by simply changing the order of the multiple inheritance - e.g. class Reduce(nn.Module, ReduceMixin):
to put nn.Module
first. I'm not 100% sure this is the same kind of issue that I've seen before w. mixins, but I'd certainly be curious if that would have fixed the issue...
I was playing around with mixins to see how they interact with Module
.
See experiments
Following case works:import flax.linen as nn
import jax.numpy as jnp
import jax
class Mixin:
def __call__(self, x):
return self.dense(x)
class MyModule(nn.Module, Mixin):
def setup(self):
self.dense = nn.Dense(2)
module_a = MyModule()
variables = module_a.init(jax.random.PRNGKey(0), jnp.ones((1, 1)))
However, passing setup
to Mixin
fails:
class Mixin:
def setup(self):
self.dense = nn.Dense(2)
def __call__(self, x):
return self.dense(x)
class MyModule(nn.Module, Mixin):
pass
# AttributeError: "MyModule" object has no attribute "dense"
This is again fixed if Mixin
is set as the first parent:
class Mixin:
def setup(self):
self.dense = nn.Dense(2)
def __call__(self, x):
return self.dense(x)
class MyModule(Mixin, nn.Module):
pass
Also, you cannot define compact
methods on mixins (this is probably expected?):
class Mixin:
@nn.compact
def __call__(self, x):
return nn.Dense(2)(x)
class MyModule(nn.Module, Mixin): # swapping doesn't help
pass
Discussion
Based on these experiments the only insight I see is: don't define scope-dependent operations (compact
, self.param
/variable
) in inside mixins as their methods will not be wrapped appropriately. Not sure if there is a way to properly wrap mixin methods in __init_subclass__
, either _get_local_method_names
is not detecting them or they are not available when __init_subclass__
is called.