[Feature request] torch.topk like ops support
Currently, in order to fine-tune the Moe Model, we have to use stop_gradient on mx.argpartition for selecting the top k experts. Otherwise, it will throw an error during fine-tuning because mx.argpartition doesn't support backpropagation. Please correct me if I am wrong, but my understanding is that with this approach, we won't be able to effectively fine-tune the expert's MLP layers since the gradients don't flow through the topk selection. Wondering if, by any chance, we can implement similar operations as torch.topk that support backpropagation to enable proper fine-tuning for MOE models.
related issue in mlx-example -> https://github.com/ml-explore/mlx-examples/issues/394
So in general one can't trivially define the gradients with respect to the indices of any operation (argmax, argmin, argpartition, categorical, etc). We have to decide on the type of approximation we will use.
From the following snippet you will see that the gradient propagates to the "experts" even though they are selected using argpartition.
import mlx.core as mx
import mlx.nn as nn
class SwitchLayer(nn.Module):
def __init__(self, dims, experts, k):
super().__init__()
self.k = k
self.gate = nn.Linear(dims, experts)
# lazy way to initialize the weight
self.experts = nn.Linear(dims, experts*dims, bias=False)
self.out = nn.Linear(dims, dims)
def __call__(self, x):
k = self.k
g = mx.softmax(self.gate(x), -1)
gk = mx.stop_gradient(mx.argpartition(g, k, -1)[..., :k])
dims = x.shape[-1]
experts = g.shape[-1]
w = self.experts.weight.reshape(experts, dims, dims)[gk]
w = w.reshape(-1, k * dims, dims)
y = x[..., None, :] @ w.swapaxes(-2, -1)
y = y.reshape(-1, k, dims)
# This can be omitted or we can renormalize the weights or anything else.
# Gradients will flow to the experts selected by gk.
# y = y * mx.take_along_axis(g, gk, -1)[..., None]
y = y.sum(-2)
return self.out(y)
m = SwitchLayer(128, 8, 2)
x = mx.random.normal(shape=(16, 128))
mx.eval(m.parameters(), x)
def loss(x):
return m(x).sum()
loss_and_grad = nn.value_and_grad(m, loss)
s, g = loss_and_grad(x)
mx.eval(s, g)
The above is the most common approximation used I think. Others would be to re-normalize the weights or do some form of principled sampling to estimate the gradient as if we were computing the value for all the experts.
@angeloskath Thanks for the detailed explanation, but it's kind of beyond my knowledge. Just to clarify, based on the example above, it seems like the issue is not with stop_gradient on argpartition. We may just need to update how we apply the forward pass to the selected expert in Mixtral example, in order to ensure that the gradient flows correctly to experts. Please let me know if I have misunderstood this.
Kind of yes. In the above actually omitting the multiplication makes the gradients to the gate 0. But adding it back in makes them flow to the gate as well so it can be learned.
I would do one of the following
# Assigning the weights of the selected experts to gw for brevity
gw = mx.take_along_axis(g, gk, -1)[..., None]
# The commented line, which makes the gradients flow.
y = y * gw
# The exact output as above but the gradients propagate as if the
# line was not commented.
y = y * (gw - mx.stop_gradient(gw + mx.ones_like(gw)))
# Renormalize the weights using softmax. Allows the ability to add a
# temperature to make it more selective to the max weight.
tau = 1.0
y = y * mx.softmax(gw * tau, -2)
# Plain old renormalization
y = y * (gw / gw.sum(-2, keepdims=True))
Another way to think about this is that no matter what you do, the experts will get some gradient because you are using their output. However, for the gater to get gradients, you need to use its output somehow and not just the indices predicted (via topk, sampling or anything else).
The gradient of something like an argmax would be zeros almost everywhere.. I don't think that's really the point of this discussion (which is very nice), but probably we should have that as the VJP for consistency with other functions which have discontinuous gradients (e.g. mx.equal).
Edit, let me ask it as a question, since I'm not certain: should we have zeros as the default gradient for argmax and friend?
This is what I see Jax did FWIW:
import jax
import jax.numpy as jnp
def fun(x):
return jnp.argmax(x)
x = jnp.array([1.0, 2.0, 3.0])
out, vjf = jax.vjp(fun, x)
print(vjf(jnp.array(1)))
Gives all zeros.
(Array([0., 0., 0.], dtype=float32),)
Same for argpartition.
Clearly this is not helpful at all for @mzbac use case, since it doesn't change the flow of gradient to the gates.. it is effectively a stop gradient on the argpartition.
I believe that not implementing the gradient and asking people to add a stop gradient is a bit better as
- we don't have to build that part of the backward which will result in an op with 0s
- it makes it very clear to the user that this is not possible so they should wrap with stop_gradient
Now, otoh, it is the implicit behavior in the other frameworks. For instance
x = torch.rand(10).requires_grad_(True)
print(x.argmax().requires_grad) # False
While I tend to agree with you, I also find the inconsistency with other zero-grad ops a bit incongruous (e.g. a > b). Maybe a good compromise is to default a stop gradient the outputs?
It's a bit implicit but maybe preferable to forcing a stop grad everywhere those ops are used.