git-re-basin
git-re-basin copied to clipboard
PyTorch compatibility
Very interesting work! Are there plans to make weights matching applicable to PyTorch models?
no concrete plans yet, but it ought not to be all that hard! might be a fun idea to live code/stream on twitch at some point... i certainly won't have time for this until at least a month after the ICLR deadline. if someone can beat me to it, that would be cool as well!
I converted part of the code to PyTorch. Will push to a repo and share here
https://github.com/themrzmaster/git-re-basin-pytorch
Still working on somethings, but it is there :)
How can this be applied to arbitrary pytorch models?
@PythonNut has been working on a PyTorch tracer that will enable you to automatically get PermutationSpec
s between most model architectures for free... but that's still a work in progress atm
There's an attempt at merging for Stable Diffusion here (just linking it for newcomers).
Note: we haven't gotten it to work on GPU yet, and it OOMs on 32GB RAM, currently.
@PythonNut has been working on a PyTorch tracer that will enable you to automatically get
PermutationSpec
s between most model architectures for free... but that's still a work in progress atm
Did this work out in the end?
This is really @pythonnut's thing, but IIUC I don't think it's in a state suitable to open sourcing yet. That said, it shouldn't be too hard to hack together something that works with the symbolic execution stuff provided in https://pytorch.org/docs/stable/fx.html.
@markdjwilliams
Did this work out in the end?
I have an independent implementation of the matching algorithm in PyTorch with some optimizations over our original version. This is definitely something I could release but the existence of themrzmaster/git-re-basin-pytorch makes this less urgent.
For the tracer, I have code that basically works, but right now the utility is somewhat limited. It should be able to reproduce all of the permutation specs used in the paper, but there are not that many new networks that work out of the box. Currently, I assume that any given tensor axis can be represented by a single permutation group, but there are several ways this assumption might not hold:
- Tensor axis which is a concatenation of multiple permutation groups (e.g. DenseNet)
- Tensor axis which has a hierarchical permutation structure (e.g. MultiHeadAttention)
- Tensor axis which is a "tensor product" of permutation groups (e.g. VGGs on ImageNet [CIFAR is fine])
- Probably more!
We've worked out what the matching algorithm would look like in all of these cases. The main missing piece is a structure for the permutation spec which can support all of these cases.
If you have a particular use-case in mind, we'd love to hear about it!
Thank you! I'd say my use-case was quite speculative, so I was hoping for a quick way of testing it out without investing the time into writing correct specs for my model.
When implemented in PyTorch, linear_sum_assignment in script does not run on the GPU, which causes a speed bottleneck. For example, in STE, linear_sum_assignment exists in the learning loop, which causes frequent memory transfers between the CPU and GPU, sacrificing speed.
I have done a PyTorch implementation of linear_sum_assignment. It accepts torch.Tensor on GPU and returns result similar to Scipy. However, I am having trouble with the torch version because it runs slower than Scipy for some reason. If anyone has tried similar problems, please advise. https://gist.github.com/MasanoriYamada/72405515264749df02ba392f16810e12
Referenced jax implementation