Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] TransformerLens's use of `einsum` leads to different training dynamics on TPUs
I'm cross-posting an issue from torch_xla
to here, so TransformerLens users will have an easier time finding it. The message is: don't trust TransformerLens's HookedTransformer if you're using TPUs. I think most of the responsibility lies with torch_xla
, but it might be worth adding a warning message until it's been fixed with them.
Otherwise, the other possible fix would be to replace each instance of torch.einsum
with other torch operations. I've already done this in my own fork, so let me know if you'd like to see a PR with this change.