AdvancedVI.jl
AdvancedVI.jl copied to clipboard
add minibatch subsampling (doubly stochastic) objective
This is a draft for the subsampling variational objective, which addresses #38 . Any perspectives/concerns/comments are welcome! The current plan is to only implement random reshuffling. As I recently showed that there is no point in implementing independent subsampling. Although importance sampling of datapoints could be an interesting addition, it will require custom DynamicPPL.Context
s.
The key design decisions is the following function:
# This function/signature will be moved to src/AdvancedVI.jl
"""
subsample(model, batch)
# Arguments
- `model`: Model subject to subsampling. Could be the target model or the variational approximation.
- `batch`: Data points or indices corresponding to the subsampled "batch."
# Returns
- `sub`: Subsampled model.
"""
prob_sub = subsample(prob, batch)
q_sub = subsample(q, batch)
Given a previous slack DM thread with @yebai , this interface could be straightforwardly implemented by Turing models as
@model function pPCA(X::AbstractMatrix{<:Real}, k::Int; data_or_indices = 1:size(X,1))
N, D = size(X)
N_sub = length(batch_idx)
W ~ filldist(Normal(), D, k)
Z ~ filldist(Normal(), k, N)
# Subsampling
# Some AD backends are not happy about `view`.
# In that case, this step will commit a copy and, therefore, shall not be considered free.
Z_sub = view(Z, :, idx)
X_sub = view(X, :, idx)
genes_mean = W * Z_sub
return X_sub ~ arraydist([MvNormal(m, Eye(N_sub)) for m in eachcol(genes_mean')])
end;
where data_or_indices
could be made a reserved keyword argument for Turing models. Then, I think
using Accessors
function subsample(m::DynamicPPL.Model, batch)
n, b = length(m.defaults), length(batch)
m = @set m.defaults = batch
m = @set m.context = MiniBatchContext(context=m.context; b, n)
m
end
should generally work?
My current guess would be that subsample(m::DynamicPPL.Model, batch)
would have to end up in the main Turing repository.