Implement unconstraining transform for LKJCorr
I've ported this bijector from tensorflow and added to LKJCorr. This ensures that initial samples drawn from LKJCorr are positive definite, which fixes #7101 . Sampling now completes successfully with no divergences.
There are several parts I'm not comfortable with:
- https://github.com/johncant/pymc/blob/fix_lkjcorr_positive_definiteness/pymc/distributions/multivariate.py#L1583 - it seems like a bad idea to run a pytensor graph here. Is there any way to get the LKJCorr
nparameter fromoporrvwithoutevaling any pytensors? - https://github.com/johncant/pymc/blob/fix_lkjcorr_positive_definiteness/pymc/distributions/transforms.py#L176-L177 - not sure whether or not this is the right way to create a constant pytensor tensor here.
@fonnesbeck @twiecki @jessegrabowski @velochy - please could you take a look? I would like to make sure that this fix makes sense before adding tests and making the linters pass.
Notes:
- Tests not yet written, linters not yet ran
- The original tensorflow bijector is defined in the opposite sense to pymc transforms, i.e.
forwardintensorflow_probabilityisbackwardinpymc - The original tensorflow bijector produces cholesky factors, not actual correlation matrices, so in this implementation, we have to do a cholesky decomposition in the forward transform.
- In the tensorflow bijector, the triagonal elements of a matrix are filled in a clockwise spiral, as opposed to numpy which defines indices in a row-major order.
Description
Backward method
- Start with identity matrix and fill lower triangular elements with unconstrained real numbers.
- Normalize each row so the L-2 norm is 1
- This is now a Cholesky factor that will always result in positive definite correlation matrices
Forward method
- Reconstruct the correlation matrix from its upper triangular elements
- Perform cholesky decomposition to obtain L
- The diagonal elements of L are multipliers we used to normalize the other elements.
- Extract those diagonal elements and divide to undo the backward method
log_jac_det
This was quite complicated to implement, so I used the symbolic jacobian.
Related Issue
- [ ] Closes #7101
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [x] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pymc--7380.org.readthedocs.build/en/7380/
]
:sparkling_heart: Thanks for opening this pull request! :sparkling_heart: The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.
Hi, It's unlikely I'm going to have any time to work on this for the next 6 months. The hardest part is coming up with a closed form solution for log_det_jac, which I don't think I'm very close to doing.
Thanks for the update @johncant and for pushing this as far as you did.