annotated_deep_learning_paper_implementations icon indicating copy to clipboard operation
annotated_deep_learning_paper_implementations copied to clipboard

bug in switch transformer when using torch.bfloat16

Open DogeWatch opened this issue 3 years ago • 1 comments

https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/25ad4d675039f1eccabb2f7ca6c14b11ee8d02c1/labml_nn/transformers/switch/init.py#L139 here final_output.dtype is torch.float32 and expert_output[i].dtype is torch.bfloat16 shoud set dtype of final_output like final_output = x.new_zeros(x.shape, dtype=expert_output[0].dtype)

DogeWatch avatar Aug 24 '22 12:08 DogeWatch

Should we do that or

final_output[indexes_list[i], :] = expert_output[i].to(x.dtype)

Because it seems like you changed expert to bfloat16, while the transformer general processing was in float32, and you would want rest of the transformer also to be in float32?

vpj avatar Aug 25 '22 10:08 vpj