pymc
pymc copied to clipboard
Port VI to use RandomVariable
The current usage of the mrng sampler in ADVI does not support JAX (https://github.com/pymc-devs/aesara/issues/322). It should be fairly easy to instead make it use the new RandomVariable Op https://github.com/pymc-devs/aesara/pull/296 to get JAX support for ADVI.
CC @ferrine
@brandonwillard should I use the master branch or V4 for development?
We've been putting in PRs to the v4 branch.
So yes, I'll open a PR to v4 branch:tada:
Oh, wait, the change requested here could go toward v3 or v4.
I would prioritize v4, but we need to port the VI code to v4 first and foremost, and that work is independent of this exact issue. Both could be done simultaneously, though.
@brandonwillard, running pymc3/tests/test_variational_inference.py I have errors unrelataed to VI, is it expected?
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises0-grouping0]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises1-grouping1]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises2-grouping2]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises3-grouping3]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises4-grouping4]
ERROR pymc3/tests/test_variational_inference.py::test_init_groups[raises5-grouping5]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_sample_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_replacements_in_sample_node_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_1_sample_1_var[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_2_sample_2_var[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_mini_sample_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: None]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[FullRankGroup: None, MeanFieldGroup: ['one']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], FullRankGroup: ['two'], NormalizingFlowGroup: ['three']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], FullRankGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_logq_globals[MeanFieldGroup: ['one'], EmpiricalGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises0-mean_field-MeanFieldGroup-kw0]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises1-mf-MeanFieldGroup-kw1]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises2-full_rank-FullRankGroup-kw2]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises3-fr-FullRankGroup-kw3]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises4-FR-FullRankGroup-kw4]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises5-loc-NormalizingFlowGroup-kw5]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises6-scale-NormalizingFlowGroup-kw6]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises7-hh-NormalizingFlowGroup-kw7]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises8-planar-NormalizingFlowGroup-kw8]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises9-radial-NormalizingFlowGroup-kw9]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises10-scale-loc-NormalizingFlowGroup-kw10]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises11-empirical-EmpiricalGroup-kw11]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_vfam[raises12-empirical-EmpiricalGroup-kw12]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises0-params0-MeanFieldGroup-kw0-None]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises1-params1-FullRankGroup-kw1-None]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises2-params2-NormalizingFlowGroup-kw2-loc]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises3-params3-NormalizingFlowGroup-kw3-scale]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises4-params4-NormalizingFlowGroup-kw4-hh]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises5-params5-NormalizingFlowGroup-kw5-planar]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises6-params6-NormalizingFlowGroup-kw6-radial]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises7-params7-NormalizingFlowGroup-kw7-scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_group_api_params[raises8-params8-EmpiricalGroup-kw8-None]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[MeanFieldGroup-MeanField-kw0]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[FullRankGroup-FullRank-kw1]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[EmpiricalGroup-Empirical-kw2]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw3]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw4]
ERROR pymc3/tests/test_variational_inference.py::test_single_group_shortcuts[NormalizingFlowGroup-NormalizingFlow-kw5]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[NFVI=scale-loc-mini]
ERROR pymc3/tests/test_variational_inference.py::test_profile[NFVI=scale-loc-mini]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[NFVI=scale-loc-full]
ERROR pymc3/tests/test_variational_inference.py::test_profile[NFVI=scale-loc-full]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ADVI-full] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ADVI-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ADVI-mini] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ADVI-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[ADVI] - Deprecatio...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[ADVI] - Ty...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[ADVI]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[FullRankADVI-full]
ERROR pymc3/tests/test_variational_inference.py::test_profile[FullRankADVI-full]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[FullRankADVI-mini]
ERROR pymc3/tests/test_variational_inference.py::test_profile[FullRankADVI-mini]
ERROR pymc3/tests/test_variational_inference.py::test_aevb[FullRankADVI] - De...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[FullRankADVI]
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[FullRankADVI]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[SVGD-full] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[SVGD-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[SVGD-mini] - Att...
ERROR pymc3/tests/test_variational_inference.py::test_profile[SVGD-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[SVGD] - Deprecatio...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[SVGD] - Ty...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[SVGD]
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ASVGD-full] - At...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ASVGD-full] - A...
ERROR pymc3/tests/test_variational_inference.py::test_fit_oo[ASVGD-mini] - At...
ERROR pymc3/tests/test_variational_inference.py::test_profile[ASVGD-mini] - A...
ERROR pymc3/tests/test_variational_inference.py::test_aevb[ASVGD] - Deprecati...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[ASVGD] - T...
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[ASVGD]
ERROR pymc3/tests/test_variational_inference.py::test_aevb[NFVI=scale-loc] - ...
ERROR pymc3/tests/test_variational_inference.py::test_replacements[NFVI=scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_sample_replacements[NFVI=scale-loc]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_rowwise_approx[NormalizingFlowGroup: {'flow': 'radial-loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: None]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[FullRankGroup: None, MeanFieldGroup: ['one']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], FullRankGroup: ['two'], NormalizingFlowGroup: ['three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], FullRankGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx[MeanFieldGroup: ['one'], EmpiricalGroup: ['two', 'three']]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_single_group - D...
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[MeanFieldGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[FullRankGroup: {}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'scale'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'loc'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'hh'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'planar'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'radial'}]
ERROR pymc3/tests/test_variational_inference.py::test_pickle_approx_aevb[NormalizingFlowGroup: {'flow': 'radial-loc'}]
I have aesara v2.0.2 there
using aesara master results in the same
I have errors like this
self = <pymc3.step_methods.metropolis.Metropolis object at 0x7f44b6c0efd0>, vars = [x], S = None, proposal_dist = None, scaling = 1.0, tune = True, tune_interval = 100
model = <pymc3.model.Model object at 0x7f44b64c5430>, mode = None, kwargs = {}
def __init__(
self,
vars=None,
S=None,
proposal_dist=None,
scaling=1.0,
tune=True,
tune_interval=100,
model=None,
mode=None,
**kwargs
):
"""Create an instance of a Metropolis stepper
Parameters
----------
vars: list
List of variables for sampler
S: standard deviation or covariance matrix
Some measure of variance to parameterize proposal distribution
proposal_dist: function
Function that returns zero-mean deviates when parameterized with
S (and n). Defaults to normal.
scaling: scalar or array
Initial scale factor for proposal. Defaults to 1.
tune: bool
Flag for tuning. Defaults to True.
tune_interval: int
The frequency of tuning. Defaults to 100 iterations.
model: PyMC Model
Optional model for sampling step. Defaults to None (taken from context).
mode: string or `Mode` instance.
compilation mode passed to Aesara functions
"""
model = pm.modelcontext(model)
if vars is None:
vars = model.vars
vars = pm.inputvars(vars)
if S is None:
# XXX: This needs to be refactored
S = None # np.ones(sum(v.dsize for v in vars))
if proposal_dist is not None:
self.proposal_dist = proposal_dist(S)
> elif S.ndim == 1:
E AttributeError: 'NoneType' object has no attribute 'ndim'
Yes, you should expect all sorts of errors.
Does any of this change with the Numba backend?
Does any of this change with the Numba backend?
It shouldn't; in general, if we use Aesara, its backends should handle everything without requiring backend-specific logic at this level.
It is not working with pymc 4:
import pymc
with pm.Model():
x = pm.Normal("x")
pm.fit()
WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "Op: {op}")
NotImplementedError: No JAX conversion for the given Op: DotModulo
@bukson please open an issue on our discourse at: https://discourse.pymc.io/
No this is a legit issue, we should port VI to use random variable: https://github.com/aesara-devs/aesara/issues/322#issuecomment-1247731089. CC @ferrine