neural-tangents
neural-tangents copied to clipboard
Fix masking with more than one input feature
This PR fixes an issue where passing inputs x1
and x2
of dimension (..., n_feat)
with n_feat > 1
to a kernel_fn
generates masks in the Kernel which still carry the original feature dimension.
This creates a problem for layers such as GlobalAvgPool
, since for kernels that represent infinitely wide layers it is generally assumed that the channel dimension is one.
I implemented a reduction over the mask's last dimension using np.any(.., keepdims=True)
, which assumes that it doesn't matter if any or all of the features are masked. A user warning spells this out.
I can create an issue with a reproducer shortly, if more detail is needed. Can also add a unit test.
Thanks a lot Jens, could you share a code sample that would fail before this change? I think we should support n_feat > 1
, so I wonder if this points to a bug that needs to be fixed elsewhere.