mlx-swift icon indicating copy to clipboard operation
mlx-swift copied to clipboard

[Feature] custom extensions

Open petertsoi opened this issue 1 year ago • 1 comments

Currently, if I want to write a new custom operation and use it in Swift, I would have to maintain my own fork of MLX. It would be great to be able to bring my own CPU/GPU implementations in my own app but still be able to use them like I can in Python.

petertsoi avatar Aug 01 '24 05:08 petertsoi

Probably your best bet for this is to use Metal custom kernels which will be added in https://github.com/ml-explore/mlx-swift/pull/137. We also have custom function (which we will plan to add to MLX Swift) to enable custom transformations (like vjp/jvp/vmap). Combining the two you should be able to do everything you can do with a custom extension (with the exception of custom CPU kernels which we don't have yet).

awni avatar Sep 30 '24 16:09 awni

@awni Any progress?

Look like the custom function bridging was implemented in mlx-c but not exposed in the mlx-swift. I am really looking forward to use this feature.

kemchenj avatar Jul 22 '25 15:07 kemchenj

Which item specifically is missing? mlx-swift has custom metal kernels, what else is missing? Thanks!

davidkoski avatar Jul 22 '25 15:07 davidkoski

Which item specifically is missing? mlx-swift has custom metal kernels, what else is missing? Thanks!

API like torch.autograd.Function, a way to define both forward and backward, and save some state for backward.

For the context, I am trying to implement Gaussian Splatting using mlx, and more specifically I am porting logics in https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/9c5c2028f6fbee2be239bc4c9421ff894fe4fbe0/diff_gaussian_rasterization/init.py#L44-L141 using mlx.

kemchenj avatar Jul 26 '25 04:07 kemchenj

Which item specifically is missing? mlx-swift has custom metal kernels, what else is missing? Thanks!

API like torch.autograd.Function, a way to define both forward and backward, and save some state for backward.

For the context, I am trying to implement Gaussian Splatting using mlx, and more specifically I am porting logics in https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/9c5c2028f6fbee2be239bc4c9421ff894fe4fbe0/diff_gaussian_rasterization/init.py#L44-L141 using mlx.

3DGS uses a PyTorch Extension written in C++ and CUDA. The forward and backward functions, radix sort are implemented in CUDA, additionally includes glm as a third parity.

Assume that you write a metal compute shader code, I suggest using the MLX C++ custom extension. You can refer to the mlx examples in mlx/examples/extension, axpby, and CMakeLists.txt. Use python install . generates build/lib.macos-.../mlx_sample_extensions/mlx_ext.metallib.

Your current options: mlx-c, mlx-swift and mlx python use fast kernel, and mlx c++ custom kernel with nanobind to python.

I have tried to write mlx c++ and c++ custom kernel code and then import it into my iOS app project. Modify this mlx-swift repo, add a new target in Package.swift. The only problem is these metal code can not use #include "mlx/backend/metal/kernels/utils.h" just like apbxy.metal.

Directory structure:

Source/
    Cmlx/
          mlx 
    MyExtension/
          Shaders/
               dummy.metal
          SwiftEntry/
                entry.h
                entry.cpp
          Dummy/
               dummy.h
               dummy.cpp

Package.swift:

.library(name: "MyMLXExtension", targets: ["MyExtension"]),
// ...
targets: [
        .target(
            name: "MyExtension",
            dependencies: ["Cmlx"],
            resources: [.process("Shaders")], // Your metal code under Source/MyExtension/Shaders directory, they will be bundled into mlx-swift_MyExtension.bundle
            publicHeadersPath: "SwiftEntry", // expose some c++ functions to Swift.
            cSettings: [
                .headerSearchPath("mlx"),
                .headerSearchPath("mlx-c"),
            ],

            cxxSettings: [
                // Copied from Cmlx target, not sure if needed.
                .headerSearchPath("mlx"),
                .headerSearchPath("mlx-c"),
                .headerSearchPath("metal-cpp"),
                .headerSearchPath("json/single_include/nlohmann"),
                .headerSearchPath("fmt/include"),
                .define("MLX_USE_ACCELERATE"),
                .define("ACCELERATE_NEW_LAPACK"),
                .define("_METAL_"),
                .define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),
                .define("METAL_PATH", to: "\"default.metallib\""),
                .define("MLX_VERSION", to: "\"0.24.2\""), 
            ],
            linkerSettings: [
                .linkedFramework("Foundation"),
                .linkedFramework("Metal"),
                .linkedFramework("Accelerate"),
            ]
        ),

Maybe we need some swift binding tool just like nanobind.

a1091150 avatar Oct 25 '25 08:10 a1091150