alpa icon indicating copy to clipboard operation
alpa copied to clipboard

[FEATURE] A CPU Swapping Runtime

Open merrymercy opened this issue 1 year ago • 6 comments

Background

To train or serve large models with limited GPU memory resources, we can utilize the huge amount of available CPU memory by swapping tensors between CPU and GPU. In this project, we are going to implement a swapping runtime for Alpa. We can start with the easiest case: swapping between 1 CPU and 1 GPU for serving. We can then move to more complicated cases: swapping between distributed CPUs and GPUs for training.

Todo

  • A Local Swapping Runtime
    • [ ] Implement swapping on top of this local runtime. To see how this runtime works, you can run this testcase. Currently, all tensors are stored in this env as GPU tensors. To implement swapping, we just need to move some tensors in this env to CPU.
    • [ ] Implement necessary optimizations such as overlapping swapping and computation (e.g., pre-fetching).
    • [ ] Swap to disk if CPU memory is not enough

References

SwapAdvisor: Push Deep Learning Beyond the GPU Memory Limit via Smart Swapping Harmony: Overcoming the Hurdles of GPU Memory Capacity to Train Massive DNN Models on Commodity Servers DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale

merrymercy avatar Sep 09 '22 05:09 merrymercy

The key point for swapping in XLA is that all parameters should be already in GPU when launching an XlaExecutable. To address this:

  • When the model is not very large. We can split more stages so that the parameter for each stage can be prepared before starting;
  • When the model is extremely large that even parameters of a single transformer layer(or likewise layer) cannot be placed into the GPU memory simultaneously. Although we can still split each operator as a stage, the auto-sharding pass will be inefficient. We can
    • split each operator into a stage, but run auto-sharding with multiple stages. To avoid missing optimization opportunities like fusion, we can split stages not at the JAX level but at the optimized HLO level.
    • modify the HloModule. Use custom calls to swap parameters in the HloComputation and replace all parameters with the output of such custom calls.
  • When the model is even larger that a single GeMM cannot be placed into the GPU memory. We need a hand-optimized GeMM kernel that runs GeMM for a sub-matrix while swapping in another sub-matrix. The hand-optimized kernel will replace the corresponding HloInstruction. Such a kernel also helps with cases that are not extremely memory intense because it overlaps swapping and computation.

ZYHowell avatar Sep 09 '22 20:09 ZYHowell

For the infra, you may want to leverage the latest xla runtime effort. We already implemented a similar solution in tf/pt with heuristic device/data placement and schedule. It works well in certain hardware systems.

On Fri, Sep 9, 2022 at 1:47 PM Yonghao Zhuang @.***> wrote:

The key point for swapping in XLA is that all parameters should be already in GPU when launching an XlaExecutable. To address this:

  • When the model is not very large. We can split more stages so that the parameter for each stage can be prepared before starting;
  • When the model is extremely large that even parameters of a single transformer layer(or likewise layer) cannot be placed into the GPU memory simultaneously. Although we can still split each operator as a stage, the auto-sharding pass will be inefficient. We can
    • split each operator into a stage, but run auto-sharding with multiple stages. To avoid missing optimization opportunities like fusion, we can split stages not at the JAX level but at the optimized HLO level.
    • modify the HloModule. Use custom calls to swap parameters in the HloComputation and replace all parameters with the output of such custom calls.
  • When the model is even larger that a single GeMM cannot be placed into the GPU memory. We need a hand-optimized GeMM kernel that runs GeMM for a sub-matrix while swapping in another sub-matrix. The hand-optimized kernel will replace the corresponding HloInstruction. Such a kernel also helps with cases that are not extremely memory intense because it overlaps swapping and computation.

— Reply to this email directly, view it on GitHub https://github.com/alpa-projects/alpa/issues/694#issuecomment-1242456460, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDJMFAXPFOKYJWETETT7UDV5OO6DANCNFSM6AAAAAAQILQQAU . You are receiving this because you are subscribed to this thread.Message ID: @.***>

ff7250 avatar Oct 11 '22 09:10 ff7250

Cpu Compute Runtime

  • Add a global configuration to choose platform ("cpu" or "gpu") (https://github.com/alpa-projects/alpa/blob/main/alpa/global_env.py)
  • Replace all hard-coded "GPU" with that global configuration (e.g., https://github.com/alpa-projects/alpa/blob/fbcb2abf6cb03215cbfc88a01e4b50199f9dfb7e/alpa/device_mesh.py#L802). There are multiple places such as alpa/device_mesh.py, alpa/util.py

merrymercy avatar Oct 11 '22 22:10 merrymercy

Cpu Compute Runtime

  • Add a global configure to choose runtime ("cpu" or "gpu") (https://github.com/alpa-projects/alpa/blob/main/alpa/global_env.py)

  • Replace all hard-coded "GPU" to that global configuration (e.g., https://github.com/alpa-projects/alpa/blob/fbcb2abf6cb03215cbfc88a01e4b50199f9dfb7e/alpa/device_mesh.py#L802)

I have some similar code in the tpu-support branch

ZYHowell avatar Oct 11 '22 22:10 ZYHowell

@ff7250 Sounds good! Could you give us some pointers to the code and usage?

merrymercy avatar Oct 11 '22 23:10 merrymercy