mlx
mlx copied to clipboard
Scatter gradient implementations
Not sure if this is an issue or I'm just holding it wrong :-) I've got an alternative approach that seems to work with gradients by building up my output up by catting slices into my tensors, which does seem to work for the grads.
I have two tensors (token_embeddings and image_features), I want to interleave them in some way. It's for a llava like model, so my input tokens are like
[
[this, is, an, <image>],
[two, images, <image>, <image>],
]
where <image>
is a sentinel token. What I've got to so far is making my output a flattened tensor [batch * max_seq, dim]
my token/image tensors [num_tokens, dim]
and [num_images, dim]
and forming indices to assign to, ending up looking like this:
output = mx.zeros(batch * max_seq, dim)
output[token_ix] = self.embeddings(token_ids)
output[image_ix] = image_features
It doesn't seem like I can get gradients for this though:
def fn(x):
output = mx.zeros([4])
output[mx.array([0, 1])] = x
return output.sum()
x = mx.array([2,3])
print(fn(x))
grad_fn = mx.grad(fn)
print(grad_fn(x))
gives me
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[310], line 10
7 print(fn(x))
9 grad_fn = mx.grad(fn)
---> 10 print(grad_fn(x))
ValueError: Primitive's vjp not implemented.
Like I said, it seems like there's an alternative in catting things since:
def fn(x, y):
val = mx.concatenate([x[:1, :], y[:, :], x[1:, :]])
return val.sum()
x = mx.array([[2,3], [4, 5]])
y = mx.array([[2,3], [4, 5]])
print(fn(x, y))
grad_fn = mx.grad(fn)
print(grad_fn(x, y))
seems to work.
Scatter vjp is indeed missing from main
at the moment. I started an implementation in the branch scatter-vjp
so that you can be unblocked. The VJP for assignment is the only one implemented so far and not particularly well tested.
Depending on the number of assignments and the size of the assigned blocks concatenating slices might not be too bad of an implementation however.
On that branch the following worked fine:
import mlx.core as mx
def fn(x):
output = mx.zeros([4, 128])
output[mx.array([0, 1])] = x
return output.square().sum()
x = mx.random.normal(shape=(2, 128))
print(fn(x))
# array(274.507, dtype=float32)
grad_fn = mx.grad(fn)
print(grad_fn(x))
# array([[1.08186, -1.9959, 1.18484, ..., 2.52877, 0.80407, 0.529424],
# [-0.74177, -1.64652, 0.113564, ..., -1.66786, 1.31119, -0.58756]], dtype=float32)
Closing as this is now merged. Scatter multiply, max, min are still missing but assign and add are there.