BUG: posterior predictive sampling with mutable coords fails when shape is different
Describe the issue:
When using mutable coords in a model then calling set_data to change the coordinates to values that are a different size than the original results in a failure when posterior predictive sampling is carried out. Details below.
Reproduceable code example:
data = pd.read_csv(pm.get_data("efron-morris-75-data.tsv"), sep="\t")
N = len(data)
player_names = data["FirstName"] + " " + data["LastName"]
# coords = {"player_names": player_names.tolist()}
with pm.Model() as baseball_model:
at_bats = pm.MutableData("at_bats", data["At-Bats"].to_numpy())
n_hits = pm.MutableData("n_hits", data["Hits"].to_numpy())
baseball_model.add_coord("player_names", player_names, mutable=True)
phi = pm.Uniform("phi", lower=0.0, upper=1.0)
kappa_log = pm.Exponential("kappa_log", lam=1.5)
kappa = pm.Deterministic("kappa", pm.math.exp(kappa_log))
theta = pm.Beta("theta", alpha=phi * kappa, beta=(1.0 - phi) * kappa, dims="player_names")
y = pm.Binomial("y", n=at_bats, p=theta, observed=n_hits, dims="player_names")
trace = pm.sample()
with baseball_model:
pm.set_data(
{"at_bats": np.array([4], dtype=np.int32),
"n_hits": np.zeros(1, dtype=np.int32)},
coords={'player_names': np.array(["new_guy"])}
)
new_batter = pm.sample_posterior_predictive(trace)
Error message:
{
"name": "ValueError",
"message": "conflicting sizes for dimension 'player_names': length 1 on the data but length 18 on coordinate 'player_names'",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[16], line 7
1 with baseball_model:
2 pm.set_data(
3 {\"at_bats\": np.array([4], dtype=np.int32),
4 \"n_hits\": np.zeros(1, dtype=np.int32)},
5 coords={'player_names': np.array([\"new_guy\"])}
6 )
----> 7 new_batter = pm.sample_posterior_predictive(idata)
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/pymc/sampling/forward.py:673, in sample_posterior_predictive(trace, model, var_names, sample_dims, random_seed, progressbar, return_inferencedata, extend_inferencedata, predictions, idata_kwargs, compile_kwargs)
671 ikwargs.setdefault(\"inplace\", True)
672 return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
--> 673 idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
675 if extend_inferencedata and idata is not None:
676 idata.extend(idata_pp)
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/pymc/backends/arviz.py:523, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, coords, dims, sample_dims, model, save_warmup, include_transformed)
509 if isinstance(trace, InferenceData):
510 return trace
512 return InferenceDataConverter(
513 trace=trace,
514 prior=prior,
515 posterior_predictive=posterior_predictive,
516 log_likelihood=log_likelihood,
517 coords=coords,
518 dims=dims,
519 sample_dims=sample_dims,
520 model=model,
521 save_warmup=save_warmup,
522 include_transformed=include_transformed,
--> 523 ).to_inference_data()
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/pymc/backends/arviz.py:436, in InferenceDataConverter.to_inference_data(self)
423 def to_inference_data(self):
424 \"\"\"Convert all available data to an InferenceData object.
425
426 Note that if groups can not be created (e.g., there is no `trace`, so
427 the `posterior` and `sample_stats` can not be extracted), then the InferenceData
428 will not have those groups.
429 \"\"\"
430 id_dict = {
431 \"posterior\": self.posterior_to_xarray(),
432 \"sample_stats\": self.sample_stats_to_xarray(),
433 \"posterior_predictive\": self.posterior_predictive_to_xarray(),
434 \"predictions\": self.predictions_to_xarray(),
435 **self.priors_to_xarray(),
--> 436 \"observed_data\": self.observed_data_to_xarray(),
437 }
438 if self.predictions:
439 id_dict[\"predictions_constant_data\"] = self.constant_data_to_xarray()
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/arviz/data/base.py:65, in requires.__call__.<locals>.wrapped(cls)
63 if all((getattr(cls, prop_i) is None for prop_i in prop)):
64 return None
---> 65 return func(cls)
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/arviz/data/base.py:65, in requires.__call__.<locals>.wrapped(cls)
63 if all((getattr(cls, prop_i) is None for prop_i in prop)):
64 return None
---> 65 return func(cls)
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/pymc/backends/arviz.py:390, in InferenceDataConverter.observed_data_to_xarray(self)
388 if self.predictions:
389 return None
--> 390 return dict_to_dataset(
391 self.observations,
392 library=pymc,
393 coords=self.coords,
394 dims=self.dims,
395 default_dims=[],
396 )
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/arviz/data/base.py:306, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
303 if dims is None:
304 dims = {}
--> 306 data_vars = {
307 key: numpy_to_data_array(
308 values,
309 var_name=key,
310 coords=coords,
311 dims=dims.get(key),
312 default_dims=default_dims,
313 index_origin=index_origin,
314 skip_event_dims=skip_event_dims,
315 )
316 for key, values in data.items()
317 }
318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/arviz/data/base.py:307, in <dictcomp>(.0)
303 if dims is None:
304 dims = {}
306 data_vars = {
--> 307 key: numpy_to_data_array(
308 values,
309 var_name=key,
310 coords=coords,
311 dims=dims.get(key),
312 default_dims=default_dims,
313 index_origin=index_origin,
314 skip_event_dims=skip_event_dims,
315 )
316 for key, values in data.items()
317 }
318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/arviz/data/base.py:255, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
253 # filter coords based on the dims
254 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
--> 255 return xr.DataArray(ary, coords=coords, dims=dims)
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/xarray/core/dataarray.py:445, in DataArray.__init__(self, data, coords, dims, name, attrs, indexes, fastpath)
443 data = _check_data_shape(data, coords, dims)
444 data = as_compatible_data(data)
--> 445 coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
446 variable = Variable(dims, data, attrs, fastpath=True)
448 if not isinstance(coords, Coordinates):
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/xarray/core/dataarray.py:192, in _infer_coords_and_dims(shape, coords, dims)
189 var.dims = (dim,)
190 new_coords[dim] = var.to_index_variable()
--> 192 _check_coords_dims(shape, new_coords, dims)
194 return new_coords, dims
File ~/miniforge3/envs/pymc/lib/python3.11/site-packages/xarray/core/dataarray.py:130, in _check_coords_dims(shape, coords, dims)
128 for d, s in v.sizes.items():
129 if s != sizes[d]:
--> 130 raise ValueError(
131 f\"conflicting sizes for dimension {d!r}: \"
132 f\"length {sizes[d]} on the data but length {s} on \"
133 f\"coordinate {k!r}\"
134 )
ValueError: conflicting sizes for dimension 'player_names': length 1 on the data but length 18 on coordinate 'player_names'"
}
PyMC version information:
PyMC version 5.10.1 PyTensor version 2.18.1
Context for the issue:
No response
I think the problem is that the MutableData don't use the dims you are trying to override in set_data.
If you defined them like:
at_bats = pm.MutableData("at_bats", data["At-Bats"].to_numpy(), dims=("player_names",))
n_hits = pm.MutableData("n_hits", data["Hits"].to_numpy(), dims=("player_names",))
It should work and you won't even need the line
baseball_model.add_coord("player_names", player_names, mutable=True)
Also, if you want, you can pass mutable coords directly to pm.Model via the coords_mutable kwarg
Just encountered the same situation. It is indeed solved with what @ricardoV94 wrote above.
So that doesn't seem to be a bug, but that's arguably sneaky. Do you think there is a way to infer that under the hood, or do we need the user to be that explicit with the dims of MutableData?
or do we need the user to be that explicit with the dims of MutableData?
This
Ok, so it's not an open bug