mcx icon indicating copy to clipboard operation
mcx copied to clipboard

WIP: Implement a `denormalize` custom Jaxpr operator simplifying MCX logpdfs

Open balancap opened this issue 3 years ago • 2 comments

Overview

This PR is implementing a generic denormalize decorator which removes normalizing constants in a logpdf. Per call to contributions in #65.

Implementation

The current implementation is a two passes algorithm:

  • Forward pass to find all constant variables in the Jaxpr graph;
  • Backward pass to find all simplifying assignment, but skipping add and sub operations where one of the input is a constant;

Once the latter simplifying mapping is found, the rest of decorator code is just a simple execution pass on the Jaxpr, skipping the operations where a simplifying mapping exists.

Limitations

Even though we try to have a fairly generic implementation, some simplifications are not supported at the moment. For instance, we do not propagate constant simplification in concat or mul operations. These cases could be supported in the future, if it happens to be a performance bottleneck in MCX.

balancap avatar Jan 31 '21 15:01 balancap

@rlouf As we discussed on Slack, there is quite a bit of additional complexity to add to this PR to handle properly the support select condition appearing in lot of distributions logpdf.

I'll start with a fairly dummy implementation, getting it working, and I think then we can iterate on it to make it less naive and using more properly symbolic programming concept (I started looking at Oryx codebase on that).

balancap avatar Feb 03 '21 19:02 balancap

That sounds like a very good plan to me! I'll have a closer look too when my big PR is merged.

rlouf avatar Feb 03 '21 20:02 rlouf