transformers
transformers copied to clipboard
Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training
As outlined in #18425, this PR changes the signature of BartLearnedPositionalEmbedding's forward method signature to take the input_ids tensor (and not just its shape). This is needed to enable private training of BART via DP-SGD in Opacus. PR welcomed by @sgugger in linked issue.
Fixes #18425.
The documentation is not available anymore as the PR was closed or merged.
You will also need to apply the same changes to all the models that are touched by the change in embedding (mBART, plBARt etc) to have the tests passing.
@sgugger I am iterating through. 👍 thanks for the heads up, though!
Ah, this also looks like it's breaking the conversion to torch.fx. Let's see if @michaelbenayoun can think of an easy solution to that.
Thanks, I was about to ask.. any thoughts? I am not well-versed in torch's symbolic tracer (or FX generally). I'm happy to do the work if you can point me somewhere useful 😄
Current offending line is line 985 in src/transformers/utils/fx.py (HFTracer.trace()):
self.graph = super().trace(root, concrete_args=concrete_args)
where root is
PLBartModel(
(shared): Embedding(99, 16, padding_idx=1)
(encoder): PLBartEncoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): PLBartLearnedPositionalEmbedding(102, 16)
(layers): ModuleList(
(0): PLBartEncoderLayer(
(self_attn): PLBartAttention(
(k_proj): Linear(in_features=16, out_features=16, bias=True)
(v_proj): Linear(in_features=16, out_features=16, bias=True)
(q_proj): Linear(in_features=16, out_features=16, bias=True)
(out_proj): Linear(in_features=16, out_features=16, bias=True)
)
(self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(activation_fn): GELUActivation()
(fc1): Linear(in_features=16, out_features=4, bias=True)
(fc2): Linear(in_features=4, out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
(1): PLBartEncoderLayer(
(self_attn): PLBartAttention(
(k_proj): Linear(in_features=16, out_features=16, bias=True)
(v_proj): Linear(in_features=16, out_features=16, bias=True)
(q_proj): Linear(in_features=16, out_features=16, bias=True)
(out_proj): Linear(in_features=16, out_features=16, bias=True)
)
(self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(activation_fn): GELUActivation()
(fc1): Linear(in_features=16, out_features=4, bias=True)
(fc2): Linear(in_features=4, out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
(layernorm_embedding): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
(decoder): PLBartDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): PLBartLearnedPositionalEmbedding(102, 16)
(layers): ModuleList(
(0): PLBartDecoderLayer(
(self_attn): PLBartAttention(
(k_proj): Linear(in_features=16, out_features=16, bias=True)
(v_proj): Linear(in_features=16, out_features=16, bias=True)
(q_proj): Linear(in_features=16, out_features=16, bias=True)
(out_proj): Linear(in_features=16, out_features=16, bias=True)
)
(activation_fn): GELUActivation()
(self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(encoder_attn): PLBartAttention(
(k_proj): Linear(in_features=16, out_features=16, bias=True)
(v_proj): Linear(in_features=16, out_features=16, bias=True)
(q_proj): Linear(in_features=16, out_features=16, bias=True)
(out_proj): Linear(in_features=16, out_features=16, bias=True)
)
(encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=16, out_features=4, bias=True)
(fc2): Linear(in_features=4, out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
(1): PLBartDecoderLayer(
(self_attn): PLBartAttention(
(k_proj): Linear(in_features=16, out_features=16, bias=True)
(v_proj): Linear(in_features=16, out_features=16, bias=True)
(q_proj): Linear(in_features=16, out_features=16, bias=True)
(out_proj): Linear(in_features=16, out_features=16, bias=True)
)
(activation_fn): GELUActivation()
(self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(encoder_attn): PLBartAttention(
(k_proj): Linear(in_features=16, out_features=16, bias=True)
(v_proj): Linear(in_features=16, out_features=16, bias=True)
(q_proj): Linear(in_features=16, out_features=16, bias=True)
(out_proj): Linear(in_features=16, out_features=16, bias=True)
)
(encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=16, out_features=4, bias=True)
(fc2): Linear(in_features=4, out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
(layernorm_embedding): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
and concrete_args is:
{'head_mask': None, 'decoder_head_mask': None, 'cross_attn_head_mask': None, 'encoder_outputs': None, 'past_key_values': None, 'inputs_embeds': None, 'decoder_inputs_embeds': None, 'use_cache': None, 'output_attentions': None, 'output_hidden_states': None, 'return_dict': None}
I will check on Monday and come back to you, it should be easily fixable I think.
Hey @michaelbenayoun, let me know if you have any thoughts to resolve the tracer issue :)
@sgugger all resolved now. Would you mind giving the PR another look?
Thanks a lot for working on this!
No problem, thank you for your support 👍