SwitchTransformers
SwitchTransformers copied to clipboard
[BUG] Dimension error within SwitchGate
Describe the bug Shape mismatch is found in the computation of auxiliary loss values: https://github.com/kyegomez/SwitchTransformers/blob/36a1ea01448e56242222b68201207a7219d72b4b/switch_transformers/model.py#L70-L74
where load
is of shape [num_experts, dim]
and importance
is of shape [batch_size, dim]
. Testing this SwitchGate
class alone by giving an input with batch_size > 1
will raise error like this RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 0
To Reproduce
Simply run a sample with batch_size > 1
:
gate = SwitchGate(dim=16, num_experts=3)
x = torch.randn((2, 64, 16)).float()
y, loss = gate(x, use_aux_loss=True)
Upvote & Fund
- We're using Polar.sh so you can upvote and help fund this issue.
- We receive the funding once the issue is completed & confirmed by you.
- Thank you in advance for helping prioritize & fund our backlog.