DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

parallelize writing of layer checkpoint files across data parallel instances

Open adammoody opened this issue 4 years ago • 10 comments

This is work in progress, but I wanted to open it early for discussion. Also, I wrote this before MOE was added, and it will need to be updated to account for that. I can help with that if this approach is approved.

In case a pipeline stage has multiple layers, this parallelizes the task of writing the layer checkpoint files across the data parallel group. For example, if one is running with two data parallel instances, and if a pipeline stage has 10 layers, this modifies things so that rank 0 will write 5 layers and rank 1 will write the other 5, rather than have rank 0 do all of the work. On my system, this reduces checkpoint cost. It also better balances the total bytes written across ranks.

The main change is to have all procs call _save_checkpoint, and then in module_state_dict, the list of layers is subdivided among the procs that have the same model parallel rank across all data parallel instances.

adammoody avatar Sep 30 '21 19:09 adammoody

Hi all, I've got some time to circle back to this. I'm hoping someone on the team can take a look and provide some feedback when they get a chance.

I have rebased this PR on the latest code. It still needs to be adopted to handle MOE. I think the idea should extend to the MOE checkpoint path, as well.

In the meantime, here are some measurements to demonstrate the performance difference.

I am training a model using 16 processes on 4 nodes with the following configuration:

zero stage = 1
tensor parallelism = 2
pipeline parallelism = 1
num layers = 8
hidden size = 5120

This gives 8 data parallel instances of a model that has 1 pipeline stage. There are 8 layers in that single stage.

Before this PR, the processes that have rank 0 within the data parallel group write all 8 layers of the transformer. With this PR, all processes in the job write 1 layer of the transformer (8 layers / 8 data parallel instances).

I added timers into the checkpoint path. The two calls that use significant time are the cost to write the non-zero checkpoint files and the cost to write the zero checkpoint files, i.e., I've got timers to report the number of seconds around these two lines:

self._save_checkpoint(save_dir, tag, client_state=client_state)
self._save_zero_checkpoint(save_dir, tag)

The changes here improve the cost writing the non-zero checkpoint files.

Before this change, I see the following timings for 4 different checkpoints (units of seconds where not labeled):

18: 0: non_zero: 15.127218961715698
18: 0: zero:     4.272188663482666
18: 15: time (ms) | save-checkpoint: 19613.66

18: 0: non_zero: 15.118722915649414
18: 0: zero:     4.241654872894287
18: 15: time (ms) | save-checkpoint: 19414.78

18: 0: non_zero: 14.895861148834229
18: 0: zero:     2.4203057289123535
18: 15: time (ms) | save-checkpoint: 18793.07

18: 0: non_zero: 15.461049795150757
18: 0: zero:     4.352852821350098
18: 15: time (ms) | save-checkpoint: 19825.21

When using the optimization in this PR, I get the following instead:

15: 0: non_zero: 4.961373567581177
15: 0: zero:     6.336402893066406
15: 15: time (ms) | save-checkpoint: 11619.31

15: 0: non_zero: 3.244314193725586
15: 0: zero:     7.434791803359985
15: 15: time (ms) | save-checkpoint: 10897.41

15: 0: non_zero: 4.893619537353516
15: 0: zero:     5.75852370262146
15: 15: time (ms) | save-checkpoint: 14128.98

15: 0: non_zero: 5.354466676712036
15: 0: zero:     5.717851877212524
15: 15: time (ms) | save-checkpoint: 11235.55

The total checkpoint time drops from about 19 seconds to 11. The gains come in reducing the cost of writing the non_zero files, which drops from 15 seconds to 5, even though the cost to write the zero files seems to have bumped up a bit.

adammoody avatar Dec 18 '21 02:12 adammoody

@awan-10 , I see this PR has conflicts again. I'll take a look at refreshing it.

Would someone from the team have some time to look this over?

I think it could be useful to others, but if not, that'd be good to know too.

adammoody avatar Mar 31 '22 21:03 adammoody

@awan-10 , oh, this likely still needs to be updated for MoE. I started this before MoE was merged in. I'm trying to get a MoE example working, so that I can better follow its checkpoint path. I'm not quite there yet.

adammoody avatar Mar 31 '22 21:03 adammoody

@jeffra , I'm still fighting with my system to get a python+pytorch build that will let me run MoE. That could be a ways off, and it's hard to put a date on it. Aside from that, I updated this for the latest code and verified that it still works and performs as expected for the non-MoE case.

Would someone be willing to review this in its current state?

adammoody avatar Apr 12 '22 19:04 adammoody

@tjruwase , these changes reduce the checkpoint cost of some models. Is this something that could be worked in?

adammoody avatar May 12 '22 20:05 adammoody

Can one of the admins verify this patch?

rocm-mici avatar Jun 09 '22 20:06 rocm-mici

@stas00 , I noted that the checkpoints in the latest bigscience run were taking about 40 seconds.

https://github.com/bigscience-workshop/bigscience/blob/master/train/tr11-176B-ml/chronicles.md#2022-03-21

The changes in this patch may reduce that cost. This PR is out of date with the latest, though it might still apply cleanly for you depending on the version used in those runs. I also know this may come too late for the current run, but thought I'd throw it out there if you are planning more.

I have separate work that effectively helps reduce disk space consumed by the checkpoints, too. Combined with the speed improvement here, those two changes enable one checkpoint large training runs more frequently.

adammoody avatar Jul 07 '22 17:07 adammoody

Hi Adam,

Indeed, we have finished training 176B, so hopefully this version will accept your work.

In the case of JeanZay from my many experiments IO seems to be the bottleneck and not the CPU. In which cases a sequential writing of the data by one process might not be faster than doing the same from many processes to different places. Your way probably would still be a bit faster I guestimate.

Otherwise yours is definitely a super-smart idea!

I have separate work that effectively helps reduce disk space consumed by the checkpoints, too. Combined with the speed improvement here, those two changes enable one checkpoint large training runs more frequently.

I'm all ears, Do tell! (but probably let's discuss it in another Issue so that we don't derail your PR)

p.s. Always awesome to read about your performance innovations!

stas00 avatar Jul 08 '22 21:07 stas00

@tjruwase , these changes reduce the checkpoint cost of some models. Is this something that could be worked in?

Hi @adammoody, sorry that I did not respond to this earlier. This is a great contribution and is timely as we looking to improve the checkpointing features of DeepSpeed. I will take a closer look at this PR to get a better understanding.

tjruwase avatar Jul 15 '22 01:07 tjruwase

@tjruwase , these changes reduce the checkpoint cost of some models. Is this something that could be worked in?

Hi @adammoody, sorry that I did not respond to this earlier. This is a great contribution and is timely as we looking to improve the checkpointing features of DeepSpeed. I will take a closer look at this PR to get a better understanding.

Thanks @tjruwase . The ideas in here should still apply, but the PR itself has fallen out of sync with main and needs to be refreshed. I have found that this speeds up the I/O on our system, where one of the bottlenecks lies in the cost of calling torch.save().

A second benefit of this approach is that by more evenly spreading the bytes written across the ranks, one can gain more benefit from checkpoint libraries that write to node-local storage. We have follow on work that I can describe, but it depends on this first change.

adammoody avatar Jul 15 '22 16:07 adammoody

@tjruwase , @stas00 , here are a couple charts showing the speedup one can get using this approach. This first chart shows the checkpoint cost in seconds when scaling up the node count, while holding the hidden size constant and increasing the number of layers.

ckpt_secs_5120

At 128 nodes, the total bytes written in this case is 1.15e12 (~1 TiB). The "Base" plot is the time I get with the original code, and the "PR#1419" plot shows the time for writing the same checkpoint using the changes in this PR. For the node counts shown, I get a speedup of 2-3x.

While this PR improves performance on its own, it also enables one to plug in the Scalable Checkpoint/Restart (SCR) library. Just for reference, the "SCR Single" and "SCR XOR" plots show the cost of writing the checkpoint files to /dev/shm on each node with SCR. This is 3-4x faster and the checkpoint cost is constant with increasing count.

adammoody avatar Sep 19 '22 22:09 adammoody

The cost to checkpoint the same model, but shown as effective write bandwidth:

ckpt_bw_5120

The peak write bandwidth of the parallel file system is about 160 GiB/s. The "Base" and "PR#1419" plots should both approach that limit asymptotically.

The SCR plots meet or surpass the parallel file system bandwidth at 128 nodes, and they continue to scale linearly with the node count.

adammoody avatar Sep 19 '22 22:09 adammoody

My main goal here is to see if we can get the PR accepted. I think this should be useful to others.

Separately, I can also talk about SCR more if you're interested. That can be helpful at larger scales, however, it also requires MPI (mpi4py).

adammoody avatar Sep 19 '22 22:09 adammoody

As I'm not part of the Deepspeed team my vote won't count, but your benchmarks are super-impressive and I'd say definitely go for it.

I will let @tjruwase to chime in and also perhaps having @jeffra's opinion shared as well.

Are there situations where this approach would be slower and in which case this behaviour should be configurable by the user?

stas00 avatar Sep 19 '22 23:09 stas00

@adammoody, this is certainly impressive and definitely of interest to DeepSpeed. Can you please refresh this PR? Please let me know how I can help. Thanks!

@GuanhuaWang, FYI

tjruwase avatar Sep 20 '22 02:09 tjruwase

@tjruwase , I've rebased this PR on the latest code. That should make it easier to compare. I haven't run new tests with it after rebasing, so I can't say for sure that it still works yet.

adammoody avatar Sep 20 '22 20:09 adammoody

@adammoody, it will be great to make this a configuration option through ds_config. I am thinking of the right way, but please let me know if you have some thoughts on this.

tjruwase avatar Sep 20 '22 20:09 tjruwase

@adammoody, it will be great to make this a configuration option through ds_config. I am thinking of the right way, but please let me know if you have some thoughts on this.

Sure. Some of the lines that I have dropped could instead be conditioned on a new ds_config option.

adammoody avatar Sep 20 '22 20:09 adammoody

@adammoody, it will be great to make this a configuration option through ds_config. I am thinking of the right way, but please let me know if you have some thoughts on this.

Sure. Some of the lines that I have dropped could instead be conditioned on a new ds_config option.

To disable this feature, I think we could just set start,end appropriately.

num_layers = len(self.forward_funcs)
if args.dp_parallel_write:
  # divide layers evenly among data parallel ranks
  offsets = ds_utils.partition_uniform(num_layers, dp_size)
  start, end = offsets[dp_rank], offsets[dp_rank+1]
else:
  # assign all layers to data parallel rank 0
  if dp_rank == 0:
    start, end = 0, num_layers
  else:
    start, end = 0, 0
layer_list = self.forward_funcs[start:end]

adammoody avatar Sep 21 '22 02:09 adammoody

Hi all, I've got some time to circle back to this. I'm hoping someone on the team can take a look and provide some feedback when they get a chance.

I have rebased this PR on the latest code. It still needs to be adopted to handle MOE. I think the idea should extend to the MOE checkpoint path, as well.

In the meantime, here are some measurements to demonstrate the performance difference.

I am training a model using 16 processes on 4 nodes with the following configuration:

zero stage = 1
tensor parallelism = 2
pipeline parallelism = 1
num layers = 8
hidden size = 5120

This gives 8 data parallel instances of a model that has 1 pipeline stage. There are 8 layers in that single stage.

Before this PR, the processes that have rank 0 within the data parallel group write all 8 layers of the transformer. With this PR, all processes in the job write 1 layer of the transformer (8 layers / 8 data parallel instances).

I added timers into the checkpoint path. The two calls that use significant time are the cost to write the non-zero checkpoint files and the cost to write the zero checkpoint files, i.e., I've got timers to report the number of seconds around these two lines:

self._save_checkpoint(save_dir, tag, client_state=client_state)
self._save_zero_checkpoint(save_dir, tag)

The changes here improve the cost writing the non-zero checkpoint files.

Before this change, I see the following timings for 4 different checkpoints (units of seconds where not labeled):

18: 0: non_zero: 15.127218961715698
18: 0: zero:     4.272188663482666
18: 15: time (ms) | save-checkpoint: 19613.66

18: 0: non_zero: 15.118722915649414
18: 0: zero:     4.241654872894287
18: 15: time (ms) | save-checkpoint: 19414.78

18: 0: non_zero: 14.895861148834229
18: 0: zero:     2.4203057289123535
18: 15: time (ms) | save-checkpoint: 18793.07

18: 0: non_zero: 15.461049795150757
18: 0: zero:     4.352852821350098
18: 15: time (ms) | save-checkpoint: 19825.21

When using the optimization in this PR, I get the following instead:

15: 0: non_zero: 4.961373567581177
15: 0: zero:     6.336402893066406
15: 15: time (ms) | save-checkpoint: 11619.31

15: 0: non_zero: 3.244314193725586
15: 0: zero:     7.434791803359985
15: 15: time (ms) | save-checkpoint: 10897.41

15: 0: non_zero: 4.893619537353516
15: 0: zero:     5.75852370262146
15: 15: time (ms) | save-checkpoint: 14128.98

15: 0: non_zero: 5.354466676712036
15: 0: zero:     5.717851877212524
15: 15: time (ms) | save-checkpoint: 11235.55

The total checkpoint time drops from about 19 seconds to 11. The gains come in reducing the cost of writing the non_zero files, which drops from 15 seconds to 5, even though the cost to write the zero files seems to have bumped up a bit.

Hi @adammoody , this is great contribution to us. Just curious, is there any reason why zero performance is slightly worse compared to baseline?

Also for the format check, please follow these instructions here

GuanhuaWang avatar Sep 21 '22 04:09 GuanhuaWang

Hi @adammoody , this is great contribution to us. Just curious, is there any reason why zero performance is slightly worse compared to baseline?

Also for the format check, please follow these instructions here

Hi @GuanhuaWang , I don't yet know why the cost to write the zero files slowed down with this change. That was a surprise to me.

Thanks for the formatting tip. I'll take a look at cleaning that up.

adammoody avatar Sep 21 '22 16:09 adammoody

This PR was written before the checkpoint bloat optimization was added. I see that the layer checkpoint files are much smaller after rebasing to pick up the bloat fix. That likely reduces absolute performance benefit from the graphs I show above, since the baseline cost should drop significantly. I'd have to queue up a new set of runs.

adammoody avatar Sep 21 '22 18:09 adammoody

@adammoody, I think the perf benefits are still useful regardless of the bloat fix.

The main issue as far as I can see right now is a clean configuration design to enable/disable and potentially specify the parallelism degree of checkpointing. It would be nice if the configuration knobs could be easily used for parallelization of the other non_zero checkpoints. Perhaps, that could be a TODO for now given how old this PR is. Please share your thoughts.

tjruwase avatar Sep 21 '22 19:09 tjruwase

@tjruwase , it should be easy to enable/disable this feature based on a ds configuraton option with code like shown here: https://github.com/microsoft/DeepSpeed/pull/1419#issuecomment-1253134230.

Want to go with something like that?

Maybe as a first step we could define a config option that leaves this disabled by default.

adammoody avatar Sep 29 '22 20:09 adammoody

@adammoody, one possibility is to modify checkpoint configuration as below:

   "checkpoint": {
      ...,
      "parallel_write": {
           "pipeline_stage": [true|false],
           "tensor_slice": [true|false]    
   }

I think the above will address your PR (parallel writes of pipeline stage), and support future extension to parallelizing of tensor slice, all on the data parallel dimension. However, I am open to other suggestions.

@stas00, FYI

tjruwase avatar Sep 30 '22 10:09 tjruwase

@adammoody, one possibility is to modify checkpoint configuration as below:

   "checkpoint": {
      ...,
      "parallel_write": {
           "pipeline_stage": [true|false],
           "tensor_slice": [true|false]    
   }

I think the above will address your PR (parallel writes of pipeline stage), and support future extension to parallelizing of tensor slice, all on the data parallel dimension. However, I am open to other suggestions.

@stas00, FYI

Thanks, @tjruwase . I'm working to add these config options now. I've pushed a commit for just the pipeline_stage value right now. A couple of questions:

  1. Based on what I have so far, do you have any recommendation on the names, option processing, or error message?
  2. I'll need to pass this setting down to be used in runtime/pipe/module.py. What do you recommend for that? It looks like one method would be to modify the PipelineModule constructor to take another argument, but maybe there is another way?

adammoody avatar Oct 03 '22 21:10 adammoody

@tjruwase , I found one way to get the value of that new configuration setting. This seems to work for me now.

adammoody avatar Oct 04 '22 01:10 adammoody

@tjruwase , I've only added the pipeline option so far, but how do things look at this point?

adammoody avatar Oct 10 '22 21:10 adammoody

@adammoody, FYI, this is in 0.7.4 Patch Release

tjruwase avatar Oct 21 '22 21:10 tjruwase

@adammoody, FYI, this is in 0.7.4 Patch Release

Great. Thanks!

adammoody avatar Oct 26 '22 16:10 adammoody