MeasureBase.jl icon indicating copy to clipboard operation
MeasureBase.jl copied to clipboard

Hitting `NoTransportOrigin` from weighted measures

Open cscherrer opened this issue 2 years ago • 3 comments

The "transport origin" of a WeightedMeasure needs to be weighted. I had thought this might be easy:

transport_origin(ν::WeightedMeasure) = weightedmeasure(ν.logweight, transport_origin(ν.base))

to_origin(w::WeightedMeasure, y) = to_origin(w.base, y)
from_origin(w::WeightedMeasure, x) = from_origin(w.base, x)

but this leads us to

julia> MeasureBase.transport_origin(2.2 * StdNormal())
2.2 * MeasureBase.NoTransportOrigin{StdNormal}()

I understand that this is just how dispatch is set up for transports, but it's bitten me a few times, and I think it will be at least as confusing for new users as it has been for me. I think we can keep the current high-level functionality but change the dispatch patterns under the hood so it works more like basemeasure. That would also help new devs, because getting used to one (basemeasures or transports) will make the other easier.

@oschulz I could try to prototype this, but I know you're also working on improving type stability for transports, and I don't want us to collide more than we have to :)

cscherrer avatar Sep 29 '22 13:09 cscherrer

but change the dispatch patterns under the hood so it works more like basemeasure

The thing is, transport needs both transport_def and transport_origin, it's not quite the same as basemeasure. But for the case above, maybe we can do this:

transport_origin(ν::WeightedMeasure) = _weighted_origin(ν.logweight, transport_origin(ν.base))

_weighted_origin(logweight, m) = weightedmeasure(logweight, m)
_weighted_origin(logweight, no_origin::NoTransportOrigin) = no_origin

oschulz avatar Sep 29 '22 20:09 oschulz

The thing is, transport needs both transport_def and transport_origin, it's not quite the same as basemeasure.

transport_def is like the 3-arg logdensity_def, and transport_origin is like basemeasure. I think abstractly it's the exact same problem.

With basemeasure, the setup is that for a given measure m we can consider "all measures having the same rootmeasure as m". This forms a tree. For comparisons between two branches of the tree we can let the computations go ahead automatically or add optimized methods. In some cases, we can also jump between trees.

If you replace basemeasure with transport_origin, how are these different?

cscherrer avatar Sep 29 '22 20:09 cscherrer

I made a little proof of concept here: https://github.com/cscherrer/MeasureBase.jl/pull/94

The advantage is that the approach is very similar to what we do with basemeasure. There might be good reasons to approach these differently (and if there are, we should) but I'm not seeing it yet.

The other side of this is, it's entirely possible that we might be better off dropping the current basemeasure approach altogether and instead changing that to match the transport implementation in master.

cscherrer avatar Sep 30 '22 02:09 cscherrer