Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Meta-Bayesian inference with Turing.jl

Open PTWaade opened this issue 6 months ago • 6 comments

Hey Turing,

Here is, as promised (@mhauru @penelopeysm, @nsiccha), a minimal example of cognitive modeling-style meta-Bayesian modeling with Turing. Wonderfully, it actually runs fine when I just use a MH() sampler in the outer inference call! Autodiff in the outer model breaks, however, and somewhat bloodily: DualNumbers doesn't work, Mooncake throws an error, and Enzyme entirely crashes Julia on my computer.

Looking forward to hearing what you think!

Explanation for Meta-bayesian modelling

In what is sometimes called meta-Bayesian modeling, we model (human) subjects as doing Bayesian inference to understand their environment, while simultaneously using Bayesian inference to fit that model to their behavioural data. This leads to a split between an "inner" or "subjective" model, which is the agent's model of the environment, and an "outer" or "objective" model, which is our model of them. In meta-Bayesian modeling, the inversion of the inner model is (part of) the outer generative model. Often in the field, researchers use hand-coded exact or variational inference schemes as the inner models (there are different schools, including those that prefer variational Bayes, to those that do not). However, it would be wonderful to use tools like Turing for both the inner and outer models (so that cognitive models are not so limited by what is convenient or possible to hand-write). There are quite a few Bayesian mind models, as they are called (using various types of Bayesian inference, usually exact Bayes where possible), out there, and this should expand the set of psosible models by quite a lot - if it works.

Explanation of the "experimental setup"

For this (construed toy) example, we might imagine that we place a human subject in an experiment where they are presented with a machine producing sort series of either rewards or punishments, and afterwards have to decide if they want to use the machine. Essentially, the participant's task is to estimate the probability of the machine giving rewards, and choose it accordingly. Assuming here that 1 reward equals -1 punishment, it is rational to engage with the machine if the probability of reward is above 0.5. This happens repeatedly for n trials, and it is the same machine so the participant will update their beliefs over time. Here we simulate just three trials with some arbitrary data, so the posteriors are not that interesting. I have chosen a few exemplary parameters that we here estimate: the participant's prior belief, their bias, an action noise, and how autocorrelated they assume their environment to be, just to give a sense of the type of parameters often estimated. Often, there are hypotheses about, for example, some clinical psychiatric conditions being correlated with these parameters - we might for example imagine that in some clinical cases people will assume the world to be highly correlated (and therefore not learn so much about the true probabilities).

MWE

using Turing
#For the logit transforms in the model
using LogExpFunctions

#For different AD types
using ADTypes
import Enzyme: set_runtime_activity, Forward, Reverse
import Mooncake

#### CREATING INNER AND OUTER MODELS ####

# Inner model function: estimating probability of reward
@model function inner_model(observations, prior_μ = 0, prior_σ = 1, ρ = 0.1)

    # The inner model's prior over the baseline reward probability (in logit space)
    prob ~ Normal(prior_μ, prior_σ)

    #For each observation
    for i in 1:length(observations)
        
        ## Logit-transform the probability
        if i == 1
            probᵢ = logistic(prob)
        else
            #Take autocorrelation into account on all but the first observation
            probᵢ = logistic(prob + ρ * (observations[i] - 0.5))
        end

        #The observation is a Bernoulli draw
        observations[i] ~ Bernoulli(probᵢ)
    end
end


# Outer model function: simulating a Bayesian-minded subject's behavior
@model function outer_model(subj_observations, subj_actions, inner_sampler = NUTS(), inner_n_samples = 1000)
    
    ### Subject-level parameters to estimate ###

    #We may want to estimate parameters of the participants "perceptual model" (i.e. their Bayesian inference)
    #This can include estimating the subject's prior
    subj_prior_μ ~ Normal(0, 1)
    subj_prior_σ ~ Exponential(1)
    #But can also include estimating other parameters of the subject's generative model
    ρ ~ LogitNormal(-2, 1) #A parameter that controls the amount of observational autocorrelation assumed by the subject

    #We may want to estimate parameters of the subject's "response model" (i.e. after the inference)
    β ~ Exponential(1) #A classic inverse temperature parameter which controls the stochasticity of the subject's choices
    b ~ Normal(0, 1) #A bias parameter that shifts the subject's action probability

    beliefs_μ = Vector{Float64}(undef, length(subj_observations))
    beliefs_σ = Vector{Float64}(undef, length(subj_observations))
    #For each trial t in the experiment
    for (t,subj_observationsₜ) in enumerate(subj_observations)


        ### The subject's "perceptual model" ###

        #On the first trial
        if t == 1
            #The subject uses their prior
            subj_μ, subj_σ = subj_prior_μ, subj_prior_σ
        else
            #On other trials, the subject uses their posterior from the previous trial as prior
            subj_μ = beliefs_μ[t-1]
            subj_σ = beliefs_σ[t-1]
        end
       
        #Condition the subject's model on her observations this trial, and using her current prior
        inner_m = inner_model(subj_observationsₜ, subj_μ, subj_σ, ρ)

        #Simulate the subject's Bayesian inference
        chns = sample(inner_m, inner_sampler, inner_n_samples, progress = false)

        #Extract the subject's point estimate of the reward probability
        subj_mean_expectationₜ = mean(chns[:prob])


        ### The subject's "response model" ###

        #Transform the point estimate, with noise and bias, into the action probability at this trial
        action_probabilityₜ = logistic(subj_mean_expectationₜ * β + b)

        #The subject's action on this trial is a Bernoulli draw from the action probability
        subj_actions[t] ~ Bernoulli(action_probabilityₜ)

        ### Store values for next trial ###
        beliefs_μ[t] = subj_mean_expectationₜ
        beliefs_σ[t] = std(chns[:prob])
    end

end


### SPECIFYING DATA AND MODEL ARGUMENTS ###

## Subject-level data ##
#The subject's observations on each trial (1 for "reward", 0 for "punishment")
subj_observations = [
    [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0],
    [0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1],
    [0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
] 
#The subject's action on that trial (1 for "engage", 0 for "not engage")
subj_action = [1, 1, 0]

## Model specifications ##
## Which AD to use in the inner model (if the sampler uses it)
# inner_adtype = AutoForwardDiff()
# inner_adtype = AutoMooncake(; config = nothing)
# inner_adtype = AutoEnzyme(; mode = set_runtime_activity(Reverse, true))
# inner_adtype = AutoEnzyme(; mode = set_runtime_activity(Forward, true))

## The sampler in the inner model
inner_sampler = MH()
#inner_sampler = NUTS(; adtype = inner_adtype)

## The number of samples to draw from the inner model
inner_n_samples = 100

### SPECIFY SAMPLING ARGUMENTS AND  SAMPLE ###

#The AD to use in the outer model
# outer_ad = AutoForwardDiff()
# outer_adtype = AutoMooncake(; config = nothing) #Throws and error
# outer_adtype = AutoEnzyme(; mode = set_runtime_activity(Reverse, true)) #Julia crashes
# outer_adtype = AutoEnzyme(; mode = set_runtime_activity(Forward, true))

#The sampler in the outer model
outer_sampler = MH()
# outer_sampler = NUTS(; adtype = outer_adtype)


## The final model ##
m = outer_model(subj_observations, subj_action, inner_sampler, inner_n_samples)

## Sample ##
chns = sample(m, outer_sampler, 1000)

More variations and other considerations

Things I haven't tried out, but which would be interesting too: using variational inference for the inner/outer loop with AdvancedVI, or using other PPL languages (JuliaBUGS, Gen.jl, RxInfer.jl) as inner or outer loops. Note that many experiments have multi-trial structures like this, so being able to reuse the posterior from a previous inference over the inner model is useful; I suppose often people use SMC for these things however. Also note the somewhat subtle difference between the subject estimating a parameter of their generative mode, such as the autocorrelation rho here, based on their observations; and us estimating their (fixed) parameter based on their subsequent actions.

Let me know if you'd like more information from me, and curious to know if Enzyme breaks or not on your computers!

PTWaade avatar Jul 11 '25 14:07 PTWaade

Also, I'm assuming that allowing DualNumbers in the sample() call would be very tiresome, and I have no idea if ForwardDiff() and ReverseDiff() would be able to differentiate it even if it could, but if so then it would be nice to not have to rely on Mooncake and Enzyme for this (if they can fix their own errors here)

PTWaade avatar Jul 11 '25 14:07 PTWaade

I suppose I don't have a specific question here, so treat this issue as an open discussion and statement of intent and feel free to close it when you want to :)

PTWaade avatar Jul 11 '25 14:07 PTWaade

(and the autocorrelation bit here is not well-implemented I think, I mostly wanted to give an example of inferring a parameter of the subject's generative model).

PTWaade avatar Jul 11 '25 14:07 PTWaade

Finally, I think we might add (a minimal version of) this to the ADTests.JL suite - I'll happily help with that if I can :)

PTWaade avatar Jul 13 '25 08:07 PTWaade

Note @yebai commenting here that autodiff over a Turing sample call shouldn't be applied in general, but that gradient-free sampling should work fine

PTWaade avatar Jul 14 '25 07:07 PTWaade

I'll add that FiniteDifferences.jl is able to autodiff over the model (although extremely slowly). Code is pasted below; feel free to ignore.

using Turing
using LogExpFunctions
using ADTypes: AutoFiniteDifferences
import FiniteDifferences: central_fdm

#### CREATING INNER AND OUTER MODELS ####

# Inner model function
@model function inner_model(observation, prior_μ = 0, prior_σ = 1)
    
    # The innter model's prior
    mean ~ Normal(prior_μ, prior_σ)

    #The inner model's likelihood
    observation ~ Normal(mean, 1)
end

# Outer model function
@model function outer_model(observation, action, inner_sampler = NUTS(), inner_n_samples = 1000)
    
    ### Sample parameters for the inner inference and response ###

    #Parameters of the inner model's prior
    subj_prior_μ ~ Normal(0, 1)
    # subj_prior_σ ~ Exponential(1)
    subj_prior_σ = 1.0 
    
    # #Inverse temperature for actions
    # β ~ Exponential(1) 
    β = 1.0


    ### "Perceptual inference": running the inner model ###

    #Condition the inner model
    inner_m = inner_model(observation, subj_prior_μ, subj_prior_σ)

    #Run the inner Bayesian inference
    chns = sample(inner_m, inner_sampler, inner_n_samples, progress = false)

    #Extract the subject's point estimate of the reward probability
    subj_mean_expectationₜ = mean(chns[:mean])


    ### "Response model": picking an action ###

    #Sample the action
    action ~ Normal(subj_mean_expectationₜ, β)

end


### SPECIFYING DATA AND MODEL ARGUMENTS ###

## Data ##
observation = 0.0
action = 1.0

## Inner model specifications ##

#The number of samples to draw from the inner model
inner_n_samples = 10

#With Metropolis hastings as inner sampler
m1 = outer_model(observation, action, MH(), inner_n_samples)
#With NUTS ForwardDiff inner sampler
m2 = outer_model(observation, action, NUTS(; adtype = AutoForwardDiff()), inner_n_samples)
#With NUTS Mooncake inner sampler
m3 = outer_model(observation, action, NUTS(; adtype = AutoFiniteDifferences(; fdm = central_fdm(5, 1))), inner_n_samples)

## Outer model specifications ##

#With NUTS FiniteDifferences.jl as outer sampler
chns = sample(m1, NUTS(; adtype = AutoFiniteDifferences(; fdm = central_fdm(5, 1))), 100) #this works
chns = sample(m2, NUTS(; adtype = AutoFiniteDifferences(; fdm = central_fdm(5, 1))), 100) #this works
chns = sample(m3, NUTS(; adtype = AutoFiniteDifferences(; fdm = central_fdm(5, 1))), 100) #this works

PTWaade avatar Jul 14 '25 12:07 PTWaade