Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] Q cannot be reshaped correctly when model is loaded in 4bit

Open po13on opened this issue 1 year ago • 4 comments

Describe the bug Query_input's shape is [batch, pos, n_heads, d_model], and the purpose of the code where the error occurred is to reshape query_input to [batch, pos, n_heads, d_head]. I found that the shape of output of bnb.matmul_4bit is still [batch, pos, n_heads, d_model] so it cannot be reshaped to [batch, pos, n_heads, d_head].

The reason for this error may be the following code in abstract_attention.py:

if self.cfg.load_in_4bit:
            nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
            self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
            self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
        else:
            self.W_Q = nn.Parameter(
                torch.empty(
                    self.cfg.n_heads,
                    self.cfg.d_model,
                    self.cfg.d_head,
                    dtype=self.cfg.dtype,
                )
            )
            self.W_O = nn.Parameter(
                torch.empty(
                    self.cfg.n_heads,
                    self.cfg.d_head,
                    self.cfg.d_model,
                    dtype=self.cfg.dtype,
                )
            )

When model is loaded in 4bit, the shape of matrix W_Q is [(self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2, 1] which leads to the unexpected shape of the output from the bnb.matmul_4bit function. When model is not loaded in 4bit, the shape of matrix W_Q is [n_heads, d_model, d_head] which does nor trigger the bug mentioned above.

Code example Eorro message:

File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:364, in AbstractAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
    339        if self.cfg.load_in_4bit:
    340            q = self.hook_q(
    341                # call bitsandbytes method to dequantize and multiply
    342                bnb.matmul_4bit(
    343                    query_input,
    344                    self.W_Q.t(),
    345                    bias=None,
    346                    quant_state=self.W_Q.quant_state,
-->347                ).reshape(
    348                    query_input.shape[0],
    349                    query_input.shape[1],
    350                    self.cfg.n_heads,
    351                    self.cfg.d_head,
    352                )
    353                + self.b_Q
RuntimeError: shape '[20, 22, 32, 128]' is invalid for input of size 57671680"

code:

with torch.inference_mode():
            with model.hooks(fwd_hooks=fwd_hooks_corrupted):
                _ = model(corrupted)

System Info Describe the characteristic of your environment:

  • git clone
  • Linux
  • python 3.11.4

Checklist

  • [√] I have checked that there is no similar issue in the repo (required)

po13on avatar Sep 28 '24 00:09 po13on

When I load the model in 4-bit and set model.cfg.use_split_qkv_input = True, this bug will be triggered. Code Example:

model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, proxies=proxies,local_files_only=False, low_cpu_mem_usage=True, use_safetensors=False, load_in_4bit=True, torch_dtype=torch.float32, )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HookedTransformer.from_pretrained("llama-7b-hf", center_unembed=False, fold_ln=False, fold_value_biases=False, device='cuda', hf_model=model, tokenizer=tokenizer, hf_model_4bit=True, center_writing_weights=False,
)
model.cfg.use_split_qkv_input = True
model.generate("The capital of Germany is", max_new_tokens=2, temperature=0)

Error:

File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:195, in AbstractAttention.forward(self, query_input, key_input, value_input, past_kv_cache_entry, additive_attention_mask, attention_mask, position_bias)
    167 def forward(
    168     self,
    169     query_input: Union[
   (...)
    186     position_bias: Optional[Float[torch.Tensor, \"1 head_index pos kv_pos\"]] = None,
    187 ) -> Float[torch.Tensor, \"batch pos d_model\"]:
    188     \"\"\"
    189     shortformer_pos_embed is only used if self.cfg.positional_embedding_type == \"shortformer\", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
    190     past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
    191     additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
    192     attention_mask is the attention mask for padded tokens. Defaults to None.
    193     \"\"\"
--> 195     q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
    197     if past_kv_cache_entry is not None:
    198         # Appends the new keys and values to the cached values, and automatically updates the cache
    199         kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)

File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:348, in AbstractAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
    339 if self.cfg.load_in_4bit:
    340     print('In calculate_qkv_matrices: query_input.shape =', query_input.shape)  # XD debug
    341     q = self.hook_q(
    342         # call bitsandbytes method to dequantize and multiply
    343         bnb.matmul_4bit(
    344             query_input,
    345             self.W_Q.t(),
    346             bias=None,
    347             quant_state=self.W_Q.quant_state,
--> 348         ).reshape(
    349             query_input.shape[0],
    350             query_input.shape[1],
    351             self.cfg.n_heads,
    352             self.cfg.d_head,
    353         )
    354         + self.b_Q
    355     )
    356 else:
    357     q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))

RuntimeError: shape '[1, 6, 32, 128]' is invalid for input of size 786432"

po13on avatar Sep 28 '24 06:09 po13on

@po13on in order to investigate this further, I am going to need to see exactly the code you used to initialize TransformerLens. This bug could be a wide ranging bug, but more likely, it is a specific model causing the issue. I need to see the full block of code you ran to boot TransformerLens in an invalid state.

bryce13950 avatar Sep 30 '24 19:09 bryce13950

@bryce13950 I'm sorry for providing incomplete code. The model I loaded is vicuna-7b. Below is the complete code

model_name = 'lmsys/vicuna-7b-v1.3'
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, proxies=proxies,local_files_only=False, low_cpu_mem_usage=True, use_safetensors=False, load_in_4bit=True, torch_dtype=torch.float32, )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HookedTransformer.from_pretrained("llama-7b-hf", center_unembed=False, fold_ln=False, fold_value_biases=False, device='cuda', hf_model=model, tokenizer=tokenizer, hf_model_4bit=True, center_writing_weights=False,
)
model.cfg.use_split_qkv_input = True
model.generate("The capital of Germany is", max_new_tokens=2, temperature=0)

po13on avatar Oct 08 '24 05:10 po13on

No problem! Thanks for providing this. This should be enough for us to recreate it now.

bryce13950 avatar Oct 08 '24 17:10 bryce13950