keras icon indicating copy to clipboard operation
keras copied to clipboard

Fix stateful RNN to raise ValueError on batch size mismatch

Open El3ssar opened this issue 6 months ago • 5 comments

This PR fixes a silent failure when calling stateful RNN layers (SimpleRNN, GRU, LSTM) with an input that doesn't match the fixed batch size.

Previously:

  • SimpleRNN would silently broadcast the input across all internal states.
  • GRU and LSTM would crash in CuDNN but still mutate internal state before failing.

Now:

  • A ValueError is raised early in RNN.call() if the batch size of the input doesn't match the expected value from the internal state.
  • This avoids state corruption and aligns the behavior across all RNN variants.

Fixes #21183

Manually tested by calling a stateful model with incorrect batch sizes and verifying that the new ValueError is raised as expected.

El3ssar avatar May 04 '25 18:05 El3ssar

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar May 04 '25 18:05 google-cla[bot]

Codecov Report

Attention: Patch coverage is 0% with 5 lines in your changes missing coverage. Please review.

Project coverage is 34.56%. Comparing base (f5171b3) to head (3180ed4). Report is 124 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/rnn/rnn.py 0.00% 5 Missing :warning:

:exclamation: There is a different number of reports uploaded between BASE (f5171b3) and HEAD (3180ed4). Click for more details.

HEAD has 8 uploads less than BASE
Flag BASE (f5171b3) HEAD (3180ed4)
keras 5 1
keras-numpy 1 0
keras-torch 1 0
keras-tensorflow 1 0
keras-jax 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #21249       +/-   ##
===========================================
- Coverage   82.60%   34.56%   -48.04%     
===========================================
  Files         564      567        +3     
  Lines       54543    56219     +1676     
  Branches     8472     8788      +316     
===========================================
- Hits        45054    19431    -25623     
- Misses       7402    35910    +28508     
+ Partials     2087      878     -1209     
Flag Coverage Δ
keras 34.56% <0.00%> (-47.86%) :arrow_down:
keras-jax ?
keras-numpy ?
keras-openvino 34.56% <0.00%> (+1.56%) :arrow_up:
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar May 04 '25 18:05 codecov-commenter

Hi @El3ssar, Can you please sign the CLA? Thanks !

keerthanakadiri avatar May 05 '25 08:05 keerthanakadiri

Thanks for the PR! Please add a simple test case for the fix (using self.assertRaisesRegex)

Gladly, where do I put the test?

El3ssar avatar May 22 '25 07:05 El3ssar

You can put the test in keras/src/layers/rnn/rnn_test.py alongside test_statefulness_two_states. You can target layer = layers.RNN(TwoStatesRNNCell(2), stateful=True) as the RNN layer.

fchollet avatar May 22 '25 15:05 fchollet

@fchollet I've added the requested test in rnn_test.py using TwoStatesRNNCell. It fails without the fix and passes with it.

El3ssar avatar Jul 25 '25 11:07 El3ssar

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] avatar Aug 23 '25 02:08 github-actions[bot]

This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

github-actions[bot] avatar Sep 07 '25 02:09 github-actions[bot]