DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[RFC] add device abstraction to allow other device than CUDA be used

Open delock opened this issue 3 years ago • 22 comments

This is a proposal to add device abstraction into DeepSpeed. Currently DeepSpeed has CUDA hard coded, which makes it works for device with CUDA abstraction only. In order to make more devices work for DeepSpeed. We need to make DeepSpeed not depending on CUDA, but depend on a device abstraction layer that could support different device types. In this proposal, we could support both CUDA device and Intel GPU device through pytorch XPU extension. In addition, we also support build SYCL kernels through SYCLOpBuilder for Intel GPU device.

This proposoal has the following design goals:

  1. Make DeepSpeed work for both CUDA device and Intel GPU device.
  2. Friendly for extending to other partie's accelerator devices.
  3. Minimal impact to current DeepSpeed models. Current models still work with DeepSpeed on CUDA device without modification. Model with CUDA hard coded with need modification to work on both CUDA device and Intel GPU.
  4. Use as less if...else... as possible when a piece of code needs to support both CUDA device and Intel GPU device.

High level design of accelerator abstrction

The high level design and implementation of accelerator abstracion is based on and extended from #2320:

  1. Use DeepSpeedAccelerator abstract class to define all accelerator interface
  2. A single global DeepSpeedAccelerator object can be actively or lazily initiated and can be used throughout DeepSpeed code and models to access accelerator functionalities. This object can be accessed through get_accelerator() and set with set_accelerator()
  3. Concrete accelerator implementation such as CUDA or XPU can be in external module and can be imported by DeepSpeed during initialization.

DeepSpeedAccelerator abstract class

DeepSpeedAccelerator abstract class define the interface a concrete accelerator needs to implement, it has the following interface types:

  1. Relates to accelerator device name. This is mainly related to usage such as 'cuda', 'cuda:0', etc. The interface name in this category is device_name() and current_device_name().
  2. Relates to accelerator runtime. This is mainly related to torch.cuda.<interface_name> such as is_available(), synchronize(), etc.
  3. Relates to tensor operation. This is mainly related to tensor operation that relies to device type. The interface name in this category is pin_memory() and on_accelerator()
  4. Relates to communication backend. This is used to select accelerator specific communication backend such as 'nccl' for CUDA device and 'ccl' for XPU device. The interface name in this category is communication_backend_name()
  5. Relates to op builder. This is used to select op builder for building accelerator kernels. The interfacename in this category is create_op_builder()

Concrete accelerator selection

Concreate accelerator selection is through deepspeed.accelerator.real_accelerator, there are two interface to set/get concreate accelerator: set_accelerator(accel_obj) -- set global accelerator to parameter, this interface can be used in the beginning of model before deepspeed intialization get_accelerator() -- get the global accelerator, if global accelerator had not been set, detect whether xpu or cuda support is present in the system and set global accelerator object accordingly, if no accelerator support is detected, return CUDA accelerator object by default.

Implement concrete accelerator in external module

Concrete accelerator can be implemented in external module, the implementation should provide an accelerator class that derives DeepSpeedAccelerator, an example of implementation can be found in cuda_accelerator.py. A model can import this external module and initiate an accelerator object and use set_accelerator to set DeepSpeed to use this accelerator:

from <external-module-for-accelerator> import <EXTERNAL>_Accelerator
accel = <EXTERNAL>_Accelerator()
from deepspeed.accelerator import set_accelerator
set_accelerator(accel)

Write accelerator specific code in DeepSpeed and model

Accelerator runtime

The accelerator abstraction provide a single entrance for accelerator specific features, which takes the form:

from deepspeed.accelerator import get_accelerator

<code that use accelerator functionality> get_accelerator().<interface name>(...)

For existing torch.cuda.<interface name> runtime call, we convert it like the following example:

if torch.cuda.is_available():
    ...

-->

if get_accelerator().is_available():
    ...

For CUDA specific device name such as 'cuda' or 'cuda:0', or 'cuda:1', we convert them to get_accelerator().device_name(), get_accelerator().device_name(0), and get_accelerator().device_name(1).

It is a little bit trick when we convert places where torch.cuda.current_device() are called. Current device return device index, but if we supply device index in Pytorch code where a device is needed, Pytorch will explain it as a CUDA device. To get current device that can be used as a device name, we need to call get_accelerator().current_device_name():

my_tensor = torch.empty(3, 4, device=get_accelerator().current_device_name())

Only when an integer number is expected we use get_accelerator().current_device():

idx = get_accelerator().current_device()
default_generator = get_accelerator().default_generator(idx)

Tensor operations

When we convert a torch tensor to accelerator device such as my_tensor.cuda(), we use my_tensor.to(get_accelerator().deivce_name())

When we check whether a torch tensor is on accelerator device such as my_tensor.is_cuda, we use get_accelerator().on_accelerator(my_tensor)

When pin a tensor to GPU memory such as my_tensor.pin_memory(), we use get_accelerator().pin_memory(my_tensor)

Communication backend

When a communication backend string is used, the interface get_accelerator().communication_backend_name() is used get get communication backend name. So instead of torch.distributed.init_process_group('nccl'), we use torch.distributed.init_process_group(get_accelerator().communication_backend_name())

Op builder abstraction

Op builders are abstracted through get_accelerator().create_op_builder(<op builder name>), if the op builder is implemented in the accelerator, an object of OpBuilder subclass will be returned. If the op builder is not implemented, None will be returned.

A typical implementation can be referred to from the CUDA implementation, or from an XPU implementation which will be released later. Typical call such as CPUAdamBuilder().load() can be convert to get_accelerator().create_op_builder("CPUAdamBuilder").load().

delock avatar Aug 16 '22 09:08 delock

@jeffra is there any comments for the over structure, direction, etc. for device abstraction or selection?

delock avatar Sep 09 '22 09:09 delock

@delock, apologies for the delay. We are still iterating on our thoughts and will sync with you asap. As you might notice, we have linked PR that builds on yours. Please stay tuned.

tjruwase avatar Sep 13 '22 14:09 tjruwase

@tjruwase Thanks for reminding. I'll read the PR and will raise question in comments.

delock avatar Sep 15 '22 05:09 delock

The latest accelerator runtime interface pin_memory() is used to convert tensor interface: Tensor t.pin_memory() --> accel_runtime.pin_memory(t)

In previous code t.pin_memory() is translated to t.pin_memory(device=accel_runtime.current_device()). However, this only works for latest Pytorch and not work for pytorch 1.8. So instead of select device in tensor interface, a new runtime interface is added for better backward compatibility.

delock avatar Sep 22 '22 08:09 delock

@delock, I notice you merged #2320. Is the PR already working correctly for you? Thanks!

tjruwase avatar Sep 26 '22 15:09 tjruwase

We merged the class definition and now we are modifying all accel_runtime and literal_device call site to use get_accelerator(). We are still testing internally before we push the change to this branch.

From what we currently observe, there are few more interface needed in DeepSpeedAccelerator:

get_accelerator().device_name(): needed when a device string is supplied, or deepspeed code needs to know current backend type get_accelerator().current_device_name(): will be used to replace most torch.cuda.current_device() call, unless an integer return value is expected

Tensor interfaces: get_accelerator().pin_memory(tensor): needed because pytorch pin_memory interface changed since version 1.8 get_accelerator().on_accelerator(tensor)

Stream should accept 'device' as first parameter name than 'device_index'

-    def Stream(self, device_index=None, priority=0, **kwargs):
+    def Stream(self, device=None, priority=0, **kwargs):

Will push the code when we done internal testing.

delock avatar Sep 30 '22 00:09 delock

We merged the class definition and now we are modifying all accel_runtime and literal_device call site to use get_accelerator(). We are still testing internally before we push the change to this branch.

From what we currently observe, there are few more interface needed in DeepSpeedAccelerator:

get_accelerator().device_name(): needed when a device string is supplied, or deepspeed code needs to know current backend type

We have added .device_name() to obtain device string, .name() to return accelerator name, and .communcation_backend_name(). Will that be sufficient?

get_accelerator().current_device_name(): will be used to replace most torch.cuda.current_device() call, unless an integer return value is expected

Could this be achieved by calling .device_name() with no or None argument, rather than creating a new interface?

Tensor interfaces: get_accelerator().pin_memory(tensor): needed because pytorch pin_memory interface changed since version 1.8 get_accelerator().on_accelerator(tensor)

Stream should accept 'device' as first parameter name than 'device_index'

-    def Stream(self, device_index=None, priority=0, **kwargs):
+    def Stream(self, device=None, priority=0, **kwargs):

Will push the code when we done internal testing.

Sounds great. Thanks!

tjruwase avatar Sep 30 '22 12:09 tjruwase

@tjruwase if we want to add a workflow for xpu device, which located outside Azure but remotely accessible, is it technically possible? Want to assess the possibility of gating CUDA only code with an xpu device.

delock avatar Oct 13 '22 10:10 delock

@tjruwase can I get approval from maintainer to run workflows for new changes?

delock avatar Oct 17 '22 16:10 delock

We are working on an OpBuilder abstraction in our internal repo, allows kernels and SYCLOpBuilder (or any accelerator builders) be put in seperate extension package. Will put to this PR when ready.

delock avatar Oct 17 '22 16:10 delock

OpBuilder abstraction has been added to this PR, we will update description to explain the mechanism.

delock avatar Oct 23 '22 02:10 delock

Hi, the new accelerator abstraction interface had been integrated and we also added support for OpBuilder. The new interface definition and its integration code is ready for review now. We will merge with main after internal test, after that this PR will be ready for test.

delock avatar Nov 08 '22 01:11 delock

@delock, thanks for the update. This is awesome!

tjruwase avatar Nov 08 '22 12:11 tjruwase

@tjruwase @jeffra this PR seems keeping conflict with master branch. How about merge with smaller PRs? Step 1. PR that merge the interface definition and implementation part. This ensures the interface brining into DeepSpeed Step 2. In a batch of N files, merge the usage of abstract interface, so each PR would be relatively small and easy to test and merge.

This PR will still exist and keep merge with master for two purpose:

  1. Used to validate XPU support from our internal branch before release XPU support
  2. Used as a map to show where we are in the merging process, once every file is merged this PR will be the same as master and can be closed.

If you are okay with this process, can you review the overall structure and see if there are any global changes we still need, before we start to split and submit smaller PRs? For any local changes, we can review and discuss under each seperate PRs.

delock avatar Nov 10 '22 15:11 delock

@delock, thanks your proposal makes sense. Let's do that.

tjruwase avatar Nov 11 '22 12:11 tjruwase

@tjruwase #2504 had been created as first step of this pull request.

delock avatar Nov 13 '22 14:11 delock

@tjruwase https://github.com/microsoft/DeepSpeed/pull/2560 is created as step 2 of this PR and ready for review.

delock avatar Dec 02 '22 10:12 delock

@delock, will review asap. Thanks so much!

tjruwase avatar Dec 02 '22 15:12 tjruwase

@tjruwase Hi, https://github.com/microsoft/DeepSpeed/pull/2677 had been created as step 3 of this PR and is ready for review.

delock avatar Jan 09 '23 01:01 delock

@delock, is this PR still actively developed for merging?

tjruwase avatar Jan 27 '23 19:01 tjruwase

@delock, is this PR still actively developed for merging?

@tjruwase thanks for asking. We are pulling this PR into our internal repo for testing and see if there are any fix needed for the latest main branch changes in tests. Once test passed, I'll reply to this PR that it is ready to merge.

delock avatar Jan 28 '23 02:01 delock

@delock, is this PR still actively developed for merging?

Hi @tjruwase, We have validated this branch in our environment. The latest Intel extension for DeepSpeed works with this PR now. The remaining changes are mainly in benchmarks and tests, so I think we can review on this PR directly for last merge step. This PR will be step 4 of non-CUDA device support in DeepSpeed. After this step, both benchmarks and UTs should be able to run on top of the new acclerator interface in DeepSpeed.

delock avatar Jan 31 '23 06:01 delock

Hi @tjruwase , want to know whether this PR is in the merge queue or still need some changes.

Currently DeepSpeed engine is already integrated with accelerator abstraction and the rest is integration for tests and benchmarks in this PR.

We use UT to test DeepSpeed functionality on Intel hardware so its better integration of tests and benchmarks in this PR are also merged. Thanks!

delock avatar Mar 03 '23 01:03 delock

@delock, this PR is in the merge queue

tjruwase avatar Mar 07 '23 02:03 tjruwase

@delock it feels like we could add a pre-commit hook to ensure that our formatter fails if someone tries to use torch.cuda outside the new get_accelerator api. Similar to this: https://github.com/microsoft/DeepSpeed/blob/80d8fcbdb3f1121bef358dac70476648a3c8ca55/.pre-commit-config.yaml#L36-L43

If you agree, would you want to give this a try?

jeffra avatar Mar 07 '23 17:03 jeffra

@delock it feels like we could add a pre-commit hook to ensure that our formatter fails if someone tries to use torch.cuda outside the new get_accelerator api. Similar to this:

https://github.com/microsoft/DeepSpeed/blob/80d8fcbdb3f1121bef358dac70476648a3c8ca55/.pre-commit-config.yaml#L36-L43

If you agree, would you want to give this a try?

Thanks @jeffra , Yes, this could be a quick check. Let me give this a try.

Some followups we are thinking of:

  1. Regression check. In current stage maybe with this hook and periodical internal test
  2. Documentation in DeepSpeed for how to use the get_accelerator() interface to write device agnostic model, and a guide how to add new accelerator.
  3. Clean up variable/function names in DeepSpeed contains cuda, this item maybe should be done in small steps.

delock avatar Mar 09 '23 00:03 delock

https://github.com/microsoft/DeepSpeed/pull/2981 is created for pre-commit check. @jeffra

delock avatar Mar 09 '23 11:03 delock