pymc-marketing icon indicating copy to clipboard operation
pymc-marketing copied to clipboard

Refactor `logp` in BG/BB to remove Scan

Open ColtAllen opened this issue 1 year ago • 5 comments

logp in the BetaGeoBetaBinom distribution block contains an iterable currently serviced by a Scan from pytensor. It's possible to refactor this so that Scan is no longer needed:

i = pt.scalar("i", dtype=int)
died = pt.lt(t_x + i, T)

unnorm_logp_died_at_tx_plus_i = pt.where(
    pt.ge(t_x, i),
    (
        betaln(alpha + x, beta + t_x - x + i)
        + betaln(gamma + died, delta + t_x + i)
    ),
    -np.inf
)

#Maximum prevents invalid T - t_x values from crashing logp
max_range = pt.maximum(pt.max(T - t_x), 0)
i_vec = pt.arange(max_range + 1)
unnorm_logp_died_at_tx_plus_i_vec = vectorize_graph(
    unnorm_logp_died_at_tx_plus_i,
    replace={i: i_vec},
)

unnorm_logp = pt.logsumexp(unnorm_logp_died_at_tx_plus_i_vec, axis=0)

I compared both approaches in a dev notebook, and sans Scan is about 3x faster:

# w/ Scan
267 ms ± 6.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# w/o Scan
85.2 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

However, the above code requires modification because tests are failing with the returned logp values.

ColtAllen avatar May 26 '24 22:05 ColtAllen

Scan may be plenty fast in other backends: numba and jax, the first will be the default sometime in the future, and it's what it's used with nutpie. Jax is used for numpyro and blackjax. I would benchmark on those backends that before bothering to get rid of it.

Also for varied datasets (t_x very different across subjects) the non scan will probably be slower as it does a lot of useless computations. In the dense/ non scan way it will evaluate the worst case scenario (the biggest gap between T and t_x) for everyone even if it's only needed for 1 row out of 10000

ricardoV94 avatar May 31 '24 08:05 ricardoV94

ok! thanks for the input! I took the PR because I always wanna play with scan, but we can close it and have other benchmarks. We can always come back and change it, as we have the code in a branch already.

juanitorduz avatar May 31 '24 09:05 juanitorduz

@ricardoV94 do you have a time estimate on when numba will became the new default backend? I'm working on the BG/BB model right now, and currently NUTS is taking over an hour on my Macbook M2 Pro with a dataset of 11.2k rows.

ColtAllen avatar Jul 22 '24 10:07 ColtAllen

You can select other backends manually, don't need to wait for the default to change

ricardoV94 avatar Jul 22 '24 14:07 ricardoV94

Rescuing key commits from https://github.com/pymc-labs/pymc-marketing/pull/707

  • https://github.com/pymc-labs/pymc-marketing/pull/707/commits/605310f23004357dbe83eac9d83ef8562e5c0d23
  • https://github.com/pymc-labs/pymc-marketing/pull/707/commits/da6568c846858723cc31ea931e51f3a6c0d7c375
  • https://github.com/pymc-labs/pymc-marketing/pull/707/commits/92cd5ce342d87a36b83392050ea58eaf6ae69914

juanitorduz avatar Sep 14 '24 12:09 juanitorduz