flax
flax copied to clipboard
Performance penalty of typecast in `nn.Embed.__call__`
Hi,
I think the following typecast is potentially unnecessary or with a wrong type.
https://github.com/google/flax/blob/d718655f62edbb1c4227281f158b94b78e886814/flax/linen/linear.py#L1123-L1125
- First, in LLMs where
num_embeddingsis of the order 50-100K rows,__call__is generally invoked on a very small subset of tokens/integers (plucking specific rows). So a typecast here applies to all rows in each invocation, which feels wasteful. - Second, as it is consistent with other linen modules,
dtyperefers to computation, whileparam_dtypeis for params. A large dot product happens innn.Embed.attend(...)while computing logits across all embedding rows, where we typecast correctly toself.dtype.
Possible fixes (either of):
- Removing the type promotion in
__call__(I also observed casting to lower precision does not have any speedup or convergence issues for 50k rows. It should be OK to let embeddings remain inself.param_dtypeIMO), or - We typecast, but only after
jnp.take.