Turing.jl
Turing.jl copied to clipboard
Broadcasting , addprob and PPL not functionning with PG and SMC sampler
Hello, I was trying to use the macro Turing.@addlogprob! in a case where I wanted to use PG sampler, but it could not work. I know that it is for advanced user, however when trying to understand the problem I found that PG and SMC do not work either with broadcasting (function called demo_array) in the following code. I implemented 4 times the function than you call coinflip in your quickstart (here called demo) with different 'complexity' of implementation. I tried four sampler: HMC MH SMC and PG. For the four different way of implementing it, it works only for all with HMC and MH With PG and SMC only the original function work:
using Turing
using Pkg
Pkg.status("Turing")
@model function demo(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end
@model function demo_array(x)
# 'array'
p ~ Beta(1,1)
x .~ Bernoulli(p)
return
end
@model function demo_addprob(x)
# 'logprob'
p ~ Beta(1,1)
loglik = loglikelihood(Bernoulli(p), x)
Turing.@addlogprob!(loglik)
return
end
function function_demo_PPL(model, varinfo, context, x)
p, varinfo = DynamicPPL.tilde_assume!!(
context,
Beta(1, 1),
Turing.@varname(p),
varinfo,
)
DynamicPPL.dot_tilde_observe!!(context, Bernoulli(p), x, Turing.@varname(x), varinfo)
end
demo_PPL(x) = Turing.Model(function_demo_PPL, (; x))
data = [true for _ in 1:20]
p_model = Dict("demo"=>demo,
"array"=>demo_array,
"add_prob"=>demo_addprob,
"PPL"=>demo_PPL)
p_sampler = Dict("HMC"=>HMC(0.05, 10), "SMC"=>SMC(),"MH"=>MH(),"PG"=>PG(20))
results=[]
for model_name in ["demo","array","add_prob","PPL"]
for (sampler_name, sampler) in p_sampler
c = sample(p_model[model_name](data),sampler, 1000)
push!(results,[model_name,sampler_name,mean(c[:p])])
end
end
println("Expected value p=",round(21/22,digits=2))
for (model_name,sampler,p) in results
println("Model name: ",model_name,", sampler: ",sampler,", p=",round(p,digits=2))
end
returns
Status `~/alignExp/Project.toml`
⌃ [fce5fe82] Turing v0.21.13
Info Packages marked with ⌃ have new versions available and may be upgradable.
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:08
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:03
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:02
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:02
Expected value p=0.95
Model name: demo, sampler: HMC, p=0.95
Model name: demo, sampler: SMC, p=0.93
Model name: demo, sampler: MH, p=0.95
Model name: demo, sampler: PG, p=0.96
Model name: array, sampler: HMC, p=0.95
Model name: array, sampler: SMC, p=0.5
Model name: array, sampler: MH, p=0.96
Model name: array, sampler: PG, p=0.51
Model name: add_prob, sampler: HMC, p=0.96
Model name: add_prob, sampler: SMC, p=0.5
Model name: add_prob, sampler: MH, p=0.95
Model name: add_prob, sampler: PG, p=0.5
Model name: PPL, sampler: HMC, p=0.96
Model name: PPL, sampler: SMC, p=0.5
Model name: PPL, sampler: MH, p=0.95
Model name: PPL, sampler: PG, p=0.5