pytorch-forecasting
pytorch-forecasting copied to clipboard
[ENH] `TemporalFusionTransformer` - allow mixed precision training
Description
This PR modifies the attention mask in the TFT model from 1e-9 to float("inf") to allow Pytorch mixed precision training.
Closes #1325, closes #285
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Comparison is base (
b3fcf86) 90.19% compared to head (d93c1e0) 90.19%. Report is 8 commits behind head on master.
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@ Coverage Diff @@
## master #1518 +/- ##
=======================================
Coverage 90.19% 90.19%
=======================================
Files 30 30
Lines 4724 4724
=======================================
Hits 4261 4261
Misses 463 463
| Flag | Coverage Δ | |
|---|---|---|
| cpu | 90.19% <100.00%> (ø) |
|
| pytest | 90.19% <100.00%> (ø) |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.