ao icon indicating copy to clipboard operation
ao copied to clipboard

[MPS] torchao low-bit-precision optim does not expose 'backend' argument to torch.compile

Open bghira opened this issue 1 year ago • 10 comments

on apple mps platforms, torchao training works great until we involve the AdamW8bit optimiser:

    assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: Device mps not supported

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

and conditionally setting:

    if torch.backends.mps.is_available():
        import torch._dynamo
        torch._dynamo.config.suppress_errors = True

to suppress this error and 'fallback to eager mode' does not work in this situation, it merely hides the notice that this parameter could be set.

using aot_eager is required instead of inductor for MPS, but additionally it kinda limits third-party backends.

Can we expose the backend parameter, perhaps as compile_backend ?

bghira avatar Sep 26 '24 17:09 bghira

This should be an easy PR to make, @bghira would you be interested in taking a stab at this? If you need any advice we hang out on #torchao on discord.gg/gpumode

Also curious is this a production use case, we haven't taken Mac perf super seriously but hey if we have users maybe we should

msaroufim avatar Sep 26 '24 18:09 msaroufim

actually, using aot_eager gets autograd involved and then dtype complaints happen. the gradients need to be in fp32 precision ... for a low bit optim? 🤔

bghira avatar Sep 26 '24 18:09 bghira

yeah simpletuner supports finetuning diffusion models via torch-mps w/ or w/o optimum-quanto up to the 12B parameter Flux model, which really takes advantage of quantisation, down from 30G at pure bf16 (stochastic rounding etc) training to 15GB with quantisation to int8 (mps does n't support fp8)

bghira avatar Sep 26 '24 18:09 bghira

Oh interesting you're also looking at diffusion models? we have a working group now dedicated towards that

msaroufim avatar Sep 26 '24 18:09 msaroufim

either way not seeing memory savings with the 8bit adamw as i need the gradients to be upcast to fp32. the 4bit optim uses some ops not implemented on MPS pytorch yet, and enabling CPU fallback results in:

(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %4 = "mps.multiply"(%2, %arg2) : (tensor<16x3072xf32>, tensor<1xbf16>) -> tensor<*xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %4 = "mps.multiply"(%2, %arg2) : (tensor<16x3072xf32>, tensor<1xbf16>) -> tensor<*xf32>
/AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification'
Traceback (most recent call last)

glorious MPS-library-level error and full application fault

bghira avatar Sep 26 '24 18:09 bghira

Adding a compile backend flag makes sense, though I'm not sure which other backends are also useful for codegen optim.

If I'm not wrong, aot_eager will run in eager mode, so there will be no memory saving benefits. The memory saving relies on the fact that we do dequant and re-quant inside the kernel, so the dequant tensors will never materialize in global memory.

gau-nernst avatar Sep 26 '24 21:09 gau-nernst

so you will require a custom mps extension for pytorch which accomplishes the same thing that you currently rely on cuda kernels for?

eg. following apple's example: https://developer.apple.com/documentation/metal/metal_sample_code_library/customizing_a_pytorch_operation

they provide a downloadable compileable sample:

'''
Copyright © 2023 Apple Inc.

See LICENSE folder for this sample’s licensing information.

Abstract:
The code for compiling the custom pytorch extension.
'''

import torch.utils.cpp_extension

compiled_lib = torch.utils.cpp_extension.load(
    name='CustomSoftshrink',
    sources=['CustomSoftshrink.mm'],
    extra_cflags=['-std=c++17'],
   )

and the relevant cpp example code that links into Metal directly:

/*
See the LICENSE.txt file for this sample’s licensing information.

Abstract:
The code that registers a PyTorch custom operation.
*/


#include <torch/extension.h>
#include "CustomSoftshrink.h"

#import <Foundation/Foundation.h>
#import <Metal/Metal.h>

// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}

torch::Tensor& dispatchSoftShrinkKernel(const torch::Tensor& input, torch::Tensor& output, float lambda) {
    @autoreleasepool {
        id<MTLDevice> device = MTLCreateSystemDefaultDevice();
...

bghira avatar Sep 26 '24 22:09 bghira

Yes, if you want 8-bit Adam for MPS, you probably need a custom kernel. (Unless torch.compile or triton support emitting MPS kernel in the future :eyes:)

Just curious, if you want memory-saving techniques, would something like LoRA/QLoRA be more suitable? Low-bit optim only makes sense when you train the whole big model. And I don't think it's practical to do full fine-tune Flux on MPS (yes, you mentioned about developing on MPS first then move to CUDA, so there are still some valid points about having a working low bit optim on MPS, though you probably can do testing on CPU instead).

gau-nernst avatar Sep 27 '24 01:09 gau-nernst

we are already quantising the whole model though with a LoRA - actually, lycoris LoKr in this case.

for my 128G unit, i can do a full finetune with 57G usage when ZeRO3 is working to use storage offload, and was hoping for more options

tinygrad is doing autogen metal kernels 🤷 it's possible to achieve

bghira avatar Sep 27 '24 03:09 bghira

@msaroufim I don't know if there are already efforts on this, but a working group for inductor/triton+MPS might be interesting :eyes: (totally out of scope of this issue though)

gau-nernst avatar Sep 27 '24 03:09 gau-nernst