flax icon indicating copy to clipboard operation
flax copied to clipboard

Performance penalty of typecast in `nn.Embed.__call__`

Open MasterSkepticista opened this issue 1 year ago • 0 comments

Hi,

I think the following typecast is potentially unnecessary or with a wrong type.

https://github.com/google/flax/blob/d718655f62edbb1c4227281f158b94b78e886814/flax/linen/linear.py#L1123-L1125

  • First, in LLMs where num_embeddings is of the order 50-100K rows, __call__ is generally invoked on a very small subset of tokens/integers (plucking specific rows). So a typecast here applies to all rows in each invocation, which feels wasteful.
  • Second, as it is consistent with other linen modules, dtype refers to computation, while param_dtype is for params. A large dot product happens in nn.Embed.attend(...) while computing logits across all embedding rows, where we typecast correctly to self.dtype.

Possible fixes (either of):

  • Removing the type promotion in __call__ (I also observed casting to lower precision does not have any speedup or convergence issues for 50k rows. It should be OK to let embeddings remain in self.param_dtype IMO), or
  • We typecast, but only after jnp.take.

MasterSkepticista avatar Jul 23 '24 05:07 MasterSkepticista