torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

fix mixed precision for `replicate` / pure DDP

Open 152334H opened this issue 1 year ago • 7 comments

Hi. I noticed the following:

  1. the keyword autocast does not exist in the repository
  2. MixedPrecisionConfig is only used in the fully_shard codepath
  3. the duration of a dummy 1000 step run is a lot longer with DDP than with FSDP

All of the above indicates that, when dp_shard_enabled is false, training runs with pure fp32, regardless of the mixed precision config.

This pull request changes the code to use torch.autocast in the training forward pass, specifically only when dp_shard_enabled is false, to the dtype of mixed_precision_param.

152334H avatar Sep 29 '24 07:09 152334H

If mixed_precision_reduce ever has arguments beyond float32, it may be worth adding an appropriate compression hook after replicate as well, e.g.

diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index fc26703..bdbcbb2 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -344,5 +344,7 @@ def apply_ddp(
             torch._dynamo.config.optimize_ddp = "ddp_optimizer"

     replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
+    from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import bf16_compress_hook
+    model.register_comm_hook(dp_mesh.get_group(), bf16_compress_hook)

     logger.info("Applied DDP to the model")

152334H avatar Sep 29 '24 07:09 152334H

might be nice to fold this as part of the train_context

awgu avatar Sep 29 '24 15:09 awgu

@awgu it can't be folded into the train context unless you're certain autocasting the backward will lead to no problems. According to torch documentation,

autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.

Personally, in other projects, I have experienced issues where torch.compiling a network with autocast over the backwards inexplicably causes NaNs to appear. I don't know whether that applies in this case but will take any overruling if people are certain it will work without problems.

152334H avatar Sep 29 '24 18:09 152334H

cc: @tianyu-l on thoughts on how to handle this perhaps separate forward and backward contexts

awgu avatar Sep 30 '24 12:09 awgu

Hmm I didn't have much context on this.

@fegin How is DDP supposed to handle mixed precision? Is AMP autocasting suggested?

tianyu-l avatar Sep 30 '24 19:09 tianyu-l

I think Rohan added a tentative mixed precision API for DDP, but it never made it to public feature. I think using AMP is probably the way to go.

awgu avatar Oct 01 '24 03:10 awgu

autocast is the right way for DDP.

fegin avatar Oct 01 '24 17:10 fegin

fixed in https://github.com/pytorch/torchtitan/pull/1303

tianyu-l avatar Jun 17 '25 00:06 tianyu-l