pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

feat(useless-rewrite): Useless rewrite from `log1mexp(log1mexp(x))` to `x`

Open lmmx opened this issue 2 years ago • 3 comments

Motivation for these changes

Picking up this ticket on the PyData Global 2023 OSS sprint :runner:

  • #471

The motivating issue concerned an issue using PyMC's Censored functionality which was traced back to the need for a "useless rewrite":

  • replacing a $log(1 - exp(x))$ within another $log(1 - exp(x))$
    • i.e. a $log(1 - exp(log(1 - exp(x))))$
  • by just an $x$.

Implementation details

  1. :broom: refactor(nan switch): DRY out nan switch rewrite function, easier to follow important parts :broom:
  • The first change is in preparation for adding a new case, which is to simplify the case handling (avoid repetition).
  • This is achieved by putting an "inner function" (function within a function) that captures the x and node from the function body in its scope, meaning we don't need to pass them as parameters, so the body of each case becomes simpler
  • Every case has a "nan switch", so this trick lets us avoid repeating ourselves but retaining clarity about the variables we're using.
  1. :writing_hand: Add new useless rewrite case :writing_hand:
  • The previous operation and the node operation are both going to be log1mexp for this case
  • The condition for the case is that x >= 0 (confirm?)

Checklist

Major / Breaking Changes

  • N/A

New features

  • "Useless rewrite" to optimise $log(1 - exp(log(1 - exp(x))))$ into just $x$

Bugfixes

  • Would resolve #471

Documentation

  • ...

Maintenance

  • Took an opportunity to tidy up the 'nan switch' cases into a dict, which is clearer to read than the repetitive if blocks

lmmx avatar Dec 06 '23 22:12 lmmx

I think to finish this I need to add a test here

https://github.com/pymc-devs/pytensor/blob/c10c376b173402eb06bed9aff56d0c77fd21dd79/tests/tensor/rewriting/test_math.py#L1945-L1971

lmmx avatar Dec 07 '23 13:12 lmmx

I think to finish this I need to add a test here

https://github.com/pymc-devs/pytensor/blob/c10c376b173402eb06bed9aff56d0c77fd21dd79/tests/tensor/rewriting/test_math.py#L1945-L1971

Sounds about right

ricardoV94 avatar Dec 08 '23 18:12 ricardoV94

The condition for the case is that x >= 0 (confirm?)

The valid condition is that x <= 0 for which the inner log1mexp is defined. x > 0 shoud yield nan, as that would lead to taking the log of a negative number.

ricardoV94 avatar Dec 08 '23 18:12 ricardoV94