numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Performance enhancements for `init_strategy` may lead to unexpected behavior.

Open tillahoffmann opened this issue 8 months ago • 0 comments

There is some dedicated logic to enhance the performance for specific init_strategys. This logic requires that init_strategy is a partial or a function that returns a partial. E.g., I expected, maybe naively, that the following init strategy would work.

>>> import jax
>>> import numpyro
>>> 
>>> 
>>> def model():
...     numpyro.sample("x", numpyro.distributions.Normal())
...     numpyro.sample("y", numpyro.distributions.Normal())
>>> 
>>> 
>>> def init_and_get_auto_loc(init_strategy):
...     guide = numpyro.infer.autoguide.AutoDiagonalNormal(model, init_loc_fn=init_strategy)
...     
...     svi = numpyro.infer.SVI(model, guide, numpyro.optim.Adam(0.1), numpyro.infer.Trace_ELBO())
...     state = svi.init(jax.random.key(9))
...     return svi.get_params(state)["auto_loc"]
>>>
>>>
>>> init_strategy = lambda site: 3.0 if site["name"] == "x" else 7.0
>>> init_and_get_auto_loc(init_strategy)
TypeError: <lambda>() missing 1 required positional argument: 'site'

But wrapping in a partial works.

>>> from functools import partial
>>>
>>> init_and_get_auto_loc(partial(init_strategy))
Array([3., 7.], dtype=float32)

I came across this while trying to write an init strategy where some sites were initialized by value but the remainder initialized to uniform although with a different radius than the default of 2. Is this the intended behavior?

The relevant logic is here.

https://github.com/pyro-ppl/numpyro/blob/d6ba5685bb57e87ef9d7af17e975128bc1ed16d6/numpyro/infer/util.py#L373-L384

https://github.com/pyro-ppl/numpyro/blob/d6ba5685bb57e87ef9d7af17e975128bc1ed16d6/numpyro/infer/util.py#L742-L748

tillahoffmann avatar Feb 04 '25 17:02 tillahoffmann