pymc-marketing
pymc-marketing copied to clipboard
Refactor `logp` in BG/BB to remove Scan
logp in the BetaGeoBetaBinom distribution block contains an iterable currently serviced by a Scan from pytensor. It's possible to refactor this so that Scan is no longer needed:
i = pt.scalar("i", dtype=int)
died = pt.lt(t_x + i, T)
unnorm_logp_died_at_tx_plus_i = pt.where(
pt.ge(t_x, i),
(
betaln(alpha + x, beta + t_x - x + i)
+ betaln(gamma + died, delta + t_x + i)
),
-np.inf
)
#Maximum prevents invalid T - t_x values from crashing logp
max_range = pt.maximum(pt.max(T - t_x), 0)
i_vec = pt.arange(max_range + 1)
unnorm_logp_died_at_tx_plus_i_vec = vectorize_graph(
unnorm_logp_died_at_tx_plus_i,
replace={i: i_vec},
)
unnorm_logp = pt.logsumexp(unnorm_logp_died_at_tx_plus_i_vec, axis=0)
I compared both approaches in a dev notebook, and sans Scan is about 3x faster:
# w/ Scan
267 ms ± 6.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# w/o Scan
85.2 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
However, the above code requires modification because tests are failing with the returned logp values.
Scan may be plenty fast in other backends: numba and jax, the first will be the default sometime in the future, and it's what it's used with nutpie. Jax is used for numpyro and blackjax. I would benchmark on those backends that before bothering to get rid of it.
Also for varied datasets (t_x very different across subjects) the non scan will probably be slower as it does a lot of useless computations. In the dense/ non scan way it will evaluate the worst case scenario (the biggest gap between T and t_x) for everyone even if it's only needed for 1 row out of 10000
ok! thanks for the input! I took the PR because I always wanna play with scan, but we can close it and have other benchmarks. We can always come back and change it, as we have the code in a branch already.
@ricardoV94 do you have a time estimate on when numba will became the new default backend? I'm working on the BG/BB model right now, and currently NUTS is taking over an hour on my Macbook M2 Pro with a dataset of 11.2k rows.
You can select other backends manually, don't need to wait for the default to change
Rescuing key commits from https://github.com/pymc-labs/pymc-marketing/pull/707
- https://github.com/pymc-labs/pymc-marketing/pull/707/commits/605310f23004357dbe83eac9d83ef8562e5c0d23
- https://github.com/pymc-labs/pymc-marketing/pull/707/commits/da6568c846858723cc31ea931e51f3a6c0d7c375
- https://github.com/pymc-labs/pymc-marketing/pull/707/commits/92cd5ce342d87a36b83392050ea58eaf6ae69914