Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Allow more AdvancedHMC options in HMC types

Open sethaxen opened this issue 3 years ago • 2 comments

As discussed on Slack, it would be very useful to expose more AdvancedHMC configuration options to the NUTS, HMC, and HMCDA convenience constructors here. In particular, it would be nice to be able to customize the integrator, metric (values), adaptor, or termination criterion.

As far as I can tell, the main challenge here is that the initial step size and metric are model-dependent, while NUTS etc should be model-independent. The approach this PR takes is to introduce default metric types that, if not replaced by a specific user-provided object, are internally constructed using the model.

Currently this is only set up for NUTS. If this seems like a useful general approach to other devs, I'll continue with the HMC and HMCDA samplers.

example
julia> using Turing, AdvancedHMC

julia> @model function foo()
           x ~ filldist(Normal() * 1000, 10)
       end
foo (generic function with 2 methods)

julia> chns = sample(foo(), Turing.NUTS(), 1000; save_state=true)  # same defaults
┌ Info: Found initial step size
└   ϵ = 1638.4
Sampling 100%|████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×22×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 0.49 seconds
Compute duration  = 0.49 seconds
parameters        = x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean         std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol    Float64     Float64    Float64   Float64     Float64   Float64       Float64 

        x[1]   -47.6449   1016.1548    32.1336   25.7694   1412.7727    0.9990     2895.0259
        x[2]    -3.2566   1022.3188    32.3286   21.1134   2216.4932    0.9990     4541.9943
        x[3]    14.5863   1031.0343    32.6042   27.1184   1719.4190    0.9993     3523.3996
        x[4]   -15.5568   1003.9893    31.7489   22.8348   1227.8352    0.9992     2516.0558
        x[5]    10.5619   1012.4916    32.0178   29.7603   1304.7482    0.9990     2673.6643
        x[6]    51.3310   1054.6483    33.3509   23.1870   1410.7269    0.9992     2890.8337
        x[7]   -11.5721    969.3929    30.6549   25.1998   1729.1101    0.9992     3543.2585
        x[8]   -18.5535   1077.2640    34.0661   31.5005   1439.2385    1.0008     2949.2593
        x[9]   -10.6400    999.5810    31.6095   26.3672   1299.9401    0.9990     2663.8117
       x[10]   -15.7143    999.1795    31.5968   26.6365   1511.8314    0.9991     3098.0152

Quantiles
  parameters         2.5%       25.0%      50.0%      75.0%       97.5% 
      Symbol      Float64     Float64    Float64    Float64     Float64 

        x[1]   -2045.2328   -737.6996   -58.0066   601.9530   1963.7858
        x[2]   -1934.4399   -719.0586   -11.8093   700.2544   1897.8098
        x[3]   -2058.7987   -686.3523   -27.1321   742.4192   2113.6651
        x[4]   -2038.9502   -687.3642     3.0949   637.3003   1958.2605
        x[5]   -1970.9327   -673.6777    25.0485   707.0809   2029.6881
        x[6]   -1958.0350   -663.9977    93.2271   798.4898   1994.8763
        x[7]   -1959.5376   -644.9694     7.7956   617.2869   1993.9105
        x[8]   -2101.6400   -778.0808   -37.4446   742.2334   2099.6595
        x[9]   -2111.1562   -687.0762     1.1653   615.2214   1982.0659
       x[10]   -2003.7428   -671.0353   -28.7991   654.2518   1913.1991


julia> chns.info.samplerstate.hamiltonian.metric  # default metric is adapted
DiagEuclideanMetric([1.0351872421358126e6, 1.11 ...])

julia> chns.info.samplerstate.kernel.τ.integrator  # default integrator's step size is adapted
Leapfrog(ϵ=0.933)

julia> integrator = JitteredLeapfrog(0.2, 0.1)
JitteredLeapfrog(ϵ0=0.2, jitter=0.1, ϵ=0.2)

julia> chns = sample(foo(), Turing.NUTS(; integrator, metricT=AdvancedHMC.DenseEuclideanMetric, adaptor=StepSizeAdaptor(0.8, integrator)), 1000; save_state=true)
Sampling 100%|████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×22×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 1.13 seconds
Compute duration  = 1.13 seconds
parameters        = x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean         std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol    Float64     Float64    Float64   Float64     Float64   Float64       Float64 

        x[1]   -24.8085    956.4862    30.2468   29.3441   1077.4406    0.9996      953.4872
        x[2]   -11.3046   1060.5540    33.5377   26.0390   1026.2949    0.9990      908.2256
        x[3]   -19.6913   1009.4485    31.9216   34.7410   1026.4851    0.9992      908.3939
        x[4]    10.9409   1044.2046    33.0206   31.3643   1201.9143    0.9991     1063.6410
        x[5]   -49.8740    999.2622    31.5994   32.9384   1165.3464    0.9992     1031.2800
        x[6]     4.5491    961.3624    30.4009   30.6247   1021.8930    0.9990      904.3301
        x[7]    67.1029    998.4159    31.5727   35.1677    937.8898    0.9991      829.9909
        x[8]     1.9122   1002.4672    31.7008   30.6341   1078.1651    0.9990      954.1284
        x[9]   -34.8320   1008.0000    31.8758   35.5399    830.5491    1.0031      734.9992
       x[10]   -22.7870    975.3529    30.8434   31.6282    873.5094    0.9991      773.0171

Quantiles
  parameters         2.5%       25.0%      50.0%      75.0%       97.5% 
      Symbol      Float64     Float64    Float64    Float64     Float64 

        x[1]   -1859.0547   -665.3047    -1.1679   625.7305   1812.4707
        x[2]   -2045.1465   -693.2438    -3.2864   696.4378   2066.3285
        x[3]   -1951.5782   -696.4558   -48.6351   661.1009   2024.9583
        x[4]   -2047.1058   -718.1381    37.8368   747.6024   1979.9771
        x[5]   -1947.0055   -709.0552   -45.4802   623.8888   1929.2428
        x[6]   -1964.3656   -648.1510    22.1216   665.4752   1815.8038
        x[7]   -1929.9386   -597.4158    62.1090   784.5026   1924.7035
        x[8]   -2064.1684   -652.1909    43.2797   671.0167   1795.4633
        x[9]   -2006.4426   -703.5645   -15.7514   624.2890   1962.0789
       x[10]   -2082.6195   -646.1086   -13.8338   622.7712   1896.3478


julia> chns.info.samplerstate.hamiltonian.metric  # specified metric is not adapted
DenseEuclideanMetric(diag=[1.0, 1.0, 1.0, 1.0, 1.0, 1 ...])

julia> chns.info.samplerstate.kernel.τ.integrator  # specified integrator's step size is adapted
JitteredLeapfrog(ϵ0=837.0, jitter=0.1, ϵ=0.2)

cc @cpfiffer @yebai

sethaxen avatar Apr 16 '22 12:04 sethaxen

Some things I don't like about the current interface:

  • To customize the adaptor, if I end up creating the StepSizeAdaptor myself, then I need to also create the integrator myself, and I don't benefit from Turing's mechanism for finding a good initial step size.
  • If I provide the integrator or metric, I may want to disable the step size or metric adaptation, respectively. Currently this requires creating the adaptor myself, but I'd prefer a simpler way to disable one or both adaptation schemes.

sethaxen avatar Apr 16 '22 21:04 sethaxen

@sethaxen apologies for the slow response, I'll take a look later this week.

yebai avatar Apr 26 '22 12:04 yebai

Closed in favour of https://github.com/TuringLang/Turing.jl/pull/1997

yebai avatar Jun 06 '23 09:06 yebai