verl icon indicating copy to clipboard operation
verl copied to clipboard

Support FSDP worker and vLLM Ascend

Open sunyi0505 opened this issue 9 months ago • 6 comments

This pr is committed for supporting Ascend NPU backend. Co-authored-by: Chendong98 [email protected] Co-authored-by: zheliuyu [email protected] Co-authored-by: celestialli [email protected] In this pr, we add the capability to determine the type of NPU device and we also add a new script for training on NPU.

These are change lists:

  1. pyproject.toml change verison of vllm
  2. requirements-npu.txt requirements for NPU
  3. verl/bert_padding.py Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
  4. verl/single_controller/ray/base.py
  5. verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py
  6. verl/trainer/fsdp_sft_trainer.py
  7. verl/utils/flops_counter.py
  8. verl/utils/fsdp_utils.py
  9. verl/workers/actor/dp_actor.py
  10. verl/workers/critic/dp_critic.py
  11. verl/workers/fsdp_workers.py
  12. verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
  13. verl/workers/sharding_manager/fsdp_vllm.py
  14. verl/utils/device.py get device type for different device
  15. docs/ascend/ascend.md

Here are our roadmap:

RoadMap

  • [x] sft
  • [x] ppo
  • [x] grpo

News

[2025.03.03] Modify the adaptation method of Ray

[2025.02.25] The PPO algorithm is supported for training on NPU with the FSDP backend.

[2025.02.23] The SFT algorithm is supported for training on NPU with the FSDP backend.

[2025.02.21] The GRPO algorithm is supported for training on NPU with the FSDP backend.

Requirements We use this PR testing on Ascend NPU and GPU to ensure the same codes can run on different devices. The device information is 8 Atlas 800T A2 and 8 A100. Other software information is shown in the following table.

Software Version
transformers 4.47.1
accelerate 1.3.0
torch_npu 2.5.1.rc1
CANN 8.1.RC1 (Not Released)

About mean error Due to differences in hardware structure, we cannot guarantee that the loss of Ascend NPU is exactly the same as that of the GPU. According to our experience, the loss differences less than 2% is acceptable. If the loss difference is greater than 2%, we will try to fix it. The calculation formula is as follows. loss_comparison

N represents the number of training steps. For more information, please refer to Calculation accuracy description

sunyi0505 avatar Feb 21 '25 03:02 sunyi0505

does this pr work on multi nodes?

huangk10 avatar Feb 21 '25 06:02 huangk10

does this pr work on multi nodes?

I am currently conducting tests on a single node only, and will subsequently supplement with multi-node testing results.

sunyi0505 avatar Feb 21 '25 07:02 sunyi0505

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Feb 26 '25 00:02 CLAassistant

@as12138 Hello. Thanks for your efforts! Can this PR be directly implemented on 910b2c 64GB now?

takagi97 avatar Mar 05 '25 07:03 takagi97

@as12138 Hello. Thanks for your efforts! Can this PR be directly implemented on 910b2c 64GB now?

I tested it on the Atlas 800T A2 64G ASCEND+ARM, and it passed, If you're interested, you can verify it, and if you encounter any issues, please feel free to reach out to me.

sunyi0505 avatar Mar 05 '25 07:03 sunyi0505

@as12138 Hello. Thanks for your efforts! Can this PR be directly implemented on 910b2c 64GB now?

I tested it on the Atlas 800T A2 64G ASCEND+ARM, and it passed, If you're interested, you can verify it, and if you encounter any issues, please feel free to reach out to me.

Thank you for your quick response! I will try it.

takagi97 avatar Mar 05 '25 09:03 takagi97

@eric-haibin-lin can you review the pr?

sunyi0505 avatar Mar 13 '25 08:03 sunyi0505

Is CANN 8.1.RC1 (Not Released) mandatory? Have you tested PPO and GRPO on 8.0.RC3?

jianzhnie avatar Mar 25 '25 07:03 jianzhnie

Is CANN 8.1.RC1 (Not Released) mandatory? Have you tested PPO and GRPO on 8.0.RC3?

It is mandatory, I have not tested it on 8.0.RC3

sunyi0505 avatar Mar 26 '25 01:03 sunyi0505

use_remove_padding is not supported on ASCEND NPU now.

sunyi0505 avatar Mar 27 '25 09:03 sunyi0505

is CANN 8.1.RC1 released now?

It is not released now.

sunyi0505 avatar Mar 31 '25 02:03 sunyi0505

We have tested SFT, GRPO algorithms on Ascend NPU currently.

Due to differences in hardware structure, we cannot guarantee that the loss of Ascend NPU is exactly the same as that of the GPU. According to our experience, the loss differences less than 2% is acceptable. If the loss difference is greater than 2%, we will try to fix it, the critic/rewards/mean differences less than 4% is acceptable. If the critic/rewards/mean difference is greater than 4%, we will try to fix it. The calculation formula is as follows. loss_comparison

N represents the number of training steps.

Software Version
transformers 4.49.0
torch_npu 2.5.1.rc1
CANN 8.1.RC1 (Not Released)

Here are training scripts and loss comparison graphs.

For SFT:

# Tested with 1 & 8 NPUs

set -x

if [ "$#" -lt 2 ]; then
    echo "Usage: run_qwen_05_peft.sh <nproc_per_node> <save_path> [other_configs...]"
    exit 1
fi

nproc_per_node=$1
save_path=$2

# Shift the arguments so $@ refers to the rest
shift 2

torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
     -m verl.trainer.fsdp_sft_trainer \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.prompt_key=extra_info \
    data.response_key=extra_info \
    data.train_batch_size=512 \
    optim.lr=1e-4 \
    +data.prompt_dict_keys=['question'] \
    +data.response_dict_keys=['answer'] \
    data.micro_batch_size_per_gpu=4 \
    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
    trainer.default_local_dir=$save_path \
    trainer.project_name=gsm8k-sft \
    trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \
    trainer.logger=['console'] \
    trainer.total_epochs=1 \
    trainer.default_hdfs_dir=null $@ \
    model.lora_rank=32\
    model.lora_alpha=16 \
    model.target_modules=all-linear

    # Or you can do this:
    # model.target_modules=[q_proj,v_proj] \

sft

For GRPO:

Parameters change information:

  • data.train_batch_size 1024 -> 16
  • actor_rollout_ref.actor.optim.lr 1e-6 -> 5e-7
  • critic.optim.lr 1e-5 -> 9e-6
  • actor_rollout_ref.actor.ppo_max_token_len_per_gpu 16384 -> 2048
  • actor_rollout_ref.model.use_remove_padding True -> False
  • actor_rollout_ref.actor.ppo_mini_batch_size 256 -> 64
  • actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu 80 -> 8
  • actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu 160 -> 80
  • actor_rollout_ref.rollout.tensor_model_parallel_size 2 -> 4
  • actor_rollout_ref.rollout.gpu_memory_utilization 0.6 -> 0.2
  • actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu 160 -> 80
  • actor_rollout_ref.rollout.enable_chunked_prefill True -> False
  • trainer.nnodes 1 -> 2
# Tested with 2 & 8 NPUs
set -x

export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=16 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=5e-7 \
    critic.optim.lr=9e-6 \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=2048 \
    actor_rollout_ref.model.use_remove_padding=False \
    actor_rollout_ref.actor.ppo_mini_batch_size=64 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=80 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=80 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    algorithm.kl_ctrl.kl_coef=0.001 \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=2\
    trainer.save_freq=-1 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

critic/rewards/mean: grpo

sunyi0505 avatar Mar 31 '25 11:03 sunyi0505

Get error while saving checkpoint

Traceback (most recent call last):
  File "/third_party/verl/verl/trainer/main_ppo.py", line 55, in main
    run_ppo(config)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 72, in run_ppo
    ray.get(main_task.remote(config))
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 2782, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 929, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::main_task() (pid=201983, ip=0.0.0.0)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 172, in main_task
    trainer.fit()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 926, in fit
    self._save_checkpoint()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 669, in _save_checkpoint
    self.actor_rollout_wg.save_checkpoint(actor_local_path,
  File "/third_party/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_save_checkpoint() (pid=202532, ip=0.0.0.0, actor_id=a42bf3481fb6488100622daa07000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xffd0073ae1a0>)
  File "/third_party/verl/verl/single_controller/ray/base.py", line 429, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/third_party/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/third_party/verl/verl/workers/fsdp_workers.py", line 604, in save_checkpoint
    self.checkpoint_manager.save_checkpoint(local_path=local_path,
  File "/third_party/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py", line 123, in save_checkpoint
    model_state_dict = self.model.state_dict()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  [Previous line repeated 1 more time]
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2225, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 724, in _post_state_dict_hook
    processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 569, in _sharded_post_state_dict_hook
    return _common_unshard_post_state_dict_hook(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 238, in _common_unshard_post_state_dict_hook
    param_hook(state_dict, prefix, fqn)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 566, in param_hook
    sharded_tensor = sharded_tensor.cpu()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
    raise RuntimeError(
RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

TaoTQ avatar Apr 10 '25 03:04 TaoTQ

Get error while saving checkpoint

Traceback (most recent call last):
  File "/third_party/verl/verl/trainer/main_ppo.py", line 55, in main
    run_ppo(config)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 72, in run_ppo
    ray.get(main_task.remote(config))
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 2782, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 929, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::main_task() (pid=201983, ip=0.0.0.0)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 172, in main_task
    trainer.fit()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 926, in fit
    self._save_checkpoint()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 669, in _save_checkpoint
    self.actor_rollout_wg.save_checkpoint(actor_local_path,
  File "/third_party/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_save_checkpoint() (pid=202532, ip=0.0.0.0, actor_id=a42bf3481fb6488100622daa07000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xffd0073ae1a0>)
  File "/third_party/verl/verl/single_controller/ray/base.py", line 429, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/third_party/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/third_party/verl/verl/workers/fsdp_workers.py", line 604, in save_checkpoint
    self.checkpoint_manager.save_checkpoint(local_path=local_path,
  File "/third_party/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py", line 123, in save_checkpoint
    model_state_dict = self.model.state_dict()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  [Previous line repeated 1 more time]
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2225, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 724, in _post_state_dict_hook
    processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 569, in _sharded_post_state_dict_hook
    return _common_unshard_post_state_dict_hook(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 238, in _common_unshard_post_state_dict_hook
    param_hook(state_dict, prefix, fqn)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 566, in param_hook
    sharded_tensor = sharded_tensor.cpu()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
    raise RuntimeError(
RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

I think this issue is caused by torch.compile. To avoid this problem, you can add actor_rollout_ref.actor.use_torch_compile=False in the script. Thank you very much for your feedback.

sunyi0505 avatar Apr 10 '25 03:04 sunyi0505

Get error while saving checkpoint

Traceback (most recent call last):
  File "/third_party/verl/verl/trainer/main_ppo.py", line 55, in main
    run_ppo(config)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 72, in run_ppo
    ray.get(main_task.remote(config))
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 2782, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 929, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::main_task() (pid=201983, ip=0.0.0.0)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 172, in main_task
    trainer.fit()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 926, in fit
    self._save_checkpoint()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 669, in _save_checkpoint
    self.actor_rollout_wg.save_checkpoint(actor_local_path,
  File "/third_party/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_save_checkpoint() (pid=202532, ip=0.0.0.0, actor_id=a42bf3481fb6488100622daa07000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xffd0073ae1a0>)
  File "/third_party/verl/verl/single_controller/ray/base.py", line 429, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/third_party/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/third_party/verl/verl/workers/fsdp_workers.py", line 604, in save_checkpoint
    self.checkpoint_manager.save_checkpoint(local_path=local_path,
  File "/third_party/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py", line 123, in save_checkpoint
    model_state_dict = self.model.state_dict()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  [Previous line repeated 1 more time]
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2225, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 724, in _post_state_dict_hook
    processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 569, in _sharded_post_state_dict_hook
    return _common_unshard_post_state_dict_hook(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 238, in _common_unshard_post_state_dict_hook
    param_hook(state_dict, prefix, fqn)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 566, in param_hook
    sharded_tensor = sharded_tensor.cpu()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
    raise RuntimeError(
RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

I think this issue is caused by torch.compile. To avoid this problem, you can add actor_rollout_ref.actor.use_torch_compile=False in the script. Thank you very much for your feedback.

This does not seem to be the root cause. I just tried. Issue still there

TaoTQ avatar Apr 10 '25 06:04 TaoTQ

Get error while saving checkpoint

Traceback (most recent call last):
  File "/third_party/verl/verl/trainer/main_ppo.py", line 55, in main
    run_ppo(config)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 72, in run_ppo
    ray.get(main_task.remote(config))
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 2782, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 929, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::main_task() (pid=201983, ip=0.0.0.0)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 172, in main_task
    trainer.fit()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 926, in fit
    self._save_checkpoint()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 669, in _save_checkpoint
    self.actor_rollout_wg.save_checkpoint(actor_local_path,
  File "/third_party/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_save_checkpoint() (pid=202532, ip=0.0.0.0, actor_id=a42bf3481fb6488100622daa07000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xffd0073ae1a0>)
  File "/third_party/verl/verl/single_controller/ray/base.py", line 429, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/third_party/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/third_party/verl/verl/workers/fsdp_workers.py", line 604, in save_checkpoint
    self.checkpoint_manager.save_checkpoint(local_path=local_path,
  File "/third_party/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py", line 123, in save_checkpoint
    model_state_dict = self.model.state_dict()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  [Previous line repeated 1 more time]
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2225, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 724, in _post_state_dict_hook
    processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 569, in _sharded_post_state_dict_hook
    return _common_unshard_post_state_dict_hook(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 238, in _common_unshard_post_state_dict_hook
    param_hook(state_dict, prefix, fqn)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 566, in param_hook
    sharded_tensor = sharded_tensor.cpu()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
    raise RuntimeError(
RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

I think this issue is caused by torch.compile. To avoid this problem, you can add actor_rollout_ref.actor.use_torch_compile=False in the script. Thank you very much for your feedback.

This does not seem to be the root cause. I just tried. Issue still there

This is an adaptation issue on NPU. You can modify lines 92 and 93 of the verl/utilils/checkpoint/fsdp_checkpoint_manager.py file to offload_to_cpu=False to avoid this problem

sunyi0505 avatar Apr 10 '25 09:04 sunyi0505

Get error while saving checkpoint

Traceback (most recent call last):
  File "/third_party/verl/verl/trainer/main_ppo.py", line 55, in main
    run_ppo(config)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 72, in run_ppo
    ray.get(main_task.remote(config))
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 2782, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/python3.10/lib/python3.10/site-packages/ray/_private/worker.py", line 929, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::main_task() (pid=201983, ip=0.0.0.0)
  File "/third_party/verl/verl/trainer/main_ppo.py", line 172, in main_task
    trainer.fit()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 926, in fit
    self._save_checkpoint()
  File "/third_party/verl/verl/trainer/ppo/ray_trainer.py", line 669, in _save_checkpoint
    self.actor_rollout_wg.save_checkpoint(actor_local_path,
  File "/third_party/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_save_checkpoint() (pid=202532, ip=0.0.0.0, actor_id=a42bf3481fb6488100622daa07000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xffd0073ae1a0>)
  File "/third_party/verl/verl/single_controller/ray/base.py", line 429, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/third_party/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/third_party/verl/verl/workers/fsdp_workers.py", line 604, in save_checkpoint
    self.checkpoint_manager.save_checkpoint(local_path=local_path,
  File "/third_party/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py", line 123, in save_checkpoint
    model_state_dict = self.model.state_dict()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2219, in state_dict
    module.state_dict(
  [Previous line repeated 1 more time]
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2225, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 724, in _post_state_dict_hook
    processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 569, in _sharded_post_state_dict_hook
    return _common_unshard_post_state_dict_hook(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 238, in _common_unshard_post_state_dict_hook
    param_hook(state_dict, prefix, fqn)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 566, in param_hook
    sharded_tensor = sharded_tensor.cpu()
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
  File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
    raise RuntimeError(
RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

I think this issue is caused by torch.compile. To avoid this problem, you can add actor_rollout_ref.actor.use_torch_compile=False in the script. Thank you very much for your feedback.

This does not seem to be the root cause. I just tried. Issue still there

This is an adaptation issue on NPU. You can modify lines 92 and 93 of the verl/utilils/checkpoint/fsdp_checkpoint_manager.py file to offload_to_cpu=False to avoid this problem

You maybe get incorrect line numbers, but I get what you mean. Here is the change I made:

    def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs):
        # record the previous global step
        self.previous_global_step = global_step

        # remove previous local_path
        # TODO: shall we remove previous ckpt every save?
        if remove_previous_ckpt:
            self.remove_previous_save_local_path()
        local_path = self.local_mkdir(local_path)
        torch.distributed.barrier()

        # every rank will save its own model and optim shard
        state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=False)  # Change it to False
        optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=False)  # Change it to False
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
                model_state_dict = self.model.state_dict()
                if self.optimizer is not None:
                    optimizer_state_dict = self.optimizer.state_dict()
                else:
                    optimizer_state_dict = None
                if self.lr_scheduler is not None:
                    lr_scheduler_state_dict = self.lr_scheduler.state_dict()
                else:
                    lr_scheduler_state_dict = None

                extra_state_dict = {
                    'lr_scheduler': lr_scheduler_state_dict,
                    'rng': self.get_rng_state(),
                }
                model_path = os.path.join(local_path, f'model_world_size_{self.world_size}_rank_{self.rank}.pt')
                optim_path = os.path.join(local_path, f'optim_world_size_{self.world_size}_rank_{self.rank}.pt')
                extra_path = os.path.join(local_path, f'extra_state_world_size_{self.world_size}_rank_{self.rank}.pt')

                print(f'[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}')
                print(f'[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}')
                print(f'[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}')
                torch.save(model_state_dict, model_path)
                torch.save(optimizer_state_dict, optim_path)  # TODO: address optimizer is None
                torch.save(extra_state_dict, extra_path)

        # wait for everyone to dump to local
        torch.distributed.barrier()

It finally works!

TaoTQ avatar Apr 10 '25 10:04 TaoTQ

你好,我使用该分支尝试了18 NPU 和 28 NPU两种配置 Qwen2-7B GRPO报错 vllm正常 2*8npu 好像所有的8个进程都跑到一张卡上 torch 2.5.1 torch-npu 2.5.1.dev20250320 verl 0.2.0.dev0 vllm 0.7.1+empty vllm_ascend 0.7.1rc2.dev0+gf17417f.d20250421
cann 8.0.0 36m(TaskRunner pid=467613)[0m Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'):[36mray::WorkerDict.ref_init_model()[39m (pid=468350, ip=, actor_id=4e3d02723d52e239fe80fe0102000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xfffc413aebf0>) [36m(TaskRunner pid=467613)[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/single_controller/ray/base.py", line 429, in func [36m(TaskRunner pid=467613)[0m return getattr(self.worker_dict[key], name)(*args, **kwargs) [36m(TaskRunner pid=467613)[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/single_controller/base/decorator.py", line 404, in inner [36m(TaskRunner pid=467613)[0m return func(*args, **kwargs) [36m(TaskRunner pid=467613)[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/workers/fsdp_workers.py", line 422, in init_model [36m(TaskRunner pid=467613)[0m self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, [36m(TaskRunner pid=467613)[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/workers/fsdp_workers.py", line 230, in build_model_optimizer [36m(TaskRunner pid=467613)[0m torch.distributed.barrier() 36m(TaskRunner pid=467613)[0m File "/home/ma-user/anaconda3/envs/verl/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper [36m(TaskRunner pid=467613)[0m return func(*args, **kwargs) [36m(TaskRunner pid=467613)[0m File "/home/ma-user/anaconda3/envs/verl/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4159, in barrier [36m(TaskRunner pid=467613)[0m work = group.barrier(opts=opts) [36m(TaskRunner pid=467613)[0m RuntimeError: create_config:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:102 HCCL function error: hcclCommInitRootInfoConfig(numRanks, &rootInfo, rank, config, &(comm->hcclComm)), error code is 7 [36m(TaskRunner pid=467613)[0m [ERROR] 2025-04-24-16:14:57 (PID:468350, Device:0, RankID:4) ERR02200 DIST call hccl api failed. [36m(TaskRunner pid=467613)[0m EJ0001: [PID: 468350] 2025-04-24-16:14:57.113.722 Failed to initialize the HCCP process. Reason: Maybe the last training process is running. [36m(TaskRunner pid=467613)[0m Solution: Wait for 10s after killing the last training process and try again.

WenderMa avatar Apr 24 '25 08:04 WenderMa

你好,我使用该分支尝试了1_8 NPU 和 2_8 NPU两种配置 Qwen2-7B GRPO报错 vllm正常 2*8npu 好像所有的8个进程都跑到一张卡上 torch 2.5.1 torch-npu 2.5.1.dev20250320 verl 0.2.0.dev0 vllm 0.7.1+empty vllm_ascend 0.7.1rc2.dev0+gf17417f.d20250421 cann 8.0.0 36m(TaskRunner pid=467613)�[0m Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'):[36mray::WorkerDict.ref_init_model()�[39m (pid=468350, ip=, actor_id=4e3d02723d52e239fe80fe0102000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xfffc413aebf0>) [36m(TaskRunner pid=467613)�[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/single_controller/ray/base.py", line 429, in func [36m(TaskRunner pid=467613)�[0m return getattr(self.worker_dict[key], name)(*args, **kwargs) [36m(TaskRunner pid=467613)�[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/single_controller/base/decorator.py", line 404, in inner [36m(TaskRunner pid=467613)�[0m return func(*args, **kwargs) [36m(TaskRunner pid=467613)�[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/workers/fsdp_workers.py", line 422, in init_model [36m(TaskRunner pid=467613)�[0m self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, [36m(TaskRunner pid=467613)�[0m File "/tmp/ray/session_2025-04-24_16-12-17_465991_446436/runtime_resources/working_dir_files/_ray_pkg_b36dcd4c14bf8643/verl/workers/fsdp_workers.py", line 230, in build_model_optimizer [36m(TaskRunner pid=467613)�[0m torch.distributed.barrier() 36m(TaskRunner pid=467613)�[0m File "/home/ma-user/anaconda3/envs/verl/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper [36m(TaskRunner pid=467613)�[0m return func(*args, **kwargs) [36m(TaskRunner pid=467613)�[0m File "/home/ma-user/anaconda3/envs/verl/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4159, in barrier [36m(TaskRunner pid=467613)�[0m work = group.barrier(opts=opts) [36m(TaskRunner pid=467613)�[0m RuntimeError: create_config:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:102 HCCL function error: hcclCommInitRootInfoConfig(numRanks, &rootInfo, rank, config, &(comm->hcclComm)), error code is 7 [36m(TaskRunner pid=467613)�[0m [ERROR] 2025-04-24-16:14:57 (PID:468350, Device:0, RankID:4) ERR02200 DIST call hccl api failed. [36m(TaskRunner pid=467613)�[0m EJ0001: [PID: 468350] 2025-04-24-16:14:57.113.722 Failed to initialize the HCCP process. Reason: Maybe the last training process is running. [36m(TaskRunner pid=467613)�[0m Solution: Wait for 10s after killing the last training process and try again.

我们只在cann8.1.rc1版本进行测试,请升级版本后重新运行,如果还是遇到问题,欢迎给我们反馈

sunyi0505 avatar Apr 25 '25 03:04 sunyi0505

transformers v4.51.4 starts to support ASCEND NPU to directly enable flash_attention_2. It seems that the transformers section in README needs to be adjusted.

zheliuyu avatar May 16 '25 09:05 zheliuyu

transformers v4.51.4 starts to support ASCEND NPU to directly enable flash_attention_2. It seems that the transformers section in README needs to be adjusted.

transformers v4.51.4 starts to support ASCEND NPU to directly enable flash_attention_2. It seems that the transformers section in README needs to be adjusted.

Thank you for your suggestion. I will make the necessary changes in the future.

sunyi0505 avatar May 16 '25 09:05 sunyi0505