fix mixed precision for `replicate` / pure DDP
Hi. I noticed the following:
- the keyword
autocastdoes not exist in the repository - MixedPrecisionConfig is only used in the
fully_shardcodepath - the duration of a dummy 1000 step run is a lot longer with DDP than with FSDP
All of the above indicates that, when dp_shard_enabled is false, training runs with pure fp32, regardless of the mixed precision config.
This pull request changes the code to use torch.autocast in the training forward pass, specifically only when dp_shard_enabled is false, to the dtype of mixed_precision_param.
If mixed_precision_reduce ever has arguments beyond float32, it may be worth adding an appropriate compression hook after replicate as well, e.g.
diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index fc26703..bdbcbb2 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -344,5 +344,7 @@ def apply_ddp(
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
+ from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import bf16_compress_hook
+ model.register_comm_hook(dp_mesh.get_group(), bf16_compress_hook)
logger.info("Applied DDP to the model")
might be nice to fold this as part of the train_context
@awgu it can't be folded into the train context unless you're certain autocasting the backward will lead to no problems. According to torch documentation,
autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.
Personally, in other projects, I have experienced issues where torch.compiling a network with autocast over the backwards inexplicably causes NaNs to appear. I don't know whether that applies in this case but will take any overruling if people are certain it will work without problems.
cc: @tianyu-l on thoughts on how to handle this perhaps separate forward and backward contexts
Hmm I didn't have much context on this.
@fegin How is DDP supposed to handle mixed precision? Is AMP autocasting suggested?
I think Rohan added a tentative mixed precision API for DDP, but it never made it to public feature. I think using AMP is probably the way to go.
autocast is the right way for DDP.
fixed in https://github.com/pytorch/torchtitan/pull/1303