DistributionsAD.jl icon indicating copy to clipboard operation
DistributionsAD.jl copied to clipboard

Faster `arraydist` with LazyArrays.jl

Open torfjelde opened this issue 2 years ago • 4 comments

This PR is basically an accumulation of the discussion in https://github.com/TuringLang/Turing.jl/issues/1934 and https://github.com/TuringLang/DistributionsAD.jl/pull/230.

It's a hack to make reverse-mode AD packages that uses ForwardDiff for broadcasting much faster when used in combination with LazyArrays.jl.

Unfortunately, this requires a rather ugly hack that is make_closure (maybe there's a more elegant solution? @devmotion pls halp!), but it does buy us a whole lot of runtime.

torfjelde avatar Jan 16 '23 19:01 torfjelde

I had a rather quick look. There's many things that should not be here (but maybe some them could be moved or fixed somewhere else). Specializations for LazyArrays could maybe be added to Distributions as weak dependencies (or maybe we just fix the use of Broadcasted). Nevertheless it might be OK for some time to have fixes here, but we should make sure to test everything.

I 100% agree with you, but realistically tackling these issues is going to take a lot of time and effort. In the mean time, I think this hacky approach will have to do :confused: This solves so many Slack and Discourse threads of people going "why is this simple model so slow in Turing.jl?"...

torfjelde avatar Jan 19 '23 00:01 torfjelde

I was wondering what is stopping this PR from being merged? I've had this page in my open tabs for a few weeks now, anxiously awaiting the purple merged badge 😄

I was hoping to re-do my logistic regression benchmark with this PR once it's merged, but I jumped the gun here.

In short, the benefits are incredible and the best part is that it would be easy even for new users (it just needs an update in the tutorials).

svilupp avatar Feb 06 '23 09:02 svilupp

I was wondering what is stopping this PR from being merged?

I've had this page in my open tabs for a few weeks now, anxiously awaiting the purple merged badge smile

The approach taken here is quite hacky, and, tbh, it's unfortunate that we even have to do this. It's really just working around type-inference issues for calling a UnionAll.

So if we accept, we have to be on high-alert in case something breaks (due to its hackinesse), and ideally this should be fixed elsewhere so maybe we should spend some more time thinking to see if we can address this in a broader manner.

I was hoping to re-do my logistic regression benchmark with this PR once it's merged, but I jumped the gun here.

In short, the benefits are incredible and the best part is that it would be easy even for new users (it just needs an update in the tutorials).

Glad to see it's working though!:)

torfjelde avatar Feb 14 '23 10:02 torfjelde

I was wondering what is stopping this PR from being merged?

I've had this page in my open tabs for a few weeks now, anxiously awaiting the purple merged badge smile

The approach taken here is quite hacky, and, tbh, it's unfortunate that we even have to do this. It's really just working around type-inference issues for calling a UnionAll.

So if we accept, we have to be on high-alert in case something breaks (due to its hackinesse), and ideally this should be fixed elsewhere so maybe we should spend some more time thinking to see if we can address this in a broader manner.

I was hoping to re-do my logistic regression benchmark with this PR once it's merged, but I jumped the gun here.

In short, the benefits are incredible and the best part is that it would be easy even for new users (it just needs an update in the tutorials).

Glad to see it's working though!:)

Since this PR is unlikely to be merged, do we have tips/snippets that intermediate users could opt into in their code?

How would I recognize that the UnionAll is slowing me down (besides the slowness itself)? I can see that you identified it with this call by running @code_warntype f.makeargs.f and looking for Body::Any.

However, it's not clear to me "where" to break into. How would I set a breakpoint for it with Debugger?

using Debugger, Turing
# as per your example
@model function irt(y, i, p; I = maximum(i), P = maximum(p))
    theta ~ filldist(Normal(), P)
    beta ~ filldist(Normal(), I)
    Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(theta[p] - beta[i]), y))

    return (; theta, beta)
end
model = irt(y, i, p); 
# what is the `some_func` at which we should break to see the context
@bp some_func
@run chn=sample(model, NUTS(), 100)

In the absence of this PR, what's the best way to overcome the AD taking the slow path? My understanding of your thread was that it comes from broadcasting over Structs, so we want compiler to remove them.

Eg, define a wrapper

BernoulliLogitF(x) =BernoulliLogit(x)

# to be used like
Turing.@addlogprob! sum(logpdf.(BernoulliLogitF.(theta[p] - beta[i]), y))

# instead of this in Turing.@addlogprob! 
Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(theta[p] - beta[i]), y))

and check as per the above point if you get Body::Any or a type

Or would you go about it differently?

EDIT: I found that a better/simpler workaround is to use BernoulliLogitF(x, c) = BernoulliLogit(x, c) as per the above

svilupp avatar Mar 09 '23 09:03 svilupp

This should be ideally fixed by autodiff, e.g. Tapir and Enzyme.

yebai avatar Apr 17 '24 14:04 yebai