annotated_deep_learning_paper_implementations
annotated_deep_learning_paper_implementations copied to clipboard
bug in switch transformer when using torch.bfloat16
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/25ad4d675039f1eccabb2f7ca6c14b11ee8d02c1/labml_nn/transformers/switch/init.py#L139
here final_output.dtype is torch.float32 and expert_output[i].dtype is torch.bfloat16
shoud set dtype of final_output like final_output = x.new_zeros(x.shape, dtype=expert_output[0].dtype)
Should we do that or
final_output[indexes_list[i], :] = expert_output[i].to(x.dtype)
Because it seems like you changed expert to bfloat16, while the transformer general processing was in float32, and you would want rest of the transformer also to be in float32?