MeasureTheory.jl icon indicating copy to clipboard operation
MeasureTheory.jl copied to clipboard

`For` transports

Open cscherrer opened this issue 3 years ago • 1 comments

I got this working, sort of:

julia> d = For(j -> Normal(j, 2.0), 1:3)
For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))

julia> test_transport(d, Normal() ^ 3)
Test Summary:                                                                                             | Pass  Total  Time
transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,)) |    8      8  0.0s
DefaultTestSet("transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))", Any[], 8, false, false, true, 1.66725e9, 1.66725e9)

To do this, I added for_constructor that's like For, but a little smarter - it might sometimes collapse to a power measure:

for_constructor(f, x) = for_constructor(f, (x,))

@generated function for_constructor(f::F, inds::I) where {F,I<:Tuple}
    eltypes = Tuple{eltype.(I.types)...}
    quote
        T = Core.Compiler.return_type(f, $eltypes)
        _for(T, f, inds, static(Base.issingletontype(T)))
    end
end

function _for(::Type{T}, f::F, inds::I, ::True) where {T,F,I}
    instance(T) ^ size(first(inds))
end

function _for(::Type{T}, f::F, inds::I, ::False) where {T,F,I}
    For{T,F,I}(f, inds)
end

Then we just need the standard stuff:

function MeasureBase.transport_origin(d::AbstractProductMeasure)
    for_constructor(MeasureBase.transport_origin, marginals(d))
end

function MeasureBase.to_origin(d::AbstractProductMeasure, x)
    map(MeasureBase.to_origin, marginals(d), x)
end

function MeasureBase.from_origin(d::AbstractProductMeasure, x)
    map(MeasureBase.from_origin, marginals(d), x)
end

Well, almost. There's also this bug:

julia> MeasureBase._origin_depth(Normal() ^ 3)
ERROR: MethodError: no method matching ^(::MeasureBase.NoTransportOrigin{StdNormal}, ::Tuple{Int64})
Closest candidates are:
  ^(::AbstractMeasure, ::Tuple) at ~/git/MeasureBase.jl/src/combinators/power.jl:55
  ^(::AbstractMeasure, ::Any) at ~/git/MeasureBase.jl/src/combinators/power.jl:56
Stacktrace:
 [1] _for(#unused#::Type{MeasureBase.NoTransportOrigin{StdNormal}}, f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}}, #unused#::Static.True)
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:37
 [2] macro expansion
   @ ~/git/MeasureTheory.jl/src/combinators/for.jl:32 [inlined]
 [3] for_constructor(f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:28
 [4] for_constructor(f::Function, x::FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:26
 [5] transport_origin(d::PowerMeasure{StdNormal, Tuple{Base.OneTo{Int64}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:305
 [6] _origin_depth(ν::PowerMeasure{Normal{(), Tuple{}}, Tuple{Base.OneTo{Int64}}})
   @ MeasureBase ~/git/MeasureBase.jl/src/transport.jl:130
 [7] top-level scope
   @ REPL[60]:1

We end up taking a power of a NoTransportOrigin, which makes no sense. As a quick fix, I temporarily changed MeasureBase._origin_depth to

@inline function _origin_depth(ν::NU) where {NU}
    ν_0 = ν
    Base.Cartesian.@nexprs 10 i -> begin  # 10 is just some "big enough" number
        ν_{i} = transport_origin(ν_{i - 1})
        if ν_{i} isa PowerMeasure
            ν_{i} = ν_{i}.parent
        else
            if ν_{i} isa NoTransportOrigin
            return static(i - 1)
        end
    end
    return static(10)
end

This last part feels kind of hacky. Also, we have the problem that map forces allocation. It would be nice to use mappedarray instead, but that doesn't infer properly. Maybe a modification of it could?

Also, it seems like a problem if we have a product with different "origin depths". A fixpoint approach would handle this, but I think the current approach will break. Any ideas for this @oschulz ?

cscherrer avatar Oct 31 '22 22:10 cscherrer

a problem if we have a product with different "origin depths"

Well, if it's a tuple-based product, the transport for each marginal should generate separate code and everything should infer. And if it's array-based the marginals have different depth then they also have different type, so type-inference is probably hopeless anyway, right?

oschulz avatar Oct 31 '22 22:10 oschulz