flax
flax copied to clipboard
Dropout + `nn.jit`
Hi there,
I was following this guide, flax.linen.Dropout. Then I decided to add nn.jit, and started getting MyModel.__call__() missing 1 required positional argument: 'training', even though I passed it.
System information
- OS Platform: macOS 12.4
- Flax, jax, jaxlib versions:
Name: flax
Version: 0.6.11
---
Name: jax
Version: 0.4.13
---
Name: jaxlib
Version: 0.4.13
- Python version
python3.10 - No hardware acceleration was used
Problem you have encountered:
Traceback (most recent call last):
File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 34, in <module>
main()
File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 25, in main
variables = my_model.init(params_key, x, training=False)
TypeError: MyModel.__call__() missing 1 required positional argument: 'training'
What you expected to happen:
I expect jitted and non-jitted version to work the same. Or am I missing something?
Steps to reproduce:
import jax
import jax.numpy as jnp
import flax.linen as nn
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
def main():
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
my_model = nn.jit(MyModel)(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x, training=False)
if __name__ == "__main__":
main()
Question
I have saw the warning
/Users/artemsereda/miniconda3/envs/py310/lib/python3.10/site-packages/flax/core/lift.py:111: RuntimeWarning: kwargs are not supported in jit, so "training" is(are) ignored
warnings.warn(msg.format(name, ', '.join(kwargs.keys())), RuntimeWarning)
so I decided to change my code to
variables = my_model.init(params_key, x, False)
which then resulted in
Traceback (most recent call last):
File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 25, in <module>
main()
File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 21, in main
variables = my_model.init(params_key, x, False)
File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 12, in __call__
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
The problem arose with the `bool` function.
The error occurred while tracing the function core_fn at /Users/artemsereda/miniconda3/envs/py310/lib/python3.10/site-packages/flax/linen/transforms.py:305 for jit. This concrete value was not available in Python because it depends on the value of the argument args[2].
My next idea was to add the bool argument to static_argnums, as follows
def main():
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
my_model = nn.jit(MyModel, static_argnums=2)(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x, False)
This one worked, but as per the documentation, "Calling the jitted function with different values for these constants will trigger recompilation.". The above-mentioned guide suggests using training=True for training steps, and training=False for validation steps, which will mean, I will have to re-compile full model twice in each training epoch.
Is there any way to address this?
Hey @aaarrti, currently nn.jit only support positional arguments which is why you are getting this error. However, after you pass training as positional you get another error which is that you must specify that training is a static_argnum. Here is the working example:
import jax
import jax.numpy as jnp
import flax.linen as nn
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
def main():
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
my_model = nn.jit(MyModel, static_argnums=(2,))(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x, False)
if __name__ == "__main__":
main()
That said, its pretty rare to use nn.jit. Usually you just use jax.jit over the train_step as shown in the Quick Start.
That said, its pretty rare to use
nn.jit. Usually you just usejax.jitover thetrain_stepas shown in the Quick Start.
Hi @cgarciae,
it looks like I had a misconception about nn.jit. I though, it plays kind of similar role as tf.keras.Model.compile().
Am I wrong? What is the intended use case for nn.jit?
The example I provided is a basic one, but my original intention was to use nn.jit to compile a bigger model to speed up training. The workaround I came up with so far is:
import jax
import jax.numpy as jnp
import flax.linen as nn
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.jit(nn.Dense)(self.num_neurons)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
def main():
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x, False)
if __name__ == "__main__":
main()
It looks redundant for this toy example, but the performance improvements become easily noticeable as model size/complexity increases. I do also jax.jit my training step, though.
To sum up:
- This is not a bug, but rather my misunderstanding of intended
nn.jitusage - There is nothing about
nn.jitintended usage in the documentation