transformers
transformers copied to clipboard
attention_mask bug when training Wav2Vec2ForCTC with DeepSpeed
System Info
transformersversion: 4.19.2- Platform: Linux-4.15.0-144-generic-x86_64-with-glibc2.27
- Python version: 3.8.13
- Huggingface_hub version: 0.7.0
- PyTorch version (GPU?): 1.7.1+cu110 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
@patrickvonplaten @stas00
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
I experienced a problem when training Wav2Vec2ForCTC if i preprocessed data to create an attention_mask, it's dtype is int32
here is simple_example
import torch
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(return_attention_mask=True)
data = [{'input_values':[0.1,0.1,0.1]},{'input_values':[0.2,0.2,0.2,0.2,0.2]}]
attn_mask = feature_extractor.pad(data,padding = "longest",return_tensors="pt")['attention_mask']
print(attn_mask.dtype)
-> torch.int32
it is caused problem when training Wav2Vec2ForCTC with deepspeed
_prepare_input method in trainer.py change int32 to float16 (if training fp16)
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = dict(device=self.args.device)
if self.deepspeed and data.dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
return data.to(**kwargs)
return data
and forword in Wav2Vec2ForCTC is using sum of attention_mask values
loss = None
if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) # Here!
because current attention_mask's dtype is float16(deepspeed), and length vector of the audio is long, attention_mask.sum(-1) has many 'inf' value and it make break training
Is this a known bug?
i solved this porblem to edit DataCollatorCTCWithPadding in example like this
batch['attention_mask'] = batch['attention_mask'].to(torch.long)
but i want know other solution
Expected behavior
maybe change attention_mask's dtype from FeatureExtractor or _prepare_input method's logic
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Sorry for being so slow / late here @ddobokki !
I think your solution sounds reasonable:
batch['attention_mask'] = batch['attention_mask'].to(torch.long)
=> attention_mask should be in long so this is a welcome change. Do you mind opening a PR for this?
BTW, we do the same (casting to long) for similar inputs for pre-training: https://github.com/huggingface/transformers/blob/6268694e27f1fc0192ba24e4bec181061b4a9bf8/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py#L335
@patrickvonplaten Thank you for the comments! It's a small change but i glad for contribution! I'll opening a PR.