xla icon indicating copy to clipboard operation
xla copied to clipboard

Most latency of weight gradient all-reduce is exposed

Open xrennvidia opened this issue 1 year ago • 2 comments

With the latest implementation of Latency Hiding Scheduling, we observe that most weight gradient all-reduce latency is still exposed. (ref slide 6 and 7 at here)

Here is a brief summary of our observations.

  • dgrads and wgrads are still scheduled separately for DP-only running (ref slide 6). This is suboptimal because in this way dgrads never can be used to overlap all-reduce kernels.
  • wgrads calculation order does not match with wgrads all-reduce order. Screenshot 2023-02-21 at 2 29 57 PM You can see that encoder wgrads calculation never get overlapped with any AR kernel. I think we can have an order (for example due to control-dependency) for AR kernels, but AR order should match with wgrad calculation order, so that we can start an AR right after its corresponding wgrads are calculated.
  • Currently, I see TP-AR and wgrads AR are scheduled in the same stream (all of them are in stream 47). I think it's better to put them in two different streams to avoid their interference. Otherwise, they can block each other and create some exposure. For example, in 11B T5X running, dgrads and wgrads compute are scheduled together in bwd (ref slide 7), but all wgrads AR are scheduled at the very end of bwd. If we can put wgrads AR in a different stream, they can start earlier (right after their corresponding wgrads are calculated).

Currently, exposed weight gradient all-reduce latency is 15%-25% of T5X training step runtime, so it's very critical to fix this issue. Thanks a lot.

xrennvidia avatar Feb 28 '23 03:02 xrennvidia

Tracked internally in b/270401274.

cheshire avatar Feb 28 '23 19:02 cheshire

The following list of things to try was suggested:

  1. We need to disable collective schedule linearizer and prove that its ok (may be only for SPMD partitioned models initially).
  2. We also need to disable collective combiner
  3. We may need to tune the latency hiding scheduler.

cheshire avatar Mar 14 '23 17:03 cheshire