DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

About the performance of Using NVMe SSd.

Open luckyq opened this issue 2 years ago • 8 comments

I am curious about the how much time it needs for the parameter updating part if we use SSD.

I notice that the data transfer will be overlapped between pulling data from SSD and pushing data into SSD. But will the data transfer still be the big problem in this part?

Thanks in advance.

luckyq avatar Jun 21 '22 00:06 luckyq

You may find this discussion helpful: #998

tjruwase avatar Jun 21 '22 01:06 tjruwase

@luckyq, should this issue remain open? Thanks!

tjruwase avatar Jul 29 '22 12:07 tjruwase

@luckyq, should this issue remain open? Thanks!

Yes. Sorry for the late reply. I have read that discussion history. but it only contains the results of benchmarking NVMe bandwidth.

May I ask what is the breakdown time of using ZeRO-Infinity to train 10B or 50B models? In that case, would I/O still be the bottleneck?

If you can share some logs for me, that would be very helpful Thanks in advance.

luckyq avatar Aug 04 '22 23:08 luckyq

I notice that the data transfer will be overlapped between pulling data from SSD and pushing data into SSD. Can you be a bit more specific, perhaps point to the code? The only thing that comes to mind is the pipelined implementation for optimizer step, where we attempt to overlap compute with SSD reads and writes. Is this what you are referring to?

But will the data transfer still be the big problem in this part? It depends on the nvme read/write bandwidth and how fast the forward/backward are, so it could be a problem. It just depends on the system and training configuration.

Unfortunately, these experiments were over a year ago and I may be forgetting some nuances. But anyways, below is a log of 16xV100-32GB on a DGX-2 for a 50B model with NVMe offload of optimizer but without pipelining. We did not use the pipelining results in the paper as we did not have time to tune it properly before paper submission. Notice forward is 31 secs, backward 55 secs, and optimizer 69 secs. The NVMe peak rates are 28GB/sec reads and 25GB/secs writes. Hope it is helpful.

worker-11: [2021-04-04 20:55:08,253] [INFO] [utils.py:559:see_memory_usage] before forward 4
worker-11: [2021-04-04 20:55:08,253] [INFO] [utils.py:564:see_memory_usage] MA 7.38 GB         Max_MA 7.39 GB         CA 25.52 GB         Max_CA 26 GB
worker-11: [2021-04-04 20:55:08,253] [INFO] [utils.py:569:see_memory_usage] CPU Virtual Memory:  used = 380.92 GB, percent = 25.2%
worker-11: iteration 4 lm loss = 36.36602783203125 loss_reduced = {'lm loss': tensor(36.9627, device='cuda:0')}
worker-11: [2021-04-04 20:55:40,003] [INFO] [utils.py:559:see_memory_usage] before backward 4
worker-11: [2021-04-04 20:55:40,003] [INFO] [utils.py:564:see_memory_usage] MA 10.71 GB         Max_MA 11.76 GB         CA 25.52 GB         Max_CA 26 GB
worker-11: [2021-04-04 20:55:40,004] [INFO] [utils.py:569:see_memory_usage] CPU Virtual Memory:  used = 536.2 GB, percent = 35.5%
worker-11: [2021-04-04 20:56:35,300] [INFO] [utils.py:559:see_memory_usage] before optimizer 4
worker-11: [2021-04-04 20:56:35,301] [INFO] [utils.py:564:see_memory_usage] MA 7.37 GB         Max_MA 15.95 GB         CA 25.52 GB         Max_CA 26 GB
worker-11: [2021-04-04 20:56:35,302] [INFO] [utils.py:569:see_memory_usage] CPU Virtual Memory:  used = 380.95 GB, percent = 25.2%
worker-11: [2021-04-04 20:56:35,332] [INFO] [async_swapper.py:121:_report_statistics] Swapped out[Before flush] num_elems = 1610612736,  6.00 GB
worker-11: [2021-04-04 20:56:36,083] [INFO] [async_swapper.py:121:_report_statistics] Swapped out[After flush] num_elems = 3146842112, 11.72 GB
worker-11: [2021-04-04 20:57:44,681] [INFO] [logging.py:60:log_dist] [Rank 0] rank=0 time (ms) | async_swap_gradient_wait: 1660.79 | swap_out_gradient: 4347.28 | swap_out_param: 22185.82 | swap_in_gradient: 9642.63 | swap_in_param: 17401.55
worker-11: [2021-04-04 20:57:44,681] [INFO] [logging.py:60:log_dist] [Rank 0] rank=0 time (ms) | optimizer_step: 69122.27 | optimizer_swap_out_state: 22186.27 | optimizer_swap_in_state: 27798.26
worker-11: [2021-04-04 20:57:44,686] [INFO] [logging.py:60:log_dist] [Rank 0] step=5, skipped=0, lr=[0.00014940860259858585, 0.00014940860259858585], mom=[(0.9, 0.999), (0.9, 0.999)]
worker-11: [2021-04-04 20:57:44,686] [INFO] [timer.py:157:stop] 0/5, SamplesPerSec=0.8287545027144215
worker-11: [2021-04-04 20:57:44,686] [INFO] [logging.py:60:log_dist] [Rank 0] rank=0 time (ms) | forward_microstep: 31621.82 | backward_microstep: 55222.81 | backward_inner_microstep: 55166.54 | backward_allreduce_microstep: 56.22 | step_microstep: 69384.40
worker-11: [2021-04-04 20:57:44,686] [INFO] [logging.py:60:log_dist] [Rank 0] rank=0 time (ms) | forward: 31621.79 | backward: 55222.79 | backward_inner: 55166.52 | backward_allreduce: 56.20 | step: 69384.38
worker-11:  iteration        5/       5 | elapsed time per iteration (ms): 156522.2 | learning rate: 1.494E-04 | lm loss: 3.696275E+01 | loss scale: 1.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
worker-11: time (ms) | forward: 31638.83 | backward: 55222.92 | backward-backward: 55222.87 | backward-allreduce: 0.00 | optimizer: 69384.58 | batch generator: 1.29
worker-11: Effective Tera Flops per GPU: 26.36 and total parameters 50.356 B
~

As you can imagine the code has changed substantially but we have not revisited these benchmarking. I think it might be useful for you to describe what you are trying to achieve and share some experiment results on your side.

tjruwase avatar Aug 05 '22 09:08 tjruwase

That's very helpful. Thanks a lot. I have a following question. For the optimizer time, does it only represent CPU optimization time + I/O time(Read previous parameter, gradients, Adam optimizer states)? Will i/o take the most time of "optimizer time"?

Thanks in advance.

luckyq avatar Aug 08 '22 00:08 luckyq

I have a following question. For the optimizer time, does it only represent CPU optimization time + I/O time(Read previous parameter, gradients, Adam optimizer states)? Will i/o take the most time of "optimizer time"?

Yes, optimizer time includes all of those components. The I/O portion will correspond to the available I/O device bandwidth. Read/write speeds of a single NVMe device is ~3 GB/sec, and so depending on model size and GPU count, I/O time could dominate unless you can scale up I/O speeds using multiple NVMe devices.

tjruwase avatar Aug 29 '22 11:08 tjruwase

Hi I'm new to DeepSpeed and particularly interested in NVMe parameter offloading. I have got some questions regarding to characteristics of how it works hope you could answer @tjruwase :

  • Does it prefetch before actually needed, or just adding an extra blocking stage when underlying data not available on GPU?
  • Does it support GPUDirect or the data has to go thru the CPU as a forwarder or even serial/deserializer?
  • Does it support multiple NVMe directories or I have to implement my own ZFS/Raid0 solution to fully saturate PCle5.0/4.0 bandwidth with multiple NVMe devices? (I presume it would be consecutive chunks/ layers mapped to least used drive?)
  • Does it re-offload (write to device) previous offloaded parameters in inference only mode? (no parameters would ever be updated) Or to say, does it avoid overloading NVMe initial cached writing performance by re-using them? (most do slowdown pretty quick and serve )?
  • Does it support partial parameter offloading (extremely beneficial if both above is supported and when I only want a small portion of parameters fine-tuned)

Willian-Zhang avatar Aug 31 '22 17:08 Willian-Zhang

@Willian-Zhang, thanks for your questions. If you have not done so, please see the ZeRO-Infinity paper for more details.

  1. Yes, layers are prefetched on two levels. Computation of current layer is overlapped with prefetching next layer rom CPU into GPU, and prefetching next-next layer from NVMe into CPU.
  2. GPUDirect is on our TODO.
  3. Current implementation assumes single volume of multiple devices. Extension to multiple devices is straightforward.
  4. Only updated parameters are written back, so no writes in inference mode.
  5. We recently added support for partial parameter offloading in #2089, but not yet documented.

tjruwase avatar Sep 02 '22 13:09 tjruwase