numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

More than 1 `input_shape` when initializing `flax_module`

Open UmarJ opened this issue 2 years ago • 3 comments

Some modules require more than 1 input when initializing, which can be passed through kwargs. But this doesn't work in some cases. For example:

class RNN(nn.Module):
    @functools.partial(
        nn.transforms.scan,
        variable_broadcast='params',
        split_rngs={'params': False})
    @nn.compact
    def __call__(self, state, x):
        return RNNCell()(state, x)

I tried to declare this with the following statement:

rnn = flax_module(
    'rnn',
    RNN(),
    input_shape=(num_hiddens,),
    x=jnp.ones((10, 10))
)

But I can't use kwargs because nn.transforms.scan does not support them:

RuntimeWarning: kwargs are not supported in scan, so "x" is(are) ignored

I worked around this by wrapping my RNN with another class, after which I could pass x as a kwarg. However, I think input_shape should allow passing dimensions for more than one input.

https://github.com/pyro-ppl/numpyro/blob/0bff074a4a54a593a7fab7e68b5c10f85dd332a6/numpyro/contrib/module.py#L83

UmarJ avatar Apr 13 '22 11:04 UmarJ

Hi @UmarJ, I think you can ignore that argument and use state=..., x=... directly (docs of kwargs: "optional keyword arguments to initialize flax neural network as an alternative to input_shape")

fehiepsi avatar Apr 13 '22 14:04 fehiepsi

Hi @UmarJ, I think you can ignore that argument and use state=..., x=... directly (docs of kwargs: "optional keyword arguments to initialize flax neural network as an alternative to input_shape")

Like I said, kwargs don't work when __call__ is wrapped in nn.transforms.scan, for example the RNN model above.

RuntimeWarning: kwargs are not supported in scan, so "x" is(are) ignored

It's a minor issue, just thought I'd raise it in case someone else was having a similar problem.

UmarJ avatar Apr 13 '22 19:04 UmarJ

Thanks for explaining! I misinterpret the issue.

I think input_shape is a bad pattern here. Originally we would like to keep the same signature w.r.t. jax stax. Maybe it is time to support *args? (I made the wrong comment here against args - seems like it is useful for this usage case.

fehiepsi avatar Apr 13 '22 20:04 fehiepsi