ENH: pymc.math.sum could not be observed
Describe the issue:
sum function will break sample if observed.
Reproduceable code example:
import pymc as pm
with pm.Model() as m:
x = pm.Normal("x", mu=0, sigma=1e6)
y = pm.Normal.dist(x, shape=(5,))
y_sum = pm.Deterministic("y_sum", pm.math.sum(y))
with pm.observe(m, {"y_sum": 2.0}):
trace = pm.sample()
Error message:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[18], line 3
1 #%%
2 with pm.observe(m, {"y_sum": 2.0}):
----> 3 trace = pm.sample(nuts_sampler='nutpie')
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:782, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
779 msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
780 _log.warning(msg)
--> 782 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
783 exclusive_nuts = (
784 # User provided an instantiated NUTS step, and nothing else is needed
785 (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
(...) 792 )
793 )
795 if nuts_sampler != "pymc":
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:245, in assign_step_methods(model, step, methods)
243 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
244 selected_steps: dict[type[BlockedStep], list] = {}
--> 245 model_logp = model.logp()
247 for var in model.value_vars:
248 if var not in assigned_vars:
249 # determine if a gradient can be computed
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\model\core.py:714, in Model.logp(self, vars, jacobian, sum)
712 rv_logps: list[TensorVariable] = []
713 if rvs:
--> 714 rv_logps = transformed_conditional_logp(
715 rvs=rvs,
716 rvs_to_values=self.rvs_to_values,
717 rvs_to_transforms=self.rvs_to_transforms,
718 jacobian=jacobian,
719 )
720 assert isinstance(rv_logps, list)
722 # Replace random variables by their value variables in potential terms
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:574, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
571 transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore[arg-type]
573 kwargs.setdefault("warn_rvs", False)
--> 574 temp_logp_terms = conditional_logp(
575 rvs_to_values,
576 extra_rewrites=transform_rewrite,
577 use_jacobian=jacobian,
578 **kwargs,
579 )
581 # The function returns the logp for every single value term we provided to it.
582 # This includes the extra values we plugged in above, so we filter those we
583 # actually wanted in the same order they were given in.
584 logp_terms = {}
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:531, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
529 missing_value_terms = set(original_values) - set(values_to_logprobs)
530 if missing_value_terms:
--> 531 raise RuntimeError(
532 f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
533 )
535 # Ensure same order as input
536 logprobs = cleanup_ir(tuple(values_to_logprobs[v] for v in original_values))
RuntimeError: The logprob terms of the following value variables could not be derived: {TensorConstant(TensorType(float64, shape=()), data=array(2.))}
PyMC version information:
5.26.1
Context for the issue:
observing sum/max etc will be very helpful for many cases
You can observe max of iid variables.
Sum isn't implemented because it's not general, only some distributions allow closed form.
Would be nice to add the Normal case for sure.
basic.py:475 # TODO: This seems too convoluted, can we just replace all RVs by their values, # except for the fgraph outputs (for which we want to call _logprob on)? for node in fgraph.toposort(): if not isinstance(node.op, MeasurableOp): continue
You can observe max of iid variables.
Sum isn't implemented because it's not general, only some distributions allow closed form.
Would be nice to add the Normal case for sure.
you are exactly right - pymc.math.max actually work. I am new to the area - is closed form mandatory for sum to be implemented or that could be part of HMC process?
It's necessary for the way observed works. We still end up building a density function passed to the samplers.
We don't interact with the samplers to achieve the goal of "observe". So either closed form or a numerical approximation within the density function.
Still would be nice to add the cases that have closed form
I traced this through Model.logp, conditional_logp, and how PyMC treats Deterministic vs Potential variables.
From what I can see, pm.math.sum creates a Deterministic variable. By design, deterministics don’t contribute to the model’s log-probability, so they can’t be observed. This matches how register_rv and observedlogp are implemented.
pm.math.max works only because it’s rewritten into a MeasurableOp with a custom _logprob implementation (in pymc/logprob/order.py). Supporting sum would require a similar approach, and would likely need distribution-specific log-probability rules (for example, a closed-form case for Normal variables).
The currently supported way to “observe” a sum is to encode it via a Potential.
Please stop posting LLM answers all over the repo. They're not giving any new insight that's not already in the discussion.
Even if it's not LLM consider changing your strategy because it doesn't seem helpful.
Sorry @ricardoV94, I’m new to open-source contribution and used LLM just to frame my understanding so that it looks professional.
Sure but we can also do that ourselves. You need to use it (if you so chose to) in a way that helps out.
Asking clarifying questions, trying some solutions (hopefully locally) to see where understanding breaks apart, fixing things that you've tested locally. You need not even restrict yourself to working on open issues. Testing and opening new bug reports is even better.
But mostly use your tools with judgement