StructWalk.jl
StructWalk.jl copied to clipboard
Thoughts on zipped tree traversal
I wanted to pick up from our discussion in https://github.com/FluxML/Functors.jl/pull/32, in particular around your comments in https://github.com/FluxML/Functors.jl/pull/32#issuecomment-1026454917. I think one concrete example to work off of is a function like torch.nn.Module.load_state_dict
[^1]. In other words, something that prefers but does not require similarly structured trees (it returns a list of discrepancies if they aren't).
I feel like this is absolutely something StructWalk could excel at, but I agree with your point that creating a one-size-fits-all mechanism would be difficult. My question would be whether there might be some extensible mechanism to handle this, just like how WalkStyle
allows for highly customizable traversal behaviour over individual nodes.
[^1]: this came up in the discussion around restoring saved model weights, ref. @darsnack's post at https://github.com/FluxML/Flux.jl/issues/1027#issuecomment-1034226504.