NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Draft: Fix the "cast ping pong" problem when we run AMP inference.

Open galv opened this issue 9 months ago • 1 comments

This has been tested only for Parakeet-CTC-1.1B right now. This problem certainly exists elsewhere.

It also is not ina mergeable state. This is my initial hack to fix the problem.

Automatic mixed precision and inference do not play well together.

Frankly, I do not think we should be using AMP in inference in the first place.

First, automatic mixed precision was created back when neural networks were much simpler. In particular, they did not have softmax and layer norm as frequent operations. In the era of transformers, softmax and layer norm are very common. AMP will uncoditionally output fp32 outputs from these operations, even if their inputs are fp16. See here: https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float32

This is no longer necessary, now that layer norm does accumulation in fp32 in pytorch, even if the input is fp16: https://github.com/pytorch/pytorch/issues/66707

There is a second problem. It is elaborated on in my long comment in nemo/core/classes/module.py. Basically, setting requires_grad=False on a parameter, confusingly enough, disables AMP's caching mechanism, which means we uncoditionally waste time calling cache kernels. Read more here: https://github.com/NVIDIA/NeMo/pull/9086/files#diff-4f86565ac3601ea79395f9d2884c28115c5fd9fe3655894564d71539fe5d8042R58-R90

Both of these problems need to be fixed to prevent inference from constantly casting its intermediate tensors to and from float16 and float32, which is a waste of time that does nothing.

This hacky fix does not suppoer bfloat16 for now.

Again, I recommend simply running model.half() or model.bfloat16() for inference. AMP does not make sense for inference because it wastes memory.

Results from running:

batch_size=32
amp=true

echo "GALVEZ: Conformer CTC"
python examples/asr/speech_to_text_eval.py  pretrained_name=nvidia/parakeet-ctc-1.1b dataset_manifest=/home/dgalvez/scratch/data/test_other_sorted_downward.json  batch_size=$batch_size  output_filename=test_clean_decoded.jsonl  amp=$amp  amp_dtype=float16  use_cer=false num_workers=1
\# ctc_decoding.greedy.batched_inference=true

RTFx values before my change:

922.302270024654
1072.6919226908824
853.5230428014742
1114.2164255498367
1264.829085006686

RTFx values after my change:

1132.0813676953462
1255.840791955383
838.7445635257388
1389.7852550162668
1479.8061656119771

You can see that RTFx gets as high as 1480 after my change. Meanwhile, the highest RTFx before my change is 1264. So it's about a 15% speedup by removing the unnecessary cast kernels.

Forgive the large variance from run-to-run. It is because of the changes made in: https://github.com/NVIDIA/NeMo/pull/8521 There is a huge delay (at least 10 milliseconds) when there is a cache miss for the caching pinned host memory allocator. I'm planning to revert the change since I didn't realize initially how large the cache miss cost was (enough to undue the benefits on reasonably sized workloads).

galv avatar May 02 '24 00:05 galv

One clean solution could be to add a method such as .downcast(self, dtype=...) to most modules that casts all parameters and buffers to a correct precision for a given dtype. It'd have to be recursively called so that every "layer" or "block" can define how it needs to be cast, e.g. rnnt_model.downcast calls {encoder,decoder,joiner}.downcast, which call transformer_block.downcast, etc etc.

This is under the assumption that simply calling model.bfloat16 may cast some weights into a too-low precision.

pzelasko avatar May 02 '24 14:05 pzelasko

#9198 supersedes this.

galv avatar May 14 '24 18:05 galv