mlx-swift
mlx-swift copied to clipboard
[Feature] custom extensions
Currently, if I want to write a new custom operation and use it in Swift, I would have to maintain my own fork of MLX. It would be great to be able to bring my own CPU/GPU implementations in my own app but still be able to use them like I can in Python.
Probably your best bet for this is to use Metal custom kernels which will be added in https://github.com/ml-explore/mlx-swift/pull/137. We also have custom function (which we will plan to add to MLX Swift) to enable custom transformations (like vjp/jvp/vmap). Combining the two you should be able to do everything you can do with a custom extension (with the exception of custom CPU kernels which we don't have yet).
@awni Any progress?
Look like the custom function bridging was implemented in mlx-c but not exposed in the mlx-swift. I am really looking forward to use this feature.
Which item specifically is missing? mlx-swift has custom metal kernels, what else is missing? Thanks!
Which item specifically is missing? mlx-swift has custom metal kernels, what else is missing? Thanks!
API like torch.autograd.Function, a way to define both forward and backward, and save some state for backward.
For the context, I am trying to implement Gaussian Splatting using mlx, and more specifically I am porting logics in https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/9c5c2028f6fbee2be239bc4c9421ff894fe4fbe0/diff_gaussian_rasterization/init.py#L44-L141 using mlx.
Which item specifically is missing? mlx-swift has custom metal kernels, what else is missing? Thanks!
API like
torch.autograd.Function, a way to define bothforwardandbackward, and save some state for backward.For the context, I am trying to implement Gaussian Splatting using mlx, and more specifically I am porting logics in https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/9c5c2028f6fbee2be239bc4c9421ff894fe4fbe0/diff_gaussian_rasterization/init.py#L44-L141 using mlx.
3DGS uses a PyTorch Extension written in C++ and CUDA. The forward and backward functions, radix sort are implemented in CUDA, additionally includes glm as a third parity.
Assume that you write a metal compute shader code, I suggest using the MLX C++ custom extension. You can refer to the mlx examples in mlx/examples/extension, axpby, and CMakeLists.txt. Use python install . generates build/lib.macos-.../mlx_sample_extensions/mlx_ext.metallib.
Your current options: mlx-c, mlx-swift and mlx python use fast kernel, and mlx c++ custom kernel with nanobind to python.
I have tried to write mlx c++ and c++ custom kernel code and then import it into my iOS app project. Modify this mlx-swift repo, add a new target in Package.swift. The only problem is these metal code can not use #include "mlx/backend/metal/kernels/utils.h" just like apbxy.metal.
Directory structure:
Source/
Cmlx/
mlx
MyExtension/
Shaders/
dummy.metal
SwiftEntry/
entry.h
entry.cpp
Dummy/
dummy.h
dummy.cpp
Package.swift:
.library(name: "MyMLXExtension", targets: ["MyExtension"]),
// ...
targets: [
.target(
name: "MyExtension",
dependencies: ["Cmlx"],
resources: [.process("Shaders")], // Your metal code under Source/MyExtension/Shaders directory, they will be bundled into mlx-swift_MyExtension.bundle
publicHeadersPath: "SwiftEntry", // expose some c++ functions to Swift.
cSettings: [
.headerSearchPath("mlx"),
.headerSearchPath("mlx-c"),
],
cxxSettings: [
// Copied from Cmlx target, not sure if needed.
.headerSearchPath("mlx"),
.headerSearchPath("mlx-c"),
.headerSearchPath("metal-cpp"),
.headerSearchPath("json/single_include/nlohmann"),
.headerSearchPath("fmt/include"),
.define("MLX_USE_ACCELERATE"),
.define("ACCELERATE_NEW_LAPACK"),
.define("_METAL_"),
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),
.define("METAL_PATH", to: "\"default.metallib\""),
.define("MLX_VERSION", to: "\"0.24.2\""),
],
linkerSettings: [
.linkedFramework("Foundation"),
.linkedFramework("Metal"),
.linkedFramework("Accelerate"),
]
),
Maybe we need some swift binding tool just like nanobind.