Allow more AdvancedHMC options in HMC types
As discussed on Slack, it would be very useful to expose more AdvancedHMC configuration options to the NUTS, HMC, and HMCDA convenience constructors here. In particular, it would be nice to be able to customize the integrator, metric (values), adaptor, or termination criterion.
As far as I can tell, the main challenge here is that the initial step size and metric are model-dependent, while NUTS etc should be model-independent. The approach this PR takes is to introduce default metric types that, if not replaced by a specific user-provided object, are internally constructed using the model.
Currently this is only set up for NUTS. If this seems like a useful general approach to other devs, I'll continue with the HMC and HMCDA samplers.
example
julia> using Turing, AdvancedHMC
julia> @model function foo()
x ~ filldist(Normal() * 1000, 10)
end
foo (generic function with 2 methods)
julia> chns = sample(foo(), Turing.NUTS(), 1000; save_state=true) # same defaults
┌ Info: Found initial step size
└ ϵ = 1638.4
Sampling 100%|████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×22×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 0.49 seconds
Compute duration = 0.49 seconds
parameters = x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
x[1] -47.6449 1016.1548 32.1336 25.7694 1412.7727 0.9990 2895.0259
x[2] -3.2566 1022.3188 32.3286 21.1134 2216.4932 0.9990 4541.9943
x[3] 14.5863 1031.0343 32.6042 27.1184 1719.4190 0.9993 3523.3996
x[4] -15.5568 1003.9893 31.7489 22.8348 1227.8352 0.9992 2516.0558
x[5] 10.5619 1012.4916 32.0178 29.7603 1304.7482 0.9990 2673.6643
x[6] 51.3310 1054.6483 33.3509 23.1870 1410.7269 0.9992 2890.8337
x[7] -11.5721 969.3929 30.6549 25.1998 1729.1101 0.9992 3543.2585
x[8] -18.5535 1077.2640 34.0661 31.5005 1439.2385 1.0008 2949.2593
x[9] -10.6400 999.5810 31.6095 26.3672 1299.9401 0.9990 2663.8117
x[10] -15.7143 999.1795 31.5968 26.6365 1511.8314 0.9991 3098.0152
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
x[1] -2045.2328 -737.6996 -58.0066 601.9530 1963.7858
x[2] -1934.4399 -719.0586 -11.8093 700.2544 1897.8098
x[3] -2058.7987 -686.3523 -27.1321 742.4192 2113.6651
x[4] -2038.9502 -687.3642 3.0949 637.3003 1958.2605
x[5] -1970.9327 -673.6777 25.0485 707.0809 2029.6881
x[6] -1958.0350 -663.9977 93.2271 798.4898 1994.8763
x[7] -1959.5376 -644.9694 7.7956 617.2869 1993.9105
x[8] -2101.6400 -778.0808 -37.4446 742.2334 2099.6595
x[9] -2111.1562 -687.0762 1.1653 615.2214 1982.0659
x[10] -2003.7428 -671.0353 -28.7991 654.2518 1913.1991
julia> chns.info.samplerstate.hamiltonian.metric # default metric is adapted
DiagEuclideanMetric([1.0351872421358126e6, 1.11 ...])
julia> chns.info.samplerstate.kernel.τ.integrator # default integrator's step size is adapted
Leapfrog(ϵ=0.933)
julia> integrator = JitteredLeapfrog(0.2, 0.1)
JitteredLeapfrog(ϵ0=0.2, jitter=0.1, ϵ=0.2)
julia> chns = sample(foo(), Turing.NUTS(; integrator, metricT=AdvancedHMC.DenseEuclideanMetric, adaptor=StepSizeAdaptor(0.8, integrator)), 1000; save_state=true)
Sampling 100%|████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×22×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 1.13 seconds
Compute duration = 1.13 seconds
parameters = x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
x[1] -24.8085 956.4862 30.2468 29.3441 1077.4406 0.9996 953.4872
x[2] -11.3046 1060.5540 33.5377 26.0390 1026.2949 0.9990 908.2256
x[3] -19.6913 1009.4485 31.9216 34.7410 1026.4851 0.9992 908.3939
x[4] 10.9409 1044.2046 33.0206 31.3643 1201.9143 0.9991 1063.6410
x[5] -49.8740 999.2622 31.5994 32.9384 1165.3464 0.9992 1031.2800
x[6] 4.5491 961.3624 30.4009 30.6247 1021.8930 0.9990 904.3301
x[7] 67.1029 998.4159 31.5727 35.1677 937.8898 0.9991 829.9909
x[8] 1.9122 1002.4672 31.7008 30.6341 1078.1651 0.9990 954.1284
x[9] -34.8320 1008.0000 31.8758 35.5399 830.5491 1.0031 734.9992
x[10] -22.7870 975.3529 30.8434 31.6282 873.5094 0.9991 773.0171
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
x[1] -1859.0547 -665.3047 -1.1679 625.7305 1812.4707
x[2] -2045.1465 -693.2438 -3.2864 696.4378 2066.3285
x[3] -1951.5782 -696.4558 -48.6351 661.1009 2024.9583
x[4] -2047.1058 -718.1381 37.8368 747.6024 1979.9771
x[5] -1947.0055 -709.0552 -45.4802 623.8888 1929.2428
x[6] -1964.3656 -648.1510 22.1216 665.4752 1815.8038
x[7] -1929.9386 -597.4158 62.1090 784.5026 1924.7035
x[8] -2064.1684 -652.1909 43.2797 671.0167 1795.4633
x[9] -2006.4426 -703.5645 -15.7514 624.2890 1962.0789
x[10] -2082.6195 -646.1086 -13.8338 622.7712 1896.3478
julia> chns.info.samplerstate.hamiltonian.metric # specified metric is not adapted
DenseEuclideanMetric(diag=[1.0, 1.0, 1.0, 1.0, 1.0, 1 ...])
julia> chns.info.samplerstate.kernel.τ.integrator # specified integrator's step size is adapted
JitteredLeapfrog(ϵ0=837.0, jitter=0.1, ϵ=0.2)
cc @cpfiffer @yebai
Some things I don't like about the current interface:
- To customize the adaptor, if I end up creating the
StepSizeAdaptormyself, then I need to also create the integrator myself, and I don't benefit from Turing's mechanism for finding a good initial step size. - If I provide the integrator or metric, I may want to disable the step size or metric adaptation, respectively. Currently this requires creating the adaptor myself, but I'd prefer a simpler way to disable one or both adaptation schemes.
@sethaxen apologies for the slow response, I'll take a look later this week.
Closed in favour of https://github.com/TuringLang/Turing.jl/pull/1997