mlx
mlx copied to clipboard
[FEATURE REQUEST] mx.grad doesn't alias argnums and argnames
Not a bug, a nice-to-have-or-at-least-better-error message-please kind of thing.
If you differentiate a function, you can't pass an argument in as a kwarg, even if the fn is differentiated w.r.t to that kwargs argnum:
import mlx.core as mx
import mlx.nn as mn
a = mx.array([1., 2, 3, 4])
xent_grad = mx.grad(mn.losses.cross_entropy)
xent_grad(logits=a, targets=mx.array(0))
# ^- doesn't work :-(((
# ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 arguments.
xent_grad(a, targets=mx.array(0))
or calling xent_grad = mx.grad(mn.losses.cross_entropy, argnames='logits')
are both fine of course.
An ergonomic solution would be to let me do either -- e.g.
xent_grad = mx.grad(mn.losses.cross_entropy, argnums=0, argnames='logits')
xent_grad(logits=a, targets=mx.array(0))
This doesn't work atm, even though argnum=0 and argnames='logits' are the same thing.
I think a good first step is to improve the error message. Even just adding the word positional
would probably help a bit, as in:
# ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 positional arguments.
What the right thing to do for args / kwargs is somewhat ambiguous. I can see an argument for being more relaxed. I can also see an argument for matching the argnums
/argnames
to the function parameters as we do now.
@awni I'd like to start working on core mlx
issues but I think I can start with this first.
I think a good first step is to improve the error message. Even just adding the word
positional
would probably help a bit, as in:# ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 positional arguments.
This works for the above message but what would be a more general way to improve error messages across the package? should we try matching torch
's error messages? that's tried and tested
This works for the above message but what would be a more general way to improve error messages across the package?
For now, let's take them on a one-by-one basis. If you see something that you think should be improved, feel free to file an issue and we can discuss. For this issue, I would just look at the grad error message.
This works for the above message but what would be a more general way to improve error messages across the package?
For now, let's take them on a one-by-one basis. If you see something that you think should be improved, feel free to file an issue and we can discuss. For this issue, I would just look at the grad error message.
alright then, lets do that! Seeing theres no existing PR for this, I'd like to be assigned. Ill make a PR shortly
UPDATE it does what you expect it to do
(mlx-dev) shubham@Shubhams-MBP mlx % python dummy.py
Traceback (most recent call last):
File "/Users/shubham/Documents/workspace/forks/mlx/dummy.py", line 5, in <module>
xent_grad(logits=a, targets=mx.array(0))
ValueError: [grad] Can't compute the gradient of argument index 0 because the function is called with only 0 positional arguments.
If we make all the arguments have a default value and then just check if the user has provided the value apart from default, that should allow accepting positional arguments right?
def cross_entropy(
logits: mx.array = None,
targets: mx.array = None,
weights: Optional[mx.array] = None,
axis: Optional[int] = -1,
label_smoothing: Optional[float] = 0.0,
reduction: Optional[Reduction] = "none",
) -> mx.array:
if logits is None or targets is None:
raise ValueError("Both logits and targets must be provided and cannot be None.")