transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training

Open donebydan opened this issue 3 years ago • 8 comments

As outlined in #18425, this PR changes the signature of BartLearnedPositionalEmbedding's forward method signature to take the input_ids tensor (and not just its shape). This is needed to enable private training of BART via DP-SGD in Opacus. PR welcomed by @sgugger in linked issue.

Fixes #18425.

donebydan avatar Aug 05 '22 12:08 donebydan

The documentation is not available anymore as the PR was closed or merged.

You will also need to apply the same changes to all the models that are touched by the change in embedding (mBART, plBARt etc) to have the tests passing.

sgugger avatar Aug 05 '22 13:08 sgugger

@sgugger I am iterating through. 👍 thanks for the heads up, though!

donebydan avatar Aug 05 '22 14:08 donebydan

Ah, this also looks like it's breaking the conversion to torch.fx. Let's see if @michaelbenayoun can think of an easy solution to that.

sgugger avatar Aug 05 '22 14:08 sgugger

Thanks, I was about to ask.. any thoughts? I am not well-versed in torch's symbolic tracer (or FX generally). I'm happy to do the work if you can point me somewhere useful 😄

donebydan avatar Aug 05 '22 14:08 donebydan

Current offending line is line 985 in src/transformers/utils/fx.py (HFTracer.trace()): self.graph = super().trace(root, concrete_args=concrete_args)

where root is

PLBartModel(
  (shared): Embedding(99, 16, padding_idx=1)
  (encoder): PLBartEncoder(
    (embed_tokens): Embedding(99, 16, padding_idx=1)
    (embed_positions): PLBartLearnedPositionalEmbedding(102, 16)
    (layers): ModuleList(
      (0): PLBartEncoderLayer(
        (self_attn): PLBartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=16, out_features=4, bias=True)
        (fc2): Linear(in_features=4, out_features=16, bias=True)
        (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      )
      (1): PLBartEncoderLayer(
        (self_attn): PLBartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=16, out_features=4, bias=True)
        (fc2): Linear(in_features=4, out_features=16, bias=True)
        (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layernorm_embedding): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): PLBartDecoder(
    (embed_tokens): Embedding(99, 16, padding_idx=1)
    (embed_positions): PLBartLearnedPositionalEmbedding(102, 16)
    (layers): ModuleList(
      (0): PLBartDecoderLayer(
        (self_attn): PLBartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (encoder_attn): PLBartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=16, out_features=4, bias=True)
        (fc2): Linear(in_features=4, out_features=16, bias=True)
        (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      )
      (1): PLBartDecoderLayer(
        (self_attn): PLBartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (encoder_attn): PLBartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=16, out_features=4, bias=True)
        (fc2): Linear(in_features=4, out_features=16, bias=True)
        (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layernorm_embedding): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  )
) 

and concrete_args is:

{'head_mask': None, 'decoder_head_mask': None, 'cross_attn_head_mask': None, 'encoder_outputs': None, 'past_key_values': None, 'inputs_embeds': None, 'decoder_inputs_embeds': None, 'use_cache': None, 'output_attentions': None, 'output_hidden_states': None, 'return_dict': None}

donebydan avatar Aug 05 '22 14:08 donebydan

I will check on Monday and come back to you, it should be easily fixable I think.

michaelbenayoun avatar Aug 05 '22 17:08 michaelbenayoun

Hey @michaelbenayoun, let me know if you have any thoughts to resolve the tracer issue :)

donebydan avatar Aug 08 '22 11:08 donebydan

@sgugger all resolved now. Would you mind giving the PR another look?

donebydan avatar Aug 10 '22 11:08 donebydan

Thanks a lot for working on this!

sgugger avatar Aug 11 '22 13:08 sgugger

No problem, thank you for your support 👍

donebydan avatar Aug 11 '22 13:08 donebydan