maxtext
maxtext copied to clipboard
`attend_dtype` not used
Here it seems that the hard-coded bfloat16
is used instead of attend_dtype
. Also query
is not cast. I guess the correct behavior should be casting both query
and self.embedding
to attend_dtype
?
yes weird. @khatwanimohit can you take a look? I'm not sure what this is meant to represent? And the upstream flag is also kind of weird given that it is orphaned? https://github.com/google/maxtext/blob/5353a957594bd6cf316747cd5a327c163caca74f/MaxText/layers/models.py#L341
I think we should figure out if (a) does doing the dot in f32 help convergence (using the 1B runs)? (b) does @ZhiyuLi-goog/MLPerf care? (c) what does Anselm Levskaya think
We should make the code consistent and as simple as possible. Also, why is our pylint/pytype not raising alarms on this, unused vars are bad?