Why `.float()` is in `conformer_modules.py`?
https://github.com/NVIDIA/NeMo/blob/eae1684f7f33c2a18de9ecfa42ec7db93d39e631/nemo/collections/asr/parts/submodules/conformer_modules.py#L259
Hi! I'm wondering why .float() is here. It seems like it will break code if both model parameters and input in float16
.masked_fill works both for float16 and bfloat16, so i still don't get an idea of .float() here. Is there any possibility that forward will get not float (16/32) type?
I know that these things were required to make Conformer train properly in fp16, but @bmwshop can provide details of why this specific case was necessary
This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.
@bmwshop could you take a look ?
This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.
This issue was closed because it has been inactive for 7 days since being marked as stale.