AdvancedVI.jl icon indicating copy to clipboard operation
AdvancedVI.jl copied to clipboard

add minibatch subsampling (doubly stochastic) objective

Open Red-Portal opened this issue 6 months ago • 5 comments

This is a draft for the subsampling variational objective, which addresses #38 . Any perspectives/concerns/comments are welcome! The current plan is to only implement random reshuffling. As I recently showed that there is no point in implementing independent subsampling. Although importance sampling of datapoints could be an interesting addition, it will require custom DynamicPPL.Contexts.

The key design decisions is the following function:

# This function/signature will be moved to src/AdvancedVI.jl
"""
    subsample(model, batch)

# Arguments
- `model`: Model subject to subsampling. Could be the target model or the variational approximation.
- `batch`: Data points or indices corresponding to the subsampled "batch."

# Returns 
- `sub`: Subsampled model.
"""
prob_sub = subsample(prob, batch)
q_sub = subsample(q, batch)

Given a previous slack DM thread with @yebai , this interface could be straightforwardly implemented by Turing models as

@model function pPCA(X::AbstractMatrix{<:Real}, k::Int; data_or_indices = 1:size(X,1))
    N, D = size(X)
    N_sub = length(batch_idx)

    W ~ filldist(Normal(), D, k)
    Z ~ filldist(Normal(), k, N)

    # Subsampling
    # Some AD backends are not happy about `view`.
    # In that case, this step will commit a copy and, therefore, shall not be considered free.
    Z_sub = view(Z, :, idx)
    X_sub = view(X, :, idx)

    genes_mean = W * Z_sub
    return X_sub ~ arraydist([MvNormal(m, Eye(N_sub)) for m in eachcol(genes_mean')])
end;

where data_or_indices could be made a reserved keyword argument for Turing models. Then, I think

using Accessors

function subsample(m::DynamicPPL.Model, batch)
    n, b = length(m.defaults), length(batch)
    m = @set m.defaults = batch
    m = @set m.context = MiniBatchContext(context=m.context; b, n)
    m
end

should generally work?

My current guess would be that subsample(m::DynamicPPL.Model, batch) would have to end up in the main Turing repository.

Red-Portal avatar Aug 11 '24 21:08 Red-Portal