ReLoRA - what we've learned
I'll try to summarize learnings from this project below, hopefully I don't forget anything important. Here is the WandB project
- We were unable to reproduce the results in the original paper, even while referencing their GitHub Repo The plan was to reproduce results at the tiny (~300M) scale and then try to scale it up to 1B or 7B. We never got to the second stage as we were unable to reproduce results at the same scale as the paper authors. There are a number of possible reasons why:
- The technique is very sensitive to the specific [data, model architecture, hyperparamters]. While we did try to mirror our model and hyper-parameters as closely as possible to the original paper, our architecture was slightly different in a few ways, e.g. we used alibi and they used rope.
- There is a very subtle bug somewhere in my relora code. Even though I took lots of code from the relora repo, I haven't ruled this out, as even with a high rank (like r=d_model), our models had flatter learning curves than the standard OLMo tiny model. If there is a bug, it is likely an FSDP quirk, since the authors' repo did not shard their models.
- some third thing
- Our base model's loss curve is better than the authors'. Good job team if this is the case. It could be that ReLoRA is strong enough to track the loss curve of their weaker model but weak enough where it wasn't able to do the same for our OLMo model.
-
I'm not convinced this sort of merging process is even viable for pre-training. I was talking to Ananya and he said he would be surprised if you can successfully do pre-training in this train-then-merge paradigm, regardless of the rank. Indeed, we did see high-rank training fail in similar ways to the low-rank ones. The only way we were able to track the base model's curve is by replicating the dimensions of W in the lora_A matrix, and then making lora_B essentially a square identity matrix.
-
For torch < 2.1.0, FSDP doesn't supporting mixing of
requires_gradfor sharded parameters. This means when freezing only a subset of params, you need to decide whether to shard frozen or trainable params. They officially added support for this type of mixing in torch 2.1.0. Except... not really. The only real difference is it won't throw an error, but it still won't do full-sharding unless you wrap each individual module with its own FSDP. That is, every call to theFSDPconstructor must wrap a module with uniformrequires_gradto achieve full sharding, and the submodules which you are separately wrapping must be inignored_modulesfor top or higher-level wraps. Bummer. -
We found that torch 2.1 FSDP calls
reset_parameters()only on modules who directly manage parameters. Importantly, this means it won't call it on the top-level OLMo module. This is an easy fix but it's good to know these sorts of changes exist and are not well-documented, so if/when we upgrade be on the lookout for these sorts of bugs.
- I don't get what you mean in the comment, there are at least 3 paper pertaining with branch merge methods, what are you saying about them? (none are mine, just following and interested)
Marking the items prior to Feb 29th as "closed".