numpyro
numpyro copied to clipboard
Performance enhancements for `init_strategy` may lead to unexpected behavior.
There is some dedicated logic to enhance the performance for specific init_strategys. This logic requires that init_strategy is a partial or a function that returns a partial. E.g., I expected, maybe naively, that the following init strategy would work.
>>> import jax
>>> import numpyro
>>>
>>>
>>> def model():
... numpyro.sample("x", numpyro.distributions.Normal())
... numpyro.sample("y", numpyro.distributions.Normal())
>>>
>>>
>>> def init_and_get_auto_loc(init_strategy):
... guide = numpyro.infer.autoguide.AutoDiagonalNormal(model, init_loc_fn=init_strategy)
...
... svi = numpyro.infer.SVI(model, guide, numpyro.optim.Adam(0.1), numpyro.infer.Trace_ELBO())
... state = svi.init(jax.random.key(9))
... return svi.get_params(state)["auto_loc"]
>>>
>>>
>>> init_strategy = lambda site: 3.0 if site["name"] == "x" else 7.0
>>> init_and_get_auto_loc(init_strategy)
TypeError: <lambda>() missing 1 required positional argument: 'site'
But wrapping in a partial works.
>>> from functools import partial
>>>
>>> init_and_get_auto_loc(partial(init_strategy))
Array([3., 7.], dtype=float32)
I came across this while trying to write an init strategy where some sites were initialized by value but the remainder initialized to uniform although with a different radius than the default of 2. Is this the intended behavior?
The relevant logic is here.
https://github.com/pyro-ppl/numpyro/blob/d6ba5685bb57e87ef9d7af17e975128bc1ed16d6/numpyro/infer/util.py#L373-L384
https://github.com/pyro-ppl/numpyro/blob/d6ba5685bb57e87ef9d7af17e975128bc1ed16d6/numpyro/infer/util.py#L742-L748