keras icon indicating copy to clipboard operation
keras copied to clipboard

Made masked losses compatible with masked nans

Open jackd opened this issue 2 years ago • 10 comments

NaNs are a great way to ensure certain values aren't used (e.g. those that are associated with masked values). This change ensures that masked values are correctly masked (set to zero, even when nan) rather than multiplied by zero (which leaves nans as nans).

This PR also (IMO) greatly simplifies the masking / weighting loss implementation. Test coverage is also improved.

jackd avatar Nov 25 '23 04:11 jackd

Codecov Report

Attention: 3 lines in your changes are missing coverage. Please review.

Comparison is base (9620d23) 79.30% compared to head (936dd13) 79.37%. Report is 11 commits behind head on master.

Files Patch % Lines
keras/losses/loss.py 86.95% 2 Missing and 1 partial :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18829      +/-   ##
==========================================
+ Coverage   79.30%   79.37%   +0.06%     
==========================================
  Files         336      336              
  Lines       34549    34775     +226     
  Branches     6799     6841      +42     
==========================================
+ Hits        27400    27603     +203     
- Misses       5567     5590      +23     
  Partials     1582     1582              
Flag Coverage Δ
keras 79.23% <90.00%> (+0.06%) :arrow_up:
keras-jax 61.07% <90.00%> (-0.28%) :arrow_down:
keras-numpy 55.90% <40.00%> (-0.19%) :arrow_down:
keras-tensorflow 63.26% <90.00%> (-0.08%) :arrow_down:
keras-torch 63.83% <90.00%> (-0.27%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Nov 25 '23 04:11 codecov-commenter

@fchollet refactored to use a re-introduced apply_mask. Can you ellaborate on the overhead introduced by a scalar cond? Is there something about the specific circumstances here that make it expensive? I would have thought it would be cheaper than a where, which itself would be about the cost of any simple binary arithmetic operation

jackd avatar Nov 29 '23 13:11 jackd

The reason has to do with parallelization. Conditional branches are much harder to handle than a single serial op even if it has more flops.

@haifeng-jin can you advise here on how to test the performance impact of this change for a couple of standard models on GPU? I'd like to compare the cond implementation, the where implementation, and the baseline (current code).

fchollet avatar Nov 29 '23 16:11 fchollet

@jackd here is what I did for benchmarking a PR. You can find 6 colab notebooks here: https://github.com/haifeng-jin/keras-benchmarking

Each of the notebooks are for one of the backends, either before the change or after the change. You can just swap out the model part of the code and use your own model.

You will see the perf at the end of the notebook.

haifeng-jin avatar Dec 07 '23 04:12 haifeng-jin

Thanks, Haifeng! @jackd can you use the same code to benchmark the impact of this change?

fchollet avatar Dec 07 '23 21:12 fchollet

This isn't a priority for me, and I've already spent a lot longer on this than I intended to. If anyone else wants to take this up feel free, otherwise would a modified PR with just the simplified version (re-replacing masking with multiplication by zero) be accepted?

jackd avatar Dec 13 '23 07:12 jackd

Hi @fchollet Can you please assist on above comments from @jackd. Thank you!

gbaned avatar Jan 05 '24 17:01 gbaned

Running on colab T4

  • batch_norm_op_jax_after (HEAD -> 936dd1345d794e91b3883bf99dec66dc8021e7fc) 102967424/102967424 ━━━━━━━━━━━━━━━━━━━━ 6s 0us/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 165s 1s/step - loss: 0.5308 training: 1111 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 33s 269ms/step inferencing: 267 ms/step
  • batch_norm_op_jax_before (HEAD -> 9620d23ab1d672e0611d3c9bd0f77a11a20b3204) 102967424/102967424 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - loss: 0.5062 training: 1076 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 41s 373ms/step inferencing: 255 ms/step

  • batch_norm_op_oom_torch_after (HEAD -> 936dd1345d794e91b3883bf99dec66dc8021e7fc) 101/101 ━━━━━━━━━━━━━━━━━━━━ 108s 1s/step - loss: 0.6557 414.0576171875
  • batch_norm_op_oom_torch_before (HEAD -> 9620d23ab1d672e0611d3c9bd0f77a11a20b3204) 101/101 ━━━━━━━━━━━━━━━━━━━━ 109s 1s/step - loss: 0.5071 414.0576171875

  • batch_norm_op_torch_after (HEAD -> 936dd1345d794e91b3883bf99dec66dc8021e7fc) 101/101 ━━━━━━━━━━━━━━━━━━━━ 71s 684ms/step - loss: 0.3952 training: 683 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 19s 189ms/step inferencing: 189 ms/step
  • batch_norm_op_torch_before (HEAD -> 9620d23ab1d672e0611d3c9bd0f77a11a20b3204) 101/101 ━━━━━━━━━━━━━━━━━━━━ 93s 545ms/step - loss: 0.5390 training: 543 ms/step 101/101 ━━━━━━━━━━━━━━━━━━━━ 18s 132ms/step inferencing: 132 ms/step

dugujiujian1999 avatar Jan 08 '24 08:01 dugujiujian1999

@dugujiujian1999 maybe I'm missing something, but why do run with lower times have a higher times per step? e.g. batch_norm_op_torch training has 71s & 684ms/step before and 93s & 545ms/step. Surely they should be proportional?

jackd avatar Jan 15 '24 01:01 jackd

@jackd i don't know. It takes less time after the patch. i use the code there: https://github.com/haifeng-jin/keras-benchmarking/tree/main/prs

dugujiujian1999 avatar Jan 15 '24 05:01 dugujiujian1999

@jackd , Can you please rebase the code to follow latest code structure like /keras/src/..

sachinprasadhs avatar May 01 '24 23:05 sachinprasadhs

CBF at this point, someone else can take over if they'd like.

jackd avatar May 01 '24 23:05 jackd