numpyro
numpyro copied to clipboard
More than 1 `input_shape` when initializing `flax_module`
Some modules require more than 1 input when initializing, which can be passed through kwargs
. But this doesn't work in some cases. For example:
class RNN(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
split_rngs={'params': False})
@nn.compact
def __call__(self, state, x):
return RNNCell()(state, x)
I tried to declare this with the following statement:
rnn = flax_module(
'rnn',
RNN(),
input_shape=(num_hiddens,),
x=jnp.ones((10, 10))
)
But I can't use kwargs because nn.transforms.scan
does not support them:
RuntimeWarning: kwargs are not supported in scan, so "x" is(are) ignored
I worked around this by wrapping my RNN
with another class, after which I could pass x
as a kwarg. However, I think input_shape
should allow passing dimensions for more than one input.
https://github.com/pyro-ppl/numpyro/blob/0bff074a4a54a593a7fab7e68b5c10f85dd332a6/numpyro/contrib/module.py#L83
Hi @UmarJ, I think you can ignore that argument and use state=..., x=...
directly (docs of kwargs: "optional keyword arguments to initialize flax neural network as an alternative to input_shape
")
Hi @UmarJ, I think you can ignore that argument and use
state=..., x=...
directly (docs of kwargs: "optional keyword arguments to initialize flax neural network as an alternative toinput_shape
")
Like I said, kwargs don't work when __call__
is wrapped in nn.transforms.scan
, for example the RNN model above.
RuntimeWarning: kwargs are not supported in scan, so "x" is(are) ignored
It's a minor issue, just thought I'd raise it in case someone else was having a similar problem.
Thanks for explaining! I misinterpret the issue.
I think input_shape is a bad pattern here. Originally we would like to keep the same signature w.r.t. jax stax. Maybe it is time to support *args
? (I made the wrong comment here against args
- seems like it is useful for this usage case.