Clarification about weight tying mechanism in OLMo 1B: shared modules vs. shared weights
❓ The questions
- whether the current method of tying weights by setting module instances to be the same is the intended behavior and if there are specific benefits to this approach.
- whether it makes sense to consider an alternative implementation that handles weight tying at the weight level and not at the module level, if only for the sake of improving compatibility with the
acceleratelibrary (and any other tools that may expect weight tying to follow this pattern)
Context
I kept seeing the warning "The model weights are not tied. Please use the tie_weightsmethod before using theinfer_auto_device function" when loading the OLMo 1B model with device_map=auto and spent some time investigating where that came from:
In the tie_weights method of the OLMo model, the weights are tied by directly setting self.model.transformer.ff_out to self.model.transformer.wte. While this approach ties the modules, it also makes them the same instance, which isn't the typical expectation for weight tying (usually expected to share tensors, not entire module instances) (If I understand correctly). This causes find_tied_parameters from Hugging Face Accelerate to raise a warning—it expects to find distinct parameters that share the same tensor, not the same module reference, so under the current implementation of tie_weights, it fails to find the tied weights, triggering the warning.
e.g. the current implementation
def tie_weights(self):
if self.config.weight_tying:
self.model.transformer.ff_out = self.model.transformer.wte
Leads to
from accelerate.utils.modeling import find_tied_parameters
find_tied_parameters(model)
[]
and results in the warning "The model weights are not tied. Please use the tie_weightsmethod before using theinfer_auto_device function." when loading the model with Accelerate with device_map=auto.
Whereas separately initializing ff_out and sharing only the weights:
def tie_weights(self):
if self.config.weight_tying:
if isinstance(self.model.transformer.wte, nn.Embedding):
num_embeddings, embedding_dim = self.model.transformer.wte.weight.shape
self.model.transformer.ff_out = nn.Embedding(num_embeddings, embedding_dim)
self.model.transformer.ff_out.weight = self.model.transformer.wte.weight
hf_olmo.OLMoForCausalLM.tie_weights = tie_weights
Leads to
find_tied_parameters(model)
[['model.transformer.ff_out.weight', 'model.transformer.wte.weight']]
and removes the warning.
My questions are:
- whether the current method of tying weights by setting module instances to be the same is the intended behavior and if there are specific benefits to this approach.
- whether it makes sense to consider an alternative implementation that handles weight tying at the weight level and not at the module level, if only for the sake of improving compatibility with the
acceleratelibrary (and any other tools that may expect weight tying to follow this pattern)
More than happy to help contribute if you think it makes sense to make any changes based on this! See the updated tie_weights above for a starting point; this did not appear to change any behavior aside from removing the warning.
Thanks for everything! I'm enjoying working with OLMo and appreciate all of your efforts on this fantastic project.
Hi djliden, thank you for raising this concern. When weight tying is turned on, there is no self.model.transformer.ff_out module in the model. Thus, I believe our current implementation is wrong and that the fix is to either remove our tie_weights implementation altogether (if that causes no issues), or make our tie_weights no-op. @AkshitaB please correct me if I'm wrong.
Regarding why we do weight tying at the module level (using wte for both input and output) instead of weight level, @epwalsh may know more. The one issue I could imagine is that it might not go well with FSDP (Fully Sharded Data Parallel). FSDP splits up the weights in Modules of a model across multiple GPUs, in order to reduce GPU memory usage (and thus enable training at higher scales). If weights were tied at the weight level, I wouldn't be surprised if FSDP would not notice the tying and would have 2 copies of the weights sharded across the GPUs.
I am happy to review a PR fixing this issue.
Thanks for the response on this issue @2015aroras ! I'm not sure I completely follow yet—specifically, what do you mean by "there is no self.model.transformer.ff_out module in the model"?
At least when loading via from_pretrained, and verifying in the config that we have "weight_tying": true, model.model shows:
Olmo(
(transformer): ModuleDict(
(wte): Embedding(50304, 2048)
(emb_drop): Dropout(p=0.0, inplace=False)
(ln_f): LayerNorm()
(blocks): ModuleList(
(0-15): 16 x OlmoSequentialBlock(
(dropout): Dropout(p=0.0, inplace=False)
(act): SwiGLU()
(attn_out): Linear(in_features=2048, out_features=2048, bias=False)
(ff_out): Linear(in_features=8192, out_features=2048, bias=False)
(rotary_emb): RotaryEmbedding()
(attn_norm): LayerNorm()
(ff_norm): LayerNorm()
(att_proj): Linear(in_features=2048, out_features=6144, bias=False)
(ff_proj): Linear(in_features=2048, out_features=16384, bias=False)
)
)
(ff_out): Embedding(50304, 2048)
)
)
and model.model.transformer.ff_out is model.model.transformer.wte evaluates to True.
When loading with from_pretrained, it looks like tie_weights is called here, which would invoke the method defined here, which is what copies the wte module.
Though from this it looks like the ff_out layer isn't used regardless? And does this mean if one were loading the model via a means other than transformers from_pretrained the ff_out would, indeed, not exist?
In other words, it looks like loading with transformers from_pretrained results in the addition of the ff_out layer, which would not be present otherwise, but it's still not consequential because it's (a) the same module as wte and (b) not used anyway because of how logits are computed when weight tying is enabled. Am I understanding that right? In which case, to your point, the fix should basically involve making sure tie_weights doesn't do anything.
Thanks again!
When loading with
from_pretrained, it looks liketie_weightsis called here, which would invoke the method defined here, which is what copies thewtemodule.Though from this it looks like the
ff_outlayer isn't used regardless? And does this mean if one were loading the model via a means other than transformersfrom_pretrainedtheff_outwould, indeed, not exist?In other words, it looks like loading with transformers
from_pretrainedresults in the addition of theff_outlayer, which would not be present otherwise, but it's still not consequential because it's (a) the same module aswteand (b) not used anyway because of how logits are computed when weight tying is enabled. Am I understanding that right? In which case, to your point, the fix should basically involve making suretie_weightsdoesn't do anything.
Yes to all those questions, from my knowledge. Hence my suggestion to remove tie_weights or make it a no-op.
Got it, thanks for confirming my understanding. I'll look into the effects of removing tie_weights or making it a no-op.
Thank you!
Closing this since a fix has been merged. Please reopen as necessary.