MeasureBase.jl icon indicating copy to clipboard operation
MeasureBase.jl copied to clipboard

Use `Sampler` interface?

Open cscherrer opened this issue 2 years ago • 4 comments

In a recent Zulip discussion (https://julialang.zulipchat.com/#narrow/stream/137791-general/topic/Static.20code.20blocks/near/289768224), @Seelengrab reminded me of the Random.Sampler interface, with this example:

struct BoundedFloat64
    min::Float64
    max::Float64
end
Base.eltype(::Type{BoundedFloat64}) = Float64
Base.rand(_::AbstractRNG, b::Random.SamplerTrivial{BoundedFloat64}) = b[].min + (b[].max - b[].min) * rand(Float64)
Random.Sampler(_::Type{<:AbstractRNG}, bf64::BoundedFloat64, r::Random.Repetition) = Random.SamplerTrivial(bf64)

This makes it so you can call e.g.

julia> rand(BoundedFloat64(0.5, 10.0), 3)
3-element Vector{Float64}:
 0.8715965799597065
 4.250901702717998
 0.9709465785383822

Our current implementation ignore this interface, instead working entirely in terms of dispatching to rand. Our rand takes three arguments, with these defaults:

Base.rand(d::AbstractMeasure) = rand(Random.GLOBAL_RNG, Float64, d)
Base.rand(T::Type, μ::AbstractMeasure) = rand(Random.GLOBAL_RNG, T, μ)
Base.rand(rng::AbstractRNG, d::AbstractMeasure) = rand(rng, Float64, d)

Here the type argument in the second position is passed to the "inner" rand call. For example,

julia> using MeasureBase

julia> rand(Float64, StdExponential())
1.8034180730760465

julia> rand(Float32, StdExponential())
2.1787617f0

julia> rand(Float16, StdExponential())
Float16(0.1595)

I think it's very useful to keep this kind of flexibility, but it could help to also have thing set up in a way to take advantage of the Sampler interface. It's very new to me and doesn't seem to be well-documented. But maybe we can go from examples or ask for help if we get stuck.

One thing to consider is that we can easily build measures over spaces more complex than a simple scalar or array. To minimize allocations in cases like this, I've previously explored having every rand have two steps: an allocation, followed by a call to rand! to fill in the values. It's not clear to me if this is the "right" way to do things. As usual in this ecosystem, we need to think a lot about composability and its impact on performance.

cscherrer avatar Jul 17 '22 18:07 cscherrer

I don't know who originally came up with the Sampler interface, and while it seems to be a perfect fit, it's documentation really is a bit terse.. Would be really good to get some input from them here. I had to piece what I got together from how the code is used in Base, and I have no idea how the API would/could be used for your differing Float64/Float32/Float16 uses :shrug:

I played around with the interface some more, and while I don't know whether it's proper, the following also works (at least for naive rand(foo, 10) calls):

using Random

struct BoundedFloat64 <: Random.Sampler{Float64}
    min::Float64
    max::Float64
end
Base.eltype(::Type{BoundedFloat64}) = Float64
Base.rand(r::AbstractRNG, b::BoundedFloat64) = b.min + (b.max - b.min) * rand(r, Float64)

But I suspect this is not proper...

Seelengrab avatar Jul 17 '22 19:07 Seelengrab

Looking back at where the various generators are defined, I really don't think the version above is good/proper. That at least has some more comments about how samplers work, maybe they're of help as well?

Seelengrab avatar Jul 17 '22 19:07 Seelengrab

Could be nice. I remember running into of dispatch ambiguities with the sampler interface extensions in the past through, we'll probably need to do this carefully.

oschulz avatar Jul 19 '22 07:07 oschulz

A way to define rand for BoundedFloat would be:

struct BoundedFloat{T<:AbstractFloat} 
    min::T
    max::T
end
Base.eltype(::Type{BoundedFloat{T}}) where {T} = T

Base.rand(r::AbstractRNG, sp::SamplerTrivial{<:BoundedFloat}) = let b = sp[]
    b.min + (b.max - b.min) * rand(r, eltype(b))
end

Actually, typically when we are calling rand in our definition of a custom rand method like here, it's generally recommended (for potential performance gains) to cache a sampler for the inner call inside the sampler for the outer call (see example of uses of SamplerSimple in the generation.jl file mentioned above), but in practice, when the inner call is just on a scalar type like Float64 or Int, it's most of the time unnecessary.

(i'm happy to try to answer questions here, but tag me to get my attention)

rfourquet avatar Jun 06 '23 13:06 rfourquet