SwitchTransformers icon indicating copy to clipboard operation
SwitchTransformers copied to clipboard

[BUG] Dimension error within SwitchGate

Open Masaaki-75 opened this issue 8 months ago • 1 comments

Describe the bug Shape mismatch is found in the computation of auxiliary loss values: https://github.com/kyegomez/SwitchTransformers/blob/36a1ea01448e56242222b68201207a7219d72b4b/switch_transformers/model.py#L70-L74

where load is of shape [num_experts, dim] and importance is of shape [batch_size, dim]. Testing this SwitchGate class alone by giving an input with batch_size > 1 will raise error like this RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 0

To Reproduce Simply run a sample with batch_size > 1:

gate = SwitchGate(dim=16, num_experts=3)
x = torch.randn((2, 64, 16)).float()
y, loss = gate(x, use_aux_loss=True)

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Masaaki-75 avatar Jun 09 '24 09:06 Masaaki-75