pymc
pymc copied to clipboard
ENH: Option to adjust the uniform distribution bounds for jitter/initialisation
Before
def make_initial_point_fn(
*,
free_rvs: Sequence[TensorVariable],
rvs_to_transforms: dict[TensorVariable, Transform],
initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None],
jitter_rvs: set[TensorVariable] | None = None,
default_strategy: str = "support_point",
return_transformed: bool = False,
) -> list[TensorVariable]:
# ... existing code ...
if variable in jitter_rvs:
jitter = pt.random.uniform(-1, 1, size=value.shape)
jitter.name = f"{variable.name}_jitter"
value = value + jitter
After
def make_initial_point_fn(
*,
model,
overrides: StartDict | None = None,
jitter_rvs: set[TensorVariable] | None = None,
jitter_bounds: tuple[float, float] = (-1, 1), # <---- new
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
# ... existing code ...
initial_values = make_initial_point_expression(
free_rvs=model.free_RVs,
rvs_to_transforms=model.rvs_to_transforms,
initval_strategies=initval_strats,
jitter_rvs=jitter_rvs,
jitter_bounds=jitter_bounds, # <---- new
default_strategy=default_strategy,
return_transformed=return_transformed,
)
# ... rest of existing code ...
def make_initial_point_fns_per_chain(
*,
model,
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
jitter_bounds: tuple[float, float] = (-1, 1), # <---- new
chains: int,
) -> list[Callable]:
if isinstance(overrides, dict) or overrides is None:
ipfns = [
make_initial_point_fn(
model=model,
overrides=overrides,
jitter_rvs=jitter_rvs,
jitter_bounds=jitter_bounds, # <---- new
return_transformed=True,
)
] * chains
elif len(overrides) == chains:
ipfns = [
make_initial_point_fn(
model=model,
jitter_rvs=jitter_rvs,
jitter_bounds=jitter_bounds, # <---- new
overrides=chain_overrides,
return_transformed=True,
)
for chain_overrides in overrides
]
def make_initial_point_expression(
*,
free_rvs: Sequence[TensorVariable],
rvs_to_transforms: dict[TensorVariable, Transform],
initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None],
jitter_rvs: set[TensorVariable] | None = None,
jitter_bounds: tuple[float, float] = (-1, 1), # <---- new
default_strategy: str = "support_point",
return_transformed: bool = False,
) -> list[TensorVariable]:
# ... existing code ...
if variable in jitter_rvs:
jitter = pt.random.uniform(
jitter_bounds[0], # <---- new
jitter_bounds[1],
size=value.shape
)
jitter.name = f"{variable.name}_jitter"
value = value + jitter
# ... existing code ...
Context for the issue:
To assist multi-path Pathfinder in exploring complicated posteriors (i.e., multimodal, flat or saddle point regions, or posteriors with several local modes that get stuck during optimisation), each single Pathfinder needs to be initialised over a broader region. This would require random initialisation points wider than Uniform(-1, 1). Attached is an image comparing wide and broad random initialisations from Figure 11 of the paper.
Allowing the uniform distribution bounds to be input parameters would allow users of the multi-path Pathfinder algorithm to adjust the initialisations to their scenario better.
I'm happy to work on this feature :) Any suggestions on how you'd like the changes to be made?
References: Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1–49.
If you pass the default strategy == ”prior" you'll get much more variation. Does that suffice?
Otherwise would a single jitter_scale argument suffice? Do you need to be able to specify both bounds? And will you need to vary this per variable?
Depending on the scenario, setting the default strategy == "prior" gives too much variation. Having jitter ~ U(-2, -2) may be too much variation, and for other problems, jitter ~ U(-20, 20) would be better. The option to set the jitter would be useful.
What's demonstrated in the Pathfinder paper is for jitter to have a single bound and its the same bound for all variables. For now single bound and for all variables should be okay.
I have implemented a workaround in this commit https://github.com/pymc-devs/pymc-experimental/pull/386 which has the initialisation needed for pathfinder:
def make_initial_pathfinder_point(
model,
jitter: float = 2.0,
random_seed: RandomSeed | None = None,
) -> DictToArrayBijection:
"""
create jittered initial point for pathfinder
Parameters
----------
model : Model
pymc model
jitter : float
initial values in the unconstrained space are jittered by the uniform distribution, U(-jitter, jitter). Set jitter to 0 for no jitter.
random_seed : RandomSeed | None
random seed for reproducibility
Returns
-------
DictToArrayBijection
bijection containing jittered initial point
"""
ipfn = make_initial_point_fn(
model=model,
)
ip = Point(ipfn(random_seed), model=model)
ip_map = DictToArrayBijection.map(ip)
rng = np.random.default_rng(random_seed)
jitter_value = rng.uniform(-jitter, jitter, size=ip_map.data.shape)
ip_map = ip_map._replace(data=ip_map.data + jitter_value)
return ip_map
Happy for this request to now be closed or remain open if the jitter~U(-n, n) option should be available in the pymc library as well.
Single jitter for all variables sounds fine, feel free to open a PR. Perhaps call it jitter_scale?
hi @ricardoV94, can we modify tis feature request to something like: ENH: Option to adjust start and jitter_scale for individual parameters?
This is motivated by findings shown in Figure 15. from the paper:
Initially thought that a single jitter for all parameters should be okay as a starting point. However, I think since there's been more tests on Pathfinder in pmx.fit, it'll be good to extend this feature request to enable more control over the initialisations, which has also been mentioned in blackjax#763.
ADVI offers this control via via start and start_sigma arguments in pm.fit(). And Pathfinder is quite sensitive to initialisations so it makes sense to have this available. I'm thinking of having arguments follow a similar typehint to ADVI.
Happy to work on this one (you can assign me to this). And of course, any implementation suggestions are welcome!
Sounds useful @aphc14, let's see how it looks like :)