keras
keras copied to clipboard
Fix stateful RNN to raise ValueError on batch size mismatch
This PR fixes a silent failure when calling stateful RNN layers (SimpleRNN, GRU, LSTM) with an input that doesn't match the fixed batch size.
Previously:
SimpleRNNwould silently broadcast the input across all internal states.GRUandLSTMwould crash in CuDNN but still mutate internal state before failing.
Now:
- A
ValueErroris raised early inRNN.call()if the batch size of the input doesn't match the expected value from the internal state. - This avoids state corruption and aligns the behavior across all RNN variants.
Fixes #21183
Manually tested by calling a stateful model with incorrect batch sizes
and verifying that the new ValueError is raised as expected.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Codecov Report
Attention: Patch coverage is 0% with 5 lines in your changes missing coverage. Please review.
Project coverage is 34.56%. Comparing base (
f5171b3) to head (3180ed4). Report is 124 commits behind head on master.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| keras/src/layers/rnn/rnn.py | 0.00% | 5 Missing :warning: |
:exclamation: There is a different number of reports uploaded between BASE (f5171b3) and HEAD (3180ed4). Click for more details.
HEAD has 8 uploads less than BASE
Flag BASE (f5171b3) HEAD (3180ed4) keras 5 1 keras-numpy 1 0 keras-torch 1 0 keras-tensorflow 1 0 keras-jax 1 0
Additional details and impacted files
@@ Coverage Diff @@
## master #21249 +/- ##
===========================================
- Coverage 82.60% 34.56% -48.04%
===========================================
Files 564 567 +3
Lines 54543 56219 +1676
Branches 8472 8788 +316
===========================================
- Hits 45054 19431 -25623
- Misses 7402 35910 +28508
+ Partials 2087 878 -1209
| Flag | Coverage Δ | |
|---|---|---|
| keras | 34.56% <0.00%> (-47.86%) |
:arrow_down: |
| keras-jax | ? |
|
| keras-numpy | ? |
|
| keras-openvino | 34.56% <0.00%> (+1.56%) |
:arrow_up: |
| keras-tensorflow | ? |
|
| keras-torch | ? |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Hi @El3ssar, Can you please sign the CLA? Thanks !
Thanks for the PR! Please add a simple test case for the fix (using
self.assertRaisesRegex)
Gladly, where do I put the test?
You can put the test in keras/src/layers/rnn/rnn_test.py alongside test_statefulness_two_states. You can target layer = layers.RNN(TwoStatesRNNCell(2), stateful=True) as the RNN layer.
@fchollet I've added the requested test in rnn_test.py using TwoStatesRNNCell. It fails without the fix and passes with it.
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.