rfcs icon indicating copy to clipboard operation
rfcs copied to clipboard

RFC-0023 Unified Memory for Pytorch

Open jayfurmanek opened this issue 3 years ago • 11 comments

This RFC proposes to add Unified Virtual Memory (UVM) (or “Managed Memory”) function utilizing the managed memory allocation APIs available in CUDA/ROCm.

The proposed changes to the front end and back end have been minimized as much as possible to have a very targeted effect when UVM is enabled and have no effect at all when UVM is disabled, which will of course remain the default.

Please note that the details of these proposals are subject to revision given feedback from users and prototype testing. Please feel free to comment on the RFCs with your feedback

jayfurmanek avatar Jan 14 '22 21:01 jayfurmanek

Hi @jayfurmanek!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

facebook-github-bot avatar Jan 14 '22 21:01 facebook-github-bot

I signed it!

jayfurmanek avatar Jan 14 '22 21:01 jayfurmanek

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

facebook-github-bot avatar Jan 14 '22 22:01 facebook-github-bot

Rendered version https://github.com/ROCmSoftwarePlatform/rfcs/blob/rfc-unified-memory/RFC-0023-Unified-Memory-for-Pytorch.md.

albanD avatar Jan 18 '22 18:01 albanD

cc @mcarilli

ngimel avatar Jan 18 '22 18:01 ngimel

@dzhulgakov Do you have any examples of using an allocator extension point? I haven't found much documentation on it, and I may be looking in the wrong areas. If you have any docs etc. that provide more clarification on this I'd love to look at it. Thanks!!

dllehr-amd avatar Jan 27 '22 21:01 dllehr-amd

Do you have any examples of using an allocator extension point?

For CPU allocator there's some proper wiring in the core library allowing to override the allocator used for new tensors: https://github.com/pytorch/pytorch/blob/master/c10/core/Allocator.h#L210 and it works as the allocations from empty_cpu just go to GetAllocator (https://github.com/pytorch/pytorch/blob/master/c10/core/CPUAllocator.cpp#L138).

Unfortunately, I originally forgot that CUDA allocator is not wired through this interface. Some use cases do use it via Allocator interface. But the majority of places call CUDACachingAllocator directly. CUDACachingAllocator is not even a namespace, not a class atm: https://github.com/pytorch/pytorch/blob/72c972e1e1b4ad838de604e35269e200a70db5f2/c10/cuda/CUDACachingAllocator.h#L32 (I remember we wanted to fix it, but sadly never did).

That means that for immediate prototyping of this RFC we'd need to do something like modifying the code in-place (i.e. in a temporary fork).

The more extensible course of action would be to turn CUDACachingAllocator into a proper class and allow to override it with a different implementation like for the CPU one. The interface of it would need to be broader than the base Allocator interface though as there's the need for recordStream. It's a worthy refactoring, would you be interested in taking it on?

cc @ezyang

dzhulgakov avatar Jan 31 '22 07:01 dzhulgakov

https://github.com/pytorch/pytorch/pull/65365 adds an alternative backend for cuda allocations (and puts backends in the namespaces, THC and CUDAMallocAsync). Design is not finalized (and THC name definitely has to go), but probably we can decide what we need to do in that PR, as it will need to go in the core soon-ish, probably sooner than unified memory.

ngimel avatar Jan 31 '22 07:01 ngimel

The more extensible course of action would be to turn CUDACachingAllocator into a proper class

In https://github.com/pytorch/pytorch/pull/65365, to avoid a vtable lookup and maintain existing inlining opportunities, I deliberately chose a non-polymorphic design. Each (inline) interface function in CUDACachingAllocator.h installs the correct static function pointer which it uses from then on, for example:

inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
  static auto f = (std::strcmp(allocatorBackend(), "native") == 0) ?
    THC::raw_alloc_with_stream : CudaMallocAsync::raw_alloc_with_stream;
  return f(nbytes, stream);
}

There are only two backends now, but macros could allow extensibility without too much boilerplate. I did it this way because I thought it was neat and accommodated existing implementations cleanly without big refactors. I don't know if avoiding the vtable lookup REALLY makes a performance difference. Probably not, but then again, this code is called often.

That being said, I'm not sure we need a full-blown unique allocator implementation to accommodate managed memory. My first instinct is most (or all!) of what a cached managed memory allocator would need is identical to what the native allocator already does for cudaMalloc...so maybe we could just s/cudaMalloc/cudaMallocManaged/g in CUDACachingAllocator.cpp (or add an envvar-based conditional choice between them at all sites where cudaMalloc is currently called) and blow imaginary smoke off our finger guns.

We'd need to add a few calls to expose managed memory's prefetch-to-gpu and prefetch-to-cpu functionality, but my point is I think most of the caching code could be shared with the native allocator.

mcarilli avatar Jan 31 '22 23:01 mcarilli

Ah, Thanks for the insight, @mcarilli, we were just debating if a proper class for the allocator makes sense and hadn't considered the possible inline benefits. That's interesting.

.so maybe we could just s/cudaMalloc/cudaMallocManaged/g in CUDACachingAllocator.cpp

That was our first instinct as well! There are a few complications, or at least decisions, that have to be resolved. One is the CUDACachingAllocator has a DeviceAllocator for each present CUDA device, and the caching is relevant to each. With managed memory, it's abstracted, so that allocator hierarchy makes less sense. Prefetch and device hints are used for data locality suggestions and those need to happen somewhere. Another question is do we want to allow some designated tensors to be managed, or just have a big switch to do all or none.

Our design proposes a single allocator for all devices (CPU too). Initially the caching will look like the existing CUDACachingAllocator, but there would be room for extension if needed (for CPU focused data loading for example).

jayfurmanek avatar Jan 31 '22 23:01 jayfurmanek

What I'd like to see are some examples / case studies of when you would use UVM, as opposed to traditional management. What is it good for, what is a typical user program that is using UVM going to look like? How would you have written this code in straight line C? If you want to be ROCm specific that's fine.

ezyang avatar Feb 10 '22 21:02 ezyang