transformers icon indicating copy to clipboard operation
transformers copied to clipboard

RWKV - Inference NF4 quantization broken, also Int8 quantization weirdness.

Open iantbutler01 opened this issue 2 years ago • 12 comments
trafficstars

System Info

  • transformers version: 4.30.0.dev0
  • Platform: Linux-5.15.0-70-generic-x86_64-with-glibc2.35
  • Python version: 3.10.6
  • Huggingface_hub version: 0.14.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: RTX 6000 Ada
  • Using distributed or parallel set-up in script?: Not for inference.
  • bitsandbytes 0.39.

I'm using the RWKV/rwkv-raven-14b model.

Rescaling is broken for NF4 quantization with RWKV

RuntimeError: result type Float can't be cast to the desired output type Byte

Looks like torch cannot do the conversion in _div

And then if I turn rescaling off, it looks like theres a projection issue somewhere, RuntimeError: mat1 and mat2 shapes cannot be multiplied (43x5120 and 1x13107200)

Additionally, with Int8 quantization enabled RWKV just outputs the endoftext token, I added a logits processor to output the scores and they're all NaN:

tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16)

Who can help?

@sgugger

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

I have a repo with everything setup in generate.py to be able to quickly repro here: https://github.com/iantbutler01/rwkv-raven-qlora-4bit-instruct/blob/main/generate.py

pip install -U git+https://github.com/huggingface/transformers.git pip install -U git+https://github.com/huggingface/peft.git pip install -U git+https://github.com/huggingface/accelerate.git pip install --upgrade bitsandbytes

And then run python generate.py in a python 3.10+ environment. Uncomment 8bit or 4bit bnb config as needed.

Expected behavior

I would expect NF4 based quantization to work at all, and then for Int8 quantization for logits not to be NaN.

iantbutler01 avatar May 29 '23 20:05 iantbutler01

Not sure quantization actually works for RWKV which has quite a few custom layers. cc @younesbelkada

sgugger avatar May 30 '23 13:05 sgugger

Hmm, I was able to do a 4bit finetuning with qlora last week at the very least targeting key value and receptance in the attention and feed forward blocks, it just seems like inference time is broken

I confirmed my tuned checkpoints worked fine for inference at full precision and actually it worked fine for just the forward call in 8bit in Eleuther's lm-evaluation-harness too now that I think of it, not sure for 4bit. Just seems to break when calling generate

iantbutler01 avatar May 30 '23 15:05 iantbutler01

Hi @iantbutler01 Thanks for the issue! The 8bit support should be added in https://github.com/huggingface/transformers/pull/23468 From my understanding it seems you have managed to finetune RWKV in 4bit ?

Hmm, I was able to do a 4bit finetuning with qlora last week at the very least targeting key value and receptance in the attention and feed forward blocks

Could you elaborate more on the error?

younesbelkada avatar May 31 '23 09:05 younesbelkada

@younesbelkada

In regards to int8, I've been testing on the development branch, which includes the code you've merged there and it very much just produces tensor([[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0', dtype=torch.float16) for the logits during a generate call even with the base RWKV 14b model so I think something is still broken. You can reproduce this easily with the steps I've linked in the issue here.

For example, with

AndBytesConfig(
    load_in_8bit=True
)

model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-raven-14b",
    return_dict=True,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    context_length=1024,
    # rescale_every=0,
).cuda()

tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-raven-14b")

pipeline = InstructionTextGenerationPipeline(
    model=model,
    tokenizer=tokenizer,
    top_p=0.92,
    top_k=50,
    temperature=1.0,
)
instruction = "Write me the steps to make a peanut butter and jelly sandwich"
prompt = PROMPT_FOR_GENERATION_FORMAT.format(
    instruction=instruction,
)

class IsBork(LogitsProcessor):
    def __call__(self, input_ids, scores):
        print(scores)
        return scores
    
prompt = str(prompt)
inputs = tokenizer(prompt, return_tensors="pt")

input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
input_ids, attention_mask = input_ids.to("cuda"), attention_mask.to("cuda")

generated_sequence = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    logits_processor=LogitsProcessorList([IsBork()]),
    pad_token_id=tokenizer.pad_token_id,
    top_p=0.92,
    top_k=50,
    temperature=1.0,
    max_new_tokens=512
)

print(generated_sequence)

The call to generate raises an error,

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 171, in <module>
    gen = pipeline(prompt, max_new_tokens=512)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/pipelines/base.py", line 1118, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/pipelines/base.py", line 1125, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/pipelines/base.py", line 1024, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/instruct_pipeline.py", line 112, in _forward
    generated_sequence = self.model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1568, in generate
    return self.sample(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2651, in sample
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0      

Adding a logits processor that just prints out scores shows on the first token generated,

tensor([[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0', dtype=torch.float16)

If I then set do_sample=False

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Write me the steps to make a peanut butter and jelly sandwich

### Response:
<|endoftext|>

It only generates end of text, where as the full precision model generates correctly.

iantbutler01 avatar May 31 '23 14:05 iantbutler01

In regards to 4bit rescaling during inference is broken for NF4 quantization with RWKV if you try to run inference, with a generate call with nf4 quantization:

RuntimeError: result type Float can't be cast to the desired output type Byte which is failing in the else statement of that block your int8 PR touches.

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 181, in <module>
    generated_sequence = model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1518, in generate
    return self.greedy_search(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2335, in greedy_search
    outputs = self(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 781, in forward
    rwkv_outputs = self.rwkv(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 642, in forward
    self._rescale_layers()
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 713, in _rescale_layers
    block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))

And then if I turn rescaling off by setting rescale_every=0, it looks like theres a projection issue somewhere, RuntimeError: mat1 and mat2 shapes cannot be multiplied (43x5120 and 1x13107200)

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 181, in <module>
    generated_sequence = model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1518, in generate
    return self.greedy_search(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2335, in greedy_search
    outputs = self(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 781, in forward
    rwkv_outputs = self.rwkv(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 667, in forward
    hidden_states, state, attentions = block(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 384, in forward
    attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 308, in forward
    receptance, key, value, state = self.extract_key_value(hidden, state=state)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 300, in extract_key_value
    key = self.key(key)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 219, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 564, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 512, in forward
    output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (42x5120 and 1x13107200)

But yeah I have this all reproducible in the script I've linked in the issue.

iantbutler01 avatar May 31 '23 14:05 iantbutler01

I see, thanks for sharing more details with me So there are 2 issues here:

1- int8 RWKV seems to not work with you. From the snippet I am seeing, you are calling .cuda() on the 8bit model. This might lead to unexpected behavior because any .to(xxx) calls to the 8bit model will re-compute the quantization statistics. I have managed to reproduce your issue with the snippet below:

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

model_id = "RWKV/rwkv-4-1b5-pile"

model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0}).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_id)

generation_config = GenerationConfig(max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
question = "Hello my name is"
inputs = tokenizer(question, return_tensors="pt").to(0)
output_int8 = model.generate((inputs["input_ids"]), generation_config=generation_config)
print(tokenizer.decode(output_int8[0], skip_special_tokens=True))

and the model directly predicts EOS token. The fix is to replace model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0}).cuda() by model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0}). Could you confirm this fixes your issue?

2- RWKV + 4bit seems to be not supported for now. I will dig into that and let you know as soon as I have a fix

younesbelkada avatar May 31 '23 14:05 younesbelkada

I just added the 4bit inference support for RWKV in #23910 - please try out the fixes stated above together with #23910 and let us know how it goes

younesbelkada avatar May 31 '23 15:05 younesbelkada

@younesbelkada

Okay so 8bit is working fine now, thank you very much for the workaround!

4bit loaded in with this configuration:

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)


model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-raven-14b",
    return_dict=True,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    context_length=1024,
    # rescale_every=0,
    device_map={"":0}
)

Is still failing unfortunately, :(

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 182, in <module>
    generated_sequence = model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1518, in generate
    return self.greedy_search(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2335, in greedy_search
    outputs = self(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 789, in forward
    rwkv_outputs = self.rwkv(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 642, in forward
    self._rescale_layers()
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 714, in _rescale_layers
    block.attention.output.weight.quant_state[0].div_(
RuntimeError: result type Float can't be cast to the desired output type Byte

iantbutler01 avatar May 31 '23 18:05 iantbutler01

I see, this is because you are using nested quantization bnb_4bit_use_double_quant=True. Can you try without that while I find a fix for this specific usecase? 🙏

younesbelkada avatar Jun 01 '23 07:06 younesbelkada

Yes sorry about that, I had always intended this to be with double quant, that was in my original repro code, but I should have been more explicit when communicating it to you 👍

I tried it without double quantization and it does work.

iantbutler01 avatar Jun 01 '23 13:06 iantbutler01

No problem and thanks for double checking, will get back once I fix the issue with nested quantization!

younesbelkada avatar Jun 01 '23 14:06 younesbelkada

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 30 '23 15:07 github-actions[bot]

I think It should not be closed @younesbelkada

jonataslaw avatar Aug 09 '23 04:08 jonataslaw

Correct, it is known that RWKV double-quant 4bit inference does not work yet, not sure if I can propose a fix anytime soon because of the rescale layers operation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py#L722

younesbelkada avatar Aug 17 '23 09:08 younesbelkada