mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Scatter gradient implementations

Open c0g opened this issue 1 year ago • 1 comments

Not sure if this is an issue or I'm just holding it wrong :-) I've got an alternative approach that seems to work with gradients by building up my output up by catting slices into my tensors, which does seem to work for the grads.

I have two tensors (token_embeddings and image_features), I want to interleave them in some way. It's for a llava like model, so my input tokens are like

   [
       [this, is, an, <image>],
       [two, images, <image>, <image>],
   ]

where <image> is a sentinel token. What I've got to so far is making my output a flattened tensor [batch * max_seq, dim] my token/image tensors [num_tokens, dim] and [num_images, dim] and forming indices to assign to, ending up looking like this:

output = mx.zeros(batch * max_seq, dim)
output[token_ix] = self.embeddings(token_ids)
output[image_ix] = image_features

It doesn't seem like I can get gradients for this though:

def fn(x):
    output = mx.zeros([4])
    output[mx.array([0, 1])] = x
    return output.sum()

x = mx.array([2,3])
print(fn(x))

grad_fn = mx.grad(fn)
print(grad_fn(x))

gives me

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[310], line 10
      7 print(fn(x))
      9 grad_fn = mx.grad(fn)
---> 10 print(grad_fn(x))

ValueError: Primitive's vjp not implemented.

Like I said, it seems like there's an alternative in catting things since:

def fn(x, y):
    val = mx.concatenate([x[:1, :], y[:, :], x[1:, :]])
    return val.sum()

x = mx.array([[2,3], [4, 5]])
y = mx.array([[2,3], [4, 5]]) 
print(fn(x, y))

grad_fn = mx.grad(fn)
print(grad_fn(x, y))

seems to work.

c0g avatar Dec 24 '23 22:12 c0g

Scatter vjp is indeed missing from main at the moment. I started an implementation in the branch scatter-vjp so that you can be unblocked. The VJP for assignment is the only one implemented so far and not particularly well tested.

Depending on the number of assignments and the size of the assigned blocks concatenating slices might not be too bad of an implementation however.

On that branch the following worked fine:

import mlx.core as mx

def fn(x):
    output = mx.zeros([4, 128])
    output[mx.array([0, 1])] = x
    return output.square().sum()

x = mx.random.normal(shape=(2, 128))
print(fn(x))
# array(274.507, dtype=float32)

grad_fn = mx.grad(fn)
print(grad_fn(x))
# array([[1.08186, -1.9959, 1.18484, ..., 2.52877, 0.80407, 0.529424],
#        [-0.74177, -1.64652, 0.113564, ..., -1.66786, 1.31119, -0.58756]], dtype=float32)

angeloskath avatar Dec 25 '23 00:12 angeloskath

Closing as this is now merged. Scatter multiply, max, min are still missing but assign and add are there.

angeloskath avatar Jan 09 '24 22:01 angeloskath