feat(useless-rewrite): Useless rewrite from `log1mexp(log1mexp(x))` to `x`
Motivation for these changes
Picking up this ticket on the PyData Global 2023 OSS sprint :runner:
- #471
The motivating issue concerned an issue using PyMC's Censored functionality which was traced back to the need for a "useless rewrite":
- replacing a $log(1 - exp(x))$ within another $log(1 - exp(x))$
- i.e. a $log(1 - exp(log(1 - exp(x))))$
- by just an $x$.
Implementation details
- :broom: refactor(nan switch): DRY out nan switch rewrite function, easier to follow important parts :broom:
- The first change is in preparation for adding a new case, which is to simplify the case handling (avoid repetition).
- This is achieved by putting an "inner function" (function within a function) that captures the
xandnodefrom the function body in its scope, meaning we don't need to pass them as parameters, so the body of each case becomes simpler - Every case has a "nan switch", so this trick lets us avoid repeating ourselves but retaining clarity about the variables we're using.
- :writing_hand: Add new useless rewrite case :writing_hand:
- The previous operation and the node operation are both going to be
log1mexpfor this case - The condition for the case is that x >= 0 (confirm?)
Checklist
- [x] Explain motivation and implementation 👆
- [x] Make sure that the pre-commit linting/style checks pass.
- [x] Link relevant issues, preferably in nice commit messages.
- [x] The commits correspond to relevant logical changes. Note that if they don't, we will rewrite/rebase/squash the git history before merging.
- [ ] Are the changes covered by tests and docstrings?
- [ ] Fill out the short summary sections 👇
Major / Breaking Changes
- N/A
New features
- "Useless rewrite" to optimise $log(1 - exp(log(1 - exp(x))))$ into just $x$
Bugfixes
- Would resolve #471
Documentation
- ...
Maintenance
- Took an opportunity to tidy up the 'nan switch' cases into a dict, which is clearer to read than the repetitive if blocks
I think to finish this I need to add a test here
https://github.com/pymc-devs/pytensor/blob/c10c376b173402eb06bed9aff56d0c77fd21dd79/tests/tensor/rewriting/test_math.py#L1945-L1971
I think to finish this I need to add a test here
https://github.com/pymc-devs/pytensor/blob/c10c376b173402eb06bed9aff56d0c77fd21dd79/tests/tensor/rewriting/test_math.py#L1945-L1971
Sounds about right
The condition for the case is that x >= 0 (confirm?)
The valid condition is that x <= 0 for which the inner log1mexp is defined. x > 0 shoud yield nan, as that would lead to taking the log of a negative number.