trlx icon indicating copy to clipboard operation
trlx copied to clipboard

Error when running Ray Tune to launch hyperparameter sweep

Open Jing-L97 opened this issue 1 year ago • 1 comments

🐛 Describe the bug

Hi we encountered the DistributedDataParallel issue when running the example code with Ray Optimization, in which we set the Distributed Type: no:

ray start --head --port=6379 python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.py

Here's the Traceback Error that we encountered

Traceback (most recent call last):
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/_private/worker.py", line 2524, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::_Inner.train() (pid=1885930, ip=10.20.0.6, actor_id=6d08bc117a6b35cc7647003f01000000, repr=AccelerateTrainer)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 375, in train
    raise skipped from exception_cause(skipped)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/_internal/utils.py", line 54, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(AttributeError): ray::_RayTrainWorker__execute.get_next() (pid=1886047, ip=10.20.0.6, actor_id=dd5dcbaf834905aa00b49be601000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7f3d4c6d20a0>)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/_internal/worker_group.py", line 32, in __execute
    raise skipped from exception_cause(skipped)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/huggingface/accelerate/accelerate_trainer.py", line 411, in _accelerate_train_loop_per_worker
    return train_loop_per_worker(*args, **kwargs)
  File "/scratch2/jliu/CF_RL/scripts/trlx/examples/ppo_sentiments.py", line 47, in main
    trlx.train(
  File "/scratch2/jliu/CF_RL/scripts/trlx/trlx/trlx.py", line 92, in train
    trainer = get_trainer(config.train.trainer)(
  File "/scratch2/jliu/CF_RL/scripts/trlx/trlx/trainer/accelerate_ppo_trainer.py", line 74, in __init__
    if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DistributedDataParallel' object has no attribute 'peft_type'

The same error occurred when we changed the config file into the iml setting below

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Thank you very much!

Which trlX version are you using?

https://github.com/CarperAI/trlx/tree/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92

Additional system and package information

transformers==4.32.0,python==3.9

Jing-L97 avatar Jul 26 '24 11:07 Jing-L97

There seem to be an issue with the if statements on line 74, 398 and 424 in the trlx/trainer/accelerate_ppo_trainer.py file.

The check for self.model.peft_type should be made with hasattr like this:

if ... and hasattr(self.model, "peft_type")

arxaqapi avatar Jul 26 '24 13:07 arxaqapi