DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[REQUEST] Upstream modifications of MiCS

Open zarzen opened this issue 2 years ago • 4 comments

Is your feature request related to a problem? Please describe. Instead of partitioning the model to all devices as ZeRO-3 did, using subgroup sharding and hierarchical communications can improve the training performance significantly.

Describe the solution you'd like Our team plan to open source our implementation of MiCS (https://arxiv.org/abs/2205.00119) and upstream the codebase.

Describe alternatives you've considered N/A

Additional context

Thanks for sharing the DeepSpeed with the community! It has been an important infrastructure as an easy-to-use large model training solution. Recently, we did some improvement to sharded data-parallel solution to further improve the training performance and maintain the simplicity of ZeRO based solutions. To share our effort with the DeepSpeed community, we would like to integrate our modification into DeepSpeed codebase. Our original modification is based on an older version of DeepSpeed. We are reworking our solution for the upstream purpose, so to minimize the code divergence to the master branch of DeepSpeed. To minimize the conflicts of the interests, we create this issue for discussions about the potential modifications.

We list tentative modifications of the code changes for discussion:

  • We plan to place most of “MiCS” specific implementations into deepspeed/runtime/zero/mics_*.py files.
  • To ease the maintenance process, we plan to inherit from the class DeepSpeedZeroOptimizer_Stage3 to reuse existing code logics.
    • While, to allow the subclass (MiCS variants) to change the behaviors of the functions, we need to modify some functions of stage3.py and partition_parameters.py as well. The modifications would mostly let existing ZeRO-3 specific function implementations to work as logical structures. Thus, the behavior of the function can be altered in the subclasses.
      • E.g.
@@ -1217,22 +1218,22 @@ class Init(InsertPostInitMethodToModuleSubClasses):
                         handle = dist.all_gather_base(flat_tensor,
                                                       param.ds_tensor.to(
                                                           get_accelerator().device_name()),
            -                                          group=self.ds_process_group,
            +                                          group=self.get_partition_dp_group(param),
                                                       async_op=async_op)
                     else:
                         partitions = []
            -            for i in range(self.world_size):
            +            for i in range(self.num_partitions):
                             partitions.append(
                                 flat_tensor.narrow(0,
                                                    partition_size * i,
                                                    partition_size))
             
            -                if i == dist.get_rank(group=self.ds_process_group):
            +                if i == dist.get_rank(group=self.get_partition_dp_group(param)):
                                 partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
             
                         handle = dist.all_gather(partitions,
            -                                     partitions[self.rank],
            -                                     group=self.ds_process_group,
            +                                     partitions[self.get_partition_rank()],
            +                                     group=self.get_partition_dp_group(param),
                                                  async_op=async_op)

Please let us know your thoughts. Feel free to comment here or via email (zhzhn AT amazon.com). Thanks! .

Zhen

zarzen avatar Feb 07 '23 23:02 zarzen

@tjruwase @jeffra45 Please let us know your thoughts. Thanks!

szhengac avatar Feb 10 '23 21:02 szhengac

Hi @zarzen and @szhengac, thanks for reaching out with this proposal! We’re discussing and will respond later next week.

jeffra avatar Feb 11 '23 04:02 jeffra

@zarzen, I just sent you an email please let me know if you don't receive it for some reason :)

jeffra avatar Feb 14 '23 23:02 jeffra

@stas00, FYI

tjruwase avatar Feb 21 '23 20:02 tjruwase