[MPS] torchao low-bit-precision optim does not expose 'backend' argument to torch.compile
on apple mps platforms, torchao training works great until we involve the AdamW8bit optimiser:
assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: Device mps not supported
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
and conditionally setting:
if torch.backends.mps.is_available():
import torch._dynamo
torch._dynamo.config.suppress_errors = True
to suppress this error and 'fallback to eager mode' does not work in this situation, it merely hides the notice that this parameter could be set.
using aot_eager is required instead of inductor for MPS, but additionally it kinda limits third-party backends.
Can we expose the backend parameter, perhaps as compile_backend ?
This should be an easy PR to make, @bghira would you be interested in taking a stab at this? If you need any advice we hang out on #torchao on discord.gg/gpumode
Also curious is this a production use case, we haven't taken Mac perf super seriously but hey if we have users maybe we should
actually, using aot_eager gets autograd involved and then dtype complaints happen. the gradients need to be in fp32 precision ... for a low bit optim? 🤔
yeah simpletuner supports finetuning diffusion models via torch-mps w/ or w/o optimum-quanto up to the 12B parameter Flux model, which really takes advantage of quantisation, down from 30G at pure bf16 (stochastic rounding etc) training to 15GB with quantisation to int8 (mps does n't support fp8)
Oh interesting you're also looking at diffusion models? we have a working group now dedicated towards that
either way not seeing memory savings with the 8bit adamw as i need the gradients to be upcast to fp32. the 4bit optim uses some ops not implemented on MPS pytorch yet, and enabling CPU fallback results in:
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %4 = "mps.multiply"(%2, %arg2) : (tensor<16x3072xf32>, tensor<1xbf16>) -> tensor<*xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %4 = "mps.multiply"(%2, %arg2) : (tensor<16x3072xf32>, tensor<1xbf16>) -> tensor<*xf32>
/AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification'
Traceback (most recent call last)
glorious MPS-library-level error and full application fault
Adding a compile backend flag makes sense, though I'm not sure which other backends are also useful for codegen optim.
If I'm not wrong, aot_eager will run in eager mode, so there will be no memory saving benefits. The memory saving relies on the fact that we do dequant and re-quant inside the kernel, so the dequant tensors will never materialize in global memory.
so you will require a custom mps extension for pytorch which accomplishes the same thing that you currently rely on cuda kernels for?
eg. following apple's example: https://developer.apple.com/documentation/metal/metal_sample_code_library/customizing_a_pytorch_operation
they provide a downloadable compileable sample:
'''
Copyright © 2023 Apple Inc.
See LICENSE folder for this sample’s licensing information.
Abstract:
The code for compiling the custom pytorch extension.
'''
import torch.utils.cpp_extension
compiled_lib = torch.utils.cpp_extension.load(
name='CustomSoftshrink',
sources=['CustomSoftshrink.mm'],
extra_cflags=['-std=c++17'],
)
and the relevant cpp example code that links into Metal directly:
/*
See the LICENSE.txt file for this sample’s licensing information.
Abstract:
The code that registers a PyTorch custom operation.
*/
#include <torch/extension.h>
#include "CustomSoftshrink.h"
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}
torch::Tensor& dispatchSoftShrinkKernel(const torch::Tensor& input, torch::Tensor& output, float lambda) {
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
...
Yes, if you want 8-bit Adam for MPS, you probably need a custom kernel. (Unless torch.compile or triton support emitting MPS kernel in the future :eyes:)
Just curious, if you want memory-saving techniques, would something like LoRA/QLoRA be more suitable? Low-bit optim only makes sense when you train the whole big model. And I don't think it's practical to do full fine-tune Flux on MPS (yes, you mentioned about developing on MPS first then move to CUDA, so there are still some valid points about having a working low bit optim on MPS, though you probably can do testing on CPU instead).
we are already quantising the whole model though with a LoRA - actually, lycoris LoKr in this case.
for my 128G unit, i can do a full finetune with 57G usage when ZeRO3 is working to use storage offload, and was hoping for more options
tinygrad is doing autogen metal kernels 🤷 it's possible to achieve
@msaroufim I don't know if there are already efforts on this, but a working group for inductor/triton+MPS might be interesting :eyes: (totally out of scope of this issue though)