transformers
transformers copied to clipboard
[FLAX] Add dtype to embedding for gpt2 model
What does this PR do?
Add dtype to embedding for gpt2 model
Who can review?
@patrickvonplaten, @LysandreJik
The documentation is not available anymore as the PR was closed or merged.
Thanks @merrymercy, I've noticed that we omit the dtype
arg from all Flax nn.Embed
modules! @patil-suraj is there a reason why we do this?
BART: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/bart/modeling_flax_bart.py#L841-L845 BERT: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/bert/modeling_flax_bert.py#L186-L200 T5: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/t5/modeling_flax_t5.py#L1259-L1263
I don't know the reasons, but this dtype is required for half-precision training. I can modify all other classes as well if needed.
Let's wait for @patil-suraj to weigh in on this!
Gentle ping @patil-suraj
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@sanchit-gandhi could you maybe take a look here?
Sorry @patrickvonplaten - as mentioned in my previous comment https://github.com/huggingface/transformers/pull/18462#issuecomment-1209435023 I'm not sure why we omit the dtype
from all Flax nn.Embed
modules, hence the request for @patil-suraj to weight in! Maybe you could shed some light on this? It seems like intrinsic design philosophy given that we do this for all models.
The change looks good to me. T5x also puts the embedding in half precision if necessary: https://github.com/google-research/t5x/blob/1f8cec78b1f28f1955d70741792d7b6e7dd76226/t5x/examples/t5/network.py#L287
@patil-suraj what do you think?
Can we merge this?
It's interesting that we omit the dtype
arg in the embedding layer for both PyTorch and Flax:
https://github.com/huggingface/transformers/blob/cbb8a37929c3860210f95c9ec99b8b84b8cf57a1/src/transformers/models/gpt2/modeling_gpt2.py#L675-L676
Wondering if this was a deliberate design decision that we're violating in this PR? Otherwise am happy with the change for half-precision training!
Your conclusion aligns with the previous observations of embedding dtypes never being down-cast in any Transformer models, both for PyTorch and Flax!
Wondering if you could share the rationale behind why one must not down-cast embedding weights to half-precision? This would be helpful in understanding why this should be avoided and help educate us all!
I think my modification does not conflict with t5x.
My PR only changes the dtype of computation and output tensor, not the parameter type (param_dtype
).
https://github.com/google/flax/blob/0be6f32582b9acafe1741e8641a748eb99501021/flax/linen/linear.py#L732-L733
This aligns with @patrickvonplaten 's finding of the code of t5x.
@patil-suraj Please review. I am working extensively on the flax backend and am happy to contribute more code.
Hey @merrymercy,
I think nn.Embed
is an exception in Flax where providing a dtype
does exactly modify the embedding weights and not just the computation. @patil-suraj can maybe explain better here :-)
By looking at the code, I don't know why dtype
changes the type of parameters. You can check the code
https://github.com/google/flax/blob/0be6f32582b9acafe1741e8641a748eb99501021/flax/linen/linear.py#L739-L742. The type of parameters is controlled by param_dtype
.
Could you explain how the "exception" happens?
The way I see it, dtype
promotes the whole embedding matrix to bf16
here: https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#Embed and then takes a bf16 vector from this tensor -> this is different from just doing the matrix computation in bf16 IMO
You are right @patrickvonplaten. This is how fp16 mixed precision training with fp32 master weights works.
My point is, the current code in hugging face is wrong. The code in t5x is correct . My modification makes hugging face’s code match t5x’s code.
Reasons:
- Regard less of self.dtype. The weights is stored in fp32. This holds for both my PR and t5x.
- If dtype is fp16, the computation is in fp16. This holds for my PR and t5x (https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/layers.py#L501). But the original hugging face code is wrong
@merrymercy But T5X exactly doesn't set dtype=jnp.bfloat16
when instantiating the layer, see: https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/layers.py#L479 but instead wraps the embedding in dtype=jnp.bfloat16
only during the forward: https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/layers.py#L501
Shouldn't we try to match this?
Aha! I think we are talking at different levels. Could my comment below address your concerns?
First, I match the way we call ‘nn.Embed’ with t5x
This PR doesn’t modify ‘nn.Embed’ at all. It modifies the way we call ‘nn.Embed’. What my pr tries to match is this line in t5x. https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/network.py#L287 You can see it passes dtype to ‘nn.Embed’
Then, I match the implementation of ‘nn.Embed’ with t5x
The code you refers to is ‘layer.Embed’ in t5x, the equivalence of this in our code base is ‘flax.nn.Embed’. Both of them are implemented correctly.
In t5x, ‘nn.Embed’ has one argument dtype to control the type of computation and hard code fp32 for the type of parameters. In flax, ‘nn.Embed’ has two arguments. One for dtype of computation and one for the dtype of parameter. I never change the ‘param_dtype’, so it uses the default value fp32. This makes flax.nn.Embed match t5x.layer.Embed.
In summary, after my PR, the hugging face gpt should match t5x. Before my PR, the dtype of computation in mixed precision training is wrong.
Hey @merrymercy, thanks for clarifying and sorry for not making the connection before! The PR looks good to me then :-)
Just one other thing - it seems there is an issue with your CircleCI permissions, the tests won't run. Could you try refreshing your permissions as shown here?
I fixed the circle CI issue, but I don't know how to fix the "Build PR Documentation" test