dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

Raise error when user passes bad window/stride for pooling

Open n2cholas opened this issue 5 years ago • 1 comments

Currently, MaxPool and AvgPool expect window shape and stride to be "Same rank as value or int.". I think the layers should check the window shape and size and ensure the user conforms to this and throw an error (or perhaps a warning) if they do not.

Coming from other frameworks, users would be used to passing in window shapes/strides with the same rank as the number of spatial dimensions. For example, for a 2D maxpool, users would be used to passing in window_shape=(2, 2), which would resolve to (1, 1, 2, 2) in Haiku which is likely not what the user intended. Throwing an error would force the user to pass in window_shape=(1, 2, 2, 1) or window_shape=2, and would save them debugging time.

n2cholas avatar Oct 06 '20 07:10 n2cholas

This did cause a silent error for me as well.

pranavsubramani avatar Oct 06 '20 20:10 pranavsubramani