Accelerate mixed torch.Tensor and DTensor error when using TE FP8 and FSDP/TP
System Info
- `Accelerate` version: 1.11.0
- Platform: Linux-5.15.0-1071-nvidia-x86_64-with-glibc2.39
- `accelerate` bash location: /app/.venv/bin/accelerate
- Python version: 3.12.12
- Numpy version: 2.3.4
- PyTorch version: 2.8.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 2015.56 GB
- GPU type: NVIDIA H200
- `Accelerate` default config:
Not found
Information
- [ ] The official example scripts
- [x] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [x] My own task or dataset (give details below)
Reproduction
When using TE FP8 and FSDP/TP with a Llama style model I get the following error during accelerate.prepare(). My code basically follows exactly the guide here: https://huggingface.co/blog/accelerate-nd-parallel.
FP8 works with FSDP and the other types of parallelism fine, its just when setting TP=2 does it fall over. I want to use TP=2 as the size of the model I am using and the number of GPUs I have it isn't possible to train with just FSDP alone.
fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
state_dict_type="SHARDED_STATE_DICT",
reshard_after_forward=cfg.reshard_after_forward,
)
parallelism_config = ParallelismConfig(
tp_size=2,
dp_replicate_size=1,
dp_shard_size=8,
)
handlers = [
TERecipeKwargs(
fp8_format="HYBRID
amax_history_len=32,
amax_compute_algo=max,
override_linear_precision=(False, False, False),
)
]
model_kwargs = (
{"tp_size": cfg.tp_size, "tp_plan": "auto", "device_mesh": accelerator.torch.device_mesh}
if cfg.tp_size > 1
else {}
)
logger.info(f"Model kwargs: {model_kwargs}")
logger.info(f"Device mesh: {device_mesh}")
model = AutoModelForCausalLM.from_pretrained(
cfg.model.base_model_path,
attn_implementation="flash_attention_3",
dtype=torch.bf16,
**model_kwargs,
)
Traceback (most recent call last):
File "/app/src/train.py", line 558, in main
_train_with_accelerate(train_cfg)
File "/app/src/train.py", line 300, in _train_with_accelerate
model, optimizer, train_loader, eval_loader, lr_scheduler = accelerator.prepare(
^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 1547, in prepare
args = self._prepare_te(*args)
^^^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 2055, in _prepare_te
convert_model(model)
File "/app/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py", line 87, in convert_model
convert_model(
File "/app/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py", line 87, in convert_model
convert_model(
File "/app/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py", line 87, in convert_model
convert_model(
[Previous line repeated 1 more time]
File "/app/.venv/lib/python3.12/site-packages/accelerate/utils/transformer_engine.py", line 55, in convert_model
te_module.weight.copy_(module.weight)
File "/app/.venv/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
return DTensor._op_dispatcher.dispatch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 151, in dispatch
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 350, in unwrap_to_op_info
self._try_replicate_spec_for_scalar_tensor(
File "/app/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 452, in _try_replicate_spec_for_scalar_tensor
raise RuntimeError(
RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
Expected behavior
FP8 with TE + FSDP/TP to work successfully.
For FSDP2 and FP8, I recommend using torchao integration. You can check the example file fsdp2_fp8.py . Try adding TP to see if this works
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.