transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add RWKV-4

Open sgugger opened this issue 1 year ago • 1 comments

What does this PR do?

This PR is a draft and while there is a working implementation of the model, there is still a lot to do :-)

This PR adds the RWKV model from BlinkDL/RWKV-LM which is a RNN-like Transformers: it has an attention layer and a feed-forward, but the attention is linear and can be expressed recurrently (more details coming in the doc page of the model).

Here is a code snippet to play with the model:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("sgugger/rwkv-7b-pile", torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("sgugger/rwkv-7b-pile")

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=400, top_p=0.8, do_sample=True)
print(tokenizer.decode(output[0].tolist()))

TODO:

  • [x] Write documentation of the model explaining the linear attention and the recurrent formulas in the code
  • [x] Make the model compatible with generate
  • [ ] Add output_attentions/output_hidden_states API
  • [ ] Convert mode models and check conversion script is compatible
  • [x] Tweak CUDA kernels for state to use the state for init
  • [ ] Make tests that pass
  • [ ] Add attention mask to be able to batch sentences (might be in a followup PR)

cc @ArthurZucker and @younesbelkada

sgugger avatar Apr 16 '23 22:04 sgugger

The documentation is not available anymore as the PR was closed or merged.

IMO the model is in a nice shape! Would love to have a round of review before I transfer the weights on the proper organization!

younesbelkada avatar May 04 '23 15:05 younesbelkada

@younesbelkada In README.md

The name should be "Bo Peng" (Peng is the surname) instead of "Peng Bo" :)

BlinkDL avatar May 09 '23 04:05 BlinkDL

hi @sgugger, thanks A TON for this merge! I am trying to train a new model of type and facing the following error:

Traceback (most recent call last):
  File "train.py", line 229, in <module>
    main(model_args, data_args, training_args)
  File "train.py", line 193, in main
    trainer.train()
  File "transformers/src/transformers/trainer.py", line 1664, in train
    return inner_training_loop(
  File "transformers/src/transformers/trainer.py", line 1940, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "transformers/src/transformers/trainer.py", line 2753, in training_step
    loss.backward()
  File ".conda/envs/rwkv-eval-3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File ".conda/envs/rwkv-eval-3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File ".conda/envs/rwkv-eval-3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
TypeError: backward() takes 2 positional arguments but 3 were given

From what I can see, the backward function of RwkvLinearAttentionBackward does not mention a g_state - should gradients be computed for the state, I guess not? Any pointers as to how I can resolve this will be very much appreciated!

YovaKem avatar May 11 '23 21:05 YovaKem

I managed to get the code to run with some changes to the forward() and backward() functions:

class RwkvLinearAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):

        batch_size, seq_len, hidden_size = key.size()
        if seq_len > rwkv_cuda_kernel.max_seq_length:
            raise ValueError(
                f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
                f"{rwkv_cuda_kernel.max_seq_length} with this model."
            )
        if batch_size * hidden_size % min(hidden_size, 32) != 0:
            raise ValueError(
                f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
                f"multiple of {min(hidden_size, 32)}."
            )

        ctx.input_dtype = key.dtype

        if (
            time_decay.device.type != "cuda"
            or time_first.device.type != "cuda"
            or key.device.type != "cuda"
            or value.device.type != "cuda"
        ):
            raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")

        time_decay = -torch.exp(time_decay.float().contiguous())
        if key.dtype == torch.float16:
            time_first = time_first.float()
            key = key.float()
            value = value.float()
        time_first = time_first.contiguous()
        key = key.contiguous()
        value = value.contiguous()
        # The CUDA kernel will fill this tensor.
        output = torch.empty_like(key, memory_format=torch.contiguous_format)
        if return_state or state is not None:
            if state is None:
                state = torch.zeros(
                    batch_size,
                    hidden_size,
                    3,
                    dtype=torch.float32,
                    device=key.device,
                    memory_format=torch.contiguous_format,
                )
                state[:, :, 2] -= 1e38
            else:
                state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()

            if key.dtype == torch.bfloat16:
                forward_func = rwkv_cuda_kernel.forward_with_state_bf16
            else:
                forward_func = rwkv_cuda_kernel.forward_with_state
            forward_func(time_decay, time_first.to(key.dtype), key, value, output, state)
        else:
            forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
            forward_func(time_decay, time_first.to(key.dtype), key, value, output)
        ctx.save_for_backward(time_decay, time_first, key, value, output)

        if state is not None:
            state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]

        return output.to(ctx.input_dtype), state
    def backward(ctx, g_output, g_state):
        input_dtype = ctx.input_dtype

        time_decay, time_first, key, value, output = ctx.saved_tensors
        # The CUDA kernel will fill those tensors.
        g_time_decay = torch.empty_like(
            time_decay,
            memory_format=torch.contiguous_format,
            dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        )
        g_time_first = torch.empty_like(
                time_first,
                memory_format=torch.contiguous_format,
                dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        )
        g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
        g_value = torch.empty_like(value, memory_format=torch.contiguous_format)

        if input_dtype == torch.float16:
            g_output = g_output.float()
        backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
        backward_func(
            time_decay,
            time_first.to(key.dtype),
            key,
            value,
            output,
            g_output.contiguous(),
            g_time_decay,
            g_time_first,
            g_key,
           g_value,
        )
        #g_time_decay = torch.sum(g_time_decay, dim=0)
        #g_time_first = torch.sum(g_time_first, dim=0)

        return (
            g_time_decay.to(input_dtype),
            g_time_first.to(input_dtype),
            g_key.to(input_dtype),
            g_value.to(input_dtype),
            None,
            None
        )

One problem I run into now is that although I'm trying to train a fairly small model (12 layers, 256 hidden size, 64 context size) I can only train with a very small batch size (16) on a 40GB A100 card. For comparison, a RoBERTa model with a similar size allows for a bs of 256. This seems counterintuitive to me, but I might be wrong.

Another issue I observed is instability: in some cases, within the first 3 steps of training the loss goes from something normal like 10 to 90543067814198.3 and then to 0.0. This seems to happen more when bf16 training is disabled and at higher batch sizes when bf16 training is enabled.

YovaKem avatar May 12 '23 09:05 YovaKem

@YovaKem Would you mind try change this

# The CUDA kernel will fill those tensors.
g_time_decay = torch.empty_like(
    time_decay,
    memory_format=torch.contiguous_format,
    dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)

to

# The CUDA kernel will fill those tensors.
g_time_decay = torch.empty(
    key.shape[0], key.shape[2],
    memory_format=torch.contiguous_format,
    dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_time_first = torch.empty(k.shape[0], k.shape[2], memory_format=torch.contiguous_format)

I suspect there's an overflow in the current code, as mentioned above in the review comment but not tested yet. The binary distribution on PyPI does not include the cuda kernels XD

Also, the gradient of the state should be computed, but the current kernel is not doing it. Later after I setup the env I'll open the PR.

Blealtan avatar May 13 '23 13:05 Blealtan

Thanks @Blealtan! I guess you meant k for key? I added bf16 support for g_time_first (I get an error otherwise) and put the tensors on CUDA

        # The CUDA kernel will fill those tensors.
        g_time_decay = torch.empty(
            key.shape[0], key.shape[2],
            memory_format=torch.contiguous_format,
            dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        ).to(key.device)
        g_time_first = torch.empty(
                key.shape[0], key.shape[2],
                memory_format=torch.contiguous_format,
                dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
        ).to(key.device)

This seems to solve both the OOM issue and the instability!

One question re your comment of state gradients - I now saw this

It will also match the _with_state variant of WKV forward.

In what cases is the _with_state variant used? As far as I can see the model I'm training is not passing states at all during the forward step. Is that something that only becomes relevant an inference time when the model is used like an RNN?

YovaKem avatar May 13 '23 20:05 YovaKem

Hey @sgugger how did you prepare the models? Could you point us how to convert original .pth or .safetensors model to your format? Thanks!

PS Awesome RWKV joined transformers!

lambdaofgod avatar May 14 '23 09:05 lambdaofgod

@lambdaofgod The logic used to convert the RWKV checkpoints from BlinkDL to HF format can be found in the conversion script.

amyeroberts avatar May 15 '23 15:05 amyeroberts

@YovaKem AFAIK, with_state is used only in inference now (in existing non-transformers implementations throughout the RWKV community). However, with proper implementation, this will allow more efficient training on long sequences, but it has not yet been implemented.

Blealtan avatar May 16 '23 10:05 Blealtan

I have no idea why the CUDA kernels all disappeared from the pacakge on Pypi (it's not just RWKV, but all models using custom kernels). Will investigate later today and post a patch release when I find a solution.

sgugger avatar May 16 '23 13:05 sgugger

Normally custom kernels should be included in 4.29.2, sorry for the inconvenience. We added stronger to checks to make sure they don't disappear again in a future release.

sgugger avatar May 16 '23 19:05 sgugger

Hi, can i ask a simple question about RWKV kernel? The rwkv model without customized kernel uses a for loop here: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/models/rwkv/modeling_rwkv.py#L223-L241

I am not familiar with cuda kernel. So i am not sure whether the customized cuda kernel still computes sequentially and delivers a faster for loop, or just make the computation parallelized in GPU?

Wednesday657 avatar May 21 '23 11:05 Wednesday657

Putting this here so it doesn't get lost.

I am trying to run microsoft guidance (https://github.com/microsoft/guidance) on RWKV through transformers and I am getting an error

AttributeError: 'RwkvCausalLMOutput' object has no attribute 'past_key_values'

which can be reproduced here: https://gist.github.com/fullstackwebdev/a6523374e6687825fcb92ca74048c12b

fullstackwebdev avatar May 21 '23 22:05 fullstackwebdev

@fullstackwebdev I don't think the fix should go inside transformers as this means we should always output past_key_values=None - which is quite misleading as by desing RWKV does not rely on past_key_values for caching - as the tokens are processed one by one. I made https://github.com/microsoft/guidance/pull/91 that fixed the issue in my local env

younesbelkada avatar May 22 '23 09:05 younesbelkada