Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Gibbs sampler does not carry through log-prob from an external sampler

Open ruarai opened this issue 7 months ago • 8 comments

Minimal working example

@model function test_model(y)
    a ~ Normal(0, 1)

    y ~ Normal(a, 1)
end

y = rand(Normal(0.5, 1), 100)

model = test_model(y)

using AdvancedMH
rwmh = AdvancedMH.RWMH(1)

chain_1 = sample(model, externalsampler(rwmh), 100);
chain_2 = sample(model, Gibbs(:a => MH()), 100);
chain_3 = sample(model, Gibbs(:a => externalsampler(rwmh)), 100);


plot(chain_1, [:a, :lp], seriestype = :traceplot)
plot(chain_2, [:a, :lp], seriestype = :traceplot)
plot(chain_3, [:a, :lp], seriestype = :traceplot)

Description

I have noticed that the Gibbs sampler does not seem to carry through the log-probability when used with an ExternalSampler.

In the MWE, both chains 1 and 2 have valid (varying) lp values across each chain, while chain 3 has a constant lp value across the chain.

I can make get chain_3 to work with the following hack, though this is dependent on my transition type having an lp field:

function Turing.Inference.varinfo(state::Turing.Inference.TuringState)
    θ = Turing.Inference.getparams(state.ldf.model, state.state)
    vi = DynamicPPL.unflatten(state.ldf.varinfo, θ)

    vi = setlogp!!(vi, state.state.transition.lp)

    return vi
end

Julia version info

versioninfo()
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × AMD EPYC 7702 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 64 default, 0 interactive, 32 GC (on 32 virtual cores)
Environment:
  JULIA_VSCODE_REPL = 1
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 64

Manifest

]st --manifest
(seromix) pkg> st --manifest
Status `/pvol/source/seromix/Manifest.toml`
  [47edcb42] ADTypes v1.14.0
  [621f4979] AbstractFFTs v1.5.0
  [80f14c24] AbstractMCMC v5.6.2
  [7a57a42e] AbstractPPL v0.11.0
  [1520ce14] AbstractTrees v0.4.5
  [7d9f7c33] Accessors v0.1.42
  [79e6a3ab] Adapt v4.3.0
⌅ [0bf59076] AdvancedHMC v0.7.1
  [5b7e9947] AdvancedMH v0.8.7
  [576499cb] AdvancedPS v0.6.2
⌅ [b5ca4192] AdvancedVI v0.2.12
  [66dad0bd] AliasTables v1.1.3
  [dce04be8] ArgCheck v2.5.0
  [ec485272] ArnoldiMethod v0.4.0
  [7d9fca2a] Arpack v0.5.4
  [4fba245c] ArrayInterface v7.19.0
  [a9b6321e] Atomix v1.1.1
  [13072b0f] AxisAlgorithms v1.1.0
  [39de3d68] AxisArrays v0.4.7
  [198e06fe] BangBang v0.4.4
  [9718e550] Baselet v0.1.1
  [6e4b80f9] BenchmarkTools v1.6.0
  [76274a88] Bijectors v0.15.7
  [d1d4a3ce] BitFlags v0.1.9
  [c3b6d118] BitIntegers v0.3.5
  [fa961155] CEnum v0.5.0
  [324d7699] CategoricalArrays v0.10.8
⌃ [082447d4] ChainRules v1.72.3
  [d360d2e6] ChainRulesCore v1.25.1
  [0ca39b1e] Chairmarks v1.3.1
  [9e997f8a] ChangesOfVariables v0.1.10
  [aaaa29a8] Clustering v0.15.8
  [944b1d66] CodecZlib v0.7.8
  [6b39b394] CodecZstd v0.8.6
  [35d6a980] ColorSchemes v3.29.0
  [3da002f7] ColorTypes v0.12.1
  [c3611d14] ColorVectorSpace v0.11.0
  [5ae59095] Colors v0.13.1
  [861a8166] Combinatorics v1.0.3
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.1
  [34da2185] Compat v4.16.0
  [a33af91c] CompositionsBase v0.1.2
  [f0e56b4a] ConcurrentUtilities v2.5.0
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.5.8
  [d38c429a] Contour v0.6.3
  [a8cc5b0e] Crayons v4.1.1
  [a10d1c49] DBInterface v2.6.1
  [9a962f9c] DataAPI v1.16.0
  [a93c6f00] DataFrames v1.7.0
  [864edb3b] DataStructures v0.18.22
  [e2d170a0] DataValueInterfaces v1.0.0
⌅ [abce61dc] Decimals v0.4.1
  [244e2a9f] DefineSingletons v0.1.2
  [8bb1440f] DelimitedFiles v1.9.1
  [b429d917] DensityInterface v0.4.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
⌅ [a0c0ee7d] DifferentiationInterface v0.6.54
  [b4f34e82] Distances v0.10.12
  [31c24e10] Distributions v0.25.120
  [ced4e74d] DistributionsAD v0.6.58
  [ffbed154] DocStringExtensions v0.9.4
  [d2f5444f] DuckDB v1.3.0
  [366bfd00] DynamicPPL v0.36.10
  [cad2338a] EllipticalSliceSampling v2.0.0
  [4e289a0a] EnumX v1.0.5
  [460bff9d] ExceptionUnwrapping v0.1.11
  [e2ba6199] ExprTools v0.1.10
  [55351af7] ExproniconLite v0.10.14
  [c87230d0] FFMPEG v0.4.2
⌃ [7a1cc6ca] FFTW v1.8.1
  [9aa1b823] FastClosures v0.3.2
  [5789e2e9] FileIO v1.17.0
  [1a297f60] FillArrays v1.13.0
  [6a86dc24] FiniteDiff v2.27.0
  [fb4d412d] FixedPointDecimals v0.6.3
  [53c48c17] FixedPointNumbers v0.8.5
  [1fa38f19] Format v1.3.7
⌅ [f6369f11] ForwardDiff v0.10.38
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [d9f16b24] Functors v0.5.2
  [46192b85] GPUArraysCore v0.2.0
  [28b8d3ca] GR v0.73.16
  [86223c79] Graphs v1.12.1
  [42e2da0e] Grisu v1.0.2
  [f67ccb44] HDF5 v0.17.2
  [cd3eb016] HTTP v1.10.16
  [076d061b] HashArrayMappedTries v0.2.0
  [34004b35] HypergeometricFunctions v0.3.28
  [d25df0c9] Inflate v0.1.5
  [22cec73e] InitialValues v0.3.1
  [842dd82b] InlineStrings v1.4.3
⌅ [a98d9a8b] Interpolations v0.15.1
  [8197267c] IntervalSets v0.7.11
  [3587e190] InverseFunctions v0.1.17
  [41ab1584] InvertedIndices v1.3.1
  [92d709cd] IrrationalConstants v0.2.4
  [c8e1da08] IterTools v1.10.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [033835bb] JLD2 v0.5.13
  [1019f520] JLFzf v0.1.11
  [692b3bcd] JLLWrappers v1.7.0
  [682c06a0] JSON v0.21.4
  [ae98c720] Jieko v0.2.1
  [63c18a36] KernelAbstractions v0.9.34
  [5ab0869b] KernelDensity v0.6.9
  [5be7bae1] LBFGSB v0.4.1
  [8ac3fa9e] LRUCache v1.6.2
  [b964fa9f] LaTeXStrings v1.4.0
  [23fbe1c1] Latexify v0.16.8
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
⌅ [6f1fad26] Libtask v0.8.8
  [d3d80556] LineSearches v7.3.0
  [1724a1d5] LittleEndianBase128 v0.3.0
  [6fdf6af0] LogDensityProblems v2.1.2
  [996a588d] LogDensityProblemsAD v1.13.0
  [2ab3a3ac] LogExpFunctions v0.3.29
  [e6f89c97] LoggingExtras v1.1.0
⌃ [c7f686f2] MCMCChains v6.0.7
  [be115224] MCMCDiagnosticTools v0.3.14
  [e80e1ace] MLJModelInterface v1.11.1
  [3da0fdf6] MPIPreferences v0.1.11
  [1914dd2f] MacroTools v0.5.16
  [dbb5928d] MappedArrays v0.4.2
  [739be429] MbedTLS v1.1.9
  [442fdcdd] Measures v0.3.2
  [128add7d] MicroCollections v0.2.0
  [e1d29d7a] Missings v1.2.0
  [dbe65cb8] MistyClosures v2.0.0
  [da2b9cff] Mooncake v0.4.121
  [2e0e35c7] Moshi v0.3.5
  [6f286f6a] MultivariateStats v0.10.3
  [d41bc354] NLSolversBase v7.9.1
  [872c559c] NNlib v0.9.30
  [77ba4419] NaNMath v1.1.3
  [86f7a689] NamedArrays v0.10.4
  [c020b1a1] NaturalSort v1.0.0
  [b8a86587] NearestNeighbors v0.4.21
  [510215fc] Observables v0.5.5
  [6fe1bfb0] OffsetArrays v1.17.0
  [4d8831e6] OpenSSL v1.5.0
  [429524aa] Optim v1.12.0
  [3bd65402] Optimisers v0.4.6
  [7f7a1694] Optimization v4.3.0
⌃ [bca83a33] OptimizationBase v2.6.0
  [36348300] OptimizationOptimJL v0.4.3
  [bac558e1] OrderedCollections v1.8.1
  [90014a1f] PDMats v0.11.35
  [d96e819e] Parameters v0.12.3
  [626c502c] Parquet v0.8.5
  [69de0a69] Parsers v2.8.3
  [ccf2f8ad] PlotThemes v3.3.0
  [995b91a9] PlotUtils v1.4.3
  [91a5bcdd] Plots v1.40.13
  [2dfb63ee] PooledArrays v1.4.3
  [85a6dd25] PositiveFactorizations v0.2.4
⌅ [aea7be01] PrecompileTools v1.2.1
  [21216c6a] Preferences v1.4.3
  [08abe8d2] PrettyTables v2.4.0
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.10.4
  [43287f4e] PtrArrays v1.3.0
  [7b8617ff] QuackIO v0.1.5
  [1fd47b50] QuadGK v2.11.2
  [74087812] Random123 v1.7.1
  [e6cf234a] RandomNumbers v1.6.0
  [b3c3ace0] RangeArrays v0.3.2
  [c84ed2f1] Ratios v0.4.5
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.4
  [01d81517] RecipesPipeline v0.6.12
  [731186ca] RecursiveArrayTools v3.33.0
  [189a3867] Reexport v1.2.2
  [05181044] RelocatableFolders v1.0.1
  [ae029012] Requires v1.3.1
  [79098fc4] Rmath v0.8.0
  [f2b01f46] Roots v2.2.7
  [7e49a35a] RuntimeGeneratedFunctions v0.5.15
  [26aad666] SSMProblems v0.5.0
  [0bca4576] SciMLBase v2.96.0
⌃ [c0aeaf25] SciMLOperators v0.4.0
  [53ae85a6] SciMLStructures v1.7.0
  [30f210dd] ScientificTypesBase v3.0.0
  [7e506255] ScopedValues v1.3.0
  [6c6a2e73] Scratch v1.2.1
  [91c51154] SentinelArrays v1.4.8
  [efcf1570] Setfield v1.1.2
  [992d4aef] Showoff v1.0.3
  [777ac1f9] SimpleBufferStream v1.2.0
  [699a6c99] SimpleTraits v0.9.4
  [59d4ed8c] Snappy v0.4.3
  [a2af1166] SortingAlgorithms v1.2.1
  [9f842d2f] SparseConnectivityTracer v0.6.18
  [dc90abb0] SparseInverseSubset v0.1.2
  [0a514795] SparseMatrixColorings v0.4.20
  [276daf66] SpecialFunctions v2.5.1
  [171d559e] SplittablesBase v0.1.15
  [860ef19b] StableRNGs v1.0.3
  [90137ffa] StaticArrays v1.9.13
  [1e83bf80] StaticArraysCore v1.4.3
  [64bff920] StatisticalTraits v3.4.0
  [10745b16] Statistics v1.11.1
  [82ae8749] StatsAPI v1.7.1
  [2913bbd2] StatsBase v0.34.5
  [4c63d2b9] StatsFuns v1.5.0
  [f3b207a7] StatsPlots v0.15.7
  [892a3eda] StringManipulation v0.4.1
  [09ab397b] StructArrays v0.7.1
  [2efcf032] SymbolicIndexingInterface v0.3.40
  [ab02a1b2] TableOperations v1.2.0
  [3783bdb8] TableTraits v1.0.1
⌃ [bd369af6] Tables v1.12.0
  [62fd8b95] TensorCore v0.1.1
  [5d786b92] TerminalLoggers v0.1.7
  [8d9c9c80] Thrift v0.8.5
  [9f7883ad] Tracker v0.2.38
  [3bb67fe8] TranscodingStreams v0.11.3
  [28d57a85] Transducers v0.4.84
⌃ [fce5fe82] Turing v0.38.3
  [5c2747f8] URIs v1.5.2
  [3a884ed6] UnPack v1.0.2
  [1cfade01] UnicodeFun v0.4.1
  [1986cc42] Unitful v1.22.1
⌃ [45397f5d] UnitfulLatexify v1.6.4
  [013be700] UnsafeAtomics v0.3.0
  [41fe7b60] Unzip v0.2.0
  [ea10d353] WeakRefStrings v1.4.2
  [cc8bc4a8] Widgets v0.6.7
  [efce3f68] WoodburyMatrices v1.0.0
  [700de1a5] ZygoteRules v0.2.7
⌅ [68821587] Arpack_jll v3.5.1+1
  [6e34b625] Bzip2_jll v1.0.9+0
  [83423d85] Cairo_jll v1.18.5+0
  [ee1fde0b] Dbus_jll v1.16.2+0
  [2cbbab25] DuckDB_jll v1.3.0+0
  [2702e6a9] EpollShim_jll v0.0.20230411+1
  [2e619515] Expat_jll v2.6.5+0
⌅ [b22a6f82] FFMPEG_jll v4.4.4+1
  [f5851436] FFTW_jll v3.3.11+0
  [a3f928ae] Fontconfig_jll v2.16.0+0
  [d7e528f0] FreeType2_jll v2.13.4+0
  [559328eb] FriBidi_jll v1.0.17+0
  [0656b61e] GLFW_jll v3.4.0+2
  [d2c73de3] GR_jll v0.73.16+0
  [78b55507] Gettext_jll v0.21.0+0
  [7746bdde] Glib_jll v2.84.0+0
  [3b182d85] Graphite2_jll v1.3.15+0
  [0234f1f7] HDF5_jll v1.14.6+0
  [2e76f6c2] HarfBuzz_jll v8.5.1+0
  [e33a78d0] Hwloc_jll v2.12.1+0
  [1d5cc7b8] IntelOpenMP_jll v2025.0.4+0
  [aacddb02] JpegTurbo_jll v3.1.1+0
  [c1c5ebd0] LAME_jll v3.100.2+0
  [88015f11] LERC_jll v4.0.1+0
  [1d63c593] LLVMOpenMP_jll v18.1.8+0
  [dd4b983a] LZO_jll v2.10.3+0
  [81d17ec3] L_BFGS_B_jll v3.0.1+0
  [e9f186c6] Libffi_jll v3.4.7+0
  [7e76a0d4] Libglvnd_jll v1.7.1+1
  [94ce4f54] Libiconv_jll v1.18.0+0
  [4b2f31a3] Libmount_jll v2.41.0+0
  [89763e89] Libtiff_jll v4.7.1+0
  [38a345b3] Libuuid_jll v2.41.0+0
  [856f044c] MKL_jll v2025.0.1+1
  [7cb0a576] MPICH_jll v4.3.0+1
  [f1f71cc9] MPItrampoline_jll v5.5.3+0
  [9237b28f] MicrosoftMPI_jll v10.1.4+3
  [e7412a2a] Ogg_jll v1.3.5+1
  [fe0851c0] OpenMPI_jll v5.0.7+2
  [458c3c95] OpenSSL_jll v3.5.0+0
  [efe28fd5] OpenSpecFun_jll v0.5.6+0
  [91d4177d] Opus_jll v1.3.3+0
  [36c8627f] Pango_jll v1.56.3+0
⌅ [30392449] Pixman_jll v0.44.2+0
  [c0090381] Qt6Base_jll v6.8.2+1
  [629bc702] Qt6Declarative_jll v6.8.2+1
  [ce943373] Qt6ShaderTools_jll v6.8.2+1
  [e99dba38] Qt6Wayland_jll v6.8.2+0
  [f50d1b31] Rmath_jll v0.5.1+0
  [815b9798] ThriftJuliaCompiler_jll v0.12.1+0
  [a44049a8] Vulkan_Loader_jll v1.3.243+0
  [a2964d1f] Wayland_jll v1.23.1+0
  [2381bf8a] Wayland_protocols_jll v1.36.0+0
⌅ [02c8fc9c] XML2_jll v2.13.6+1
  [ffd25f8a] XZ_jll v5.8.1+0
  [f67eecfb] Xorg_libICE_jll v1.1.2+0
  [c834827a] Xorg_libSM_jll v1.2.6+0
  [4f6342f7] Xorg_libX11_jll v1.8.12+0
  [0c0b7dd1] Xorg_libXau_jll v1.0.13+0
  [935fb764] Xorg_libXcursor_jll v1.2.4+0
  [a3789734] Xorg_libXdmcp_jll v1.1.6+0
  [1082639a] Xorg_libXext_jll v1.3.7+0
  [d091e8ba] Xorg_libXfixes_jll v6.0.1+0
  [a51aa0fd] Xorg_libXi_jll v1.8.3+0
  [d1454406] Xorg_libXinerama_jll v1.1.6+0
  [ec84b674] Xorg_libXrandr_jll v1.5.5+0
  [ea2f1a96] Xorg_libXrender_jll v0.9.12+0
  [c7cfdc94] Xorg_libxcb_jll v1.17.1+0
  [cc61e674] Xorg_libxkbfile_jll v1.1.3+0
  [e920d4aa] Xorg_xcb_util_cursor_jll v0.1.4+0
  [12413925] Xorg_xcb_util_image_jll v0.4.0+1
  [2def613f] Xorg_xcb_util_jll v0.4.0+1
  [975044d2] Xorg_xcb_util_keysyms_jll v0.4.0+1
  [0d47668e] Xorg_xcb_util_renderutil_jll v0.3.9+1
  [c22f9ab0] Xorg_xcb_util_wm_jll v0.4.1+1
  [35661453] Xorg_xkbcomp_jll v1.4.7+0
  [33bec58e] Xorg_xkeyboard_config_jll v2.44.0+0
  [c5fb5394] Xorg_xtrans_jll v1.6.0+0
  [3161d3a3] Zstd_jll v1.5.7+1
  [35ca27e7] eudev_jll v3.2.9+0
  [214eeab7] fzf_jll v0.61.1+0
  [1a1c6b14] gperf_jll v3.3.0+0
  [477f73a3] libaec_jll v1.1.3+0
  [a4ae2306] libaom_jll v3.11.0+0
  [0ac62f75] libass_jll v0.15.2+0
  [1183f4f0] libdecor_jll v0.2.2+0
  [2db6ffa8] libevdev_jll v1.11.0+0
  [f638f0a6] libfdk_aac_jll v2.0.3+0
  [36db933b] libinput_jll v1.18.0+0
  [b53b4c65] libpng_jll v1.6.48+0
  [f27f6e37] libvorbis_jll v1.3.7+2
  [009596ad] mtdev_jll v1.1.6+0
  [1317d2d5] oneTBB_jll v2022.0.0+0
  [fe1e1685] snappy_jll v1.2.3+0
⌅ [1270edf5] x264_jll v2021.5.5+0
⌅ [dfaa095f] x265_jll v3.5.0+0
  [d8fb68d0] xkbcommon_jll v1.8.1+0
  [0dad84c5] ArgTools v1.1.2
  [56f22d72] Artifacts v1.11.0
  [2a0f44e3] Base64 v1.11.0
  [ade2ca70] Dates v1.11.0
  [8ba89e20] Distributed v1.11.0
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching v1.11.0
  [9fa8497b] Future v1.11.0
  [b77e0a4c] InteractiveUtils v1.11.0
  [4af54fe1] LazyArtifacts v1.11.0
  [b27032c2] LibCURL v0.6.4
  [76f85450] LibGit2 v1.11.0
  [8f399da3] Libdl v1.11.0
  [37e2e46d] LinearAlgebra v1.11.0
  [56ddb016] Logging v1.11.0
  [d6f4376e] Markdown v1.11.0
  [a63ad114] Mmap v1.11.0
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.11.0
  [de0858da] Printf v1.11.0
  [9abbd945] Profile v1.11.0
  [3fa0cd96] REPL v1.11.0
  [9a3f8284] Random v1.11.0
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization v1.11.0
  [1a1011a3] SharedArrays v1.11.0
  [6462fe0b] Sockets v1.11.0
  [2f01184e] SparseArrays v1.11.0
  [f489334b] StyledStrings v1.11.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test v1.11.0
  [cf7118a7] UUIDs v1.11.0
  [4ec0a83e] Unicode v1.11.0
  [e66e0078] CompilerSupportLibraries_jll v1.1.1+0
  [deac9b47] LibCURL_jll v8.6.0+0
  [e37daf67] LibGit2_jll v1.7.2+0
  [29816b5a] LibSSH2_jll v1.11.0+1
  [c8ffd9c3] MbedTLS_jll v2.28.6+0
  [14a3606d] MozillaCACerts_jll v2023.12.12
  [4536629a] OpenBLAS_jll v0.3.27+1
  [05823500] OpenLibm_jll v0.8.1+4
  [efcefdf7] PCRE2_jll v10.42.0+1
  [bea87d4a] SuiteSparse_jll v7.7.0+0
  [83775a58] Zlib_jll v1.2.13+1
  [8e850b90] libblastrampoline_jll v5.11.0+0
  [8e850ede] nghttp2_jll v1.59.0+0
  [3f19e933] p7zip_jll v17.4.0+2

ruarai avatar Jun 05 '25 03:06 ruarai

Thanks for reporting, @ruarai, and for the excellent writeup!

This looks bad -- I'll have a look today.

penelopeysm avatar Jun 05 '25 11:06 penelopeysm

I'm actually quite surprised that chain_1 works for you. When I run your MWE, my chain_1 lp doesn't change, and that's traceable to this:

https://github.com/TuringLang/Turing.jl/blob/58232e4785050b58e23b36b4b65f0eedf957ef71/src/mcmc/external_sampler.jl#L74-L81

unflatten doesn't correctly adjust the logp -- it only sets the values. When Transition(...) is called, it calls values_as_in_model, which does actually evaluate the model, so that does actually recalculate logp -- but there's no way to extract that information because values_as_in_model doesn't return the logp.

This can be fixed through various (rather ugly) means, for example after the call to unflatten, adding

  varinfo = DynamicPPL.unflatten(f.varinfo, θ)
+ # need to recalculate logp
+ _, varinfo = DynamicPPL.evaluate!!(f.model, varinfo, DefaultContext())
  return Transition(f.model, varinfo, transition)

to recalculate logp before constructing the Transition.

A more proper fix would be to modify the interface of values_as_in_model to somehow let us get the logp field, so that we don't have to unnecessarily evaluate the model.


A similar fix inside varinfo(state::TuringState), the original function you mentioned, would fix the Gibbs case in chain_3, without enforcing the assumption that the external sampler's transition has an lp field.

The primary way of avoiding this extra evaluation in this case would be to implement a getlogp(transition) interface. If an external sampler's transition object has an lp field, then this could just be defined as getlogp(t) = t.lp. We could make this default to nothing, in which case it would just evaluate the model again.

penelopeysm avatar Jun 05 '25 11:06 penelopeysm

I'm actually quite surprised that chain_1 works for you

I think I may have tried a similar hack to what you described and forgotten that it was still in my environment! I had gotten a bit confused with the two layers of sampling abstraction between gibbs and externalsampler.

Your proposed solution sounds sensible to me.

ruarai avatar Jun 05 '25 12:06 ruarai

Would you be interested in doing a PR? :)

penelopeysm avatar Jun 05 '25 12:06 penelopeysm

Possibly! I think I have a grasp of what would be required to do the getlogp interface.

Something like the following in external_sampler.jl:

getlogp(transition) = nothing
getlogp(transition::AdvancedMH.Transition) = transition.lp

(I don't think AdvancedHMC has a similar field we can use)

And then:

function varinfo(state::TuringState)
    θ = getparams(state.ldf.model, state.state)

    # TODO: Do we need to link here first?
    vi = DynamicPPL.unflatten(state.ldf.varinfo, θ)

    logp = getlogp(state.state.transition)

    if !ismissing(logp)
        vi = setlogp!!(vi, logp)
    end

    return vi
end

?

ruarai avatar Jun 06 '25 04:06 ruarai

Yup that looks about right! Feel free to ping me for a review and I'm happy to help work through any necessary details 😄

Just some initial thoughts:

  1. I was thinking it'd be easier on future me if we called it getlogp_transition instead, just to differentiate it from the existing getlogp in DynamicPPL, haha.

  2. For AdvancedHMC, there actually is a logp field, it's t.stat.log_density where t is the Transition. I discovered this from https://github.com/TuringLang/Turing.jl/blob/9cc5be9453cc76c00646c464a600550272fce63b/src/mcmc/hmc.jl#L291-L294 These interfaces are not very well coded 😅 but it's okay for now, it will probably be cleaned up in the near future.

  3. If you add some tests too with the original MWE you had, that would be great! I didn't think about it super in depth, but I thought one way would be to call sample as you did above and then check isapprox(logjoint(model, chain), chain[:lp])?

penelopeysm avatar Jun 06 '25 12:06 penelopeysm

Thanks for the help!

I am a bit in the weeds on the state/transition interface. I realise now that many samplers don't actually have a state.transition field, and some have a logprob directly in state.

So presumably we need something like getlogp_state here instead?

(This may correspond better with future plans e.g. https://github.com/TuringLang/AbstractMCMC.jl/issues/156?)

ruarai avatar Jun 12 '25 11:06 ruarai

I haven't looked too closely but that makes sense -- we might need two different functions, one for the transition and one for the state. 😅

Yeah, I think eventually we want to roll transition and state into a single thing. I reckon that might be a long way away though.

penelopeysm avatar Jun 12 '25 12:06 penelopeysm

Okay I've had a few goes at this and not sure I will have the time to work it out properly, sorry (and the logjoint function has otherwise solved the problem for me).

I ran into another problem not accounted for in the above hacks - the log-probability returned by the sampler will not be the one we want if the sampler is working with a bijection/in unconstrained space. I assume there's some way to account for this but it's not obvious to me 😅

ruarai avatar Jul 14 '25 04:07 ruarai

Okay I've had a few goes at this and not sure I will have the time to work it out properly, sorry (and the logjoint function has otherwise solved the problem for me).

That's ok! I can take it on. Glad to hear that the immediate problem is OK 🙂

the log-probability returned by the sampler will not be the one we want if the sampler is working with a bijection/in unconstrained space

😬 Do you have a code example I can work with?

penelopeysm avatar Jul 14 '25 09:07 penelopeysm

Do you have a code example I can work with?

This should demonstrate it:

using AdvancedMH

@model function test_model_1()
    a ~ Normal(0.0, 1.0)
    b ~ Normal(0.0, 1.0)
    1.5 ~ Normal(a + b, 0.5)
    2.0 ~ Normal(a + b, 0.5)
end

@model function test_model_2()
    a ~ Normal(0.0, 1.0)
    b ~ LogNormal(0.0, 1.0)
    1.5 ~ Normal(a + b, 0.5)
    2.0 ~ Normal(a + b, 0.5)
end

model_1 = test_model_1()
chain_ext_1 = sample(model_1, externalsampler(AdvancedMH.RWMH(2); unconstrained = false), 10);
all(chain_ext_1[:,:lp,:] .== logjoint(model_1, chain_ext_1))

model_2 = test_model_2()
chain_ext_2 = sample(model_2, externalsampler(AdvancedMH.RWMH(2); unconstrained = true), 10);
all(chain_ext_2[:,:lp,:] .== logjoint(model_2, chain_ext_2))

where we have changed the following to pass the logp through:


function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
    # TODO: We should probably rename this `getparams` since it returns something
    # very different from `Turing.Inference.getparams`.
    θ = getparams(f.model, transition)
    varinfo = DynamicPPL.unflatten(f.varinfo, θ)

    logp = transition.lp # should be getlogp_transition or similar

    return Transition(getparams(f.model, varinfo), logp, getstats(transition))
end

ruarai avatar Jul 15 '25 02:07 ruarai

I just tried the example above on your PR branch and it unfortunately trips up on the second chain as well

ruarai avatar Jul 15 '25 03:07 ruarai

Thanks! That's super helpful. Quite ugly, too, as I don't know where it's coming from.

If it's okay, I'll open a separate issue to track this. I'm not sure if it's AdvancedMH, DynamicPPL, or Turing that's being weird here.

penelopeysm avatar Jul 15 '25 10:07 penelopeysm

Yeah of course! Thanks very much for the help.

ruarai avatar Jul 15 '25 10:07 ruarai

By the way, if you don't mind me asking: what are you using externalsampler for? Is there something in AdvancedMH that the built-in Turing MH doesn't have (or is it actually a different sampler that you're using)? (Purely out of curiosity, not trying to imply that you should or shouldn't :))

penelopeysm avatar Jul 15 '25 11:07 penelopeysm

Ah yep I am using my own sampler(s). One to deal with a gross discrete parameter space and one to reproduce the sampler used in an earlier paper, then combining these with the Gibbs sampler. Turing has made this all much easier (and probably was the only option other than doing everything myself, given the particularities of the problem!)

A bit limited on what else I can share at the moment but hopefully will have something eventually.

ruarai avatar Jul 15 '25 22:07 ruarai