keras icon indicating copy to clipboard operation
keras copied to clipboard

Modified compute_output_shape function to handle broadcasting behavior in layers.Rescaling

Open sonali-kumari1 opened this issue 5 months ago • 2 comments

  • Fixed compute_output_shape() in Rescaling to correctly handle the broadcasting and ensure consistent shapes in model.summary() and model.output_shape.
  • Improved Rescaling layer docstring with detailed description of scale and offset parameters.

Fixes #21319

sonali-kumari1 avatar Jun 04 '25 06:06 sonali-kumari1

Codecov Report

Attention: Patch coverage is 77.14286% with 8 lines in your changes missing coverage. Please review.

Project coverage is 82.74%. Comparing base (de9cf25) to head (2544af9). Report is 50 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/preprocessing/rescaling.py 77.14% 4 Missing and 4 partials :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21351      +/-   ##
==========================================
+ Coverage   82.65%   82.74%   +0.09%     
==========================================
  Files         565      565              
  Lines       54823    55339     +516     
  Branches     8514     8639     +125     
==========================================
+ Hits        45315    45793     +478     
- Misses       7414     7440      +26     
- Partials     2094     2106      +12     
Flag Coverage Δ
keras 82.56% <74.28%> (+0.09%) :arrow_up:
keras-jax 63.36% <71.42%> (-0.23%) :arrow_down:
keras-numpy 58.53% <25.71%> (-0.22%) :arrow_down:
keras-openvino 33.76% <2.85%> (+0.62%) :arrow_up:
keras-tensorflow 63.78% <71.42%> (-0.23%) :arrow_down:
keras-torch 63.39% <71.42%> (-0.27%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar Jun 04 '25 06:06 codecov-commenter

Please fix the failing torch test

FAILED keras/src/layers/preprocessing/rescaling_test.py::RescalingTest::test_rescaling_broadcast_output_shape - AssertionError: False is not true : Expected output shapes (2, 2) but received torch.Size([2, 2, 2])

fchollet avatar Jun 13 '25 18:06 fchollet

Please fix the failing torch test

FAILED keras/src/layers/preprocessing/rescaling_test.py::RescalingTest::test_rescaling_broadcast_output_shape - AssertionError: False is not true : Expected output shapes (2, 2) but received torch.Size([2, 2, 2])

Hi @fchollet, I tested the output shapes locally across all three backends — TensorFlow, JAX, and Torch — and observed consistent shapes: (2, 2) for TensorFlow and JAX, and torch.Size([2, 2]) for Torch. Here's the code snippet I used:

import numpy as np
import keras
from keras import layers

x = np.ones((2, 1))
y_target = np.ones((2, 1))


model = keras.Sequential(
    [
        layers.Input(shape=(1,), name="Feature"),
        layers.Rescaling([1.0, 2.0], [0.2, 0.4]),
    ]
)
model.summary()

y_pred = model(x)

print(f"Expected model output shape: {model.output_shape}")
print(f"True model output shape: {y_pred.shape}")
print(f"Target shape: {y_target.shape}")

model.compile(
    optimizer="adam",
    loss="mse",
)
model.fit(x, y_target, epochs=1, batch_size=8)

print("Result:", y_pred)

print("Actual input: ", x)

Output shapes :

  • TensorFlow: (2, 2)
  • JAX: (2, 2)
  • Torch: torch.Size([2, 2])

However, in CI environment, and specifically with the Torch backend, the output shape has an extra dimension.

sonali-kumari1 avatar Jun 18 '25 06:06 sonali-kumari1

However, in CI environment, and specifically with the Torch backend, the output shape has an extra dimension.

The key bit is that torch tests run with image_data_format set to channels_first. Please make sure the test takes this into account.

fchollet avatar Jun 20 '25 17:06 fchollet