transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[FLAX] Add dtype to embedding for gpt2 model

Open merrymercy opened this issue 2 years ago • 3 comments

What does this PR do?

Add dtype to embedding for gpt2 model

Who can review?

@patrickvonplaten, @LysandreJik

merrymercy avatar Aug 03 '22 21:08 merrymercy

The documentation is not available anymore as the PR was closed or merged.

Thanks @merrymercy, I've noticed that we omit the dtype arg from all Flax nn.Embed modules! @patil-suraj is there a reason why we do this?

BART: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/bart/modeling_flax_bart.py#L841-L845 BERT: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/bert/modeling_flax_bert.py#L186-L200 T5: https://github.com/huggingface/transformers/blob/ab2006e3d6db88654526a4169e65d4bfc52da2e3/src/transformers/models/t5/modeling_flax_t5.py#L1259-L1263

sanchit-gandhi avatar Aug 09 '22 14:08 sanchit-gandhi

I don't know the reasons, but this dtype is required for half-precision training. I can modify all other classes as well if needed.

merrymercy avatar Aug 09 '22 22:08 merrymercy

Let's wait for @patil-suraj to weigh in on this!

sanchit-gandhi avatar Aug 10 '22 15:08 sanchit-gandhi

Gentle ping @patil-suraj

merrymercy avatar Aug 15 '22 19:08 merrymercy

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 09 '22 15:09 github-actions[bot]

@sanchit-gandhi could you maybe take a look here?

patrickvonplaten avatar Sep 09 '22 15:09 patrickvonplaten

Sorry @patrickvonplaten - as mentioned in my previous comment https://github.com/huggingface/transformers/pull/18462#issuecomment-1209435023 I'm not sure why we omit the dtype from all Flax nn.Embed modules, hence the request for @patil-suraj to weight in! Maybe you could shed some light on this? It seems like intrinsic design philosophy given that we do this for all models.

sanchit-gandhi avatar Sep 09 '22 15:09 sanchit-gandhi

The change looks good to me. T5x also puts the embedding in half precision if necessary: https://github.com/google-research/t5x/blob/1f8cec78b1f28f1955d70741792d7b6e7dd76226/t5x/examples/t5/network.py#L287

@patil-suraj what do you think?

patrickvonplaten avatar Sep 14 '22 18:09 patrickvonplaten

Can we merge this?

merrymercy avatar Sep 23 '22 21:09 merrymercy

It's interesting that we omit the dtype arg in the embedding layer for both PyTorch and Flax: https://github.com/huggingface/transformers/blob/cbb8a37929c3860210f95c9ec99b8b84b8cf57a1/src/transformers/models/gpt2/modeling_gpt2.py#L675-L676

Wondering if this was a deliberate design decision that we're violating in this PR? Otherwise am happy with the change for half-precision training!

sanchit-gandhi avatar Oct 10 '22 11:10 sanchit-gandhi

Your conclusion aligns with the previous observations of embedding dtypes never being down-cast in any Transformer models, both for PyTorch and Flax!

Wondering if you could share the rationale behind why one must not down-cast embedding weights to half-precision? This would be helpful in understanding why this should be avoided and help educate us all!

sanchit-gandhi avatar Oct 10 '22 16:10 sanchit-gandhi

I think my modification does not conflict with t5x. My PR only changes the dtype of computation and output tensor, not the parameter type (param_dtype). https://github.com/google/flax/blob/0be6f32582b9acafe1741e8641a748eb99501021/flax/linen/linear.py#L732-L733

This aligns with @patrickvonplaten 's finding of the code of t5x.

@patil-suraj Please review. I am working extensively on the flax backend and am happy to contribute more code.

merrymercy avatar Oct 10 '22 20:10 merrymercy

Hey @merrymercy,

I think nn.Embed is an exception in Flax where providing a dtype does exactly modify the embedding weights and not just the computation. @patil-suraj can maybe explain better here :-)

patrickvonplaten avatar Oct 11 '22 18:10 patrickvonplaten

By looking at the code, I don't know why dtype changes the type of parameters. You can check the code https://github.com/google/flax/blob/0be6f32582b9acafe1741e8641a748eb99501021/flax/linen/linear.py#L739-L742. The type of parameters is controlled by param_dtype.

Could you explain how the "exception" happens?

merrymercy avatar Oct 11 '22 23:10 merrymercy

The way I see it, dtype promotes the whole embedding matrix to bf16 here: https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#Embed and then takes a bf16 vector from this tensor -> this is different from just doing the matrix computation in bf16 IMO

patrickvonplaten avatar Oct 14 '22 17:10 patrickvonplaten

You are right @patrickvonplaten. This is how fp16 mixed precision training with fp32 master weights works.

My point is, the current code in hugging face is wrong. The code in t5x is correct . My modification makes hugging face’s code match t5x’s code.

Reasons:

  1. Regard less of self.dtype. The weights is stored in fp32. This holds for both my PR and t5x.
  2. If dtype is fp16, the computation is in fp16. This holds for my PR and t5x (https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/layers.py#L501). But the original hugging face code is wrong

merrymercy avatar Oct 14 '22 22:10 merrymercy

@merrymercy But T5X exactly doesn't set dtype=jnp.bfloat16 when instantiating the layer, see: https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/layers.py#L479 but instead wraps the embedding in dtype=jnp.bfloat16 only during the forward: https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/layers.py#L501

Shouldn't we try to match this?

patrickvonplaten avatar Oct 17 '22 16:10 patrickvonplaten

Aha! I think we are talking at different levels. Could my comment below address your concerns?

First, I match the way we call ‘nn.Embed’ with t5x

This PR doesn’t modify ‘nn.Embed’ at all. It modifies the way we call ‘nn.Embed’. What my pr tries to match is this line in t5x. https://github.com/google-research/t5x/blob/ca3d2e43c8db2e6769073ffa98b7689443e3b2b8/t5x/examples/t5/network.py#L287 You can see it passes dtype to ‘nn.Embed’

Then, I match the implementation of ‘nn.Embed’ with t5x

The code you refers to is ‘layer.Embed’ in t5x, the equivalence of this in our code base is ‘flax.nn.Embed’. Both of them are implemented correctly.

In t5x, ‘nn.Embed’ has one argument dtype to control the type of computation and hard code fp32 for the type of parameters. In flax, ‘nn.Embed’ has two arguments. One for dtype of computation and one for the dtype of parameter. I never change the ‘param_dtype’, so it uses the default value fp32. This makes flax.nn.Embed match t5x.layer.Embed.

In summary, after my PR, the hugging face gpt should match t5x. Before my PR, the dtype of computation in mixed precision training is wrong.

merrymercy avatar Oct 17 '22 18:10 merrymercy

Hey @merrymercy, thanks for clarifying and sorry for not making the connection before! The PR looks good to me then :-)

Just one other thing - it seems there is an issue with your CircleCI permissions, the tests won't run. Could you try refreshing your permissions as shown here?

patrickvonplaten avatar Oct 18 '22 19:10 patrickvonplaten

I fixed the circle CI issue, but I don't know how to fix the "Build PR Documentation" test

merrymercy avatar Oct 18 '22 22:10 merrymercy