pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Port VI to use RandomVariable

Open twiecki opened this issue 4 years ago • 14 comments

The current usage of the mrng sampler in ADVI does not support JAX (https://github.com/pymc-devs/aesara/issues/322). It should be fairly easy to instead make it use the new RandomVariable Op https://github.com/pymc-devs/aesara/pull/296 to get JAX support for ADVI.

CC @ferrine

twiecki avatar Mar 10 '21 06:03 twiecki

@brandonwillard should I use the master branch or V4 for development?

ferrine avatar Mar 13 '21 19:03 ferrine

We've been putting in PRs to the v4 branch.

brandonwillard avatar Mar 13 '21 19:03 brandonwillard

So yes, I'll open a PR to v4 branch:tada:

ferrine avatar Mar 13 '21 19:03 ferrine

Oh, wait, the change requested here could go toward v3 or v4.

I would prioritize v4, but we need to port the VI code to v4 first and foremost, and that work is independent of this exact issue. Both could be done simultaneously, though.

brandonwillard avatar Mar 13 '21 19:03 brandonwillard

@brandonwillard, running pymc3/tests/test_variational_inference.py I have errors unrelataed to VI, is it expected?

ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises0-grouping0]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises1-grouping1]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises2-grouping2]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises3-grouping3]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises4-grouping4]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises5-grouping5]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: None]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[FullRankGroup: None, MeanFieldGroup: ['one']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], FullRankGroup: ['two'], NormalizingFlowGroup: ['three']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], FullRankGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], EmpiricalGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises0-mean_field-MeanFieldGroup-kw0]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises1-mf-MeanFieldGroup-kw1]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises2-full_rank-FullRankGroup-kw2]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises3-fr-FullRankGroup-kw3]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises4-FR-FullRankGroup-kw4]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises5-loc-NormalizingFlowGroup-kw5]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises6-scale-NormalizingFlowGroup-kw6]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises7-hh-NormalizingFlowGroup-kw7]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises8-planar-NormalizingFlowGroup-kw8]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises9-radial-NormalizingFlowGroup-kw9]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises10-scale-loc-NormalizingFlowGroup-kw10]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises11-empirical-EmpiricalGroup-kw11]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises12-empirical-EmpiricalGroup-kw12]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises0-params0-MeanFieldGroup-kw0-None]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises1-params1-FullRankGroup-kw1-None]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises2-params2-NormalizingFlowGroup-kw2-loc]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises3-params3-NormalizingFlowGroup-kw3-scale]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises4-params4-NormalizingFlowGroup-kw4-hh]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises5-params5-NormalizingFlowGroup-kw5-planar]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises6-params6-NormalizingFlowGroup-kw6-radial]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises7-params7-NormalizingFlowGroup-kw7-scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises8-params8-EmpiricalGroup-kw8-None]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[MeanFieldGroup-MeanField-kw0]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[FullRankGroup-FullRank-kw1]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[EmpiricalGroup-Empirical-kw2]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw3]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw4]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw5]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[NFVI=scale-loc-mini]
ERROR pymc3/tests/test_variational_inference.py::test_profile[NFVI=scale-loc-mini]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[NFVI=scale-loc-full]
ERROR pymc3/tests/test_variational_inference.py::test_profile[NFVI=scale-loc-full]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ADVI-full] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ADVI-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ADVI-mini] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ADVI-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[ADVI] - Deprecatio...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[ADVI] - Ty...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[ADVI]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[FullRankADVI-full]
ERROR pymc3/tests/test_variational_inference.py::test_profile[FullRankADVI-full]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[FullRankADVI-mini]
ERROR pymc3/tests/test_variational_inference.py::test_profile[FullRankADVI-mini]
ERROR pymc3/tests/test_variational_inference.py::test_aevb[FullRankADVI] - De...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[FullRankADVI]
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[FullRankADVI]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[SVGD-full] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[SVGD-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[SVGD-mini] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[SVGD-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[SVGD] - Deprecatio...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[SVGD] - Ty...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[SVGD]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ASVGD-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ASVGD-full] - A...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ASVGD-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ASVGD-mini] - A...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[ASVGD] - Deprecati...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[ASVGD] - T...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[ASVGD]
ERROR pymc3/tests/test_variational_inference.py::test_aevb[NFVI=scale-loc] - ...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[NFVI=scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[NFVI=scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: None]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[FullRankGroup: None, MeanFieldGroup: ['one']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], FullRankGroup: ['two'], NormalizingFlowGroup: ['three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], FullRankGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], EmpiricalGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_single_group - D...
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]

ferrine avatar Mar 14 '21 07:03 ferrine

I have aesara v2.0.2 there

ferrine avatar Mar 14 '21 07:03 ferrine

using aesara master results in the same

ferrine avatar Mar 14 '21 07:03 ferrine

I have errors like this

self = <pymc3.step_methods.metropolis.Metropolis object at 0x7f44b6c0efd0>, vars = [x], S = None, proposal_dist = None, scaling = 1.0, tune = True, tune_interval = 100
model = <pymc3.model.Model object at 0x7f44b64c5430>, mode = None, kwargs = {}

    def __init__(
        self,
        vars=None,
        S=None,
        proposal_dist=None,
        scaling=1.0,
        tune=True,
        tune_interval=100,
        model=None,
        mode=None,
        **kwargs
    ):
        """Create an instance of a Metropolis stepper

        Parameters
        ----------
        vars: list
            List of variables for sampler
        S: standard deviation or covariance matrix
            Some measure of variance to parameterize proposal distribution
        proposal_dist: function
            Function that returns zero-mean deviates when parameterized with
            S (and n). Defaults to normal.
        scaling: scalar or array
            Initial scale factor for proposal. Defaults to 1.
        tune: bool
            Flag for tuning. Defaults to True.
        tune_interval: int
            The frequency of tuning. Defaults to 100 iterations.
        model: PyMC Model
            Optional model for sampling step. Defaults to None (taken from context).
        mode: string or `Mode` instance.
            compilation mode passed to Aesara functions
        """

        model = pm.modelcontext(model)

        if vars is None:
            vars = model.vars
        vars = pm.inputvars(vars)

        if S is None:
            # XXX: This needs to be refactored
            S = None  # np.ones(sum(v.dsize for v in vars))

        if proposal_dist is not None:
            self.proposal_dist = proposal_dist(S)
>       elif S.ndim == 1:
E       AttributeError: 'NoneType' object has no attribute 'ndim'

ferrine avatar Mar 14 '21 07:03 ferrine

Yes, you should expect all sorts of errors.

brandonwillard avatar Mar 14 '21 14:03 brandonwillard

Does any of this change with the Numba backend?

fonnesbeck avatar Jul 02 '21 14:07 fonnesbeck

Does any of this change with the Numba backend?

It shouldn't; in general, if we use Aesara, its backends should handle everything without requiring backend-specific logic at this level.

brandonwillard avatar Jul 28 '21 17:07 brandonwillard

It is not working with pymc 4:

import pymc

with pm.Model():
    x = pm.Normal("x")
    pm.fit()

WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions. Traceback (most recent call last): File "/usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "", line 5, in <cell line: 3> pm.fit() File "/usr/local/lib/python3.10/site-packages/pymc/variational/inference.py", line 744, in fit return inference.fit(n, **kwargs) File "/usr/local/lib/python3.10/site-packages/pymc/variational/inference.py", line 138, in fit step_func = self.objective.step_function(score=score, **kwargs) File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res return f(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 355, in step_function updates = self.updates( File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 244, in updates self.add_obj_updates( File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 289, in add_obj_updates obj_target = self( File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res return f(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 407, in call a = self.approx.set_size_and_deterministic(a, nmc, 0, kwargs.get("more_replacements")) File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res return f(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 1359, in set_size_and_deterministic flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements) File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 1333, in make_size_and_deterministic_replacements flat2rand.update(g.make_size_and_deterministic_replacements(s, d, more_replacements)) File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 1067, in make_size_and_deterministic_replacements initial = self._new_initial(s, d, more_replacements) File "/usr/local/lib/python3.10/site-packages/pymc/variational/opvi.py", line 978, in _new_initial return getattr(self._rng, dist_name)(size=shape) File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 1184, in normal uniform = self.uniform( File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 914, in uniform rstates = self.get_substream_rstates(nstreams, dtype) File "/usr/local/lib/python3.10/site-packages/aesara/configparser.py", line 47, in res return f(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 818, in get_substream_rstates multMatVect(rval[0], A1p72, M1, A2p72, M2) File "/usr/local/lib/python3.10/site-packages/aesara/sandbox/rng_mrg.py", line 66, in multMatVect multMatVect.dot_modulo = function( File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/init.py", line 317, in function fn = pfunc( File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/pfunc.py", line 374, in pfunc return orig_function( File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1763, in orig_function fn = m.create(defaults) File "/usr/local/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1656, in create _fn, _i, _o = self.linker.make_thunk( File "/usr/local/lib/python3.10/site-packages/aesara/link/basic.py", line 254, in make_thunk return self.make_all( File "/usr/local/lib/python3.10/site-packages/aesara/link/basic.py", line 697, in make_all thunks, nodes, jit_fn = self.create_jitable_thunk( File "/usr/local/lib/python3.10/site-packages/aesara/link/basic.py", line 646, in create_jitable_thunk converted_fgraph = self.fgraph_convert( File "/usr/local/lib/python3.10/site-packages/aesara/link/jax/linker.py", line 13, in fgraph_convert return jax_funcify(fgraph, **kwargs) File "/usr/local/lib/python3.10/functools.py", line 889, in wrapper return dispatch(args[0].class)(*args, **kw) File "/usr/local/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 670, in jax_funcify_FunctionGraph return fgraph_to_python( File "/usr/local/lib/python3.10/site-packages/aesara/link/utils.py", line 741, in fgraph_to_python compiled_func = op_conversion_fn( File "/usr/local/lib/python3.10/functools.py", line 889, in wrapper return dispatch(args[0].class)(*args, **kw) File "/usr/local/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 143, in jax_funcify raise NotImplementedError(f"No JAX conversion for the given Op: {op}") NotImplementedError: No JAX conversion for the given Op: DotModulo

bukson avatar Aug 26 '22 14:08 bukson

@bukson please open an issue on our discourse at: https://discourse.pymc.io/

ricardoV94 avatar Aug 26 '22 14:08 ricardoV94

No this is a legit issue, we should port VI to use random variable: https://github.com/aesara-devs/aesara/issues/322#issuecomment-1247731089. CC @ferrine

twiecki avatar Sep 15 '22 08:09 twiecki