[Proposal] Compatibility for OLMo and OLMo2?
Proposal
It would be nice to include OLMo (1B and 7B) and their checkpoints as available compatible models for HookedTransformer.
Motivation
OLMo-1B would be a great model to do some mechanistic interpretability, especially as it is very open-source, allowing us to see relations between training data/processes, checkpoints and model performance. It should have fairly similar architecture to already compatible models. If it is possible to get it to run already, I would really appreciate a link to some information, as I've tried to look through the documentation myself in the meantime.
Pitch
Add OLMo-1B, -7B. Add OLMo2-7B and -13B. Add model checkpoints?
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
I would just like to express my enthusiastic endorsement of this proposal. I tried to the implementation a little bit, and thought I would share some of what that revealed. It seems to me that OLMo-1 and OLMo-2 follow Llama-2 quite closely. For example, in convert_hf_model_config() inside of loading_from_pretrained.py, something similar to
if official_model_name.startswith(
("olmo2-7b", "allenai/OLMo-2-1124-7B")
): # same architecture for LLaMA and Llama-2
cfg_dict = {
"d_model": 4096,
"d_head": 4096 // 32,
"n_heads": 32,
"d_mlp": 11008,
"n_layers": 32,
"n_ctx": 4096,
"eps": 1e-6,
"d_vocab": 100352,
"act_fn": "silu",
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"rotary_dim": 4096 // 32,
"final_rms": True,
"gated_mlp": True,
}
might make sense? Then we would probably need a new pretrained/weight_conversions/olmo2.py file. A nuance here seems to be that when loaded with the newest version HuggingFace transformers (at the time of writing), an Olmo2ForCausalLM object looks like
Olmo2ForCausalLM(
(model): Olmo2Model(
(embed_tokens): Embedding(100352, 4096, padding_idx=100277)
(layers): ModuleList(
(0-31): 32 x Olmo2DecoderLayer(
(self_attn): Olmo2SdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): Olmo2RotaryEmbedding()
(q_norm): Olmo2RMSNorm((4096,), eps=1e-06)
(k_norm): Olmo2RMSNorm((4096,), eps=1e-06)
)
(mlp): Olmo2MLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(post_attention_layernorm): Olmo2RMSNorm((4096,), eps=1e-06)
(post_feedforward_layernorm): Olmo2RMSNorm((4096,), eps=1e-06)
)
)
(norm): Olmo2RMSNorm((4096,), eps=1e-06)
)
(lm_head): Linear(in_features=4096, out_features=100352, bias=False)
)
Whereas a LlamaForCausalLM object from 'meta-llama/Llama-2-7b-hf' looks like
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
So for example, OLMo-2 has (post_attention_layernorm), (post_feedforward_layernorm), at every layer, as opposed to (input_layernorm), (post_attention_layernorm) for Llama-2. It also has additional (rotary_emb), (q_norm), (k_norm) in every self_attn module which Llama-2 does not, while missing the model-wide (rotary_emb) that Llama-2 has. Finally, there's the vocabulary size of 100352 in OLMo-2 vs 32000 in Llama-2. Finally, Olmo2RMSNorm and LlamaRMSNorm seem to both be equivalent to T5LayerNorm.
I'm tempted to give a PR a shot but I'm not sure if I know enough about TransformerLens. Is there anyone who could bridge the gap?
Actually, it looks like #718 as well as https://github.com/jonasrohw/TransformerLens/tree/OLMo are closely related
It should be in the next release!
Hi, I'm working on supporting OLMo2 in a quick and dirty way. I've created a PR to @jonasrohw 's OLMo branch. I wonder if I missed something since my logits from utils.test_prompt do not match those of the HF implementation. Help would be much appreciated, thanks!
I did turn use HookedTransformer.from_pretrained_no_processing since the fold_layer_norm logic would be drastically changed.
It should be in the next release!
Hi, I think it is not supported in the newest version yet?
Is OLMo support still on the roadmap?