pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

[ENH] `TemporalFusionTransformer` - allow mixed precision training

Open Marcrb2 opened this issue 1 year ago • 2 comments

Description

This PR modifies the attention mask in the TFT model from 1e-9 to float("inf") to allow Pytorch mixed precision training.

Closes #1325, closes #285

Marcrb2 avatar Feb 19 '24 09:02 Marcrb2

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (b3fcf86) 90.19% compared to head (d93c1e0) 90.19%. Report is 8 commits behind head on master.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1518   +/-   ##
=======================================
  Coverage   90.19%   90.19%           
=======================================
  Files          30       30           
  Lines        4724     4724           
=======================================
  Hits         4261     4261           
  Misses        463      463           
Flag Coverage Δ
cpu 90.19% <100.00%> (ø)
pytest 90.19% <100.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Feb 19 '24 09:02 codecov-commenter