mixture-of-experts icon indicating copy to clipboard operation
mixture-of-experts copied to clipboard

Regarding experts.w1 and experts.w2 gradients

Open MukundVarmaT opened this issue 3 years ago • 1 comments

Hi @lucidrains

I am a little confused about how the parameters experts.w1 and experts.w2 are updated. The top1 operation is non-differentiable and therefore the gradients of these two parameters would be None. To confirm i even ran the following:

moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    hidden_dim = 512 * 4,           # size of hidden dimension in each expert, defaults to 4 * dimension
    activation = nn.LeakyReLU,      # use your preferred activation, will default to GELU
    second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
    second_policy_eval = 'random',  # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
    second_threshold_train = 0.2,
    second_threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    loss_coef = 1e-2                # multiplier on the auxiliary expert balancing auxiliary loss
).cuda()
inputs = torch.randn(4, 1024, 512).cuda()
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
aux_loss.backward()
for name, param in moe.named_parameters():
    if param.grad is None:
        print(name)

which gave the following output:

experts.w1
experts.w2

It would be really helpful if you could clarify my understanding. Thanks

MukundVarmaT avatar Jan 21 '22 05:01 MukundVarmaT

try it on cpu ? I can work top1 operation is non-differentiable, but the balance loss is based on logits of gating distribution and count num of tokens per expert, so actually the grad of weight should not be None

rattlesnakey avatar Jul 13 '23 14:07 rattlesnakey