git-re-basin icon indicating copy to clipboard operation
git-re-basin copied to clipboard

PyTorch compatibility

Open superkirill opened this issue 2 years ago • 11 comments

Very interesting work! Are there plans to make weights matching applicable to PyTorch models?

superkirill avatar Sep 17 '22 10:09 superkirill

no concrete plans yet, but it ought not to be all that hard! might be a fun idea to live code/stream on twitch at some point... i certainly won't have time for this until at least a month after the ICLR deadline. if someone can beat me to it, that would be cool as well!

samuela avatar Sep 17 '22 18:09 samuela

I converted part of the code to PyTorch. Will push to a repo and share here

themrzmaster avatar Sep 19 '22 22:09 themrzmaster

https://github.com/themrzmaster/git-re-basin-pytorch

Still working on somethings, but it is there :)

themrzmaster avatar Sep 19 '22 23:09 themrzmaster

How can this be applied to arbitrary pytorch models?

affableroots avatar Oct 28 '22 20:10 affableroots

@PythonNut has been working on a PyTorch tracer that will enable you to automatically get PermutationSpecs between most model architectures for free... but that's still a work in progress atm

samuela avatar Oct 28 '22 20:10 samuela

There's an attempt at merging for Stable Diffusion here (just linking it for newcomers).

Note: we haven't gotten it to work on GPU yet, and it OOMs on 32GB RAM, currently.

affableroots avatar Nov 17 '22 18:11 affableroots

@PythonNut has been working on a PyTorch tracer that will enable you to automatically get PermutationSpecs between most model architectures for free... but that's still a work in progress atm

Did this work out in the end?

markdjwilliams avatar Feb 15 '23 20:02 markdjwilliams

This is really @pythonnut's thing, but IIUC I don't think it's in a state suitable to open sourcing yet. That said, it shouldn't be too hard to hack together something that works with the symbolic execution stuff provided in https://pytorch.org/docs/stable/fx.html.

samuela avatar Feb 15 '23 21:02 samuela

@markdjwilliams

Did this work out in the end?

I have an independent implementation of the matching algorithm in PyTorch with some optimizations over our original version. This is definitely something I could release but the existence of themrzmaster/git-re-basin-pytorch makes this less urgent.

For the tracer, I have code that basically works, but right now the utility is somewhat limited. It should be able to reproduce all of the permutation specs used in the paper, but there are not that many new networks that work out of the box. Currently, I assume that any given tensor axis can be represented by a single permutation group, but there are several ways this assumption might not hold:

  1. Tensor axis which is a concatenation of multiple permutation groups (e.g. DenseNet)
  2. Tensor axis which has a hierarchical permutation structure (e.g. MultiHeadAttention)
  3. Tensor axis which is a "tensor product" of permutation groups (e.g. VGGs on ImageNet [CIFAR is fine])
  4. Probably more!

We've worked out what the matching algorithm would look like in all of these cases. The main missing piece is a structure for the permutation spec which can support all of these cases.

If you have a particular use-case in mind, we'd love to hear about it!

PythonNut avatar Feb 21 '23 03:02 PythonNut

Thank you! I'd say my use-case was quite speculative, so I was hoping for a quick way of testing it out without investing the time into writing correct specs for my model.

markdjwilliams avatar Feb 24 '23 17:02 markdjwilliams

When implemented in PyTorch, linear_sum_assignment in script does not run on the GPU, which causes a speed bottleneck. For example, in STE, linear_sum_assignment exists in the learning loop, which causes frequent memory transfers between the CPU and GPU, sacrificing speed.

I have done a PyTorch implementation of linear_sum_assignment. It accepts torch.Tensor on GPU and returns result similar to Scipy. However, I am having trouble with the torch version because it runs slower than Scipy for some reason. If anyone has tried similar problems, please advise. https://gist.github.com/MasanoriYamada/72405515264749df02ba392f16810e12

Referenced jax implementation

MasanoriYamada avatar Aug 02 '23 03:08 MasanoriYamada