Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Proposal] Add support for Baichuan1 and Baichuan2

Open StarrySeas1 opened this issue 1 year ago • 3 comments

Hello, does this analysis tool support Baichuan1 and Baichuan2?

Hello, I want to use this tool to analyze the Baichuan1 and Baichuan2 models. I don’t know if it is supported.

StarrySeas1 avatar Jun 03 '24 06:06 StarrySeas1

TransformerLens does not currently support Baichuan. Looking at their page on HuggingFace, it appears to be pretty similar to LLaMA, which means that it should be relatively easy to add, barring any surprises. @StarrySeas1 if you would like to use this model, I can walk you through how to do it, and you can give it a shot. It should be a matter of adding an alias to llama, and then adding a configuration block to loading_from_pretrained to make sure it matches the config on hugging face https://huggingface.co/baichuan-inc/Baichuan-7B/blob/main/config.json.

bryce13950 avatar Jun 03 '24 17:06 bryce13950

Here are the modifications I made:

  1. Added "Baichuan-13B-Chat" to the "OFFICIAL_MODEL_NAMES".
  2. Added the configuration for Baichuan in the "convert_hf_model_config" function.
elif "Baichuan-13B" in official_model_name:
        cfg_dict = {
            "d_model": hf_config.hidden_size,
            "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
            "n_heads": hf_config.num_attention_heads,
            "d_mlp": hf_config.intermediate_size,
            "n_layers": hf_config.num_hidden_layers,
            "n_ctx": 2048,  # Capped due to HF Tokenizer Constraints
            "d_vocab": hf_config.vocab_size,
            "eps": hf_config.rms_norm_eps,
            "act_fn": hf_config.hidden_act,
            "initializer_range": hf_config.initializer_range,
            "normalization_type": "RMS",
            "positional_embedding_type": "alibi",
            "post_embedding_ln": True,
            "positional_embedding_type": "alibi",
        }
  1. Added a function "convert_baichuan_weights"
def convert_baichuan_weights(baichuan, cfg: HookedTransformerConfig):
    state_dict = {}

    state_dict["embed.W_E"] = baichuan.model.embed_tokens.weight
    
    assert cfg.d_mlp is not None  # keep mypy happy

    for l in range(cfg.n_layers):
        state_dict[f"blocks.{l}.ln1.w"] = baichuan.model.layers[l].input_layernorm.weight

        W = baichuan.model.layers[l].self_attn.W_pack.weight

        W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)

        W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
        W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads)
        W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads)
        W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads)
        state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
        state_dict[f"blocks.{l}.attn.W_K"] = W_K
        state_dict[f"blocks.{l}.attn.W_V"] = W_V
        
        state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
            cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device
        )
        state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(
            cfg.n_heads,
            cfg.d_head,
            dtype=cfg.dtype,
            device=W_Q.device,
        )
        state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(
            cfg.n_heads,
            cfg.d_head,
            dtype=cfg.dtype,
            device=W_Q.device,
        )
        
        W_O = baichuan.model.layers[l].self_attn.o_proj.weight
        W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
        state_dict[f"blocks.{l}.attn.W_O"] = W_O
        state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
            cfg.d_model, dtype=cfg.dtype, device=W_O.device
        )

        state_dict[f"blocks.{l}.ln2.w"] = baichuan.model.layers[l].post_attention_layernorm.weight
        
        state_dict[f"blocks.{l}.mlp.W_in"] = baichuan.model.layers[l].mlp.up_proj.weight.T
        state_dict[f"blocks.{l}.mlp.W_gate"] = baichuan.model.layers[l].mlp.gate_proj.weight.T
        state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=W_O.dtype)

        state_dict[f"blocks.{l}.mlp.W_out"] = baichuan.model.layers[l].mlp.down_proj.weight.T
        state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=W_O.dtype)
        
    state_dict["ln_final.w"] = baichuan.model.norm.weight
    state_dict["unembed.W_U"] = baichuan.lm_head.weight.T
    state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=W_O.dtype)
    
    return state_dict

The following method calls can be used to load without error

model = transformer_lens.HookedTransformer.from_pretrained(ckpt_path,fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,device="cuda:6",trust_remote_code=True)

Question:

  1. Are the above additions correct?
  2. I compared the Attention Patterns of Baichuan-13B-Chat and the Baichuan-13B-Chat after SFT on the same sample, and they are basically not much different, with the values on the first few Tokens being very large. I'm not sure if this is a problem with the model training or if my addition of support for Baichuan is incorrect.

StarrySeas1 avatar Jun 04 '24 04:06 StarrySeas1

Thank you very much for doing the work here to get it to work in TransformerLens! Your implementation from a glance seems correct. However, it's hard to say if the discrepancies are a result of something being slightly off, or if it is something of no concern without playing with the code directly. Can you setup a PR with these changes in it, so that we can double check everything together?

bryce13950 avatar Jun 06 '24 15:06 bryce13950