Gen.jl
Gen.jl copied to clipboard
Help with split-merge involutive MCMC for a tree-structured change-point model
Hi,
I'm trying to do involutive MCMC for a 2d non-parametric change-point model (the same one as in #388). The model samples a tree structure that recursively divides an image into rectangles, and then samples pixel values for the whole image given the structure. Model code here, the relevant bits are
@gen function grow_tree()
if @trace(bernoulli(0.5), :isleaf)
mean = @trace(normal(0, 1), :mean)
variance = @trace(gamma(1, 1), :variance)
return LeafNode(mean, variance)
else
frac = logistic(@trace(normal(0, 1), :frac))
ishorizontal = @trace(bernoulli(0.5), :ishorizontal)
a = @trace(grow_tree(), :a)
b = @trace(grow_tree(), :b)
return BranchNode(a,b,frac,ishorizontal)
end
end;
@gen function screen_model(size::Tuple{Int64,Int64})
nrows, ncols = size
screenshot = Array{Float64}(undef, nrows, ncols)
tree = @trace(grow_tree(), :tree)
is,js = get_index_matrices(screenshot)
img_params = map((i,j) -> get_value_at(i,j,tree,(1.,1.,Float64(nrows),Float64(ncols))), is, js)
img_mean = getindex.(img_params,1)
img_variance = getindex.(img_params,2)
screenshot = @trace(broadcasted_normal([img_mean...],[img_variance...]), :img)
return reshape(screenshot,size)
end
I wrote an involution inspired by the GP structure search in GenExamples.jl, where the proposal picks a random node in the tree, and samples a subtree that replaces whatever was at that node. Full code here.
@gen function subtree_proposal(prev_trace)
prev_subtree_node::Node = prev_trace[:tree]
(path::Vector{Symbol}) = @trace(pick_random_node_path(prev_subtree_node, Symbol[]), :choose_subtree_root)
subtree_addr = isempty(path) ? :tree : (:tree => foldr(=>, path))
new_subtree_node::Node = @trace(grow_tree(), :subtree) # mixed discrete / continuous
(path, new_subtree_node)
end
@transform subtree_involution_tree_transform (model_in, aux_in) to (model_out, aux_out) begin
(path::Vector{Symbol}, new_subtree_node) = @read(aux_in[], :discrete)
# populate backward assignment with choice of root
@copy(aux_in[:choose_subtree_root], aux_out[:choose_subtree_root])
# swap subtrees
model_subtree_addr = isempty(path) ? :tree : (:tree => foldr(=>, path))
@copy(aux_in[:subtree], model_out[model_subtree_addr])
@copy(model_in[model_subtree_addr], aux_out[:subtree])
end
This runs without errors, but few traces are accepted, probably because randomly proposed trees aren't a good fit, and even a well-fitting tree usually has bad values (:mean and :variance) at the leaf nodes.
So I want to write a split-merge involution which proposes to split a leaf node, or merge sibling leaf nodes, along with :mean and :variance parameters derived from the current trace, similar to the mixture example here. To begin with, I skip the parameter proposals and just propose a random split/merge, like this:
@gen function split_merge_proposal(prev_trace)
tree = prev_trace[:tree]
n_leaf_nodes = count_leaves(tree)
random_split = @trace(bernoulli(0.5), :split)
split = (n_leaf_nodes == 1) ? true : random_split
if split
# select random leaf for splitting
leaf_path = @trace(pick_random_leaf(tree), :leaf_path)
new_node = @trace(make_branch(), :new_node)
else
# select random leaf that will merge with sibling
# sibling needs to be a leaf as well
leaf_path = @trace(pick_random_leaf_parent(tree), :leaf_path)
new_node = @trace(make_leaf(), :new_node)
end
(leaf_path, new_node)
end
@transform split_merge_transform (model_in, aux_in) to (model_out, aux_out) begin
#(leaf_path::Array{Symbol,1}, new_node) = @read(aux_in[], :discrete)
leaf_path = @read(aux_in[:leaf_path], :discrete)
new_node = @read(aux_in[:new_node], :discrete)
tree = @read(model_in[:tree], :discrete)
n_leaf_nodes = count_leaves(tree)
random_split = @read(aux_in[:split], :discrete)
split = (n_leaf_nodes == 1) ? true : random_split
new_node_addr = isempty(leaf_path) ? :tree : (:tree => foldr(=>, leaf_path))
#leaf_path = @read(aux_in[:leaf_path], :discrete)
#@copy(aux_in[:leaf_path], aux_out[:leaf_path])
@write(aux_out[:leaf_path], leaf_path, :discrete)
@write(aux_out[:split], !random_split, :discrete)
@copy(aux_in[:new_node], model_out[new_node_addr])
@copy(model_in[new_node_addr], aux_out[:new_node])
end
Full code here
Just like :choose_subtree_root in the subtree involution, to perform the correct reversal, aux_out needs the same value at :leaf_path that aux_in had, so I did exactly the same
@copy(aux_in[:leaf_path], aux_out[:leaf_path])
But here, it runs for a few iterations and then crashes on
ERROR: transform round trip check failed
because the copied :leaf_path would, for some reason, differ from the original one, and so the new node that was proposed was spliced back at the wrong address, resulting in different model traces before and after the round trip.
I tried with
@write(aux_out[:leaf_path], leaf_path, :discrete)
instead, and now I get
ERROR: KeyError: key :leaf_path not found
for either
leaf_path = @trace(pick_random_leaf(tree), :leaf_path)
or
leaf_path = @trace(pick_random_leaf_parent(tree), :leaf_path)
in the proposal.
I hope this makes sense, let me know if I should clarify anything. Ultimately I want to do split-merge jumps that derive the leaf node values from the previous state (e.g. take the means of the mean and variance parameters of two merging leaf nodes), and then do MAP optimization on all the continuous parameters. But maybe there's a better way to do inference?
Thanks in advance for any help.
I think the issue might be that pick_random_leaf
is not sampling a single random choice, but is itself a generative function that samples multiple random choices. If that's the case, you can either make pick_random_leaf
into a Gen.Distribution
(https://www.gen.dev/dev/ref/extending/#custom_distributions-1) or you can extend the transform with a call to a recursive transform (that you call with @tcall
, see https://www.gen.dev/dev/ref/trace_translators/#Trace-Transform-DSL-1) that walks the tree and populates values of all the choices made by pick_random_leaf
(see https://github.com/probcomp/GenExamples.jl/blob/main/gp_structure/involution_mh.jl#L91-L155 for an example of that approach).
Note that while you can @copy
entire choice maps from one trace to another, you can only @read
and @write
individual random choices.
Thank you for the answer! I'm not sure if it's exactly what you meant, but I tried the recursive transform approach like this:
@gen function pick_random_leaf(node::Node, cur::Int, depth::Int)
if isa(node, LeafNode)
(cur, depth)
elseif @trace(bernoulli(0.5), :recurse_a => cur)
@trace(pick_random_leaf(node.a, get_child(cur, 1, 2), depth+1))
else
@trace(pick_random_leaf(node.b, get_child(cur, 2, 2), depth+1))
end
end
@gen function pick_random_leaf_parent(node::Node, cur::Int, depth::Int)
if isa(node.a, LeafNode) && isa(node.b, LeafNode)
return (cur, depth)
else
if isa(node.a, BranchNode) && isa(node.b, BranchNode)
recurse_a_prob = 0.5
elseif isa(node.a, BranchNode) && !isa(node.b, BranchNode)
recurse_a_prob = 1.
elseif isa(node.b, BranchNode) && !isa(node.a, BranchNode)
recurse_a_prob = 0.
else
error(node)
end
end
if @trace(bernoulli(recurse_a_prob), :recurse_a => cur)
@trace(pick_random_leaf_parent(node.a, get_child(cur, 1, 2), depth+1))
else
@trace(pick_random_leaf_parent(node.b, get_child(cur, 2, 2), depth+1))
end
end
@gen function split_merge_proposal(prev_trace)
tree = prev_trace[:tree]
n_leaf_nodes = count_leaves(tree)
random_split = @trace(bernoulli(0.5), :split)
split = (n_leaf_nodes == 1) ? true : random_split
if split
# select random leaf for splitting
leaf_path = @trace(pick_random_leaf(tree, 1, 0), :leaf_path)
new_node = @trace(make_branch(), :new_node)
else
# select random leaf that will merge with sibling
# sibling needs to be a leaf as well
leaf_path = @trace(pick_random_leaf_parent(tree, 1, 0), :leaf_path)
new_node = @trace(make_leaf(), :new_node)
end
(leaf_path, new_node)
end
@transform walk_tree(cur::Int, leaf_path::Array{Symbol,1}) (model_in, aux_in) to (model_out, aux_out) begin
(leaf_number, leaf_depth) = @read(aux_in[:leaf_path], :discrete)
if leaf_number == cur
new_node_addr = isempty(leaf_path) ? :tree : (:tree => foldr(=>, leaf_path))
@copy(aux_in[:new_node], model_out[new_node_addr])
@copy(model_in[new_node_addr], aux_out[:new_node])
else
recurse_a = @read(aux_in[:leaf_path => :recurse_a => cur], :discrete)
if recurse_a
push!(leaf_path, :a)
@tcall(walk_tree(get_child(cur, 1, 2), leaf_path))
else
push!(leaf_path, :b)
@tcall(walk_tree(get_child(cur, 2, 2), leaf_path))
end
end
end
@transform split_merge_transform (model_in, aux_in) to (model_out, aux_out) begin
(leaf_number, leaf_depth) = @read(aux_in[:leaf_path], :discrete)
new_node = @read(aux_in[:new_node], :discrete)
tree = @read(model_in[:tree], :discrete)
n_leaf_nodes = count_leaves(tree)
random_split = @read(aux_in[:split], :discrete)
split = (n_leaf_nodes == 1) ? true : random_split
@copy(aux_in[:leaf_path], aux_out[:leaf_path])
@write(aux_out[:split], !random_split, :discrete)
@tcall(walk_tree(1,Symbol[]))
end
It does pass the round trip checks, and usually gets many iterations in, but eventually crashes on
ERROR: Did not visit all constraints
Any idea what might be going wrong, or how I can inspect what constraints aren't visited and why?
I've also been trying to do map_optimize on a selection of the continuous choices in the trace, like this:
function select_continuous(node)
selection = DynamicSelection()
leaves = get_path_to_leaf(node)
branches = []
for leaf in leaves
leaf_address = isempty(leaf) ? :tree : (:tree => foldr(=>, leaf))
push!(selection, leaf_address => :mean)
push!(selection, leaf_address => :variance)
for (i,choice) in enumerate(leaf)
if leaf[1:i] ∉ branches
push!(branches, leaf[1:i])
branch_address = isempty(leaf[1:i]) ? :tree : (:tree => foldr(=>, leaf[1:i]))
push!(selection, branch_address => :frac)
end
end
end
return selection
end
function do_inference(model, img, n_iter)
# condition on image
observations = choicemap()
observations[:img] = [img...]
# generate initial trace consistent with observed data
(trace, _) = generate(model, (size(img),), observations)
continous_variables = select_continuous(trace[:tree])
# do MCMC
for iter=1:n_iter
# do MH move on the subtree
trace = replace_split_merge_move(trace)
# optimize continuous variables
trace = map_optimize(trace, continuous_variables)
end
return trace
end;
but this runs into lots of different bugs like
ERROR: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64,Float64,Nothing} to Float64 is not defined. Please use
ReverseDiff.value instead.
and
ERROR: AssertionError: length(arr) >= start_idx
which makes me think there's something fundamentally wrong with how I'm setting this up.
In any case, thank you for your time!