Alternate implementation of Lora leveraging tensor subclasses and reparametrization.
I thought this might be interesting as an alternate implementation of LoRA leveraging tensor subclasses and reparametrization.
https://gist.github.com/Chillee/a8d2070b1b7b3f97d8c87bac3c366f8e
The main idea here is that we can leverage parametrization in order to transform our parameter in a manner that's composable with existing modules (i.e. we don't need to use a totally new layer).
Then, since LoRA also requires us to leverage special matrix structure for efficiency, we return a tensor subclass that has special handling when we encounter F.linear(x: Tensor, weight: LoraTensor, bias: Tensor). This tensor subclass composes with things like autograd and such, so we can still differentiate through our tensor.
Super cool! Thanks for sharing, Horace!
Do you think it might make sense to make something like this available in torch?
@edwardjhu I think it's an interesting thought. LoRA has certainly become quite popular. OTOH, I'm sure that there will be many more variants and extensions of LoRA proposed in the coming months/years.
One of the hard parts about PyTorch is figuring out how to balance 1. following the leading edge closely enough to be useful to folks, and 2. not following the leading edge so closely that we end up committing to maintaining things that nobody uses anymore.
So, if we wanted to build something like this in torch, we'd need to think about how to build something composable and extensible, or think about what features we could add that are likely to be useful as this technology evolves. For example, this example came out of me making sure that we had the extensibility features to build LoRA in a nice way where you didn't need custom layers :P If you have any other thoughts I'd be happy to hear them!
For example, perhaps if it's inconvenient to save/load lora weights perhaps there's something we could be doing there?
@Chillee That makes sense! Let's revisit after we see if LoRA stands the test of time. (Same for muP.)
Two questions:
- Do you see a straightforward way to merge A and B with W in this implementation?
- I assume that save/load works pretty seamlessly with tensor subclasses. If a user tries to load with torch.Tensor, they will get an error. Is that right?
Btw, if you are okay with it, I'd love to post this on Twitter and tag you since I think many might find this useful!
(Sorry about the delay.)