GeneralizedGenerated.jl
GeneralizedGenerated.jl copied to clipboard
Closures with Fix1 for type stability
Hi @thautwarm ,
In a new version of Soss, I'm trying to really nail down type stability. Using GG I was having lots of trouble getting constructions like
For(n -> Normal(α + β * x[n], σ), 1:N)
to be type-stable.
After lots of exploring, I realized this could be solved by putting this function outside of GG:
f(ctx, n) = Normal(ctx.α + ctx.β * ctx.x[n], ctx.σ)
f(ctx::NamedTuple) = Base.Fix1(f, ctx)
Then within GG, it could be called as
For(f((;α,β,σ,x)), 1:N)
That works great, but it's tricky to automate, since the f needs to be made available separately within the module, and I'd rather not use eval to put it there.
I eventually found this to work:
For(let f = ctx -> Base.Fix1(ctx) do ctx, n
Normal(ctx.α + ctx.β * ctx.x[n], ctx.σ)
end
f((; α, β, σ, x))
end, 1:N)
So now, I think I can just rewrite things this way before calling GG. But I wanted to check with you before doing this, since it seems plausible the current Closure approach could be updated to something similar, and maybe it would help other cases with type stability. What do you think?
Hi @cscherrer, could you give an MWE for why the following Soss model is not type-stable?
For(n -> Normal(α + β * x[n], σ), 1:N)
since it seems plausible the current Closure approach could be updated to something similar, and maybe it would help other cases with type stability. What do you think?
As you mentioned the implementation of GG's Closure here, I'd guess the type stability issue is caused by the use of Core.Box.
Core.Box allows capturing type-unstable free variables, so that you can reassign heterogenous data to a free variable. However, it seems that it breaks type stability.
https://github.com/JuliaStaging/GeneralizedGenerated.jl/blob/6ebfe69887acb6fc2b6b4e1a434918c8f67e2cf7/src/closure_conv.jl#L18-L22
I now have time to work on GG and Soss.
Great! Yesterday's covid booster is hitting me pretty hard, but I'll send you an update when I'm feeling better. Lots going on 🙂
Sorry to hear that.. Hope to see you get better, and we will work together again!
Hi @cscherrer, could you give an MWE for why the following Soss model is not type-stable?
For(n -> Normal(α + β * x[n], σ), 1:N)
Using MeasureBase#dev, MeasureTheory#dev2, and Soss#dev, I get
julia> using Soss
julia> using MeasureTheory
julia> m = @model n begin
y ~ For(n) do j
Bernoulli(1/j)
end
end;
julia> s = rand(m(10))
(y = Bool[1, 0, 0, 0, 0, 0, 0, 0, 1, 0],)
julia> using JET
julia> @test_opt rand(m(10))
JET-test failed at REPL[44]:1
Expression: #= REPL[44]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m(10))
═════ 7 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:24 Soss.rand(Soss.GLOBAL_RNG, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:20 Soss._rand(_, m, Soss.argvals(c))(rng)
││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 Base.getproperty(Main, :For)(function = (j;) -> begin
(Main).Bernoulli((Main).:/(1, j))
end, n)
││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:199 MeasureBase.For(f, Base.OneTo(n))
│││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:198 MeasureBase.For(f, inds)
││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:18 For(::ggfunc-function, ::Tuple{Base.OneTo{Int64}})
│││││││││ failed to optimize: For(::ggfunc-function, ::Tuple{Base.OneTo{Int64}})
││││││││└───────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:198 For(::ggfunc-function, ::Base.OneTo{Int64})
││││││││ failed to optimize: For(::ggfunc-function, ::Base.OneTo{Int64})
│││││││└────────────────────────────────────────────────────────────
││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:199 For(::ggfunc-function, ::Int64)
│││││││ failed to optimize: For(::ggfunc-function, ::Int64)
││││││└────────────────────────────────────────────────────────────
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 y = Base.getproperty(Main, :rand)(_rng, Base.getproperty(Main, :For)(function = (j;) -> begin
(Main).Bernoulli((Main).:/(1, j))
end, n))
││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 MeasureBase.rand(rng, MeasureBase.Float64, d)
│││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:212 MeasureBase._rand_product(rng, _, MeasureBase.marginals(d), _)
││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:33 MeasureBase.map(#28, mar)
│││││││││┌ @ abstractarray.jl:2849 Base.collect_similar(A, Base.Generator(f, A))
││││││││││┌ @ array.jl:653 Base._collect(cont, itr, Base.IteratorEltype(itr), Base.IteratorSize(itr))
│││││││││││┌ @ array.jl:701 et = Base.promote_typejoin_union(T)
││││││││││││┌ @ promotion.jl:170 Base.promote_typejoin_union(Base.getproperty(_, :a))
│││││││││││││┌ @ promotion.jl:190 unwrapva(%35)
││││││││││││││ runtime dispatch detected: unwrapva(%35::Any)
│││││││││││││└────────────────────
│││││││││││││┌ @ promotion.jl:194 Base.typejoin(%57, %59)
││││││││││││││ runtime dispatch detected: Base.typejoin(%57::Any, %59::Any)
│││││││││││││└────────────────────
│││││││││││││┌ @ promotion.jl:198 Base.promote_typejoin_union(%50)
││││││││││││││ runtime dispatch detected: Base.promote_typejoin_union(%50::Any)
│││││││││││││└────────────────────
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Int64, ::Random._GLOBAL_RNG)
││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Int64, ::Random._GLOBAL_RNG)
│││││└───────────────────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing
I think this comes down to
@generated function For(f::F, inds::I) where {F,I<:Tuple}
eltypes = Tuple{eltype.(I.types)...}
quote
$(Expr(:meta, :inline))
T = Core.Compiler.return_type(f, $eltypes)
For{T,F,I}(f, inds)
end
end
I've had some comments that using return_type is a bad idea, but I really need as much of this as possible to happen at compile time, and so far I don't see a good alternative. Also worth noting that MappedArrays.jl takes a very similar approach, see e.g. https://github.com/JuliaArrays/MappedArrays.jl/blob/46bf47f3388d011419fe43404214c1b7a44a49cc/src/MappedArrays.jl#L61
Interestingly, this fixes it:
julia> f(ctx) = Base.Fix1(ctx) do ctx, j
Bernoulli(1/j)
end
f (generic function with 1 method)
julia> m = @model n begin
y ~ For(f(NamedTuple()), n)
end;
julia> s = rand(m(10))
(y = Bool[1, 0, 0, 1, 0, 0, 0, 0, 0, 0],)
julia> using JET
julia> @test_opt rand(m(10))
Test Passed
Expression: #= REPL[7]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m(10))
In general, ctx can be set to a named tuple of local variables, and f could be a runtime generated function. In addition to type stability, this gives a way of breaking up a GG function into smaller pieces, which may also solve the problem of source code size limitations in GG :)
Looking at this a little closer, I think my approach is kind of clumsy. Do you think it could be possible to instead have an option to just avoid boxing in the first place? Maybe something like an opaque closure, or a faked version of that? Even something throwing an error on type instability would be useful.
Just a little more information on this, then I'm going to try to set it aside for a while...
Say you start with
using Soss, JET
m1 = @model N begin
p ~ Uniform()
x ~ For(N) do j
Bernoulli(p / j)
end
end
Then JET.@test_opt rand(m1(10)) fails. This is possibly because of boxing the state for the closure, though I don't understand the GG implementation well enough to confirm this.
**Here's the test for m1**
julia> @test_opt rand(m1(10))
JET-test failed at REPL[38]:1
Expression: #= REPL[38]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m1(10))
═════ 16 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#49(Base.pairs(Core.NamedTuple()), #self#, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#51(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
│││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:24 tilde_rand(:x, Base.getproperty(Main, :For)(Core.apply_type(Closure, function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Base.typeof(freevars))(freevars), N), _cfg, _ctx, targs)
││││││││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:43 x = Soss.rand(Base.getproperty(cfg, :rng), d)
│││││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 MeasureBase.rand(rng, MeasureBase.Float64, d)
││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:220 MeasureBase._rand_product(rng, _, MeasureBase.marginals(d), _)
│││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:33 MeasureBase.map(#28, mar)
││││││││││││┌ @ abstractarray.jl:2849 Base.collect_similar(A, Base.Generator(f, A))
│││││││││││││┌ @ array.jl:653 Base._collect(cont, itr, Base.IteratorEltype(itr), Base.IteratorSize(itr))
││││││││││││││┌ @ array.jl:744 y = Base.iterate(itr)
│││││││││││││││┌ @ generator.jl:44 y = Base.iterate(Core.tuple(Base.getproperty(g, :iter)), s...)
││││││││││││││││┌ @ abstractarray.jl:1142 #self#(A, Core.tuple(Base.eachindex(A)))
│││││││││││││││││┌ @ abstractarray.jl:1144 Base.getindex(A, Base.getindex(y, 1))
││││││││││││││││││┌ @ /home/chad/.julia/packages/MappedArrays/bS6Yp/src/MappedArrays.jl:166 Base.getproperty(A, :f)(Base.getindex(Core.tuple(Base.getproperty(A, :data)), i...))
│││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
││││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 Base.merge(Base.NamedTuple(), kwargs)
│││││││││││││││││││││┌ @ namedtuple.jl:303 Core.apply_type(Base.NamedTuple, Core.tuple(names...))(Core.tuple(vals...))
││││││││││││││││││││││┌ @ boot.jl:601 Core.apply_type(Core.NamedTuple, _, Core.typeof(args))(args)
│││││││││││││││││││││││┌ @ namedtuple.jl:96 _(args)
││││││││││││││││││││││││┌ @ tuple.jl:312 Base.convert(_, x)
│││││││││││││││││││││││││┌ @ essentials.jl:344 Base.Val(_)
││││││││││││││││││││││││││┌ @ essentials.jl:701 %1()
│││││││││││││││││││││││││││ runtime dispatch detected: %1::Type{Val{_A}} where _A()
││││││││││││││││││││││││││└─────────────────────
│││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:5 (::GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}})(::Int64)
││││││││││││││││││││ failed to optimize: (::GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}})(::Int64)
│││││││││││││││││││└──────────────────────────────────────────────────────────────────────────
││││││││││││││││││┌ @ /home/chad/.julia/packages/MappedArrays/bS6Yp/src/MappedArrays.jl:164 getindex(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Int64)
│││││││││││││││││││ failed to optimize: getindex(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Int64)
││││││││││││││││││└─────────────────────────────────────────────────────────────────────────
│││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}})
││││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}})
│││││││││││││││││└─────────────────────────
││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}})
│││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}})
││││││││││││││││└─────────────────────────
│││││││││││││││┌ @ generator.jl:42 iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
││││││││││││││││ failed to optimize: iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
│││││││││││││││└───────────────────
││││││││││││││┌ @ array.jl:754 Base.collect_to_with_first!(dest, v1, itr, st)
│││││││││││││││┌ @ array.jl:760 Base.collect_to!(dest, itr, Base.+(i1, 1), st)
││││││││││││││││┌ @ array.jl:782 y = Base.iterate(itr, st)
│││││││││││││││││┌ @ generator.jl:44 y = Base.iterate(Core.tuple(Base.getproperty(g, :iter)), s...)
││││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}, Int64})
│││││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}, Int64})
││││││││││││││││││└─────────────────────────
│││││││││││││││││┌ @ generator.jl:42 iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Tuple{Base.OneTo{Int64}, Int64})
││││││││││││││││││ failed to optimize: iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Tuple{Base.OneTo{Int64}, Int64})
│││││││││││││││││└───────────────────
││││││││││││││┌ @ array.jl:741 Base._collect(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Base.EltypeUnknown, ::Base.HasShape{1})
│││││││││││││││ failed to optimize: Base._collect(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Base.EltypeUnknown, ::Base.HasShape{1})
││││││││││││││└────────────────
│││││││││││││┌ @ array.jl:653 Base.collect_similar(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
││││││││││││││ failed to optimize: Base.collect_similar(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
│││││││││││││└────────────────
│││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:32 MeasureBase._rand_product(::Random._GLOBAL_RNG, ::Type{Float64}, ::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Type{Bernoulli{(:p,), Tuple{Float64}}})
││││││││││││ failed to optimize: MeasureBase._rand_product(::Random._GLOBAL_RNG, ::Type{Float64}, ::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Type{Bernoulli{(:p,), Tuple{Float64}}})
│││││││││││└───────────────────────────────────────────────────────────────
││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:219 rand(::Random._GLOBAL_RNG, ::Type{Float64}, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
│││││││││││ failed to optimize: rand(::Random._GLOBAL_RNG, ::Type{Float64}, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
││││││││││└────────────────────────────────────────────────────────────
│││││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 rand(::Random._GLOBAL_RNG, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
││││││││││ failed to optimize: rand(::Random._GLOBAL_RNG, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
│││││││││└───────────────────────────────────────────────
││││││││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:42 Soss.tilde_rand(::Symbol, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}}, ::NamedTuple{(:rng, :_args, :_obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:N,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, ::NamedTuple{(:p,), Tuple{Float64}}, ::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}, Static.False, Static.False})
│││││││││ failed to optimize: Soss.tilde_rand(::Symbol, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}}, ::NamedTuple{(:rng, :_args, :_obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:N,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, ::NamedTuple{(:p,), Tuple{Float64}}, ::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}, Static.False, Static.False})
││││││││└────────────────────────────────────────────────────
│││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
││││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
│││││││└───────────────────────────────────────────────────────────────────────────────────
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:5 GeneralizedGenerated.var"#_#4"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::GeneralizedGenerated.Closure{function = (_mc, _cfg, _ctx;) -> begin
begin
$(Expr(:meta, :((Main).inline)))
local _retn
_args = (Main).Soss.argvals(_mc)
_obs = (Main).Soss.observations(_mc)
_cfg = (Main).merge(_cfg, (_args = _args, _obs = _obs))
let
begin
N = _args.N
(p, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:p, (Main).Uniform(), _cfg, _ctx, targs)
end
end
(x, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:x, (Main).For(begin
let freevars = (p,)
(GeneralizedGenerated.Closure){function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Base.typeof(freevars)}(freevars)
end
end, N), _cfg, _ctx, targs)
end
end
end
end
_retn
end
end, Tuple{Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
││││││ failed to optimize: GeneralizedGenerated.var"#_#4"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::GeneralizedGenerated.Closure{function = (_mc, _cfg, _ctx;) -> begin
begin
$(Expr(:meta, :((Main).inline)))
local _retn
_args = (Main).Soss.argvals(_mc)
_obs = (Main).Soss.observations(_mc)
_cfg = (Main).merge(_cfg, (_args = _args, _obs = _obs))
let
begin
N = _args.N
(p, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:p, (Main).Uniform(), _cfg, _ctx, targs)
end
end
(x, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:x, (Main).For(begin
let freevars = (p,)
(GeneralizedGenerated.Closure){function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Base.typeof(freevars)}(freevars)
end
end, N), _cfg, _ctx, targs)
end
end
end
end
_retn
end
end, Tuple{Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
│││││└──────────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing
If instead, you do it like this:
f(ctx) = Base.Fix1(ctx) do ctx, j
Bernoulli(ctx.p / j)
end
m2 = @model N begin
p ~ Uniform()
x ~ For(f((p=p,)), N)
end
Then @test_opt rand(m2(10)) passes just fine. So my next thought was to have f(ctx) generated at runtime, using GG. But I'm not sure it's possible, since it seems it would rely on the same mechanism that failed for the first case.
That got me wondering whether it makes more sense to change the GG implementation to take a similar approach. But again this isn't at all clear to me. For example, this attempt also fails:
m3 = @model N begin
p ~ Uniform()
f(ctx) = Base.Fix1(ctx) do ctx, j
Bernoulli(ctx.p / j)
end
x ~ For(f((p=p,)), N)
end
**Here's the test for m3**
julia> @test_opt rand(m3(10))
JET-test failed at REPL[37]:1
Expression: #= REPL[37]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m3(10))
═════ 2 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#49(Base.pairs(Core.NamedTuple()), #self#, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#51(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
│││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││││┌ @ /home/chad/git/Soss.jl/src/primitives/interpret.jl:25 f(Core.apply_type(Core.NamedTuple, (:p,))(Core.tuple(p)))
││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 (::ggfunc-f)(::NamedTuple{(:p,), Tuple{Float64}})
│││││││││ failed to optimize: (::ggfunc-f)(::NamedTuple{(:p,), Tuple{Float64}})
││││││││└───────────────────────────────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{118}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
││││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{118}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
│││││││└───────────────────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing
A couple of things are surprising:
- Though
m1andm3both fail, there seem to be far fewer issues withm3 - Dynamic dispatch usually leads to big slowdowns in benchmark times, but here I don't see a difference at all.
It seems likely to me that dynamic dispatch could become a problem for some models, even if it's not for this one. But I'm feeling kind of stuck anyway, so I'm going to set it aside for a while. Let me know if you have any ideas, and we can pick back up.
Sorry for the delay, I'd get started working on this.
Right now I cloned Soss#dev but have issues in executing instantiate.
ERROR: Unsatisfiable requirements detected for package TransformVariables [84d833dd]:
TransformVariables [84d833dd] log:
├─possible versions are: 0.1.0-0.5.0 or uninstalled
├─restricted to versions 0.5 by Soss [8ce77f84], leaving only versions 0.5.0
│ └─Soss [8ce77f84] log:
│ ├─possible versions are: 0.20.9 or uninstalled
│ └─Soss [8ce77f84] is fixed to version 0.20.9
└─restricted by compatibility requirements with MeasureTheory [eadaa1a4] to versions: 0.4.0-0.4.1 — no versions left
└─MeasureTheory [eadaa1a4] log:
├─possible versions are: 0.2.1-0.13.2 or uninstalled
└─restricted to versions 0.13 by Soss [8ce77f84], leaving only versions 0.13.0-0.13.2
└─Soss [8ce77f84] log: see above
I think this comes down to
@generated function For(f::F, inds::I) where {F,I<:Tuple} eltypes = Tuple{eltype.(I.types)...} quote $(Expr(:meta, :inline)) T = Core.Compiler.return_type(f, $eltypes) For{T,F,I}(f, inds) end end
Some rough ideas: is the issue possibly related to a run-time invocation of return_type? Is there any issue of lifting the inference to codegen time as follow?
@generated function For(f::F, inds::I) where {F,I<:Tuple}
eltypes = Tuple{eltype.(I.types)...}
T = Core.Compiler.return_type(f, eltypes)
quote
$(Expr(:meta, :inline))
For{$T,F,I}(f, inds)
end
end
It seems f is expected to be later extended (add new methods), so Core.Compiler.return_type might return different results?
Right now I cloned
Soss#devbut have issues in executinginstantiate.
Sorry, I've been changing more things. I'll get to something stable and send you the Manifest.
is the issue possibly related to a run-time invocation of
return_type?
I originally had this outside the quote, I forget why I moved it. But anyway, I had to add this for it to work at all:
@generated function MeasureTheory.For(f::GG.Closure{F,Free}, inds::I) where {F,Free,I<:Tuple}
freetypes = Free.types
eltypes = eltype.(I.types)
T = Core.Compiler.return_type(F, Tuple{freetypes..., eltypes...})
quote
$(Expr(:meta, :inline))
For{$T,GG.Closure{F,Free},I}(f, inds)
end
end
Sorry, I've been changing more things. I'll get to something stable and send you the Manifest.
Thanks, just send me them when done and I can work on this.
I will need some time to address the concrete issue, and thanks a lot for providing above cases.
Ok, I sent you a Zulip message with lots of details
This seems maybe helpful:
julia> g = mk_function(Main, :(ctx -> (j -> Bernoulli(ctx.p/j))))
function = (ctx;) -> begin
begin
begin
let freevars = (ctx,)
(GeneralizedGenerated.Closure){function = (ctx, j;) -> begin
begin
(Main).Bernoulli((Main).:/(ctx.p, j))
end
end, Base.typeof(freevars)}(freevars)
end
end
end
end
julia> m3 = @model n begin
p ~ Uniform()
y ~ For(g((p=p,)), n)
end;
julia> rand(m3(3))
(p = 0.9464608557346492, y = Bool[1, 1, 0])
julia> @test_opt rand(m3(3))
JET-test failed at REPL[162]:1
Expression: #= REPL[162]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m3(3))
═════ 3 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#43(Base.pairs(Core.NamedTuple()), #self#, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#45(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
│││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││││┌ @ /home/chad/git/Soss.jl/src/primitives/interpret.jl:22 %48(%49)
││││││││ runtime dispatch detected: %48::Any(%49::NamedTuple{(:p,), Tuple{Float64}})
│││││││└─────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/git/MeasureTheory.jl/src/combinators/for.jl:263 MeasureTheory.For(%50, %53)
││││││││ runtime dispatch detected: MeasureTheory.For(%50::Any, %53::Tuple{ArrayInterface.OptionallyStaticUnitRange{Static.StaticInt{1}, Int64}})
│││││││└──────────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/ngg/runtime_fns.jl:21 tilde_rand(:y, %54, %3, %45, %47)
││││││││ runtime dispatch detected: tilde_rand(:y::Symbol, %54::Any, %3::NamedTuple{(:rng, :args, :obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:n,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, %45::NamedTuple{(:p,), Tuple{Float64}}, %47::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}})
│││││││└────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing
It seems to get pretty close, but then there's this
│││││││┌ @ /home/chad/git/MeasureTheory.jl/src/combinators/for.jl:263 MeasureTheory.For(%50, %53)
││││││││ runtime dispatch detected: MeasureTheory.For(%50::Any, %53::Tuple{ArrayInterface.OptionallyStaticUnitRange{Static.StaticInt{1}, Int64}})
The first argument passed to For is g((p=p,)). The compiler can figure this out just fine for a given p:
julia> Core.Compiler.return_type(g, Tuple{typeof((p = 0.2,))})
GeneralizedGenerated.Closure{function = (ctx, j;) -> begin
begin
(Main).Bernoulli((Main).:/(ctx.p, j))
end
end, Tuple{NamedTuple{(:p,), Tuple{Float64}}}}
But when it's inside another GG function, it falls back to Any