pymc icon indicating copy to clipboard operation
pymc copied to clipboard

ENH: pymc.math.sum could not be observed

Open louiszhuang opened this issue 3 weeks ago • 4 comments

Describe the issue:

sum function will break sample if observed.

Reproduceable code example:

import pymc as pm
with pm.Model() as m:
    x = pm.Normal("x", mu=0, sigma=1e6)
    y = pm.Normal.dist(x, shape=(5,))
    y_sum = pm.Deterministic("y_sum", pm.math.sum(y))
with pm.observe(m, {"y_sum": 2.0}):
    trace = pm.sample()

Error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[18], line 3
      1 #%%
      2 with pm.observe(m, {"y_sum": 2.0}):
----> 3     trace = pm.sample(nuts_sampler='nutpie')

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:782, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    779     msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
    780     _log.warning(msg)
--> 782 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
    783 exclusive_nuts = (
    784     # User provided an instantiated NUTS step, and nothing else is needed
    785     (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
   (...)    792     )
    793 )
    795 if nuts_sampler != "pymc":

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:245, in assign_step_methods(model, step, methods)
    243 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    244 selected_steps: dict[type[BlockedStep], list] = {}
--> 245 model_logp = model.logp()
    247 for var in model.value_vars:
    248     if var not in assigned_vars:
    249         # determine if a gradient can be computed

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\model\core.py:714, in Model.logp(self, vars, jacobian, sum)
    712 rv_logps: list[TensorVariable] = []
    713 if rvs:
--> 714     rv_logps = transformed_conditional_logp(
    715         rvs=rvs,
    716         rvs_to_values=self.rvs_to_values,
    717         rvs_to_transforms=self.rvs_to_transforms,
    718         jacobian=jacobian,
    719     )
    720     assert isinstance(rv_logps, list)
    722 # Replace random variables by their value variables in potential terms

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:574, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    571     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore[arg-type]
    573 kwargs.setdefault("warn_rvs", False)
--> 574 temp_logp_terms = conditional_logp(
    575     rvs_to_values,
    576     extra_rewrites=transform_rewrite,
    577     use_jacobian=jacobian,
    578     **kwargs,
    579 )
    581 # The function returns the logp for every single value term we provided to it.
    582 # This includes the extra values we plugged in above, so we filter those we
    583 # actually wanted in the same order they were given in.
    584 logp_terms = {}

File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:531, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    529 missing_value_terms = set(original_values) - set(values_to_logprobs)
    530 if missing_value_terms:
--> 531     raise RuntimeError(
    532         f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
    533     )
    535 # Ensure same order as input
    536 logprobs = cleanup_ir(tuple(values_to_logprobs[v] for v in original_values))

RuntimeError: The logprob terms of the following value variables could not be derived: {TensorConstant(TensorType(float64, shape=()), data=array(2.))}

PyMC version information:

5.26.1

Context for the issue:

observing sum/max etc will be very helpful for many cases

louiszhuang avatar Dec 09 '25 21:12 louiszhuang

You can observe max of iid variables.

Sum isn't implemented because it's not general, only some distributions allow closed form.

Would be nice to add the Normal case for sure.

ricardoV94 avatar Dec 09 '25 22:12 ricardoV94

basic.py:475 # TODO: This seems too convoluted, can we just replace all RVs by their values, # except for the fgraph outputs (for which we want to call _logprob on)? for node in fgraph.toposort(): if not isinstance(node.op, MeasurableOp): continue

louiszhuang avatar Dec 09 '25 22:12 louiszhuang

You can observe max of iid variables.

Sum isn't implemented because it's not general, only some distributions allow closed form.

Would be nice to add the Normal case for sure.

you are exactly right - pymc.math.max actually work. I am new to the area - is closed form mandatory for sum to be implemented or that could be part of HMC process?

louiszhuang avatar Dec 09 '25 22:12 louiszhuang

It's necessary for the way observed works. We still end up building a density function passed to the samplers.

We don't interact with the samplers to achieve the goal of "observe". So either closed form or a numerical approximation within the density function.

Still would be nice to add the cases that have closed form

ricardoV94 avatar Dec 09 '25 22:12 ricardoV94

I traced this through Model.logp, conditional_logp, and how PyMC treats Deterministic vs Potential variables.

From what I can see, pm.math.sum creates a Deterministic variable. By design, deterministics don’t contribute to the model’s log-probability, so they can’t be observed. This matches how register_rv and observedlogp are implemented.

pm.math.max works only because it’s rewritten into a MeasurableOp with a custom _logprob implementation (in pymc/logprob/order.py). Supporting sum would require a similar approach, and would likely need distribution-specific log-probability rules (for example, a closed-form case for Normal variables).

The currently supported way to “observe” a sum is to encode it via a Potential.

Jiya873 avatar Dec 16 '25 20:12 Jiya873

Please stop posting LLM answers all over the repo. They're not giving any new insight that's not already in the discussion.

Even if it's not LLM consider changing your strategy because it doesn't seem helpful.

ricardoV94 avatar Dec 17 '25 23:12 ricardoV94

Sorry @ricardoV94, I’m new to open-source contribution and used LLM just to frame my understanding so that it looks professional.

Jiya873 avatar Dec 20 '25 08:12 Jiya873

Sure but we can also do that ourselves. You need to use it (if you so chose to) in a way that helps out.

Asking clarifying questions, trying some solutions (hopefully locally) to see where understanding breaks apart, fixing things that you've tested locally. You need not even restrict yourself to working on open issues. Testing and opening new bug reports is even better.

But mostly use your tools with judgement

ricardoV94 avatar Dec 20 '25 08:12 ricardoV94