pymc-examples icon indicating copy to clipboard operation
pymc-examples copied to clipboard

errors in `examples/time_series/bayesian_var_model.ipynb`

Open ikuzmin404 opened this issue 1 year ago • 0 comments

errors in examples/time_series/bayesian_var_model.ipynb

Error in bayesian_var_model.ipynb: Notebook url: https://github.com/pymc-devs/pymc-examples/tree/main/examples/time_series/bayesian_var_model.ipynb

Issue description

Error in imports: replace from pymc.sampling_jax import sample_blackjax_nuts with from pymc.sampling.jax import sample_blackjax_nuts

Error on creating betaX in make_model and make_hierarchical_model:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[44], line 1
----> 1 make_model(n_lags, n_eqs, df, priors)

Cell In[43], line 44, in make_model(n_lags, n_eqs, df, priors, mv_norm, prior_checks)
     41 data_obs = pm.Data("data_obs", df.values[n_lags:], dims=["time", "equations"])
     43 betaX = calc_ar_step(lag_coefs, n_eqs, n_lags, df)
---> 44 betaX = pm.Deterministic(
     45     "betaX",
     46     betaX,
     47     dims=[
     48         "time",
     49     ],
     50 )
     51 mean = alpha + betaX
     53 if mv_norm:

File c:\Users\Ivan\anaconda3\envs\pymc_env\Lib\site-packages\pymc\model\core.py:2254, in Deterministic(name, var, model, dims)
   2252 var = var.copy(model.name_for(name))
   2253 model.deterministics.append(var)
-> 2254 model.add_named_variable(var, dims)
   2256 from pymc.printing import str_for_potential_or_deterministic
   2258 var.str_repr = types.MethodType(
   2259     functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
   2260 )

File c:\Users\Ivan\anaconda3\envs\pymc_env\Lib\site-packages\pymc\model\core.py:1472, in Model.add_named_variable(self, var, dims)
   1470     # This check implicitly states that only vars with .ndim attribute can have dims
   1471     if var.ndim != len(dims):
-> 1472         raise ValueError(
   1473             f"{var} has {var.ndim} dims but {len(dims)} dim labels were provided."
   1474         )
   1475     self.named_vars_to_dims[var.name] = dims
   1477 self.named_vars[var.name] = var

ValueError: betaX has 2 dims but 1 dim labels were provided.

Proposed solution

Adding another dimension to broken piece of code (namely "equations") solves the problem:

betaX = pm.Deterministic(
            "betaX",
            betaX,
            dims=[
                "time",
                "equations",
            ],
        )

Another issue

This error, AFAIC, is purely Windows-related (see here). In function make_hierarchical_model this line breaks: idata.extend(sample_blackjax_nuts(2000, random_seed=120)). Same error with sample_numpyro_nuts.

First error is RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32. It's being fixed as in this issue.

But then another error shows:

TypeError: true_fun and false_fun output must have identical types, got
Proposal(state=IntegratorState(position=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs.  ShapedArray(float32[6])'], momentum=['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])'], logdensity='ShapedArray(float64[])', logdensity_grad=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs.  ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])']), energy='ShapedArray(float64[])', weight='ShapedArray(float64[])', sum_log_p_accept='ShapedArray(float64[])').

and I have no idea how to solve it.

Possible solution

The workaround (if it is not an issue for Linux systems) is to use simple pm.sample instead of sample_blackjax_nuts if code is running on Windows (can be checked with if os.name == 'nt' for example).

This behavior was also fixed in numpy 2.0 (link to release notes), so this solution may be temporary.

ikuzmin404 avatar Jan 23 '25 16:01 ikuzmin404