Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Modifying an array input as part of struct and passing it to a softmax throws a conversion MethodError when done with ForwardDiff

Open PTWaade opened this issue 9 months ago • 5 comments

Minimal working example (updated)

using Turing, LogExpFunctions

mutable struct Bar
    values::Vector
end

values = zeros(Real, 4)
state = Bar(values)

inputs = [1,2,3,4,2,3,2,1,2,3,2,1]
actions = [1,2,3,4,2,3,2,1,2,3,2,1]

@model function m(inputs::Vector{Int64}, actions::Vector{Int64}, state)

    α ~ LogitNormal(0, 1) #learning rate
    β ~ LogNormal(0, 1)   #inverse temperature

    for (input, action) in zip(inputs, actions)

        x ~ to_submodel(single_step(input, action, state, α, β))
    end
end

@model function single_step(input, action, state, α, β)

    values = state.values

    @show values

    action_probs = softmax(values * β)

    values[input] += α 

    state.values = values

    action ~ Categorical(action_probs)
end

model = m(inputs, actions, state)

sample(model, NUTS(), 1000)

import Mooncake
values = zeros(Real, 4)
state = Bar(values)
model = m(inputs, actions, state)
sample(model, NUTS(; adtype = AutoMooncake(; config = nothing)), 1000)

Description

Dear Turing,

Again, thank you for your incredible work. Here is another error I hit while developing a Turing-reliant package.

Above is a fairly minimal example that errors with a MethodErrorr, when a Dual number is attempted converted to a Float64 somewhere within DynamicPPL.

It seems to happen when I update a vector in a struct that is passed as argument to the Turing model.

Mooncake does not seem to error, but since the error happens within DynamicPPL I thought I would post the issue here.

Let me know if there is anything further I can do !

Julia version info

versioninfo()
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 8 × Apple M1
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 

Manifest

]st --manifest
  [47edcb42] ADTypes v1.14.0
  [621f4979] AbstractFFTs v1.5.0
  [80f14c24] AbstractMCMC v5.6.0
⌅ [7a57a42e] AbstractPPL v0.10.1
  [1520ce14] AbstractTrees v0.4.5
  [7d9f7c33] Accessors v0.1.42
  [79e6a3ab] Adapt v4.3.0
  [0bf59076] AdvancedHMC v0.7.0
  [5b7e9947] AdvancedMH v0.8.6
  [576499cb] AdvancedPS v0.6.1
⌅ [b5ca4192] AdvancedVI v0.2.11
  [66dad0bd] AliasTables v1.1.3
  [dce04be8] ArgCheck v2.5.0
  [ec485272] ArnoldiMethod v0.4.0
  [4fba245c] ArrayInterface v7.18.0
  [a9b6321e] Atomix v1.1.1
  [13072b0f] AxisAlgorithms v1.1.0
  [39de3d68] AxisArrays v0.4.7
  [198e06fe] BangBang v0.4.4
  [9718e550] Baselet v0.1.1
  [76274a88] Bijectors v0.15.6
  [082447d4] ChainRules v1.72.3
  [d360d2e6] ChainRulesCore v1.25.1
  [9e997f8a] ChangesOfVariables v0.1.9
  [861a8166] Combinatorics v1.0.2
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.1
  [34da2185] Compat v4.16.0
  [a33af91c] CompositionsBase v0.1.2
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.5.8
  [a8cc5b0e] Crayons v4.1.1
  [9a962f9c] DataAPI v1.16.0
  [864edb3b] DataStructures v0.18.22
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [8bb1440f] DelimitedFiles v1.9.1
  [b429d917] DensityInterface v0.4.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [de460e47] DiffTests v0.1.2
  [a0c0ee7d] DifferentiationInterface v0.6.48
  [31c24e10] Distributions v0.25.118
  [ced4e74d] DistributionsAD v0.6.57
  [ffbed154] DocStringExtensions v0.9.3
  [366bfd00] DynamicPPL v0.35.5
  [cad2338a] EllipticalSliceSampling v2.0.0
  [4e289a0a] EnumX v1.0.4
  [e2ba6199] ExprTools v0.1.10
  [55351af7] ExproniconLite v0.10.14
  [7a1cc6ca] FFTW v1.8.1
  [9aa1b823] FastClosures v0.3.2
  [1a297f60] FillArrays v1.13.0
  [6a86dc24] FiniteDiff v2.27.0
  [f6369f11] ForwardDiff v0.10.38
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [d9f16b24] Functors v0.5.2
  [46192b85] GPUArraysCore v0.2.0
  [86223c79] Graphs v1.12.0
  [076d061b] HashArrayMappedTries v0.2.0
  [34004b35] HypergeometricFunctions v0.3.28
  [d25df0c9] Inflate v0.1.5
  [22cec73e] InitialValues v0.3.1
  [a98d9a8b] Interpolations v0.15.1
  [8197267c] IntervalSets v0.7.10
  [3587e190] InverseFunctions v0.1.17
  [41ab1584] InvertedIndices v1.3.1
  [92d709cd] IrrationalConstants v0.2.4
  [c8e1da08] IterTools v1.10.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.7.0
  [682c06a0] JSON v0.21.4
  [ae98c720] Jieko v0.2.1
  [63c18a36] KernelAbstractions v0.9.34
  [5ab0869b] KernelDensity v0.6.9
  [5be7bae1] LBFGSB v0.4.1
  [8ac3fa9e] LRUCache v1.6.2
  [b964fa9f] LaTeXStrings v1.4.0
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
  [6f1fad26] Libtask v0.8.8
  [d3d80556] LineSearches v7.3.0
  [6fdf6af0] LogDensityProblems v2.1.2
  [996a588d] LogDensityProblemsAD v1.13.0
  [2ab3a3ac] LogExpFunctions v0.3.29
  [e6f89c97] LoggingExtras v1.1.0
  [c7f686f2] MCMCChains v6.0.7
  [be115224] MCMCDiagnosticTools v0.3.14
  [e80e1ace] MLJModelInterface v1.11.0
  [1914dd2f] MacroTools v0.5.15
  [dbb5928d] MappedArrays v0.4.2
  [128add7d] MicroCollections v0.2.0
  [e1d29d7a] Missings v1.2.0
  [dbe65cb8] MistyClosures v2.0.0
  [da2b9cff] Mooncake v0.4.109
  [2e0e35c7] Moshi v0.3.5
  [d41bc354] NLSolversBase v7.9.0
  [872c559c] NNlib v0.9.29
  [77ba4419] NaNMath v1.1.2
  [86f7a689] NamedArrays v0.10.3
  [c020b1a1] NaturalSort v1.0.0
  [6fe1bfb0] OffsetArrays v1.16.0
  [429524aa] Optim v1.11.0
  [3bd65402] Optimisers v0.4.5
  [7f7a1694] Optimization v4.1.2
  [bca83a33] OptimizationBase v2.5.0
  [36348300] OptimizationOptimJL v0.4.1
  [bac558e1] OrderedCollections v1.8.0
  [90014a1f] PDMats v0.11.32
  [d96e819e] Parameters v0.12.3
  [69de0a69] Parsers v2.8.1
  [85a6dd25] PositiveFactorizations v0.2.4
  [aea7be01] PrecompileTools v1.2.1
  [21216c6a] Preferences v1.4.3
  [08abe8d2] PrettyTables v2.4.0
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.10.2
  [43287f4e] PtrArrays v1.3.0
  [1fd47b50] QuadGK v2.11.2
  [74087812] Random123 v1.7.0
  [e6cf234a] RandomNumbers v1.6.0
  [b3c3ace0] RangeArrays v0.3.2
  [c84ed2f1] Ratios v0.4.5
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.4
  [731186ca] RecursiveArrayTools v3.31.1
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.1
  [79098fc4] Rmath v0.8.0
  [f2b01f46] Roots v2.2.6
  [7e49a35a] RuntimeGeneratedFunctions v0.5.13
⌅ [26aad666] SSMProblems v0.1.1
  [0bca4576] SciMLBase v2.79.0
  [c0aeaf25] SciMLOperators v0.3.13
  [53ae85a6] SciMLStructures v1.7.0
  [30f210dd] ScientificTypesBase v3.0.0
  [7e506255] ScopedValues v1.3.0
  [efcf1570] Setfield v1.1.2
  [699a6c99] SimpleTraits v0.9.4
  [ce78b400] SimpleUnPack v1.1.0
  [a2af1166] SortingAlgorithms v1.2.1
  [9f842d2f] SparseConnectivityTracer v0.6.15
  [dc90abb0] SparseInverseSubset v0.1.2
  [0a514795] SparseMatrixColorings v0.4.14
  [276daf66] SpecialFunctions v2.5.0
  [171d559e] SplittablesBase v0.1.15
  [90137ffa] StaticArrays v1.9.13
  [1e83bf80] StaticArraysCore v1.4.3
  [64bff920] StatisticalTraits v3.4.0
  [10745b16] Statistics v1.11.1
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.4
  [4c63d2b9] StatsFuns v1.3.2
  [892a3eda] StringManipulation v0.4.1
  [09ab397b] StructArrays v0.7.0
  [2efcf032] SymbolicIndexingInterface v0.3.38
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.12.0
  [5d786b92] TerminalLoggers v0.1.7
  [9f7883ad] Tracker v0.2.37
  [28d57a85] Transducers v0.4.84
  [fce5fe82] Turing v0.37.0
  [3a884ed6] UnPack v1.0.2
  [013be700] UnsafeAtomics v0.3.0
  [efce3f68] WoodburyMatrices v1.0.0
  [700de1a5] ZygoteRules v0.2.7
  [f5851436] FFTW_jll v3.3.10+3
  [1d5cc7b8] IntelOpenMP_jll v2025.0.4+0
  [81d17ec3] L_BFGS_B_jll v3.0.1+0
  [856f044c] MKL_jll v2025.0.1+1
  [efe28fd5] OpenSpecFun_jll v0.5.6+0
  [f50d1b31] Rmath_jll v0.5.1+0
  [1317d2d5] oneTBB_jll v2022.0.0+0
  [0dad84c5] ArgTools v1.1.2
  [56f22d72] Artifacts v1.11.0
  [2a0f44e3] Base64 v1.11.0
  [ade2ca70] Dates v1.11.0
  [8ba89e20] Distributed v1.11.0
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching v1.11.0
  [9fa8497b] Future v1.11.0
  [b77e0a4c] InteractiveUtils v1.11.0
  [4af54fe1] LazyArtifacts v1.11.0
  [b27032c2] LibCURL v0.6.4
  [76f85450] LibGit2 v1.11.0
  [8f399da3] Libdl v1.11.0
  [37e2e46d] LinearAlgebra v1.11.0
  [56ddb016] Logging v1.11.0
  [d6f4376e] Markdown v1.11.0
  [a63ad114] Mmap v1.11.0
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.11.0
  [de0858da] Printf v1.11.0
  [9a3f8284] Random v1.11.0
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization v1.11.0
  [1a1011a3] SharedArrays v1.11.0
  [6462fe0b] Sockets v1.11.0
  [2f01184e] SparseArrays v1.11.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test v1.11.0
  [cf7118a7] UUIDs v1.11.0
  [4ec0a83e] Unicode v1.11.0
  [e66e0078] CompilerSupportLibraries_jll v1.1.1+0
  [deac9b47] LibCURL_jll v8.6.0+0
  [e37daf67] LibGit2_jll v1.7.2+0
  [29816b5a] LibSSH2_jll v1.11.0+1
  [c8ffd9c3] MbedTLS_jll v2.28.6+0
  [14a3606d] MozillaCACerts_jll v2023.12.12
  [4536629a] OpenBLAS_jll v0.3.27+1
  [05823500] OpenLibm_jll v0.8.1+2
  [bea87d4a] SuiteSparse_jll v7.7.0+0
  [83775a58] Zlib_jll v1.2.13+1
  [8e850b90] libblastrampoline_jll v5.11.0+0
  [8e850ede] nghttp2_jll v1.59.0+0
  [3f19e933] p7zip_jll v17.4.0+2

PTWaade avatar Mar 25 '25 13:03 PTWaade

Hi @PTWaade! The use of Dual by ForwardDiff indeed sometimes causes trouble with element types of collections (see e.g. here for one common situation). Mooncake avoids this by not using tracker types like Dual.

However, I think in your case there may be a simpler issue to be dealt with primarily. Namely, every time model gets executed, the values array keeps accumulating its elements. I don't fully understand what your model is for, but I doubt this is desirable. Observe for instance the following:

model = m(inputs, actions, state)

println("After 0 evaluations: $state")
model()
println("After 1 evaluation: $state")
for _ in 1:100
    model()
end
println("After 101 evaluations: $state")

prints out

After 0 evaluations: Main.MWE.Bar(Real[0, 0, 0, 0])
After 1 evaluation: Main.MWE.Bar(Real[0.5416161697738975, 0.9026936162898291, 0.5416161697738975, 0.18053872325796583])
After 101 evaluations: Main.MWE.Bar(Real[160.78127600912504, 267.968793348542, 160.78127600912504, 53.593758669708414])

I would assume this would throw off model behaviour, and make it depend on things like how many model evaluations the sampler happens to do per sample.

Might this make sense instead?

struct Bar{E}
    values::Vector{E}
end

inputs = [1,2,3,4,2,3,2,1,2,3,2,1]
actions = [1,2,3,4,2,3,2,1,2,3,2,1]

@model function m(inputs::Vector{Int64}, actions::Vector{Int64}, ::Type{T}=Float64) where {T}

    values = zeros(T, 4)
    state = Bar(values)

    α ~ LogitNormal(0, 1) #learning rate
    β ~ LogNormal(0, 1)   #inverse temperature

    for (input, action) in zip(inputs, actions)

        x ~ to_submodel(single_step(input, action, state, α, β))
    end
end

@model function single_step(input, action, state, α, β)

    values = state.values

    action_probs = softmax(values * β)

    values[input] += α 

    # values = [
    #     values[i] + α * (input == i) for i = 1:4
    # ]

    action ~ Categorical(action_probs)
end

I've moved the initialisation of state inside the model, and further made its element type be a variable T (see the link from above for why to do the latter).

I've also changed the Bar type a bit. This isn't necessary to fix the Turing.jl issue, but just good Julia practice otherwise:

  • I've given it a type parameter E so that the element type of bar.values can be inferred from the type of bar::Bar, which helps type stability and hence performance.
  • I've made Bar immutable. Since the values[input] += α operation already modifies state.values in-place, there's no need to reassign bar.values = values, and hence no need for Bar to be mutable.

mhauru avatar Apr 01 '25 14:04 mhauru

Hey mhauru !

Thank you - for a quick and comprehensive answer :)

I'm aware that initialising Bar inside the model would both avoid the object being updated every time the model is run, and also allow for setting the type in model function header (instead of just setting the type to Real, as is giving as the other option in the Turing documentation that you linked to).

Just for context, the package ActionModels, which is certainly still work in progress, is for cognitive modelling, which can be used for agent-based simulation, and for fitting to experimental behavioural data (I'm using Turing for the latter).

I've had two reasons for initializing the Bar struct outside the function (in ActionModels, it's an Agent structure which creates behaviour given inputs). One is that, depending on the model, the Agent struct can be somewhat large (for example containing a predictive coding network, which can be a little complex to initialize properly), and I thought it would be unnecessarily computationally heavy to initialise it on every sample. Secondly, and less importantly, it is convenient that the struct can be created and manipulated by the user outside the model before being passed to it (although this I can create a different API for). In the package, I have also made sure that a copy of the Bar struct is created and passed to the model inside a function, and that the struct is reset at the beginning of the forward run, so that the multiple-evaluation problem doesn't arise.

However, initializing the struct outside the model of course prevents me from using a type parameter (like E) in the struct and from using the function header type parameter, as you do here.

One question: is it correct that it is more efficient to use the type parameter and initialize vectors/structs inside the function than for them to have Real types (which, it seems, leads to type instability) in order to allow for dual number? Both are suggested in the Turing documentation, but it was unclear to me if one was preferable.

I have been considering options for making as much as possible instantiated outside the function (including for example a network with node types that controls update functions etc.), and just having a separate container specifically for values that inference touches which is then initialized inside the model. Do I understand correctly that this would be the optimal / only solution here?

PTWaade avatar Apr 03 '25 08:04 PTWaade

It seems I read too much into your minimised example. :)

One question: is it correct that it is more efficient to use the type parameter and initialize vectors/structs inside the function than for them to have Real types (which, it seems, leads to type instability) in order to allow for dual number?

Depends of course on how heavy the initialisation is, but if I had to guess I would guess yes. Type stability can hit performance really hard. You could try both approaches though.

I have been considering options for making as much as possible instantiated outside the function (including for example a network with node types that controls update functions etc.), and just having a separate container specifically for values that inference touches which is then initialized inside the model. Do I understand correctly that this would be the optimal / only solution here?

This sounds like a good approach to me. The other option would be to rely fully on Mooncake. Would that work for you?

mhauru avatar Apr 03 '25 12:04 mhauru

Hey @mhauru my apologies for the long response time here - and thanks again for your responses. I'll adapt my package to initialize the structures in the model, but keep it as minimal as I can, so that I won't have to depend only on Mooncake.

One thought from my side: it might be clearer from the documentation you linked to that, of the two options presented, using a container of Real is discouraged (due to the loss of type stability). Naively, I just picked the first of the two when reading this section because it was simpler to implement in my case, but it seems to me that it is not usually an advantage. It might also be nice to be clear that it doesn't apply if using Mooncake.

Finally, in the code example above (I simplified it a little), I get an error even though I have made the container with a Real type, although I thought from the documentation that it would run (even if it wasn't type-stable). Did I make some obvious mistake?

These are minor points by now, as I am proceeding with my work in any case, so you can close this issue if you would like :)

And again - thank you!

PTWaade avatar Apr 16 '25 09:04 PTWaade

One thought from my side: it might be clearer from the documentation you linked to that, of the two options presented, using a container of Real is discouraged (due to the loss of type stability). Naively, I just picked the first of the two when reading this section because it was simpler to implement in my case, but it seems to me that it is not usually an advantage. It might also be nice to be clear that it doesn't apply if using Mooncake.

Fully agree on this, we should improve that part of the documentation.

I'll come back to the code example a bit later.

mhauru avatar Apr 16 '25 10:04 mhauru