[BUG] _reshape_norm() leading to unexpected array shape when using 3D arrays with one channel
Describe the bug In the file train.py, the function _reshape_norm() behave in an unexpected way.
Specifically, when input data has shape (1, H, W), the current preprocessing logic incorrectly expands it to (3, 1, H, W) instead of the expected (3, H, W). This is due to the following code:
if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1):
td = np.stack((td, 0*td, 0*td), axis=0)
elif td.ndim == 3 and td.shape[0] < 3:
td = np.concatenate((td, 0*td[:1]), axis=0)
This condition treats 3D arrays with a single channel (e.g. shape (1, H, W)) the same way as 2D images, stacking them along a new first axis and unintentionally adding an extra dimension leading to a 4D array.
To Reproduce Steps to reproduce the behavior:
- Prepare input data shaped (1, H, W) (e.g., a single-channel image with explicit channel dimension).
- Call
_reshape_norm(data, channel_axis=0). - Observe that the output shape becomes (3, 1, H, W) instead of (3, H, W).
Suggested fix
if td.ndim == 2:
td = np.stack((td, 0 * td, 0 * td), axis=0)
elif td.ndim == 3 and td.shape[0] < 3:
channel_to_add = 3 - td.shape[0]
pad_width = [(0, channel_to_add)] + [(0, 0)] * 2
td = np.pad(td, pad_width=pad_width, mode="constant", constant_values=0)
Hi @PandaGab would you like to open a PR for this to get credit for the fix?
Please also add a test that would catch this before making the fix, and that passes after the fix is implemented.