pymc icon indicating copy to clipboard operation
pymc copied to clipboard

BUG: Setting None coord is problematic

Open ferrine opened this issue 2 years ago • 9 comments

Describe the issue:

When setting a coord None. Something goes broken in pymc 5.0.2.

Reproduceable code example:

# This is good
import pymc as pm
import numpy as np
with pm.Model(coords=dict(d1=range(2), d2=range(6))) as model:
    pm.Data("a", np.random.randn(2, 6), dims=("d1", "d2"), mutable=True)
    pm.Normal("b", 10)
    t = pm.sample(1, tune=1)

# this is broken
import pymc as pm
import numpy as np
with pm.Model(coords=dict(d1=range(2), d2=range(6))) as model:
    pm.Data("a", np.random.randn(2, 6), dims=(None, "d2"), mutable=True)
    pm.Normal("b", 10)
    t = pm.sample(1, tune=1)

Error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[14], line 6
      4 pm.Data("a", np.random.randn(2, 6), dims=(None, "d2"), mutable=True)
      5 pm.Normal("b", 10)
----> 6 t = pm.sample(1, tune=1)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/sampling/mcmc.py:612, in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, keep_warning_stat, idata_kwargs, mp_ctx, **kwargs)
    610 if idata_kwargs:
    611     ikwargs.update(idata_kwargs)
--> 612 idata = pm.to_inference_data(mtrace, **ikwargs)
    614 if compute_convergence_checks:
    615     if draws - tune < 100:

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/backends/arviz.py:485, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, coords, dims, sample_dims, model, save_warmup, include_transformed)
    482 if isinstance(trace, InferenceData):
    483     return trace
--> 485 return InferenceDataConverter(
    486     trace=trace,
    487     prior=prior,
    488     posterior_predictive=posterior_predictive,
    489     log_likelihood=log_likelihood,
    490     coords=coords,
    491     dims=dims,
    492     sample_dims=sample_dims,
    493     model=model,
    494     save_warmup=save_warmup,
    495     include_transformed=include_transformed,
    496 ).to_inference_data()

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/backends/arviz.py:414, in InferenceDataConverter.to_inference_data(self)
    412     id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
    413 else:
--> 414     id_dict["constant_data"] = self.constant_data_to_xarray()
    415 idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
    416 if self.log_likelihood:

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:65, in requires.__call__.<locals>.wrapped(cls)
     63     if all((getattr(cls, prop_i) is None for prop_i in prop)):
     64         return None
---> 65 return func(cls)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/backends/arviz.py:388, in InferenceDataConverter.constant_data_to_xarray(self)
    385 if not constant_data:
    386     return None
--> 388 return dict_to_dataset(
    389     constant_data,
    390     library=pymc,
    391     coords=self.coords,
    392     dims=self.dims,
    393     default_dims=[],
    394 )

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:306, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
    303 if dims is None:
    304     dims = {}
--> 306 data_vars = {
    307     key: numpy_to_data_array(
    308         values,
    309         var_name=key,
    310         coords=coords,
    311         dims=dims.get(key),
    312         default_dims=default_dims,
    313         index_origin=index_origin,
    314         skip_event_dims=skip_event_dims,
    315     )
    316     for key, values in data.items()
    317 }
    318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:307, in <dictcomp>(.0)
    303 if dims is None:
    304     dims = {}
    306 data_vars = {
--> 307     key: numpy_to_data_array(
    308         values,
    309         var_name=key,
    310         coords=coords,
    311         dims=dims.get(key),
    312         default_dims=default_dims,
    313         index_origin=index_origin,
    314         skip_event_dims=skip_event_dims,
    315     )
    316     for key, values in data.items()
    317 }
    318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:255, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
    253 # filter coords based on the dims
    254 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
--> 255 return xr.DataArray(ary, coords=coords, dims=dims)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/xarray/core/dataarray.py:419, in DataArray.__init__(self, data, coords, dims, name, attrs, indexes, fastpath)
    417 data = _check_data_shape(data, coords, dims)
    418 data = as_compatible_data(data)
--> 419 coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
    420 variable = Variable(dims, data, attrs, fastpath=True)
    421 indexes, coords = _create_indexes_from_coords(coords)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/xarray/core/dataarray.py:164, in _infer_coords_and_dims(shape, coords, dims)
    162 for d, s in zip(v.dims, v.shape):
    163     if s != sizes[d]:
--> 164         raise ValueError(
    165             f"conflicting sizes for dimension {d!r}: "
    166             f"length {sizes[d]} on the data but length {s} on "
    167             f"coordinate {k!r}"
    168         )
    170 if k in sizes and v.shape != (sizes[k],):
    171     raise ValueError(
    172         f"coordinate {k!r} is a DataArray dimension, but "
    173         f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} "
    174         "matching the dimension size"
    175     )

ValueError: conflicting sizes for dimension 'd2': length 2 on the data but length 6 on coordinate 'd2'

PyMC version information:

pymc 5.0.2

Context for the issue:

No response

ferrine avatar Jan 27 '23 13:01 ferrine

None dims were disabled in https://github.com/pymc-devs/pymc/pull/6470 exactly because of this. They only worked when they were in the rightmost position.

ricardoV94 avatar Jan 27 '23 13:01 ricardoV94

Seems like the error is misleading then. It should prevent me from using None coord in the first place. However, I think it is a valid case. Why not just replacing None coord with anon coord?

ferrine avatar Jan 27 '23 13:01 ferrine

If you try in main it should raise directly with a TypeError

ricardoV94 avatar Jan 27 '23 13:01 ricardoV94

Thank you!

ferrine avatar Jan 27 '23 13:01 ferrine

Ah that's for Data, not distributions. I don't know about that. In main it doesn't raise an error but I am not sure what it does.

Why not just replacing None coord with anon coord?

Someone needs to implement it and make sure it works with everything we have Data, Deterministics, Variables. I made it raise an error for RVs in that linked PR because the infrastructure was not there yet.

ricardoV94 avatar Jan 27 '23 13:01 ricardoV94

Maybe the best would be to give dummy dims (without coordinates) as soon as we register a variable in the model

ricardoV94 avatar Jan 27 '23 13:01 ricardoV94

Is there a reason this is postponed?

ferrine avatar Jan 27 '23 13:01 ferrine

I don't see what is gained by using None, you can just give it any name. Why not use pm.Data("a", np.random.randn(2, 6), dims=("unamed_dim", "d2"), mutable=True) or calling the dimension "none" instead of using None and adding code to convert that into a random dimension name? In pm.Data instances dims don't need to be things in the coords argument of the model, they will get registered automatically.

OriolAbril avatar Jan 27 '23 19:01 OriolAbril

Although I'm not a fan, None has the advantage that it won't name clash. So you don't need to worry whether a dim with the same name was already given to something else. This showed up in #7390

ricardoV94 avatar Jun 27 '24 11:06 ricardoV94