maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

`attend_dtype` not used

Open zhixuan-lin opened this issue 11 months ago • 1 comments

Here it seems that the hard-coded bfloat16 is used instead of attend_dtype. Also query is not cast. I guess the correct behavior should be casting both query and self.embedding to attend_dtype?

zhixuan-lin avatar Mar 18 '24 20:03 zhixuan-lin

yes weird. @khatwanimohit can you take a look? I'm not sure what this is meant to represent? And the upstream flag is also kind of weird given that it is orphaned? https://github.com/google/maxtext/blob/5353a957594bd6cf316747cd5a327c163caca74f/MaxText/layers/models.py#L341

I think we should figure out if (a) does doing the dot in f32 help convergence (using the 1B runs)? (b) does @ZhiyuLi-goog/MLPerf care? (c) what does Anselm Levskaya think

We should make the code consistent and as simple as possible. Also, why is our pylint/pytype not raising alarms on this, unused vars are bad?

rwitten avatar Mar 19 '24 03:03 rwitten