accelerate
accelerate copied to clipboard
PyTorch FSDP with Accelerate
Hi all,
I was wondering if you could give any input on whether the standard PyTorch FSDP wrapper was compatible with Huggingface accelerate.prepare()?
For example:
import torch
from accelerate import Accelerator
from torch.distributed.fsdp import (
FullyShardedDataParallel,
)
accelerator = Accelerator(bf16=True)
model = MyModel().to(accelerator.device)
fsdp_model = FullyShardedDataParallel(model)
fsdp_model = accelerator.prepare(fsdp_model)
optimizer = MyOpt()
optimizer, train_dataloader, val_dataloader = accelerator.prepare(
optimizer, train_dataloader, val_dataloader
)
Thank you,
Enrico
Yes, if the model is already wrapped in FullyShardedDataParallel , accelerator.prepare will just return the same
Yes, if the model is already wrapped in
FullyShardedDataParallel,accelerator.preparewill just return the same
Thank you for the immediate response.
I currently have the accelerate config as:
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_process_ip: ''
main_process_port: ''
main_training_function: main
megatron_lm_config: {}
mixed_precision: bf16
num_machines: 8
num_processes: 64
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Would this configuration require any alterations to work with PyTorch FullyShardedDataParallel?
Does Mixed Precision need to be set as bf16 in the FSDP wrapper? Or does applying it in the Accelerate already handle this case? Does anything need to be added to fsdp_config if it is already set in the code?
check_fn = lambda submodule: isinstance(submodule, ParallelTransformerBlock)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
palm_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
ParallelTransformerBlock,
},
)
bf16_fsdp = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
model = FullyShardedDataParallel(
model,
auto_wrap_policy=palm_auto_wrap_policy,
mixed_precision=bf16_fsdp,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
sharding_strategy=ShardingStrategy.FULL_SHARD,
forward_prefetch=True,
use_orig_params=True,
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
model = accelerator.prepare(model)
Thank you,
Enrico
Additionally, I am having an issue when trying to resume from a model checkpoint with use_orig_params=True:
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
I am receiving this error:
main()
File "/train_distributed.py", line 545, in main
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
File "/lib/python3.9/site-packages/accelerate/accelerator.py", line 2439, in load_state
self.state.fsdp_plugin.load_optimizer(self, opt, self._models[i], input_dir, i)
File "/lib/python3.9/site-packages/accelerate/utils/dataclasses.py", line 951, in load_optimizer
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1581, in scatter_full_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1236, in _optim_state_dict_to_load_impl
flat_osd = _flatten_optim_state_dict(
File "/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 359, in _flatten_optim_state_dict
assert (
AssertionError: If use_orig_params is True, shard_state must be True.
I am unsure where shard_state should be set to True as I do not see it in the FSDP documentation.
Thank you,
Enrico
This has not been resolved and the PR addressing it was automatically closed.
Hello @conceptofmind, accelerate returns the model as is if it is already an instance of FullyShardedDataParallel . Refer this line: https://github.com/huggingface/accelerate/blob/50eabe5b1d3cde24a9f3b65b6e7e25075b269da4/src/accelerate/accelerator.py#L1302
Regarding the use_orig_params, replied on the relevant issues
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.