verl icon indicating copy to clipboard operation
verl copied to clipboard

SFT FlashAttention questions

Open vadimkantorov opened this issue 7 months ago • 7 comments

In the official log https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/gemma-2-2b-it-sft-0.411.log (and in my repro):

Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Gemma2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
  1. Is it expected? Does the trainer attempt to run / learn the model in float32? Or is it just loading the model in float32?

  2. Is it attempting to run the model on CPU? Where is this warning coming from?

Related:

  • https://github.com/volcengine/verl/issues/1498

vadimkantorov avatar May 13 '25 11:05 vadimkantorov

What about the warning:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').

I believe that for me it leads to an illegal memory access (guessing!):

RuntimeError: CUDA error: an illegal memory access was encountered

yazdanbakhsh avatar May 22 '25 18:05 yazdanbakhsh

Well, I just launched the verl-provided SFT basic example script...

vadimkantorov avatar May 22 '25 20:05 vadimkantorov

@vadimkantorov and? do you see this warning? do you use offloading?

I still get the following warning: You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').

yazdanbakhsh avatar May 22 '25 21:05 yazdanbakhsh

yes, I then get this warning. I'm worried that something's running suboptimally (e.g. in fp32)

vadimkantorov avatar May 23 '25 19:05 vadimkantorov

from verl/workers/fsdp_workers.py.

        torch_dtype = fsdp_config.get("model_dtype", None)
        if torch_dtype is None:
            torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
        else:
            torch_dtype = PrecisionType.to_dtype(torch_dtype)

...           
 actor_module = actor_module_class.from_pretrained(
                pretrained_model_name_or_path=local_path,
                torch_dtype=torch_dtype,
                config=actor_model_config,
                attn_implementation="flash_attention_2",
                trust_remote_code=trust_remote_code,
            )

If you want to use bf16, add 2 options in training script:

    ++actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
    ++critic.model.fsdp_config.model_dtype=bfloat16 \

@chenhaiq Does training with bf16 degrade the performance of the model? Is this effect acceptable?

SikaStar avatar May 25 '25 16:05 SikaStar

from verl/workers/fsdp_workers.py.

        torch_dtype = fsdp_config.get("model_dtype", None)
        if torch_dtype is None:
            torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
        else:
            torch_dtype = PrecisionType.to_dtype(torch_dtype)

...           
 actor_module = actor_module_class.from_pretrained(
                pretrained_model_name_or_path=local_path,
                torch_dtype=torch_dtype,
                config=actor_model_config,
                attn_implementation="flash_attention_2",
                trust_remote_code=trust_remote_code,
            )

If you want to use bf16, add 2 options in training script:

    ++actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
    ++critic.model.fsdp_config.model_dtype=bfloat16 \

@chenhaiq Does training with bf16 degrade the performance of the model? Is this effect acceptable?

The option is used in generating text with inference engine. If the base model's tensor type is bf16, no side effect to training.

chenhaiq avatar May 30 '25 06:05 chenhaiq

So the conclusion is?

Do we add these two lines, or simply ignore the warning?

tjoymeed avatar Jun 08 '25 23:06 tjoymeed

What about the warning:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').

I believe that for me it leads to an illegal memory access (guessing!):

RuntimeError: CUDA error: an illegal memory access was encountered

Same problem here.

SeiunSky0131 avatar Jun 15 '25 14:06 SeiunSky0131

What about the warning:警告呢: You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').您正在尝试将 Flash Attention 2.0 与未在 GPU 上初始化的模型一起使用。使用 model.to('cuda') 在 CPU 上初始化模型后,请确保将模型移动到 GPU。 I believe that for me it leads to an illegal memory access (guessing!):我相信这对我来说会导致非法的内存访问(猜测! RuntimeError: CUDA error: an illegal memory access was encounteredRuntimeError:CUDA 错误:遇到非法内存访问

Same problem here.  这里也是同样的问题。

Same problem here.

momofive avatar Jun 17 '25 10:06 momofive

I still get the following warning: You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').

Have you found a solution to this warning?

tobiashaab avatar Jun 29 '25 10:06 tobiashaab

from verl/workers/fsdp_workers.py.

        torch_dtype = fsdp_config.get("model_dtype", None)
        if torch_dtype is None:
            torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
        else:
            torch_dtype = PrecisionType.to_dtype(torch_dtype)

...           
 actor_module = actor_module_class.from_pretrained(
                pretrained_model_name_or_path=local_path,
                torch_dtype=torch_dtype,
                config=actor_model_config,
                attn_implementation="flash_attention_2",
                trust_remote_code=trust_remote_code,
            )

If you want to use bf16, add 2 options in training script:

    ++actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
    ++critic.model.fsdp_config.model_dtype=bfloat16 \

Could not override 'actor_rollout_ref.actor.fsdp_config.model_dtype'. To append to your config use +actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 Key 'model_dtype' is not in struct full_key: actor_rollout_ref.actor.fsdp_config.model_dtype object_type=dict

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

sz128 avatar Jul 04 '25 03:07 sz128

Same problem here! Any updates?

yizhouzhao avatar Aug 07 '25 22:08 yizhouzhao

This log can be ignored. It should use fp32 to enable fsdp optimizer using fp32.

It is not recommended to use actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16

chenhaiq avatar Aug 08 '25 09:08 chenhaiq