pymc icon indicating copy to clipboard operation
pymc copied to clipboard

ENH: Measurable `dot` and `matmul`

Open jessegrabowski opened this issue 2 years ago • 4 comments

Before

def scaled_mv_normal(mu, cov, R):
    return pm.MvNormal.dist(mu=R @ mu, cov= R @ cov @ R.T)

After

def scaled_mv_normal(mu, cov, R):
    return R @ pm.MvNormal.dist(mu=mu, cov=cov)

Context for the issue:

Currently, the automatic logp inference isn't able to handle dot products or matrix multiplication of a multivariate normal with a deterministic vector/matrix. Since there are closed form results for these cases (it's essentially the same as repeated shifting/scaling/convolution of normals), it seems like it should be possible to have a measurable_dot and measurable_matmul, at least in the MvNormal case?

jessegrabowski avatar Oct 12 '23 16:10 jessegrabowski

Sounds good, we should only need a rewrite that converts the second form to the first?

ricardoV94 avatar Oct 14 '23 15:10 ricardoV94

Yes I think so. Is there a rewrite to look at that I could pattern a PR after?

jessegrabowski avatar Oct 14 '23 16:10 jessegrabowski

Yes I think so. Is there a rewrite to look at that I could pattern a PR after?

It should be like some of these standalone rewrites: https://github.com/pymc-devs/pymc/blob/827918b42720e4108cd86af88a56229f9af85fcf/pymc/logprob/transforms.py#L540-L544

They are only called when the variable in question (in this casve MV @ x)) has to be measured. So you only need to check the dot involves a MvNormal and return the new equivalent MvNormal.

ricardoV94 avatar Oct 16 '23 08:10 ricardoV94