DistributionsAD.jl icon indicating copy to clipboard operation
DistributionsAD.jl copied to clipboard

MappedDistribution

Open mohdibntarek opened this issue 6 years ago • 8 comments

Sometimes it is useful to be able to define a multivariate distribution on iid variables by generating distributions on the fly which use different distribution parameters each variable according to a certain rule/function. Defining an efficient logpdf and adjoint can give significant computational savings. This is similar to the Soss For combinator.

mohdibntarek avatar Nov 01 '19 22:11 mohdibntarek

Here's my current setup:

struct For{F,T,D,X} 
    f :: F  
    θ :: T
end

where...

  • F and T are specified in the struct
  • D is the distribution returned
  • X is the eltype of that distribution (unfortunately, often not available directly from D)

Some example use cases:

# T = NTuple{2,Int}
x ~ For(10,3) do i,j
    Bernoulli(j/i)
end
# T = Base.Generator{Base.OneTo{Int64},Base.var"#174#175"{Array{Float64,2}}}
y ~ For(eachrow(X)) do xrow
    Normal(xrow' * β, 1)
end

We'll have different methods for rand, logpdf, etc based mostly on T.

Also, I currently have the following restrictions:

  1. D is consistent across indices
  2. support(d::D) is consistent across indices

Currently this targets "array-like" results, but in principle T can be anything, for example an iterator or Real (for function spaces, GPs, etc).

cscherrer avatar Nov 01 '19 22:11 cscherrer

I don't think we need a restriction on D being the same. The logpdf can be something like this:

function logpdf(dist::For, x::AbstractArray)
    @assert size(dist.θ) == size(x)
    return sum(1:length(dist.θ)) do i
        logpdf(dist.f(dist.θ[i]), x[i])
    end
end
rand(dist::For) = rand.(dist.f.(dist.θ))

Whether f returns the same distribution or not, this should be inferrable by the Julia compiler.

mohdibntarek avatar Nov 01 '19 23:11 mohdibntarek

eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ))

mohdibntarek avatar Nov 01 '19 23:11 mohdibntarek

Note that the above is a dynamically sized distribution. We can also get free specialization and inlining for small, fixed-size distributions when using θ::StaticArray.

mohdibntarek avatar Nov 01 '19 23:11 mohdibntarek

I think for Tracker

sum(logpdf.(dist.f.(dist.θ), x))

will be faster than

sum(1:length(dist.θ)) do i
    logpdf(dist.f(dist.θ[i]), x[i])
end

So if either θ or x is a TrackedArray, all intermediates will also be TrackedArrays not TrackedReals.

mohdibntarek avatar Nov 01 '19 23:11 mohdibntarek

I don't think we need a restriction on D being the same.

The most obvious reason for this is type stability, though there may be ways around that. In addition, the vast majority of models will satisfy this anyway, and it often opens up opportunities for optimization. For example, in cases where d.f maps to continuous distributions, how can we determine the bijection to ℝⁿ? Parameterizing by D makes this trivial.

One thing I've found a bit tricky is make useful type information available without much computational cost. Unfortunately in Julia, we can't just ask a function about its codomain, so instantiating a For requires some computation in order to determine the types. To this point, I've been trying to make construction cheap by assuming distributions and supports are consistent, and just computing them for a single index at construction time. Your eltype suggestion,

eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ))

is appealing, but would require O(n) instantiation cost.

Above I suggested For might also be useful for building distributions over function spaces. I may disagree with myself on this point, because it drops the conditional independence assumption of other For instances, and would require adding some way to specify covariance.

Finally, we had some recent discussion on Discourse about the best approach for parallelism, which will be important for many cases.

cscherrer avatar Nov 02 '19 16:11 cscherrer

Cleaning this up a bit in Soss, here's the current state: https://github.com/cscherrer/Soss.jl/blob/dev/src/for.jl Should be able to get a PR submitted today.

There's also iid, which is like For but without the distributional dependence on indices. I have a curried form, which I usually use like this:

x ~ Normal() |> iid(N)

cscherrer avatar Nov 02 '19 20:11 cscherrer

Thanks for the PR @cscherrer and sorry for the late review; I was busy the last few weeks. I will review your PR asap.

mohdibntarek avatar Nov 20 '19 01:11 mohdibntarek