pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

Introduce a device-agnostic runtime API design

Open guangyey opened this issue 1 year ago • 10 comments

Stack from ghstack (oldest at bottom):

  • #132371
  • #137493
  • #133572
  • -> #132204

Motivation

According to [RFC]A device-agnostic Python runtime API design for stream-based accelerators, this PR intends to introduce a device-agnostic runtime API design. I personally prefer the Simple Version APIs that no longer accept the device type as an input argument. It means we will leverage getAccelerator to fetch the current accelerator. And it is flexible to expand these APIs to handle multiple types of accelerator scenarios. The design does NOT break the previous design philosophies. I also believe that namespace torch.acc is better. It makes users know that the APIs they are calling are running on an accelerator rather than CPU. This is important. Meanwhile, we can follow a simple API design principle:

  1. Device-agnostic APIs should be placed under the torch.acc namespace and not accept a device_type optional parameter.
  2. Device-specific APIs should be placed under device-specific submodules.
  3. APIS required by both CPU and accelerators should be placed under the torch namespace and accept a device_type optional parameter.

Also, I list the pros and cons of Simple Version here: Pros:

  • torch.acc.foo will have the same input argument as torch.xxx.foo, bringing a better user experience;
  • more concise, facilitate the developer to write a device-agnostic code.

Cons:

  • no obvious drawbacks.

Additional Context

I list the new APIs here:

torch.acc.is_available() -> bool:
torch.acc.current_accelerator() -> str:
torch.acc.device_count() -> int:
torch.acc.current_device() -> int:
torch.acc.set_device(device: Union[torch.device, str, int, None]) -> None:
torch.acc.current_stream(device: Union[torch.device, str, int, None]) -> torch.Stream:
torch.acc.set_stream(stream: torch.Stream) -> None:
torch.acc.synchronize(device: Union[torch.device, str, int, None]) -> None:
torch.acc.DeviceGuard(device: int) -> context manager:
torch.acc.StreamGuard(stream: torch.Stream) -> context manager:

cc @albanD @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

guangyey avatar Jul 31 '24 02:07 guangyey

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/132204

Note: Links to docs will display an error until the docs builds have been completed.

:x: 1 New Failure, 1 Unrelated Failure

As of commit 202d5de9c85ae6fd953051d3a5b1e87946c07a53 with merge base 0efa590d435d2b4aefcbad9014dd5fa75dcf8405 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Jul 31 '24 02:07 pytorch-bot[bot]

From my point of view, compared to DeviceGuardImplInterface , it is maybe better to implement this device-agnostic api on top of AcceleratorHooksInterface, which should be positioned to provide a common interface for new backends

cc @albanD

fffrog avatar Sep 02 '24 03:09 fffrog

I agree with @FFFrog that implementing this device-agnostic api on top of AcceleratorHooksInterface is better. Reason 1 is that all these apis are based on the precondition of at::getAccelerator(), so gathering them together is more reasonable and readable. Reason 2 is that DeviceGuardImplInterface also includes non-accelerators, like CPU and Meta, while we expect all these device-agnostic apis only regarding accelerators.

wizzniu avatar Sep 02 '24 07:09 wizzniu

Sorry for being late to the party here, I agree with not going with device module getter.

One thing that I would like to consider is to actually create a torch.acc.XXX namespace to host all of these. The main reason for this are few fold:

  • This makes it clear in the library code that this is accessing the accelerator and NOT the cpu/other device.
  • It will make documenting these a lot easier as we can have all of them in a single namespace with a given page
  • It will make the (hopefully rare) device-specific code easier to differentiate as it unambiguously calls into a different submodule.
  • We don't have to have a "device_type" optional argument for each of these APIs

On the other hand, using torch.XXX allow us to:

  • Continue to make a general stance that "device_type not being specified implies accelerator"
  • Not create a new namespace making this easier to "just use" for end users
  • The general device_type optional argument makes it easy to locally use a different device
  • Any device-specific code lives in a submodule, while all device-generic APIs are on torch.
  • We can (and should) make a special page on the docs page to host all of these.

What do you think?

@albanD Thank you very much for your comments. I also believe that torch.acc is better. I think torch.acc lets users know that the APIs they are calling are running on an accelerator rather than CPU, which is important. Perhaps, we can follow a simple principle:

  1. Device-agnostic APIs should be placed under the torch.acc namespace and not accept a device_type optional parameter.
  2. Device-specific APIs should be placed under device-specific submodules.
  3. APIS required by both CPU and accelerators should be placed under the torch namespace and accept a device_type optional parameter. Do you think this is reasonable enough?

guangyey avatar Oct 03 '24 03:10 guangyey

That sounds good to me! Captured this in https://dev-discuss.pytorch.org/t/python-c-api-rules-for-device-generic-apis/2511 to be able to easily reference it in the future!

albanD avatar Oct 03 '24 15:10 albanD

When porting an existing CUDA specific code with this API, a good practice is to make the code backward compatible. i.e.

import torch                                                 
if hasattr(torch, 'cuda'):                                   
    import torch.cuda as acc                                 
else:                                                        
    import torch.acc as acc                                  

# device agnostic on pytorch with torch.acc, and cuda specific on pytorch without torch.acc
print(acc.is_available()) # check if GPU is available

In this way we can make user code compatible with both pytorch with torch.acc, and still works on CUDA device with older version of pytorch. This keeps user experience when they had not upgraded PyTorch, or avoid breaking workflow using older versions of PyTorch.

delock avatar Oct 08 '24 08:10 delock

@delock , good point. May I know if you are proposing a practice or a feature that we need to implement on the torch side?

EikanWang avatar Oct 08 '24 12:10 EikanWang

@guangyey , is this PR still WIP? If not, please refine the PR title.

EikanWang avatar Oct 08 '24 12:10 EikanWang

@delock , good point. May I know if you are proposing a practice or a feature that we need to implement on the torch side? Its a practice in user code. But the names under torch.acc needs to be consistent with torch.cuda in order for this practice to work. I see most APIs are consistent but may need to keep in mind when add new names under torch.acc

delock avatar Oct 08 '24 13:10 delock

When porting an existing CUDA specific code with this API, a good practice is to make the code backward compatible. i.e.

import torch                                                 
if hasattr(torch, 'cuda'):                                   
    import torch.cuda as acc                                 
else:                                                        
    import torch.acc as acc                                  

# device agnostic on pytorch with torch.acc, and cuda specific on pytorch without torch.acc
print(acc.is_available()) # check if GPU is available

In this way we can make user code compatible with both pytorch with torch.acc, and still works on CUDA device with older version of pytorch. This keeps user experience when they had not upgraded PyTorch, or avoid breaking workflow using older versions of PyTorch.

torch.acc has the same API names as torch.cuda to keep consistency except device guard and stream guard. Their name is still in discussion. We can facilitate user experience via alias name.

guangyey avatar Oct 08 '24 15:10 guangyey

Hello, I came here by way of #110080 looking for a device type agnostic API.

TLDR this is nice, but why aren't all CPU and accelerators just 'devices', more in line with what's there now but extending the API to be 'device type' agnostic instead of adding a new API that appears to be focused on a subset of all devices? To me this solves a few problems but leaves a number of related concerns unaddressed and possibly muddies the water further, making the ideal harder to obtain.


There are some good ideas here, but I feel this is still fragmenting device into 'accelerators' and other (read CPU). From standpoint of someone writing minimal train scripts, etc that should work on any and all devices without requiring loads of conditionals, etc that still seems problematic.

Instead of accelerators can't we build an API that covers ALL devices? From cpu, cpu /w extra instructions, to full blown stream enabled accelerators?

Some questions that pop into mind with this

  • How do I get my accelerator in the first place? What if there are multiple options on the system?
  • If I have a CPU, or mkldnn enabled CPU or other CPU + extensions, etc AND a GPU or other dedicated 'accelerator' on a system, do I still have to call torch.is_xxx_available for some 'non-accelerator' devices before I can use them, and then now a different API for 'accelerators'?

In my mind accelerator as a distinction, separate from 'CPU' is a near term idea. Sure, it's stuck around for a bit so far, but really they're all compute devices, and I want a generic API to pick my compute device (know what's avail on my system) and determine it's capabilities (supports streams or not, requires explicit graph exe (mark_step), ec), and what subset of the API is useable...

rwightman avatar Oct 21 '24 17:10 rwightman

Hi @rwightman, thanks for your idea.

Here we deliberate to place the accelerators APIs under the namespace torch.accelerator to tell the user what API they are using is tailored for accelerators, not CPU.

Instead of accelerators can't we build an API that covers ALL devices? From cpu, cpu /w extra instructions, to full blown stream enabled accelerators?

Considering that the architecture of CPU and stream-based accelerators differs, I don't think the user can use the same code script for both CPU and accelerators to achieve a good performance.

How do I get my accelerator in the first place? What if there are multiple options on the system?

You can use torch.accelerator.is_available to check the workload for CPU or accelerators, and use torch.accelerator.current_accelerator to know what device type the current accelerator is. Then add some device-specific code based on the current accelerator type if necessary. Currently, PyTorch doesn't support building multi-kind accelerators (like CUDA and XPU) at once on a host. PyTorch declares this limitation on the definition of Accelerator. In the future, if PyTorch needs to support this scenario, we still could extend these accelerators APIs to accept device type as its input.

If I have a CPU, or mkldnn enabled CPU or other CPU + extensions, etc AND a GPU or other dedicated 'accelerator' on a system, do I still have to call torch.is_xxx_available for some 'non-accelerator' devices before I can use them, and then now a different API for 'accelerators'?

Do you mean if you still need to call is_cuda_available, is_xpu_available, is_hpu_available, and is_npu_available, etc in your model script? This depends on what the situation is.

  • If you would like to call some generic API. You can use torch.accelerator.is_available instead of them, like the below code.
if torch.accelerator.is_available():
    s1 = torch.Stream()
    s2 = torch.Stream()
    e = torch.Event()
    torch.accelerator.set_stream(s1)
    # do some generic thing
    torch.accelerator.synchronize()
else:
    # do something for CPU
  • Otherwise, you have to call is_xxx_vailable to write the device-specific code. We recommend the below code, like:
if torch.accelerator.is_available():
    if torch.accelerator.current_accelerator().type == 'cuda':
        # do cuda-specific something
    elif torch.accelerator.current_accelerator().type == 'xpu':
        # do xpu-specific something
    elif torch.accelerator.current_accelerator().type == 'npu':
        # do npu-specific something
    ....

Here we assume we could categorize all devices into two classes: CPU and Accelerators. If a new device is neither CPU nor Accelerator, we could use the below code to distinguish them.

if torch.accelerator.is_available():
    # do something for Accelerators
elif is_xxx_available:
    # do something for non-accelerator xxx
else:
    # do something for CPU

guangyey avatar Oct 22 '24 06:10 guangyey

@guangyey having written quite a few train/val scripts I feel it is possible to support many different types of devices in one scripts fairly cleanly.

When I refer to capabilities I look at the snippet you provided (below) and it suggests there are two capability sets, accelerators that support streams, events, synchronize, and CPU that support none of that and nothing else. So yes, this is okay to work with, but

  1. all accelerators support streams and events and synchronize
  2. cpu never do

That doesn't seem true right now but maybe goal is to make it so? I think Habana supports streams, but does XLA + TPU?

What about those graph based accelerators that need a .mark_step() or equivalent (TPU, HPU, IPU?) ?

if torch.accelerator.is_available():
    s1 = torch.Stream()
    s2 = torch.Stream()
    e = torch.Event()
    torch.accelerator.set_stream(s1)
    # do some generic thing
    torch.accelerator.synchronize()
else:
    # do something for CPU

Setting aside 'accelerator' for a moment and sticking with 'devices'

device = torch.device(my_device) 
assert device.is_available()

s1, s2 = None, None
if device.has_streams():
    s1 = torch.Stream()
    s2 = torch.Stream()

...

if s1 is not None:
  device.set_stream(s1)

if device.has_graph():
   device.mark_step()


if device.has_events()
   ...

So there you have more fine-grained capabilities, you can still make a generic script (or fail if your script needs an unsupported capability) but you aren't specific to whether it's an hpu or a npu or as a script maintainer / modeller, someones just realeased a zpu.

I'd really like to be able to write scripts that work with ALL devices. I feel that's possible, this is moving in a good direction, trying to convince myself it covers all the major corner cases with devices I might need to support...

rwightman avatar Oct 22 '24 13:10 rwightman

@rwightman From my perspective, it isn't easy to define CPU stream/device behavior. This is because the programming model of the CPU is deeply ingrained in our minds. For example, how to define set_device for CPU device if there are two sockets in a host. And how to define stream for CPU device. Should it be to bind some CPU thread sources to a CPU stream, represent a CPU process, or have nothing to do? Currently, if we assume the CPU never do stream/event/synchronization, the code written according to GPU programming conventions may not impact data dependencies and not lead to accuracy issues on CPU. However, I believe it will affect the CPU performance, for example, cache utilization? The more important is if CPU has a design to bind CPU stream to some hardware resources, like socket/core/thread in the future, the accuracy and performance will be unpredicted. I couldn't agree more with you about writing simple scripts that work with ALL devices. But at this stage, it is so hard for me to unify the behavior CPU and other accelerators. We can start with simple tasks, such as unifying these APIs for other accelerators since they all have stream/synchronization and already exist. And then we can consider further design developments based on this design philosophy. Anyway, this PR could unify these APIs for all the accelerators and aim to facilitate the user scripts, right?

What about those graph based accelerators that need a .mark_step() or equivalent (TPU, HPU, IPU?) ?

At this stage, I think leave its mark_step as it is and leverage these common generic APIs (set_device/device_count) to unify the scripts. I believe torch.accelerator.set_stream is also suitable for them. Later, we will also unify the allocator APIs, like torch.accelerator.empty_cache, to simplify the user scripts. The fine-grained design you provided will be studied. Perhaps, we could provide a generic and more PyTorch-idiomatic API named torch.accelerator.mark_step for these graph accelerators? I appreciate your insights and feedback. Thanks very much.

guangyey avatar Oct 23 '24 10:10 guangyey

I guess my main point is I feel life would be easier if we stopped thinking about it in relation to device type, but instead focus on capabilities. Still having the distinction between 'cpu' and 'accelerator' is keeping a foot in the old way of thinking and leaving the gate open to more if device == 'cpu' elif device == 'cuda/accelerator/whatever' type behaviour which is brittle because it is more likely to result in device specific scripts...

If devices are devices and we use an API that is capabilities focused (regardless of whether a CPU can ever support any of the other capabilities) we will end up with more device agnostic code. When new devices and device extensions get added out there there will be fewer breaks/issues. The accelerator interface will clearly need extension in the future for new capabilities, so not switching now is just pushing that down the road and inviting more if cpu, elif accelerator, elif accelerator_with_x, etc.

rwightman avatar Oct 23 '24 16:10 rwightman

@rwightman I agree with you that our final goal is facilitating the user easier to write more device-agnostic code. When the user writes the code synchronize, he already has an assumption that the code should be running on accelerators. I am concerned about whether the code oriented towards accelerator programming is efficient for CPU? So we place the APIs under namespace torch.accelerator to let user be aware of that. On the other hand, if there is a simple logic suitable for both CPU and accelerator, the user could write the code like below.

# do something
if torch.accelerator.is_available():
    torch.accelerator.synchronize()
# do something

I think this aligns better with the user's programming mindset. I also agree with you about, as you said, defining set_device and synchronize as no-op for CPU can produce some simpler code in some situations.

guangyey avatar Oct 24 '24 11:10 guangyey

Thanks for taking the time to chime in @rwightman !

These are definitely very valid points! I think I agree with you that device capability-based device module is a good global solution. I think there are 3 pieces here in my mind:

  1. The any-device shared API available on every "device module", as given by torch.get_device_module().
  2. An accelerator-like device shared API
  3. Per-device API (is the hw feature enabled or disabled for example)

I do think there is a need for each layer there, mainly 3) will always be needed as there are always device-specific stuff we'll need, 2) is needed because we have a whole set of device that have a shared set of capabilities and having a single concept for them allows us to a) unify the stack for them all the way down (less code, less bugs, better consistency), b) have a common language to talk for each of these devices and c) expose a simpler PrivateUse1 API for out-of-tree backends that want to opt-in to this. I also think that a lot of library code we have today (fsdp, ac, offloading, etc) have a strong "cpu vs accelerator" idea, thus explicitly defining these concepts helps.

  1. is even one step further, where we need to enforce consistency across all the devices.

I'm sure you've seen part of the discussion in the issues above about device module vs custom namespace. My concern with aiming for 1) straight away is that it will be even more alignment work to convince everyone about it. 2) is a good intermediate state to reduce the amount of alignment and provide a good chunk of the benefits.

albanD avatar Oct 24 '24 20:10 albanD

The change sounds good. Doc build needs fixing

Thanks very much, doc has been fixed.

guangyey avatar Oct 27 '24 03:10 guangyey

"Unrelated failure" @pytorchbot merge -i

guangyey avatar Oct 27 '24 03:10 guangyey

Merge started

Your change will be merged while ignoring the following 1 checks: xpu / win-vs2022-xpu-py3 / build

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar Oct 27 '24 03:10 pytorchmergebot

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command For more information see pytorch-bot wiki.

pytorchmergebot avatar Oct 27 '24 09:10 pytorchmergebot

@pytorchbot merge -f "unrelated failure, Macos job was queuing"

guangyey avatar Oct 27 '24 10:10 guangyey

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar Oct 27 '24 10:10 pytorchmergebot