mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Custom Metal Kernels from Python

Open barronalex opened this issue 1 year ago • 7 comments

This PR let's you define custom Metal kernels through the C++/Python API.

The user provides the body of the Metal kernel and then we use the shape/dtype of the arrays it's called with together with provided output_shapes/output_dtypes to determine what the function signature should be.

Metal attributes such as [[thread_position_in_grid]] are automatically added to the function signature if present in the body.

Template parameters can be set with kernel.template(**kwargs).

For example:

source = """
    uint elem = thread_position_in_grid.x;
    T tmp = inp[elem];
    out[elem] = metal::exp(tmp);
"""
A = mx.random.normal(shape=(4, 16)).astype(mx.float16)
kernel = mx.fast.MetalKernel(
    name="myexp",
    source=source,
    grid=(A.size, 1, 1),
    threadgroup=(256, 1, 1),
    output_shapes={"out": A.shape},
    output_dtypes={"out": A.dtype},
    verbose=True,  # prints the full generated kernel
)
kernel.template(T=mx.float32)
out = kernel(inp=A)

Generates:

template <typename T>
[[kernel]] void custom_kernel_myexp(
  const device float16_t* inp [[buffer(0)]],
  device float16_t* out [[buffer(1)]],
  uint3 thread_position_in_grid [[thread_position_in_grid]]) {

    uint elem = thread_position_in_grid.x;
    T tmp = inp[elem];
    out[elem] = metal::exp(tmp);

}

template [[host_name("custom_kernel_myexp_float_")]] [[kernel]] decltype(custom_kernel_myexp<float>) custom_kernel_myexp<float>;

@angeloskath's nice 3x3 inversion example from #1238 becomes:

def invert_3x3(A):
    source = """
        int elem = thread_position_in_grid.x;
        int index = elem * 9;
        T a11 = a[index];
        T a12 = a[index + 1];
        T a13 = a[index + 2];
        T a21 = a[index + 3];
        T a22 = a[index + 4];
        T a23 = a[index + 5];
        T a31 = a[index + 6];
        T a32 = a[index + 7];
        T a33 = a[index + 8];
        T det = (
            a11 * a22 * a33
            + a12 * a23 * a31
            + a13 * a21 * a32
            - a11 * a23 * a32
            - a12 * a21 * a33
            - a13 * a22 * a31
        );
        out[index] = (a22 * a33 - a23 * a32) / det;
        out[index + 1] = (a13 * a32 - a12 * a33) / det;
        out[index + 2] = (a12 * a23 - a13 * a22) / det;
        out[index + 3] = (a23 * a31 - a21 * a33) / det;
        out[index + 4] = (a11 * a33 - a13 * a31) / det;
        out[index + 5] = (a13 * a21 - a11 * a23) / det;
        out[index + 6] = (a21 * a32 - a22 * a31) / det;
        out[index + 7] = (a12 * a31 - a11 * a32) / det;
        out[index + 8] = (a11 * a22 - a12 * a21) / det;
    """
    kernel = mx.fast.MetalKernel(
        name="invert3x3",
        source=source,
        output_shapes={"out": A.shape},
        output_dtypes={"out": A.dtype},
        grid=(A.size // 9, 1, 1),
        threadgroup=(256, 1, 1),
        verbose=True,
    )
    kernel.template(T=A.dtype)
    return kernel(a=A)["out"]

A = mx.random.normal(shape=(16, 3, 3))
A = mx.matmul(A, mx.swapaxes(A, -1, -2))
A_inv = mx.linalg.inv(A, stream=mx.cpu)
A_inv_metal = invert_3x3(A)
mx.allclose(A_inv, A_inv_metal, atol=1e-4)

Fusing the mx.split/mx.concatenate get's you a nice speed up, particularly for large batch sizes: inv

barronalex avatar Aug 13 '24 18:08 barronalex

Wow this is fantastic! Just wondering if this PR got approved and merged, how hard would it be to support Accelerate in a similar fashion? I can see a HUGE potential from this PR!

mzy2240 avatar Aug 14 '24 11:08 mzy2240

Thanks for the comments -- very helpful! Let me get the tests to pass and add a few more updates. Then it should be worth having another look.

barronalex avatar Aug 14 '24 23:08 barronalex

This is really neat!

I didn't look at the internals yet.. but one thing I'm wondering about is why the grid (and a few other parameters) gets specified when you construct the kernel as opposed to when you launch it? Is that a constraint due to some other reason or is there a way to make it more flexible?

awni avatar Aug 15 '24 13:08 awni

Really everything is happening in kernel.__call__ since we can't construct the full function without knowing the names and types of the inputs. I've gone back and forth but I have grid etc being passed to kernel.__init__ mainly because the syntax looked a little bit prettier.

Another option would be forgoing the object altogether and just passing everything in one function call:

mx.fast.metal_kernel(
    name="myexp",
    source=source,
    grid=(4, 1, 1),
    threadgroup=(2, 1, 1),
    inputs={"a": a},
    template_args={"T": a.dtype},
    output_shapes={"out": out.shape},
    output_dtypes={"out": out.dtype},
)

Definitely open to suggestions

barronalex avatar Aug 15 '24 14:08 barronalex

I guess maybe a better question is what do you do if you want to use the same kernel with outputs with different shapes?

Do you make a new MetalKernel in that case? Under the hood will it recompile if only the grid changes?

If that's the expected workflow I think the single function version is nicer.

awni avatar Aug 15 '24 14:08 awni

Well when I was thinking about how to do that I was thinking that we would simply pass a function from inputs to grid dimensions.

We could even define this function to be from input shapes to grid dimensions but except if we need the strides (which we shouldn't) I think using the inputs is more general.

angeloskath avatar Aug 15 '24 14:08 angeloskath

I'm using a hash of the source and the template arguments to get the host_name so it won't recompile if you change the output_shapes, grid and/or threadgroup but have the same source and template.

I did consider having grid accept a function and then you'd need output_shapes/output_dtypes to accept a function as well. It seemed simpler to let the user write the logic directly as in the invert_3x3 example above. There shouldn't be much overhead from regenerating the kernel and it's happening at graph construction time anyway, but maybe it's worth measuring to make sure.

barronalex avatar Aug 15 '24 16:08 barronalex

Closes #1025

awni avatar Aug 19 '24 13:08 awni

Thanks for the review @awni! I think it's in a better state now but let me know if you're OK with the single function approach.

barronalex avatar Aug 19 '24 19:08 barronalex

I'm using a hash of the source and the template arguments to get the host_name so it won't recompile if you change the output_shapes, grid and/or threadgroup but have the same source and template.

I did consider having grid accept a function and then you'd need output_shapes/output_dtypes to accept a function as well. It seemed simpler to let the user write the logic directly as in the invert_3x3 example above. There shouldn't be much overhead from regenerating the kernel and it's happening at graph construction time anyway, but maybe it's worth measuring to make sure.

Coming back to this for a minute. I'm still not sure if we should separate the kernel generation from the kernel calling into two steps or not. Even if most of the work doesn't happen until you call it.. it feels a bit more intuitive from a usage standpoint based on the way one typically builds and runs a kernel. For example taking your exp custom kernel, it might look like this instead:

source = """
    uint elem = thread_position_in_grid.x;
    T tmp = inp[elem];
    out[elem] = metal::exp(tmp);
"""

# Build
kernel = mx.fast.metal_kernel(
    name="myexp",
    source=source,
) 

# Call
output = kernel(
    inputs={"inp": a},
    template={"T": mx.float32},
    grid=(a.size, 1, 1),
    threadgroup=(256, 1, 1),
    output_shapes={"out": a.shape},
    output_dtypes={"out": a.dtype},
)  

Would it be potentially more scalable if the source is quite long / we can do more at "compile" time?

Does it feel more complicated from a usage standpoint?

FWIW I'm on board with what you have now.. but just want to discuss this a bit since nows the easiest time to make API changes..

awni avatar Aug 21 '24 14:08 awni

I've been going back and forth as well.

I think I agree that the above is more readable and it let's us cache the generated kernel if the inputs/outputs are the same which means the overhead can be kept to a minimum if we're calling in a tight loop.

I'll change it back.

barronalex avatar Aug 21 '24 16:08 barronalex