MeasureBase.jl
MeasureBase.jl copied to clipboard
Use `Sampler` interface?
In a recent Zulip discussion (https://julialang.zulipchat.com/#narrow/stream/137791-general/topic/Static.20code.20blocks/near/289768224), @Seelengrab reminded me of the Random.Sampler
interface, with this example:
struct BoundedFloat64
min::Float64
max::Float64
end
Base.eltype(::Type{BoundedFloat64}) = Float64
Base.rand(_::AbstractRNG, b::Random.SamplerTrivial{BoundedFloat64}) = b[].min + (b[].max - b[].min) * rand(Float64)
Random.Sampler(_::Type{<:AbstractRNG}, bf64::BoundedFloat64, r::Random.Repetition) = Random.SamplerTrivial(bf64)
This makes it so you can call e.g.
julia> rand(BoundedFloat64(0.5, 10.0), 3)
3-element Vector{Float64}:
0.8715965799597065
4.250901702717998
0.9709465785383822
Our current implementation ignore this interface, instead working entirely in terms of dispatching to rand
. Our rand
takes three arguments, with these defaults:
Base.rand(d::AbstractMeasure) = rand(Random.GLOBAL_RNG, Float64, d)
Base.rand(T::Type, μ::AbstractMeasure) = rand(Random.GLOBAL_RNG, T, μ)
Base.rand(rng::AbstractRNG, d::AbstractMeasure) = rand(rng, Float64, d)
Here the type argument in the second position is passed to the "inner" rand
call. For example,
julia> using MeasureBase
julia> rand(Float64, StdExponential())
1.8034180730760465
julia> rand(Float32, StdExponential())
2.1787617f0
julia> rand(Float16, StdExponential())
Float16(0.1595)
I think it's very useful to keep this kind of flexibility, but it could help to also have thing set up in a way to take advantage of the Sampler interface. It's very new to me and doesn't seem to be well-documented. But maybe we can go from examples or ask for help if we get stuck.
One thing to consider is that we can easily build measures over spaces more complex than a simple scalar or array. To minimize allocations in cases like this, I've previously explored having every rand
have two steps: an allocation, followed by a call to rand!
to fill in the values. It's not clear to me if this is the "right" way to do things. As usual in this ecosystem, we need to think a lot about composability and its impact on performance.
I don't know who originally came up with the Sampler
interface, and while it seems to be a perfect fit, it's documentation really is a bit terse.. Would be really good to get some input from them here. I had to piece what I got together from how the code is used in Base, and I have no idea how the API would/could be used for your differing Float64
/Float32
/Float16
uses :shrug:
I played around with the interface some more, and while I don't know whether it's proper, the following also works (at least for naive rand(foo, 10)
calls):
using Random
struct BoundedFloat64 <: Random.Sampler{Float64}
min::Float64
max::Float64
end
Base.eltype(::Type{BoundedFloat64}) = Float64
Base.rand(r::AbstractRNG, b::BoundedFloat64) = b.min + (b.max - b.min) * rand(r, Float64)
But I suspect this is not proper...
Looking back at where the various generators are defined, I really don't think the version above is good/proper. That at least has some more comments about how samplers work, maybe they're of help as well?
Could be nice. I remember running into of dispatch ambiguities with the sampler interface extensions in the past through, we'll probably need to do this carefully.
A way to define rand
for BoundedFloat
would be:
struct BoundedFloat{T<:AbstractFloat}
min::T
max::T
end
Base.eltype(::Type{BoundedFloat{T}}) where {T} = T
Base.rand(r::AbstractRNG, sp::SamplerTrivial{<:BoundedFloat}) = let b = sp[]
b.min + (b.max - b.min) * rand(r, eltype(b))
end
Actually, typically when we are calling rand
in our definition of a custom rand
method like here, it's generally recommended (for potential performance gains) to cache a sampler for the inner call inside the sampler for the outer call (see example of uses of SamplerSimple
in the generation.jl file mentioned above), but in practice, when the inner call is just on a scalar type like Float64
or Int
, it's most of the time unnecessary.
(i'm happy to try to answer questions here, but tag me to get my attention)