pymc icon indicating copy to clipboard operation
pymc copied to clipboard

BUG: ZeroSumTransform fails with initvalues

Open velochy opened this issue 7 months ago • 1 comments

Describe the issue:

Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.

It seems to be caused by input being a numpy array rather than a pytensor one.

Fix seems simple. Posting a PR for it next

Reproduceable code example:

import pymc as pm, numpy as np

with pm.Model() as model:
    pm.ZeroSumNormal('zsn',shape=(10,))
    pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
    mp = pm.find_MAP()

    pm.sample(initvals=mp)

Error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/home/velochy/salk/sandbox/sandy.ipynb Cell 1 line 8
      5 pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0]))
      6 mp = pm.find_MAP()
----> 8 pm.sample(initvals=mp)

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    830         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    831     with joined_blas_limiter():
--> 832         initial_points, step = init_nuts(
    833             init=init,
    834             chains=chains,
    835             n_init=n_init,
    836             model=model,
    837             random_seed=random_seed_list,
    838             progressbar=progress_bool,
    839             jitter_max_retries=jitter_max_retries,
    840             tune=tune,
    841             initvals=initvals,
    842             compile_kwargs=compile_kwargs,
    843             **kwargs,
    844         )
    845 else:
    846     # Get initial points
    847     ipfns = make_initial_point_fns_per_chain(
    848         model=model,
    849         overrides=initvals,
    850         jitter_rvs=set(),
    851         chains=chains,
    852     )

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
   1602     q, _ = DictToArrayBijection.map(ip)
   1603     return logp_dlogp_func([q], extra_vars={})[0]
-> 1605 initial_points = _init_jitter(
   1606     model,
   1607     initvals,
   1608     seeds=random_seed_list,
   1609     jitter="jitter" in init,
   1610     jitter_max_retries=jitter_max_retries,
   1611     logp_fn=model_logp_fn,
   1612 )
   1614 apoints = [DictToArrayBijection.map(point) for point in initial_points]
   1615 apoints_data = [apoint.data for apoint in apoints]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1462, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn)
   1432 def _init_jitter(
   1433     model: Model,
   1434     initvals: StartDict | Sequence[StartDict | None] | None,
   (...)
   1438     logp_fn: Callable[[PointType], np.ndarray] | None = None,
   1439 ) -> list[PointType]:
   1440     """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
   1441 
   1442     ``model.check_start_vals`` is used to test whether the jittered starting
   (...)
   1460         List of starting points for the sampler
   1461     """
-> 1462     ipfns = make_initial_point_fns_per_chain(
   1463         model=model,
   1464         overrides=initvals,
   1465         jitter_rvs=set(model.free_RVs) if jitter else set(),
   1466         chains=len(seeds),
   1467     )
   1469     if not jitter:
   1470         return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:101, in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains)
     72 """Create an initial point function for each chain, as defined by initvals.
     73 
     74 If a single initval dictionary is passed, the function is replicated for each
   (...)
     95 
     96 """
     97 if isinstance(overrides, dict) or overrides is None:
     98     # One strategy for all chains
     99     # Only one function compilation is needed.
    100     ipfns = [
--> 101         make_initial_point_fn(
    102             model=model,
    103             overrides=overrides,
    104             jitter_rvs=jitter_rvs,
    105             return_transformed=True,
    106         )
    107     ] * chains
    108 elif len(overrides) == chains:
    109     ipfns = [
    110         make_initial_point_fn(
    111             model=model,
   (...)
    116         for chain_overrides in overrides
    117     ]

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed)
    126 def make_initial_point_fn(
    127     *,
    128     model,
   (...)
    132     return_transformed: bool = True,
    133 ) -> Callable[[SeedSequenceSeed], PointType]:
    134     """Create seeded function that computes initial values for all free model variables.
    135 
    136     Parameters
   (...)
    150     initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
    151     """
--> 152     sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
    153     initval_strats = {
    154         **model.rvs_to_initial_values,
    155         **sdict_overrides,
    156     }
    158     initial_values = make_initial_point_expression(
    159         free_rvs=model.free_RVs,
    160         rvs_to_transforms=model.rvs_to_transforms,
   (...)
    164         return_transformed=return_transformed,
    165     )

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:57, in convert_str_to_rv_dict(model, start)
     55 if is_transformed_name(key):
     56     rv = model[get_untransformed_name(key)]
---> 57     initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)
     58 else:
     59     initvals[model[key]] = initval

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:309, in ZeroSumTransform.backward(self, value, *rv_inputs)
    307 def backward(self, value, *rv_inputs):
    308     for axis in self.zerosum_axes:
--> 309         value = self.extend_axis(value, axis=axis)
    310     return value

File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:281, in ZeroSumTransform.extend_axis(array, axis)
    279 @staticmethod
    280 def extend_axis(array, axis):
--> 281     n = (array.shape[axis] + 1).astype("floatX")
    282     sum_vals = array.sum(axis, keepdims=True)
    283     norm = sum_vals / (pt.sqrt(n) + n)

AttributeError: 'int' object has no attribute 'astype'

PyMC version information:

pymc 5.22.0

Context for the issue:

I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.

velochy avatar May 02 '25 07:05 velochy

One-line fix at https://github.com/pymc-devs/pymc/pull/7773

velochy avatar May 02 '25 07:05 velochy