NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Why `.float()` is in `conformer_modules.py`?

Open grazder opened this issue 3 years ago • 2 comments

https://github.com/NVIDIA/NeMo/blob/eae1684f7f33c2a18de9ecfa42ec7db93d39e631/nemo/collections/asr/parts/submodules/conformer_modules.py#L259

Hi! I'm wondering why .float() is here. It seems like it will break code if both model parameters and input in float16

grazder avatar Aug 30 '22 11:08 grazder

.masked_fill works both for float16 and bfloat16, so i still don't get an idea of .float() here. Is there any possibility that forward will get not float (16/32) type?

grazder avatar Aug 31 '22 06:08 grazder

I know that these things were required to make Conformer train properly in fp16, but @bmwshop can provide details of why this specific case was necessary

titu1994 avatar Aug 31 '22 11:08 titu1994

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] avatar Oct 06 '22 02:10 github-actions[bot]

@bmwshop could you take a look ?

titu1994 avatar Oct 06 '22 03:10 titu1994

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] avatar Nov 07 '22 02:11 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Nov 15 '22 02:11 github-actions[bot]