transformers
transformers copied to clipboard
`SeamlessM4Tv2ConformerEncoder` does not behaves as expected if gradient checkpointing is enabled
System Info
-
transformers
version: 4.42.0.dev0 - Platform: Linux-5.4.0-172-generic-x86_64-with-glibc2.17
- Python version: 3.8.19
- Huggingface_hub version: 0.23.1
- Safetensors version: 0.4.3
- Accelerate version: 0.30.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Tensorflow version (GPU?): 2.13.1 (True)
- Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
- Jax version: 0.4.13
- JaxLib version: 0.4.13
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
@ArthurZucker @sanchit
Proposed fix:
class SeamlessM4Tv2ConformerEncoder(...):
[...]
def forward(...):
[...]
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions, # <---------- Add this parameter
conv_attention_mask, # <---------- Add this parameter
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
conv_attention_mask=conv_attention_mask,
)
hidden_states = layer_outputs[0]
[...]
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
- Train a model that has
transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2.SeamlessM4Tv2ConformerEncoder
as a submodule - (enable gradient checkpointing while training)
- When calling
SeamlessM4Tv2ConformerEncoder.forward()
, passoutput_attentions=True
andreturn_dict=True
. For example:encoder: SeamlessM4Tv2ConformerEncoder = ... output = encoder(..., output_attentions=True, return_dict=True)
Expected behavior
output.attentions
is a tuple of not-None tensors, one per encoder layer. Instead, the actual behavior is that output.attentions = (None, None, ..., None)
.