StructWalk.jl icon indicating copy to clipboard operation
StructWalk.jl copied to clipboard

Thoughts on zipped tree traversal

Open ToucheSir opened this issue 3 years ago • 8 comments

I wanted to pick up from our discussion in https://github.com/FluxML/Functors.jl/pull/32, in particular around your comments in https://github.com/FluxML/Functors.jl/pull/32#issuecomment-1026454917. I think one concrete example to work off of is a function like torch.nn.Module.load_state_dict [^1]. In other words, something that prefers but does not require similarly structured trees (it returns a list of discrepancies if they aren't).

I feel like this is absolutely something StructWalk could excel at, but I agree with your point that creating a one-size-fits-all mechanism would be difficult. My question would be whether there might be some extensible mechanism to handle this, just like how WalkStyle allows for highly customizable traversal behaviour over individual nodes.

[^1]: this came up in the discussion around restoring saved model weights, ref. @darsnack's post at https://github.com/FluxML/Flux.jl/issues/1027#issuecomment-1034226504.

ToucheSir avatar Feb 11 '22 02:02 ToucheSir

Maybe something like:

abstract type AlignedStyle{W<:WalkStyle} end

const Align = Union{AlignedStyle, Type{AlignedStyle}}

WalkStyle(::AlignedStyle) = WalkStyle(AlignedStyle)
WalkStyle(::Type{AlignedStyle}) = WalkStyle

function alignedstyle(style::Align, xs...)
    fns = map(xs) do x
        S = walkstyle(WalkStyle(style), x)
        _, fields = S
        isnontuple = length(S) <= 2 ? false : S[3]
        nchild = isnontuple ? sum(length, fields) : length(fields)
        (isnontuple ? Iterators.flatten(fields) : fields), nchild
    end
    fs = map(x->x[1], fns)
    n = minimum(x->x[2], fns)
    return identity, zip(fs...), n
end

zippedwalk(f, style::Align, inner_walk, xs...) = zippedwalk(f, f, style, inner_walk, xs...)
function zippedwalk(f, g, style::Align, inner_walk, xs...)
    T, a, n = alignedstyle(style, xs...)
    isleaf = iszero(n)
    if isleaf
        return f(xs)
    else
        return g(T(map(inner_walk, a)))
    end
end

zippedpostwalk(f, xs...) = zippedpostwalk(f, AlignedStyle, xs...)
zippedpostwalk(f, style::Align, xs...) = zippedwalk(f, style, x -> zippedpostwalk(f, style, x...), xs...)


julia> a
(x = 0, y = (w = 0, b = 0))

julia> b
(0.5, (0.3, 0.5))

julia> c
(1, 1, 1)

julia> StructWalk.zippedpostwalk(identity, a, b)
2-element Vector{Any}:
 (0, 0.5)
 [(0, 0.3), (0, 0.5)]

julia> StructWalk.zippedpostwalk(, a, b)^C

julia> StructWalk.zippedpostwalk(identity, a, c)
2-element Vector{Tuple{Any, Int64}}:
 (0, 1)
 ((w = 0, b = 0), 1)

I haven't think too much on the design, but basically the extensible mechanism here is coming from the alignedstyle which could probably dispatch on all kinds of argument combination if needed and fallback on zipping the children extracted with a linked WalkStyle. Things like fieldnames can be take into account with a different WalkStyle.

In other words, something that prefers but does not require similarly structured trees (it returns a list of discrepancies if they aren't).

Can you elaborate more on this? Assuming that we can load the "aligned"-able part, I wonder what do we want to do with the others? The torch.nn.Module.load_state_dict is actually quite narrow as it only take two argument which have a relatively clear hierarchy (state_dict is a flat dictionary with string key and models are all subclass of torch.nn.Module).

chengchingwen avatar Feb 11 '22 17:02 chengchingwen

alignedstyle and co. look very interesting, thanks for writing this out! I played around and managed to get zippedpostwalk(+, a, b) working just by changing to f(xs...) in zippedwalk, do you foresee any issues with that change?

Can you elaborate more on this? Assuming that we can load the "aligned"-able part, I wonder what do we want to do with the others? The torch.nn.Module.load_state_dict is actually quite narrow as it only take two argument which have a relatively clear hierarchy (state_dict is a flat dictionary with string key and models are all subclass of torch.nn.Module).

My mistake, I'd misremembered. A better example would be how JAX frameworks deserialize weights or how https://github.com/deepmind/optax runs over multiple trees much like Optimisers.jl. In both cases, I believe they take the easy path and implicitly enforce structural similarity through flattening.

ToucheSir avatar Feb 13 '22 02:02 ToucheSir

do you foresee any issues with that change?

Should be fine, but some small concern:

  • returning tuple is somewhat more zip-like
  • In the zippedwalk, I use f and g for leaf function and node function respectively and g is default to f. I use f(xs) instead of f(xs...) is because I want to distinguish between called on leaves and called on nodes. But maybe I just shouldn't use f as a default for g. identity is probably enough.

A limitation of alignedstyle and zippedwalk is that it cannot handle unaligned children excepting ignoring them.

BTW, will you be making the PR for this?

A better example would be how JAX frameworks deserialize weights

I'm not familiar with JAX. Are you talking about pytree and jax.tree_multimap? It seems to only allow tree with exact same structure:

For tree_multimap, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.

and they seems to generally work on flattened array with a tree structure encoding (treedef). flax.from_state_dict looks more like how torch.nn.Module.load_state_dict works. IIUC that means torch.nn.Module.load_state_dictprobably handled the unaligned part more explicitly.

chengchingwen avatar Feb 13 '22 09:02 chengchingwen

@ToucheSir any updates?

chengchingwen avatar Feb 21 '22 11:02 chengchingwen

Sorry, nothing from my end yet. I'll need to take a solid chunk of time to sit down and figure out the details of the default align and exact same structure enforcing align for a PR. Until then, I think we're basically in agreement on your points above.

ToucheSir avatar Feb 21 '22 21:02 ToucheSir

@ToucheSir I slightly change the code and will merge it in #4

chengchingwen avatar Jul 20 '22 04:07 chengchingwen

So this implemented by the aligned walk and can be closed?

CarloLucibello avatar Apr 12 '23 03:04 CarloLucibello

Aligned walk is a one possible implementation, but we are not fully convinced it's the best way to do, so I leave this issue open.

chengchingwen avatar Apr 12 '23 23:04 chengchingwen