keras
keras copied to clipboard
Made masked losses compatible with masked nans
NaNs are a great way to ensure certain values aren't used (e.g. those that are associated with masked values). This change ensures that masked values are correctly masked (set to zero, even when nan) rather than multiplied by zero (which leaves nans as nans).
This PR also (IMO) greatly simplifies the masking / weighting loss implementation. Test coverage is also improved.
Codecov Report
Attention: 3 lines in your changes are missing coverage. Please review.
Comparison is base (
9620d23) 79.30% compared to head (936dd13) 79.37%. Report is 11 commits behind head on master.
| Files | Patch % | Lines |
|---|---|---|
| keras/losses/loss.py | 86.95% | 2 Missing and 1 partial :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #18829 +/- ##
==========================================
+ Coverage 79.30% 79.37% +0.06%
==========================================
Files 336 336
Lines 34549 34775 +226
Branches 6799 6841 +42
==========================================
+ Hits 27400 27603 +203
- Misses 5567 5590 +23
Partials 1582 1582
| Flag | Coverage Δ | |
|---|---|---|
| keras | 79.23% <90.00%> (+0.06%) |
:arrow_up: |
| keras-jax | 61.07% <90.00%> (-0.28%) |
:arrow_down: |
| keras-numpy | 55.90% <40.00%> (-0.19%) |
:arrow_down: |
| keras-tensorflow | 63.26% <90.00%> (-0.08%) |
:arrow_down: |
| keras-torch | 63.83% <90.00%> (-0.27%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@fchollet refactored to use a re-introduced apply_mask. Can you ellaborate on the overhead introduced by a scalar cond? Is there something about the specific circumstances here that make it expensive? I would have thought it would be cheaper than a where, which itself would be about the cost of any simple binary arithmetic operation
The reason has to do with parallelization. Conditional branches are much harder to handle than a single serial op even if it has more flops.
@haifeng-jin can you advise here on how to test the performance impact of this change for a couple of standard models on GPU? I'd like to compare the cond implementation, the where implementation, and the baseline (current code).
@jackd here is what I did for benchmarking a PR. You can find 6 colab notebooks here: https://github.com/haifeng-jin/keras-benchmarking
Each of the notebooks are for one of the backends, either before the change or after the change. You can just swap out the model part of the code and use your own model.
You will see the perf at the end of the notebook.
Thanks, Haifeng! @jackd can you use the same code to benchmark the impact of this change?
This isn't a priority for me, and I've already spent a lot longer on this than I intended to. If anyone else wants to take this up feel free, otherwise would a modified PR with just the simplified version (re-replacing masking with multiplication by zero) be accepted?
Hi @fchollet Can you please assist on above comments from @jackd. Thank you!
Running on colab T4
- batch_norm_op_jax_after (HEAD -> 936dd1345d794e91b3883bf99dec66dc8021e7fc) 102967424/102967424 ━━━━━━━━━━━━━━━━━━━━ 6s 0us/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 165s 1s/step - loss: 0.5308 training: 1111 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 33s 269ms/step inferencing: 267 ms/step
- batch_norm_op_jax_before (HEAD -> 9620d23ab1d672e0611d3c9bd0f77a11a20b3204) 102967424/102967424 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - loss: 0.5062 training: 1076 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 41s 373ms/step inferencing: 255 ms/step
- batch_norm_op_oom_torch_after (HEAD -> 936dd1345d794e91b3883bf99dec66dc8021e7fc) 101/101 ━━━━━━━━━━━━━━━━━━━━ 108s 1s/step - loss: 0.6557 414.0576171875
- batch_norm_op_oom_torch_before (HEAD -> 9620d23ab1d672e0611d3c9bd0f77a11a20b3204) 101/101 ━━━━━━━━━━━━━━━━━━━━ 109s 1s/step - loss: 0.5071 414.0576171875
- batch_norm_op_torch_after (HEAD -> 936dd1345d794e91b3883bf99dec66dc8021e7fc) 101/101 ━━━━━━━━━━━━━━━━━━━━ 71s 684ms/step - loss: 0.3952 training: 683 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 19s 189ms/step inferencing: 189 ms/step
- batch_norm_op_torch_before (HEAD -> 9620d23ab1d672e0611d3c9bd0f77a11a20b3204) 101/101 ━━━━━━━━━━━━━━━━━━━━ 93s 545ms/step - loss: 0.5390 training: 543 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 18s 132ms/step inferencing: 132 ms/step
@dugujiujian1999 maybe I'm missing something, but why do run with lower times have a higher times per step? e.g. batch_norm_op_torch training has 71s & 684ms/step before and 93s & 545ms/step. Surely they should be proportional?
@jackd i don't know. It takes less time after the patch. i use the code there: https://github.com/haifeng-jin/keras-benchmarking/tree/main/prs
@jackd , Can you please rebase the code to follow latest code structure like /keras/src/..
CBF at this point, someone else can take over if they'd like.