Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Proposal] Compatibility for OLMo and OLMo2?

Open spaidataiga opened this issue 1 year ago • 4 comments

Proposal

It would be nice to include OLMo (1B and 7B) and their checkpoints as available compatible models for HookedTransformer.

Motivation

OLMo-1B would be a great model to do some mechanistic interpretability, especially as it is very open-source, allowing us to see relations between training data/processes, checkpoints and model performance. It should have fairly similar architecture to already compatible models. If it is possible to get it to run already, I would really appreciate a link to some information, as I've tried to look through the documentation myself in the meantime.

Pitch

Add OLMo-1B, -7B. Add OLMo2-7B and -13B. Add model checkpoints?

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)

spaidataiga avatar Nov 28 '24 12:11 spaidataiga

I would just like to express my enthusiastic endorsement of this proposal. I tried to the implementation a little bit, and thought I would share some of what that revealed. It seems to me that OLMo-1 and OLMo-2 follow Llama-2 quite closely. For example, in convert_hf_model_config() inside of loading_from_pretrained.py, something similar to

if official_model_name.startswith(
        ("olmo2-7b", "allenai/OLMo-2-1124-7B")
    ):  # same architecture for LLaMA and Llama-2
        cfg_dict = {
            "d_model": 4096,
            "d_head": 4096 // 32,
            "n_heads": 32,
            "d_mlp": 11008,
            "n_layers": 32,
            "n_ctx": 4096,
            "eps": 1e-6,
            "d_vocab": 100352,
            "act_fn": "silu",
            "normalization_type": "RMS",
            "positional_embedding_type": "rotary",
            "rotary_adjacent_pairs": False,
            "rotary_dim": 4096 // 32,
            "final_rms": True,
            "gated_mlp": True,
        }

might make sense? Then we would probably need a new pretrained/weight_conversions/olmo2.py file. A nuance here seems to be that when loaded with the newest version HuggingFace transformers (at the time of writing), an Olmo2ForCausalLM object looks like

Olmo2ForCausalLM(
  (model): Olmo2Model(
    (embed_tokens): Embedding(100352, 4096, padding_idx=100277)
    (layers): ModuleList(
      (0-31): 32 x Olmo2DecoderLayer(
        (self_attn): Olmo2SdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): Olmo2RotaryEmbedding()
          (q_norm): Olmo2RMSNorm((4096,), eps=1e-06)
          (k_norm): Olmo2RMSNorm((4096,), eps=1e-06)
        )
        (mlp): Olmo2MLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (post_attention_layernorm): Olmo2RMSNorm((4096,), eps=1e-06)
        (post_feedforward_layernorm): Olmo2RMSNorm((4096,), eps=1e-06)
      )
    )
    (norm): Olmo2RMSNorm((4096,), eps=1e-06)
  )
  (lm_head): Linear(in_features=4096, out_features=100352, bias=False)
)

Whereas a LlamaForCausalLM object from 'meta-llama/Llama-2-7b-hf' looks like

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
 

So for example, OLMo-2 has (post_attention_layernorm), (post_feedforward_layernorm), at every layer, as opposed to (input_layernorm), (post_attention_layernorm) for Llama-2. It also has additional (rotary_emb), (q_norm), (k_norm) in every self_attn module which Llama-2 does not, while missing the model-wide (rotary_emb) that Llama-2 has. Finally, there's the vocabulary size of 100352 in OLMo-2 vs 32000 in Llama-2. Finally, Olmo2RMSNorm and LlamaRMSNorm seem to both be equivalent to T5LayerNorm.

I'm tempted to give a PR a shot but I'm not sure if I know enough about TransformerLens. Is there anyone who could bridge the gap?

Neelectric avatar Dec 13 '24 18:12 Neelectric

Actually, it looks like #718 as well as https://github.com/jonasrohw/TransformerLens/tree/OLMo are closely related

Neelectric avatar Dec 13 '24 18:12 Neelectric

It should be in the next release!

bryce13950 avatar Jan 06 '25 22:01 bryce13950

Hi, I'm working on supporting OLMo2 in a quick and dirty way. I've created a PR to @jonasrohw 's OLMo branch. I wonder if I missed something since my logits from utils.test_prompt do not match those of the HF implementation. Help would be much appreciated, thanks!

I did turn use HookedTransformer.from_pretrained_no_processing since the fold_layer_norm logic would be drastically changed.

Ja1Zhou avatar Feb 02 '25 05:02 Ja1Zhou

It should be in the next release!

Hi, I think it is not supported in the newest version yet?

PosoSAgapo avatar May 28 '25 05:05 PosoSAgapo

Is OLMo support still on the roadmap?

taziksh avatar Sep 29 '25 04:09 taziksh