neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

Feature masks do not get reduced in the kernel

Open jglaser opened this issue 1 year ago • 2 comments

I am observing an error message when providing masked inputs with more than one feature dimensions to a kernel that involves stax.GlobalAvgPool()

Reproducer:

import jax
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax

if __name__ == '__main__':
    # input tokens
    X = 3*np.ones((10,512))

    mask_constant = 10
    pad_token = 0

    # pad some elements
    X = X.at[0,4].set(pad_token)
    X = X.at[7,422].set(pad_token)
    print('before encode ',X.shape)

    # vocabulary size
    n_vocab = 5
    def encode(x, mask_constant):
        # zero mean embeddings
        res = jax.nn.one_hot(x, n_vocab)
        res -= np.mean(res, axis=-1, keepdims=True)
        return np.where(x[..., None] == pad_token, mask_constant, res)

    X = encode(X, mask_constant=mask_constant)
    print('after encode ', X.shape)

    # trace over output correlations
    _, _, kernel_fn_avg = stax.GlobalAvgPool()
    input_fn = nt.batch(kernel_fn_avg, batch_size=2)
    cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
    print('output ', cov.shape)

Output

$ python mask_reproducer.py 
Attempting to register factory for plugin cuBLAS when one has already been registered
before encode  (10, 512)
after encode  (10, 512, 5)
Traceback (most recent call last):
  File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
    cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
    return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
    _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
    return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
    return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
    return _f(x_or_kernel, *args_np, **kwargs_np)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/api.py", line 525, in cache_miss
    out_flat = xla.xla_call(
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 687, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
    compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 295, in memoized_fun
    ans = call(fun, *args)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 248, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 293, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2167, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2117, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
    return f(_x_or_kernel, *_args, **_kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
    return kernel_fn_x1(x1_or_kernel, x2, get,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
    mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
    return _mask_fn(mask, input_shape)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
    _check_is_implemented(mask, channel_axis)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
    raise NotImplementedError(
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
    cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
    return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
    _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
    return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
    return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
    return _f(x_or_kernel, *args_np, **kwargs_np)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
    return f(_x_or_kernel, *_args, **_kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
    return kernel_fn_x1(x1_or_kernel, x2, get,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
    mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
    return _mask_fn(mask, input_shape)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
    _check_is_implemented(mask, channel_axis)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
    raise NotImplementedError(
NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new

Expected output (with the fix from #158):

before encode  (10, 512)
after encode  (10, 512, 5)
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
  warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
  warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
output  (10, 10)

I admit that the warning is a little noisy, perhaps it could be omitted and the reduction mentioned in the documentation.

jglaser avatar Aug 15 '22 03:08 jglaser

Thanks for the detailed repro!

This touches on a bit of a fragile part of the library, which may take a while to figure out how fix properly (namely, we don't support different channelwise masks inside the network; and pooling is a layer that can produce different channelwise masks, given different input channelwise masks). I'm hesitant to employ the proposed solution though, because it might lead to silent errors if the user does use different channel masks and doesn't pay attention to the warning.

Two user-side short-term solutions:

  • Have a network start at the bottom with a parametric layer (e.g. Dense, Conv, etc), since these layers do produce outputs that have identical masks for all channels (regardless of different masks for different channels in the inputs).
  • Apply any non-parametric layers (e.g. pooling, relu, etc) as part of dataset preprocessing, and not part of the network.

Lmk if this would work for the short-term!

romanngg avatar Aug 19 '22 18:08 romanngg

Yup, starting the network with a Dense layer works

jglaser avatar Sep 11 '22 16:09 jglaser