mlx
mlx copied to clipboard
[Feature] Expose something like custom VJP in Python
Consider allowing a function to have a custom VJP function attached to it.
Oops I thought I had exposed that. It might be worth having a high level API for that as well (which is probably why I didn't expose it), maybe implemented in python using the custom_vjp transform? Either of the following two feels more natural to use imho
Option a:
import mlx.core as mx
@mx.custom_function
def my_addition(a, b):
return a + b
@my_addition.vjp
def my_addition_vjp(primals, outputs, cotangents):
a, b = primals
cotan = cotangents[0]
return cotan * b, cotan * a
Option b:
import mlx.core as mx
@mx.custom_function
def my_addition(a, b):
def my_vjp(primals, outputs, cotangents):
a, b = primals
cotan = cotangents[0]
return cotan * b, cotan * a
return a + b, my_vjp
I lean towards the first one. Obviously names and details need some changing.
I think the interface your propose is good, but if it becomes pytorch compatible, it might be useful when migrating.
Anyway, I am eagerly waiting for this feature to be implemented.
Here is an example of a torch.autograd.Function compatible interface:
import mlx.core as mx
class MxCustomVjp():
@classmethod
def apply(cls, *args, **kwargs):
ctx = cls()
ctx.save_for_backward(*args)
r = ctx.forward(ctx, *args, **kwargs)
return r if type(r) == tuple else (r,), ctx
def save_for_backward(self, *args):
self.saved_tensors = args
def vjp(self, grad_output):
return self.__class__.backward(self, *grad_output)
@staticmethod
def forward(self, *args, **kwargs):
raise NotImplementedError
@staticmethod
def backward(self, *args, **kwargs):
raise NotImplementedError
class CustomLinear(MxCustomVjp):
@staticmethod
def forward(ctx, x, weight):
ctx.save_for_backward(x, weight)
return x @ weight
@staticmethod
def backward(ctx, grad_output):
x, weight = ctx.saved_tensors
grad_x = grad_output @ weight.T
grad_weight = x.T @ grad_output
return grad_x, grad_weight
pair2x = lambda a, b: (a * 2, b * 2,)
mxMean = lambda a: a.mean()
def test_forward(x, weight):
r_pair2x = pair2x(x, weight)
r_lin, ctx_lin = CustomLinear.apply(*r_pair2x)
r_mean = mxMean(*r_lin)
# r_pair2x is not used in backward
backward_bucket = (r_lin, ctx_lin)
return r_mean, backward_bucket
# I want same process with autograd
def manual_value_and_grad(x, weight):
# forward
r, backward_bucket = test_forward(x, weight)
# backward
_, grad_mean = mx.vjp(mxMean, backward_bucket[0], (mx.ones_like(r),))
grad_lin = backward_bucket[1].vjp(grad_mean)
_, grad = mx.vjp(pair2x, (x, weight,), grad_lin)
return r, grad
x = mx.random.normal((1,3))
weight = mx.random.normal((x.shape[-1],3))
(r, _), grad = mx.value_and_grad(test_forward, [0, 1])(x, weight)
r_m, grad_m = manual_value_and_grad(x, weight)
assert mx.allclose(r, r_m).item()
assert mx.allclose(grad[0], grad_m[0]).item()
assert mx.allclose(grad[1], grad_m[1]).item()
First, I want to thank you for all the incredible work you have done on MLX. It has been an invaluable tool for my projects.
I am currently working on a project that heavily relies on custom gradients, and the custom vjp functionality discussed here would be extremely beneficial for my use case. Are there any plans to implement this feature in the near future?
Or could you please let me know if there are any corresponding interfaces in the underlying C++ layer that I could use to implement a simple version of custom vjp? I am using mlx-swift, and I am considering directly utilizing the C++ interfaces to meet my needs.
Thank you once again for your hard work and dedication.
@kemchenj we have a custom_vjp transformation in C++.
@angeloskath regarding the Python interface. I think in option A is nicer for enabling any transformation to be customized which I think is preferable.
I'm not so crazy about the name custom_function. Maybe customizable or even extension. Does this make sense:
my_extendable_fun = mx.extension(my_fun)
my_extendable_fun.vjp = ...
my_extendable_fun.eval_cpu = ...
my_extendable_fun.vmap = ...
Presumably it will fall back to the default VJP / other transforms if they are not implemented?
This is available in v0.16. See the documentation for more information.
If you have any feedback or run into any issues please let us know as this is a new feature there may be some kinks that need to be ironed out.