Functors.jl
Functors.jl copied to clipboard
Extract common functionality into fold
Traversal functions in Functors.jl currently handle traversal via manual recursion. This isn't the end of the world, but it results in a good amount of code duplication as well as parameter explosion from having to plumb auxiliary arguments through the call stack. With #31 looking to add a number of additional functions and Optimisers.jl not getting a lot of mileage out of fmap
, it's a good time to consider whether we can cut down on the boilerplate.
This PR introduces helper traversal functionality in the form of [fF]old
. In the language of recursion schemes, fold
is a "catamorphism" which performs a generalized structural reduction over a tree (or in our case, DAG). Also added are a couple of caching-related helpers that may be useful downstream.
This is the first step in an effort to implement the vision of #27 while maintaining as much backwards compatibility as possible. A non-exhaustive list of objectives for future PRs include:
- Reducing the reliance on custom walks for downstream code
- Removing the requirement to carry around the
re
(construct) closure. This can lead to some unfortunate gymnastics, so getting rid of it would be great. - Accomodating multiple inputs, e.g.
fmap(f, x, xs...)
. This would help Optimisers.jl and any other downstream library that have rolled their ownfmap
variants withfunctor
.
e.g. fmap(f, x, xs...). This would help Optimisers.jl and any other downstream library that have rolled their own fmap variants with functor.
Can you sketch how this would work? Optimisers does some tricky things, and it often has to treat the gradient asymmetrically to the model. Is it clear that this is a good fit?
Haven't tested this, but this a rough sketch of the approach. Probably needs more iterations in the oven, but the more we take Functors.jl in this direction, the better. Even without Optimisers.jl, I will be able to use these improvements for model pruning.
function update(tree, x, dxs...)
function opt_walk(f, leaf, x, dxs...)
any(isnothing, (leaf, x, dxs...)) && return first(dxs)
stree = children(leaf)
sx = children(x)
sdxs = map(children, dxs...)
foreach((_tree, _x, _dxs) -> f(_tree, _x, _dxs...), stree, sx, zip(sdxs))
end
fmap(apply, tree, x, dxs...; walk = opt_walk)
end
Can you sketch how this would work? Optimisers does some tricky things, and it often has to treat the gradient asymmetrically to the model. Is it clear that this is a good fit?
Good questions. fmap
will never modify the structure of the tree, so using it that way is a non-starter. I intend to rename or alias it to mapleaves
to clearly reflect that [^1]. fold
, on the other hand, will traverse every level of the tree. So you can add or prune subtrees at will using it. This is overkill for most of Optimisers, but you can imagine how it might be useful for FluxML/Optimisers.jl#42 with a little tweaking.
Handling parallel traversal (i.e. zip then traverse) over asymmetric trees will require better walk functions than we currently have in Functors. I started prototyping that as part of #27, but it's going to be a challenge to get working without completely changing core functionality such as the behaviour of functor
itself. Hence the incremental approach described up top: keep the user-facing interface as stable as possible while updating the internals. If that requires maintaining two parallel sets of functionality (one for back compat and one for the new stuff), so be it.
[^1]: What's old is new again: https://github.com/FluxML/Flux.jl/blob/e7d76b8423818c5a165e388dd3b090cc5bf42cbb/src/treelike.jl#L27. But seriously, removing jargon from functors is a good thing, especially when some of it (e.g. fmap
) is not strictly correct.
Ok. It might be worth trying to write FluxML/Optimisers.jl#42 using this as a logic stress-test? Although need not hold up this PR I guess. (I am not happy yet with how to handle transposed tied arrays, there, for which the gradient might not have .parent, but might be a thunk or a Broadcasted....)
Longer term I think we should be open to just scrapping Functors and starting over. The entire library fits on a page, really, it's just a matter of choosing logic which matches the problems we want to solve. And a clean start (or two or three) might be much easier than contortions to modify it while keeping backward compat.
The major breakages will be w.r.t. Flux's uses of Functors. There are probably few users directly using Functors outside of @functor
. So, if we need to scrap one day, then the contortions can always live in Flux.jl for a deprecation cycle.
We could contemplate having Flux provide an @functor
which for now calls this one, to ease transitions.
The entire library fitting on a page is both part of the allure and the fundamental problem with Functors. It makes the problem look so easy, whereas as we've discovered it is anything but.
Longer term I think we should be open to just scrapping Functors and starting over.
This could be done today by moving to an alternative like https://github.com/chengchingwen/StructWalk.jl. fold
is basically postwalk
+ cache and overriding of isleaf
, after all. IMO WalkStyle
is a cleaner and more idiomatic paradigm than custom walk functions too. In any case, the missing pieces for any contender (including Functors itself) are:
- Walking over similarly structured trees,
walk(f, x, xs...)
- Walking over multiple subtrees with different branches,
walk(f, (a=1, b=2), (b=3, c=4))
- Functionality to traverse multiple trees and return a single tree,
fmap(f, x, xs...) -> tree
- Functionality to traverse multiple trees and return multiple trees,
fmultimap(f, x, xs...) -> tree, trees...
. This could also be done as anunzip
-like post-processing step, seejax.tree_util.tree_transpose
.
cc @chengchingwen for his thoughts on this.
We could contemplate having Flux provide an
@functor
which for now calls this one, to ease transitions.
JuliaHub tells me that a number of libraries rely on Functors but not Flux, so it would help a bit but not completely solve the migration issue.
@ToucheSir Do you have some small test cases in mind for those missing pieces? It's not obvious to me how they should work in general. For example, what would be considered to be similarly structured: Is 1//2
similar to (a=1,b=2)
? what about (a, b)
and (b=3, c=4)
(they can both be viewed as NTuple{2}
)? Or maybe we want an interface to define what is similar and how should they aligned? I used to tackle a similar problem in loading weights from state_dict
by using the field names (i.e. hasfield
/getfield
/keys
/...) and functor
, and StructWalk.jl
was originally built for rewriting that part of code. OTOH, what is the different between "walking over subtrees" (1./2.) and "traverse multiple trees" (3./4.)? Returning multiple trees (4.) is the same as returning a tuple of trees, which is a single tree (3.). And if we are traversing multiple trees that are not similar / cannot be aligned (3. but not 1./2.), what would the traversal algorithm be? Priority first search with custom priority?
@chengchingwen those are all good examples, and honestly I'm not sure either. One idea I had is that structural similarity would be determined by the LHS. So (1, 2), (a=1,b=2), ...
uses the tuple path, while (a=1,b=2), (1, 2), ...
uses the namedtuple path. Arrays will be similar to tuples while structs will act like namedtuples (I don't want to think about Dicts...).
And if we are traversing multiple trees that are not similar / cannot be aligned (3. but not 1./2.), what would the traversal algorithm be? Priority first search with custom priority?
That too I'm not sure on. Do we error/warn about them, if at all? My thought was to take the intersection of all properties for keyed types and walk over only those. Some thinking out loud with pseudocode:
function walk(f, nt, nts...) # single return, f(xs...) = y
common_props = intersect(propertynames(nt), map(propertynames, nts...)...)
patch = map(f, nt[common_props], map(t -> t[common_props])...)
return merge(nt, patch)
end
function walk(f, nt, nts...) # multiple returns, f(xs...) = ys..., length(ys) == length(xs)
common_props = intersect(propertynames(nt), map(propertynames, nts...)...)
patch, patches... = map(f, nt[common_props], map(t -> t[common_props])...)
return (merge(nt, patch), map(merge, nts, patches)...)
end
Maybe that's too surprising?
Linearly-indexed types should probably just fail on a length mismatch? I see the 2-arg case as an instance of the multi-arg case, but perhaps that's the wrong way to look at it.
Returning multiple trees (4.) is the same as returning a tuple of trees, which is a single tree (3.)
Yes, the question is how to create a tuple of trees given a tuple of trees and a user-specified varargs function. AIUI applying the function at each node will result in a tree of tuples instead, so the library would need to know how to unwrap those tuples either during or after traversal.
@ToucheSir
One idea I had is that structural similarity would be determined by the LHS
This might be problematic since the similar operator is asymmetric, and I'm not sure we always want to align the children just according to the first argument (or just the intersect of property names as in the pseudocode). There are cases we would align children with different name or even different depths to ignore some wrapper.
All in all, I think we need to clarify what the data do we really need during the walk/transformation. In the "walking over subtrees" paradigm, we want the "zip"ed data in each tree like running several BFS in parallel. OTOH, In the "traverse multiple trees" paradigm, we might want different part of data. So IMPO no matter it's a single/multiple input tree single/multiple output tree function, the real question is in what order do we want to access those subtrees?
This might be problematic since the similar operator is asymmetric, and I'm not sure we always want to align the children just according to the first argument (or just the intersect of property names as in the pseudocode). There are cases we would align children with different name or even different depths to ignore some wrapper.
That's true, is there a way to customize this alignment behaviour? To be clear, I only think basing the output structure off the leftmost tree makes sense when a single tree is returned from the multi-tree walk, as it aligns somewhat with how map
works on collections now. Having it be pluggable with a sane default (e.g. force exact structural similarity as liftA2
does) would be amazing.
All in all, I think we need to clarify what the data do we really need during the walk/transformation. In the "walking over subtrees" paradigm, we want the "zip"ed data in each tree like running several BFS in parallel. OTOH, In the "traverse multiple trees" paradigm, we might want different part of data. So IMPO no matter it's a single/multiple input tree single/multiple output tree function, the real question is in what order do we want to access those subtrees?
I think "walking over multiple subtrees" was bad wording on my part. A better term would've been "dissimilar trees" to mirror the 1st point about similarly-structured trees. Presumably that narrows down the design space dramatically, as I can't think of a use-case in e.g. Optimisers.jl that wouldn't work with either prewalk
or postwalk
.
as I can't think of a use-case in e.g. Optimisers.jl that wouldn't work with either prewalk or postwalk.
One thing it wants is the tree equivalent of map(f, pairs(x))
, with cache=false, over only trainable nodes. As a concrete example, can this be written neatly with pre/post-walk? Or without, is it made easier / neater by this PR?
Like Peter, I'm a little sceptical that we can design the ideal multi-tree walker, without close reference to the problems it's meant to solve, and their weird edge cases. Optimisers.jl gets a lot of mileage out of functor(typeof(x), y)
, and perhaps that pattern can be further abstracted somewhere here... but whether it can be done without making everything harder to understand I don't know. Knowing all the knobs on a configurable walking machine at some point becomes more complex than just writing the one you need in Julia code.
Some side notes about StructWalk.jl
is that it has a interface for define custom WalkStyle
, so we can have different kinds of children for different types according to the style. This is something that I think is missing in the current implementation of Functors.jl
, as we are all based on functor
to define the tree-equivalent map
(trainable
/gpu
/f32
/...), but those function might actually need different sets of child node. For example, not all array fields are trainable
, but some should still be gpu
-able. So the WalkStyle
interface provide a way do modified what is being treated as a child node. Besides, walk
are more general than fmap
as they not only apply function on the leaf nodes, it also apply on the non-leaf nodes. I'm not sure if we really need this for the optimiser case, but should totally be ignored if not needed.
One thing it wants is the tree equivalent of
map(f, pairs(x))
, with cache=false, over only trainable nodes.
map(f, pairs(x))
over trainable nodes is doable with a custom walk function on master, I believe. The logic would be similar to Optimisers._trainable
, but with a callback. One could also envision another walk that does replace non-trainable fields with nothing. In fact I think you could MacGyver one with the current functionality in Optimisers:
# _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) ...
function _trainable_fillnothing_walk(f, x)
func, re = functor(x)
tfunc = trainable(x)
return re(_trainable(func, map(f, tfunc)))
end
cache=false
is part of this PR, but could be accomplished on master with a dummy dict type with dummy setindex!
and haskey
methods.
As a concrete example, can this be written neatly with pre/post-walk? Or without, is it made easier / neater by this PR?
fmap
is already doing an implicit post-order traversal, so I'd say so. This PR makes that implicit part explicit by pulling it into a function. What fold
buys you over plain fmap
is that it visits non-leaf nodes too. I think most of Optimisers could get away without such a thing (via packing logic into walk functions, assuming we get multi-tree walk functions), but it has the potential to make things neater.
To Peter's latest point, all this is why I became interested in StructWalk.jl. fold
is more or less a convergently evolved StructWalk.postwalk
, but we have no equivalent of prewalk
in Functors at present. IMO WalkStyle
is far nicer than custom walk functions (which were a bit of a hack from the beginning, we basically exposed part of the internals). The only thing missing is the ability to walk multiple trees (regardless of semantics), but perhaps that discussion would be better had on the StructWalk issue tracker. Thoughts?