funsor icon indicating copy to clipboard operation
funsor copied to clipboard

Gaussian funsor variable elimination

Open fritzo opened this issue 3 years ago β€’ 7 comments

Addresses https://github.com/pyro-ppl/pyro/pull/2929 See design doc

This issue tracks changes needed to efficiently perform variable elimination in Gaussian graphical models with plates. While funsor.sum_product.sum_product() is a partial solution, we'd like to generalize to a complete solution.

Tasks

  • [x] Introduce a new Funsor ConditionalGaussian(info_vec, precision, conditional, inputs) representing the batched conditional distribution of the rightmost real input variable, conditioned on other real input variables. This could be (i) a new Funsor in addition to Gaussian, (ii) a replacement or generalization of Gaussian, or (iii) a special case of Gaussian where the input info_vec and precision are structured (requires #556). This may allow cheaper linear algebra.

    Alternatively #567 Temporary Workaround: naively scatter the three parameters (info_vec, precision, conditional) into a dense Gaussian. This can be much more computationally expensive.

  • [ ] Handle collider variables where a latent variable outside a plate depends on an upstream latent variable inside a plate, thereby coupling the upstream variables via moralization. Currently such problems cannot even be specified in the plated-einsum DSL. Temporary workaround: Globally break all plates out of which any arrow leads; equivalent to .to_event().

  • [ ] Handle complete bipartite graphs resulting from the RBM motif (x_i --> y_ij <-- z_j). Currently sum_product() and the TVE algorithm give up in this case with "intractable!". Temporary workaround: no known workaround

fritzo avatar Sep 30 '21 17:09 fritzo

@eb8680 it looks like AutoGaussian(pyrocov_model) runs out of GPU memory in constructing a low-rank matrices precision = sqrt @ sqrt.T. One possible solution is to use a sqrt(precision) representation in funsor's Gaussian. I guess the crux is whether we can implement cheap Gaussian tensordot without materializing intermediate low-rank precision matrices. @fehiepsi already worked out most of the sqrt representation in Pyro PR #2019, where ops.add becomes mere concatenation.

@fehiepsi how much effort do you think it would it take for us to port your Pyro PR #2019 to funsor (where it would also be available in NumPyro πŸ˜‰)?

fritzo avatar Oct 05 '21 18:10 fritzo

Here is the optimized GFVE schedule for my pyro-cov model. It fits in main memory but runs out of GPU memory.

Contraction(ops.null, ops.add,
 frozenset(),
 (Contraction(ops.logaddexp, ops.add,
   frozenset({Variable('rate_loc_scale__BOUND_13', Real)}),
   (Gaussian(
   β”‚ torch.tensor(...1..., dtype=torch.float32),
   β”‚ torch.tensor(...1 x 1..., dtype=torch.float32),
   β”‚ (('rate_loc_scale__BOUND_13', Real),)),
   β”‚Contraction(ops.logaddexp, ops.add,
   β”‚ frozenset({Variable('rate_scale__BOUND_14', Real)}),
   β”‚ (Gaussian(
   β”‚   torch.tensor(...1..., dtype=torch.float32),
   β”‚   torch.tensor(...1 x 1..., dtype=torch.float32),
   β”‚   (('rate_scale__BOUND_14', Real),)),
   β”‚  Contraction(ops.logaddexp, ops.add,
   β”‚   frozenset({Variable('coef__BOUND_12', Reals[2367])}),
   β”‚   (Gaussian(
   β”‚   β”‚ torch.tensor(...2367..., dtype=torch.float32),
   β”‚   β”‚ torch.tensor(...2367 x 2367..., dtype=torch.float32),
   β”‚   β”‚ (('coef__BOUND_12', Reals[2367]),)),
   β”‚   β”‚Contraction(ops.add, ops.null,
   β”‚   β”‚ frozenset({Variable('strain__BOUND_11', Bint[1343])}),
   β”‚   β”‚ (Contraction(ops.logaddexp, ops.add,
   β”‚   β”‚   frozenset({Variable('rate_loc__BOUND_10', Real)}),
   β”‚   β”‚   (Gaussian(
   β”‚   β”‚   β”‚ torch.tensor(...1343 x 2369..., dtype=torch.float32),
   β”‚   β”‚   β”‚ torch.tensor(...1343 x 2369 x 2369..., dtype=torch.float32),
   β”‚   β”‚   β”‚ (('strain__BOUND_11', Bint[1343]),
   β”‚   β”‚   β”‚  ('rate_loc__BOUND_10', Real),
   β”‚   β”‚   β”‚  ('rate_loc_scale__BOUND_13', Real),
   β”‚   β”‚   β”‚  ('coef__BOUND_12', Reals[2367]),)),
   β”‚   β”‚   β”‚Contraction(ops.add, ops.null,
   β”‚   β”‚   β”‚ frozenset({Variable('place__BOUND_4', Bint[1372])}),
   β”‚   β”‚   β”‚ (Contraction(ops.logaddexp, ops.null,
   β”‚   β”‚   β”‚   frozenset({Variable('rate__BOUND_3', Real)}),
   β”‚   β”‚   β”‚   (Gaussian(
   β”‚   β”‚   β”‚   β”‚ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚   β”‚ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚   β”‚ (('place__BOUND_4', Bint[1372]),
   β”‚   β”‚   β”‚   β”‚  ('strain__BOUND_11', Bint[1343]),
   β”‚   β”‚   β”‚   β”‚  ('rate__BOUND_3', Real),
   β”‚   β”‚   β”‚   β”‚  ('rate_scale__BOUND_14', Real),
   β”‚   β”‚   β”‚   β”‚  ('rate_loc__BOUND_10', Real),)),)),)),)),)),)),)),)),
  Contraction(ops.null, ops.add,
   frozenset(),
   (Contraction(ops.logaddexp, ops.add,
   β”‚ frozenset({Variable('pois_loc__BOUND_16', Real)}),
   β”‚ (Gaussian(
   β”‚   torch.tensor(...1..., dtype=torch.float32),
   β”‚   torch.tensor(...1 x 1..., dtype=torch.float32),
   β”‚   (('pois_loc__BOUND_16', Real),)),
   β”‚  Contraction(ops.logaddexp, ops.add,
   β”‚   frozenset({Variable('pois_scale__BOUND_15', Real)}),
   β”‚   (Gaussian(
   β”‚   β”‚ torch.tensor(...1..., dtype=torch.float32),
   β”‚   β”‚ torch.tensor(...1 x 1..., dtype=torch.float32),
   β”‚   β”‚ (('pois_scale__BOUND_15', Real),)),
   β”‚   β”‚Contraction(ops.add, ops.null,
   β”‚   β”‚ frozenset({Variable('place__BOUND_6', Bint[1372]), Variable('time__BOUND_7', Bint[49])}),
   β”‚   β”‚ (Contraction(ops.logaddexp, ops.null,
   β”‚   β”‚   frozenset({Variable('pois__BOUND_5', Real)}),
   β”‚   β”‚   (Gaussian(
   β”‚   β”‚   β”‚ torch.tensor(...49 x 1372 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚ torch.tensor(...49 x 1372 x 3 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚ (('time__BOUND_7', Bint[49]),
   β”‚   β”‚   β”‚  ('place__BOUND_6', Bint[1372]),
   β”‚   β”‚   β”‚  ('pois__BOUND_5', Real),
   β”‚   β”‚   β”‚  ('pois_loc__BOUND_16', Real),
   β”‚   β”‚   β”‚  ('pois_scale__BOUND_15', Real),)),)),)),)),)),
   β”‚Contraction(ops.logaddexp, ops.add,
   β”‚ frozenset({Variable('init_loc_scale__BOUND_17', Real)}),
   β”‚ (Gaussian(
   β”‚   torch.tensor(...1..., dtype=torch.float32),
   β”‚   torch.tensor(...1 x 1..., dtype=torch.float32),
   β”‚   (('init_loc_scale__BOUND_17', Real),)),
   β”‚  Contraction(ops.logaddexp, ops.add,
   β”‚   frozenset({Variable('init_scale__BOUND_18', Real)}),
   β”‚   (Gaussian(
   β”‚   β”‚ torch.tensor(...1..., dtype=torch.float32),
   β”‚   β”‚ torch.tensor(...1 x 1..., dtype=torch.float32),
   β”‚   β”‚ (('init_scale__BOUND_18', Real),)),
   β”‚   β”‚Contraction(ops.add, ops.null,
   β”‚   β”‚ frozenset({Variable('strain__BOUND_9', Bint[1343])}),
   β”‚   β”‚ (Contraction(ops.logaddexp, ops.add,
   β”‚   β”‚   frozenset({Variable('init_loc__BOUND_8', Real)}),
   β”‚   β”‚   (Gaussian(
   β”‚   β”‚   β”‚ torch.tensor(...1343 x 2..., dtype=torch.float32),
   β”‚   β”‚   β”‚ torch.tensor(...1343 x 2 x 2..., dtype=torch.float32),
   β”‚   β”‚   β”‚ (('strain__BOUND_9', Bint[1343]),
   β”‚   β”‚   β”‚  ('init_loc__BOUND_8', Real),
   β”‚   β”‚   β”‚  ('init_loc_scale__BOUND_17', Real),)),
   β”‚   β”‚   β”‚Contraction(ops.add, ops.null,
   β”‚   β”‚   β”‚ frozenset({Variable('place__BOUND_2', Bint[1372])}),
   β”‚   β”‚   β”‚ (Contraction(ops.logaddexp, ops.null,
   β”‚   β”‚   β”‚   frozenset({Variable('init__BOUND_1', Real)}),
   β”‚   β”‚   β”‚   (Gaussian(
   β”‚   β”‚   β”‚   β”‚ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚   β”‚ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚   β”‚ (('place__BOUND_2', Bint[1372]),
   β”‚   β”‚   β”‚   β”‚  ('strain__BOUND_9', Bint[1343]),
   β”‚   β”‚   β”‚   β”‚  ('init__BOUND_1', Real),
   β”‚   β”‚   β”‚   β”‚  ('init_scale__BOUND_18', Real),
   β”‚   β”‚   β”‚   β”‚  ('init_loc__BOUND_8', Real),)),)),)),)),)),)),)),)),))

The crux is this pair of Gaussian contractions with over 1e9 elements

   β”‚  Contraction(ops.logaddexp, ops.add,
   β”‚   frozenset({Variable('coef__BOUND_12', Reals[2367])}),
   β”‚   (Gaussian(
   β”‚   β”‚ torch.tensor(...2367..., dtype=torch.float32),
   β”‚   β”‚ torch.tensor(...2367 x 2367..., dtype=torch.float32),
   β”‚   β”‚ (('coef__BOUND_12', Reals[2367]),)),
   β”‚   β”‚Contraction(ops.add, ops.null,
   β”‚   β”‚ frozenset({Variable('strain__BOUND_11', Bint[1343])}),
   β”‚   β”‚ (Contraction(ops.logaddexp, ops.add,
   β”‚   β”‚   frozenset({Variable('rate_loc__BOUND_10', Real)}),
   β”‚   β”‚   (Gaussian(
   β”‚   β”‚   β”‚ torch.tensor(...1343 x 2369..., dtype=torch.float32),
   β”‚   β”‚   β”‚ torch.tensor(...1343 x 2369 x 2369..., dtype=torch.float32),  # <-------- OOM here
   β”‚   β”‚   β”‚ (('strain__BOUND_11', Bint[1343]),
   β”‚   β”‚   β”‚  ('rate_loc__BOUND_10', Real),
   β”‚   β”‚   β”‚  ('rate_loc_scale__BOUND_13', Real),
   β”‚   β”‚   β”‚  ('coef__BOUND_12', Reals[2367]),)),
   β”‚   β”‚   β”‚Contraction(ops.add, ops.null,
   β”‚   β”‚   β”‚ frozenset({Variable('place__BOUND_4', Bint[1372])}),
   β”‚   β”‚   β”‚ (Contraction(ops.logaddexp, ops.null,
   β”‚   β”‚   β”‚   frozenset({Variable('rate__BOUND_3', Real)}),
   β”‚   β”‚   β”‚   (Gaussian(
   β”‚   β”‚   β”‚   β”‚ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚   β”‚ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32),
   β”‚   β”‚   β”‚   β”‚ (('place__BOUND_4', Bint[1372]),
   β”‚   β”‚   β”‚   β”‚  ('strain__BOUND_11', Bint[1343]),
   β”‚   β”‚   β”‚   β”‚  ('rate__BOUND_3', Real),
   β”‚   β”‚   β”‚   β”‚  ('rate_scale__BOUND_14', Real),
   β”‚   β”‚   β”‚   β”‚  ('rate_loc__BOUND_10', Real),)),)),)),)),)),)),)),)),

I believe we can work around this using a combination of @fehiepsi's prec_sqrt representation https://github.com/pyro-ppl/pyro/pull/2019 and a ConditionalGaussian that generalizes AffineNormal. Happy to discuss.

fritzo avatar Oct 06 '21 00:10 fritzo

My impression is most of the details can be preserved (e.g. block vector, block matrix, align gaussian). Back then, one issue was batch qr is very slow on GPU, but torch linalg seems to have been improved a lot since then.

fehiepsi avatar Oct 07 '21 00:10 fehiepsi

@fehiepsi do you recall whether Cholesky was sufficient instead of QR? IIRC there was a PyTorch discussion about cheaply testing for positive definiteness or condition number using torch.linalg.cholesky_ex().

fritzo avatar Oct 07 '21 01:10 fritzo

Looking at the code, I guess we need to triangulate a non-positive-definite precision matrix (e.g. zeros matrix) but I can't recall when we need such triangularization. :( Probably, it is unnecessary. (anyway, we can switch to qr if we face the positive definiteness issue)

fehiepsi avatar Oct 07 '21 03:10 fehiepsi

@eb8680 want to pair code next week on the high-level algorithm for variable elimination, continuing our work from https://github.com/pyro-ppl/funsor/compare/tractable-for-gaussians ?

fritzo avatar Feb 19 '22 14:02 fritzo

Sure!

eb8680 avatar Feb 19 '22 16:02 eb8680